Skip to content

Commit

Permalink
Port to new Arrow2 casting support.
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed May 12, 2023
1 parent 6011c52 commit 00c7a0b
Show file tree
Hide file tree
Showing 12 changed files with 206 additions and 108 deletions.
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ prettytable-rs = "^0.10"
rand = "^0.8"

[dependencies.arrow2]
branch = "clark/expand-casting-support"
features = ["compute", "io_ipc"]
version = "0.17.0"
git = "https://github.com/Eventual-Inc/arrow2"
package = "arrow2"
version = "0.17"

[dependencies.bincode]
version = "1.3.3"
Expand All @@ -15,6 +18,9 @@ version = "1.3.3"
features = ["serde"]
version = "1.9.2"

[dependencies.lazy_static]
version = "1.4.0"

[dependencies.num-traits]
version = "0.2"

Expand Down
2 changes: 2 additions & 0 deletions daft/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType:
metadata = arrow_type.__arrow_ext_serialize__().decode()
except AttributeError:
metadata = None
if metadata == "":
metadata = None
return cls.extension(
name,
cls.from_arrow_type(arrow_type.storage_type),
Expand Down
22 changes: 19 additions & 3 deletions src/array/ops/arrow2/sort/primitive/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ mod tests {
use super::*;

use arrow2::array::ord;
use arrow2::array::Array;
use arrow2::array::PrimitiveArray;
use arrow2::datatypes::DataType;

Expand All @@ -177,13 +178,28 @@ mod tests {
) where
T: NativeType + std::cmp::Ord,
{
let input = PrimitiveArray::<T>::from(data).to(data_type.clone());
let expected = PrimitiveArray::<T>::from(expected_data).to(data_type.clone());
let input = PrimitiveArray::<T>::from(data)
.to(data_type.clone())
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.unwrap()
.clone();
let expected = PrimitiveArray::<T>::from(expected_data)
.to(data_type.clone())
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.unwrap()
.clone();
let output = sort_by(&input, ord::total_cmp, &options, None);
assert_eq!(expected, output);

// with limit
let expected = PrimitiveArray::<T>::from(&expected_data[..3]).to(data_type);
let expected = PrimitiveArray::<T>::from(&expected_data[..3])
.to(data_type)
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.unwrap()
.clone();
let output = sort_by(&input, ord::total_cmp, &options, Some(3));
assert_eq!(expected, output)
}
Expand Down
4 changes: 4 additions & 0 deletions src/array/pseudo_arrow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,4 +313,8 @@ impl<T: Send + Sync + Clone + 'static> Array for PseudoArrowArray<T> {
.map(|x| x.unset_bits())
.unwrap_or(0)
}

fn to_type(&self, _: DataType) -> Box<dyn Array> {
panic!("not implemented");
}
}
12 changes: 8 additions & 4 deletions src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ use pyo3::types::PyList;
use pyo3::{PyAny, PyObject, PyResult, Python};

use crate::{
error::DaftResult, schema::SchemaRef, series::Series, table::Table,
utils::arrow::cast_array_if_needed,
error::DaftResult,
schema::SchemaRef,
series::Series,
table::Table,
utils::arrow::{cast_array_for_daft_if_needed, cast_array_from_daft_if_needed},
};

pub type ArrayRef = Box<dyn Array>;
Expand Down Expand Up @@ -71,7 +74,7 @@ pub fn record_batches_to_table(
.into_iter()
.enumerate()
.map(|(i, c)| {
let c = cast_array_if_needed(c);
let c = cast_array_for_daft_if_needed(c);
Series::try_from((names.get(i).unwrap().as_str(), c))
})
.collect::<DaftResult<Vec<_>>>()?;
Expand Down Expand Up @@ -110,7 +113,8 @@ pub fn table_to_record_batch(table: &Table, py: Python, pyarrow: &PyModule) -> P
for i in 0..table.num_columns() {
let s = table.get_column_by_index(i)?;
let arrow_array = s.array().data();
let py_array = to_py_array(arrow_array.to_boxed(), py, pyarrow)?;
let arrow_array = cast_array_from_daft_if_needed(arrow_array.to_boxed());
let py_array = to_py_array(arrow_array, py, pyarrow)?;
arrays.push(py_array);
names.push(s.name().to_string());
}
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#![feature(hash_raw_entry)]
#[macro_use]
extern crate lazy_static;

mod array;
mod datatypes;
mod dsl;
Expand Down
5 changes: 3 additions & 2 deletions src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
datatypes::{DataType, Field, PythonType, UInt64Type},
ffi,
series::{self, Series},
utils::arrow::cast_array_if_needed,
utils::arrow::{cast_array_for_daft_if_needed, cast_array_from_daft_if_needed},
};

