Skip to content

Commit

Permalink
remove conenction state from reader
Browse files Browse the repository at this point in the history
  • Loading branch information
pacman82 committed Sep 21, 2024
1 parent a890d99 commit 165c56e
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 59 deletions.
5 changes: 1 addition & 4 deletions python/arrow_odbc/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ def to_bytes_and_len(value: Optional[str]) -> Tuple[bytes, int]:


class ConnectionRaii:
def __init__(
self,
handle: Any
) -> None:
def __init__(self, handle: Any) -> None:
self.handle = handle

def _arrow_odbc_connection(self) -> Any:
Expand Down
18 changes: 5 additions & 13 deletions python/arrow_odbc/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pyarrow.cffi import ffi as arrow_ffi # type: ignore
from pyarrow import RecordBatch, Schema, Array # type: ignore

from arrow_odbc.connect import to_bytes_and_len, connect # type: ignore
from arrow_odbc.connect import to_bytes_and_len, connect, ConnectionRaii # type: ignore

from .arrow_odbc import ffi, lib # type: ignore
from .error import raise_on_error
Expand Down Expand Up @@ -68,19 +68,9 @@ def into_concurrent(self):
error = lib.arrow_odbc_reader_into_concurrent(self.handle)
raise_on_error(error)

def connect(
self,
connection_string: str,
user: Optional[str],
password: Optional[str],
login_timeout_sec: Optional[int],
packet_size: Optional[int],
):
connection = connect(connection_string, user, password, login_timeout_sec, packet_size)
lib.arrow_odbc_reader_set_connection(self.handle, connection._arrow_odbc_connection())

