Skip to content

Commit

Permalink
refactor: Unified handling of annotated types in type engine
Browse files Browse the repository at this point in the history
Issue: #3424
Signed-off-by: Adrian Rumpold <[email protected]>
  • Loading branch information
AdrianoKF committed May 12, 2023
1 parent c7e866a commit 093536f
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 32 deletions.
48 changes: 27 additions & 21 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
return self._to_literal_transformer(python_val)

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
if get_origin(expected_python_type) is Annotated:
expected_python_type = get_args(expected_python_type)[0]
expected_python_type = get_underlying_type(expected_python_type)

if expected_python_type != self._type:
raise TypeTransformerFailedError(
Expand Down Expand Up @@ -311,7 +310,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
Extracts the Literal type definition for a Dataclass and returns a type Struct.
If possible also extracts the JSONSchema for the dataclass.
"""
if get_origin(t) is Annotated:
if is_annotated(t):
raise ValueError(
"Flytekit does not currently have support for FlyteAnnotations applied to Dataclass."
f"Type {t} cannot be parsed."
Expand Down Expand Up @@ -368,7 +367,7 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]:
self._get_origin_type_in_annotation(get_args(python_type)[0]),
self._get_origin_type_in_annotation(get_args(python_type)[1]),
]
elif get_origin(python_type) is Annotated:
elif is_annotated(python_type):
return get_args(python_type)[0]
elif dataclasses.is_dataclass(python_type):
for field in dataclasses.fields(copy.deepcopy(python_type)):
Expand Down Expand Up @@ -734,8 +733,7 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
"""
cls.lazy_import_transformers()
# Step 1
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]
python_type = get_underlying_type(python_type)

if python_type in cls._REGISTRY:
return cls._REGISTRY[python_type]
Expand All @@ -744,7 +742,7 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
if hasattr(python_type, "__origin__"):
# Handling of annotated generics, eg:
# Annotated[typing.List[int], 'foo']
if get_origin(python_type) is Annotated:
if is_annotated(python_type):
return cls.get_transformer(get_args(python_type)[0])

if python_type.__origin__ in cls._REGISTRY:
Expand Down Expand Up @@ -815,7 +813,7 @@ def to_literal_type(cls, python_type: Type) -> LiteralType:
transformer = cls.get_transformer(python_type)
res = transformer.get_literal_type(python_type)
data = None
if get_origin(python_type) is Annotated:
if is_annotated(python_type):
for x in get_args(python_type)[1:]:
if not isinstance(x, FlyteAnnotation):
continue
Expand Down Expand Up @@ -843,9 +841,9 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type

# In case the value is an annotated type we inspect the annotations and look for hash-related annotations.
hash = None
if get_origin(python_type) is Annotated:
if is_annotated(python_type):
# We are now dealing with one of two cases:
# 1. The annotated type is a `HashMethod`, which indicates that we should we should produce the hash using
# 1. The annotated type is a `HashMethod`, which indicates that we should produce the hash using
# the method indicated in the annotation.
# 2. The annotated type is being used for a different purpose other than calculating hash values, in which case
# we should just continue.
Expand All @@ -872,7 +870,7 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T
@classmethod
def to_html(cls, ctx: FlyteContext, python_val: typing.Any, expected_python_type: Type[typing.Any]) -> str:
transformer = cls.get_transformer(expected_python_type)
if get_origin(expected_python_type) is Annotated:
if is_annotated(expected_python_type):
expected_python_type, *annotate_args = get_args(expected_python_type)
from flytekit.deck.renderer import Renderable

Expand Down Expand Up @@ -996,7 +994,7 @@ def get_sub_type(t: Type[T]) -> Type[T]:
if hasattr(t, "__origin__"):
# Handle annotation on list generic, eg:
# Annotated[typing.List[int], 'foo']
if get_origin(t) is Annotated:
if is_annotated(t):
return ListTransformer.get_sub_type(get_args(t)[0])

