Skip to content

Commit

Permalink
enable nested list
Browse files Browse the repository at this point in the history
Signed-off-by: Mecoli1219 <[email protected]>
  • Loading branch information
Mecoli1219 committed Oct 17, 2024
1 parent a8f3ccc commit 38ba736
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
13 changes: 10 additions & 3 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def my_wf(in1: int, in2: int) -> int:
if type(v) is Promise:
v = await resolve_attr_path_in_promise(v)
elif type(v) is list:
for i, elem in enumerate(v):
if type(elem) is Promise:
v[i] = await resolve_attr_path_in_promise(elem)
v = await resolve_attr_path_in_list(v)
result[k] = await TypeEngine.async_to_literal(ctx, v, t, var.type)
except TypeTransformerFailedError as exc:
exc.args = (f"Failed argument '{k}': {exc.args[0]}",)
Expand Down Expand Up @@ -175,6 +173,15 @@ async def resolve_attr_path_in_promise(p: Promise) -> Promise:
return p


async def resolve_attr_path_in_list(l: List[Any]) -> List[Any]:
for i, elem in enumerate(l):
if type(elem) is Promise:
l[i] = await resolve_attr_path_in_promise(elem)
elif type(elem) is list:
l[i] = await resolve_attr_path_in_list(elem)
return l


def resolve_attr_path_in_dict(d: dict, attr_path: List[Union[str, int]]) -> Any:
curr_val = d
for attr in attr_path:
Expand Down
39 changes: 39 additions & 0 deletions tests/flytekit/unit/core/test_promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,45 @@ def t1(a: List[str]):
assert python_values == ["foo", "bar", "foo", "bar"]


def test_translate_inputs_to_literals_with_nested_list():
ctx = context_manager.FlyteContext.current_context()

@dataclass_json
@dataclass
class Foo:
b: str

@dataclass_json
@dataclass
class Bar:
a: List[Foo]
b: str

src = {"a": [Bar(a=[Foo(b="foo")], b="bar")]}
src_lit = TypeEngine.to_literal(
ctx,
src,
Dict[str, List[Bar]],
TypeEngine.to_literal_type(Dict[str, List[Bar]]),
)
src_promise = Promise("val1", src_lit)

i0 = "foo"
i1 = Promise("n1", TypeEngine.to_literal(ctx, "bar", str, TypeEngine.to_literal_type(str)))

@task
def t1(a: List[List[str]]):
print(a)


lits = translate_inputs_to_literals(ctx, {"a": [[
i0, i1, src_promise["a"][0]["a"][0]["b"], src_promise["a"][0]["b"]
]]}, t1.interface.inputs, t1.python_interface.inputs)
literal_map = literal_models.LiteralMap(literals=lits)
python_values = TypeEngine.literal_map_to_kwargs(ctx, literal_map, t1.python_interface.inputs)["a"]
assert python_values == [["foo", "bar", "foo", "bar"]]


def test_translate_inputs_to_literals_with_wrong_types():
ctx = context_manager.FlyteContext.current_context()
with pytest.raises(TypeError, match="Cannot convert"):
Expand Down

0 comments on commit 38ba736

Please sign in to comment.