From a19043116f507d149c7ff7858ab97100717e51ee Mon Sep 17 00:00:00 2001 From: Adrian Rumpold Date: Tue, 21 Mar 2023 22:21:41 +0100 Subject: [PATCH] Make `FlyteFile` compatible with `Annotated[..., HashMethod]` (#1544) * fix: Make FlyteFile compatible with Annotated[..., HashMethod] See issue #3424 Signed-off-by: Adrian Rumpold * tests: Add test case for FlyteFile with HashMethod annotation Issue: #3424 Signed-off-by: Adrian Rumpold * fix: Use typing_extensions.Annotated for py3.8 compatibility Issue: #3424 Signed-off-by: Adrian Rumpold * fix: Use `get_args` and `get_origin` from typing_extensions for py3.8 compatibility Issue: #3424 Signed-off-by: Adrian Rumpold * fix(tests): Use fixture for local dummy file Signed-off-by: Adrian Rumpold --------- Signed-off-by: Adrian Rumpold --- flytekit/types/file/file.py | 5 +++++ tests/flytekit/unit/core/test_flyte_file.py | 17 +++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 23f4137344..bb8feb3d9c 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -8,6 +8,7 @@ from dataclasses_json import config, dataclass_json from marshmallow import fields +from typing_extensions import Annotated, get_args, get_origin from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError @@ -335,6 +336,10 @@ def to_literal( if python_val is None: raise TypeTransformerFailedError("None value cannot be converted to a file.") + # Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type + if get_origin(python_type) is Annotated: + python_type = get_args(python_type)[0] + if not (python_type is os.PathLike or issubclass(python_type, FlyteFile)): raise ValueError(f"Incorrect type {python_type}, must be either a FlyteFile or os.PathLike") diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 1c1593ad4c..b7f0a1aeee 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -5,12 +5,14 @@ from unittest.mock import MagicMock import pytest +from typing_extensions import Annotated import flytekit.configuration from flytekit.configuration import Config, Image, ImageConfig from flytekit.core.context_manager import ExecutionState, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider, flyte_tmp_dir from flytekit.core.dynamic_workflow_task import dynamic +from flytekit.core.hash import HashMethod from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine @@ -433,6 +435,21 @@ def wf(path: str) -> os.PathLike: assert flyte_tmp_dir in wf(path="s3://somewhere").path +def test_flyte_file_annotated_hashmethod(local_dummy_file): + def calc_hash(ff: FlyteFile) -> str: + return str(ff.path) + + @task + def t1(path: str) -> Annotated[FlyteFile, HashMethod(calc_hash)]: + return FlyteFile(path) + + @workflow + def wf(path: str) -> None: + t1(path=path) + + wf(path=local_dummy_file) + + @pytest.mark.sandbox_test def test_file_open_things(): @task