Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(internal): fix typing util function #310

Merged
merged 1 commit into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions src/anthropic/_utils/_typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, cast
from typing import Any, TypeVar, cast
from typing_extensions import Required, Annotated, get_args, get_origin

from .._types import InheritsGeneric
Expand All @@ -23,6 +23,12 @@ def is_required_type(typ: type) -> bool:
return get_origin(typ) == Required


def is_typevar(typ: type) -> bool:
# type ignore is required because type checkers
# think this expression will always return False
return type(typ) == TypeVar # type: ignore


# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
def strip_annotated_type(typ: type) -> type:
if is_required_type(typ) or is_annotated_type(typ):
Expand All @@ -49,6 +55,15 @@ class MyResponse(Foo[bytes]):

extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
```

And where a generic subclass is given:
```py
_T = TypeVar('_T')
class MyResponse(Foo[_T]):
...

extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes
```
"""
cls = cast(object, get_origin(typ) or typ)
if cls in generic_bases:
Expand All @@ -75,6 +90,18 @@ class MyResponse(Foo[bytes]):
f"Does {cls} inherit from one of {generic_bases} ?"
)

return extract_type_arg(target_base_class, index)
extracted = extract_type_arg(target_base_class, index)
if is_typevar(extracted):
# If the extracted type argument is itself a type variable
# then that means the subclass itself is generic, so we have
# to resolve the type argument from the class itself, not
# the base class.
#
# Note: if there is more than 1 type argument, the subclass could
# change the ordering of the type arguments, this is not currently
# supported.
return extract_type_arg(typ, index)

return extracted

raise RuntimeError(f"Could not resolve inner type variable at index {index} for {typ}")
78 changes: 78 additions & 0 deletions tests/test_utils/test_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

from typing import Generic, TypeVar, cast

from anthropic._utils import extract_type_var_from_base

_T = TypeVar("_T")
_T2 = TypeVar("_T2")
_T3 = TypeVar("_T3")


class BaseGeneric(Generic[_T]):
...


class SubclassGeneric(BaseGeneric[_T]):
...


class BaseGenericMultipleTypeArgs(Generic[_T, _T2, _T3]):
...


class SubclassGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T, _T2, _T3]):
...


class SubclassDifferentOrderGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T2, _T, _T3]):
...


def test_extract_type_var() -> None:
assert (
extract_type_var_from_base(
BaseGeneric[int],
index=0,
generic_bases=cast("tuple[type, ...]", (BaseGeneric,)),
)
== int
)


def test_extract_type_var_generic_subclass() -> None:
assert (
extract_type_var_from_base(
SubclassGeneric[int],
index=0,
generic_bases=cast("tuple[type, ...]", (BaseGeneric,)),
)
== int
)


def test_extract_type_var_multiple() -> None:
typ = BaseGenericMultipleTypeArgs[int, str, None]

generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,))
assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int
assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str
assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None)


def test_extract_type_var_generic_subclass_multiple() -> None:
typ = SubclassGenericMultipleTypeArgs[int, str, None]

generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,))
assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int
assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str
assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None)


def test_extract_type_var_generic_subclass_different_ordering_multiple() -> None:
typ = SubclassDifferentOrderGenericMultipleTypeArgs[int, str, None]

generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,))
assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int
assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str
assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None)