Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Improved Union
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Dec 18, 2022
1 parent 12d955d commit e60b886
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 71 deletions.
4 changes: 2 additions & 2 deletions src/array/growable/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ impl<'a> Growable<'a> for GrowableUnion<'a> {
fn extend_validity(&mut self, _additional: usize) {}

fn as_arc(&mut self) -> Arc<dyn Array> {
Arc::new(self.to())
self.to().arced()
}

fn as_box(&mut self) -> Box<dyn Array> {
Box::new(self.to())
self.to().boxed()
}
}

Expand Down
152 changes: 98 additions & 54 deletions src/array/union/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use ahash::AHashMap;

use crate::{
bitmap::Bitmap,
buffer::Buffer,
Expand All @@ -14,7 +12,6 @@ mod ffi;
pub(super) mod fmt;
mod iterator;

type FieldEntry = (usize, Box<dyn Array>);
type UnionComponents<'a> = (&'a [Field], Option<&'a [i32]>, UnionMode);

/// [`UnionArray`] represents an array whose each slot can contain different values.
Expand All @@ -29,10 +26,13 @@ type UnionComponents<'a> = (&'a [Field], Option<&'a [i32]>, UnionMode);
// ```
#[derive(Clone)]
pub struct UnionArray {
// Invariant: every item in `types` is `> 0 && < fields.len()`
types: Buffer<i8>,
// None represents when there is no typeid
fields_hash: Option<AHashMap<i8, FieldEntry>>,
// Invariant: `map.len() == fields.len()`
// Invariant: every item in `map` is `> 0 && < fields.len()`
map: Option<Vec<usize>>,
fields: Vec<Box<dyn Array>>,
// Invariant: when set, `offsets.len() == types.len()`
offsets: Option<Buffer<i32>>,
data_type: DataType,
offset: usize,
Expand All @@ -44,6 +44,7 @@ impl UnionArray {
/// This function errors iff:
/// * `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Union`].
/// * the fields's len is different from the `data_type`'s children's length
/// * The number of `fields` is larger than `i8::MAX`
/// * any of the values's data type is different from its corresponding children' data type
pub fn try_new(
data_type: DataType,
Expand All @@ -58,6 +59,10 @@ impl UnionArray {
"The number of `fields` must equal the number of children fields in DataType::Union",
));
};
let number_of_fields: i8 = fields
.len()
.try_into()
.map_err(|_| Error::oos("The number of `fields` cannot be larger than i8::MAX"))?;

f
.iter().map(|a| a.data_type())
Expand All @@ -74,27 +79,52 @@ impl UnionArray {
}
})?;

if let Some(offsets) = &offsets {
if offsets.len() != types.len() {
return Err(Error::oos(
"In a UnionArray, the offsets' length must be equal to the number of types",
));
}
}
if offsets.is_none() != mode.is_sparse() {
return Err(Error::oos(
"The offsets must be set when the Union is dense and vice-versa",
"In a sparse UnionArray, the offsets must be set (and vice-versa)",
));
}

// build hash
let map = if let Some(&ids) = ids.as_ref() {
if ids.len() != fields.len() {
return Err(Error::oos(
"In a union, when the ids are set, their length must be equal to the number of fields",
));
}
ids.iter().map(|&id| {
if id < 0 || id >= fields.len() as i32 {
return Err(Error::oos("In a union, when the ids are set, each id must be smaller than the number of fields."));
}
Ok(id as usize)
}).collect::<Result<Vec<_>, Error>>().map(Some)?
} else {
None
};

// Safety: every type in types is smaller than number of fields
let mut is_valid = true;
for &type_ in types.iter() {
if type_ < 0 || type_ >= number_of_fields {
is_valid = false
}
}
if !is_valid {
return Err(Error::oos(
"Every type in `types` must be larger than 0 and smaller than the number of fields.",
));
}

let fields_hash = ids.as_ref().map(|ids| {
ids.iter()
.map(|x| *x as i8)
.enumerate()
.zip(fields.iter().cloned())
.map(|((i, type_), field)| (type_, (i, field)))
.collect()
});

