Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements in union matching logic during validation #1332

Merged
merged 26 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
9117c6b
first pass at union match improvement
sydney-runkle Jun 16, 2024
fcc84ec
don't need to add explicit numbers
sydney-runkle Jun 16, 2024
0255211
figuring out fields set
sydney-runkle Jun 16, 2024
ceb5e1b
behavior for checking model fields set
sydney-runkle Jun 17, 2024
1a6b00c
working on my rust best practices
sydney-runkle Jun 17, 2024
1718e8b
using method on Validator trait to streamline things
sydney-runkle Jun 17, 2024
500a812
comment udpate
sydney-runkle Jun 17, 2024
6645ac2
new round of tests
sydney-runkle Jun 17, 2024
d479928
get all tests passing
sydney-runkle Jun 17, 2024
e16103a
use Union not pipe in tests
sydney-runkle Jun 17, 2024
9208282
abandon num_fields for a state based approach
sydney-runkle Jun 18, 2024
69f6ec0
get typed dicts working, add tests for other model like cases
sydney-runkle Jun 18, 2024
18b70d4
dataclass support + more efficient exact return
sydney-runkle Jun 18, 2024
e4f1e6b
all dataclass tests passing :)
sydney-runkle Jun 18, 2024
00ddca2
bubble up fields set in nested models
sydney-runkle Jun 18, 2024
fc21636
add nested counting for dataclasses, typed dicts
sydney-runkle Jun 18, 2024
72aa7bd
corresponding tests
sydney-runkle Jun 19, 2024
d2ef400
updating fields set at the end)
sydney-runkle Jun 19, 2024
1585e96
comments and best practice with state updates
sydney-runkle Jun 19, 2024
d3f88b7
consistency w var names
sydney-runkle Jun 19, 2024
38f911f
abbreviated syntax
sydney-runkle Jun 19, 2024
c3b43e9
adding a bubble up test + doing some minor test refactoring
sydney-runkle Jun 19, 2024
31a439a
3.8 fixes
sydney-runkle Jun 19, 2024
ed920ad
oops, another list
sydney-runkle Jun 19, 2024
645f917
ugh, last 3.8 fix hopefully
sydney-runkle Jun 19, 2024
ed8ddb9
name change success -> best_match
sydney-runkle Jun 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/validators/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ impl Validator for ModelValidator {
for field_name in validated_fields_set {
fields_set.add(field_name)?;
}
state.fields_set_count = Some(fields_set.len());
}

