Skip to content

Commit

Permalink
Restrict ScalarBuffer types
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Jan 11, 2022
1 parent 5e6bb01 commit 48b6d62
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 58 deletions.
93 changes: 63 additions & 30 deletions parquet/src/arrow/array_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::any::Any;
use std::cmp::{max, min};
use std::collections::{HashMap, HashSet};
use std::marker::PhantomData;
Expand Down Expand Up @@ -62,12 +63,13 @@ use crate::arrow::converter::{
IntervalYearMonthConverter, LargeBinaryArrayConverter, LargeBinaryConverter,
LargeUtf8ArrayConverter, LargeUtf8Converter,
};
use crate::arrow::record_reader::RecordReader;
use crate::arrow::record_reader::buffer::{ScalarValue, ValuesBuffer};
use crate::arrow::record_reader::{GenericRecordReader, RecordReader};
use crate::arrow::schema::parquet_to_arrow_field;
use crate::basic::{ConvertedType, Repetition, Type as PhysicalType};
use crate::column::page::PageIterator;
use crate::column::reader::decoder::ColumnValueDecoder;
use crate::column::reader::ColumnReaderImpl;
use crate::data_type::private::ScalarDataType;
use crate::data_type::{
BoolType, ByteArrayType, DataType, DoubleType, FixedLenByteArrayType, FloatType,
Int32Type, Int64Type, Int96Type,
Expand All @@ -78,7 +80,6 @@ use crate::schema::types::{
ColumnDescPtr, ColumnDescriptor, ColumnPath, SchemaDescPtr, Type, TypePtr,
};
use crate::schema::visitor::TypeVisitor;
use std::any::Any;

/// Array reader reads parquet data into arrow array.
pub trait ArrayReader {
Expand All @@ -105,11 +106,15 @@ pub trait ArrayReader {
///
/// Returns the number of records read, which can be less than batch_size if
/// pages is exhausted.
fn read_records<T: ScalarDataType>(
record_reader: &mut RecordReader<T>,
fn read_records<V, CV>(
record_reader: &mut GenericRecordReader<V, CV>,
pages: &mut dyn PageIterator,
batch_size: usize,
) -> Result<usize> {
) -> Result<usize>
where
V: ValuesBuffer + Default,
CV: ColumnValueDecoder<Slice = V::Slice>,
{
let mut records_read = 0usize;
while records_read < batch_size {
let records_to_read = batch_size - records_read;
Expand All @@ -133,7 +138,11 @@ fn read_records<T: ScalarDataType>(

/// A NullArrayReader reads Parquet columns stored as null int32s with an Arrow
/// NullArray type.
pub struct NullArrayReader<T: ScalarDataType> {
pub struct NullArrayReader<T>
where
T: DataType,
T::T: ScalarValue,
{
data_type: ArrowType,
pages: Box<dyn PageIterator>,
def_levels_buffer: Option<Buffer>,
Expand All @@ -143,7 +152,11 @@ pub struct NullArrayReader<T: ScalarDataType> {
_type_marker: PhantomData<T>,
}

impl<T: ScalarDataType> NullArrayReader<T> {
impl<T> NullArrayReader<T>
where
T: DataType,
T::T: ScalarValue,
{
/// Construct null array reader.
pub fn new(pages: Box<dyn PageIterator>, column_desc: ColumnDescPtr) -> Result<Self> {
let record_reader = RecordReader::<T>::new(column_desc.clone());
Expand All @@ -161,7 +174,11 @@ impl<T: ScalarDataType> NullArrayReader<T> {
}

/// Implementation of primitive array reader.
impl<T: ScalarDataType> ArrayReader for NullArrayReader<T> {
impl<T> ArrayReader for NullArrayReader<T>
where
T: DataType,
T::T: ScalarValue,
{
fn as_any(&self) -> &dyn Any {
self
}
Expand Down Expand Up @@ -201,7 +218,11 @@ impl<T: ScalarDataType> ArrayReader for NullArrayReader<T> {

/// Primitive array readers are leaves of array reader tree. They accept page iterator
/// and read them into primitive arrays.
pub struct PrimitiveArrayReader<T: ScalarDataType> {
pub struct PrimitiveArrayReader<T>
where
T: DataType,
T::T: ScalarValue,
{
data_type: ArrowType,
pages: Box<dyn PageIterator>,
def_levels_buffer: Option<Buffer>,
Expand All @@ -210,7 +231,11 @@ pub struct PrimitiveArrayReader<T: ScalarDataType> {
record_reader: RecordReader<T>,
}

impl<T: ScalarDataType> PrimitiveArrayReader<T> {
impl<T> PrimitiveArrayReader<T>
where
T: DataType,
T::T: ScalarValue,
{
/// Construct primitive array reader.
pub fn new(
pages: Box<dyn PageIterator>,
Expand Down Expand Up @@ -239,7 +264,11 @@ impl<T: ScalarDataType> PrimitiveArrayReader<T> {
}

/// Implementation of primitive array reader.
impl<T: ScalarDataType> ArrayReader for PrimitiveArrayReader<T> {
impl<T> ArrayReader for PrimitiveArrayReader<T>
where
T: DataType,
T::T: ScalarValue,
{
fn as_any(&self) -> &dyn Any {
self
}
Expand Down Expand Up @@ -1907,7 +1936,26 @@ impl<'a> ArrayReaderBuilder {

#[cfg(test)]
mod tests {
use super::*;
use std::any::Any;
use std::collections::VecDeque;
use std::sync::Arc;

use rand::distributions::uniform::SampleUniform;
use rand::{thread_rng, Rng};

use arrow::array::{
Array, ArrayRef, LargeListArray, ListArray, PrimitiveArray, StringArray,
StructArray,
};
use arrow::datatypes::{
ArrowPrimitiveType, DataType as ArrowType, Date32Type as ArrowDate32, Field,
Int32Type as ArrowInt32, Int64Type as ArrowInt64,
Time32MillisecondType as ArrowTime32MillisecondArray,
Time64MicrosecondType as ArrowTime64MicrosecondArray,
TimestampMicrosecondType as ArrowTimestampMicrosecondType,
TimestampMillisecondType as ArrowTimestampMillisecondType,
};

use crate::arrow::converter::{Utf8ArrayConverter, Utf8Converter};
use crate::arrow::schema::parquet_to_arrow_schema;
use crate::basic::{Encoding, Type as PhysicalType};
Expand All @@ -1921,23 +1969,8 @@ mod tests {
DataPageBuilder, DataPageBuilderImpl, InMemoryPageIterator,
};
use crate::util::test_common::{get_test_file, make_pages};
use arrow::array::{
Array, ArrayRef, LargeListArray, ListArray, PrimitiveArray, StringArray,
StructArray,
};
use arrow::datatypes::{
ArrowPrimitiveType, DataType as ArrowType, Date32Type as ArrowDate32, Field,
Int32Type as ArrowInt32, Int64Type as ArrowInt64,
Time32MillisecondType as ArrowTime32MillisecondArray,
Time64MicrosecondType as ArrowTime64MicrosecondArray,
TimestampMicrosecondType as ArrowTimestampMicrosecondType,
TimestampMillisecondType as ArrowTimestampMillisecondType,
};
use rand::distributions::uniform::SampleUniform;
use rand::{thread_rng, Rng};
use std::any::Any;
use std::collections::VecDeque;
use std::sync::Arc;

use super::*;

fn make_column_chunks<T: DataType>(
column_desc: ColumnDescPtr,
Expand Down
8 changes: 4 additions & 4 deletions parquet/src/arrow/record_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use arrow::bitmap::Bitmap;
use arrow::buffer::Buffer;

use crate::arrow::record_reader::{
buffer::{BufferQueue, TypedBuffer, ValuesBuffer},
buffer::{BufferQueue, ScalarBuffer, ValuesBuffer},
definition_levels::{DefinitionLevelBuffer, DefinitionLevelDecoder},
};
use crate::column::{
Expand All @@ -42,7 +42,7 @@ const MIN_BATCH_SIZE: usize = 1024;

/// A `RecordReader` is a stateful column reader that delimits semantic records.
pub type RecordReader<T> =
GenericRecordReader<TypedBuffer<<T as DataType>::T>, ColumnValueDecoderImpl<T>>;
GenericRecordReader<ScalarBuffer<<T as DataType>::T>, ColumnValueDecoderImpl<T>>;

#[doc(hidden)]
/// A generic stateful column reader that delimits semantic records
Expand All @@ -55,7 +55,7 @@ pub struct GenericRecordReader<V, CV> {

records: V,
def_levels: Option<DefinitionLevelBuffer>,
rep_levels: Option<TypedBuffer<i16>>,
rep_levels: Option<ScalarBuffer<i16>>,
column_reader:
Option<GenericColumnReader<ColumnLevelDecoderImpl, DefinitionLevelDecoder, CV>>,

Expand All @@ -77,7 +77,7 @@ where
let def_levels =
(desc.max_def_level() > 0).then(|| DefinitionLevelBuffer::new(&desc));

let rep_levels = (desc.max_rep_level() > 0).then(TypedBuffer::new);
let rep_levels = (desc.max_rep_level() > 0).then(ScalarBuffer::new);

Self {
records: Default::default(),
Expand Down
26 changes: 21 additions & 5 deletions parquet/src/arrow/record_reader/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,24 @@ pub trait BufferQueue: Sized {
fn set_len(&mut self, len: usize);
}

/// A marker trait for [scalar] types
///
/// This means that a `[Self::default()]` of length `len` can be safely created from a
/// zero-initialized `[u8]` with length `len * std::mem::size_of::<Self>()` and
/// alignment of `std::mem::size_of::<Self>()`
///
/// [scalar]: https://doc.rust-lang.org/book/ch03-02-data-types.html#scalar-types
///
pub trait ScalarValue {}
impl ScalarValue for bool {}
impl ScalarValue for i16 {}
impl ScalarValue for i32 {}
impl ScalarValue for i64 {}
impl ScalarValue for f32 {}
impl ScalarValue for f64 {}

/// A typed buffer similar to [`Vec<T>`] but using [`MutableBuffer`] for storage
pub struct TypedBuffer<T> {
pub struct ScalarBuffer<T: ScalarValue> {
buffer: MutableBuffer,

/// Length in elements of size T
Expand All @@ -75,13 +91,13 @@ pub struct TypedBuffer<T> {
_phantom: PhantomData<*mut T>,
}

impl<T> Default for TypedBuffer<T> {
impl<T: ScalarValue> Default for ScalarBuffer<T> {
fn default() -> Self {
Self::new()
}
}

impl<T> TypedBuffer<T> {
impl<T: ScalarValue> ScalarBuffer<T> {
pub fn new() -> Self {
Self {
buffer: MutableBuffer::new(0),
Expand Down Expand Up @@ -114,7 +130,7 @@ impl<T> TypedBuffer<T> {
}
}

impl<T> BufferQueue for TypedBuffer<T> {
impl<T: ScalarValue> BufferQueue for ScalarBuffer<T> {
type Output = Buffer;

type Slice = [T];
Expand Down Expand Up @@ -177,7 +193,7 @@ pub trait ValuesBuffer: BufferQueue {
);
}

impl<T> ValuesBuffer for TypedBuffer<T> {
impl<T: ScalarValue> ValuesBuffer for ScalarBuffer<T> {
fn pad_nulls(
&mut self,
range: Range<usize>,
Expand Down
6 changes: 3 additions & 3 deletions parquet/src/arrow/record_reader/definition_levels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ use crate::column::reader::decoder::ColumnLevelDecoderImpl;
use crate::schema::types::ColumnDescPtr;

use super::{
buffer::{BufferQueue, TypedBuffer},
buffer::{BufferQueue, ScalarBuffer},
MIN_BATCH_SIZE,
};

pub struct DefinitionLevelBuffer {
buffer: TypedBuffer<i16>,
buffer: ScalarBuffer<i16>,
builder: BooleanBufferBuilder,
max_level: i16,
}
Expand Down Expand Up @@ -62,7 +62,7 @@ impl BufferQueue for DefinitionLevelBuffer {
impl DefinitionLevelBuffer {
pub fn new(desc: &ColumnDescPtr) -> Self {
Self {
buffer: TypedBuffer::new(),
buffer: ScalarBuffer::new(),
builder: BooleanBufferBuilder::new(0),
max_level: desc.max_def_level(),
}
Expand Down
16 changes: 0 additions & 16 deletions parquet/src/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,6 @@ impl AsBytes for str {
}

pub(crate) mod private {
use super::*;
use crate::encodings::decoding::PlainDecoderDetails;
use crate::util::bit_util::{round_upto_power_of_2, BitReader, BitWriter};
use crate::util::memory::ByteBufferPtr;
Expand Down Expand Up @@ -1033,21 +1032,6 @@ pub(crate) mod private {
self
}
}

/// A marker trait for [`DataType`] with a [scalar] physical type
///
/// This means that a `[Self::T::default()]` of length `len` can be safely created from a
/// zero-initialized `[u8]` with length `len * Self::get_type_size()` and
/// alignment of `Self::get_type_size()`
///
/// [scalar]: https://doc.rust-lang.org/book/ch03-02-data-types.html#scalar-types
///
pub trait ScalarDataType: DataType {}
impl ScalarDataType for BoolType {}
impl ScalarDataType for Int32Type {}
impl ScalarDataType for Int64Type {}
impl ScalarDataType for FloatType {}
impl ScalarDataType for DoubleType {}
}

/// Contains the Parquet physical type information as well as the Rust primitive type
Expand Down

0 comments on commit 48b6d62

Please sign in to comment.