diff --git a/src/de.rs b/src/de.rs index 046d8b9..65b3f57 100644 --- a/src/de.rs +++ b/src/de.rs @@ -8,6 +8,10 @@ use crate::int_key::IntKey; pub trait KeyDeserialize { type Output: Sized; + /// The number of key elements is used for the deserialization of compound keys. + /// It should be equal to PrimaryKey::key().len() + const KEY_ELEMS: u16; + fn from_vec(value: Vec) -> StdResult; fn from_slice(value: &[u8]) -> StdResult { @@ -18,6 +22,8 @@ pub trait KeyDeserialize { impl KeyDeserialize for () { type Output = (); + const KEY_ELEMS: u16 = 0; + #[inline(always)] fn from_vec(_value: Vec) -> StdResult { Ok(()) @@ -27,6 +33,8 @@ impl KeyDeserialize for () { impl KeyDeserialize for Vec { type Output = Vec; + const KEY_ELEMS: u16 = 1; + #[inline(always)] fn from_vec(value: Vec) -> StdResult { Ok(value) @@ -36,6 +44,8 @@ impl KeyDeserialize for Vec { impl KeyDeserialize for &Vec { type Output = Vec; + const KEY_ELEMS: u16 = 1; + #[inline(always)] fn from_vec(value: Vec) -> StdResult { Ok(value) @@ -45,6 +55,8 @@ impl KeyDeserialize for &Vec { impl KeyDeserialize for &[u8] { type Output = Vec; + const KEY_ELEMS: u16 = 1; + #[inline(always)] fn from_vec(value: Vec) -> StdResult { Ok(value) @@ -54,6 +66,8 @@ impl KeyDeserialize for &[u8] { impl KeyDeserialize for String { type Output = String; + const KEY_ELEMS: u16 = 1; + #[inline(always)] fn from_vec(value: Vec) -> StdResult { String::from_utf8(value).map_err(StdError::invalid_utf8) @@ -63,6 +77,8 @@ impl KeyDeserialize for String { impl KeyDeserialize for &String { type Output = String; + const KEY_ELEMS: u16 = 1; + #[inline(always)] fn from_vec(value: Vec) -> StdResult { Self::Output::from_vec(value) @@ -72,6 +88,8 @@ impl KeyDeserialize for &String { impl KeyDeserialize for &str { type Output = String; + const KEY_ELEMS: u16 = 1; + #[inline(always)] fn from_vec(value: Vec) -> StdResult { Self::Output::from_vec(value) @@ -81,6 +99,8 @@ impl KeyDeserialize for &str { impl KeyDeserialize for Addr { type Output = Addr; + const KEY_ELEMS: u16 = 1; + #[inline(always)] fn from_vec(value: Vec) -> StdResult { Ok(Addr::unchecked(String::from_vec(value)?)) @@ -90,6 +110,8 @@ impl KeyDeserialize for Addr { impl KeyDeserialize for &Addr { type Output = Addr; + const KEY_ELEMS: u16 = 1; + #[inline(always)] fn from_vec(value: Vec) -> StdResult { Self::Output::from_vec(value) @@ -101,6 +123,8 @@ macro_rules! integer_de { $(impl KeyDeserialize for $t { type Output = $t; + const KEY_ELEMS: u16 = 1; + #[inline(always)] fn from_vec(value: Vec) -> StdResult { Ok(<$t>::from_cw_bytes(value.as_slice().try_into() @@ -121,33 +145,54 @@ fn parse_length(value: &[u8]) -> StdResult { .into()) } +/// Splits the first key from the value based on the provided number of key elements. +/// The return value is ordered as (first_key, remainder). +/// +fn split_first_key(key_elems: u16, value: &[u8]) -> StdResult<(Vec, &[u8])> { + let mut index = 0; + let mut first_key = Vec::new(); + + // Iterate over the sub keys + for i in 0..key_elems { + let len_slice = &value[index..index + 2]; + index += 2; + let is_last_key = i == key_elems - 1; + + if !is_last_key { + first_key.extend_from_slice(len_slice); + } + + let subkey_len = parse_length(len_slice)?; + first_key.extend_from_slice(&value[index..index + subkey_len]); + index += subkey_len; + } + + let remainder = &value[index..]; + Ok((first_key, remainder)) +} + impl KeyDeserialize for (T, U) { type Output = (T::Output, U::Output); - #[inline(always)] - fn from_vec(mut value: Vec) -> StdResult { - let mut tu = value.split_off(2); - let t_len = parse_length(&value)?; - let u = tu.split_off(t_len); + const KEY_ELEMS: u16 = T::KEY_ELEMS + U::KEY_ELEMS; - Ok((T::from_vec(tu)?, U::from_vec(u)?)) + #[inline(always)] + fn from_vec(value: Vec) -> StdResult { + let (t, u) = split_first_key(T::KEY_ELEMS, value.as_ref())?; + Ok((T::from_vec(t)?, U::from_vec(u.to_vec())?)) } } impl KeyDeserialize for (T, U, V) { type Output = (T::Output, U::Output, V::Output); - #[inline(always)] - fn from_vec(mut value: Vec) -> StdResult { - let mut tuv = value.split_off(2); - let t_len = parse_length(&value)?; - let mut len_uv = tuv.split_off(t_len); - - let mut uv = len_uv.split_off(2); - let u_len = parse_length(&len_uv)?; - let v = uv.split_off(u_len); + const KEY_ELEMS: u16 = T::KEY_ELEMS + U::KEY_ELEMS + V::KEY_ELEMS; - Ok((T::from_vec(tuv)?, U::from_vec(uv)?, V::from_vec(v)?)) + #[inline(always)] + fn from_vec(value: Vec) -> StdResult { + let (t, remainder) = split_first_key(T::KEY_ELEMS, value.as_ref())?; + let (u, v) = split_first_key(U::KEY_ELEMS, remainder)?; + Ok((T::from_vec(t)?, U::from_vec(u)?, V::from_vec(v.to_vec())?)) } } @@ -257,6 +302,74 @@ mod test { ); } + #[test] + fn deserialize_tuple_of_tuples_works() { + assert_eq!( + <((&[u8], &str), (&[u8], &str))>::from_slice( + ((BYTES, STRING), (BYTES, STRING)).joined_key().as_slice() + ) + .unwrap(), + ( + (BYTES.to_vec(), STRING.to_string()), + (BYTES.to_vec(), STRING.to_string()) + ) + ); + } + + #[test] + fn deserialize_tuple_of_triples_works() { + assert_eq!( + <((&[u8], &str, u32), (&[u8], &str, u16))>::from_slice( + ((BYTES, STRING, 1234u32), (BYTES, STRING, 567u16)) + .joined_key() + .as_slice() + ) + .unwrap(), + ( + (BYTES.to_vec(), STRING.to_string(), 1234), + (BYTES.to_vec(), STRING.to_string(), 567) + ) + ); + } + + #[test] + fn deserialize_triple_of_tuples_works() { + assert_eq!( + <((u32, &str), (&str, &[u8]), (i32, i32))>::from_slice( + ((1234u32, STRING), (STRING, BYTES), (1234i32, 567i32)) + .joined_key() + .as_slice() + ) + .unwrap(), + ( + (1234, STRING.to_string()), + (STRING.to_string(), BYTES.to_vec()), + (1234, 567) + ) + ); + } + + #[test] + fn deserialize_triple_of_triples_works() { + assert_eq!( + <((u32, &str, &str), (&str, &[u8], u8), (i32, u8, i32))>::from_slice( + ( + (1234u32, STRING, STRING), + (STRING, BYTES, 123u8), + (4567i32, 89u8, 10i32) + ) + .joined_key() + .as_slice() + ) + .unwrap(), + ( + (1234, STRING.to_string(), STRING.to_string()), + (STRING.to_string(), BYTES.to_vec(), 123), + (4567, 89, 10) + ) + ); + } + #[test] fn deserialize_triple_works() { assert_eq!(