diff --git a/src/compute/take/fixed_size_list.rs b/src/compute/take/fixed_size_list.rs new file mode 100644 index 00000000000..31fc04d65e9 --- /dev/null +++ b/src/compute/take/fixed_size_list.rs @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::array::growable::GrowableFixedSizeList; +use crate::array::FixedSizeListArray; +use crate::array::{growable::Growable, PrimitiveArray}; + +use super::Index; + +/// `take` implementation for FixedSizeListArrays +pub fn take( + values: &FixedSizeListArray, + indices: &PrimitiveArray, +) -> FixedSizeListArray { + let mut capacity = 0; + let arrays = indices + .values() + .iter() + .map(|index| { + let index = index.to_usize(); + let slice = values.slice(index, 1); + capacity += slice.len(); + slice + }) + .collect::>(); + + let arrays = arrays.iter().collect(); + + if let Some(validity) = indices.validity() { + let mut growable: GrowableFixedSizeList = + GrowableFixedSizeList::new(arrays, true, capacity); + + for index in 0..indices.len() { + if validity.get_bit(index) { + growable.extend(index, 0, 1); + } else { + growable.extend_validity(1) + } + } + + growable.into() + } else { + let mut growable: GrowableFixedSizeList = + GrowableFixedSizeList::new(arrays, false, capacity); + for index in 0..indices.len() { + growable.extend(index, 0, 1); + } + + growable.into() + } +} diff --git a/src/compute/take/mod.rs b/src/compute/take/mod.rs index b9ae790d0fd..3acf47dc7a1 100644 --- a/src/compute/take/mod.rs +++ b/src/compute/take/mod.rs @@ -27,6 +27,7 @@ use crate::{ mod binary; mod boolean; mod dict; +mod fixed_size_list; mod generic_binary; mod list; mod primitive; @@ -90,6 +91,10 @@ pub fn take(values: &dyn Array, indices: &PrimitiveArray) -> Result let array = values.as_any().downcast_ref().unwrap(); Ok(Box::new(list::take::(array, indices))) } + FixedSizeList => { + let array = values.as_any().downcast_ref().unwrap(); + Ok(Box::new(fixed_size_list::take::(array, indices))) + } t => unimplemented!("Take not supported for data type {:?}", t), } } @@ -135,6 +140,7 @@ pub fn can_take(data_type: &DataType) -> bool { | DataType::Struct(_) | DataType::List(_) | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) | DataType::Dictionary(..) ) } diff --git a/tests/it/compute/take.rs b/tests/it/compute/take.rs index 39fcf94da6e..feaa0d82081 100644 --- a/tests/it/compute/take.rs +++ b/tests/it/compute/take.rs @@ -259,6 +259,25 @@ fn list_both_validity() { assert_eq!(expected, result.as_ref()); } +#[test] +fn fixed_size_list_with_no_none() { + let values = Buffer::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + let values = PrimitiveArray::::new(DataType::Int32, values, None); + + let data_type = FixedSizeListArray::default_datatype(DataType::Int32, 2); + let array = FixedSizeListArray::new(data_type, Box::new(values), None); + + let indices = PrimitiveArray::from([Some(4i32), Some(1), Some(3)]); + let result = take(&array, &indices).unwrap(); + + let expected_values = Buffer::from(vec![8, 9, 2, 3, 6, 7]); + let expected_values = PrimitiveArray::::new(DataType::Int32, expected_values, None); + let expected_type = FixedSizeListArray::default_datatype(DataType::Int32, 2); + let expected = FixedSizeListArray::new(expected_type, Box::new(expected_values), None); + + assert_eq!(expected, result.as_ref()); +} + #[test] fn test_nested() { let values = Buffer::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);