Skip to content

Commit

Permalink
Merge 0b88a37 into bf4fde0
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyan-dfinity authored Dec 4, 2022
2 parents bf4fde0 + 0b88a37 commit f10d655
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 53 deletions.
126 changes: 91 additions & 35 deletions rust/candid/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,6 @@ impl<'de> IDLDeserialize<'de> {
expected_type.clone()
};
self.de.wire_type = ty.clone();
self.de
.check_subtype()
.with_context(|| self.de.dump_state())
.with_context(|| {
format!(
"Fail to decode argument {} from {} to {}",
ind, ty, expected_type
)
})?;

let v = T::deserialize(&mut self.de)
.with_context(|| self.de.dump_state())
Expand Down Expand Up @@ -129,6 +120,14 @@ macro_rules! assert {
}};
}

macro_rules! check {
($exp:expr, $msg:expr) => {
if !$exp {
return Err(Error::Subtype($msg.to_string()));
}
};
}

struct Deserializer<'de> {
input: Cursor<&'de [u8]>,
table: TypeEnv,
Expand Down Expand Up @@ -244,8 +243,12 @@ impl<'de> Deserializer<'de> {
.0
.try_into()
.map_err(Error::msg)?),
// We already did subtype checking before deserialize, so this is unreachable code
_ => assert!(false),
t => {
return Err(Error::subtype(format!(
"{} cannot be deserialized to int",
t
)))
}
};
bytes.extend_from_slice(&int.0.to_signed_bytes_le());
visitor.visit_byte_buf(bytes)
Expand All @@ -255,7 +258,10 @@ impl<'de> Deserializer<'de> {
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(self.expect_type == Type::Nat && self.wire_type == Type::Nat);
check!(
self.expect_type == Type::Nat && self.wire_type == Type::Nat,
"nat"
);
let mut bytes = vec![1u8];
let nat = Nat::decode(&mut self.input).map_err(Error::msg)?;
bytes.extend_from_slice(&nat.0.to_bytes_le());
Expand All @@ -266,7 +272,10 @@ impl<'de> Deserializer<'de> {
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(self.expect_type == Type::Principal && self.wire_type == Type::Principal);
check!(
self.expect_type == Type::Principal && self.wire_type == Type::Principal,
"principal"
);
let mut bytes = vec![2u8];
let id = PrincipalBytes::read(&mut self.input)?.inner;
bytes.extend_from_slice(&id);
Expand All @@ -284,7 +293,8 @@ impl<'de> Deserializer<'de> {
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(matches!(self.wire_type, Type::Service(_)));
check!(matches!(self.wire_type, Type::Service(_)), "service");
self.check_subtype()?;
let mut bytes = vec![4u8];
let id = PrincipalBytes::read(&mut self.input)?.inner;
bytes.extend_from_slice(&id);
Expand All @@ -295,7 +305,8 @@ impl<'de> Deserializer<'de> {
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(matches!(self.wire_type, Type::Func(_)));
check!(matches!(self.wire_type, Type::Func(_)), "function");
self.check_subtype()?;
if !BoolValue::read(&mut self.input)?.0 {
return Err(Error::msg("Opaque reference not supported"));
}
Expand Down Expand Up @@ -323,7 +334,7 @@ macro_rules! primitive_impl {
fn [<deserialize_ $ty>]<V>(self, visitor: V) -> Result<V::Value>
where V: Visitor<'de> {
self.unroll_type()?;
assert!(self.expect_type == $type && self.wire_type == $type);
check!(self.expect_type == $type && self.wire_type == $type, stringify!($type));
let val = self.input.$($value)*().map_err(|_| Error::msg(format!("Cannot read {} value", stringify!($type))))?;
//let val: $ty = self.input.read_le()?;
visitor.[<visit_ $ty>](val)
Expand Down Expand Up @@ -417,7 +428,12 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
let nat = Nat::decode(&mut self.input).map_err(Error::msg)?;
nat.0.try_into().map_err(Error::msg)?
}
_ => assert!(false),
t => {
return Err(Error::subtype(format!(
"{} cannot be deserialized to int",
t
)))
}
};
visitor.visit_i128(value)
}
Expand All @@ -427,7 +443,10 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
{
use std::convert::TryInto;
self.unroll_type()?;
assert!(self.expect_type == Type::Nat && self.wire_type == Type::Nat);
check!(
self.expect_type == Type::Nat && self.wire_type == Type::Nat,
"nat"
);
let nat = Nat::decode(&mut self.input).map_err(Error::msg)?;
let value: u128 = nat.0.try_into().map_err(Error::msg)?;
visitor.visit_u128(value)
Expand All @@ -437,15 +456,21 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(self.expect_type == Type::Null && self.wire_type == Type::Null);
check!(
self.expect_type == Type::Null && self.wire_type == Type::Null,
"unit"
);
visitor.visit_unit()
}
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(self.expect_type == Type::Bool && self.wire_type == Type::Bool);
check!(
self.expect_type == Type::Bool && self.wire_type == Type::Bool,
"bool"
);
let res = BoolValue::read(&mut self.input)?;
visitor.visit_bool(res.0)
}
Expand All @@ -454,7 +479,10 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(self.expect_type == Type::Text && self.wire_type == Type::Text);
check!(
self.expect_type == Type::Text && self.wire_type == Type::Text,
"text"
);
let len = Len::read(&mut self.input)?.0;
let bytes = self.borrow_bytes(len)?.to_owned();
let value = String::from_utf8(bytes).map_err(Error::msg)?;
Expand All @@ -465,7 +493,10 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
V: Visitor<'de>,
{
self.unroll_type()?;
assert!(self.expect_type == Type::Text && self.wire_type == Type::Text);
check!(
self.expect_type == Type::Text && self.wire_type == Type::Text,
"text"
);
let len = Len::read(&mut self.input)?.0;
let slice = self.borrow_bytes(len)?;
let value: &str = std::str::from_utf8(slice).map_err(Error::msg)?;
Expand Down Expand Up @@ -494,22 +525,38 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
self.wire_type = *t1.clone();
self.expect_type = *t2.clone();
if BoolValue::read(&mut self.input)?.0 {
if self.check_subtype().is_ok() {
visitor.visit_some(self)
} else {
self.deserialize_ignored_any(serde::de::IgnoredAny)?;
visitor.visit_none()
unsafe {
let v = std::ptr::read(&visitor);
let self_ptr = std::ptr::read(&self);
match v.visit_some(self_ptr) {
Ok(v) => Ok(v),
Err(Error::Subtype(_)) => {
// TODO make sure it consumes a complete value
self.deserialize_ignored_any(serde::de::IgnoredAny)?;
visitor.visit_none()
}
Err(e) => Err(e),
}
}
} else {
visitor.visit_none()
}
}
(_, Type::Opt(t2)) => {
self.expect_type = self.table.trace_type(t2)?;
if !matches!(self.expect_type, Type::Null | Type::Reserved | Type::Opt(_))
&& self.check_subtype().is_ok()
{
visitor.visit_some(self)
if !matches!(self.expect_type, Type::Null | Type::Reserved | Type::Opt(_)) {
unsafe {
let v = std::ptr::read(&visitor);
let self_ptr = std::ptr::read(&self);
match v.visit_some(self_ptr) {
Ok(v) => Ok(v),
Err(Error::Subtype(_)) => {
self.deserialize_ignored_any(serde::de::IgnoredAny)?;
visitor.visit_none()
}
Err(e) => Err(e),
}
}
} else {
self.deserialize_ignored_any(serde::de::IgnoredAny)?;
visitor.visit_none()
Expand Down Expand Up @@ -550,9 +597,10 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
}
fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
self.unroll_type()?;
assert!(
check!(
self.expect_type == Type::Vec(Box::new(Type::Nat8))
&& self.wire_type == Type::Vec(Box::new(Type::Nat8))
&& self.wire_type == Type::Vec(Box::new(Type::Nat8)),
"vec nat8"
);
let len = Len::read(&mut self.input)?.0;
let bytes = self.borrow_bytes(len)?.to_owned();
Expand Down Expand Up @@ -646,7 +694,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
visitor.visit_map(Compound::new(self, Style::Struct { expect, wire }))?;
Ok(value)
}
(Type::Record(_), Type::Empty) => Err(Error::msg("Cannot decode empty type")),
//(Type::Record(_), Type::Empty) => Err(Error::msg("Cannot decode empty type")),
_ => assert!(false),
}
}
Expand Down Expand Up @@ -789,8 +837,16 @@ impl<'de, 'a> de::MapAccess<'de> for Compound<'a, 'de> {
}
Ordering::Less => {
// by subtyping rules, expect_type can only be opt, reserved or null.
self.de.set_field_name(e.id.clone());
let field = e.id.clone();
self.de.set_field_name(field.clone());
self.de.expect_type = expect.pop_front().unwrap().ty;
check!(
matches!(
self.de.expect_type,
Type::Opt(_) | Type::Reserved | Type::Null
),
format!("field {} is not optional field", field)
);
self.de.wire_type = Type::Reserved;
}
Ordering::Greater => {
Expand Down
10 changes: 10 additions & 0 deletions rust/candid/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ pub enum Error {
#[error("binary parser error: {}", .0.get(0).map(|f| format!("{} at byte offset {}", f.message, f.range.start/2)).unwrap_or_else(|| "io error".to_string()))]
Binread(Vec<Label<()>>),

#[error("Subtyping error: {0}")]
Subtype(String),

#[error(transparent)]
Custom(#[from] anyhow::Error),
}
Expand All @@ -27,6 +30,9 @@ impl Error {
pub fn msg<T: ToString>(msg: T) -> Self {
Error::Custom(anyhow::anyhow!(msg.to_string()))
}
pub fn subtype<T: ToString>(msg: T) -> Self {
Error::Subtype(msg.to_string())
}
pub fn report(&self) -> Diagnostic<()> {
match self {
Error::Parse(e) => {
Expand Down Expand Up @@ -57,6 +63,7 @@ impl Error {
let diag = Diagnostic::error().with_message("decoding error");
diag.with_labels(labels.to_vec())
}
Error::Subtype(e) => Diagnostic::error().with_message(e),
Error::Custom(e) => Diagnostic::error().with_message(e.to_string()),
}
}
Expand Down Expand Up @@ -135,6 +142,9 @@ impl de::Error for Error {
fn custom<T: std::fmt::Display>(msg: T) -> Self {
Error::msg(format!("Deserialize error: {}", msg))
}
fn invalid_type(_: de::Unexpected<'_>, exp: &dyn de::Expected) -> Self {
Error::Subtype(format!("{}", exp))
}
}

impl From<io::Error> for Error {
Expand Down
8 changes: 4 additions & 4 deletions rust/candid/src/types/subtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub fn subtype(gamma: &mut Gamma, env: &TypeEnv, t1: &Type, t2: &Type) -> Result
format!("Variant field {}: {} is not a subtype of {}", id, ty1, ty2)
})?,
None => {
return Err(Error::msg(format!(
return Err(Error::subtype(format!(
"Variant field {} not found in the expected type",
id
)))
Expand All @@ -81,7 +81,7 @@ pub fn subtype(gamma: &mut Gamma, env: &TypeEnv, t1: &Type, t2: &Type) -> Result
format!("Method {}: {} is not a subtype of {}", name, ty1, ty2)
})?,
None => {
return Err(Error::msg(format!(
return Err(Error::subtype(format!(
"Method {} is only in the expected type",
name
)))
Expand All @@ -94,7 +94,7 @@ pub fn subtype(gamma: &mut Gamma, env: &TypeEnv, t1: &Type, t2: &Type) -> Result
let m1: HashSet<_> = f1.modes.iter().collect();
let m2: HashSet<_> = f2.modes.iter().collect();
if m1 != m2 {
return Err(Error::msg("Function mode mismatch"));
return Err(Error::subtype("Function mode mismatch"));
}
let args1 = to_tuple(&f1.args);
let args2 = to_tuple(&f2.args);
Expand All @@ -109,7 +109,7 @@ pub fn subtype(gamma: &mut Gamma, env: &TypeEnv, t1: &Type, t2: &Type) -> Result
(_, Class(_, t)) => subtype(gamma, env, t1, t),
(Unknown, _) => unreachable!(),
(_, Unknown) => unreachable!(),
(_, _) => Err(Error::msg(format!("{} is not a subtype of {}", t1, t2))),
(_, _) => Err(Error::subtype(format!("{} is not a subtype of {}", t1, t2))),
}
}

Expand Down
14 changes: 4 additions & 10 deletions rust/candid/tests/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,7 @@ fn test_integer() {
Int::parse(b"-60000000000000000").unwrap(),
"4449444c00017c8080e88b96cab5957f",
);
check_error(
|| test_decode(&hex("4449444c00017c2a"), &42i64),
"int is not a subtype of int64",
);
check_error(|| test_decode(&hex("4449444c00017c2a"), &42i64), "Int64");
}

#[test]
Expand Down Expand Up @@ -320,10 +317,7 @@ fn test_extra_fields() {

let bytes = encode(&E2::Foo);
test_decode(&bytes, &Some(E2::Foo));
check_error(
|| test_decode(&bytes, &E1::Foo),
"Variant field 3_303_867 not found",
);
test_decode(&bytes, &E1::Foo);
}

#[test]
Expand Down Expand Up @@ -490,7 +484,7 @@ fn test_tuple() {
&(Int::from(42), "💩"),
)
},
"field 1: text is only in the expected type",
"is not a tuple type",
);
}

Expand All @@ -509,7 +503,7 @@ fn test_variant() {

check_error(
|| test_decode(&hex("4449444c016b02b4d3c9017fe6fdd5017f010000"), &Unit::Bar),
"Variant field 3_303_860 not found",
"Unknown variant field 3_303_860",
);

let res: Result<String, String> = Ok("good".to_string());
Expand Down
8 changes: 4 additions & 4 deletions test/construct.test.did
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ assert blob "DIDL\00\00" == "(null)" : (Opt) "op

// vector
assert blob "DIDL\01\6d\7c\01\00\00" == "(vec {})" : (vec int) "vec: empty";
assert blob "DIDL\01\6d\7c\01\00\00" !: (vec int8) "vec: non subtype empty";
assert blob "DIDL\01\6d\7c\01\00\00" : (vec int8) "vec: non subtype empty";
assert blob "DIDL\01\6d\7c\01\00\02\01\02" == "(vec { 1; 2 })" : (vec int) "vec";
assert blob "DIDL\01\6d\7b\01\00\02\01\02" == "(blob \"\\01\\02\")" : (vec nat8) "vec: blob";
assert blob "DIDL\01\6d\00\01\00\00" == "(vec {})" : (Vec) "vec: recursive vector";
Expand Down Expand Up @@ -139,10 +139,10 @@ assert "(variant {})" !
assert blob "DIDL\01\6b\00\01\00" !: (variant {}) "variant: no empty value";
assert blob "DIDL\01\6b\01\00\7f\01\00\00" == "(variant {0})" : (variant {0}) "variant: numbered field";
assert blob "DIDL\01\6b\01\00\7f\01\00\00\2a" !: (variant {0:int}) "variant: type mismatch";
assert blob "DIDL\01\6b\02\00\7f\01\7c\01\00\01\2a" !: (variant {0:int; 1:int}) "variant: type mismatch in unused tag";
assert blob "DIDL\01\6b\02\00\7f\01\7c\01\00\01\2a" : (variant {0:int; 1:int}) "variant: type mismatch in unused tag";
assert blob "DIDL\01\6b\01\00\7f\01\00\00" == "(variant {0})" : (variant {0;1}) "variant: ignore field";
assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\00" !: (variant {0}) "variant {0;1} !<: variant {0}";
assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\00" == "(null)" : (opt variant {0}) "variant {0;1} <: opt variant {0}";
assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\00" : (variant {0}) "variant {0;1} <: variant {0}";
assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\00" == "(opt variant {0})" : (opt variant {0}) "variant {0;1} <: opt variant {0}";
assert blob "DIDL\01\6b\02\00\7f\01\7f\01\00\01" == "(variant {1})" : (variant {0;1;2}) "variant: change index";
assert blob "DIDL\01\6b\01\00\7f\01\00\00" !: (variant {1}) "variant: missing field";
assert blob "DIDL\01\6b\01\00\7f\01\00\01" !: (variant {0}) "variant: index out of range";
Expand Down

0 comments on commit f10d655

Please sign in to comment.