Skip to content

Commit

Permalink
Allow annotated FlyteFile as task input argument (#1632)
Browse files Browse the repository at this point in the history
* fix: Allow annotated FlyteFile as task input argument

Using an annotated FlyteFile type as an input to a task was previously impossible due
to an exception being raised in `FlyteFilePathTransformer.to_python_value`.

This commit applies the fix previously used in `FlyteFilePathTransformer.to_literal`
to permit using annotated FlyteFiles as either inputs and outputs of a task.

Issue: #3424
Signed-off-by: Adrian Rumpold <[email protected]>

* refactor: Unified handling of annotated types in type engine

Issue: #3424
Signed-off-by: Adrian Rumpold <[email protected]>

* fix: Use py3.8-compatible types in type engine tests

Issue: #3424
Signed-off-by: Adrian Rumpold <[email protected]>

---------

Signed-off-by: Adrian Rumpold <[email protected]>
  • Loading branch information
AdrianoKF authored May 19, 2023
1 parent 47be974 commit 06fffc7
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 27 deletions.
47 changes: 27 additions & 20 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
return self._to_literal_transformer(python_val)

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
if get_origin(expected_python_type) is Annotated:
expected_python_type = get_args(expected_python_type)[0]
expected_python_type = get_underlying_type(expected_python_type)

if expected_python_type != self._type:
raise TypeTransformerFailedError(
Expand Down Expand Up @@ -311,7 +310,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
Extracts the Literal type definition for a Dataclass and returns a type Struct.
If possible also extracts the JSONSchema for the dataclass.
"""
if get_origin(t) is Annotated:
if is_annotated(t):
raise ValueError(
"Flytekit does not currently have support for FlyteAnnotations applied to Dataclass."
f"Type {t} cannot be parsed."
Expand Down Expand Up @@ -368,7 +367,7 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]:
self._get_origin_type_in_annotation(get_args(python_type)[0]),
self._get_origin_type_in_annotation(get_args(python_type)[1]),
]
elif get_origin(python_type) is Annotated:
elif is_annotated(python_type):
return get_args(python_type)[0]
elif dataclasses.is_dataclass(python_type):
for field in dataclasses.fields(copy.deepcopy(python_type)):
Expand Down Expand Up @@ -737,7 +736,7 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
"""
cls.lazy_import_transformers()
# Step 1
if get_origin(python_type) is Annotated:
if is_annotated(python_type):
args = get_args(python_type)
for annotation in args:
if isinstance(annotation, TypeTransformer):
Expand All @@ -752,7 +751,7 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
if hasattr(python_type, "__origin__"):
# Handling of annotated generics, eg:
# Annotated[typing.List[int], 'foo']
if get_origin(python_type) is Annotated:
if is_annotated(python_type):
return cls.get_transformer(get_args(python_type)[0])

if python_type.__origin__ in cls._REGISTRY:
Expand Down Expand Up @@ -823,7 +822,7 @@ def to_literal_type(cls, python_type: Type) -> LiteralType:
transformer = cls.get_transformer(python_type)
res = transformer.get_literal_type(python_type)
data = None
if get_origin(python_type) is Annotated:
if is_annotated(python_type):
for x in get_args(python_type)[1:]:
if not isinstance(x, FlyteAnnotation):
continue
Expand Down Expand Up @@ -851,9 +850,9 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type

# In case the value is an annotated type we inspect the annotations and look for hash-related annotations.
hash = None
if get_origin(python_type) is Annotated:
if is_annotated(python_type):
# We are now dealing with one of two cases:
# 1. The annotated type is a `HashMethod`, which indicates that we should we should produce the hash using
# 1. The annotated type is a `HashMethod`, which indicates that we should produce the hash using
# the method indicated in the annotation.
# 2. The annotated type is being used for a different purpose other than calculating hash values, in which case
# we should just continue.
Expand All @@ -880,7 +879,7 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T
@classmethod
def to_html(cls, ctx: FlyteContext, python_val: typing.Any, expected_python_type: Type[typing.Any]) -> str:
transformer = cls.get_transformer(expected_python_type)
if get_origin(expected_python_type) is Annotated:
if is_annotated(expected_python_type):
expected_python_type, *annotate_args = get_args(expected_python_type)
from flytekit.deck.renderer import Renderable

Expand Down Expand Up @@ -1004,7 +1003,7 @@ def get_sub_type(t: Type[T]) -> Type[T]:
if hasattr(t, "__origin__"):
# Handle annotation on list generic, eg:
# Annotated[typing.List[int], 'foo']
if get_origin(t) is Annotated:
if is_annotated(t):
return ListTransformer.get_sub_type(get_args(t)[0])

if getattr(t, "__origin__") is list and hasattr(t, "__args__"):
Expand All @@ -1030,7 +1029,7 @@ def is_batchable(t: Type):
"""
from flytekit.types.pickle import FlytePickle

if get_origin(t) is Annotated:
if is_annotated(t):
return ListTransformer.is_batchable(get_args(t)[0])
if get_origin(t) is list:
subtype = get_args(t)[0]
Expand All @@ -1047,7 +1046,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp

batch_size = len(python_val) # default batch size
# parse annotated to get the number of items saved in a pickle file.
if get_origin(python_type) is Annotated:
if is_annotated(python_type):
for annotation in get_args(python_type)[1:]:
if isinstance(annotation, BatchSize):
batch_size = annotation.val
Expand Down Expand Up @@ -1191,8 +1190,7 @@ def get_sub_type_in_optional(t: Type[T]) -> Type[T]:
return get_args(t)[0]

def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
if get_origin(t) is Annotated:
t = get_args(t)[0]
t = get_underlying_type(t)

try:
trans: typing.List[typing.Tuple[TypeTransformer, typing.Any]] = [
Expand All @@ -1206,8 +1204,7 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
raise ValueError(f"Type of Generic Union type is not supported, {e}")

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]
python_type = get_underlying_type(python_type)

found_res = False
res = None
Expand All @@ -1232,8 +1229,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
raise TypeTransformerFailedError(f"Cannot convert from {python_val} to {python_type}")

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[typing.Any]:
if get_origin(expected_python_type) is Annotated:
expected_python_type = get_args(expected_python_type)[0]
expected_python_type = get_underlying_type(expected_python_type)

union_tag = None
union_type = None
Expand Down Expand Up @@ -1468,7 +1464,7 @@ def __init__(self):
super().__init__(name="DefaultEnumTransformer", t=enum.Enum)

def get_literal_type(self, t: Type[T]) -> LiteralType:
if get_origin(t) is Annotated:
if is_annotated(t):
raise ValueError(
f"Flytekit does not currently have support \
for FlyteAnnotations applied to enums. {t} cannot be \
Expand Down Expand Up @@ -1782,3 +1778,14 @@ def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any:


_register_default_type_transformers()


def is_annotated(t: Type) -> bool:
return get_origin(t) is Annotated


def get_underlying_type(t: Type) -> Type:
"""Return the underlying type for annotated types or the type itself"""
if is_annotated(t):
return get_args(t)[0]
return t
9 changes: 5 additions & 4 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@

from dataclasses_json import config, dataclass_json
from marshmallow import fields
from typing_extensions import Annotated, get_args, get_origin

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_underlying_type
from flytekit.loggers import logger
from flytekit.models.core.types import BlobType
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
Expand Down Expand Up @@ -337,8 +336,7 @@ def to_literal(
raise TypeTransformerFailedError("None value cannot be converted to a file.")

# Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]
python_type = get_underlying_type(python_type)

if not (python_type is os.PathLike or issubclass(python_type, FlyteFile)):
raise ValueError(f"Incorrect type {python_type}, must be either a FlyteFile or os.PathLike")
Expand Down Expand Up @@ -413,6 +411,9 @@ def to_python_value(
if expected_python_type is os.PathLike:
return FlyteFile(uri)

# Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type
expected_python_type = get_underlying_type(expected_python_type)

# The rest of the logic is only for FlyteFile types.
if not issubclass(expected_python_type, FlyteFile): # type: ignore
raise TypeError(f"Neither os.PathLike nor FlyteFile specified {expected_python_type}")
Expand Down
13 changes: 10 additions & 3 deletions tests/flytekit/unit/core/test_flyte_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,13 +439,20 @@ def test_flyte_file_annotated_hashmethod(local_dummy_file):
def calc_hash(ff: FlyteFile) -> str:
return str(ff.path)

HashedFlyteFile = Annotated[FlyteFile, HashMethod(calc_hash)]

@task
def t1(path: str) -> Annotated[FlyteFile, HashMethod(calc_hash)]:
return FlyteFile(path)
def t1(path: str) -> HashedFlyteFile:
return HashedFlyteFile(path)

@task
def t2(ff: HashedFlyteFile) -> None:
print(ff.path)

@workflow
def wf(path: str) -> None:
t1(path=path)
ff = t1(path=path)
t2(ff=ff)

wf(path=local_dummy_file)

Expand Down
27 changes: 27 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
UnionTransformer,
convert_json_schema_to_python_class,
dataclass_from_dict,
get_underlying_type,
is_annotated,
)
from flytekit.exceptions import user as user_exceptions
from flytekit.models import types as model_types
Expand Down Expand Up @@ -1685,3 +1687,28 @@ def test_batch_pickle_list(python_val, python_type, expected_list_length):
# data = task0() # task0() -> Annotated[typing.List[FlytePickle], BatchSize(2)]
# task1(data=data) # task1(data: typing.List[FlytePickle])
assert pv == python_val


@pytest.mark.parametrize(
"t,expected",
[
(list, False),
(Annotated[int, "tag"], True),
(Annotated[typing.List[str], "a", "b"], True),
(Annotated[typing.Dict[int, str], FlyteAnnotation({"foo": "bar"})], True),
],
)
def test_is_annotated(t, expected):
assert is_annotated(t) == expected


@pytest.mark.parametrize(
"t,expected",
[
(typing.List, typing.List),
(Annotated[int, "tag"], int),
(Annotated[typing.List[str], "a", "b"], typing.List[str]),
],
)
def test_get_underlying_type(t, expected):
assert get_underlying_type(t) == expected

0 comments on commit 06fffc7

Please sign in to comment.