Skip to content

Commit

Permalink
Migrate to another dictionary type
Browse files Browse the repository at this point in the history
Signed-off-by: i1i1 <[email protected]>
  • Loading branch information
i1i1 committed Jun 25, 2021
1 parent 5b73751 commit 2a74a29
Show file tree
Hide file tree
Showing 5 changed files with 407 additions and 45 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ documentation = "https://docs.rs/crate/pythonize/"

[dependencies]
serde = { version = "1.0", default-features = false, features = ["std"] }
pyo3 = { version = "0.13", default-features = false }
pyo3 = { version = "0.13", default-features = false, features = ["macros"] }

[dev-dependencies]
serde = { version = "1.0", default-features = false, features = ["derive"] }
Expand Down
48 changes: 30 additions & 18 deletions src/de.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use pyo3::types::*;
use pyo3::{types::*, PyNativeType};
use serde::de::{self, IntoDeserializer};
use serde::Deserialize;
use std::convert::TryInto;

use crate::error::{PythonizeError, Result};
use crate::py::Dict;

/// Attempt to convert a Python object to an instance of `T`
pub fn depythonize<'de, T>(obj: &'de PyAny) -> Result<T>
Expand Down Expand Up @@ -38,9 +39,13 @@ impl<'de> Depythonizer<'de> {

fn dict_access(
&self,
) -> Result<PyDictAccess<'de, impl Iterator<Item = (&'de PyAny, &'de PyAny)>>> {
let dict: &PyDict = self.input.downcast()?;
Ok(PyDictAccess::new(dict.iter()))
) -> Result<DictAccess<'de, impl Iterator<Item = (&'de PyAny, &'de PyAny)>>> {
let dict = if self.input.is_instance::<PyDict>()? {
self.input.downcast::<PyDict>()?.try_into()?
} else {
self.input.extract::<Dict>()?
};
Ok(DictAccess::new(dict.into_iter(self.input.py())))
}
}

Expand Down Expand Up @@ -70,7 +75,7 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Depythonizer<'de> {
self.deserialize_bool(visitor)
} else if obj.is_instance::<PyByteArray>()? || obj.is_instance::<PyBytes>()? {
self.deserialize_bytes(visitor)
} else if obj.is_instance::<PyDict>()? {
} else if obj.is_instance::<Dict>()? || obj.is_instance::<PyDict>()? {
self.deserialize_map(visitor)
} else if obj.is_instance::<PyFloat>()? {
self.deserialize_f64(visitor)
Expand Down Expand Up @@ -248,19 +253,25 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Depythonizer<'de> {
V: de::Visitor<'de>,
{
let item = self.input;
if item.is_instance::<PyDict>()? {
if item.is_instance::<Dict>()? || item.is_instance::<PyDict>()? {
// Get the enum variant from the dict key
let d: &PyDict = item.cast_as().unwrap();
let d = if self.input.is_instance::<PyDict>()? {
self.input.downcast::<PyDict>()?.try_into()?
} else {
self.input.extract::<Dict>()?
};
if d.len() != 1 {
return Err(PythonizeError::invalid_length_enum());
}
let variant: &PyString = d
.keys()
.get_item(0)
.clone()
.into_keys(self.input.py())
.next()
.unwrap()
.cast_as()
.map_err(|_| PythonizeError::dict_key_not_string())?;
let value = d.get_item(variant).unwrap();
let mut de = Depythonizer::from_object(value);
let value = d.get_item(self.input.py(), variant)?.unwrap();
let mut de = Depythonizer::from_object(value.into_ref(self.input.py()));
visitor.visit_enum(PyEnumAccess::new(&mut de, variant))
} else if item.is_instance::<PyString>()? {
let s: &PyString = self.input.cast_as()?;
Expand Down Expand Up @@ -318,15 +329,15 @@ impl<'de> de::SeqAccess<'de> for PySequenceAccess<'de> {
}
}

struct PyDictAccess<'de, Iter>
struct DictAccess<'de, Iter>
where
Iter: Iterator<Item = (&'de PyAny, &'de PyAny)> + 'de,
{
iter: Iter, // TODO: figure out why PyDictIterator is not publicly accessible upstream?
iter: Iter,
next_value: Option<&'de PyAny>,
}

impl<'de, Iter> PyDictAccess<'de, Iter>
impl<'de, Iter> DictAccess<'de, Iter>
where
Iter: Iterator<Item = (&'de PyAny, &'de PyAny)> + 'de,
{
Expand All @@ -338,7 +349,7 @@ where
}
}

impl<'de, Iter> de::MapAccess<'de> for PyDictAccess<'de, Iter>
impl<'de, Iter> de::MapAccess<'de> for DictAccess<'de, Iter>
where
Iter: Iterator<Item = (&'de PyAny, &'de PyAny)> + 'de,
{
Expand Down Expand Up @@ -430,6 +441,7 @@ mod test {
use super::*;
use crate::error::ErrorImpl;
use maplit::hashmap;
use pyo3::types::Dict;
use pyo3::Python;
use serde_json::{json, Value as JsonValue};

Expand All @@ -439,7 +451,7 @@ mod test {
{
let gil = Python::acquire_gil();
let py = gil.python();
let locals = PyDict::new(py);
let locals = Dict::new(py);
py.run(&format!("obj = {}", code), None, Some(locals))
.unwrap();
let obj = locals.get_item("obj").unwrap();
Expand Down Expand Up @@ -495,7 +507,7 @@ mod test {

let gil = Python::acquire_gil();
let py = gil.python();
let locals = PyDict::new(py);
let locals = Dict::new(py);
py.run(&format!("obj = {}", code), None, Some(locals))
.unwrap();
let obj = locals.get_item("obj").unwrap();
Expand Down Expand Up @@ -525,7 +537,7 @@ mod test {

let gil = Python::acquire_gil();
let py = gil.python();
let locals = PyDict::new(py);
let locals = Dict::new(py);
py.run(&format!("obj = {}", code), None, Some(locals))
.unwrap();
let obj = locals.get_item("obj").unwrap();
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
//! ```
mod de;
mod error;
pub mod py;
mod ser;

pub use crate::de::depythonize;
Expand Down
Loading

0 comments on commit 2a74a29

Please sign in to comment.