// not validated:
// * `offsets` is valid
// * max id < fields.len()
Ok(Self {
data_type,
fields_hash,
map,
fields,
offsets,
types,
Expand Down Expand Up @@ -128,7 +158,7 @@ impl UnionArray {
let offsets = if mode.is_sparse() {
None
} else {
Some((0..length as i32).collect::<Buffer<i32>>())
Some((0..length as i32).collect::<Vec<_>>().into())
};

// all from the same field
Expand All @@ -151,12 +181,12 @@ impl UnionArray {
let offsets = if mode.is_sparse() {
None
} else {
Some(Buffer::new())
Some(Buffer::default())
};

Self {
data_type,
fields_hash: None,
map: None,
fields,
offsets,
types: Buffer::new(),
Expand Down Expand Up @@ -186,17 +216,11 @@ impl UnionArray {
/// This function panics iff `offset + length >= self.len()`.
#[inline]
pub fn slice(&self, offset: usize, length: usize) -> Self {
Self {
data_type: self.data_type.clone(),
fields: self.fields.clone(),
fields_hash: self.fields_hash.clone(),
types: self.types.clone().slice(offset, length),
offsets: self
.offsets
.clone()
.map(|offsets| offsets.slice(offset, length)),
offset: self.offset + offset,
}
assert!(
offset + length <= self.len(),
"the offset of the new array cannot exceed the existing length"
);
unsafe { self.slice_unchecked(offset, length) }
}

/// Returns a slice of this [`UnionArray`].
Expand All @@ -206,10 +230,11 @@ impl UnionArray {
/// The caller must ensure that `offset + length <= self.len()`.
#[inline]
pub unsafe fn slice_unchecked(&self, offset: usize, length: usize) -> Self {
debug_assert!(offset + length <= self.len());
Self {
data_type: self.data_type.clone(),
fields: self.fields.clone(),
fields_hash: self.fields_hash.clone(),
map: self.map.clone(),
types: self.types.clone().slice_unchecked(offset, length),
offsets: self
.offsets
Expand Down Expand Up @@ -243,38 +268,57 @@ impl UnionArray {
}

#[inline]
fn field(&self, type_: i8) -> &dyn Array {
self.fields_hash
.as_ref()
.map(|x| x[&type_].1.as_ref())
.unwrap_or_else(|| self.fields[type_ as usize].as_ref())
}

#[inline]
fn field_slot(&self, index: usize) -> usize {
unsafe fn field_slot_unchecked(&self, index: usize) -> usize {
self.offsets()
.as_ref()
.map(|x| x[index] as usize)
.map(|x| *x.get_unchecked(index) as usize)
.unwrap_or(index + self.offset)
}

/// Returns the index and slot of the field to select from `self.fields`.
#[inline]
pub fn index(&self, index: usize) -> (usize, usize) {
let type_ = self.types()[index];
let field_index = self
.fields_hash
assert!(index < self.len());
unsafe { self.index_unchecked(index) }
}

/// Returns the index and slot of the field to select from `self.fields`.
/// The first value is guaranteed to be `< self.fields().len()`
/// # Safety
/// This function is safe iff `index < self.len`.
#[inline]
pub unsafe fn index_unchecked(&self, index: usize) -> (usize, usize) {
debug_assert!(index < self.len());
// Safety: assumption of the function
let type_ = unsafe { *self.types.get_unchecked(index) };
// Safety: assumption of the struct
let type_ = self
.map
.as_ref()
.map(|x| x[&type_].0)
.unwrap_or_else(|| type_ as usize);
let index = self.field_slot(index);
(field_index, index)
.map(|map| unsafe { *map.get_unchecked(type_ as usize) })
.unwrap_or(type_ as usize);
// Safety: assumption of the function
let index = self.field_slot_unchecked(index);
(type_, index)
}

/// Returns the slot `index` as a [`Scalar`].
/// # Panics
/// iff `index >= self.len()`
pub fn value(&self, index: usize) -> Box<dyn Scalar> {
let type_ = self.types()[index];
let field = self.field(type_);
let index = self.field_slot(index);
assert!(index < self.len());
unsafe { self.value_unchecked(index) }
}

/// Returns the slot `index` as a [`Scalar`].
/// # Safety
/// This function is safe iff `i < self.len`.
pub unsafe fn value_unchecked(&self, index: usize) -> Box<dyn Scalar> {
debug_assert!(index < self.len());
let (type_, index) = self.index_unchecked(index);
// Safety: assumption of the struct
debug_assert!(type_ < self.fields.len());
let field = self.fields.get_unchecked(type_).as_ref();
new_scalar(field, index)
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/compute/sort/row/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -647,9 +647,9 @@ mod tests {
#[test]
fn test_fixed_width() {
let cols = [
Int16Array::from_iter([Some(1), Some(2), None, Some(-5), Some(2), Some(2), Some(0)])
Int16Array::from([Some(1), Some(2), None, Some(-5), Some(2), Some(2), Some(0)])
.to_boxed(),
Float32Array::from_iter([
Float32Array::from([
Some(1.3),
Some(2.5),
None,
Expand Down
6 changes: 3 additions & 3 deletions src/compute/sort/row/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,9 @@ pub fn encode<'a, I: Iterator<Item = Option<&'a [u8]>>>(out: &mut Rows, i: I, op
// Write `2_u8` to demarcate as non-empty, non-null string
to_write[0] = NON_EMPTY_SENTINEL;

let chunks = val.chunks_exact(BLOCK_SIZE);
let remainder = chunks.remainder();
let mut chunks = val.chunks_exact(BLOCK_SIZE);
for (input, output) in chunks
.clone()
.by_ref()
.zip(to_write[1..].chunks_exact_mut(BLOCK_SIZE + 1))
{
let input: &[u8; BLOCK_SIZE] = input.try_into().unwrap();
Expand All @@ -92,6 +91,7 @@ pub fn encode<'a, I: Iterator<Item = Option<&'a [u8]>>>(out: &mut Rows, i: I, op
output[BLOCK_SIZE] = BLOCK_CONTINUATION;
}

let remainder = chunks.remainder();
if !remainder.is_empty() {
let start_offset = 1 + (block_count - 1) * (BLOCK_SIZE + 1);
to_write[start_offset..start_offset + remainder.len()]
Expand Down
3 changes: 2 additions & 1 deletion src/io/json_integration/read/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,8 @@ pub fn to_array(
}
_ => panic!(),
})
.collect(),
.collect::<Vec<_>>()
.into(),
)
})
.unwrap_or_default();
Expand Down
Loading

0 comments on commit e60b886

Please sign in to comment.