From 2a74a29f55e112120539baaaf5f27b8dec411f75 Mon Sep 17 00:00:00 2001 From: i1i1 Date: Thu, 24 Jun 2021 13:56:47 +0300 Subject: [PATCH] Migrate to another dictionary type Signed-off-by: i1i1 --- Cargo.toml | 2 +- src/de.rs | 48 +++++--- src/lib.rs | 1 + src/py.rs | 333 +++++++++++++++++++++++++++++++++++++++++++++++++++++ src/ser.rs | 68 ++++++----- 5 files changed, 407 insertions(+), 45 deletions(-) create mode 100644 src/py.rs diff --git a/Cargo.toml b/Cargo.toml index 2349e52..89be457 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ documentation = "https://docs.rs/crate/pythonize/" [dependencies] serde = { version = "1.0", default-features = false, features = ["std"] } -pyo3 = { version = "0.13", default-features = false } +pyo3 = { version = "0.13", default-features = false, features = ["macros"] } [dev-dependencies] serde = { version = "1.0", default-features = false, features = ["derive"] } diff --git a/src/de.rs b/src/de.rs index 73399d8..cffef93 100644 --- a/src/de.rs +++ b/src/de.rs @@ -1,9 +1,10 @@ -use pyo3::types::*; +use pyo3::{types::*, PyNativeType}; use serde::de::{self, IntoDeserializer}; use serde::Deserialize; use std::convert::TryInto; use crate::error::{PythonizeError, Result}; +use crate::py::Dict; /// Attempt to convert a Python object to an instance of `T` pub fn depythonize<'de, T>(obj: &'de PyAny) -> Result @@ -38,9 +39,13 @@ impl<'de> Depythonizer<'de> { fn dict_access( &self, - ) -> Result>> { - let dict: &PyDict = self.input.downcast()?; - Ok(PyDictAccess::new(dict.iter())) + ) -> Result>> { + let dict = if self.input.is_instance::()? { + self.input.downcast::()?.try_into()? + } else { + self.input.extract::()? + }; + Ok(DictAccess::new(dict.into_iter(self.input.py()))) } } @@ -70,7 +75,7 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Depythonizer<'de> { self.deserialize_bool(visitor) } else if obj.is_instance::()? || obj.is_instance::()? { self.deserialize_bytes(visitor) - } else if obj.is_instance::()? { + } else if obj.is_instance::()? || obj.is_instance::()? { self.deserialize_map(visitor) } else if obj.is_instance::()? { self.deserialize_f64(visitor) @@ -248,19 +253,25 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Depythonizer<'de> { V: de::Visitor<'de>, { let item = self.input; - if item.is_instance::()? { + if item.is_instance::()? || item.is_instance::()? { // Get the enum variant from the dict key - let d: &PyDict = item.cast_as().unwrap(); + let d = if self.input.is_instance::()? { + self.input.downcast::()?.try_into()? + } else { + self.input.extract::()? + }; if d.len() != 1 { return Err(PythonizeError::invalid_length_enum()); } let variant: &PyString = d - .keys() - .get_item(0) + .clone() + .into_keys(self.input.py()) + .next() + .unwrap() .cast_as() .map_err(|_| PythonizeError::dict_key_not_string())?; - let value = d.get_item(variant).unwrap(); - let mut de = Depythonizer::from_object(value); + let value = d.get_item(self.input.py(), variant)?.unwrap(); + let mut de = Depythonizer::from_object(value.into_ref(self.input.py())); visitor.visit_enum(PyEnumAccess::new(&mut de, variant)) } else if item.is_instance::()? { let s: &PyString = self.input.cast_as()?; @@ -318,15 +329,15 @@ impl<'de> de::SeqAccess<'de> for PySequenceAccess<'de> { } } -struct PyDictAccess<'de, Iter> +struct DictAccess<'de, Iter> where Iter: Iterator + 'de, { - iter: Iter, // TODO: figure out why PyDictIterator is not publicly accessible upstream? + iter: Iter, next_value: Option<&'de PyAny>, } -impl<'de, Iter> PyDictAccess<'de, Iter> +impl<'de, Iter> DictAccess<'de, Iter> where Iter: Iterator + 'de, { @@ -338,7 +349,7 @@ where } } -impl<'de, Iter> de::MapAccess<'de> for PyDictAccess<'de, Iter> +impl<'de, Iter> de::MapAccess<'de> for DictAccess<'de, Iter> where Iter: Iterator + 'de, { @@ -430,6 +441,7 @@ mod test { use super::*; use crate::error::ErrorImpl; use maplit::hashmap; + use pyo3::types::Dict; use pyo3::Python; use serde_json::{json, Value as JsonValue}; @@ -439,7 +451,7 @@ mod test { { let gil = Python::acquire_gil(); let py = gil.python(); - let locals = PyDict::new(py); + let locals = Dict::new(py); py.run(&format!("obj = {}", code), None, Some(locals)) .unwrap(); let obj = locals.get_item("obj").unwrap(); @@ -495,7 +507,7 @@ mod test { let gil = Python::acquire_gil(); let py = gil.python(); - let locals = PyDict::new(py); + let locals = Dict::new(py); py.run(&format!("obj = {}", code), None, Some(locals)) .unwrap(); let obj = locals.get_item("obj").unwrap(); @@ -525,7 +537,7 @@ mod test { let gil = Python::acquire_gil(); let py = gil.python(); - let locals = PyDict::new(py); + let locals = Dict::new(py); py.run(&format!("obj = {}", code), None, Some(locals)) .unwrap(); let obj = locals.get_item("obj").unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 175a42c..4be01da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,6 +38,7 @@ //! ``` mod de; mod error; +pub mod py; mod ser; pub use crate::de::depythonize; diff --git a/src/py.rs b/src/py.rs new file mode 100644 index 0000000..f55082d --- /dev/null +++ b/src/py.rs @@ -0,0 +1,333 @@ +//! Module with python dictionary which can be hashed + +use std::collections::btree_map::IntoIter; +use std::collections::BTreeMap; +use std::convert::TryFrom; + +use pyo3::basic::CompareOp; +use pyo3::class::iter::{IterNextOutput, PyIterProtocol}; +use pyo3::conversion::{ToBorrowedObject, ToPyObject}; +use pyo3::exceptions::PyKeyError; +use pyo3::types::PyDict; +use pyo3::{prelude::*, PyMappingProtocol, PyNativeType, PyObjectProtocol}; + +type Hash = isize; + +/// Dictionary which can be hashed. Requires both keys and values to be hashable +#[pyclass] +#[derive(Debug, Clone, Default)] +pub struct Dict { + map: BTreeMap< + // Key hash + Hash, + BTreeMap< + // Value hash + Hash, + // Entry + (PyObject, PyObject), + >, + >, +} + +impl TryFrom<&PyDict> for Dict { + type Error = PyErr; + fn try_from(dict: &PyDict) -> PyResult { + let mut map = Dict::new(); + for (k, v) in dict { + map.set_item(dict.py(), k, v)?; + } + Ok(map) + } +} + +fn as_dict_obj(py: Python, obj: PyObject) -> PyResult { + let obj = obj.into_ref(py); + let obj = if obj.is_instance::()? { + Dict::try_from(obj.downcast::()?)?.into_py(py) + } else { + obj.to_object(py) + }; + Ok(obj) +} + +impl Dict { + fn is_eq(py: Python, a: &PyObject, b: &PyObject) -> bool { + a.as_ref(py) + .rich_compare(b.as_ref(py), CompareOp::Eq) + .and_then(|b| b.extract::()) + .unwrap_or(false) + } + + /// Sets item with key and value + pub fn set_item( + &mut self, + py: Python, + k: K, + v: V, + ) -> PyResult<()> { + let k = as_dict_obj(py, k.to_object(py))?; + let v = as_dict_obj(py, v.to_object(py))?; + let k_hash = k.as_ref(py).hash()?; + let v_hash = v.as_ref(py).hash()?; + self.map.entry(k_hash).or_default().insert(v_hash, (k, v)); + Ok(()) + } + + /// Gets item by key + pub fn get_item(&self, py: Python, k: K) -> PyResult> { + let k = as_dict_obj(py, k.to_object(py))?; + let k_hash = k.as_ref(py).hash()?; + let bucket = match self.map.get(&k_hash) { + Some(bucket) => bucket, + None => return Ok(None), + }; + Ok(bucket + .values() + .find(|(bucket_k, _)| Self::is_eq(py, &k, bucket_k)) + .map(|(_, v)| v.clone())) + } + + /// Remove item by key + pub fn remove_item( + &mut self, + py: Python, + k: K, + ) -> PyResult> { + let k = as_dict_obj(py, k.to_object(py))?; + let k_hash = k.as_ref(py).hash()?; + let bucket = match self.map.get_mut(&k_hash) { + Some(bucket) => bucket, + None => return Ok(None), + }; + Ok(bucket + .values_mut() + .find(|(bucket_k, _)| Self::is_eq(py, &k, bucket_k)) + .map(|(_, v)| v.clone())) + } + + /// Iterator over both keys and values + pub fn iter<'py>( + &'py self, + py: Python<'py>, + ) -> impl Iterator + 'py { + self.map + .values() + .flat_map(BTreeMap::values) + .map(move |(k, v)| (k.clone().into_ref(py), v.clone().into_ref(py))) + } + + /// Iterator which consumes object + pub fn into_iter(self, py: Python<'_>) -> impl Iterator { + self.map + .into_iter() + .flat_map(|(_, bucket)| bucket.into_iter()) + .map(move |(_, (k, v))| (k.into_ref(py), v.into_ref(py))) + } + + /// Iterator over keys + pub fn keys<'py>(&'py self, py: Python<'py>) -> impl Iterator + 'py { + self.iter(py).map(|(k, _)| k) + } + + /// Iterator over keys which consumes object + pub fn into_keys(self, py: Python<'_>) -> impl Iterator { + self.into_iter(py).map(|(k, _)| k) + } + + /// Iterator over values + pub fn values<'py>(&'py self, py: Python<'py>) -> impl Iterator + 'py { + self.iter(py).map(|(_, v)| v) + } + + /// Iterator over values which consumes object + pub fn into_values(self, py: Python<'_>) -> impl Iterator { + self.into_iter(py).map(|(_, v)| v) + } + + /// Hashes object + pub fn hash(&self, py: Python) -> Hash { + let sum_hash = + |a: isize, b: isize| a.wrapping_add(b).into_py(py).as_ref(py).hash().unwrap(); + + self.map + .iter() + .flat_map(|(&key, bucket)| bucket.keys().map(move |&value| sum_hash(key, value))) + .fold(0, sum_hash) + } + + /// Returns number of items in dictionary + pub fn len(&self) -> usize { + self.map.iter().fold(0, |prev, (_, next)| prev + next.len()) + } + + /// Checks whether dict is empty + pub fn is_empty(&self) -> bool { + self.map.len() == 0 + } +} + +#[pymethods] +impl Dict { + /// Constructor + #[new] + pub fn new() -> Self { + Self::default() + } + + /// Returns an iterator over keys + #[name = "keys"] + fn _keys(&self) -> Keys { + Keys::new(self.clone()) + } + + /// Returns an iterator over values + #[name = "values"] + fn _values(&self) -> Values { + Values::new(self.clone()) + } + + /// Returns an iterator over both keys and values + fn items(&self) -> Items { + Items::new(self.clone()) + } +} + +#[pyproto] +impl PyObjectProtocol for Dict { + /// Comparison which relies on hashes + fn __richcmp__(&self, other: Self, op: CompareOp) -> bool { + matches!(op, CompareOp::Eq if Python::with_gil(|py| self.hash(py) == other.hash(py))) + } + + fn __hash__(&self) -> isize { + Python::with_gil(|py| self.hash(py)) + } + + fn __str__(&self) -> PyResult { + Python::with_gil(|py| { + let items = self + .iter(py) + .map(|(k, v)| Ok(format!("{}: {}", k.str()?, v.str()?))) + .collect::>>()? + .join(","); + Ok(format!("{{{}}}", items)) + }) + } + + fn __repr__(&self) -> PyResult { + Python::with_gil(|py| { + let items = self + .iter(py) + .map(|(k, v)| Ok(format!("{}: {}", k.repr()?, v.repr()?))) + .collect::>>()? + .join(","); + Ok(format!("{{{}}}", items)) + }) + } +} + +#[pyproto] +impl PyMappingProtocol for Dict { + fn __len__(&self) -> usize { + self.len() + } + + fn __setitem__(&mut self, key: PyObject, value: PyObject) -> PyResult<()> { + Python::with_gil(|py| self.set_item(py, key, value)) + } + + fn __delitem__(&mut self, key: PyObject) -> PyResult<()> { + Python::with_gil(|py| self.remove_item(py, key))?; + Ok(()) + } + + fn __getitem__(&self, key: PyObject) -> PyResult { + match Python::with_gil(|py| self.get_item(py, key.clone()))? { + Some(obj) => Ok(obj), + None => Err(PyErr::new::(key)), + } + } +} + +#[pyclass] +struct Items { + dict: IntoIter>, + bucket: Option>, +} + +impl Items { + fn new(dict: Dict) -> Self { + Self { + dict: dict.map.into_iter(), + bucket: None, + } + } + + fn next(&mut self) -> Option<(PyObject, PyObject)> { + if let Some(bucket) = &mut self.bucket { + if let Some((_, entry)) = bucket.next() { + return Some(entry); + } + } + self.bucket = match self.dict.next() { + Some((_, bucket)) => Some(bucket.into_iter()), + None => return None, + }; + + self.next() + } +} + +#[pyclass] +struct Keys { + items: Items, +} + +impl Keys { + fn new(dict: Dict) -> Self { + let items = Items::new(dict); + Self { items } + } +} + +#[pyclass] +struct Values { + items: Items, +} + +impl Values { + fn new(dict: Dict) -> Self { + let items = Items::new(dict); + Self { items } + } +} + +#[pyproto] +impl PyIterProtocol for Items { + fn __next__(mut slf: PyRefMut) -> IterNextOutput<(PyObject, PyObject), ()> { + match slf.next() { + Some(entry) => IterNextOutput::Yield(entry), + None => IterNextOutput::Return(()), + } + } +} + +#[pyproto] +impl PyIterProtocol for Keys { + fn __next__(mut slf: PyRefMut) -> IterNextOutput { + match slf.items.next() { + Some((entry, _)) => IterNextOutput::Yield(entry), + None => IterNextOutput::Return(()), + } + } +} + +#[pyproto] +impl PyIterProtocol for Values { + fn __next__(mut slf: PyRefMut) -> IterNextOutput { + match slf.items.next() { + Some((_, entry)) => IterNextOutput::Yield(entry), + None => IterNextOutput::Return(()), + } + } +} diff --git a/src/ser.rs b/src/ser.rs index c8b9912..7470ebd 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -1,8 +1,9 @@ -use pyo3::types::{PyDict, PyList, PyTuple}; -use pyo3::{IntoPy, PyNativeType, PyObject, Python}; +use pyo3::types::{PyList, PyTuple}; +use pyo3::{IntoPy, PyObject, Python}; use serde::{ser, Serialize}; use crate::error::{PythonizeError, Result}; +use crate::py::Dict as PyDict; /// Attempt to convert the given data into a Python object pub fn pythonize(py: Python, value: &T) -> Result @@ -37,12 +38,14 @@ pub struct PythonStructVariantSerializer<'py> { #[doc(hidden)] pub struct PythonDictSerializer<'py> { - dict: &'py PyDict, + py: Python<'py>, + dict: PyDict, } #[doc(hidden)] pub struct PythonMapSerializer<'py> { - dict: &'py PyDict, + py: Python<'py>, + dict: PyDict, key: Option, } @@ -158,9 +161,9 @@ impl<'py> ser::Serializer for Pythonizer<'py> { where T: ?Sized + Serialize, { - let d = PyDict::new(self.py); - d.set_item(variant, value.serialize(self)?)?; - Ok(d.into()) + let mut d = PyDict::new(); + d.set_item(self.py, variant, value.serialize(self)?)?; + Ok(d.into_py(self.py)) } fn serialize_seq(self, len: Option) -> Result> { @@ -199,7 +202,8 @@ impl<'py> ser::Serializer for Pythonizer<'py> { fn serialize_map(self, _len: Option) -> Result> { Ok(PythonMapSerializer { - dict: PyDict::new(self.py), + py: self.py, + dict: PyDict::new(), key: None, }) } @@ -210,7 +214,8 @@ impl<'py> ser::Serializer for Pythonizer<'py> { _len: usize, ) -> Result> { Ok(PythonDictSerializer { - dict: PyDict::new(self.py), + py: self.py, + dict: PyDict::new(), }) } @@ -224,7 +229,8 @@ impl<'py> ser::Serializer for Pythonizer<'py> { Ok(PythonStructVariantSerializer { variant, inner: PythonDictSerializer { - dict: PyDict::new(self.py), + py: self.py, + dict: PyDict::new(), }, }) } @@ -291,13 +297,14 @@ impl ser::SerializeTupleVariant for PythonTupleVariantSerializer<'_> { } fn end(self) -> Result { - let d = PyDict::new(self.inner.py); - d.set_item(self.variant, ser::SerializeTuple::end(self.inner)?)?; - Ok(d.into()) + let mut d = PyDict::new(); + let py = self.inner.py; + d.set_item(py, self.variant, ser::SerializeTuple::end(self.inner)?)?; + Ok(d.into_py(py)) } } -impl ser::SerializeMap for PythonMapSerializer<'_> { +impl<'py> ser::SerializeMap for PythonMapSerializer<'py> { type Ok = PyObject; type Error = PythonizeError; @@ -305,8 +312,10 @@ impl ser::SerializeMap for PythonMapSerializer<'_> { where T: ?Sized + Serialize, { - self.key = Some(pythonize(self.dict.py(), key)?); - Ok(()) + Python::with_gil(|py| -> Result<()> { + self.key = Some(pythonize(py, key)?); + Ok(()) + }) } fn serialize_value(&mut self, value: &T) -> Result<()> @@ -314,20 +323,21 @@ impl ser::SerializeMap for PythonMapSerializer<'_> { T: ?Sized + Serialize, { self.dict.set_item( + self.py, self.key .take() .expect("serialize_value should always be called after serialize_key"), - pythonize(self.dict.py(), value)?, + pythonize(self.py, value)?, )?; Ok(()) } fn end(self) -> Result { - Ok(self.dict.into()) + Ok(self.dict.into_py(self.py)) } } -impl ser::SerializeStruct for PythonDictSerializer<'_> { +impl<'py> ser::SerializeStruct for PythonDictSerializer<'py> { type Ok = PyObject; type Error = PythonizeError; @@ -335,15 +345,17 @@ impl ser::SerializeStruct for PythonDictSerializer<'_> { where T: ?Sized + Serialize, { - Ok(self.dict.set_item(key, pythonize(self.dict.py(), value)?)?) + Ok(self + .dict + .set_item(self.py, key, pythonize(self.py, value)?)?) } fn end(self) -> Result { - Ok(self.dict.into()) + Ok(self.dict.into_py(self.py)) } } -impl ser::SerializeStructVariant for PythonStructVariantSerializer<'_> { +impl<'py> ser::SerializeStructVariant for PythonStructVariantSerializer<'py> { type Ok = PyObject; type Error = PythonizeError; @@ -353,14 +365,18 @@ impl ser::SerializeStructVariant for PythonStructVariantSerializer<'_> { { self.inner .dict - .set_item(key, pythonize(self.inner.dict.py(), value)?)?; + .set_item(self.inner.py, key, pythonize(self.inner.py, value)?)?; Ok(()) } fn end(self) -> Result { - let d = PyDict::new(self.inner.dict.py()); - d.set_item(self.variant, self.inner.dict)?; - Ok(d.into()) + let mut d = PyDict::new(); + d.set_item( + self.inner.py, + self.variant, + self.inner.dict.into_py(self.inner.py), + )?; + Ok(d.into_py(self.inner.py)) } }