Skip to content

Commit

Permalink
Add guess to DictTransformer (#658)
Browse files Browse the repository at this point in the history
* add guess

Signed-off-by: Yee Hing Tong <[email protected]>

* test

Signed-off-by: Yee Hing Tong <[email protected]>

* Add expected value for test_dict3

Signed-off-by: Eduardo Apolinario <[email protected]>

Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
wild-endeavor and eapolinario authored Sep 15, 2021
1 parent 7d09ec2 commit 373e50b
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 3 deletions.
4 changes: 4 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,10 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
if literal_type.map_value_type:
mt = TypeEngine.guess_python_type(literal_type.map_value_type)
return typing.Dict[str, mt] # type: ignore

if literal_type.simple == SimpleType.STRUCT and literal_type.metadata is None:
return dict

raise ValueError(f"Dictionary transformer cannot reverse {literal_type}")


Expand Down
67 changes: 64 additions & 3 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
import pandas
import pytest
from dataclasses_json import dataclass_json
from google.protobuf.struct_pb2 import Struct

import flytekit
from flytekit import ContainerTask, Secret, SQLTask, dynamic, kwtypes, map_task
from flytekit.common.translator import get_serializable
from flytekit.core import context_manager, launch_plan, promise
from flytekit.core.condition import conditional
from flytekit.core.context_manager import ExecutionState, FastSerializationSettings, Image, ImageConfig

# from flytekit.interfaces.data.data_proxy import FileAccessProvider
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.node import Node
from flytekit.core.promise import NodeOutput, Promise, VoidPromise
Expand All @@ -30,9 +29,17 @@
from flytekit.models.core import types as _core_types
from flytekit.models.interface import Parameter
from flytekit.models.task import Resources as _resource_models
from flytekit.models.types import LiteralType
from flytekit.models.types import LiteralType, SimpleType
from flytekit.types.schema import FlyteSchema, SchemaOpenMode

serialization_settings = context_manager.SerializationSettings(
project="proj",
domain="dom",
version="123",
image_config=ImageConfig(Image(name="name", fqn="asdf/fdsa", tag="123")),
env={},
)


def test_default_wf_params_works():
@task
Expand Down Expand Up @@ -1277,3 +1284,57 @@ def consume_outputs(my_input: int, seed: int = 5) -> int:

assert consume_outputs(my_input=4, seed=7) == 32
assert consume_outputs(my_input=4) == 16


def test_guess_dict():
@task
def t2(a: dict) -> str:
return ", ".join([f"K: {k} V: {v}" for k, v in a.items()])

task_spec = get_serializable(OrderedDict(), serialization_settings, t2)
assert task_spec.template.interface.inputs["a"].type.simple == SimpleType.STRUCT

pt = TypeEngine.guess_python_type(task_spec.template.interface.inputs["a"].type)
assert pt is dict

input_map = {"a": {"k1": "v1", "k2": "2"}}
guessed_types = {"a": pt}
ctx = context_manager.FlyteContext.current_context()
lm = TypeEngine.dict_to_literal_map(ctx, d=input_map, guessed_python_types=guessed_types)
assert isinstance(lm.literals["a"].scalar.generic, Struct)

output_lm = t2.dispatch_execute(ctx, lm)
str_value = output_lm.literals["o0"].scalar.primitive.string_value
assert str_value == "K: k2 V: 2, K: k1 V: v1" or str_value == "K: k1 V: v1, K: k2 V: 2"


def test_guess_dict2():
@task
def t2(a: typing.List[dict]) -> str:
strs = []
for input_dict in a:
strs.append(", ".join([f"K: {k} V: {v}" for k, v in input_dict.items()]))
return " ".join(strs)

task_spec = get_serializable(OrderedDict(), serialization_settings, t2)
assert task_spec.template.interface.inputs["a"].type.collection_type.simple == SimpleType.STRUCT
pt_map = TypeEngine.guess_python_types(task_spec.template.interface.inputs)
assert pt_map == {"a": typing.List[dict]}


def test_guess_dict3():
@task
def t2() -> dict:
return {"k1": "v1", "k2": 3, 4: {"one": [1, "two", [3]]}}

task_spec = get_serializable(OrderedDict(), serialization_settings, t2)
print(task_spec.template.interface.outputs["o0"])

pt_map = TypeEngine.guess_python_types(task_spec.template.interface.outputs)
assert pt_map["o0"] is dict

ctx = context_manager.FlyteContextManager.current_context()
output_lm = t2.dispatch_execute(ctx, _literal_models.LiteralMap(literals={}))
expected_struct = Struct()
expected_struct.update({"k1": "v1", "k2": 3, "4": {"one": [1, "two", [3]]}})
assert output_lm.literals["o0"].scalar.generic == expected_struct

0 comments on commit 373e50b

Please sign in to comment.