use super::datatype::PyDataType;
Expand All @@ -24,7 +24,7 @@ impl PySeries {
#[staticmethod]
pub fn from_arrow(name: &str, pyarrow_array: &PyAny) -> PyResult<Self> {
let arrow_array = ffi::array_to_rust(pyarrow_array)?;
let arrow_array = cast_array_if_needed(arrow_array.to_boxed());
let arrow_array = cast_array_for_daft_if_needed(arrow_array.to_boxed());
let series = series::Series::try_from((name, arrow_array))?;
Ok(series.into())
}
Expand All @@ -51,6 +51,7 @@ impl PySeries {

pub fn to_arrow(&self) -> PyResult<PyObject> {
let arrow_array = self.series.array().data().to_boxed();
let arrow_array = cast_array_from_daft_if_needed(arrow_array);
Python::with_gil(|py| {
let pyarrow = py.import("pyarrow")?;
ffi::to_py_array(arrow_array, py, pyarrow)
Expand Down
201 changes: 117 additions & 84 deletions src/utils/arrow.rs
Original file line number Diff line number Diff line change
@@ -1,105 +1,138 @@
use std::collections::HashMap;
use std::sync::Mutex;

use arrow2::compute::cast;

pub fn cast_array_if_needed(
arrow_array: Box<dyn arrow2::array::Array>,
) -> Box<dyn arrow2::array::Array> {
match arrow_array.data_type() {
arrow2::datatypes::DataType::Utf8 => {
cast::utf8_to_large_utf8(arrow_array.as_any().downcast_ref().unwrap()).boxed()
}
arrow2::datatypes::DataType::Binary => cast::binary_to_large_binary(
arrow_array.as_any().downcast_ref().unwrap(),
arrow2::datatypes::DataType::LargeBinary,
)
.boxed(),
// TODO(Clark): Refactor to GILOnceCell in order to avoid deadlock between the below mutex and the Python GIL.
lazy_static! {
static ref REGISTRY: Mutex<HashMap<std::string::String, arrow2::datatypes::DataType>> =
Mutex::new(HashMap::new());
}

fn coerce_to_daft_compatible_type(
dtype: &arrow2::datatypes::DataType,
) -> Option<arrow2::datatypes::DataType> {
match dtype {
arrow2::datatypes::DataType::Utf8 => Some(arrow2::datatypes::DataType::LargeUtf8),
arrow2::datatypes::DataType::Binary => Some(arrow2::datatypes::DataType::LargeBinary),
arrow2::datatypes::DataType::List(field) => {
let array = arrow_array
.as_any()
.downcast_ref::<arrow2::array::ListArray<i32>>()
.unwrap();
let new_values = cast_array_if_needed(array.values().clone());
let offsets = array.offsets().into();
arrow2::array::ListArray::<i64>::new(
arrow2::datatypes::DataType::LargeList(Box::new(arrow2::datatypes::Field::new(
field.name.clone(),
new_values.data_type().clone(),
field.is_nullable,
))),
offsets,
new_values,
arrow_array.validity().cloned(),
)
.boxed()
let new_field = match coerce_to_daft_compatible_type(field.data_type()) {
Some(new_inner_dtype) => Box::new(
arrow2::datatypes::Field::new(
field.name.clone(),
new_inner_dtype,
field.is_nullable,
)
.with_metadata(field.metadata.clone()),
),
None => field.clone(),
};
Some(arrow2::datatypes::DataType::LargeList(new_field))
}
arrow2::datatypes::DataType::LargeList(field) => {
// Types nested within LargeList may need casting.
let array = arrow_array
.as_any()
.downcast_ref::<arrow2::array::ListArray<i64>>()
.unwrap();
let new_values = cast_array_if_needed(array.values().clone());
if new_values.data_type() == array.values().data_type() {
return arrow_array;
}
arrow2::array::ListArray::<i64>::new(
arrow2::datatypes::DataType::LargeList(Box::new(arrow2::datatypes::Field::new(
let new_inner_dtype = coerce_to_daft_compatible_type(field.data_type())?;
Some(arrow2::datatypes::DataType::LargeList(Box::new(
arrow2::datatypes::Field::new(
field.name.clone(),
new_values.data_type().clone(),
new_inner_dtype,
field.is_nullable,
))),
array.offsets().clone(),
new_values,
arrow_array.validity().cloned(),
)
.boxed()
)
.with_metadata(field.metadata.clone()),
)))
}
arrow2::datatypes::DataType::FixedSizeList(field, size) => {
// Types nested within FixedSizeList may need casting.
let array = arrow_array
.as_any()
.downcast_ref::<arrow2::array::FixedSizeListArray>()
.unwrap();
let new_values = cast_array_if_needed(array.values().clone());
if new_values.data_type() == array.values().data_type() {
return arrow_array;
}
arrow2::array::FixedSizeListArray::new(
arrow2::datatypes::DataType::FixedSizeList(
Box::new(arrow2::datatypes::Field::new(
let new_inner_dtype = coerce_to_daft_compatible_type(field.data_type())?;
Some(arrow2::datatypes::DataType::FixedSizeList(
Box::new(
arrow2::datatypes::Field::new(
field.name.clone(),
new_values.data_type().clone(),
new_inner_dtype,
field.is_nullable,
)),
*size,
)
.with_metadata(field.metadata.clone()),
),
new_values,
arrow_array.validity().cloned(),
)
.boxed()
*size,
))
}
arrow2::datatypes::DataType::Struct(fields) => {
let new_arrays = arrow_array
.as_any()
.downcast_ref::<arrow2::array::StructArray>()
.unwrap()
.values()
.iter()
.map(|field_arr| cast_array_if_needed(field_arr.clone()))
.collect::<Vec<Box<dyn arrow2::array::Array>>>();
let new_fields = fields
.iter()
.zip(new_arrays.iter().map(|arr| arr.data_type().clone()))
.map(|(field, dtype)| {
arrow2::datatypes::Field::new(field.name.clone(), dtype, field.is_nullable)
})
.collect();
Box::new(arrow2::array::StructArray::new(
arrow2::datatypes::DataType::Struct(new_fields),
new_arrays,
arrow_array.validity().cloned(),
.map(
|field| match coerce_to_daft_compatible_type(field.data_type()) {
Some(new_inner_dtype) => arrow2::datatypes::Field::new(
field.name.clone(),
new_inner_dtype,
field.is_nullable,
)
.with_metadata(field.metadata.clone()),
None => field.clone(),
},
)
.collect::<Vec<arrow2::datatypes::Field>>();
if &new_fields == fields {
None
} else {
Some(arrow2::datatypes::DataType::Struct(new_fields))
}
}
arrow2::datatypes::DataType::Extension(name, inner, metadata) => {
let new_inner_dtype = coerce_to_daft_compatible_type(inner.as_ref())?;
REGISTRY.lock().unwrap().insert(name.clone(), dtype.clone());
Some(arrow2::datatypes::DataType::Extension(
name.clone(),
Box::new(new_inner_dtype),
metadata.clone(),
))
}
_ => arrow_array,
_ => None,
}
}

pub fn cast_array_for_daft_if_needed(
arrow_array: Box<dyn arrow2::array::Array>,
) -> Box<dyn arrow2::array::Array> {
match coerce_to_daft_compatible_type(arrow_array.data_type()) {
Some(coerced_dtype) => cast::cast(
arrow_array.as_ref(),
&coerced_dtype,
cast::CastOptions {
wrapped: true,
partial: false,
},
)
.unwrap(),
None => arrow_array,
}
}

fn coerce_from_daft_compatible_type(
dtype: &arrow2::datatypes::DataType,
) -> Option<arrow2::datatypes::DataType> {
match dtype {
arrow2::datatypes::DataType::Extension(name, _, _)
if REGISTRY.lock().unwrap().contains_key(name) =>
{
let entry = REGISTRY.lock().unwrap();
Some(entry.get(name).unwrap().clone())
}
_ => None,
}
}

pub fn cast_array_from_daft_if_needed(
arrow_array: Box<dyn arrow2::array::Array>,
) -> Box<dyn arrow2::array::Array> {
match coerce_from_daft_compatible_type(arrow_array.data_type()) {
Some(coerced_dtype) => cast::cast(
arrow_array.as_ref(),
&coerced_dtype,
cast::CastOptions {
wrapped: true,
partial: false,
},
)
.unwrap(),
None => arrow_array,
}
}

Expand Down
Loading

0 comments on commit 00c7a0b

Please sign in to comment.