diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs new file mode 100644 index 000000000..26f684f51 --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs @@ -0,0 +1,174 @@ +use std::marker::PhantomData; + +use arrow::{ + array::{Array, ArrayRef, AsArray, PrimitiveArray}, + datatypes::{self, ArrowPrimitiveType}, +}; +use dora_ros2_bridge_msg_gen::types::{ + primitives::{BasicType, NestableType}, + sequences, +}; +use serde::ser::SerializeTuple; + +use super::error; + +/// Serialize an array with known size as tuple. +pub struct ArraySerializeWrapper<'a> { + pub array_info: &'a sequences::Array, + pub column: &'a ArrayRef, +} + +impl serde::Serialize for ArraySerializeWrapper<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let entry = if let Some(list) = self.column.as_list_opt::() { + // should match the length of the outer struct + assert_eq!(list.len(), 1); + list.value(0) + } else { + // try as large list + let list = self.column.as_list_opt::().ok_or_else(|| { + error(format!("value is not compatible with expected Array type")) + })?; + // should match the length of the outer struct + assert_eq!(list.len(), 1); + list.value(0) + }; + + let mut s = serializer.serialize_tuple(self.array_info.size)?; + match self.array_info.value_type { + NestableType::BasicType(t) => match t { + BasicType::I8 => s.serialize_element(&BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + })?, + BasicType::I16 => s.serialize_element(&BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + })?, + BasicType::I32 => s.serialize_element(&BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + })?, + BasicType::I64 => s.serialize_element(&BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + })?, + BasicType::U8 | BasicType::Char | BasicType::Byte => { + s.serialize_element(&BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + })? + } + BasicType::U16 => s.serialize_element(&BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + })?, + BasicType::U32 => s.serialize_element(&BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + })?, + BasicType::U64 => s.serialize_element(&BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + })?, + BasicType::F32 => s.serialize_element(&BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + })?, + BasicType::F64 => s.serialize_element(&BasicArrayAsTuple { + len: self.array_info.size, + value: &entry, + ty: PhantomData::, + })?, + BasicType::Bool => s.serialize_element(&BoolArrayAsTuple { + len: self.array_info.size, + value: &entry, + })?, + }, + NestableType::NamedType(_) => todo!(), + NestableType::NamespacedType(_) => todo!(), + NestableType::GenericString(_) => todo!(), + }; + s.end() + } +} + +/// Serializes a primitive array with known size as tuple. +struct BasicArrayAsTuple<'a, T> { + len: usize, + value: &'a ArrayRef, + ty: PhantomData, +} + +impl serde::Serialize for BasicArrayAsTuple<'_, T> +where + T: ArrowPrimitiveType, + T::Native: serde::Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut seq = serializer.serialize_tuple(self.len)?; + let array: &PrimitiveArray = self + .value + .as_primitive_opt() + .ok_or_else(|| error(format!("not a primitive array")))?; + if array.len() != self.len { + return Err(error(format!( + "expected array with length {}, got length {}", + self.len, + array.len() + ))); + } + + for value in array.values() { + seq.serialize_element(value)?; + } + + seq.end() + } +} + +struct BoolArrayAsTuple<'a> { + len: usize, + value: &'a ArrayRef, +} + +impl serde::Serialize for BoolArrayAsTuple<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut seq = serializer.serialize_tuple(self.len)?; + let array = self + .value + .as_boolean_opt() + .ok_or_else(|| error(format!("not a boolean array")))?; + if array.len() != self.len { + return Err(error(format!( + "expected array with length {}, got length {}", + self.len, + array.len() + ))); + } + + for value in array.values() { + seq.serialize_element(&value)?; + } + + seq.end() + } +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs new file mode 100644 index 000000000..11b2bcf95 --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs @@ -0,0 +1,140 @@ +use std::{borrow::Cow, collections::HashMap, fmt::Display}; + +use arrow::array::{Array, ArrayRef, AsArray}; +use dora_ros2_bridge_msg_gen::types::{primitives::NestableType, MemberType}; +use serde::ser::SerializeStruct; + +use super::TypeInfo; + +mod array; +mod primitive; +mod sequence; + +#[derive(Debug, Clone)] +pub struct TypedValue<'a> { + pub value: &'a ArrayRef, + pub type_info: &'a TypeInfo<'a>, +} + +impl serde::Serialize for TypedValue<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let empty = HashMap::new(); + let package_messages = self + .type_info + .messages + .get(self.type_info.package_name.as_ref()) + .unwrap_or(&empty); + let message = package_messages + .get(self.type_info.package_name.as_ref()) + .ok_or_else(|| { + error(format!( + "could not find message type {}::{}", + self.type_info.package_name, self.type_info.message_name + )) + })?; + + let input = self + .value + .as_struct_opt() + .ok_or_else(|| error("expected struct array"))?; + for column_name in input.column_names() { + if !message.members.iter().any(|m| m.name == column_name) { + return Err(error(format!( + "given struct has unknown field {column_name}" + )))?; + } + } + if input.is_empty() { + // TODO: publish default value + return Err(error(format!("given struct is empty")))?; + } + if input.len() > 1 { + return Err(error(format!( + "expected single struct instance, got struct array with {} entries", + input.len() + )))?; + } + let mut s = serializer.serialize_struct(&message.name, message.members.len())?; + for field in message.members.iter() { + match input.column_by_name(&field.name) { + Some(column) => match &field.r#type { + MemberType::NestableType(t) => match t { + NestableType::BasicType(t) => { + s.serialize_field(&field.name, &primitive::SerializeWrapper{t, column})?; + } + NestableType::NamedType(name) => { + let referenced_value = &TypedValue { + value: &column, + type_info: &TypeInfo { + package_name: Cow::Borrowed(&self.type_info.package_name), + message_name: Cow::Borrowed(&name.0), + messages: self.type_info.messages.clone(), + } + }; + s.serialize_field(&field.name, &referenced_value)?; + } + NestableType::NamespacedType(reference) => { + if reference.namespace != "msg" { + return Err(error(format!( + "struct field {} references non-message type {reference:?}", + field.name + ))); + } + + let referenced_value: &TypedValue<'_> = &TypedValue { + value: &column, + type_info: &TypeInfo { + package_name: Cow::Borrowed(&reference.package), + message_name: Cow::Borrowed(&reference.name), + messages: self.type_info.messages.clone(), + } + }; + s.serialize_field(&field.name, &referenced_value)?; + } + NestableType::GenericString(t) => match t { + dora_ros2_bridge_msg_gen::types::primitives::GenericString::String | + dora_ros2_bridge_msg_gen::types::primitives::GenericString::BoundedString(_) => { + let string = if let Some(string_array) = column.as_string_opt::() { + // should match the length of the outer struct array + assert_eq!(string_array.len(), 1); + string_array.value(0) + } else { + // try again with large offset type + let string_array = column.as_string_opt::().ok_or_else(|| error("expected string array"))?; + // should match the length of the outer struct array + assert_eq!(string_array.len(), 1); + string_array.value(0) + }; + s.serialize_field(&field.name, string); + }, + dora_ros2_bridge_msg_gen::types::primitives::GenericString::WString => todo!(), + dora_ros2_bridge_msg_gen::types::primitives::GenericString::BoundedWString(_) => todo!(), + }, + }, + dora_ros2_bridge_msg_gen::types::MemberType::Array(a) => { + s.serialize_field(&field.name, &array::ArraySerializeWrapper {array_info: a, column})?; + } + dora_ros2_bridge_msg_gen::types::MemberType::Sequence(v) => { + s.serialize_field(&field.name, &sequence::SequenceSerializeWrapper {item_type: &v.value_type, column})?; + }, + dora_ros2_bridge_msg_gen::types::MemberType::BoundedSequence(v) => { + s.serialize_field(&field.name, &sequence::SequenceSerializeWrapper {item_type: &v.value_type, column})?; + }, + }, + None => todo!(), // TODO use default value + }; + } + s.end() + } +} + +fn error(e: T) -> E +where + T: Display, + E: serde::ser::Error, +{ + serde::ser::Error::custom(e) +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/primitive.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/primitive.rs new file mode 100644 index 000000000..b44945f69 --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/primitive.rs @@ -0,0 +1,85 @@ +use arrow::{ + array::{ArrayRef, AsArray}, + datatypes::{self, ArrowPrimitiveType}, +}; +use dora_ros2_bridge_msg_gen::types::primitives::BasicType; + +pub struct SerializeWrapper<'a> { + pub t: &'a BasicType, + pub column: &'a ArrayRef, +} + +impl serde::Serialize for SerializeWrapper<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self.t { + BasicType::I8 => serializer + .serialize_i8(as_single_primitive::(&self.column)?), + BasicType::I16 => serializer + .serialize_i16(as_single_primitive::( + &self.column, + )?), + BasicType::I32 => serializer + .serialize_i32(as_single_primitive::( + &self.column, + )?), + BasicType::I64 => serializer + .serialize_i64(as_single_primitive::( + &self.column, + )?), + BasicType::U8 | BasicType::Char | BasicType::Byte => serializer.serialize_u8( + as_single_primitive::(&self.column)?, + ), + BasicType::U16 => serializer + .serialize_u16(as_single_primitive::( + &self.column, + )?), + BasicType::U32 => serializer + .serialize_u32(as_single_primitive::( + &self.column, + )?), + BasicType::U64 => serializer + .serialize_u64(as_single_primitive::( + &self.column, + )?), + BasicType::F32 => serializer + .serialize_f32(as_single_primitive::( + &self.column, + )?), + BasicType::F64 => serializer + .serialize_f64(as_single_primitive::( + &self.column, + )?), + BasicType::Bool => { + let array = self.column.as_boolean_opt().ok_or_else(|| { + serde::ser::Error::custom( + "value is not compatible with expected `BooleanArray` type", + ) + })?; + // should match the length of the outer struct + assert_eq!(array.len(), 1); + let field_value = array.value(0); + serializer.serialize_bool(field_value) + } + } + } +} + +fn as_single_primitive(column: &ArrayRef) -> Result +where + T: ArrowPrimitiveType, + E: serde::ser::Error, +{ + let array: &arrow::array::PrimitiveArray = column.as_primitive_opt().ok_or_else(|| { + serde::ser::Error::custom(format!( + "value is not compatible with expected `{}` type", + std::any::type_name::() + )) + })?; + // should match the length of the outer struct + assert_eq!(array.len(), 1); + let number = array.value(0); + Ok(number) +} diff --git a/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs b/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs new file mode 100644 index 000000000..16e720571 --- /dev/null +++ b/libraries/extensions/ros2-bridge/python/src/typed/serialize/sequence.rs @@ -0,0 +1,148 @@ +use std::marker::PhantomData; + +use arrow::{ + array::{Array, ArrayRef, AsArray, PrimitiveArray}, + datatypes::{self, ArrowPrimitiveType}, +}; +use dora_ros2_bridge_msg_gen::types::primitives::{BasicType, NestableType}; +use serde::ser::{SerializeSeq, SerializeTuple}; + +use super::error; + +/// Serialize a variable-sized sequence. +pub struct SequenceSerializeWrapper<'a> { + pub item_type: &'a NestableType, + pub column: &'a ArrayRef, +} + +impl serde::Serialize for SequenceSerializeWrapper<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let entry = if let Some(list) = self.column.as_list_opt::() { + // should match the length of the outer struct + assert_eq!(list.len(), 1); + list.value(0) + } else { + // try as large list + let list = self + .column + .as_list_opt::() + .ok_or_else(|| error("value is not compatible with expected Array type"))?; + // should match the length of the outer struct + assert_eq!(list.len(), 1); + list.value(0) + }; + match &self.item_type { + NestableType::BasicType(t) => match t { + BasicType::I8 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::I16 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::I32 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::I64 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U8 | BasicType::Char | BasicType::Byte => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U16 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U32 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::U64 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::F32 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::F64 => BasicSequence { + value: &entry, + ty: PhantomData::, + } + .serialize(serializer), + BasicType::Bool => BoolArray { value: &entry }.serialize(serializer), + }, + NestableType::NamedType(_) => todo!(), + NestableType::NamespacedType(_) => todo!(), + NestableType::GenericString(_) => todo!(), + } + } +} + +struct BasicSequence<'a, T> { + value: &'a ArrayRef, + ty: PhantomData, +} + +impl serde::Serialize for BasicSequence<'_, T> +where + T: ArrowPrimitiveType, + T::Native: serde::Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let array: &PrimitiveArray = self + .value + .as_primitive_opt() + .ok_or_else(|| error(format!("not a primitive array")))?; + + let mut seq = serializer.serialize_seq(Some(array.len()))?; + + for value in array.values() { + seq.serialize_element(value)?; + } + + seq.end() + } +} + +struct BoolArray<'a> { + value: &'a ArrayRef, +} + +impl serde::Serialize for BoolArray<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let array = self + .value + .as_boolean_opt() + .ok_or_else(|| error(format!("not a boolean array")))?; + let mut seq = serializer.serialize_tuple(array.len())?; + + for value in array.values() { + seq.serialize_element(&value)?; + } + + seq.end() + } +}