diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index e28834b1da2..5a6da2bdf65 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -23,6 +23,14 @@ import arrow_pyarrow_integration_testing +class UuidType(pyarrow.PyExtensionType): + def __init__(self): + super().__init__(pyarrow.binary(16)) + + def __reduce__(self): + return UuidType, () + + class TestCase(unittest.TestCase): def setUp(self): self.old_allocated_rust = ( @@ -179,3 +187,10 @@ def test_field_metadata(self): result = arrow_pyarrow_integration_testing.round_trip_field(field) assert field == result assert field.metadata == result.metadata + + # see https://issues.apache.org/jira/browse/ARROW-13855 + def _test_field_extension(self): + field = pyarrow.field("aa", UuidType()) + result = arrow_pyarrow_integration_testing.round_trip_field(field) + assert field == result + assert field.metadata == result.metadata diff --git a/src/datatypes/field.rs b/src/datatypes/field.rs index d5a66a52507..41ea83d1003 100644 --- a/src/datatypes/field.rs +++ b/src/datatypes/field.rs @@ -277,3 +277,19 @@ impl std::fmt::Display for Field { write!(f, "{:?}", self) } } + +pub(crate) type Metadata = Option>; +pub(crate) type Extension = Option<(String, Option)>; + +pub(crate) fn get_extension(metadata: &Option>) -> Extension { + if let Some(metadata) = metadata { + if let Some(name) = metadata.get("ARROW:extension:name") { + let metadata = metadata.get("ARROW:extension:metadata").cloned(); + Some((name.clone(), metadata)) + } else { + None + } + } else { + None + } +} diff --git a/src/datatypes/mod.rs b/src/datatypes/mod.rs index b7e6a7d3674..eed204cd4a6 100644 --- a/src/datatypes/mod.rs +++ b/src/datatypes/mod.rs @@ -7,6 +7,8 @@ pub use field::Field; pub use physical_type::*; pub use schema::Schema; +pub(crate) use field::{get_extension, Extension, Metadata}; + /// The set of datatypes that are supported by this implementation of Apache Arrow. /// /// The Arrow specification on data types includes some more types. diff --git a/src/ffi/schema.rs b/src/ffi/schema.rs index a3b12a70efc..2ae006e1722 100644 --- a/src/ffi/schema.rs +++ b/src/ffi/schema.rs @@ -1,7 +1,7 @@ use std::{collections::BTreeMap, convert::TryInto, ffi::CStr, ffi::CString, ptr}; use crate::{ - datatypes::{DataType, Field, IntervalUnit, TimeUnit}, + datatypes::{DataType, Extension, Field, IntervalUnit, Metadata, TimeUnit}, error::{ArrowError, Result}, }; @@ -91,7 +91,26 @@ impl Ffi_ArrowSchema { None }; - let metadata = field.metadata().as_ref().map(metadata_to_bytes); + let metadata = field.metadata(); + + let metadata = if let DataType::Extension(name, _, extension_metadata) = field.data_type() { + // append extension information. + let mut metadata = metadata.clone().unwrap_or_default(); + + // metadata + if let Some(extension_metadata) = extension_metadata { + metadata.insert( + "ARROW:extension:metadata".to_string(), + extension_metadata.clone(), + ); + } + + metadata.insert("ARROW:extension:name".to_string(), name.clone()); + + Some(metadata_to_bytes(&metadata)) + } else { + metadata.as_ref().map(metadata_to_bytes) + }; let name = CString::new(name).unwrap(); let format = CString::new(format).unwrap(); @@ -192,7 +211,14 @@ pub fn to_field(schema: &Ffi_ArrowSchema) -> Result { } else { to_data_type(schema)? }; - let metadata = unsafe { metadata_from_bytes(schema.metadata) }; + let (metadata, extension) = unsafe { metadata_from_bytes(schema.metadata) }; + + let data_type = if let Some((name, extension_metadata)) = extension { + DataType::Extension(name, Box::new(data_type), extension_metadata) + } else { + data_type + }; + let mut field = Field::new(schema.name(), data_type, schema.nullable()); field.set_metadata(metadata); Ok(field) @@ -412,17 +438,17 @@ unsafe fn read_bytes(ptr: *const u8, len: usize) -> &'static str { std::str::from_utf8(slice).unwrap() } -unsafe fn metadata_from_bytes( - data: *const ::std::os::raw::c_char, -) -> Option> { +unsafe fn metadata_from_bytes(data: *const ::std::os::raw::c_char) -> (Metadata, Extension) { let mut data = data as *const u8; // u8 = i8 if data.is_null() { - return None; + return (None, None); }; let len = read_ne_i32(data); data = data.add(4); let mut result = BTreeMap::new(); + let mut extension_name = None; + let mut extension_metadata = None; for _ in 0..len { let key_len = read_ne_i32(data) as usize; data = data.add(4); @@ -432,7 +458,18 @@ unsafe fn metadata_from_bytes( data = data.add(4); let value = read_bytes(data, value_len); data = data.add(value_len); - result.insert(key.to_string(), value.to_string()); + match key { + "ARROW:extension:name" => { + extension_name = Some(value.to_string()); + } + "ARROW:extension:metadata" => { + extension_metadata = Some(value.to_string()); + } + _ => { + result.insert(key.to_string(), value.to_string()); + } + }; } - Some(result) + let extension = extension_name.map(|name| (name, extension_metadata)); + (Some(result), extension) } diff --git a/src/io/ipc/convert.rs b/src/io/ipc/convert.rs index ea6c6d98c89..b0765bf5fa8 100644 --- a/src/io/ipc/convert.rs +++ b/src/io/ipc/convert.rs @@ -17,7 +17,9 @@ //! Utilities for converting between IPC types and native Arrow types -use crate::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; +use crate::datatypes::{ + get_extension, DataType, Extension, Field, IntervalUnit, Metadata, Schema, TimeUnit, +}; use crate::endianess::is_native_little_endian; use crate::io::ipc::convert::ipc::UnionMode; @@ -32,9 +34,6 @@ use std::collections::{BTreeMap, HashMap}; use DataType::*; -type Metadata = Option>; -type Extension = Option<(String, Option)>; - pub fn schema_to_fb_offset<'a>( fbb: &mut FlatBufferBuilder<'a>, schema: &Schema, @@ -84,19 +83,6 @@ fn read_metadata(field: &ipc::Field) -> Metadata { } } -pub(crate) fn get_extension(metadata: &Metadata) -> Extension { - if let Some(metadata) = metadata { - if let Some(name) = metadata.get("ARROW:extension:name") { - let metadata = metadata.get("ARROW:extension:metadata").cloned(); - Some((name.clone(), metadata)) - } else { - None - } - } else { - None - } -} - /// Convert an IPC Field to Arrow Field impl<'a> From> for Field { fn from(field: ipc::Field) -> Field { diff --git a/src/io/ipc/mod.rs b/src/io/ipc/mod.rs index 49e22fc18e0..f6b347534c9 100644 --- a/src/io/ipc/mod.rs +++ b/src/io/ipc/mod.rs @@ -13,7 +13,6 @@ mod compression; mod convert; pub use convert::fb_to_schema; -pub(crate) use convert::get_extension; pub use gen::Message::root_as_message; pub mod read; pub mod write; diff --git a/src/io/json_integration/schema.rs b/src/io/json_integration/schema.rs index 09dc46d45ad..12b167401c3 100644 --- a/src/io/json_integration/schema.rs +++ b/src/io/json_integration/schema.rs @@ -25,8 +25,7 @@ use serde_json::{json, Value}; use crate::error::{ArrowError, Result}; -use crate::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; -use crate::io::ipc::get_extension; +use crate::datatypes::{get_extension, DataType, Field, IntervalUnit, Schema, TimeUnit}; pub trait ToJson { /// Generate a JSON representation diff --git a/tests/it/ffi.rs b/tests/it/ffi.rs index 856034c00f9..1bc86f3a638 100644 --- a/tests/it/ffi.rs +++ b/tests/it/ffi.rs @@ -172,3 +172,13 @@ fn schema() -> Result<()> { let field = field.with_metadata(metadata); test_round_trip_schema(field) } + +#[test] +fn extension() -> Result<()> { + let field = Field::new( + "a", + DataType::Extension("a".to_string(), Box::new(DataType::Int32), None), + true, + ); + test_round_trip_schema(field) +}