From ced382aaa278ca591a9625c9691ac80e83e5a747 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 21 Aug 2021 11:05:55 +0000 Subject: [PATCH] Removed clone requirement in Struct -> RecordBatch. --- src/array/struct_.rs | 14 ++++++++++++++ src/record_batch.rs | 19 ++++++++----------- tests/it/record_batch.rs | 22 +++++++++------------- 3 files changed, 31 insertions(+), 24 deletions(-) diff --git a/src/array/struct_.rs b/src/array/struct_.rs index 86c2b092102..9a6191a5d47 100644 --- a/src/array/struct_.rs +++ b/src/array/struct_.rs @@ -52,6 +52,20 @@ impl StructArray { } } + pub fn into_data(self) -> (Vec, Vec>, Option) { + let Self { + data_type, + values, + validity, + } = self; + let fields = if let DataType::Struct(fields) = data_type { + fields + } else { + unreachable!() + }; + (fields, values, validity) + } + pub fn slice(&self, offset: usize, length: usize) -> Self { let validity = self.validity.clone().map(|x| x.slice(offset, length)); Self { diff --git a/src/record_batch.rs b/src/record_batch.rs index 07e512cd06c..ddd171af545 100644 --- a/src/record_batch.rs +++ b/src/record_batch.rs @@ -339,17 +339,14 @@ impl Default for RecordBatchOptions { } } -impl From<&StructArray> for RecordBatch { - /// Create a record batch from struct array. - /// - /// This currently does not flatten and nested struct types - fn from(struct_array: &StructArray) -> Self { - if let DataType::Struct(fields) = struct_array.data_type() { - let schema = Arc::new(Schema::new(fields.clone())); - let columns = struct_array.values().to_vec(); - RecordBatch { schema, columns } - } else { - unreachable!("unable to get datatype as struct") +impl From for RecordBatch { + /// # Panics iff the null count of the array is not null. + fn from(array: StructArray) -> Self { + assert!(array.null_count() == 0); + let (fields, values, _) = array.into_data(); + RecordBatch { + schema: Arc::new(Schema::new(fields)), + columns: values, } } } diff --git a/tests/it/record_batch.rs b/tests/it/record_batch.rs index b9ca197c060..8bda0484b88 100644 --- a/tests/it/record_batch.rs +++ b/tests/it/record_batch.rs @@ -91,22 +91,18 @@ fn number_of_fields_mismatch() { fn from_struct_array() { let boolean = Arc::new(BooleanArray::from_slice(&[false, false, true, true])) as ArrayRef; let int = Arc::new(Int32Array::from_slice(&[42, 28, 19, 31])) as ArrayRef; - let struct_array = StructArray::from_data( - vec![ - Field::new("b", DataType::Boolean, false), - Field::new("c", DataType::Int32, false), - ], - vec![boolean.clone(), int.clone()], - None, - ); - let batch = RecordBatch::from(&struct_array); + let fields = vec![ + Field::new("b", DataType::Boolean, false), + Field::new("c", DataType::Int32, false), + ]; + + let array = StructArray::from_data(fields.clone(), vec![boolean.clone(), int.clone()], None); + + let batch = RecordBatch::from(array); assert_eq!(2, batch.num_columns()); assert_eq!(4, batch.num_rows()); - assert_eq!( - struct_array.data_type(), - &DataType::Struct(batch.schema().fields().to_vec()) - ); + assert_eq!(&fields, batch.schema().fields()); assert_eq!(boolean.as_ref(), batch.column(0).as_ref()); assert_eq!(int.as_ref(), batch.column(1).as_ref()); }