diff --git a/arrow-array/src/array/struct_array.rs b/arrow-array/src/array/struct_array.rs index 0604f71d3294..e31594d4b073 100644 --- a/arrow-array/src/array/struct_array.rs +++ b/arrow-array/src/array/struct_array.rs @@ -18,12 +18,34 @@ use crate::{make_array, Array, ArrayRef, RecordBatch}; use arrow_buffer::{buffer_bin_or, Buffer, NullBuffer}; use arrow_data::ArrayData; -use arrow_schema::{ArrowError, DataType, Field, SchemaBuilder}; +use arrow_schema::{ArrowError, DataType, Field, Fields, SchemaBuilder}; use std::sync::Arc; use std::{any::Any, ops::Index}; /// A nested array type where each child (called *field*) is represented by a separate /// array. +/// +/// +/// # Comparison with [RecordBatch] +/// +/// Both [`RecordBatch`] and [`StructArray`] represent a collection of columns / arrays with the +/// same length. +/// +/// However, there are a couple of key differences: +/// +/// * [`StructArray`] can be nested within other [`Array`], including itself +/// * [`RecordBatch`] can contain top-level metadata on its associated [`Schema`][arrow_schema::Schema] +/// * [`StructArray`] can contain top-level nulls, i.e. `null` +/// * [`RecordBatch`] can only represent nulls in its child columns, i.e. `{"field": null}` +/// +/// [`StructArray`] is therefore a more general data container than [`RecordBatch`], and as such +/// code that needs to handle both will typically share an implementation in terms of +/// [`StructArray`] and convert to/from [`RecordBatch`] as necessary. +/// +/// [`From`] implementations are provided to facilitate this conversion, however, converting +/// from a [`StructArray`] containing top-level nulls to a [`RecordBatch`] will panic, as there +/// is no way to preserve them. +/// /// # Example: Create an array from a vector of fields /// /// ``` @@ -89,6 +111,14 @@ impl StructArray { } } + /// Returns the [`Fields`] of this [`StructArray`] + pub fn fields(&self) -> &Fields { + match self.data_type() { + DataType::Struct(f) => f, + _ => unreachable!(), + } + } + /// Return child array whose field name equals to column_name /// /// Note: A schema can currently have duplicate field names, in which case diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs index db4bb1230ca7..081bd55fc650 100644 --- a/arrow-array/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -446,23 +446,28 @@ impl Default for RecordBatchOptions { Self::new() } } +impl From for RecordBatch { + fn from(value: StructArray) -> Self { + assert_eq!( + value.null_count(), + 0, + "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation" + ); + let row_count = value.len(); + let schema = Arc::new(Schema::new(value.fields().clone())); + let columns = value.boxed_fields; + + RecordBatch { + schema, + row_count, + columns, + } + } +} + impl From<&StructArray> for RecordBatch { - /// Create a record batch from struct array, where each field of - /// the `StructArray` becomes a `Field` in the schema. - /// - /// This currently does not flatten and nested struct types fn from(struct_array: &StructArray) -> Self { - if let DataType::Struct(fields) = struct_array.data_type() { - let schema = Schema::new(fields.clone()); - let columns = struct_array.boxed_fields.clone(); - RecordBatch { - schema: Arc::new(schema), - row_count: struct_array.len(), - columns, - } - } else { - unreachable!("unable to get datatype as struct") - } + struct_array.clone().into() } } @@ -558,7 +563,7 @@ mod tests { BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray, }; use arrow_buffer::{Buffer, ToByteSlice}; - use arrow_data::ArrayDataBuilder; + use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::Fields; #[test] @@ -1046,4 +1051,15 @@ mod tests { assert!(!options.match_field_names); assert_eq!(options.row_count.unwrap(), 20) } + + #[test] + #[should_panic(expected = "Cannot convert nullable StructArray to RecordBatch")] + fn test_from_struct() { + let s = StructArray::from(ArrayData::new_null( + // Note child is not nullable + &DataType::Struct(vec![Field::new("foo", DataType::Int32, false)].into()), + 2, + )); + let _ = RecordBatch::from(s); + } } diff --git a/arrow/src/ffi_stream.rs b/arrow/src/ffi_stream.rs index b1046d142f32..6b3067ab7d75 100644 --- a/arrow/src/ffi_stream.rs +++ b/arrow/src/ffi_stream.rs @@ -373,7 +373,7 @@ impl Iterator for ArrowArrayStreamReader { .to_data() .ok()?; - let record_batch = RecordBatch::from(&StructArray::from(data)); + let record_batch = RecordBatch::from(StructArray::from(data)); Some(Ok(record_batch)) } else { @@ -492,7 +492,7 @@ mod tests { .to_data() .unwrap(); - let record_batch = RecordBatch::from(&StructArray::from(array)); + let record_batch = RecordBatch::from(StructArray::from(array)); produced_batches.push(record_batch); } diff --git a/parquet/src/arrow/array_reader/struct_array.rs b/parquet/src/arrow/array_reader/struct_array.rs index 22724ae3f081..0670701a0375 100644 --- a/parquet/src/arrow/array_reader/struct_array.rs +++ b/parquet/src/arrow/array_reader/struct_array.rs @@ -217,6 +217,7 @@ mod tests { use crate::arrow::array_reader::ListArrayReader; use arrow::buffer::Buffer; use arrow::datatypes::Field; + use arrow_array::cast::AsArray; use arrow_array::{Array, Int32Array, ListArray}; use arrow_schema::Fields; @@ -252,7 +253,7 @@ mod tests { ); let struct_array = struct_array_reader.next_batch(5).unwrap(); - let struct_array = struct_array.as_any().downcast_ref::().unwrap(); + let struct_array = struct_array.as_struct(); assert_eq!(5, struct_array.len()); assert_eq!( @@ -328,7 +329,7 @@ mod tests { ); let actual = struct_reader.next_batch(1024).unwrap(); - let actual = actual.as_any().downcast_ref::().unwrap(); + let actual = actual.as_struct(); assert_eq!(actual, &expected) } } diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs index ba322e29d868..4b88a33f3a25 100644 --- a/parquet/src/arrow/arrow_reader/mod.rs +++ b/parquet/src/arrow/arrow_reader/mod.rs @@ -20,7 +20,8 @@ use std::collections::VecDeque; use std::sync::Arc; -use arrow_array::{Array, StructArray}; +use arrow_array::cast::AsArray; +use arrow_array::Array; use arrow_array::{RecordBatch, RecordBatchReader}; use arrow_schema::{ArrowError, DataType as ArrowType, Schema, SchemaRef}; use arrow_select::filter::prep_null_mask_filter; @@ -559,12 +560,11 @@ impl Iterator for ParquetRecordBatchReader { match self.array_reader.consume_batch() { Err(error) => Some(Err(error.into())), Ok(array) => { - let struct_array = - array.as_any().downcast_ref::().ok_or_else(|| { - ArrowError::ParquetError( - "Struct array reader should return struct array".to_string(), - ) - }); + let struct_array = array.as_struct_opt().ok_or_else(|| { + ArrowError::ParquetError( + "Struct array reader should return struct array".to_string(), + ) + }); match struct_array { Err(err) => Some(Err(err)),