Skip to content

Commit

Permalink
fix repr breakages
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Sep 26, 2023
1 parent 791c0e3 commit cddcc71
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 26 deletions.
104 changes: 90 additions & 14 deletions src/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
use std::{
collections::hash_map::Entry,
fmt::Debug,
sync::{Arc, OnceLock},
sync::{
atomic::{AtomicBool, Ordering},
Arc, OnceLock,
},
};

use pyo3::{prelude::*, PyTraverseError, PyVisit};
Expand Down Expand Up @@ -35,14 +38,19 @@ impl<T> Definitions<T> {
}

/// Internal type which contains a definition to be filled
pub struct Definition<T>(Arc<OnceLock<T>>);
pub struct Definition<T>(Arc<DefinitionInner<T>>);

impl<T> Definition<T> {
pub fn get(&self) -> Option<&T> {
self.0.get()
self.0.value.get()
}
}

struct DefinitionInner<T> {
value: OnceLock<T>,
name: LazyName,
}

/// Reference to a definition.
pub struct DefinitionRef<T> {
name: Arc<String>,
Expand All @@ -64,12 +72,15 @@ impl<T> DefinitionRef<T> {
Arc::as_ptr(&self.value.0) as usize
}

pub fn name(&self) -> &str {
&self.name
pub fn get_or_init_name(&self, init: impl FnOnce(&T) -> String) -> &str {
match self.value.0.value.get() {
Some(value) => self.value.0.name.get_or_init(|| init(value)),
None => "...",
}
}

pub fn get(&self) -> Option<&T> {
self.value.0.get()
self.value.0.value.get()
}
}

Expand All @@ -83,7 +94,17 @@ impl<T: Debug> Debug for DefinitionRef<T> {

impl<T: Debug> Debug for Definitions<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
// Formatted as a list for backwards compatibility; in principle
// this could be formatted as a map. Maybe change in a future
// minor release of pydantic.
write![f, "["]?;
let mut first = true;
for def in self.0.values() {
write![f, "{sep}{def:?}", sep = if first { "" } else { ", " }]?;
first = false;
}
write![f, "]"]?;
Ok(())
}
}

