Skip to content

Commit

Permalink
Enable dict support
Browse files Browse the repository at this point in the history
Signed-off-by: Mecoli1219 <[email protected]>
  • Loading branch information
Mecoli1219 committed Oct 17, 2024
1 parent 6b67687 commit fee5905
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 19 deletions.
27 changes: 15 additions & 12 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,21 @@ def my_wf(in1: int, in2: int) -> int:
translate_inputs_to_literals = loop_manager.synced(_translate_inputs_to_literals)


async def resolve_attr_path_recursively(v: Any) -> Any:
"""
This function resolves the attribute path in a nested structure recursively.
"""
if isinstance(v, Promise):
v = await resolve_attr_path_in_promise(v)
elif isinstance(v, list):
for i, elem in enumerate(v):
v[i] = await resolve_attr_path_recursively(elem)
elif isinstance(v, dict):
for k, elem in v.items():
v[k] = await resolve_attr_path_recursively(elem)
return v


async def resolve_attr_path_in_promise(p: Promise) -> Promise:
"""
resolve_attr_path_in_promise resolves the attribute path in a promise and returns a new promise with the resolved value
Expand Down Expand Up @@ -170,18 +185,6 @@ async def resolve_attr_path_in_promise(p: Promise) -> Promise:
return p


async def resolve_attr_path_recursively(v: Any) -> Any:
"""
This function resolves the attribute path in a nested structure recursively.
"""
if isinstance(v, Promise):
v = await resolve_attr_path_in_promise(v)
elif isinstance(v, list):
for i, elem in enumerate(v):
v[i] = await resolve_attr_path_recursively(elem)
return v


def resolve_attr_path_in_dict(d: dict, attr_path: List[Union[str, int]]) -> Any:
curr_val = d
for attr in attr_path:
Expand Down
64 changes: 57 additions & 7 deletions tests/flytekit/unit/core/test_promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,14 @@ def t1(a: List[str]):


lits = translate_inputs_to_literals(ctx, {"a": [
i0, i1, src_promise["a"][0]["a"][0]["b"], src_promise["a"][0]["b"]
i0, i1, src_promise["a"][0].a[0]["b"], src_promise["a"][0]["b"]
]}, t1.interface.inputs, t1.python_interface.inputs)
literal_map = literal_models.LiteralMap(literals=lits)
python_values = TypeEngine.literal_map_to_kwargs(ctx, literal_map, t1.python_interface.inputs)["a"]
assert python_values == ["foo", "bar", "foo", "bar"]


def test_translate_inputs_to_literals_with_nested_list():
def test_translate_inputs_to_literals_with_dict():
ctx = context_manager.FlyteContext.current_context()

@dataclass_json
Expand All @@ -214,16 +214,66 @@ class Bar:
i1 = Promise("n1", TypeEngine.to_literal(ctx, "bar", str, TypeEngine.to_literal_type(str)))

@task
def t1(a: List[List[str]]):
def t1(a: Dict[str, str]):
print(a)


lits = translate_inputs_to_literals(ctx, {"a": [[
i0, i1, src_promise["a"][0]["a"][0]["b"], src_promise["a"][0]["b"]
]]}, t1.interface.inputs, t1.python_interface.inputs)
lits = translate_inputs_to_literals(ctx, {"a": {
"k0": i0,
"k1": i1,
"k2": src_promise["a"][0].a[0]["b"],
"k3": src_promise["a"][0]["b"]
}}, t1.interface.inputs, t1.python_interface.inputs)
literal_map = literal_models.LiteralMap(literals=lits)
python_values = TypeEngine.literal_map_to_kwargs(ctx, literal_map, t1.python_interface.inputs)["a"]
assert python_values == [["foo", "bar", "foo", "bar"]]
assert python_values == {
"k0": "foo",
"k1": "bar",
"k2": "foo",
"k3": "bar",
}


def test_translate_inputs_to_literals_with_nested_list_and_dict():
ctx = context_manager.FlyteContext.current_context()

@dataclass_json
@dataclass
class Foo:
b: str

@dataclass_json
@dataclass
class Bar:
a: List[Foo]
b: str

src = {"a": [Bar(a=[Foo(b="foo")], b="bar")]}
src_lit = TypeEngine.to_literal(
ctx,
src,
Dict[str, List[Bar]],
TypeEngine.to_literal_type(Dict[str, List[Bar]]),
)
src_promise = Promise("val1", src_lit)

i0 = "foo"
i1 = Promise("n1", TypeEngine.to_literal(ctx, "bar", str, TypeEngine.to_literal_type(str)))

@task
def t1(a: Dict[str, List[Dict[str, str]]]):
print(a)

lits = translate_inputs_to_literals(ctx, {"a": {
"k0": [{"k00": i0, "k01": i1}],
"k1": [{"k10": src_promise["a"][0].a[0]["b"], "k11": src_promise["a"][0]["b"]}]
}}, t1.interface.inputs, t1.python_interface.inputs)
literal_map = literal_models.LiteralMap(literals=lits)
python_values = TypeEngine.literal_map_to_kwargs(ctx, literal_map, t1.python_interface.inputs)["a"]
assert python_values == {
"k0": [{"k00": "foo", "k01": "bar"}],
"k1": [{"k10": "foo", "k11": "bar"}]
}


def test_translate_inputs_to_literals_with_wrong_types():
Expand Down

0 comments on commit fee5905

Please sign in to comment.