Skip to content

Commit

Permalink
Pickle in Union Type (#1147)
Browse files Browse the repository at this point in the history
* Pickel in Union type

Signed-off-by: Kevin Su <[email protected]>

* Pickel in Union type

Signed-off-by: Kevin Su <[email protected]>

* wip

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* fix tests

Signed-off-by: Kevin Su <[email protected]>

* update tests

Signed-off-by: Kevin Su <[email protected]>

* fix tests

Signed-off-by: Kevin Su <[email protected]>

* fix tests

Signed-off-by: Kevin Su <[email protected]>

* fix tests

Signed-off-by: Kevin Su <[email protected]>

* fix tests

Signed-off-by: Kevin Su <[email protected]>

* fix tests

Signed-off-by: Kevin Su <[email protected]>

* fix tests

Signed-off-by: Kevin Su <[email protected]>

* fix tests

Signed-off-by: Kevin Su <[email protected]>

* fix tests

Signed-off-by: Kevin Su <[email protected]>

* Address comment

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Sep 8, 2022
1 parent eced84b commit 61bd02a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 14 deletions.
20 changes: 14 additions & 6 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import OrderedDict
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union

from typing_extensions import get_args, get_origin, get_type_hints
from typing_extensions import Annotated, get_args, get_origin, get_type_hints

from flytekit.core import context_manager
from flytekit.core.docstring import Docstring
Expand Down Expand Up @@ -259,12 +259,16 @@ def transform_interface_to_list_interface(interface: Interface) -> Interface:
def _change_unrecognized_type_to_pickle(t: Type[T]) -> Type[T]:
try:
if hasattr(t, "__origin__") and hasattr(t, "__args__"):
if t.__origin__ == list:
if get_origin(t) is list:
return typing.List[_change_unrecognized_type_to_pickle(t.__args__[0])]
elif t.__origin__ == dict and t.__args__[0] == str:
elif get_origin(t) is dict and t.__args__[0] == str:
return typing.Dict[str, _change_unrecognized_type_to_pickle(t.__args__[1])]
else:
TypeEngine.get_transformer(t)
elif get_origin(t) is typing.Union:
return typing.Union[tuple(_change_unrecognized_type_to_pickle(v) for v in get_args(t))]
elif get_origin(t) is Annotated:
base_type, *config = get_args(t)
return Annotated[(_change_unrecognized_type_to_pickle(base_type), *config)]
TypeEngine.get_transformer(t)
except ValueError:
logger.warning(
f"Unsupported Type {t} found, Flyte will default to use PickleFile as the transport. "
Expand Down Expand Up @@ -329,7 +333,11 @@ def transform_variable_map(
elif v.__origin__ is dict:
sub_type = v.__args__[1]
if hasattr(sub_type, "__origin__") and sub_type.__origin__ is FlytePickle:
res[k].type.metadata = {"python_class_name": sub_type.python_type().__name__}
if hasattr(sub_type.python_type(), "__name__"):
res[k].type.metadata = {"python_class_name": sub_type.python_type().__name__}
elif hasattr(sub_type.python_type(), "_name"):
# If the class doesn't have the __name__ attribute, like typing.Sequence, use _name instead.
res[k].type.metadata = {"python_class_name": sub_type.python_type()._name}

return res

Expand Down
19 changes: 18 additions & 1 deletion tests/flytekit/unit/core/test_flyte_pickle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from collections import OrderedDict
from typing import Dict, List
from collections.abc import Sequence
from typing import Dict, List, Union

import numpy as np
import pandas as pd
from typing_extensions import Annotated

import flytekit.configuration
from flytekit.configuration import Image, ImageConfig
Expand Down Expand Up @@ -80,3 +85,15 @@ def t1(a: int) -> List[Dict[str, Foo]]:
task_spec.template.interface.outputs["o0"].type.collection_type.map_value_type.blob.format
is FlytePickleTransformer.PYTHON_PICKLE_FORMAT
)


def test_union():
@task
def t1(data: Annotated[Union[np.ndarray, pd.DataFrame, Sequence], "some annotation"]):
print(data)

task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
variants = task_spec.template.interface.inputs["data"].type.union_type.variants
assert variants[0].blob.format == "NumpyArray"
assert variants[1].structured_dataset_type.format == "parquet"
assert variants[2].blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
import os
import typing

import pytest

try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from typing_extensions import Annotated

from flytekit import FlyteContext, FlyteContextManager, kwtypes, task, workflow
from flytekit.models import literals
Expand Down

0 comments on commit 61bd02a

Please sign in to comment.