Expand All @@ -95,7 +116,7 @@ impl<T> Clone for Definition<T> {

impl<T: Debug> Debug for Definition<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.0.get() {
match self.0.value.get() {
Some(value) => value.fmt(f),
None => "...".fmt(f),
}
Expand All @@ -104,7 +125,7 @@ impl<T: Debug> Debug for Definition<T> {

impl<T: PyGcTraverse> PyGcTraverse for DefinitionRef<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
if let Some(value) = self.value.0.get() {
if let Some(value) = self.value.0.value.get() {
value.py_gc_traverse(visit)?;
}
Ok(())
Expand All @@ -114,7 +135,7 @@ impl<T: PyGcTraverse> PyGcTraverse for DefinitionRef<T> {
impl<T: PyGcTraverse> PyGcTraverse for Definitions<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
for value in self.0.values() {
if let Some(value) = value.0.get() {
if let Some(value) = value.0.value.get() {
value.py_gc_traverse(visit)?;
}
}
Expand Down Expand Up @@ -142,7 +163,10 @@ impl<T: std::fmt::Debug> DefinitionsBuilder<T> {
let name = Arc::new(reference.to_string());
let value = match self.definitions.0.entry(name.clone()) {
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => entry.insert(Definition(Arc::new(OnceLock::new()))),
Entry::Vacant(entry) => entry.insert(Definition(Arc::new(DefinitionInner {
value: OnceLock::new(),
name: LazyName::new(),
}))),
};
DefinitionRef {
name,
Expand All @@ -156,23 +180,75 @@ impl<T: std::fmt::Debug> DefinitionsBuilder<T> {
let value = match self.definitions.0.entry(name.clone()) {
Entry::Occupied(entry) => {
let definition = entry.into_mut();
match definition.0.set(value) {
match definition.0.value.set(value) {
Ok(()) => definition.clone(),
Err(_) => return py_schema_err!("Duplicate ref: `{}`", name),
}
}
Entry::Vacant(entry) => entry.insert(Definition(Arc::new(OnceLock::from(value)))).clone(),
Entry::Vacant(entry) => entry
.insert(Definition(Arc::new(DefinitionInner {
value: OnceLock::from(value),
name: LazyName::new(),
})))
.clone(),
};
Ok(DefinitionRef { name, value })
}

/// Consume this Definitions into a vector of items, indexed by each items ReferenceId
pub fn finish(self) -> PyResult<Definitions<T>> {
for (reference, def) in &self.definitions.0 {
if def.0.get().is_none() {
if def.0.value.get().is_none() {
return py_schema_err!("Definitions error: definition `{}` was never filled", reference);
}
}
Ok(self.definitions)
}
}

struct LazyName {
initialized: OnceLock<String>,
in_recursion: AtomicBool,
}

impl LazyName {
fn new() -> Self {
Self {
initialized: OnceLock::new(),
in_recursion: AtomicBool::new(false),
}
}

/// Gets the validator name, returning the default in the case of recursion loops
fn get_or_init(&self, init: impl FnOnce() -> String) -> &str {
if let Some(s) = self.initialized.get() {
return s.as_str();
}

if self
.in_recursion
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
return "...";
}
let result = self.initialized.get_or_init(init).as_str();
self.in_recursion.store(false, Ordering::SeqCst);
result
}
}

impl Debug for LazyName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.initialized.get().map_or("...", String::as_str).fmt(f)
}
}

impl Clone for LazyName {
fn clone(&self) -> Self {
Self {
initialized: OnceLock::new(),
in_recursion: AtomicBool::new(false),
}
}
}
14 changes: 8 additions & 6 deletions src/validators/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ impl BuildValidator for DefinitionRefValidator {
let schema_ref = schema.get_as_req(intern!(schema.py(), "schema_ref"))?;

let definition = definitions.get_definition(schema_ref);

Ok(Self { definition }.into())
Ok(Self::new(definition).into())
}
}

Expand Down Expand Up @@ -131,8 +130,12 @@ impl Validator for DefinitionRefValidator {

let id = self as *const _ as usize;
// have to unwrap here, because we can't return an error from this function, should be okay
let validator = self.definition.get().unwrap();
if RECURSION_SET.with(|set| set.borrow_mut().get_or_insert_with(HashSet::new).insert(id)) {
let validator: &CombinedValidator = self.definition.get().unwrap();
if RECURSION_SET.with(
|set: &RefCell<Option<std::collections::HashSet<usize, ahash::RandomState>>>| {
set.borrow_mut().get_or_insert_with(HashSet::new).insert(id)
},
) {
let different_strict_behavior = validator.different_strict_behavior(ultra_strict);
RECURSION_SET.with(|set| set.borrow_mut().get_or_insert_with(HashSet::new).remove(&id));
different_strict_behavior
Expand All @@ -142,10 +145,9 @@ impl Validator for DefinitionRefValidator {
}

fn get_name(&self) -> &str {
self.definition.get().map_or("...", |validator| validator.get_name())
self.definition.get_or_init_name(|v| v.get_name().into())
}

/// don't need to call complete on the inner validator here, complete_validators takes care of that.
fn complete(&self) -> PyResult<()> {
Ok(())
}
Expand Down
19 changes: 16 additions & 3 deletions src/validators/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,27 @@ impl Validator for ListValidator {
}

fn get_name(&self) -> &str {
self.name.get().map_or("list[...]", String::as_str)
// The logic here is a little janky, it's done to try to cache the formatted name
// while also trying to render definitions correctly when possible.
//
// Probably an opportunity for a future refactor
match self.name.get() {
Some(s) => s.as_str(),
None => {
let name = self.item_validator.as_ref().map_or("any", |v| v.get_name());
if name == "..." {
// when inner name is not initialized yet, don't cache it here
"list[...]"
} else {
self.name.get_or_init(|| format!("list[{name}]")).as_str()
}
}
}
}

fn complete(&self) -> PyResult<()> {
if let Some(v) = &self.item_validator {
v.complete()?;
let inner_name = v.get_name();
let _ = self.name.set(format!("{}[{inner_name}]", Self::EXPECTED_TYPE));
}
Ok(())
}
Expand Down
4 changes: 2 additions & 2 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ impl SchemaValidator {
let mut definitions_builder = DefinitionsBuilder::new();

let validator = build_validator(schema, config, &mut definitions_builder)?;
validator.complete()?;
let definitions = definitions_builder.finish()?;
validator.complete()?;
for val in definitions.values() {
val.get().unwrap().complete()?;
}
Expand Down Expand Up @@ -387,8 +387,8 @@ impl<'py> SelfValidator<'py> {
Ok(v) => v,
Err(err) => return py_schema_err!("Error building self-schema:\n {}", err),
};
validator.complete()?;
let definitions = definitions_builder.finish()?;
validator.complete()?;
for val in definitions.values() {
val.get().unwrap().complete()?;
}
Expand Down
100 changes: 99 additions & 1 deletion tests/validators/test_definitions_recursive.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import platform
from dataclasses import dataclass
from typing import List, Optional
Expand Down Expand Up @@ -243,7 +244,7 @@ class Branch:


def test_invalid_schema():
with pytest.raises(SchemaError, match='Definitions error: attempted to use `Branch` before it was filled'):
with pytest.raises(SchemaError, match='Definitions error: definition `Branch` was never filled'):
SchemaValidator(
{
'type': 'list',
Expand Down Expand Up @@ -987,3 +988,100 @@ def test_cyclic_data_threeway() -> None:
'input': cyclic_data,
}
]


def test_complex_recursive_type() -> None:
schema = core_schema.definitions_schema(
core_schema.definition_reference_schema('JsonType'),
[
core_schema.nullable_schema(
core_schema.union_schema(
[
core_schema.list_schema(core_schema.definition_reference_schema('JsonType')),
core_schema.dict_schema(
core_schema.str_schema(), core_schema.definition_reference_schema('JsonType')
),
core_schema.str_schema(),
core_schema.int_schema(),
core_schema.float_schema(),
core_schema.bool_schema(),
]
),
ref='JsonType',
)
],
)

validator = SchemaValidator(schema)

with pytest.raises(ValidationError) as exc_info:
validator.validate_python({'a': datetime.date(year=1992, month=12, day=11)})

assert exc_info.value.errors(include_url=False) == [
{
'type': 'list_type',
'loc': ('list[nullable[union[list[...],dict[str,...],str,int,float,bool]]]',),
'msg': 'Input should be a valid list',
'input': {'a': datetime.date(1992, 12, 11)},
},
{
'type': 'list_type',
'loc': ('dict[str,...]', 'a', 'list[nullable[union[list[...],dict[str,...],str,int,float,bool]]]'),
'msg': 'Input should be a valid list',
'input': datetime.date(1992, 12, 11),
},
{
'type': 'dict_type',
'loc': ('dict[str,...]', 'a', 'dict[str,...]'),
'msg': 'Input should be a valid dictionary',
'input': datetime.date(1992, 12, 11),
},
{
'type': 'string_type',
'loc': ('dict[str,...]', 'a', 'str'),
'msg': 'Input should be a valid string',
'input': datetime.date(1992, 12, 11),
},
{
'type': 'int_type',
'loc': ('dict[str,...]', 'a', 'int'),
'msg': 'Input should be a valid integer',
'input': datetime.date(1992, 12, 11),
},
{
'type': 'float_type',
'loc': ('dict[str,...]', 'a', 'float'),
'msg': 'Input should be a valid number',
'input': datetime.date(1992, 12, 11),
},
{
'type': 'bool_type',
'loc': ('dict[str,...]', 'a', 'bool'),
'msg': 'Input should be a valid boolean',
'input': datetime.date(1992, 12, 11),
},
{
'type': 'string_type',
'loc': ('str',),
'msg': 'Input should be a valid string',
'input': {'a': datetime.date(1992, 12, 11)},
},
{
'type': 'int_type',
'loc': ('int',),
'msg': 'Input should be a valid integer',
'input': {'a': datetime.date(1992, 12, 11)},
},
{
'type': 'float_type',
'loc': ('float',),
'msg': 'Input should be a valid number',
'input': {'a': datetime.date(1992, 12, 11)},
},
{
'type': 'bool_type',
'loc': ('bool',),
'msg': 'Input should be a valid boolean',
'input': {'a': datetime.date(1992, 12, 11)},
},
]

0 comments on commit cddcc71

Please sign in to comment.