diff --git a/src/compression.rs b/src/compression.rs index 2602f5c5a..efeb1b807 100644 --- a/src/compression.rs +++ b/src/compression.rs @@ -188,6 +188,20 @@ pub fn decompress(compression: Compression, input_buf: &[u8], output_buf: &mut [ crate::error::Feature::Lz4, "decompress with lz4".to_string(), )), + + #[cfg(any(feature = "lz4_flex", feature = "lz4"))] + Compression::Lz4 => try_decompress_hadoop(input_buf, output_buf).or_else(|_| { + lz4_decompress_to_buffer(input_buf, Some(output_buf.len() as i32), output_buf) + .map(|_| {}) + } + ), + + #[cfg(all(not(feature = "lz4_flex"), not(feature = "lz4")))] + Compression::Lz4 => Err(Error::FeatureNotActive( + crate::error::Feature::Lz4, + "decompress with legacy lz4".to_string(), + )), + #[cfg(feature = "zstd")] Compression::Zstd => { use std::io::Read; @@ -209,6 +223,92 @@ pub fn decompress(compression: Compression, input_buf: &[u8], output_buf: &mut [ } } +/// Try to decompress the buffer as if it was compressed with the Hadoop Lz4Codec. +/// Translated from the apache arrow c++ function [TryDecompressHadoop](https://github.com/apache/arrow/blob/bf18e6e4b5bb6180706b1ba0d597a65a4ce5ca48/cpp/src/arrow/util/compression_lz4.cc#L474). +/// Returns error if decompression failed. +#[cfg(any(feature = "lz4", feature = "lz4_flex"))] +fn try_decompress_hadoop(input_buf: &[u8], output_buf: &mut [u8]) -> Result<()> { + // Parquet files written with the Hadoop Lz4Codec use their own framing. + // The input buffer can contain an arbitrary number of "frames", each + // with the following structure: + // - bytes 0..3: big-endian uint32_t representing the frame decompressed size + // - bytes 4..7: big-endian uint32_t representing the frame compressed size + // - bytes 8...: frame compressed data + // + // The Hadoop Lz4Codec source code can be found here: + // https://github.com/apache/hadoop/blob/trunk/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-nativetask/src/main/native/src/codec/Lz4Codec.cc + + const SIZE_U32: usize = std::mem::size_of::(); + const PREFIX_LEN: usize = SIZE_U32 * 2; + let mut input_len = input_buf.len(); + let mut input = input_buf; + let mut output_len = output_buf.len(); + let mut output: &mut [u8] = output_buf; + while input_len >= PREFIX_LEN { + let mut bytes = [0; SIZE_U32]; + bytes.copy_from_slice(&input[0..4]); + let expected_decompressed_size = u32::from_be_bytes(bytes); + let mut bytes = [0; SIZE_U32]; + bytes.copy_from_slice(&input[4..8]); + let expected_compressed_size = u32::from_be_bytes(bytes); + input = &input[PREFIX_LEN..]; + input_len -= PREFIX_LEN; + + if input_len < expected_compressed_size as usize { + return Err(general_err!("Not enough bytes for Hadoop frame".to_owned())); + } + + if output_len < expected_decompressed_size as usize { + return Err(general_err!( + "Not enough bytes to hold advertised output".to_owned() + )); + } + let decompressed_size = lz4_decompress_to_buffer( + &input[..expected_compressed_size as usize], + Some(output_len as i32), + output, + )?; + if decompressed_size != expected_decompressed_size as usize { + return Err(general_err!("unexpected decompressed size")); + } + input_len -= expected_compressed_size as usize; + output_len -= expected_decompressed_size as usize; + if input_len > expected_compressed_size as usize { + input = &input[expected_compressed_size as usize..]; + output = &mut output[expected_decompressed_size as usize..]; + } else { + break; + } + } + if input_len == 0 { + Ok(()) + } else { + Err(general_err!("Not all input are consumed")) + } +} + +#[cfg(all(feature = "lz4", not(feature = "lz4_flex")))] +#[inline] +fn lz4_decompress_to_buffer( + src: &[u8], + uncompressed_size: Option, + buffer: &mut [u8], +) -> Result { + let size = lz4::block::decompress_to_buffer(src, uncompressed_size, buffer)?; + Ok(size) +} + +#[cfg(all(feature = "lz4_flex", not(feature = "lz4")))] +#[inline] +fn lz4_decompress_to_buffer( + src: &[u8], + _uncompressed_size: Option, + buffer: &mut [u8], +) -> Result { + let size = lz4_flex::block::decompress_into(src, buffer)?; + Ok(size) +} + #[cfg(test)] mod tests { use super::*; diff --git a/tests/it/read/lz4_legacy.rs b/tests/it/read/lz4_legacy.rs new file mode 100644 index 000000000..c840d83bd --- /dev/null +++ b/tests/it/read/lz4_legacy.rs @@ -0,0 +1,67 @@ +use crate::get_path; +use crate::read::get_column; +use crate::Array; +use parquet2::error::Result; + +fn verify_column_data(column: &str) -> Array { + match column { + "c0" => { + let expected = vec![1593604800, 1593604800, 1593604801, 1593604801]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int64(expected) + } + "c1" => { + let expected = vec!["abc", "def", "abc", "def"]; + let expected = expected + .into_iter() + .map(|v| Some(v.as_bytes().to_vec())) + .collect::>(); + Array::Binary(expected) + } + "v11" => { + let expected = vec![42_f64, 7.7, 42.125, 7.7]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Float64(expected) + } + _ => unreachable!(), + } +} + +#[test] +fn test_lz4_inference() -> Result<()> { + + // - file "hadoop_lz4_compressed.parquet" is compressed using the hadoop Lz4Codec + // - file "non_hadoop_lz4_compressed.parquet" is "the LZ4 block format without the custom Hadoop header". + // see https://github.com/apache/parquet-testing/pull/14 + + // Those two files, are all marked as compressed as Lz4, the decompressor should + // be able to distinguish them from each other. + + let files = ["hadoop_lz4_compressed.parquet", "non_hadoop_lz4_compressed.parquet"]; + let columns = ["c0", "c1", "v11"]; + for file in files { + let mut path = get_path(); + path.push(file); + let path = path.to_str().unwrap(); + for column in columns { + let (result, _statistics) = get_column(path, column)?; + assert_eq!(result, verify_column_data(column), "of file {}", file); + } + } + Ok(()) +} + +#[test] +fn test_lz4_large_file() -> Result<()> { + + //File "hadoop_lz4_compressed_larger.parquet" is compressed using the hadoop Lz4Codec, + //which contains 10000 rows. + + let mut path = get_path(); + let file = "hadoop_lz4_compressed_larger.parquet"; + path.push(file); + let path = path.to_str().unwrap(); + let (result, _statistics) = get_column(path, "a")?; + assert_eq!(result.len(), 10000); + Ok(()) +} diff --git a/tests/it/read/mod.rs b/tests/it/read/mod.rs index c8a268a5b..349c5c59c 100644 --- a/tests/it/read/mod.rs +++ b/tests/it/read/mod.rs @@ -11,6 +11,9 @@ mod primitive_nested; mod struct_; mod utils; +#[cfg(any(feature = "lz4", feature = "lz4_flex"))] +mod lz4_legacy; + use std::fs::File; use futures::StreamExt;