Skip to content

Commit

Permalink
Fix serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
phil-opp committed Jan 21, 2024
1 parent 713ac5e commit 23625a2
Show file tree
Hide file tree
Showing 4 changed files with 547 additions and 0 deletions.
174 changes: 174 additions & 0 deletions libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
use std::marker::PhantomData;

use arrow::{
array::{Array, ArrayRef, AsArray, PrimitiveArray},
datatypes::{self, ArrowPrimitiveType},
};
use dora_ros2_bridge_msg_gen::types::{
primitives::{BasicType, NestableType},
sequences,
};
use serde::ser::SerializeTuple;

use super::error;

/// Serialize an array with known size as tuple.
pub struct ArraySerializeWrapper<'a> {
pub array_info: &'a sequences::Array,
pub column: &'a ArrayRef,
}

impl serde::Serialize for ArraySerializeWrapper<'_> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let entry = if let Some(list) = self.column.as_list_opt::<i32>() {
// should match the length of the outer struct
assert_eq!(list.len(), 1);
list.value(0)
} else {
// try as large list
let list = self.column.as_list_opt::<i64>().ok_or_else(|| {
error(format!("value is not compatible with expected Array type"))
})?;
// should match the length of the outer struct
assert_eq!(list.len(), 1);
list.value(0)
};

let mut s = serializer.serialize_tuple(self.array_info.size)?;
match self.array_info.value_type {
NestableType::BasicType(t) => match t {
BasicType::I8 => s.serialize_element(&BasicArrayAsTuple {
len: self.array_info.size,
value: &entry,
ty: PhantomData::<datatypes::Int8Type>,
})?,
BasicType::I16 => s.serialize_element(&BasicArrayAsTuple {
len: self.array_info.size,
value: &entry,
ty: PhantomData::<datatypes::Int16Type>,
})?,
BasicType::I32 => s.serialize_element(&BasicArrayAsTuple {
len: self.array_info.size,
value: &entry,
ty: PhantomData::<datatypes::Int32Type>,
})?,
BasicType::I64 => s.serialize_element(&BasicArrayAsTuple {
len: self.array_info.size,
value: &entry,
ty: PhantomData::<datatypes::Int64Type>,
})?,
BasicType::U8 | BasicType::Char | BasicType::Byte => {
s.serialize_element(&BasicArrayAsTuple {
len: self.array_info.size,
value: &entry,
ty: PhantomData::<datatypes::UInt8Type>,
})?
}
BasicType::U16 => s.serialize_element(&BasicArrayAsTuple {
len: self.array_info.size,
value: &entry,
ty: PhantomData::<datatypes::UInt16Type>,
})?,
BasicType::U32 => s.serialize_element(&BasicArrayAsTuple {
len: self.array_info.size,
value: &entry,
ty: PhantomData::<datatypes::UInt32Type>,
})?,
BasicType::U64 => s.serialize_element(&BasicArrayAsTuple {
len: self.array_info.size,
value: &entry,
ty: PhantomData::<datatypes::UInt64Type>,
})?,
BasicType::F32 => s.serialize_element(&BasicArrayAsTuple {
len: self.array_info.size,
value: &entry,
ty: PhantomData::<datatypes::Float32Type>,
})?,
BasicType::F64 => s.serialize_element(&BasicArrayAsTuple {
len: self.array_info.size,
value: &entry,
ty: PhantomData::<datatypes::Float64Type>,
})?,
BasicType::Bool => s.serialize_element(&BoolArrayAsTuple {
len: self.array_info.size,
value: &entry,
})?,
},
NestableType::NamedType(_) => todo!(),
NestableType::NamespacedType(_) => todo!(),
NestableType::GenericString(_) => todo!(),
};
s.end()
}
}

/// Serializes a primitive array with known size as tuple.
struct BasicArrayAsTuple<'a, T> {
len: usize,
value: &'a ArrayRef,
ty: PhantomData<T>,
}

impl<T> serde::Serialize for BasicArrayAsTuple<'_, T>
where
T: ArrowPrimitiveType,
T::Native: serde::Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut seq = serializer.serialize_tuple(self.len)?;
let array: &PrimitiveArray<T> = self
.value
.as_primitive_opt()
.ok_or_else(|| error(format!("not a primitive array")))?;
if array.len() != self.len {
return Err(error(format!(
"expected array with length {}, got length {}",
self.len,
array.len()
)));
}

for value in array.values() {
seq.serialize_element(value)?;
}

seq.end()
}
}

struct BoolArrayAsTuple<'a> {
len: usize,
value: &'a ArrayRef,
}