def query(
self,
connection: ConnectionRaii,
query: str,
parameters: Optional[List[Optional[str]]],
):
Expand Down Expand Up @@ -111,6 +101,7 @@ def query(

error = lib.arrow_odbc_reader_query(
self.handle,
connection._arrow_odbc_connection(),
query_bytes,
len(query_bytes),
parameters_array,
Expand Down Expand Up @@ -486,7 +477,7 @@ def read_arrow_batches_from_odbc(
"""
reader = _BatchReaderRaii()

reader.connect(
connection = connect(
connection_string=connection_string,
user=user,
password=password,
Expand All @@ -495,6 +486,7 @@ def read_arrow_batches_from_odbc(
)

reader.query(
connection=connection,
query=query,
parameters=parameters,
)
Expand Down
21 changes: 3 additions & 18 deletions src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,6 @@ use crate::{parameter::ArrowOdbcParameter, try_, ArrowOdbcError, ArrowOdbcConnec
pub use self::arrow_odbc_reader::ArrowOdbcReader;


/// Set a reader into connection state
///
/// # Safety
///
/// * `reader` must point to a valid reader in empty state.
/// * `connection` must point to a valid OdbcConnection. This function takes ownership of the
/// connection, even in case of an error. So The connection must not be freed explicitly
/// afterwards.
#[no_mangle]
pub unsafe extern "C" fn arrow_odbc_reader_set_connection(
mut reader: NonNull<ArrowOdbcReader>,
mut connection: NonNull<ArrowOdbcConnection>,
) {
reader.as_mut().set_connection(connection.as_mut().take());
}

/// Creates an Arrow ODBC reader instance.
///
/// Executes the SQL Query and moves the reader into cursor state.
Expand Down Expand Up @@ -65,11 +48,13 @@ pub unsafe extern "C" fn arrow_odbc_reader_set_connection(
#[no_mangle]
pub unsafe extern "C" fn arrow_odbc_reader_query(
mut reader: NonNull<ArrowOdbcReader>,
mut connection: NonNull<ArrowOdbcConnection>,
query_buf: *const u8,
query_len: usize,
parameters: *const *mut ArrowOdbcParameter,
parameters_len: usize,
) -> *mut ArrowOdbcError {
let connection = connection.as_mut().take();
// Transtlate C Args into more idiomatic rust representations
let query = slice::from_raw_parts(query_buf, query_len);
let query = str::from_utf8(query).unwrap();
Expand All @@ -83,7 +68,7 @@ pub unsafe extern "C" fn arrow_odbc_reader_query(
.collect()
};

try_!(reader.as_mut().promote_to_cursor(query, &parameters[..]));
try_!(reader.as_mut().promote_to_cursor(connection, query, &parameters[..]));

null_mut() // Ok(())
}
Expand Down
30 changes: 6 additions & 24 deletions src/reader/arrow_odbc_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ use arrow_odbc::{
pub enum ArrowOdbcReader {
/// Either a freshly created instance, or all our resources have been moved to another instance.
Empty,
/// Either a connection freshly generated, or we have fully consumed the result set, and there
/// is nothing more to fetch.
Connection(Connection<'static>),
/// We can not read batches in cursor state yet. We still would need to figure out the schema
/// of the source data and (usually) infer the arrow schema from it. So in earlier versions
/// we created everything directly in `Reader` state. However, if we want the user to be able
Expand All @@ -54,7 +51,7 @@ impl ArrowOdbcReader {
&mut self,
) -> Result<Option<(FFI_ArrowArray, FFI_ArrowSchema)>, ArrowOdbcError> {
let next = match self {
ArrowOdbcReader::Empty | ArrowOdbcReader::Connection(_) => None,
ArrowOdbcReader::Empty => None,
ArrowOdbcReader::Cursor(_) => {
unreachable!("Python code must not allow to call next_batch from cursor state")
}
Expand Down Expand Up @@ -85,7 +82,7 @@ impl ArrowOdbcReader {
let cursor = match tmp_self {
// In case there has been a query without a result set, we could be in an empty state.
// Let's just keep it, there is simply nothing to bind a buffer to.
ArrowOdbcReader::Empty | ArrowOdbcReader::Connection(_) => return Ok(()),
ArrowOdbcReader::Empty => return Ok(()),
ArrowOdbcReader::Cursor(cursor) => cursor,
ArrowOdbcReader::Reader(_) | ArrowOdbcReader::ConcurrentReader(_) => {
unreachable!("Python part must ensure to only promote cursors to readers.")
Expand All @@ -97,40 +94,25 @@ impl ArrowOdbcReader {
Ok(())
}

/// Take ownership of `connection` and set reader to connection state. All resources which
/// previously might have been associated with `self` would be deallocated.
pub fn set_connection(&mut self, connection: Connection<'static>) {
*self = ArrowOdbcReader::Connection(connection)
}

/// Promote Connection to cursor state. If this operation fails, the reader will stay in
/// connection state.
pub fn promote_to_cursor(
&mut self,
conn: Connection<'static>,
query: &str,
params: impl ParameterCollectionRef,
) -> Result<(), ArrowOdbcError> {
// Move self into a temporary instance we own, in order to take ownership of the inner
// reader and move it to a different state.
let mut tmp_self = ArrowOdbcReader::Empty;
swap(self, &mut tmp_self);
let conn = match tmp_self {
ArrowOdbcReader::Connection(conn) => conn,
ArrowOdbcReader::Empty
| ArrowOdbcReader::Cursor(_)
| ArrowOdbcReader::Reader(_)
| ArrowOdbcReader::ConcurrentReader(_) => {
unreachable!("Python part must ensure to only connections to cursors.")
}
};

match conn.into_cursor(query, params) {
Ok(None) => (),
Ok(Some(cursor)) => {
*self = ArrowOdbcReader::Cursor(cursor);
}
Err(error) => {
*self = ArrowOdbcReader::Connection(error.connection);
return Err(error.error.into());
}
}
Expand All @@ -145,7 +127,7 @@ impl ArrowOdbcReader {
let mut tmp_self = ArrowOdbcReader::Empty;
swap(self, &mut tmp_self);
let cursor = match tmp_self {
ArrowOdbcReader::Empty | ArrowOdbcReader::Connection(_) => return Ok(false),
ArrowOdbcReader::Empty => return Ok(false),
ArrowOdbcReader::Cursor(cursor) => cursor,
ArrowOdbcReader::Reader(inner) => inner.into_cursor()?,
ArrowOdbcReader::ConcurrentReader(inner) => inner.into_cursor()?,
Expand All @@ -161,7 +143,7 @@ impl ArrowOdbcReader {

pub fn schema(&mut self) -> Result<FFI_ArrowSchema, ArrowOdbcError> {
let schema_ffi = match self {
ArrowOdbcReader::Empty | ArrowOdbcReader::Connection(_) => {
ArrowOdbcReader::Empty => {
// A schema with no columns. Different from FFI_ArrowSchema empty, which can not be
// imported into pyarrow
let schema = Schema::empty();
Expand Down Expand Up @@ -197,7 +179,7 @@ impl ArrowOdbcReader {

*self = match tmp_self {
// Nothing to do. There is nothing left to fetch.
reader @ (ArrowOdbcReader::Empty | ArrowOdbcReader::Connection(_)) => reader,
ArrowOdbcReader::Empty => ArrowOdbcReader::Empty,
ArrowOdbcReader::Cursor(_) => {
unreachable!("Python code must not allow to call into_concurrent from cursor state")
}
Expand Down

0 comments on commit 165c56e

Please sign in to comment.