Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added support to write to ODBC
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Feb 18, 2022
1 parent 83c47fc commit 840c8a3
Show file tree
Hide file tree
Showing 7 changed files with 288 additions and 3 deletions.
2 changes: 1 addition & 1 deletion arrow-odbc-integration-testing/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![cfg(test)]

mod read;
//mod write;
mod write;

use arrow2::io::odbc::api::{Connection, Environment, Error as OdbcError};
use lazy_static::lazy_static;
Expand Down
99 changes: 99 additions & 0 deletions arrow-odbc-integration-testing/src/write.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use stdext::function_name;

use arrow2::array::{Array, BinaryArray, BooleanArray, Int32Array, Utf8Array};
use arrow2::chunk::Chunk;
use arrow2::datatypes::{DataType, Field};
use arrow2::error::Result;
use arrow2::io::odbc::api::{Connection, Cursor};
use arrow2::io::odbc::write::{buffer_from_description, infer_descriptions, serialize};

use super::read::read;
use super::{setup_empty_table, ENV, MSSQL};

fn test(
expected: Chunk<Box<dyn Array>>,
fields: Vec<Field>,
type_: &str,
table_name: &str,
) -> Result<()> {
let connection = ENV.connect_with_connection_string(MSSQL).unwrap();
setup_empty_table(&connection, table_name, &[type_]).unwrap();

let query = &format!("INSERT INTO {table_name} (a) VALUES (?)");
let mut a = connection.prepare(query).unwrap();

let mut buffer = buffer_from_description(infer_descriptions(&fields)?, expected.len());

// write
buffer.set_num_rows(expected.len());
let array = &expected.columns()[0];

serialize(array.as_ref(), &mut buffer.column_mut(0))?;

a.execute(&buffer).unwrap();

// read
let query = format!("SELECT a FROM {table_name} ORDER BY id");
let chunks = read(&connection, &query)?.1;

assert_eq!(chunks[0], expected);
Ok(())
}

#[test]
fn int() -> Result<()> {
let table_name = function_name!().rsplit_once(':').unwrap().1;
let table_name = format!("write_{}", table_name);
let expected = Chunk::new(vec![Box::new(Int32Array::from_slice([1])) as _]);

test(
expected,
vec![Field::new("a", DataType::Int32, false)],
"INT",
&table_name,
)
}

#[test]
fn int_nullable() -> Result<()> {
let table_name = function_name!().rsplit_once(':').unwrap().1;
let table_name = format!("write_{}", table_name);
let expected = Chunk::new(vec![Box::new(Int32Array::from([Some(1), None])) as _]);

test(
expected,
vec![Field::new("a", DataType::Int32, true)],
"INT",
&table_name,
)
}

#[test]
fn bool() -> Result<()> {
let table_name = function_name!().rsplit_once(':').unwrap().1;
let table_name = format!("write_{}", table_name);
let expected = Chunk::new(vec![Box::new(BooleanArray::from_slice([true, false])) as _]);

test(
expected,
vec![Field::new("a", DataType::Boolean, false)],
"BIT",
&table_name,
)
}

#[test]
fn bool_nullable() -> Result<()> {
let table_name = function_name!().rsplit_once(':').unwrap().1;
let table_name = format!("write_{}", table_name);
let expected = Chunk::new(vec![
Box::new(BooleanArray::from([Some(true), Some(false), None])) as _,
]);

test(
expected,
vec![Field::new("a", DataType::Boolean, true)],
"BIT",
&table_name,
)
}
2 changes: 1 addition & 1 deletion src/io/odbc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
pub use odbc_api as api;

pub mod read;
//pub mod write;
pub mod write;
2 changes: 1 addition & 1 deletion src/io/odbc/read/schema.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::datatypes::{DataType, Field, TimeUnit};
use crate::error::{ArrowError, Result};
use crate::error::Result;

use super::super::api;
use super::super::api::ResultSetMetadata;
Expand Down
24 changes: 24 additions & 0 deletions src/io/odbc/write/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//! APIs to write to ODBC
mod schema;
mod serialize;

use super::api;
pub use schema::infer_descriptions;
pub use serialize::serialize;

/// Creates a [`api::buffers::ColumnarBuffer`] from [`api::ColumnDescription`]s.
pub fn buffer_from_description(
descriptions: Vec<api::ColumnDescription>,
max_batch_size: usize,
) -> api::buffers::ColumnarBuffer<api::buffers::AnyColumnBuffer> {
let descs = descriptions
.into_iter()
.map(|description| api::buffers::BufferDescription {
nullable: description.could_be_nullable(),
kind: api::buffers::BufferKind::from_data_type(description.data_type).unwrap(),
});

let mut buffer = api::buffers::buffer_from_description(max_batch_size, descs);
buffer.set_num_rows(max_batch_size);
buffer
}
36 changes: 36 additions & 0 deletions src/io/odbc/write/schema.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use super::super::api;

use crate::datatypes::{DataType, Field};
use crate::error::{ArrowError, Result};

/// Infers the [`ColumnDescription`] from the fields
pub fn infer_descriptions(fields: &[Field]) -> Result<Vec<api::ColumnDescription>> {
fields
.iter()
.map(|field| {
let nullability = if field.is_nullable {
api::Nullability::Nullable
} else {
api::Nullability::NoNulls
};
let data_type = data_type_to(field.data_type())?;
Ok(api::ColumnDescription {
name: api::U16String::from_str(&field.name).into_vec(),
nullability,
data_type,
})
})
.collect()
}

