From e52f2d8bca5dbc011317f8ec1dd184c636034438 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 26 Dec 2021 18:12:52 +0000 Subject: [PATCH 1/6] Moved dict_id to IPC-specific IO --- examples/extension.rs | 4 +- examples/ipc_file_write.rs | 4 +- examples/json_read.rs | 2 +- src/columns.rs | 83 ++ src/datatypes/field.rs | 34 - src/io/flight/mod.rs | 45 +- src/io/ipc/convert.rs | 915 ------------------ src/io/ipc/mod.rs | 18 +- src/io/ipc/read/array/dictionary.rs | 17 +- src/io/ipc/read/array/fixed_size_list.rs | 11 +- src/io/ipc/read/array/list.rs | 11 +- src/io/ipc/read/array/map.rs | 11 +- src/io/ipc/read/array/struct_.rs | 14 +- src/io/ipc/read/array/union.rs | 14 +- src/io/ipc/read/common.rs | 111 +-- src/io/ipc/read/deserialize.rs | 16 +- src/io/ipc/read/mod.rs | 9 + src/io/ipc/read/reader.rs | 55 +- src/io/ipc/read/schema.rs | 334 +++++++ src/io/ipc/read/stream.rs | 50 +- src/io/ipc/write/common.rs | 96 +- src/io/ipc/write/mod.rs | 46 + src/io/ipc/write/schema.rs | 671 ++++++++++++- src/io/ipc/write/stream.rs | 57 +- src/io/ipc/write/stream_async.rs | 51 +- src/io/ipc/write/writer.rs | 69 +- src/io/json_integration/mod.rs | 82 +- .../{read.rs => read/array.rs} | 70 +- src/io/json_integration/read/mod.rs | 4 + src/io/json_integration/{ => read}/schema.rs | 410 +++----- .../{write.rs => write/array.rs} | 3 +- src/io/json_integration/write/mod.rs | 4 + src/io/json_integration/write/schema.rs | 173 ++++ src/io/parquet/read/schema/metadata.rs | 2 +- src/io/parquet/write/schema.rs | 3 +- src/lib.rs | 1 + src/record_batch.rs | 6 + tests/it/io/ipc/common.rs | 33 +- tests/it/io/ipc/read/file.rs | 2 +- tests/it/io/ipc/read/stream.rs | 5 +- tests/it/io/ipc/write/file.rs | 70 +- tests/it/io/ipc/write/stream.rs | 39 +- tests/it/io/ipc/write_async.rs | 56 +- tests/it/io/parquet/mod.rs | 2 +- 44 files changed, 1957 insertions(+), 1756 deletions(-) create mode 100644 src/columns.rs delete mode 100644 src/io/ipc/convert.rs create mode 100644 src/io/ipc/read/schema.rs rename src/io/json_integration/{read.rs => read/array.rs} (86%) create mode 100644 src/io/json_integration/read/mod.rs rename src/io/json_integration/{ => read}/schema.rs (51%) rename src/io/json_integration/{write.rs => write/array.rs} (93%) create mode 100644 src/io/json_integration/write/mod.rs create mode 100644 src/io/json_integration/write/schema.rs diff --git a/examples/extension.rs b/examples/extension.rs index 53b0ed2ad55..a9659a821ae 100644 --- a/examples/extension.rs +++ b/examples/extension.rs @@ -38,11 +38,11 @@ fn write_ipc(writer: W, array: impl Array + 'static) -> Result< let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), false)]); let options = write::WriteOptions { compression: None }; - let mut writer = write::FileWriter::try_new(writer, &schema, options)?; + let mut writer = write::FileWriter::try_new(writer, &schema, None, options)?; let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)])?; - writer.write(&batch)?; + writer.write(&batch, None)?; Ok(writer.into_inner()) } diff --git a/examples/ipc_file_write.rs b/examples/ipc_file_write.rs index 5ae5b26ba84..757622b8ced 100644 --- a/examples/ipc_file_write.rs +++ b/examples/ipc_file_write.rs @@ -11,10 +11,10 @@ fn write_batches(path: &str, schema: &Schema, batches: &[RecordBatch]) -> Result let file = File::create(path)?; let options = write::WriteOptions { compression: None }; - let mut writer = write::FileWriter::try_new(file, schema, options)?; + let mut writer = write::FileWriter::try_new(file, schema, None, options)?; for batch in batches { - writer.write(batch)? + writer.write(batch, None)? } writer.finish() } diff --git a/examples/json_read.rs b/examples/json_read.rs index ddc4ba5f9be..9c975f2067a 100644 --- a/examples/json_read.rs +++ b/examples/json_read.rs @@ -31,7 +31,7 @@ fn read_path(path: &str, projection: Option>) -> Result { // deserialize `rows` into a `RecordBatch`. This is CPU-intensive, has no IO, // and can be performed on a different thread pool via a channel. - read::deserialize(&rows, fields) + read::deserialize(rows, fields) } fn main() -> Result<()> { diff --git a/src/columns.rs b/src/columns.rs new file mode 100644 index 00000000000..653f3448f7c --- /dev/null +++ b/src/columns.rs @@ -0,0 +1,83 @@ +//! Contains [`Columns`], a container [`Array`] where all arrays have the +//! same length. +use std::sync::Arc; + +use crate::array::Array; +use crate::error::{ArrowError, Result}; +use crate::record_batch::RecordBatch; + +/// A vector of [`Array`] where every array has the same length. +#[derive(Debug, Clone, PartialEq)] +pub struct Columns> { + arrays: Vec, +} + +impl> Columns { + /// Creates a new [`Columns`]. + /// # Panic + /// Iff the arrays do not have the same length + pub fn new(arrays: Vec) -> Self { + Self::try_new(arrays).unwrap() + } + + /// Creates a new [`Columns`]. + /// # Error + /// Iff the arrays do not have the same length + pub fn try_new(arrays: Vec) -> Result { + if !arrays.is_empty() { + let len = arrays.first().unwrap().as_ref().len(); + if arrays + .iter() + .map(|array| array.as_ref()) + .any(|array| array.len() != len) + { + return Err(ArrowError::InvalidArgumentError( + "Columns require all its arrays to have an equal number of rows".to_string(), + )); + } + } + Ok(Self { arrays }) + } + + /// returns the [`Array`]s in [`Columns`]. + pub fn arrays(&self) -> &[A] { + &self.arrays + } + + /// returns the length (number of rows) + pub fn len(&self) -> usize { + self.arrays + .first() + .map(|x| x.as_ref().len()) + .unwrap_or_default() + } + + /// Consumes [`Columns`] into its underlying arrays. + /// The arrays are guaranteed to have the same length + pub fn into_arrays(self) -> Vec { + self.arrays + } +} + +impl> From> for Vec { + fn from(c: Columns) -> Self { + c.into_arrays() + } +} + +impl> std::ops::Deref for Columns { + type Target = [A]; + + #[inline] + fn deref(&self) -> &[A] { + self.arrays() + } +} + +impl From for Columns> { + fn from(batch: RecordBatch) -> Self { + Self { + arrays: batch.into_inner().0, + } + } +} diff --git a/src/datatypes/field.rs b/src/datatypes/field.rs index aa0d58f1b71..fc40864b1fd 100644 --- a/src/datatypes/field.rs +++ b/src/datatypes/field.rs @@ -31,8 +31,6 @@ pub struct Field { pub data_type: DataType, /// Whether its values can be null or not pub nullable: bool, - /// The dictionary id of this field (currently un-used) - pub dict_id: i64, /// A map of key-value pairs containing additional custom meta data. pub metadata: Option>, } @@ -62,23 +60,6 @@ impl Field { name: name.into(), data_type, nullable, - dict_id: 0, - metadata: None, - } - } - - /// Creates a new field - pub fn new_dict>( - name: T, - data_type: DataType, - nullable: bool, - dict_id: i64, - ) -> Self { - Field { - name: name.into(), - data_type, - nullable, - dict_id, metadata: None, } } @@ -90,7 +71,6 @@ impl Field { name: self.name, data_type: self.data_type, nullable: self.nullable, - dict_id: self.dict_id, metadata: Some(metadata), } } @@ -131,15 +111,6 @@ impl Field { self.nullable } - /// Returns the dictionary ID, if this is a dictionary type. - #[inline] - pub const fn dict_id(&self) -> Option { - match self.data_type { - DataType::Dictionary(_, _, _) => Some(self.dict_id), - _ => None, - } - } - /// Merge field into self if it is compatible. Struct will be merged recursively. /// NOTE: `self` may be updated to unexpected state in case of merge failure. /// @@ -175,11 +146,6 @@ impl Field { } _ => {} } - if from.dict_id != self.dict_id { - return Err(ArrowError::InvalidArgumentError( - "Fail to merge schema Field due to conflicting dict_id".to_string(), - )); - } match &mut self.data_type { DataType::Struct(nested_fields) => match &from.data_type { DataType::Struct(from_nested_fields) => { diff --git a/src/io/flight/mod.rs b/src/io/flight/mod.rs index 8ff4bcda5c3..70720a180b0 100644 --- a/src/io/flight/mod.rs +++ b/src/io/flight/mod.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::convert::TryFrom; use std::sync::Arc; @@ -6,26 +5,28 @@ use arrow_format::flight::data::{FlightData, SchemaResult}; use arrow_format::ipc; use crate::{ - array::*, datatypes::*, error::{ArrowError, Result}, - io::ipc::fb_to_schema, - io::ipc::read::read_record_batch, + io::ipc::read, io::ipc::write, - io::ipc::write::common::{encoded_batch, DictionaryTracker, EncodedData, WriteOptions}, + io::ipc::write::common::{encode_columns, DictionaryTracker, EncodedData, WriteOptions}, record_batch::RecordBatch, }; +use super::ipc::{IpcField, IpcSchema}; + /// Serializes a [`RecordBatch`] to a vector of [`FlightData`] representing the serialized dictionaries /// and a [`FlightData`] representing the batch. pub fn serialize_batch( batch: &RecordBatch, + fields: &[IpcField], options: &WriteOptions, ) -> (Vec, FlightData) { let mut dictionary_tracker = DictionaryTracker::new(false); + let columns = batch.clone().into(); let (encoded_dictionaries, encoded_batch) = - encoded_batch(batch, &mut dictionary_tracker, options) + encode_columns(&columns, fields, &mut dictionary_tracker, options) .expect("DictionaryTracker configured above to not error on replacement"); let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); @@ -45,38 +46,38 @@ impl From for FlightData { } /// Serializes a [`Schema`] to [`SchemaResult`]. -pub fn serialize_schema_to_result(schema: &Schema) -> SchemaResult { +pub fn serialize_schema_to_result(schema: &Schema, ipc_fields: &[IpcField]) -> SchemaResult { SchemaResult { - schema: schema_as_flatbuffer(schema), + schema: schema_as_flatbuffer(schema, ipc_fields), } } /// Serializes a [`Schema`] to [`FlightData`]. -pub fn serialize_schema(schema: &Schema) -> FlightData { - let data_header = schema_as_flatbuffer(schema); +pub fn serialize_schema(schema: &Schema, ipc_fields: &[IpcField]) -> FlightData { + let data_header = schema_as_flatbuffer(schema, ipc_fields); FlightData { data_header, ..Default::default() } } -/// Convert a [`Schema`] to bytes in the format expected in [`arrow_format::flight::FlightInfo`]. -pub fn serialize_schema_to_info(schema: &Schema) -> Result> { - let encoded_data = schema_as_encoded_data(schema); +/// Convert a [`Schema`] to bytes in the format expected in [`arrow_format::flight::data::FlightInfo`]. +pub fn serialize_schema_to_info(schema: &Schema, ipc_fields: &[IpcField]) -> Result> { + let encoded_data = schema_as_encoded_data(schema, ipc_fields); let mut schema = vec![]; write::common_sync::write_message(&mut schema, encoded_data)?; Ok(schema) } -fn schema_as_flatbuffer(schema: &Schema) -> Vec { - let encoded_data = schema_as_encoded_data(schema); +fn schema_as_flatbuffer(schema: &Schema, ipc_fields: &[IpcField]) -> Vec { + let encoded_data = schema_as_encoded_data(schema, ipc_fields); encoded_data.ipc_message } -fn schema_as_encoded_data(schema: &Schema) -> EncodedData { +fn schema_as_encoded_data(schema: &Schema, ipc_fields: &[IpcField]) -> EncodedData { EncodedData { - ipc_message: write::schema_to_bytes(schema), + ipc_message: write::schema_to_bytes(schema, ipc_fields), arrow_data: vec![], } } @@ -84,7 +85,7 @@ fn schema_as_encoded_data(schema: &Schema) -> EncodedData { /// Deserialize an IPC message into a schema fn schema_from_bytes(bytes: &[u8]) -> Result { if let Ok(ipc) = ipc::Message::root_as_message(bytes) { - if let Some((schema, _)) = ipc.header_as_schema().map(fb_to_schema) { + if let Some((schema, _)) = ipc.header_as_schema().map(read::fb_to_schema) { Ok(schema) } else { Err(ArrowError::OutOfSpec( @@ -126,8 +127,8 @@ impl TryFrom<&SchemaResult> for Schema { pub fn deserialize_batch( data: &FlightData, schema: Arc, - is_little_endian: bool, - dictionaries: &HashMap>, + ipc_schema: &IpcSchema, + dictionaries: &read::Dictionaries, ) -> Result { // check that the data_header is a record batch message let message = ipc::Message::root_as_message(&data.data_header[..]).map_err(|err| { @@ -144,11 +145,11 @@ pub fn deserialize_batch( ) }) .map(|batch| { - read_record_batch( + read::read_record_batch( batch, schema.clone(), + ipc_schema, None, - is_little_endian, dictionaries, ipc::Schema::MetadataVersion::V5, &mut reader, diff --git a/src/io/ipc/convert.rs b/src/io/ipc/convert.rs deleted file mode 100644 index 2c7e0974d1b..00000000000 --- a/src/io/ipc/convert.rs +++ /dev/null @@ -1,915 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Utilities for converting between IPC types and native Arrow types - -use arrow_format::ipc::flatbuffers::{ - FlatBufferBuilder, ForwardsUOffset, UnionWIPOffset, Vector, WIPOffset, -}; -use std::collections::{BTreeMap, HashMap}; -mod ipc { - pub use arrow_format::ipc::File::*; - pub use arrow_format::ipc::Message::*; - pub use arrow_format::ipc::Schema::*; -} - -use crate::datatypes::{ - get_extension, DataType, Extension, Field, IntegerType, IntervalUnit, Metadata, Schema, - TimeUnit, UnionMode, -}; -use crate::io::ipc::endianess::is_native_little_endian; - -pub fn schema_to_fb_offset<'a>( - fbb: &mut FlatBufferBuilder<'a>, - schema: &Schema, -) -> WIPOffset> { - let mut fields = vec![]; - for field in schema.fields() { - let fb_field = build_field(fbb, field); - fields.push(fb_field); - } - - let mut custom_metadata = vec![]; - for (k, v) in schema.metadata() { - let fb_key_name = fbb.create_string(k.as_str()); - let fb_val_name = fbb.create_string(v.as_str()); - - let mut kv_builder = ipc::KeyValueBuilder::new(fbb); - kv_builder.add_key(fb_key_name); - kv_builder.add_value(fb_val_name); - custom_metadata.push(kv_builder.finish()); - } - - let fb_field_list = fbb.create_vector(&fields); - let fb_metadata_list = fbb.create_vector(&custom_metadata); - - let mut builder = ipc::SchemaBuilder::new(fbb); - builder.add_fields(fb_field_list); - builder.add_custom_metadata(fb_metadata_list); - builder.add_endianness(if is_native_little_endian() { - ipc::Endianness::Little - } else { - ipc::Endianness::Big - }); - builder.finish() -} - -fn read_metadata(field: &ipc::Field) -> Metadata { - if let Some(list) = field.custom_metadata() { - let mut metadata_map = BTreeMap::default(); - for kv in list { - if let (Some(k), Some(v)) = (kv.key(), kv.value()) { - metadata_map.insert(k.to_string(), v.to_string()); - } - } - Some(metadata_map) - } else { - None - } -} - -/// Convert an IPC Field to Arrow Field -impl<'a> From> for Field { - fn from(field: ipc::Field) -> Field { - let metadata = read_metadata(&field); - - let extension = get_extension(&metadata); - - let data_type = get_data_type(field, extension, true); - - let mut arrow_field = if let Some(dictionary) = field.dictionary() { - Field::new_dict( - field.name().unwrap(), - data_type, - field.nullable(), - dictionary.id(), - ) - } else { - Field::new(field.name().unwrap(), data_type, field.nullable()) - }; - - arrow_field.set_metadata(metadata); - arrow_field - } -} - -/// Deserialize a Schema table from IPC format to Schema data type -pub fn fb_to_schema(fb: ipc::Schema) -> (Schema, bool) { - let mut fields: Vec = vec![]; - let c_fields = fb.fields().unwrap(); - let len = c_fields.len(); - for i in 0..len { - let c_field: ipc::Field = c_fields.get(i); - fields.push(c_field.into()); - } - - let is_little_endian = fb.endianness().variant_name().unwrap_or("Little") == "Little"; - - let mut metadata: HashMap = HashMap::default(); - if let Some(md_fields) = fb.custom_metadata() { - let len = md_fields.len(); - for i in 0..len { - let kv = md_fields.get(i); - let k_str = kv.key(); - let v_str = kv.value(); - if let Some(k) = k_str { - if let Some(v) = v_str { - metadata.insert(k.to_string(), v.to_string()); - } - } - } - } - (Schema::new_from(fields, metadata), is_little_endian) -} - -/// Get the Arrow data type from the flatbuffer Field table -fn get_data_type(field: ipc::Field, extension: Extension, may_be_dictionary: bool) -> DataType { - if let Some(dictionary) = field.dictionary() { - if may_be_dictionary { - let int = dictionary.indexType().unwrap(); - let index_type = match (int.bitWidth(), int.is_signed()) { - (8, true) => IntegerType::Int8, - (8, false) => IntegerType::UInt8, - (16, true) => IntegerType::Int16, - (16, false) => IntegerType::UInt16, - (32, true) => IntegerType::Int32, - (32, false) => IntegerType::UInt32, - (64, true) => IntegerType::Int64, - (64, false) => IntegerType::UInt64, - _ => panic!("Unexpected bitwidth and signed"), - }; - return DataType::Dictionary( - index_type, - Box::new(get_data_type(field, extension, false)), - dictionary.isOrdered(), - ); - } - } - - if let Some(extension) = extension { - let (name, metadata) = extension; - let data_type = get_data_type(field, None, false); - return DataType::Extension(name, Box::new(data_type), metadata); - } - - match field.type_type() { - ipc::Type::Null => DataType::Null, - ipc::Type::Bool => DataType::Boolean, - ipc::Type::Int => { - let int = field.type_as_int().unwrap(); - match (int.bitWidth(), int.is_signed()) { - (8, true) => DataType::Int8, - (8, false) => DataType::UInt8, - (16, true) => DataType::Int16, - (16, false) => DataType::UInt16, - (32, true) => DataType::Int32, - (32, false) => DataType::UInt32, - (64, true) => DataType::Int64, - (64, false) => DataType::UInt64, - z => panic!( - "Int type with bit width of {} and signed of {} not supported", - z.0, z.1 - ), - } - } - ipc::Type::Binary => DataType::Binary, - ipc::Type::LargeBinary => DataType::LargeBinary, - ipc::Type::Utf8 => DataType::Utf8, - ipc::Type::LargeUtf8 => DataType::LargeUtf8, - ipc::Type::FixedSizeBinary => { - let fsb = field.type_as_fixed_size_binary().unwrap(); - DataType::FixedSizeBinary(fsb.byteWidth() as usize) - } - ipc::Type::FloatingPoint => { - let float = field.type_as_floating_point().unwrap(); - match float.precision() { - ipc::Precision::HALF => DataType::Float16, - ipc::Precision::SINGLE => DataType::Float32, - ipc::Precision::DOUBLE => DataType::Float64, - z => panic!("FloatingPoint type with precision of {:?} not supported", z), - } - } - ipc::Type::Date => { - let date = field.type_as_date().unwrap(); - match date.unit() { - ipc::DateUnit::DAY => DataType::Date32, - ipc::DateUnit::MILLISECOND => DataType::Date64, - z => panic!("Date type with unit of {:?} not supported", z), - } - } - ipc::Type::Time => { - let time = field.type_as_time().unwrap(); - match (time.bitWidth(), time.unit()) { - (32, ipc::TimeUnit::SECOND) => DataType::Time32(TimeUnit::Second), - (32, ipc::TimeUnit::MILLISECOND) => DataType::Time32(TimeUnit::Millisecond), - (64, ipc::TimeUnit::MICROSECOND) => DataType::Time64(TimeUnit::Microsecond), - (64, ipc::TimeUnit::NANOSECOND) => DataType::Time64(TimeUnit::Nanosecond), - z => panic!( - "Time type with bit width of {} and unit of {:?} not supported", - z.0, z.1 - ), - } - } - ipc::Type::Timestamp => { - let timestamp = field.type_as_timestamp().unwrap(); - let timezone: Option = timestamp.timezone().map(|tz| tz.to_string()); - match timestamp.unit() { - ipc::TimeUnit::SECOND => DataType::Timestamp(TimeUnit::Second, timezone), - ipc::TimeUnit::MILLISECOND => DataType::Timestamp(TimeUnit::Millisecond, timezone), - ipc::TimeUnit::MICROSECOND => DataType::Timestamp(TimeUnit::Microsecond, timezone), - ipc::TimeUnit::NANOSECOND => DataType::Timestamp(TimeUnit::Nanosecond, timezone), - z => panic!("Timestamp type with unit of {:?} not supported", z), - } - } - ipc::Type::Interval => { - let interval = field.type_as_interval().unwrap(); - match interval.unit() { - ipc::IntervalUnit::YEAR_MONTH => DataType::Interval(IntervalUnit::YearMonth), - ipc::IntervalUnit::DAY_TIME => DataType::Interval(IntervalUnit::DayTime), - ipc::IntervalUnit::MONTH_DAY_NANO => DataType::Interval(IntervalUnit::MonthDayNano), - z => panic!("Interval type with unit of {:?} unsupported", z), - } - } - ipc::Type::Duration => { - let duration = field.type_as_duration().unwrap(); - match duration.unit() { - ipc::TimeUnit::SECOND => DataType::Duration(TimeUnit::Second), - ipc::TimeUnit::MILLISECOND => DataType::Duration(TimeUnit::Millisecond), - ipc::TimeUnit::MICROSECOND => DataType::Duration(TimeUnit::Microsecond), - ipc::TimeUnit::NANOSECOND => DataType::Duration(TimeUnit::Nanosecond), - z => panic!("Duration type with unit of {:?} unsupported", z), - } - } - ipc::Type::List => { - let children = field.children().unwrap(); - if children.len() != 1 { - panic!("expect a list to have one child") - } - DataType::List(Box::new(children.get(0).into())) - } - ipc::Type::LargeList => { - let children = field.children().unwrap(); - if children.len() != 1 { - panic!("expect a large list to have one child") - } - DataType::LargeList(Box::new(children.get(0).into())) - } - ipc::Type::FixedSizeList => { - let children = field.children().unwrap(); - if children.len() != 1 { - panic!("expect a list to have one child") - } - let fsl = field.type_as_fixed_size_list().unwrap(); - DataType::FixedSizeList(Box::new(children.get(0).into()), fsl.listSize() as usize) - } - ipc::Type::Struct_ => { - let mut fields = vec![]; - if let Some(children) = field.children() { - for i in 0..children.len() { - fields.push(children.get(i).into()); - } - }; - - DataType::Struct(fields) - } - ipc::Type::Decimal => { - let fsb = field.type_as_decimal().unwrap(); - DataType::Decimal(fsb.precision() as usize, fsb.scale() as usize) - } - ipc::Type::Union => { - let type_ = field.type_as_union().unwrap(); - - let mode = UnionMode::sparse(type_.mode() == ipc::UnionMode::Sparse); - - let ids = type_.typeIds().map(|x| x.iter().collect()); - - let fields = if let Some(children) = field.children() { - (0..children.len()) - .map(|i| children.get(i).into()) - .collect() - } else { - vec![] - }; - DataType::Union(fields, ids, mode) - } - ipc::Type::Map => { - let map = field.type_as_map().unwrap(); - let children = field.children().unwrap(); - if children.len() != 1 { - panic!("expect a map to have one child") - } - DataType::Map(Box::new(children.get(0).into()), map.keysSorted()) - } - t => unimplemented!("Type {:?} not supported", t), - } -} - -pub(crate) struct FbFieldType<'b> { - pub(crate) type_type: ipc::Type, - pub(crate) type_: WIPOffset, - pub(crate) children: Option>>>>, -} - -fn write_metadata<'a>( - fbb: &mut FlatBufferBuilder<'a>, - metadata: &BTreeMap, - kv_vec: &mut Vec>>, -) { - for (k, v) in metadata { - if k != "ARROW:extension:name" && k != "ARROW:extension:metadata" { - let kv_args = ipc::KeyValueArgs { - key: Some(fbb.create_string(k.as_str())), - value: Some(fbb.create_string(v.as_str())), - }; - kv_vec.push(ipc::KeyValue::create(fbb, &kv_args)); - } - } -} - -fn write_extension<'a>( - fbb: &mut FlatBufferBuilder<'a>, - name: &str, - metadata: &Option, - kv_vec: &mut Vec>>, -) { - // metadata - if let Some(metadata) = metadata { - let kv_args = ipc::KeyValueArgs { - key: Some(fbb.create_string("ARROW:extension:metadata")), - value: Some(fbb.create_string(metadata.as_str())), - }; - kv_vec.push(ipc::KeyValue::create(fbb, &kv_args)); - } - - // name - let kv_args = ipc::KeyValueArgs { - key: Some(fbb.create_string("ARROW:extension:name")), - value: Some(fbb.create_string(name)), - }; - kv_vec.push(ipc::KeyValue::create(fbb, &kv_args)); -} - -/// Create an IPC Field from an Arrow Field -pub(crate) fn build_field<'a>( - fbb: &mut FlatBufferBuilder<'a>, - field: &Field, -) -> WIPOffset> { - // custom metadata. - let mut kv_vec = vec![]; - if let DataType::Extension(name, _, metadata) = field.data_type() { - write_extension(fbb, name, metadata, &mut kv_vec); - } - - let fb_field_name = fbb.create_string(field.name().as_str()); - let field_type = get_fb_field_type(field.data_type(), field.is_nullable(), fbb); - - let fb_dictionary = - if let DataType::Dictionary(index_type, inner, is_ordered) = field.data_type() { - if let DataType::Extension(name, _, metadata) = inner.as_ref() { - write_extension(fbb, name, metadata, &mut kv_vec); - } - Some(get_fb_dictionary( - index_type, - field - .dict_id() - .expect("All Dictionary types have `dict_id`"), - *is_ordered, - fbb, - )) - } else { - None - }; - - if let Some(metadata) = field.metadata() { - if !metadata.is_empty() { - write_metadata(fbb, metadata, &mut kv_vec); - } - }; - let fb_metadata = if !kv_vec.is_empty() { - Some(fbb.create_vector(&kv_vec)) - } else { - None - }; - - let mut field_builder = ipc::FieldBuilder::new(fbb); - field_builder.add_name(fb_field_name); - if let Some(dictionary) = fb_dictionary { - field_builder.add_dictionary(dictionary) - } - field_builder.add_type_type(field_type.type_type); - field_builder.add_nullable(field.is_nullable()); - match field_type.children { - None => {} - Some(children) => field_builder.add_children(children), - }; - field_builder.add_type_(field_type.type_); - - if let Some(fb_metadata) = fb_metadata { - field_builder.add_custom_metadata(fb_metadata); - } - - field_builder.finish() -} - -fn type_to_field_type(data_type: &DataType) -> ipc::Type { - use DataType::*; - match data_type { - Null => ipc::Type::Null, - Boolean => ipc::Type::Bool, - UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 => ipc::Type::Int, - Float16 | Float32 | Float64 => ipc::Type::FloatingPoint, - Decimal(_, _) => ipc::Type::Decimal, - Binary => ipc::Type::Binary, - LargeBinary => ipc::Type::LargeBinary, - Utf8 => ipc::Type::Utf8, - LargeUtf8 => ipc::Type::LargeUtf8, - FixedSizeBinary(_) => ipc::Type::FixedSizeBinary, - Date32 | Date64 => ipc::Type::Date, - Duration(_) => ipc::Type::Duration, - Time32(_) | Time64(_) => ipc::Type::Time, - Timestamp(_, _) => ipc::Type::Timestamp, - Interval(_) => ipc::Type::Interval, - List(_) => ipc::Type::List, - LargeList(_) => ipc::Type::LargeList, - FixedSizeList(_, _) => ipc::Type::FixedSizeList, - Union(_, _, _) => ipc::Type::Union, - Map(_, _) => ipc::Type::Map, - Struct(_) => ipc::Type::Struct_, - Dictionary(_, v, _) => type_to_field_type(v), - Extension(_, v, _) => type_to_field_type(v), - } -} - -/// Get the IPC type of a data type -pub(crate) fn get_fb_field_type<'a>( - data_type: &DataType, - is_nullable: bool, - fbb: &mut FlatBufferBuilder<'a>, -) -> FbFieldType<'a> { - use DataType::*; - let type_type = type_to_field_type(data_type); - - // some IPC implementations expect an empty list for child data, instead of a null value. - // An empty field list is thus returned for primitive types - let empty_fields: Vec> = vec![]; - match data_type { - Null => FbFieldType { - type_type, - type_: ipc::NullBuilder::new(fbb).finish().as_union_value(), - children: Some(fbb.create_vector(&empty_fields[..])), - }, - Boolean => FbFieldType { - type_type, - type_: ipc::BoolBuilder::new(fbb).finish().as_union_value(), - children: Some(fbb.create_vector(&empty_fields[..])), - }, - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => { - let children = fbb.create_vector(&empty_fields[..]); - let mut builder = ipc::IntBuilder::new(fbb); - if matches!(data_type, UInt8 | UInt16 | UInt32 | UInt64) { - builder.add_is_signed(false); - } else { - builder.add_is_signed(true); - } - match data_type { - Int8 | UInt8 => builder.add_bitWidth(8), - Int16 | UInt16 => builder.add_bitWidth(16), - Int32 | UInt32 => builder.add_bitWidth(32), - Int64 | UInt64 => builder.add_bitWidth(64), - _ => {} - }; - FbFieldType { - type_type, - type_: builder.finish().as_union_value(), - children: Some(children), - } - } - Float16 | Float32 | Float64 => { - let children = fbb.create_vector(&empty_fields[..]); - let mut builder = ipc::FloatingPointBuilder::new(fbb); - match data_type { - Float16 => builder.add_precision(ipc::Precision::HALF), - Float32 => builder.add_precision(ipc::Precision::SINGLE), - Float64 => builder.add_precision(ipc::Precision::DOUBLE), - _ => {} - }; - FbFieldType { - type_type, - type_: builder.finish().as_union_value(), - children: Some(children), - } - } - Binary => FbFieldType { - type_type, - type_: ipc::BinaryBuilder::new(fbb).finish().as_union_value(), - children: Some(fbb.create_vector(&empty_fields[..])), - }, - LargeBinary => FbFieldType { - type_type, - type_: ipc::LargeBinaryBuilder::new(fbb).finish().as_union_value(), - children: Some(fbb.create_vector(&empty_fields[..])), - }, - Utf8 => FbFieldType { - type_type, - type_: ipc::Utf8Builder::new(fbb).finish().as_union_value(), - children: Some(fbb.create_vector(&empty_fields[..])), - }, - LargeUtf8 => FbFieldType { - type_type, - type_: ipc::LargeUtf8Builder::new(fbb).finish().as_union_value(), - children: Some(fbb.create_vector(&empty_fields[..])), - }, - FixedSizeBinary(len) => { - let mut builder = ipc::FixedSizeBinaryBuilder::new(fbb); - builder.add_byteWidth(*len as i32); - FbFieldType { - type_type, - type_: builder.finish().as_union_value(), - children: Some(fbb.create_vector(&empty_fields[..])), - } - } - Date32 => { - let mut builder = ipc::DateBuilder::new(fbb); - builder.add_unit(ipc::DateUnit::DAY); - FbFieldType { - type_type, - type_: builder.finish().as_union_value(), - children: Some(fbb.create_vector(&empty_fields[..])), - } - } - Date64 => { - let mut builder = ipc::DateBuilder::new(fbb); - builder.add_unit(ipc::DateUnit::MILLISECOND); - FbFieldType { - type_type, - type_: builder.finish().as_union_value(), - children: Some(fbb.create_vector(&empty_fields[..])), - } - } - Time32(unit) | Time64(unit) => { - let mut builder = ipc::TimeBuilder::new(fbb); - match unit { - TimeUnit::Second => { - builder.add_bitWidth(32); - builder.add_unit(ipc::TimeUnit::SECOND); - } - TimeUnit::Millisecond => { - builder.add_bitWidth(32); - builder.add_unit(ipc::TimeUnit::MILLISECOND); - } - TimeUnit::Microsecond => { - builder.add_bitWidth(64); - builder.add_unit(ipc::TimeUnit::MICROSECOND); - } - TimeUnit::Nanosecond => { - builder.add_bitWidth(64); - builder.add_unit(ipc::TimeUnit::NANOSECOND); - } - } - FbFieldType { - type_type, - type_: builder.finish().as_union_value(), - children: Some(fbb.create_vector(&empty_fields[..])), - } - } - Timestamp(unit, tz) => { - let tz = tz.clone().unwrap_or_else(String::new); - let tz_str = fbb.create_string(tz.as_str()); - let mut builder = ipc::TimestampBuilder::new(fbb); - let time_unit = match unit { - TimeUnit::Second => ipc::TimeUnit::SECOND, - TimeUnit::Millisecond => ipc::TimeUnit::MILLISECOND, - TimeUnit::Microsecond => ipc::TimeUnit::MICROSECOND, - TimeUnit::Nanosecond => ipc::TimeUnit::NANOSECOND, - }; - builder.add_unit(time_unit); - if !tz.is_empty() { - builder.add_timezone(tz_str); - } - FbFieldType { - type_type, - type_: builder.finish().as_union_value(), - children: Some(fbb.create_vector(&empty_fields[..])), - } - } - Interval(unit) => { - let mut builder = ipc::IntervalBuilder::new(fbb); - let interval_unit = match unit { - IntervalUnit::YearMonth => ipc::IntervalUnit::YEAR_MONTH, - IntervalUnit::DayTime => ipc::IntervalUnit::DAY_TIME, - IntervalUnit::MonthDayNano => ipc::IntervalUnit::MONTH_DAY_NANO, - }; - builder.add_unit(interval_unit); - FbFieldType { - type_type, - type_: builder.finish().as_union_value(), - children: Some(fbb.create_vector(&empty_fields[..])), - } - } - Duration(unit) => { - let mut builder = ipc::DurationBuilder::new(fbb); - let time_unit = match unit { - TimeUnit::Second => ipc::TimeUnit::SECOND, - TimeUnit::Millisecond => ipc::TimeUnit::MILLISECOND, - TimeUnit::Microsecond => ipc::TimeUnit::MICROSECOND, - TimeUnit::Nanosecond => ipc::TimeUnit::NANOSECOND, - }; - builder.add_unit(time_unit); - FbFieldType { - type_type, - type_: builder.finish().as_union_value(), - children: Some(fbb.create_vector(&empty_fields[..])), - } - } - List(ref list_type) => { - let child = build_field(fbb, list_type); - FbFieldType { - type_type, - type_: ipc::ListBuilder::new(fbb).finish().as_union_value(), - children: Some(fbb.create_vector(&[child])), - } - } - LargeList(ref list_type) => { - let child = build_field(fbb, list_type); - FbFieldType { - type_type, - type_: ipc::LargeListBuilder::new(fbb).finish().as_union_value(), - children: Some(fbb.create_vector(&[child])), - } - } - FixedSizeList(ref list_type, len) => { - let child = build_field(fbb, list_type); - let mut builder = ipc::FixedSizeListBuilder::new(fbb); - builder.add_listSize(*len as i32); - FbFieldType { - type_type, - type_: builder.finish().as_union_value(), - children: Some(fbb.create_vector(&[child])), - } - } - Struct(fields) => { - let children: Vec<_> = fields.iter().map(|field| build_field(fbb, field)).collect(); - - FbFieldType { - type_type, - type_: ipc::Struct_Builder::new(fbb).finish().as_union_value(), - children: Some(fbb.create_vector(&children[..])), - } - } - Dictionary(_, value_type, _) => { - // In this library, the dictionary "type" is a logical construct. Here we - // pass through to the value type, as we've already captured the index - // type in the DictionaryEncoding metadata in the parent field - get_fb_field_type(value_type, is_nullable, fbb) - } - Extension(_, value_type, _) => get_fb_field_type(value_type, is_nullable, fbb), - Decimal(precision, scale) => { - let mut builder = ipc::DecimalBuilder::new(fbb); - builder.add_precision(*precision as i32); - builder.add_scale(*scale as i32); - builder.add_bitWidth(128); - FbFieldType { - type_type, - type_: builder.finish().as_union_value(), - children: Some(fbb.create_vector(&empty_fields[..])), - } - } - Union(fields, ids, mode) => { - let children: Vec<_> = fields.iter().map(|field| build_field(fbb, field)).collect(); - - let ids = ids.as_ref().map(|ids| fbb.create_vector(ids)); - - let mut builder = ipc::UnionBuilder::new(fbb); - builder.add_mode(if mode.is_sparse() { - ipc::UnionMode::Sparse - } else { - ipc::UnionMode::Dense - }); - - if let Some(ids) = ids { - builder.add_typeIds(ids); - } - FbFieldType { - type_type, - type_: builder.finish().as_union_value(), - children: Some(fbb.create_vector(&children)), - } - } - Map(field, keys_sorted) => { - let child = build_field(fbb, field); - let mut field_type = ipc::MapBuilder::new(fbb); - field_type.add_keysSorted(*keys_sorted); - FbFieldType { - type_type: ipc::Type::Map, - type_: field_type.finish().as_union_value(), - children: Some(fbb.create_vector(&[child])), - } - } - } -} - -/// Create an IPC dictionary encoding -pub(crate) fn get_fb_dictionary<'a>( - index_type: &IntegerType, - dict_id: i64, - dict_is_ordered: bool, - fbb: &mut FlatBufferBuilder<'a>, -) -> WIPOffset> { - use IntegerType::*; - // We assume that the dictionary index type (as an integer) has already been - // validated elsewhere, and can safely assume we are dealing with integers - let mut index_builder = ipc::IntBuilder::new(fbb); - - match index_type { - Int8 | Int16 | Int32 | Int64 => index_builder.add_is_signed(true), - UInt8 | UInt16 | UInt32 | UInt64 => index_builder.add_is_signed(false), - } - - match index_type { - Int8 | UInt8 => index_builder.add_bitWidth(8), - Int16 | UInt16 => index_builder.add_bitWidth(16), - Int32 | UInt32 => index_builder.add_bitWidth(32), - Int64 | UInt64 => index_builder.add_bitWidth(64), - } - - let index_builder = index_builder.finish(); - - let mut builder = ipc::DictionaryEncodingBuilder::new(fbb); - builder.add_id(dict_id); - builder.add_indexType(index_builder); - builder.add_isOrdered(dict_is_ordered); - - builder.finish() -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::datatypes::{DataType, Field, Schema}; - - /// Serialize a schema in IPC format - fn schema_to_fb(schema: &Schema) -> FlatBufferBuilder { - let mut fbb = FlatBufferBuilder::new(); - - let root = schema_to_fb_offset(&mut fbb, schema); - - fbb.finish(root, None); - - fbb - } - - #[test] - fn convert_schema_round_trip() { - let md: HashMap = [("Key".to_string(), "value".to_string())] - .iter() - .cloned() - .collect(); - let field_md: BTreeMap = [("k".to_string(), "v".to_string())] - .iter() - .cloned() - .collect(); - let schema = Schema::new_from( - vec![ - { - let mut f = Field::new("uint8", DataType::UInt8, false); - f.set_metadata(Some(field_md)); - f - }, - Field::new("uint16", DataType::UInt16, true), - Field::new("uint32", DataType::UInt32, false), - Field::new("uint64", DataType::UInt64, true), - Field::new("int8", DataType::Int8, true), - Field::new("int16", DataType::Int16, false), - Field::new("int32", DataType::Int32, true), - Field::new("int64", DataType::Int64, false), - Field::new("float16", DataType::Float16, true), - Field::new("float32", DataType::Float32, false), - Field::new("float64", DataType::Float64, true), - Field::new("null", DataType::Null, false), - Field::new("bool", DataType::Boolean, false), - Field::new("date32", DataType::Date32, false), - Field::new("date64", DataType::Date64, true), - Field::new("time32[s]", DataType::Time32(TimeUnit::Second), true), - Field::new("time32[ms]", DataType::Time32(TimeUnit::Millisecond), false), - Field::new("time64[us]", DataType::Time64(TimeUnit::Microsecond), false), - Field::new("time64[ns]", DataType::Time64(TimeUnit::Nanosecond), true), - Field::new( - "timestamp[s]", - DataType::Timestamp(TimeUnit::Second, None), - false, - ), - Field::new( - "timestamp[ms]", - DataType::Timestamp(TimeUnit::Millisecond, None), - true, - ), - Field::new( - "timestamp[us]", - DataType::Timestamp( - TimeUnit::Microsecond, - Some("Africa/Johannesburg".to_string()), - ), - false, - ), - Field::new( - "timestamp[ns]", - DataType::Timestamp(TimeUnit::Nanosecond, None), - true, - ), - Field::new( - "interval[ym]", - DataType::Interval(IntervalUnit::YearMonth), - true, - ), - Field::new( - "interval[dt]", - DataType::Interval(IntervalUnit::DayTime), - true, - ), - Field::new("utf8", DataType::Utf8, false), - Field::new("binary", DataType::Binary, false), - Field::new( - "list[u8]", - DataType::List(Box::new(Field::new("item", DataType::UInt8, false))), - true, - ), - Field::new( - "list[struct]", - DataType::List(Box::new(Field::new( - "struct", - DataType::Struct(vec![ - Field::new("float32", DataType::UInt8, false), - Field::new("int32", DataType::Int32, true), - Field::new("bool", DataType::Boolean, true), - ]), - true, - ))), - false, - ), - Field::new( - "struct]>]>", - DataType::Struct(vec![ - Field::new("int64", DataType::Int64, true), - Field::new( - "list[struct]>]", - DataType::List(Box::new(Field::new( - "struct", - DataType::Struct(vec![ - Field::new("date32", DataType::Date32, true), - Field::new( - "list[struct<>]", - DataType::List(Box::new(Field::new( - "struct", - DataType::Struct(vec![]), - false, - ))), - false, - ), - ]), - false, - ))), - false, - ), - ]), - false, - ), - Field::new("struct<>", DataType::Struct(vec![]), true), - Field::new_dict( - "dictionary", - DataType::Dictionary(IntegerType::Int32, Box::new(DataType::Utf8), true), - true, - 123, - ), - Field::new_dict( - "dictionary", - DataType::Dictionary(IntegerType::UInt8, Box::new(DataType::UInt32), true), - true, - 123, - ), - Field::new("decimal", DataType::Decimal(10, 6), false), - ], - md, - ); - - let fb = schema_to_fb(&schema); - - // read back fields - let ipc = ipc::root_as_schema(fb.finished_data()).unwrap(); - let (schema2, _) = fb_to_schema(ipc); - assert_eq!(schema, schema2); - } -} diff --git a/src/io/ipc/mod.rs b/src/io/ipc/mod.rs index 6450763cde4..bcc3e9259f7 100644 --- a/src/io/ipc/mod.rs +++ b/src/io/ipc/mod.rs @@ -79,12 +79,26 @@ //! [3](https://github.com/jorgecarleitao/arrow2/tree/main/examples/ipc_pyarrow)). mod compression; -mod convert; mod endianess; -pub use convert::fb_to_schema; pub mod read; pub mod write; const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1']; const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; + +/// Struct containing `dictionary_id` and nested `IpcField`, allowing users +/// to specify the dictionary ids of the IPC fields when writing to IPC. +#[derive(Debug, Clone, PartialEq, Default)] +pub struct IpcField { + // optional children + pub fields: Vec, + // dictionary id + pub dictionary_id: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct IpcSchema { + pub fields: Vec, + pub is_little_endian: bool, +} diff --git a/src/io/ipc/read/array/dictionary.rs b/src/io/ipc/read/array/dictionary.rs index 6cd61adfc32..c7d1a04170d 100644 --- a/src/io/ipc/read/array/dictionary.rs +++ b/src/io/ipc/read/array/dictionary.rs @@ -1,24 +1,23 @@ -use std::collections::{HashMap, HashSet, VecDeque}; +use std::collections::{HashSet, VecDeque}; use std::convert::TryInto; use std::io::{Read, Seek}; -use std::sync::Arc; use arrow_format::ipc; -use crate::array::{Array, DictionaryArray, DictionaryKey}; -use crate::datatypes::Field; +use crate::array::{DictionaryArray, DictionaryKey}; use crate::error::{ArrowError, Result}; use super::super::deserialize::Node; +use super::super::Dictionaries; use super::{read_primitive, skip_primitive}; #[allow(clippy::too_many_arguments)] pub fn read_dictionary( field_nodes: &mut VecDeque, - field: &Field, + id: Option, buffers: &mut VecDeque<&ipc::Schema::Buffer>, reader: &mut R, - dictionaries: &HashMap>, + dictionaries: &Dictionaries, block_offset: u64, compression: Option, is_little_endian: bool, @@ -26,7 +25,11 @@ pub fn read_dictionary( where Vec: TryInto, { - let id = field.dict_id().unwrap() as usize; + let id = if let Some(id) = id { + id + } else { + return Err(ArrowError::OutOfSpec("Dictionary has no id.".to_string())); + }; let values = dictionaries .get(&id) .ok_or_else(|| { diff --git a/src/io/ipc/read/array/fixed_size_list.rs b/src/io/ipc/read/array/fixed_size_list.rs index 274b50fe490..b5db90fd752 100644 --- a/src/io/ipc/read/array/fixed_size_list.rs +++ b/src/io/ipc/read/array/fixed_size_list.rs @@ -1,23 +1,25 @@ -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; use std::io::{Read, Seek}; -use std::sync::Arc; use arrow_format::ipc; -use crate::array::{Array, FixedSizeListArray}; +use crate::array::FixedSizeListArray; use crate::datatypes::DataType; use crate::error::Result; +use super::super::super::IpcField; use super::super::deserialize::{read, skip, Node}; use super::super::read_basic::*; +use super::super::Dictionaries; #[allow(clippy::too_many_arguments)] pub fn read_fixed_size_list( field_nodes: &mut VecDeque, data_type: DataType, + ipc_field: &IpcField, buffers: &mut VecDeque<&ipc::Schema::Buffer>, reader: &mut R, - dictionaries: &HashMap>, + dictionaries: &Dictionaries, block_offset: u64, is_little_endian: bool, compression: Option, @@ -39,6 +41,7 @@ pub fn read_fixed_size_list( let values = read( field_nodes, field, + &ipc_field.fields[0], buffers, reader, dictionaries, diff --git a/src/io/ipc/read/array/list.rs b/src/io/ipc/read/array/list.rs index d01f8a1ce2e..cf8ddd41bdf 100644 --- a/src/io/ipc/read/array/list.rs +++ b/src/io/ipc/read/array/list.rs @@ -1,25 +1,27 @@ -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; use std::convert::TryInto; use std::io::{Read, Seek}; -use std::sync::Arc; use arrow_format::ipc; -use crate::array::{Array, ListArray, Offset}; +use crate::array::{ListArray, Offset}; use crate::buffer::Buffer; use crate::datatypes::DataType; use crate::error::Result; +use super::super::super::IpcField; use super::super::deserialize::{read, skip, Node}; use super::super::read_basic::*; +use super::super::Dictionaries; #[allow(clippy::too_many_arguments)] pub fn read_list( field_nodes: &mut VecDeque, data_type: DataType, + ipc_field: &IpcField, buffers: &mut VecDeque<&ipc::Schema::Buffer>, reader: &mut R, - dictionaries: &HashMap>, + dictionaries: &Dictionaries, block_offset: u64, is_little_endian: bool, compression: Option, @@ -55,6 +57,7 @@ where let values = read( field_nodes, field, + &ipc_field.fields[0], buffers, reader, dictionaries, diff --git a/src/io/ipc/read/array/map.rs b/src/io/ipc/read/array/map.rs index c1cd0670bfc..e61887cc5ba 100644 --- a/src/io/ipc/read/array/map.rs +++ b/src/io/ipc/read/array/map.rs @@ -1,24 +1,26 @@ -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; use std::io::{Read, Seek}; -use std::sync::Arc; use arrow_format::ipc; -use crate::array::{Array, MapArray}; +use crate::array::MapArray; use crate::buffer::Buffer; use crate::datatypes::DataType; use crate::error::Result; +use super::super::super::IpcField; use super::super::deserialize::{read, skip, Node}; use super::super::read_basic::*; +use super::super::Dictionaries; #[allow(clippy::too_many_arguments)] pub fn read_map( field_nodes: &mut VecDeque, data_type: DataType, + ipc_field: &IpcField, buffers: &mut VecDeque<&ipc::Schema::Buffer>, reader: &mut R, - dictionaries: &HashMap>, + dictionaries: &Dictionaries, block_offset: u64, is_little_endian: bool, compression: Option, @@ -51,6 +53,7 @@ pub fn read_map( let field = read( field_nodes, field, + &ipc_field.fields[0], buffers, reader, dictionaries, diff --git a/src/io/ipc/read/array/struct_.rs b/src/io/ipc/read/array/struct_.rs index 775774291d3..30c47f654f2 100644 --- a/src/io/ipc/read/array/struct_.rs +++ b/src/io/ipc/read/array/struct_.rs @@ -1,23 +1,25 @@ -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; use std::io::{Read, Seek}; -use std::sync::Arc; use arrow_format::ipc; -use crate::array::{Array, StructArray}; +use crate::array::StructArray; use crate::datatypes::DataType; use crate::error::Result; +use super::super::super::IpcField; use super::super::deserialize::{read, skip, Node}; use super::super::read_basic::*; +use super::super::Dictionaries; #[allow(clippy::too_many_arguments)] pub fn read_struct( field_nodes: &mut VecDeque, data_type: DataType, + ipc_field: &IpcField, buffers: &mut VecDeque<&ipc::Schema::Buffer>, reader: &mut R, - dictionaries: &HashMap>, + dictionaries: &Dictionaries, block_offset: u64, is_little_endian: bool, compression: Option, @@ -38,10 +40,12 @@ pub fn read_struct( let values = fields .iter() - .map(|field| { + .zip(ipc_field.fields.iter()) + .map(|(field, ipc_field)| { read( field_nodes, field, + ipc_field, buffers, reader, dictionaries, diff --git a/src/io/ipc/read/array/union.rs b/src/io/ipc/read/array/union.rs index 87afdfb582e..ce3fb0fb79b 100644 --- a/src/io/ipc/read/array/union.rs +++ b/src/io/ipc/read/array/union.rs @@ -1,24 +1,26 @@ -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; use std::io::{Read, Seek}; -use std::sync::Arc; use arrow_format::ipc; -use crate::array::{Array, UnionArray}; +use crate::array::UnionArray; use crate::datatypes::DataType; use crate::datatypes::UnionMode::Dense; use crate::error::Result; +use super::super::super::IpcField; use super::super::deserialize::{read, skip, Node}; use super::super::read_basic::*; +use super::super::Dictionaries; #[allow(clippy::too_many_arguments)] pub fn read_union( field_nodes: &mut VecDeque, data_type: DataType, + ipc_field: &IpcField, buffers: &mut VecDeque<&ipc::Schema::Buffer>, reader: &mut R, - dictionaries: &HashMap>, + dictionaries: &Dictionaries, block_offset: u64, is_little_endian: bool, compression: Option, @@ -60,10 +62,12 @@ pub fn read_union( let fields = fields .iter() - .map(|field| { + .zip(ipc_field.fields.iter()) + .map(|(field, ipc_field)| { read( field_nodes, field, + ipc_field, buffers, reader, dictionaries, diff --git a/src/io/ipc/read/common.rs b/src/io/ipc/read/common.rs index aa81e64b9b6..1bc9b87ac31 100644 --- a/src/io/ipc/read/common.rs +++ b/src/io/ipc/read/common.rs @@ -1,20 +1,3 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - use std::collections::{HashMap, VecDeque}; use std::io::{Read, Seek}; use std::sync::Arc; @@ -25,9 +8,11 @@ use arrow_format::ipc::Schema::MetadataVersion; use crate::array::*; use crate::datatypes::{DataType, Field, Schema}; use crate::error::{ArrowError, Result}; +use crate::io::ipc::{IpcField, IpcSchema}; use crate::record_batch::RecordBatch; use super::deserialize::{read, skip}; +use super::Dictionaries; type ArrayRef = Arc; @@ -96,13 +81,14 @@ impl<'a, A, I: Iterator> Iterator for ProjectionIter<'a, A, I> { pub fn read_record_batch( batch: ipc::Message::RecordBatch, schema: Arc, + ipc_schema: &IpcSchema, projection: Option<(&[usize], Arc)>, - is_little_endian: bool, - dictionaries: &HashMap>, + dictionaries: &Dictionaries, version: MetadataVersion, reader: &mut R, block_offset: u64, ) -> Result { + assert_eq!(schema.fields().len(), ipc_schema.fields.len()); let buffers = batch.buffers().ok_or_else(|| { ArrowError::OutOfSpec("Unable to get buffers from IPC RecordBatch".to_string()) })?; @@ -116,22 +102,26 @@ pub fn read_record_batch( let (schema, columns) = if let Some(projection) = projection { let projected_schema = projection.1.clone(); - let projection = ProjectionIter::new(projection.0, schema.fields().iter()); + let projection = ProjectionIter::new( + projection.0, + schema.fields().iter().zip(ipc_schema.fields.iter()), + ); let arrays = projection .map(|maybe_field| match maybe_field { - ProjectionResult::Selected(field) => Some(read( + ProjectionResult::Selected((field, ipc_field)) => Some(read( &mut field_nodes, field, + ipc_field, &mut buffers, reader, dictionaries, block_offset, - is_little_endian, + ipc_schema.is_little_endian, batch.compression(), version, )), - ProjectionResult::NotSelected(field) => { + ProjectionResult::NotSelected((field, _)) => { skip(&mut field_nodes, field.data_type(), &mut buffers); None } @@ -143,15 +133,17 @@ pub fn read_record_batch( let arrays = schema .fields() .iter() - .map(|field| { + .zip(ipc_schema.fields.iter()) + .map(|(field, ipc_field)| { read( &mut field_nodes, field, + ipc_field, &mut buffers, reader, dictionaries, block_offset, - is_little_endian, + ipc_schema.is_little_endian, batch.compression(), version, ) @@ -162,25 +154,20 @@ pub fn read_record_batch( RecordBatch::try_new(schema, columns) } -fn find_first_dict_field_d(id: usize, data_type: &DataType) -> Option<&Field> { +fn find_first_dict_field_d<'a>( + id: i64, + data_type: &'a DataType, + ipc_field: &'a IpcField, +) -> Option<(&'a Field, &'a IpcField)> { use DataType::*; match data_type { - Dictionary(_, inner, _) => find_first_dict_field_d(id, inner.as_ref()), - Map(field, _) => find_first_dict_field(id, field.as_ref()), - List(field) => find_first_dict_field(id, field.as_ref()), - LargeList(field) => find_first_dict_field(id, field.as_ref()), - FixedSizeList(field, _) => find_first_dict_field(id, field.as_ref()), - Union(fields, _, _) => { - for field in fields { - if let Some(f) = find_first_dict_field(id, field) { - return Some(f); - } - } - None + Dictionary(_, inner, _) => find_first_dict_field_d(id, inner.as_ref(), ipc_field), + List(field) | LargeList(field) | FixedSizeList(field, ..) | Map(field, ..) => { + find_first_dict_field(id, field.as_ref(), &ipc_field.fields[0]) } - Struct(fields) => { - for field in fields { - if let Some(f) = find_first_dict_field(id, field) { + Union(fields, ..) | Struct(fields) => { + for (field, ipc_field) in fields.iter().zip(ipc_field.fields.iter()) { + if let Some(f) = find_first_dict_field(id, field, ipc_field) { return Some(f); } } @@ -190,18 +177,27 @@ fn find_first_dict_field_d(id: usize, data_type: &DataType) -> Option<&Field> { } } -fn find_first_dict_field(id: usize, field: &Field) -> Option<&Field> { - if let DataType::Dictionary(_, _, _) = &field.data_type { - if field.dict_id as usize == id { - return Some(field); +fn find_first_dict_field<'a>( + id: i64, + field: &'a Field, + ipc_field: &'a IpcField, +) -> Option<(&'a Field, &'a IpcField)> { + if let Some(field_id) = ipc_field.dictionary_id { + if id == field_id { + return Some((field, ipc_field)); } } - find_first_dict_field_d(id, &field.data_type) + find_first_dict_field_d(id, &field.data_type, ipc_field) } -fn first_dict_field(id: usize, fields: &[Field]) -> Result<&Field> { - for field in fields { - if let Some(field) = find_first_dict_field(id, field) { +fn first_dict_field<'a>( + id: i64, + fields: &'a [Field], + ipc_fields: &'a [IpcField], +) -> Result<(&'a Field, &'a IpcField)> { + assert_eq!(fields.len(), ipc_fields.len()); + for (field, ipc_field) in fields.iter().zip(ipc_fields.iter()) { + if let Some(field) = find_first_dict_field(id, field, ipc_field) { return Ok(field); } } @@ -215,9 +211,9 @@ fn first_dict_field(id: usize, fields: &[Field]) -> Result<&Field> { /// updating the `dictionaries` with the resulting dictionary pub fn read_dictionary( batch: ipc::Message::DictionaryBatch, - schema: &Schema, - is_little_endian: bool, - dictionaries: &mut HashMap>, + fields: &[Field], + ipc_schema: &IpcSchema, + dictionaries: &mut Dictionaries, reader: &mut R, block_offset: u64, ) -> Result<()> { @@ -228,7 +224,7 @@ pub fn read_dictionary( } let id = batch.id(); - let first_field = first_dict_field(id as usize, &schema.fields)?; + let (first_field, first_ipc_field) = first_dict_field(id, fields, &ipc_schema.fields)?; // As the dictionary batch does not contain the type of the // values array, we need to retrieve this from the schema. @@ -240,12 +236,17 @@ pub fn read_dictionary( fields: vec![Field::new("", value_type.as_ref().clone(), false)], metadata: HashMap::new(), }); + let ipc_schema = IpcSchema { + fields: vec![first_ipc_field.clone()], + is_little_endian: ipc_schema.is_little_endian, + }; + assert_eq!(ipc_schema.fields.len(), schema.fields().len()); // Read a single column let record_batch = read_record_batch( batch.data().unwrap(), schema, + &ipc_schema, None, - is_little_endian, dictionaries, MetadataVersion::V5, reader, @@ -259,7 +260,7 @@ pub fn read_dictionary( ArrowError::InvalidArgumentError("dictionary id not found in schema".to_string()) })?; - dictionaries.insert(id as usize, dictionary_values); + dictionaries.insert(id, dictionary_values); Ok(()) } diff --git a/src/io/ipc/read/deserialize.rs b/src/io/ipc/read/deserialize.rs index f533a63af40..824f1ff4218 100644 --- a/src/io/ipc/read/deserialize.rs +++ b/src/io/ipc/read/deserialize.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; use std::{ io::{Read, Seek}, sync::Arc, @@ -10,8 +10,9 @@ use arrow_format::ipc::{Message::BodyCompression, Schema::MetadataVersion}; use crate::array::*; use crate::datatypes::{DataType, Field, PhysicalType}; use crate::error::Result; +use crate::io::ipc::IpcField; -use super::array::*; +use super::{array::*, Dictionaries}; pub type Node<'a> = &'a ipc::Message::FieldNode; @@ -19,9 +20,10 @@ pub type Node<'a> = &'a ipc::Message::FieldNode; pub fn read( field_nodes: &mut VecDeque, field: &Field, + ipc_field: &IpcField, buffers: &mut VecDeque<&ipc::Schema::Buffer>, reader: &mut R, - dictionaries: &HashMap>, + dictionaries: &Dictionaries, block_offset: u64, is_little_endian: bool, compression: Option, @@ -120,6 +122,7 @@ pub fn read( List => read_list::( field_nodes, data_type, + ipc_field, buffers, reader, dictionaries, @@ -132,6 +135,7 @@ pub fn read( LargeList => read_list::( field_nodes, data_type, + ipc_field, buffers, reader, dictionaries, @@ -144,6 +148,7 @@ pub fn read( FixedSizeList => read_fixed_size_list( field_nodes, data_type, + ipc_field, buffers, reader, dictionaries, @@ -156,6 +161,7 @@ pub fn read( Struct => read_struct( field_nodes, data_type, + ipc_field, buffers, reader, dictionaries, @@ -169,7 +175,7 @@ pub fn read( match_integer_type!(key_type, |$T| { read_dictionary::<$T, _>( field_nodes, - field, + ipc_field.dictionary_id, buffers, reader, dictionaries, @@ -183,6 +189,7 @@ pub fn read( Union => read_union( field_nodes, data_type, + ipc_field, buffers, reader, dictionaries, @@ -195,6 +202,7 @@ pub fn read( Map => read_map( field_nodes, data_type, + ipc_field, buffers, reader, dictionaries, diff --git a/src/io/ipc/read/mod.rs b/src/io/ipc/read/mod.rs index adc2b41e477..f36fbb6283e 100644 --- a/src/io/ipc/read/mod.rs +++ b/src/io/ipc/read/mod.rs @@ -4,14 +4,23 @@ //! which provides arbitrary access to any of its messages, and the //! [`StreamReader`](stream::StreamReader), which only supports reading //! data in the order it was written in. +use std::collections::HashMap; +use std::sync::Arc; + +use crate::array::Array; mod array; mod common; mod deserialize; mod read_basic; mod reader; +mod schema; mod stream; pub use common::{read_dictionary, read_record_batch}; pub use reader::{read_file_metadata, FileMetadata, FileReader}; +pub use schema::fb_to_schema; pub use stream::{read_stream_metadata, StreamMetadata, StreamReader, StreamState}; + +// how dictionaries are tracked in this crate +pub type Dictionaries = HashMap>; diff --git a/src/io/ipc/read/reader.rs b/src/io/ipc/read/reader.rs index d8f09c0e9bf..af980b7ef75 100644 --- a/src/io/ipc/read/reader.rs +++ b/src/io/ipc/read/reader.rs @@ -1,21 +1,3 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::collections::HashMap; use std::io::{Read, Seek, SeekFrom}; use std::sync::Arc; @@ -23,19 +5,23 @@ use arrow_format::ipc; use arrow_format::ipc::flatbuffers::VerifierOptions; use arrow_format::ipc::File::Block; -use crate::array::*; use crate::datatypes::Schema; use crate::error::{ArrowError, Result}; +use crate::io::ipc::IpcSchema; use crate::record_batch::RecordBatch; -use super::super::convert; use super::super::{ARROW_MAGIC, CONTINUATION_MARKER}; use super::common::*; +use super::schema::fb_to_schema; +use super::Dictionaries; #[derive(Debug, Clone)] pub struct FileMetadata { - /// The schema that is read from the file header - schema: Arc, + /// The schema that is read from the file footer + pub schema: Arc, + + /// The files' [`IpcSchema`] + pub ipc_schema: IpcSchema, /// The blocks in the file /// @@ -43,12 +29,10 @@ pub struct FileMetadata { blocks: Vec, /// Dictionaries associated to each dict_id - dictionaries: HashMap>, + dictionaries: Dictionaries, /// FileMetadata version version: ipc::Schema::MetadataVersion, - - is_little_endian: bool, } impl FileMetadata { @@ -91,9 +75,9 @@ fn read_dictionary_message( fn read_dictionaries( reader: &mut R, schema: &Schema, - is_little_endian: bool, + ipc_schema: &IpcSchema, blocks: &[Block], -) -> Result>> { +) -> Result { let mut dictionaries = Default::default(); let mut data = vec![]; @@ -112,8 +96,8 @@ fn read_dictionaries( let batch = message.header_as_dictionary_batch().unwrap(); read_dictionary( batch, - schema, - is_little_endian, + schema.fields(), + ipc_schema, &mut dictionaries, reader, block_offset, @@ -178,20 +162,20 @@ pub fn read_file_metadata(reader: &mut R) -> Result( read_record_batch( batch, metadata.schema.clone(), + &metadata.ipc_schema, projection, - metadata.is_little_endian, &metadata.dictionaries, metadata.version, reader, @@ -297,6 +281,11 @@ impl FileReader { .unwrap_or(&self.metadata.schema) } + /// Returns the [`FileMetadata`] + pub fn metadata(&self) -> &FileMetadata { + &self.metadata + } + /// Consumes this FileReader, returning the underlying reader pub fn into_inner(self) -> R { self.reader diff --git a/src/io/ipc/read/schema.rs b/src/io/ipc/read/schema.rs new file mode 100644 index 00000000000..380116351e5 --- /dev/null +++ b/src/io/ipc/read/schema.rs @@ -0,0 +1,334 @@ +use std::collections::{BTreeMap, HashMap}; + +mod ipc { + pub use arrow_format::ipc::File::*; + pub use arrow_format::ipc::Message::*; + pub use arrow_format::ipc::Schema::*; +} + +use crate::datatypes::{ + get_extension, DataType, Extension, Field, IntegerType, IntervalUnit, Metadata, Schema, + TimeUnit, UnionMode, +}; + +use super::super::{IpcField, IpcSchema}; + +fn deserialize_field(ipc_field: ipc::Field) -> (Field, IpcField) { + let metadata = read_metadata(&ipc_field); + + let extension = get_extension(&metadata); + + let (data_type, ipc_field_) = get_data_type(ipc_field, extension, true); + + let field = Field { + name: ipc_field.name().unwrap().to_string(), + data_type, + nullable: ipc_field.nullable(), + metadata, + }; + + (field, ipc_field_) +} + +fn read_metadata(field: &ipc::Field) -> Metadata { + if let Some(list) = field.custom_metadata() { + let mut metadata_map = BTreeMap::default(); + for kv in list { + if let (Some(k), Some(v)) = (kv.key(), kv.value()) { + metadata_map.insert(k.to_string(), v.to_string()); + } + } + Some(metadata_map) + } else { + None + } +} + +/// Get the Arrow data type from the flatbuffer Field table +fn get_data_type( + field: ipc::Field, + extension: Extension, + may_be_dictionary: bool, +) -> (DataType, IpcField) { + if let Some(dictionary) = field.dictionary() { + if may_be_dictionary { + let int = dictionary.indexType().unwrap(); + let index_type = match (int.bitWidth(), int.is_signed()) { + (8, true) => IntegerType::Int8, + (8, false) => IntegerType::UInt8, + (16, true) => IntegerType::Int16, + (16, false) => IntegerType::UInt16, + (32, true) => IntegerType::Int32, + (32, false) => IntegerType::UInt32, + (64, true) => IntegerType::Int64, + (64, false) => IntegerType::UInt64, + _ => panic!("Unexpected bitwidth and signed"), + }; + let (inner, mut ipc_field) = get_data_type(field, extension, false); + ipc_field.dictionary_id = Some(dictionary.id()); + return ( + DataType::Dictionary(index_type, Box::new(inner), dictionary.isOrdered()), + ipc_field, + ); + } + } + + if let Some(extension) = extension { + let (name, metadata) = extension; + let (data_type, fields) = get_data_type(field, None, false); + return ( + DataType::Extension(name, Box::new(data_type), metadata), + fields, + ); + } + + match field.type_type() { + ipc::Type::Null => (DataType::Null, IpcField::default()), + ipc::Type::Bool => (DataType::Boolean, IpcField::default()), + ipc::Type::Int => { + let int = field.type_as_int().unwrap(); + let data_type = match (int.bitWidth(), int.is_signed()) { + (8, true) => DataType::Int8, + (8, false) => DataType::UInt8, + (16, true) => DataType::Int16, + (16, false) => DataType::UInt16, + (32, true) => DataType::Int32, + (32, false) => DataType::UInt32, + (64, true) => DataType::Int64, + (64, false) => DataType::UInt64, + z => panic!( + "Int type with bit width of {} and signed of {} not supported", + z.0, z.1 + ), + }; + (data_type, IpcField::default()) + } + ipc::Type::Binary => (DataType::Binary, IpcField::default()), + ipc::Type::LargeBinary => (DataType::LargeBinary, IpcField::default()), + ipc::Type::Utf8 => (DataType::Utf8, IpcField::default()), + ipc::Type::LargeUtf8 => (DataType::LargeUtf8, IpcField::default()), + ipc::Type::FixedSizeBinary => { + let fsb = field.type_as_fixed_size_binary().unwrap(); + ( + DataType::FixedSizeBinary(fsb.byteWidth() as usize), + IpcField::default(), + ) + } + ipc::Type::FloatingPoint => { + let float = field.type_as_floating_point().unwrap(); + let data_type = match float.precision() { + ipc::Precision::HALF => DataType::Float16, + ipc::Precision::SINGLE => DataType::Float32, + ipc::Precision::DOUBLE => DataType::Float64, + z => panic!("FloatingPoint type with precision of {:?} not supported", z), + }; + (data_type, IpcField::default()) + } + ipc::Type::Date => { + let date = field.type_as_date().unwrap(); + let data_type = match date.unit() { + ipc::DateUnit::DAY => DataType::Date32, + ipc::DateUnit::MILLISECOND => DataType::Date64, + z => panic!("Date type with unit of {:?} not supported", z), + }; + (data_type, IpcField::default()) + } + ipc::Type::Time => { + let time = field.type_as_time().unwrap(); + let data_type = match (time.bitWidth(), time.unit()) { + (32, ipc::TimeUnit::SECOND) => DataType::Time32(TimeUnit::Second), + (32, ipc::TimeUnit::MILLISECOND) => DataType::Time32(TimeUnit::Millisecond), + (64, ipc::TimeUnit::MICROSECOND) => DataType::Time64(TimeUnit::Microsecond), + (64, ipc::TimeUnit::NANOSECOND) => DataType::Time64(TimeUnit::Nanosecond), + z => panic!( + "Time type with bit width of {} and unit of {:?} not supported", + z.0, z.1 + ), + }; + (data_type, IpcField::default()) + } + ipc::Type::Timestamp => { + let timestamp = field.type_as_timestamp().unwrap(); + let timezone: Option = timestamp.timezone().map(|tz| tz.to_string()); + let data_type = match timestamp.unit() { + ipc::TimeUnit::SECOND => DataType::Timestamp(TimeUnit::Second, timezone), + ipc::TimeUnit::MILLISECOND => DataType::Timestamp(TimeUnit::Millisecond, timezone), + ipc::TimeUnit::MICROSECOND => DataType::Timestamp(TimeUnit::Microsecond, timezone), + ipc::TimeUnit::NANOSECOND => DataType::Timestamp(TimeUnit::Nanosecond, timezone), + z => panic!("Timestamp type with unit of {:?} not supported", z), + }; + (data_type, IpcField::default()) + } + ipc::Type::Interval => { + let interval = field.type_as_interval().unwrap(); + let data_type = match interval.unit() { + ipc::IntervalUnit::YEAR_MONTH => DataType::Interval(IntervalUnit::YearMonth), + ipc::IntervalUnit::DAY_TIME => DataType::Interval(IntervalUnit::DayTime), + ipc::IntervalUnit::MONTH_DAY_NANO => DataType::Interval(IntervalUnit::MonthDayNano), + z => panic!("Interval type with unit of {:?} unsupported", z), + }; + (data_type, IpcField::default()) + } + ipc::Type::Duration => { + let duration = field.type_as_duration().unwrap(); + let data_type = match duration.unit() { + ipc::TimeUnit::SECOND => DataType::Duration(TimeUnit::Second), + ipc::TimeUnit::MILLISECOND => DataType::Duration(TimeUnit::Millisecond), + ipc::TimeUnit::MICROSECOND => DataType::Duration(TimeUnit::Microsecond), + ipc::TimeUnit::NANOSECOND => DataType::Duration(TimeUnit::Nanosecond), + z => panic!("Duration type with unit of {:?} unsupported", z), + }; + (data_type, IpcField::default()) + } + ipc::Type::Decimal => { + let fsb = field.type_as_decimal().unwrap(); + let data_type = DataType::Decimal(fsb.precision() as usize, fsb.scale() as usize); + (data_type, IpcField::default()) + } + ipc::Type::List => { + let children = field.children().unwrap(); + if children.len() != 1 { + panic!("expect a list to have one child") + } + let (field, ipc_field) = deserialize_field(children.get(0)); + + ( + DataType::List(Box::new(field)), + IpcField { + fields: vec![ipc_field], + dictionary_id: None, + }, + ) + } + ipc::Type::LargeList => { + let children = field.children().unwrap(); + if children.len() != 1 { + panic!("expect a large list to have one child") + } + let (field, ipc_field) = deserialize_field(children.get(0)); + + ( + DataType::LargeList(Box::new(field)), + IpcField { + fields: vec![ipc_field], + dictionary_id: None, + }, + ) + } + ipc::Type::FixedSizeList => { + let fsl = field.type_as_fixed_size_list().unwrap(); + let size = fsl.listSize() as usize; + let children = field.children().unwrap(); + if children.len() != 1 { + panic!("expect a list to have one child") + } + let (field, ipc_field) = deserialize_field(children.get(0)); + + ( + DataType::FixedSizeList(Box::new(field), size), + IpcField { + fields: vec![ipc_field], + dictionary_id: None, + }, + ) + } + ipc::Type::Struct_ => { + let fields = field.children().unwrap(); + if fields.is_empty() { + panic!("expect a struct to have at least one child") + } + let (fields, ipc_fields): (Vec<_>, Vec<_>) = (0..fields.len()) + .map(|field| { + let field = fields.get(field); + let (field, fields) = deserialize_field(field); + (field, fields) + }) + .unzip(); + let ipc_field = IpcField { + fields: ipc_fields, + dictionary_id: None, + }; + (DataType::Struct(fields), ipc_field) + } + ipc::Type::Union => { + let type_ = field.type_as_union().unwrap(); + let mode = UnionMode::sparse(type_.mode() == ipc::UnionMode::Sparse); + let ids = type_.typeIds().map(|x| x.iter().collect()); + + let fields = field.children().unwrap(); + if fields.is_empty() { + panic!("expect a struct to have at least one child") + } + + let (fields, ipc_fields): (Vec<_>, Vec<_>) = (0..fields.len()) + .map(|field| { + let field = fields.get(field); + let (field, fields) = deserialize_field(field); + (field, fields) + }) + .unzip(); + let ipc_field = IpcField { + fields: ipc_fields, + dictionary_id: None, + }; + (DataType::Union(fields, ids, mode), ipc_field) + } + ipc::Type::Map => { + let map = field.type_as_map().unwrap(); + let is_sorted = map.keysSorted(); + + let children = field.children().unwrap(); + if children.len() != 1 { + panic!("expect a list to have one child") + } + let (field, ipc_field) = deserialize_field(children.get(0)); + + let data_type = DataType::Map(Box::new(field), is_sorted); + ( + data_type, + IpcField { + fields: vec![ipc_field], + dictionary_id: None, + }, + ) + } + t => unimplemented!("Type {:?} not supported", t), + } +} + +/// Deserialize the raw Schema table from IPC format to Schema data type +pub fn fb_to_schema(fb: ipc::Schema) -> (Schema, IpcSchema) { + let fields = fb.fields().unwrap(); + let (fields, ipc_fields): (Vec<_>, Vec<_>) = (0..fields.len()) + .map(|field| { + let field = fields.get(field); + let (field, fields) = deserialize_field(field); + (field, fields) + }) + .unzip(); + + let is_little_endian = fb.endianness().variant_name().unwrap_or("Little") == "Little"; + + let mut metadata: HashMap = HashMap::default(); + if let Some(md_fields) = fb.custom_metadata() { + let len = md_fields.len(); + for i in 0..len { + let kv = md_fields.get(i); + let k_str = kv.key(); + let v_str = kv.value(); + if let Some(k) = k_str { + if let Some(v) = v_str { + metadata.insert(k.to_string(), v.to_string()); + } + } + } + } + + ( + Schema { fields, metadata }, + IpcSchema { + fields: ipc_fields, + is_little_endian, + }, + ) +} diff --git a/src/io/ipc/read/stream.rs b/src/io/ipc/read/stream.rs index f9e5e34c2b4..ebc6a83c4f1 100644 --- a/src/io/ipc/read/stream.rs +++ b/src/io/ipc/read/stream.rs @@ -1,45 +1,27 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::collections::HashMap; use std::io::Read; use std::sync::Arc; use arrow_format::ipc; use arrow_format::ipc::Schema::MetadataVersion; -use crate::array::*; use crate::datatypes::Schema; use crate::error::{ArrowError, Result}; +use crate::io::ipc::IpcSchema; use crate::record_batch::RecordBatch; -use super::super::convert; use super::super::CONTINUATION_MARKER; use super::common::*; +use super::schema::fb_to_schema; +use super::Dictionaries; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct StreamMetadata { /// The schema that is read from the stream's first message - schema: Arc, + pub schema: Arc, - version: MetadataVersion, + pub version: MetadataVersion, - /// Whether the incoming stream is little-endian - is_little_endian: bool, + pub ipc_schema: IpcSchema, } /// Reads the metadata of the stream @@ -67,13 +49,13 @@ pub fn read_stream_metadata(reader: &mut R) -> Result { let ipc_schema: ipc::Schema::Schema = message .header_as_schema() .ok_or_else(|| ArrowError::OutOfSpec("Unable to read IPC message as schema".to_string()))?; - let (schema, is_little_endian) = convert::fb_to_schema(ipc_schema); + let (schema, ipc_schema) = fb_to_schema(ipc_schema); let schema = Arc::new(schema); Ok(StreamMetadata { schema, version, - is_little_endian, + ipc_schema, }) } @@ -113,7 +95,7 @@ impl StreamState { fn read_next( reader: &mut R, metadata: &StreamMetadata, - dictionaries: &mut HashMap>, + dictionaries: &mut Dictionaries, message_buffer: &mut Vec, data_buffer: &mut Vec, ) -> Result> { @@ -174,8 +156,8 @@ fn read_next( read_record_batch( batch, metadata.schema.clone(), + &metadata.ipc_schema, None, - metadata.is_little_endian, dictionaries, metadata.version, &mut reader, @@ -195,8 +177,8 @@ fn read_next( read_dictionary( batch, - &metadata.schema, - metadata.is_little_endian, + metadata.schema.fields(), + &metadata.ipc_schema, dictionaries, &mut dict_reader, 0, @@ -222,7 +204,7 @@ fn read_next( pub struct StreamReader { reader: R, metadata: StreamMetadata, - dictionaries: HashMap>, + dictionaries: Dictionaries, finished: bool, data_buffer: Vec, message_buffer: Vec, @@ -246,8 +228,8 @@ impl StreamReader { } /// Return the schema of the stream - pub fn schema(&self) -> &Arc { - &self.metadata.schema + pub fn metadata(&self) -> &StreamMetadata { + &self.metadata } /// Check if the stream is finished diff --git a/src/io/ipc/write/common.rs b/src/io/ipc/write/common.rs index 238ba412c1f..20890c93538 100644 --- a/src/io/ipc/write/common.rs +++ b/src/io/ipc/write/common.rs @@ -1,15 +1,17 @@ -use std::{collections::HashMap, sync::Arc}; +use std::sync::Arc; use arrow_format::ipc; use arrow_format::ipc::flatbuffers::FlatBufferBuilder; use arrow_format::ipc::Message::CompressionType; use crate::array::*; +use crate::columns::Columns; use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::io::ipc::endianess::is_native_little_endian; -use crate::record_batch::RecordBatch; +use crate::io::ipc::read::Dictionaries; +use super::super::IpcField; use super::{write, write_dictionary}; /// Compression codec @@ -30,7 +32,7 @@ pub struct WriteOptions { } fn encode_dictionary( - field: &Field, + field: &IpcField, array: &Arc, options: &WriteOptions, dictionary_tracker: &mut DictionaryTracker, @@ -41,14 +43,11 @@ fn encode_dictionary( Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null | FixedSizeBinary => Ok(()), Dictionary(key_type) => match_integer_type!(key_type, |$T| { - let dict_id = field - .dict_id() - .expect("All Dictionary types have `dict_id`"); + let dict_id = field.dictionary_id + .ok_or_else(|| ArrowError::InvalidArgumentError("Dictionaries must have an associated id".to_string()))?; let values = array.as_any().downcast_ref::>().unwrap().values(); - // todo: this is won't work for Dict>; - let field = Field::new("item", values.data_type().clone(), true); - encode_dictionary(&field, + encode_dictionary(field, values, options, dictionary_tracker, @@ -68,19 +67,16 @@ fn encode_dictionary( Ok(()) }), Struct => { - let values = array - .as_any() - .downcast_ref::() - .unwrap() - .values(); - let fields = if let DataType::Struct(fields) = array.data_type() { - fields - } else { - unreachable!() - }; + let array = array.as_any().downcast_ref::().unwrap(); + let fields = field.fields.as_slice(); + if array.fields().len() != fields.len() { + return Err(ArrowError::InvalidArgumentError( + "The number of fields in a struct must equal the number of children in IpcField".to_string(), + )); + } fields .iter() - .zip(values.iter()) + .zip(array.values().iter()) .try_for_each(|(field, values)| { encode_dictionary( field, @@ -97,11 +93,7 @@ fn encode_dictionary( .downcast_ref::>() .unwrap() .values(); - let field = if let DataType::List(field) = field.data_type() { - field.as_ref() - } else { - unreachable!() - }; + let field = &field.fields[0]; // todo: error instead encode_dictionary( field, values, @@ -116,11 +108,7 @@ fn encode_dictionary( .downcast_ref::>() .unwrap() .values(); - let field = if let DataType::LargeList(field) = field.data_type() { - field.as_ref() - } else { - unreachable!() - }; + let field = &field.fields[0]; // todo: error instead encode_dictionary( field, values, @@ -135,11 +123,7 @@ fn encode_dictionary( .downcast_ref::() .unwrap() .values(); - let field = if let DataType::FixedSizeList(field, _) = field.data_type() { - field.as_ref() - } else { - unreachable!() - }; + let field = &field.fields[0]; // todo: error instead encode_dictionary( field, values, @@ -154,11 +138,13 @@ fn encode_dictionary( .downcast_ref::() .unwrap() .fields(); - let fields = if let DataType::Union(fields, _, _) = field.data_type() { - fields - } else { - unreachable!() - }; + let fields = &field.fields[..]; // todo: error instead + if values.len() != fields.len() { + return Err(ArrowError::InvalidArgumentError( + "The number of fields in a union must equal the number of children in IpcField" + .to_string(), + )); + } fields .iter() .zip(values.iter()) @@ -174,11 +160,7 @@ fn encode_dictionary( } Map => { let values = array.as_any().downcast_ref::().unwrap().field(); - let field = if let DataType::Map(field, _) = field.data_type() { - field.as_ref() - } else { - unreachable!() - }; + let field = &field.fields[0]; // todo: error instead encode_dictionary( field, values, @@ -190,39 +172,39 @@ fn encode_dictionary( } } -pub fn encoded_batch( - batch: &RecordBatch, +pub fn encode_columns( + columns: &Columns>, + fields: &[IpcField], dictionary_tracker: &mut DictionaryTracker, options: &WriteOptions, ) -> Result<(Vec, EncodedData)> { - let schema = batch.schema(); - let mut encoded_dictionaries = Vec::with_capacity(schema.fields().len()); + let mut encoded_dictionaries = vec![]; - for (field, column) in schema.fields().iter().zip(batch.columns()) { + for (field, array) in fields.iter().zip(columns.as_ref()) { encode_dictionary( field, - column, + array, options, dictionary_tracker, &mut encoded_dictionaries, )?; } - let encoded_message = record_batch_to_bytes(batch, options); + let encoded_message = columns_to_bytes(columns, options); Ok((encoded_dictionaries, encoded_message)) } /// Write a `RecordBatch` into two sets of bytes, one for the header (ipc::Schema::Message) and the /// other for the batch's data -fn record_batch_to_bytes(batch: &RecordBatch, options: &WriteOptions) -> EncodedData { +fn columns_to_bytes(columns: &Columns>, options: &WriteOptions) -> EncodedData { let mut fbb = FlatBufferBuilder::new(); let mut nodes: Vec = vec![]; let mut buffers: Vec = vec![]; let mut arrow_data: Vec = vec![]; let mut offset = 0; - for array in batch.columns() { + for array in columns.arrays() { write( array.as_ref(), &mut buffers, @@ -252,7 +234,7 @@ fn record_batch_to_bytes(batch: &RecordBatch, options: &WriteOptions) -> Encoded let root = { let mut batch_builder = ipc::Message::RecordBatchBuilder::new(&mut fbb); - batch_builder.add_length(batch.num_rows() as i64); + batch_builder.add_length(columns.len() as i64); batch_builder.add_nodes(nodes); batch_builder.add_buffers(buffers); if let Some(compression) = compression { @@ -358,14 +340,14 @@ fn dictionary_batch_to_bytes( /// multiple times. Can optionally error if an update to an existing dictionary is attempted, which /// isn't allowed in the `FileWriter`. pub struct DictionaryTracker { - written: HashMap>, + written: Dictionaries, error_on_replacement: bool, } impl DictionaryTracker { pub fn new(error_on_replacement: bool) -> Self { Self { - written: HashMap::new(), + written: Dictionaries::new(), error_on_replacement, } } diff --git a/src/io/ipc/write/mod.rs b/src/io/ipc/write/mod.rs index 5e75961884a..6331e9dc9a0 100644 --- a/src/io/ipc/write/mod.rs +++ b/src/io/ipc/write/mod.rs @@ -18,3 +18,49 @@ mod common_async; #[cfg(feature = "io_ipc_write_async")] #[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_write_async")))] pub mod stream_async; + +use crate::datatypes::{DataType, Field}; + +use super::IpcField; + +fn default_ipc_field(data_type: &DataType, current_id: &mut i64) -> IpcField { + use crate::datatypes::DataType::*; + match data_type.to_logical_type() { + // single child => recurse + Map(inner, ..) | FixedSizeList(inner, _) | LargeList(inner) | List(inner) => IpcField { + fields: vec![default_ipc_field(inner.data_type(), current_id)], + dictionary_id: None, + }, + // multiple children => recurse + Union(fields, ..) | Struct(fields) => IpcField { + fields: fields + .iter() + .map(|f| default_ipc_field(f.data_type(), current_id)) + .collect(), + dictionary_id: None, + }, + // dictionary => current_id + Dictionary(_, data_type, _) => { + let dictionary_id = Some(*current_id); + *current_id += 1; + IpcField { + fields: vec![default_ipc_field(data_type, current_id)], + dictionary_id, + } + } + // no children => do nothing + _ => IpcField { + fields: vec![], + dictionary_id: None, + }, + } +} + +/// Assigns every dictionary field a unique ID +pub fn default_ipc_fields(fields: &[Field]) -> Vec { + let mut dictionary_id = 0i64; + fields + .iter() + .map(|field| default_ipc_field(field.data_type().to_logical_type(), &mut dictionary_id)) + .collect() +} diff --git a/src/io/ipc/write/schema.rs b/src/io/ipc/write/schema.rs index b5400ef67e7..2d7950c9534 100644 --- a/src/io/ipc/write/schema.rs +++ b/src/io/ipc/write/schema.rs @@ -1,21 +1,29 @@ -use arrow_format::ipc; -use arrow_format::ipc::flatbuffers::FlatBufferBuilder; +use arrow_format::ipc::flatbuffers::{ + FlatBufferBuilder, ForwardsUOffset, UnionWIPOffset, Vector, WIPOffset, +}; +use std::collections::BTreeMap; +mod ipc { + pub use arrow_format::ipc::File::*; + pub use arrow_format::ipc::Message::*; + pub use arrow_format::ipc::Schema::*; +} -use crate::datatypes::*; +use crate::datatypes::{DataType, Field, IntegerType, IntervalUnit, Schema, TimeUnit}; +use crate::io::ipc::endianess::is_native_little_endian; -use super::super::convert; +use super::super::IpcField; /// Converts -pub fn schema_to_bytes(schema: &Schema) -> Vec { +pub fn schema_to_bytes(schema: &Schema, ipc_fields: &[IpcField]) -> Vec { let mut fbb = FlatBufferBuilder::new(); let schema = { - let fb = convert::schema_to_fb_offset(&mut fbb, schema); + let fb = schema_to_fb_offset(&mut fbb, schema, ipc_fields); fb.as_union_value() }; - let mut message = ipc::Message::MessageBuilder::new(&mut fbb); - message.add_version(ipc::Schema::MetadataVersion::V5); - message.add_header_type(ipc::Message::MessageHeader::Schema); + let mut message = ipc::MessageBuilder::new(&mut fbb); + message.add_version(ipc::MetadataVersion::V5); + message.add_header_type(ipc::MessageHeader::Schema); message.add_bodyLength(0); message.add_header(schema); // TODO: custom metadata @@ -24,3 +32,648 @@ pub fn schema_to_bytes(schema: &Schema) -> Vec { fbb.finished_data().to_vec() } + +pub fn schema_to_fb_offset<'a>( + fbb: &mut FlatBufferBuilder<'a>, + schema: &Schema, + ipc_fields: &[IpcField], +) -> WIPOffset> { + let fields = schema + .fields() + .iter() + .zip(ipc_fields.iter()) + .map(|(field, ipc_field)| build_field(fbb, field, ipc_field)) + .collect::>(); + + let mut custom_metadata = vec![]; + for (k, v) in schema.metadata() { + let fb_key_name = fbb.create_string(k.as_str()); + let fb_val_name = fbb.create_string(v.as_str()); + + let mut kv_builder = ipc::KeyValueBuilder::new(fbb); + kv_builder.add_key(fb_key_name); + kv_builder.add_value(fb_val_name); + custom_metadata.push(kv_builder.finish()); + } + + let fb_field_list = fbb.create_vector(&fields); + let fb_metadata_list = fbb.create_vector(&custom_metadata); + + let mut builder = ipc::SchemaBuilder::new(fbb); + builder.add_fields(fb_field_list); + builder.add_custom_metadata(fb_metadata_list); + builder.add_endianness(if is_native_little_endian() { + ipc::Endianness::Little + } else { + ipc::Endianness::Big + }); + builder.finish() +} + +pub(crate) struct FbFieldType<'b> { + pub(crate) type_type: ipc::Type, + pub(crate) type_: WIPOffset, + pub(crate) children: Option>>>>, +} + +fn write_metadata<'a>( + fbb: &mut FlatBufferBuilder<'a>, + metadata: &BTreeMap, + kv_vec: &mut Vec>>, +) { + for (k, v) in metadata { + if k != "ARROW:extension:name" && k != "ARROW:extension:metadata" { + let kv_args = ipc::KeyValueArgs { + key: Some(fbb.create_string(k.as_str())), + value: Some(fbb.create_string(v.as_str())), + }; + kv_vec.push(ipc::KeyValue::create(fbb, &kv_args)); + } + } +} + +fn write_extension<'a>( + fbb: &mut FlatBufferBuilder<'a>, + name: &str, + metadata: &Option, + kv_vec: &mut Vec>>, +) { + // metadata + if let Some(metadata) = metadata { + let kv_args = ipc::KeyValueArgs { + key: Some(fbb.create_string("ARROW:extension:metadata")), + value: Some(fbb.create_string(metadata.as_str())), + }; + kv_vec.push(ipc::KeyValue::create(fbb, &kv_args)); + } + + // name + let kv_args = ipc::KeyValueArgs { + key: Some(fbb.create_string("ARROW:extension:name")), + value: Some(fbb.create_string(name)), + }; + kv_vec.push(ipc::KeyValue::create(fbb, &kv_args)); +} + +/// Create an IPC Field from an Arrow Field +pub(crate) fn build_field<'a>( + fbb: &mut FlatBufferBuilder<'a>, + field: &Field, + ipc_field: &IpcField, +) -> WIPOffset> { + // custom metadata. + let mut kv_vec = vec![]; + if let DataType::Extension(name, _, metadata) = field.data_type() { + write_extension(fbb, name, metadata, &mut kv_vec); + } + + let fb_field_name = fbb.create_string(field.name().as_str()); + let field_type = get_fb_field_type(field.data_type(), ipc_field, field.is_nullable(), fbb); + + let fb_dictionary = + if let DataType::Dictionary(index_type, inner, is_ordered) = field.data_type() { + if let DataType::Extension(name, _, metadata) = inner.as_ref() { + write_extension(fbb, name, metadata, &mut kv_vec); + } + Some(get_fb_dictionary( + index_type, + ipc_field + .dictionary_id + .expect("All Dictionary types have `dict_id`"), + *is_ordered, + fbb, + )) + } else { + None + }; + + if let Some(metadata) = field.metadata() { + if !metadata.is_empty() { + write_metadata(fbb, metadata, &mut kv_vec); + } + }; + let fb_metadata = if !kv_vec.is_empty() { + Some(fbb.create_vector(&kv_vec)) + } else { + None + }; + + let mut field_builder = ipc::FieldBuilder::new(fbb); + field_builder.add_name(fb_field_name); + if let Some(dictionary) = fb_dictionary { + field_builder.add_dictionary(dictionary) + } + field_builder.add_type_type(field_type.type_type); + field_builder.add_nullable(field.is_nullable()); + match field_type.children { + None => {} + Some(children) => field_builder.add_children(children), + }; + field_builder.add_type_(field_type.type_); + + if let Some(fb_metadata) = fb_metadata { + field_builder.add_custom_metadata(fb_metadata); + } + + field_builder.finish() +} + +fn type_to_field_type(data_type: &DataType) -> ipc::Type { + use DataType::*; + match data_type { + Null => ipc::Type::Null, + Boolean => ipc::Type::Bool, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 => ipc::Type::Int, + Float16 | Float32 | Float64 => ipc::Type::FloatingPoint, + Decimal(_, _) => ipc::Type::Decimal, + Binary => ipc::Type::Binary, + LargeBinary => ipc::Type::LargeBinary, + Utf8 => ipc::Type::Utf8, + LargeUtf8 => ipc::Type::LargeUtf8, + FixedSizeBinary(_) => ipc::Type::FixedSizeBinary, + Date32 | Date64 => ipc::Type::Date, + Duration(_) => ipc::Type::Duration, + Time32(_) | Time64(_) => ipc::Type::Time, + Timestamp(_, _) => ipc::Type::Timestamp, + Interval(_) => ipc::Type::Interval, + List(_) => ipc::Type::List, + LargeList(_) => ipc::Type::LargeList, + FixedSizeList(_, _) => ipc::Type::FixedSizeList, + Union(_, _, _) => ipc::Type::Union, + Map(_, _) => ipc::Type::Map, + Struct(_) => ipc::Type::Struct_, + Dictionary(_, v, _) => type_to_field_type(v), + Extension(_, v, _) => type_to_field_type(v), + } +} + +/// Get the IPC type of a data type +pub(crate) fn get_fb_field_type<'a>( + data_type: &DataType, + ipc_field: &IpcField, + is_nullable: bool, + fbb: &mut FlatBufferBuilder<'a>, +) -> FbFieldType<'a> { + use DataType::*; + let type_type = type_to_field_type(data_type); + + // some IPC implementations expect an empty list for child data, instead of a null value. + // An empty field list is thus returned for primitive types + let empty_fields: Vec> = vec![]; + match data_type { + Null => FbFieldType { + type_type, + type_: ipc::NullBuilder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + }, + Boolean => FbFieldType { + type_type, + type_: ipc::BoolBuilder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + }, + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => { + let children = fbb.create_vector(&empty_fields[..]); + let mut builder = ipc::IntBuilder::new(fbb); + if matches!(data_type, UInt8 | UInt16 | UInt32 | UInt64) { + builder.add_is_signed(false); + } else { + builder.add_is_signed(true); + } + match data_type { + Int8 | UInt8 => builder.add_bitWidth(8), + Int16 | UInt16 => builder.add_bitWidth(16), + Int32 | UInt32 => builder.add_bitWidth(32), + Int64 | UInt64 => builder.add_bitWidth(64), + _ => {} + }; + FbFieldType { + type_type, + type_: builder.finish().as_union_value(), + children: Some(children), + } + } + Float16 | Float32 | Float64 => { + let children = fbb.create_vector(&empty_fields[..]); + let mut builder = ipc::FloatingPointBuilder::new(fbb); + match data_type { + Float16 => builder.add_precision(ipc::Precision::HALF), + Float32 => builder.add_precision(ipc::Precision::SINGLE), + Float64 => builder.add_precision(ipc::Precision::DOUBLE), + _ => {} + }; + FbFieldType { + type_type, + type_: builder.finish().as_union_value(), + children: Some(children), + } + } + Binary => FbFieldType { + type_type, + type_: ipc::BinaryBuilder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + }, + LargeBinary => FbFieldType { + type_type, + type_: ipc::LargeBinaryBuilder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + }, + Utf8 => FbFieldType { + type_type, + type_: ipc::Utf8Builder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + }, + LargeUtf8 => FbFieldType { + type_type, + type_: ipc::LargeUtf8Builder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + }, + FixedSizeBinary(len) => { + let mut builder = ipc::FixedSizeBinaryBuilder::new(fbb); + builder.add_byteWidth(*len as i32); + FbFieldType { + type_type, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + Date32 => { + let mut builder = ipc::DateBuilder::new(fbb); + builder.add_unit(ipc::DateUnit::DAY); + FbFieldType { + type_type, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + Date64 => { + let mut builder = ipc::DateBuilder::new(fbb); + builder.add_unit(ipc::DateUnit::MILLISECOND); + FbFieldType { + type_type, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + Time32(unit) | Time64(unit) => { + let mut builder = ipc::TimeBuilder::new(fbb); + match unit { + TimeUnit::Second => { + builder.add_bitWidth(32); + builder.add_unit(ipc::TimeUnit::SECOND); + } + TimeUnit::Millisecond => { + builder.add_bitWidth(32); + builder.add_unit(ipc::TimeUnit::MILLISECOND); + } + TimeUnit::Microsecond => { + builder.add_bitWidth(64); + builder.add_unit(ipc::TimeUnit::MICROSECOND); + } + TimeUnit::Nanosecond => { + builder.add_bitWidth(64); + builder.add_unit(ipc::TimeUnit::NANOSECOND); + } + } + FbFieldType { + type_type, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + Timestamp(unit, tz) => { + let tz = tz.clone().unwrap_or_else(String::new); + let tz_str = fbb.create_string(tz.as_str()); + let mut builder = ipc::TimestampBuilder::new(fbb); + let time_unit = match unit { + TimeUnit::Second => ipc::TimeUnit::SECOND, + TimeUnit::Millisecond => ipc::TimeUnit::MILLISECOND, + TimeUnit::Microsecond => ipc::TimeUnit::MICROSECOND, + TimeUnit::Nanosecond => ipc::TimeUnit::NANOSECOND, + }; + builder.add_unit(time_unit); + if !tz.is_empty() { + builder.add_timezone(tz_str); + } + FbFieldType { + type_type, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + Interval(unit) => { + let mut builder = ipc::IntervalBuilder::new(fbb); + let interval_unit = match unit { + IntervalUnit::YearMonth => ipc::IntervalUnit::YEAR_MONTH, + IntervalUnit::DayTime => ipc::IntervalUnit::DAY_TIME, + IntervalUnit::MonthDayNano => ipc::IntervalUnit::MONTH_DAY_NANO, + }; + builder.add_unit(interval_unit); + FbFieldType { + type_type, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + Duration(unit) => { + let mut builder = ipc::DurationBuilder::new(fbb); + let time_unit = match unit { + TimeUnit::Second => ipc::TimeUnit::SECOND, + TimeUnit::Millisecond => ipc::TimeUnit::MILLISECOND, + TimeUnit::Microsecond => ipc::TimeUnit::MICROSECOND, + TimeUnit::Nanosecond => ipc::TimeUnit::NANOSECOND, + }; + builder.add_unit(time_unit); + FbFieldType { + type_type, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + List(ref list_type) => { + let child = build_field(fbb, list_type, &ipc_field.fields[0]); + FbFieldType { + type_type, + type_: ipc::ListBuilder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&[child])), + } + } + LargeList(ref list_type) => { + let child = build_field(fbb, list_type, &ipc_field.fields[0]); + FbFieldType { + type_type, + type_: ipc::LargeListBuilder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&[child])), + } + } + FixedSizeList(ref list_type, len) => { + let child = build_field(fbb, list_type, &ipc_field.fields[0]); + let mut builder = ipc::FixedSizeListBuilder::new(fbb); + builder.add_listSize(*len as i32); + FbFieldType { + type_type, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&[child])), + } + } + Struct(fields) => { + let children: Vec<_> = fields + .iter() + .zip(ipc_field.fields.iter()) + .map(|(field, ipc_field)| build_field(fbb, field, ipc_field)) + .collect(); + + FbFieldType { + type_type, + type_: ipc::Struct_Builder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&children[..])), + } + } + Dictionary(_, value_type, _) => { + // In this library, the dictionary "type" is a logical construct. Here we + // pass through to the value type, as we've already captured the index + // type in the DictionaryEncoding metadata in the parent field + get_fb_field_type(value_type, ipc_field, is_nullable, fbb) + } + Extension(_, value_type, _) => get_fb_field_type(value_type, ipc_field, is_nullable, fbb), + Decimal(precision, scale) => { + let mut builder = ipc::DecimalBuilder::new(fbb); + builder.add_precision(*precision as i32); + builder.add_scale(*scale as i32); + builder.add_bitWidth(128); + FbFieldType { + type_type, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + Union(fields, ids, mode) => { + let children: Vec<_> = fields + .iter() + .zip(ipc_field.fields.iter()) + .map(|(field, ipc_field)| build_field(fbb, field, ipc_field)) + .collect(); + + let ids = ids.as_ref().map(|ids| fbb.create_vector(ids)); + + let mut builder = ipc::UnionBuilder::new(fbb); + builder.add_mode(if mode.is_sparse() { + ipc::UnionMode::Sparse + } else { + ipc::UnionMode::Dense + }); + + if let Some(ids) = ids { + builder.add_typeIds(ids); + } + FbFieldType { + type_type, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&children)), + } + } + Map(field, keys_sorted) => { + let child = build_field(fbb, field, &ipc_field.fields[0]); + let mut field_type = ipc::MapBuilder::new(fbb); + field_type.add_keysSorted(*keys_sorted); + FbFieldType { + type_type: ipc::Type::Map, + type_: field_type.finish().as_union_value(), + children: Some(fbb.create_vector(&[child])), + } + } + } +} + +/// Create an IPC dictionary encoding +pub(crate) fn get_fb_dictionary<'a>( + index_type: &IntegerType, + dict_id: i64, + dict_is_ordered: bool, + fbb: &mut FlatBufferBuilder<'a>, +) -> WIPOffset> { + use IntegerType::*; + // We assume that the dictionary index type (as an integer) has already been + // validated elsewhere, and can safely assume we are dealing with integers + let mut index_builder = ipc::IntBuilder::new(fbb); + + match index_type { + Int8 | Int16 | Int32 | Int64 => index_builder.add_is_signed(true), + UInt8 | UInt16 | UInt32 | UInt64 => index_builder.add_is_signed(false), + } + + match index_type { + Int8 | UInt8 => index_builder.add_bitWidth(8), + Int16 | UInt16 => index_builder.add_bitWidth(16), + Int32 | UInt32 => index_builder.add_bitWidth(32), + Int64 | UInt64 => index_builder.add_bitWidth(64), + } + + let index_builder = index_builder.finish(); + + let mut builder = ipc::DictionaryEncodingBuilder::new(fbb); + builder.add_id(dict_id); + builder.add_indexType(index_builder); + builder.add_isOrdered(dict_is_ordered); + + builder.finish() +} + +/* +#[cfg(test)] +mod tests { + use super::*; + use crate::datatypes::{DataType, Field, Schema}; + + /// Serialize a schema in IPC format + fn schema_to_fb(schema: &Schema) -> FlatBufferBuilder { + let mut fbb = FlatBufferBuilder::new(); + + let root = schema_to_fb_offset(&mut fbb, schema); + + fbb.finish(root, None); + + fbb + } + + #[test] + fn convert_schema_round_trip() { + let md: HashMap = [("Key".to_string(), "value".to_string())] + .iter() + .cloned() + .collect(); + let field_md: BTreeMap = [("k".to_string(), "v".to_string())] + .iter() + .cloned() + .collect(); + let schema = Schema::new_from( + vec![ + { + let mut f = Field::new("uint8", DataType::UInt8, false); + f.set_metadata(Some(field_md)); + f + }, + Field::new("uint16", DataType::UInt16, true), + Field::new("uint32", DataType::UInt32, false), + Field::new("uint64", DataType::UInt64, true), + Field::new("int8", DataType::Int8, true), + Field::new("int16", DataType::Int16, false), + Field::new("int32", DataType::Int32, true), + Field::new("int64", DataType::Int64, false), + Field::new("float16", DataType::Float16, true), + Field::new("float32", DataType::Float32, false), + Field::new("float64", DataType::Float64, true), + Field::new("null", DataType::Null, false), + Field::new("bool", DataType::Boolean, false), + Field::new("date32", DataType::Date32, false), + Field::new("date64", DataType::Date64, true), + Field::new("time32[s]", DataType::Time32(TimeUnit::Second), true), + Field::new("time32[ms]", DataType::Time32(TimeUnit::Millisecond), false), + Field::new("time64[us]", DataType::Time64(TimeUnit::Microsecond), false), + Field::new("time64[ns]", DataType::Time64(TimeUnit::Nanosecond), true), + Field::new( + "timestamp[s]", + DataType::Timestamp(TimeUnit::Second, None), + false, + ), + Field::new( + "timestamp[ms]", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "timestamp[us]", + DataType::Timestamp( + TimeUnit::Microsecond, + Some("Africa/Johannesburg".to_string()), + ), + false, + ), + Field::new( + "timestamp[ns]", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), + Field::new( + "interval[ym]", + DataType::Interval(IntervalUnit::YearMonth), + true, + ), + Field::new( + "interval[dt]", + DataType::Interval(IntervalUnit::DayTime), + true, + ), + Field::new("utf8", DataType::Utf8, false), + Field::new("binary", DataType::Binary, false), + Field::new( + "list[u8]", + DataType::List(Box::new(Field::new("item", DataType::UInt8, false))), + true, + ), + Field::new( + "list[struct]", + DataType::List(Box::new(Field::new( + "struct", + DataType::Struct(vec![ + Field::new("float32", DataType::UInt8, false), + Field::new("int32", DataType::Int32, true), + Field::new("bool", DataType::Boolean, true), + ]), + true, + ))), + false, + ), + Field::new( + "struct]>]>", + DataType::Struct(vec![ + Field::new("int64", DataType::Int64, true), + Field::new( + "list[struct]>]", + DataType::List(Box::new(Field::new( + "struct", + DataType::Struct(vec![ + Field::new("date32", DataType::Date32, true), + Field::new( + "list[struct<>]", + DataType::List(Box::new(Field::new( + "struct", + DataType::Struct(vec![]), + false, + ))), + false, + ), + ]), + false, + ))), + false, + ), + ]), + false, + ), + Field::new("struct<>", DataType::Struct(vec![]), true), + Field::new_dict( + "dictionary", + DataType::Dictionary(IntegerType::Int32, Box::new(DataType::Utf8), true), + true, + 123, + ), + Field::new_dict( + "dictionary", + DataType::Dictionary(IntegerType::UInt8, Box::new(DataType::UInt32), true), + true, + 123, + ), + Field::new("decimal", DataType::Decimal(10, 6), false), + ], + md, + ); + + let fb = schema_to_fb(&schema); + + // read back fields + let ipc = ipc::root_as_schema(fb.finished_data()).unwrap(); + let (schema2, _) = fb_to_schema(ipc); + assert_eq!(schema, schema2); + } +} + */ diff --git a/src/io/ipc/write/stream.rs b/src/io/ipc/write/stream.rs index bdc09a44951..b19a76feeb8 100644 --- a/src/io/ipc/write/stream.rs +++ b/src/io/ipc/write/stream.rs @@ -1,20 +1,3 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - //! Arrow IPC File and Stream Writers //! //! The `FileWriter` and `StreamWriter` have similar interfaces, @@ -22,7 +5,8 @@ use std::io::Write; -use super::common::{encoded_batch, DictionaryTracker, EncodedData, WriteOptions}; +use super::super::IpcField; +use super::common::{encode_columns, DictionaryTracker, EncodedData, WriteOptions}; use super::common_sync::{write_continuation, write_message}; use super::schema_to_bytes; @@ -48,24 +32,28 @@ pub struct StreamWriter { } impl StreamWriter { - /// Try create a new writer, with the schema written as part of the header - pub fn try_new(mut writer: W, schema: &Schema, write_options: WriteOptions) -> Result { - // write the schema, set the written bytes to the schema - let encoded_message = EncodedData { - ipc_message: schema_to_bytes(schema), - arrow_data: vec![], - }; - write_message(&mut writer, encoded_message)?; - Ok(Self { + /// Creates a new [`StreamWriter`] + pub fn new(writer: W, write_options: WriteOptions) -> Self { + Self { writer, write_options, finished: false, dictionary_tracker: DictionaryTracker::new(false), - }) + } + } + + /// Starts the stream + pub fn start(&mut self, schema: &Schema, ipc_fields: &[IpcField]) -> Result<()> { + let encoded_message = EncodedData { + ipc_message: schema_to_bytes(schema, ipc_fields), + arrow_data: vec![], + }; + write_message(&mut self.writer, encoded_message)?; + Ok(()) } - /// Write a record batch to the stream - pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + /// Writes [`RecordBatch`] to the stream + pub fn write(&mut self, batch: &RecordBatch, fields: &[IpcField]) -> Result<()> { if self.finished { return Err(ArrowError::Io(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, @@ -73,8 +61,13 @@ impl StreamWriter { ))); } - let (encoded_dictionaries, encoded_message) = - encoded_batch(batch, &mut self.dictionary_tracker, &self.write_options)?; + let columns = batch.clone().into(); + let (encoded_dictionaries, encoded_message) = encode_columns( + &columns, + fields, + &mut self.dictionary_tracker, + &self.write_options, + )?; for encoded_dictionary in encoded_dictionaries { write_message(&mut self.writer, encoded_dictionary)?; diff --git a/src/io/ipc/write/stream_async.rs b/src/io/ipc/write/stream_async.rs index 6af81fb6fce..e84df52342c 100644 --- a/src/io/ipc/write/stream_async.rs +++ b/src/io/ipc/write/stream_async.rs @@ -1,10 +1,12 @@ //! `async` writing of arrow streams + use futures::AsyncWrite; +use super::super::IpcField; pub use super::common::WriteOptions; -use super::common::{encoded_batch, DictionaryTracker, EncodedData}; +use super::common::{encode_columns, DictionaryTracker, EncodedData}; use super::common_async::{write_continuation, write_message}; -use super::schema_to_bytes; +use super::{default_ipc_fields, schema_to_bytes}; use crate::datatypes::*; use crate::error::{ArrowError, Result}; @@ -34,17 +36,29 @@ impl StreamWriter { } /// Starts the stream - pub async fn start(&mut self, schema: &Schema) -> Result<()> { - let encoded_message = EncodedData { - ipc_message: schema_to_bytes(schema), - arrow_data: vec![], + pub async fn start(&mut self, schema: &Schema, ipc_fields: Option<&[IpcField]>) -> Result<()> { + let encoded_message = if let Some(ipc_fields) = ipc_fields { + EncodedData { + ipc_message: schema_to_bytes(schema, ipc_fields), + arrow_data: vec![], + } + } else { + let ipc_fields = default_ipc_fields(schema.fields()); + EncodedData { + ipc_message: schema_to_bytes(schema, &ipc_fields), + arrow_data: vec![], + } }; write_message(&mut self.writer, encoded_message).await?; Ok(()) } - /// Writes a [`RecordBatch`] to the stream - pub async fn write(&mut self, batch: &RecordBatch) -> Result<()> { + /// Writes [`RecordBatch`] to the stream + pub async fn write( + &mut self, + batch: &RecordBatch, + ipc_fields: Option<&[IpcField]>, + ) -> Result<()> { if self.finished { return Err(ArrowError::Io(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, @@ -52,9 +66,24 @@ impl StreamWriter { ))); } - // todo: move this out of the `async` since this is blocking. - let (encoded_dictionaries, encoded_message) = - encoded_batch(batch, &mut self.dictionary_tracker, &self.write_options)?; + let (encoded_dictionaries, encoded_message) = if let Some(ipc_fields) = ipc_fields { + let columns = batch.clone().into(); + encode_columns( + &columns, + ipc_fields, + &mut self.dictionary_tracker, + &self.write_options, + )? + } else { + let ipc_fields = default_ipc_fields(batch.schema().fields()); + let columns = batch.clone().into(); + encode_columns( + &columns, + &ipc_fields, + &mut self.dictionary_tracker, + &self.write_options, + )? + }; for encoded_dictionary in encoded_dictionaries { write_message(&mut self.writer, encoded_dictionary).await?; diff --git a/src/io/ipc/write/writer.rs b/src/io/ipc/write/writer.rs index 6a66fb3a1a8..4591f9a3953 100644 --- a/src/io/ipc/write/writer.rs +++ b/src/io/ipc/write/writer.rs @@ -1,36 +1,14 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Arrow IPC File and Stream Writers -//! -//! The `FileWriter` and `StreamWriter` have similar interfaces, -//! however the `FileWriter` expects a reader that supports `Seek`ing - use std::io::Write; use arrow_format::ipc; use arrow_format::ipc::flatbuffers::FlatBufferBuilder; -use super::super::ARROW_MAGIC; use super::{ - super::convert, - common::{encoded_batch, DictionaryTracker, EncodedData, WriteOptions}, + super::IpcField, + super::ARROW_MAGIC, + common::{encode_columns, DictionaryTracker, EncodedData, WriteOptions}, common_sync::{write_continuation, write_message}, - schema_to_bytes, + default_ipc_fields, schema, schema_to_bytes, }; use crate::datatypes::*; @@ -45,6 +23,7 @@ pub struct FileWriter { options: WriteOptions, /// A reference to the schema, used in validating record batches schema: Schema, + ipc_fields: Vec, /// The number of bytes between each block of bytes, as an offset for random access block_offsets: usize, /// Dictionary blocks that will be written as part of the IPC footer @@ -59,21 +38,34 @@ pub struct FileWriter { impl FileWriter { /// Try create a new writer, with the schema written as part of the header - pub fn try_new(mut writer: W, schema: &Schema, options: WriteOptions) -> Result { + pub fn try_new( + mut writer: W, + schema: &Schema, + ipc_fields: Option>, + options: WriteOptions, + ) -> Result { // write magic to header writer.write_all(&ARROW_MAGIC[..])?; // create an 8-byte boundary after the header writer.write_all(&[0, 0])?; // write the schema, set the written bytes to the schema + + let ipc_fields = if let Some(ipc_fields) = ipc_fields { + ipc_fields + } else { + default_ipc_fields(schema.fields()) + }; let encoded_message = EncodedData { - ipc_message: schema_to_bytes(schema), + ipc_message: schema_to_bytes(schema, &ipc_fields), arrow_data: vec![], }; + let (meta, data) = write_message(&mut writer, encoded_message)?; Ok(Self { writer, options, schema: schema.clone(), + ipc_fields, block_offsets: meta + data + 8, dictionary_blocks: vec![], record_blocks: vec![], @@ -86,8 +78,8 @@ impl FileWriter { self.writer } - /// Write a record batch to the file - pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + /// Writes [`RecordBatch`] to the file + pub fn write(&mut self, batch: &RecordBatch, ipc_fields: Option<&[IpcField]>) -> Result<()> { if self.finished { return Err(ArrowError::Io(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, @@ -95,8 +87,19 @@ impl FileWriter { ))); } - let (encoded_dictionaries, encoded_message) = - encoded_batch(batch, &mut self.dictionary_tracker, &self.options)?; + let ipc_fields = if let Some(ipc_fields) = ipc_fields { + ipc_fields + } else { + self.ipc_fields.as_ref() + }; + + let columns = batch.clone().into(); + let (encoded_dictionaries, encoded_message) = encode_columns( + &columns, + ipc_fields, + &mut self.dictionary_tracker, + &self.options, + )?; for encoded_dictionary in encoded_dictionaries { let (meta, data) = write_message(&mut self.writer, encoded_dictionary)?; @@ -126,7 +129,7 @@ impl FileWriter { let mut fbb = FlatBufferBuilder::new(); let dictionaries = fbb.create_vector(&self.dictionary_blocks); let record_batches = fbb.create_vector(&self.record_blocks); - let schema = convert::schema_to_fb_offset(&mut fbb, &self.schema); + let schema = schema::schema_to_fb_offset(&mut fbb, &self.schema, &self.ipc_fields); let root = { let mut footer_builder = ipc::File::FooterBuilder::new(&mut fbb); diff --git a/src/io/json_integration/mod.rs b/src/io/json_integration/mod.rs index ccb9b868d68..8c5f3e2556e 100644 --- a/src/io/json_integration/mod.rs +++ b/src/io/json_integration/mod.rs @@ -1,35 +1,12 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - //! Utils for JSON integration testing //! //! These utilities define structs that read the integration JSON format for integration testing purposes. use serde_derive::{Deserialize, Serialize}; -use serde_json::{Map, Value}; - -use crate::datatypes::*; +use serde_json::Value; -mod schema; -use schema::ToJson; -mod read; -mod write; -pub use read::to_record_batch; -pub use write::from_record_batch; +pub mod read; +pub mod write; /// A struct that represents an Arrow file with a schema and record batches #[derive(Deserialize, Serialize, Debug)] @@ -64,59 +41,6 @@ pub struct ArrowJsonField { pub metadata: Option, } -impl From<&Field> for ArrowJsonField { - fn from(field: &Field) -> Self { - let metadata_value = match field.metadata() { - Some(kv_list) => { - let mut array = Vec::new(); - for (k, v) in kv_list { - let mut kv_map = Map::new(); - kv_map.insert(k.clone(), Value::String(v.clone())); - array.push(Value::Object(kv_map)); - } - if !array.is_empty() { - Some(Value::Array(array)) - } else { - None - } - } - _ => None, - }; - - let dictionary = if let DataType::Dictionary(key_type, _, is_ordered) = &field.data_type { - use crate::datatypes::IntegerType::*; - Some(ArrowJsonFieldDictionary { - id: field.dict_id, - index_type: IntegerType { - name: "".to_string(), - bit_width: match key_type { - Int8 | UInt8 => 8, - Int16 | UInt16 => 16, - Int32 | UInt32 => 32, - Int64 | UInt64 => 64, - }, - is_signed: match key_type { - Int8 | Int16 | Int32 | Int64 => true, - UInt8 | UInt16 | UInt32 | UInt64 => false, - }, - }, - is_ordered: *is_ordered, - }) - } else { - None - }; - - Self { - name: field.name().to_string(), - field_type: field.data_type().to_json(), - nullable: field.is_nullable(), - children: vec![], - dictionary, - metadata: metadata_value, - } - } -} - #[derive(Deserialize, Serialize, Debug)] pub struct ArrowJsonFieldDictionary { pub id: i64, diff --git a/src/io/json_integration/read.rs b/src/io/json_integration/read/array.rs similarity index 86% rename from src/io/json_integration/read.rs rename to src/io/json_integration/read/array.rs index e29e1a72375..6579d7f815c 100644 --- a/src/io/json_integration/read.rs +++ b/src/io/json_integration/read/array.rs @@ -1,20 +1,3 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - use std::{collections::HashMap, sync::Arc}; use num_traits::NumCast; @@ -26,11 +9,12 @@ use crate::{ buffer::Buffer, datatypes::{DataType, PhysicalType, PrimitiveType, Schema}, error::{ArrowError, Result}, + io::ipc::IpcField, record_batch::RecordBatch, types::{days_ms, months_days_ns, NativeType}, }; -use super::{ArrowJsonBatch, ArrowJsonColumn, ArrowJsonDictionaryBatch}; +use super::super::{ArrowJsonBatch, ArrowJsonColumn, ArrowJsonDictionaryBatch}; fn to_validity(validity: &Option>) -> Option { validity.as_ref().and_then(|x| { @@ -209,6 +193,7 @@ fn to_utf8(json_col: &ArrowJsonColumn, data_type: DataType) -> Arc( json_col: &ArrowJsonColumn, data_type: DataType, + field: &IpcField, dictionaries: &HashMap, ) -> Result> { let validity = to_validity(&json_col.validity); @@ -217,7 +202,7 @@ fn to_list( let children = &json_col.children.as_ref().unwrap()[0]; let values = to_array( child_field.data_type().clone(), - child_field.dict_id(), + &field.fields[0], children, dictionaries, )?; @@ -230,6 +215,7 @@ fn to_list( fn to_map( json_col: &ArrowJsonColumn, data_type: DataType, + field: &IpcField, dictionaries: &HashMap, ) -> Result> { let validity = to_validity(&json_col.validity); @@ -238,7 +224,7 @@ fn to_map( let children = &json_col.children.as_ref().unwrap()[0]; let field = to_array( child_field.data_type().clone(), - child_field.dict_id(), + &field.fields[0], children, dictionaries, )?; @@ -250,22 +236,22 @@ fn to_map( fn to_dictionary( data_type: DataType, - dict_id: i64, + field: &IpcField, json_col: &ArrowJsonColumn, dictionaries: &HashMap, ) -> Result> { // find dictionary + let dict_id = field.dictionary_id.unwrap(); let dictionary = dictionaries.get(&dict_id).ok_or_else(|| { ArrowError::OutOfSpec(format!("Unable to find any dictionary id {}", dict_id)) })?; let keys = to_primitive(json_col, K::PRIMITIVE.into()); - // todo: make DataType::Dictionary hold a Field so that it can hold dictionary_id let inner_data_type = DictionaryArray::::get_child(&data_type); let values = to_array( inner_data_type.clone(), - None, // this should not be None: we need to propagate the id as dicts can be nested. + field, &dictionary.data.columns[0], dictionaries, )?; @@ -276,7 +262,7 @@ fn to_dictionary( /// Construct an [`Array`] from the JSON integration format pub fn to_array( data_type: DataType, - dict_id: Option, + field: &IpcField, json_col: &ArrowJsonColumn, dictionaries: &HashMap, ) -> Result> { @@ -330,8 +316,8 @@ pub fn to_array( data_type, values, validity, ))) } - List => to_list::(json_col, data_type, dictionaries), - LargeList => to_list::(json_col, data_type, dictionaries), + List => to_list::(json_col, data_type, field, dictionaries), + LargeList => to_list::(json_col, data_type, field, dictionaries), FixedSizeList => { let validity = to_validity(&json_col.validity); @@ -340,7 +326,7 @@ pub fn to_array( let children = &json_col.children.as_ref().unwrap()[0]; let values = to_array( child_field.data_type().clone(), - child_field.dict_id(), + &field.fields[0], children, dictionaries, )?; @@ -357,13 +343,9 @@ pub fn to_array( let values = fields .iter() .zip(json_col.children.as_ref().unwrap()) - .map(|(field, col)| { - to_array( - field.data_type().clone(), - field.dict_id(), - col, - dictionaries, - ) + .zip(field.fields.iter()) + .map(|((field, col), ipc_field)| { + to_array(field.data_type().clone(), ipc_field, col, dictionaries) }) .collect::>>()?; @@ -372,7 +354,7 @@ pub fn to_array( } Dictionary(key_type) => { match_integer_type!(key_type, |$T| { - to_dictionary::<$T>(data_type, dict_id.unwrap(), json_col, dictionaries) + to_dictionary::<$T>(data_type, field, json_col, dictionaries) }) } Union => { @@ -380,13 +362,9 @@ pub fn to_array( let fields = fields .iter() .zip(json_col.children.as_ref().unwrap()) - .map(|(field, col)| { - to_array( - field.data_type().clone(), - field.dict_id(), - col, - dictionaries, - ) + .zip(field.fields.iter()) + .map(|((field, col), ipc_field)| { + to_array(field.data_type().clone(), ipc_field, col, dictionaries) }) .collect::>>()?; @@ -428,12 +406,13 @@ pub fn to_array( let array = UnionArray::from_data(data_type, types, fields, offsets); Ok(Arc::new(array)) } - Map => to_map(json_col, data_type, dictionaries), + Map => to_map(json_col, data_type, field, dictionaries), } } pub fn to_record_batch( schema: &Schema, + ipc_fields: &[IpcField], json_batch: &ArrowJsonBatch, json_dictionaries: &HashMap, ) -> Result { @@ -441,10 +420,11 @@ pub fn to_record_batch( .fields() .iter() .zip(&json_batch.columns) - .map(|(field, json_col)| { + .zip(ipc_fields.iter()) + .map(|((field, json_col), ipc_field)| { to_array( field.data_type().clone(), - field.dict_id(), + ipc_field, json_col, json_dictionaries, ) diff --git a/src/io/json_integration/read/mod.rs b/src/io/json_integration/read/mod.rs new file mode 100644 index 00000000000..9a4e5318639 --- /dev/null +++ b/src/io/json_integration/read/mod.rs @@ -0,0 +1,4 @@ +mod array; +pub use array::*; +mod schema; +pub use schema::*; diff --git a/src/io/json_integration/schema.rs b/src/io/json_integration/read/schema.rs similarity index 51% rename from src/io/json_integration/schema.rs rename to src/io/json_integration/read/schema.rs index 607aca9474e..06cb0f1fd47 100644 --- a/src/io/json_integration/schema.rs +++ b/src/io/json_integration/read/schema.rs @@ -1,165 +1,18 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::{ - collections::{BTreeMap, HashMap}, - convert::TryFrom, -}; +use std::collections::{BTreeMap, HashMap}; use serde_derive::Deserialize; -use serde_json::{json, Value}; +use serde_json::Value; use crate::{ datatypes::UnionMode, error::{ArrowError, Result}, + io::ipc::IpcField, }; use crate::datatypes::{ get_extension, DataType, Field, IntegerType, IntervalUnit, Schema, TimeUnit, }; -pub trait ToJson { - /// Generate a JSON representation - fn to_json(&self) -> Value; -} - -impl ToJson for DataType { - fn to_json(&self) -> Value { - match self { - DataType::Null => json!({"name": "null"}), - DataType::Boolean => json!({"name": "bool"}), - DataType::Int8 => json!({"name": "int", "bitWidth": 8, "isSigned": true}), - DataType::Int16 => json!({"name": "int", "bitWidth": 16, "isSigned": true}), - DataType::Int32 => json!({"name": "int", "bitWidth": 32, "isSigned": true}), - DataType::Int64 => json!({"name": "int", "bitWidth": 64, "isSigned": true}), - DataType::UInt8 => json!({"name": "int", "bitWidth": 8, "isSigned": false}), - DataType::UInt16 => json!({"name": "int", "bitWidth": 16, "isSigned": false}), - DataType::UInt32 => json!({"name": "int", "bitWidth": 32, "isSigned": false}), - DataType::UInt64 => json!({"name": "int", "bitWidth": 64, "isSigned": false}), - DataType::Float16 => json!({"name": "floatingpoint", "precision": "HALF"}), - DataType::Float32 => json!({"name": "floatingpoint", "precision": "SINGLE"}), - DataType::Float64 => json!({"name": "floatingpoint", "precision": "DOUBLE"}), - DataType::Utf8 => json!({"name": "utf8"}), - DataType::LargeUtf8 => json!({"name": "largeutf8"}), - DataType::Binary => json!({"name": "binary"}), - DataType::LargeBinary => json!({"name": "largebinary"}), - DataType::FixedSizeBinary(byte_width) => { - json!({"name": "fixedsizebinary", "byteWidth": byte_width}) - } - DataType::Struct(_) => json!({"name": "struct"}), - DataType::Union(_, _, _) => json!({"name": "union"}), - DataType::Map(_, _) => json!({"name": "map"}), - DataType::List(_) => json!({ "name": "list"}), - DataType::LargeList(_) => json!({ "name": "largelist"}), - DataType::FixedSizeList(_, length) => { - json!({"name":"fixedsizelist", "listSize": length}) - } - DataType::Time32(unit) => { - json!({"name": "time", "bitWidth": 32, "unit": match unit { - TimeUnit::Second => "SECOND", - TimeUnit::Millisecond => "MILLISECOND", - TimeUnit::Microsecond => "MICROSECOND", - TimeUnit::Nanosecond => "NANOSECOND", - }}) - } - DataType::Time64(unit) => { - json!({"name": "time", "bitWidth": 64, "unit": match unit { - TimeUnit::Second => "SECOND", - TimeUnit::Millisecond => "MILLISECOND", - TimeUnit::Microsecond => "MICROSECOND", - TimeUnit::Nanosecond => "NANOSECOND", - }}) - } - DataType::Date32 => { - json!({"name": "date", "unit": "DAY"}) - } - DataType::Date64 => { - json!({"name": "date", "unit": "MILLISECOND"}) - } - DataType::Timestamp(unit, None) => { - json!({"name": "timestamp", "unit": match unit { - TimeUnit::Second => "SECOND", - TimeUnit::Millisecond => "MILLISECOND", - TimeUnit::Microsecond => "MICROSECOND", - TimeUnit::Nanosecond => "NANOSECOND", - }}) - } - DataType::Timestamp(unit, Some(tz)) => { - json!({"name": "timestamp", "unit": match unit { - TimeUnit::Second => "SECOND", - TimeUnit::Millisecond => "MILLISECOND", - TimeUnit::Microsecond => "MICROSECOND", - TimeUnit::Nanosecond => "NANOSECOND", - }, "timezone": tz}) - } - DataType::Interval(unit) => json!({"name": "interval", "unit": match unit { - IntervalUnit::YearMonth => "YEAR_MONTH", - IntervalUnit::DayTime => "DAY_TIME", - IntervalUnit::MonthDayNano => "MONTH_DAY_NANO", - }}), - DataType::Duration(unit) => json!({"name": "duration", "unit": match unit { - TimeUnit::Second => "SECOND", - TimeUnit::Millisecond => "MILLISECOND", - TimeUnit::Microsecond => "MICROSECOND", - TimeUnit::Nanosecond => "NANOSECOND", - }}), - DataType::Dictionary(_, _, _) => json!({ "name": "dictionary"}), - DataType::Decimal(precision, scale) => { - json!({"name": "decimal", "precision": precision, "scale": scale}) - } - DataType::Extension(_, inner_data_type, _) => inner_data_type.to_json(), - } - } -} - -impl ToJson for Field { - fn to_json(&self) -> Value { - let children: Vec = match self.data_type() { - DataType::Struct(fields) => fields.iter().map(|f| f.to_json()).collect(), - DataType::List(field) => vec![field.to_json()], - DataType::LargeList(field) => vec![field.to_json()], - DataType::FixedSizeList(field, _) => vec![field.to_json()], - _ => vec![], - }; - match self.data_type() { - DataType::Dictionary(ref index_type, ref value_type, is_ordered) => { - let index_type: DataType = (*index_type).into(); - json!({ - "name": self.name(), - "nullable": self.is_nullable(), - "type": value_type.to_json(), - "children": children, - "dictionary": { - "id": self.dict_id(), - "indexType": index_type.to_json(), - "isOrdered": is_ordered - } - }) - } - _ => json!({ - "name": self.name(), - "nullable": self.is_nullable(), - "type": self.data_type().to_json(), - "children": children - }), - } - } -} - fn to_time_unit(item: Option<&Value>) -> Result { match item { Some(p) if p == "SECOND" => Ok(TimeUnit::Second), @@ -218,13 +71,13 @@ fn to_int(item: &Value) -> Result { }) } -fn children(children: Option<&Value>) -> Result> { +fn deserialize_fields(children: Option<&Value>) -> Result> { children .map(|x| { if let Value::Array(values) = x { values .iter() - .map(Field::try_from) + .map(deserialize_field) .collect::>>() } else { Err(ArrowError::OutOfSpec( @@ -440,103 +293,124 @@ fn to_data_type(item: &Value, mut children: Vec) -> Result { }) } -impl TryFrom<&Value> for Field { - type Error = ArrowError; +fn deserialize_ipc_field(value: &Value) -> Result { + let map = if let Value::Object(map) = value { + map + } else { + return Err(ArrowError::OutOfSpec( + "Invalid json value type for field".to_string(), + )); + }; - fn try_from(value: &Value) -> Result { - match *value { - Value::Object(ref map) => { - let name = match map.get("name") { - Some(&Value::String(ref name)) => name.to_string(), - _ => { - return Err(ArrowError::OutOfSpec( - "Field missing 'name' attribute".to_string(), - )); - } - }; - let nullable = match map.get("nullable") { - Some(&Value::Bool(b)) => b, - _ => { - return Err(ArrowError::OutOfSpec( - "Field missing 'nullable' attribute".to_string(), - )); - } - }; + let fields = map + .get("children") + .map(|x| { + if let Value::Array(values) = x { + values + .iter() + .map(deserialize_ipc_field) + .collect::>>() + } else { + Err(ArrowError::OutOfSpec( + "children must be an array".to_string(), + )) + } + }) + .unwrap_or_else(|| Ok(vec![]))?; - let children = children(map.get("children"))?; + let dictionary_id = if let Some(dictionary) = map.get("dictionary") { + match dictionary.get("id") { + Some(Value::Number(n)) => Some(n.as_i64().unwrap()), + _ => { + return Err(ArrowError::OutOfSpec( + "Field missing 'id' attribute".to_string(), + )); + } + } + } else { + None + }; + Ok(IpcField { + fields, + dictionary_id, + }) +} - let metadata = if let Some(metadata) = map.get("metadata") { - Some(read_metadata(metadata)?) - } else { - None - }; +fn deserialize_field(value: &Value) -> Result { + let map = if let Value::Object(map) = value { + map + } else { + return Err(ArrowError::OutOfSpec( + "Invalid json value type for field".to_string(), + )); + }; + + let name = match map.get("name") { + Some(&Value::String(ref name)) => name.to_string(), + _ => { + return Err(ArrowError::OutOfSpec( + "Field missing 'name' attribute".to_string(), + )); + } + }; + let nullable = match map.get("nullable") { + Some(&Value::Bool(b)) => b, + _ => { + return Err(ArrowError::OutOfSpec( + "Field missing 'nullable' attribute".to_string(), + )); + } + }; - let extension = get_extension(&metadata); + let metadata = if let Some(metadata) = map.get("metadata") { + Some(read_metadata(metadata)?) + } else { + None + }; - let type_ = map - .get("type") - .ok_or_else(|| ArrowError::OutOfSpec("type missing".to_string()))?; + let extension = get_extension(&metadata); - let data_type = to_data_type(type_, children)?; + let type_ = map + .get("type") + .ok_or_else(|| ArrowError::OutOfSpec("type missing".to_string()))?; - let data_type = if let Some((name, metadata)) = extension { - DataType::Extension(name, Box::new(data_type), metadata) - } else { - data_type - }; + let children = deserialize_fields(map.get("children"))?; + let data_type = to_data_type(type_, children)?; - let data_type = if let Some(dictionary) = map.get("dictionary") { - let index_type = match dictionary.get("indexType") { - Some(t) => to_int(t)?, - _ => { - return Err(ArrowError::OutOfSpec( - "Field missing 'indexType' attribute".to_string(), - )); - } - }; - let is_ordered = match dictionary.get("isOrdered") { - Some(&Value::Bool(n)) => n, - _ => { - return Err(ArrowError::OutOfSpec( - "Field missing 'isOrdered' attribute".to_string(), - )); - } - }; - DataType::Dictionary(index_type, Box::new(data_type), is_ordered) - } else { - data_type - }; + let data_type = if let Some((name, metadata)) = extension { + DataType::Extension(name, Box::new(data_type), metadata) + } else { + data_type + }; - let dict_id = if let Some(dictionary) = map.get("dictionary") { - match dictionary.get("id") { - Some(Value::Number(n)) => n.as_i64().unwrap(), - _ => { - return Err(ArrowError::OutOfSpec( - "Field missing 'id' attribute".to_string(), - )); - } - } - } else { - 0 - }; - let mut f = Field::new_dict(&name, data_type, nullable, dict_id); - f.set_metadata(metadata); - Ok(f) + let data_type = if let Some(dictionary) = map.get("dictionary") { + let index_type = match dictionary.get("indexType") { + Some(t) => to_int(t)?, + _ => { + return Err(ArrowError::OutOfSpec( + "Field missing 'indexType' attribute".to_string(), + )); } - _ => Err(ArrowError::OutOfSpec( - "Invalid json value type for field".to_string(), - )), - } - } -} + }; + let is_ordered = match dictionary.get("isOrdered") { + Some(&Value::Bool(n)) => n, + _ => { + return Err(ArrowError::OutOfSpec( + "Field missing 'isOrdered' attribute".to_string(), + )); + } + }; + DataType::Dictionary(index_type, Box::new(data_type), is_ordered) + } else { + data_type + }; -impl ToJson for Schema { - fn to_json(&self) -> Value { - json!({ - "fields": self.fields.iter().map(|field| field.to_json()).collect::>(), - "metadata": serde_json::to_value(&self.metadata).unwrap(), - }) - } + Ok(Field { + name, + data_type, + nullable, + metadata, + }) } #[derive(Deserialize)] @@ -575,31 +449,43 @@ fn from_metadata(json: &Value) -> Result> { } } -impl TryFrom<&Value> for Schema { - type Error = ArrowError; +/// Deserializes a [`Value`] +pub fn deserialize_schema(value: &Value) -> Result<(Schema, Vec)> { + let schema = if let Value::Object(schema) = value { + schema + } else { + return Err(ArrowError::OutOfSpec( + "Invalid json value type for schema".to_string(), + )); + }; + + let fields = if let Some(Value::Array(fields)) = schema.get("fields") { + fields + .iter() + .map(deserialize_field) + .collect::>()? + } else { + return Err(ArrowError::OutOfSpec( + "Schema fields should be an array".to_string(), + )); + }; - fn try_from(json: &Value) -> Result { - match *json { - Value::Object(ref schema) => { - let fields = if let Some(Value::Array(fields)) = schema.get("fields") { - fields.iter().map(Field::try_from).collect::>()? - } else { - return Err(ArrowError::OutOfSpec( - "Schema fields should be an array".to_string(), - )); - }; + let ipc_fields = if let Some(Value::Array(fields)) = schema.get("fields") { + fields + .iter() + .map(deserialize_ipc_field) + .collect::>()? + } else { + return Err(ArrowError::OutOfSpec( + "Schema fields should be an array".to_string(), + )); + }; - let metadata = if let Some(value) = schema.get("metadata") { - from_metadata(value)? - } else { - HashMap::default() - }; + let metadata = if let Some(value) = schema.get("metadata") { + from_metadata(value)? + } else { + HashMap::default() + }; - Ok(Self { fields, metadata }) - } - _ => Err(ArrowError::OutOfSpec( - "Invalid json value type for schema".to_string(), - )), - } - } + Ok((Schema { fields, metadata }, ipc_fields)) } diff --git a/src/io/json_integration/write.rs b/src/io/json_integration/write/array.rs similarity index 93% rename from src/io/json_integration/write.rs rename to src/io/json_integration/write/array.rs index fcb1ebffec3..ea0bc1c7669 100644 --- a/src/io/json_integration/write.rs +++ b/src/io/json_integration/write/array.rs @@ -1,8 +1,9 @@ use crate::record_batch::RecordBatch; use crate::{array::PrimitiveArray, datatypes::DataType}; -use super::{ArrowJsonBatch, ArrowJsonColumn}; +use super::super::{ArrowJsonBatch, ArrowJsonColumn}; +/// Serializes a [`RecordBatch`] to [`ArrowJsonBatch`]. pub fn from_record_batch(batch: &RecordBatch) -> ArrowJsonBatch { let mut json_batch = ArrowJsonBatch { count: batch.num_rows(), diff --git a/src/io/json_integration/write/mod.rs b/src/io/json_integration/write/mod.rs new file mode 100644 index 00000000000..9a4e5318639 --- /dev/null +++ b/src/io/json_integration/write/mod.rs @@ -0,0 +1,4 @@ +mod array; +pub use array::*; +mod schema; +pub use schema::*; diff --git a/src/io/json_integration/write/schema.rs b/src/io/json_integration/write/schema.rs new file mode 100644 index 00000000000..a9a73d86100 --- /dev/null +++ b/src/io/json_integration/write/schema.rs @@ -0,0 +1,173 @@ +use serde_json::{json, Map, Value}; + +use crate::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; +use crate::io::ipc::IpcField; +use crate::io::json_integration::ArrowJsonSchema; + +use super::super::{ArrowJsonField, ArrowJsonFieldDictionary, IntegerType}; + +fn serialize_data_type(data_type: &DataType) -> Value { + match data_type { + DataType::Null => json!({"name": "null"}), + DataType::Boolean => json!({"name": "bool"}), + DataType::Int8 => json!({"name": "int", "bitWidth": 8, "isSigned": true}), + DataType::Int16 => json!({"name": "int", "bitWidth": 16, "isSigned": true}), + DataType::Int32 => json!({"name": "int", "bitWidth": 32, "isSigned": true}), + DataType::Int64 => json!({"name": "int", "bitWidth": 64, "isSigned": true}), + DataType::UInt8 => json!({"name": "int", "bitWidth": 8, "isSigned": false}), + DataType::UInt16 => json!({"name": "int", "bitWidth": 16, "isSigned": false}), + DataType::UInt32 => json!({"name": "int", "bitWidth": 32, "isSigned": false}), + DataType::UInt64 => json!({"name": "int", "bitWidth": 64, "isSigned": false}), + DataType::Float16 => json!({"name": "floatingpoint", "precision": "HALF"}), + DataType::Float32 => json!({"name": "floatingpoint", "precision": "SINGLE"}), + DataType::Float64 => json!({"name": "floatingpoint", "precision": "DOUBLE"}), + DataType::Utf8 => json!({"name": "utf8"}), + DataType::LargeUtf8 => json!({"name": "largeutf8"}), + DataType::Binary => json!({"name": "binary"}), + DataType::LargeBinary => json!({"name": "largebinary"}), + DataType::FixedSizeBinary(byte_width) => { + json!({"name": "fixedsizebinary", "byteWidth": byte_width}) + } + DataType::Struct(_) => json!({"name": "struct"}), + DataType::Union(_, _, _) => json!({"name": "union"}), + DataType::Map(_, _) => json!({"name": "map"}), + DataType::List(_) => json!({ "name": "list"}), + DataType::LargeList(_) => json!({ "name": "largelist"}), + DataType::FixedSizeList(_, length) => { + json!({"name":"fixedsizelist", "listSize": length}) + } + DataType::Time32(unit) => { + json!({"name": "time", "bitWidth": 32, "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }}) + } + DataType::Time64(unit) => { + json!({"name": "time", "bitWidth": 64, "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }}) + } + DataType::Date32 => { + json!({"name": "date", "unit": "DAY"}) + } + DataType::Date64 => { + json!({"name": "date", "unit": "MILLISECOND"}) + } + DataType::Timestamp(unit, None) => { + json!({"name": "timestamp", "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }}) + } + DataType::Timestamp(unit, Some(tz)) => { + json!({"name": "timestamp", "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }, "timezone": tz}) + } + DataType::Interval(unit) => json!({"name": "interval", "unit": match unit { + IntervalUnit::YearMonth => "YEAR_MONTH", + IntervalUnit::DayTime => "DAY_TIME", + IntervalUnit::MonthDayNano => "MONTH_DAY_NANO", + }}), + DataType::Duration(unit) => json!({"name": "duration", "unit": match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }}), + DataType::Dictionary(_, _, _) => json!({ "name": "dictionary"}), + DataType::Decimal(precision, scale) => { + json!({"name": "decimal", "precision": precision, "scale": scale}) + } + DataType::Extension(_, inner_data_type, _) => serialize_data_type(inner_data_type), + } +} + +fn serialize_field(field: &Field, ipc_field: &IpcField) -> ArrowJsonField { + let children = match field.data_type() { + DataType::Union(fields, ..) | DataType::Struct(fields) => fields + .iter() + .zip(ipc_field.fields.iter()) + .map(|(field, ipc_field)| serialize_field(field, ipc_field)) + .collect(), + DataType::Map(field, ..) + | DataType::FixedSizeList(field, _) + | DataType::LargeList(field) + | DataType::List(field) => { + vec![serialize_field(field, &ipc_field.fields[0])] + } + _ => vec![], + }; + let metadata = serialize_metadata(field); + + let dictionary = if let DataType::Dictionary(key_type, _, is_ordered) = field.data_type() { + use crate::datatypes::IntegerType::*; + Some(ArrowJsonFieldDictionary { + id: ipc_field.dictionary_id.unwrap(), + index_type: IntegerType { + name: "".to_string(), + bit_width: match key_type { + Int8 | UInt8 => 8, + Int16 | UInt16 => 16, + Int32 | UInt32 => 32, + Int64 | UInt64 => 64, + }, + is_signed: match key_type { + Int8 | Int16 | Int32 | Int64 => true, + UInt8 | UInt16 | UInt32 | UInt64 => false, + }, + }, + is_ordered: *is_ordered, + }) + } else { + None + }; + + ArrowJsonField { + name: field.name().to_string(), + field_type: serialize_data_type(field.data_type()), + nullable: field.is_nullable(), + children, + dictionary, + metadata, + } +} + +/// Serializes a [`Schema`] and associated [`IpcField`] to [`ArrowJsonSchema`]. +pub fn serialize_schema(schema: &Schema, ipc_fields: &[IpcField]) -> ArrowJsonSchema { + ArrowJsonSchema { + fields: schema + .fields + .iter() + .zip(ipc_fields.iter()) + .map(|(field, ipc_field)| serialize_field(field, ipc_field)) + .collect(), + metadata: Some(serde_json::to_value(&schema.metadata).unwrap()), + } +} + +fn serialize_metadata(field: &Field) -> Option { + field.metadata().as_ref().and_then(|kv_list| { + let mut array = Vec::new(); + for (k, v) in kv_list { + let mut kv_map = Map::new(); + kv_map.insert(k.clone(), Value::String(v.clone())); + array.push(Value::Object(kv_map)); + } + if !array.is_empty() { + Some(Value::Array(array)) + } else { + None + } + }) +} diff --git a/src/io/parquet/read/schema/metadata.rs b/src/io/parquet/read/schema/metadata.rs index c7cd2370305..2b5825bb8df 100644 --- a/src/io/parquet/read/schema/metadata.rs +++ b/src/io/parquet/read/schema/metadata.rs @@ -6,7 +6,7 @@ pub use parquet2::metadata::KeyValue; use crate::datatypes::Schema; use crate::error::{ArrowError, Result}; -use crate::io::ipc::fb_to_schema; +use crate::io::ipc::read::fb_to_schema; use super::super::super::ARROW_SCHEMA_META_KEY; diff --git a/src/io/parquet/write/schema.rs b/src/io/parquet/write/schema.rs index 521ec53aa88..3db64579aad 100644 --- a/src/io/parquet/write/schema.rs +++ b/src/io/parquet/write/schema.rs @@ -12,6 +12,7 @@ use parquet2::{ use crate::{ datatypes::{DataType, Field, Schema, TimeUnit}, error::{ArrowError, Result}, + io::ipc::write::default_ipc_fields, io::ipc::write::schema_to_bytes, io::parquet::write::decimal_length_from_precision, }; @@ -19,7 +20,7 @@ use crate::{ use super::super::ARROW_SCHEMA_META_KEY; pub fn schema_to_metadata_key(schema: &Schema) -> KeyValue { - let serialized_schema = schema_to_bytes(schema); + let serialized_schema = schema_to_bytes(schema, &default_ipc_fields(schema.fields())); // manually prepending the length to the schema as arrow uses the legacy IPC format // TODO: change after addressing ARROW-9777 diff --git a/src/lib.rs b/src/lib.rs index fe3d42c4e40..d215753026f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ pub mod array; pub mod bitmap; pub mod buffer; +pub(crate) mod columns; pub mod error; pub mod scalar; pub mod trusted_len; diff --git a/src/record_batch.rs b/src/record_batch.rs index 05457f253d2..369e74a0481 100644 --- a/src/record_batch.rs +++ b/src/record_batch.rs @@ -291,6 +291,12 @@ impl RecordBatch { let schema = Arc::new(Schema::new(fields)); RecordBatch::try_new(schema, columns) } + + /// Deconstructs itself into its internal components + pub fn into_inner(self) -> (Vec>, Arc) { + let Self { columns, schema } = self; + (columns, schema) + } } /// Options that control the behaviour used when creating a [`RecordBatch`]. diff --git a/tests/it/io/ipc/common.rs b/tests/it/io/ipc/common.rs index 8eff523789e..52a75e0052a 100644 --- a/tests/it/io/ipc/common.rs +++ b/tests/it/io/ipc/common.rs @@ -1,18 +1,18 @@ -use std::{collections::HashMap, convert::TryFrom, fs::File, io::Read}; +use std::{collections::HashMap, fs::File, io::Read}; use arrow2::{ - datatypes::Schema, - error::Result, - io::ipc::read::read_stream_metadata, - io::ipc::read::StreamReader, - io::json_integration::{to_record_batch, ArrowJson}, - record_batch::RecordBatch, + datatypes::Schema, error::Result, io::ipc::read::read_stream_metadata, + io::ipc::read::StreamReader, io::ipc::IpcField, io::json_integration::read, + io::json_integration::ArrowJson, record_batch::RecordBatch, }; use flate2::read::GzDecoder; /// Read gzipped JSON file -pub fn read_gzip_json(version: &str, file_name: &str) -> Result<(Schema, Vec)> { +pub fn read_gzip_json( + version: &str, + file_name: &str, +) -> Result<(Schema, Vec, Vec)> { let testdata = crate::test_util::arrow_test_data(); let file = File::open(format!( "{}/arrow-ipc-stream/integration/{}/{}.json.gz", @@ -27,7 +27,7 @@ pub fn read_gzip_json(version: &str, file_name: &str) -> Result<(Schema, Vec Result<(Schema, Vec>>()?; - Ok((schema, batches)) + Ok((schema, ipc_fields, batches)) } -pub fn read_arrow_stream(version: &str, file_name: &str) -> (Schema, Vec) { +pub fn read_arrow_stream( + version: &str, + file_name: &str, +) -> (Schema, Vec, Vec) { let testdata = crate::test_util::arrow_test_data(); let mut file = File::open(format!( "{}/arrow-ipc-stream/integration/{}/{}.stream", @@ -58,10 +61,12 @@ pub fn read_arrow_stream(version: &str, file_name: &str) -> (Schema, Vec>() diff --git a/tests/it/io/ipc/read/file.rs b/tests/it/io/ipc/read/file.rs index c9f067ccc14..8cce26b81c5 100644 --- a/tests/it/io/ipc/read/file.rs +++ b/tests/it/io/ipc/read/file.rs @@ -13,7 +13,7 @@ fn test_file(version: &str, file_name: &str) -> Result<()> { ))?; // read expected JSON output - let (schema, batches) = read_gzip_json(version, file_name)?; + let (schema, _, batches) = read_gzip_json(version, file_name)?; let metadata = read_file_metadata(&mut file)?; let reader = FileReader::new(file, metadata, None); diff --git a/tests/it/io/ipc/read/stream.rs b/tests/it/io/ipc/read/stream.rs index 7a69e191785..3c3fcf095d6 100644 --- a/tests/it/io/ipc/read/stream.rs +++ b/tests/it/io/ipc/read/stream.rs @@ -16,9 +16,10 @@ fn test_file(version: &str, file_name: &str) -> Result<()> { let reader = StreamReader::new(file, metadata); // read expected JSON output - let (schema, batches) = read_gzip_json(version, file_name)?; + let (schema, ipc_fields, batches) = read_gzip_json(version, file_name)?; - assert_eq!(&schema, reader.schema().as_ref()); + assert_eq!(&schema, reader.metadata().schema.as_ref()); + assert_eq!(&ipc_fields, &reader.metadata().ipc_schema.fields); batches .iter() diff --git a/tests/it/io/ipc/write/file.rs b/tests/it/io/ipc/write/file.rs index 8dcbc7423cb..979d970c150 100644 --- a/tests/it/io/ipc/write/file.rs +++ b/tests/it/io/ipc/write/file.rs @@ -1,35 +1,41 @@ use std::io::Cursor; use arrow2::array::*; +use arrow2::datatypes::Schema; use arrow2::error::Result; use arrow2::io::ipc::read::{read_file_metadata, FileReader}; -use arrow2::io::ipc::write::*; +use arrow2::io::ipc::{write::*, IpcField}; use arrow2::record_batch::RecordBatch; use crate::io::ipc::common::read_gzip_json; -fn round_trip(batch: RecordBatch) -> Result<()> { - let result = Vec::::new(); - - // write IPC version 5 - let written_result = { - let options = WriteOptions { - compression: Some(Compression::LZ4), - }; - let mut writer = FileWriter::try_new(result, batch.schema(), options)?; - writer.write(&batch)?; - writer.finish()?; - writer.into_inner() - }; - let mut reader = Cursor::new(written_result); +fn write_( + batches: &[RecordBatch], + schema: &Schema, + ipc_fields: Option>, + compression: Option, +) -> Result> { + let result = vec![]; + let options = WriteOptions { compression }; + let mut writer = FileWriter::try_new(result, schema, ipc_fields.clone(), options)?; + for batch in batches { + writer.write(batch, ipc_fields.as_ref().map(|x| x.as_ref()))?; + } + writer.finish()?; + Ok(writer.into_inner()) +} + +fn round_trip(batch: RecordBatch, ipc_fields: Option>) -> Result<()> { + let (expected_schema, expected_batches) = (batch.schema().clone(), vec![batch.clone()]); + + let schema = batch.schema().clone(); + let result = write_(&[batch], &schema, ipc_fields, Some(Compression::ZSTD))?; + let mut reader = Cursor::new(result); let metadata = read_file_metadata(&mut reader)?; let schema = metadata.schema().clone(); let reader = FileReader::new(reader, metadata, None); - // read expected JSON output - let (expected_schema, expected_batches) = (batch.schema().clone(), vec![batch]); - assert_eq!(schema.as_ref(), expected_schema.as_ref()); let batches = reader.collect::>>()?; @@ -39,9 +45,7 @@ fn round_trip(batch: RecordBatch) -> Result<()> { } fn test_file(version: &str, file_name: &str, compressed: bool) -> Result<()> { - let (schema, batches) = read_gzip_json(version, file_name)?; - - let result = Vec::::new(); + let (schema, ipc_fields, batches) = read_gzip_json(version, file_name)?; let compression = if compressed { Some(Compression::ZSTD) @@ -49,26 +53,20 @@ fn test_file(version: &str, file_name: &str, compressed: bool) -> Result<()> { None }; - // write IPC version 5 - let written_result = { - let options = WriteOptions { compression }; - let mut writer = FileWriter::try_new(result, &schema, options)?; - for batch in batches { - writer.write(&batch)?; - } - writer.finish()?; - writer.into_inner() - }; - let mut reader = Cursor::new(written_result); + let result = write_(&batches, &schema, Some(ipc_fields), compression)?; + let mut reader = Cursor::new(result); let metadata = read_file_metadata(&mut reader)?; let schema = metadata.schema().clone(); + let ipc_fields = metadata.ipc_schema.fields.clone(); let reader = FileReader::new(reader, metadata, None); // read expected JSON output - let (expected_schema, expected_batches) = read_gzip_json(version, file_name)?; + let (expected_schema, expected_ipc_fields, expected_batches) = + read_gzip_json(version, file_name)?; assert_eq!(schema.as_ref(), &expected_schema); + assert_eq!(ipc_fields, expected_ipc_fields); let batches = reader.collect::>>()?; @@ -332,7 +330,7 @@ fn write_boolean() -> Result<()> { Some(true), ])) as Arc; let batch = RecordBatch::try_from_iter(vec![("a", array)])?; - round_trip(batch) + round_trip(batch, None) } #[test] @@ -341,7 +339,7 @@ fn write_sliced_utf8() -> Result<()> { use std::sync::Arc; let array = Arc::new(Utf8Array::::from_slice(["aa", "bb"]).slice(1, 1)) as Arc; let batch = RecordBatch::try_from_iter(vec![("a", array)])?; - round_trip(batch) + round_trip(batch, None) } #[test] @@ -357,5 +355,5 @@ fn write_sliced_list() -> Result<()> { array.try_extend(data).unwrap(); let array = array.into_arc().slice(1, 2).into(); let batch = RecordBatch::try_from_iter(vec![("a", array)]).unwrap(); - round_trip(batch) + round_trip(batch, None) } diff --git a/tests/it/io/ipc/write/stream.rs b/tests/it/io/ipc/write/stream.rs index 54693a22fdb..5384de7377f 100644 --- a/tests/it/io/ipc/write/stream.rs +++ b/tests/it/io/ipc/write/stream.rs @@ -1,38 +1,47 @@ use std::io::Cursor; +use arrow2::datatypes::Schema; use arrow2::error::Result; use arrow2::io::ipc::read::read_stream_metadata; use arrow2::io::ipc::read::StreamReader; use arrow2::io::ipc::write::{StreamWriter, WriteOptions}; +use arrow2::io::ipc::IpcField; +use arrow2::record_batch::RecordBatch; use crate::io::ipc::common::read_arrow_stream; use crate::io::ipc::common::read_gzip_json; -fn test_file(version: &str, file_name: &str) { - let (schema, batches) = read_arrow_stream(version, file_name); - - let mut result = Vec::::new(); - - // write IPC version 5 - { - let options = WriteOptions { compression: None }; - let mut writer = StreamWriter::try_new(&mut result, &schema, options).unwrap(); - for batch in batches { - writer.write(&batch).unwrap(); - } - writer.finish().unwrap(); +fn write_(schema: &Schema, ipc_fields: &[IpcField], batches: &[RecordBatch]) -> Vec { + let mut result = vec![]; + + let options = WriteOptions { compression: None }; + let mut writer = StreamWriter::new(&mut result, options); + writer.start(&schema, ipc_fields).unwrap(); + for batch in batches { + writer.write(batch, ipc_fields).unwrap(); } + writer.finish().unwrap(); + result +} + +fn test_file(version: &str, file_name: &str) { + let (schema, ipc_fields, batches) = read_arrow_stream(version, file_name); + + let result = write_(&schema, &ipc_fields, &batches); let mut reader = Cursor::new(result); let metadata = read_stream_metadata(&mut reader).unwrap(); let reader = StreamReader::new(reader, metadata); - let schema = reader.schema().clone(); + let schema = reader.metadata().schema.clone(); + let ipc_fields = reader.metadata().ipc_schema.fields.clone(); // read expected JSON output - let (expected_schema, expected_batches) = read_gzip_json(version, file_name).unwrap(); + let (expected_schema, expected_ipc_fields, expected_batches) = + read_gzip_json(version, file_name).unwrap(); assert_eq!(schema.as_ref(), &expected_schema); + assert_eq!(ipc_fields, expected_ipc_fields); let batches = reader .map(|x| x.map(|x| x.unwrap())) diff --git a/tests/it/io/ipc/write_async.rs b/tests/it/io/ipc/write_async.rs index 31f19d37d1e..18110bbe3bd 100644 --- a/tests/it/io/ipc/write_async.rs +++ b/tests/it/io/ipc/write_async.rs @@ -1,41 +1,51 @@ use std::io::Cursor; +use arrow2::datatypes::Schema; use arrow2::error::Result; -use arrow2::io::ipc::read::read_stream_metadata; -use arrow2::io::ipc::read::StreamReader; -use arrow2::io::ipc::write::stream_async::{StreamWriter, WriteOptions}; +use arrow2::io::ipc::read; +use arrow2::io::ipc::write::stream_async; +use arrow2::io::ipc::IpcField; +use arrow2::record_batch::RecordBatch; use futures::io::Cursor as AsyncCursor; use crate::io::ipc::common::read_arrow_stream; use crate::io::ipc::common::read_gzip_json; -async fn test_file(version: &str, file_name: &str) -> Result<()> { - let (schema, batches) = read_arrow_stream(version, file_name); - - let mut result = AsyncCursor::new(Vec::::new()); - - // write IPC version 5 - { - let options = WriteOptions { compression: None }; - let mut writer = StreamWriter::new(&mut result, options); - writer.start(&schema).await?; - for batch in batches { - writer.write(&batch).await?; - } - writer.finish().await?; +async fn write_( + schema: &Schema, + ipc_fields: &[IpcField], + batches: &[RecordBatch], +) -> Result> { + let mut result = AsyncCursor::new(vec![]); + + let options = stream_async::WriteOptions { compression: None }; + let mut writer = stream_async::StreamWriter::new(&mut result, options); + writer.start(&schema, Some(&ipc_fields)).await?; + for batch in batches { + writer.write(batch, Some(&ipc_fields)).await?; } - let result = result.into_inner(); + writer.finish().await?; + Ok(result.into_inner()) +} + +async fn test_file(version: &str, file_name: &str) -> Result<()> { + let (schema, ipc_fields, batches) = read_arrow_stream(version, file_name); + + let result = write_(&schema, &ipc_fields, &batches).await?; let mut reader = Cursor::new(result); - let metadata = read_stream_metadata(&mut reader)?; - let reader = StreamReader::new(reader, metadata); + let metadata = read::read_stream_metadata(&mut reader)?; + let reader = read::StreamReader::new(reader, metadata); - let schema = reader.schema().clone(); + let schema = reader.metadata().schema.as_ref(); + let ipc_fields = reader.metadata().ipc_schema.fields.clone(); // read expected JSON output - let (expected_schema, expected_batches) = read_gzip_json(version, file_name).unwrap(); + let (expected_schema, expected_ipc_fields, expected_batches) = + read_gzip_json(version, file_name).unwrap(); - assert_eq!(schema.as_ref(), &expected_schema); + assert_eq!(schema, &expected_schema); + assert_eq!(ipc_fields, expected_ipc_fields); let batches = reader .map(|x| x.map(|x| x.unwrap())) diff --git a/tests/it/io/parquet/mod.rs b/tests/it/io/parquet/mod.rs index 9a722802571..ada15b32615 100644 --- a/tests/it/io/parquet/mod.rs +++ b/tests/it/io/parquet/mod.rs @@ -678,7 +678,7 @@ fn integration_read(data: &[u8]) -> Result<(Arc, Vec)> { } fn test_file(version: &str, file_name: &str) -> Result<()> { - let (schema, batches) = read_gzip_json(version, file_name)?; + let (schema, _, batches) = read_gzip_json(version, file_name)?; let data = integration_write(&schema, &batches)?; From 02d0201770427556eccd39f2ffe9509ffb607242 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 26 Dec 2021 22:15:39 +0000 Subject: [PATCH 2/6] Migrated integration tests --- .../src/bin/arrow-file-to-stream.rs | 9 ++- .../src/bin/arrow-json-integration-test.rs | 42 +++-------- .../src/bin/arrow-stream-to-file.rs | 12 ++- .../integration_test.rs | 75 ++++++++++++------- .../integration_test.rs | 49 +++++++----- integration-testing/src/lib.rs | 11 +-- src/io/flight/mod.rs | 34 ++------- 7 files changed, 114 insertions(+), 118 deletions(-) diff --git a/integration-testing/src/bin/arrow-file-to-stream.rs b/integration-testing/src/bin/arrow-file-to-stream.rs index b3b024abb1d..b5411bebfa4 100644 --- a/integration-testing/src/bin/arrow-file-to-stream.rs +++ b/integration-testing/src/bin/arrow-file-to-stream.rs @@ -27,15 +27,16 @@ fn main() -> Result<()> { let filename = &args[1]; let mut f = File::open(filename)?; let metadata = read::read_file_metadata(&mut f)?; - let mut reader = read::FileReader::new(f, metadata, None); - let schema = reader.schema(); + let mut reader = read::FileReader::new(f, metadata.clone(), None); let options = write::WriteOptions { compression: None }; - let mut writer = write::StreamWriter::try_new(std::io::stdout(), schema, options)?; + let mut writer = write::StreamWriter::new(std::io::stdout(), options); + + writer.start(&metadata.schema, &metadata.ipc_schema.fields)?; reader.try_for_each(|batch| { let batch = batch?; - writer.write(&batch) + writer.write(&batch, &metadata.ipc_schema.fields) })?; writer.finish()?; diff --git a/integration-testing/src/bin/arrow-json-integration-test.rs b/integration-testing/src/bin/arrow-json-integration-test.rs index becb663f5ce..180896a9418 100644 --- a/integration-testing/src/bin/arrow-json-integration-test.rs +++ b/integration-testing/src/bin/arrow-json-integration-test.rs @@ -1,29 +1,13 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - use std::fs::File; +use arrow2::io::json_integration::ArrowJson; use clap::{App, Arg}; use arrow2::io::ipc::read; use arrow2::io::ipc::write; use arrow2::{ error::{ArrowError, Result}, - io::json_integration::*, + io::json_integration::write as json_write, }; use arrow_integration_testing::read_json_file; @@ -82,10 +66,15 @@ fn json_to_arrow(json_name: &str, arrow_name: &str, verbose: bool) -> Result<()> let arrow_file = File::create(arrow_name)?; let options = write::WriteOptions { compression: None }; - let mut writer = write::FileWriter::try_new(arrow_file, &json_file.schema, options)?; + let mut writer = write::FileWriter::try_new( + arrow_file, + &json_file.schema, + Some(json_file.fields), + options, + )?; for b in json_file.batches { - writer.write(&b)?; + writer.write(&b, None)?; } writer.finish()?; @@ -100,19 +89,12 @@ fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> let mut arrow_file = File::open(arrow_name)?; let metadata = read::read_file_metadata(&mut arrow_file)?; - let reader = read::FileReader::new(arrow_file, metadata, None); + let reader = read::FileReader::new(arrow_file, metadata.clone(), None); - let mut fields: Vec = vec![]; - for f in reader.schema().fields() { - fields.push(ArrowJsonField::from(f)); - } - let schema = ArrowJsonSchema { - fields, - metadata: None, - }; + let schema = json_write::serialize_schema(&metadata.schema, &metadata.ipc_schema.fields); let batches = reader - .map(|batch| Ok(from_record_batch(&batch?))) + .map(|batch| Ok(json_write::from_record_batch(&batch?))) .collect::>>()?; let arrow_json = ArrowJson { diff --git a/integration-testing/src/bin/arrow-stream-to-file.rs b/integration-testing/src/bin/arrow-stream-to-file.rs index fe41fe8e45c..ab0855bf677 100644 --- a/integration-testing/src/bin/arrow-stream-to-file.rs +++ b/integration-testing/src/bin/arrow-stream-to-file.rs @@ -24,15 +24,19 @@ use arrow2::io::ipc::write; fn main() -> Result<()> { let mut reader = io::stdin(); let metadata = read::read_stream_metadata(&mut reader)?; - let mut arrow_stream_reader = read::StreamReader::new(reader, metadata); - let schema = arrow_stream_reader.schema(); + let mut arrow_stream_reader = read::StreamReader::new(reader, metadata.clone()); let writer = io::stdout(); let options = write::WriteOptions { compression: None }; - let mut writer = write::FileWriter::try_new(writer, schema, options)?; + let mut writer = write::FileWriter::try_new( + writer, + &metadata.schema, + Some(metadata.ipc_schema.fields), + options, + )?; - arrow_stream_reader.try_for_each(|batch| writer.write(&batch?.unwrap()))?; + arrow_stream_reader.try_for_each(|batch| writer.write(&batch?.unwrap(), None))?; writer.finish()?; Ok(()) diff --git a/integration-testing/src/flight_client_scenarios/integration_test.rs b/integration-testing/src/flight_client_scenarios/integration_test.rs index dd65352df5b..461659069dc 100644 --- a/integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/integration-testing/src/flight_client_scenarios/integration_test.rs @@ -18,10 +18,15 @@ use crate::{read_json_file, ArrowFile}; use arrow2::{ - array::*, datatypes::*, - io::flight::{self, deserialize_batch, serialize_batch}, - io::ipc::{read, write}, + io::ipc::{ + read::{self, Dictionaries}, + write, IpcSchema, + }, + io::{ + flight::{self, deserialize_batch, serialize_batch}, + ipc::IpcField, + }, record_batch::RecordBatch, }; use arrow_format::flight::data::{ @@ -33,10 +38,7 @@ use arrow_format::ipc::Message::MessageHeader; use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt}; use tonic::{Request, Streaming}; -use std::{collections::HashMap, sync::Arc}; - -type ArrayRef = Arc; -type SchemaRef = Arc; +use std::sync::Arc; type Error = Box; type Result = std::result::Result; @@ -49,8 +51,15 @@ pub async fn run_scenario(host: &str, port: &str, path: &str) -> Result { let client = FlightServiceClient::connect(url).await?; let ArrowFile { - schema, batches, .. + schema, + batches, + fields, + .. } = read_json_file(path)?; + let ipc_schema = IpcSchema { + fields, + is_little_endian: true, + }; let schema = Arc::new(schema); @@ -60,19 +69,21 @@ pub async fn run_scenario(host: &str, port: &str, path: &str) -> Result { upload_data( client.clone(), - schema.clone(), + &schema, + &ipc_schema.fields, descriptor.clone(), batches.clone(), ) .await?; - verify_data(client, descriptor, schema, &batches).await?; + verify_data(client, descriptor, schema, &ipc_schema, &batches).await?; Ok(()) } async fn upload_data( mut client: Client, - schema: SchemaRef, + schema: &Schema, + fields: &[IpcField], descriptor: FlightDescriptor, original_data: Vec, ) -> Result { @@ -80,7 +91,7 @@ async fn upload_data( let options = write::WriteOptions { compression: None }; - let mut schema = flight::serialize_schema(&schema); + let mut schema = flight::serialize_schema(schema, fields); schema.flight_descriptor = Some(descriptor.clone()); upload_tx.send(schema).await?; @@ -89,7 +100,7 @@ async fn upload_data( if let Some((counter, first_batch)) = original_data_iter.next() { let metadata = counter.to_string().into_bytes(); // Preload the first batch into the channel before starting the request - send_batch(&mut upload_tx, &metadata, first_batch, &options).await?; + send_batch(&mut upload_tx, &metadata, first_batch, fields, &options).await?; let outer = client.do_put(Request::new(upload_rx)).await?; let mut inner = outer.into_inner(); @@ -104,7 +115,7 @@ async fn upload_data( // Stream the rest of the batches for (counter, batch) in original_data_iter { let metadata = counter.to_string().into_bytes(); - send_batch(&mut upload_tx, &metadata, batch, &options).await?; + send_batch(&mut upload_tx, &metadata, batch, fields, &options).await?; let r = inner .next() @@ -130,9 +141,10 @@ async fn send_batch( upload_tx: &mut mpsc::Sender, metadata: &[u8], batch: &RecordBatch, + fields: &[IpcField], options: &write::WriteOptions, ) -> Result { - let (dictionary_flight_data, mut batch_flight_data) = serialize_batch(batch, options); + let (dictionary_flight_data, mut batch_flight_data) = serialize_batch(batch, fields, options); upload_tx .send_all(&mut stream::iter(dictionary_flight_data).map(Ok)) @@ -148,6 +160,7 @@ async fn verify_data( mut client: Client, descriptor: FlightDescriptor, expected_schema: SchemaRef, + ipc_schema: &IpcSchema, expected_data: &[RecordBatch], ) -> Result { let resp = client.get_flight_info(Request::new(descriptor)).await?; @@ -172,6 +185,7 @@ async fn verify_data( ticket.clone(), expected_data, expected_schema.clone(), + ipc_schema, ) .await?; } @@ -185,6 +199,7 @@ async fn consume_flight_location( ticket: Ticket, expected_data: &[RecordBatch], schema: SchemaRef, + ipc_schema: &IpcSchema, ) -> Result { let mut location = location; // The other Flight implementations use the `grpc+tcp` scheme, but the Rust http libs @@ -202,20 +217,21 @@ async fn consume_flight_location( let mut dictionaries = Default::default(); for (counter, expected_batch) in expected_data.iter().enumerate() { - let data = receive_batch_flight_data(&mut resp, schema.clone(), &mut dictionaries) - .await - .unwrap_or_else(|| { - panic!( - "Got fewer batches than expected, received so far: {} expected: {}", - counter, - expected_data.len(), - ) - }); + let data = + receive_batch_flight_data(&mut resp, schema.fields(), ipc_schema, &mut dictionaries) + .await + .unwrap_or_else(|| { + panic!( + "Got fewer batches than expected, received so far: {} expected: {}", + counter, + expected_data.len(), + ) + }); let metadata = counter.to_string().into_bytes(); assert_eq!(metadata, data.app_metadata); - let actual_batch = deserialize_batch(&data, schema.clone(), true, &dictionaries) + let actual_batch = deserialize_batch(&data, schema.clone(), ipc_schema, &dictionaries) .expect("Unable to convert flight data to Arrow batch"); assert_eq!(expected_batch.schema(), actual_batch.schema()); @@ -244,8 +260,9 @@ async fn consume_flight_location( async fn receive_batch_flight_data( resp: &mut Streaming, - schema: SchemaRef, - dictionaries: &mut HashMap>, + fields: &[Field], + ipc_schema: &IpcSchema, + dictionaries: &mut Dictionaries, ) -> Option { let mut data = resp.next().await?.ok()?; let mut message = @@ -257,8 +274,8 @@ async fn receive_batch_flight_data( message .header_as_dictionary_batch() .expect("Error parsing dictionary"), - &schema, - true, + fields, + ipc_schema, dictionaries, &mut reader, 0, diff --git a/integration-testing/src/flight_server_scenarios/integration_test.rs b/integration-testing/src/flight_server_scenarios/integration_test.rs index a91e2b7348e..82e99f503d7 100644 --- a/integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/integration-testing/src/flight_server_scenarios/integration_test.rs @@ -16,11 +16,12 @@ // under the License. use std::collections::HashMap; -use std::convert::TryFrom; use std::pin::Pin; use std::sync::Arc; -use arrow2::io::flight::{serialize_batch, serialize_schema}; +use arrow2::io::flight::{deserialize_schemas, serialize_batch, serialize_schema}; +use arrow2::io::ipc::read::Dictionaries; +use arrow2::io::ipc::IpcSchema; use arrow_format::flight::data::flight_descriptor::*; use arrow_format::flight::data::*; use arrow_format::flight::service::flight_service_server::*; @@ -28,8 +29,7 @@ use arrow_format::ipc::Message::{root_as_message, Message, MessageHeader}; use arrow_format::ipc::Schema as ArrowSchema; use arrow2::{ - array::Array, datatypes::*, io::flight::serialize_schema_to_info, io::ipc, - record_batch::RecordBatch, + datatypes::*, io::flight::serialize_schema_to_info, io::ipc, record_batch::RecordBatch, }; use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt}; @@ -61,6 +61,7 @@ pub async fn scenario_setup(port: &str) -> Result { #[derive(Debug, Clone)] struct IntegrationDataset { schema: Schema, + ipc_schema: IpcSchema, chunks: Vec, } @@ -110,7 +111,10 @@ impl FlightService for FlightServiceImpl { let options = ipc::write::WriteOptions { compression: None }; - let schema = std::iter::once(Ok(serialize_schema(&flight.schema))); + let schema = std::iter::once(Ok(serialize_schema( + &flight.schema, + &flight.ipc_schema.fields, + ))); let batches = flight .chunks @@ -118,7 +122,7 @@ impl FlightService for FlightServiceImpl { .enumerate() .flat_map(|(counter, batch)| { let (dictionary_flight_data, mut batch_flight_data) = - serialize_batch(batch, &options); + serialize_batch(batch, &flight.ipc_schema.fields, &options); // Only the record batch's FlightData gets app_metadata let metadata = counter.to_string().into_bytes(); @@ -171,10 +175,11 @@ impl FlightService for FlightServiceImpl { let total_records: usize = flight.chunks.iter().map(|chunk| chunk.num_rows()).sum(); - let schema = serialize_schema_to_info(&flight.schema).expect( - "Could not generate schema bytes from schema stored by a DoPut; \ + let schema = serialize_schema_to_info(&flight.schema, &flight.ipc_schema.fields) + .expect( + "Could not generate schema bytes from schema stored by a DoPut; \ this should be impossible", - ); + ); let info = FlightInfo { schema, @@ -211,7 +216,7 @@ impl FlightService for FlightServiceImpl { let key = descriptor.path[0].clone(); - let schema = Schema::try_from(&flight_data) + let (schema, ipc_schema) = deserialize_schemas(&flight_data.data_header) .map_err(|e| Status::invalid_argument(format!("Invalid schema: {:?}", e)))?; let schema_ref = Arc::new(schema.clone()); @@ -224,6 +229,7 @@ impl FlightService for FlightServiceImpl { if let Err(e) = save_uploaded_chunks( uploaded_chunks, schema_ref, + ipc_schema, input_stream, response_tx, schema, @@ -275,7 +281,8 @@ async fn record_batch_from_message( message: Message<'_>, data_body: &[u8], schema_ref: Arc, - dictionaries: &mut HashMap>, + ipc_schema: &IpcSchema, + dictionaries: &mut Dictionaries, ) -> Result { let ipc_batch = message .header_as_record_batch() @@ -286,8 +293,8 @@ async fn record_batch_from_message( let arrow_batch_result = ipc::read::read_record_batch( ipc_batch, schema_ref, + ipc_schema, None, - true, dictionaries, ArrowSchema::MetadataVersion::V5, &mut reader, @@ -301,8 +308,9 @@ async fn record_batch_from_message( async fn dictionary_from_message( message: Message<'_>, data_body: &[u8], - schema_ref: Arc, - dictionaries: &mut HashMap>, + fields: &[Field], + ipc_schema: &IpcSchema, + dictionaries: &mut Dictionaries, ) -> Result<(), Status> { let ipc_batch = message .header_as_dictionary_batch() @@ -311,7 +319,7 @@ async fn dictionary_from_message( let mut reader = std::io::Cursor::new(data_body); let dictionary_batch_result = - ipc::read::read_dictionary(ipc_batch, &schema_ref, true, dictionaries, &mut reader, 0); + ipc::read::read_dictionary(ipc_batch, fields, ipc_schema, dictionaries, &mut reader, 0); dictionary_batch_result .map_err(|e| Status::internal(format!("Could not convert to Dictionary: {:?}", e))) } @@ -319,6 +327,7 @@ async fn dictionary_from_message( async fn save_uploaded_chunks( uploaded_chunks: Arc>>, schema_ref: Arc, + ipc_schema: IpcSchema, mut input_stream: Streaming, mut response_tx: mpsc::Sender>, schema: Schema, @@ -346,6 +355,7 @@ async fn save_uploaded_chunks( message, &data.data_body, schema_ref.clone(), + &ipc_schema, &mut dictionaries, ) .await?; @@ -356,7 +366,8 @@ async fn save_uploaded_chunks( dictionary_from_message( message, &data.data_body, - schema_ref.clone(), + schema_ref.fields(), + &ipc_schema, &mut dictionaries, ) .await?; @@ -371,7 +382,11 @@ async fn save_uploaded_chunks( } } - let dataset = IntegrationDataset { schema, chunks }; + let dataset = IntegrationDataset { + schema, + chunks, + ipc_schema, + }; uploaded_chunks.insert(key, dataset); Ok(()) diff --git a/integration-testing/src/lib.rs b/integration-testing/src/lib.rs index 7b4942af4a2..6da3ff17081 100644 --- a/integration-testing/src/lib.rs +++ b/integration-testing/src/lib.rs @@ -17,13 +17,12 @@ //! Common code used in the integration test binaries -use std::convert::TryFrom; - +use arrow2::io::ipc::IpcField; use serde_json::Value; use arrow2::datatypes::*; use arrow2::error::Result; -use arrow2::io::json_integration::*; +use arrow2::io::json_integration::{read, ArrowJsonBatch, ArrowJsonDictionaryBatch}; use arrow2::record_batch::RecordBatch; use std::collections::HashMap; @@ -40,6 +39,7 @@ pub mod flight_server_scenarios; pub struct ArrowFile { pub schema: Schema, + pub fields: Vec, // we can evolve this into a concrete Arrow type // this is temporarily not being read from pub _dictionaries: HashMap, @@ -51,7 +51,7 @@ pub fn read_json_file(json_name: &str) -> Result { let reader = BufReader::new(json_file); let arrow_json: Value = serde_json::from_reader(reader).unwrap(); - let schema = Schema::try_from(&arrow_json["schema"])?; + let (schema, fields) = read::deserialize_schema(&arrow_json["schema"])?; // read dictionaries let mut dictionaries = HashMap::new(); if let Some(dicts) = arrow_json.get("dictionaries") { @@ -69,11 +69,12 @@ pub fn read_json_file(json_name: &str) -> Result { let mut batches = vec![]; for b in arrow_json["batches"].as_array().unwrap() { let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap(); - let batch = to_record_batch(&schema, &json_batch, &dictionaries)?; + let batch = read::to_record_batch(&schema, &fields, &json_batch, &dictionaries)?; batches.push(batch); } Ok(ArrowFile { schema, + fields, _dictionaries: dictionaries, batches, }) diff --git a/src/io/flight/mod.rs b/src/io/flight/mod.rs index 70720a180b0..21ef52dc572 100644 --- a/src/io/flight/mod.rs +++ b/src/io/flight/mod.rs @@ -1,4 +1,3 @@ -use std::convert::TryFrom; use std::sync::Arc; use arrow_format::flight::data::{FlightData, SchemaResult}; @@ -82,11 +81,12 @@ fn schema_as_encoded_data(schema: &Schema, ipc_fields: &[IpcField]) -> EncodedDa } } -/// Deserialize an IPC message into a schema -fn schema_from_bytes(bytes: &[u8]) -> Result { +/// Deserialize an IPC message into [`Schema`], [`IpcSchema`]. +/// Use to deserialize [`FlightData::data_header`] and [`SchemaResult::schema`]. +pub fn deserialize_schemas(bytes: &[u8]) -> Result<(Schema, IpcSchema)> { if let Ok(ipc) = ipc::Message::root_as_message(bytes) { - if let Some((schema, _)) = ipc.header_as_schema().map(read::fb_to_schema) { - Ok(schema) + if let Some(schemas) = ipc.header_as_schema().map(read::fb_to_schema) { + Ok(schemas) } else { Err(ArrowError::OutOfSpec( "Unable to get head as schema".to_string(), @@ -99,30 +99,6 @@ fn schema_from_bytes(bytes: &[u8]) -> Result { } } -impl TryFrom<&FlightData> for Schema { - type Error = ArrowError; - fn try_from(data: &FlightData) -> Result { - schema_from_bytes(&data.data_header[..]).map_err(|err| { - ArrowError::OutOfSpec(format!( - "Unable to convert flight data to Arrow schema: {}", - err - )) - }) - } -} - -impl TryFrom<&SchemaResult> for Schema { - type Error = ArrowError; - fn try_from(data: &SchemaResult) -> Result { - schema_from_bytes(&data.schema[..]).map_err(|err| { - ArrowError::OutOfSpec(format!( - "Unable to convert schema result to Arrow schema: {}", - err - )) - }) - } -} - /// Deserializes [`FlightData`] to a [`RecordBatch`]. pub fn deserialize_batch( data: &FlightData, From e73f94d6d6984bcd3ff6d960caabc1c096143fa8 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 26 Dec 2021 22:21:41 +0000 Subject: [PATCH 3/6] Updated example --- src/io/ipc/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/io/ipc/mod.rs b/src/io/ipc/mod.rs index bcc3e9259f7..dd4b93590af 100644 --- a/src/io/ipc/mod.rs +++ b/src/io/ipc/mod.rs @@ -44,7 +44,7 @@ //! let y_coord = Field::new("y", DataType::Int32, false); //! let schema = Schema::new(vec![x_coord, y_coord]); //! let options = WriteOptions {compression: None}; -//! let mut writer = FileWriter::try_new(file, &schema, options)?; +//! let mut writer = FileWriter::try_new(file, &schema, None, options)?; //! //! // Setup the data //! let x_data = Int32Array::from_slice([-1i32, 1]); @@ -56,7 +56,7 @@ //! //! // Write the messages and finalize the stream //! for _ in 0..5 { -//! writer.write(&batch); +//! writer.write(&batch, None); //! } //! writer.finish(); //! From 8ebc2a851d13ea027faa8324f1e6517eed868af5 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 26 Dec 2021 22:37:31 +0000 Subject: [PATCH 4/6] chore --- arrow-parquet-integration-testing/src/main.rs | 16 +++++++++------- tests/it/io/ipc/write_async.rs | 4 ++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/arrow-parquet-integration-testing/src/main.rs b/arrow-parquet-integration-testing/src/main.rs index 28f161095ce..ba60e7b8462 100644 --- a/arrow-parquet-integration-testing/src/main.rs +++ b/arrow-parquet-integration-testing/src/main.rs @@ -1,12 +1,14 @@ use std::fs::File; use std::sync::Arc; -use std::{collections::HashMap, convert::TryFrom, io::Read}; +use std::{collections::HashMap, io::Read}; +use arrow2::io::ipc::IpcField; use arrow2::{ datatypes::{DataType, Schema}, error::Result, io::{ - json_integration::{to_record_batch, ArrowJson}, + json_integration::read, + json_integration::ArrowJson, parquet::write::{ write_file, Compression, Encoding, RowGroupIterator, Version, WriteOptions, }, @@ -19,7 +21,7 @@ use clap::{App, Arg}; use flate2::read::GzDecoder; /// Read gzipped JSON file -fn read_gzip_json(version: &str, file_name: &str) -> (Schema, Vec) { +fn read_gzip_json(version: &str, file_name: &str) -> (Schema, Vec, Vec) { let path = format!( "../testing/arrow-testing/data/arrow-ipc-stream/integration/{}/{}.json.gz", version, file_name @@ -32,7 +34,7 @@ fn read_gzip_json(version: &str, file_name: &str) -> (Schema, Vec) let arrow_json: ArrowJson = serde_json::from_str(&s).unwrap(); let schema = serde_json::to_value(arrow_json.schema).unwrap(); - let schema = Schema::try_from(&schema).unwrap(); + let (schema, ipc_fields) = read::deserialize_schema(&schema).unwrap(); // read dictionaries let mut dictionaries = HashMap::new(); @@ -46,11 +48,11 @@ fn read_gzip_json(version: &str, file_name: &str) -> (Schema, Vec) let batches = arrow_json .batches .iter() - .map(|batch| to_record_batch(&schema, batch, &dictionaries)) + .map(|batch| read::to_record_batch(&schema, &ipc_fields, batch, &dictionaries)) .collect::>>() .unwrap(); - (schema, batches) + (schema, ipc_fields, batches) } fn main() -> Result<()> { @@ -106,7 +108,7 @@ fn main() -> Result<()> { .collect::>() }); - let (schema, batches) = read_gzip_json("1.0.0-littleendian", json_file); + let (schema, _, batches) = read_gzip_json("1.0.0-littleendian", json_file); let schema = if let Some(projection) = &projection { let fields = schema diff --git a/tests/it/io/ipc/write_async.rs b/tests/it/io/ipc/write_async.rs index 18110bbe3bd..e77558ce510 100644 --- a/tests/it/io/ipc/write_async.rs +++ b/tests/it/io/ipc/write_async.rs @@ -20,9 +20,9 @@ async fn write_( let options = stream_async::WriteOptions { compression: None }; let mut writer = stream_async::StreamWriter::new(&mut result, options); - writer.start(&schema, Some(&ipc_fields)).await?; + writer.start(schema, Some(ipc_fields)).await?; for batch in batches { - writer.write(batch, Some(&ipc_fields)).await?; + writer.write(batch, Some(ipc_fields)).await?; } writer.finish().await?; Ok(result.into_inner()) From 45f2fa96ceb6b8bd321659121abf4064b36f4b9f Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 26 Dec 2021 22:47:39 +0000 Subject: [PATCH 5/6] Chore --- tests/it/io/ipc/write/stream.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/it/io/ipc/write/stream.rs b/tests/it/io/ipc/write/stream.rs index 5384de7377f..80fe5f505fe 100644 --- a/tests/it/io/ipc/write/stream.rs +++ b/tests/it/io/ipc/write/stream.rs @@ -16,7 +16,7 @@ fn write_(schema: &Schema, ipc_fields: &[IpcField], batches: &[RecordBatch]) -> let options = WriteOptions { compression: None }; let mut writer = StreamWriter::new(&mut result, options); - writer.start(&schema, ipc_fields).unwrap(); + writer.start(schema, ipc_fields).unwrap(); for batch in batches { writer.write(batch, ipc_fields).unwrap(); } From 78faa366c4598111bb8b2a9693bc94ef79397263 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Mon, 27 Dec 2021 06:44:23 +0000 Subject: [PATCH 6/6] Simplified reading --- src/io/ipc/read/schema.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/io/ipc/read/schema.rs b/src/io/ipc/read/schema.rs index 380116351e5..23e7b25767d 100644 --- a/src/io/ipc/read/schema.rs +++ b/src/io/ipc/read/schema.rs @@ -32,13 +32,17 @@ fn deserialize_field(ipc_field: ipc::Field) -> (Field, IpcField) { fn read_metadata(field: &ipc::Field) -> Metadata { if let Some(list) = field.custom_metadata() { - let mut metadata_map = BTreeMap::default(); - for kv in list { - if let (Some(k), Some(v)) = (kv.key(), kv.value()) { - metadata_map.insert(k.to_string(), v.to_string()); + if !list.is_empty() { + let mut metadata_map = BTreeMap::default(); + for kv in list { + if let (Some(k), Some(v)) = (kv.key(), kv.value()) { + metadata_map.insert(k.to_string(), v.to_string()); + } } + Some(metadata_map) + } else { + None } - Some(metadata_map) } else { None }