From 9117c6b01b5cc9ff59332ddfc2764d74ec8ad6a6 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Sat, 15 Jun 2024 21:40:17 -0400 Subject: [PATCH 01/26] first pass at union match improvement --- src/validators/dataclass.rs | 16 +++++++++- src/validators/mod.rs | 4 +++ src/validators/model.rs | 12 +++++++- src/validators/model_fields.rs | 10 ++++++- src/validators/typed_dict.rs | 10 ++++++- src/validators/union.rs | 48 ++++++++++++++++++++++-------- src/validators/validation_state.rs | 6 ++-- 7 files changed, 87 insertions(+), 19 deletions(-) diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 4e98538a8..edc00874f 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -17,7 +17,9 @@ use crate::validators::function::convert_err; use super::model::{create_class, force_setattr, Revalidate}; use super::validation_state::Exactness; -use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; +use super::{ + build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, HasNumFields, ValidationState, Validator, +}; #[derive(Debug)] struct Field { @@ -423,6 +425,12 @@ impl Validator for DataclassArgsValidator { } } +impl HasNumFields for DataclassArgsValidator { + fn num_fields(&self) -> Option { + Some(self.fields.len()) + } +} + #[derive(Debug)] pub struct DataclassValidator { strict: bool, @@ -631,3 +639,9 @@ impl DataclassValidator { Ok(()) } } + +impl HasNumFields for DataclassValidator { + fn num_fields(&self) -> Option { + Some(self.fields.len()) + } +} diff --git a/src/validators/mod.rs b/src/validators/mod.rs index ede6489ab..1ae2af9bb 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -582,6 +582,10 @@ pub fn build_validator( ) } +pub trait HasNumFields { + fn num_fields(&self) -> Option; +} + /// More (mostly immutable) data to pass between validators, should probably be class `Context`, /// but that would confuse it with context as per pydantic/pydantic#1549 #[derive(Debug, Clone)] diff --git a/src/validators/model.rs b/src/validators/model.rs index 741b6f9f0..a74133a9d 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -8,7 +8,8 @@ use pyo3::{intern, prelude::*}; use super::function::convert_err; use super::validation_state::Exactness; use super::{ - build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, + build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, HasNumFields, ValidationState, + Validator, }; use crate::build_tools::py_schema_err; use crate::build_tools::schema_or_config_same; @@ -307,6 +308,15 @@ impl ModelValidator { } } +impl HasNumFields for ModelValidator { + fn num_fields(&self) -> Option { + match &*self.validator { + CombinedValidator::ModelFields(model_fields_validator) => model_fields_validator.num_fields(), + _ => None, + } + } +} + /// based on the following but with the second argument of new_func set to an empty tuple as required /// https://github.com/PyO3/pyo3/blob/d2caa056e9aacc46374139ef491d112cb8af1a25/src/pyclass_init.rs#L35-L77 pub(super) fn create_class<'py>(class: &Bound<'py, PyType>) -> PyResult> { diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index eda76056a..fb2791eae 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -14,7 +14,9 @@ use crate::input::{BorrowInput, Input, ValidatedDict, ValidationMatch}; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; -use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; +use super::{ + build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, HasNumFields, ValidationState, Validator, +}; #[derive(Debug)] struct Field { @@ -436,3 +438,9 @@ impl Validator for ModelFieldsValidator { Self::EXPECTED_TYPE } } + +impl HasNumFields for ModelFieldsValidator { + fn num_fields(&self) -> Option { + Some(self.fields.len()) + } +} diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 5aa8a5f00..f872a27cf 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -15,7 +15,9 @@ use crate::input::{Input, ValidatedDict}; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; -use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; +use super::{ + build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, HasNumFields, ValidationState, Validator, +}; #[derive(Debug)] struct TypedDictField { @@ -329,3 +331,9 @@ impl Validator for TypedDictValidator { Self::EXPECTED_TYPE } } + +impl HasNumFields for TypedDictValidator { + fn num_fields(&self) -> Option { + Some(self.fields.len()) + } +} diff --git a/src/validators/union.rs b/src/validators/union.rs index 8a33870ec..e9eac85fc 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -17,7 +17,8 @@ use crate::tools::SchemaDict; use super::custom_error::CustomError; use super::literal::LiteralLookup; use super::{ - build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Exactness, ValidationState, Validator, + build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Exactness, HasNumFields, ValidationState, + Validator, }; #[derive(Debug)] @@ -111,7 +112,9 @@ impl UnionValidator { let strict = state.strict_or(self.strict); let mut errors = MaybeErrors::new(self.custom_error.as_ref()); - let mut success = None; + // we use this to track the validation against the most compatible union member + // up to the current point + let mut success: Option<(Py, Exactness, Option)> = None; for (choice, label) in &self.choices { let state = &mut state.rebind_extra(|extra| { @@ -134,16 +137,37 @@ impl UnionValidator { _ => { // success should always have an exactness debug_assert_ne!(state.exactness, None); + let new_exactness = state.exactness.unwrap_or(Exactness::Lax); - // if the new result has higher exactness than the current success, replace it - if success - .as_ref() - .map_or(true, |(_, current_exactness)| *current_exactness < new_exactness) - { - // TODO: is there a possible optimization here, where once there has - // been one success, we turn on strict mode, to avoid unnecessary - // coercions for further validation? - success = Some((new_success, new_exactness)); + + // in the case where the exactness is Strict or Lax, we also check + // the number of fields in the model, and prefer the one with more fields + // which is useful for choosing a more specific model in a union, like a subclass + let new_num_fields: Option = match choice { + CombinedValidator::Model(x) => x.num_fields(), + CombinedValidator::Dataclass(y) => y.num_fields(), + CombinedValidator::TypedDict(z) => z.num_fields(), + _ => None, + }; + + let new_success_is_best_match: bool = + success.as_ref().map_or(true, |(_, cur_exactness, cur_num_fields)| { + match (*cur_num_fields, new_num_fields) { + // if the number of fields is greater and the exactness is more precise + // then the new union member is a better match + (Some(cur_fields), Some(new_fields)) => { + cur_fields < new_fields && *cur_exactness <= new_exactness + } + // if the number of fields isn't known, the new union + // member is only a better match if the exactness is more precise + // which ensures that we prefer left-to-right order in the union + // when exactness ties occur + _ => *cur_exactness < new_exactness, + } + }); + + if new_success_is_best_match { + success = Some((new_success, new_exactness, new_num_fields)); } } }, @@ -158,7 +182,7 @@ impl UnionValidator { } state.exactness = old_exactness; - if let Some((success, exactness)) = success { + if let Some((success, exactness, _fields)) = success { state.floor_exactness(exactness); return Ok(success); } diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs index ef6954618..ee870916c 100644 --- a/src/validators/validation_state.rs +++ b/src/validators/validation_state.rs @@ -10,9 +10,9 @@ use super::Extra; #[derive(Clone, Copy, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)] pub enum Exactness { - Lax, - Strict, - Exact, + Lax = 1, + Strict = 2, + Exact = 3, } pub struct ValidationState<'a, 'py> { From fcc84ecb3e64d3a64f28632bb90531d1717e4e89 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Sat, 15 Jun 2024 21:45:30 -0400 Subject: [PATCH 02/26] don't need to add explicit numbers --- src/validators/validation_state.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs index ee870916c..ef6954618 100644 --- a/src/validators/validation_state.rs +++ b/src/validators/validation_state.rs @@ -10,9 +10,9 @@ use super::Extra; #[derive(Clone, Copy, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)] pub enum Exactness { - Lax = 1, - Strict = 2, - Exact = 3, + Lax, + Strict, + Exact, } pub struct ValidationState<'a, 'py> { From 025521127508bd51567c0a2fe891e0ad3fc23d80 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Sun, 16 Jun 2024 18:48:53 -0400 Subject: [PATCH 03/26] figuring out fields set --- src/validators/union.rs | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/src/validators/union.rs b/src/validators/union.rs index e9eac85fc..debac4150 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -2,7 +2,7 @@ use std::fmt::Write; use std::str::FromStr; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList, PyString, PyTuple}; +use pyo3::types::{PyDict, PyList, PySet, PyString, PyTuple}; use pyo3::{intern, PyTraverseError, PyVisit}; use smallvec::SmallVec; @@ -21,6 +21,8 @@ use super::{ Validator, }; +const DUNDER_FIELDS_SET_KEY: &str = "__pydantic_fields_set__"; + #[derive(Debug)] enum UnionMode { Smart, @@ -150,6 +152,36 @@ impl UnionValidator { _ => None, }; + // let fields_set: Option> = match new_success.getattr(py, DUNDER_FIELDS_SET_KEY) { + // Ok(fields_set) => fields_set.extract(py)?, + // Err(_) => None, + // }; + // dbg!(fields_set); + + // let new_fields_set: Option = match new_success.getattr(py, intern!(py, DUNDER_FIELDS_SET_KEY)) { + // Ok(fields_set) => match fields_set.downcast_bound::(py) { + // Ok(fields_set) => Some(fields_set), + // Err(_) => None, + // } + // Err(_) => None, + // }; + // dbg!(new_fields_set); + + let fields_set_length: Option = match new_success.getattr(py, intern!(py, DUNDER_FIELDS_SET_KEY)) { + Ok(fields_set) => match fields_set.downcast_bound::(py) { + Ok(py_set) => Some(py_set.len()), + Err(_) => None, + }, + Err(_) => None, + }; + + dbg!(fields_set_length); + + dbg!(fields_set_length); + + // dbg!(new_success.getattr(py, DUNDER_FIELDS_SET_KEY)?.extract()?); + + let new_success_is_best_match: bool = success.as_ref().map_or(true, |(_, cur_exactness, cur_num_fields)| { match (*cur_num_fields, new_num_fields) { From ceb5e1b7d496b99435b9305d3174a9c44c6d804a Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Sun, 16 Jun 2024 21:04:20 -0400 Subject: [PATCH 04/26] behavior for checking model fields set --- src/validators/union.rs | 64 +++++++++++----------------------- tests/validators/test_union.py | 2 ++ 2 files changed, 23 insertions(+), 43 deletions(-) diff --git a/src/validators/union.rs b/src/validators/union.rs index debac4150..7ed64d1f7 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -21,8 +21,6 @@ use super::{ Validator, }; -const DUNDER_FIELDS_SET_KEY: &str = "__pydantic_fields_set__"; - #[derive(Debug)] enum UnionMode { Smart, @@ -146,55 +144,35 @@ impl UnionValidator { // the number of fields in the model, and prefer the one with more fields // which is useful for choosing a more specific model in a union, like a subclass let new_num_fields: Option = match choice { - CombinedValidator::Model(x) => x.num_fields(), + CombinedValidator::Model(x) => { + // for models, we attempt to fetch the number of fields set, as this is a better + // measure of the fit of the input data than the total number of fields on a model + // but we fallback to the total number of fields + let num_model_fields_set: Option = + match new_success.getattr(py, intern!(py, "__pydantic_fields_set__")) { + Ok(fields_set) => match fields_set.downcast_bound::(py) { + Ok(py_set) => Some(py_set.len()), + Err(_) => x.num_fields(), + }, + Err(_) => x.num_fields(), + }; + num_model_fields_set + } CombinedValidator::Dataclass(y) => y.num_fields(), CombinedValidator::TypedDict(z) => z.num_fields(), _ => None, }; - // let fields_set: Option> = match new_success.getattr(py, DUNDER_FIELDS_SET_KEY) { - // Ok(fields_set) => fields_set.extract(py)?, - // Err(_) => None, - // }; - // dbg!(fields_set); - - // let new_fields_set: Option = match new_success.getattr(py, intern!(py, DUNDER_FIELDS_SET_KEY)) { - // Ok(fields_set) => match fields_set.downcast_bound::(py) { - // Ok(fields_set) => Some(fields_set), - // Err(_) => None, - // } - // Err(_) => None, - // }; - // dbg!(new_fields_set); - - let fields_set_length: Option = match new_success.getattr(py, intern!(py, DUNDER_FIELDS_SET_KEY)) { - Ok(fields_set) => match fields_set.downcast_bound::(py) { - Ok(py_set) => Some(py_set.len()), - Err(_) => None, - }, - Err(_) => None, - }; - - dbg!(fields_set_length); - - dbg!(fields_set_length); - - // dbg!(new_success.getattr(py, DUNDER_FIELDS_SET_KEY)?.extract()?); - - let new_success_is_best_match: bool = success.as_ref().map_or(true, |(_, cur_exactness, cur_num_fields)| { - match (*cur_num_fields, new_num_fields) { - // if the number of fields is greater and the exactness is more precise + if *cur_exactness < new_exactness { + true + } else if let (Some(cur_fields), Some(new_fields)) = (*cur_num_fields, new_num_fields) { + // if the number of fields is greater and the exactness is the same, // then the new union member is a better match - (Some(cur_fields), Some(new_fields)) => { - cur_fields < new_fields && *cur_exactness <= new_exactness - } - // if the number of fields isn't known, the new union - // member is only a better match if the exactness is more precise - // which ensures that we prefer left-to-right order in the union - // when exactness ties occur - _ => *cur_exactness < new_exactness, + cur_fields < new_fields && *cur_exactness == new_exactness + } else { + false } }); diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index 332b23257..a00967fcb 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -172,6 +172,8 @@ def test_model_a(self, schema_validator: SchemaValidator): def test_model_b_ignored(self, schema_validator: SchemaValidator): # first choice works, so second choice is not used + # TODO: I would argue that this is the wrong behavior - we should initialize + # a ModelB in this case. m = schema_validator.validate_python({'a': 1, 'b': 'hello', 'c': 2.0}) assert isinstance(m, self.ModelA) assert m.a == 1 From 1a6b00c414071fef2eead1a8ebbfc8464f7623f3 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 17 Jun 2024 05:57:06 -0400 Subject: [PATCH 05/26] working on my rust best practices --- src/validators/dataclass.rs | 6 ------ src/validators/union.rs | 31 +++++++++++++------------------ 2 files changed, 13 insertions(+), 24 deletions(-) diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index edc00874f..55208d5a7 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -425,12 +425,6 @@ impl Validator for DataclassArgsValidator { } } -impl HasNumFields for DataclassArgsValidator { - fn num_fields(&self) -> Option { - Some(self.fields.len()) - } -} - #[derive(Debug)] pub struct DataclassValidator { strict: bool, diff --git a/src/validators/union.rs b/src/validators/union.rs index 7ed64d1f7..1228651c2 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -148,15 +148,16 @@ impl UnionValidator { // for models, we attempt to fetch the number of fields set, as this is a better // measure of the fit of the input data than the total number of fields on a model // but we fallback to the total number of fields - let num_model_fields_set: Option = - match new_success.getattr(py, intern!(py, "__pydantic_fields_set__")) { - Ok(fields_set) => match fields_set.downcast_bound::(py) { - Ok(py_set) => Some(py_set.len()), - Err(_) => x.num_fields(), - }, - Err(_) => x.num_fields(), - }; - num_model_fields_set + new_success + .getattr(py, intern!(py, "__pydantic_fields_set__")) + .ok() + .and_then(|fields_set| { + fields_set + .downcast_bound::(py) + .ok() + .map(pyo3::prelude::PySetMethods::len) + }) + .or_else(|| x.num_fields()) } CombinedValidator::Dataclass(y) => y.num_fields(), CombinedValidator::TypedDict(z) => z.num_fields(), @@ -165,15 +166,9 @@ impl UnionValidator { let new_success_is_best_match: bool = success.as_ref().map_or(true, |(_, cur_exactness, cur_num_fields)| { - if *cur_exactness < new_exactness { - true - } else if let (Some(cur_fields), Some(new_fields)) = (*cur_num_fields, new_num_fields) { - // if the number of fields is greater and the exactness is the same, - // then the new union member is a better match - cur_fields < new_fields && *cur_exactness == new_exactness - } else { - false - } + *cur_exactness < new_exactness + || (*cur_exactness == new_exactness + && cur_num_fields.unwrap_or(0) < new_num_fields.unwrap_or(0)) }); if new_success_is_best_match { From 1718e8b1dc3d3b18f09ebf632cfc0bf171313791 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 17 Jun 2024 11:17:09 -0400 Subject: [PATCH 06/26] using method on Validator trait to streamline things --- src/validators/dataclass.rs | 14 +++++-------- src/validators/mod.rs | 12 +++++++---- src/validators/model.rs | 16 +++++---------- src/validators/model_fields.rs | 14 +++++-------- src/validators/typed_dict.rs | 14 +++++-------- src/validators/union.rs | 37 ++++++++++++---------------------- tests/validators/test_union.py | 12 +++++------ 7 files changed, 47 insertions(+), 72 deletions(-) diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 55208d5a7..16f946da2 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -17,9 +17,7 @@ use crate::validators::function::convert_err; use super::model::{create_class, force_setattr, Revalidate}; use super::validation_state::Exactness; -use super::{ - build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, HasNumFields, ValidationState, Validator, -}; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug)] struct Field { @@ -569,6 +567,10 @@ impl Validator for DataclassValidator { Ok(obj.to_object(py)) } + fn num_fields(&self) -> Option { + Some(self.fields.len()) + } + fn get_name(&self) -> &str { &self.name } @@ -633,9 +635,3 @@ impl DataclassValidator { Ok(()) } } - -impl HasNumFields for DataclassValidator { - fn num_fields(&self) -> Option { - Some(self.fields.len()) - } -} diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 1ae2af9bb..5804c3595 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -582,10 +582,6 @@ pub fn build_validator( ) } -pub trait HasNumFields { - fn num_fields(&self) -> Option; -} - /// More (mostly immutable) data to pass between validators, should probably be class `Context`, /// but that would confuse it with context as per pydantic/pydantic#1549 #[derive(Debug, Clone)] @@ -774,6 +770,14 @@ pub trait Validator: Send + Sync + Debug { Err(py_err.into()) } + // Used as a tie-breaking score for union validators for model like validators + // Eventually, it'd be nice to make this a more robust metric that applies + // to a greater variety of validators, but we want to avoid the arbitrary metric, + // z-index like spiral of doom for now, see https://danielrotter.at/2020/04/08/avoid-z-index-whenever-possible.html + fn num_fields(&self) -> Option { + None + } + /// `get_name` generally returns `Self::EXPECTED_TYPE` or some other clear identifier of the validator /// this is used in the error location in unions, and in the top level message in `ValidationError` fn get_name(&self) -> &str; diff --git a/src/validators/model.rs b/src/validators/model.rs index a74133a9d..ce5c545d5 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -8,8 +8,7 @@ use pyo3::{intern, prelude::*}; use super::function::convert_err; use super::validation_state::Exactness; use super::{ - build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, HasNumFields, ValidationState, - Validator, + build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, }; use crate::build_tools::py_schema_err; use crate::build_tools::schema_or_config_same; @@ -217,6 +216,10 @@ impl Validator for ModelValidator { Ok(model.into_py(py)) } + fn num_fields(&self) -> Option { + self.validator.num_fields() + } + fn get_name(&self) -> &str { &self.name } @@ -308,15 +311,6 @@ impl ModelValidator { } } -impl HasNumFields for ModelValidator { - fn num_fields(&self) -> Option { - match &*self.validator { - CombinedValidator::ModelFields(model_fields_validator) => model_fields_validator.num_fields(), - _ => None, - } - } -} - /// based on the following but with the second argument of new_func set to an empty tuple as required /// https://github.com/PyO3/pyo3/blob/d2caa056e9aacc46374139ef491d112cb8af1a25/src/pyclass_init.rs#L35-L77 pub(super) fn create_class<'py>(class: &Bound<'py, PyType>) -> PyResult> { diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index fb2791eae..939fb43d4 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -14,9 +14,7 @@ use crate::input::{BorrowInput, Input, ValidatedDict, ValidationMatch}; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; -use super::{ - build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, HasNumFields, ValidationState, Validator, -}; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug)] struct Field { @@ -434,13 +432,11 @@ impl Validator for ModelFieldsValidator { Ok((new_data.to_object(py), new_extra, fields_set.to_object(py)).to_object(py)) } - fn get_name(&self) -> &str { - Self::EXPECTED_TYPE - } -} - -impl HasNumFields for ModelFieldsValidator { fn num_fields(&self) -> Option { Some(self.fields.len()) } + + fn get_name(&self) -> &str { + Self::EXPECTED_TYPE + } } diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index f872a27cf..977a98502 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -15,9 +15,7 @@ use crate::input::{Input, ValidatedDict}; use crate::lookup_key::LookupKey; use crate::tools::SchemaDict; -use super::{ - build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, HasNumFields, ValidationState, Validator, -}; +use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator}; #[derive(Debug)] struct TypedDictField { @@ -327,13 +325,11 @@ impl Validator for TypedDictValidator { } } - fn get_name(&self) -> &str { - Self::EXPECTED_TYPE - } -} - -impl HasNumFields for TypedDictValidator { fn num_fields(&self) -> Option { Some(self.fields.len()) } + + fn get_name(&self) -> &str { + Self::EXPECTED_TYPE + } } diff --git a/src/validators/union.rs b/src/validators/union.rs index 1228651c2..5fcb845a9 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -17,8 +17,7 @@ use crate::tools::SchemaDict; use super::custom_error::CustomError; use super::literal::LiteralLookup; use super::{ - build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Exactness, HasNumFields, ValidationState, - Validator, + build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Exactness, ValidationState, Validator, }; #[derive(Debug)] @@ -140,29 +139,19 @@ impl UnionValidator { let new_exactness = state.exactness.unwrap_or(Exactness::Lax); - // in the case where the exactness is Strict or Lax, we also check - // the number of fields in the model, and prefer the one with more fields - // which is useful for choosing a more specific model in a union, like a subclass - let new_num_fields: Option = match choice { - CombinedValidator::Model(x) => { - // for models, we attempt to fetch the number of fields set, as this is a better - // measure of the fit of the input data than the total number of fields on a model - // but we fallback to the total number of fields - new_success - .getattr(py, intern!(py, "__pydantic_fields_set__")) + // in the case where the exactness is Strict or Lax, we also check the number of fields set on the result + // (or number of fields on the model, as a fallback), and prefer the one with more fields which is useful + // for choosing a more specific model in a union, like a subclass over a superclass + let new_num_fields = new_success + .getattr(py, intern!(py, "__pydantic_fields_set__")) + .ok() + .and_then(|fields_set| { + fields_set + .downcast_bound::(py) .ok() - .and_then(|fields_set| { - fields_set - .downcast_bound::(py) - .ok() - .map(pyo3::prelude::PySetMethods::len) - }) - .or_else(|| x.num_fields()) - } - CombinedValidator::Dataclass(y) => y.num_fields(), - CombinedValidator::TypedDict(z) => z.num_fields(), - _ => None, - }; + .map(pyo3::prelude::PySetMethods::len) + }) + .or_else(|| choice.num_fields()); let new_success_is_best_match: bool = success.as_ref().map_or(true, |(_, cur_exactness, cur_num_fields)| { diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index a00967fcb..3f3a2b2f2 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -170,15 +170,15 @@ def test_model_a(self, schema_validator: SchemaValidator): assert m.b == 'hello' assert not hasattr(m, 'c') - def test_model_b_ignored(self, schema_validator: SchemaValidator): - # first choice works, so second choice is not used - # TODO: I would argue that this is the wrong behavior - we should initialize - # a ModelB in this case. + def test_model_b_preferred(self, schema_validator: SchemaValidator): + # Note, this is a different behavior to previous smart union behavior, + # where the first match would be preferred. However, we believe is it better + # to prefer the match with the greatest number of valid fields set. m = schema_validator.validate_python({'a': 1, 'b': 'hello', 'c': 2.0}) - assert isinstance(m, self.ModelA) + assert isinstance(m, self.ModelB) assert m.a == 1 assert m.b == 'hello' - assert not hasattr(m, 'c') + assert m.c == 2.0 def test_model_b_not_ignored(self, schema_validator: SchemaValidator): m1 = self.ModelB() From 500a8123c81bad28354bb00dd72de2e3367f0797 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 17 Jun 2024 11:23:07 -0400 Subject: [PATCH 07/26] comment udpate --- src/validators/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 5804c3595..0ccf8df11 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -772,8 +772,8 @@ pub trait Validator: Send + Sync + Debug { // Used as a tie-breaking score for union validators for model like validators // Eventually, it'd be nice to make this a more robust metric that applies - // to a greater variety of validators, but we want to avoid the arbitrary metric, - // z-index like spiral of doom for now, see https://danielrotter.at/2020/04/08/avoid-z-index-whenever-possible.html + // to a greater variety of validators, but we want to avoid the arbitrary metric, z-index like spiral of doom for now. + // see https://danielrotter.at/2020/04/08/avoid-z-index-whenever-possible.html fn num_fields(&self) -> Option { None } From 6645ac28b99d815e141750cff74d2017ba50532c Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 17 Jun 2024 15:01:58 -0400 Subject: [PATCH 08/26] new round of tests --- tests/validators/test_union.py | 174 ++++++++++++++++++++++++++++++++- 1 file changed, 173 insertions(+), 1 deletion(-) diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index 3f3a2b2f2..bd5a37867 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from datetime import date, time from enum import Enum, IntEnum -from typing import Any +from typing import Any, Optional from uuid import UUID import pytest @@ -805,3 +805,175 @@ class ModelA: assert isinstance(m, ModelA) assert m.a == 42 assert validator.validate_python(True) is True + + +def test_union_with_subclass() -> None: + class ModelA: + a: int + + class ModelB(ModelA): + b: int + + model_a_schema = { + 'type': 'model', + 'cls': ModelA, + 'schema': { + 'type': 'model-fields', + 'fields': { + 'a': {'type': 'model-field', 'schema': {'type': 'int'}}, + }, + }, + } + + model_b_schema = { + 'type': 'model', + 'cls': ModelB, + 'schema': { + 'type': 'model-fields', + 'fields': { + 'a': {'type': 'model-field', 'schema': {'type': 'int'}}, + 'b': {'type': 'model-field', 'schema': {'type': 'int'}}, + }, + }, + } + + a_b_val = SchemaValidator( + { + 'type': 'union', + 'choices': [model_a_schema, model_b_schema], + } + ) + + b_a_val = SchemaValidator( + { + 'type': 'union', + 'choices': [model_b_schema, model_a_schema], + } + ) + + assert isinstance(a_b_val.validate_python({'a': 1}), ModelA) + assert isinstance(b_a_val.validate_python({'a': 1}), ModelA) + + assert isinstance(a_b_val.validate_python({'a': 1, 'b': 2}), ModelB) + assert isinstance(b_a_val.validate_python({'a': 1, 'b': 2}), ModelB) + + +def test_union_with_default() -> None: + class ModelA: + a: int = 0 + + class ModelB: + b: int = 0 + + val = SchemaValidator( + { + 'type': 'union', + 'choices': [ + { + 'type': 'model', + 'cls': ModelA, + 'schema': { + 'type': 'model-fields', + 'fields': { + 'a': { + 'type': 'model-field', + 'schema': {'type': 'default', 'schema': {'type': 'int'}, 'default': 0}, + }, + }, + }, + }, + { + 'type': 'model', + 'cls': ModelB, + 'schema': { + 'type': 'model-fields', + 'fields': { + 'b': { + 'type': 'model-field', + 'schema': {'type': 'default', 'schema': {'type': 'int'}, 'default': 0}, + }, + }, + }, + }, + ], + } + ) + + assert isinstance(val.validate_python({'a': 1}), ModelA) + assert isinstance(val.validate_python({'b': 1}), ModelB) + + # defaults to leftmost choice if there's a tie + assert isinstance(val.validate_python({}), ModelA) + + +def test_optional_union_with_members_having_defaults() -> None: + class ModelA: + a: int = 0 + + class ModelB: + b: int = 0 + + class WrapModel: + val: Optional[ModelA | ModelB] = None + + val = SchemaValidator( + { + 'type': 'model', + 'cls': WrapModel, + 'schema': { + 'type': 'model-fields', + 'fields': { + 'val': { + 'type': 'model-field', + 'schema': { + 'type': 'default', + 'schema': { + 'type': 'union', + 'choices': [ + { + 'type': 'model', + 'cls': ModelA, + 'schema': { + 'type': 'model-fields', + 'fields': { + 'a': { + 'type': 'model-field', + 'schema': { + 'type': 'default', + 'schema': {'type': 'int'}, + 'default': 0, + }, + } + }, + }, + }, + { + 'type': 'model', + 'cls': ModelB, + 'schema': { + 'type': 'model-fields', + 'fields': { + 'b': { + 'type': 'model-field', + 'schema': { + 'type': 'default', + 'schema': {'type': 'int'}, + 'default': 0, + }, + } + }, + }, + }, + ], + }, + 'default': None, + }, + } + }, + }, + } + ) + + assert isinstance(val.validate_python({'val': {'a': 1}}).val, ModelA) + assert isinstance(val.validate_python({'val': {'b': 1}}).val, ModelB) + assert val.validate_python({}).val is None From d4799281d8660415a5b0b27a3fe1ebf331905599 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 17 Jun 2024 18:03:25 -0400 Subject: [PATCH 09/26] get all tests passing --- src/validators/union.rs | 7 ++-- tests/validators/test_union.py | 60 ++++++++++++---------------------- 2 files changed, 25 insertions(+), 42 deletions(-) diff --git a/src/validators/union.rs b/src/validators/union.rs index 5fcb845a9..c7b1a6861 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -155,9 +155,10 @@ impl UnionValidator { let new_success_is_best_match: bool = success.as_ref().map_or(true, |(_, cur_exactness, cur_num_fields)| { - *cur_exactness < new_exactness - || (*cur_exactness == new_exactness - && cur_num_fields.unwrap_or(0) < new_num_fields.unwrap_or(0)) + match (*cur_num_fields, new_num_fields) { + (Some(cur), Some(new)) if cur != new => cur < new, + _ => *cur_exactness < new_exactness, + } }); if new_success_is_best_match { diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index bd5a37867..94c95594d 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -814,48 +814,30 @@ class ModelA: class ModelB(ModelA): b: int - model_a_schema = { - 'type': 'model', - 'cls': ModelA, - 'schema': { - 'type': 'model-fields', - 'fields': { - 'a': {'type': 'model-field', 'schema': {'type': 'int'}}, - }, - }, - } - - model_b_schema = { - 'type': 'model', - 'cls': ModelB, - 'schema': { - 'type': 'model-fields', - 'fields': { - 'a': {'type': 'model-field', 'schema': {'type': 'int'}}, - 'b': {'type': 'model-field', 'schema': {'type': 'int'}}, - }, - }, - } - - a_b_val = SchemaValidator( - { - 'type': 'union', - 'choices': [model_a_schema, model_b_schema], - } + model_a_schema = core_schema.model_schema( + ModelA, core_schema.model_fields_schema(fields={'a': core_schema.model_field(core_schema.int_schema())}) ) - - b_a_val = SchemaValidator( - { - 'type': 'union', - 'choices': [model_b_schema, model_a_schema], - } + model_b_schema = core_schema.model_schema( + ModelB, + core_schema.model_fields_schema( + fields={ + 'a': core_schema.model_field(core_schema.int_schema()), + 'b': core_schema.model_field(core_schema.int_schema()), + } + ), ) - assert isinstance(a_b_val.validate_python({'a': 1}), ModelA) - assert isinstance(b_a_val.validate_python({'a': 1}), ModelA) - - assert isinstance(a_b_val.validate_python({'a': 1, 'b': 2}), ModelB) - assert isinstance(b_a_val.validate_python({'a': 1, 'b': 2}), ModelB) + for choices in [[model_a_schema, model_b_schema], [model_b_schema, model_a_schema]]: + validator = SchemaValidator(schema=core_schema.union_schema(choices)) + assert isinstance(validator.validate_python({'a': 1}), ModelA) + assert isinstance(validator.validate_python({'a': 1, 'b': 2}), ModelB) + + # confirm that a model that matches in lax mode with 2 fields + # is preferred over a model that matches in strict mode with 1 field + assert isinstance(validator.validate_python({'a': '1', 'b': '2'}), ModelB) + assert isinstance(validator.validate_python({'a': '1', 'b': 2}), ModelB) + assert isinstance(validator.validate_python({'a': 1, 'b': '2'}), ModelB) + assert isinstance(validator.validate_python({'a': 1, 'b': 2}), ModelB) def test_union_with_default() -> None: From e16103a20bed6f33df4e25aaa044564c4b3eeea8 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 17 Jun 2024 18:14:48 -0400 Subject: [PATCH 10/26] use Union not pipe in tests --- tests/validators/test_union.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index 94c95594d..b6ace83b7 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from datetime import date, time from enum import Enum, IntEnum -from typing import Any, Optional +from typing import Any, Optional, Union from uuid import UUID import pytest @@ -896,7 +896,7 @@ class ModelB: b: int = 0 class WrapModel: - val: Optional[ModelA | ModelB] = None + val: Optional[Union[ModelA, ModelB]] = None val = SchemaValidator( { From 9208282f3391ba8e05b7c5b8e0b43c615b44d286 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Mon, 17 Jun 2024 20:48:43 -0400 Subject: [PATCH 11/26] abandon num_fields for a state based approach --- src/validators/dataclass.rs | 4 ---- src/validators/mod.rs | 8 -------- src/validators/model.rs | 16 +++++++++------- src/validators/model_fields.rs | 4 ---- src/validators/typed_dict.rs | 8 ++++---- src/validators/union.rs | 25 ++++++------------------- src/validators/validation_state.rs | 2 ++ 7 files changed, 21 insertions(+), 46 deletions(-) diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 16f946da2..4e98538a8 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -567,10 +567,6 @@ impl Validator for DataclassValidator { Ok(obj.to_object(py)) } - fn num_fields(&self) -> Option { - Some(self.fields.len()) - } - fn get_name(&self) -> &str { &self.name } diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 0ccf8df11..ede6489ab 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -770,14 +770,6 @@ pub trait Validator: Send + Sync + Debug { Err(py_err.into()) } - // Used as a tie-breaking score for union validators for model like validators - // Eventually, it'd be nice to make this a more robust metric that applies - // to a greater variety of validators, but we want to avoid the arbitrary metric, z-index like spiral of doom for now. - // see https://danielrotter.at/2020/04/08/avoid-z-index-whenever-possible.html - fn num_fields(&self) -> Option { - None - } - /// `get_name` generally returns `Self::EXPECTED_TYPE` or some other clear identifier of the validator /// this is used in the error location in unions, and in the top level message in `ValidationError` fn get_name(&self) -> &str; diff --git a/src/validators/model.rs b/src/validators/model.rs index ce5c545d5..e359669bc 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -204,6 +204,7 @@ impl Validator for ModelValidator { for field_name in validated_fields_set { fields_set.add(field_name)?; } + state.fields_set_count = Some(fields_set.len()); } force_setattr(py, model, intern!(py, DUNDER_DICT), validated_dict.to_object(py))?; @@ -216,10 +217,6 @@ impl Validator for ModelValidator { Ok(model.into_py(py)) } - fn num_fields(&self) -> Option { - self.validator.num_fields() - } - fn get_name(&self) -> &str { &self.name } @@ -245,10 +242,13 @@ impl ModelValidator { } else { PySet::new_bound(py, [&String::from(ROOT_FIELD)])? }; - force_setattr(py, self_instance, intern!(py, DUNDER_FIELDS_SET_KEY), fields_set)?; + force_setattr(py, self_instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?; force_setattr(py, self_instance, intern!(py, ROOT_FIELD), &output)?; + state.fields_set_count = Some(fields_set.len()); } else { - let (model_dict, model_extra, fields_set) = output.extract(py)?; + let (model_dict, model_extra, fields_set): (Bound, Bound, Bound) = + output.extract(py)?; + state.fields_set_count = fields_set.len().ok(); set_model_attrs(self_instance, &model_dict, &model_extra, &fields_set)?; } self.call_post_init(py, self_instance.clone(), input, state.extra()) @@ -285,11 +285,13 @@ impl ModelValidator { } else { PySet::new_bound(py, [&String::from(ROOT_FIELD)])? }; - force_setattr(py, &instance, intern!(py, DUNDER_FIELDS_SET_KEY), fields_set)?; + force_setattr(py, &instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?; force_setattr(py, &instance, intern!(py, ROOT_FIELD), output)?; + state.fields_set_count = Some(fields_set.len()); } else { let (model_dict, model_extra, val_fields_set) = output.extract(py)?; let fields_set = existing_fields_set.unwrap_or(&val_fields_set); + state.fields_set_count = fields_set.len().ok(); set_model_attrs(&instance, &model_dict, &model_extra, fields_set)?; } self.call_post_init(py, instance, input, state.extra()) diff --git a/src/validators/model_fields.rs b/src/validators/model_fields.rs index 939fb43d4..eda76056a 100644 --- a/src/validators/model_fields.rs +++ b/src/validators/model_fields.rs @@ -432,10 +432,6 @@ impl Validator for ModelFieldsValidator { Ok((new_data.to_object(py), new_extra, fields_set.to_object(py)).to_object(py)) } - fn num_fields(&self) -> Option { - Some(self.fields.len()) - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 977a98502..da351642b 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -165,6 +165,7 @@ impl Validator for TypedDictValidator { { let state = &mut state.rebind_extra(|extra| extra.data = Some(output_dict.clone())); + let mut fields_set_count: usize = 0; for field in &self.fields { let op_key_value = match dict.get_item(&field.lookup_key) { @@ -186,6 +187,7 @@ impl Validator for TypedDictValidator { match field.validator.validate(py, value.borrow_input(), state) { Ok(value) => { output_dict.set_item(&field.name_py, value)?; + fields_set_count += 1; } Err(ValError::Omit) => continue, Err(ValError::LineErrors(line_errors)) => { @@ -227,6 +229,8 @@ impl Validator for TypedDictValidator { Err(err) => return Err(err), } } + + state.fields_set_count = (fields_set_count != 0).then_some(fields_set_count); } if let Some(used_keys) = used_keys { @@ -325,10 +329,6 @@ impl Validator for TypedDictValidator { } } - fn num_fields(&self) -> Option { - Some(self.fields.len()) - } - fn get_name(&self) -> &str { Self::EXPECTED_TYPE } diff --git a/src/validators/union.rs b/src/validators/union.rs index c7b1a6861..05fea3e8e 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -2,7 +2,7 @@ use std::fmt::Write; use std::str::FromStr; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList, PySet, PyString, PyTuple}; +use pyo3::types::{PyDict, PyList, PyString, PyTuple}; use pyo3::{intern, PyTraverseError, PyVisit}; use smallvec::SmallVec; @@ -138,31 +138,18 @@ impl UnionValidator { debug_assert_ne!(state.exactness, None); let new_exactness = state.exactness.unwrap_or(Exactness::Lax); - - // in the case where the exactness is Strict or Lax, we also check the number of fields set on the result - // (or number of fields on the model, as a fallback), and prefer the one with more fields which is useful - // for choosing a more specific model in a union, like a subclass over a superclass - let new_num_fields = new_success - .getattr(py, intern!(py, "__pydantic_fields_set__")) - .ok() - .and_then(|fields_set| { - fields_set - .downcast_bound::(py) - .ok() - .map(pyo3::prelude::PySetMethods::len) - }) - .or_else(|| choice.num_fields()); + let new_fields_set = state.fields_set_count; let new_success_is_best_match: bool = - success.as_ref().map_or(true, |(_, cur_exactness, cur_num_fields)| { - match (*cur_num_fields, new_num_fields) { + success.as_ref().map_or(true, |(_, cur_exactness, cur_fields_set)| { + match (*cur_fields_set, new_fields_set) { (Some(cur), Some(new)) if cur != new => cur < new, _ => *cur_exactness < new_exactness, } }); if new_success_is_best_match { - success = Some((new_success, new_exactness, new_num_fields)); + success = Some((new_success, new_exactness, new_fields_set)); } } }, @@ -177,7 +164,7 @@ impl UnionValidator { } state.exactness = old_exactness; - if let Some((success, exactness, _fields)) = success { + if let Some((success, exactness, _fields_set)) = success { state.floor_exactness(exactness); return Ok(success); } diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs index ef6954618..c7db59c22 100644 --- a/src/validators/validation_state.rs +++ b/src/validators/validation_state.rs @@ -18,6 +18,7 @@ pub enum Exactness { pub struct ValidationState<'a, 'py> { pub recursion_guard: &'a mut RecursionState, pub exactness: Option, + pub fields_set_count: Option, // deliberately make Extra readonly extra: Extra<'a, 'py>, } @@ -27,6 +28,7 @@ impl<'a, 'py> ValidationState<'a, 'py> { Self { recursion_guard, // Don't care about exactness unless doing union validation exactness: None, + fields_set_count: None, extra, } } From 69f6ec01ebe2adf5c53a1f067e76f8b9d9b81f20 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Tue, 18 Jun 2024 11:12:48 -0400 Subject: [PATCH 12/26] get typed dicts working, add tests for other model like cases --- src/validators/union.rs | 50 +++---- tests/validators/test_union.py | 262 ++++++++++++++++++++++----------- 2 files changed, 198 insertions(+), 114 deletions(-) diff --git a/src/validators/union.rs b/src/validators/union.rs index 05fea3e8e..f4bbb6cf4 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -122,37 +122,28 @@ impl UnionValidator { } }); state.exactness = Some(Exactness::Exact); + state.fields_set_count = None; let result = choice.validate(py, input, state); match result { - Ok(new_success) => match state.exactness { - // exact match, return - Some(Exactness::Exact) => { - return { - // exact match, return, restore any previous exactness - state.exactness = old_exactness; - Ok(new_success) - }; + Ok(new_success) => { + // success should always have an exactness + debug_assert_ne!(state.exactness, None); + + let new_exactness = state.exactness.unwrap_or(Exactness::Lax); + let new_fields_set = state.fields_set_count; + + let new_success_is_best_match: bool = + success.as_ref().map_or(true, |(_, cur_exactness, cur_fields_set)| { + match (*cur_fields_set, new_fields_set) { + (Some(cur), Some(new)) if cur != new => cur < new, + _ => *cur_exactness < new_exactness, + } + }); + + if new_success_is_best_match { + success = Some((new_success, new_exactness, new_fields_set)); } - _ => { - // success should always have an exactness - debug_assert_ne!(state.exactness, None); - - let new_exactness = state.exactness.unwrap_or(Exactness::Lax); - let new_fields_set = state.fields_set_count; - - let new_success_is_best_match: bool = - success.as_ref().map_or(true, |(_, cur_exactness, cur_fields_set)| { - match (*cur_fields_set, new_fields_set) { - (Some(cur), Some(new)) if cur != new => cur < new, - _ => *cur_exactness < new_exactness, - } - }); - - if new_success_is_best_match { - success = Some((new_success, new_exactness, new_fields_set)); - } - } - }, + } Err(ValError::LineErrors(lines)) => { // if we don't yet know this validation will succeed, record the error if success.is_none() { @@ -162,7 +153,10 @@ impl UnionValidator { otherwise => return otherwise, } } + + // restore previous validation state to prepare for any future validations state.exactness = old_exactness; + state.fields_set_count = None; if let Some((success, exactness, _fields_set)) = success { state.floor_exactness(exactness); diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index b6ace83b7..43aff5f84 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -6,6 +6,7 @@ import pytest from dirty_equals import IsFloat, IsInt +from typing_extensions import TypedDict from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema, validate_core_schema @@ -848,37 +849,30 @@ class ModelB: b: int = 0 val = SchemaValidator( - { - 'type': 'union', - 'choices': [ - { - 'type': 'model', - 'cls': ModelA, - 'schema': { - 'type': 'model-fields', - 'fields': { - 'a': { - 'type': 'model-field', - 'schema': {'type': 'default', 'schema': {'type': 'int'}, 'default': 0}, - }, - }, - }, - }, - { - 'type': 'model', - 'cls': ModelB, - 'schema': { - 'type': 'model-fields', - 'fields': { - 'b': { - 'type': 'model-field', - 'schema': {'type': 'default', 'schema': {'type': 'int'}, 'default': 0}, - }, - }, - }, - }, - ], - } + core_schema.union_schema( + [ + core_schema.model_schema( + ModelA, + core_schema.model_fields_schema( + fields={ + 'a': core_schema.model_field( + core_schema.with_default_schema(core_schema.int_schema(), default=0) + ) + } + ), + ), + core_schema.model_schema( + ModelB, + core_schema.model_fields_schema( + fields={ + 'b': core_schema.model_field( + core_schema.with_default_schema(core_schema.int_schema(), default=0) + ) + } + ), + ), + ] + ), ) assert isinstance(val.validate_python({'a': 1}), ModelA) @@ -899,63 +893,159 @@ class WrapModel: val: Optional[Union[ModelA, ModelB]] = None val = SchemaValidator( - { - 'type': 'model', - 'cls': WrapModel, - 'schema': { - 'type': 'model-fields', - 'fields': { - 'val': { - 'type': 'model-field', - 'schema': { - 'type': 'default', - 'schema': { - 'type': 'union', - 'choices': [ - { - 'type': 'model', - 'cls': ModelA, - 'schema': { - 'type': 'model-fields', - 'fields': { - 'a': { - 'type': 'model-field', - 'schema': { - 'type': 'default', - 'schema': {'type': 'int'}, - 'default': 0, - }, - } - }, - }, - }, - { - 'type': 'model', - 'cls': ModelB, - 'schema': { - 'type': 'model-fields', - 'fields': { - 'b': { - 'type': 'model-field', - 'schema': { - 'type': 'default', - 'schema': {'type': 'int'}, - 'default': 0, - }, - } - }, - }, - }, - ], - }, - 'default': None, - }, - } - }, - }, - } + core_schema.model_schema( + WrapModel, + core_schema.model_fields_schema( + fields={ + 'val': core_schema.model_field( + core_schema.with_default_schema( + core_schema.union_schema( + [ + core_schema.model_schema( + ModelA, + core_schema.model_fields_schema( + fields={ + 'a': core_schema.model_field( + core_schema.with_default_schema(core_schema.int_schema(), default=0) + ) + } + ), + ), + core_schema.model_schema( + ModelB, + core_schema.model_fields_schema( + fields={ + 'b': core_schema.model_field( + core_schema.with_default_schema(core_schema.int_schema(), default=0) + ) + } + ), + ), + ] + ), + default=None, + ) + ) + } + ), + ) ) assert isinstance(val.validate_python({'val': {'a': 1}}).val, ModelA) assert isinstance(val.validate_python({'val': {'b': 1}}).val, ModelB) assert val.validate_python({}).val is None + + +def test_dc_smart_union_by_fields_set() -> None: + @dataclass + class ModelA: + x: int + + @dataclass + class ModelB(ModelA): + y: int + + dc_a_schema = core_schema.dataclass_schema( + ModelA, + core_schema.dataclass_args_schema('ModelA', [core_schema.dataclass_field('x', core_schema.int_schema())]), + ['x'], + ) + + dc_b_schema = core_schema.dataclass_schema( + ModelB, + core_schema.dataclass_args_schema( + 'ModelB', + [ + core_schema.dataclass_field('x', core_schema.int_schema()), + core_schema.dataclass_field('y', core_schema.int_schema()), + ], + ), + ['x', 'y'], + ) + + for choices in [[dc_a_schema, dc_b_schema], [dc_b_schema, dc_a_schema]]: + validator = SchemaValidator(core_schema.union_schema(choices=choices)) + + assert isinstance(validator.validate_python({'x': 1}), ModelA) + assert isinstance(validator.validate_python({'x': '1'}), ModelA) + + assert isinstance(validator.validate_python({'x': 1, 'y': 2}), ModelB) + assert isinstance(validator.validate_python({'x': 1, 'y': '2'}), ModelB) + assert isinstance(validator.validate_python({'x': '1', 'y': 2}), ModelB) + assert isinstance(validator.validate_python({'x': '1', 'y': '2'}), ModelB) + + +def test_dc_smart_union_with_defaults() -> None: + @dataclass + class ModelA: + a: int = 0 + + @dataclass + class ModelB: + b: int = 0 + + dc_a_schema = core_schema.dataclass_schema( + ModelA, + core_schema.dataclass_args_schema( + 'ModelA', + [ + core_schema.dataclass_field( + 'a', core_schema.with_default_schema(schema=core_schema.int_schema(), default=0) + ) + ], + ), + ['a'], + ) + + dc_b_schema = core_schema.dataclass_schema( + ModelB, + core_schema.dataclass_args_schema( + 'ModelB', + [ + core_schema.dataclass_field( + 'b', core_schema.with_default_schema(schema=core_schema.int_schema(), default=0) + ) + ], + ), + ['b'], + ) + + for choices in [[dc_a_schema, dc_b_schema], [dc_b_schema, dc_a_schema]]: + validator = SchemaValidator(core_schema.union_schema(choices=choices)) + + assert isinstance(validator.validate_python({'a': 1}), ModelA) + assert isinstance(validator.validate_python({'b': 1}), ModelB) + + # defaults to leftmost choice if there's a tie + assert isinstance(validator.validate_python({}), ModelA) + + +def test_td_smart_union_by_fields_set() -> None: + class ModelA(TypedDict): + x: int + + class ModelB(TypedDict): + x: int + y: int + + td_a_schema = core_schema.typed_dict_schema( + fields={'x': core_schema.typed_dict_field(core_schema.int_schema())}, + ) + + td_b_schema = core_schema.typed_dict_schema( + fields={ + 'x': core_schema.typed_dict_field(core_schema.int_schema()), + 'y': core_schema.typed_dict_field(core_schema.int_schema()), + }, + ) + + for choices in [[td_a_schema, td_b_schema], [td_b_schema, td_a_schema]]: + validator = SchemaValidator(core_schema.union_schema(choices=choices)) + + assert validator.validate_python({'x': 1}).keys() == ModelA.__required_keys__ + assert validator.validate_python({'x': '1'}).keys() == ModelA.__required_keys__ + + assert validator.validate_python({'x': 1, 'y': 2}).keys() == ModelB.__required_keys__ + assert validator.validate_python({'x': 1, 'y': '2'}).keys() == ModelB.__required_keys__ + assert validator.validate_python({'x': '1', 'y': 2}).keys() == ModelB.__required_keys__ + assert validator.validate_python({'x': '1', 'y': '2'}).keys() == ModelB.__required_keys__ From 18b70d40083891e9a6f4f191a05d70b8aae2a793 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Tue, 18 Jun 2024 11:32:25 -0400 Subject: [PATCH 13/26] dataclass support + more efficient exact return --- src/validators/dataclass.rs | 14 +++++++++-- src/validators/union.rs | 46 ++++++++++++++++++++++--------------- 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 4e98538a8..2554b38d1 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -154,6 +154,7 @@ impl Validator for DataclassArgsValidator { let mut used_keys: AHashSet<&str> = AHashSet::with_capacity(self.fields.len()); let state = &mut state.rebind_extra(|extra| extra.data = Some(output_dict.clone())); + let mut fields_set_count: usize = 0; macro_rules! set_item { ($field:ident, $value:expr) => {{ @@ -175,6 +176,7 @@ impl Validator for DataclassArgsValidator { Ok(Some(value)) => { // Default value exists, and passed validation if required set_item!(field, value); + fields_set_count += 1; } Ok(None) | Err(ValError::Omit) => continue, // Note: this will always use the field name even if there is an alias @@ -214,7 +216,10 @@ impl Validator for DataclassArgsValidator { } // found a positional argument, validate it (Some(pos_value), None) => match field.validator.validate(py, pos_value.borrow_input(), state) { - Ok(value) => set_item!(field, value), + Ok(value) => { + set_item!(field, value); + fields_set_count += 1; + } Err(ValError::LineErrors(line_errors)) => { errors.extend(line_errors.into_iter().map(|err| err.with_outer_location(index))); } @@ -222,7 +227,10 @@ impl Validator for DataclassArgsValidator { }, // found a keyword argument, validate it (None, Some((lookup_path, kw_value))) => match field.validator.validate(py, kw_value, state) { - Ok(value) => set_item!(field, value), + Ok(value) => { + set_item!(field, value); + fields_set_count += 1; + } Err(ValError::LineErrors(line_errors)) => { errors.extend( line_errors @@ -336,6 +344,8 @@ impl Validator for DataclassArgsValidator { } } + state.fields_set_count = (fields_set_count != 0).then_some(fields_set_count); + if errors.is_empty() { if let Some(init_only_args) = init_only_args { Ok((output_dict, PyTuple::new_bound(py, init_only_args)).to_object(py)) diff --git a/src/validators/union.rs b/src/validators/union.rs index f4bbb6cf4..7568d8c20 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -125,25 +125,35 @@ impl UnionValidator { state.fields_set_count = None; let result = choice.validate(py, input, state); match result { - Ok(new_success) => { - // success should always have an exactness - debug_assert_ne!(state.exactness, None); - - let new_exactness = state.exactness.unwrap_or(Exactness::Lax); - let new_fields_set = state.fields_set_count; - - let new_success_is_best_match: bool = - success.as_ref().map_or(true, |(_, cur_exactness, cur_fields_set)| { - match (*cur_fields_set, new_fields_set) { - (Some(cur), Some(new)) if cur != new => cur < new, - _ => *cur_exactness < new_exactness, - } - }); - - if new_success_is_best_match { - success = Some((new_success, new_exactness, new_fields_set)); + Ok(new_success) => match (state.exactness, state.fields_set_count) { + (Some(Exactness::Exact), None) => { + // exact match with no fields set data, return immediately + return { + // exact match, return, restore any previous exactness + state.exactness = old_exactness; + Ok(new_success) + }; } - } + _ => { + // success should always have an exactness + debug_assert_ne!(state.exactness, None); + + let new_exactness = state.exactness.unwrap_or(Exactness::Lax); + let new_fields_set = state.fields_set_count; + + let new_success_is_best_match: bool = + success.as_ref().map_or(true, |(_, cur_exactness, cur_fields_set)| { + match (*cur_fields_set, new_fields_set) { + (Some(cur), Some(new)) if cur != new => cur < new, + _ => *cur_exactness < new_exactness, + } + }); + + if new_success_is_best_match { + success = Some((new_success, new_exactness, new_fields_set)); + } + } + }, Err(ValError::LineErrors(lines)) => { // if we don't yet know this validation will succeed, record the error if success.is_none() { From e4f1e6bbf0e8a3137f4703fde3acb37378ff7712 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Tue, 18 Jun 2024 11:39:08 -0400 Subject: [PATCH 14/26] all dataclass tests passing :) --- src/validators/dataclass.rs | 2 +- src/validators/typed_dict.rs | 2 +- tests/validators/test_union.py | 3 --- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 2554b38d1..855249bf7 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -344,7 +344,7 @@ impl Validator for DataclassArgsValidator { } } - state.fields_set_count = (fields_set_count != 0).then_some(fields_set_count); + state.fields_set_count = Some(fields_set_count); if errors.is_empty() { if let Some(init_only_args) = init_only_args { diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index da351642b..98755512b 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -230,7 +230,7 @@ impl Validator for TypedDictValidator { } } - state.fields_set_count = (fields_set_count != 0).then_some(fields_set_count); + state.fields_set_count = Some(fields_set_count); } if let Some(used_keys) = used_keys { diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index 43aff5f84..0f19d9328 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -1016,9 +1016,6 @@ class ModelB: assert isinstance(validator.validate_python({'a': 1}), ModelA) assert isinstance(validator.validate_python({'b': 1}), ModelB) - # defaults to leftmost choice if there's a tie - assert isinstance(validator.validate_python({}), ModelA) - def test_td_smart_union_by_fields_set() -> None: class ModelA(TypedDict): From 00ddca26d813970f8d9def1d1b57b4da9addcdff Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Tue, 18 Jun 2024 19:02:40 -0400 Subject: [PATCH 15/26] bubble up fields set in nested models --- src/validators/model.rs | 10 +++++----- src/validators/union.rs | 5 ++++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/validators/model.rs b/src/validators/model.rs index e359669bc..b06718251 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -204,7 +204,7 @@ impl Validator for ModelValidator { for field_name in validated_fields_set { fields_set.add(field_name)?; } - state.fields_set_count = Some(fields_set.len()); + state.fields_set_count = Some(fields_set.len() + state.fields_set_count.unwrap_or(0)); } force_setattr(py, model, intern!(py, DUNDER_DICT), validated_dict.to_object(py))?; @@ -244,11 +244,11 @@ impl ModelValidator { }; force_setattr(py, self_instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?; force_setattr(py, self_instance, intern!(py, ROOT_FIELD), &output)?; - state.fields_set_count = Some(fields_set.len()); + state.fields_set_count = Some(fields_set.len() + state.fields_set_count.unwrap_or(0)); } else { let (model_dict, model_extra, fields_set): (Bound, Bound, Bound) = output.extract(py)?; - state.fields_set_count = fields_set.len().ok(); + state.fields_set_count = Some(fields_set.len().unwrap_or(0) + state.fields_set_count.unwrap_or(0)); set_model_attrs(self_instance, &model_dict, &model_extra, &fields_set)?; } self.call_post_init(py, self_instance.clone(), input, state.extra()) @@ -287,11 +287,11 @@ impl ModelValidator { }; force_setattr(py, &instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?; force_setattr(py, &instance, intern!(py, ROOT_FIELD), output)?; - state.fields_set_count = Some(fields_set.len()); + state.fields_set_count = Some(fields_set.len() + state.fields_set_count.unwrap_or(0)); } else { let (model_dict, model_extra, val_fields_set) = output.extract(py)?; let fields_set = existing_fields_set.unwrap_or(&val_fields_set); - state.fields_set_count = fields_set.len().ok(); + state.fields_set_count = Some(fields_set.len().unwrap_or(0) + state.fields_set_count.unwrap_or(0)); set_model_attrs(&instance, &model_dict, &model_extra, fields_set)?; } self.call_post_init(py, instance, input, state.extra()) diff --git a/src/validators/union.rs b/src/validators/union.rs index 7568d8c20..085db0289 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -108,6 +108,8 @@ impl UnionValidator { state: &mut ValidationState<'_, 'py>, ) -> ValResult { let old_exactness = state.exactness; + let old_fields_set_count = state.fields_set_count; + let strict = state.strict_or(self.strict); let mut errors = MaybeErrors::new(self.custom_error.as_ref()); @@ -131,6 +133,7 @@ impl UnionValidator { return { // exact match, return, restore any previous exactness state.exactness = old_exactness; + state.fields_set_count = old_fields_set_count; Ok(new_success) }; } @@ -166,7 +169,7 @@ impl UnionValidator { // restore previous validation state to prepare for any future validations state.exactness = old_exactness; - state.fields_set_count = None; + state.fields_set_count = old_fields_set_count; if let Some((success, exactness, _fields_set)) = success { state.floor_exactness(exactness); From fc21636855fb0bc5c967c3f21eba136640a610d4 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Tue, 18 Jun 2024 19:05:23 -0400 Subject: [PATCH 16/26] add nested counting for dataclasses, typed dicts --- src/validators/dataclass.rs | 2 +- src/validators/typed_dict.rs | 2 +- tests/validators/test_union.py | 9 +++++++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 855249bf7..6455049db 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -344,7 +344,7 @@ impl Validator for DataclassArgsValidator { } } - state.fields_set_count = Some(fields_set_count); + state.fields_set_count = Some(fields_set_count + state.fields_set_count.unwrap_or(0)); if errors.is_empty() { if let Some(init_only_args) = init_only_args { diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 98755512b..eb933eb74 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -230,7 +230,7 @@ impl Validator for TypedDictValidator { } } - state.fields_set_count = Some(fields_set_count); + state.fields_set_count = Some(fields_set_count + state.fields_set_count.unwrap_or(0)); } if let Some(used_keys) = used_keys { diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index 0f19d9328..37f705c74 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -1046,3 +1046,12 @@ class ModelB(TypedDict): assert validator.validate_python({'x': 1, 'y': '2'}).keys() == ModelB.__required_keys__ assert validator.validate_python({'x': '1', 'y': 2}).keys() == ModelB.__required_keys__ assert validator.validate_python({'x': '1', 'y': '2'}).keys() == ModelB.__required_keys__ + + +def test_smart_union_does_nested_model_field_counting() -> None: ... + + +def test_smart_union_does_nested_dataclass_field_counting() -> None: ... + + +def test_smart_union_does_nested_typed_dict_field_counting() -> None: ... From 72aa7bdebdb36987b5386b96ed0981173b852593 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Tue, 18 Jun 2024 20:22:54 -0400 Subject: [PATCH 17/26] corresponding tests --- tests/validators/test_union.py | 187 ++++++++++++++++++++++++++++++--- 1 file changed, 170 insertions(+), 17 deletions(-) diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index 37f705c74..928424d09 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -6,7 +6,6 @@ import pytest from dirty_equals import IsFloat, IsInt -from typing_extensions import TypedDict from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema, validate_core_schema @@ -1018,13 +1017,6 @@ class ModelB: def test_td_smart_union_by_fields_set() -> None: - class ModelA(TypedDict): - x: int - - class ModelB(TypedDict): - x: int - y: int - td_a_schema = core_schema.typed_dict_schema( fields={'x': core_schema.typed_dict_field(core_schema.int_schema())}, ) @@ -1039,19 +1031,180 @@ class ModelB(TypedDict): for choices in [[td_a_schema, td_b_schema], [td_b_schema, td_a_schema]]: validator = SchemaValidator(core_schema.union_schema(choices=choices)) - assert validator.validate_python({'x': 1}).keys() == ModelA.__required_keys__ - assert validator.validate_python({'x': '1'}).keys() == ModelA.__required_keys__ + assert set(validator.validate_python({'x': 1}).keys()) == {'x'} + assert set(validator.validate_python({'x': '1'}).keys()) == {'x'} + + assert set(validator.validate_python({'x': 1, 'y': 2}).keys()) == {'x', 'y'} + assert set(validator.validate_python({'x': 1, 'y': '2'}).keys()) == {'x', 'y'} + assert set(validator.validate_python({'x': '1', 'y': 2}).keys()) == {'x', 'y'} + assert set(validator.validate_python({'x': '1', 'y': '2'}).keys()) == {'x', 'y'} + + +def test_smart_union_does_nested_model_field_counting() -> None: + class SubModelA: + x: int = 1 + + class SubModelB: + y: int = 2 + + class ModelA: + sub: SubModelA + + class ModelB: + sub: SubModelB + + model_a_schema = core_schema.model_schema( + ModelA, + core_schema.model_fields_schema( + fields={ + 'sub': core_schema.model_field( + core_schema.model_schema( + SubModelA, + core_schema.model_fields_schema( + fields={ + 'x': core_schema.model_field( + core_schema.with_default_schema(core_schema.int_schema(), default=1) + ) + } + ), + ) + ) + } + ), + ) + + model_b_schema = core_schema.model_schema( + ModelB, + core_schema.model_fields_schema( + fields={ + 'sub': core_schema.model_field( + core_schema.model_schema( + SubModelB, + core_schema.model_fields_schema( + fields={ + 'y': core_schema.model_field( + core_schema.with_default_schema(core_schema.int_schema(), default=2) + ) + } + ), + ) + ) + } + ), + ) + + for choices in [[model_a_schema, model_b_schema], [model_b_schema, model_a_schema]]: + validator = SchemaValidator(core_schema.union_schema(choices=choices)) + + assert isinstance(validator.validate_python({'sub': {'x': 1}}), ModelA) + assert isinstance(validator.validate_python({'sub': {'y': 3}}), ModelB) + + # defaults to leftmost choice if there's a tie + assert isinstance(validator.validate_python({'sub': {}}), choices[0]['cls']) + + +def test_smart_union_does_nested_dataclass_field_counting() -> None: + @dataclass + class SubModelA: + x: int = 1 + + @dataclass + class SubModelB: + y: int = 2 + + @dataclass + class ModelA: + sub: SubModelA + + @dataclass + class ModelB: + sub: SubModelB - assert validator.validate_python({'x': 1, 'y': 2}).keys() == ModelB.__required_keys__ - assert validator.validate_python({'x': 1, 'y': '2'}).keys() == ModelB.__required_keys__ - assert validator.validate_python({'x': '1', 'y': 2}).keys() == ModelB.__required_keys__ - assert validator.validate_python({'x': '1', 'y': '2'}).keys() == ModelB.__required_keys__ + dc_a_schema = core_schema.dataclass_schema( + ModelA, + core_schema.dataclass_args_schema( + 'ModelA', + [ + core_schema.dataclass_field( + 'sub', + core_schema.with_default_schema( + core_schema.dataclass_schema( + SubModelA, + core_schema.dataclass_args_schema( + 'SubModelA', + [ + core_schema.dataclass_field( + 'x', core_schema.with_default_schema(core_schema.int_schema(), default=1) + ) + ], + ), + ['x'], + ), + default=SubModelA(), + ), + ) + ], + ), + ['sub'], + ) + dc_b_schema = core_schema.dataclass_schema( + ModelB, + core_schema.dataclass_args_schema( + 'ModelB', + [ + core_schema.dataclass_field( + 'sub', + core_schema.with_default_schema( + core_schema.dataclass_schema( + SubModelB, + core_schema.dataclass_args_schema( + 'SubModelB', + [ + core_schema.dataclass_field( + 'y', core_schema.with_default_schema(core_schema.int_schema(), default=2) + ) + ], + ), + ['y'], + ), + default=SubModelB(), + ), + ) + ], + ), + ['sub'], + ) + + for choices in [[dc_a_schema, dc_b_schema], [dc_b_schema, dc_a_schema]]: + validator = SchemaValidator(core_schema.union_schema(choices=choices)) -def test_smart_union_does_nested_model_field_counting() -> None: ... + assert isinstance(validator.validate_python({'sub': {'x': 1}}), ModelA) + assert isinstance(validator.validate_python({'sub': {'y': 3}}), ModelB) + # defaults to leftmost choice if there's a tie + assert isinstance(validator.validate_python({'sub': {}}), choices[0]['cls']) -def test_smart_union_does_nested_dataclass_field_counting() -> None: ... +def test_smart_union_does_nested_typed_dict_field_counting() -> None: + td_a_schema = core_schema.typed_dict_schema( + fields={ + 'sub': core_schema.typed_dict_field( + core_schema.typed_dict_schema(fields={'x': core_schema.typed_dict_field(core_schema.int_schema())}) + ) + } + ) + + td_b_schema = core_schema.typed_dict_schema( + fields={ + 'sub': core_schema.typed_dict_field( + core_schema.typed_dict_schema(fields={'y': core_schema.typed_dict_field(core_schema.int_schema())}) + ) + } + ) + + for choices in [[td_a_schema, td_b_schema], [td_b_schema, td_a_schema]]: + validator = SchemaValidator(core_schema.union_schema(choices=choices)) -def test_smart_union_does_nested_typed_dict_field_counting() -> None: ... + assert set(validator.validate_python({'sub': {'x': 1}})['sub'].keys()) == {'x'} + assert set(validator.validate_python({'sub': {'y': 2}})['sub'].keys()) == {'y'} From d2ef40032c169bf859153b3a08537ea6f3ae578f Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Wed, 19 Jun 2024 07:27:40 -0400 Subject: [PATCH 18/26] updating fields set at the end) --- src/validators/union.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/validators/union.rs b/src/validators/union.rs index 085db0289..5da73807e 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -171,8 +171,13 @@ impl UnionValidator { state.exactness = old_exactness; state.fields_set_count = old_fields_set_count; - if let Some((success, exactness, _fields_set)) = success { + if let Some((success, exactness, fields_set_count)) = success { state.floor_exactness(exactness); + state.fields_set_count = match (state.fields_set_count, fields_set_count) { + (Some(cur), Some(new)) => Some(cur + new), + (None, Some(new)) => Some(new), + _ => state.fields_set_count, + }; return Ok(success); } From 1585e96708c1ffc048d1971f1b714392c7e8cbdc Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Wed, 19 Jun 2024 08:09:38 -0400 Subject: [PATCH 19/26] comments and best practice with state updates --- src/validators/dataclass.rs | 2 +- src/validators/model.rs | 10 +++++----- src/validators/typed_dict.rs | 2 +- src/validators/union.rs | 8 ++++++++ src/validators/validation_state.rs | 4 ++++ tests/validators/test_union.py | 3 +++ 6 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 6455049db..72d738440 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -344,7 +344,7 @@ impl Validator for DataclassArgsValidator { } } - state.fields_set_count = Some(fields_set_count + state.fields_set_count.unwrap_or(0)); + state.add_fields_set(fields_set_count); if errors.is_empty() { if let Some(init_only_args) = init_only_args { diff --git a/src/validators/model.rs b/src/validators/model.rs index b06718251..fc4ca75d3 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -204,7 +204,7 @@ impl Validator for ModelValidator { for field_name in validated_fields_set { fields_set.add(field_name)?; } - state.fields_set_count = Some(fields_set.len() + state.fields_set_count.unwrap_or(0)); + state.add_fields_set(fields_set.len()); } force_setattr(py, model, intern!(py, DUNDER_DICT), validated_dict.to_object(py))?; @@ -244,11 +244,11 @@ impl ModelValidator { }; force_setattr(py, self_instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?; force_setattr(py, self_instance, intern!(py, ROOT_FIELD), &output)?; - state.fields_set_count = Some(fields_set.len() + state.fields_set_count.unwrap_or(0)); + state.add_fields_set(fields_set.len()); } else { let (model_dict, model_extra, fields_set): (Bound, Bound, Bound) = output.extract(py)?; - state.fields_set_count = Some(fields_set.len().unwrap_or(0) + state.fields_set_count.unwrap_or(0)); + state.add_fields_set(fields_set.len().unwrap_or(0)); set_model_attrs(self_instance, &model_dict, &model_extra, &fields_set)?; } self.call_post_init(py, self_instance.clone(), input, state.extra()) @@ -287,11 +287,11 @@ impl ModelValidator { }; force_setattr(py, &instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?; force_setattr(py, &instance, intern!(py, ROOT_FIELD), output)?; - state.fields_set_count = Some(fields_set.len() + state.fields_set_count.unwrap_or(0)); + state.add_fields_set(fields_set.len()); } else { let (model_dict, model_extra, val_fields_set) = output.extract(py)?; let fields_set = existing_fields_set.unwrap_or(&val_fields_set); - state.fields_set_count = Some(fields_set.len().unwrap_or(0) + state.fields_set_count.unwrap_or(0)); + state.add_fields_set(fields_set.len().unwrap_or(0)); set_model_attrs(&instance, &model_dict, &model_extra, fields_set)?; } self.call_post_init(py, instance, input, state.extra()) diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index eb933eb74..f72d969a8 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -230,7 +230,7 @@ impl Validator for TypedDictValidator { } } - state.fields_set_count = Some(fields_set_count + state.fields_set_count.unwrap_or(0)); + state.add_fields_set(fields_set_count); } if let Some(used_keys) = used_keys { diff --git a/src/validators/union.rs b/src/validators/union.rs index 5da73807e..bda00fac5 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -144,6 +144,12 @@ impl UnionValidator { let new_exactness = state.exactness.unwrap_or(Exactness::Lax); let new_fields_set = state.fields_set_count; + // we use both the exactness and the fields_set_count to determine the best union member match + // if fields_set_count is available for the current best match and the new candidate, we use this + // as the primary metric. If the new fields_set_count is greater, the new candidate is better. + // if the fields_set_count is the same, we use the exactness as a tie breaker to determine the best match. + // if the fields_set_count is not available for either the current best match or the new candidate, + // we use the exactness to determine the best match. let new_success_is_best_match: bool = success.as_ref().map_or(true, |(_, cur_exactness, cur_fields_set)| { match (*cur_fields_set, new_fields_set) { @@ -176,6 +182,8 @@ impl UnionValidator { state.fields_set_count = match (state.fields_set_count, fields_set_count) { (Some(cur), Some(new)) => Some(cur + new), (None, Some(new)) => Some(new), + // if both are None, or if new is None, we keep state.fields_set_count = None + // because for some validators, fields_set_count is not relevant _ => state.fields_set_count, }; return Ok(success); diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs index c7db59c22..92edfbbe9 100644 --- a/src/validators/validation_state.rs +++ b/src/validators/validation_state.rs @@ -70,6 +70,10 @@ impl<'a, 'py> ValidationState<'a, 'py> { } } + pub fn add_fields_set(&mut self, fields_set_count: usize) { + *self.fields_set_count.get_or_insert(0) += fields_set_count; + } + pub fn cache_str(&self) -> StringCacheMode { self.extra.cache_str } diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index 928424d09..98fe1cc98 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -1208,3 +1208,6 @@ def test_smart_union_does_nested_typed_dict_field_counting() -> None: assert set(validator.validate_python({'sub': {'x': 1}})['sub'].keys()) == {'x'} assert set(validator.validate_python({'sub': {'y': 2}})['sub'].keys()) == {'y'} + + +def test_nested_unions_bubble_up_field_count() -> None: ... From d3f88b7f94e02c8f39bed47ee379903e286e1bf9 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Wed, 19 Jun 2024 08:11:07 -0400 Subject: [PATCH 20/26] consistency w var names --- src/validators/union.rs | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/validators/union.rs b/src/validators/union.rs index bda00fac5..72e88674a 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -142,7 +142,7 @@ impl UnionValidator { debug_assert_ne!(state.exactness, None); let new_exactness = state.exactness.unwrap_or(Exactness::Lax); - let new_fields_set = state.fields_set_count; + let new_fields_set_count = state.fields_set_count; // we use both the exactness and the fields_set_count to determine the best union member match // if fields_set_count is available for the current best match and the new candidate, we use this @@ -151,15 +151,17 @@ impl UnionValidator { // if the fields_set_count is not available for either the current best match or the new candidate, // we use the exactness to determine the best match. let new_success_is_best_match: bool = - success.as_ref().map_or(true, |(_, cur_exactness, cur_fields_set)| { - match (*cur_fields_set, new_fields_set) { - (Some(cur), Some(new)) if cur != new => cur < new, - _ => *cur_exactness < new_exactness, - } - }); + success + .as_ref() + .map_or(true, |(_, cur_exactness, cur_fields_set_count)| { + match (*cur_fields_set_count, new_fields_set_count) { + (Some(cur), Some(new)) if cur != new => cur < new, + _ => *cur_exactness < new_exactness, + } + }); if new_success_is_best_match { - success = Some((new_success, new_exactness, new_fields_set)); + success = Some((new_success, new_exactness, new_fields_set_count)); } } }, From 38f911ff624c58b5549ea25240e1fe47a78ecc82 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Wed, 19 Jun 2024 08:16:07 -0400 Subject: [PATCH 21/26] abbreviated syntax --- src/validators/union.rs | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/validators/union.rs b/src/validators/union.rs index 72e88674a..405a19c02 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -181,13 +181,9 @@ impl UnionValidator { if let Some((success, exactness, fields_set_count)) = success { state.floor_exactness(exactness); - state.fields_set_count = match (state.fields_set_count, fields_set_count) { - (Some(cur), Some(new)) => Some(cur + new), - (None, Some(new)) => Some(new), - // if both are None, or if new is None, we keep state.fields_set_count = None - // because for some validators, fields_set_count is not relevant - _ => state.fields_set_count, - }; + if let Some(count) = fields_set_count { + state.add_fields_set(count); + } return Ok(success); } From c3b43e9965cf423a1c3a71546d4e20b057a916a5 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Wed, 19 Jun 2024 09:25:18 -0400 Subject: [PATCH 22/26] adding a bubble up test + doing some minor test refactoring --- tests/validators/test_union.py | 255 +++++++++++++++++++++------------ 1 file changed, 162 insertions(+), 93 deletions(-) diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index 98fe1cc98..2baab9ade 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from datetime import date, time from enum import Enum, IntEnum +from itertools import permutations from typing import Any, Optional, Union from uuid import UUID @@ -807,7 +808,11 @@ class ModelA: assert validator.validate_python(True) is True -def test_union_with_subclass() -> None: +def permute_choices(choices: list[core_schema.CoreSchema]) -> list[list[core_schema.CoreSchema]]: + return [list(p) for p in permutations(choices)] + + +class TestSmartUnionWithSubclass: class ModelA: a: int @@ -827,112 +832,74 @@ class ModelB(ModelA): ), ) - for choices in [[model_a_schema, model_b_schema], [model_b_schema, model_a_schema]]: + @pytest.mark.parametrize('choices', permute_choices([model_a_schema, model_b_schema])) + def test_more_specific_data_matches_subclass(self, choices) -> None: validator = SchemaValidator(schema=core_schema.union_schema(choices)) - assert isinstance(validator.validate_python({'a': 1}), ModelA) - assert isinstance(validator.validate_python({'a': 1, 'b': 2}), ModelB) + assert isinstance(validator.validate_python({'a': 1}), self.ModelA) + assert isinstance(validator.validate_python({'a': 1, 'b': 2}), self.ModelB) + + assert isinstance(validator.validate_python({'a': 1, 'b': 2}), self.ModelB) # confirm that a model that matches in lax mode with 2 fields # is preferred over a model that matches in strict mode with 1 field - assert isinstance(validator.validate_python({'a': '1', 'b': '2'}), ModelB) - assert isinstance(validator.validate_python({'a': '1', 'b': 2}), ModelB) - assert isinstance(validator.validate_python({'a': 1, 'b': '2'}), ModelB) - assert isinstance(validator.validate_python({'a': 1, 'b': 2}), ModelB) + assert isinstance(validator.validate_python({'a': '1', 'b': '2'}), self.ModelB) + assert isinstance(validator.validate_python({'a': '1', 'b': 2}), self.ModelB) + assert isinstance(validator.validate_python({'a': 1, 'b': '2'}), self.ModelB) -def test_union_with_default() -> None: +class TestSmartUnionWithDefaults: class ModelA: a: int = 0 class ModelB: b: int = 0 - val = SchemaValidator( - core_schema.union_schema( - [ - core_schema.model_schema( - ModelA, - core_schema.model_fields_schema( - fields={ - 'a': core_schema.model_field( - core_schema.with_default_schema(core_schema.int_schema(), default=0) - ) - } - ), - ), - core_schema.model_schema( - ModelB, - core_schema.model_fields_schema( - fields={ - 'b': core_schema.model_field( - core_schema.with_default_schema(core_schema.int_schema(), default=0) - ) - } - ), - ), - ] + model_a_schema = core_schema.model_schema( + ModelA, + core_schema.model_fields_schema( + fields={'a': core_schema.model_field(core_schema.with_default_schema(core_schema.int_schema(), default=0))} + ), + ) + model_b_schema = core_schema.model_schema( + ModelB, + core_schema.model_fields_schema( + fields={'b': core_schema.model_field(core_schema.with_default_schema(core_schema.int_schema(), default=0))} ), ) - assert isinstance(val.validate_python({'a': 1}), ModelA) - assert isinstance(val.validate_python({'b': 1}), ModelB) - - # defaults to leftmost choice if there's a tie - assert isinstance(val.validate_python({}), ModelA) - + @pytest.mark.parametrize('choices', permute_choices([model_a_schema, model_b_schema])) + def test_fields_set_ensures_best_match(self, choices) -> None: + validator = SchemaValidator(schema=core_schema.union_schema(choices)) + assert isinstance(validator.validate_python({'a': 1}), self.ModelA) + assert isinstance(validator.validate_python({'b': 1}), self.ModelB) -def test_optional_union_with_members_having_defaults() -> None: - class ModelA: - a: int = 0 + # defaults to leftmost choice if there's a tie + assert isinstance(validator.validate_python({}), choices[0]['cls']) - class ModelB: - b: int = 0 + @pytest.mark.parametrize('choices', permute_choices([model_a_schema, model_b_schema])) + def test_optional_union_with_members_having_defaults(self, choices) -> None: + class WrapModel: + val: Optional[Union[self.ModelA, self.ModelB]] = None - class WrapModel: - val: Optional[Union[ModelA, ModelB]] = None - - val = SchemaValidator( - core_schema.model_schema( - WrapModel, - core_schema.model_fields_schema( - fields={ - 'val': core_schema.model_field( - core_schema.with_default_schema( - core_schema.union_schema( - [ - core_schema.model_schema( - ModelA, - core_schema.model_fields_schema( - fields={ - 'a': core_schema.model_field( - core_schema.with_default_schema(core_schema.int_schema(), default=0) - ) - } - ), - ), - core_schema.model_schema( - ModelB, - core_schema.model_fields_schema( - fields={ - 'b': core_schema.model_field( - core_schema.with_default_schema(core_schema.int_schema(), default=0) - ) - } - ), - ), - ] - ), - default=None, + val = SchemaValidator( + core_schema.model_schema( + WrapModel, + core_schema.model_fields_schema( + fields={ + 'val': core_schema.model_field( + core_schema.with_default_schema( + core_schema.union_schema(choices), + default=None, + ) ) - ) - } - ), + } + ), + ) ) - ) - assert isinstance(val.validate_python({'val': {'a': 1}}).val, ModelA) - assert isinstance(val.validate_python({'val': {'b': 1}}).val, ModelB) - assert val.validate_python({}).val is None + assert isinstance(val.validate_python({'val': {'a': 1}}).val, self.ModelA) + assert isinstance(val.validate_python({'val': {'b': 1}}).val, self.ModelB) + assert val.validate_python({}).val is None def test_dc_smart_union_by_fields_set() -> None: @@ -962,7 +929,7 @@ class ModelB(ModelA): ['x', 'y'], ) - for choices in [[dc_a_schema, dc_b_schema], [dc_b_schema, dc_a_schema]]: + for choices in permute_choices([dc_a_schema, dc_b_schema]): validator = SchemaValidator(core_schema.union_schema(choices=choices)) assert isinstance(validator.validate_python({'x': 1}), ModelA) @@ -1009,7 +976,7 @@ class ModelB: ['b'], ) - for choices in [[dc_a_schema, dc_b_schema], [dc_b_schema, dc_a_schema]]: + for choices in permute_choices([dc_a_schema, dc_b_schema]): validator = SchemaValidator(core_schema.union_schema(choices=choices)) assert isinstance(validator.validate_python({'a': 1}), ModelA) @@ -1028,7 +995,7 @@ def test_td_smart_union_by_fields_set() -> None: }, ) - for choices in [[td_a_schema, td_b_schema], [td_b_schema, td_a_schema]]: + for choices in permute_choices([td_a_schema, td_b_schema]): validator = SchemaValidator(core_schema.union_schema(choices=choices)) assert set(validator.validate_python({'x': 1}).keys()) == {'x'} @@ -1093,7 +1060,7 @@ class ModelB: ), ) - for choices in [[model_a_schema, model_b_schema], [model_b_schema, model_a_schema]]: + for choices in permute_choices([model_a_schema, model_b_schema]): validator = SchemaValidator(core_schema.union_schema(choices=choices)) assert isinstance(validator.validate_python({'sub': {'x': 1}}), ModelA) @@ -1176,7 +1143,7 @@ class ModelB: ['sub'], ) - for choices in [[dc_a_schema, dc_b_schema], [dc_b_schema, dc_a_schema]]: + for choices in permute_choices([dc_a_schema, dc_b_schema]): validator = SchemaValidator(core_schema.union_schema(choices=choices)) assert isinstance(validator.validate_python({'sub': {'x': 1}}), ModelA) @@ -1203,11 +1170,113 @@ def test_smart_union_does_nested_typed_dict_field_counting() -> None: } ) - for choices in [[td_a_schema, td_b_schema], [td_b_schema, td_a_schema]]: + for choices in permute_choices([td_a_schema, td_b_schema]): validator = SchemaValidator(core_schema.union_schema(choices=choices)) assert set(validator.validate_python({'sub': {'x': 1}})['sub'].keys()) == {'x'} assert set(validator.validate_python({'sub': {'y': 2}})['sub'].keys()) == {'y'} -def test_nested_unions_bubble_up_field_count() -> None: ... +def test_nested_unions_bubble_up_field_count() -> None: + class SubModelX: + x1: int = 0 + x2: int = 0 + x3: int = 0 + + class SubModelY: + x1: int = 0 + x2: int = 0 + x3: int = 0 + + class SubModelZ: + z1: int = 0 + z2: int = 0 + z3: int = 0 + + class SubModelW: + w1: int = 0 + w2: int = 0 + w3: int = 0 + + class ModelA: + a: SubModelX | SubModelY + + class ModelB: + b: SubModelZ | SubModelW + + model_x_schema = core_schema.model_schema( + SubModelX, + core_schema.model_fields_schema( + fields={ + 'x1': core_schema.model_field(core_schema.with_default_schema(core_schema.int_schema(), default=0)), + 'x2': core_schema.model_field(core_schema.with_default_schema(core_schema.int_schema(), default=0)), + 'x3': core_schema.model_field(core_schema.with_default_schema(core_schema.int_schema(), default=0)), + } + ), + ) + + model_y_schema = core_schema.model_schema( + SubModelY, + core_schema.model_fields_schema( + fields={ + 'x1': core_schema.model_field(core_schema.with_default_schema(core_schema.int_schema(), default=0)), + 'x2': core_schema.model_field(core_schema.with_default_schema(core_schema.int_schema(), default=0)), + 'x3': core_schema.model_field(core_schema.with_default_schema(core_schema.int_schema(), default=0)), + } + ), + ) + + model_z_schema = core_schema.model_schema( + SubModelZ, + core_schema.model_fields_schema( + fields={ + 'z1': core_schema.model_field(core_schema.with_default_schema(core_schema.int_schema(), default=0)), + 'z2': core_schema.model_field(core_schema.with_default_schema(core_schema.int_schema(), default=0)), + 'z3': core_schema.model_field(core_schema.with_default_schema(core_schema.int_schema(), default=0)), + } + ), + ) + + model_w_schema = core_schema.model_schema( + SubModelW, + core_schema.model_fields_schema( + fields={ + 'w1': core_schema.model_field(core_schema.with_default_schema(core_schema.int_schema(), default=0)), + 'w2': core_schema.model_field(core_schema.with_default_schema(core_schema.int_schema(), default=0)), + 'w3': core_schema.model_field(core_schema.with_default_schema(core_schema.int_schema(), default=0)), + } + ), + ) + + model_a_schema_options = [ + core_schema.union_schema([model_x_schema, model_y_schema]), + core_schema.union_schema([model_y_schema, model_x_schema]), + ] + + model_b_schema_options = [ + core_schema.union_schema([model_z_schema, model_w_schema]), + core_schema.union_schema([model_w_schema, model_z_schema]), + ] + + for model_a_schema in model_a_schema_options: + for model_b_schema in model_b_schema_options: + validator = SchemaValidator( + core_schema.union_schema( + [ + core_schema.model_schema( + ModelA, + core_schema.model_fields_schema(fields={'a': core_schema.model_field(model_a_schema)}), + ), + core_schema.model_schema( + ModelB, + core_schema.model_fields_schema(fields={'b': core_schema.model_field(model_b_schema)}), + ), + ] + ) + ) + + result = validator.validate_python( + {'a': {'x1': 1, 'x2': 2, 'y1': 1, 'y2': 2}, 'b': {'w1': 1, 'w2': 2, 'w3': 3}} + ) + assert isinstance(result, ModelB) + assert isinstance(result.b, SubModelW) From 31a439a01376cc3b1071f5a76d39c7907a16e646 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Wed, 19 Jun 2024 09:32:49 -0400 Subject: [PATCH 23/26] 3.8 fixes --- tests/validators/test_union.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index 2baab9ade..5099506bd 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -2,7 +2,7 @@ from datetime import date, time from enum import Enum, IntEnum from itertools import permutations -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union from uuid import UUID import pytest @@ -808,7 +808,7 @@ class ModelA: assert validator.validate_python(True) is True -def permute_choices(choices: list[core_schema.CoreSchema]) -> list[list[core_schema.CoreSchema]]: +def permute_choices(choices: List[core_schema.CoreSchema]) -> List[list[core_schema.CoreSchema]]: return [list(p) for p in permutations(choices)] From ed920ad3496abe42e5403d71ae33efabb93c0c87 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Wed, 19 Jun 2024 10:12:59 -0400 Subject: [PATCH 24/26] oops, another list --- tests/validators/test_union.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index 5099506bd..e0919e2ea 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -808,7 +808,7 @@ class ModelA: assert validator.validate_python(True) is True -def permute_choices(choices: List[core_schema.CoreSchema]) -> List[list[core_schema.CoreSchema]]: +def permute_choices(choices: List[core_schema.CoreSchema]) -> List[List[core_schema.CoreSchema]]: return [list(p) for p in permutations(choices)] From 645f9176e5566641ef4ad66d6b0af1edaa36f8aa Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Wed, 19 Jun 2024 10:18:14 -0400 Subject: [PATCH 25/26] ugh, last 3.8 fix hopefully --- tests/validators/test_union.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index e0919e2ea..f2d10f36b 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -1199,10 +1199,10 @@ class SubModelW: w3: int = 0 class ModelA: - a: SubModelX | SubModelY + a: Union[SubModelX, SubModelY] class ModelB: - b: SubModelZ | SubModelW + b: Union[SubModelZ, SubModelW] model_x_schema = core_schema.model_schema( SubModelX, From ed8ddb994439e28725fd92c5ab41893c55ea8b04 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Wed, 19 Jun 2024 11:34:30 -0400 Subject: [PATCH 26/26] name change success -> best_match --- src/validators/union.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/validators/union.rs b/src/validators/union.rs index 405a19c02..dc46bf52f 100644 --- a/src/validators/union.rs +++ b/src/validators/union.rs @@ -113,9 +113,7 @@ impl UnionValidator { let strict = state.strict_or(self.strict); let mut errors = MaybeErrors::new(self.custom_error.as_ref()); - // we use this to track the validation against the most compatible union member - // up to the current point - let mut success: Option<(Py, Exactness, Option)> = None; + let mut best_match: Option<(Py, Exactness, Option)> = None; for (choice, label) in &self.choices { let state = &mut state.rebind_extra(|extra| { @@ -151,7 +149,7 @@ impl UnionValidator { // if the fields_set_count is not available for either the current best match or the new candidate, // we use the exactness to determine the best match. let new_success_is_best_match: bool = - success + best_match .as_ref() .map_or(true, |(_, cur_exactness, cur_fields_set_count)| { match (*cur_fields_set_count, new_fields_set_count) { @@ -161,13 +159,13 @@ impl UnionValidator { }); if new_success_is_best_match { - success = Some((new_success, new_exactness, new_fields_set_count)); + best_match = Some((new_success, new_exactness, new_fields_set_count)); } } }, Err(ValError::LineErrors(lines)) => { // if we don't yet know this validation will succeed, record the error - if success.is_none() { + if best_match.is_none() { errors.push(choice, label.as_deref(), lines); } } @@ -179,12 +177,12 @@ impl UnionValidator { state.exactness = old_exactness; state.fields_set_count = old_fields_set_count; - if let Some((success, exactness, fields_set_count)) = success { + if let Some((best_match, exactness, fields_set_count)) = best_match { state.floor_exactness(exactness); if let Some(count) = fields_set_count { state.add_fields_set(count); } - return Ok(success); + return Ok(best_match); } // no matches, build errors