impl serde::Serialize for BoolArrayAsTuple<'_> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut seq = serializer.serialize_tuple(self.len)?;
let array = self
.value
.as_boolean_opt()
.ok_or_else(|| error(format!("not a boolean array")))?;
if array.len() != self.len {
return Err(error(format!(
"expected array with length {}, got length {}",
self.len,
array.len()
)));
}

for value in array.values() {
seq.serialize_element(&value)?;
}

seq.end()
}
}
140 changes: 140 additions & 0 deletions libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
use std::{borrow::Cow, collections::HashMap, fmt::Display};

use arrow::array::{Array, ArrayRef, AsArray};
use dora_ros2_bridge_msg_gen::types::{primitives::NestableType, MemberType};
use serde::ser::SerializeStruct;

use super::TypeInfo;

mod array;
mod primitive;
mod sequence;

#[derive(Debug, Clone)]
pub struct TypedValue<'a> {
pub value: &'a ArrayRef,
pub type_info: &'a TypeInfo<'a>,
}

impl serde::Serialize for TypedValue<'_> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let empty = HashMap::new();
let package_messages = self
.type_info
.messages
.get(self.type_info.package_name.as_ref())
.unwrap_or(&empty);
let message = package_messages
.get(self.type_info.package_name.as_ref())
.ok_or_else(|| {
error(format!(
"could not find message type {}::{}",
self.type_info.package_name, self.type_info.message_name
))
})?;

let input = self
.value
.as_struct_opt()
.ok_or_else(|| error("expected struct array"))?;
for column_name in input.column_names() {
if !message.members.iter().any(|m| m.name == column_name) {
return Err(error(format!(
"given struct has unknown field {column_name}"
)))?;
}
}
if input.is_empty() {
// TODO: publish default value
return Err(error(format!("given struct is empty")))?;
}
if input.len() > 1 {
return Err(error(format!(
"expected single struct instance, got struct array with {} entries",
input.len()
)))?;
}
let mut s = serializer.serialize_struct(&message.name, message.members.len())?;
for field in message.members.iter() {
match input.column_by_name(&field.name) {
Some(column) => match &field.r#type {
MemberType::NestableType(t) => match t {
NestableType::BasicType(t) => {
s.serialize_field(&field.name, &primitive::SerializeWrapper{t, column})?;
}
NestableType::NamedType(name) => {
let referenced_value = &TypedValue {
value: &column,
type_info: &TypeInfo {
package_name: Cow::Borrowed(&self.type_info.package_name),
message_name: Cow::Borrowed(&name.0),
messages: self.type_info.messages.clone(),
}
};
s.serialize_field(&field.name, &referenced_value)?;
}
NestableType::NamespacedType(reference) => {
if reference.namespace != "msg" {
return Err(error(format!(
"struct field {} references non-message type {reference:?}",
field.name
)));
}

let referenced_value: &TypedValue<'_> = &TypedValue {
value: &column,
type_info: &TypeInfo {
package_name: Cow::Borrowed(&reference.package),
message_name: Cow::Borrowed(&reference.name),
messages: self.type_info.messages.clone(),
}
};
s.serialize_field(&field.name, &referenced_value)?;
}
NestableType::GenericString(t) => match t {
dora_ros2_bridge_msg_gen::types::primitives::GenericString::String |
dora_ros2_bridge_msg_gen::types::primitives::GenericString::BoundedString(_) => {
let string = if let Some(string_array) = column.as_string_opt::<i32>() {
// should match the length of the outer struct array
assert_eq!(string_array.len(), 1);
string_array.value(0)
} else {
// try again with large offset type
let string_array = column.as_string_opt::<i64>().ok_or_else(|| error("expected string array"))?;
// should match the length of the outer struct array
assert_eq!(string_array.len(), 1);
string_array.value(0)
};
s.serialize_field(&field.name, string);
},
dora_ros2_bridge_msg_gen::types::primitives::GenericString::WString => todo!(),
dora_ros2_bridge_msg_gen::types::primitives::GenericString::BoundedWString(_) => todo!(),
},
},
dora_ros2_bridge_msg_gen::types::MemberType::Array(a) => {
s.serialize_field(&field.name, &array::ArraySerializeWrapper {array_info: a, column})?;
}
dora_ros2_bridge_msg_gen::types::MemberType::Sequence(v) => {
s.serialize_field(&field.name, &sequence::SequenceSerializeWrapper {item_type: &v.value_type, column})?;
},
dora_ros2_bridge_msg_gen::types::MemberType::BoundedSequence(v) => {
s.serialize_field(&field.name, &sequence::SequenceSerializeWrapper {item_type: &v.value_type, column})?;
},
},
None => todo!(), // TODO use default value
};
}
s.end()
}
}

fn error<E, T>(e: T) -> E
where
T: Display,
E: serde::ser::Error,
{
serde::ser::Error::custom(e)
}
Loading

0 comments on commit 23625a2

Please sign in to comment.