diff --git a/Cargo.toml b/Cargo.toml index 22ed05de4db..6500ca771ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,6 +60,8 @@ ahash = { version = "0.7", optional = true } parquet2 = { version = "0.4", optional = true, default_features = false, features = ["stream"] } +avro-rs = { version = "0.13", optional = true, default_features = false } + # for division/remainder optimization at runtime strength_reduce = { version = "0.2", optional = true } multiversion = { version = "0.6.1", optional = true } @@ -86,6 +88,7 @@ full = [ "io_print", "io_parquet", "io_parquet_compression", + "io_avro", "regex", "merge_sort", "ahash", @@ -107,6 +110,7 @@ io_parquet_compression = [ "parquet2/lz4", "parquet2/brotli", ] +io_avro = ["avro-rs", "streaming-iterator", "serde_json"] # io_json: its dependencies + error handling # serde_derive: there is some derive around io_json_integration = ["io_json", "serde_derive", "hex"] diff --git a/src/io/avro/mod.rs b/src/io/avro/mod.rs new file mode 100644 index 00000000000..119ea6d142b --- /dev/null +++ b/src/io/avro/mod.rs @@ -0,0 +1,11 @@ +//! Read and write from and to Apache Avro + +pub mod read; + +use crate::error::ArrowError; + +impl From for ArrowError { + fn from(error: avro_rs::SerError) -> Self { + ArrowError::External("".to_string(), Box::new(error)) + } +} diff --git a/src/io/avro/read/deserialize.rs b/src/io/avro/read/deserialize.rs new file mode 100644 index 00000000000..8004adae75e --- /dev/null +++ b/src/io/avro/read/deserialize.rs @@ -0,0 +1,138 @@ +use std::convert::TryInto; +use std::sync::Arc; + +use crate::array::*; +use crate::datatypes::*; +use crate::error::ArrowError; +use crate::error::Result; +use crate::record_batch::RecordBatch; + +use super::util; + +pub fn deserialize(mut block: &[u8], rows: usize, schema: Arc) -> Result { + // create mutables, one per field + let mut arrays: Vec> = schema + .fields() + .iter() + .map(|field| match field.data_type().to_physical_type() { + PhysicalType::Boolean => { + Ok(Box::new(MutableBooleanArray::with_capacity(rows)) as Box) + } + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + Ok(Box::new(MutablePrimitiveArray::<$T>::with_capacity(rows)) as Box) + }), + PhysicalType::Utf8 => { + Ok(Box::new(MutableUtf8Array::::with_capacity(rows)) as Box) + } + PhysicalType::Binary => { + Ok(Box::new(MutableBinaryArray::::with_capacity(rows)) + as Box) + } + other => { + return Err(ArrowError::NotYetImplemented(format!( + "Deserializing type {:?} is still not implemented", + other + ))) + } + }) + .collect::>()?; + + // this is _the_ expensive transpose (rows -> columns) + for _ in 0..rows { + for (array, field) in arrays.iter_mut().zip(schema.fields().iter()) { + if field.is_nullable() { + // variant 0 is always the null in a union array + if util::zigzag_i64(&mut block)? == 0 { + array.push_null(); + continue; + } + } + + match array.data_type().to_physical_type() { + PhysicalType::Boolean => { + let is_valid = block[0] == 1; + block = &block[1..]; + let array = array + .as_mut_any() + .downcast_mut::() + .unwrap(); + array.push(Some(is_valid)) + } + PhysicalType::Primitive(primitive) => { + use crate::datatypes::PrimitiveType::*; + match primitive { + Int32 => { + let value = util::zigzag_i64(&mut block)? as i32; + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + } + Int64 => { + let value = util::zigzag_i64(&mut block)? as i64; + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + } + Float32 => { + let value = f32::from_le_bytes(block[..4].try_into().unwrap()); + block = &block[4..]; + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + } + Float64 => { + let value = f64::from_le_bytes(block[..8].try_into().unwrap()); + block = &block[8..]; + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + } + _ => unreachable!(), + } + } + PhysicalType::Utf8 => { + let len: usize = util::zigzag_i64(&mut block)?.try_into().map_err(|_| { + ArrowError::ExternalFormat( + "Avro format contains a non-usize number of bytes".to_string(), + ) + })?; + let data = std::str::from_utf8(&block[..len])?; + block = &block[len..]; + + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(data)) + } + PhysicalType::Binary => { + let len: usize = util::zigzag_i64(&mut block)?.try_into().map_err(|_| { + ArrowError::ExternalFormat( + "Avro format contains a non-usize number of bytes".to_string(), + ) + })?; + let data = &block[..len]; + block = &block[len..]; + + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(data)) + } + _ => todo!(), + }; + } + } + let columns = arrays.iter_mut().map(|array| array.as_arc()).collect(); + + RecordBatch::try_new(schema, columns) +} diff --git a/src/io/avro/read/mod.rs b/src/io/avro/read/mod.rs new file mode 100644 index 00000000000..300310b2fbe --- /dev/null +++ b/src/io/avro/read/mod.rs @@ -0,0 +1,172 @@ +use std::io::Read; +use std::sync::Arc; + +use avro_rs::Codec; +use streaming_iterator::StreamingIterator; + +mod deserialize; +mod schema; +mod util; + +use crate::datatypes::Schema; +use crate::error::{ArrowError, Result}; +use crate::record_batch::RecordBatch; + +pub fn read_metadata(reader: &mut R) -> Result<(Schema, Codec, [u8; 16])> { + let (schema, codec, marker) = util::read_schema(reader)?; + Ok((schema::convert_schema(&schema)?, codec, marker)) +} + +fn read_size(reader: &mut R) -> Result<(usize, usize)> { + let rows = match util::zigzag_i64(reader) { + Ok(a) => a, + Err(ArrowError::Io(io_err)) => { + if let std::io::ErrorKind::UnexpectedEof = io_err.kind() { + // end + return Ok((0, 0)); + } else { + return Err(ArrowError::Io(io_err)); + } + } + Err(other) => return Err(other), + }; + let bytes = util::zigzag_i64(reader)?; + Ok((rows as usize, bytes as usize)) +} + +/// Reads a block from the file into `buf`. +/// # Panic +/// Panics iff the block marker does not equal to the file's marker +fn read_block(reader: &mut R, buf: &mut Vec, file_marker: [u8; 16]) -> Result { + let (rows, bytes) = read_size(reader)?; + if rows == 0 { + return Ok(0); + }; + + buf.resize(bytes, 0); + reader.read_exact(buf)?; + + let mut marker = [0u8; 16]; + reader.read_exact(&mut marker)?; + + if marker != file_marker { + panic!(); + } + Ok(rows) +} + +fn decompress_block(buf: &mut Vec, decompress: &mut Vec, codec: Codec) -> Result { + match codec { + Codec::Null => { + std::mem::swap(buf, decompress); + Ok(false) + } + Codec::Deflate => { + todo!() + } + } +} + +/// [`StreamingIterator`] of blocks of avro data +pub struct BlockStreamIterator<'a, R: Read> { + buf: (Vec, usize), + reader: &'a mut R, + file_marker: [u8; 16], +} + +impl<'a, R: Read> BlockStreamIterator<'a, R> { + pub fn new(reader: &'a mut R, file_marker: [u8; 16]) -> Self { + Self { + reader, + file_marker, + buf: (vec![], 0), + } + } + + pub fn buffer(&mut self) -> &mut Vec { + &mut self.buf.0 + } +} + +impl<'a, R: Read> StreamingIterator for BlockStreamIterator<'a, R> { + type Item = (Vec, usize); + + fn advance(&mut self) { + let (buf, rows) = &mut self.buf; + // todo: surface this error + *rows = read_block(self.reader, buf, self.file_marker).unwrap(); + } + + fn get(&self) -> Option<&Self::Item> { + if self.buf.1 > 0 { + Some(&self.buf) + } else { + None + } + } +} + +/// [`StreamingIterator`] of blocks of decompressed avro data +pub struct Decompressor<'a, R: Read> { + blocks: BlockStreamIterator<'a, R>, + codec: Codec, + buf: (Vec, usize), + was_swapped: bool, +} + +impl<'a, R: Read> Decompressor<'a, R> { + pub fn new(blocks: BlockStreamIterator<'a, R>, codec: Codec) -> Self { + Self { + blocks, + codec, + buf: (vec![], 0), + was_swapped: false, + } + } +} + +impl<'a, R: Read> StreamingIterator for Decompressor<'a, R> { + type Item = (Vec, usize); + + fn advance(&mut self) { + if self.was_swapped { + std::mem::swap(self.blocks.buffer(), &mut self.buf.0); + } + self.blocks.advance(); + self.was_swapped = + decompress_block(self.blocks.buffer(), &mut self.buf.0, self.codec).unwrap(); + self.buf.1 = self.blocks.get().map(|(_, rows)| *rows).unwrap_or_default(); + } + + fn get(&self) -> Option<&Self::Item> { + if self.buf.1 > 0 { + Some(&self.buf) + } else { + None + } + } +} + +/// Single threaded, blocking reader of Avro files; [`Iterator`] of [`RecordBatch`]es. +pub struct Reader<'a, R: Read> { + iter: Decompressor<'a, R>, + schema: Arc, +} + +impl<'a, R: Read> Reader<'a, R> { + pub fn new(iter: Decompressor<'a, R>, schema: Arc) -> Self { + Self { iter, schema } + } +} + +impl<'a, R: Read> Iterator for Reader<'a, R> { + type Item = Result; + + fn next(&mut self) -> Option { + if let Some((data, rows)) = self.iter.next() { + Some(deserialize::deserialize(data, *rows, self.schema.clone())) + } else { + None + } + } +} diff --git a/src/io/avro/read/schema.rs b/src/io/avro/read/schema.rs new file mode 100644 index 00000000000..6d4554e9185 --- /dev/null +++ b/src/io/avro/read/schema.rs @@ -0,0 +1,199 @@ +use std::collections::BTreeMap; + +use avro_rs::schema::Name; +use avro_rs::types::Value; +use avro_rs::Schema as AvroSchema; + +use crate::datatypes::*; +use crate::error::{ArrowError, Result}; + +/// Returns the fully qualified name for a field +pub fn aliased(name: &str, namespace: Option<&str>, default_namespace: Option<&str>) -> String { + if name.contains('.') { + name.to_string() + } else { + let namespace = namespace.as_ref().copied().or(default_namespace); + + match namespace { + Some(ref namespace) => format!("{}.{}", namespace, name), + None => name.to_string(), + } + } +} + +fn external_props(schema: &AvroSchema) -> BTreeMap { + let mut props = BTreeMap::new(); + match &schema { + AvroSchema::Record { + doc: Some(ref doc), .. + } + | AvroSchema::Enum { + doc: Some(ref doc), .. + } => { + props.insert("avro::doc".to_string(), doc.clone()); + } + _ => {} + } + match &schema { + AvroSchema::Record { + name: + Name { + aliases: Some(aliases), + namespace, + .. + }, + .. + } + | AvroSchema::Enum { + name: + Name { + aliases: Some(aliases), + namespace, + .. + }, + .. + } + | AvroSchema::Fixed { + name: + Name { + aliases: Some(aliases), + namespace, + .. + }, + .. + } => { + let aliases: Vec = aliases + .iter() + .map(|alias| aliased(alias, namespace.as_deref(), None)) + .collect(); + props.insert( + "avro::aliases".to_string(), + format!("[{}]", aliases.join(",")), + ); + } + _ => {} + } + props +} + +pub fn convert_schema(schema: &AvroSchema) -> Result { + let mut schema_fields = vec![]; + match schema { + AvroSchema::Record { fields, .. } => { + for field in fields { + schema_fields.push(schema_to_field( + &field.schema, + Some(&field.name), + false, + Some(&external_props(&field.schema)), + )?) + } + } + schema => schema_fields.push(schema_to_field(schema, Some(""), false, None)?), + } + + let schema = Schema::new(schema_fields); + Ok(schema) +} + +fn schema_to_field( + schema: &AvroSchema, + name: Option<&str>, + mut nullable: bool, + props: Option<&BTreeMap>, +) -> Result { + let data_type = match schema { + AvroSchema::Null => DataType::Null, + AvroSchema::Boolean => DataType::Boolean, + AvroSchema::Int => DataType::Int32, + AvroSchema::Long => DataType::Int64, + AvroSchema::Float => DataType::Float32, + AvroSchema::Double => DataType::Float64, + AvroSchema::Bytes => DataType::Binary, + AvroSchema::String => DataType::Utf8, + AvroSchema::Array(item_schema) => { + DataType::List(Box::new(schema_to_field(item_schema, None, false, None)?)) + } + AvroSchema::Map(value_schema) => { + let value_field = schema_to_field(value_schema, Some("value"), false, None)?; + DataType::Dictionary( + Box::new(DataType::Utf8), + Box::new(value_field.data_type().clone()), + ) + } + AvroSchema::Union(us) => { + // If there are only two variants and one of them is null, set the other type as the field data type + let has_nullable = us.find_schema(&Value::Null).is_some(); + let sub_schemas = us.variants(); + if has_nullable && sub_schemas.len() == 2 { + nullable = true; + if let Some(schema) = sub_schemas + .iter() + .find(|&schema| !matches!(schema, AvroSchema::Null)) + { + schema_to_field(schema, None, has_nullable, None)? + .data_type() + .clone() + } else { + return Err(ArrowError::NotYetImplemented(format!( + "Can't read avro union {:?}", + us + ))); + } + } else { + let fields = sub_schemas + .iter() + .map(|s| schema_to_field(s, None, has_nullable, None)) + .collect::>>()?; + DataType::Union(fields, None, false) + } + } + AvroSchema::Record { name, fields, .. } => { + let fields: Result> = fields + .iter() + .map(|field| { + let mut props = BTreeMap::new(); + if let Some(doc) = &field.doc { + props.insert("avro::doc".to_string(), doc.clone()); + } + /*if let Some(aliases) = fields.aliases { + props.insert("aliases", aliases); + }*/ + schema_to_field( + &field.schema, + Some(&format!("{}.{}", name.fullname(None), field.name)), + false, + Some(&props), + ) + }) + .collect(); + DataType::Struct(fields?) + } + AvroSchema::Enum { name, .. } => { + return Ok(Field::new_dict( + &name.fullname(None), + DataType::Dictionary(Box::new(DataType::UInt64), Box::new(DataType::Utf8)), + false, + 0, + false, + )) + } + AvroSchema::Fixed { size, .. } => DataType::FixedSizeBinary(*size as i32), + AvroSchema::Decimal { + precision, scale, .. + } => DataType::Decimal(*precision, *scale), + AvroSchema::Uuid => DataType::Utf8, + AvroSchema::Date => DataType::Date32, + AvroSchema::TimeMillis => DataType::Time32(TimeUnit::Millisecond), + AvroSchema::TimeMicros => DataType::Time64(TimeUnit::Microsecond), + AvroSchema::TimestampMillis => DataType::Timestamp(TimeUnit::Millisecond, None), + AvroSchema::TimestampMicros => DataType::Timestamp(TimeUnit::Microsecond, None), + AvroSchema::Duration => DataType::Duration(TimeUnit::Millisecond), + }; + + let name = name.unwrap_or_default(); + + let mut field = Field::new(name, data_type, nullable); + field.set_metadata(props.cloned()); + Ok(field) +} diff --git a/src/io/avro/read/util.rs b/src/io/avro/read/util.rs new file mode 100644 index 00000000000..1bdcbd9fd0f --- /dev/null +++ b/src/io/avro/read/util.rs @@ -0,0 +1,94 @@ +use std::io::Read; +use std::str::FromStr; + +use crate::error::Result; + +use avro_rs::{from_avro_datum, types::Value, AvroResult, Codec, Error, Schema}; +use serde_json::from_slice; + +pub fn zigzag_i64(reader: &mut R) -> Result { + let z = decode_variable(reader)?; + Ok(if z & 0x1 == 0 { + (z >> 1) as i64 + } else { + !(z >> 1) as i64 + }) +} + +fn decode_variable(reader: &mut R) -> Result { + let mut i = 0u64; + let mut buf = [0u8; 1]; + + let mut j = 0; + loop { + if j > 9 { + // if j * 7 > 64 + panic!() + } + reader.read_exact(&mut buf[..])?; + i |= (u64::from(buf[0] & 0x7F)) << (j * 7); + if (buf[0] >> 7) == 0 { + break; + } else { + j += 1; + } + } + + Ok(i) +} + +fn read_file_marker(reader: &mut R) -> AvroResult<[u8; 16]> { + let mut marker = [0u8; 16]; + reader.read_exact(&mut marker).map_err(Error::ReadMarker)?; + Ok(marker) +} + +/// Reads the schema from `reader`, returning the file's [`Schema`] and [`Codec`]. +/// # Error +/// This function errors iff the header is not a valid avro file header. +pub fn read_schema(reader: &mut R) -> AvroResult<(Schema, Codec, [u8; 16])> { + let meta_schema = Schema::Map(Box::new(Schema::Bytes)); + + let mut buf = [0u8; 4]; + reader.read_exact(&mut buf).map_err(Error::ReadHeader)?; + + if buf != [b'O', b'b', b'j', 1u8] { + return Err(Error::HeaderMagic); + } + + if let Value::Map(meta) = from_avro_datum(&meta_schema, reader, None)? { + // TODO: surface original parse schema errors instead of coalescing them here + let json = meta + .get("avro.schema") + .and_then(|bytes| { + if let Value::Bytes(ref bytes) = *bytes { + from_slice(bytes.as_ref()).ok() + } else { + None + } + }) + .ok_or(Error::GetAvroSchemaFromMap)?; + let schema = Schema::parse(&json)?; + + let codec = if let Some(codec) = meta + .get("avro.codec") + .and_then(|codec| { + if let Value::Bytes(ref bytes) = *codec { + std::str::from_utf8(bytes.as_ref()).ok() + } else { + None + } + }) + .and_then(|codec| Codec::from_str(codec).ok()) + { + codec + } else { + Codec::Null + }; + let marker = read_file_marker(reader)?; + + Ok((schema, codec, marker)) + } else { + Err(Error::GetHeaderMetadata) + } +} diff --git a/src/io/mod.rs b/src/io/mod.rs index 1ad0c3b4598..3bddcbc20a9 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -19,6 +19,10 @@ pub mod json_integration; #[cfg_attr(docsrs, doc(cfg(feature = "io_parquet")))] pub mod parquet; +#[cfg(feature = "io_avro")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_avro")))] +pub mod avro; + #[cfg(feature = "io_print")] #[cfg_attr(docsrs, doc(cfg(feature = "io_print")))] pub mod print; diff --git a/tests/it/array/mod.rs b/tests/it/array/mod.rs index 9b367ce7c94..6feb6b905c0 100644 --- a/tests/it/array/mod.rs +++ b/tests/it/array/mod.rs @@ -13,7 +13,6 @@ mod utf8; use arrow2::array::{clone, new_empty_array, new_null_array, Array, PrimitiveArray}; use arrow2::bitmap::Bitmap; -use arrow2::datatypes::PhysicalType::Primitive; use arrow2::datatypes::{DataType, Field}; #[test] diff --git a/tests/it/io/avro/mod.rs b/tests/it/io/avro/mod.rs new file mode 100644 index 00000000000..918582b61ee --- /dev/null +++ b/tests/it/io/avro/mod.rs @@ -0,0 +1,3 @@ +//! Read and write from and to Apache Avro + +mod read; diff --git a/tests/it/io/avro/read/mod.rs b/tests/it/io/avro/read/mod.rs new file mode 100644 index 00000000000..d7b4d3db837 --- /dev/null +++ b/tests/it/io/avro/read/mod.rs @@ -0,0 +1,99 @@ +use std::sync::Arc; + +use avro_rs::types::Record; +use avro_rs::Schema as AvroSchema; +use avro_rs::Writer; + +use arrow2::array::*; +use arrow2::datatypes::*; +use arrow2::error::Result; +use arrow2::io::avro::read; +use arrow2::record_batch::RecordBatch; + +fn schema() -> (AvroSchema, Schema) { + let raw_schema = r#" + { + "type": "record", + "name": "test", + "fields": [ + {"name": "a", "type": "long"}, + {"name": "b", "type": "string"}, + {"name": "c", "type": "int"}, + {"name": "d", "type": "bytes"}, + {"name": "e", "type": "double"}, + {"name": "f", "type": "boolean"}, + {"name": "h", "type": ["null", "string"], "default": null} + ] + } +"#; + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Int32, false), + Field::new("d", DataType::Binary, false), + Field::new("e", DataType::Float64, false), + Field::new("f", DataType::Boolean, false), + Field::new("h", DataType::Utf8, true), + ]); + + (AvroSchema::parse_str(raw_schema).unwrap(), schema) +} + +fn write() -> Result<(Vec, RecordBatch)> { + let (avro, schema) = schema(); + // a writer needs a schema and something to write to + let mut writer = Writer::new(&avro, Vec::new()); + + // the Record type models our Record schema + let mut record = Record::new(writer.schema()).unwrap(); + record.put("a", 27i64); + record.put("b", "foo"); + record.put("c", 1i32); + record.put("d", b"foo".as_ref()); + record.put("e", 1.0f64); + record.put("f", true); + record.put("h", Some("foo")); + writer.append(record)?; + + let mut record = Record::new(writer.schema()).unwrap(); + record.put("b", "bar"); + record.put("a", 47i64); + record.put("c", 1i32); + record.put("d", b"bar".as_ref()); + record.put("e", 2.0f64); + record.put("f", false); + record.put("h", None::<&str>); + writer.append(record)?; + + let columns = vec![ + Arc::new(Int64Array::from_slice([27, 47])) as Arc, + Arc::new(Utf8Array::::from_slice(["foo", "bar"])) as Arc, + Arc::new(Int32Array::from_slice([1, 1])) as Arc, + Arc::new(BinaryArray::::from_slice([b"foo", b"bar"])) as Arc, + Arc::new(PrimitiveArray::::from_slice([1.0, 2.0])) as Arc, + Arc::new(BooleanArray::from_slice([true, false])) as Arc, + Arc::new(Utf8Array::::from([Some("foo"), None])) as Arc, + ]; + + let expected = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); + + Ok((writer.into_inner().unwrap(), expected)) +} + +#[test] +fn read() -> Result<()> { + let (data, expected) = write()?; + + let file = &mut &data[..]; + + let (schema, codec, file_marker) = read::read_metadata(file)?; + + let mut reader = read::Reader::new( + read::Decompressor::new(read::BlockStreamIterator::new(file, file_marker), codec), + Arc::new(schema), + ); + + assert_eq!(reader.next().unwrap().unwrap(), expected); + Ok(()) +} diff --git a/tests/it/io/mod.rs b/tests/it/io/mod.rs index 9ec0112e906..5009e99015f 100644 --- a/tests/it/io/mod.rs +++ b/tests/it/io/mod.rs @@ -10,5 +10,8 @@ mod ipc; #[cfg(feature = "io_parquet")] mod parquet; +#[cfg(feature = "io_avro")] +mod avro; + #[cfg(feature = "io_csv")] mod csv;