Skip to content

Commit

Permalink
introduce raii connection object
Browse files Browse the repository at this point in the history
  • Loading branch information
pacman82 committed Jul 14, 2024
1 parent 93378d6 commit 172c6eb
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 75 deletions.
99 changes: 57 additions & 42 deletions python/arrow_odbc/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,61 @@ def to_bytes_and_len(value: Optional[str]) -> Tuple[bytes, int]:
return (value_bytes, value_len)


def connect_to_database(
connection_string: str,
user: Optional[str],
password: Optional[str],
login_timeout_sec: Optional[int],
packet_size: Optional[int],
) -> Any:
connection_string_bytes = connection_string.encode("utf-8")

(user_bytes, user_len) = to_bytes_and_len(user)
(password_bytes, password_len) = to_bytes_and_len(password)

# We use a pointer to pass the login time, so NULL can represent None
if login_timeout_sec is None:
login_timeout_sec_ptr = FFI.NULL
else:
login_timeout_sec_ptr = ffi.new("uint32_t *")
login_timeout_sec_ptr[0] = login_timeout_sec
class ConnectionRaii:
def __init__(
self,
connection_string: str,
user: Optional[str],
password: Optional[str],
login_timeout_sec: Optional[int],
packet_size: Optional[int],
) -> None:
connection_string_bytes = connection_string.encode("utf-8")

if packet_size is None:
packet_size_ptr = FFI.NULL
else:
packet_size_ptr = ffi.new("uint32_t *")
packet_size_ptr[0] = packet_size

connection_out = ffi.new("OdbcConnection **")

# Open connection to ODBC Data Source
error = lib.arrow_odbc_connect_with_connection_string(
connection_string_bytes,
len(connection_string_bytes),
user_bytes,
user_len,
password_bytes,
password_len,
login_timeout_sec_ptr,
packet_size_ptr,
connection_out,
)
# See if we connected successfully and return an error if not
raise_on_error(error)
# Dereference output pointer. This gives us an `OdbcConnection *`
return connection_out[0]
(user_bytes, user_len) = to_bytes_and_len(user)
(password_bytes, password_len) = to_bytes_and_len(password)

# We use a pointer to pass the login time, so NULL can represent None
if login_timeout_sec is None:
login_timeout_sec_ptr = FFI.NULL
else:
login_timeout_sec_ptr = ffi.new("uint32_t *")
login_timeout_sec_ptr[0] = login_timeout_sec

if packet_size is None:
packet_size_ptr = FFI.NULL
else:
packet_size_ptr = ffi.new("uint32_t *")
packet_size_ptr[0] = packet_size

connection_out = ffi.new("ArrowOdbcConnection **")

# Open connection to ODBC Data Source
error = lib.arrow_odbc_connection_make(
connection_string_bytes,
len(connection_string_bytes),
user_bytes,
user_len,
password_bytes,
password_len,
login_timeout_sec_ptr,
packet_size_ptr,
connection_out,
)
# See if we connected successfully and return an error if not
raise_on_error(error)
# Dereference output pointer. This gives us an `ArrowOdbcConnection *`. We take ownership of
# the ArrowOdbcConnection and must take care to free it.
self.handle = connection_out[0]


def _arrow_odbc_connection(self) -> Any:
"""
Give access to the inner ArrowOdbcConnection handle
"""
return self.handle


def __del__(self):
# Free the resources associated with this handle.
lib.arrow_odbc_connection_free(self.handle)
14 changes: 3 additions & 11 deletions python/arrow_odbc/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pyarrow.cffi import ffi as arrow_ffi # type: ignore
from pyarrow import RecordBatch, Schema, Array # type: ignore

from arrow_odbc.connect import to_bytes_and_len, connect_to_database # type: ignore
from arrow_odbc.connect import to_bytes_and_len, ConnectionRaii # type: ignore

