Skip to content

Commit

Permalink
Complete test for avro_to_arrow_reader on alltypes_dictionnary
Browse files Browse the repository at this point in the history
  • Loading branch information
Igosuki committed Aug 20, 2021
1 parent 2c695a4 commit 16e392f
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 6 deletions.
123 changes: 118 additions & 5 deletions datafusion/src/avro_to_arrow/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::avro_to_arrow::arrow_array_reader::AvroArrowArrayReader;
use crate::error::Result;
use arrow::error::Result as ArrowResult;
use avro_rs::Reader as AvroReader;
use std::io::{BufReader, Read, Seek};
use std::io::{Read, Seek};
use std::sync::Arc;

/// Avro file reader builder
Expand Down Expand Up @@ -86,18 +86,19 @@ impl ReaderBuilder {
}

/// Create a new `Reader` from the `ReaderBuilder`
pub fn build<'a, R>(self, source: R) -> Result<Reader<'a, BufReader<R>>>
pub fn build<'a, R>(self, source: R) -> Result<Reader<'a, R>>
where
R: Read + Seek,
{
let mut buf_reader = BufReader::new(source);
let mut source = source;

// check if schema should be inferred
let schema = match self.schema {
Some(schema) => schema,
None => Arc::new(infer_avro_schema_from_reader(&mut buf_reader)?),
None => Arc::new(infer_avro_schema_from_reader(&mut source)?),
};
Reader::try_new(buf_reader, schema, self.batch_size, self.projection)
source.rewind()?;
Reader::try_new(source, schema, self.batch_size, self.projection)
}
}

Expand Down Expand Up @@ -157,3 +158,115 @@ pub fn infer_avro_schema_from_reader<R: Read + Seek>(reader: &mut R) -> Result<S
let schema = avro_reader.writer_schema();
super::to_arrow_schema(schema)
}

#[cfg(test)]
mod tests {
use super::*;
use crate::arrow::array::*;
use crate::arrow::datatypes::{DataType, Field};
use arrow::datatypes::TimeUnit;
use std::fs::File;

fn build_reader(name: &str) -> Reader<File> {
let testdata = crate::test_util::arrow_test_data();
let filename = format!("{}/avro/{}", testdata, name);
let builder = ReaderBuilder::new().infer_schema().with_batch_size(64);
builder.build(File::open(filename).unwrap()).unwrap()
}

fn get_col<'a, T: 'static>(
batch: &'a RecordBatch,
col: (usize, &Field),
) -> Option<&'a T> {
batch.column(col.0).as_any().downcast_ref::<T>()
}

#[test]
fn test_json_basic() {
let mut reader = build_reader("alltypes_dictionary.avro");
let batch = reader.next().unwrap().unwrap();

assert_eq!(11, batch.num_columns());
assert_eq!(2, batch.num_rows());

let schema = reader.schema();
let batch_schema = batch.schema();
assert_eq!(schema, batch_schema);

let id = schema.column_with_name("id").unwrap();
assert_eq!(0, id.0);
assert_eq!(&DataType::Int32, id.1.data_type());
let col = get_col::<Int32Array>(&batch, id).unwrap();
assert_eq!(0, col.value(0));
assert_eq!(1, col.value(1));
let bool_col = schema.column_with_name("bool_col").unwrap();
assert_eq!(1, bool_col.0);
assert_eq!(&DataType::Boolean, bool_col.1.data_type());
let col = get_col::<BooleanArray>(&batch, bool_col).unwrap();
assert_eq!(true, col.value(0));
assert_eq!(false, col.value(1));
let tinyint_col = schema.column_with_name("tinyint_col").unwrap();
assert_eq!(2, tinyint_col.0);
assert_eq!(&DataType::Int32, tinyint_col.1.data_type());
let col = get_col::<Int32Array>(&batch, tinyint_col).unwrap();
assert_eq!(0, col.value(0));
assert_eq!(1, col.value(1));
let smallint_col = schema.column_with_name("smallint_col").unwrap();
assert_eq!(3, smallint_col.0);
assert_eq!(&DataType::Int32, smallint_col.1.data_type());
let col = get_col::<Int32Array>(&batch, smallint_col).unwrap();
assert_eq!(0, col.value(0));
assert_eq!(1, col.value(1));
let int_col = schema.column_with_name("int_col").unwrap();
assert_eq!(4, int_col.0);
let col = get_col::<Int32Array>(&batch, int_col).unwrap();
assert_eq!(0, col.value(0));
assert_eq!(1, col.value(1));
assert_eq!(&DataType::Int32, int_col.1.data_type());
let col = get_col::<Int32Array>(&batch, int_col).unwrap();
assert_eq!(0, col.value(0));
assert_eq!(1, col.value(1));
let bigint_col = schema.column_with_name("bigint_col").unwrap();
assert_eq!(5, bigint_col.0);
let col = get_col::<Int64Array>(&batch, bigint_col).unwrap();
assert_eq!(0, col.value(0));
assert_eq!(10, col.value(1));
assert_eq!(&DataType::Int64, bigint_col.1.data_type());
let float_col = schema.column_with_name("float_col").unwrap();
assert_eq!(6, float_col.0);
let col = get_col::<Float32Array>(&batch, float_col).unwrap();
assert_eq!(0.0, col.value(0));
assert_eq!(1.1, col.value(1));
assert_eq!(&DataType::Float32, float_col.1.data_type());
let col = get_col::<Float32Array>(&batch, float_col).unwrap();
assert_eq!(0.0, col.value(0));
assert_eq!(1.1, col.value(1));
let double_col = schema.column_with_name("double_col").unwrap();
assert_eq!(7, double_col.0);
assert_eq!(&DataType::Float64, double_col.1.data_type());
let col = get_col::<Float64Array>(&batch, double_col).unwrap();
assert_eq!(0.0, col.value(0));
assert_eq!(10.1, col.value(1));
let date_string_col = schema.column_with_name("date_string_col").unwrap();
assert_eq!(8, date_string_col.0);
assert_eq!(&DataType::Binary, date_string_col.1.data_type());
let col = get_col::<BinaryArray>(&batch, date_string_col).unwrap();
assert_eq!("01/01/09".as_bytes(), col.value(0));
assert_eq!("01/01/09".as_bytes(), col.value(1));
let string_col = schema.column_with_name("string_col").unwrap();
assert_eq!(9, string_col.0);
assert_eq!(&DataType::Binary, string_col.1.data_type());
let col = get_col::<BinaryArray>(&batch, string_col).unwrap();
assert_eq!("0".as_bytes(), col.value(0));
assert_eq!("1".as_bytes(), col.value(1));
let timestamp_col = schema.column_with_name("timestamp_col").unwrap();
assert_eq!(10, timestamp_col.0);
assert_eq!(
&DataType::Timestamp(TimeUnit::Microsecond, None),
timestamp_col.1.data_type()
);
let col = get_col::<TimestampMicrosecondArray>(&batch, timestamp_col).unwrap();
assert_eq!(1230768000000000, col.value(0));
assert_eq!(1230768060000000, col.value(1));
}
}

0 comments on commit 16e392f

Please sign in to comment.