Skip to content

Commit

Permalink
Simplify some tests to reduce missing code coverage, and move utils a…
Browse files Browse the repository at this point in the history
…round.
  • Loading branch information
mattalbr committed Aug 10, 2023
1 parent 0af7d24 commit aab8a82
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 30 deletions.
11 changes: 4 additions & 7 deletions strawberry/private.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import TypeVar
from typing_extensions import Annotated, get_args, get_origin
from typing_extensions import Annotated

from strawberry.utils.typing import type_has_annotation


class StrawberryPrivate:
Expand All @@ -22,9 +24,4 @@ class StrawberryPrivate:


def is_private(type_: object) -> bool:
if get_origin(type_) is Annotated:
return any(
isinstance(argument, StrawberryPrivate) for argument in get_args(type_)
)

return False
return type_has_annotation(type_, StrawberryPrivate)
9 changes: 1 addition & 8 deletions strawberry/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Union,
overload,
)
from typing_extensions import Annotated, Literal, Protocol, Self, get_args, get_origin
from typing_extensions import Literal, Protocol, Self

from strawberry.utils.typing import is_concrete_generic

Expand Down Expand Up @@ -229,10 +229,3 @@ def get_object_definition(
if strict and definition is None:
raise TypeError(f"{obj!r} does not have a StrawberryObjectDefinition")
return definition


def is_annotated_type(type_: object, annotated: Type) -> bool:
if get_origin(type_) is Annotated:
return any(isinstance(argument, annotated) for argument in get_args(type_))

return False
5 changes: 3 additions & 2 deletions strawberry/types/fields/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
MissingArgumentsAnnotationsError,
)
from strawberry.parent import StrawberryParent
from strawberry.type import StrawberryType, has_object_definition, is_annotated_type
from strawberry.type import StrawberryType, has_object_definition
from strawberry.types.info import Info
from strawberry.utils.cached_property import cached_property
from strawberry.utils.typing import type_has_annotation

if TYPE_CHECKING:
import builtins
Expand Down Expand Up @@ -140,7 +141,7 @@ def is_reserved_type(self, other: builtins.type) -> bool:
origin = cast(type, get_origin(other)) or other
if origin is Annotated:
# Handle annotated arguments such as Private[str] and DirectiveValue[str]
return is_annotated_type(other, self.type)
return type_has_annotation(other, self.type)
else:
# Handle both concrete and generic types (i.e Info, and Info[Any, Any])
return (
Expand Down
8 changes: 8 additions & 0 deletions strawberry/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@ def is_classvar(cls: type, annotation: Union[ForwardRef, str]) -> bool:
)


def type_has_annotation(type_: object, annotation: Type) -> bool:
"""Returns True if the type_ has been annotated with annotation."""
if get_origin(type_) is Annotated:
return any(isinstance(argument, annotation) for argument in get_args(type_))

return False


def get_parameters(annotation: Type) -> Union[Tuple[object], Tuple[()]]:
if (
isinstance(annotation, _GenericAlias)
Expand Down
17 changes: 4 additions & 13 deletions tests/schema/test_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,12 +593,12 @@ def user(self, user_id: str) -> User:
assert result.data["user"]["name"] == "User 🍓"


def multiple_parents(user: Parent[UserLiteral], user2: Parent[UserLiteral]) -> str:
return f"User {user.id}"
def multiple_parents(user: Parent[Any], user2: Parent[Any]) -> str:
raise AssertionError("Unreachable code.")


def multiple_infos(root, info1: Info, info2: Info) -> str:
return f"User {root.id}"
raise AssertionError("Unreachable code.")


@pytest.mark.parametrize(
Expand All @@ -609,18 +609,9 @@ def multiple_infos(root, info1: Info, info2: Info) -> str:
),
)
def test_multiple_conflicting_reserved_arguments(resolver):
@strawberry.type
class User:
id: str

name: str = strawberry.field(resolver=resolver)

@strawberry.type
class Query:
@strawberry.field
@staticmethod
def user(self, user_id: str) -> User:
return UserLiteral(user_id)
name: str = strawberry.field(resolver=resolver)

# Would be awesome to give a more helpful error here, but c'est la vie.
with pytest.raises(TypeError):
Expand Down

0 comments on commit aab8a82

Please sign in to comment.