diff --git a/parquet/src/arrow/arrow_reader.rs b/parquet/src/arrow/arrow_reader.rs index 01eef9a3a3b0..fada28daaae2 100644 --- a/parquet/src/arrow/arrow_reader.rs +++ b/parquet/src/arrow/arrow_reader.rs @@ -17,6 +17,13 @@ //! Contains reader which reads parquet data into arrow array. +use std::sync::Arc; + +use arrow::datatypes::{DataType as ArrowType, Schema, SchemaRef}; +use arrow::error::Result as ArrowResult; +use arrow::record_batch::{RecordBatch, RecordBatchReader}; +use arrow::{array::StructArray, error::ArrowError}; + use crate::arrow::array_reader::{build_array_reader, ArrayReader, StructArrayReader}; use crate::arrow::schema::parquet_to_arrow_schema; use crate::arrow::schema::{ @@ -25,11 +32,6 @@ use crate::arrow::schema::{ use crate::errors::{ParquetError, Result}; use crate::file::metadata::ParquetMetaData; use crate::file::reader::FileReader; -use arrow::datatypes::{DataType as ArrowType, Schema, SchemaRef}; -use arrow::error::Result as ArrowResult; -use arrow::record_batch::{RecordBatch, RecordBatchReader}; -use arrow::{array::StructArray, error::ArrowError}; -use std::sync::Arc; /// Arrow reader api. /// With this api, user can get arrow schema from parquet file, and read parquet data @@ -233,13 +235,29 @@ impl ParquetRecordBatchReader { #[cfg(test)] mod tests { + use std::cmp::min; + use std::convert::TryFrom; + use std::fs::File; + use std::io::Seek; + use std::path::PathBuf; + use std::sync::Arc; + + use rand::{thread_rng, RngCore}; + use serde_json::json; + use serde_json::Value::{Array as JArray, Null as JNull, Object as JObject}; + + use arrow::array::*; + use arrow::datatypes::{DataType as ArrowDataType, Field}; + use arrow::error::Result as ArrowResult; + use arrow::record_batch::{RecordBatch, RecordBatchReader}; + use crate::arrow::arrow_reader::{ArrowReader, ParquetFileArrowReader}; use crate::arrow::converter::{ BinaryArrayConverter, Converter, FixedSizeArrayConverter, FromConverter, IntervalDayTimeArrayConverter, LargeUtf8ArrayConverter, Utf8ArrayConverter, }; use crate::arrow::schema::add_encoded_arrow_schema_to_metadata; - use crate::basic::{ConvertedType, Encoding, Repetition}; + use crate::basic::{ConvertedType, Encoding, Repetition, Type as PhysicalType}; use crate::column::writer::get_typed_column_writer_mut; use crate::data_type::{ BoolType, ByteArray, ByteArrayType, DataType, FixedLenByteArray, @@ -253,18 +271,6 @@ mod tests { use crate::schema::types::{Type, TypePtr}; use crate::util::cursor::SliceableCursor; use crate::util::test_common::RandGen; - use arrow::array::*; - use arrow::datatypes::{DataType as ArrowDataType, Field}; - use arrow::record_batch::RecordBatchReader; - use rand::{thread_rng, RngCore}; - use serde_json::json; - use serde_json::Value::{Array as JArray, Null as JNull, Object as JObject}; - use std::cmp::min; - use std::convert::TryFrom; - use std::fs::File; - use std::io::Seek; - use std::path::PathBuf; - use std::sync::Arc; #[test] fn test_arrow_reader_all_columns() { @@ -1058,4 +1064,101 @@ mod tests { error ); } + + #[test] + fn test_dictionary_preservation() { + let mut fields = vec![Arc::new( + Type::primitive_type_builder("leaf", PhysicalType::BYTE_ARRAY) + .with_repetition(Repetition::OPTIONAL) + .with_converted_type(ConvertedType::UTF8) + .build() + .unwrap(), + )]; + + let schema = Arc::new( + Type::group_type_builder("test_schema") + .with_fields(&mut fields) + .build() + .unwrap(), + ); + + let dict_type = ArrowDataType::Dictionary( + Box::new(ArrowDataType::Int32), + Box::new(ArrowDataType::Utf8), + ); + + let arrow_field = Field::new("leaf", dict_type, true); + + let mut file = tempfile::tempfile().unwrap(); + + let values = vec![ + vec![ + ByteArray::from("hello"), + ByteArray::from("a"), + ByteArray::from("b"), + ByteArray::from("d"), + ], + vec![ + ByteArray::from("c"), + ByteArray::from("a"), + ByteArray::from("b"), + ], + ]; + + let def_levels = vec![ + vec![1, 0, 0, 1, 0, 0, 1, 1], + vec![0, 0, 1, 1, 0, 0, 1, 0, 0], + ]; + + let opts = TestOptions { + encoding: Encoding::RLE_DICTIONARY, + ..Default::default() + }; + + generate_single_column_file_with_data::( + &values, + Some(&def_levels), + file.try_clone().unwrap(), // Cannot use &mut File (#1163) + schema, + Some(arrow_field), + &opts, + ) + .unwrap(); + + file.rewind().unwrap(); + + let parquet_reader = SerializedFileReader::try_from(file).unwrap(); + let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(parquet_reader)); + + let record_reader = arrow_reader.get_record_reader(3).unwrap(); + + let batches = record_reader + .collect::>>() + .unwrap(); + + assert_eq!(batches.len(), 6); + assert!(batches.iter().all(|x| x.num_columns() == 1)); + + let row_counts = batches + .iter() + .map(|x| (x.num_rows(), x.column(0).null_count())) + .collect::>(); + + assert_eq!( + row_counts, + vec![(3, 2), (3, 2), (3, 1), (3, 1), (3, 2), (2, 2)] + ); + + let get_dict = + |batch: &RecordBatch| batch.column(0).data().child_data()[0].clone(); + + // First and second batch in same row group -> same dictionary + assert_eq!(get_dict(&batches[0]), get_dict(&batches[1])); + // Third batch spans row group -> computed dictionary + assert_ne!(get_dict(&batches[1]), get_dict(&batches[2])); + assert_ne!(get_dict(&batches[2]), get_dict(&batches[3])); + // Fourth, fifth and sixth from same row group -> same dictionary + assert_eq!(get_dict(&batches[3]), get_dict(&batches[4])); + assert_eq!(get_dict(&batches[4]), get_dict(&batches[5])); + } } diff --git a/parquet/src/arrow/record_reader.rs b/parquet/src/arrow/record_reader.rs index 593296270ca0..af4766aa1cd3 100644 --- a/parquet/src/arrow/record_reader.rs +++ b/parquet/src/arrow/record_reader.rs @@ -167,7 +167,13 @@ where break; } - let batch_size = max(num_records - records_read, MIN_BATCH_SIZE); + // If repetition levels present, we don't know how much more to read + // in order to read the requested number of records, therefore read at least + // MIN_BATCH_SIZE, otherwise read exactly what was requested + let batch_size = match &self.rep_levels { + Some(_) => max(num_records - records_read, MIN_BATCH_SIZE), + None => num_records - records_read, + }; // Try to more value from parquet pages let values_read = self.read_one_batch(batch_size)?;