force_setattr(py, model, intern!(py, DUNDER_DICT), validated_dict.to_object(py))?;
Expand Down Expand Up @@ -241,10 +242,13 @@ impl ModelValidator {
} else {
PySet::new_bound(py, [&String::from(ROOT_FIELD)])?
};
force_setattr(py, self_instance, intern!(py, DUNDER_FIELDS_SET_KEY), fields_set)?;
force_setattr(py, self_instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?;
force_setattr(py, self_instance, intern!(py, ROOT_FIELD), &output)?;
state.fields_set_count = Some(fields_set.len());
} else {
let (model_dict, model_extra, fields_set) = output.extract(py)?;
let (model_dict, model_extra, fields_set): (Bound<PyAny>, Bound<PyAny>, Bound<PyAny>) =
output.extract(py)?;
state.fields_set_count = fields_set.len().ok();
set_model_attrs(self_instance, &model_dict, &model_extra, &fields_set)?;
}
self.call_post_init(py, self_instance.clone(), input, state.extra())
Expand Down Expand Up @@ -281,11 +285,13 @@ impl ModelValidator {
} else {
PySet::new_bound(py, [&String::from(ROOT_FIELD)])?
};
force_setattr(py, &instance, intern!(py, DUNDER_FIELDS_SET_KEY), fields_set)?;
force_setattr(py, &instance, intern!(py, DUNDER_FIELDS_SET_KEY), &fields_set)?;
force_setattr(py, &instance, intern!(py, ROOT_FIELD), output)?;
state.fields_set_count = Some(fields_set.len());
} else {
let (model_dict, model_extra, val_fields_set) = output.extract(py)?;
let fields_set = existing_fields_set.unwrap_or(&val_fields_set);
state.fields_set_count = fields_set.len().ok();
set_model_attrs(&instance, &model_dict, &model_extra, fields_set)?;
}
self.call_post_init(py, instance, input, state.extra())
Expand Down
4 changes: 4 additions & 0 deletions src/validators/typed_dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ impl Validator for TypedDictValidator {

{
let state = &mut state.rebind_extra(|extra| extra.data = Some(output_dict.clone()));
let mut fields_set_count: usize = 0;

for field in &self.fields {
let op_key_value = match dict.get_item(&field.lookup_key) {
Expand All @@ -186,6 +187,7 @@ impl Validator for TypedDictValidator {
match field.validator.validate(py, value.borrow_input(), state) {
Ok(value) => {
output_dict.set_item(&field.name_py, value)?;
fields_set_count += 1;
}
Err(ValError::Omit) => continue,
Err(ValError::LineErrors(line_errors)) => {
Expand Down Expand Up @@ -227,6 +229,8 @@ impl Validator for TypedDictValidator {
Err(err) => return Err(err),
}
}

state.fields_set_count = (fields_set_count != 0).then_some(fields_set_count);
}

if let Some(used_keys) = used_keys {
Expand Down
28 changes: 17 additions & 11 deletions src/validators/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ impl UnionValidator {
let strict = state.strict_or(self.strict);
let mut errors = MaybeErrors::new(self.custom_error.as_ref());

let mut success = None;
// we use this to track the validation against the most compatible union member
// up to the current point
let mut success: Option<(Py<PyAny>, Exactness, Option<usize>)> = None;

for (choice, label) in &self.choices {
let state = &mut state.rebind_extra(|extra| {
Expand All @@ -134,16 +136,20 @@ impl UnionValidator {
_ => {
// success should always have an exactness
debug_assert_ne!(state.exactness, None);

let new_exactness = state.exactness.unwrap_or(Exactness::Lax);
// if the new result has higher exactness than the current success, replace it
if success
.as_ref()
.map_or(true, |(_, current_exactness)| *current_exactness < new_exactness)
{
// TODO: is there a possible optimization here, where once there has
// been one success, we turn on strict mode, to avoid unnecessary
// coercions for further validation?
success = Some((new_success, new_exactness));
let new_fields_set = state.fields_set_count;
Copy link
Contributor

@davidhewitt davidhewitt Jun 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we probably need to reset fields_set_count between each loop turn? (e.g. like exactness is reset on line 124).

Propagating on the state like this potentially has surprising effects like list[ModelA] | list[ModelB] probably now gets chosen based on whether the last entry fitted ModelA or ModelB better. We could deal with that by explicitly clearing num_fields_set in compound validators like list validators.

Overall I'm fine with this approach of using state, I think either this or num_fields (and exactness) all are fiddly to implement and will likely just need lots of hardening via test cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Behavior looks good for the list case you've mentioned above:

from pydantic import BaseModel, TypeAdapter

class ModelA(BaseModel):
    a: int

class ModelB(ModelA):
    b: int

ta = TypeAdapter(list[ModelA] | list[ModelB])

print(repr(ta.validate_python([{'a': 1}, {'a': 2}])))
#> ModelA(a=1), ModelA(a=2)]

print(repr(ta.validate_python([{'a': 1, 'b': 2}, {'a': 2, 'b': 3}])))
#> [ModelB(a=1, b=2), ModelB(a=2, b=3)]

print(repr(ta.validate_python([{'a': 1, 'b': 2}, {'a': 2}])))
#> [ModelA(a=1), ModelA(a=2)]

print(repr(ta.validate_python([{'a': 1}, {'a': 2, 'b': 3}])))
#> [ModelA(a=1), ModelA(a=2)]

I think this is because I'm now clearing the relevant parts of the state, and fields_set_count is by default None, in which case we only consider exactness.


let new_success_is_best_match: bool =
success.as_ref().map_or(true, |(_, cur_exactness, cur_fields_set)| {
match (*cur_fields_set, new_fields_set) {
(Some(cur), Some(new)) if cur != new => cur < new,
_ => *cur_exactness < new_exactness,
}
});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment here describing the logic? I read it as follows:

  • We prefer number of fields to exactness
  • In case where number of fields is tied, or one side doesn't have number of fields, we use exactness

What about ModelA | dict[str, int]? If we pass {'a': 2, 'b': 3} as input, then according to the above the dict validator doesn't currently have number of fields, and will win based on exactness.

I think instead if we passed {'a': '2', 'b': '3'} with string keys, the exactness is the same (lax) and so ModelA probably wins as the leftmost member.

Maybe dict validators should also count fields set?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the comment :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct re current behavior:

from pydantic import BaseModel, TypeAdapter

class DictModel(BaseModel):
    a: int
    b: int

ta = TypeAdapter(DictModel | dict[str, int])

print(repr(ta.validate_python({'a': 1, 'b': 2})))
#> {'a': 1, 'b': 2}

print(repr(ta.validate_python({'a': '1', 'b': '2'})))
#> DictModel(a=1, b=2)

print(repr(ta.validate_python(DictModel(a=1, b=2))))
#> DictModel(a=1, b=2)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I think this might be a case where we suggest that users use left-to-right or a discriminator if they really want to customize behavior.


if new_success_is_best_match {
success = Some((new_success, new_exactness, new_fields_set));
}
}
},
Expand All @@ -158,7 +164,7 @@ impl UnionValidator {
}
state.exactness = old_exactness;

if let Some((success, exactness)) = success {
if let Some((success, exactness, _fields_set)) = success {
state.floor_exactness(exactness);
return Ok(success);
}
Expand Down
2 changes: 2 additions & 0 deletions src/validators/validation_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub enum Exactness {
pub struct ValidationState<'a, 'py> {
pub recursion_guard: &'a mut RecursionState,
pub exactness: Option<Exactness>,
pub fields_set_count: Option<usize>,
// deliberately make Extra readonly
extra: Extra<'a, 'py>,
}
Expand All @@ -27,6 +28,7 @@ impl<'a, 'py> ValidationState<'a, 'py> {
Self {
recursion_guard, // Don't care about exactness unless doing union validation
exactness: None,
fields_set_count: None,
extra,
}
}
Expand Down
166 changes: 161 additions & 5 deletions tests/validators/test_union.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from datetime import date, time
from enum import Enum, IntEnum
from typing import Any
from typing import Any, Optional, Union
from uuid import UUID

import pytest
Expand Down Expand Up @@ -170,13 +170,15 @@ def test_model_a(self, schema_validator: SchemaValidator):
assert m.b == 'hello'
assert not hasattr(m, 'c')

def test_model_b_ignored(self, schema_validator: SchemaValidator):
# first choice works, so second choice is not used
def test_model_b_preferred(self, schema_validator: SchemaValidator):
# Note, this is a different behavior to previous smart union behavior,
# where the first match would be preferred. However, we believe is it better
# to prefer the match with the greatest number of valid fields set.
m = schema_validator.validate_python({'a': 1, 'b': 'hello', 'c': 2.0})
assert isinstance(m, self.ModelA)
assert isinstance(m, self.ModelB)
assert m.a == 1
assert m.b == 'hello'
assert not hasattr(m, 'c')
assert m.c == 2.0

def test_model_b_not_ignored(self, schema_validator: SchemaValidator):
m1 = self.ModelB()
Expand Down Expand Up @@ -803,3 +805,157 @@ class ModelA:
assert isinstance(m, ModelA)
assert m.a == 42
assert validator.validate_python(True) is True


def test_union_with_subclass() -> None:
class ModelA:
a: int

class ModelB(ModelA):
b: int

model_a_schema = core_schema.model_schema(
ModelA, core_schema.model_fields_schema(fields={'a': core_schema.model_field(core_schema.int_schema())})
)
model_b_schema = core_schema.model_schema(
ModelB,
core_schema.model_fields_schema(
fields={
'a': core_schema.model_field(core_schema.int_schema()),
'b': core_schema.model_field(core_schema.int_schema()),
}
),
)

for choices in [[model_a_schema, model_b_schema], [model_b_schema, model_a_schema]]:
validator = SchemaValidator(schema=core_schema.union_schema(choices))
assert isinstance(validator.validate_python({'a': 1}), ModelA)
assert isinstance(validator.validate_python({'a': 1, 'b': 2}), ModelB)

# confirm that a model that matches in lax mode with 2 fields
# is preferred over a model that matches in strict mode with 1 field
assert isinstance(validator.validate_python({'a': '1', 'b': '2'}), ModelB)
assert isinstance(validator.validate_python({'a': '1', 'b': 2}), ModelB)
assert isinstance(validator.validate_python({'a': 1, 'b': '2'}), ModelB)
assert isinstance(validator.validate_python({'a': 1, 'b': 2}), ModelB)


def test_union_with_default() -> None:
class ModelA:
a: int = 0

class ModelB:
b: int = 0

val = SchemaValidator(
{
'type': 'union',
'choices': [
{
'type': 'model',
'cls': ModelA,
'schema': {
'type': 'model-fields',
'fields': {
'a': {
'type': 'model-field',
'schema': {'type': 'default', 'schema': {'type': 'int'}, 'default': 0},
},
},
},
},
{
'type': 'model',
'cls': ModelB,
'schema': {
'type': 'model-fields',
'fields': {
'b': {
'type': 'model-field',
'schema': {'type': 'default', 'schema': {'type': 'int'}, 'default': 0},
},
},
},
},
],
}
)

assert isinstance(val.validate_python({'a': 1}), ModelA)
assert isinstance(val.validate_python({'b': 1}), ModelB)

# defaults to leftmost choice if there's a tie
assert isinstance(val.validate_python({}), ModelA)


def test_optional_union_with_members_having_defaults() -> None:
class ModelA:
a: int = 0

class ModelB:
b: int = 0

class WrapModel:
val: Optional[Union[ModelA, ModelB]] = None

val = SchemaValidator(
{
'type': 'model',
'cls': WrapModel,
'schema': {
'type': 'model-fields',
'fields': {
'val': {
'type': 'model-field',
'schema': {
'type': 'default',
'schema': {
'type': 'union',
'choices': [
{
'type': 'model',
'cls': ModelA,
'schema': {
'type': 'model-fields',
'fields': {
'a': {
'type': 'model-field',
'schema': {
'type': 'default',
'schema': {'type': 'int'},
'default': 0,
},
}
},
},
},
{
'type': 'model',
'cls': ModelB,
'schema': {
'type': 'model-fields',
'fields': {
'b': {
'type': 'model-field',
'schema': {
'type': 'default',
'schema': {'type': 'int'},
'default': 0,
},
}
},
},
},
],
},
'default': None,
},
}
},
},
}
)

assert isinstance(val.validate_python({'val': {'a': 1}}).val, ModelA)
assert isinstance(val.validate_python({'val': {'b': 1}}).val, ModelB)
assert val.validate_python({}).val is None
Loading