diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 40b39eae90..6ddeb5c58c 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -361,6 +361,12 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A from flytekit.types.schema.types import FlyteSchema from flytekit.types.structured.structured_dataset import StructuredDataset + # Handle Optional + if get_origin(python_type) is typing.Union and type(None) in get_args(python_type): + if python_val is None: + return None + return self._serialize_flyte_type(python_val, get_args(python_type)[0]) + if hasattr(python_type, "__origin__") and python_type.__origin__ is list: return [self._serialize_flyte_type(v, python_type.__args__[0]) for v in python_val] @@ -400,12 +406,18 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A python_val.__setattr__(v.name, self._serialize_flyte_type(val, field_type)) return python_val - def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> T: + def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> Optional[T]: from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer from flytekit.types.schema.types import FlyteSchema, FlyteSchemaTransformer from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine + # Handle Optional + if get_origin(expected_python_type) is typing.Union and type(None) in get_args(expected_python_type): + if python_val is None: + return None + return self._deserialize_flyte_type(python_val, get_args(expected_python_type)[0]) + if hasattr(expected_python_type, "__origin__") and expected_python_type.__origin__ is list: return [self._deserialize_flyte_type(v, expected_python_type.__args__[0]) for v in python_val] # type: ignore diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 3e813c0fb7..bbe46845fd 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -6,6 +6,7 @@ from datetime import timedelta from enum import Enum +import mock import pandas as pd import pyarrow as pa import pytest @@ -569,6 +570,90 @@ def test_dataclass_int_preserving(): assert ot == o +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +def test_optional_flytefile_in_dataclass(mock_upload_dir): + mock_upload_dir.return_value = True + + @dataclass_json + @dataclass + class A(object): + a: int + + @dataclass_json + @dataclass + class TestFileStruct(object): + a: FlyteFile + b: typing.Optional[FlyteFile] + b_prime: typing.Optional[FlyteFile] + c: typing.Union[FlyteFile, None] + d: typing.List[FlyteFile] + e: typing.List[typing.Optional[FlyteFile]] + e_prime: typing.List[typing.Optional[FlyteFile]] + f: typing.Dict[str, FlyteFile] + g: typing.Dict[str, typing.Optional[FlyteFile]] + g_prime: typing.Dict[str, typing.Optional[FlyteFile]] + h: typing.Optional[FlyteFile] = None + h_prime: typing.Optional[FlyteFile] = None + i: typing.Optional[A] = None + i_prime: typing.Optional[A] = A(a=99) + + remote_path = "s3://tmp/file" + with tempfile.TemporaryFile() as f: + f.write(b"abc") + f1 = FlyteFile("f1", remote_path=remote_path) + o = TestFileStruct( + a=f1, + b=f1, + b_prime=None, + c=f1, + d=[f1], + e=[f1], + e_prime=[None], + f={"a": f1}, + g={"a": f1}, + g_prime={"a": None}, + h=f1, + i=A(a=42), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct) + lv = tf.to_literal(ctx, o, TestFileStruct, lt) + + assert lv.scalar.generic["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["b"].fields["path"].string_value == remote_path + assert lv.scalar.generic["b_prime"] is None + assert lv.scalar.generic["c"].fields["path"].string_value == remote_path + assert lv.scalar.generic["d"].values[0].struct_value.fields["path"].string_value == remote_path + assert lv.scalar.generic["e"].values[0].struct_value.fields["path"].string_value == remote_path + assert lv.scalar.generic["e_prime"].values[0].WhichOneof("kind") == "null_value" + assert lv.scalar.generic["f"]["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["g"]["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["g_prime"]["a"] is None + assert lv.scalar.generic["h"].fields["path"].string_value == remote_path + assert lv.scalar.generic["h_prime"] is None + assert lv.scalar.generic["i"].fields["a"].number_value == 42 + assert lv.scalar.generic["i_prime"].fields["a"].number_value == 99 + + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct) + + assert o.a.path == ot.a.remote_source + assert o.b.path == ot.b.remote_source + assert ot.b_prime is None + assert o.c.path == ot.c.remote_source + assert o.d[0].path == ot.d[0].remote_source + assert o.e[0].path == ot.e[0].remote_source + assert o.e_prime == [None] + assert o.f["a"].path == ot.f["a"].remote_source + assert o.g["a"].path == ot.g["a"].remote_source + assert o.g_prime == {"a": None} + assert o.h.path == ot.h.remote_source + assert ot.h_prime is None + assert o.i == ot.i + assert o.i_prime == A(a=99) + + def test_flyte_file_in_dataclass(): @dataclass_json @dataclass