diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 23f4137344..bb8feb3d9c 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -8,6 +8,7 @@ from dataclasses_json import config, dataclass_json from marshmallow import fields +from typing_extensions import Annotated, get_args, get_origin from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError @@ -335,6 +336,10 @@ def to_literal( if python_val is None: raise TypeTransformerFailedError("None value cannot be converted to a file.") + # Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type + if get_origin(python_type) is Annotated: + python_type = get_args(python_type)[0] + if not (python_type is os.PathLike or issubclass(python_type, FlyteFile)): raise ValueError(f"Incorrect type {python_type}, must be either a FlyteFile or os.PathLike") diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 1c1593ad4c..b7f0a1aeee 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -5,12 +5,14 @@ from unittest.mock import MagicMock import pytest +from typing_extensions import Annotated import flytekit.configuration from flytekit.configuration import Config, Image, ImageConfig from flytekit.core.context_manager import ExecutionState, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider, flyte_tmp_dir from flytekit.core.dynamic_workflow_task import dynamic +from flytekit.core.hash import HashMethod from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine @@ -433,6 +435,21 @@ def wf(path: str) -> os.PathLike: assert flyte_tmp_dir in wf(path="s3://somewhere").path +def test_flyte_file_annotated_hashmethod(local_dummy_file): + def calc_hash(ff: FlyteFile) -> str: + return str(ff.path) + + @task + def t1(path: str) -> Annotated[FlyteFile, HashMethod(calc_hash)]: + return FlyteFile(path) + + @workflow + def wf(path: str) -> None: + t1(path=path) + + wf(path=local_dummy_file) + + @pytest.mark.sandbox_test def test_file_open_things(): @task