diff --git a/binaries/runtime/src/operator/mod.rs b/binaries/runtime/src/operator/mod.rs index 474a54c57..b78fb939b 100644 --- a/binaries/runtime/src/operator/mod.rs +++ b/binaries/runtime/src/operator/mod.rs @@ -13,6 +13,7 @@ pub mod channel; mod python; mod shared_lib; +#[allow(unused_variables)] pub fn run_operator( node_id: &NodeId, operator_definition: OperatorDefinition, diff --git a/libraries/arrow-convert/src/from_impls.rs b/libraries/arrow-convert/src/from_impls.rs index db484c6b3..01e8a9518 100644 --- a/libraries/arrow-convert/src/from_impls.rs +++ b/libraries/arrow-convert/src/from_impls.rs @@ -39,7 +39,7 @@ impl TryFrom<&ArrowData> for u8 { fn try_from(value: &ArrowData) -> Result { let array = value .as_primitive_opt::() - .context("not a primitive array")?; + .context("not a primitive UInt8Type array")?; extract_single_primitive(array) } } @@ -48,7 +48,7 @@ impl TryFrom<&ArrowData> for u16 { fn try_from(value: &ArrowData) -> Result { let array = value .as_primitive_opt::() - .context("not a primitive array")?; + .context("not a primitive UInt16Type array")?; extract_single_primitive(array) } } @@ -57,7 +57,7 @@ impl TryFrom<&ArrowData> for u32 { fn try_from(value: &ArrowData) -> Result { let array = value .as_primitive_opt::() - .context("not a primitive array")?; + .context("not a primitive UInt32Type array")?; extract_single_primitive(array) } } @@ -66,7 +66,7 @@ impl TryFrom<&ArrowData> for u64 { fn try_from(value: &ArrowData) -> Result { let array = value .as_primitive_opt::() - .context("not a primitive array")?; + .context("not a primitive UInt64Type array")?; extract_single_primitive(array) } } @@ -75,7 +75,7 @@ impl TryFrom<&ArrowData> for i8 { fn try_from(value: &ArrowData) -> Result { let array = value .as_primitive_opt::() - .context("not a primitive array")?; + .context("not a primitive Int8Type array")?; extract_single_primitive(array) } } @@ -84,7 +84,7 @@ impl TryFrom<&ArrowData> for i16 { fn try_from(value: &ArrowData) -> Result { let array = value .as_primitive_opt::() - .context("not a primitive array")?; + .context("not a primitive Int16Type array")?; extract_single_primitive(array) } } @@ -93,7 +93,7 @@ impl TryFrom<&ArrowData> for i32 { fn try_from(value: &ArrowData) -> Result { let array = value .as_primitive_opt::() - .context("not a primitive array")?; + .context("not a primitive Int32Type array")?; extract_single_primitive(array) } } @@ -102,7 +102,7 @@ impl TryFrom<&ArrowData> for i64 { fn try_from(value: &ArrowData) -> Result { let array = value .as_primitive_opt::() - .context("not a primitive array")?; + .context("not a primitive Int64Type array")?; extract_single_primitive(array) } } @@ -127,8 +127,9 @@ impl<'a> TryFrom<&'a ArrowData> for &'a str { impl<'a> TryFrom<&'a ArrowData> for &'a [u8] { type Error = eyre::Report; fn try_from(value: &'a ArrowData) -> Result { - let array: &PrimitiveArray = - value.as_primitive_opt().wrap_err("not a primitive array")?; + let array: &PrimitiveArray = value + .as_primitive_opt() + .wrap_err("not a primitive UInt8Type array")?; if array.null_count() != 0 { eyre::bail!("array has nulls"); } diff --git a/libraries/extensions/ros2-bridge/python/src/lib.rs b/libraries/extensions/ros2-bridge/python/src/lib.rs index 37ac68a6b..9d06bbaf5 100644 --- a/libraries/extensions/ros2-bridge/python/src/lib.rs +++ b/libraries/extensions/ros2-bridge/python/src/lib.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Cow, collections::HashMap, path::{Path, PathBuf}, sync::Arc, @@ -6,7 +7,7 @@ use std::{ use ::dora_ros2_bridge::{ros2_client, rustdds}; use arrow::{ - array::ArrayData, + array::{make_array, ArrayData}, pyarrow::{FromPyArrow, ToPyArrow}, }; use dora_ros2_bridge_msg_gen::types::Message; @@ -17,7 +18,7 @@ use pyo3::{ types::{PyDict, PyList, PyModule}, PyAny, PyObject, PyResult, Python, }; -use typed::{deserialize::TypedDeserializer, for_message, TypeInfo, TypedValue}; +use typed::{deserialize::StructDeserializer, TypeInfo, TypedValue}; pub mod qos; pub mod typed; @@ -52,6 +53,7 @@ impl Ros2Context { ament_prefix_path_parsed.split(':').map(Path::new).collect() } }; + let packages = dora_ros2_bridge_msg_gen::get_packages(&paths) .map_err(|err| eyre!(err)) .context("failed to parse ROS2 message types")?; @@ -111,10 +113,11 @@ impl Ros2Node { let topic = self .node .create_topic(&topic_name, message_type_name, &qos.into())?; - let type_info = - for_message(&self.messages, namespace_name, message_name).with_context(|| { - format!("failed to determine type info for message {namespace_name}/{message_name}") - })?; + let type_info = TypeInfo { + package_name: namespace_name.to_owned().into(), + message_name: message_name.to_owned().into(), + messages: self.messages.clone(), + }; Ok(Ros2Topic { topic, type_info }) } @@ -143,7 +146,7 @@ impl Ros2Node { .create_subscription(&topic.topic, qos.map(Into::into))?; Ok(Ros2Subscription { subscription: Some(subscription), - deserializer: TypedDeserializer::new(topic.type_info.clone()), + deserializer: StructDeserializer::new(Cow::Owned(topic.type_info.clone())), }) } } @@ -175,14 +178,14 @@ impl From for ros2_client::NodeOptions { #[non_exhaustive] pub struct Ros2Topic { topic: rustdds::Topic, - type_info: TypeInfo, + type_info: TypeInfo<'static>, } #[pyclass] #[non_exhaustive] pub struct Ros2Publisher { publisher: ros2_client::Publisher>, - type_info: TypeInfo, + type_info: TypeInfo<'static>, } #[pymethods] @@ -209,7 +212,7 @@ impl Ros2Publisher { //// add type info to ensure correct serialization (e.g. struct types //// and map types need to be serialized differently) let typed_value = TypedValue { - value: &value, + value: &make_array(value), type_info: &self.type_info, }; @@ -224,7 +227,7 @@ impl Ros2Publisher { #[pyclass] #[non_exhaustive] pub struct Ros2Subscription { - deserializer: TypedDeserializer, + deserializer: StructDeserializer<'static>, subscription: Option>, } @@ -263,7 +266,7 @@ impl Ros2Subscription { } pub struct Ros2SubscriptionStream { - deserializer: TypedDeserializer, + deserializer: StructDeserializer<'static>, subscription: ros2_client::Subscription, } diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize.rs deleted file mode 100644 index 30ba81fd7..000000000 --- a/libraries/extensions/ros2-bridge/python/src/typed/deserialize.rs +++ /dev/null @@ -1,397 +0,0 @@ -use super::TypeInfo; -use arrow::{ - array::{ - make_array, Array, ArrayData, BooleanBuilder, Float32Builder, Float64Builder, Int16Builder, - Int32Builder, Int64Builder, Int8Builder, ListArray, NullArray, StringBuilder, StructArray, - UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder, - }, - buffer::OffsetBuffer, - compute::concat, - datatypes::{DataType, Field, Fields}, -}; -use core::fmt; -use std::sync::Arc; -#[derive(Debug, Clone, PartialEq)] -pub struct TypedDeserializer { - type_info: TypeInfo, -} - -impl TypedDeserializer { - pub fn new(type_info: TypeInfo) -> Self { - Self { type_info } - } -} - -impl<'de> serde::de::DeserializeSeed<'de> for TypedDeserializer { - type Value = ArrayData; - - fn deserialize(self, deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let data_type = self.type_info.data_type; - let value = match data_type.clone() { - DataType::Struct(fields) => { - /// Serde requires that struct and field names are known at - /// compile time with a `'static` lifetime, which is not - /// possible in this case. Thus, we need to use dummy names - /// instead. - /// - /// The actual names do not really matter because - /// the CDR format of ROS2 does not encode struct or field - /// names. - const DUMMY_STRUCT_NAME: &str = "struct"; - const DUMMY_FIELDS: &[&str] = &[""; 100]; - - deserializer.deserialize_struct( - DUMMY_STRUCT_NAME, - &DUMMY_FIELDS[..fields.len()], - StructVisitor { - fields, - defaults: self.type_info.defaults, - }, - ) - } - DataType::List(field) => deserializer.deserialize_seq(ListVisitor { - field, - defaults: self.type_info.defaults, - }), - DataType::UInt8 => deserializer.deserialize_u8(PrimitiveValueVisitor), - DataType::UInt16 => deserializer.deserialize_u16(PrimitiveValueVisitor), - DataType::UInt32 => deserializer.deserialize_u32(PrimitiveValueVisitor), - DataType::UInt64 => deserializer.deserialize_u64(PrimitiveValueVisitor), - DataType::Int8 => deserializer.deserialize_i8(PrimitiveValueVisitor), - DataType::Int16 => deserializer.deserialize_i16(PrimitiveValueVisitor), - DataType::Int32 => deserializer.deserialize_i32(PrimitiveValueVisitor), - DataType::Int64 => deserializer.deserialize_i64(PrimitiveValueVisitor), - DataType::Float32 => deserializer.deserialize_f32(PrimitiveValueVisitor), - DataType::Float64 => deserializer.deserialize_f64(PrimitiveValueVisitor), - DataType::Utf8 => deserializer.deserialize_str(PrimitiveValueVisitor), - _ => todo!(), - }?; - - debug_assert!( - value.data_type() == &data_type, - "Datatype does not correspond to default data type.\n Expected: {:#?} \n but got: {:#?}, with value: {:#?}", data_type, value.data_type(), value - ); - - Ok(value) - } -} - -/// Based on https://docs.rs/serde_yaml/0.9.22/src/serde_yaml/value/de.rs.html#14-121 -struct PrimitiveValueVisitor; - -impl<'de> serde::de::Visitor<'de> for PrimitiveValueVisitor { - type Value = ArrayData; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a primitive value") - } - - fn visit_bool(self, b: bool) -> Result - where - E: serde::de::Error, - { - let mut array = BooleanBuilder::new(); - array.append_value(b); - Ok(array.finish().into()) - } - - fn visit_i8(self, u: i8) -> Result - where - E: serde::de::Error, - { - let mut array = Int8Builder::new(); - array.append_value(u); - Ok(array.finish().into()) - } - - fn visit_i16(self, u: i16) -> Result - where - E: serde::de::Error, - { - let mut array = Int16Builder::new(); - array.append_value(u); - Ok(array.finish().into()) - } - fn visit_i32(self, u: i32) -> Result - where - E: serde::de::Error, - { - let mut array = Int32Builder::new(); - array.append_value(u); - Ok(array.finish().into()) - } - fn visit_i64(self, i: i64) -> Result - where - E: serde::de::Error, - { - let mut array = Int64Builder::new(); - array.append_value(i); - Ok(array.finish().into()) - } - - fn visit_u8(self, u: u8) -> Result - where - E: serde::de::Error, - { - let mut array = UInt8Builder::new(); - array.append_value(u); - Ok(array.finish().into()) - } - fn visit_u16(self, u: u16) -> Result - where - E: serde::de::Error, - { - let mut array = UInt16Builder::new(); - array.append_value(u); - Ok(array.finish().into()) - } - fn visit_u32(self, u: u32) -> Result - where - E: serde::de::Error, - { - let mut array = UInt32Builder::new(); - array.append_value(u); - Ok(array.finish().into()) - } - fn visit_u64(self, u: u64) -> Result - where - E: serde::de::Error, - { - let mut array = UInt64Builder::new(); - array.append_value(u); - Ok(array.finish().into()) - } - - fn visit_f32(self, f: f32) -> Result - where - E: serde::de::Error, - { - let mut array = Float32Builder::new(); - array.append_value(f); - Ok(array.finish().into()) - } - - fn visit_f64(self, f: f64) -> Result - where - E: serde::de::Error, - { - let mut array = Float64Builder::new(); - array.append_value(f); - Ok(array.finish().into()) - } - - fn visit_str(self, s: &str) -> Result - where - E: serde::de::Error, - { - let mut array = StringBuilder::new(); - array.append_value(s); - Ok(array.finish().into()) - } - - fn visit_string(self, s: String) -> Result - where - E: serde::de::Error, - { - let mut array = StringBuilder::new(); - array.append_value(s); - Ok(array.finish().into()) - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - let array = NullArray::new(0); - Ok(array.into()) - } - - fn visit_none(self) -> Result - where - E: serde::de::Error, - { - let array = NullArray::new(0); - Ok(array.into()) - } -} - -struct StructVisitor { - fields: Fields, - defaults: ArrayData, -} - -impl<'de> serde::de::Visitor<'de> for StructVisitor { - type Value = ArrayData; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a struct encoded as sequence") - } - - fn visit_seq(self, mut data: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { - let mut fields = vec![]; - let defaults: StructArray = self.defaults.clone().into(); - for field in self.fields.iter() { - let default = match defaults.column_by_name(field.name()) { - Some(value) => value.clone(), - None => { - return Err(serde::de::Error::custom(format!( - "missing field {} for deserialization", - &field.name() - ))) - } - }; - let value = match data.next_element_seed(TypedDeserializer { - type_info: TypeInfo { - data_type: field.data_type().clone(), - defaults: default.to_data(), - }, - })? { - Some(value) => make_array(value), - None => default, - }; - fields.push(( - // Recreate a new field as List(UInt8) can be converted to UInt8 - Arc::new(Field::new(field.name(), value.data_type().clone(), true)), - value, - )); - } - - let struct_array: StructArray = fields.into(); - - Ok(struct_array.into()) - } -} - -struct ListVisitor { - field: Arc, - defaults: ArrayData, -} - -impl<'de> serde::de::Visitor<'de> for ListVisitor { - type Value = ArrayData; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("an array encoded as sequence") - } - - fn visit_seq(self, mut data: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { - let data = match self.field.data_type().clone() { - DataType::UInt8 => { - let mut array = UInt8Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::UInt16 => { - let mut array = UInt16Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::UInt32 => { - let mut array = UInt32Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::UInt64 => { - let mut array = UInt64Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::Int8 => { - let mut array = Int8Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::Int16 => { - let mut array = Int16Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::Int32 => { - let mut array = Int32Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::Int64 => { - let mut array = Int64Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::Float32 => { - let mut array = Float32Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::Float64 => { - let mut array = Float64Builder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - DataType::Utf8 => { - let mut array = StringBuilder::new(); - while let Some(value) = data.next_element::()? { - array.append_value(value); - } - Ok(array.finish().into()) - } - _ => { - let mut buffer = vec![]; - while let Some(value) = data.next_element_seed(TypedDeserializer { - type_info: TypeInfo { - data_type: self.field.data_type().clone(), - defaults: self.defaults.clone(), - }, - })? { - let element = make_array(value); - buffer.push(element); - } - - concat( - buffer - .iter() - .map(|data| data.as_ref()) - .collect::>() - .as_slice(), - ) - .map(|op| op.to_data()) - } - }; - - if let Ok(values) = data { - let offsets = OffsetBuffer::new(vec![0, values.len() as i32].into()); - - let array = ListArray::new(self.field, offsets, make_array(values), None).to_data(); - Ok(array) - } else { - Ok(self.defaults) // TODO: Better handle deserialization error - } - } -} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/array.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/array.rs new file mode 100644 index 000000000..170092dc3 --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/array.rs @@ -0,0 +1,28 @@ +use arrow::array::ArrayData; +use dora_ros2_bridge_msg_gen::types::sequences; + +use crate::typed::TypeInfo; + +use super::sequence::SequenceVisitor; + +pub struct ArrayDeserializer<'a> { + pub array_type: &'a sequences::Array, + pub type_info: &'a TypeInfo<'a>, +} + +impl<'de> serde::de::DeserializeSeed<'de> for ArrayDeserializer<'_> { + type Value = ArrayData; + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_tuple( + self.array_type.size, + SequenceVisitor { + item_type: &self.array_type.value_type, + type_info: self.type_info, + }, + ) + } +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs new file mode 100644 index 000000000..db9249d1d --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/mod.rs @@ -0,0 +1,163 @@ +use super::{TypeInfo, DUMMY_STRUCT_NAME}; +use arrow::{ + array::{make_array, ArrayData, StructArray}, + datatypes::Field, +}; +use core::fmt; +use std::{borrow::Cow, collections::HashMap, fmt::Display, sync::Arc}; + +mod array; +mod primitive; +mod sequence; +mod string; + +#[derive(Debug, Clone)] +pub struct StructDeserializer<'a> { + type_info: Cow<'a, TypeInfo<'a>>, +} + +impl<'a> StructDeserializer<'a> { + pub fn new(type_info: Cow<'a, TypeInfo<'a>>) -> Self { + Self { type_info } + } +} + +impl<'de> serde::de::DeserializeSeed<'de> for StructDeserializer<'_> { + type Value = ArrayData; + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let empty = HashMap::new(); + let package_messages = self + .type_info + .messages + .get(self.type_info.package_name.as_ref()) + .unwrap_or(&empty); + let message = package_messages + .get(self.type_info.message_name.as_ref()) + .ok_or_else(|| { + error(format!( + "could not find message type {}::{}", + self.type_info.package_name, self.type_info.message_name + )) + })?; + + let visitor = StructVisitor { + type_info: self.type_info.as_ref(), + }; + deserializer.deserialize_tuple_struct(DUMMY_STRUCT_NAME, message.members.len(), visitor) + } +} + +struct StructVisitor<'a> { + type_info: &'a TypeInfo<'a>, +} + +impl<'a, 'de> serde::de::Visitor<'de> for StructVisitor<'a> { + type Value = ArrayData; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a struct encoded as TupleStruct") + } + + fn visit_seq(self, mut data: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let empty = HashMap::new(); + let package_messages = self + .type_info + .messages + .get(self.type_info.package_name.as_ref()) + .unwrap_or(&empty); + let message = package_messages + .get(self.type_info.message_name.as_ref()) + .ok_or_else(|| { + error(format!( + "could not find message type {}::{}", + self.type_info.package_name, self.type_info.message_name + )) + })?; + + let mut fields = vec![]; + for member in &message.members { + let value = match &member.r#type { + dora_ros2_bridge_msg_gen::types::MemberType::NestableType(t) => match t { + dora_ros2_bridge_msg_gen::types::primitives::NestableType::BasicType(t) => { + data.next_element_seed(primitive::PrimitiveDeserializer(t))? + } + dora_ros2_bridge_msg_gen::types::primitives::NestableType::NamedType(name) => { + data.next_element_seed(StructDeserializer { + type_info: Cow::Owned(TypeInfo { + package_name: Cow::Borrowed(&self.type_info.package_name), + message_name: Cow::Borrowed(&name.0), + messages: self.type_info.messages.clone(), + }), + })? + } + dora_ros2_bridge_msg_gen::types::primitives::NestableType::NamespacedType( + reference, + ) => { + if reference.namespace != "msg" { + return Err(error(format!( + "struct field {} references non-message type {reference:?}", + member.name + ))); + } + data.next_element_seed(StructDeserializer { + type_info: Cow::Owned(TypeInfo { + package_name: Cow::Borrowed(&reference.package), + message_name: Cow::Borrowed(&reference.name), + messages: self.type_info.messages.clone(), + }), + })? + } + dora_ros2_bridge_msg_gen::types::primitives::NestableType::GenericString(t) => { + match t { + dora_ros2_bridge_msg_gen::types::primitives::GenericString::String | dora_ros2_bridge_msg_gen::types::primitives::GenericString::BoundedString(_)=> { + data.next_element_seed(string::StringDeserializer)? + }, + dora_ros2_bridge_msg_gen::types::primitives::GenericString::WString => todo!("deserialize WString"), + dora_ros2_bridge_msg_gen::types::primitives::GenericString::BoundedWString(_) => todo!("deserialize BoundedWString"), + } + } + }, + dora_ros2_bridge_msg_gen::types::MemberType::Array(a) => { + data.next_element_seed(array::ArrayDeserializer{ array_type : a, type_info: self.type_info})? + }, + dora_ros2_bridge_msg_gen::types::MemberType::Sequence(s) => { + data.next_element_seed(sequence::SequenceDeserializer{item_type: &s.value_type, type_info: self.type_info})? + }, + dora_ros2_bridge_msg_gen::types::MemberType::BoundedSequence(s) => { + data.next_element_seed(sequence::SequenceDeserializer{ item_type: &s.value_type, type_info: self.type_info})? + }, + }; + + let value = value.ok_or_else(|| { + error(format!( + "struct member {} not present in message", + member.name + )) + })?; + + fields.push(( + Arc::new(Field::new(&member.name, value.data_type().clone(), true)), + make_array(value), + )); + } + + let struct_array: StructArray = fields.into(); + + Ok(struct_array.into()) + } +} + +fn error(e: T) -> E +where + T: Display, + E: serde::de::Error, +{ + serde::de::Error::custom(e) +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/primitive.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/primitive.rs new file mode 100644 index 000000000..7f13b575f --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/primitive.rs @@ -0,0 +1,155 @@ +use arrow::array::{ + ArrayData, BooleanBuilder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, + Int64Builder, Int8Builder, NullArray, UInt16Builder, UInt32Builder, UInt64Builder, + UInt8Builder, +}; +use core::fmt; +use dora_ros2_bridge_msg_gen::types::primitives::BasicType; + +pub struct PrimitiveDeserializer<'a>(pub &'a BasicType); + +impl<'de> serde::de::DeserializeSeed<'de> for PrimitiveDeserializer<'_> { + type Value = ArrayData; + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + match self.0 { + BasicType::I8 => deserializer.deserialize_i8(PrimitiveValueVisitor), + BasicType::I16 => deserializer.deserialize_i16(PrimitiveValueVisitor), + BasicType::I32 => deserializer.deserialize_i32(PrimitiveValueVisitor), + BasicType::I64 => deserializer.deserialize_i64(PrimitiveValueVisitor), + BasicType::U8 | BasicType::Char | BasicType::Byte => { + deserializer.deserialize_u8(PrimitiveValueVisitor) + } + BasicType::U16 => deserializer.deserialize_u16(PrimitiveValueVisitor), + BasicType::U32 => deserializer.deserialize_u32(PrimitiveValueVisitor), + BasicType::U64 => deserializer.deserialize_u64(PrimitiveValueVisitor), + BasicType::F32 => deserializer.deserialize_f32(PrimitiveValueVisitor), + BasicType::F64 => deserializer.deserialize_f64(PrimitiveValueVisitor), + BasicType::Bool => deserializer.deserialize_bool(PrimitiveValueVisitor), + } + } +} + +/// Based on https://docs.rs/serde_yaml/0.9.22/src/serde_yaml/value/de.rs.html#14-121 +struct PrimitiveValueVisitor; + +impl<'de> serde::de::Visitor<'de> for PrimitiveValueVisitor { + type Value = ArrayData; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a primitive value") + } + + fn visit_bool(self, b: bool) -> Result + where + E: serde::de::Error, + { + let mut array = BooleanBuilder::new(); + array.append_value(b); + Ok(array.finish().into()) + } + + fn visit_i8(self, u: i8) -> Result + where + E: serde::de::Error, + { + let mut array = Int8Builder::new(); + array.append_value(u); + Ok(array.finish().into()) + } + + fn visit_i16(self, u: i16) -> Result + where + E: serde::de::Error, + { + let mut array = Int16Builder::new(); + array.append_value(u); + Ok(array.finish().into()) + } + fn visit_i32(self, u: i32) -> Result + where + E: serde::de::Error, + { + let mut array = Int32Builder::new(); + array.append_value(u); + Ok(array.finish().into()) + } + fn visit_i64(self, i: i64) -> Result + where + E: serde::de::Error, + { + let mut array = Int64Builder::new(); + array.append_value(i); + Ok(array.finish().into()) + } + + fn visit_u8(self, u: u8) -> Result + where + E: serde::de::Error, + { + let mut array = UInt8Builder::new(); + array.append_value(u); + Ok(array.finish().into()) + } + fn visit_u16(self, u: u16) -> Result + where + E: serde::de::Error, + { + let mut array = UInt16Builder::new(); + array.append_value(u); + Ok(array.finish().into()) + } + fn visit_u32(self, u: u32) -> Result + where + E: serde::de::Error, + { + let mut array = UInt32Builder::new(); + array.append_value(u); + Ok(array.finish().into()) + } + fn visit_u64(self, u: u64) -> Result + where + E: serde::de::Error, + { + let mut array = UInt64Builder::new(); + array.append_value(u); + Ok(array.finish().into()) + } + + fn visit_f32(self, f: f32) -> Result + where + E: serde::de::Error, + { + let mut array = Float32Builder::new(); + array.append_value(f); + Ok(array.finish().into()) + } + + fn visit_f64(self, f: f64) -> Result + where + E: serde::de::Error, + { + let mut array = Float64Builder::new(); + array.append_value(f); + Ok(array.finish().into()) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + let array = NullArray::new(0); + Ok(array.into()) + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + let array = NullArray::new(0); + Ok(array.into()) + } +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs new file mode 100644 index 000000000..a55921968 --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/sequence.rs @@ -0,0 +1,163 @@ +use arrow::{ + array::{ + Array, ArrayData, BooleanBuilder, ListArray, ListBuilder, PrimitiveBuilder, StringBuilder, + }, + buffer::OffsetBuffer, + datatypes::{self, ArrowPrimitiveType, Field}, +}; +use core::fmt; +use dora_ros2_bridge_msg_gen::types::primitives::{self, BasicType, NestableType}; +use serde::Deserialize; +use std::{borrow::Cow, ops::Deref, sync::Arc}; + +use crate::typed::TypeInfo; + +use super::{error, StructDeserializer}; + +pub struct SequenceDeserializer<'a> { + pub item_type: &'a NestableType, + pub type_info: &'a TypeInfo<'a>, +} + +impl<'de> serde::de::DeserializeSeed<'de> for SequenceDeserializer<'_> { + type Value = ArrayData; + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_seq(SequenceVisitor { + item_type: self.item_type, + type_info: self.type_info, + }) + } +} + +pub struct SequenceVisitor<'a> { + pub item_type: &'a NestableType, + pub type_info: &'a TypeInfo<'a>, +} + +impl<'de> serde::de::Visitor<'de> for SequenceVisitor<'_> { + type Value = ArrayData; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a sequence") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + match &self.item_type { + NestableType::BasicType(t) => match t { + BasicType::I8 => deserialize_primitive_seq::<_, datatypes::Int8Type>(seq), + BasicType::I16 => deserialize_primitive_seq::<_, datatypes::Int16Type>(seq), + BasicType::I32 => deserialize_primitive_seq::<_, datatypes::Int32Type>(seq), + BasicType::I64 => deserialize_primitive_seq::<_, datatypes::Int64Type>(seq), + BasicType::U8 | BasicType::Char | BasicType::Byte => { + deserialize_primitive_seq::<_, datatypes::UInt8Type>(seq) + } + BasicType::U16 => deserialize_primitive_seq::<_, datatypes::UInt16Type>(seq), + BasicType::U32 => deserialize_primitive_seq::<_, datatypes::UInt32Type>(seq), + BasicType::U64 => deserialize_primitive_seq::<_, datatypes::UInt64Type>(seq), + BasicType::F32 => deserialize_primitive_seq::<_, datatypes::Float32Type>(seq), + BasicType::F64 => deserialize_primitive_seq::<_, datatypes::Float64Type>(seq), + BasicType::Bool => { + let mut array = BooleanBuilder::new(); + while let Some(value) = seq.next_element()? { + array.append_value(value); + } + // wrap array into list of length 1 + let mut list = ListBuilder::new(array); + list.append(true); + Ok(list.finish().into()) + } + }, + NestableType::NamedType(name) => { + let deserializer = StructDeserializer { + type_info: Cow::Owned(TypeInfo { + package_name: Cow::Borrowed(&self.type_info.package_name), + message_name: Cow::Borrowed(&name.0), + messages: self.type_info.messages.clone(), + }), + }; + deserialize_struct_seq(&mut seq, deserializer) + } + NestableType::NamespacedType(reference) => { + if reference.namespace != "msg" { + return Err(error(format!( + "sequence item references non-message type {reference:?}", + ))); + } + let deserializer = StructDeserializer { + type_info: Cow::Owned(TypeInfo { + package_name: Cow::Borrowed(&reference.package), + message_name: Cow::Borrowed(&reference.name), + messages: self.type_info.messages.clone(), + }), + }; + deserialize_struct_seq(&mut seq, deserializer) + } + NestableType::GenericString(t) => match t { + primitives::GenericString::String | primitives::GenericString::BoundedString(_) => { + let mut array = StringBuilder::new(); + while let Some(value) = seq.next_element::()? { + array.append_value(value); + } + // wrap array into list of length 1 + let mut list = ListBuilder::new(array); + list.append(true); + Ok(list.finish().into()) + } + primitives::GenericString::WString => todo!("deserialize sequence of WString"), + primitives::GenericString::BoundedWString(_) => { + todo!("deserialize sequence of BoundedWString") + } + }, + } + } +} + +fn deserialize_struct_seq<'de, A>( + seq: &mut A, + deserializer: StructDeserializer<'_>, +) -> Result>::Error> +where + A: serde::de::SeqAccess<'de>, +{ + let mut values = Vec::new(); + while let Some(value) = seq.next_element_seed(deserializer.clone())? { + values.push(arrow::array::make_array(value)); + } + let refs: Vec<_> = values.iter().map(|a| a.deref()).collect(); + let concatenated = arrow::compute::concat(&refs).map_err(super::error)?; + + let list = ListArray::try_new( + Arc::new(Field::new("item", concatenated.data_type().clone(), true)), + OffsetBuffer::from_lengths([concatenated.len()]), + Arc::new(concatenated), + None, + ) + .map_err(error)?; + + Ok(list.to_data()) +} + +fn deserialize_primitive_seq<'de, S, T>( + mut seq: S, +) -> Result>::Error> +where + S: serde::de::SeqAccess<'de>, + T: ArrowPrimitiveType, + T::Native: Deserialize<'de>, +{ + let mut array = PrimitiveBuilder::::new(); + while let Some(value) = seq.next_element::()? { + array.append_value(value); + } + // wrap array into list of length 1 + let mut list = ListBuilder::new(array); + list.append(true); + Ok(list.finish().into()) +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/deserialize/string.rs b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/string.rs new file mode 100644 index 000000000..646ea38d3 --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/deserialize/string.rs @@ -0,0 +1,44 @@ +use arrow::array::{ArrayData, StringBuilder}; +use core::fmt; + +pub struct StringDeserializer; + +impl<'de> serde::de::DeserializeSeed<'de> for StringDeserializer { + type Value = ArrayData; + + fn deserialize(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_str(StringVisitor) + } +} + +/// Based on https://docs.rs/serde_yaml/0.9.22/src/serde_yaml/value/de.rs.html#14-121 +struct StringVisitor; + +impl<'de> serde::de::Visitor<'de> for StringVisitor { + type Value = ArrayData; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string value") + } + + fn visit_str(self, s: &str) -> Result + where + E: serde::de::Error, + { + let mut array = StringBuilder::new(); + array.append_value(s); + Ok(array.finish().into()) + } + + fn visit_string(self, s: String) -> Result + where + E: serde::de::Error, + { + let mut array = StringBuilder::new(); + array.append_value(s); + Ok(array.finish().into()) + } +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/mod.rs index 150897fdb..2b91d08b1 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/mod.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/mod.rs @@ -1,277 +1,24 @@ -use arrow::{ - array::{ - make_array, Array, ArrayData, BooleanArray, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, StringArray, StructArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, - }, - buffer::Buffer, - compute::concat, - datatypes::{DataType, Field}, -}; -use dora_ros2_bridge_msg_gen::types::{ - primitives::{BasicType, NestableType}, - MemberType, Message, -}; -use eyre::{Context, ContextCompat, Result}; -use std::{collections::HashMap, sync::Arc}; +use dora_ros2_bridge_msg_gen::types::Message; +use std::{borrow::Cow, collections::HashMap, sync::Arc}; pub use serialize::TypedValue; pub mod deserialize; pub mod serialize; -#[derive(Debug, Clone, PartialEq)] -pub struct TypeInfo { - data_type: DataType, - defaults: ArrayData, +#[derive(Debug, Clone)] +pub struct TypeInfo<'a> { + pub package_name: Cow<'a, str>, + pub message_name: Cow<'a, str>, + pub messages: Arc>>, } -pub fn for_message( - messages: &HashMap>, - package_name: &str, - message_name: &str, -) -> eyre::Result { - let empty = HashMap::new(); - let package_messages = messages.get(package_name).unwrap_or(&empty); - let message = package_messages - .get(message_name) - .context("unknown type name")?; - let default_struct_vec: Vec<(Arc, Arc)> = message - .members - .iter() - .map(|m| { - let default = make_array(default_for_member(m, package_name, messages)?); - Result::<_, eyre::Report>::Ok(( - Arc::new(Field::new( - m.name.clone(), - default.data_type().clone(), - true, - )), - default, - )) - }) - .collect::>()?; - - let default_struct: StructArray = default_struct_vec.into(); - - Ok(TypeInfo { - data_type: default_struct.data_type().clone(), - defaults: default_struct.into(), - }) -} - -pub fn default_for_member( - m: &dora_ros2_bridge_msg_gen::types::Member, - package_name: &str, - messages: &HashMap>, -) -> eyre::Result { - let value = match &m.r#type { - MemberType::NestableType(t) => match t { - NestableType::BasicType(_) | NestableType::GenericString(_) => match &m - .default - .as_deref() - { - Some([]) => eyre::bail!("empty default value not supported"), - Some([default]) => preset_default_for_basic_type(t, default) - .with_context(|| format!("failed to parse default value for `{}`", m.name))?, - Some(_) => eyre::bail!( - "there should be only a single default value for non-sequence types" - ), - None => default_for_nestable_type(t, package_name, messages)?, - }, - NestableType::NamedType(_) => { - if m.default.is_some() { - eyre::bail!("default values for nested types are not supported") - } else { - default_for_nestable_type(t, package_name, messages)? - } - } - NestableType::NamespacedType(_) => { - default_for_nestable_type(t, package_name, messages)? - } - }, - MemberType::Array(array) => { - list_default_values(m, &array.value_type, package_name, messages)? - } - MemberType::Sequence(seq) => { - list_default_values(m, &seq.value_type, package_name, messages)? - } - MemberType::BoundedSequence(seq) => { - list_default_values(m, &seq.value_type, package_name, messages)? - } - }; - Ok(value) -} - -fn default_for_nestable_type( - t: &NestableType, - package_name: &str, - messages: &HashMap>, -) -> Result { - let empty = HashMap::new(); - let package_messages = messages.get(package_name).unwrap_or(&empty); - let array = match t { - NestableType::BasicType(t) => match t { - BasicType::I8 => Int8Array::from(vec![0]).into(), - BasicType::I16 => Int16Array::from(vec![0]).into(), - BasicType::I32 => Int32Array::from(vec![0]).into(), - BasicType::I64 => Int64Array::from(vec![0]).into(), - BasicType::U8 => UInt8Array::from(vec![0]).into(), - BasicType::U16 => UInt16Array::from(vec![0]).into(), - BasicType::U32 => UInt32Array::from(vec![0]).into(), - BasicType::U64 => UInt64Array::from(vec![0]).into(), - BasicType::F32 => Float32Array::from(vec![0.]).into(), - BasicType::F64 => Float64Array::from(vec![0.]).into(), - BasicType::Char => StringArray::from(vec![""]).into(), - BasicType::Byte => UInt8Array::from(vec![0u8] as Vec).into(), - BasicType::Bool => BooleanArray::from(vec![false]).into(), - }, - NestableType::GenericString(_) => StringArray::from(vec![""]).into(), - NestableType::NamedType(name) => { - let referenced_message = package_messages - .get(&name.0) - .context("unknown referenced message")?; - - default_for_referenced_message(referenced_message, package_name, messages)? - } - NestableType::NamespacedType(t) => { - let referenced_package_messages = messages.get(&t.package).unwrap_or(&empty); - let referenced_message = referenced_package_messages - .get(&t.name) - .context("unknown referenced message")?; - default_for_referenced_message(referenced_message, &t.package, messages)? - } - }; - Ok(array) -} - -fn preset_default_for_basic_type(t: &NestableType, preset: &str) -> Result { - Ok(match t { - NestableType::BasicType(t) => match t { - BasicType::I8 => Int8Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::I16 => Int16Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::I32 => Int32Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::I64 => Int64Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::U8 => UInt8Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::U16 => UInt16Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::U32 => UInt32Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::U64 => UInt64Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::F32 => Float32Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::F64 => Float64Array::from(vec![preset - .parse::() - .context("Could not parse preset default value")?]) - .into(), - BasicType::Char => StringArray::from(vec![preset]).into(), - BasicType::Byte => UInt8Array::from(preset.as_bytes().to_owned()).into(), - BasicType::Bool => BooleanArray::from(vec![preset - .parse::() - .context("could not parse preset default value")?]) - .into(), - }, - NestableType::GenericString(_) => StringArray::from(vec![preset]).into(), - _ => todo!(), - }) -} - -fn default_for_referenced_message( - referenced_message: &Message, - package_name: &str, - messages: &HashMap>, -) -> eyre::Result { - let fields: Vec<(Arc, Arc)> = referenced_message - .members - .iter() - .map(|m| { - let default = default_for_member(m, package_name, messages)?; - Result::<_, eyre::Report>::Ok(( - Arc::new(Field::new( - m.name.clone(), - default.data_type().clone(), - true, - )), - make_array(default), - )) - }) - .collect::>()?; - - let struct_array: StructArray = fields.into(); - Ok(struct_array.into()) -} - -fn list_default_values( - m: &dora_ros2_bridge_msg_gen::types::Member, - value_type: &NestableType, - package_name: &str, - messages: &HashMap>, -) -> Result { - let defaults = match &m.default.as_deref() { - Some([]) => eyre::bail!("empty default value not supported"), - Some(defaults) => { - let raw_array: Vec> = defaults - .iter() - .map(|default| { - preset_default_for_basic_type(value_type, default) - .with_context(|| format!("failed to parse default value for `{}`", m.name)) - .map(make_array) - }) - .collect::>()?; - let default_values = concat( - raw_array - .iter() - .map(|data| data.as_ref()) - .collect::>() - .as_slice(), - ) - .context("Failed to concatenate default list value")?; - default_values.to_data() - } - None => { - let default_nested_type = - default_for_nestable_type(value_type, package_name, messages)?; - - let value_offsets = Buffer::from_slice_ref([0i64, 1]); - - let list_data_type = DataType::List(Arc::new(Field::new( - &m.name, - default_nested_type.data_type().clone(), - true, - ))); - // Construct a list array from the above two - ArrayData::builder(list_data_type) - .len(1) - .add_buffer(value_offsets) - .add_child_data(default_nested_type) - .build() - .context("Failed to build default list value")? - } - }; - - Ok(defaults) -} +/// Serde requires that struct and field names are known at +/// compile time with a `'static` lifetime, which is not +/// possible in this case. Thus, we need to use dummy names +/// instead. +/// +/// The actual names do not really matter because +/// the CDR format of ROS2 does not encode struct or field +/// names. +const DUMMY_STRUCT_NAME: &str = "struct"; diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize.rs deleted file mode 100644 index 15e8f622e..000000000 --- a/libraries/extensions/ros2-bridge/python/src/typed/serialize.rs +++ /dev/null @@ -1,158 +0,0 @@ -use arrow::array::ArrayData; -use arrow::array::Float32Array; -use arrow::array::Float64Array; -use arrow::array::Int16Array; -use arrow::array::Int32Array; -use arrow::array::Int64Array; -use arrow::array::Int8Array; -use arrow::array::ListArray; -use arrow::array::StringArray; -use arrow::array::StructArray; -use arrow::array::UInt16Array; -use arrow::array::UInt32Array; -use arrow::array::UInt64Array; -use arrow::array::UInt8Array; -use arrow::datatypes::DataType; -use serde::ser::SerializeSeq; -use serde::ser::SerializeStruct; - -use super::TypeInfo; - -#[derive(Debug, Clone, PartialEq)] -pub struct TypedValue<'a> { - pub value: &'a ArrayData, - pub type_info: &'a TypeInfo, -} - -impl serde::Serialize for TypedValue<'_> { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - match &self.type_info.data_type { - DataType::UInt8 => { - let uint_array: UInt8Array = self.value.clone().into(); - let number = uint_array.value(0); - serializer.serialize_u8(number) - } - DataType::UInt16 => { - let uint_array: UInt16Array = self.value.clone().into(); - let number = uint_array.value(0); - serializer.serialize_u16(number) - } - DataType::UInt32 => { - let uint_array: UInt32Array = self.value.clone().into(); - let number = uint_array.value(0); - serializer.serialize_u32(number) - } - DataType::UInt64 => { - let uint_array: UInt64Array = self.value.clone().into(); - let number = uint_array.value(0); - serializer.serialize_u64(number) - } - DataType::Int8 => { - let int_array: Int8Array = self.value.clone().into(); - let number = int_array.value(0); - serializer.serialize_i8(number) - } - DataType::Int16 => { - let int_array: Int16Array = self.value.clone().into(); - let number = int_array.value(0); - serializer.serialize_i16(number) - } - DataType::Int32 => { - let int_array: Int32Array = self.value.clone().into(); - let number = int_array.value(0); - serializer.serialize_i32(number) - } - DataType::Int64 => { - let int_array: Int64Array = self.value.clone().into(); - let number = int_array.value(0); - serializer.serialize_i64(number) - } - DataType::Float32 => { - let int_array: Float32Array = self.value.clone().into(); - let number = int_array.value(0); - serializer.serialize_f32(number) - } - DataType::Float64 => { - let int_array: Float64Array = self.value.clone().into(); - let number = int_array.value(0); - serializer.serialize_f64(number) - } - DataType::Utf8 => { - let int_array: StringArray = self.value.clone().into(); - let string = int_array.value(0); - serializer.serialize_str(string) - } - DataType::List(field) => { - let list_array: ListArray = self.value.clone().into(); - let values = list_array.values(); - let mut s = serializer.serialize_seq(Some(values.len()))?; - for value in list_array.iter() { - let value = match value { - Some(value) => value.to_data(), - None => { - return Err(serde::ser::Error::custom( - "Value in ListArray is null and not yet supported".to_string(), - )) - } - }; - - s.serialize_element(&TypedValue { - value: &value, - type_info: &TypeInfo { - data_type: field.data_type().clone(), - defaults: self.type_info.defaults.clone(), - }, - })?; - } - s.end() - } - DataType::Struct(fields) => { - /// Serde requires that struct and field names are known at - /// compile time with a `'static` lifetime, which is not - /// possible in this case. Thus, we need to use dummy names - /// instead. - /// - /// The actual names do not really matter because - /// the CDR format of ROS2 does not encode struct or field - /// names. - const DUMMY_STRUCT_NAME: &str = "struct"; - const DUMMY_FIELD_NAME: &str = "field"; - - let struct_array: StructArray = self.value.clone().into(); - let mut s = serializer.serialize_struct(DUMMY_STRUCT_NAME, fields.len())?; - let defaults: StructArray = self.type_info.defaults.clone().into(); - for field in fields.iter() { - let default = match defaults.column_by_name(field.name()) { - Some(value) => value.to_data(), - None => { - return Err(serde::ser::Error::custom(format!( - "missing field {} for serialization", - &field.name() - ))) - } - }; - let field_value = match struct_array.column_by_name(field.name()) { - Some(value) => value.to_data(), - None => default.clone(), - }; - - s.serialize_field( - DUMMY_FIELD_NAME, - &TypedValue { - value: &field_value, - type_info: &TypeInfo { - data_type: field.data_type().clone(), - defaults: default, - }, - }, - )?; - } - s.end() - } - _ => todo!(), - } - } -} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs new file mode 100644 index 000000000..3b2bf889a --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs @@ -0,0 +1,259 @@ +use std::{any::type_name, borrow::Cow, marker::PhantomData, sync::Arc}; + +use arrow::{ + array::{Array, ArrayRef, AsArray, OffsetSizeTrait, PrimitiveArray}, + datatypes::{self, ArrowPrimitiveType}, +}; +use dora_ros2_bridge_msg_gen::types::{ + primitives::{BasicType, GenericString, NestableType}, + sequences, +}; +use serde::ser::SerializeTuple; + +use crate::typed::TypeInfo; + +use super::{error, TypedValue}; + +/// Serialize an array with known size as tuple. +pub struct ArraySerializeWrapper<'a> { + pub array_info: &'a sequences::Array, + pub column: &'a ArrayRef, + pub type_info: &'a TypeInfo<'a>, +} + +impl serde::Serialize for ArraySerializeWrapper<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let entry = if let Some(list) = self.column.as_list_opt::() { + // should match the length of the outer struct + assert_eq!(list.len(), 1); + list.value(0) + } else { + // try as large list + let list = self + .column + .as_list_opt::() + .ok_or_else(|| error("value is not compatible with expected array type"))?; + // should match the length of the outer struct + assert_eq!(list.len(), 1); + list.value(0) + }; + + match &self.array_info.value_type { + NestableType::BasicType(t) => match t { + BasicType::I8 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::I16 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::I32 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::I64 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U8 | BasicType::Char | BasicType::Byte => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U16 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U32 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U64 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::F32 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::F64 => BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::Bool => BoolArrayAsTuple { + len: self.array_info.size, + value: &entry, + } + .serialize(serializer), + }, + NestableType::NamedType(name) => { + let array = entry + .as_struct_opt() + .ok_or_else(|| error("not a struct array"))?; + let mut seq = serializer.serialize_tuple(self.array_info.size)?; + for i in 0..array.len() { + let row = array.slice(i, 1); + seq.serialize_element(&TypedValue { + value: &(Arc::new(row) as ArrayRef), + type_info: &crate::typed::TypeInfo { + package_name: Cow::Borrowed(&self.type_info.package_name), + message_name: Cow::Borrowed(&name.0), + messages: self.type_info.messages.clone(), + }, + })?; + } + seq.end() + } + NestableType::NamespacedType(reference) => { + if reference.namespace != "msg" { + return Err(error(format!( + "sequence references non-message type {reference:?}" + ))); + } + + let array = entry + .as_struct_opt() + .ok_or_else(|| error("not a struct array"))?; + let mut seq = serializer.serialize_tuple(self.array_info.size)?; + for i in 0..array.len() { + let row = array.slice(i, 1); + seq.serialize_element(&TypedValue { + value: &(Arc::new(row) as ArrayRef), + type_info: &crate::typed::TypeInfo { + package_name: Cow::Borrowed(&reference.package), + message_name: Cow::Borrowed(&reference.name), + messages: self.type_info.messages.clone(), + }, + })?; + } + seq.end() + } + NestableType::GenericString(s) => match s { + GenericString::String | GenericString::BoundedString(_) => { + match entry.as_string_opt::() { + Some(array) => { + serialize_arrow_string(serializer, array, self.array_info.size) + } + None => { + let array = entry + .as_string_opt::() + .ok_or_else(|| error("expected string array"))?; + serialize_arrow_string(serializer, array, self.array_info.size) + } + } + } + GenericString::WString => { + todo!("serializing WString sequences") + } + GenericString::BoundedWString(_) => todo!("serializing BoundedWString sequences"), + }, + } + } +} + +/// Serializes a primitive array with known size as tuple. +struct BasicArrayAsTuple<'a, T> { + len: usize, + value: &'a ArrayRef, + ty: PhantomData, +} + +impl serde::Serialize for BasicArrayAsTuple<'_, T> +where + T: ArrowPrimitiveType, + T::Native: serde::Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut seq = serializer.serialize_tuple(self.len)?; + let array: &PrimitiveArray = self + .value + .as_primitive_opt() + .ok_or_else(|| error(format!("not a primitive {} array", type_name::())))?; + if array.len() != self.len { + return Err(error(format!( + "expected array with length {}, got length {}", + self.len, + array.len() + ))); + } + + for value in array.values() { + seq.serialize_element(value)?; + } + + seq.end() + } +} + +struct BoolArrayAsTuple<'a> { + len: usize, + value: &'a ArrayRef, +} + +impl serde::Serialize for BoolArrayAsTuple<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut seq = serializer.serialize_tuple(self.len)?; + let array = self + .value + .as_boolean_opt() + .ok_or_else(|| error("not a boolean array"))?; + if array.len() != self.len { + return Err(error(format!( + "expected array with length {}, got length {}", + self.len, + array.len() + ))); + } + + for value in array.values() { + seq.serialize_element(&value)?; + } + + seq.end() + } +} + +fn serialize_arrow_string( + serializer: S, + array: &arrow::array::GenericByteArray>, + array_len: usize, +) -> Result<::Ok, ::Error> +where + S: serde::Serializer, + O: OffsetSizeTrait, +{ + let mut seq = serializer.serialize_tuple(array_len)?; + for s in array.iter() { + seq.serialize_element(s.unwrap_or_default())?; + } + seq.end() +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/defaults.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/defaults.rs new file mode 100644 index 000000000..2cc1e6b73 --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/defaults.rs @@ -0,0 +1,237 @@ +use arrow::{ + array::{ + make_array, Array, ArrayData, BooleanArray, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, ListArray, StringArray, StructArray, UInt16Array, + UInt32Array, UInt64Array, UInt8Array, + }, + buffer::{OffsetBuffer, ScalarBuffer}, + compute::concat, + datatypes::Field, +}; +use dora_ros2_bridge_msg_gen::types::{ + primitives::{BasicType, NestableType}, + MemberType, Message, +}; +use eyre::{Context, ContextCompat, Result}; +use std::{collections::HashMap, sync::Arc, vec}; + +pub fn default_for_member( + m: &dora_ros2_bridge_msg_gen::types::Member, + package_name: &str, + messages: &HashMap>, +) -> eyre::Result { + let value = match &m.r#type { + MemberType::NestableType(t) => match t { + NestableType::BasicType(_) | NestableType::GenericString(_) => match &m + .default + .as_deref() + { + Some([]) => eyre::bail!("empty default value not supported"), + Some([default]) => preset_default_for_basic_type(t, default) + .with_context(|| format!("failed to parse default value for `{}`", m.name))?, + Some(_) => eyre::bail!( + "there should be only a single default value for non-sequence types" + ), + None => default_for_nestable_type(t, package_name, messages, 1)?, + }, + NestableType::NamedType(_) => { + if m.default.is_some() { + eyre::bail!("default values for nested types are not supported") + } else { + default_for_nestable_type(t, package_name, messages, 1)? + } + } + NestableType::NamespacedType(_) => { + default_for_nestable_type(t, package_name, messages, 1)? + } + }, + MemberType::Array(array) => list_default_values( + m, + &array.value_type, + package_name, + messages, + Some(array.size), + )?, + MemberType::Sequence(seq) => { + list_default_values(m, &seq.value_type, package_name, messages, None)? + } + MemberType::BoundedSequence(seq) => list_default_values( + m, + &seq.value_type, + package_name, + messages, + Some(seq.max_size), + )?, + }; + Ok(value) +} + +fn default_for_nestable_type( + t: &NestableType, + package_name: &str, + messages: &HashMap>, + size: usize, +) -> Result { + let empty = HashMap::new(); + let package_messages = messages.get(package_name).unwrap_or(&empty); + let array = match t { + NestableType::BasicType(t) => match t { + BasicType::I8 => Int8Array::from(vec![0; size]).into(), + BasicType::I16 => Int16Array::from(vec![0; size]).into(), + BasicType::I32 => Int32Array::from(vec![0; size]).into(), + BasicType::I64 => Int64Array::from(vec![0; size]).into(), + BasicType::U8 => UInt8Array::from(vec![0; size]).into(), + BasicType::U16 => UInt16Array::from(vec![0; size]).into(), + BasicType::U32 => UInt32Array::from(vec![0; size]).into(), + BasicType::U64 => UInt64Array::from(vec![0; size]).into(), + BasicType::F32 => Float32Array::from(vec![0.; size]).into(), + BasicType::F64 => Float64Array::from(vec![0.; size]).into(), + BasicType::Char => StringArray::from(vec![""]).into(), + BasicType::Byte => UInt8Array::from(vec![0u8; size]).into(), + BasicType::Bool => BooleanArray::from(vec![false; size]).into(), + }, + NestableType::GenericString(_) => StringArray::from(vec![""]).into(), + NestableType::NamedType(name) => { + let referenced_message = package_messages + .get(&name.0) + .context("unknown referenced message")?; + + default_for_referenced_message(referenced_message, package_name, messages)? + } + NestableType::NamespacedType(t) => { + let referenced_package_messages = messages.get(&t.package).unwrap_or(&empty); + let referenced_message = referenced_package_messages + .get(&t.name) + .context("unknown referenced message")?; + default_for_referenced_message(referenced_message, &t.package, messages)? + } + }; + Ok(array) +} + +fn preset_default_for_basic_type(t: &NestableType, preset: &str) -> Result { + Ok(match t { + NestableType::BasicType(t) => match t { + BasicType::I8 => Int8Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::I16 => Int16Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::I32 => Int32Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::I64 => Int64Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::U8 => UInt8Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::U16 => UInt16Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::U32 => UInt32Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::U64 => UInt64Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::F32 => Float32Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::F64 => Float64Array::from(vec![preset + .parse::() + .context("Could not parse preset default value")?]) + .into(), + BasicType::Char => StringArray::from(vec![preset]).into(), + BasicType::Byte => UInt8Array::from(preset.as_bytes().to_owned()).into(), + BasicType::Bool => BooleanArray::from(vec![preset + .parse::() + .context("could not parse preset default value")?]) + .into(), + }, + NestableType::GenericString(_) => StringArray::from(vec![preset]).into(), + _ => todo!("preset_default_for_basic_type (other)"), + }) +} + +fn default_for_referenced_message( + referenced_message: &Message, + package_name: &str, + messages: &HashMap>, +) -> eyre::Result { + let fields: Vec<(Arc, Arc)> = referenced_message + .members + .iter() + .map(|m| { + let default = default_for_member(m, package_name, messages)?; + Result::<_, eyre::Report>::Ok(( + Arc::new(Field::new( + m.name.clone(), + default.data_type().clone(), + true, + )), + make_array(default), + )) + }) + .collect::>()?; + + let struct_array: StructArray = fields.into(); + Ok(struct_array.into()) +} + +fn list_default_values( + m: &dora_ros2_bridge_msg_gen::types::Member, + value_type: &NestableType, + package_name: &str, + messages: &HashMap>, + size: Option, +) -> Result { + let defaults = match &m.default.as_deref() { + Some([]) => eyre::bail!("empty default value not supported"), + Some(defaults) => { + let raw_array: Vec> = defaults + .iter() + .map(|default| { + preset_default_for_basic_type(value_type, default) + .with_context(|| format!("failed to parse default value for `{}`", m.name)) + .map(make_array) + }) + .collect::>()?; + let default_values = concat( + raw_array + .iter() + .map(|data| data.as_ref()) + .collect::>() + .as_slice(), + ) + .context("Failed to concatenate default list value")?; + default_values.to_data() + } + None => { + let size = size.unwrap_or(1); + let default_nested_type = + default_for_nestable_type(value_type, package_name, messages, size)?; + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, size as i32])); + + let field = Arc::new(Field::new( + "item", + default_nested_type.data_type().clone(), + true, + )); + let list = ListArray::new(field, offsets, make_array(default_nested_type), None); + list.to_data() + } + }; + + Ok(defaults) +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs new file mode 100644 index 000000000..8420f14f0 --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs @@ -0,0 +1,205 @@ +use std::{borrow::Cow, collections::HashMap, fmt::Display}; + +use arrow::{ + array::{Array, ArrayRef, AsArray}, + error, +}; +use dora_ros2_bridge_msg_gen::types::{ + primitives::{GenericString, NestableType}, + MemberType, +}; +use eyre::Context; +use serde::ser::SerializeTupleStruct; + +use super::{TypeInfo, DUMMY_STRUCT_NAME}; + +mod array; +mod defaults; +mod primitive; +mod sequence; + +#[derive(Debug, Clone)] +pub struct TypedValue<'a> { + pub value: &'a ArrayRef, + pub type_info: &'a TypeInfo<'a>, +} + +impl serde::Serialize for TypedValue<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let empty = HashMap::new(); + let package_messages = self + .type_info + .messages + .get(self.type_info.package_name.as_ref()) + .unwrap_or(&empty); + let message = package_messages + .get(self.type_info.message_name.as_ref()) + .ok_or_else(|| { + error(format!( + "could not find message type {}::{}", + self.type_info.package_name, self.type_info.message_name + )) + })?; + + let input = self.value.as_struct_opt().ok_or_else(|| { + error(format!( + "expected struct array for message: {}, with following format: {:#?} \n But, got type: {:#?}", + self.type_info.message_name, message, self.value.data_type() + )) + })?; + for column_name in input.column_names() { + if !message.members.iter().any(|m| m.name == column_name) { + return Err(error(format!( + "given struct has unknown field {column_name}" + )))?; + } + } + if input.is_empty() { + // TODO: publish default value + return Err(error("given struct is empty"))?; + } + if input.len() > 1 { + return Err(error(format!( + "expected single struct instance, got struct array with {} entries", + input.len() + )))?; + } + let mut s = serializer.serialize_tuple_struct(DUMMY_STRUCT_NAME, message.members.len())?; + for field in message.members.iter() { + let column: Cow<_> = match input.column_by_name(&field.name) { + Some(input) => Cow::Borrowed(input), + None => { + let default = defaults::default_for_member( + field, + &self.type_info.package_name, + &self.type_info.messages, + ) + .with_context(|| { + format!( + "failed to calculate default value for field {}.{}", + message.name, field.name + ) + }) + .map_err(|e| error(format!("{e:?}")))?; + Cow::Owned(arrow::array::make_array(default)) + } + }; + + self.serialize_field::(field, column, &mut s) + .map_err(|e| { + error(format!( + "failed to serialize field {}.{}: {e}", + message.name, field.name + )) + })?; + } + s.end() + } +} + +impl<'a> TypedValue<'a> { + fn serialize_field( + &self, + field: &dora_ros2_bridge_msg_gen::types::Member, + column: Cow<'_, std::sync::Arc>, + s: &mut S::SerializeTupleStruct, + ) -> Result<(), S::Error> + where + S: serde::Serializer, + { + match &field.r#type { + MemberType::NestableType(t) => match t { + NestableType::BasicType(t) => { + s.serialize_field(&primitive::SerializeWrapper { + t, + column: column.as_ref(), + })?; + } + NestableType::NamedType(name) => { + let referenced_value = &TypedValue { + value: column.as_ref(), + type_info: &TypeInfo { + package_name: Cow::Borrowed(&self.type_info.package_name), + message_name: Cow::Borrowed(&name.0), + messages: self.type_info.messages.clone(), + }, + }; + s.serialize_field(&referenced_value)?; + } + NestableType::NamespacedType(reference) => { + if reference.namespace != "msg" { + return Err(error(format!( + "struct field {} references non-message type {reference:?}", + field.name + ))); + } + + let referenced_value: &TypedValue<'_> = &TypedValue { + value: column.as_ref(), + type_info: &TypeInfo { + package_name: Cow::Borrowed(&reference.package), + message_name: Cow::Borrowed(&reference.name), + messages: self.type_info.messages.clone(), + }, + }; + s.serialize_field(&referenced_value)?; + } + NestableType::GenericString(t) => match t { + GenericString::String | GenericString::BoundedString(_) => { + let string = if let Some(string_array) = column.as_string_opt::() { + // should match the length of the outer struct array + assert_eq!(string_array.len(), 1); + string_array.value(0) + } else { + // try again with large offset type + let string_array = column + .as_string_opt::() + .ok_or_else(|| error("expected string array"))?; + // should match the length of the outer struct array + assert_eq!(string_array.len(), 1); + string_array.value(0) + }; + s.serialize_field(string)?; + } + GenericString::WString => todo!("serializing WString types"), + GenericString::BoundedWString(_) => { + todo!("serializing BoundedWString types") + } + }, + }, + dora_ros2_bridge_msg_gen::types::MemberType::Array(a) => { + s.serialize_field(&array::ArraySerializeWrapper { + array_info: a, + column: column.as_ref(), + type_info: self.type_info, + })?; + } + dora_ros2_bridge_msg_gen::types::MemberType::Sequence(v) => { + s.serialize_field(&sequence::SequenceSerializeWrapper { + item_type: &v.value_type, + column: column.as_ref(), + type_info: self.type_info, + })?; + } + dora_ros2_bridge_msg_gen::types::MemberType::BoundedSequence(v) => { + s.serialize_field(&sequence::SequenceSerializeWrapper { + item_type: &v.value_type, + column: column.as_ref(), + type_info: self.type_info, + })?; + } + } + Ok(()) + } +} + +fn error(e: T) -> E +where + T: Display, + E: serde::ser::Error, +{ + serde::ser::Error::custom(e) +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/primitive.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/primitive.rs new file mode 100644 index 000000000..a13bf444d --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/primitive.rs @@ -0,0 +1,79 @@ +use arrow::{ + array::{ArrayRef, AsArray}, + datatypes::{self, ArrowPrimitiveType}, +}; +use dora_ros2_bridge_msg_gen::types::primitives::BasicType; + +pub struct SerializeWrapper<'a> { + pub t: &'a BasicType, + pub column: &'a ArrayRef, +} + +impl serde::Serialize for SerializeWrapper<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self.t { + BasicType::I8 => { + serializer.serialize_i8(as_single_primitive::(self.column)?) + } + BasicType::I16 => serializer + .serialize_i16(as_single_primitive::(self.column)?), + BasicType::I32 => serializer + .serialize_i32(as_single_primitive::(self.column)?), + BasicType::I64 => serializer + .serialize_i64(as_single_primitive::(self.column)?), + BasicType::U8 | BasicType::Char | BasicType::Byte => serializer + .serialize_u8(as_single_primitive::(self.column)?), + BasicType::U16 => serializer + .serialize_u16(as_single_primitive::( + self.column, + )?), + BasicType::U32 => serializer + .serialize_u32(as_single_primitive::( + self.column, + )?), + BasicType::U64 => serializer + .serialize_u64(as_single_primitive::( + self.column, + )?), + BasicType::F32 => serializer + .serialize_f32(as_single_primitive::( + self.column, + )?), + BasicType::F64 => serializer + .serialize_f64(as_single_primitive::( + self.column, + )?), + BasicType::Bool => { + let array = self.column.as_boolean_opt().ok_or_else(|| { + serde::ser::Error::custom( + "value is not compatible with expected `BooleanArray` type", + ) + })?; + // should match the length of the outer struct + assert_eq!(array.len(), 1); + let field_value = array.value(0); + serializer.serialize_bool(field_value) + } + } + } +} + +fn as_single_primitive(column: &ArrayRef) -> Result +where + T: ArrowPrimitiveType, + E: serde::ser::Error, +{ + let array: &arrow::array::PrimitiveArray = column.as_primitive_opt().ok_or_else(|| { + serde::ser::Error::custom(format!( + "value is not compatible with expected `{}` type", + std::any::type_name::() + )) + })?; + // should match the length of the outer struct + assert_eq!(array.len(), 1); + let number = array.value(0); + Ok(number) +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs new file mode 100644 index 000000000..d42d45fb6 --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs @@ -0,0 +1,268 @@ +use std::{any::type_name, borrow::Cow, marker::PhantomData, sync::Arc}; + +use arrow::{ + array::{Array, ArrayRef, AsArray, OffsetSizeTrait, PrimitiveArray}, + datatypes::{self, ArrowPrimitiveType, UInt8Type}, +}; +use dora_ros2_bridge_msg_gen::types::primitives::{BasicType, GenericString, NestableType}; +use serde::ser::{SerializeSeq, SerializeTuple}; + +use crate::typed::TypeInfo; + +use super::{error, TypedValue}; + +/// Serialize a variable-sized sequence. +pub struct SequenceSerializeWrapper<'a> { + pub item_type: &'a NestableType, + pub column: &'a ArrayRef, + pub type_info: &'a TypeInfo<'a>, +} + +impl serde::Serialize for SequenceSerializeWrapper<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let entry = if let Some(list) = self.column.as_list_opt::() { + // should match the length of the outer struct + assert_eq!(list.len(), 1); + list.value(0) + } else if let Some(list) = self.column.as_list_opt::() { + // should match the length of the outer struct + assert_eq!(list.len(), 1); + list.value(0) + } else if let Some(list) = self.column.as_binary_opt::() { + // should match the length of the outer struct + assert_eq!(list.len(), 1); + Arc::new(list.slice(0, 1)) as ArrayRef + } else if let Some(list) = self.column.as_binary_opt::() { + // should match the length of the outer struct + assert_eq!(list.len(), 1); + Arc::new(list.slice(0, 1)) as ArrayRef + } else { + return Err(error(format!( + "value is not compatible with expected sequence type: {:?}", + self.column + ))); + }; + match &self.item_type { + NestableType::BasicType(t) => match t { + BasicType::I8 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::I16 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::I32 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::I64 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U8 | BasicType::Char | BasicType::Byte => { + ByteSequence { value: &entry }.serialize(serializer) + } + BasicType::U16 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U32 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U64 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::F32 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::F64 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::Bool => BoolArray { value: &entry }.serialize(serializer), + }, + NestableType::NamedType(name) => { + let array = entry + .as_struct_opt() + .ok_or_else(|| error("not a struct array"))?; + let mut seq = serializer.serialize_seq(Some(array.len()))?; + for i in 0..array.len() { + let row = array.slice(i, 1); + seq.serialize_element(&TypedValue { + value: &(Arc::new(row) as ArrayRef), + type_info: &crate::typed::TypeInfo { + package_name: Cow::Borrowed(&self.type_info.package_name), + message_name: Cow::Borrowed(&name.0), + messages: self.type_info.messages.clone(), + }, + })?; + } + seq.end() + } + NestableType::NamespacedType(reference) => { + if reference.namespace != "msg" { + return Err(error(format!( + "sequence references non-message type {reference:?}" + ))); + } + + let array = entry + .as_struct_opt() + .ok_or_else(|| error("not a struct array"))?; + let mut seq = serializer.serialize_seq(Some(array.len()))?; + for i in 0..array.len() { + let row = array.slice(i, 1); + seq.serialize_element(&TypedValue { + value: &(Arc::new(row) as ArrayRef), + type_info: &crate::typed::TypeInfo { + package_name: Cow::Borrowed(&reference.package), + message_name: Cow::Borrowed(&reference.name), + messages: self.type_info.messages.clone(), + }, + })?; + } + seq.end() + } + NestableType::GenericString(s) => match s { + GenericString::String | GenericString::BoundedString(_) => { + match entry.as_string_opt::() { + Some(array) => serialize_arrow_string(serializer, array), + None => { + let array = entry + .as_string_opt::() + .ok_or_else(|| error("expected string array"))?; + serialize_arrow_string(serializer, array) + } + } + } + GenericString::WString => { + todo!("serializing WString sequences") + } + GenericString::BoundedWString(_) => todo!("serializing BoundedWString sequences"), + }, + } + } +} + +fn serialize_arrow_string( + serializer: S, + array: &arrow::array::GenericByteArray>, +) -> Result<::Ok, ::Error> +where + S: serde::Serializer, + O: OffsetSizeTrait, +{ + let mut seq = serializer.serialize_seq(Some(array.len()))?; + for s in array.iter() { + seq.serialize_element(s.unwrap_or_default())?; + } + seq.end() +} + +struct BasicSequence<'a, T> { + value: &'a ArrayRef, + ty: PhantomData, +} + +impl serde::Serialize for BasicSequence<'_, T> +where + T: ArrowPrimitiveType, + T::Native: serde::Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let array: &PrimitiveArray = self + .value + .as_primitive_opt() + .ok_or_else(|| error(format!("not a primitive {} array", type_name::())))?; + + let mut seq = serializer.serialize_seq(Some(array.len()))?; + + for value in array.values() { + seq.serialize_element(value)?; + } + + seq.end() + } +} + +struct ByteSequence<'a> { + value: &'a ArrayRef, +} + +impl serde::Serialize for ByteSequence<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + if let Some(binary) = self.value.as_binary_opt::() { + serialize_binary(serializer, binary) + } else if let Some(binary) = self.value.as_binary_opt::() { + serialize_binary(serializer, binary) + } else { + BasicSequence { + value: self.value, + ty: PhantomData::, + } + .serialize(serializer) + } + } +} + +fn serialize_binary( + serializer: S, + binary: &arrow::array::GenericByteArray>, +) -> Result<::Ok, ::Error> +where + S: serde::Serializer, + O: OffsetSizeTrait, +{ + let mut seq = serializer.serialize_seq(Some(binary.len()))?; + + for value in binary.iter() { + seq.serialize_element(value.unwrap_or_default())?; + } + + seq.end() +} + +struct BoolArray<'a> { + value: &'a ArrayRef, +} + +impl serde::Serialize for BoolArray<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let array = self + .value + .as_boolean_opt() + .ok_or_else(|| error("not a boolean array"))?; + let mut seq = serializer.serialize_tuple(array.len())?; + + for value in array.values() { + seq.serialize_element(&value)?; + } + + seq.end() + } +}