from .arrow_odbc import ffi, lib # type: ignore
from .error import raise_on_error
Expand Down Expand Up @@ -76,18 +76,10 @@ def connect(
login_timeout_sec: Optional[int],
packet_size: Optional[int],
):
connection = connect_to_database(
connection = ConnectionRaii(
connection_string, user, password, login_timeout_sec, packet_size
)

# Connecting to the database has been successful. Note that connection does not truly take
# ownership of the connection. If it runs out of scope (e.g. due to a raised exception) the
# connection would not be closed and its associated resources would not be freed.
# However, this is fine since everything from here on out until we call
# arrow_odbc_reader_set_connection is infalliable. arrow_odbc_reader_connection will truly
# take ownership of the connection.

lib.arrow_odbc_reader_set_connection(self.handle, connection)
lib.arrow_odbc_reader_set_connection(self.handle, connection._arrow_odbc_connection())

def query(
self,
Expand Down
6 changes: 3 additions & 3 deletions python/arrow_odbc/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pyarrow import RecordBatchReader
from pyarrow.cffi import ffi as arrow_ffi
from arrow_odbc.connect import connect_to_database
from arrow_odbc.connect import ConnectionRaii

from .arrow_odbc import ffi, lib # type: ignore
from .error import raise_on_error
Expand Down Expand Up @@ -124,7 +124,7 @@ def dataframe_to_table(df):
# Export the schema to the C Data structures.
reader.schema._export_to_c(c_schema_ptr)

connection = connect_to_database(
connection = ConnectionRaii(
connection_string, user, password, login_timeout_sec, packet_size
)

Expand All @@ -136,7 +136,7 @@ def dataframe_to_table(df):

writer_out = ffi.new("ArrowOdbcWriter **")
error = lib.arrow_odbc_writer_make(
connection,
connection._arrow_odbc_connection(),
table_bytes,
len(table_bytes),
chunk_size,
Expand Down
34 changes: 24 additions & 10 deletions src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,38 @@
use std::{borrow::Cow, ptr::null_mut, slice, str};
use std::{
borrow::Cow,
ptr::{null_mut, NonNull},
slice, str,
};

use arrow_odbc::odbc_api::{escape_attribute_value, Connection, ConnectionOptions, Environment};
use log::debug;

use crate::{try_, ArrowOdbcError, ENV};

/// Opaque type to transport connection to an ODBC Datasource over language boundry
pub struct OdbcConnection(Connection<'static>);
pub struct ArrowOdbcConnection(Option<Connection<'static>>);

impl OdbcConnection {
impl ArrowOdbcConnection {
pub fn new(connection: Connection<'static>) -> Self {
OdbcConnection(connection)
ArrowOdbcConnection(Some(connection))
}

/// Take the inner connection out of its wrapper
pub fn take(self) -> Connection<'static> {
self.0
pub fn take(&mut self) -> Connection<'static> {
self.0.take().unwrap()
}
}

/// Frees the resources associated with an ArrowOdbcConnection
///
/// # Safety
///
/// `reader` must point to a valid ArrowOdbcConnection.
#[no_mangle]
pub unsafe extern "C" fn arrow_odbc_connection_free(connection: NonNull<ArrowOdbcConnection>) {
drop(Box::from_raw(connection.as_ptr()));
}

/// Allocate and open an ODBC connection using the specified connection string. In case of an error
/// this function returns a NULL pointer.
///
Expand All @@ -28,7 +42,7 @@ impl OdbcConnection {
/// hold the length of text in `connection_string_buf`.
/// `user` and or `password` are optional and are allowed to be `NULL`.
#[no_mangle]
pub unsafe extern "C" fn arrow_odbc_connect_with_connection_string(
pub unsafe extern "C" fn arrow_odbc_connection_make(
connection_string_buf: *const u8,
connection_string_len: usize,
user: *const u8,
Expand All @@ -37,7 +51,7 @@ pub unsafe extern "C" fn arrow_odbc_connect_with_connection_string(
password_len: usize,
login_timeout_sec_ptr: *const u32,
packet_size_ptr: *const u32,
connection_out: *mut *mut OdbcConnection,
connection_out: *mut *mut ArrowOdbcConnection,
) -> *mut ArrowOdbcError {
let env = if let Some(env) = ENV.get() {
// Use existing environment
Expand Down Expand Up @@ -78,7 +92,7 @@ pub unsafe extern "C" fn arrow_odbc_connect_with_connection_string(
let dbms_name = try_!(connection.database_management_system_name());
debug!("Database managment system name as reported by ODBC: {dbms_name}");

*connection_out = Box::into_raw(Box::new(OdbcConnection::new(connection)));
*connection_out = Box::into_raw(Box::new(ArrowOdbcConnection::new(connection)));
null_mut()
}

Expand All @@ -98,4 +112,4 @@ unsafe fn append_attribute(
let text = str::from_utf8(bytes).unwrap();
let escaped = escape_attribute_value(text);
*connection_string = format!("{connection_string}{attribute_name}={escaped};").into()
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::sync::OnceLock;
use arrow_odbc::odbc_api::Environment;

pub use self::{
connection::{arrow_odbc_connect_with_connection_string, OdbcConnection},
connection::{arrow_odbc_connection_make, ArrowOdbcConnection},
error::{arrow_odbc_error_free, arrow_odbc_error_message, ArrowOdbcError},
logging::arrow_odbc_log_to_stderr,
reader::{arrow_odbc_reader_free, arrow_odbc_reader_next, ArrowOdbcReader},
Expand Down
7 changes: 3 additions & 4 deletions src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::{
use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
use arrow_odbc::OdbcReaderBuilder;

use crate::{parameter::ArrowOdbcParameter, try_, ArrowOdbcError, OdbcConnection};
use crate::{parameter::ArrowOdbcParameter, try_, ArrowOdbcError, ArrowOdbcConnection};

pub use self::arrow_odbc_reader::ArrowOdbcReader;

Expand All @@ -29,10 +29,9 @@ pub use self::arrow_odbc_reader::ArrowOdbcReader;
#[no_mangle]
pub unsafe extern "C" fn arrow_odbc_reader_set_connection(
mut reader: NonNull<ArrowOdbcReader>,
connection: NonNull<OdbcConnection>,
mut connection: NonNull<ArrowOdbcConnection>,
) {
let connection = *Box::from_raw(connection.as_ptr());
reader.as_mut().set_connection(connection.take());
reader.as_mut().set_connection(connection.as_mut().take());
}

/// Creates an Arrow ODBC reader instance.
Expand Down
7 changes: 3 additions & 4 deletions src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use arrow_odbc::{
OdbcWriter,
};

use crate::{try_, ArrowOdbcError, OdbcConnection};
use crate::{try_, ArrowOdbcError, ArrowOdbcConnection};

/// Opaque type holding all the state associated with an ODBC writer implementation in Rust. This
/// type also has ownership of the ODBC Connection handle.
Expand Down Expand Up @@ -43,15 +43,14 @@ pub unsafe extern "C" fn arrow_odbc_writer_free(writer: NonNull<ArrowOdbcWriter>
/// is transferred to the caller.
#[no_mangle]
pub unsafe extern "C" fn arrow_odbc_writer_make(
connection: NonNull<OdbcConnection>,
mut connection: NonNull<ArrowOdbcConnection>,
table_buf: *const u8,
table_len: usize,
chunk_size: usize,
schema: *const c_void,
writer_out: *mut *mut ArrowOdbcWriter,
) -> *mut ArrowOdbcError {
let connection = *Box::from_raw(connection.as_ptr());
let connection = connection.take();
let connection = connection.as_mut().take();

let table = slice::from_raw_parts(table_buf, table_len);
let table = str::from_utf8(table).unwrap();
Expand Down

0 comments on commit 172c6eb

Please sign in to comment.