From 3671aa6b3b05776b727a727020366bb6c349f66a Mon Sep 17 00:00:00 2001 From: Stainless Bot <107565488+stainless-bot@users.noreply.github.com> Date: Wed, 17 Jan 2024 09:14:09 -0500 Subject: [PATCH] chore(internal): fix typing util function (#310) --- src/anthropic/_utils/_typing.py | 31 ++++++++++++- tests/test_utils/test_typing.py | 78 +++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 2 deletions(-) create mode 100644 tests/test_utils/test_typing.py diff --git a/src/anthropic/_utils/_typing.py b/src/anthropic/_utils/_typing.py index b5e2c2e3..a020822b 100644 --- a/src/anthropic/_utils/_typing.py +++ b/src/anthropic/_utils/_typing.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, cast +from typing import Any, TypeVar, cast from typing_extensions import Required, Annotated, get_args, get_origin from .._types import InheritsGeneric @@ -23,6 +23,12 @@ def is_required_type(typ: type) -> bool: return get_origin(typ) == Required +def is_typevar(typ: type) -> bool: + # type ignore is required because type checkers + # think this expression will always return False + return type(typ) == TypeVar # type: ignore + + # Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]] def strip_annotated_type(typ: type) -> type: if is_required_type(typ) or is_annotated_type(typ): @@ -49,6 +55,15 @@ class MyResponse(Foo[bytes]): extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes ``` + + And where a generic subclass is given: + ```py + _T = TypeVar('_T') + class MyResponse(Foo[_T]): + ... + + extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes + ``` """ cls = cast(object, get_origin(typ) or typ) if cls in generic_bases: @@ -75,6 +90,18 @@ class MyResponse(Foo[bytes]): f"Does {cls} inherit from one of {generic_bases} ?" ) - return extract_type_arg(target_base_class, index) + extracted = extract_type_arg(target_base_class, index) + if is_typevar(extracted): + # If the extracted type argument is itself a type variable + # then that means the subclass itself is generic, so we have + # to resolve the type argument from the class itself, not + # the base class. + # + # Note: if there is more than 1 type argument, the subclass could + # change the ordering of the type arguments, this is not currently + # supported. + return extract_type_arg(typ, index) + + return extracted raise RuntimeError(f"Could not resolve inner type variable at index {index} for {typ}") diff --git a/tests/test_utils/test_typing.py b/tests/test_utils/test_typing.py new file mode 100644 index 00000000..730a06b6 --- /dev/null +++ b/tests/test_utils/test_typing.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import Generic, TypeVar, cast + +from anthropic._utils import extract_type_var_from_base + +_T = TypeVar("_T") +_T2 = TypeVar("_T2") +_T3 = TypeVar("_T3") + + +class BaseGeneric(Generic[_T]): + ... + + +class SubclassGeneric(BaseGeneric[_T]): + ... + + +class BaseGenericMultipleTypeArgs(Generic[_T, _T2, _T3]): + ... + + +class SubclassGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T, _T2, _T3]): + ... + + +class SubclassDifferentOrderGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T2, _T, _T3]): + ... + + +def test_extract_type_var() -> None: + assert ( + extract_type_var_from_base( + BaseGeneric[int], + index=0, + generic_bases=cast("tuple[type, ...]", (BaseGeneric,)), + ) + == int + ) + + +def test_extract_type_var_generic_subclass() -> None: + assert ( + extract_type_var_from_base( + SubclassGeneric[int], + index=0, + generic_bases=cast("tuple[type, ...]", (BaseGeneric,)), + ) + == int + ) + + +def test_extract_type_var_multiple() -> None: + typ = BaseGenericMultipleTypeArgs[int, str, None] + + generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,)) + assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int + assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str + assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None) + + +def test_extract_type_var_generic_subclass_multiple() -> None: + typ = SubclassGenericMultipleTypeArgs[int, str, None] + + generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,)) + assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int + assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str + assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None) + + +def test_extract_type_var_generic_subclass_different_ordering_multiple() -> None: + typ = SubclassDifferentOrderGenericMultipleTypeArgs[int, str, None] + + generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,)) + assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int + assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str + assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None)