diff --git a/src/io/parquet/read/primitive/basic.rs b/src/io/parquet/read/primitive/basic.rs index 968945a1771..a783f001a4a 100644 --- a/src/io/parquet/read/primitive/basic.rs +++ b/src/io/parquet/read/primitive/basic.rs @@ -72,6 +72,32 @@ fn read_dict_buffer_optional( } } +fn read_dict_buffer_required( + indices_buffer: &[u8], + additional: usize, + dict: &PrimitivePageDict, + values: &mut MutableBuffer, + validity: &mut MutableBitmap, + op: F, +) where + T: NativeType, + A: ArrowNativeType, + F: Fn(T) -> A, +{ + let dict_values = dict.values(); + + // SPEC: Data page format: the bit width used to encode the entry ids stored as 1 byte (max bit width = 32), + // SPEC: followed by the values encoded using RLE/Bit packed described above (with the given bit width). + let bit_width = indices_buffer[0]; + let indices_buffer = &indices_buffer[1..]; + + let indices = hybrid_rle::HybridRleDecoder::new(indices_buffer, bit_width as u32, additional); + + values.extend(indices.map(|index| op(dict_values[index as usize]))); + + validity.extend_constant(additional, true); +} + fn read_nullable( validity_buffer: &[u8], values_buffer: &[u8], @@ -170,6 +196,16 @@ where op, ) } + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), false) => { + read_dict_buffer_required( + values_buffer, + additional, + dict.as_any().downcast_ref().unwrap(), + values, + validity, + op, + ) + } // it can happen that there is a dictionary but the encoding is plain because // it falled back. (Encoding::Plain, _, true) => read_nullable( diff --git a/tests/it/io/parquet/read.rs b/tests/it/io/parquet/read.rs index 61d855edf79..35ddcf3b5ce 100644 --- a/tests/it/io/parquet/read.rs +++ b/tests/it/io/parquet/read.rs @@ -112,6 +112,16 @@ fn v1_int64_nullable_dict() -> Result<()> { test_pyarrow_integration(0, 1, "basic", true, false) } +#[test] +fn v2_int64_required_dict() -> Result<()> { + test_pyarrow_integration(0, 2, "basic", true, true) +} + +#[test] +fn v1_int64_required_dict() -> Result<()> { + test_pyarrow_integration(0, 1, "basic", true, true) +} + #[test] fn v2_utf8_nullable() -> Result<()> { test_pyarrow_integration(2, 2, "basic", false, false)