diff --git a/src/anthropic/_utils/_transform.py b/src/anthropic/_utils/_transform.py index d524b329..db40bff2 100644 --- a/src/anthropic/_utils/_transform.py +++ b/src/anthropic/_utils/_transform.py @@ -4,6 +4,8 @@ from datetime import date, datetime from typing_extensions import Literal, get_args, override, get_type_hints +import pydantic + from ._utils import ( is_list, is_mapping, @@ -14,7 +16,7 @@ is_annotated_type, strip_annotated_type, ) -from .._compat import is_typeddict +from .._compat import model_dump, is_typeddict _T = TypeVar("_T") @@ -165,6 +167,9 @@ def _transform_recursive( data = _transform_recursive(data, annotation=annotation, inner_type=subtype) return data + if isinstance(data, pydantic.BaseModel): + return model_dump(data, exclude_unset=True, exclude_defaults=True) + return _transform_value(data, annotation) diff --git a/tests/test_transform.py b/tests/test_transform.py index b7334957..8e1d4724 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -1,10 +1,13 @@ from __future__ import annotations -from typing import List, Union, Optional +from typing import Any, List, Union, Optional from datetime import date, datetime from typing_extensions import Required, Annotated, TypedDict +import pytest + from anthropic._utils import PropertyInfo, transform, parse_datetime +from anthropic._models import BaseModel class Foo1(TypedDict): @@ -186,3 +189,44 @@ class DateDictWithRequiredAlias(TypedDict, total=False): def test_datetime_with_alias() -> None: assert transform({"required_prop": None}, DateDictWithRequiredAlias) == {"prop": None} # type: ignore[comparison-overlap] assert transform({"required_prop": date.fromisoformat("2023-02-23")}, DateDictWithRequiredAlias) == {"prop": "2023-02-23"} # type: ignore[comparison-overlap] + + +class MyModel(BaseModel): + foo: str + + +def test_pydantic_model_to_dictionary() -> None: + assert transform(MyModel(foo="hi!"), Any) == {"foo": "hi!"} + assert transform(MyModel.construct(foo="hi!"), Any) == {"foo": "hi!"} + + +def test_pydantic_empty_model() -> None: + assert transform(MyModel.construct(), Any) == {} + + +def test_pydantic_unknown_field() -> None: + assert transform(MyModel.construct(my_untyped_field=True), Any) == {"my_untyped_field": True} + + +def test_pydantic_mismatched_types() -> None: + model = MyModel.construct(foo=True) + with pytest.warns(UserWarning): + params = transform(model, Any) + assert params == {"foo": True} + + +def test_pydantic_mismatched_object_type() -> None: + model = MyModel.construct(foo=MyModel.construct(hello="world")) + with pytest.warns(UserWarning): + params = transform(model, Any) + assert params == {"foo": {"hello": "world"}} + + +class ModelNestedObjects(BaseModel): + nested: MyModel + + +def test_pydantic_nested_objects() -> None: + model = ModelNestedObjects.construct(nested={"foo": "stainless"}) + assert isinstance(model.nested, MyModel) + assert transform(model, Any) == {"nested": {"foo": "stainless"}}