From 5aeee0e0a5545250cb90727ae5e072cb1111a60d Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Mon, 13 Jun 2022 15:02:23 +0000 Subject: [PATCH] IPC panic free --- Cargo.toml | 2 ++ README.md | 10 +++++++-- src/io/flight/mod.rs | 2 ++ src/io/ipc/read/common.rs | 19 ++++++++++++++--- src/io/ipc/read/file_async.rs | 13 +++++++++-- src/io/ipc/read/reader.rs | 26 +++++++++++++++------- src/io/ipc/read/stream.rs | 18 +++++++++++++--- src/io/ipc/read/stream_async.rs | 24 ++++++++++++++++----- tests/it/io/ipc/read/file.rs | 38 +++++++++++++++++++++++++++++++++ 9 files changed, 129 insertions(+), 23 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 549965c2c2c..44989feb225 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -101,6 +101,8 @@ tokio-util = { version = "0.6", features = ["compat"] } # used to run formal property testing proptest = { version = "1", default_features = false, features = ["std"] } avro-rs = { version = "0.13", features = ["snappy"] } +# use for flaky testing +rand = "0.8" [package.metadata.docs.rs] features = ["full"] diff --git a/README.md b/README.md index 07ebc960de5..1bb4b25ccca 100644 --- a/README.md +++ b/README.md @@ -69,8 +69,14 @@ Most uses of `unsafe` fall into 3 categories: We actively monitor for vulnerabilities in Rust's advisory and either patch or mitigate them (see e.g. `.cargo/audit.yaml` and `.github/workflows/security.yaml`). -Reading parquet and IPC currently `panic!` when they receive invalid. We are -actively addressing this. +Reading from untrusted data currently _may_ `panic!` on the following formats: + +* parquet +* avro +* Arrow IPC streams +* compressed Arrow IPC files and streams + +We are actively addressing this. ## Integration tests diff --git a/src/io/flight/mod.rs b/src/io/flight/mod.rs index 524b3c140e0..4b848e3adf2 100644 --- a/src/io/flight/mod.rs +++ b/src/io/flight/mod.rs @@ -118,6 +118,7 @@ pub fn deserialize_batch( let message = arrow_format::ipc::MessageRef::read_as_root(&data.data_header) .map_err(|err| Error::OutOfSpec(format!("Unable to get root as message: {:?}", err)))?; + let length = data.data_body.len(); let mut reader = std::io::Cursor::new(&data.data_body); match message.header()?.ok_or_else(|| { @@ -132,6 +133,7 @@ pub fn deserialize_batch( message.version()?, &mut reader, 0, + length as u64, ), _ => Err(Error::nyi( "flight currently only supports reading RecordBatch messages", diff --git a/src/io/ipc/read/common.rs b/src/io/ipc/read/common.rs index 1190b158879..221fbefb258 100644 --- a/src/io/ipc/read/common.rs +++ b/src/io/ipc/read/common.rs @@ -83,6 +83,7 @@ pub fn read_record_batch( version: arrow_format::ipc::MetadataVersion, reader: &mut R, block_offset: u64, + file_size: u64, ) -> Result>> { assert_eq!(fields.len(), ipc_schema.fields.len()); let buffers = batch @@ -90,6 +91,14 @@ pub fn read_record_batch( .ok_or_else(|| Error::oos("IPC RecordBatch must contain buffers"))?; let mut buffers: VecDeque = buffers.iter().collect(); + for buffer in buffers.iter() { + if buffer.length() as u64 > file_size { + return Err(Error::oos( + "Any buffer's length must be smaller than the size of the file", + )); + } + } + let field_nodes = batch .nodes()? .ok_or_else(|| Error::oos("IPC RecordBatch must contain field nodes"))?; @@ -205,6 +214,7 @@ pub fn read_dictionary( dictionaries: &mut Dictionaries, reader: &mut R, block_offset: u64, + file_size: u64, ) -> Result<()> { if batch.is_delta()? { return Err(Error::NotYetImplemented( @@ -220,6 +230,10 @@ pub fn read_dictionary( // Get an array representing this dictionary's values. let dictionary_values: Box = match &first_field.data_type { DataType::Dictionary(_, ref value_type, _) => { + let batch = batch + .data()? + .ok_or_else(|| Error::oos("The dictionary batch must have data."))?; + // Make a fake schema for the dictionary batch. let fields = vec![Field::new("", value_type.as_ref().clone(), false)]; let ipc_schema = IpcSchema { @@ -227,9 +241,7 @@ pub fn read_dictionary( is_little_endian: ipc_schema.is_little_endian, }; let columns = read_record_batch( - batch - .data()? - .ok_or_else(|| Error::oos("The dictionary batch must have data."))?, + batch, &fields, &ipc_schema, None, @@ -237,6 +249,7 @@ pub fn read_dictionary( arrow_format::ipc::MetadataVersion::V5, reader, block_offset, + file_size, )?; let mut arrays = columns.into_arrays(); Some(arrays.pop().unwrap()) diff --git a/src/io/ipc/read/file_async.rs b/src/io/ipc/read/file_async.rs index 90f90cadc6b..50923a50e61 100644 --- a/src/io/ipc/read/file_async.rs +++ b/src/io/ipc/read/file_async.rs @@ -144,7 +144,7 @@ where reader.seek(SeekFrom::End(-10 - footer_size as i64)).await?; reader.read_exact(&mut footer).await?; - deserialize_footer(&footer) + deserialize_footer(&footer, u64::MAX) } async fn read_batch( @@ -188,6 +188,7 @@ where message.version()?, &mut cursor, 0, + metadata.size, ) } @@ -220,7 +221,15 @@ where buffer.resize(length, 0); reader.read_exact(&mut buffer).await?; let mut cursor = std::io::Cursor::new(&mut buffer); - read_dictionary(batch, fields, ipc_schema, &mut dictionaries, &mut cursor, 0)?; + read_dictionary( + batch, + fields, + ipc_schema, + &mut dictionaries, + &mut cursor, + 0, + u64::MAX, + )?; } other => { return Err(Error::OutOfSpec(format!( diff --git a/src/io/ipc/read/reader.rs b/src/io/ipc/read/reader.rs index 3bd82591d7f..c91efe13fba 100644 --- a/src/io/ipc/read/reader.rs +++ b/src/io/ipc/read/reader.rs @@ -30,6 +30,9 @@ pub struct FileMetadata { /// Dictionaries associated to each dict_id pub(crate) dictionaries: Option>, + + /// The total size of the file in bytes + pub(crate) size: u64, } /// Arrow File reader @@ -92,6 +95,7 @@ fn read_dictionary_block( dictionaries, reader, block_offset, + metadata.size, )?; } t => { @@ -126,9 +130,10 @@ pub fn read_file_dictionaries( } /// Reads the footer's length and magic number in footer -fn read_footer_len(reader: &mut R) -> Result { +fn read_footer_len(reader: &mut R) -> Result<(u64, usize)> { // read footer length and magic number in footer - reader.seek(SeekFrom::End(-10))?; + let end = reader.seek(SeekFrom::End(-10))? + 10; + let mut footer: [u8; 10] = [0; 10]; reader.read_exact(&mut footer)?; @@ -139,12 +144,14 @@ fn read_footer_len(reader: &mut R) -> Result { "Arrow file does not contain correct footer".to_string(), )); } - footer_len + let footer_len = footer_len .try_into() - .map_err(|_| Error::oos("The footer's lenght must be a positive number")) + .map_err(|_| Error::oos("The footer's lenght must be a positive number"))?; + + Ok((end, footer_len)) } -pub(super) fn deserialize_footer(footer_data: &[u8]) -> Result { +pub(super) fn deserialize_footer(footer_data: &[u8], size: u64) -> Result { let footer = arrow_format::ipc::FooterRef::read_as_root(footer_data) .map_err(|err| Error::OutOfSpec(format!("Unable to get root as footer: {:?}", err)))?; @@ -177,6 +184,7 @@ pub(super) fn deserialize_footer(footer_data: &[u8]) -> Result { ipc_schema, blocks, dictionaries, + size, }) } @@ -184,6 +192,7 @@ pub(super) fn deserialize_footer(footer_data: &[u8]) -> Result { pub fn read_file_metadata(reader: &mut R) -> Result { // check if header contain the correct magic bytes let mut magic_buffer: [u8; 6] = [0; 6]; + let start = reader.seek(SeekFrom::Current(0))?; reader.read_exact(&mut magic_buffer)?; if magic_buffer != ARROW_MAGIC { return Err(Error::OutOfSpec( @@ -191,14 +200,14 @@ pub fn read_file_metadata(reader: &mut R) -> Result( @@ -264,6 +273,7 @@ pub fn read_batch( message.version()?, reader, block.offset as u64 + block.meta_data_length as u64, + metadata.size, ) } diff --git a/src/io/ipc/read/stream.rs b/src/io/ipc/read/stream.rs index cc81148ebe6..5d1769a1e64 100644 --- a/src/io/ipc/read/stream.rs +++ b/src/io/ipc/read/stream.rs @@ -132,10 +132,16 @@ fn read_next( arrow_format::ipc::MessageHeaderRef::Schema(_) => Err(Error::oos("A stream ")), arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => { // read the block that makes up the record batch into a buffer + let length: usize = message + .body_length()? + .try_into() + .map_err(|_| Error::oos("The body length of a header must be larger than zero"))?; data_buffer.clear(); - data_buffer.resize(message.body_length()? as usize, 0); + data_buffer.resize(length, 0); reader.read_exact(data_buffer)?; + let file_size = data_buffer.len() as u64; + let mut reader = std::io::Cursor::new(data_buffer); read_record_batch( @@ -147,15 +153,20 @@ fn read_next( metadata.version, &mut reader, 0, + file_size, ) .map(|x| Some(StreamState::Some(x))) } arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => { // read the block that makes up the dictionary batch into a buffer - let mut buf = vec![0; message.body_length()? as usize]; + let length: usize = message + .body_length()? + .try_into() + .map_err(|_| Error::oos("The body length of a header must be larger than zero"))?; + let mut buf = vec![0; length]; reader.read_exact(&mut buf)?; - let mut dict_reader = std::io::Cursor::new(buf); + let mut dict_reader = std::io::Cursor::new(&buf); read_dictionary( batch, @@ -164,6 +175,7 @@ fn read_next( dictionaries, &mut dict_reader, 0, + buf.len() as u64, )?; // read the next message until we encounter a RecordBatch message diff --git a/src/io/ipc/read/stream_async.rs b/src/io/ipc/read/stream_async.rs index b59bdc612eb..5df116ca8ee 100644 --- a/src/io/ipc/read/stream_async.rs +++ b/src/io/ipc/read/stream_async.rs @@ -104,11 +104,17 @@ async fn maybe_next( })?; match header { - arrow_format::ipc::MessageHeaderRef::Schema(_) => Err(Error::oos("A stream ")), + arrow_format::ipc::MessageHeaderRef::Schema(_) => { + Err(Error::oos("A stream cannot contain a schema message")) + } arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => { // read the block that makes up the record batch into a buffer + let length: usize = message + .body_length()? + .try_into() + .map_err(|_| Error::oos("The body length of a header must be larger than zero"))?; state.data_buffer.clear(); - state.data_buffer.resize(message.body_length()? as usize, 0); + state.data_buffer.resize(length, 0); state.reader.read_exact(&mut state.data_buffer).await?; read_record_batch( @@ -120,15 +126,22 @@ async fn maybe_next( state.metadata.version, &mut std::io::Cursor::new(&state.data_buffer), 0, + state.data_buffer.len() as u64, ) .map(|chunk| Some(StreamState::Some((state, chunk)))) } arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => { // read the block that makes up the dictionary batch into a buffer - let mut buf = vec![0; message.body_length()? as usize]; - state.reader.read_exact(&mut buf).await?; + let length: usize = message + .body_length()? + .try_into() + .map_err(|_| Error::oos("The body length of a header must be larger than zero"))?; + let mut body = vec![0; length]; + state.reader.read_exact(&mut body).await?; + + let file_size = body.len() as u64; - let mut dict_reader = std::io::Cursor::new(buf); + let mut dict_reader = std::io::Cursor::new(body); read_dictionary( batch, @@ -137,6 +150,7 @@ async fn maybe_next( &mut state.dictionaries, &mut dict_reader, 0, + file_size, )?; // read the next message until we encounter a Chunk> message diff --git a/tests/it/io/ipc/read/file.rs b/tests/it/io/ipc/read/file.rs index 75dc9655d73..52253bafed6 100644 --- a/tests/it/io/ipc/read/file.rs +++ b/tests/it/io/ipc/read/file.rs @@ -191,3 +191,41 @@ fn read_projected() -> Result<()> { test_projection("1.0.0-littleendian", "generated_primitive", vec![2, 1]) } + +fn read_corrupted_ipc(data: Vec) -> Result<()> { + let mut file = std::io::Cursor::new(data); + + let metadata = read_file_metadata(&mut file)?; + let mut reader = FileReader::new(file, metadata, None); + + reader.try_for_each(|rhs| { + rhs?; + Result::Ok(()) + })?; + + Ok(()) +} + +#[test] +fn test_does_not_panic() { + use rand::Rng; // 0.8.0 + + let version = "1.0.0-littleendian"; + let file_name = "generated_primitive"; + let testdata = crate::test_util::arrow_test_data(); + let path = format!( + "{}/arrow-ipc-stream/integration/{}/{}.arrow_file", + testdata, version, file_name + ); + let original = std::fs::read(path).unwrap(); + + for errors in 0..1000 { + let mut data = original.clone(); + for _ in 0..errors { + let position: usize = rand::thread_rng().gen_range(0..data.len()); + let new_byte: u8 = rand::thread_rng().gen_range(0..u8::MAX); + data[position] = new_byte + } + let _ = read_corrupted_ipc(data); + } +}