if getattr(t, "__origin__") is list and hasattr(t, "__args__"):
Expand All @@ -1022,7 +1020,7 @@ def is_batchable(t: Type):
"""
from flytekit.types.pickle import FlytePickle

if get_origin(t) is Annotated:
if is_annotated(t):
return ListTransformer.is_batchable(get_args(t)[0])
if get_origin(t) is list:
subtype = get_args(t)[0]
Expand All @@ -1039,7 +1037,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp

batch_size = len(python_val) # default batch size
# parse annotated to get the number of items saved in a pickle file.
if get_origin(python_type) is Annotated:
if is_annotated(python_type):
for annotation in get_args(python_type)[1:]:
if isinstance(annotation, BatchSize):
batch_size = annotation.val
Expand Down Expand Up @@ -1183,8 +1181,7 @@ def get_sub_type_in_optional(t: Type[T]) -> Type[T]:
return get_args(t)[0]

def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
if get_origin(t) is Annotated:
t = get_args(t)[0]
t = get_underlying_type(t)

try:
trans: typing.List[typing.Tuple[TypeTransformer, typing.Any]] = [
Expand All @@ -1198,8 +1195,7 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
raise ValueError(f"Type of Generic Union type is not supported, {e}")

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]
python_type = get_underlying_type(python_type)

found_res = False
res = None
Expand All @@ -1224,8 +1220,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
raise TypeTransformerFailedError(f"Cannot convert from {python_val} to {python_type}")

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[typing.Any]:
if get_origin(expected_python_type) is Annotated:
expected_python_type = get_args(expected_python_type)[0]
expected_python_type = get_underlying_type(expected_python_type)

union_tag = None
union_type = None
Expand Down Expand Up @@ -1460,7 +1455,7 @@ def __init__(self):
super().__init__(name="DefaultEnumTransformer", t=enum.Enum)

def get_literal_type(self, t: Type[T]) -> LiteralType:
if get_origin(t) is Annotated:
if is_annotated(t):
raise ValueError(
f"Flytekit does not currently have support \
for FlyteAnnotations applied to enums. {t} cannot be \
Expand Down Expand Up @@ -1774,3 +1769,14 @@ def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any:


_register_default_type_transformers()


def is_annotated(t: Type) -> bool:
return get_origin(t) is Annotated


def get_underlying_type(t: Type) -> Type:
"""Return the underlying type for annotated types or the type itself"""
if is_annotated(t):
return get_args(t)[0]
return t
14 changes: 3 additions & 11 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@

from dataclasses_json import config, dataclass_json
from marshmallow import fields
from typing_extensions import Annotated, Type, get_args, get_origin

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_underlying_type
from flytekit.loggers import logger
from flytekit.models.core.types import BlobType
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
Expand All @@ -26,13 +25,6 @@ def noop():
T = typing.TypeVar("T")


def _get_origin_type(t: Type) -> Type:
"""Return the origin type for annotated types or the type itself, such that it can be used with ``issubclass()``"""
if get_origin(t) is Annotated:
return get_args(t)[0]
return t


@dataclass_json
@dataclass
class FlyteFile(os.PathLike, typing.Generic[T]):
Expand Down Expand Up @@ -344,7 +336,7 @@ def to_literal(
raise TypeTransformerFailedError("None value cannot be converted to a file.")

# Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type
python_type = _get_origin_type(python_type)
python_type = get_underlying_type(python_type)

if not (python_type is os.PathLike or issubclass(python_type, FlyteFile)):
raise ValueError(f"Incorrect type {python_type}, must be either a FlyteFile or os.PathLike")
Expand Down Expand Up @@ -420,7 +412,7 @@ def to_python_value(
return FlyteFile(uri)

# Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type
expected_python_type = _get_origin_type(expected_python_type)
expected_python_type = get_underlying_type(expected_python_type)

# The rest of the logic is only for FlyteFile types.
if not issubclass(expected_python_type, FlyteFile): # type: ignore
Expand Down
22 changes: 22 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
UnionTransformer,
convert_json_schema_to_python_class,
dataclass_from_dict,
get_underlying_type,
is_annotated,
)
from flytekit.exceptions import user as user_exceptions
from flytekit.models import types as model_types
Expand Down Expand Up @@ -1638,3 +1640,23 @@ def test_batch_pickle_list(python_val, python_type, expected_list_length):
# data = task0() # task0() -> Annotated[typing.List[FlytePickle], BatchSize(2)]
# task1(data=data) # task1(data: typing.List[FlytePickle])
assert pv == python_val


@pytest.mark.parametrize(
"t,expected",
[
(list, False),
(Annotated[int, "tag"], True),
(Annotated[list[str], "a", "b"], True),
(Annotated[dict[int, str], FlyteAnnotation({"foo": "bar"})], True),
],
)
def test_is_annotated(t, expected):
assert is_annotated(t) == expected


@pytest.mark.parametrize(
"t,expected", [(list, list), (Annotated[int, "tag"], int), (Annotated[list[str], "a", "b"], list[str])]
)
def test_get_underlying_type(t, expected):
assert get_underlying_type(t) == expected

0 comments on commit 093536f

Please sign in to comment.