diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 40b39eae90..6ddeb5c58c 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -361,6 +361,12 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A from flytekit.types.schema.types import FlyteSchema from flytekit.types.structured.structured_dataset import StructuredDataset + # Handle Optional + if get_origin(python_type) is typing.Union and type(None) in get_args(python_type): + if python_val is None: + return None + return self._serialize_flyte_type(python_val, get_args(python_type)[0]) + if hasattr(python_type, "__origin__") and python_type.__origin__ is list: return [self._serialize_flyte_type(v, python_type.__args__[0]) for v in python_val] @@ -400,12 +406,18 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A python_val.__setattr__(v.name, self._serialize_flyte_type(val, field_type)) return python_val - def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> T: + def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> Optional[T]: from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer from flytekit.types.schema.types import FlyteSchema, FlyteSchemaTransformer from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine + # Handle Optional + if get_origin(expected_python_type) is typing.Union and type(None) in get_args(expected_python_type): + if python_val is None: + return None + return self._deserialize_flyte_type(python_val, get_args(expected_python_type)[0]) + if hasattr(expected_python_type, "__origin__") and expected_python_type.__origin__ is list: return [self._deserialize_flyte_type(v, expected_python_type.__args__[0]) for v in python_val] # type: ignore diff --git a/plugins/flytekit-deck-standard/requirements.txt b/plugins/flytekit-deck-standard/requirements.txt index 6270d54a81..4efdce8486 100644 --- a/plugins/flytekit-deck-standard/requirements.txt +++ b/plugins/flytekit-deck-standard/requirements.txt @@ -6,10 +6,18 @@ # -e file:.#egg=flytekitplugins-deck-standard # via -r requirements.in +appnope==0.1.3 + # via + # ipykernel + # ipython arrow==1.2.3 # via jinja2-time +asttokens==2.2.1 + # via stack-data attrs==22.1.0 # via visions +backcall==0.2.0 + # via ipython binaryornot==0.4.4 # via cookiecutter certifi==2022.12.7 @@ -26,6 +34,8 @@ click==8.1.3 # flytekit cloudpickle==2.2.0 # via flytekit +comm==0.1.2 + # via ipykernel contourpy==1.0.6 # via matplotlib cookiecutter==2.1.1 @@ -33,15 +43,17 @@ cookiecutter==2.1.1 croniter==1.3.8 # via flytekit cryptography==38.0.4 - # via - # pyopenssl - # secretstorage + # via pyopenssl cycler==0.11.0 # via matplotlib dataclasses-json==0.5.7 # via flytekit +debugpy==1.6.4 + # via ipykernel decorator==5.1.1 - # via retry + # via + # ipython + # retry deprecated==1.2.13 # via flytekit diskcache==5.4.0 @@ -52,6 +64,10 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit +entrypoints==0.4 + # via jupyter-client +executing==1.2.0 + # via stack-data flyteidl==1.3.0 # via flytekit flytekit==1.3.0b2 @@ -79,12 +95,18 @@ importlib-metadata==5.1.0 # flytekit # keyring # markdown +ipykernel==6.19.4 + # via ipywidgets +ipython==8.7.0 + # via + # ipykernel + # ipywidgets +ipywidgets==8.0.4 + # via flytekitplugins-deck-standard jaraco-classes==3.2.3 # via keyring -jeepney==0.8.0 - # via - # keyring - # secretstorage +jedi==0.18.2 + # via ipython jinja2==3.1.2 # via # cookiecutter @@ -96,6 +118,12 @@ joblib==1.2.0 # via # flytekit # phik +jupyter-client==7.4.8 + # via ipykernel +jupyter-core==5.1.1 + # via jupyter-client +jupyterlab-widgets==3.0.5 + # via ipywidgets keyring==23.11.0 # via flytekit kiwisolver==1.4.4 @@ -118,6 +146,10 @@ matplotlib==3.6.2 # pandas-profiling # phik # seaborn +matplotlib-inline==0.1.6 + # via + # ipykernel + # ipython more-itertools==9.0.0 # via jaraco-classes multimethod==1.9 @@ -128,6 +160,10 @@ mypy-extensions==0.4.3 # via typing-inspect natsort==8.2.0 # via flytekit +nest-asyncio==1.5.6 + # via + # ipykernel + # jupyter-client networkx==2.8.8 # via visions numpy==1.23.5 @@ -148,6 +184,7 @@ numpy==1.23.5 packaging==22.0 # via # docker + # ipykernel # marshmallow # matplotlib # statsmodels @@ -161,17 +198,27 @@ pandas==1.5.2 # visions pandas-profiling==3.5.0 # via flytekitplugins-deck-standard +parso==0.8.3 + # via jedi patsy==0.5.3 # via statsmodels +pexpect==4.8.0 + # via ipython phik==0.12.3 # via pandas-profiling +pickleshare==0.7.5 + # via ipython pillow==9.3.0 # via # imagehash # matplotlib # visions +platformdirs==2.6.0 + # via jupyter-core plotly==5.11.0 # via flytekitplugins-deck-standard +prompt-toolkit==3.0.36 + # via ipython protobuf==4.21.11 # via # flyteidl @@ -180,6 +227,12 @@ protobuf==4.21.11 # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl +psutil==5.9.4 + # via ipykernel +ptyprocess==0.7.0 + # via pexpect +pure-eval==0.2.2 + # via stack-data py==1.11.0 # via retry pyarrow==10.0.1 @@ -188,6 +241,8 @@ pycparser==2.21 # via cffi pydantic==1.10.2 # via pandas-profiling +pygments==2.13.0 + # via ipython pyopenssl==22.1.0 # via flytekit pyparsing==3.0.9 @@ -197,6 +252,7 @@ python-dateutil==2.8.2 # arrow # croniter # flytekit + # jupyter-client # matplotlib # pandas python-json-logger==2.0.4 @@ -216,6 +272,10 @@ pyyaml==6.0 # cookiecutter # flytekit # pandas-profiling +pyzmq==24.0.1 + # via + # ipykernel + # jupyter-client regex==2022.10.31 # via docker-image-py requests==2.28.1 @@ -237,14 +297,15 @@ scipy==1.9.3 # statsmodels seaborn==0.12.1 # via pandas-profiling -secretstorage==3.3.3 - # via keyring six==1.16.0 # via + # asttokens # patsy # python-dateutil sortedcontainers==2.4.0 # via flytekit +stack-data==0.6.2 + # via ipython statsd==3.3.0 # via flytekit statsmodels==0.13.5 @@ -257,8 +318,21 @@ text-unidecode==1.3 # via python-slugify toml==0.10.2 # via responses +tornado==6.2 + # via + # ipykernel + # jupyter-client tqdm==4.64.1 # via pandas-profiling +traitlets==5.8.0 + # via + # comm + # ipykernel + # ipython + # ipywidgets + # jupyter-client + # jupyter-core + # matplotlib-inline typeguard==2.13.3 # via pandas-profiling types-toml==0.10.8.1 @@ -278,10 +352,14 @@ urllib3==1.26.13 # responses visions[type_image_path]==0.7.5 # via pandas-profiling +wcwidth==0.2.5 + # via prompt-toolkit websocket-client==1.4.2 # via docker wheel==0.38.4 # via flytekit +widgetsnbextension==4.0.5 + # via ipywidgets wrapt==1.14.1 # via # deprecated diff --git a/plugins/flytekit-deck-standard/setup.py b/plugins/flytekit-deck-standard/setup.py index a2e087eb20..82ac6788dd 100644 --- a/plugins/flytekit-deck-standard/setup.py +++ b/plugins/flytekit-deck-standard/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}-standard" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "markdown", "plotly", "pandas_profiling"] +plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "markdown", "plotly", "pandas_profiling", "ipywidgets"] __version__ = "0.0.0+develop" diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 3e813c0fb7..bbe46845fd 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -6,6 +6,7 @@ from datetime import timedelta from enum import Enum +import mock import pandas as pd import pyarrow as pa import pytest @@ -569,6 +570,90 @@ def test_dataclass_int_preserving(): assert ot == o +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +def test_optional_flytefile_in_dataclass(mock_upload_dir): + mock_upload_dir.return_value = True + + @dataclass_json + @dataclass + class A(object): + a: int + + @dataclass_json + @dataclass + class TestFileStruct(object): + a: FlyteFile + b: typing.Optional[FlyteFile] + b_prime: typing.Optional[FlyteFile] + c: typing.Union[FlyteFile, None] + d: typing.List[FlyteFile] + e: typing.List[typing.Optional[FlyteFile]] + e_prime: typing.List[typing.Optional[FlyteFile]] + f: typing.Dict[str, FlyteFile] + g: typing.Dict[str, typing.Optional[FlyteFile]] + g_prime: typing.Dict[str, typing.Optional[FlyteFile]] + h: typing.Optional[FlyteFile] = None + h_prime: typing.Optional[FlyteFile] = None + i: typing.Optional[A] = None + i_prime: typing.Optional[A] = A(a=99) + + remote_path = "s3://tmp/file" + with tempfile.TemporaryFile() as f: + f.write(b"abc") + f1 = FlyteFile("f1", remote_path=remote_path) + o = TestFileStruct( + a=f1, + b=f1, + b_prime=None, + c=f1, + d=[f1], + e=[f1], + e_prime=[None], + f={"a": f1}, + g={"a": f1}, + g_prime={"a": None}, + h=f1, + i=A(a=42), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct) + lv = tf.to_literal(ctx, o, TestFileStruct, lt) + + assert lv.scalar.generic["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["b"].fields["path"].string_value == remote_path + assert lv.scalar.generic["b_prime"] is None + assert lv.scalar.generic["c"].fields["path"].string_value == remote_path + assert lv.scalar.generic["d"].values[0].struct_value.fields["path"].string_value == remote_path + assert lv.scalar.generic["e"].values[0].struct_value.fields["path"].string_value == remote_path + assert lv.scalar.generic["e_prime"].values[0].WhichOneof("kind") == "null_value" + assert lv.scalar.generic["f"]["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["g"]["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["g_prime"]["a"] is None + assert lv.scalar.generic["h"].fields["path"].string_value == remote_path + assert lv.scalar.generic["h_prime"] is None + assert lv.scalar.generic["i"].fields["a"].number_value == 42 + assert lv.scalar.generic["i_prime"].fields["a"].number_value == 99 + + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct) + + assert o.a.path == ot.a.remote_source + assert o.b.path == ot.b.remote_source + assert ot.b_prime is None + assert o.c.path == ot.c.remote_source + assert o.d[0].path == ot.d[0].remote_source + assert o.e[0].path == ot.e[0].remote_source + assert o.e_prime == [None] + assert o.f["a"].path == ot.f["a"].remote_source + assert o.g["a"].path == ot.g["a"].remote_source + assert o.g_prime == {"a": None} + assert o.h.path == ot.h.remote_source + assert ot.h_prime is None + assert o.i == ot.i + assert o.i_prime == A(a=99) + + def test_flyte_file_in_dataclass(): @dataclass_json @dataclass