From 1ef454f46f15c898374d65d3da9f3921e39a32a5 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Thu, 16 Jun 2022 16:56:15 +0000 Subject: [PATCH] Audited all IPC read code --- Cargo.toml | 2 +- integration-testing/Cargo.toml | 2 +- .../integration_test.rs | 5 +- src/io/ipc/mod.rs | 8 -- src/io/ipc/read/array/binary.rs | 9 +- src/io/ipc/read/array/boolean.rs | 8 +- src/io/ipc/read/array/fixed_size_binary.rs | 9 +- src/io/ipc/read/array/list.rs | 9 +- src/io/ipc/read/array/map.rs | 9 +- src/io/ipc/read/array/null.rs | 9 +- src/io/ipc/read/array/primitive.rs | 9 +- src/io/ipc/read/array/union.rs | 11 +- src/io/ipc/read/array/utf8.rs | 9 +- src/io/ipc/read/common.rs | 73 ++++++---- src/io/ipc/read/error.rs | 110 +++++++++++++++ src/io/ipc/read/file_async.rs | 80 +++++++---- src/io/ipc/read/mod.rs | 4 + src/io/ipc/read/read_basic.rs | 81 ++++++++---- src/io/ipc/read/reader.rs | 125 +++++++++++------- src/io/ipc/read/schema.rs | 28 +++- src/io/ipc/read/stream.rs | 51 +++---- src/io/ipc/read/stream_async.rs | 54 ++++---- 22 files changed, 496 insertions(+), 209 deletions(-) create mode 100644 src/io/ipc/read/error.rs diff --git a/Cargo.toml b/Cargo.toml index 44989feb225..e583b6c2246 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,7 @@ indexmap = { version = "^1.6", optional = true } # used to print columns in a nice columnar format comfy-table = { version = "5.0", optional = true, default-features = false } -arrow-format = { version = "0.6", optional = true, features = ["ipc"] } +arrow-format = { version = "0.7", optional = true, features = ["ipc"] } hex = { version = "^0.4", optional = true } diff --git a/integration-testing/Cargo.toml b/integration-testing/Cargo.toml index 11488b17aa2..4f9c1761bcb 100644 --- a/integration-testing/Cargo.toml +++ b/integration-testing/Cargo.toml @@ -29,7 +29,7 @@ logging = ["tracing-subscriber"] [dependencies] arrow2 = { path = "../", features = ["io_ipc", "io_ipc_compression", "io_flight", "io_json_integration"] } -arrow-format = { version = "0.6", features = ["full"] } +arrow-format = { version = "0.7", features = ["full"] } async-trait = "0.1.41" clap = { version = "^3", features = ["derive"] } futures = "0.3" diff --git a/integration-testing/src/flight_server_scenarios/integration_test.rs b/integration-testing/src/flight_server_scenarios/integration_test.rs index 87d182928ca..7e8a42ce3e0 100644 --- a/integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/integration-testing/src/flight_server_scenarios/integration_test.rs @@ -283,6 +283,7 @@ async fn record_batch_from_message( ipc_schema: &IpcSchema, dictionaries: &mut Dictionaries, ) -> Result>, Status> { + let length = data_body.len(); let mut reader = std::io::Cursor::new(data_body); let arrow_batch_result = ipc::read::read_record_batch( @@ -294,6 +295,7 @@ async fn record_batch_from_message( arrow_format::ipc::MetadataVersion::V5, &mut reader, 0, + length, ); arrow_batch_result.map_err(|e| Status::internal(format!("Could not convert to Chunk: {:?}", e))) @@ -306,10 +308,11 @@ async fn dictionary_from_message( ipc_schema: &IpcSchema, dictionaries: &mut Dictionaries, ) -> Result<(), Status> { + let length = data_body.len(); let mut reader = std::io::Cursor::new(data_body); let dictionary_batch_result = - ipc::read::read_dictionary(dict_batch, fields, ipc_schema, dictionaries, &mut reader, 0); + ipc::read::read_dictionary(dict_batch, fields, ipc_schema, dictionaries, &mut reader, 0, length); dictionary_batch_result .map_err(|e| Status::internal(format!("Could not convert to Dictionary: {:?}", e))) } diff --git a/src/io/ipc/mod.rs b/src/io/ipc/mod.rs index 3fa3eb01a6b..68124c8bfa0 100644 --- a/src/io/ipc/mod.rs +++ b/src/io/ipc/mod.rs @@ -73,8 +73,6 @@ //! [2](https://github.com/jorgecarleitao/arrow2/blob/main/examples/ipc_file_write.rs), //! [3](https://github.com/jorgecarleitao/arrow2/tree/main/examples/ipc_pyarrow)). -use crate::error::Error; - mod compression; mod endianess; @@ -103,9 +101,3 @@ pub struct IpcSchema { /// Endianness of the file pub is_little_endian: bool, } - -impl From for Error { - fn from(error: arrow_format::ipc::planus::Error) -> Self { - Error::OutOfSpec(error.to_string()) - } -} diff --git a/src/io/ipc/read/array/binary.rs b/src/io/ipc/read/array/binary.rs index f963f3b697a..f68471f699d 100644 --- a/src/io/ipc/read/array/binary.rs +++ b/src/io/ipc/read/array/binary.rs @@ -7,7 +7,7 @@ use crate::datatypes::DataType; use crate::error::{Error, Result}; use super::super::read_basic::*; -use super::super::{Compression, IpcBuffer, Node}; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind}; pub fn read_binary( field_nodes: &mut VecDeque, @@ -34,9 +34,14 @@ pub fn read_binary( compression, )?; + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let offsets: Buffer = read_buffer( buffers, - 1 + field_node.length() as usize, + 1 + length, reader, block_offset, is_little_endian, diff --git a/src/io/ipc/read/array/boolean.rs b/src/io/ipc/read/array/boolean.rs index ecdf240751c..4bfe9063a57 100644 --- a/src/io/ipc/read/array/boolean.rs +++ b/src/io/ipc/read/array/boolean.rs @@ -6,7 +6,7 @@ use crate::datatypes::DataType; use crate::error::{Error, Result}; use super::super::read_basic::*; -use super::super::{Compression, IpcBuffer, Node}; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind}; pub fn read_boolean( field_nodes: &mut VecDeque, @@ -24,7 +24,11 @@ pub fn read_boolean( )) })?; - let length = field_node.length() as usize; + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let validity = read_validity( buffers, field_node, diff --git a/src/io/ipc/read/array/fixed_size_binary.rs b/src/io/ipc/read/array/fixed_size_binary.rs index f1f7225a409..b2dad926511 100644 --- a/src/io/ipc/read/array/fixed_size_binary.rs +++ b/src/io/ipc/read/array/fixed_size_binary.rs @@ -6,7 +6,7 @@ use crate::datatypes::DataType; use crate::error::{Error, Result}; use super::super::read_basic::*; -use super::super::{Compression, IpcBuffer, Node}; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind}; pub fn read_fixed_size_binary( field_nodes: &mut VecDeque, @@ -33,7 +33,12 @@ pub fn read_fixed_size_binary( compression, )?; - let length = field_node.length() as usize * FixedSizeBinaryArray::get_size(&data_type); + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let length = length * FixedSizeBinaryArray::get_size(&data_type); let values = read_buffer( buffers, length, diff --git a/src/io/ipc/read/array/list.rs b/src/io/ipc/read/array/list.rs index a9780df4d35..f5d2eb5465e 100644 --- a/src/io/ipc/read/array/list.rs +++ b/src/io/ipc/read/array/list.rs @@ -11,7 +11,7 @@ use super::super::super::IpcField; use super::super::deserialize::{read, skip}; use super::super::read_basic::*; use super::super::Dictionaries; -use super::super::{Compression, IpcBuffer, Node, Version}; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind, Version}; #[allow(clippy::too_many_arguments)] pub fn read_list( @@ -45,9 +45,14 @@ where compression, )?; + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let offsets = read_buffer::( buffers, - 1 + field_node.length() as usize, + 1 + length, reader, block_offset, is_little_endian, diff --git a/src/io/ipc/read/array/map.rs b/src/io/ipc/read/array/map.rs index 9629786401c..594f8a495b1 100644 --- a/src/io/ipc/read/array/map.rs +++ b/src/io/ipc/read/array/map.rs @@ -10,7 +10,7 @@ use super::super::super::IpcField; use super::super::deserialize::{read, skip}; use super::super::read_basic::*; use super::super::Dictionaries; -use super::super::{Compression, IpcBuffer, Node, Version}; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind, Version}; #[allow(clippy::too_many_arguments)] pub fn read_map( @@ -41,9 +41,14 @@ pub fn read_map( compression, )?; + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let offsets = read_buffer::( buffers, - 1 + field_node.length() as usize, + 1 + length, reader, block_offset, is_little_endian, diff --git a/src/io/ipc/read/array/null.rs b/src/io/ipc/read/array/null.rs index 7515e7fd8e8..5cf5184c055 100644 --- a/src/io/ipc/read/array/null.rs +++ b/src/io/ipc/read/array/null.rs @@ -6,7 +6,7 @@ use crate::{ error::{Error, Result}, }; -use super::super::Node; +use super::super::{Node, OutOfSpecKind}; pub fn read_null(field_nodes: &mut VecDeque, data_type: DataType) -> Result { let field_node = field_nodes.pop_front().ok_or_else(|| { @@ -16,7 +16,12 @@ pub fn read_null(field_nodes: &mut VecDeque, data_type: DataType) -> Resul )) })?; - NullArray::try_new(data_type, field_node.length() as usize) + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + NullArray::try_new(data_type, length) } pub fn skip_null(field_nodes: &mut VecDeque) -> Result<()> { diff --git a/src/io/ipc/read/array/primitive.rs b/src/io/ipc/read/array/primitive.rs index 5477801d610..c7c7b780d89 100644 --- a/src/io/ipc/read/array/primitive.rs +++ b/src/io/ipc/read/array/primitive.rs @@ -6,7 +6,7 @@ use crate::error::{Error, Result}; use crate::{array::PrimitiveArray, types::NativeType}; use super::super::read_basic::*; -use super::super::{Compression, IpcBuffer, Node}; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind}; pub fn read_primitive( field_nodes: &mut VecDeque, @@ -36,9 +36,14 @@ where compression, )?; + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let values = read_buffer( buffers, - field_node.length() as usize, + length, reader, block_offset, is_little_endian, diff --git a/src/io/ipc/read/array/union.rs b/src/io/ipc/read/array/union.rs index 49dde87c44c..7c86f5e30bb 100644 --- a/src/io/ipc/read/array/union.rs +++ b/src/io/ipc/read/array/union.rs @@ -10,7 +10,7 @@ use super::super::super::IpcField; use super::super::deserialize::{read, skip}; use super::super::read_basic::*; use super::super::Dictionaries; -use super::super::{Compression, IpcBuffer, Node, Version}; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind, Version}; #[allow(clippy::too_many_arguments)] pub fn read_union( @@ -38,9 +38,14 @@ pub fn read_union( .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; }; + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let types = read_buffer( buffers, - field_node.length() as usize, + length, reader, block_offset, is_little_endian, @@ -51,7 +56,7 @@ pub fn read_union( if !mode.is_sparse() { Some(read_buffer( buffers, - field_node.length() as usize, + length, reader, block_offset, is_little_endian, diff --git a/src/io/ipc/read/array/utf8.rs b/src/io/ipc/read/array/utf8.rs index d51627c3abb..8424fbb2e73 100644 --- a/src/io/ipc/read/array/utf8.rs +++ b/src/io/ipc/read/array/utf8.rs @@ -7,7 +7,7 @@ use crate::datatypes::DataType; use crate::error::{Error, Result}; use super::super::read_basic::*; -use super::super::{Compression, IpcBuffer, Node}; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind}; pub fn read_utf8( field_nodes: &mut VecDeque, @@ -34,9 +34,14 @@ pub fn read_utf8( compression, )?; + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let offsets: Buffer = read_buffer( buffers, - 1 + field_node.length() as usize, + 1 + length, reader, block_offset, is_little_endian, diff --git a/src/io/ipc/read/common.rs b/src/io/ipc/read/common.rs index 221fbefb258..f2a9eaa8522 100644 --- a/src/io/ipc/read/common.rs +++ b/src/io/ipc/read/common.rs @@ -7,6 +7,7 @@ use crate::array::*; use crate::chunk::Chunk; use crate::datatypes::{DataType, Field}; use crate::error::{Error, Result}; +use crate::io::ipc::read::OutOfSpecKind; use crate::io::ipc::{IpcField, IpcSchema}; use super::deserialize::{read, skip}; @@ -87,21 +88,33 @@ pub fn read_record_batch( ) -> Result>> { assert_eq!(fields.len(), ipc_schema.fields.len()); let buffers = batch - .buffers()? - .ok_or_else(|| Error::oos("IPC RecordBatch must contain buffers"))?; + .buffers() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferBuffers(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageBuffers))?; let mut buffers: VecDeque = buffers.iter().collect(); - for buffer in buffers.iter() { - if buffer.length() as u64 > file_size { - return Err(Error::oos( - "Any buffer's length must be smaller than the size of the file", - )); - } + // check that the sum of the sizes of all buffers is <= than the size of the file + let buffers_size = buffers + .iter() + .map(|buffer| { + let buffer_size: u64 = buffer + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + Ok(buffer_size) + }) + .sum::>()?; + if buffers_size > file_size { + return Err(Error::from(OutOfSpecKind::InvalidBuffersLength { + buffers_size, + file_size, + })); } let field_nodes = batch - .nodes()? - .ok_or_else(|| Error::oos("IPC RecordBatch must contain field nodes"))?; + .nodes() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferNodes(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageNodes))?; let mut field_nodes = field_nodes.iter().collect::>(); let columns = if let Some(projection) = projection { @@ -119,7 +132,9 @@ pub fn read_record_batch( dictionaries, block_offset, ipc_schema.is_little_endian, - batch.compression()?, + batch.compression().map_err(|err| { + Error::from(OutOfSpecKind::InvalidFlatbufferCompression(err)) + })?, version, )?)), ProjectionResult::NotSelected((field, _)) => { @@ -143,7 +158,9 @@ pub fn read_record_batch( dictionaries, block_offset, ipc_schema.is_little_endian, - batch.compression()?, + batch.compression().map_err(|err| { + Error::from(OutOfSpecKind::InvalidFlatbufferCompression(err)) + })?, version, ) }) @@ -199,10 +216,7 @@ fn first_dict_field<'a>( return Ok(field); } } - Err(Error::OutOfSpec(format!( - "dictionary id {} not found in schema", - id - ))) + Err(Error::from(OutOfSpecKind::InvalidId { requested_id: id })) } /// Read the dictionary from the buffer and provided metadata, @@ -216,23 +230,29 @@ pub fn read_dictionary( block_offset: u64, file_size: u64, ) -> Result<()> { - if batch.is_delta()? { + if batch + .is_delta() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferIsDelta(err)))? + { return Err(Error::NotYetImplemented( "delta dictionary batches not supported".to_string(), )); } - let id = batch.id()?; + let id = batch + .id() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferId(err)))?; let (first_field, first_ipc_field) = first_dict_field(id, fields, &ipc_schema.fields)?; // As the dictionary batch does not contain the type of the // values array, we need to retrieve this from the schema. // Get an array representing this dictionary's values. - let dictionary_values: Box = match &first_field.data_type { + let dictionary_values: Box = match first_field.data_type.to_logical_type() { DataType::Dictionary(_, ref value_type, _) => { let batch = batch - .data()? - .ok_or_else(|| Error::oos("The dictionary batch must have data."))?; + .data() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferData(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingData))?; // Make a fake schema for the dictionary batch. let fields = vec![Field::new("", value_type.as_ref().clone(), false)]; @@ -252,11 +272,14 @@ pub fn read_dictionary( file_size, )?; let mut arrays = columns.into_arrays(); - Some(arrays.pop().unwrap()) + arrays.pop().unwrap() } - _ => None, - } - .ok_or_else(|| Error::InvalidArgumentError("dictionary id not found in schema".to_string()))?; + _ => { + return Err(Error::from(OutOfSpecKind::InvalidIdDataType { + requested_id: id, + })) + } + }; dictionaries.insert(id, dictionary_values); diff --git a/src/io/ipc/read/error.rs b/src/io/ipc/read/error.rs new file mode 100644 index 00000000000..0fdd6a01c94 --- /dev/null +++ b/src/io/ipc/read/error.rs @@ -0,0 +1,110 @@ +use crate::error::Error; + +/// The different types of errors that reading from IPC can cause +#[derive(Debug)] +#[non_exhaustive] +pub enum OutOfSpecKind { + /// The IPC file does not start with [b'A', b'R', b'R', b'O', b'W', b'1'] + InvalidHeader, + /// The IPC file does not end with [b'A', b'R', b'R', b'O', b'W', b'1'] + InvalidFooter, + /// The first 4 bytes of the last 10 bytes is < 0 + NegativeFooterLength, + /// The footer is an invalid flatbuffer + InvalidFlatbufferFooter(arrow_format::ipc::planus::Error), + /// The file's footer does not contain record batches + MissingRecordBatches, + /// The footer's record batches is an invalid flatbuffer + InvalidFlatbufferRecordBatches(arrow_format::ipc::planus::Error), + /// The file's footer does not contain a schema + MissingSchema, + /// The footer's schema is an invalid flatbuffer + InvalidFlatbufferSchema(arrow_format::ipc::planus::Error), + /// The file's schema does not contain fields + MissingFields, + /// The footer's dictionaries is an invalid flatbuffer + InvalidFlatbufferDictionaries(arrow_format::ipc::planus::Error), + /// The block is an invalid flatbuffer + InvalidFlatbufferBlock(arrow_format::ipc::planus::Error), + /// The dictionary message is an invalid flatbuffer + InvalidFlatbufferMessage(arrow_format::ipc::planus::Error), + /// The message does not contain a header + MissingMessageHeader, + /// The message's header is an invalid flatbuffer + InvalidFlatbufferHeader(arrow_format::ipc::planus::Error), + /// Relative positions in the file is < 0 + UnexpectedNegativeInteger, + /// dictionaries can only contain dictionary messages; record batches can only contain records + UnexpectedMessageType, + /// RecordBatch messages do not contain buffers + MissingMessageBuffers, + /// The message's buffers is an invalid flatbuffer + InvalidFlatbufferBuffers(arrow_format::ipc::planus::Error), + /// RecordBatch messages does not contain nodes + MissingMessageNodes, + /// The message's nodes is an invalid flatbuffer + InvalidFlatbufferNodes(arrow_format::ipc::planus::Error), + /// The message's body length is an invalid flatbuffer + InvalidFlatbufferBodyLength(arrow_format::ipc::planus::Error), + /// The message does not contain data + MissingData, + /// The message's data is an invalid flatbuffer + InvalidFlatbufferData(arrow_format::ipc::planus::Error), + /// The version is an invalid flatbuffer + InvalidFlatbufferVersion(arrow_format::ipc::planus::Error), + /// The compression is an invalid flatbuffer + InvalidFlatbufferCompression(arrow_format::ipc::planus::Error), + /// The record contains a number of buffers that does not match the required number by the data type + ExpectedBuffer, + /// A buffer's size is smaller than the required for the number of elements + InvalidBuffer { + /// Declared number of elements in the buffer + length: usize, + /// The name of the `NativeType` + type_name: &'static str, + /// Bytes required for the `length` and `type` + required_number_of_bytes: usize, + /// The size of the IPC buffer + buffer_length: usize, + }, + /// A buffer's size is larger than the file size + InvalidBuffersLength { + /// number of bytes of all buffers in the record + buffers_size: u64, + /// the size of the file + file_size: u64, + }, + /// A bitmap's size is smaller than the required for the number of elements + InvalidBitmap { + /// Declared length of the bitmap + length: usize, + /// Number of bits on the IPC buffer + number_of_bits: usize, + }, + /// The dictionary is_delta is an invalid flatbuffer + InvalidFlatbufferIsDelta(arrow_format::ipc::planus::Error), + /// The dictionary id is an invalid flatbuffer + InvalidFlatbufferId(arrow_format::ipc::planus::Error), + /// Invalid dictionary id + InvalidId { + /// The requested dictionary id + requested_id: i64, + }, + /// Field id is not a dictionary + InvalidIdDataType { + /// The requested dictionary id + requested_id: i64, + }, +} + +impl From for Error { + fn from(kind: OutOfSpecKind) -> Self { + Error::OutOfSpec(format!("{:?}", kind)) + } +} + +impl From for Error { + fn from(error: arrow_format::ipc::planus::Error) -> Self { + Error::OutOfSpec(error.to_string()) + } +} diff --git a/src/io/ipc/read/file_async.rs b/src/io/ipc/read/file_async.rs index 50923a50e61..a4674167e3c 100644 --- a/src/io/ipc/read/file_async.rs +++ b/src/io/ipc/read/file_async.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::io::SeekFrom; -use arrow_format::ipc::{planus::ReadAsRoot, Block, MessageHeaderRef, MessageRef}; +use arrow_format::ipc::{planus::ReadAsRoot, Block, MessageHeaderRef}; use futures::{ stream::BoxStream, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, Stream, StreamExt, }; @@ -17,6 +17,7 @@ use super::common::{apply_projection, prepare_projection, read_dictionary, read_ use super::reader::{deserialize_footer, get_serialized_batch}; use super::Dictionaries; use super::FileMetadata; +use super::OutOfSpecKind; /// Async reader for Arrow IPC files pub struct FileStream<'a> { @@ -124,13 +125,11 @@ async fn read_footer_len(reader: &mut R) -> Re let footer_len = i32::from_le_bytes(footer[..4].try_into().unwrap()); if footer[4..] != ARROW_MAGIC { - return Err(Error::OutOfSpec( - "Arrow file does not contain correct footer".to_string(), - )); + return Err(Error::from(OutOfSpecKind::InvalidFooter)); } footer_len .try_into() - .map_err(|_| Error::oos("The footer's lenght must be a positive number")) + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength)) } /// Read the metadata from an IPC file. @@ -140,7 +139,7 @@ where { let footer_size = read_footer_len(reader).await?; // Read footer - let mut footer = vec![0; footer_size as usize]; + let mut footer = vec![0; footer_size]; reader.seek(SeekFrom::End(-10 - footer_size as i64)).await?; reader.read_exact(&mut footer).await?; @@ -160,22 +159,40 @@ where R: AsyncRead + AsyncSeek + Unpin, { let block = metadata.blocks[block]; - reader.seek(SeekFrom::Start(block.offset as u64)).await?; + + let offset: u64 = block + .offset + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + reader.seek(SeekFrom::Start(offset)).await?; let mut meta_buf = [0; 4]; reader.read_exact(&mut meta_buf).await?; if meta_buf == CONTINUATION_MARKER { reader.read_exact(&mut meta_buf).await?; } - let meta_len = i32::from_le_bytes(meta_buf) as usize; + + let meta_len = i32::from_le_bytes(meta_buf) + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; + meta_buffer.clear(); meta_buffer.resize(meta_len, 0); reader.read_exact(meta_buffer).await?; - let message = MessageRef::read_as_root(&meta_buffer[..]) - .map_err(|err| Error::oos(format!("unable to parse message: {:?}", err)))?; + let message = arrow_format::ipc::MessageRef::read_as_root(meta_buffer) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + let batch = get_serialized_batch(&message)?; + + let block_length: usize = message + .body_length() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferBodyLength(err)))? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; + block_buffer.clear(); - block_buffer.resize(message.body_length()? as usize, 0); + block_buffer.resize(block_length, 0); reader.read_exact(block_buffer).await?; let mut cursor = std::io::Cursor::new(block_buffer); @@ -185,7 +202,9 @@ where &metadata.ipc_schema, projection, dictionaries, - message.version()?, + message + .version() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferVersion(err)))?, &mut cursor, 0, metadata.size, @@ -206,15 +225,26 @@ where let mut buffer = vec![]; for block in blocks { - let offset = block.offset as u64; - let length = block.body_length as usize; + let offset: u64 = block + .offset + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let length: usize = block + .body_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + read_dictionary_message(&mut reader, offset, &mut data).await?; - let message = MessageRef::read_as_root(&data) - .map_err(|err| Error::OutOfSpec(format!("unable to get root as message: {:?}", err)))?; + let message = arrow_format::ipc::MessageRef::read_as_root(&data) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + let header = message - .header()? - .ok_or_else(|| Error::oos("message must have a header"))?; + .header() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageHeader))?; + match header { MessageHeaderRef::DictionaryBatch(batch) => { buffer.clear(); @@ -231,12 +261,7 @@ where u64::MAX, )?; } - other => { - return Err(Error::OutOfSpec(format!( - "expected DictionaryBatch in dictionary blocks, found {:?}", - other, - ))) - } + _ => return Err(Error::from(OutOfSpecKind::UnexpectedMessageType)), } } Ok(dictionaries) @@ -253,8 +278,13 @@ where reader.read_exact(&mut message_size).await?; } let footer_size = i32::from_le_bytes(message_size); + + let footer_size: usize = footer_size + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + data.clear(); - data.resize(footer_size as usize, 0); + data.resize(footer_size, 0); reader.read_exact(data).await?; Ok(()) diff --git a/src/io/ipc/read/mod.rs b/src/io/ipc/read/mod.rs index 283df26c53f..5ffe6426e20 100644 --- a/src/io/ipc/read/mod.rs +++ b/src/io/ipc/read/mod.rs @@ -11,10 +11,14 @@ use crate::array::Array; mod array; mod common; mod deserialize; +mod error; mod read_basic; pub(crate) mod reader; mod schema; mod stream; + +pub use error::OutOfSpecKind; + #[cfg(feature = "io_ipc_read_async")] #[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_read_async")))] pub mod stream_async; diff --git a/src/io/ipc/read/read_basic.rs b/src/io/ipc/read/read_basic.rs index 893d62fc4d3..38f6d349eef 100644 --- a/src/io/ipc/read/read_basic.rs +++ b/src/io/ipc/read/read_basic.rs @@ -7,7 +7,7 @@ use crate::{bitmap::Bitmap, types::NativeType}; use super::super::compression; use super::super::endianess::is_native_little_endian; -use super::{Compression, IpcBuffer, Node}; +use super::{Compression, IpcBuffer, Node, OutOfSpecKind}; fn read_swapped( reader: &mut R, @@ -49,8 +49,16 @@ fn read_uncompressed_buffer( length: usize, is_little_endian: bool, ) -> Result> { - let bytes = length * std::mem::size_of::(); - if bytes > buffer_length { + let required_number_of_bytes = length * std::mem::size_of::(); + if required_number_of_bytes > buffer_length { + return Err(Error::from(OutOfSpecKind::InvalidBuffer { + length, + type_name: std::any::type_name::(), + required_number_of_bytes, + buffer_length, + })); + // todo: move this to the error's Display + /* return Err(Error::OutOfSpec( format!("The slots of the array times the physical size must \ be smaller or equal to the length of the IPC buffer. \ @@ -62,6 +70,7 @@ fn read_uncompressed_buffer( buffer_length, ), )); + */ } // it is undefined behavior to call read_exact on un-initialized, https://doc.rust-lang.org/std/io/trait.Read.html#tymethod.read @@ -69,7 +78,7 @@ fn read_uncompressed_buffer( let mut buffer = vec![T::default(); length]; if is_native_little_endian() == is_little_endian { - // fast case where we can just copy the contents as is + // fast case where we can just copy the contents let slice = bytemuck::cast_slice_mut(&mut buffer); reader.read_exact(slice)?; } else { @@ -102,16 +111,19 @@ fn read_compressed_buffer( let out_slice = bytemuck::cast_slice_mut(&mut buffer); - match compression.codec()? { + let compression = compression + .codec() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferCompression(err)))?; + + match compression { arrow_format::ipc::CompressionType::Lz4Frame => { compression::decompress_lz4(&slice[8..], out_slice)?; - Ok(buffer) } arrow_format::ipc::CompressionType::Zstd => { compression::decompress_zstd(&slice[8..], out_slice)?; - Ok(buffer) } } + Ok(buffer) } pub fn read_buffer( @@ -124,11 +136,19 @@ pub fn read_buffer( ) -> Result> { let buf = buf .pop_front() - .ok_or_else(|| Error::oos("IPC: unable to fetch a buffer. The file is corrupted."))?; + .ok_or_else(|| Error::from(OutOfSpecKind::ExpectedBuffer))?; - reader.seek(SeekFrom::Start(block_offset + buf.offset() as u64))?; + let offset: u64 = buf + .offset() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; - let buffer_length = buf.length() as usize; + let buffer_length: usize = buf + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + reader.seek(SeekFrom::Start(block_offset + offset))?; if let Some(compression) = compression { Ok( @@ -146,13 +166,10 @@ fn read_uncompressed_bitmap( reader: &mut R, ) -> Result> { if length > bytes * 8 { - return Err(Error::OutOfSpec(format!( - "An array requires a bitmap with at least the same number of bits as slots. \ - However, this array reports {} slots but the the bitmap in IPC only contains \ - {} bits", + return Err(Error::from(OutOfSpecKind::InvalidBitmap { length, - bytes * 8, - ))); + number_of_bits: bytes * 8, + })); } // it is undefined behavior to call read_exact on un-initialized, https://doc.rust-lang.org/std/io/trait.Read.html#tymethod.read // see also https://github.com/MaikKlein/ash/issues/354#issue-781730580 @@ -175,16 +192,19 @@ fn read_compressed_bitmap( let mut slice = vec![0u8; bytes]; reader.read_exact(&mut slice)?; - match compression.codec()? { + let compression = compression + .codec() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferCompression(err)))?; + + match compression { arrow_format::ipc::CompressionType::Lz4Frame => { compression::decompress_lz4(&slice[8..], &mut buffer)?; - Ok(buffer) } arrow_format::ipc::CompressionType::Zstd => { compression::decompress_zstd(&slice[8..], &mut buffer)?; - Ok(buffer) } } + Ok(buffer) } pub fn read_bitmap( @@ -197,11 +217,19 @@ pub fn read_bitmap( ) -> Result { let buf = buf .pop_front() - .ok_or_else(|| Error::oos("IPC: unable to fetch a buffer. The file is corrupted."))?; + .ok_or_else(|| Error::from(OutOfSpecKind::ExpectedBuffer))?; + + let offset: u64 = buf + .offset() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; - reader.seek(SeekFrom::Start(block_offset + buf.offset() as u64))?; + let bytes: usize = buf + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; - let bytes = buf.length() as usize; + reader.seek(SeekFrom::Start(block_offset + offset))?; let buffer = if let Some(compression) = compression { read_compressed_bitmap(length, bytes, compression, reader) @@ -220,10 +248,15 @@ pub fn read_validity( is_little_endian: bool, compression: Option, ) -> Result> { + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + Ok(if field_node.null_count() > 0 { Some(read_bitmap( buffers, - field_node.length() as usize, + length, reader, block_offset, is_little_endian, @@ -232,7 +265,7 @@ pub fn read_validity( } else { let _ = buffers .pop_front() - .ok_or_else(|| Error::oos("IPC: unable to fetch a buffer. The file is corrupted."))?; + .ok_or_else(|| Error::from(OutOfSpecKind::ExpectedBuffer))?; None }) } diff --git a/src/io/ipc/read/reader.rs b/src/io/ipc/read/reader.rs index c91efe13fba..1e3f4459f9e 100644 --- a/src/io/ipc/read/reader.rs +++ b/src/io/ipc/read/reader.rs @@ -12,6 +12,7 @@ use super::super::{ARROW_MAGIC, CONTINUATION_MARKER}; use super::common::*; use super::schema::fb_to_schema; use super::Dictionaries; +use super::OutOfSpecKind; use arrow_format::ipc::planus::ReadAsRoot; /// Metadata of an Arrow IPC file, written in the footer of the file. @@ -59,9 +60,13 @@ fn read_dictionary_message( }; let message_length = i32::from_le_bytes(message_size); + let message_length: usize = message_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + // prepare `data` to read the message data.clear(); - data.resize(message_length as usize, 0); + data.resize(message_length, 0); reader.read_exact(data)?; Ok(()) @@ -74,16 +79,23 @@ fn read_dictionary_block( dictionaries: &mut Dictionaries, scratch: &mut Vec, ) -> Result<()> { - let offset = block.offset as u64; - let length = block.meta_data_length as u64; + let offset: u64 = block + .offset + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; + let length: u64 = block + .meta_data_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; read_dictionary_message(reader, offset, scratch)?; let message = arrow_format::ipc::MessageRef::read_as_root(scratch) - .map_err(|err| Error::OutOfSpec(format!("Unable to get root as message: {:?}", err)))?; + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; let header = message - .header()? - .ok_or_else(|| Error::oos("Message must have an header"))?; + .header() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageHeader))?; match header { arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => { @@ -96,16 +108,10 @@ fn read_dictionary_block( reader, block_offset, metadata.size, - )?; - } - t => { - return Err(Error::OutOfSpec(format!( - "Expecting DictionaryBatch in dictionary blocks, found {:?}.", - t - ))); + ) } - }; - Ok(()) + _ => Err(Error::from(OutOfSpecKind::UnexpectedMessageType)), + } } /// Reads all file's dictionaries, if any @@ -140,41 +146,50 @@ fn read_footer_len(reader: &mut R) -> Result<(u64, usize)> { let footer_len = i32::from_le_bytes(footer[..4].try_into().unwrap()); if footer[4..] != ARROW_MAGIC { - return Err(Error::OutOfSpec( - "Arrow file does not contain correct footer".to_string(), - )); + return Err(Error::from(OutOfSpecKind::InvalidFooter)); } let footer_len = footer_len .try_into() - .map_err(|_| Error::oos("The footer's lenght must be a positive number"))?; + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; Ok((end, footer_len)) } pub(super) fn deserialize_footer(footer_data: &[u8], size: u64) -> Result { let footer = arrow_format::ipc::FooterRef::read_as_root(footer_data) - .map_err(|err| Error::OutOfSpec(format!("Unable to get root as footer: {:?}", err)))?; + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferFooter(err)))?; let blocks = footer - .record_batches()? - .ok_or_else(|| Error::OutOfSpec("Unable to get record batches from footer".to_string()))?; + .record_batches() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferRecordBatches(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingRecordBatches))?; let blocks = blocks .iter() - .map(|block| Ok(block.try_into()?)) + .map(|block| { + block + .try_into() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferRecordBatches(err))) + }) .collect::>>()?; let ipc_schema = footer - .schema()? - .ok_or_else(|| Error::OutOfSpec("Unable to get the schema from footer".to_string()))?; + .schema() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferSchema(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingSchema))?; let (schema, ipc_schema) = fb_to_schema(ipc_schema)?; let dictionaries = footer - .dictionaries()? + .dictionaries() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferDictionaries(err)))? .map(|dictionaries| { dictionaries .into_iter() - .map(|x| Ok(x.try_into()?)) + .map(|block| { + block.try_into().map_err(|err| { + Error::from(OutOfSpecKind::InvalidFlatbufferRecordBatches(err)) + }) + }) .collect::>>() }) .transpose()?; @@ -188,43 +203,36 @@ pub(super) fn deserialize_footer(footer_data: &[u8], size: u64) -> Result(reader: &mut R) -> Result { // check if header contain the correct magic bytes let mut magic_buffer: [u8; 6] = [0; 6]; let start = reader.seek(SeekFrom::Current(0))?; reader.read_exact(&mut magic_buffer)?; if magic_buffer != ARROW_MAGIC { - return Err(Error::OutOfSpec( - "Arrow file does not contain correct header".to_string(), - )); + return Err(Error::from(OutOfSpecKind::InvalidHeader)); } let (end, footer_len) = read_footer_len(reader)?; // read footer - let mut footer_data = vec![0; footer_len]; + let mut serialized_footer = vec![0; footer_len]; reader.seek(SeekFrom::End(-10 - footer_len as i64))?; - reader.read_exact(&mut footer_data)?; + reader.read_exact(&mut serialized_footer)?; - deserialize_footer(&footer_data, end - start) + deserialize_footer(&serialized_footer, end - start) } pub(super) fn get_serialized_batch<'a>( message: &'a arrow_format::ipc::MessageRef, ) -> Result> { - let header = message.header()?.ok_or_else(|| { - Error::oos("IPC: unable to fetch the message header. The file or stream is corrupted.") - })?; + let header = message + .header() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageHeader))?; match header { - arrow_format::ipc::MessageHeaderRef::Schema(_) => Err(Error::OutOfSpec( - "Not expecting a schema when messages are read".to_string(), - )), arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => Ok(batch), - t => Err(Error::OutOfSpec(format!( - "Reading types other than record batches not yet supported, unable to read {:?}", - t - ))), + _ => Err(Error::from(OutOfSpecKind::UnexpectedMessageType)), } } @@ -245,34 +253,53 @@ pub fn read_batch( ) -> Result>> { let block = metadata.blocks[index]; + let offset: u64 = block + .offset + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + // read length - reader.seek(SeekFrom::Start(block.offset as u64))?; + reader.seek(SeekFrom::Start(offset))?; let mut meta_buf = [0; 4]; reader.read_exact(&mut meta_buf)?; if meta_buf == CONTINUATION_MARKER { // continuation marker encountered, read message next reader.read_exact(&mut meta_buf)?; } - let meta_len = i32::from_le_bytes(meta_buf) as usize; + let meta_len = i32::from_le_bytes(meta_buf) + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; stratch.clear(); stratch.resize(meta_len, 0); reader.read_exact(stratch)?; let message = arrow_format::ipc::MessageRef::read_as_root(stratch) - .map_err(|err| Error::oos(format!("Unable parse message: {:?}", err)))?; + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; let batch = get_serialized_batch(&message)?; + let offset: u64 = block + .offset + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let length: u64 = block + .meta_data_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + read_record_batch( batch, &metadata.schema.fields, &metadata.ipc_schema, projection, dictionaries, - message.version()?, + message + .version() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferVersion(err)))?, reader, - block.offset as u64 + block.meta_data_length as u64, + offset + length, metadata.size, ) } diff --git a/src/io/ipc/read/schema.rs b/src/io/ipc/read/schema.rs index b5925d00d2e..b54292baab8 100644 --- a/src/io/ipc/read/schema.rs +++ b/src/io/ipc/read/schema.rs @@ -10,7 +10,7 @@ use crate::{ use super::{ super::{IpcField, IpcSchema}, - StreamMetadata, + OutOfSpecKind, StreamMetadata, }; fn try_unzip_vec>>(iter: I) -> Result<(Vec, Vec)> { @@ -131,7 +131,12 @@ fn get_data_type( Utf8(_) => (DataType::Utf8, IpcField::default()), LargeUtf8(_) => (DataType::LargeUtf8, IpcField::default()), FixedSizeBinary(fixed) => ( - DataType::FixedSizeBinary(fixed.byte_width()? as usize), + DataType::FixedSizeBinary( + fixed + .byte_width()? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?, + ), IpcField::default(), ), FloatingPoint(float) => { @@ -193,8 +198,16 @@ fn get_data_type( (DataType::Duration(time_unit), IpcField::default()) } Decimal(decimal) => { - let data_type = - DataType::Decimal(decimal.precision()? as usize, decimal.scale()? as usize); + let data_type = DataType::Decimal( + decimal + .precision()? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?, + decimal + .scale()? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?, + ); (data_type, IpcField::default()) } List(_) => { @@ -240,7 +253,10 @@ fn get_data_type( .ok_or_else(|| Error::oos("IPC: FixedSizeList must contain one child"))??; let (field, ipc_field) = deserialize_field(inner)?; - let size = list.list_size()? as usize; + let size = list + .list_size()? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; ( DataType::FixedSizeList(Box::new(field), size), @@ -332,7 +348,7 @@ pub fn deserialize_schema(bytes: &[u8]) -> Result<(Schema, IpcSchema)> { pub(super) fn fb_to_schema(schema: arrow_format::ipc::SchemaRef) -> Result<(Schema, IpcSchema)> { let fields = schema .fields()? - .ok_or_else(|| Error::oos("IPC: Schema must contain fields"))?; + .ok_or_else(|| Error::from(OutOfSpecKind::MissingFields))?; let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| { let (field, fields) = deserialize_field(field?)?; Ok((field, fields)) diff --git a/src/io/ipc/read/stream.rs b/src/io/ipc/read/stream.rs index 5d1769a1e64..c49512faaaa 100644 --- a/src/io/ipc/read/stream.rs +++ b/src/io/ipc/read/stream.rs @@ -13,6 +13,7 @@ use super::super::CONTINUATION_MARKER; use super::common::*; use super::schema::deserialize_stream_metadata; use super::Dictionaries; +use super::OutOfSpecKind; /// Metadata of an Arrow IPC stream, written at the start of the stream #[derive(Debug, Clone)] @@ -32,7 +33,7 @@ pub fn read_stream_metadata(reader: &mut R) -> Result { // determine metadata length let mut meta_size: [u8; 4] = [0; 4]; reader.read_exact(&mut meta_size)?; - let meta_len = { + let meta_length = { // If a continuation marker is encountered, skip over it and read // the size from the next four bytes. if meta_size == CONTINUATION_MARKER { @@ -41,7 +42,11 @@ pub fn read_stream_metadata(reader: &mut R) -> Result { i32::from_le_bytes(meta_size) }; - let mut meta_buffer = vec![0; meta_len as usize]; + let meta_length: usize = meta_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let mut meta_buffer = vec![0; meta_length]; reader.read_exact(&mut meta_buffer)?; deserialize_stream_metadata(&meta_buffer) @@ -110,9 +115,13 @@ fn read_next( if meta_length == CONTINUATION_MARKER { reader.read_exact(&mut meta_length)?; } - i32::from_le_bytes(meta_length) as usize + i32::from_le_bytes(meta_length) }; + let meta_length: usize = meta_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + if meta_length == 0 { // the stream has ended, mark the reader as finished return Ok(None); @@ -123,21 +132,23 @@ fn read_next( reader.read_exact(message_buffer)?; let message = arrow_format::ipc::MessageRef::read_as_root(message_buffer) - .map_err(|err| Error::OutOfSpec(format!("Unable to get root as message: {:?}", err)))?; - let header = message.header()?.ok_or_else(|| { - Error::oos("IPC: unable to fetch the message header. The file or stream is corrupted.") - })?; + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let header = message + .header() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageHeader))?; + + let block_length: usize = message + .body_length() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferBodyLength(err)))? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; match header { - arrow_format::ipc::MessageHeaderRef::Schema(_) => Err(Error::oos("A stream ")), arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => { - // read the block that makes up the record batch into a buffer - let length: usize = message - .body_length()? - .try_into() - .map_err(|_| Error::oos("The body length of a header must be larger than zero"))?; data_buffer.clear(); - data_buffer.resize(length, 0); + data_buffer.resize(block_length, 0); reader.read_exact(data_buffer)?; let file_size = data_buffer.len() as u64; @@ -158,12 +169,7 @@ fn read_next( .map(|x| Some(StreamState::Some(x))) } arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => { - // read the block that makes up the dictionary batch into a buffer - let length: usize = message - .body_length()? - .try_into() - .map_err(|_| Error::oos("The body length of a header must be larger than zero"))?; - let mut buf = vec![0; length]; + let mut buf = vec![0; block_length]; reader.read_exact(&mut buf)?; let mut dict_reader = std::io::Cursor::new(&buf); @@ -181,10 +187,7 @@ fn read_next( // read the next message until we encounter a RecordBatch message read_next(reader, metadata, dictionaries, message_buffer, data_buffer) } - t => Err(Error::OutOfSpec(format!( - "Reading types other than record batches not yet supported, unable to read {:?} ", - t - ))), + _ => Err(Error::from(OutOfSpecKind::UnexpectedMessageType)), } } diff --git a/src/io/ipc/read/stream_async.rs b/src/io/ipc/read/stream_async.rs index 5df116ca8ee..1e130b48496 100644 --- a/src/io/ipc/read/stream_async.rs +++ b/src/io/ipc/read/stream_async.rs @@ -4,6 +4,7 @@ use arrow_format::ipc::planus::ReadAsRoot; use futures::future::BoxFuture; use futures::AsyncRead; use futures::AsyncReadExt; +use futures::FutureExt; use futures::Stream; use crate::array::*; @@ -14,6 +15,7 @@ use super::super::CONTINUATION_MARKER; use super::common::{read_dictionary, read_record_batch}; use super::schema::deserialize_stream_metadata; use super::Dictionaries; +use super::OutOfSpecKind; use super::StreamMetadata; /// A (private) state of stream messages @@ -51,7 +53,11 @@ pub async fn read_stream_metadata_async( i32::from_le_bytes(meta_size) }; - let mut meta_buffer = vec![0; meta_len as usize]; + let meta_len: usize = meta_len + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let mut meta_buffer = vec![0; meta_len]; reader.read_exact(&mut meta_buffer).await?; deserialize_stream_metadata(&meta_buffer) @@ -85,9 +91,13 @@ async fn maybe_next( if meta_length == CONTINUATION_MARKER { state.reader.read_exact(&mut meta_length).await?; } - i32::from_le_bytes(meta_length) as usize + i32::from_le_bytes(meta_length) }; + let meta_length: usize = meta_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + if meta_length == 0 { // the stream has ended, mark the reader as finished return Ok(None); @@ -98,23 +108,23 @@ async fn maybe_next( state.reader.read_exact(&mut state.message_buffer).await?; let message = arrow_format::ipc::MessageRef::read_as_root(&state.message_buffer) - .map_err(|err| Error::OutOfSpec(format!("Unable to get root as message: {:?}", err)))?; - let header = message.header()?.ok_or_else(|| { - Error::oos("IPC: unable to fetch the message header. The file or stream is corrupted.") - })?; + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let header = message + .header() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageHeader))?; + + let block_length: usize = message + .body_length() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferBodyLength(err)))? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; match header { - arrow_format::ipc::MessageHeaderRef::Schema(_) => { - Err(Error::oos("A stream cannot contain a schema message")) - } arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => { - // read the block that makes up the record batch into a buffer - let length: usize = message - .body_length()? - .try_into() - .map_err(|_| Error::oos("The body length of a header must be larger than zero"))?; state.data_buffer.clear(); - state.data_buffer.resize(length, 0); + state.data_buffer.resize(block_length, 0); state.reader.read_exact(&mut state.data_buffer).await?; read_record_batch( @@ -131,12 +141,7 @@ async fn maybe_next( .map(|chunk| Some(StreamState::Some((state, chunk)))) } arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => { - // read the block that makes up the dictionary batch into a buffer - let length: usize = message - .body_length()? - .try_into() - .map_err(|_| Error::oos("The body length of a header must be larger than zero"))?; - let mut body = vec![0; length]; + let mut body = vec![0; block_length]; state.reader.read_exact(&mut body).await?; let file_size = body.len() as u64; @@ -156,10 +161,7 @@ async fn maybe_next( // read the next message until we encounter a Chunk> message Ok(Some(StreamState::Waiting(state))) } - t => Err(Error::OutOfSpec(format!( - "Reading types other than record batches not yet supported, unable to read {:?} ", - t - ))), + _ => Err(Error::from(OutOfSpecKind::UnexpectedMessageType)), } } @@ -179,7 +181,7 @@ impl<'a, R: AsyncRead + Unpin + Send + 'a> AsyncStreamReader<'a, R> { data_buffer: Default::default(), message_buffer: Default::default(), }; - let future = Some(Box::pin(maybe_next(state)) as _); + let future = Some(maybe_next(state).boxed()); Self { metadata, future } }