Skip to content

Commit

Permalink
Added more tests
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Oct 14, 2021
1 parent b9bd683 commit be3083e
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 12 deletions.
6 changes: 6 additions & 0 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ def transform_signature_to_interface(signature: inspect.Signature, docstring: Op
try:
if hasattr(v, "__origin__") and v.__origin__ == list:
TypeEngine.get_transformer(v.__args__[0])
elif hasattr(v, "__origin__") and v.__origin__ == dict:
TypeEngine.get_transformer(v.__args__[1])
else:
TypeEngine.get_transformer(v)

Expand All @@ -275,6 +277,8 @@ def transform_signature_to_interface(signature: inspect.Signature, docstring: Op
try:
if hasattr(annotation, "__origin__") and annotation.__origin__ == list:
TypeEngine.get_transformer(annotation.__args__[0])
elif hasattr(annotation, "__origin__") and annotation.__origin__ == dict:
TypeEngine.get_transformer(annotation.__args__[1])
else:
TypeEngine.get_transformer(annotation)
except ValueError:
Expand Down Expand Up @@ -316,6 +320,8 @@ def transform_variable_map(
if isinstance(v, FlytePickle):
if hasattr(v.python_type, "__origin__") and v.python_type.__origin__ is list:
res[k].type.metadata = {"python_class_name": v.python_type.__args__[0].__name__}
elif hasattr(v.python_type, "__origin__") and v.python_type.__origin__ is dict:
res[k].type.metadata = {"python_class_name": v.python_type.__args__[1].__name__}
else:
res[k].type.metadata = {"python_class_name": v.python_type.__name__}
return res
Expand Down
4 changes: 4 additions & 0 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flytekit.models import types as type_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.literals import Primitive
from flytekit.types.pickle.pickle import FlytePickleTransformer


def translate_inputs_to_literals(
Expand Down Expand Up @@ -66,6 +67,9 @@ def my_wf(in1: int, in2: int) -> int:
def extract_value(
ctx: FlyteContext, input_val: Any, val_type: type, flyte_literal_type: _type_models.LiteralType
) -> _literal_models.Literal:
if flyte_literal_type.blob and flyte_literal_type.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT:
return TypeEngine.to_literal(ctx, input_val, val_type, flyte_literal_type)

if isinstance(input_val, list):
if flyte_literal_type.collection_type is None:
raise TypeError(f"Not a collection type {flyte_literal_type} but got a list {input_val}")
Expand Down
19 changes: 9 additions & 10 deletions flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(self, python_type: Type[T]):


class FlytePickleTransformer(TypeTransformer[FlytePickle]):
PICKLE = "pickle"
PYTHON_PICKLE_FORMAT = "PythonPickle"

def __init__(self):
Expand All @@ -40,11 +39,11 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
# Deserialize the pickle, and return data in the pickle,
# and download pickle file to local first if file is not in the local file systems.
if ctx.file_access.is_remote(uri):
ctx.file_access.get_data(uri, self.PICKLE, False)
uri = self.PICKLE
infile = open(uri, "rb")
data = cloudpickle.load(infile)
infile.close()
local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, False)
uri = local_path
with open(uri, "rb") as infile:
data = cloudpickle.load(infile)
return data

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
Expand All @@ -56,10 +55,10 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
# Dump the task output into pickle
local_dir = ctx.file_access.get_random_local_directory()
os.makedirs(local_dir, exist_ok=True)
uri = os.path.join(local_dir, self.PICKLE)
outfile = open(uri, "w+b")
cloudpickle.dump(python_val, outfile)
outfile.close()
local_path = ctx.file_access.get_random_local_path()
uri = os.path.join(local_dir, local_path)
with open(uri, "w+b") as outfile:
cloudpickle.dump(python_val, outfile)

remote_path = ctx.file_access.get_random_remote_path(uri)
ctx.file_access.put_data(uri, remote_path, is_multipart=False)
Expand Down
17 changes: 15 additions & 2 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,15 +1008,28 @@ def my_wf(a: int, b: str) -> (MyCustomType, int):

def test_arbit_class():
class Foo(object):
pass
def __init__(self, number: int):
self.number = number

@task
def t1(a: int) -> Foo:
return Foo()
return Foo(number=a)

@task
def t2(a: Foo) -> typing.List[Foo]:
return [a, a]

@task
def t3(a: typing.List[Foo]) -> typing.Dict[str, Foo]:
return {"hello": a[0]}

def wf(a: int) -> typing.Dict[str, Foo]:
o1 = t1(a=a)
o2 = t2(a=o1)
return t3(a=o2)

wf(1)


def test_dataclass_more():
@dataclass_json
Expand Down

0 comments on commit be3083e

Please sign in to comment.