Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Every other arm, limiting pain as much as possible
Browse files Browse the repository at this point in the history
  • Loading branch information
teh-cmc committed Apr 13, 2023
1 parent c1e58ad commit d274275
Show file tree
Hide file tree
Showing 52 changed files with 460 additions and 278 deletions.
4 changes: 2 additions & 2 deletions src/array/dictionary/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::hint::unreachable_unchecked;
use std::{hint::unreachable_unchecked, sync::Arc};

use crate::{
bitmap::{
Expand Down Expand Up @@ -290,7 +290,7 @@ impl<K: DictionaryKey> DictionaryArray<K> {
}

pub(crate) fn default_data_type(values_datatype: DataType) -> DataType {
DataType::Dictionary(K::KEY_TYPE, Box::new(values_datatype), false)
DataType::Dictionary(K::KEY_TYPE, Arc::new(values_datatype), false)
}

/// Slices this [`DictionaryArray`].
Expand Down
4 changes: 2 additions & 2 deletions src/array/dictionary/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl<K: DictionaryKey, M: MutableArray> From<M> for MutableDictionaryArray<K, M>
Self {
data_type: DataType::Dictionary(
K::KEY_TYPE,
Box::new(values.data_type().clone()),
std::sync::Arc::new(values.data_type().clone()),
false,
),
keys: MutablePrimitiveArray::<K>::new(),
Expand All @@ -72,7 +72,7 @@ impl<K: DictionaryKey, M: MutableArray + Default> MutableDictionaryArray<K, M> {
Self {
data_type: DataType::Dictionary(
K::KEY_TYPE,
Box::new(values.data_type().clone()),
std::sync::Arc::new(values.data_type().clone()),
false,
),
keys: MutablePrimitiveArray::<K>::new(),
Expand Down
8 changes: 5 additions & 3 deletions src/array/struct_/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use crate::{
bitmap::Bitmap,
datatypes::{DataType, Field},
Expand Down Expand Up @@ -28,7 +30,7 @@ pub use mutable::*;
/// Field::new("c", DataType::Int32, false),
/// ];
///
/// let array = StructArray::new(DataType::Struct(fields), vec![boolean, int], None);
/// let array = StructArray::new(DataType::Struct(std::sync::Arc::new(fields)), vec![boolean, int], None);
/// ```
#[derive(Clone)]
pub struct StructArray {
Expand Down Expand Up @@ -69,7 +71,7 @@ impl StructArray {
.try_for_each(|(index, (data_type, child))| {
if data_type != child {
Err(Error::oos(format!(
"The children DataTypes of a StructArray must equal the children data types.
"The children DataTypes of a StructArray must equal the children data types.
However, the field {index} has data type {data_type:?} but the value has data type {child:?}"
)))
} else {
Expand Down Expand Up @@ -153,7 +155,7 @@ impl StructArray {
impl StructArray {
/// Deconstructs the [`StructArray`] into its individual components.
#[must_use]
pub fn into_data(self) -> (Vec<Field>, Vec<Box<dyn Array>>, Option<Bitmap>) {
pub fn into_data(self) -> (Arc<Vec<Field>>, Vec<Box<dyn Array>>, Option<Bitmap>) {
let Self {
data_type,
values,
Expand Down
4 changes: 2 additions & 2 deletions src/array/union/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl UnionArray {
.try_for_each(|(index, (data_type, child))| {
if data_type != child {
Err(Error::oos(format!(
"The children DataTypes of a UnionArray must equal the children data types.
"The children DataTypes of a UnionArray must equal the children data types.
However, the field {index} has data type {data_type:?} but the value has data type {child:?}"
)))
} else {
Expand Down Expand Up @@ -352,7 +352,7 @@ impl UnionArray {
fn try_get_all(data_type: &DataType) -> Result<UnionComponents, Error> {
match data_type.to_logical_type() {
DataType::Union(fields, ids, mode) => {
Ok((fields, ids.as_ref().map(|x| x.as_ref()), *mode))
Ok((fields, ids.as_ref().map(|x| x.as_slice()), *mode))
}
_ => Err(Error::oos(
"The UnionArray requires a logical type of DataType::Union",
Expand Down
4 changes: 2 additions & 2 deletions src/compute/cast/dictionary_to.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ where
} else {
let data_type = DataType::Dictionary(
K2::KEY_TYPE,
Box::new(values.data_type().clone()),
std::sync::Arc::new(values.data_type().clone()),
is_ordered,
);
// Safety: this is safe because given a type `T` that fits in a `usize`, casting it to type `P` either overflows or also fits in a `usize`
Expand All @@ -116,7 +116,7 @@ where
} else {
let data_type = DataType::Dictionary(
K2::KEY_TYPE,
Box::new(values.data_type().clone()),
std::sync::Arc::new(values.data_type().clone()),
is_ordered,
);
// some of the values may not fit in `usize` and thus this needs to be checked
Expand Down
42 changes: 28 additions & 14 deletions src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,10 @@ pub enum DataType {
/// A list of some logical data type whose offsets are represented as [`i64`].
LargeList(Arc<Field>),
/// A nested [`DataType`] with a given number of [`Field`]s.
Struct(Vec<Field>),
Struct(Arc<Vec<Field>>),
/// A nested datatype that can represent slots of differing types.
/// Third argument represents mode
Union(Vec<Field>, Option<Vec<i32>>, UnionMode),
Union(Arc<Vec<Field>>, Option<Arc<Vec<i32>>>, UnionMode),
/// A nested type that is represented as
///
/// List<entries: Struct<key: K, value: V>>
Expand Down Expand Up @@ -189,7 +189,7 @@ pub enum DataType {
/// arrays or a limited set of primitive types as integers.
///
/// The `bool` value indicates the `Dictionary` is sorted if set to `true`.
Dictionary(IntegerType, Box<DataType>, bool),
Dictionary(IntegerType, Arc<DataType>, bool),
/// Decimal value with precision and scale
/// precision is the number of digits in the number and
/// scale is the number of decimal places.
Expand All @@ -198,7 +198,7 @@ pub enum DataType {
/// Decimal backed by 256 bits
Decimal256(usize, usize),
/// Extension type.
Extension(String, Box<DataType>, Option<String>),
Extension(String, Arc<DataType>, Option<Arc<String>>),
}

#[cfg(feature = "arrow")]
Expand Down Expand Up @@ -239,27 +239,41 @@ impl From<DataType> for arrow_schema::DataType {
DataType::LargeList(f) => {
Self::LargeList(Box::new(Arc::unwrap_or_clone_polyfill(f).into()))
}
DataType::Struct(f) => Self::Struct(f.into_iter().map(Into::into).collect()),
DataType::Struct(f) => Self::Struct(
Arc::unwrap_or_clone_polyfill(f)
.into_iter()
.map(Into::into)
.collect(),
),
DataType::Union(fields, Some(ids), mode) => {
let ids = ids.into_iter().map(|x| x as _).collect();
let fields = fields.into_iter().map(Into::into).collect();
let ids = Arc::unwrap_or_clone_polyfill(ids)
.into_iter()
.map(|x| x as _)
.collect();
let fields = Arc::unwrap_or_clone_polyfill(fields)
.into_iter()
.map(Into::into)
.collect();
Self::Union(fields, ids, mode.into())
}
DataType::Union(fields, None, mode) => {
let ids = (0..fields.len() as i8).collect();
let fields = fields.into_iter().map(Into::into).collect();
let fields = Arc::unwrap_or_clone_polyfill(fields)
.into_iter()
.map(Into::into)
.collect();
Self::Union(fields, ids, mode.into())
}
DataType::Map(f, ordered) => {
Self::Map(Box::new(Arc::unwrap_or_clone_polyfill(f).into()), ordered)
}
DataType::Dictionary(key, value, _) => Self::Dictionary(
Box::new(DataType::from(key).into()),
Box::new((*value).into()),
Box::new(Arc::unwrap_or_clone_polyfill(value).into()),
),
DataType::Decimal(precision, scale) => Self::Decimal128(precision as _, scale as _),
DataType::Decimal256(precision, scale) => Self::Decimal256(precision as _, scale as _),
DataType::Extension(_, d, _) => (*d).into(),
DataType::Extension(_, d, _) => Arc::unwrap_or_clone_polyfill(d).into(),
}
}
}
Expand Down Expand Up @@ -299,10 +313,10 @@ impl From<arrow_schema::DataType> for DataType {
Self::FixedSizeList(Arc::new((*f).into()), size as _)
}
DataType::LargeList(f) => Self::LargeList(Arc::new((*f).into())),
DataType::Struct(f) => Self::Struct(f.into_iter().map(Into::into).collect()),
DataType::Struct(f) => Self::Struct(Arc::new(f.into_iter().map(Into::into).collect())),
DataType::Union(fields, ids, mode) => {
let ids = ids.into_iter().map(|x| x as _).collect();
let fields = fields.into_iter().map(Into::into).collect();
let ids = Arc::new(ids.into_iter().map(|x| x as _).collect());
let fields = Arc::new(fields.into_iter().map(Into::into).collect());
Self::Union(fields, Some(ids), mode.into())
}
DataType::Map(f, ordered) => Self::Map(std::sync::Arc::new((*f).into()), ordered),
Expand All @@ -318,7 +332,7 @@ impl From<arrow_schema::DataType> for DataType {
DataType::UInt64 => IntegerType::UInt64,
d => panic!("illegal dictionary key type: {d}"),
};
Self::Dictionary(key, Box::new((*value).into()), false)
Self::Dictionary(key, Arc::new((*value).into()), false)
}
DataType::Decimal128(precision, scale) => Self::Decimal(precision as _, scale as _),
DataType::Decimal256(precision, scale) => Self::Decimal256(precision as _, scale as _),
Expand Down
30 changes: 17 additions & 13 deletions src/ffi/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl ArrowSchema {
if let Some(extension_metadata) = extension_metadata {
metadata.insert(
"ARROW:extension:metadata".to_string(),
extension_metadata.clone(),
extension_metadata.to_string(),
);
}

Expand Down Expand Up @@ -193,14 +193,18 @@ pub(crate) unsafe fn to_field(schema: &ArrowSchema) -> Result<Field> {
let indices = to_integer_type(schema.format())?;
let values = to_field(dictionary)?;
let is_ordered = schema.flags & 1 == 1;
DataType::Dictionary(indices, Box::new(values.data_type().clone()), is_ordered)
DataType::Dictionary(
indices,
std::sync::Arc::new(values.data_type().clone()),
is_ordered,
)
} else {
to_data_type(schema)?
};
let (metadata, extension) = unsafe { metadata_from_bytes(schema.metadata) };

let data_type = if let Some((name, extension_metadata)) = extension {
DataType::Extension(name, Box::new(data_type), extension_metadata)
DataType::Extension(name, Arc::new(data_type), extension_metadata.map(Arc::new))
} else {
data_type
};
Expand Down Expand Up @@ -276,7 +280,7 @@ unsafe fn to_data_type(schema: &ArrowSchema) -> Result<DataType> {
let children = (0..schema.n_children as usize)
.map(|x| to_field(schema.child(x)))
.collect::<Result<Vec<_>>>()?;
DataType::Struct(children)
DataType::Struct(Arc::new(children))
}
other => {
match other.splitn(2, ':').collect::<Vec<_>>()[..] {
Expand Down Expand Up @@ -378,7 +382,7 @@ unsafe fn to_data_type(schema: &ArrowSchema) -> Result<DataType> {
let fields = (0..schema.n_children as usize)
.map(|x| to_field(schema.child(x)))
.collect::<Result<Vec<_>>>()?;
DataType::Union(fields, Some(type_ids), mode)
DataType::Union(Arc::new(fields), Some(Arc::new(type_ids)), mode)
}
_ => {
return Err(Error::OutOfSpec(format!(
Expand Down Expand Up @@ -576,40 +580,40 @@ mod tests {
DataType::List(Arc::new(Field::new("example", DataType::Boolean, false))),
DataType::FixedSizeList(Arc::new(Field::new("example", DataType::Boolean, false)), 2),
DataType::LargeList(Arc::new(Field::new("example", DataType::Boolean, false))),
DataType::Struct(vec![
DataType::Struct(Arc::new(vec![
Field::new("a", DataType::Int64, true),
Field::new(
"b",
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
true,
),
]),
])),
DataType::Map(
std::sync::Arc::new(Field::new("a", DataType::Int64, true)),
true,
),
DataType::Union(
vec![
Arc::new(vec![
Field::new("a", DataType::Int64, true),
Field::new(
"b",
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
true,
),
],
Some(vec![1, 2]),
]),
Some(Arc::new(vec![1, 2])),
UnionMode::Dense,
),
DataType::Union(
vec![
Arc::new(vec![
Field::new("a", DataType::Int64, true),
Field::new(
"b",
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
true,
),
],
Some(vec![0, 1]),
]),
Some(Arc::new(vec![0, 1])),
UnionMode::Sparse,
),
];
Expand Down
2 changes: 1 addition & 1 deletion src/io/avro/read/nested.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl FixedItemsUtf8Dictionary {
Self {
data_type: DataType::Dictionary(
IntegerType::Int32,
Box::new(values.data_type().clone()),
std::sync::Arc::new(values.data_type().clone()),
false,
),
keys: MutablePrimitiveArray::<i32>::with_capacity(capacity),
Expand Down
6 changes: 3 additions & 3 deletions src/io/avro/read/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ fn schema_to_field(schema: &AvroSchema, name: Option<&str>, props: Metadata) ->
.iter()
.map(|s| schema_to_field(s, None, Metadata::default()))
.collect::<Result<Vec<Field>>>()?;
DataType::Union(fields, None, UnionMode::Dense)
DataType::Union(Arc::new(fields), None, UnionMode::Dense)
}
}
AvroSchema::Record(Record { fields, .. }) => {
Expand All @@ -119,12 +119,12 @@ fn schema_to_field(schema: &AvroSchema, name: Option<&str>, props: Metadata) ->
schema_to_field(&field.schema, Some(&field.name), props)
})
.collect::<Result<_>>()?;
DataType::Struct(fields)
DataType::Struct(std::sync::Arc::new(fields))
}
AvroSchema::Enum { .. } => {
return Ok(Field::new(
name.unwrap_or_default(),
DataType::Dictionary(IntegerType::Int32, Box::new(DataType::Utf8), false),
DataType::Dictionary(IntegerType::Int32, Arc::new(DataType::Utf8), false),
false,
))
}
Expand Down
11 changes: 7 additions & 4 deletions src/io/ipc/read/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,10 @@ fn deserialize_union(union_: UnionRef, field: FieldRef) -> Result<(DataType, Ipc
fields: ipc_fields,
dictionary_id: None,
};
Ok((DataType::Union(fields, ids, mode), ipc_field))
Ok((
DataType::Union(Arc::new(fields), ids.map(Arc::new), mode),
ipc_field,
))
}

fn deserialize_map(map: MapRef, field: FieldRef) -> Result<(DataType, IpcField)> {
Expand Down Expand Up @@ -172,7 +175,7 @@ fn deserialize_struct(field: FieldRef) -> Result<(DataType, IpcField)> {
fields: ipc_fields,
dictionary_id: None,
};
Ok((DataType::Struct(fields), ipc_field))
Ok((DataType::Struct(std::sync::Arc::new(fields)), ipc_field))
}

fn deserialize_list(field: FieldRef) -> Result<(DataType, IpcField)> {
Expand Down Expand Up @@ -252,7 +255,7 @@ fn get_data_type(
let (inner, mut ipc_field) = get_data_type(field, extension, false)?;
ipc_field.dictionary_id = Some(dictionary.id()?);
return Ok((
DataType::Dictionary(index_type, Box::new(inner), dictionary.is_ordered()?),
DataType::Dictionary(index_type, Arc::new(inner), dictionary.is_ordered()?),
ipc_field,
));
}
Expand All @@ -262,7 +265,7 @@ fn get_data_type(
let (name, metadata) = extension;
let (data_type, fields) = get_data_type(field, None, false)?;
return Ok((
DataType::Extension(name, Box::new(data_type), metadata),
DataType::Extension(name, Arc::new(data_type), metadata.map(Arc::new)),
fields,
));
}
Expand Down
Loading

0 comments on commit d274275

Please sign in to comment.