-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upgrade Pydantic to v2 #31
Open
kabirkhan
wants to merge
51
commits into
main
Choose a base branch
from
kab/pydantic-v2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
51 commits
Select commit
Hold shift + click to select a range
42696e4
convert library to compatibility w pydantic v2, start fixing tests
ccedeb0
fix constr and model config access
05ce65b
add support for Pydantic models and dataclasses out of registered fun…
a822d3c
update reqs
59f3f55
start compat
8ee7a24
Merge branch 'main' of ssh://github.com/explosion/confection into kab…
3b51749
small corrrections around new model_construct behavior
2df560f
use Iterator instead of Generator and GeneratorType
5213981
don't validate in fill_without_resolve test
ff3b55f
bump reqs
1999477
refactor and fix for mypy
45f12ba
disable python 3.6
efd1737
rm extra python 3.6 ref
ca99729
check that pydantic and dataclass versions of Optimizer both work
caf94a2
Merge branch 'kab/pydantic-v2' of ssh://github.com/explosion/confecti…
beb567e
fix conflict
07870e7
move back to old Config nested class
36bc368
fix tests
0b31287
update from model_extra
2247265
fix pydantic generator equals
712e0ed
fixes for organization
ee6b10c
allow pydantic v1/v2 in reqs/setup and test both in CI
81fa915
only run CI push to main, not other branches
0fb6858
fix issue with model_validate
04354f2
fix filter warnings for tests
a5f2d5a
try run ci
21196e5
try run ci
20974fd
smaller test matrix
995b9af
print pydantic version before tests
3acfe90
fixes for mypy
0bac230
test fixes
6d95b50
re-enable mypy
7aa2207
Merge pull request #36 from explosion/kab/pydantic-v2-compat
aa7d13b
Undo unrelated changes to CI tests
adrianeboyd fc29ccd
Ignore tests for mypy
adrianeboyd 65e69c1
Add mypy for pydantic v1
adrianeboyd a9dd2a3
Format
adrianeboyd 36d1d1c
Lower typing_extensions pin for python 3.6
adrianeboyd de14431
Undo changes to typing_extensions
adrianeboyd 3283e4a
Allow older pydantic v1 for tests for python 3.6
adrianeboyd cbada4b
Add CI test for spacy init config
adrianeboyd bf6624f
Merge branch 'main' of ssh://github.com/explosion/confection into kab…
680e224
black formatting
4f7d5b3
Fix spacy init issue (#37)
c4f78e8
Simplify pydantic requirements
adrianeboyd 10b45d5
Merge remote-tracking branch 'upstream/main' into kab/pydantic-v2
adrianeboyd e6ac8ec
Merge branch 'kab/pydantic-v2' of ssh://github.com/explosion/confecti…
8ae9252
add a spacy init config regression test if spacy is installed
883d76b
rm unused import
d25d81a
rm spacy init config step in gha in favor of test in pytest
ba068d6
fix checks
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
NoSectionError, | ||
ParsingError, | ||
) | ||
from dataclasses import dataclass | ||
from dataclasses import dataclass, is_dataclass | ||
from pathlib import Path | ||
from types import GeneratorType | ||
from typing import ( | ||
|
@@ -26,21 +26,24 @@ | |
Mapping, | ||
Optional, | ||
Sequence, | ||
Set, | ||
Tuple, | ||
Type, | ||
TypeVar, | ||
Union, | ||
cast, | ||
) | ||
|
||
import catalogue | ||
import srsly | ||
from pydantic import BaseModel, ValidationError, create_model | ||
from pydantic.fields import FieldInfo | ||
|
||
try: | ||
from pydantic.v1 import BaseModel, Extra, ValidationError, create_model | ||
from pydantic.v1.fields import ModelField | ||
from pydantic.v1.main import ModelMetaclass | ||
except ImportError: | ||
from pydantic import BaseModel, create_model, ValidationError, Extra # type: ignore | ||
from pydantic.main import ModelMetaclass # type: ignore | ||
from .util import PYDANTIC_V2, Decorator, SimpleFrozenDict, SimpleFrozenList | ||
|
||
if PYDANTIC_V2: | ||
from pydantic.v1.fields import ModelField # type: ignore | ||
else: | ||
from pydantic.fields import ModelField # type: ignore | ||
|
||
from .util import SimpleFrozenDict, SimpleFrozenList # noqa: F401 | ||
|
@@ -689,10 +692,7 @@ def alias_generator(name: str) -> str: | |
return name | ||
|
||
|
||
def copy_model_field(field: ModelField, type_: Any) -> ModelField: | ||
"""Copy a model field and assign a new type, e.g. to accept an Any type | ||
even though the original value is typed differently. | ||
""" | ||
def _copy_model_field_v1(field: ModelField, type_: Any) -> ModelField: | ||
return ModelField( | ||
name=field.name, | ||
type_=type_, | ||
|
@@ -704,6 +704,107 @@ def copy_model_field(field: ModelField, type_: Any) -> ModelField: | |
) | ||
|
||
|
||
def copy_model_field(field: FieldInfo, type_: Any) -> FieldInfo: | ||
"""Copy a model field and assign a new type, e.g. to accept an Any type | ||
even though the original value is typed differently. | ||
""" | ||
if PYDANTIC_V2: | ||
field_info = copy.deepcopy(field) | ||
field_info.annotation = type_ # type: ignore | ||
return field_info | ||
else: | ||
return _copy_model_field_v1(field, type_) # type: ignore | ||
|
||
|
||
def get_model_config_extra(model: Type[BaseModel]) -> str: | ||
if PYDANTIC_V2: | ||
extra = str(model.model_config.get("extra", "forbid")) # type: ignore | ||
else: | ||
extra = str(model.Config.extra) or "forbid" # type: ignore | ||
assert isinstance(extra, str) | ||
return extra | ||
|
||
|
||
_ModelT = TypeVar("_ModelT", bound=BaseModel) | ||
|
||
|
||
def _schema_is_pydantic_v2(Schema: Union[Type[BaseModel], BaseModel]) -> bool: | ||
"""If `model_fields` attr is present, it means we have a schema or instance | ||
of a pydantic v2 BaseModel. Even if we're using Pydantic V2, users could still | ||
import from `pydantic.v1` and that would break our compat checks. | ||
Schema (Union[Type[BaseModel], BaseModel]): Input schema or instance. | ||
RETURNS (bool): True if the pydantic model is a v2 model or not | ||
""" | ||
return hasattr(Schema, "model_fields") | ||
|
||
|
||
def model_validate(Schema: Type[_ModelT], data: Dict[str, Any]) -> _ModelT: | ||
if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema): | ||
return Schema.model_validate(data) # type: ignore | ||
else: | ||
return Schema.validate(data) # type: ignore | ||
|
||
|
||
def model_construct( | ||
Schema: Type[_ModelT], fields_set: Optional[Set[str]], data: Dict[str, Any] | ||
) -> _ModelT: | ||
if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema): | ||
return Schema.model_construct(fields_set, **data) # type: ignore | ||
else: | ||
return Schema.construct(fields_set, **data) # type: ignore | ||
|
||
|
||
def model_dump(instance: BaseModel) -> Dict[str, Any]: | ||
if PYDANTIC_V2 and _schema_is_pydantic_v2(instance): | ||
return instance.model_dump() # type: ignore | ||
else: | ||
return instance.dict() | ||
|
||
|
||
def get_field_annotation(field: FieldInfo) -> Type: | ||
return field.annotation if PYDANTIC_V2 else field.type_ # type: ignore | ||
|
||
|
||
def get_model_fields(Schema: Union[Type[BaseModel], BaseModel]) -> Dict[str, FieldInfo]: | ||
if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema): | ||
return Schema.model_fields # type: ignore | ||
else: | ||
return Schema.__fields__ # type: ignore | ||
|
||
|
||
def get_model_fields_set(Schema: Union[Type[BaseModel], BaseModel]) -> Set[str]: | ||
if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema): | ||
return Schema.model_fields_set # type: ignore | ||
else: | ||
return Schema.__fields_set__ # type: ignore | ||
|
||
|
||
def get_model_extra(instance: BaseModel) -> Dict[str, FieldInfo]: | ||
if PYDANTIC_V2 and _schema_is_pydantic_v2(instance): | ||
return instance.model_extra # type: ignore | ||
else: | ||
return {} | ||
|
||
|
||
def set_model_field(Schema: Type[BaseModel], key: str, field: FieldInfo): | ||
if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema): | ||
Schema.model_fields[key] = field # type: ignore | ||
else: | ||
Schema.__fields__[key] = field # type: ignore | ||
|
||
|
||
def update_from_model_extra( | ||
shallow_result_dict: Dict[str, Any], result: BaseModel | ||
) -> None: | ||
if PYDANTIC_V2 and _schema_is_pydantic_v2(result): | ||
if result.model_extra is not None: # type: ignore | ||
shallow_result_dict.update(result.model_extra) # type: ignore | ||
|
||
|
||
def _safe_is_subclass(cls: type, expected: type) -> bool: | ||
return inspect.isclass(cls) and issubclass(cls, BaseModel) | ||
|
||
|
||
class EmptySchema(BaseModel): | ||
class Config: | ||
extra = "allow" | ||
|
@@ -829,6 +930,7 @@ def _fill( | |
resolve: bool = True, | ||
parent: str = "", | ||
overrides: Dict[str, Dict[str, Any]] = {}, | ||
resolved_object_keys: Set[str] = set(), | ||
) -> Tuple[ | ||
Union[Dict[str, Any], Config], Union[Dict[str, Any], Config], Dict[str, Any] | ||
]: | ||
|
@@ -850,12 +952,14 @@ def _fill( | |
value = overrides[key_parent] | ||
config[key] = value | ||
if cls.is_promise(value): | ||
if key in schema.__fields__ and not resolve: | ||
model_fields = get_model_fields(schema) | ||
if key in model_fields and not resolve: | ||
# If we're not resolving the config, make sure that the field | ||
# expecting the promise is typed Any so it doesn't fail | ||
# validation if it doesn't receive the function return value | ||
field = schema.__fields__[key] | ||
schema.__fields__[key] = copy_model_field(field, Any) | ||
field = model_fields[key] | ||
new_field = copy_model_field(field, Any) | ||
set_model_field(schema, key, new_field) | ||
promise_schema = cls.make_promise_schema(value, resolve=resolve) | ||
filled[key], validation[v_key], final[key] = cls._fill( | ||
value, | ||
|
@@ -864,6 +968,7 @@ def _fill( | |
resolve=resolve, | ||
parent=key_parent, | ||
overrides=overrides, | ||
resolved_object_keys=resolved_object_keys, | ||
) | ||
reg_name, func_name = cls.get_constructor(final[key]) | ||
args, kwargs = cls.parse_args(final[key]) | ||
|
@@ -875,6 +980,11 @@ def _fill( | |
# We don't want to try/except this and raise our own error | ||
# here, because we want the traceback if the function fails. | ||
getter_result = getter(*args, **kwargs) | ||
|
||
if isinstance(getter_result, BaseModel) or is_dataclass( | ||
getter_result | ||
): | ||
resolved_object_keys.add(key) | ||
else: | ||
# We're not resolving and calling the function, so replace | ||
# the getter_result with a Promise class | ||
|
@@ -890,12 +1000,14 @@ def _fill( | |
validation[v_key] = [] | ||
elif hasattr(value, "items"): | ||
field_type = EmptySchema | ||
if key in schema.__fields__: | ||
field = schema.__fields__[key] | ||
field_type = field.type_ | ||
if not isinstance(field.type_, ModelMetaclass): | ||
# If we don't have a pydantic schema and just a type | ||
field_type = EmptySchema | ||
fields = get_model_fields(schema) | ||
if key in fields: | ||
field = fields[key] | ||
annotation = get_field_annotation(field) | ||
if annotation is not None and _safe_is_subclass( | ||
annotation, BaseModel | ||
): | ||
field_type = annotation | ||
filled[key], validation[v_key], final[key] = cls._fill( | ||
value, | ||
field_type, | ||
|
@@ -921,21 +1033,39 @@ def _fill( | |
exclude = [] | ||
if validate: | ||
try: | ||
result = schema.parse_obj(validation) | ||
result = model_validate(schema, validation) | ||
except ValidationError as e: | ||
raise ConfigValidationError( | ||
config=config, errors=e.errors(), parent=parent | ||
) from None | ||
else: | ||
# Same as parse_obj, but without validation | ||
result = schema.construct(**validation) | ||
# Same as model_validate, but without validation | ||
fields_set = set(get_model_fields(schema).keys()) | ||
result = model_construct(schema, fields_set, validation) | ||
# If our schema doesn't allow extra values, we need to filter them | ||
# manually because .construct doesn't parse anything | ||
if schema.Config.extra in (Extra.forbid, Extra.ignore): | ||
fields = schema.__fields__.keys() | ||
exclude = [k for k in result.__fields_set__ if k not in fields] | ||
if get_model_config_extra(schema) in ("forbid", "extra"): | ||
result_field_names = get_model_fields_set(result) | ||
exclude = [ | ||
k for k in dict(result).keys() if k not in result_field_names | ||
] | ||
exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()]) | ||
validation.update(result.dict(exclude=exclude_validation)) | ||
# Do a shallow serialization first | ||
# If any of the sub-objects are Pydantic models, first check if they | ||
# were resolved earlier from a registry. If they weren't resolved | ||
# they are part of a nested schema and need to be serialized with | ||
# model.dict() | ||
# Allows for returning Pydantic models from a registered function | ||
shallow_result_dict = dict(result) | ||
update_from_model_extra(shallow_result_dict, result) | ||
result_dict = {} | ||
for k, v in shallow_result_dict.items(): | ||
if k in exclude_validation: | ||
continue | ||
result_dict[k] = v | ||
if isinstance(v, BaseModel) and k not in resolved_object_keys: | ||
result_dict[k] = model_dump(v) | ||
validation.update(result_dict) | ||
filled, final = cls._update_from_parsed(validation, filled, final) | ||
if exclude: | ||
filled = {k: v for k, v in filled.items() if k not in exclude} | ||
|
@@ -969,6 +1099,8 @@ def _update_from_parsed( | |
# Check numpy first, just in case. Use stringified type so that numpy dependency can be ditched. | ||
elif str(type(value)) == "<class 'numpy.ndarray'>": | ||
final[key] = value | ||
elif isinstance(value, BaseModel) and isinstance(final[key], BaseModel): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This case can occur from the above change to pass through resolved pydantic models |
||
final[key] = value | ||
elif ( | ||
value != final[key] or not isinstance(type(value), type(final[key])) | ||
) and not isinstance(final[key], GeneratorType): | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change might seem a bit weird but it came up because Pydantic now treats Pydantic dataclasses and BaseModels the same way during serialization.
In Pydantic v1, calling
model.dict()
with an instance of adataclass
would not JSON serialize that dataclass, it would just return that dataclass.This allowed for our Optimizer README example to work and resolve that dataclass instance.
e.g. with pydantic v1 this worked still:
However, when swapped to pydantic v2, the final line here was resolving the
optimzer
to a dict.This has actually always been a bug, because previously we could not make
MyCoolOptimizer
into a Pydantic model, it had to be adataclass
(or any other class that wasn't a Pydantic model).With this code above, we could make
MyCoolOptimzer
into a Pydantic model and we'd get a Pydantic model back.