fn data_type_to(data_type: &DataType) -> Result<api::DataType> {
Ok(match data_type {
DataType::Boolean => api::DataType::Bit,
DataType::Int16 => api::DataType::SmallInt,
DataType::Int32 => api::DataType::Integer,
DataType::Float32 => api::DataType::Float { precision: 24 },
DataType::Float64 => api::DataType::Float { precision: 53 },
DataType::FixedSizeBinary(length) => api::DataType::Varbinary { length: *length },
other => return Err(ArrowError::nyi(format!("{other:?} to ODBC"))),
})
}
126 changes: 126 additions & 0 deletions src/io/odbc/write/serialize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
use api::buffers::BinColumnWriter;

use crate::array::{Array, BooleanArray, FixedSizeBinaryArray, PrimitiveArray};
use crate::bitmap::Bitmap;
use crate::datatypes::DataType;
use crate::error::{ArrowError, Result};
use crate::types::NativeType;

use super::super::api;
use super::super::api::buffers::NullableSliceMut;

/// Serializes an [`Array`] to [`api::buffers::AnyColumnViewMut`]
/// This operation is CPU-bounded
pub fn serialize(array: &dyn Array, column: &mut api::buffers::AnyColumnViewMut) -> Result<()> {
match array.data_type() {
DataType::Boolean => {
if let api::buffers::AnyColumnViewMut::Bit(values) = column {
Ok(bool(array.as_any().downcast_ref().unwrap(), values))
} else if let api::buffers::AnyColumnViewMut::NullableBit(values) = column {
Ok(bool_optional(
array.as_any().downcast_ref().unwrap(),
values,
))
} else {
Err(ArrowError::nyi("serialize bool to non-bool ODBC"))
}
}
DataType::Int16 => {
if let api::buffers::AnyColumnViewMut::I16(values) = column {
Ok(primitive(array.as_any().downcast_ref().unwrap(), values))
} else if let api::buffers::AnyColumnViewMut::NullableI16(values) = column {
Ok(primitive_optional(
array.as_any().downcast_ref().unwrap(),
values,
))
} else {
Err(ArrowError::nyi("serialize i16 to non-i16 ODBC"))
}
}
DataType::Int32 => {
if let api::buffers::AnyColumnViewMut::I32(values) = column {
Ok(primitive(array.as_any().downcast_ref().unwrap(), values))
} else if let api::buffers::AnyColumnViewMut::NullableI32(values) = column {
Ok(primitive_optional(
array.as_any().downcast_ref().unwrap(),
values,
))
} else {
Err(ArrowError::nyi("serialize i32 to non-i32 ODBC"))
}
}
DataType::Float32 => {
if let api::buffers::AnyColumnViewMut::F32(values) = column {
Ok(primitive(array.as_any().downcast_ref().unwrap(), values))
} else if let api::buffers::AnyColumnViewMut::NullableF32(values) = column {
Ok(primitive_optional(
array.as_any().downcast_ref().unwrap(),
values,
))
} else {
Err(ArrowError::nyi("serialize f32 to non-f32 ODBC"))
}
}
DataType::Float64 => {
if let api::buffers::AnyColumnViewMut::F64(values) = column {
Ok(primitive(array.as_any().downcast_ref().unwrap(), values))
} else if let api::buffers::AnyColumnViewMut::NullableF64(values) = column {
Ok(primitive_optional(
array.as_any().downcast_ref().unwrap(),
values,
))
} else {
Err(ArrowError::nyi("serialize f64 to non-f64 ODBC"))
}
}
DataType::FixedSizeBinary(_) => {
if let api::buffers::AnyColumnViewMut::Binary(values) = column {
Ok(binary(array.as_any().downcast_ref().unwrap(), values))
} else {
Err(ArrowError::nyi("serialize f64 to non-f64 ODBC"))
}
}
other => Err(ArrowError::nyi(format!("{other:?} to ODBC"))),
}
}

fn bool(array: &BooleanArray, values: &mut [api::Bit]) {
array
.values()
.iter()
.zip(values.iter_mut())
.for_each(|(from, to)| *to = api::Bit(from as u8));
}

fn bool_optional(array: &BooleanArray, values: &mut NullableSliceMut<api::Bit>) {
array
.values()
.iter()
.zip(values.values().iter_mut())
.for_each(|(from, to)| *to = api::Bit(from as u8));
write_validity(array.validity(), values.indicators());
}

fn primitive<T: NativeType>(array: &PrimitiveArray<T>, values: &mut [T]) {
values.copy_from_slice(array.values())
}

fn write_validity(validity: Option<&Bitmap>, indicators: &mut [isize]) {
if let Some(validity) = validity {
indicators
.iter_mut()
.zip(validity.iter())
.for_each(|(indicator, is_valid)| *indicator = if is_valid { 0 } else { -1 })
} else {
indicators.iter_mut().for_each(|x| *x = 0)
}
}

fn primitive_optional<T: NativeType>(array: &PrimitiveArray<T>, values: &mut NullableSliceMut<T>) {
values.values().copy_from_slice(array.values());
write_validity(array.validity(), values.indicators());
}

fn binary(array: &FixedSizeBinaryArray, writer: &mut BinColumnWriter) {
writer.write(array.iter())
}

0 comments on commit 840c8a3

Please sign in to comment.