Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Aligns MutableDictionaryArray's with MutablePrimitiveArrays with TryPush #981

Merged
merged 3 commits into from
May 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions src/array/dictionary/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ use std::{collections::hash_map::DefaultHasher, sync::Arc};

use hash_hasher::HashedMap;

use crate::array::TryExtend;
use crate::{
array::{primitive::MutablePrimitiveArray, Array, MutableArray},
array::{primitive::MutablePrimitiveArray, Array, MutableArray, TryExtend, TryPush},
bitmap::MutableBitmap,
datatypes::DataType,
error::{ArrowError, Result},
Expand All @@ -14,6 +13,20 @@ use crate::{
use super::{DictionaryArray, DictionaryKey};

/// A mutable, strong-typed version of [`DictionaryArray`].
///
/// # Example
/// Building a UTF8 dictionary with `i32` keys.
/// ```
/// # use arrow2::array::{MutableDictionaryArray, MutableUtf8Array, TryPush};
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
/// let mut array: MutableDictionaryArray<i32, MutableUtf8Array<i32>> = MutableDictionaryArray::new();
/// array.try_push(Some("A"))?;
/// array.try_push(Some("B"))?;
/// array.push_null();
/// array.try_push(Some("C"))?;
/// # Ok(())
/// # }
/// ```
#[derive(Debug)]
pub struct MutableDictionaryArray<K: DictionaryKey, M: MutableArray> {
data_type: DataType,
Expand Down Expand Up @@ -68,7 +81,7 @@ impl<K: DictionaryKey, M: MutableArray + Default> Default for MutableDictionaryA

impl<K: DictionaryKey, M: MutableArray> MutableDictionaryArray<K, M> {
/// Returns whether the value should be pushed to the values or not
pub fn try_push_valid<T: Hash>(&mut self, value: &T) -> Result<bool> {
fn try_push_valid<T: Hash>(&mut self, value: &T) -> Result<bool> {
let mut hasher = DefaultHasher::new();
value.hash(&mut hasher);
let hash = hasher.finish();
Expand Down Expand Up @@ -118,6 +131,16 @@ impl<K: DictionaryKey, M: MutableArray> MutableDictionaryArray<K, M> {
self.values.shrink_to_fit();
self.keys.shrink_to_fit();
}

/// Returns the dictionary map
pub fn map(&self) -> &HashedMap<u64, K> {
&self.map
}

/// Returns the dictionary keys
pub fn keys(&self) -> &MutablePrimitiveArray<K> {
&self.keys
}
}

impl<K: DictionaryKey, M: 'static + MutableArray> MutableArray for MutableDictionaryArray<K, M> {
Expand Down Expand Up @@ -181,3 +204,23 @@ where
Ok(())
}
}

impl<K, M, T> TryPush<Option<T>> for MutableDictionaryArray<K, M>
where
K: DictionaryKey,
M: MutableArray + TryPush<Option<T>>,
T: Hash,
{
fn try_push(&mut self, item: Option<T>) -> Result<()> {
if let Some(value) = item {
if self.try_push_valid(&value)? {
self.values.try_push(Some(value))
} else {
Ok(())
}
} else {
self.push_null();
Ok(())
}
}
}
36 changes: 36 additions & 0 deletions tests/it/array/dictionary/mutable.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use arrow2::array::*;
use arrow2::error::Result;
use hash_hasher::HashedMap;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};

#[test]
fn primitive() -> Result<()> {
Expand Down Expand Up @@ -38,3 +41,36 @@ fn binary_natural() -> Result<()> {
assert_eq!(a.values().len(), 2);
Ok(())
}

#[test]
fn push_utf8() {
let mut new: MutableDictionaryArray<i32, MutableUtf8Array<i32>> = MutableDictionaryArray::new();

for value in [Some("A"), Some("B"), None, Some("C"), Some("A"), Some("B")] {
new.try_push(value).unwrap();
}

assert_eq!(
new.values().values(),
MutableUtf8Array::<i32>::from_iter_values(["A", "B", "C"].into_iter()).values()
);

let mut expected_keys = MutablePrimitiveArray::<i32>::from_slice(&[0, 1]);
expected_keys.push(None);
expected_keys.push(Some(2));
expected_keys.push(Some(0));
expected_keys.push(Some(1));
assert_eq!(*new.keys(), expected_keys);

let expected_map = ["A", "B", "C"]
.iter()
.enumerate()
.map(|(index, value)| {
let mut hasher = DefaultHasher::new();
value.hash(&mut hasher);
let hash = hasher.finish();
(hash, index as i32)
})
.collect::<HashedMap<_, _>>();
assert_eq!(*new.map(), expected_map);
}