Skip to content

Commit

Permalink
feat: add support to msgspec for kw_only=True (#2162)
Browse files Browse the repository at this point in the history
Co-authored-by: Koudai Aono <[email protected]>
  • Loading branch information
indrat and koxudaxi authored Nov 23, 2024
1 parent 957ce04 commit 41a37e0
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 3 deletions.
7 changes: 6 additions & 1 deletion datamodel_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,13 @@ def validate_custom_file_header(cls, values: Dict[str, Any]) -> Dict[str, Any]:

@model_validator(mode='after')
def validate_keyword_only(cls, values: Dict[str, Any]) -> Dict[str, Any]:
output_model_type: DataModelType = values.get('output_model_type')
python_target: PythonVersion = values.get('target_python_version')
if values.get('keyword_only') and not python_target.has_kw_only_dataclass:
if (
values.get('keyword_only')
and output_model_type == DataModelType.DataclassesDataclass
and not python_target.has_kw_only_dataclass
):
raise Error(
f'`--keyword-only` requires `--target-python-version` {PythonVersion.PY_310.value} or higher.'
)
Expand Down
5 changes: 3 additions & 2 deletions datamodel_code_generator/model/msgspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
IMPORT_MSGSPEC_CONVERT,
IMPORT_MSGSPEC_FIELD,
IMPORT_MSGSPEC_META,
IMPORT_MSGSPEC_STRUCT,
)
from datamodel_code_generator.model.pydantic.base_model import (
Constraints as _Constraints,
Expand Down Expand Up @@ -88,7 +87,7 @@ class RootModel(_RootModel):
class Struct(DataModel):
TEMPLATE_FILE_PATH: ClassVar[str] = 'msgspec.jinja2'
BASE_CLASS: ClassVar[str] = 'msgspec.Struct'
DEFAULT_IMPORTS: ClassVar[Tuple[Import, ...]] = (IMPORT_MSGSPEC_STRUCT,)
DEFAULT_IMPORTS: ClassVar[Tuple[Import, ...]] = ()

def __init__(
self,
Expand Down Expand Up @@ -123,6 +122,8 @@ def __init__(
keyword_only=keyword_only,
)
self.extra_template_data.setdefault('base_class_kwargs', {})
if self.keyword_only:
self.add_base_class_kwarg('kw_only', 'True')

def add_base_class_kwarg(self, name: str, value):
self.extra_template_data['base_class_kwargs'][name] = value
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# generated by datamodel-codegen:
# filename: referenced.json
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import Optional

from msgspec import Struct


class Model(Struct):
some_optional_property: Optional[str] = None
some_optional_typed_property: Optional[str] = None
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# generated by datamodel-codegen:
# filename: required.json
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from .referenced import Model as Model_1


class Model(Model_1):
some_optional_property: str
some_optional_typed_property: str
20 changes: 20 additions & 0 deletions tests/data/expected/main/openapi/msgspec_keyword_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# generated by datamodel-codegen:
# filename: inheritance.yaml
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import Optional

from msgspec import Struct


class Base(Struct, kw_only=True):
id: str
createdAt: Optional[str] = None
version: Optional[float] = 1


class Child(Base, kw_only=True):
title: str
url: Optional[str] = 'https://example.com'
4 changes: 4 additions & 0 deletions tests/main/jsonschema/test_main_jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,10 @@ def test_main_require_referenced_field_naivedatetime(output_model, expected_outp
'pydantic_v2.BaseModel',
'require_referenced_field',
),
(
'msgspec.Struct',
'require_referenced_field_msgspec',
),
],
)
@freeze_time('2019-07-26')
Expand Down
30 changes: 30 additions & 0 deletions tests/main/openapi/test_main_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2993,3 +2993,33 @@ def test_main_openapi_dataclass_with_NaiveDatetime(capsys: CaptureFixture):
captured.err
== '`--output-datetime-class` only allows "datetime" for `--output-model-type` dataclasses.dataclass\n'
)


@freeze_time('2019-07-26')
@pytest.mark.skipif(
black.__version__.split('.')[0] == '19',
reason="Installed black doesn't support the old style",
)
def test_main_openapi_keyword_only_msgspec():
with TemporaryDirectory() as output_dir:
output_file: Path = Path(output_dir) / 'output.py'
return_code: Exit = main(
[
'--input',
str(OPEN_API_DATA_PATH / 'inheritance.yaml'),
'--output',
str(output_file),
'--input-file-type',
'openapi',
'--output-model-type',
'msgspec.Struct',
'--keyword-only',
'--target-python-version',
'3.8',
]
)
assert return_code == Exit.OK
assert (
output_file.read_text()
== (EXPECTED_OPENAPI_PATH / 'msgspec_keyword_only.py').read_text()
)

0 comments on commit 41a37e0

Please sign in to comment.