diff --git a/python/arrow_odbc/connect.py b/python/arrow_odbc/connect.py index 5e776ce..cba9fbe 100644 --- a/python/arrow_odbc/connect.py +++ b/python/arrow_odbc/connect.py @@ -16,46 +16,61 @@ def to_bytes_and_len(value: Optional[str]) -> Tuple[bytes, int]: return (value_bytes, value_len) -def connect_to_database( - connection_string: str, - user: Optional[str], - password: Optional[str], - login_timeout_sec: Optional[int], - packet_size: Optional[int], -) -> Any: - connection_string_bytes = connection_string.encode("utf-8") - - (user_bytes, user_len) = to_bytes_and_len(user) - (password_bytes, password_len) = to_bytes_and_len(password) - - # We use a pointer to pass the login time, so NULL can represent None - if login_timeout_sec is None: - login_timeout_sec_ptr = FFI.NULL - else: - login_timeout_sec_ptr = ffi.new("uint32_t *") - login_timeout_sec_ptr[0] = login_timeout_sec +class ConnectionRaii: + def __init__( + self, + connection_string: str, + user: Optional[str], + password: Optional[str], + login_timeout_sec: Optional[int], + packet_size: Optional[int], + ) -> None: + connection_string_bytes = connection_string.encode("utf-8") - if packet_size is None: - packet_size_ptr = FFI.NULL - else: - packet_size_ptr = ffi.new("uint32_t *") - packet_size_ptr[0] = packet_size - - connection_out = ffi.new("OdbcConnection **") - - # Open connection to ODBC Data Source - error = lib.arrow_odbc_connect_with_connection_string( - connection_string_bytes, - len(connection_string_bytes), - user_bytes, - user_len, - password_bytes, - password_len, - login_timeout_sec_ptr, - packet_size_ptr, - connection_out, - ) - # See if we connected successfully and return an error if not - raise_on_error(error) - # Dereference output pointer. This gives us an `OdbcConnection *` - return connection_out[0] + (user_bytes, user_len) = to_bytes_and_len(user) + (password_bytes, password_len) = to_bytes_and_len(password) + + # We use a pointer to pass the login time, so NULL can represent None + if login_timeout_sec is None: + login_timeout_sec_ptr = FFI.NULL + else: + login_timeout_sec_ptr = ffi.new("uint32_t *") + login_timeout_sec_ptr[0] = login_timeout_sec + + if packet_size is None: + packet_size_ptr = FFI.NULL + else: + packet_size_ptr = ffi.new("uint32_t *") + packet_size_ptr[0] = packet_size + + connection_out = ffi.new("ArrowOdbcConnection **") + + # Open connection to ODBC Data Source + error = lib.arrow_odbc_connection_make( + connection_string_bytes, + len(connection_string_bytes), + user_bytes, + user_len, + password_bytes, + password_len, + login_timeout_sec_ptr, + packet_size_ptr, + connection_out, + ) + # See if we connected successfully and return an error if not + raise_on_error(error) + # Dereference output pointer. This gives us an `ArrowOdbcConnection *`. We take ownership of + # the ArrowOdbcConnection and must take care to free it. + self.handle = connection_out[0] + + + def _arrow_odbc_connection(self) -> Any: + """ + Give access to the inner ArrowOdbcConnection handle + """ + return self.handle + + + def __del__(self): + # Free the resources associated with this handle. + lib.arrow_odbc_connection_free(self.handle) diff --git a/python/arrow_odbc/reader.py b/python/arrow_odbc/reader.py index 966d360..b2f7495 100644 --- a/python/arrow_odbc/reader.py +++ b/python/arrow_odbc/reader.py @@ -6,7 +6,7 @@ from pyarrow.cffi import ffi as arrow_ffi # type: ignore from pyarrow import RecordBatch, Schema, Array # type: ignore -from arrow_odbc.connect import to_bytes_and_len, connect_to_database # type: ignore +from arrow_odbc.connect import to_bytes_and_len, ConnectionRaii # type: ignore from .arrow_odbc import ffi, lib # type: ignore from .error import raise_on_error @@ -76,18 +76,10 @@ def connect( login_timeout_sec: Optional[int], packet_size: Optional[int], ): - connection = connect_to_database( + connection = ConnectionRaii( connection_string, user, password, login_timeout_sec, packet_size ) - - # Connecting to the database has been successful. Note that connection does not truly take - # ownership of the connection. If it runs out of scope (e.g. due to a raised exception) the - # connection would not be closed and its associated resources would not be freed. - # However, this is fine since everything from here on out until we call - # arrow_odbc_reader_set_connection is infalliable. arrow_odbc_reader_connection will truly - # take ownership of the connection. - - lib.arrow_odbc_reader_set_connection(self.handle, connection) + lib.arrow_odbc_reader_set_connection(self.handle, connection._arrow_odbc_connection()) def query( self, diff --git a/python/arrow_odbc/writer.py b/python/arrow_odbc/writer.py index b4d0a7f..d4840bf 100644 --- a/python/arrow_odbc/writer.py +++ b/python/arrow_odbc/writer.py @@ -2,7 +2,7 @@ from pyarrow import RecordBatchReader from pyarrow.cffi import ffi as arrow_ffi -from arrow_odbc.connect import connect_to_database +from arrow_odbc.connect import ConnectionRaii from .arrow_odbc import ffi, lib # type: ignore from .error import raise_on_error @@ -124,7 +124,7 @@ def dataframe_to_table(df): # Export the schema to the C Data structures. reader.schema._export_to_c(c_schema_ptr) - connection = connect_to_database( + connection = ConnectionRaii( connection_string, user, password, login_timeout_sec, packet_size ) @@ -136,7 +136,7 @@ def dataframe_to_table(df): writer_out = ffi.new("ArrowOdbcWriter **") error = lib.arrow_odbc_writer_make( - connection, + connection._arrow_odbc_connection(), table_bytes, len(table_bytes), chunk_size, diff --git a/src/connection.rs b/src/connection.rs index 5437dbe..ce14438 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,4 +1,8 @@ -use std::{borrow::Cow, ptr::null_mut, slice, str}; +use std::{ + borrow::Cow, + ptr::{null_mut, NonNull}, + slice, str, +}; use arrow_odbc::odbc_api::{escape_attribute_value, Connection, ConnectionOptions, Environment}; use log::debug; @@ -6,19 +10,29 @@ use log::debug; use crate::{try_, ArrowOdbcError, ENV}; /// Opaque type to transport connection to an ODBC Datasource over language boundry -pub struct OdbcConnection(Connection<'static>); +pub struct ArrowOdbcConnection(Option>); -impl OdbcConnection { +impl ArrowOdbcConnection { pub fn new(connection: Connection<'static>) -> Self { - OdbcConnection(connection) + ArrowOdbcConnection(Some(connection)) } /// Take the inner connection out of its wrapper - pub fn take(self) -> Connection<'static> { - self.0 + pub fn take(&mut self) -> Connection<'static> { + self.0.take().unwrap() } } +/// Frees the resources associated with an ArrowOdbcConnection +/// +/// # Safety +/// +/// `reader` must point to a valid ArrowOdbcConnection. +#[no_mangle] +pub unsafe extern "C" fn arrow_odbc_connection_free(connection: NonNull) { + drop(Box::from_raw(connection.as_ptr())); +} + /// Allocate and open an ODBC connection using the specified connection string. In case of an error /// this function returns a NULL pointer. /// @@ -28,7 +42,7 @@ impl OdbcConnection { /// hold the length of text in `connection_string_buf`. /// `user` and or `password` are optional and are allowed to be `NULL`. #[no_mangle] -pub unsafe extern "C" fn arrow_odbc_connect_with_connection_string( +pub unsafe extern "C" fn arrow_odbc_connection_make( connection_string_buf: *const u8, connection_string_len: usize, user: *const u8, @@ -37,7 +51,7 @@ pub unsafe extern "C" fn arrow_odbc_connect_with_connection_string( password_len: usize, login_timeout_sec_ptr: *const u32, packet_size_ptr: *const u32, - connection_out: *mut *mut OdbcConnection, + connection_out: *mut *mut ArrowOdbcConnection, ) -> *mut ArrowOdbcError { let env = if let Some(env) = ENV.get() { // Use existing environment @@ -78,7 +92,7 @@ pub unsafe extern "C" fn arrow_odbc_connect_with_connection_string( let dbms_name = try_!(connection.database_management_system_name()); debug!("Database managment system name as reported by ODBC: {dbms_name}"); - *connection_out = Box::into_raw(Box::new(OdbcConnection::new(connection))); + *connection_out = Box::into_raw(Box::new(ArrowOdbcConnection::new(connection))); null_mut() } @@ -98,4 +112,4 @@ unsafe fn append_attribute( let text = str::from_utf8(bytes).unwrap(); let escaped = escape_attribute_value(text); *connection_string = format!("{connection_string}{attribute_name}={escaped};").into() -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index 6c7679a..524a3ac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,7 @@ use std::sync::OnceLock; use arrow_odbc::odbc_api::Environment; pub use self::{ - connection::{arrow_odbc_connect_with_connection_string, OdbcConnection}, + connection::{arrow_odbc_connection_make, ArrowOdbcConnection}, error::{arrow_odbc_error_free, arrow_odbc_error_message, ArrowOdbcError}, logging::arrow_odbc_log_to_stderr, reader::{arrow_odbc_reader_free, arrow_odbc_reader_next, ArrowOdbcReader}, diff --git a/src/reader.rs b/src/reader.rs index d4ea7ce..8cd99cb 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -12,7 +12,7 @@ use std::{ use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema}; use arrow_odbc::OdbcReaderBuilder; -use crate::{parameter::ArrowOdbcParameter, try_, ArrowOdbcError, OdbcConnection}; +use crate::{parameter::ArrowOdbcParameter, try_, ArrowOdbcError, ArrowOdbcConnection}; pub use self::arrow_odbc_reader::ArrowOdbcReader; @@ -29,10 +29,9 @@ pub use self::arrow_odbc_reader::ArrowOdbcReader; #[no_mangle] pub unsafe extern "C" fn arrow_odbc_reader_set_connection( mut reader: NonNull, - connection: NonNull, + mut connection: NonNull, ) { - let connection = *Box::from_raw(connection.as_ptr()); - reader.as_mut().set_connection(connection.take()); + reader.as_mut().set_connection(connection.as_mut().take()); } /// Creates an Arrow ODBC reader instance. diff --git a/src/writer.rs b/src/writer.rs index 03b2713..d001f7d 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -11,7 +11,7 @@ use arrow_odbc::{ OdbcWriter, }; -use crate::{try_, ArrowOdbcError, OdbcConnection}; +use crate::{try_, ArrowOdbcError, ArrowOdbcConnection}; /// Opaque type holding all the state associated with an ODBC writer implementation in Rust. This /// type also has ownership of the ODBC Connection handle. @@ -43,15 +43,14 @@ pub unsafe extern "C" fn arrow_odbc_writer_free(writer: NonNull /// is transferred to the caller. #[no_mangle] pub unsafe extern "C" fn arrow_odbc_writer_make( - connection: NonNull, + mut connection: NonNull, table_buf: *const u8, table_len: usize, chunk_size: usize, schema: *const c_void, writer_out: *mut *mut ArrowOdbcWriter, ) -> *mut ArrowOdbcError { - let connection = *Box::from_raw(connection.as_ptr()); - let connection = connection.take(); + let connection = connection.as_mut().take(); let table = slice::from_raw_parts(table_buf, table_len); let table = str::from_utf8(table).unwrap();