Skip to content

Commit

Permalink
Merge 683885c into 28a9fe8
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyan-dfinity authored Dec 2, 2022
2 parents 28a9fe8 + 683885c commit fff535b
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 48 deletions.
89 changes: 52 additions & 37 deletions rust/candid/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,6 @@ impl<'de> IDLDeserialize<'de> {
expected_type.clone()
};
self.de.wire_type = ty.clone();
self.de
.check_subtype()
.with_context(|| self.de.dump_state())
.with_context(|| {
format!(
"Fail to decode argument {} from {} to {}",
ind, ty, expected_type
)
})?;

let v = T::deserialize(&mut self.de)
.with_context(|| self.de.dump_state())
Expand Down Expand Up @@ -235,7 +226,6 @@ impl<'de> Deserializer<'de> {
{
use std::convert::TryInto;
self.unroll_type()?;
assert!(self.expect_type == Type::Int);
let mut bytes = vec![0u8];
let int = match &self.wire_type {
Type::Int => Int::decode(&mut self.input).map_err(Error::msg)?,
Expand All @@ -244,20 +234,25 @@ impl<'de> Deserializer<'de> {
.0
.try_into()
.map_err(Error::msg)?),
// We already did subtype checking before deserialize, so this is unreachable code
_ => assert!(false),
t => {
return Err(Error::subtype(format!(
"{} cannot be deserialized to int",
t
)))
}
};
bytes.extend_from_slice(&int.0.to_signed_bytes_le());
assert!(self.expect_type == Type::Int);
visitor.visit_byte_buf(bytes)
}
fn deserialize_nat<'a, V>(&'a mut self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(self.expect_type == Type::Nat && self.wire_type == Type::Nat);
let mut bytes = vec![1u8];
let nat = Nat::decode(&mut self.input).map_err(Error::msg)?;
assert!(self.expect_type == Type::Nat && self.wire_type == Type::Nat);
bytes.extend_from_slice(&nat.0.to_bytes_le());
visitor.visit_byte_buf(bytes)
}
Expand All @@ -266,9 +261,9 @@ impl<'de> Deserializer<'de> {
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(self.expect_type == Type::Principal && self.wire_type == Type::Principal);
let mut bytes = vec![2u8];
let id = PrincipalBytes::read(&mut self.input)?.inner;
assert!(self.expect_type == Type::Principal && self.wire_type == Type::Principal);
bytes.extend_from_slice(&id);
visitor.visit_byte_buf(bytes)
}
Expand All @@ -284,9 +279,10 @@ impl<'de> Deserializer<'de> {
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(matches!(self.wire_type, Type::Service(_)));
self.check_subtype()?;
let mut bytes = vec![4u8];
let id = PrincipalBytes::read(&mut self.input)?.inner;
assert!(matches!(self.wire_type, Type::Service(_)));
bytes.extend_from_slice(&id);
visitor.visit_byte_buf(bytes)
}
Expand All @@ -295,7 +291,7 @@ impl<'de> Deserializer<'de> {
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(matches!(self.wire_type, Type::Func(_)));
self.check_subtype()?;
if !BoolValue::read(&mut self.input)?.0 {
return Err(Error::msg("Opaque reference not supported"));
}
Expand All @@ -305,6 +301,7 @@ impl<'de> Deserializer<'de> {
let meth = self.borrow_bytes(len)?;
// TODO find a better way
leb128::write::unsigned(&mut bytes, len as u64)?;
assert!(matches!(self.wire_type, Type::Func(_)));
bytes.extend_from_slice(meth);
bytes.extend_from_slice(&id);
visitor.visit_byte_buf(bytes)
Expand All @@ -323,9 +320,9 @@ macro_rules! primitive_impl {
fn [<deserialize_ $ty>]<V>(self, visitor: V) -> Result<V::Value>
where V: Visitor<'de> {
self.unroll_type()?;
assert!(self.expect_type == $type && self.wire_type == $type);
let val = self.input.$($value)*().map_err(|_| Error::msg(format!("Cannot read {} value", stringify!($type))))?;
//let val: $ty = self.input.read_le()?;
assert!(self.expect_type == $type && self.wire_type == $type);
visitor.[<visit_ $ty>](val)
}
}
Expand Down Expand Up @@ -407,7 +404,6 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
{
use std::convert::TryInto;
self.unroll_type()?;
assert!(self.expect_type == Type::Int);
let value: i128 = match &self.wire_type {
Type::Int => {
let int = Int::decode(&mut self.input).map_err(Error::msg)?;
Expand All @@ -417,8 +413,14 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
let nat = Nat::decode(&mut self.input).map_err(Error::msg)?;
nat.0.try_into().map_err(Error::msg)?
}
_ => assert!(false),
t => {
return Err(Error::subtype(format!(
"{} cannot be deserialized to int",
t
)))
}
};
assert!(self.expect_type == Type::Int);
visitor.visit_i128(value)
}
fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value>
Expand All @@ -427,9 +429,9 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
{
use std::convert::TryInto;
self.unroll_type()?;
assert!(self.expect_type == Type::Nat && self.wire_type == Type::Nat);
let nat = Nat::decode(&mut self.input).map_err(Error::msg)?;
let value: u128 = nat.0.try_into().map_err(Error::msg)?;
assert!(self.expect_type == Type::Nat && self.wire_type == Type::Nat);
visitor.visit_u128(value)
}
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
Expand All @@ -445,30 +447,30 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(self.expect_type == Type::Bool && self.wire_type == Type::Bool);
let res = BoolValue::read(&mut self.input)?;
assert!(self.expect_type == Type::Bool && self.wire_type == Type::Bool);
visitor.visit_bool(res.0)
}
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(self.expect_type == Type::Text && self.wire_type == Type::Text);
let len = Len::read(&mut self.input)?.0;
let bytes = self.borrow_bytes(len)?.to_owned();
let value = String::from_utf8(bytes).map_err(Error::msg)?;
assert!(self.expect_type == Type::Text && self.wire_type == Type::Text);
visitor.visit_string(value)
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(self.expect_type == Type::Text && self.wire_type == Type::Text);
let len = Len::read(&mut self.input)?.0;
let slice = self.borrow_bytes(len)?;
let value: &str = std::str::from_utf8(slice).map_err(Error::msg)?;
assert!(self.expect_type == Type::Text && self.wire_type == Type::Text);
visitor.visit_borrowed_str(value)
}
fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
Expand All @@ -493,23 +495,36 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
(Type::Opt(t1), Type::Opt(t2)) => {
self.wire_type = *t1.clone();
self.expect_type = *t2.clone();
if BoolValue::read(&mut self.input)?.0 {
if self.check_subtype().is_ok() {
visitor.visit_some(self)
unsafe {
let v = std::ptr::read(&visitor);
if BoolValue::read(&mut self.input)?.0 {
match v.visit_some(self) {
Ok(v) => Ok(v),
Err(Error::Subtype(_)) => {
let v = std::ptr::read(&visitor);
v.visit_none()
}
Err(e) => Err(e),
}
} else {
self.deserialize_ignored_any(serde::de::IgnoredAny)?;
visitor.visit_none()
v.visit_none()
}
} else {
visitor.visit_none()
}
}
(_, Type::Opt(t2)) => {
self.expect_type = self.table.trace_type(t2)?;
if !matches!(self.expect_type, Type::Null | Type::Reserved | Type::Opt(_))
&& self.check_subtype().is_ok()
{
visitor.visit_some(self)
if !matches!(self.expect_type, Type::Null | Type::Reserved | Type::Opt(_)) {
unsafe {
let v = std::ptr::read(&visitor);
match v.visit_some(self) {
Ok(v) => Ok(v),
Err(Error::Subtype(_)) => {
let v = std::ptr::read(&visitor);
v.visit_none()
}
Err(e) => Err(e),
}
}
} else {
self.deserialize_ignored_any(serde::de::IgnoredAny)?;
visitor.visit_none()
Expand Down Expand Up @@ -550,13 +565,13 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
}
fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
self.unroll_type()?;
let len = Len::read(&mut self.input)?.0;
let bytes = self.borrow_bytes(len)?.to_owned();
//let bytes = Bytes::read(&mut self.input)?.inner;
assert!(
self.expect_type == Type::Vec(Box::new(Type::Nat8))
&& self.wire_type == Type::Vec(Box::new(Type::Nat8))
);
let len = Len::read(&mut self.input)?.0;
let bytes = self.borrow_bytes(len)?.to_owned();
//let bytes = Bytes::read(&mut self.input)?.inner;
visitor.visit_byte_buf(bytes)
}
fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
Expand Down
10 changes: 10 additions & 0 deletions rust/candid/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ pub enum Error {
#[error("binary parser error: {}", .0.get(0).map(|f| format!("{} at byte offset {}", f.message, f.range.start/2)).unwrap_or_else(|| "io error".to_string()))]
Binread(Vec<Label<()>>),

#[error("Subtyping error: {0}")]
Subtype(String),

#[error(transparent)]
Custom(#[from] anyhow::Error),
}
Expand All @@ -27,6 +30,9 @@ impl Error {
pub fn msg<T: ToString>(msg: T) -> Self {
Error::Custom(anyhow::anyhow!(msg.to_string()))
}
pub fn subtype<T: ToString>(msg: T) -> Self {
Error::Subtype(msg.to_string())
}
pub fn report(&self) -> Diagnostic<()> {
match self {
Error::Parse(e) => {
Expand Down Expand Up @@ -57,6 +63,7 @@ impl Error {
let diag = Diagnostic::error().with_message("decoding error");
diag.with_labels(labels.to_vec())
}
Error::Subtype(e) => Diagnostic::error().with_message(e),
Error::Custom(e) => Diagnostic::error().with_message(e.to_string()),
}
}
Expand Down Expand Up @@ -135,6 +142,9 @@ impl de::Error for Error {
fn custom<T: std::fmt::Display>(msg: T) -> Self {
Error::msg(format!("Deserialize error: {}", msg))
}
fn invalid_type(_: de::Unexpected<'_>, exp: &dyn de::Expected) -> Self {
Error::Subtype(format!("{}", exp))
}
}

impl From<io::Error> for Error {
Expand Down
8 changes: 4 additions & 4 deletions rust/candid/src/types/subtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub fn subtype(gamma: &mut Gamma, env: &TypeEnv, t1: &Type, t2: &Type) -> Result
format!("Variant field {}: {} is not a subtype of {}", id, ty1, ty2)
})?,
None => {
return Err(Error::msg(format!(
return Err(Error::subtype(format!(
"Variant field {} not found in the expected type",
id
)))
Expand All @@ -81,7 +81,7 @@ pub fn subtype(gamma: &mut Gamma, env: &TypeEnv, t1: &Type, t2: &Type) -> Result
format!("Method {}: {} is not a subtype of {}", name, ty1, ty2)
})?,
None => {
return Err(Error::msg(format!(
return Err(Error::subtype(format!(
"Method {} is only in the expected type",
name
)))
Expand All @@ -94,7 +94,7 @@ pub fn subtype(gamma: &mut Gamma, env: &TypeEnv, t1: &Type, t2: &Type) -> Result
let m1: HashSet<_> = f1.modes.iter().collect();
let m2: HashSet<_> = f2.modes.iter().collect();
if m1 != m2 {
return Err(Error::msg("Function mode mismatch"));
return Err(Error::subtype("Function mode mismatch"));
}
let args1 = to_tuple(&f1.args);
let args2 = to_tuple(&f2.args);
Expand All @@ -109,7 +109,7 @@ pub fn subtype(gamma: &mut Gamma, env: &TypeEnv, t1: &Type, t2: &Type) -> Result
(_, Class(_, t)) => subtype(gamma, env, t1, t),
(Unknown, _) => unreachable!(),
(_, Unknown) => unreachable!(),
(_, _) => Err(Error::msg(format!("{} is not a subtype of {}", t1, t2))),
(_, _) => Err(Error::subtype(format!("{} is not a subtype of {}", t1, t2))),
}
}

Expand Down
6 changes: 3 additions & 3 deletions rust/candid/tests/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ fn test_integer() {
);
check_error(
|| test_decode(&hex("4449444c00017c2a"), &42i64),
"int is not a subtype of int64",
"Int64 value",
);
}

Expand Down Expand Up @@ -320,10 +320,10 @@ fn test_extra_fields() {

let bytes = encode(&E2::Foo);
test_decode(&bytes, &Some(E2::Foo));
check_error(
/*check_error(
|| test_decode(&bytes, &E1::Foo),
"Variant field 3_303_867 not found",
);
);*/
}

#[test]
Expand Down
8 changes: 4 additions & 4 deletions test/construct.test.did
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ assert blob "DIDL\00\00" == "(null)" : (Opt) "op

// vector
assert blob "DIDL\01\6d\7c\01\00\00" == "(vec {})" : (vec int) "vec: empty";
assert blob "DIDL\01\6d\7c\01\00\00" !: (vec int8) "vec: non subtype empty";
assert blob "DIDL\01\6d\7c\01\00\00" : (vec int8) "vec: non subtype empty";
assert blob "DIDL\01\6d\7c\01\00\02\01\02" == "(vec { 1; 2 })" : (vec int) "vec";
assert blob "DIDL\01\6d\7b\01\00\02\01\02" == "(blob \"\\01\\02\")" : (vec nat8) "vec: blob";
assert blob "DIDL\01\6d\00\01\00\00" == "(vec {})" : (Vec) "vec: recursive vector";
Expand Down Expand Up @@ -139,10 +139,10 @@ assert "(variant {})" !
assert blob "DIDL\01\6b\00\01\00" !: (variant {}) "variant: no empty value";
assert blob "DIDL\01\6b\01\00\7f\01\00\00" == "(variant {0})" : (variant {0}) "variant: numbered field";
assert blob "DIDL\01\6b\01\00\7f\01\00\00\2a" !: (variant {0:int}) "variant: type mismatch";
assert blob "DIDL\01\6b\02\00\7f\01\7c\01\00\01\2a" !: (variant {0:int; 1:int}) "variant: type mismatch in unused tag";
assert blob "DIDL\01\6b\02\00\7f\01\7c\01\00\01\2a" : (variant {0:int; 1:int}) "variant: type mismatch in unused tag";
assert blob "DIDL\01\6b\01\00\7f\01\00\00" == "(variant {0})" : (variant {0;1}) "variant: ignore field";
assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\00" !: (variant {0}) "variant {0;1} !<: variant {0}";
assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\00" == "(null)" : (opt variant {0}) "variant {0;1} <: opt variant {0}";
assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\00" : (variant {0}) "variant {0;1} <: variant {0}";
assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\00" == "(variant {0})" : (opt variant {0}) "variant {0;1} <: opt variant {0}";
assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\01" == "(variant {1})" : (variant {0;1;2}) "variant: change index";
assert blob "DIDL\01\6b\01\00\7f\01\00\00" !: (variant {1}) "variant: missing field";
assert blob "DIDL\01\6b\01\00\7f\01\00\01" !: (variant {0}) "variant: index out of range";
Expand Down

0 comments on commit fff535b

Please sign in to comment.