-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
547 additions
and
0 deletions.
There are no files selected for viewing
174 changes: 174 additions & 0 deletions
174
libraries/extensions/ros2-bridge/python/src/typed/serialize/array.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
use std::marker::PhantomData; | ||
|
||
use arrow::{ | ||
array::{Array, ArrayRef, AsArray, PrimitiveArray}, | ||
datatypes::{self, ArrowPrimitiveType}, | ||
}; | ||
use dora_ros2_bridge_msg_gen::types::{ | ||
primitives::{BasicType, NestableType}, | ||
sequences, | ||
}; | ||
use serde::ser::SerializeTuple; | ||
|
||
use super::error; | ||
|
||
/// Serialize an array with known size as tuple. | ||
pub struct ArraySerializeWrapper<'a> { | ||
pub array_info: &'a sequences::Array, | ||
pub column: &'a ArrayRef, | ||
} | ||
|
||
impl serde::Serialize for ArraySerializeWrapper<'_> { | ||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | ||
where | ||
S: serde::Serializer, | ||
{ | ||
let entry = if let Some(list) = self.column.as_list_opt::<i32>() { | ||
// should match the length of the outer struct | ||
assert_eq!(list.len(), 1); | ||
list.value(0) | ||
} else { | ||
// try as large list | ||
let list = self.column.as_list_opt::<i64>().ok_or_else(|| { | ||
error(format!("value is not compatible with expected Array type")) | ||
})?; | ||
// should match the length of the outer struct | ||
assert_eq!(list.len(), 1); | ||
list.value(0) | ||
}; | ||
|
||
let mut s = serializer.serialize_tuple(self.array_info.size)?; | ||
match self.array_info.value_type { | ||
NestableType::BasicType(t) => match t { | ||
BasicType::I8 => s.serialize_element(&BasicArrayAsTuple { | ||
len: self.array_info.size, | ||
value: &entry, | ||
ty: PhantomData::<datatypes::Int8Type>, | ||
})?, | ||
BasicType::I16 => s.serialize_element(&BasicArrayAsTuple { | ||
len: self.array_info.size, | ||
value: &entry, | ||
ty: PhantomData::<datatypes::Int16Type>, | ||
})?, | ||
BasicType::I32 => s.serialize_element(&BasicArrayAsTuple { | ||
len: self.array_info.size, | ||
value: &entry, | ||
ty: PhantomData::<datatypes::Int32Type>, | ||
})?, | ||
BasicType::I64 => s.serialize_element(&BasicArrayAsTuple { | ||
len: self.array_info.size, | ||
value: &entry, | ||
ty: PhantomData::<datatypes::Int64Type>, | ||
})?, | ||
BasicType::U8 | BasicType::Char | BasicType::Byte => { | ||
s.serialize_element(&BasicArrayAsTuple { | ||
len: self.array_info.size, | ||
value: &entry, | ||
ty: PhantomData::<datatypes::UInt8Type>, | ||
})? | ||
} | ||
BasicType::U16 => s.serialize_element(&BasicArrayAsTuple { | ||
len: self.array_info.size, | ||
value: &entry, | ||
ty: PhantomData::<datatypes::UInt16Type>, | ||
})?, | ||
BasicType::U32 => s.serialize_element(&BasicArrayAsTuple { | ||
len: self.array_info.size, | ||
value: &entry, | ||
ty: PhantomData::<datatypes::UInt32Type>, | ||
})?, | ||
BasicType::U64 => s.serialize_element(&BasicArrayAsTuple { | ||
len: self.array_info.size, | ||
value: &entry, | ||
ty: PhantomData::<datatypes::UInt64Type>, | ||
})?, | ||
BasicType::F32 => s.serialize_element(&BasicArrayAsTuple { | ||
len: self.array_info.size, | ||
value: &entry, | ||
ty: PhantomData::<datatypes::Float32Type>, | ||
})?, | ||
BasicType::F64 => s.serialize_element(&BasicArrayAsTuple { | ||
len: self.array_info.size, | ||
value: &entry, | ||
ty: PhantomData::<datatypes::Float64Type>, | ||
})?, | ||
BasicType::Bool => s.serialize_element(&BoolArrayAsTuple { | ||
len: self.array_info.size, | ||
value: &entry, | ||
})?, | ||
}, | ||
NestableType::NamedType(_) => todo!(), | ||
NestableType::NamespacedType(_) => todo!(), | ||
NestableType::GenericString(_) => todo!(), | ||
}; | ||
s.end() | ||
} | ||
} | ||
|
||
/// Serializes a primitive array with known size as tuple. | ||
struct BasicArrayAsTuple<'a, T> { | ||
len: usize, | ||
value: &'a ArrayRef, | ||
ty: PhantomData<T>, | ||
} | ||
|
||
impl<T> serde::Serialize for BasicArrayAsTuple<'_, T> | ||
where | ||
T: ArrowPrimitiveType, | ||
T::Native: serde::Serialize, | ||
{ | ||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | ||
where | ||
S: serde::Serializer, | ||
{ | ||
let mut seq = serializer.serialize_tuple(self.len)?; | ||
let array: &PrimitiveArray<T> = self | ||
.value | ||
.as_primitive_opt() | ||
.ok_or_else(|| error(format!("not a primitive array")))?; | ||
if array.len() != self.len { | ||
return Err(error(format!( | ||
"expected array with length {}, got length {}", | ||
self.len, | ||
array.len() | ||
))); | ||
} | ||
|
||
for value in array.values() { | ||
seq.serialize_element(value)?; | ||
} | ||
|
||
seq.end() | ||
} | ||
} | ||
|
||
struct BoolArrayAsTuple<'a> { | ||
len: usize, | ||
value: &'a ArrayRef, | ||
} | ||
|
||
impl serde::Serialize for BoolArrayAsTuple<'_> { | ||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | ||
where | ||
S: serde::Serializer, | ||
{ | ||
let mut seq = serializer.serialize_tuple(self.len)?; | ||
let array = self | ||
.value | ||
.as_boolean_opt() | ||
.ok_or_else(|| error(format!("not a boolean array")))?; | ||
if array.len() != self.len { | ||
return Err(error(format!( | ||
"expected array with length {}, got length {}", | ||
self.len, | ||
array.len() | ||
))); | ||
} | ||
|
||
for value in array.values() { | ||
seq.serialize_element(&value)?; | ||
} | ||
|
||
seq.end() | ||
} | ||
} |
140 changes: 140 additions & 0 deletions
140
libraries/extensions/ros2-bridge/python/src/typed/serialize/mod.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
use std::{borrow::Cow, collections::HashMap, fmt::Display}; | ||
|
||
use arrow::array::{Array, ArrayRef, AsArray}; | ||
use dora_ros2_bridge_msg_gen::types::{primitives::NestableType, MemberType}; | ||
use serde::ser::SerializeStruct; | ||
|
||
use super::TypeInfo; | ||
|
||
mod array; | ||
mod primitive; | ||
mod sequence; | ||
|
||
#[derive(Debug, Clone)] | ||
pub struct TypedValue<'a> { | ||
pub value: &'a ArrayRef, | ||
pub type_info: &'a TypeInfo<'a>, | ||
} | ||
|
||
impl serde::Serialize for TypedValue<'_> { | ||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | ||
where | ||
S: serde::Serializer, | ||
{ | ||
let empty = HashMap::new(); | ||
let package_messages = self | ||
.type_info | ||
.messages | ||
.get(self.type_info.package_name.as_ref()) | ||
.unwrap_or(&empty); | ||
let message = package_messages | ||
.get(self.type_info.package_name.as_ref()) | ||
.ok_or_else(|| { | ||
error(format!( | ||
"could not find message type {}::{}", | ||
self.type_info.package_name, self.type_info.message_name | ||
)) | ||
})?; | ||
|
||
let input = self | ||
.value | ||
.as_struct_opt() | ||
.ok_or_else(|| error("expected struct array"))?; | ||
for column_name in input.column_names() { | ||
if !message.members.iter().any(|m| m.name == column_name) { | ||
return Err(error(format!( | ||
"given struct has unknown field {column_name}" | ||
)))?; | ||
} | ||
} | ||
if input.is_empty() { | ||
// TODO: publish default value | ||
return Err(error(format!("given struct is empty")))?; | ||
} | ||
if input.len() > 1 { | ||
return Err(error(format!( | ||
"expected single struct instance, got struct array with {} entries", | ||
input.len() | ||
)))?; | ||
} | ||
let mut s = serializer.serialize_struct(&message.name, message.members.len())?; | ||
for field in message.members.iter() { | ||
match input.column_by_name(&field.name) { | ||
Some(column) => match &field.r#type { | ||
MemberType::NestableType(t) => match t { | ||
NestableType::BasicType(t) => { | ||
s.serialize_field(&field.name, &primitive::SerializeWrapper{t, column})?; | ||
} | ||
NestableType::NamedType(name) => { | ||
let referenced_value = &TypedValue { | ||
value: &column, | ||
type_info: &TypeInfo { | ||
package_name: Cow::Borrowed(&self.type_info.package_name), | ||
message_name: Cow::Borrowed(&name.0), | ||
messages: self.type_info.messages.clone(), | ||
} | ||
}; | ||
s.serialize_field(&field.name, &referenced_value)?; | ||
} | ||
NestableType::NamespacedType(reference) => { | ||
if reference.namespace != "msg" { | ||
return Err(error(format!( | ||
"struct field {} references non-message type {reference:?}", | ||
field.name | ||
))); | ||
} | ||
|
||
let referenced_value: &TypedValue<'_> = &TypedValue { | ||
value: &column, | ||
type_info: &TypeInfo { | ||
package_name: Cow::Borrowed(&reference.package), | ||
message_name: Cow::Borrowed(&reference.name), | ||
messages: self.type_info.messages.clone(), | ||
} | ||
}; | ||
s.serialize_field(&field.name, &referenced_value)?; | ||
} | ||
NestableType::GenericString(t) => match t { | ||
dora_ros2_bridge_msg_gen::types::primitives::GenericString::String | | ||
dora_ros2_bridge_msg_gen::types::primitives::GenericString::BoundedString(_) => { | ||
let string = if let Some(string_array) = column.as_string_opt::<i32>() { | ||
// should match the length of the outer struct array | ||
assert_eq!(string_array.len(), 1); | ||
string_array.value(0) | ||
} else { | ||
// try again with large offset type | ||
let string_array = column.as_string_opt::<i64>().ok_or_else(|| error("expected string array"))?; | ||
// should match the length of the outer struct array | ||
assert_eq!(string_array.len(), 1); | ||
string_array.value(0) | ||
}; | ||
s.serialize_field(&field.name, string); | ||
}, | ||
dora_ros2_bridge_msg_gen::types::primitives::GenericString::WString => todo!(), | ||
dora_ros2_bridge_msg_gen::types::primitives::GenericString::BoundedWString(_) => todo!(), | ||
}, | ||
}, | ||
dora_ros2_bridge_msg_gen::types::MemberType::Array(a) => { | ||
s.serialize_field(&field.name, &array::ArraySerializeWrapper {array_info: a, column})?; | ||
} | ||
dora_ros2_bridge_msg_gen::types::MemberType::Sequence(v) => { | ||
s.serialize_field(&field.name, &sequence::SequenceSerializeWrapper {item_type: &v.value_type, column})?; | ||
}, | ||
dora_ros2_bridge_msg_gen::types::MemberType::BoundedSequence(v) => { | ||
s.serialize_field(&field.name, &sequence::SequenceSerializeWrapper {item_type: &v.value_type, column})?; | ||
}, | ||
}, | ||
None => todo!(), // TODO use default value | ||
}; | ||
} | ||
s.end() | ||
} | ||
} | ||
|
||
fn error<E, T>(e: T) -> E | ||
where | ||
T: Display, | ||
E: serde::ser::Error, | ||
{ | ||
serde::ser::Error::custom(e) | ||
} |
Oops, something went wrong.