This repository has been archived by the owner on Feb 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 224
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
83c47fc
commit 840c8a3
Showing
7 changed files
with
288 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
use stdext::function_name; | ||
|
||
use arrow2::array::{Array, BinaryArray, BooleanArray, Int32Array, Utf8Array}; | ||
use arrow2::chunk::Chunk; | ||
use arrow2::datatypes::{DataType, Field}; | ||
use arrow2::error::Result; | ||
use arrow2::io::odbc::api::{Connection, Cursor}; | ||
use arrow2::io::odbc::write::{buffer_from_description, infer_descriptions, serialize}; | ||
|
||
use super::read::read; | ||
use super::{setup_empty_table, ENV, MSSQL}; | ||
|
||
fn test( | ||
expected: Chunk<Box<dyn Array>>, | ||
fields: Vec<Field>, | ||
type_: &str, | ||
table_name: &str, | ||
) -> Result<()> { | ||
let connection = ENV.connect_with_connection_string(MSSQL).unwrap(); | ||
setup_empty_table(&connection, table_name, &[type_]).unwrap(); | ||
|
||
let query = &format!("INSERT INTO {table_name} (a) VALUES (?)"); | ||
let mut a = connection.prepare(query).unwrap(); | ||
|
||
let mut buffer = buffer_from_description(infer_descriptions(&fields)?, expected.len()); | ||
|
||
// write | ||
buffer.set_num_rows(expected.len()); | ||
let array = &expected.columns()[0]; | ||
|
||
serialize(array.as_ref(), &mut buffer.column_mut(0))?; | ||
|
||
a.execute(&buffer).unwrap(); | ||
|
||
// read | ||
let query = format!("SELECT a FROM {table_name} ORDER BY id"); | ||
let chunks = read(&connection, &query)?.1; | ||
|
||
assert_eq!(chunks[0], expected); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn int() -> Result<()> { | ||
let table_name = function_name!().rsplit_once(':').unwrap().1; | ||
let table_name = format!("write_{}", table_name); | ||
let expected = Chunk::new(vec![Box::new(Int32Array::from_slice([1])) as _]); | ||
|
||
test( | ||
expected, | ||
vec![Field::new("a", DataType::Int32, false)], | ||
"INT", | ||
&table_name, | ||
) | ||
} | ||
|
||
#[test] | ||
fn int_nullable() -> Result<()> { | ||
let table_name = function_name!().rsplit_once(':').unwrap().1; | ||
let table_name = format!("write_{}", table_name); | ||
let expected = Chunk::new(vec![Box::new(Int32Array::from([Some(1), None])) as _]); | ||
|
||
test( | ||
expected, | ||
vec![Field::new("a", DataType::Int32, true)], | ||
"INT", | ||
&table_name, | ||
) | ||
} | ||
|
||
#[test] | ||
fn bool() -> Result<()> { | ||
let table_name = function_name!().rsplit_once(':').unwrap().1; | ||
let table_name = format!("write_{}", table_name); | ||
let expected = Chunk::new(vec![Box::new(BooleanArray::from_slice([true, false])) as _]); | ||
|
||
test( | ||
expected, | ||
vec![Field::new("a", DataType::Boolean, false)], | ||
"BIT", | ||
&table_name, | ||
) | ||
} | ||
|
||
#[test] | ||
fn bool_nullable() -> Result<()> { | ||
let table_name = function_name!().rsplit_once(':').unwrap().1; | ||
let table_name = format!("write_{}", table_name); | ||
let expected = Chunk::new(vec![ | ||
Box::new(BooleanArray::from([Some(true), Some(false), None])) as _, | ||
]); | ||
|
||
test( | ||
expected, | ||
vec![Field::new("a", DataType::Boolean, true)], | ||
"BIT", | ||
&table_name, | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,4 @@ | |
pub use odbc_api as api; | ||
|
||
pub mod read; | ||
//pub mod write; | ||
pub mod write; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
//! APIs to write to ODBC | ||
mod schema; | ||
mod serialize; | ||
|
||
use super::api; | ||
pub use schema::infer_descriptions; | ||
pub use serialize::serialize; | ||
|
||
/// Creates a [`api::buffers::ColumnarBuffer`] from [`api::ColumnDescription`]s. | ||
pub fn buffer_from_description( | ||
descriptions: Vec<api::ColumnDescription>, | ||
max_batch_size: usize, | ||
) -> api::buffers::ColumnarBuffer<api::buffers::AnyColumnBuffer> { | ||
let descs = descriptions | ||
.into_iter() | ||
.map(|description| api::buffers::BufferDescription { | ||
nullable: description.could_be_nullable(), | ||
kind: api::buffers::BufferKind::from_data_type(description.data_type).unwrap(), | ||
}); | ||
|
||
let mut buffer = api::buffers::buffer_from_description(max_batch_size, descs); | ||
buffer.set_num_rows(max_batch_size); | ||
buffer | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
use super::super::api; | ||
|
||
use crate::datatypes::{DataType, Field}; | ||
use crate::error::{ArrowError, Result}; | ||
|
||
/// Infers the [`ColumnDescription`] from the fields | ||
pub fn infer_descriptions(fields: &[Field]) -> Result<Vec<api::ColumnDescription>> { | ||
fields | ||
.iter() | ||
.map(|field| { | ||
let nullability = if field.is_nullable { | ||
api::Nullability::Nullable | ||
} else { | ||
api::Nullability::NoNulls | ||
}; | ||
let data_type = data_type_to(field.data_type())?; | ||
Ok(api::ColumnDescription { | ||
name: api::U16String::from_str(&field.name).into_vec(), | ||
nullability, | ||
data_type, | ||
}) | ||
}) | ||
.collect() | ||
} | ||
|
||
fn data_type_to(data_type: &DataType) -> Result<api::DataType> { | ||
Ok(match data_type { | ||
DataType::Boolean => api::DataType::Bit, | ||
DataType::Int16 => api::DataType::SmallInt, | ||
DataType::Int32 => api::DataType::Integer, | ||
DataType::Float32 => api::DataType::Float { precision: 24 }, | ||
DataType::Float64 => api::DataType::Float { precision: 53 }, | ||
DataType::FixedSizeBinary(length) => api::DataType::Varbinary { length: *length }, | ||
other => return Err(ArrowError::nyi(format!("{other:?} to ODBC"))), | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
use api::buffers::BinColumnWriter; | ||
|
||
use crate::array::{Array, BooleanArray, FixedSizeBinaryArray, PrimitiveArray}; | ||
use crate::bitmap::Bitmap; | ||
use crate::datatypes::DataType; | ||
use crate::error::{ArrowError, Result}; | ||
use crate::types::NativeType; | ||
|
||
use super::super::api; | ||
use super::super::api::buffers::NullableSliceMut; | ||
|
||
/// Serializes an [`Array`] to [`api::buffers::AnyColumnViewMut`] | ||
/// This operation is CPU-bounded | ||
pub fn serialize(array: &dyn Array, column: &mut api::buffers::AnyColumnViewMut) -> Result<()> { | ||
match array.data_type() { | ||
DataType::Boolean => { | ||
if let api::buffers::AnyColumnViewMut::Bit(values) = column { | ||
Ok(bool(array.as_any().downcast_ref().unwrap(), values)) | ||
} else if let api::buffers::AnyColumnViewMut::NullableBit(values) = column { | ||
Ok(bool_optional( | ||
array.as_any().downcast_ref().unwrap(), | ||
values, | ||
)) | ||
} else { | ||
Err(ArrowError::nyi("serialize bool to non-bool ODBC")) | ||
} | ||
} | ||
DataType::Int16 => { | ||
if let api::buffers::AnyColumnViewMut::I16(values) = column { | ||
Ok(primitive(array.as_any().downcast_ref().unwrap(), values)) | ||
} else if let api::buffers::AnyColumnViewMut::NullableI16(values) = column { | ||
Ok(primitive_optional( | ||
array.as_any().downcast_ref().unwrap(), | ||
values, | ||
)) | ||
} else { | ||
Err(ArrowError::nyi("serialize i16 to non-i16 ODBC")) | ||
} | ||
} | ||
DataType::Int32 => { | ||
if let api::buffers::AnyColumnViewMut::I32(values) = column { | ||
Ok(primitive(array.as_any().downcast_ref().unwrap(), values)) | ||
} else if let api::buffers::AnyColumnViewMut::NullableI32(values) = column { | ||
Ok(primitive_optional( | ||
array.as_any().downcast_ref().unwrap(), | ||
values, | ||
)) | ||
} else { | ||
Err(ArrowError::nyi("serialize i32 to non-i32 ODBC")) | ||
} | ||
} | ||
DataType::Float32 => { | ||
if let api::buffers::AnyColumnViewMut::F32(values) = column { | ||
Ok(primitive(array.as_any().downcast_ref().unwrap(), values)) | ||
} else if let api::buffers::AnyColumnViewMut::NullableF32(values) = column { | ||
Ok(primitive_optional( | ||
array.as_any().downcast_ref().unwrap(), | ||
values, | ||
)) | ||
} else { | ||
Err(ArrowError::nyi("serialize f32 to non-f32 ODBC")) | ||
} | ||
} | ||
DataType::Float64 => { | ||
if let api::buffers::AnyColumnViewMut::F64(values) = column { | ||
Ok(primitive(array.as_any().downcast_ref().unwrap(), values)) | ||
} else if let api::buffers::AnyColumnViewMut::NullableF64(values) = column { | ||
Ok(primitive_optional( | ||
array.as_any().downcast_ref().unwrap(), | ||
values, | ||
)) | ||
} else { | ||
Err(ArrowError::nyi("serialize f64 to non-f64 ODBC")) | ||
} | ||
} | ||
DataType::FixedSizeBinary(_) => { | ||
if let api::buffers::AnyColumnViewMut::Binary(values) = column { | ||
Ok(binary(array.as_any().downcast_ref().unwrap(), values)) | ||
} else { | ||
Err(ArrowError::nyi("serialize f64 to non-f64 ODBC")) | ||
} | ||
} | ||
other => Err(ArrowError::nyi(format!("{other:?} to ODBC"))), | ||
} | ||
} | ||
|
||
fn bool(array: &BooleanArray, values: &mut [api::Bit]) { | ||
array | ||
.values() | ||
.iter() | ||
.zip(values.iter_mut()) | ||
.for_each(|(from, to)| *to = api::Bit(from as u8)); | ||
} | ||
|
||
fn bool_optional(array: &BooleanArray, values: &mut NullableSliceMut<api::Bit>) { | ||
array | ||
.values() | ||
.iter() | ||
.zip(values.values().iter_mut()) | ||
.for_each(|(from, to)| *to = api::Bit(from as u8)); | ||
write_validity(array.validity(), values.indicators()); | ||
} | ||
|
||
fn primitive<T: NativeType>(array: &PrimitiveArray<T>, values: &mut [T]) { | ||
values.copy_from_slice(array.values()) | ||
} | ||
|
||
fn write_validity(validity: Option<&Bitmap>, indicators: &mut [isize]) { | ||
if let Some(validity) = validity { | ||
indicators | ||
.iter_mut() | ||
.zip(validity.iter()) | ||
.for_each(|(indicator, is_valid)| *indicator = if is_valid { 0 } else { -1 }) | ||
} else { | ||
indicators.iter_mut().for_each(|x| *x = 0) | ||
} | ||
} | ||
|
||
fn primitive_optional<T: NativeType>(array: &PrimitiveArray<T>, values: &mut NullableSliceMut<T>) { | ||
values.values().copy_from_slice(array.values()); | ||
write_validity(array.validity(), values.indicators()); | ||
} | ||
|
||
fn binary(array: &FixedSizeBinaryArray, writer: &mut BinColumnWriter) { | ||
writer.write(array.iter()) | ||
} |