diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9cbc698..e52a13f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -97,22 +97,22 @@ jobs: - name: Install test requirements with Pydantic v1 run: | python -m pip install -U -r requirements.txt - python -m pip install -U "pydantic<2.0" + python -m pip install -U "pydantic<2.0" "spacy" - name: Run tests for Pydantic v1 run: | python -c "import pydantic; print(pydantic.VERSION)" - python -m pytest --pyargs confection -Werror + python -m pytest --pyargs confection - name: Install test requirements with Pydantic v2 run: | - python -m pip install -U -r requirements.txt + python -m pip install -U -r requirements.txt spacy python -m pip install -U pydantic - name: Run tests for Pydantic v2 run: | python -c "import pydantic; print(pydantic.VERSION)" - python -m pytest --pyargs confection -Werror + python -m pytest --pyargs confection - name: Test for import conflicts with hypothesis run: | diff --git a/confection/__init__.py b/confection/__init__.py index e9dfca6..edcde34 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -14,7 +14,7 @@ NoSectionError, ParsingError, ) -from dataclasses import dataclass +from dataclasses import dataclass, is_dataclass from pathlib import Path from types import GeneratorType from typing import ( @@ -26,21 +26,24 @@ Mapping, Optional, Sequence, + Set, Tuple, Type, + TypeVar, Union, cast, ) +import catalogue import srsly +from pydantic import BaseModel, ValidationError, create_model +from pydantic.fields import FieldInfo -try: - from pydantic.v1 import BaseModel, Extra, ValidationError, create_model - from pydantic.v1.fields import ModelField - from pydantic.v1.main import ModelMetaclass -except ImportError: - from pydantic import BaseModel, create_model, ValidationError, Extra # type: ignore - from pydantic.main import ModelMetaclass # type: ignore +from .util import PYDANTIC_V2, Decorator, SimpleFrozenDict, SimpleFrozenList + +if PYDANTIC_V2: + from pydantic.v1.fields import ModelField # type: ignore +else: from pydantic.fields import ModelField # type: ignore from .util import SimpleFrozenDict, SimpleFrozenList # noqa: F401 @@ -689,10 +692,7 @@ def alias_generator(name: str) -> str: return name -def copy_model_field(field: ModelField, type_: Any) -> ModelField: - """Copy a model field and assign a new type, e.g. to accept an Any type - even though the original value is typed differently. - """ +def _copy_model_field_v1(field: ModelField, type_: Any) -> ModelField: return ModelField( name=field.name, type_=type_, @@ -704,6 +704,107 @@ def copy_model_field(field: ModelField, type_: Any) -> ModelField: ) +def copy_model_field(field: FieldInfo, type_: Any) -> FieldInfo: + """Copy a model field and assign a new type, e.g. to accept an Any type + even though the original value is typed differently. + """ + if PYDANTIC_V2: + field_info = copy.deepcopy(field) + field_info.annotation = type_ # type: ignore + return field_info + else: + return _copy_model_field_v1(field, type_) # type: ignore + + +def get_model_config_extra(model: Type[BaseModel]) -> str: + if PYDANTIC_V2: + extra = str(model.model_config.get("extra", "forbid")) # type: ignore + else: + extra = str(model.Config.extra) or "forbid" # type: ignore + assert isinstance(extra, str) + return extra + + +_ModelT = TypeVar("_ModelT", bound=BaseModel) + + +def _schema_is_pydantic_v2(Schema: Union[Type[BaseModel], BaseModel]) -> bool: + """If `model_fields` attr is present, it means we have a schema or instance + of a pydantic v2 BaseModel. Even if we're using Pydantic V2, users could still + import from `pydantic.v1` and that would break our compat checks. + Schema (Union[Type[BaseModel], BaseModel]): Input schema or instance. + RETURNS (bool): True if the pydantic model is a v2 model or not + """ + return hasattr(Schema, "model_fields") + + +def model_validate(Schema: Type[_ModelT], data: Dict[str, Any]) -> _ModelT: + if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema): + return Schema.model_validate(data) # type: ignore + else: + return Schema.validate(data) # type: ignore + + +def model_construct( + Schema: Type[_ModelT], fields_set: Optional[Set[str]], data: Dict[str, Any] +) -> _ModelT: + if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema): + return Schema.model_construct(fields_set, **data) # type: ignore + else: + return Schema.construct(fields_set, **data) # type: ignore + + +def model_dump(instance: BaseModel) -> Dict[str, Any]: + if PYDANTIC_V2 and _schema_is_pydantic_v2(instance): + return instance.model_dump() # type: ignore + else: + return instance.dict() + + +def get_field_annotation(field: FieldInfo) -> Type: + return field.annotation if PYDANTIC_V2 else field.type_ # type: ignore + + +def get_model_fields(Schema: Union[Type[BaseModel], BaseModel]) -> Dict[str, FieldInfo]: + if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema): + return Schema.model_fields # type: ignore + else: + return Schema.__fields__ # type: ignore + + +def get_model_fields_set(Schema: Union[Type[BaseModel], BaseModel]) -> Set[str]: + if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema): + return Schema.model_fields_set # type: ignore + else: + return Schema.__fields_set__ # type: ignore + + +def get_model_extra(instance: BaseModel) -> Dict[str, FieldInfo]: + if PYDANTIC_V2 and _schema_is_pydantic_v2(instance): + return instance.model_extra # type: ignore + else: + return {} + + +def set_model_field(Schema: Type[BaseModel], key: str, field: FieldInfo): + if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema): + Schema.model_fields[key] = field # type: ignore + else: + Schema.__fields__[key] = field # type: ignore + + +def update_from_model_extra( + shallow_result_dict: Dict[str, Any], result: BaseModel +) -> None: + if PYDANTIC_V2 and _schema_is_pydantic_v2(result): + if result.model_extra is not None: # type: ignore + shallow_result_dict.update(result.model_extra) # type: ignore + + +def _safe_is_subclass(cls: type, expected: type) -> bool: + return inspect.isclass(cls) and issubclass(cls, BaseModel) + + class EmptySchema(BaseModel): class Config: extra = "allow" @@ -829,6 +930,7 @@ def _fill( resolve: bool = True, parent: str = "", overrides: Dict[str, Dict[str, Any]] = {}, + resolved_object_keys: Set[str] = set(), ) -> Tuple[ Union[Dict[str, Any], Config], Union[Dict[str, Any], Config], Dict[str, Any] ]: @@ -850,12 +952,14 @@ def _fill( value = overrides[key_parent] config[key] = value if cls.is_promise(value): - if key in schema.__fields__ and not resolve: + model_fields = get_model_fields(schema) + if key in model_fields and not resolve: # If we're not resolving the config, make sure that the field # expecting the promise is typed Any so it doesn't fail # validation if it doesn't receive the function return value - field = schema.__fields__[key] - schema.__fields__[key] = copy_model_field(field, Any) + field = model_fields[key] + new_field = copy_model_field(field, Any) + set_model_field(schema, key, new_field) promise_schema = cls.make_promise_schema(value, resolve=resolve) filled[key], validation[v_key], final[key] = cls._fill( value, @@ -864,6 +968,7 @@ def _fill( resolve=resolve, parent=key_parent, overrides=overrides, + resolved_object_keys=resolved_object_keys, ) reg_name, func_name = cls.get_constructor(final[key]) args, kwargs = cls.parse_args(final[key]) @@ -875,6 +980,11 @@ def _fill( # We don't want to try/except this and raise our own error # here, because we want the traceback if the function fails. getter_result = getter(*args, **kwargs) + + if isinstance(getter_result, BaseModel) or is_dataclass( + getter_result + ): + resolved_object_keys.add(key) else: # We're not resolving and calling the function, so replace # the getter_result with a Promise class @@ -890,12 +1000,14 @@ def _fill( validation[v_key] = [] elif hasattr(value, "items"): field_type = EmptySchema - if key in schema.__fields__: - field = schema.__fields__[key] - field_type = field.type_ - if not isinstance(field.type_, ModelMetaclass): - # If we don't have a pydantic schema and just a type - field_type = EmptySchema + fields = get_model_fields(schema) + if key in fields: + field = fields[key] + annotation = get_field_annotation(field) + if annotation is not None and _safe_is_subclass( + annotation, BaseModel + ): + field_type = annotation filled[key], validation[v_key], final[key] = cls._fill( value, field_type, @@ -921,21 +1033,39 @@ def _fill( exclude = [] if validate: try: - result = schema.parse_obj(validation) + result = model_validate(schema, validation) except ValidationError as e: raise ConfigValidationError( config=config, errors=e.errors(), parent=parent ) from None else: - # Same as parse_obj, but without validation - result = schema.construct(**validation) + # Same as model_validate, but without validation + fields_set = set(get_model_fields(schema).keys()) + result = model_construct(schema, fields_set, validation) # If our schema doesn't allow extra values, we need to filter them # manually because .construct doesn't parse anything - if schema.Config.extra in (Extra.forbid, Extra.ignore): - fields = schema.__fields__.keys() - exclude = [k for k in result.__fields_set__ if k not in fields] + if get_model_config_extra(schema) in ("forbid", "extra"): + result_field_names = get_model_fields_set(result) + exclude = [ + k for k in dict(result).keys() if k not in result_field_names + ] exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()]) - validation.update(result.dict(exclude=exclude_validation)) + # Do a shallow serialization first + # If any of the sub-objects are Pydantic models, first check if they + # were resolved earlier from a registry. If they weren't resolved + # they are part of a nested schema and need to be serialized with + # model.dict() + # Allows for returning Pydantic models from a registered function + shallow_result_dict = dict(result) + update_from_model_extra(shallow_result_dict, result) + result_dict = {} + for k, v in shallow_result_dict.items(): + if k in exclude_validation: + continue + result_dict[k] = v + if isinstance(v, BaseModel) and k not in resolved_object_keys: + result_dict[k] = model_dump(v) + validation.update(result_dict) filled, final = cls._update_from_parsed(validation, filled, final) if exclude: filled = {k: v for k, v in filled.items() if k not in exclude} @@ -969,6 +1099,8 @@ def _update_from_parsed( # Check numpy first, just in case. Use stringified type so that numpy dependency can be ditched. elif str(type(value)) == "": final[key] = value + elif isinstance(value, BaseModel) and isinstance(final[key], BaseModel): + final[key] = value elif ( value != final[key] or not isinstance(type(value), type(final[key])) ) and not isinstance(final[key], GeneratorType): diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index 4f80fce..b177ca5 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -1,20 +1,17 @@ +# type: ignore import inspect import pickle import platform from types import GeneratorType -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union import catalogue import pytest +from pydantic import BaseModel, PositiveInt, StrictFloat +from pydantic.fields import Field +from pydantic.types import StrictBool -try: - from pydantic.v1 import BaseModel, PositiveInt, StrictFloat, constr - from pydantic.v1.types import StrictBool -except ImportError: - from pydantic import BaseModel, StrictFloat, PositiveInt, constr # type: ignore - from pydantic.types import StrictBool # type: ignore - -from confection import Config, ConfigValidationError +from confection import PYDANTIC_V2, Config, ConfigValidationError, get_model_fields from confection.tests.util import Cat, make_tempdir, my_registry from confection.util import Generator, partial @@ -49,7 +46,7 @@ """ -OPTIMIZER_CFG = """ +OPTIMIZER_DATACLASS_CFG = """ [optimizer] @optimizers = "Adam.v1" beta1 = 0.9 @@ -64,6 +61,27 @@ """ +OPTIMIZER_PYDANTIC_CFG = """ +[optimizer] +@optimizers = "Adam.pydantic.v1" +beta1 = 0.9 +beta2 = 0.999 +use_averages = true + +[optimizer.learn_rate] +@schedules = "warmup_linear.v1" +initial_rate = 0.1 +warmup_steps = 10000 +total_steps = 100000 +""" + + +if PYDANTIC_V2: + INT_PARSING_ERROR_TYPE = "int_parsing" +else: + INT_PARSING_ERROR_TYPE = "type_error.integer" + + class HelloIntsSchema(BaseModel): hello: int world: int @@ -107,7 +125,7 @@ def test_invalidate_simple_config(): my_registry._fill(invalid_config, HelloIntsSchema) error = exc_info.value assert len(error.errors) == 1 - assert "type_error.integer" in error.error_types + assert INT_PARSING_ERROR_TYPE in error.error_types def test_invalidate_extra_args(): @@ -157,8 +175,9 @@ def test_parse_args(): def test_make_promise_schema(): schema = my_registry.make_promise_schema(good_catsie) - assert "evil" in schema.__fields__ - assert "cute" in schema.__fields__ + model_fields = get_model_fields(schema) + assert "evil" in model_fields + assert "cute" in model_fields def test_validate_promise(): @@ -219,8 +238,7 @@ class TestBaseSchema(BaseModel): one: PositiveInt two: TestBaseSubSchema - class Config: - extra = "forbid" + model_config = {"extra": "forbid"} class TestSchema(BaseModel): cfg: TestBaseSchema @@ -237,12 +255,19 @@ class TestSchema(BaseModel): my_registry.resolve({"cfg": config}, schema=TestSchema) +@pytest.mark.skipif( + PYDANTIC_V2, + reason="In Pydantic v2, int/float cannot be coerced to str so this test will fail.", +) def test_resolve_schema_coerced(): class TestBaseSchema(BaseModel): test1: str test2: bool test3: float + class Config: + strict = False + class TestSchema(BaseModel): cfg: TestBaseSchema @@ -264,17 +289,25 @@ def test_read_config(): assert cfg["pipeline"]["classifier"]["model"]["embedding"]["width"] == 128 -def test_optimizer_config(): - cfg = Config().from_str(OPTIMIZER_CFG) +@pytest.mark.parametrize( + "optimizer_cfg_str", + [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG], + ids=["dataclasses", "pydantic"], +) +def test_optimizer_config(optimizer_cfg_str: str): + cfg = Config().from_str(optimizer_cfg_str) optimizer = my_registry.resolve(cfg, validate=True)["optimizer"] assert optimizer.beta1 == 0.9 -def test_config_to_str(): - cfg = Config().from_str(OPTIMIZER_CFG) - assert cfg.to_str().strip() == OPTIMIZER_CFG.strip() - cfg = Config({"optimizer": {"foo": "bar"}}).from_str(OPTIMIZER_CFG) - assert cfg.to_str().strip() == OPTIMIZER_CFG.strip() +@pytest.mark.parametrize( + "optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG] +) +def test_config_to_str(optimizer_cfg_str: str): + cfg = Config().from_str(optimizer_cfg_str) + assert cfg.to_str().strip() == optimizer_cfg_str.strip() + cfg = Config({"optimizer": {"foo": "bar"}}).from_str(optimizer_cfg_str) + assert cfg.to_str().strip() == optimizer_cfg_str.strip() def test_config_to_str_creates_intermediate_blocks(): @@ -290,28 +323,39 @@ def test_config_to_str_creates_intermediate_blocks(): ) -def test_config_roundtrip_bytes(): - cfg = Config().from_str(OPTIMIZER_CFG) +@pytest.mark.parametrize( + "optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG] +) +def test_config_roundtrip_bytes(optimizer_cfg_str: str): + cfg = Config().from_str(optimizer_cfg_str) cfg_bytes = cfg.to_bytes() new_cfg = Config().from_bytes(cfg_bytes) - assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() + assert new_cfg.to_str().strip() == optimizer_cfg_str.strip() -def test_config_roundtrip_disk(): - cfg = Config().from_str(OPTIMIZER_CFG) +@pytest.mark.parametrize( + "optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG] +) +def test_config_roundtrip_disk(optimizer_cfg_str: str): + cfg = Config().from_str(optimizer_cfg_str) with make_tempdir() as path: cfg_path = path / "config.cfg" cfg.to_disk(cfg_path) new_cfg = Config().from_disk(cfg_path) - assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() + assert new_cfg.to_str().strip() == optimizer_cfg_str.strip() -def test_config_roundtrip_disk_respects_path_subclasses(pathy_fixture): - cfg = Config().from_str(OPTIMIZER_CFG) +@pytest.mark.parametrize( + "optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG] +) +def test_config_roundtrip_disk_respects_path_subclasses( + pathy_fixture, optimizer_cfg_str: str +): + cfg = Config().from_str(optimizer_cfg_str) cfg_path = pathy_fixture / "config.cfg" cfg.to_disk(cfg_path) new_cfg = Config().from_disk(cfg_path) - assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() + assert new_cfg.to_str().strip() == optimizer_cfg_str.strip() def test_config_to_str_invalid_defaults(): @@ -328,10 +372,15 @@ def test_config_to_str_invalid_defaults(): def test_validation_custom_types(): + if PYDANTIC_V2: + log_field = Field("ERROR", pattern="(DEBUG|INFO|WARNING|ERROR)") + else: + log_field = Field("ERROR", regex="(DEBUG|INFO|WARNING|ERROR)") + def complex_args( rate: StrictFloat, - steps: PositiveInt = 10, # type: ignore - log_level: constr(regex="(DEBUG|INFO|WARNING|ERROR)") = "ERROR", # noqa: F821 + steps: PositiveInt = 10, + log_level: str = log_field, ): return None @@ -583,10 +632,10 @@ def test_schedule(): cfg = {"@schedules": "test_schedule.v2"} result = my_registry.resolve({"test": cfg})["test"] - assert isinstance(result, GeneratorType) + assert isinstance(result, Iterator) @my_registry.optimizers("test_optimizer.v2") - def test_optimizer2(rate: Generator) -> Generator: + def test_optimizer2(rate: Iterable[float]) -> Iterable[float]: return rate cfg = { @@ -594,10 +643,10 @@ def test_optimizer2(rate: Generator) -> Generator: "rate": {"@schedules": "test_schedule.v2"}, } result = my_registry.resolve({"test": cfg})["test"] - assert isinstance(result, GeneratorType) + assert isinstance(result, Iterator) @my_registry.optimizers("test_optimizer.v3") - def test_optimizer3(schedules: Dict[str, Generator]) -> Generator: + def test_optimizer3(schedules: Dict[str, Iterable[float]]) -> Iterable[float]: return schedules["rate"] cfg = { @@ -605,10 +654,10 @@ def test_optimizer3(schedules: Dict[str, Generator]) -> Generator: "schedules": {"rate": {"@schedules": "test_schedule.v2"}}, } result = my_registry.resolve({"test": cfg})["test"] - assert isinstance(result, GeneratorType) + assert isinstance(result, Iterator) @my_registry.optimizers("test_optimizer.v4") - def test_optimizer4(*schedules: Generator) -> Generator: + def test_optimizer4(*schedules: Iterable[float]) -> Iterable[float]: return schedules[0] @@ -1253,9 +1302,9 @@ class Schema(BaseModel): assert e1.show_config is True assert len(e1.errors) == 1 assert e1.errors[0]["loc"] == ("world",) - assert e1.errors[0]["msg"] == "value is not a valid integer" - assert e1.errors[0]["type"] == "type_error.integer" - assert e1.error_types == set(["type_error.integer"]) + assert e1.errors[0]["type"] == INT_PARSING_ERROR_TYPE + assert e1.error_types == {INT_PARSING_ERROR_TYPE} + # Create a new error with overrides title = "Custom error" desc = "Some error description here" @@ -1285,7 +1334,7 @@ class BaseSchema(BaseModel): assert filled["catsie"]["cute"] is True with pytest.raises(ConfigValidationError): my_registry.resolve(config, schema=BaseSchema) - filled2 = my_registry.fill(config, schema=BaseSchema) + filled2 = my_registry.fill(config, schema=BaseSchema, validate=False) assert filled2["catsie"]["cute"] is True resolved = my_registry.resolve(filled2) assert resolved["catsie"] == "meow" @@ -1408,3 +1457,14 @@ def test_parse_strings_interpretable_as_ints(): ) assert cfg["a"]["foo"] == [3, "003", "y"] assert cfg["b"]["bar"] == 3 + + +def test_spacy_init_config(): + """Regression test to ensure spacy init config works""" + try: + from spacy.cli import init_config + except ImportError: + pytest.skip("SpaCy not installed") + + config = init_config(pipeline=["tagger"]) + assert isinstance(config, Config) diff --git a/confection/tests/util.py b/confection/tests/util.py index 5f56f80..c233a9c 100644 --- a/confection/tests/util.py +++ b/confection/tests/util.py @@ -6,18 +6,15 @@ import shutil import tempfile from pathlib import Path -from typing import Generator, Generic, Iterable, List, Optional, TypeVar, Union +from typing import Generic, Iterable, List, Optional, TypeVar, Union import catalogue - -try: - from pydantic.v1.types import StrictBool -except ImportError: - from pydantic.types import StrictBool # type: ignore +from pydantic import BaseModel +from pydantic.types import StrictBool import confection -FloatOrSeq = Union[float, List[float], Generator] +FloatOrSeq = Union[float, Iterable[float]] InT = TypeVar("InT") OutT = TypeVar("OutT") @@ -89,6 +86,30 @@ class Optimizer: ) +@my_registry.optimizers("Adam.pydantic.v1") +def Adam_pydantic( + learn_rate: FloatOrSeq = 0.001, + *, + beta1: FloatOrSeq = 0.001, + beta2: FloatOrSeq = 0.001, + use_averages: bool = True, +): + """ + Mocks optimizer generation. Note that the returned object is not actually an optimizer. This function is merely used + to illustrate how to use the function registry, e.g. with thinc. + """ + + class Optimizer(BaseModel): + learn_rate: FloatOrSeq + beta1: FloatOrSeq + beta2: FloatOrSeq + use_averages: bool + + return Optimizer( + learn_rate=learn_rate, beta1=beta1, beta2=beta2, use_averages=use_averages + ) + + @my_registry.schedules("warmup_linear.v1") def warmup_linear( initial_rate: float, warmup_steps: int, total_steps: int diff --git a/confection/util.py b/confection/util.py index d204118..a81c152 100644 --- a/confection/util.py +++ b/confection/util.py @@ -2,6 +2,8 @@ import sys from typing import Any, Callable, Iterator, TypeVar +from pydantic.version import VERSION as PYDANTIC_VERSION + if sys.version_info < (3, 8): # Ignoring type for mypy to avoid "Incompatible import" error (https://github.com/python/mypy/issues/4427). from typing_extensions import Protocol # type: ignore @@ -9,6 +11,7 @@ from typing import Protocol _DIn = TypeVar("_DIn") +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") class Decorator(Protocol): diff --git a/pyproject.toml b/pyproject.toml index d1ecc2a..3aeecb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,3 +6,9 @@ build-backend = "setuptools.build_meta" [tool.isort] profile = "black" + +[tool.pytest.ini_options] +filterwarnings = [ + "error", + "ignore:^.*Support for class-based `config` is deprecated, use ConfigDict instead.*:DeprecationWarning" +]