From 64fece9c47bdfe1a67d0ebfefe7022f484f4335d Mon Sep 17 00:00:00 2001 From: haixuanTao Date: Mon, 15 Jan 2024 22:14:16 +0100 Subject: [PATCH 01/14] Generate the right size of default value for array --- .../ros2-bridge/python/src/typed/mod.rs | 92 ++++++++++--------- 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/libraries/extensions/ros2-bridge/python/src/typed/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/mod.rs index 150897fdb..7fc4ddcd5 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/mod.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/mod.rs @@ -1,10 +1,10 @@ use arrow::{ array::{ make_array, Array, ArrayData, BooleanArray, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, StringArray, StructArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, + Int32Array, Int64Array, Int8Array, ListArray, StringArray, StructArray, UInt16Array, + UInt32Array, UInt64Array, UInt8Array, }, - buffer::Buffer, + buffer::{OffsetBuffer, ScalarBuffer}, compute::concat, datatypes::{DataType, Field}, }; @@ -13,7 +13,7 @@ use dora_ros2_bridge_msg_gen::types::{ MemberType, Message, }; use eyre::{Context, ContextCompat, Result}; -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, sync::Arc, vec}; pub use serialize::TypedValue; @@ -40,7 +40,10 @@ pub fn for_message( .members .iter() .map(|m| { - let default = make_array(default_for_member(m, package_name, messages)?); + let default = make_array( + default_for_member(m, package_name, messages) + .context(format!("Could not create default value for {:#?}", m))?, + ); Result::<_, eyre::Report>::Ok(( Arc::new(Field::new( m.name.clone(), @@ -50,8 +53,8 @@ pub fn for_message( default, )) }) - .collect::>()?; - + .collect::>() + .context("Could not build default value")?; let default_struct: StructArray = default_struct_vec.into(); Ok(TypeInfo { @@ -77,28 +80,36 @@ pub fn default_for_member( Some(_) => eyre::bail!( "there should be only a single default value for non-sequence types" ), - None => default_for_nestable_type(t, package_name, messages)?, + None => default_for_nestable_type(t, package_name, messages, 1)?, }, NestableType::NamedType(_) => { if m.default.is_some() { eyre::bail!("default values for nested types are not supported") } else { - default_for_nestable_type(t, package_name, messages)? + default_for_nestable_type(t, package_name, messages, 1)? } } NestableType::NamespacedType(_) => { - default_for_nestable_type(t, package_name, messages)? + default_for_nestable_type(t, package_name, messages, 1)? } }, - MemberType::Array(array) => { - list_default_values(m, &array.value_type, package_name, messages)? - } + MemberType::Array(array) => list_default_values( + m, + &array.value_type, + package_name, + messages, + Some(array.size), + )?, MemberType::Sequence(seq) => { - list_default_values(m, &seq.value_type, package_name, messages)? - } - MemberType::BoundedSequence(seq) => { - list_default_values(m, &seq.value_type, package_name, messages)? + list_default_values(m, &seq.value_type, package_name, messages, None)? } + MemberType::BoundedSequence(seq) => list_default_values( + m, + &seq.value_type, + package_name, + messages, + Some(seq.max_size), + )?, }; Ok(value) } @@ -107,24 +118,25 @@ fn default_for_nestable_type( t: &NestableType, package_name: &str, messages: &HashMap>, + size: usize, ) -> Result { let empty = HashMap::new(); let package_messages = messages.get(package_name).unwrap_or(&empty); let array = match t { NestableType::BasicType(t) => match t { - BasicType::I8 => Int8Array::from(vec![0]).into(), - BasicType::I16 => Int16Array::from(vec![0]).into(), - BasicType::I32 => Int32Array::from(vec![0]).into(), - BasicType::I64 => Int64Array::from(vec![0]).into(), - BasicType::U8 => UInt8Array::from(vec![0]).into(), - BasicType::U16 => UInt16Array::from(vec![0]).into(), - BasicType::U32 => UInt32Array::from(vec![0]).into(), - BasicType::U64 => UInt64Array::from(vec![0]).into(), - BasicType::F32 => Float32Array::from(vec![0.]).into(), - BasicType::F64 => Float64Array::from(vec![0.]).into(), + BasicType::I8 => Int8Array::from(vec![0; size]).into(), + BasicType::I16 => Int16Array::from(vec![0; size]).into(), + BasicType::I32 => Int32Array::from(vec![0; size]).into(), + BasicType::I64 => Int64Array::from(vec![0; size]).into(), + BasicType::U8 => UInt8Array::from(vec![0; size]).into(), + BasicType::U16 => UInt16Array::from(vec![0; size]).into(), + BasicType::U32 => UInt32Array::from(vec![0; size]).into(), + BasicType::U64 => UInt64Array::from(vec![0; size]).into(), + BasicType::F32 => Float32Array::from(vec![0.; size]).into(), + BasicType::F64 => Float64Array::from(vec![0.; size]).into(), BasicType::Char => StringArray::from(vec![""]).into(), - BasicType::Byte => UInt8Array::from(vec![0u8] as Vec).into(), - BasicType::Bool => BooleanArray::from(vec![false]).into(), + BasicType::Byte => UInt8Array::from(vec![0u8; size]).into(), + BasicType::Bool => BooleanArray::from(vec![false; size]).into(), }, NestableType::GenericString(_) => StringArray::from(vec![""]).into(), NestableType::NamedType(name) => { @@ -230,6 +242,7 @@ fn list_default_values( value_type: &NestableType, package_name: &str, messages: &HashMap>, + size: Option, ) -> Result { let defaults = match &m.default.as_deref() { Some([]) => eyre::bail!("empty default value not supported"), @@ -253,23 +266,18 @@ fn list_default_values( default_values.to_data() } None => { + let size = size.unwrap_or(1); let default_nested_type = - default_for_nestable_type(value_type, package_name, messages)?; - - let value_offsets = Buffer::from_slice_ref([0i64, 1]); + default_for_nestable_type(value_type, package_name, messages, size)?; + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, size as i32])); - let list_data_type = DataType::List(Arc::new(Field::new( - &m.name, + let field = Arc::new(Field::new( + "item", default_nested_type.data_type().clone(), true, - ))); - // Construct a list array from the above two - ArrayData::builder(list_data_type) - .len(1) - .add_buffer(value_offsets) - .add_child_data(default_nested_type) - .build() - .context("Failed to build default list value")? + )); + let list = ListArray::new(field, offsets, make_array(default_nested_type), None); + list.to_data() } }; From f436fafc0770c95485dffede09b7eb578b4389c9 Mon Sep 17 00:00:00 2001 From: haixuanTao Date: Mon, 15 Jan 2024 22:15:41 +0100 Subject: [PATCH 02/14] Serialize by looping over the array This should fix: https://github.com/dora-rs/dora-autoware/discussions/9 --- .../extensions/ros2-bridge/python/src/lib.rs | 5 +- .../ros2-bridge/python/src/typed/serialize.rs | 189 ++++++++++++------ 2 files changed, 136 insertions(+), 58 deletions(-) diff --git a/libraries/extensions/ros2-bridge/python/src/lib.rs b/libraries/extensions/ros2-bridge/python/src/lib.rs index 37ac68a6b..6c8f79c2c 100644 --- a/libraries/extensions/ros2-bridge/python/src/lib.rs +++ b/libraries/extensions/ros2-bridge/python/src/lib.rs @@ -6,7 +6,7 @@ use std::{ use ::dora_ros2_bridge::{ros2_client, rustdds}; use arrow::{ - array::ArrayData, + array::{make_array, ArrayData}, pyarrow::{FromPyArrow, ToPyArrow}, }; use dora_ros2_bridge_msg_gen::types::Message; @@ -52,6 +52,7 @@ impl Ros2Context { ament_prefix_path_parsed.split(':').map(Path::new).collect() } }; + let packages = dora_ros2_bridge_msg_gen::get_packages(&paths) .map_err(|err| eyre!(err)) .context("failed to parse ROS2 message types")?; @@ -209,7 +210,7 @@ impl Ros2Publisher { //// add type info to ensure correct serialization (e.g. struct types //// and map types need to be serialized differently) let typed_value = TypedValue { - value: &value, + value: &make_array(value), type_info: &self.type_info, }; diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize.rs index 15e8f622e..32129f45d 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/serialize.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize.rs @@ -1,4 +1,7 @@ -use arrow::array::ArrayData; +use arrow::array::make_array; +use arrow::array::Array; +use arrow::array::ArrayRef; +use arrow::array::AsArray; use arrow::array::Float32Array; use arrow::array::Float64Array; use arrow::array::Int16Array; @@ -20,7 +23,7 @@ use super::TypeInfo; #[derive(Debug, Clone, PartialEq)] pub struct TypedValue<'a> { - pub value: &'a ArrayData, + pub value: &'a ArrayRef, pub type_info: &'a TypeInfo, } @@ -31,81 +34,154 @@ impl serde::Serialize for TypedValue<'_> { { match &self.type_info.data_type { DataType::UInt8 => { - let uint_array: UInt8Array = self.value.clone().into(); - let number = uint_array.value(0); - serializer.serialize_u8(number) + let uint_array: &UInt8Array = self.value.as_primitive(); + if uint_array.len() == 1 { + let number = uint_array.value(0); + serializer.serialize_u8(number) + } else { + let mut s = serializer.serialize_seq(Some(uint_array.len()))?; + for value in uint_array.iter() { + s.serialize_element(&value)?; + } + s.end() + } } DataType::UInt16 => { - let uint_array: UInt16Array = self.value.clone().into(); - let number = uint_array.value(0); - serializer.serialize_u16(number) + let uint_array: &UInt16Array = self.value.as_primitive(); + if uint_array.len() == 1 { + let number = uint_array.value(0); + serializer.serialize_u16(number) + } else { + let mut s = serializer.serialize_seq(Some(uint_array.len()))?; + for value in uint_array.iter() { + s.serialize_element(&value)?; + } + s.end() + } } DataType::UInt32 => { - let uint_array: UInt32Array = self.value.clone().into(); - let number = uint_array.value(0); - serializer.serialize_u32(number) + let array: &UInt32Array = self.value.as_primitive(); + if array.len() == 1 { + let number = array.value(0); + serializer.serialize_u32(number) + } else { + let mut s = serializer.serialize_seq(Some(array.len()))?; + for value in array.iter() { + s.serialize_element(&value)?; + } + s.end() + } } DataType::UInt64 => { - let uint_array: UInt64Array = self.value.clone().into(); - let number = uint_array.value(0); - serializer.serialize_u64(number) + let array: &UInt64Array = self.value.as_primitive(); + if array.len() == 1 { + let number = array.value(0); + serializer.serialize_u64(number) + } else { + let mut s = serializer.serialize_seq(Some(array.len()))?; + for value in array.iter() { + s.serialize_element(&value)?; + } + s.end() + } } DataType::Int8 => { - let int_array: Int8Array = self.value.clone().into(); - let number = int_array.value(0); - serializer.serialize_i8(number) + let array: &Int8Array = self.value.as_primitive(); + if array.len() == 1 { + let number = array.value(0); + serializer.serialize_i8(number) + } else { + let mut s = serializer.serialize_seq(Some(array.len()))?; + for value in array.iter() { + s.serialize_element(&value)?; + } + s.end() + } } DataType::Int16 => { - let int_array: Int16Array = self.value.clone().into(); - let number = int_array.value(0); - serializer.serialize_i16(number) + let array: &Int16Array = self.value.as_primitive(); + if array.len() == 1 { + let number = array.value(0); + serializer.serialize_i16(number) + } else { + let mut s = serializer.serialize_seq(Some(array.len()))?; + for value in array.iter() { + s.serialize_element(&value)?; + } + s.end() + } } DataType::Int32 => { - let int_array: Int32Array = self.value.clone().into(); - let number = int_array.value(0); - serializer.serialize_i32(number) + let array: &Int32Array = self.value.as_primitive(); + if array.len() == 1 { + let number = array.value(0); + serializer.serialize_i32(number) + } else { + let mut s = serializer.serialize_seq(Some(array.len()))?; + for value in array.iter() { + s.serialize_element(&value)?; + } + s.end() + } } DataType::Int64 => { - let int_array: Int64Array = self.value.clone().into(); - let number = int_array.value(0); - serializer.serialize_i64(number) + let array: &Int64Array = self.value.as_primitive(); + if array.len() == 1 { + let number = array.value(0); + serializer.serialize_i64(number) + } else { + let mut s = serializer.serialize_seq(Some(array.len()))?; + for value in array.iter() { + s.serialize_element(&value)?; + } + s.end() + } } DataType::Float32 => { - let int_array: Float32Array = self.value.clone().into(); - let number = int_array.value(0); - serializer.serialize_f32(number) + let array: &Float32Array = self.value.as_primitive(); + if array.len() == 1 { + let number = array.value(0); + serializer.serialize_f32(number) + } else { + let mut s = serializer.serialize_seq(Some(array.len()))?; + for value in array.iter() { + s.serialize_element(&value)?; + } + s.end() + } } DataType::Float64 => { - let int_array: Float64Array = self.value.clone().into(); - let number = int_array.value(0); - serializer.serialize_f64(number) + let array: &Float64Array = self.value.as_primitive(); + if array.len() == 1 { + let number = array.value(0); + serializer.serialize_f64(number) + } else { + let mut s = serializer.serialize_seq(Some(array.len()))?; + for value in array.iter() { + s.serialize_element(&value)?; + } + s.end() + } } DataType::Utf8 => { - let int_array: StringArray = self.value.clone().into(); + let int_array: &StringArray = self.value.as_string(); let string = int_array.value(0); serializer.serialize_str(string) } DataType::List(field) => { - let list_array: ListArray = self.value.clone().into(); - let values = list_array.values(); - let mut s = serializer.serialize_seq(Some(values.len()))?; + let list_array: &ListArray = self.value.as_list(); + // let values: &UInt16Array = list_array.values().as_primitive(); + let mut s = serializer.serialize_seq(Some(list_array.len()))?; for value in list_array.iter() { - let value = match value { - Some(value) => value.to_data(), - None => { - return Err(serde::ser::Error::custom( - "Value in ListArray is null and not yet supported".to_string(), - )) - } - }; - - s.serialize_element(&TypedValue { - value: &value, - type_info: &TypeInfo { - data_type: field.data_type().clone(), - defaults: self.type_info.defaults.clone(), - }, - })?; + if let Some(value) = value { + s.serialize_element(&TypedValue { + value: &value, + type_info: &TypeInfo { + data_type: field.data_type().clone(), + defaults: self.type_info.defaults.clone(), + }, + })?; + } } s.end() } @@ -121,7 +197,7 @@ impl serde::Serialize for TypedValue<'_> { const DUMMY_STRUCT_NAME: &str = "struct"; const DUMMY_FIELD_NAME: &str = "field"; - let struct_array: StructArray = self.value.clone().into(); + let struct_array: &StructArray = self.value.as_struct(); let mut s = serializer.serialize_struct(DUMMY_STRUCT_NAME, fields.len())?; let defaults: StructArray = self.type_info.defaults.clone().into(); for field in fields.iter() { @@ -134,15 +210,16 @@ impl serde::Serialize for TypedValue<'_> { ))) } }; + let value = make_array(default.clone()); let field_value = match struct_array.column_by_name(field.name()) { - Some(value) => value.to_data(), - None => default.clone(), + Some(value) => value, + None => &value, }; s.serialize_field( DUMMY_FIELD_NAME, &TypedValue { - value: &field_value, + value: &field_value.clone(), type_info: &TypeInfo { data_type: field.data_type().clone(), defaults: default, From 4085ebc46be527b893c54b0e591a0761e0882760 Mon Sep 17 00:00:00 2001 From: haixuanTao Date: Fri, 19 Jan 2024 19:16:25 +0100 Subject: [PATCH 03/14] Make PrimitiveArray only serialize one element at a time --- binaries/runtime/src/operator/mod.rs | 1 + .../ros2-bridge/python/src/typed/serialize.rs | 273 ++++++++++-------- 2 files changed, 161 insertions(+), 113 deletions(-) diff --git a/binaries/runtime/src/operator/mod.rs b/binaries/runtime/src/operator/mod.rs index 474a54c57..b78fb939b 100644 --- a/binaries/runtime/src/operator/mod.rs +++ b/binaries/runtime/src/operator/mod.rs @@ -13,6 +13,7 @@ pub mod channel; mod python; mod shared_lib; +#[allow(unused_variables)] pub fn run_operator( node_id: &NodeId, operator_definition: OperatorDefinition, diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize.rs index 32129f45d..5219489da 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/serialize.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize.rs @@ -34,153 +34,200 @@ impl serde::Serialize for TypedValue<'_> { { match &self.type_info.data_type { DataType::UInt8 => { - let uint_array: &UInt8Array = self.value.as_primitive(); - if uint_array.len() == 1 { - let number = uint_array.value(0); - serializer.serialize_u8(number) - } else { - let mut s = serializer.serialize_seq(Some(uint_array.len()))?; - for value in uint_array.iter() { - s.serialize_element(&value)?; - } - s.end() - } + let array: &UInt8Array = self.value.as_primitive(); + debug_assert!(array.len() == 1, "array length was: {}", array.len()); + let number = array.value(0); + serializer.serialize_u8(number) } DataType::UInt16 => { - let uint_array: &UInt16Array = self.value.as_primitive(); - if uint_array.len() == 1 { - let number = uint_array.value(0); - serializer.serialize_u16(number) - } else { - let mut s = serializer.serialize_seq(Some(uint_array.len()))?; - for value in uint_array.iter() { - s.serialize_element(&value)?; - } - s.end() - } + let array: &UInt16Array = self.value.as_primitive(); + debug_assert!(array.len() == 1); + let number = array.value(0); + serializer.serialize_u16(number) } DataType::UInt32 => { let array: &UInt32Array = self.value.as_primitive(); - if array.len() == 1 { - let number = array.value(0); - serializer.serialize_u32(number) - } else { - let mut s = serializer.serialize_seq(Some(array.len()))?; - for value in array.iter() { - s.serialize_element(&value)?; - } - s.end() - } + debug_assert!(array.len() == 1); + let number = array.value(0); + serializer.serialize_u32(number) } DataType::UInt64 => { let array: &UInt64Array = self.value.as_primitive(); - if array.len() == 1 { - let number = array.value(0); - serializer.serialize_u64(number) - } else { - let mut s = serializer.serialize_seq(Some(array.len()))?; - for value in array.iter() { - s.serialize_element(&value)?; - } - s.end() - } + debug_assert!(array.len() == 1); + let number = array.value(0); + serializer.serialize_u64(number) } DataType::Int8 => { let array: &Int8Array = self.value.as_primitive(); - if array.len() == 1 { - let number = array.value(0); - serializer.serialize_i8(number) - } else { - let mut s = serializer.serialize_seq(Some(array.len()))?; - for value in array.iter() { - s.serialize_element(&value)?; - } - s.end() - } + debug_assert!(array.len() == 1); + let number = array.value(0); + serializer.serialize_i8(number) } DataType::Int16 => { let array: &Int16Array = self.value.as_primitive(); - if array.len() == 1 { - let number = array.value(0); - serializer.serialize_i16(number) - } else { - let mut s = serializer.serialize_seq(Some(array.len()))?; - for value in array.iter() { - s.serialize_element(&value)?; - } - s.end() - } + debug_assert!(array.len() == 1); + let number = array.value(0); + serializer.serialize_i16(number) } DataType::Int32 => { let array: &Int32Array = self.value.as_primitive(); - if array.len() == 1 { - let number = array.value(0); - serializer.serialize_i32(number) - } else { - let mut s = serializer.serialize_seq(Some(array.len()))?; - for value in array.iter() { - s.serialize_element(&value)?; - } - s.end() - } + debug_assert!(array.len() == 1); + let number = array.value(0); + serializer.serialize_i32(number) } DataType::Int64 => { let array: &Int64Array = self.value.as_primitive(); - if array.len() == 1 { - let number = array.value(0); - serializer.serialize_i64(number) - } else { - let mut s = serializer.serialize_seq(Some(array.len()))?; - for value in array.iter() { - s.serialize_element(&value)?; - } - s.end() - } + debug_assert!(array.len() == 1, "array was: {:#?}", array); + let number = array.value(0); + serializer.serialize_i64(number) } DataType::Float32 => { let array: &Float32Array = self.value.as_primitive(); - if array.len() == 1 { - let number = array.value(0); - serializer.serialize_f32(number) - } else { - let mut s = serializer.serialize_seq(Some(array.len()))?; - for value in array.iter() { - s.serialize_element(&value)?; - } - s.end() - } + debug_assert!(array.len() == 1); + let number = array.value(0); + serializer.serialize_f32(number) } DataType::Float64 => { let array: &Float64Array = self.value.as_primitive(); - if array.len() == 1 { - let number = array.value(0); - serializer.serialize_f64(number) - } else { - let mut s = serializer.serialize_seq(Some(array.len()))?; - for value in array.iter() { - s.serialize_element(&value)?; - } - s.end() - } + debug_assert!(array.len() == 1); + let number = array.value(0); + serializer.serialize_f64(number) } DataType::Utf8 => { let int_array: &StringArray = self.value.as_string(); let string = int_array.value(0); serializer.serialize_str(string) } - DataType::List(field) => { + DataType::List(_field) => { let list_array: &ListArray = self.value.as_list(); - // let values: &UInt16Array = list_array.values().as_primitive(); let mut s = serializer.serialize_seq(Some(list_array.len()))?; - for value in list_array.iter() { - if let Some(value) = value { - s.serialize_element(&TypedValue { - value: &value, - type_info: &TypeInfo { - data_type: field.data_type().clone(), - defaults: self.type_info.defaults.clone(), - }, - })?; + for root in list_array.iter() { + if let Some(values) = root { + match values.data_type() { + DataType::UInt8 => { + let values: &UInt8Array = values.as_primitive(); + for value in values.iter() { + if let Some(value) = value { + s.serialize_element(&value)?; + } else { + todo!("Implement null management"); + } + } + } + DataType::UInt16 => { + let values: &UInt16Array = values.as_primitive(); + for value in values.iter() { + if let Some(value) = value { + s.serialize_element(&value)?; + } else { + todo!("Implement null management"); + } + } + } + DataType::UInt32 => { + let values: &UInt32Array = values.as_primitive(); + for value in values.iter() { + if let Some(value) = value { + s.serialize_element(&value)?; + } else { + todo!("Implement null management"); + } + } + } + DataType::UInt64 => { + let values: &UInt64Array = values.as_primitive(); + for value in values.iter() { + if let Some(value) = value { + s.serialize_element(&value)?; + } else { + todo!("Implement null management"); + } + } + } + DataType::Int8 => { + let values: &Int8Array = values.as_primitive(); + for value in values.iter() { + if let Some(value) = value { + s.serialize_element(&value)?; + } else { + todo!("Implement null management"); + } + } + } + DataType::Int16 => { + let values: &Int16Array = values.as_primitive(); + for value in values.iter() { + if let Some(value) = value { + s.serialize_element(&value)?; + } else { + todo!("Implement null management"); + } + } + } + DataType::Int32 => { + let values: &Int32Array = values.as_primitive(); + for value in values.iter() { + if let Some(value) = value { + s.serialize_element(&value)?; + } else { + todo!("Implement null management"); + } + } + } + DataType::Int64 => { + let values: &Int64Array = values.as_primitive(); + for value in values.iter() { + if let Some(value) = value { + s.serialize_element(&value)?; + } else { + todo!("Implement null management"); + } + } + } + DataType::Float32 => { + let values: &Float32Array = values.as_primitive(); + for value in values.iter() { + if let Some(value) = value { + s.serialize_element(&value)?; + } else { + todo!("Implement null management"); + } + } + } + DataType::Float64 => { + let values: &Float64Array = values.as_primitive(); + for value in values.iter() { + if let Some(value) = value { + s.serialize_element(&value)?; + } else { + todo!("Implement null management"); + } + } + } + DataType::Utf8 => { + let values: &StringArray = values.as_string(); + for value in values.iter() { + if let Some(value) = value { + s.serialize_element(&value)?; + } else { + todo!("Implement null management"); + } + } + } + DataType::Struct(_fields) => { + let list_array: ListArray = self.type_info.defaults.clone().into(); + s.serialize_element(&TypedValue { + value: &values, + type_info: &TypeInfo { + data_type: values.data_type().clone(), + defaults: list_array.value(0).to_data(), + }, + })?; + } + op => todo!("Implement additional type: {:?}", op), + } + } else { + todo!("Implement null management"); } } s.end() From 13273824e1fe5a27c98ab1a891a687ce7faba2af Mon Sep 17 00:00:00 2001 From: Philipp Oppermann Date: Sat, 20 Jan 2024 16:44:03 +0100 Subject: [PATCH 04/14] Rework python ROS2 (de)serialization using parsed ROS2 messages directly Use ROS2 types as type info to ensure that serialization type matches exactly. This is necessary because the types need to match exactly, otherwise the serialization or deserialization will result in invalid data. Using arrow schemas to specify the ROS2 message types does not work because not all ROS2 message types can be represented. For example, ROS2 serializes arrays with known length differently than sequences with dynamic length. --- .../extensions/ros2-bridge/python/src/lib.rs | 22 +- .../python/src/typed/deserialize.rs | 397 ------------------ .../python/src/typed/deserialize/array.rs | 17 + .../python/src/typed/deserialize/mod.rs | 292 +++++++++++++ .../python/src/typed/deserialize/primitive.rs | 155 +++++++ .../python/src/typed/deserialize/sequence.rs | 91 ++++ .../python/src/typed/deserialize/string.rs | 44 ++ .../ros2-bridge/python/src/typed/mod.rs | 294 +------------ .../ros2-bridge/python/src/typed/serialize.rs | 282 ------------- .../python/src/typed/serialize/array.rs | 186 ++++++++ .../python/src/typed/serialize/defaults.rs | 237 +++++++++++ .../python/src/typed/serialize/mod.rs | 181 ++++++++ .../python/src/typed/serialize/primitive.rs | 79 ++++ .../python/src/typed/serialize/sequence.rs | 152 +++++++ 14 files changed, 1463 insertions(+), 966 deletions(-) delete mode 100644 libraries/extensions/ros2-bridge/python/src/typed/deserialize.rs create mode 100644 libraries/extensions/ros2-bridge/python/src/typed/deserialize/array.rs create mode 100644 libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs create mode 100644 libraries/extensions/ros2-bridge/python/src/typed/deserialize/primitive.rs create mode 100644 libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs create mode 100644 libraries/extensions/ros2-bridge/python/src/typed/deserialize/string.rs delete mode 100644 libraries/extensions/ros2-bridge/python/src/typed/serialize.rs create mode 100644 libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs create mode 100644 libraries/extensions/ros2-bridge/python/src/typed/serialize/defaults.rs create mode 100644 libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs create mode 100644 libraries/extensions/ros2-bridge/python/src/typed/serialize/primitive.rs create mode 100644 libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs diff --git a/libraries/extensions/ros2-bridge/python/src/lib.rs b/libraries/extensions/ros2-bridge/python/src/lib.rs index 6c8f79c2c..9d06bbaf5 100644 --- a/libraries/extensions/ros2-bridge/python/src/lib.rs +++ b/libraries/extensions/ros2-bridge/python/src/lib.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Cow, collections::HashMap, path::{Path, PathBuf}, sync::Arc, @@ -17,7 +18,7 @@ use pyo3::{ types::{PyDict, PyList, PyModule}, PyAny, PyObject, PyResult, Python, }; -use typed::{deserialize::TypedDeserializer, for_message, TypeInfo, TypedValue}; +use typed::{deserialize::StructDeserializer, TypeInfo, TypedValue}; pub mod qos; pub mod typed; @@ -112,10 +113,11 @@ impl Ros2Node { let topic = self .node .create_topic(&topic_name, message_type_name, &qos.into())?; - let type_info = - for_message(&self.messages, namespace_name, message_name).with_context(|| { - format!("failed to determine type info for message {namespace_name}/{message_name}") - })?; + let type_info = TypeInfo { + package_name: namespace_name.to_owned().into(), + message_name: message_name.to_owned().into(), + messages: self.messages.clone(), + }; Ok(Ros2Topic { topic, type_info }) } @@ -144,7 +146,7 @@ impl Ros2Node { .create_subscription(&topic.topic, qos.map(Into::into))?; Ok(Ros2Subscription { subscription: Some(subscription), - deserializer: TypedDeserializer::new(topic.type_info.clone()), + deserializer: StructDeserializer::new(Cow::Owned(topic.type_info.clone())), }) } } @@ -176,14 +178,14 @@ impl From for ros2_client::NodeOptions { #[non_exhaustive] pub struct Ros2Topic { topic: rustdds::Topic, - type_info: TypeInfo, + type_info: TypeInfo<'static>, } #[pyclass] #[non_exhaustive] pub struct Ros2Publisher { publisher: ros2_client::Publisher>, - type_info: TypeInfo, + type_info: TypeInfo<'static>, } #[pymethods] @@ -225,7 +227,7 @@ impl Ros2Publisher { #[pyclass] #[non_exhaustive] pub struct Ros2Subscription { - deserializer: TypedDeserializer, + deserializer: StructDeserializer<'static>, subscription: Option>, } @@ -264,7 +266,7 @@ impl Ros2Subscription { } pub struct Ros2SubscriptionStream { - deserializer: TypedDeserializer, + deserializer: StructDeserializer<'static>, subscription: ros2_client::Subscription, } diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize.rs deleted file mode 100644 index 30ba81fd7..000000000 --- a/libraries/extensions/ros2-bridge/python/src/typed/deserialize.rs +++ /dev/null @@ -1,397 +0,0 @@ -use super::TypeInfo; -use arrow::{ - array::{ - make_array, Array, ArrayData, BooleanBuilder, Float32Builder, Float64Builder, Int16Builder, - Int32Builder, Int64Builder, Int8Builder, ListArray, NullArray, StringBuilder, StructArray, - UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder, - }, - buffer::OffsetBuffer, - compute::concat, - datatypes::{DataType, Field, Fields}, -}; -use core::fmt; -use std::sync::Arc; -#[derive(Debug, Clone, PartialEq)] -pub struct TypedDeserializer { - type_info: TypeInfo, -} - -impl TypedDeserializer { - pub fn new(type_info: TypeInfo) -> Self { - Self { type_info } - } -} - -impl<'de> serde::de::DeserializeSeed<'de> for TypedDeserializer { - type Value = ArrayData; - - fn deserialize(self, deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let data_type = self.type_info.data_type; - let value = match data_type.clone() { - DataType::Struct(fields) => { - /// Serde requires that struct and field names are known at - /// compile time with a `'static` lifetime, which is not - /// possible in this case. Thus, we need to use dummy names - /// instead. - /// - /// The actual names do not really matter because - /// the CDR format of ROS2 does not encode struct or field - /// names. - const DUMMY_STRUCT_NAME: &str = "struct"; - const DUMMY_FIELDS: &[&str] = &[""; 100]; - - deserializer.deserialize_struct( - DUMMY_STRUCT_NAME, - &DUMMY_FIELDS[..fields.len()], - StructVisitor { - fields, - defaults: self.type_info.defaults, - }, - ) - } - DataType::List(field) => deserializer.deserialize_seq(ListVisitor { - field, - defaults: self.type_info.defaults, - }), - DataType::UInt8 => deserializer.deserialize_u8(PrimitiveValueVisitor), - DataType::UInt16 => deserializer.deserialize_u16(PrimitiveValueVisitor), - DataType::UInt32 => deserializer.deserialize_u32(PrimitiveValueVisitor), - DataType::UInt64 => deserializer.deserialize_u64(PrimitiveValueVisitor), - DataType::Int8 => deserializer.deserialize_i8(PrimitiveValueVisitor), - DataType::Int16 => deserializer.deserialize_i16(PrimitiveValueVisitor), - DataType::Int32 => deserializer.deserialize_i32(PrimitiveValueVisitor), - DataType::Int64 => deserializer.deserialize_i64(PrimitiveValueVisitor), - DataType::Float32 => deserializer.deserialize_f32(PrimitiveValueVisitor), - DataType::Float64 => deserializer.deserialize_f64(PrimitiveValueVisitor), - DataType::Utf8 => deserializer.deserialize_str(PrimitiveValueVisitor), - _ => todo!(), - }?; - - debug_assert!( - value.data_type() == &data_type, - "Datatype does not correspond to default data type.\n Expected: {:#?} \n but got: {:#?}, with value: {:#?}", data_type, value.data_type(), value - ); - - Ok(value) - } -} - -/// Based on https://docs.rs/serde_yaml/0.9.22/src/serde_yaml/value/de.rs.html#14-121 -struct PrimitiveValueVisitor; - -impl<'de> serde::de::Visitor<'de> for PrimitiveValueVisitor { - type Value = ArrayData; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a primitive value") - } - - fn visit_bool(self, b: bool) -> Result - where - E: serde::de::Error, - { - let mut array = BooleanBuilder::new(); - array.append_value(b); - Ok(array.finish().into()) - } - - fn visit_i8(self, u: i8) -> Result - where - E: serde::de::Error, - { - let mut array = Int8Builder::new(); - array.append_value(u); - Ok(array.finish().into()) - } - - fn visit_i16(self, u: i16) -> Result - where - E: serde::de::Error, - { - let mut array = Int16Builder::new(); - array.append_value(u); - Ok(array.finish().into()) - } - fn visit_i32(self, u: i32) -> Result - where - E: serde::de::Error, - { - let mut array = Int32Builder::new(); - array.append_value(u); - Ok(array.finish().into()) - } - fn visit_i64(self, i: i64) -> Result - where - E: serde::de::Error, - { - let mut array = Int64Builder::new(); - array.append_value(i); - Ok(array.finish().into()) - } - - fn visit_u8(self, u: u8) -> Result - where - E: serde::de::Error, - { - let mut array = UInt8Builder::new(); - array.append_value(u); - Ok(array.finish().into()) - } - fn visit_u16(self, u: u16) -> Result - where - E: serde::de::Error, - { - let mut array = UInt16Builder::new(); - array.append_value(u); - Ok(array.finish().into()) - } - fn visit_u32(self, u: u32) -> Result - where - E: serde::de::Error, - { - let mut array = UInt32Builder::new(); - array.append_value(u); - Ok(array.finish().into()) - } - fn visit_u64(self, u: u64) -> Result - where - E: serde::de::Error, - { - let mut array = UInt64Builder::new(); - array.append_value(u); - Ok(array.finish().into()) - } - - fn visit_f32(self, f: f32) -> Result - where - E: serde::de::Error, - { - let mut array = Float32Builder::new(); - array.append_value(f); - Ok(array.finish().into()) - } - - fn visit_f64(self, f: f64) -> Result - where - E: serde::de::Error, - { - let mut array = Float64Builder::new(); - array.append_value(f); - Ok(array.finish().into()) - } - - fn visit_str(self, s: &str) -> Result - where - E: serde::de::Error, - { - let mut array = StringBuilder::new(); - array.append_value(s); - Ok(array.finish().into()) - } - - fn visit_string(self, s: String) -> Result - where - E: serde::de::Error, - { - let mut array = StringBuilder::new(); - array.append_value(s); - Ok(array.finish().into()) - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - let array = NullArray::new(0); - Ok(array.into()) - } - - fn visit_none(self) -> Result - where - E: serde::de::Error, - { - let array = NullArray::new(0); - Ok(array.into()) - } -} - -struct StructVisitor { - fields: Fields, - defaults: ArrayData, -} - -impl<'de> serde::de::Visitor<'de> for StructVisitor { - type Value = ArrayData; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a struct encoded as sequence") - } - - fn visit_seq(self, mut data: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { - let mut fields = vec![]; - let defaults: StructArray = self.defaults.clone().into(); - for field in self.fields.iter() { - let default = match defaults.column_by_name(field.name()) { - Some(value) => value.clone(), - None => { - return Err(serde::de::Error::custom(format!( - "missing field {} for deserialization", - &field.name() - ))) - } - }; - let value = match data.next_element_seed(TypedDeserializer { - type_info: TypeInfo { - data_type: field.data_type().clone(), - defaults: default.to_data(), - }, - })? { - Some(value) => make_array(value), - None => default, - }; - fields.push(( - // Recreate a new field as List(UInt8) can be converted to UInt8 - Arc::new(Field::new(field.name(), value.data_type().clone(), true)), - value, - )); - } - - let struct_array: StructArray = fields.into(); - - Ok(struct_array.into()) - } -} - -struct ListVisitor { - field: Arc, - defaults: ArrayData, -} - -impl<'de> serde::de::Visitor<'de> for ListVisitor { - type Value = ArrayData; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("an array encoded as sequence") - } - - fn visit_seq(self, mut data: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { - let data = match self.field.data_type().clone() { - DataType::UInt8 => { - let mut array = UInt8Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::UInt16 => { - let mut array = UInt16Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::UInt32 => { - let mut array = UInt32Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::UInt64 => { - let mut array = UInt64Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::Int8 => { - let mut array = Int8Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::Int16 => { - let mut array = Int16Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::Int32 => { - let mut array = Int32Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::Int64 => { - let mut array = Int64Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::Float32 => { - let mut array = Float32Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::Float64 => { - let mut array = Float64Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::Utf8 => { - let mut array = StringBuilder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - _ => { - let mut buffer = vec![]; - while let Some(value) = data.next_element_seed(TypedDeserializer { - type_info: TypeInfo { - data_type: self.field.data_type().clone(), - defaults: self.defaults.clone(), - }, - })? { - let element = make_array(value); - buffer.push(element); - } - - concat( - buffer - .iter() - .map(|data| data.as_ref()) - .collect::>() - .as_slice(), - ) - .map(|op| op.to_data()) - } - }; - - if let Ok(values) = data { - let offsets = OffsetBuffer::new(vec![0, values.len() as i32].into()); - - let array = ListArray::new(self.field, offsets, make_array(values), None).to_data(); - Ok(array) - } else { - Ok(self.defaults) // TODO: Better handle deserialization error - } - } -} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/array.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/array.rs new file mode 100644 index 000000000..79030abe7 --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/array.rs @@ -0,0 +1,17 @@ +use arrow::array::ArrayData; +use dora_ros2_bridge_msg_gen::types::sequences; + +use super::sequence::SequenceVisitor; + +pub struct ArrayDeserializer<'a>(pub &'a sequences::Array); + +impl<'de> serde::de::DeserializeSeed<'de> for ArrayDeserializer<'_> { + type Value = ArrayData; + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_tuple(self.0.size, SequenceVisitor(&self.0.value_type)) + } +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs new file mode 100644 index 000000000..93a94eee1 --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs @@ -0,0 +1,292 @@ +use super::{TypeInfo, DUMMY_STRUCT_NAME}; +use arrow::{ + array::{make_array, ArrayData, StructArray}, + datatypes::Field, +}; +use core::fmt; +use std::{borrow::Cow, collections::HashMap, fmt::Display, sync::Arc}; + +mod array; +mod primitive; +mod sequence; +mod string; + +#[derive(Debug, Clone)] +pub struct StructDeserializer<'a> { + type_info: Cow<'a, TypeInfo<'a>>, +} + +impl<'a> StructDeserializer<'a> { + pub fn new(type_info: Cow<'a, TypeInfo<'a>>) -> Self { + Self { type_info } + } +} + +impl<'de> serde::de::DeserializeSeed<'de> for StructDeserializer<'_> { + type Value = ArrayData; + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let empty = HashMap::new(); + let package_messages = self + .type_info + .messages + .get(self.type_info.package_name.as_ref()) + .unwrap_or(&empty); + let message = package_messages + .get(self.type_info.message_name.as_ref()) + .ok_or_else(|| { + error(format!( + "could not find message type {}::{}", + self.type_info.package_name, self.type_info.message_name + )) + })?; + + let visitor = StructVisitor { + type_info: self.type_info.as_ref(), + }; + deserializer.deserialize_tuple_struct(DUMMY_STRUCT_NAME, message.members.len(), visitor) + } +} + +struct StructVisitor<'a> { + type_info: &'a TypeInfo<'a>, +} + +impl<'a, 'de> serde::de::Visitor<'de> for StructVisitor<'a> { + type Value = ArrayData; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a struct encoded as sequence") + } + + fn visit_seq(self, mut data: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let empty = HashMap::new(); + let package_messages = self + .type_info + .messages + .get(self.type_info.package_name.as_ref()) + .unwrap_or(&empty); + let message = package_messages + .get(self.type_info.message_name.as_ref()) + .ok_or_else(|| { + error(format!( + "could not find message type {}::{}", + self.type_info.package_name, self.type_info.message_name + )) + })?; + + let mut fields = vec![]; + for member in &message.members { + let value = match &member.r#type { + dora_ros2_bridge_msg_gen::types::MemberType::NestableType(t) => match t { + dora_ros2_bridge_msg_gen::types::primitives::NestableType::BasicType(t) => { + data.next_element_seed(primitive::PrimitiveDeserializer(t))? + } + dora_ros2_bridge_msg_gen::types::primitives::NestableType::NamedType(name) => { + data.next_element_seed(StructDeserializer { + type_info: Cow::Owned(TypeInfo { + package_name: Cow::Borrowed(&self.type_info.package_name), + message_name: Cow::Borrowed(&name.0), + messages: self.type_info.messages.clone(), + }), + })? + } + dora_ros2_bridge_msg_gen::types::primitives::NestableType::NamespacedType( + reference, + ) => { + if reference.namespace != "msg" { + return Err(error(format!( + "struct field {} references non-message type {reference:?}", + member.name + ))); + } + data.next_element_seed(StructDeserializer { + type_info: Cow::Owned(TypeInfo { + package_name: Cow::Borrowed(&reference.package), + message_name: Cow::Borrowed(&reference.name), + messages: self.type_info.messages.clone(), + }), + })? + } + dora_ros2_bridge_msg_gen::types::primitives::NestableType::GenericString(t) => { + match t { + dora_ros2_bridge_msg_gen::types::primitives::GenericString::String | dora_ros2_bridge_msg_gen::types::primitives::GenericString::BoundedString(_)=> { + data.next_element_seed(string::StringDeserializer)? + }, + dora_ros2_bridge_msg_gen::types::primitives::GenericString::WString => todo!("deserialize WString"), + dora_ros2_bridge_msg_gen::types::primitives::GenericString::BoundedWString(_) => todo!("deserialize BoundedWString"), + } + } + }, + dora_ros2_bridge_msg_gen::types::MemberType::Array(a) => { + data.next_element_seed(array::ArrayDeserializer(a))? + }, + dora_ros2_bridge_msg_gen::types::MemberType::Sequence(s) => { + data.next_element_seed(sequence::SequenceDeserializer(&s.value_type))? + }, + dora_ros2_bridge_msg_gen::types::MemberType::BoundedSequence(s) => { + data.next_element_seed(sequence::SequenceDeserializer(&s.value_type))? + }, + }; + + let value = value.ok_or_else(|| { + error(format!( + "struct member {} not present in message", + member.name + )) + })?; + + fields.push(( + // Recreate a new field as List(UInt8) can be converted to UInt8 + Arc::new(Field::new(&member.name, value.data_type().clone(), true)), + make_array(value), + )); + } + + let struct_array: StructArray = fields.into(); + + Ok(struct_array.into()) + } +} + +// struct ListVisitor { +// field: Arc, +// defaults: ArrayData, +// } + +// impl<'de> serde::de::Visitor<'de> for ListVisitor { +// type Value = ArrayData; + +// fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { +// formatter.write_str("an array encoded as sequence") +// } + +// fn visit_seq(self, mut data: A) -> Result +// where +// A: serde::de::SeqAccess<'de>, +// { +// let data = match self.field.data_type().clone() { +// DataType::UInt8 => { +// let mut array = UInt8Builder::new(); +// while let Some(value) = data.next_element::()? { +// array.append_value(value); +// } +// Ok(array.finish().into()) +// } +// DataType::UInt16 => { +// let mut array = UInt16Builder::new(); +// while let Some(value) = data.next_element::()? { +// array.append_value(value); +// } +// Ok(array.finish().into()) +// } +// DataType::UInt32 => { +// let mut array = UInt32Builder::new(); +// while let Some(value) = data.next_element::()? { +// array.append_value(value); +// } +// Ok(array.finish().into()) +// } +// DataType::UInt64 => { +// let mut array = UInt64Builder::new(); +// while let Some(value) = data.next_element::()? { +// array.append_value(value); +// } +// Ok(array.finish().into()) +// } +// DataType::Int8 => { +// let mut array = Int8Builder::new(); +// while let Some(value) = data.next_element::()? { +// array.append_value(value); +// } +// Ok(array.finish().into()) +// } +// DataType::Int16 => { +// let mut array = Int16Builder::new(); +// while let Some(value) = data.next_element::()? { +// array.append_value(value); +// } +// Ok(array.finish().into()) +// } +// DataType::Int32 => { +// let mut array = Int32Builder::new(); +// while let Some(value) = data.next_element::()? { +// array.append_value(value); +// } +// Ok(array.finish().into()) +// } +// DataType::Int64 => { +// let mut array = Int64Builder::new(); +// while let Some(value) = data.next_element::()? { +// array.append_value(value); +// } +// Ok(array.finish().into()) +// } +// DataType::Float32 => { +// let mut array = Float32Builder::new(); +// while let Some(value) = data.next_element::()? { +// array.append_value(value); +// } +// Ok(array.finish().into()) +// } +// DataType::Float64 => { +// let mut array = Float64Builder::new(); +// while let Some(value) = data.next_element::()? { +// array.append_value(value); +// } +// Ok(array.finish().into()) +// } +// DataType::Utf8 => { +// let mut array = StringBuilder::new(); +// while let Some(value) = data.next_element::()? { +// array.append_value(value); +// } +// Ok(array.finish().into()) +// } +// _ => { +// let mut buffer = vec![]; +// while let Some(value) = data.next_element_seed(StructDeserializer { +// type_info: TypeInfo { +// data_type: self.field.data_type().clone(), +// defaults: self.defaults.clone(), +// }, +// })? { +// let element = make_array(value); +// buffer.push(element); +// } + +// concat( +// buffer +// .iter() +// .map(|data| data.as_ref()) +// .collect::>() +// .as_slice(), +// ) +// .map(|op| op.to_data()) +// } +// }; + +// if let Ok(values) = data { +// let offsets = OffsetBuffer::new(vec![0, values.len() as i32].into()); + +// let array = ListArray::new(self.field, offsets, make_array(values), None).to_data(); +// Ok(array) +// } else { +// Ok(self.defaults) // TODO: Better handle deserialization error +// } +// } +// } + +fn error(e: T) -> E +where + T: Display, + E: serde::de::Error, +{ + serde::de::Error::custom(e) +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/primitive.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/primitive.rs new file mode 100644 index 000000000..7f13b575f --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/primitive.rs @@ -0,0 +1,155 @@ +use arrow::array::{ + ArrayData, BooleanBuilder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, + Int64Builder, Int8Builder, NullArray, UInt16Builder, UInt32Builder, UInt64Builder, + UInt8Builder, +}; +use core::fmt; +use dora_ros2_bridge_msg_gen::types::primitives::BasicType; + +pub struct PrimitiveDeserializer<'a>(pub &'a BasicType); + +impl<'de> serde::de::DeserializeSeed<'de> for PrimitiveDeserializer<'_> { + type Value = ArrayData; + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + match self.0 { + BasicType::I8 => deserializer.deserialize_i8(PrimitiveValueVisitor), + BasicType::I16 => deserializer.deserialize_i16(PrimitiveValueVisitor), + BasicType::I32 => deserializer.deserialize_i32(PrimitiveValueVisitor), + BasicType::I64 => deserializer.deserialize_i64(PrimitiveValueVisitor), + BasicType::U8 | BasicType::Char | BasicType::Byte => { + deserializer.deserialize_u8(PrimitiveValueVisitor) + } + BasicType::U16 => deserializer.deserialize_u16(PrimitiveValueVisitor), + BasicType::U32 => deserializer.deserialize_u32(PrimitiveValueVisitor), + BasicType::U64 => deserializer.deserialize_u64(PrimitiveValueVisitor), + BasicType::F32 => deserializer.deserialize_f32(PrimitiveValueVisitor), + BasicType::F64 => deserializer.deserialize_f64(PrimitiveValueVisitor), + BasicType::Bool => deserializer.deserialize_bool(PrimitiveValueVisitor), + } + } +} + +/// Based on https://docs.rs/serde_yaml/0.9.22/src/serde_yaml/value/de.rs.html#14-121 +struct PrimitiveValueVisitor; + +impl<'de> serde::de::Visitor<'de> for PrimitiveValueVisitor { + type Value = ArrayData; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a primitive value") + } + + fn visit_bool(self, b: bool) -> Result + where + E: serde::de::Error, + { + let mut array = BooleanBuilder::new(); + array.append_value(b); + Ok(array.finish().into()) + } + + fn visit_i8(self, u: i8) -> Result + where + E: serde::de::Error, + { + let mut array = Int8Builder::new(); + array.append_value(u); + Ok(array.finish().into()) + } + + fn visit_i16(self, u: i16) -> Result + where + E: serde::de::Error, + { + let mut array = Int16Builder::new(); + array.append_value(u); + Ok(array.finish().into()) + } + fn visit_i32(self, u: i32) -> Result + where + E: serde::de::Error, + { + let mut array = Int32Builder::new(); + array.append_value(u); + Ok(array.finish().into()) + } + fn visit_i64(self, i: i64) -> Result + where + E: serde::de::Error, + { + let mut array = Int64Builder::new(); + array.append_value(i); + Ok(array.finish().into()) + } + + fn visit_u8(self, u: u8) -> Result + where + E: serde::de::Error, + { + let mut array = UInt8Builder::new(); + array.append_value(u); + Ok(array.finish().into()) + } + fn visit_u16(self, u: u16) -> Result + where + E: serde::de::Error, + { + let mut array = UInt16Builder::new(); + array.append_value(u); + Ok(array.finish().into()) + } + fn visit_u32(self, u: u32) -> Result + where + E: serde::de::Error, + { + let mut array = UInt32Builder::new(); + array.append_value(u); + Ok(array.finish().into()) + } + fn visit_u64(self, u: u64) -> Result + where + E: serde::de::Error, + { + let mut array = UInt64Builder::new(); + array.append_value(u); + Ok(array.finish().into()) + } + + fn visit_f32(self, f: f32) -> Result + where + E: serde::de::Error, + { + let mut array = Float32Builder::new(); + array.append_value(f); + Ok(array.finish().into()) + } + + fn visit_f64(self, f: f64) -> Result + where + E: serde::de::Error, + { + let mut array = Float64Builder::new(); + array.append_value(f); + Ok(array.finish().into()) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + let array = NullArray::new(0); + Ok(array.into()) + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + let array = NullArray::new(0); + Ok(array.into()) + } +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs new file mode 100644 index 000000000..473950129 --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs @@ -0,0 +1,91 @@ +use arrow::{ + array::{ArrayData, BooleanBuilder, PrimitiveBuilder, StringBuilder}, + datatypes::{self, ArrowPrimitiveType}, +}; +use core::fmt; +use dora_ros2_bridge_msg_gen::types::primitives::{self, BasicType, NestableType}; +use serde::Deserialize; + +pub struct SequenceDeserializer<'a>(pub &'a NestableType); + +impl<'de> serde::de::DeserializeSeed<'de> for SequenceDeserializer<'_> { + type Value = ArrayData; + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_seq(SequenceVisitor(self.0)) + } +} + +pub struct SequenceVisitor<'a>(pub &'a NestableType); + +impl<'de> serde::de::Visitor<'de> for SequenceVisitor<'_> { + type Value = ArrayData; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a sequence") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + match &self.0 { + NestableType::BasicType(t) => match t { + BasicType::I8 => deserialize_primitive_seq::<_, datatypes::Int8Type>(seq), + BasicType::I16 => deserialize_primitive_seq::<_, datatypes::Int16Type>(seq), + BasicType::I32 => deserialize_primitive_seq::<_, datatypes::Int32Type>(seq), + BasicType::I64 => deserialize_primitive_seq::<_, datatypes::Int64Type>(seq), + BasicType::U8 | BasicType::Char | BasicType::Byte => { + deserialize_primitive_seq::<_, datatypes::UInt8Type>(seq) + } + BasicType::U16 => deserialize_primitive_seq::<_, datatypes::UInt16Type>(seq), + BasicType::U32 => deserialize_primitive_seq::<_, datatypes::UInt32Type>(seq), + BasicType::U64 => deserialize_primitive_seq::<_, datatypes::UInt64Type>(seq), + BasicType::F32 => deserialize_primitive_seq::<_, datatypes::Float32Type>(seq), + BasicType::F64 => deserialize_primitive_seq::<_, datatypes::Float64Type>(seq), + BasicType::Bool => { + let mut array = BooleanBuilder::new(); + while let Some(value) = seq.next_element()? { + array.append_value(value); + } + Ok(array.finish().into()) + } + }, + NestableType::NamedType(_) => todo!("deserialize sequence of NestableType::NamedType"), + NestableType::NamespacedType(_) => { + todo!("deserialize sequence of NestableType::NamedspacedType") + } + NestableType::GenericString(t) => match t { + primitives::GenericString::String | primitives::GenericString::BoundedString(_) => { + let mut array = StringBuilder::new(); + while let Some(value) = seq.next_element::()? { + array.append_value(value); + } + Ok(array.finish().into()) + } + primitives::GenericString::WString => todo!("deserialize sequence of WString"), + primitives::GenericString::BoundedWString(_) => { + todo!("deserialize sequence of BoundedWString") + } + }, + } + } +} + +fn deserialize_primitive_seq<'de, S, T>( + mut seq: S, +) -> Result>::Error> +where + S: serde::de::SeqAccess<'de>, + T: ArrowPrimitiveType, + T::Native: Deserialize<'de>, +{ + let mut array = PrimitiveBuilder::::new(); + while let Some(value) = seq.next_element::()? { + array.append_value(value); + } + Ok(array.finish().into()) +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/string.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/string.rs new file mode 100644 index 000000000..646ea38d3 --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/string.rs @@ -0,0 +1,44 @@ +use arrow::array::{ArrayData, StringBuilder}; +use core::fmt; + +pub struct StringDeserializer; + +impl<'de> serde::de::DeserializeSeed<'de> for StringDeserializer { + type Value = ArrayData; + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_str(StringVisitor) + } +} + +/// Based on https://docs.rs/serde_yaml/0.9.22/src/serde_yaml/value/de.rs.html#14-121 +struct StringVisitor; + +impl<'de> serde::de::Visitor<'de> for StringVisitor { + type Value = ArrayData; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string value") + } + + fn visit_str(self, s: &str) -> Result + where + E: serde::de::Error, + { + let mut array = StringBuilder::new(); + array.append_value(s); + Ok(array.finish().into()) + } + + fn visit_string(self, s: String) -> Result + where + E: serde::de::Error, + { + let mut array = StringBuilder::new(); + array.append_value(s); + Ok(array.finish().into()) + } +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/mod.rs index 7fc4ddcd5..b8b76893f 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/mod.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/mod.rs @@ -1,285 +1,25 @@ -use arrow::{ - array::{ - make_array, Array, ArrayData, BooleanArray, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, ListArray, StringArray, StructArray, UInt16Array, - UInt32Array, UInt64Array, UInt8Array, - }, - buffer::{OffsetBuffer, ScalarBuffer}, - compute::concat, - datatypes::{DataType, Field}, -}; -use dora_ros2_bridge_msg_gen::types::{ - primitives::{BasicType, NestableType}, - MemberType, Message, -}; -use eyre::{Context, ContextCompat, Result}; -use std::{collections::HashMap, sync::Arc, vec}; +use dora_ros2_bridge_msg_gen::types::Message; +use std::{borrow::Cow, collections::HashMap, sync::Arc}; pub use serialize::TypedValue; pub mod deserialize; pub mod serialize; -#[derive(Debug, Clone, PartialEq)] -pub struct TypeInfo { - data_type: DataType, - defaults: ArrayData, +#[derive(Debug, Clone)] +pub struct TypeInfo<'a> { + pub package_name: Cow<'a, str>, + pub message_name: Cow<'a, str>, + pub messages: Arc>>, } -pub fn for_message( - messages: &HashMap>, - package_name: &str, - message_name: &str, -) -> eyre::Result { - let empty = HashMap::new(); - let package_messages = messages.get(package_name).unwrap_or(&empty); - let message = package_messages - .get(message_name) - .context("unknown type name")?; - let default_struct_vec: Vec<(Arc, Arc)> = message - .members - .iter() - .map(|m| { - let default = make_array( - default_for_member(m, package_name, messages) - .context(format!("Could not create default value for {:#?}", m))?, - ); - Result::<_, eyre::Report>::Ok(( - Arc::new(Field::new( - m.name.clone(), - default.data_type().clone(), - true, - )), - default, - )) - }) - .collect::>() - .context("Could not build default value")?; - let default_struct: StructArray = default_struct_vec.into(); - - Ok(TypeInfo { - data_type: default_struct.data_type().clone(), - defaults: default_struct.into(), - }) -} - -pub fn default_for_member( - m: &dora_ros2_bridge_msg_gen::types::Member, - package_name: &str, - messages: &HashMap>, -) -> eyre::Result { - let value = match &m.r#type { - MemberType::NestableType(t) => match t { - NestableType::BasicType(_) | NestableType::GenericString(_) => match &m - .default - .as_deref() - { - Some([]) => eyre::bail!("empty default value not supported"), - Some([default]) => preset_default_for_basic_type(t, default) - .with_context(|| format!("failed to parse default value for `{}`", m.name))?, - Some(_) => eyre::bail!( - "there should be only a single default value for non-sequence types" - ), - None => default_for_nestable_type(t, package_name, messages, 1)?, - }, - NestableType::NamedType(_) => { - if m.default.is_some() { - eyre::bail!("default values for nested types are not supported") - } else { - default_for_nestable_type(t, package_name, messages, 1)? - } - } - NestableType::NamespacedType(_) => { - default_for_nestable_type(t, package_name, messages, 1)? - } - }, - MemberType::Array(array) => list_default_values( - m, - &array.value_type, - package_name, - messages, - Some(array.size), - )?, - MemberType::Sequence(seq) => { - list_default_values(m, &seq.value_type, package_name, messages, None)? - } - MemberType::BoundedSequence(seq) => list_default_values( - m, - &seq.value_type, - package_name, - messages, - Some(seq.max_size), - )?, - }; - Ok(value) -} - -fn default_for_nestable_type( - t: &NestableType, - package_name: &str, - messages: &HashMap>, - size: usize, -) -> Result { - let empty = HashMap::new(); - let package_messages = messages.get(package_name).unwrap_or(&empty); - let array = match t { - NestableType::BasicType(t) => match t { - BasicType::I8 => Int8Array::from(vec![0; size]).into(), - BasicType::I16 => Int16Array::from(vec![0; size]).into(), - BasicType::I32 => Int32Array::from(vec![0; size]).into(), - BasicType::I64 => Int64Array::from(vec![0; size]).into(), - BasicType::U8 => UInt8Array::from(vec![0; size]).into(), - BasicType::U16 => UInt16Array::from(vec![0; size]).into(), - BasicType::U32 => UInt32Array::from(vec![0; size]).into(), - BasicType::U64 => UInt64Array::from(vec![0; size]).into(), - BasicType::F32 => Float32Array::from(vec![0.; size]).into(), - BasicType::F64 => Float64Array::from(vec![0.; size]).into(), - BasicType::Char => StringArray::from(vec![""]).into(), - BasicType::Byte => UInt8Array::from(vec![0u8; size]).into(), - BasicType::Bool => BooleanArray::from(vec![false; size]).into(), - }, - NestableType::GenericString(_) => StringArray::from(vec![""]).into(), - NestableType::NamedType(name) => { - let referenced_message = package_messages - .get(&name.0) - .context("unknown referenced message")?; - - default_for_referenced_message(referenced_message, package_name, messages)? - } - NestableType::NamespacedType(t) => { - let referenced_package_messages = messages.get(&t.package).unwrap_or(&empty); - let referenced_message = referenced_package_messages - .get(&t.name) - .context("unknown referenced message")?; - default_for_referenced_message(referenced_message, &t.package, messages)? - } - }; - Ok(array) -} - -fn preset_default_for_basic_type(t: &NestableType, preset: &str) -> Result { - Ok(match t { - NestableType::BasicType(t) => match t { - BasicType::I8 => Int8Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::I16 => Int16Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::I32 => Int32Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::I64 => Int64Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::U8 => UInt8Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::U16 => UInt16Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::U32 => UInt32Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::U64 => UInt64Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::F32 => Float32Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::F64 => Float64Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::Char => StringArray::from(vec![preset]).into(), - BasicType::Byte => UInt8Array::from(preset.as_bytes().to_owned()).into(), - BasicType::Bool => BooleanArray::from(vec![preset - .parse::() - .context("could not parse preset default value")?]) - .into(), - }, - NestableType::GenericString(_) => StringArray::from(vec![preset]).into(), - _ => todo!(), - }) -} - -fn default_for_referenced_message( - referenced_message: &Message, - package_name: &str, - messages: &HashMap>, -) -> eyre::Result { - let fields: Vec<(Arc, Arc)> = referenced_message - .members - .iter() - .map(|m| { - let default = default_for_member(m, package_name, messages)?; - Result::<_, eyre::Report>::Ok(( - Arc::new(Field::new( - m.name.clone(), - default.data_type().clone(), - true, - )), - make_array(default), - )) - }) - .collect::>()?; - - let struct_array: StructArray = fields.into(); - Ok(struct_array.into()) -} - -fn list_default_values( - m: &dora_ros2_bridge_msg_gen::types::Member, - value_type: &NestableType, - package_name: &str, - messages: &HashMap>, - size: Option, -) -> Result { - let defaults = match &m.default.as_deref() { - Some([]) => eyre::bail!("empty default value not supported"), - Some(defaults) => { - let raw_array: Vec> = defaults - .iter() - .map(|default| { - preset_default_for_basic_type(value_type, default) - .with_context(|| format!("failed to parse default value for `{}`", m.name)) - .map(make_array) - }) - .collect::>()?; - let default_values = concat( - raw_array - .iter() - .map(|data| data.as_ref()) - .collect::>() - .as_slice(), - ) - .context("Failed to concatenate default list value")?; - default_values.to_data() - } - None => { - let size = size.unwrap_or(1); - let default_nested_type = - default_for_nestable_type(value_type, package_name, messages, size)?; - let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, size as i32])); - - let field = Arc::new(Field::new( - "item", - default_nested_type.data_type().clone(), - true, - )); - let list = ListArray::new(field, offsets, make_array(default_nested_type), None); - list.to_data() - } - }; - - Ok(defaults) -} +/// Serde requires that struct and field names are known at +/// compile time with a `'static` lifetime, which is not +/// possible in this case. Thus, we need to use dummy names +/// instead. +/// +/// The actual names do not really matter because +/// the CDR format of ROS2 does not encode struct or field +/// names. +const DUMMY_STRUCT_NAME: &str = "struct"; +const DUMMY_FIELD_NAME: &str = "field"; diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize.rs deleted file mode 100644 index 5219489da..000000000 --- a/libraries/extensions/ros2-bridge/python/src/typed/serialize.rs +++ /dev/null @@ -1,282 +0,0 @@ -use arrow::array::make_array; -use arrow::array::Array; -use arrow::array::ArrayRef; -use arrow::array::AsArray; -use arrow::array::Float32Array; -use arrow::array::Float64Array; -use arrow::array::Int16Array; -use arrow::array::Int32Array; -use arrow::array::Int64Array; -use arrow::array::Int8Array; -use arrow::array::ListArray; -use arrow::array::StringArray; -use arrow::array::StructArray; -use arrow::array::UInt16Array; -use arrow::array::UInt32Array; -use arrow::array::UInt64Array; -use arrow::array::UInt8Array; -use arrow::datatypes::DataType; -use serde::ser::SerializeSeq; -use serde::ser::SerializeStruct; - -use super::TypeInfo; - -#[derive(Debug, Clone, PartialEq)] -pub struct TypedValue<'a> { - pub value: &'a ArrayRef, - pub type_info: &'a TypeInfo, -} - -impl serde::Serialize for TypedValue<'_> { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - match &self.type_info.data_type { - DataType::UInt8 => { - let array: &UInt8Array = self.value.as_primitive(); - debug_assert!(array.len() == 1, "array length was: {}", array.len()); - let number = array.value(0); - serializer.serialize_u8(number) - } - DataType::UInt16 => { - let array: &UInt16Array = self.value.as_primitive(); - debug_assert!(array.len() == 1); - let number = array.value(0); - serializer.serialize_u16(number) - } - DataType::UInt32 => { - let array: &UInt32Array = self.value.as_primitive(); - debug_assert!(array.len() == 1); - let number = array.value(0); - serializer.serialize_u32(number) - } - DataType::UInt64 => { - let array: &UInt64Array = self.value.as_primitive(); - debug_assert!(array.len() == 1); - let number = array.value(0); - serializer.serialize_u64(number) - } - DataType::Int8 => { - let array: &Int8Array = self.value.as_primitive(); - debug_assert!(array.len() == 1); - let number = array.value(0); - serializer.serialize_i8(number) - } - DataType::Int16 => { - let array: &Int16Array = self.value.as_primitive(); - debug_assert!(array.len() == 1); - let number = array.value(0); - serializer.serialize_i16(number) - } - DataType::Int32 => { - let array: &Int32Array = self.value.as_primitive(); - debug_assert!(array.len() == 1); - let number = array.value(0); - serializer.serialize_i32(number) - } - DataType::Int64 => { - let array: &Int64Array = self.value.as_primitive(); - debug_assert!(array.len() == 1, "array was: {:#?}", array); - let number = array.value(0); - serializer.serialize_i64(number) - } - DataType::Float32 => { - let array: &Float32Array = self.value.as_primitive(); - debug_assert!(array.len() == 1); - let number = array.value(0); - serializer.serialize_f32(number) - } - DataType::Float64 => { - let array: &Float64Array = self.value.as_primitive(); - debug_assert!(array.len() == 1); - let number = array.value(0); - serializer.serialize_f64(number) - } - DataType::Utf8 => { - let int_array: &StringArray = self.value.as_string(); - let string = int_array.value(0); - serializer.serialize_str(string) - } - DataType::List(_field) => { - let list_array: &ListArray = self.value.as_list(); - let mut s = serializer.serialize_seq(Some(list_array.len()))?; - for root in list_array.iter() { - if let Some(values) = root { - match values.data_type() { - DataType::UInt8 => { - let values: &UInt8Array = values.as_primitive(); - for value in values.iter() { - if let Some(value) = value { - s.serialize_element(&value)?; - } else { - todo!("Implement null management"); - } - } - } - DataType::UInt16 => { - let values: &UInt16Array = values.as_primitive(); - for value in values.iter() { - if let Some(value) = value { - s.serialize_element(&value)?; - } else { - todo!("Implement null management"); - } - } - } - DataType::UInt32 => { - let values: &UInt32Array = values.as_primitive(); - for value in values.iter() { - if let Some(value) = value { - s.serialize_element(&value)?; - } else { - todo!("Implement null management"); - } - } - } - DataType::UInt64 => { - let values: &UInt64Array = values.as_primitive(); - for value in values.iter() { - if let Some(value) = value { - s.serialize_element(&value)?; - } else { - todo!("Implement null management"); - } - } - } - DataType::Int8 => { - let values: &Int8Array = values.as_primitive(); - for value in values.iter() { - if let Some(value) = value { - s.serialize_element(&value)?; - } else { - todo!("Implement null management"); - } - } - } - DataType::Int16 => { - let values: &Int16Array = values.as_primitive(); - for value in values.iter() { - if let Some(value) = value { - s.serialize_element(&value)?; - } else { - todo!("Implement null management"); - } - } - } - DataType::Int32 => { - let values: &Int32Array = values.as_primitive(); - for value in values.iter() { - if let Some(value) = value { - s.serialize_element(&value)?; - } else { - todo!("Implement null management"); - } - } - } - DataType::Int64 => { - let values: &Int64Array = values.as_primitive(); - for value in values.iter() { - if let Some(value) = value { - s.serialize_element(&value)?; - } else { - todo!("Implement null management"); - } - } - } - DataType::Float32 => { - let values: &Float32Array = values.as_primitive(); - for value in values.iter() { - if let Some(value) = value { - s.serialize_element(&value)?; - } else { - todo!("Implement null management"); - } - } - } - DataType::Float64 => { - let values: &Float64Array = values.as_primitive(); - for value in values.iter() { - if let Some(value) = value { - s.serialize_element(&value)?; - } else { - todo!("Implement null management"); - } - } - } - DataType::Utf8 => { - let values: &StringArray = values.as_string(); - for value in values.iter() { - if let Some(value) = value { - s.serialize_element(&value)?; - } else { - todo!("Implement null management"); - } - } - } - DataType::Struct(_fields) => { - let list_array: ListArray = self.type_info.defaults.clone().into(); - s.serialize_element(&TypedValue { - value: &values, - type_info: &TypeInfo { - data_type: values.data_type().clone(), - defaults: list_array.value(0).to_data(), - }, - })?; - } - op => todo!("Implement additional type: {:?}", op), - } - } else { - todo!("Implement null management"); - } - } - s.end() - } - DataType::Struct(fields) => { - /// Serde requires that struct and field names are known at - /// compile time with a `'static` lifetime, which is not - /// possible in this case. Thus, we need to use dummy names - /// instead. - /// - /// The actual names do not really matter because - /// the CDR format of ROS2 does not encode struct or field - /// names. - const DUMMY_STRUCT_NAME: &str = "struct"; - const DUMMY_FIELD_NAME: &str = "field"; - - let struct_array: &StructArray = self.value.as_struct(); - let mut s = serializer.serialize_struct(DUMMY_STRUCT_NAME, fields.len())?; - let defaults: StructArray = self.type_info.defaults.clone().into(); - for field in fields.iter() { - let default = match defaults.column_by_name(field.name()) { - Some(value) => value.to_data(), - None => { - return Err(serde::ser::Error::custom(format!( - "missing field {} for serialization", - &field.name() - ))) - } - }; - let value = make_array(default.clone()); - let field_value = match struct_array.column_by_name(field.name()) { - Some(value) => value, - None => &value, - }; - - s.serialize_field( - DUMMY_FIELD_NAME, - &TypedValue { - value: &field_value.clone(), - type_info: &TypeInfo { - data_type: field.data_type().clone(), - defaults: default, - }, - }, - )?; - } - s.end() - } - _ => todo!(), - } - } -} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs new file mode 100644 index 000000000..517d2e6a2 --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs @@ -0,0 +1,186 @@ +use std::marker::PhantomData; + +use arrow::{ + array::{Array, ArrayRef, AsArray, PrimitiveArray}, + datatypes::{self, ArrowPrimitiveType}, +}; +use dora_ros2_bridge_msg_gen::types::{ + primitives::{BasicType, NestableType}, + sequences, +}; +use serde::ser::SerializeTuple; + +use super::error; + +/// Serialize an array with known size as tuple. +pub struct ArraySerializeWrapper<'a> { + pub array_info: &'a sequences::Array, + pub column: &'a ArrayRef, +} + +impl serde::Serialize for ArraySerializeWrapper<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let entry = if let Some(list) = self.column.as_list_opt::() { + // should match the length of the outer struct + assert_eq!(list.len(), 1); + list.value(0) + } else { + // try as large list + let list = self + .column + .as_list_opt::() + .ok_or_else(|| error("value is not compatible with expected Array type"))?; + // should match the length of the outer struct + assert_eq!(list.len(), 1); + list.value(0) + }; + + match &self.array_info.value_type { + NestableType::BasicType(t) => match t { + BasicType::I8 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::I16 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::I32 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::I64 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U8 | BasicType::Char | BasicType::Byte => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U16 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U32 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U64 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::F32 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::F64 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::Bool => BoolArrayAsTuple { + len: self.array_info.size, + value: &entry, + } + .serialize(serializer), + }, + NestableType::NamedType(_) => todo!("serializing arrays of NestableType::NamedType"), + NestableType::NamespacedType(_) => { + todo!("serializing arrays of NestableType::NamespacedType") + } + NestableType::GenericString(_) => { + todo!("serializing arrays of NestableType::GenericString") + } + } + } +} + +/// Serializes a primitive array with known size as tuple. +struct BasicArrayAsTuple<'a, T> { + len: usize, + value: &'a ArrayRef, + ty: PhantomData, +} + +impl serde::Serialize for BasicArrayAsTuple<'_, T> +where + T: ArrowPrimitiveType, + T::Native: serde::Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut seq = serializer.serialize_tuple(self.len)?; + let array: &PrimitiveArray = self + .value + .as_primitive_opt() + .ok_or_else(|| error("not a primitive array"))?; + if array.len() != self.len { + return Err(error(format!( + "expected array with length {}, got length {}", + self.len, + array.len() + ))); + } + + for value in array.values() { + seq.serialize_element(value)?; + } + + seq.end() + } +} + +struct BoolArrayAsTuple<'a> { + len: usize, + value: &'a ArrayRef, +} + +impl serde::Serialize for BoolArrayAsTuple<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut seq = serializer.serialize_tuple(self.len)?; + let array = self + .value + .as_boolean_opt() + .ok_or_else(|| error("not a boolean array"))?; + if array.len() != self.len { + return Err(error(format!( + "expected array with length {}, got length {}", + self.len, + array.len() + ))); + } + + for value in array.values() { + seq.serialize_element(&value)?; + } + + seq.end() + } +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/defaults.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/defaults.rs new file mode 100644 index 000000000..2cc1e6b73 --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/defaults.rs @@ -0,0 +1,237 @@ +use arrow::{ + array::{ + make_array, Array, ArrayData, BooleanArray, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, ListArray, StringArray, StructArray, UInt16Array, + UInt32Array, UInt64Array, UInt8Array, + }, + buffer::{OffsetBuffer, ScalarBuffer}, + compute::concat, + datatypes::Field, +}; +use dora_ros2_bridge_msg_gen::types::{ + primitives::{BasicType, NestableType}, + MemberType, Message, +}; +use eyre::{Context, ContextCompat, Result}; +use std::{collections::HashMap, sync::Arc, vec}; + +pub fn default_for_member( + m: &dora_ros2_bridge_msg_gen::types::Member, + package_name: &str, + messages: &HashMap>, +) -> eyre::Result { + let value = match &m.r#type { + MemberType::NestableType(t) => match t { + NestableType::BasicType(_) | NestableType::GenericString(_) => match &m + .default + .as_deref() + { + Some([]) => eyre::bail!("empty default value not supported"), + Some([default]) => preset_default_for_basic_type(t, default) + .with_context(|| format!("failed to parse default value for `{}`", m.name))?, + Some(_) => eyre::bail!( + "there should be only a single default value for non-sequence types" + ), + None => default_for_nestable_type(t, package_name, messages, 1)?, + }, + NestableType::NamedType(_) => { + if m.default.is_some() { + eyre::bail!("default values for nested types are not supported") + } else { + default_for_nestable_type(t, package_name, messages, 1)? + } + } + NestableType::NamespacedType(_) => { + default_for_nestable_type(t, package_name, messages, 1)? + } + }, + MemberType::Array(array) => list_default_values( + m, + &array.value_type, + package_name, + messages, + Some(array.size), + )?, + MemberType::Sequence(seq) => { + list_default_values(m, &seq.value_type, package_name, messages, None)? + } + MemberType::BoundedSequence(seq) => list_default_values( + m, + &seq.value_type, + package_name, + messages, + Some(seq.max_size), + )?, + }; + Ok(value) +} + +fn default_for_nestable_type( + t: &NestableType, + package_name: &str, + messages: &HashMap>, + size: usize, +) -> Result { + let empty = HashMap::new(); + let package_messages = messages.get(package_name).unwrap_or(&empty); + let array = match t { + NestableType::BasicType(t) => match t { + BasicType::I8 => Int8Array::from(vec![0; size]).into(), + BasicType::I16 => Int16Array::from(vec![0; size]).into(), + BasicType::I32 => Int32Array::from(vec![0; size]).into(), + BasicType::I64 => Int64Array::from(vec![0; size]).into(), + BasicType::U8 => UInt8Array::from(vec![0; size]).into(), + BasicType::U16 => UInt16Array::from(vec![0; size]).into(), + BasicType::U32 => UInt32Array::from(vec![0; size]).into(), + BasicType::U64 => UInt64Array::from(vec![0; size]).into(), + BasicType::F32 => Float32Array::from(vec![0.; size]).into(), + BasicType::F64 => Float64Array::from(vec![0.; size]).into(), + BasicType::Char => StringArray::from(vec![""]).into(), + BasicType::Byte => UInt8Array::from(vec![0u8; size]).into(), + BasicType::Bool => BooleanArray::from(vec![false; size]).into(), + }, + NestableType::GenericString(_) => StringArray::from(vec![""]).into(), + NestableType::NamedType(name) => { + let referenced_message = package_messages + .get(&name.0) + .context("unknown referenced message")?; + + default_for_referenced_message(referenced_message, package_name, messages)? + } + NestableType::NamespacedType(t) => { + let referenced_package_messages = messages.get(&t.package).unwrap_or(&empty); + let referenced_message = referenced_package_messages + .get(&t.name) + .context("unknown referenced message")?; + default_for_referenced_message(referenced_message, &t.package, messages)? + } + }; + Ok(array) +} + +fn preset_default_for_basic_type(t: &NestableType, preset: &str) -> Result { + Ok(match t { + NestableType::BasicType(t) => match t { + BasicType::I8 => Int8Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::I16 => Int16Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::I32 => Int32Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::I64 => Int64Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::U8 => UInt8Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::U16 => UInt16Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::U32 => UInt32Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::U64 => UInt64Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::F32 => Float32Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::F64 => Float64Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::Char => StringArray::from(vec![preset]).into(), + BasicType::Byte => UInt8Array::from(preset.as_bytes().to_owned()).into(), + BasicType::Bool => BooleanArray::from(vec![preset + .parse::() + .context("could not parse preset default value")?]) + .into(), + }, + NestableType::GenericString(_) => StringArray::from(vec![preset]).into(), + _ => todo!("preset_default_for_basic_type (other)"), + }) +} + +fn default_for_referenced_message( + referenced_message: &Message, + package_name: &str, + messages: &HashMap>, +) -> eyre::Result { + let fields: Vec<(Arc, Arc)> = referenced_message + .members + .iter() + .map(|m| { + let default = default_for_member(m, package_name, messages)?; + Result::<_, eyre::Report>::Ok(( + Arc::new(Field::new( + m.name.clone(), + default.data_type().clone(), + true, + )), + make_array(default), + )) + }) + .collect::>()?; + + let struct_array: StructArray = fields.into(); + Ok(struct_array.into()) +} + +fn list_default_values( + m: &dora_ros2_bridge_msg_gen::types::Member, + value_type: &NestableType, + package_name: &str, + messages: &HashMap>, + size: Option, +) -> Result { + let defaults = match &m.default.as_deref() { + Some([]) => eyre::bail!("empty default value not supported"), + Some(defaults) => { + let raw_array: Vec> = defaults + .iter() + .map(|default| { + preset_default_for_basic_type(value_type, default) + .with_context(|| format!("failed to parse default value for `{}`", m.name)) + .map(make_array) + }) + .collect::>()?; + let default_values = concat( + raw_array + .iter() + .map(|data| data.as_ref()) + .collect::>() + .as_slice(), + ) + .context("Failed to concatenate default list value")?; + default_values.to_data() + } + None => { + let size = size.unwrap_or(1); + let default_nested_type = + default_for_nestable_type(value_type, package_name, messages, size)?; + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, size as i32])); + + let field = Arc::new(Field::new( + "item", + default_nested_type.data_type().clone(), + true, + )); + let list = ListArray::new(field, offsets, make_array(default_nested_type), None); + list.to_data() + } + }; + + Ok(defaults) +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs new file mode 100644 index 000000000..ab7556d4c --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs @@ -0,0 +1,181 @@ +use std::{borrow::Cow, collections::HashMap, fmt::Display}; + +use arrow::array::{Array, ArrayRef, AsArray}; +use dora_ros2_bridge_msg_gen::types::{ + primitives::{GenericString, NestableType}, + MemberType, +}; +use serde::ser::SerializeStruct; + +use super::{TypeInfo, DUMMY_FIELD_NAME, DUMMY_STRUCT_NAME}; + +mod array; +mod defaults; +mod primitive; +mod sequence; + +#[derive(Debug, Clone)] +pub struct TypedValue<'a> { + pub value: &'a ArrayRef, + pub type_info: &'a TypeInfo<'a>, +} + +impl serde::Serialize for TypedValue<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let empty = HashMap::new(); + let package_messages = self + .type_info + .messages + .get(self.type_info.package_name.as_ref()) + .unwrap_or(&empty); + let message = package_messages + .get(self.type_info.message_name.as_ref()) + .ok_or_else(|| { + error(format!( + "could not find message type {}::{}", + self.type_info.package_name, self.type_info.message_name + )) + })?; + + let input = self + .value + .as_struct_opt() + .ok_or_else(|| error("expected struct array"))?; + for column_name in input.column_names() { + if !message.members.iter().any(|m| m.name == column_name) { + return Err(error(format!( + "given struct has unknown field {column_name}" + )))?; + } + } + if input.is_empty() { + // TODO: publish default value + return Err(error("given struct is empty"))?; + } + if input.len() > 1 { + return Err(error(format!( + "expected single struct instance, got struct array with {} entries", + input.len() + )))?; + } + let mut s = serializer.serialize_struct(DUMMY_STRUCT_NAME, message.members.len())?; + for field in message.members.iter() { + let column: Cow<_> = match input.column_by_name(&field.name) { + Some(input) => Cow::Borrowed(input), + None => { + let default = defaults::default_for_member( + field, + &self.type_info.package_name, + &self.type_info.messages, + ) + .map_err(error)?; + Cow::Owned(arrow::array::make_array(default)) + } + }; + + match &field.r#type { + MemberType::NestableType(t) => match t { + NestableType::BasicType(t) => { + s.serialize_field( + DUMMY_FIELD_NAME, + &primitive::SerializeWrapper { + t, + column: column.as_ref(), + }, + )?; + } + NestableType::NamedType(name) => { + let referenced_value = &TypedValue { + value: column.as_ref(), + type_info: &TypeInfo { + package_name: Cow::Borrowed(&self.type_info.package_name), + message_name: Cow::Borrowed(&name.0), + messages: self.type_info.messages.clone(), + }, + }; + s.serialize_field(DUMMY_FIELD_NAME, &referenced_value)?; + } + NestableType::NamespacedType(reference) => { + if reference.namespace != "msg" { + return Err(error(format!( + "struct field {} references non-message type {reference:?}", + field.name + ))); + } + + let referenced_value: &TypedValue<'_> = &TypedValue { + value: column.as_ref(), + type_info: &TypeInfo { + package_name: Cow::Borrowed(&reference.package), + message_name: Cow::Borrowed(&reference.name), + messages: self.type_info.messages.clone(), + }, + }; + s.serialize_field(DUMMY_FIELD_NAME, &referenced_value)?; + } + NestableType::GenericString(t) => match t { + GenericString::String | GenericString::BoundedString(_) => { + let string = if let Some(string_array) = column.as_string_opt::() { + // should match the length of the outer struct array + assert_eq!(string_array.len(), 1); + string_array.value(0) + } else { + // try again with large offset type + let string_array = column + .as_string_opt::() + .ok_or_else(|| error("expected string array"))?; + // should match the length of the outer struct array + assert_eq!(string_array.len(), 1); + string_array.value(0) + }; + s.serialize_field(DUMMY_FIELD_NAME, string)?; + } + GenericString::WString => todo!("serializing WString types"), + GenericString::BoundedWString(_) => { + todo!("serializing BoundedWString types") + } + }, + }, + dora_ros2_bridge_msg_gen::types::MemberType::Array(a) => { + s.serialize_field( + DUMMY_FIELD_NAME, + &array::ArraySerializeWrapper { + array_info: a, + column: column.as_ref(), + }, + )?; + } + dora_ros2_bridge_msg_gen::types::MemberType::Sequence(v) => { + s.serialize_field( + DUMMY_FIELD_NAME, + &sequence::SequenceSerializeWrapper { + item_type: &v.value_type, + column: column.as_ref(), + }, + )?; + } + dora_ros2_bridge_msg_gen::types::MemberType::BoundedSequence(v) => { + s.serialize_field( + DUMMY_FIELD_NAME, + &sequence::SequenceSerializeWrapper { + item_type: &v.value_type, + column: column.as_ref(), + }, + )?; + } + } + } + s.end() + } +} + +fn error(e: T) -> E +where + T: Display, + E: serde::ser::Error, +{ + serde::ser::Error::custom(e) +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/primitive.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/primitive.rs new file mode 100644 index 000000000..a13bf444d --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/primitive.rs @@ -0,0 +1,79 @@ +use arrow::{ + array::{ArrayRef, AsArray}, + datatypes::{self, ArrowPrimitiveType}, +}; +use dora_ros2_bridge_msg_gen::types::primitives::BasicType; + +pub struct SerializeWrapper<'a> { + pub t: &'a BasicType, + pub column: &'a ArrayRef, +} + +impl serde::Serialize for SerializeWrapper<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self.t { + BasicType::I8 => { + serializer.serialize_i8(as_single_primitive::(self.column)?) + } + BasicType::I16 => serializer + .serialize_i16(as_single_primitive::(self.column)?), + BasicType::I32 => serializer + .serialize_i32(as_single_primitive::(self.column)?), + BasicType::I64 => serializer + .serialize_i64(as_single_primitive::(self.column)?), + BasicType::U8 | BasicType::Char | BasicType::Byte => serializer + .serialize_u8(as_single_primitive::(self.column)?), + BasicType::U16 => serializer + .serialize_u16(as_single_primitive::( + self.column, + )?), + BasicType::U32 => serializer + .serialize_u32(as_single_primitive::( + self.column, + )?), + BasicType::U64 => serializer + .serialize_u64(as_single_primitive::( + self.column, + )?), + BasicType::F32 => serializer + .serialize_f32(as_single_primitive::( + self.column, + )?), + BasicType::F64 => serializer + .serialize_f64(as_single_primitive::( + self.column, + )?), + BasicType::Bool => { + let array = self.column.as_boolean_opt().ok_or_else(|| { + serde::ser::Error::custom( + "value is not compatible with expected `BooleanArray` type", + ) + })?; + // should match the length of the outer struct + assert_eq!(array.len(), 1); + let field_value = array.value(0); + serializer.serialize_bool(field_value) + } + } + } +} + +fn as_single_primitive(column: &ArrayRef) -> Result +where + T: ArrowPrimitiveType, + E: serde::ser::Error, +{ + let array: &arrow::array::PrimitiveArray = column.as_primitive_opt().ok_or_else(|| { + serde::ser::Error::custom(format!( + "value is not compatible with expected `{}` type", + std::any::type_name::() + )) + })?; + // should match the length of the outer struct + assert_eq!(array.len(), 1); + let number = array.value(0); + Ok(number) +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs new file mode 100644 index 000000000..fd8dd5910 --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs @@ -0,0 +1,152 @@ +use std::marker::PhantomData; + +use arrow::{ + array::{Array, ArrayRef, AsArray, PrimitiveArray}, + datatypes::{self, ArrowPrimitiveType}, +}; +use dora_ros2_bridge_msg_gen::types::primitives::{BasicType, NestableType}; +use serde::ser::{SerializeSeq, SerializeTuple}; + +use super::error; + +/// Serialize a variable-sized sequence. +pub struct SequenceSerializeWrapper<'a> { + pub item_type: &'a NestableType, + pub column: &'a ArrayRef, +} + +impl serde::Serialize for SequenceSerializeWrapper<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let entry = if let Some(list) = self.column.as_list_opt::() { + // should match the length of the outer struct + assert_eq!(list.len(), 1); + list.value(0) + } else { + // try as large list + let list = self + .column + .as_list_opt::() + .ok_or_else(|| error("value is not compatible with expected Array type"))?; + // should match the length of the outer struct + assert_eq!(list.len(), 1); + list.value(0) + }; + match &self.item_type { + NestableType::BasicType(t) => match t { + BasicType::I8 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::I16 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::I32 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::I64 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U8 | BasicType::Char | BasicType::Byte => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U16 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U32 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U64 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::F32 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::F64 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::Bool => BoolArray { value: &entry }.serialize(serializer), + }, + NestableType::NamedType(_) => todo!("serializing NestableType::NamedType sequences"), + NestableType::NamespacedType(_) => { + todo!("serializing NestableType::NamespacedType sequences") + } + NestableType::GenericString(_) => { + todo!("serializing NestableType::Genericstring sequences") + } + } + } +} + +struct BasicSequence<'a, T> { + value: &'a ArrayRef, + ty: PhantomData, +} + +impl serde::Serialize for BasicSequence<'_, T> +where + T: ArrowPrimitiveType, + T::Native: serde::Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let array: &PrimitiveArray = self + .value + .as_primitive_opt() + .ok_or_else(|| error("not a primitive array"))?; + + let mut seq = serializer.serialize_seq(Some(array.len()))?; + + for value in array.values() { + seq.serialize_element(value)?; + } + + seq.end() + } +} + +struct BoolArray<'a> { + value: &'a ArrayRef, +} + +impl serde::Serialize for BoolArray<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let array = self + .value + .as_boolean_opt() + .ok_or_else(|| error("not a boolean array"))?; + let mut seq = serializer.serialize_tuple(array.len())?; + + for value in array.values() { + seq.serialize_element(&value)?; + } + + seq.end() + } +} From 4b8dbfc7c26e1fbe8a8186deabd2a08c1cb4e605 Mon Sep 17 00:00:00 2001 From: Philipp Oppermann Date: Sun, 21 Jan 2024 16:24:52 +0100 Subject: [PATCH 05/14] Add support for (de)serializing arrays/sequences of structs and strings --- .../python/src/typed/deserialize/array.rs | 15 ++- .../python/src/typed/deserialize/mod.rs | 6 +- .../python/src/typed/deserialize/sequence.rs | 66 ++++++++++++-- .../python/src/typed/serialize/array.rs | 91 +++++++++++++++++-- .../python/src/typed/serialize/mod.rs | 3 + .../python/src/typed/serialize/sequence.rs | 88 ++++++++++++++++-- 6 files changed, 239 insertions(+), 30 deletions(-) diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/array.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/array.rs index 79030abe7..170092dc3 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/array.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/array.rs @@ -1,9 +1,14 @@ use arrow::array::ArrayData; use dora_ros2_bridge_msg_gen::types::sequences; +use crate::typed::TypeInfo; + use super::sequence::SequenceVisitor; -pub struct ArrayDeserializer<'a>(pub &'a sequences::Array); +pub struct ArrayDeserializer<'a> { + pub array_type: &'a sequences::Array, + pub type_info: &'a TypeInfo<'a>, +} impl<'de> serde::de::DeserializeSeed<'de> for ArrayDeserializer<'_> { type Value = ArrayData; @@ -12,6 +17,12 @@ impl<'de> serde::de::DeserializeSeed<'de> for ArrayDeserializer<'_> { where D: serde::Deserializer<'de>, { - deserializer.deserialize_tuple(self.0.size, SequenceVisitor(&self.0.value_type)) + deserializer.deserialize_tuple( + self.array_type.size, + SequenceVisitor { + item_type: &self.array_type.value_type, + type_info: self.type_info, + }, + ) } } diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs index 93a94eee1..ca74e8e57 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs @@ -125,13 +125,13 @@ impl<'a, 'de> serde::de::Visitor<'de> for StructVisitor<'a> { } }, dora_ros2_bridge_msg_gen::types::MemberType::Array(a) => { - data.next_element_seed(array::ArrayDeserializer(a))? + data.next_element_seed(array::ArrayDeserializer{ array_type : a, type_info: self.type_info})? }, dora_ros2_bridge_msg_gen::types::MemberType::Sequence(s) => { - data.next_element_seed(sequence::SequenceDeserializer(&s.value_type))? + data.next_element_seed(sequence::SequenceDeserializer{item_type: &s.value_type, type_info: self.type_info})? }, dora_ros2_bridge_msg_gen::types::MemberType::BoundedSequence(s) => { - data.next_element_seed(sequence::SequenceDeserializer(&s.value_type))? + data.next_element_seed(sequence::SequenceDeserializer{ item_type: &s.value_type, type_info: self.type_info})? }, }; diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs index 473950129..ee918ef35 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs @@ -5,8 +5,16 @@ use arrow::{ use core::fmt; use dora_ros2_bridge_msg_gen::types::primitives::{self, BasicType, NestableType}; use serde::Deserialize; +use std::{borrow::Cow, ops::Deref}; -pub struct SequenceDeserializer<'a>(pub &'a NestableType); +use crate::typed::TypeInfo; + +use super::{error, StructDeserializer}; + +pub struct SequenceDeserializer<'a> { + pub item_type: &'a NestableType, + pub type_info: &'a TypeInfo<'a>, +} impl<'de> serde::de::DeserializeSeed<'de> for SequenceDeserializer<'_> { type Value = ArrayData; @@ -15,11 +23,17 @@ impl<'de> serde::de::DeserializeSeed<'de> for SequenceDeserializer<'_> { where D: serde::Deserializer<'de>, { - deserializer.deserialize_seq(SequenceVisitor(self.0)) + deserializer.deserialize_seq(SequenceVisitor { + item_type: self.item_type, + type_info: self.type_info, + }) } } -pub struct SequenceVisitor<'a>(pub &'a NestableType); +pub struct SequenceVisitor<'a> { + pub item_type: &'a NestableType, + pub type_info: &'a TypeInfo<'a>, +} impl<'de> serde::de::Visitor<'de> for SequenceVisitor<'_> { type Value = ArrayData; @@ -32,7 +46,7 @@ impl<'de> serde::de::Visitor<'de> for SequenceVisitor<'_> { where A: serde::de::SeqAccess<'de>, { - match &self.0 { + match &self.item_type { NestableType::BasicType(t) => match t { BasicType::I8 => deserialize_primitive_seq::<_, datatypes::Int8Type>(seq), BasicType::I16 => deserialize_primitive_seq::<_, datatypes::Int16Type>(seq), @@ -54,9 +68,30 @@ impl<'de> serde::de::Visitor<'de> for SequenceVisitor<'_> { Ok(array.finish().into()) } }, - NestableType::NamedType(_) => todo!("deserialize sequence of NestableType::NamedType"), - NestableType::NamespacedType(_) => { - todo!("deserialize sequence of NestableType::NamedspacedType") + NestableType::NamedType(name) => { + let deserializer = StructDeserializer { + type_info: Cow::Owned(TypeInfo { + package_name: Cow::Borrowed(&self.type_info.package_name), + message_name: Cow::Borrowed(&name.0), + messages: self.type_info.messages.clone(), + }), + }; + deserialize_struct_array(&mut seq, deserializer) + } + NestableType::NamespacedType(reference) => { + if reference.namespace != "msg" { + return Err(error(format!( + "sequence item references non-message type {reference:?}", + ))); + } + let deserializer = StructDeserializer { + type_info: Cow::Owned(TypeInfo { + package_name: Cow::Borrowed(&reference.package), + message_name: Cow::Borrowed(&reference.name), + messages: self.type_info.messages.clone(), + }), + }; + deserialize_struct_array(&mut seq, deserializer) } NestableType::GenericString(t) => match t { primitives::GenericString::String | primitives::GenericString::BoundedString(_) => { @@ -75,6 +110,23 @@ impl<'de> serde::de::Visitor<'de> for SequenceVisitor<'_> { } } +fn deserialize_struct_array<'de, A>( + seq: &mut A, + deserializer: StructDeserializer<'_>, +) -> Result>::Error> +where + A: serde::de::SeqAccess<'de>, +{ + let mut values = Vec::new(); + while let Some(value) = seq.next_element_seed(deserializer.clone())? { + values.push(arrow::array::make_array(value)); + } + let refs: Vec<_> = values.iter().map(|a| a.deref()).collect(); + arrow::compute::concat(&refs) + .map(|a| a.to_data()) + .map_err(super::error) +} + fn deserialize_primitive_seq<'de, S, T>( mut seq: S, ) -> Result>::Error> diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs index 517d2e6a2..d1454fd89 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs @@ -1,21 +1,24 @@ -use std::marker::PhantomData; +use std::{borrow::Cow, marker::PhantomData, sync::Arc}; use arrow::{ - array::{Array, ArrayRef, AsArray, PrimitiveArray}, + array::{Array, ArrayRef, AsArray, OffsetSizeTrait, PrimitiveArray}, datatypes::{self, ArrowPrimitiveType}, }; use dora_ros2_bridge_msg_gen::types::{ - primitives::{BasicType, NestableType}, + primitives::{BasicType, GenericString, NestableType}, sequences, }; use serde::ser::SerializeTuple; -use super::error; +use crate::typed::TypeInfo; + +use super::{error, TypedValue}; /// Serialize an array with known size as tuple. pub struct ArraySerializeWrapper<'a> { pub array_info: &'a sequences::Array, pub column: &'a ArrayRef, + pub type_info: &'a TypeInfo<'a>, } impl serde::Serialize for ArraySerializeWrapper<'_> { @@ -106,13 +109,67 @@ impl serde::Serialize for ArraySerializeWrapper<'_> { } .serialize(serializer), }, - NestableType::NamedType(_) => todo!("serializing arrays of NestableType::NamedType"), - NestableType::NamespacedType(_) => { - todo!("serializing arrays of NestableType::NamespacedType") + NestableType::NamedType(name) => { + let array = entry + .as_struct_opt() + .ok_or_else(|| error("not a struct array"))?; + let mut seq = serializer.serialize_tuple(self.array_info.size)?; + for i in 0..array.len() { + let row = array.slice(i, 1); + seq.serialize_element(&TypedValue { + value: &(Arc::new(row) as ArrayRef), + type_info: &crate::typed::TypeInfo { + package_name: Cow::Borrowed(&self.type_info.package_name), + message_name: Cow::Borrowed(&name.0), + messages: self.type_info.messages.clone(), + }, + })?; + } + seq.end() } - NestableType::GenericString(_) => { - todo!("serializing arrays of NestableType::GenericString") + NestableType::NamespacedType(reference) => { + if reference.namespace != "msg" { + return Err(error(format!( + "sequence references non-message type {reference:?}" + ))); + } + + let array = entry + .as_struct_opt() + .ok_or_else(|| error("not a struct array"))?; + let mut seq = serializer.serialize_tuple(self.array_info.size)?; + for i in 0..array.len() { + let row = array.slice(i, 1); + seq.serialize_element(&TypedValue { + value: &(Arc::new(row) as ArrayRef), + type_info: &crate::typed::TypeInfo { + package_name: Cow::Borrowed(&reference.package), + message_name: Cow::Borrowed(&reference.name), + messages: self.type_info.messages.clone(), + }, + })?; + } + seq.end() } + NestableType::GenericString(s) => match s { + GenericString::String | GenericString::BoundedString(_) => { + match entry.as_string_opt::() { + Some(array) => { + serialize_arrow_string(serializer, array, self.array_info.size) + } + None => { + let array = entry + .as_string_opt::() + .ok_or_else(|| error("expected string array"))?; + serialize_arrow_string(serializer, array, self.array_info.size) + } + } + } + GenericString::WString => { + todo!("serializing WString sequences") + } + GenericString::BoundedWString(_) => todo!("serializing BoundedWString sequences"), + }, } } } @@ -184,3 +241,19 @@ impl serde::Serialize for BoolArrayAsTuple<'_> { seq.end() } } + +fn serialize_arrow_string( + serializer: S, + array: &arrow::array::GenericByteArray>, + array_len: usize, +) -> Result<::Ok, ::Error> +where + S: serde::Serializer, + O: OffsetSizeTrait, +{ + let mut seq = serializer.serialize_tuple(array_len)?; + for s in array.iter() { + seq.serialize_element(s.unwrap_or_default())?; + } + seq.end() +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs index ab7556d4c..0382b355d 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs @@ -145,6 +145,7 @@ impl serde::Serialize for TypedValue<'_> { &array::ArraySerializeWrapper { array_info: a, column: column.as_ref(), + type_info: self.type_info, }, )?; } @@ -154,6 +155,7 @@ impl serde::Serialize for TypedValue<'_> { &sequence::SequenceSerializeWrapper { item_type: &v.value_type, column: column.as_ref(), + type_info: self.type_info, }, )?; } @@ -163,6 +165,7 @@ impl serde::Serialize for TypedValue<'_> { &sequence::SequenceSerializeWrapper { item_type: &v.value_type, column: column.as_ref(), + type_info: self.type_info, }, )?; } diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs index fd8dd5910..3fb6bdae7 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs @@ -1,18 +1,21 @@ -use std::marker::PhantomData; +use std::{borrow::Cow, marker::PhantomData, sync::Arc}; use arrow::{ - array::{Array, ArrayRef, AsArray, PrimitiveArray}, + array::{Array, ArrayRef, AsArray, OffsetSizeTrait, PrimitiveArray}, datatypes::{self, ArrowPrimitiveType}, }; -use dora_ros2_bridge_msg_gen::types::primitives::{BasicType, NestableType}; +use dora_ros2_bridge_msg_gen::types::primitives::{BasicType, GenericString, NestableType}; use serde::ser::{SerializeSeq, SerializeTuple}; -use super::error; +use crate::typed::TypeInfo; + +use super::{error, TypedValue}; /// Serialize a variable-sized sequence. pub struct SequenceSerializeWrapper<'a> { pub item_type: &'a NestableType, pub column: &'a ArrayRef, + pub type_info: &'a TypeInfo<'a>, } impl serde::Serialize for SequenceSerializeWrapper<'_> { @@ -88,17 +91,84 @@ impl serde::Serialize for SequenceSerializeWrapper<'_> { .serialize(serializer), BasicType::Bool => BoolArray { value: &entry }.serialize(serializer), }, - NestableType::NamedType(_) => todo!("serializing NestableType::NamedType sequences"), - NestableType::NamespacedType(_) => { - todo!("serializing NestableType::NamespacedType sequences") + NestableType::NamedType(name) => { + let array = entry + .as_struct_opt() + .ok_or_else(|| error("not a struct array"))?; + let mut seq = serializer.serialize_seq(Some(array.len()))?; + for i in 0..array.len() { + let row = array.slice(i, 1); + seq.serialize_element(&TypedValue { + value: &(Arc::new(row) as ArrayRef), + type_info: &crate::typed::TypeInfo { + package_name: Cow::Borrowed(&self.type_info.package_name), + message_name: Cow::Borrowed(&name.0), + messages: self.type_info.messages.clone(), + }, + })?; + } + seq.end() } - NestableType::GenericString(_) => { - todo!("serializing NestableType::Genericstring sequences") + NestableType::NamespacedType(reference) => { + if reference.namespace != "msg" { + return Err(error(format!( + "sequence references non-message type {reference:?}" + ))); + } + + let array = entry + .as_struct_opt() + .ok_or_else(|| error("not a struct array"))?; + let mut seq = serializer.serialize_seq(Some(array.len()))?; + for i in 0..array.len() { + let row = array.slice(i, 1); + seq.serialize_element(&TypedValue { + value: &(Arc::new(row) as ArrayRef), + type_info: &crate::typed::TypeInfo { + package_name: Cow::Borrowed(&reference.package), + message_name: Cow::Borrowed(&reference.name), + messages: self.type_info.messages.clone(), + }, + })?; + } + seq.end() } + NestableType::GenericString(s) => match s { + GenericString::String | GenericString::BoundedString(_) => { + match entry.as_string_opt::() { + Some(array) => serialize_arrow_string(serializer, array), + None => { + let array = entry + .as_string_opt::() + .ok_or_else(|| error("expected string array"))?; + serialize_arrow_string(serializer, array) + } + } + } + GenericString::WString => { + todo!("serializing WString sequences") + } + GenericString::BoundedWString(_) => todo!("serializing BoundedWString sequences"), + }, } } } +fn serialize_arrow_string( + serializer: S, + array: &arrow::array::GenericByteArray>, +) -> Result<::Ok, ::Error> +where + S: serde::Serializer, + O: OffsetSizeTrait, +{ + let mut seq = serializer.serialize_seq(Some(array.len()))?; + for s in array.iter() { + seq.serialize_element(s.unwrap_or_default())?; + } + seq.end() +} + struct BasicSequence<'a, T> { value: &'a ArrayRef, ty: PhantomData, From abc057dda2ebeda8a54c4cb947dbc9eb8cd502e1 Mon Sep 17 00:00:00 2001 From: Philipp Oppermann Date: Sun, 21 Jan 2024 16:25:29 +0100 Subject: [PATCH 06/14] Remove commented out code --- .../python/src/typed/deserialize/mod.rs | 128 ------------------ 1 file changed, 128 deletions(-) diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs index ca74e8e57..d8b9123d3 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs @@ -155,134 +155,6 @@ impl<'a, 'de> serde::de::Visitor<'de> for StructVisitor<'a> { } } -// struct ListVisitor { -// field: Arc, -// defaults: ArrayData, -// } - -// impl<'de> serde::de::Visitor<'de> for ListVisitor { -// type Value = ArrayData; - -// fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { -// formatter.write_str("an array encoded as sequence") -// } - -// fn visit_seq(self, mut data: A) -> Result -// where -// A: serde::de::SeqAccess<'de>, -// { -// let data = match self.field.data_type().clone() { -// DataType::UInt8 => { -// let mut array = UInt8Builder::new(); -// while let Some(value) = data.next_element::()? { -// array.append_value(value); -// } -// Ok(array.finish().into()) -// } -// DataType::UInt16 => { -// let mut array = UInt16Builder::new(); -// while let Some(value) = data.next_element::()? { -// array.append_value(value); -// } -// Ok(array.finish().into()) -// } -// DataType::UInt32 => { -// let mut array = UInt32Builder::new(); -// while let Some(value) = data.next_element::()? { -// array.append_value(value); -// } -// Ok(array.finish().into()) -// } -// DataType::UInt64 => { -// let mut array = UInt64Builder::new(); -// while let Some(value) = data.next_element::()? { -// array.append_value(value); -// } -// Ok(array.finish().into()) -// } -// DataType::Int8 => { -// let mut array = Int8Builder::new(); -// while let Some(value) = data.next_element::()? { -// array.append_value(value); -// } -// Ok(array.finish().into()) -// } -// DataType::Int16 => { -// let mut array = Int16Builder::new(); -// while let Some(value) = data.next_element::()? { -// array.append_value(value); -// } -// Ok(array.finish().into()) -// } -// DataType::Int32 => { -// let mut array = Int32Builder::new(); -// while let Some(value) = data.next_element::()? { -// array.append_value(value); -// } -// Ok(array.finish().into()) -// } -// DataType::Int64 => { -// let mut array = Int64Builder::new(); -// while let Some(value) = data.next_element::()? { -// array.append_value(value); -// } -// Ok(array.finish().into()) -// } -// DataType::Float32 => { -// let mut array = Float32Builder::new(); -// while let Some(value) = data.next_element::()? { -// array.append_value(value); -// } -// Ok(array.finish().into()) -// } -// DataType::Float64 => { -// let mut array = Float64Builder::new(); -// while let Some(value) = data.next_element::()? { -// array.append_value(value); -// } -// Ok(array.finish().into()) -// } -// DataType::Utf8 => { -// let mut array = StringBuilder::new(); -// while let Some(value) = data.next_element::()? { -// array.append_value(value); -// } -// Ok(array.finish().into()) -// } -// _ => { -// let mut buffer = vec![]; -// while let Some(value) = data.next_element_seed(StructDeserializer { -// type_info: TypeInfo { -// data_type: self.field.data_type().clone(), -// defaults: self.defaults.clone(), -// }, -// })? { -// let element = make_array(value); -// buffer.push(element); -// } - -// concat( -// buffer -// .iter() -// .map(|data| data.as_ref()) -// .collect::>() -// .as_slice(), -// ) -// .map(|op| op.to_data()) -// } -// }; - -// if let Ok(values) = data { -// let offsets = OffsetBuffer::new(vec![0, values.len() as i32].into()); - -// let array = ListArray::new(self.field, offsets, make_array(values), None).to_data(); -// Ok(array) -// } else { -// Ok(self.defaults) // TODO: Better handle deserialization error -// } -// } -// } - fn error(e: T) -> E where T: Display, From e5a037afeabb9bd9882fcab07246aaa0f94e2776 Mon Sep 17 00:00:00 2001 From: Philipp Oppermann Date: Sun, 21 Jan 2024 16:57:16 +0100 Subject: [PATCH 07/14] Improve error messages --- libraries/arrow-convert/src/from_impls.rs | 21 ++++++++++--------- .../python/src/typed/serialize/array.rs | 6 +++--- .../python/src/typed/serialize/sequence.rs | 2 +- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/libraries/arrow-convert/src/from_impls.rs b/libraries/arrow-convert/src/from_impls.rs index db484c6b3..01e8a9518 100644 --- a/libraries/arrow-convert/src/from_impls.rs +++ b/libraries/arrow-convert/src/from_impls.rs @@ -39,7 +39,7 @@ impl TryFrom<&ArrowData> for u8 { fn try_from(value: &ArrowData) -> Result { let array = value .as_primitive_opt::() - .context("not a primitive array")?; + .context("not a primitive UInt8Type array")?; extract_single_primitive(array) } } @@ -48,7 +48,7 @@ impl TryFrom<&ArrowData> for u16 { fn try_from(value: &ArrowData) -> Result { let array = value .as_primitive_opt::() - .context("not a primitive array")?; + .context("not a primitive UInt16Type array")?; extract_single_primitive(array) } } @@ -57,7 +57,7 @@ impl TryFrom<&ArrowData> for u32 { fn try_from(value: &ArrowData) -> Result { let array = value .as_primitive_opt::() - .context("not a primitive array")?; + .context("not a primitive UInt32Type array")?; extract_single_primitive(array) } } @@ -66,7 +66,7 @@ impl TryFrom<&ArrowData> for u64 { fn try_from(value: &ArrowData) -> Result { let array = value .as_primitive_opt::() - .context("not a primitive array")?; + .context("not a primitive UInt64Type array")?; extract_single_primitive(array) } } @@ -75,7 +75,7 @@ impl TryFrom<&ArrowData> for i8 { fn try_from(value: &ArrowData) -> Result { let array = value .as_primitive_opt::() - .context("not a primitive array")?; + .context("not a primitive Int8Type array")?; extract_single_primitive(array) } } @@ -84,7 +84,7 @@ impl TryFrom<&ArrowData> for i16 { fn try_from(value: &ArrowData) -> Result { let array = value .as_primitive_opt::() - .context("not a primitive array")?; + .context("not a primitive Int16Type array")?; extract_single_primitive(array) } } @@ -93,7 +93,7 @@ impl TryFrom<&ArrowData> for i32 { fn try_from(value: &ArrowData) -> Result { let array = value .as_primitive_opt::() - .context("not a primitive array")?; + .context("not a primitive Int32Type array")?; extract_single_primitive(array) } } @@ -102,7 +102,7 @@ impl TryFrom<&ArrowData> for i64 { fn try_from(value: &ArrowData) -> Result { let array = value .as_primitive_opt::() - .context("not a primitive array")?; + .context("not a primitive Int64Type array")?; extract_single_primitive(array) } } @@ -127,8 +127,9 @@ impl<'a> TryFrom<&'a ArrowData> for &'a str { impl<'a> TryFrom<&'a ArrowData> for &'a [u8] { type Error = eyre::Report; fn try_from(value: &'a ArrowData) -> Result { - let array: &PrimitiveArray = - value.as_primitive_opt().wrap_err("not a primitive array")?; + let array: &PrimitiveArray = value + .as_primitive_opt() + .wrap_err("not a primitive UInt8Type array")?; if array.null_count() != 0 { eyre::bail!("array has nulls"); } diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs index d1454fd89..3b2bf889a 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs @@ -1,4 +1,4 @@ -use std::{borrow::Cow, marker::PhantomData, sync::Arc}; +use std::{any::type_name, borrow::Cow, marker::PhantomData, sync::Arc}; use arrow::{ array::{Array, ArrayRef, AsArray, OffsetSizeTrait, PrimitiveArray}, @@ -35,7 +35,7 @@ impl serde::Serialize for ArraySerializeWrapper<'_> { let list = self .column .as_list_opt::() - .ok_or_else(|| error("value is not compatible with expected Array type"))?; + .ok_or_else(|| error("value is not compatible with expected array type"))?; // should match the length of the outer struct assert_eq!(list.len(), 1); list.value(0) @@ -194,7 +194,7 @@ where let array: &PrimitiveArray = self .value .as_primitive_opt() - .ok_or_else(|| error("not a primitive array"))?; + .ok_or_else(|| error(format!("not a primitive {} array", type_name::())))?; if array.len() != self.len { return Err(error(format!( "expected array with length {}, got length {}", diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs index 3fb6bdae7..e80c7cf9c 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs @@ -186,7 +186,7 @@ where let array: &PrimitiveArray = self .value .as_primitive_opt() - .ok_or_else(|| error("not a primitive array"))?; + .ok_or_else(|| error(format!("not a primitive {} array", type_name::())))?; let mut seq = serializer.serialize_seq(Some(array.len()))?; From 63188965c58bcd3ab178cca1cf09c7e3e2cd2c54 Mon Sep 17 00:00:00 2001 From: Philipp Oppermann Date: Sun, 21 Jan 2024 17:03:06 +0100 Subject: [PATCH 08/14] Add support for serializing arrow `BinaryArray` types Arrow has a special `BinaryArray` type for storing lists of variable-sized binary data efficiently. This type is equivalent to a `ListArray` of `PrimitiveArray`, but it is a different data type. This commit updates the Python ROS2 serialization code to permit BinaryArray for all uint8 array or sequence fields. --- .../python/src/typed/serialize/sequence.rs | 70 +++++++++++++++---- 1 file changed, 58 insertions(+), 12 deletions(-) diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs index e80c7cf9c..d42d45fb6 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs @@ -1,8 +1,8 @@ -use std::{borrow::Cow, marker::PhantomData, sync::Arc}; +use std::{any::type_name, borrow::Cow, marker::PhantomData, sync::Arc}; use arrow::{ array::{Array, ArrayRef, AsArray, OffsetSizeTrait, PrimitiveArray}, - datatypes::{self, ArrowPrimitiveType}, + datatypes::{self, ArrowPrimitiveType, UInt8Type}, }; use dora_ros2_bridge_msg_gen::types::primitives::{BasicType, GenericString, NestableType}; use serde::ser::{SerializeSeq, SerializeTuple}; @@ -27,15 +27,23 @@ impl serde::Serialize for SequenceSerializeWrapper<'_> { // should match the length of the outer struct assert_eq!(list.len(), 1); list.value(0) - } else { - // try as large list - let list = self - .column - .as_list_opt::() - .ok_or_else(|| error("value is not compatible with expected Array type"))?; + } else if let Some(list) = self.column.as_list_opt::() { // should match the length of the outer struct assert_eq!(list.len(), 1); list.value(0) + } else if let Some(list) = self.column.as_binary_opt::() { + // should match the length of the outer struct + assert_eq!(list.len(), 1); + Arc::new(list.slice(0, 1)) as ArrayRef + } else if let Some(list) = self.column.as_binary_opt::() { + // should match the length of the outer struct + assert_eq!(list.len(), 1); + Arc::new(list.slice(0, 1)) as ArrayRef + } else { + return Err(error(format!( + "value is not compatible with expected sequence type: {:?}", + self.column + ))); }; match &self.item_type { NestableType::BasicType(t) => match t { @@ -59,11 +67,9 @@ impl serde::Serialize for SequenceSerializeWrapper<'_> { ty: PhantomData::, } .serialize(serializer), - BasicType::U8 | BasicType::Char | BasicType::Byte => BasicSequence { - value: &entry, - ty: PhantomData::, + BasicType::U8 | BasicType::Char | BasicType::Byte => { + ByteSequence { value: &entry }.serialize(serializer) } - .serialize(serializer), BasicType::U16 => BasicSequence { value: &entry, ty: PhantomData::, @@ -198,6 +204,46 @@ where } } +struct ByteSequence<'a> { + value: &'a ArrayRef, +} + +impl serde::Serialize for ByteSequence<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + if let Some(binary) = self.value.as_binary_opt::() { + serialize_binary(serializer, binary) + } else if let Some(binary) = self.value.as_binary_opt::() { + serialize_binary(serializer, binary) + } else { + BasicSequence { + value: self.value, + ty: PhantomData::, + } + .serialize(serializer) + } + } +} + +fn serialize_binary( + serializer: S, + binary: &arrow::array::GenericByteArray>, +) -> Result<::Ok, ::Error> +where + S: serde::Serializer, + O: OffsetSizeTrait, +{ + let mut seq = serializer.serialize_seq(Some(binary.len()))?; + + for value in binary.iter() { + seq.serialize_element(value.unwrap_or_default())?; + } + + seq.end() +} + struct BoolArray<'a> { value: &'a ArrayRef, } From 945b3f887b04475bed42a3feef45534833880fcc Mon Sep 17 00:00:00 2001 From: haixuanTao Date: Tue, 23 Jan 2024 16:28:59 +0100 Subject: [PATCH 09/14] Changing `Struct` to `TupleStruct` --- .../python/src/typed/deserialize/mod.rs | 2 +- .../python/src/typed/serialize/mod.rs | 60 ++++++++----------- 2 files changed, 25 insertions(+), 37 deletions(-) diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs index d8b9123d3..13288aa5d 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs @@ -59,7 +59,7 @@ impl<'a, 'de> serde::de::Visitor<'de> for StructVisitor<'a> { type Value = ArrayData; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a struct encoded as sequence") + formatter.write_str("a struct encoded as TupleStruct") } fn visit_seq(self, mut data: A) -> Result diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs index 0382b355d..7f38c36c6 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs @@ -5,7 +5,7 @@ use dora_ros2_bridge_msg_gen::types::{ primitives::{GenericString, NestableType}, MemberType, }; -use serde::ser::SerializeStruct; +use serde::ser::{SerializeStruct, SerializeTupleStruct}; use super::{TypeInfo, DUMMY_FIELD_NAME, DUMMY_STRUCT_NAME}; @@ -61,7 +61,7 @@ impl serde::Serialize for TypedValue<'_> { input.len() )))?; } - let mut s = serializer.serialize_struct(DUMMY_STRUCT_NAME, message.members.len())?; + let mut s = serializer.serialize_tuple_struct(DUMMY_STRUCT_NAME, message.members.len())?; for field in message.members.iter() { let column: Cow<_> = match input.column_by_name(&field.name) { Some(input) => Cow::Borrowed(input), @@ -79,13 +79,10 @@ impl serde::Serialize for TypedValue<'_> { match &field.r#type { MemberType::NestableType(t) => match t { NestableType::BasicType(t) => { - s.serialize_field( - DUMMY_FIELD_NAME, - &primitive::SerializeWrapper { - t, - column: column.as_ref(), - }, - )?; + s.serialize_field(&primitive::SerializeWrapper { + t, + column: column.as_ref(), + })?; } NestableType::NamedType(name) => { let referenced_value = &TypedValue { @@ -96,7 +93,7 @@ impl serde::Serialize for TypedValue<'_> { messages: self.type_info.messages.clone(), }, }; - s.serialize_field(DUMMY_FIELD_NAME, &referenced_value)?; + s.serialize_field(&referenced_value)?; } NestableType::NamespacedType(reference) => { if reference.namespace != "msg" { @@ -114,7 +111,7 @@ impl serde::Serialize for TypedValue<'_> { messages: self.type_info.messages.clone(), }, }; - s.serialize_field(DUMMY_FIELD_NAME, &referenced_value)?; + s.serialize_field(&referenced_value)?; } NestableType::GenericString(t) => match t { GenericString::String | GenericString::BoundedString(_) => { @@ -131,7 +128,7 @@ impl serde::Serialize for TypedValue<'_> { assert_eq!(string_array.len(), 1); string_array.value(0) }; - s.serialize_field(DUMMY_FIELD_NAME, string)?; + s.serialize_field(string)?; } GenericString::WString => todo!("serializing WString types"), GenericString::BoundedWString(_) => { @@ -140,34 +137,25 @@ impl serde::Serialize for TypedValue<'_> { }, }, dora_ros2_bridge_msg_gen::types::MemberType::Array(a) => { - s.serialize_field( - DUMMY_FIELD_NAME, - &array::ArraySerializeWrapper { - array_info: a, - column: column.as_ref(), - type_info: self.type_info, - }, - )?; + s.serialize_field(&array::ArraySerializeWrapper { + array_info: a, + column: column.as_ref(), + type_info: self.type_info, + })?; } dora_ros2_bridge_msg_gen::types::MemberType::Sequence(v) => { - s.serialize_field( - DUMMY_FIELD_NAME, - &sequence::SequenceSerializeWrapper { - item_type: &v.value_type, - column: column.as_ref(), - type_info: self.type_info, - }, - )?; + s.serialize_field(&sequence::SequenceSerializeWrapper { + item_type: &v.value_type, + column: column.as_ref(), + type_info: self.type_info, + })?; } dora_ros2_bridge_msg_gen::types::MemberType::BoundedSequence(v) => { - s.serialize_field( - DUMMY_FIELD_NAME, - &sequence::SequenceSerializeWrapper { - item_type: &v.value_type, - column: column.as_ref(), - type_info: self.type_info, - }, - )?; + s.serialize_field(&sequence::SequenceSerializeWrapper { + item_type: &v.value_type, + column: column.as_ref(), + type_info: self.type_info, + })?; } } } From 4b7c43c3230e22552e0b28700d081d5032eb0375 Mon Sep 17 00:00:00 2001 From: Philipp Oppermann Date: Wed, 24 Jan 2024 12:24:40 +0100 Subject: [PATCH 10/14] Fix some unused code and import warnings --- libraries/extensions/ros2-bridge/python/src/typed/mod.rs | 1 - .../extensions/ros2-bridge/python/src/typed/serialize/mod.rs | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/libraries/extensions/ros2-bridge/python/src/typed/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/mod.rs index b8b76893f..2b91d08b1 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/mod.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/mod.rs @@ -22,4 +22,3 @@ pub struct TypeInfo<'a> { /// the CDR format of ROS2 does not encode struct or field /// names. const DUMMY_STRUCT_NAME: &str = "struct"; -const DUMMY_FIELD_NAME: &str = "field"; diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs index 7f38c36c6..ce9480db1 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs @@ -5,9 +5,9 @@ use dora_ros2_bridge_msg_gen::types::{ primitives::{GenericString, NestableType}, MemberType, }; -use serde::ser::{SerializeStruct, SerializeTupleStruct}; +use serde::ser::SerializeTupleStruct; -use super::{TypeInfo, DUMMY_FIELD_NAME, DUMMY_STRUCT_NAME}; +use super::{TypeInfo, DUMMY_STRUCT_NAME}; mod array; mod defaults; From efb6a51c2246699ef46fa959919bd12308bb4d0b Mon Sep 17 00:00:00 2001 From: Philipp Oppermann Date: Wed, 24 Jan 2024 12:26:02 +0100 Subject: [PATCH 11/14] Fix: Wrap deserialized sequences into list of length 1 Required because arrow uses column-oriented data format, which requires all struct fields to have length 1. --- .../python/src/typed/deserialize/mod.rs | 1 - .../python/src/typed/deserialize/sequence.rs | 38 ++++++++++++++----- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs index 13288aa5d..db9249d1d 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs @@ -143,7 +143,6 @@ impl<'a, 'de> serde::de::Visitor<'de> for StructVisitor<'a> { })?; fields.push(( - // Recreate a new field as List(UInt8) can be converted to UInt8 Arc::new(Field::new(&member.name, value.data_type().clone(), true)), make_array(value), )); diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs index ee918ef35..ecf74de27 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs @@ -1,11 +1,14 @@ use arrow::{ - array::{ArrayData, BooleanBuilder, PrimitiveBuilder, StringBuilder}, - datatypes::{self, ArrowPrimitiveType}, + array::{ + Array, ArrayData, BooleanBuilder, ListArray, ListBuilder, PrimitiveBuilder, StringBuilder, + }, + buffer::OffsetBuffer, + datatypes::{self, ArrowPrimitiveType, Field}, }; use core::fmt; use dora_ros2_bridge_msg_gen::types::primitives::{self, BasicType, NestableType}; use serde::Deserialize; -use std::{borrow::Cow, ops::Deref}; +use std::{borrow::Cow, ops::Deref, sync::Arc}; use crate::typed::TypeInfo; @@ -65,7 +68,10 @@ impl<'de> serde::de::Visitor<'de> for SequenceVisitor<'_> { while let Some(value) = seq.next_element()? { array.append_value(value); } - Ok(array.finish().into()) + // wrap array into list of length 1 + let mut list = ListBuilder::new(array); + list.append(true); + Ok(list.finish().into()) } }, NestableType::NamedType(name) => { @@ -99,7 +105,10 @@ impl<'de> serde::de::Visitor<'de> for SequenceVisitor<'_> { while let Some(value) = seq.next_element::()? { array.append_value(value); } - Ok(array.finish().into()) + // wrap array into list of length 1 + let mut list = ListBuilder::new(array); + list.append(true); + Ok(list.finish().into()) } primitives::GenericString::WString => todo!("deserialize sequence of WString"), primitives::GenericString::BoundedWString(_) => { @@ -122,9 +131,17 @@ where values.push(arrow::array::make_array(value)); } let refs: Vec<_> = values.iter().map(|a| a.deref()).collect(); - arrow::compute::concat(&refs) - .map(|a| a.to_data()) - .map_err(super::error) + let concatenated = arrow::compute::concat(&refs).map_err(super::error)?; + + let list = ListArray::try_new( + Arc::new(Field::new("item", concatenated.data_type().clone(), true)), + OffsetBuffer::from_lengths([concatenated.len()]), + Arc::new(concatenated), + None, + ) + .map_err(error)?; + + Ok(list.to_data()) } fn deserialize_primitive_seq<'de, S, T>( @@ -139,5 +156,8 @@ where while let Some(value) = seq.next_element::()? { array.append_value(value); } - Ok(array.finish().into()) + // wrap array into list of length 1 + let mut list = ListBuilder::new(array); + list.append(true); + Ok(list.finish().into()) } From 9eecf22fee5b081eaf559032b962eb751b687635 Mon Sep 17 00:00:00 2001 From: Philipp Oppermann Date: Wed, 24 Jan 2024 12:26:19 +0100 Subject: [PATCH 12/14] Rename method --- .../ros2-bridge/python/src/typed/deserialize/sequence.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs index ecf74de27..a55921968 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs @@ -82,7 +82,7 @@ impl<'de> serde::de::Visitor<'de> for SequenceVisitor<'_> { messages: self.type_info.messages.clone(), }), }; - deserialize_struct_array(&mut seq, deserializer) + deserialize_struct_seq(&mut seq, deserializer) } NestableType::NamespacedType(reference) => { if reference.namespace != "msg" { @@ -97,7 +97,7 @@ impl<'de> serde::de::Visitor<'de> for SequenceVisitor<'_> { messages: self.type_info.messages.clone(), }), }; - deserialize_struct_array(&mut seq, deserializer) + deserialize_struct_seq(&mut seq, deserializer) } NestableType::GenericString(t) => match t { primitives::GenericString::String | primitives::GenericString::BoundedString(_) => { @@ -119,7 +119,7 @@ impl<'de> serde::de::Visitor<'de> for SequenceVisitor<'_> { } } -fn deserialize_struct_array<'de, A>( +fn deserialize_struct_seq<'de, A>( seq: &mut A, deserializer: StructDeserializer<'_>, ) -> Result>::Error> From 1e510a4236fdbc223659028aabf3557f3e283ce7 Mon Sep 17 00:00:00 2001 From: Philipp Oppermann Date: Thu, 25 Jan 2024 13:10:16 +0100 Subject: [PATCH 13/14] Improve error message: print message type and value Co-authored-by: Haixuan Xavier Tao --- .../ros2-bridge/python/src/typed/serialize/mod.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs index ce9480db1..f4cee1b7d 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs @@ -40,10 +40,12 @@ impl serde::Serialize for TypedValue<'_> { )) })?; - let input = self - .value - .as_struct_opt() - .ok_or_else(|| error("expected struct array"))?; + let input = self.value.as_struct_opt().ok_or_else(|| { + error(format!( + "expected struct array for message: {}, with following format: {:#?} \n But, got value: {:#?}", + self.type_info.message_name, message, self.value + )) + })?; for column_name in input.column_names() { if !message.members.iter().any(|m| m.name == column_name) { return Err(error(format!( From 1addfa385a4815a32e817423f99a933f403a4fc1 Mon Sep 17 00:00:00 2001 From: Philipp Oppermann Date: Thu, 25 Jan 2024 16:22:45 +0100 Subject: [PATCH 14/14] Add more context to error messages when serializing nested message --- .../python/src/typed/serialize/mod.rs | 191 ++++++++++-------- 1 file changed, 111 insertions(+), 80 deletions(-) diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs index f4cee1b7d..8420f14f0 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs @@ -1,10 +1,14 @@ use std::{borrow::Cow, collections::HashMap, fmt::Display}; -use arrow::array::{Array, ArrayRef, AsArray}; +use arrow::{ + array::{Array, ArrayRef, AsArray}, + error, +}; use dora_ros2_bridge_msg_gen::types::{ primitives::{GenericString, NestableType}, MemberType, }; +use eyre::Context; use serde::ser::SerializeTupleStruct; use super::{TypeInfo, DUMMY_STRUCT_NAME}; @@ -42,8 +46,8 @@ impl serde::Serialize for TypedValue<'_> { let input = self.value.as_struct_opt().ok_or_else(|| { error(format!( - "expected struct array for message: {}, with following format: {:#?} \n But, got value: {:#?}", - self.type_info.message_name, message, self.value + "expected struct array for message: {}, with following format: {:#?} \n But, got type: {:#?}", + self.type_info.message_name, message, self.value.data_type() )) })?; for column_name in input.column_names() { @@ -73,95 +77,122 @@ impl serde::Serialize for TypedValue<'_> { &self.type_info.package_name, &self.type_info.messages, ) - .map_err(error)?; + .with_context(|| { + format!( + "failed to calculate default value for field {}.{}", + message.name, field.name + ) + }) + .map_err(|e| error(format!("{e:?}")))?; Cow::Owned(arrow::array::make_array(default)) } }; - match &field.r#type { - MemberType::NestableType(t) => match t { - NestableType::BasicType(t) => { - s.serialize_field(&primitive::SerializeWrapper { - t, - column: column.as_ref(), - })?; - } - NestableType::NamedType(name) => { - let referenced_value = &TypedValue { - value: column.as_ref(), - type_info: &TypeInfo { - package_name: Cow::Borrowed(&self.type_info.package_name), - message_name: Cow::Borrowed(&name.0), - messages: self.type_info.messages.clone(), - }, - }; - s.serialize_field(&referenced_value)?; - } - NestableType::NamespacedType(reference) => { - if reference.namespace != "msg" { - return Err(error(format!( - "struct field {} references non-message type {reference:?}", - field.name - ))); - } + self.serialize_field::(field, column, &mut s) + .map_err(|e| { + error(format!( + "failed to serialize field {}.{}: {e}", + message.name, field.name + )) + })?; + } + s.end() + } +} - let referenced_value: &TypedValue<'_> = &TypedValue { - value: column.as_ref(), - type_info: &TypeInfo { - package_name: Cow::Borrowed(&reference.package), - message_name: Cow::Borrowed(&reference.name), - messages: self.type_info.messages.clone(), - }, - }; - s.serialize_field(&referenced_value)?; - } - NestableType::GenericString(t) => match t { - GenericString::String | GenericString::BoundedString(_) => { - let string = if let Some(string_array) = column.as_string_opt::() { - // should match the length of the outer struct array - assert_eq!(string_array.len(), 1); - string_array.value(0) - } else { - // try again with large offset type - let string_array = column - .as_string_opt::() - .ok_or_else(|| error("expected string array"))?; - // should match the length of the outer struct array - assert_eq!(string_array.len(), 1); - string_array.value(0) - }; - s.serialize_field(string)?; - } - GenericString::WString => todo!("serializing WString types"), - GenericString::BoundedWString(_) => { - todo!("serializing BoundedWString types") - } - }, - }, - dora_ros2_bridge_msg_gen::types::MemberType::Array(a) => { - s.serialize_field(&array::ArraySerializeWrapper { - array_info: a, +impl<'a> TypedValue<'a> { + fn serialize_field( + &self, + field: &dora_ros2_bridge_msg_gen::types::Member, + column: Cow<'_, std::sync::Arc>, + s: &mut S::SerializeTupleStruct, + ) -> Result<(), S::Error> + where + S: serde::Serializer, + { + match &field.r#type { + MemberType::NestableType(t) => match t { + NestableType::BasicType(t) => { + s.serialize_field(&primitive::SerializeWrapper { + t, column: column.as_ref(), - type_info: self.type_info, })?; } - dora_ros2_bridge_msg_gen::types::MemberType::Sequence(v) => { - s.serialize_field(&sequence::SequenceSerializeWrapper { - item_type: &v.value_type, - column: column.as_ref(), - type_info: self.type_info, - })?; + NestableType::NamedType(name) => { + let referenced_value = &TypedValue { + value: column.as_ref(), + type_info: &TypeInfo { + package_name: Cow::Borrowed(&self.type_info.package_name), + message_name: Cow::Borrowed(&name.0), + messages: self.type_info.messages.clone(), + }, + }; + s.serialize_field(&referenced_value)?; } - dora_ros2_bridge_msg_gen::types::MemberType::BoundedSequence(v) => { - s.serialize_field(&sequence::SequenceSerializeWrapper { - item_type: &v.value_type, - column: column.as_ref(), - type_info: self.type_info, - })?; + NestableType::NamespacedType(reference) => { + if reference.namespace != "msg" { + return Err(error(format!( + "struct field {} references non-message type {reference:?}", + field.name + ))); + } + + let referenced_value: &TypedValue<'_> = &TypedValue { + value: column.as_ref(), + type_info: &TypeInfo { + package_name: Cow::Borrowed(&reference.package), + message_name: Cow::Borrowed(&reference.name), + messages: self.type_info.messages.clone(), + }, + }; + s.serialize_field(&referenced_value)?; } + NestableType::GenericString(t) => match t { + GenericString::String | GenericString::BoundedString(_) => { + let string = if let Some(string_array) = column.as_string_opt::() { + // should match the length of the outer struct array + assert_eq!(string_array.len(), 1); + string_array.value(0) + } else { + // try again with large offset type + let string_array = column + .as_string_opt::() + .ok_or_else(|| error("expected string array"))?; + // should match the length of the outer struct array + assert_eq!(string_array.len(), 1); + string_array.value(0) + }; + s.serialize_field(string)?; + } + GenericString::WString => todo!("serializing WString types"), + GenericString::BoundedWString(_) => { + todo!("serializing BoundedWString types") + } + }, + }, + dora_ros2_bridge_msg_gen::types::MemberType::Array(a) => { + s.serialize_field(&array::ArraySerializeWrapper { + array_info: a, + column: column.as_ref(), + type_info: self.type_info, + })?; + } + dora_ros2_bridge_msg_gen::types::MemberType::Sequence(v) => { + s.serialize_field(&sequence::SequenceSerializeWrapper { + item_type: &v.value_type, + column: column.as_ref(), + type_info: self.type_info, + })?; + } + dora_ros2_bridge_msg_gen::types::MemberType::BoundedSequence(v) => { + s.serialize_field(&sequence::SequenceSerializeWrapper { + item_type: &v.value_type, + column: column.as_ref(), + type_info: self.type_info, + })?; } } - s.end() + Ok(()) } }