diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index e95835b53f..2394a9b1af 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -97,8 +97,7 @@ def my_wf(in1: int, in2: int) -> int: var = flyte_interface_types[k] t = native_types[k] try: - if type(v) is Promise: - v = await resolve_attr_path_in_promise(v) + v = await resolve_attr_path_recursively(v) result[k] = await TypeEngine.async_to_literal(ctx, v, t, var.type) except TypeTransformerFailedError as exc: exc.args = (f"Failed argument '{k}': {exc.args[0]}",) @@ -110,6 +109,21 @@ def my_wf(in1: int, in2: int) -> int: translate_inputs_to_literals = loop_manager.synced(_translate_inputs_to_literals) +async def resolve_attr_path_recursively(v: Any) -> Any: + """ + This function resolves the attribute path in a nested structure recursively. + """ + if isinstance(v, Promise): + v = await resolve_attr_path_in_promise(v) + elif isinstance(v, list): + for i, elem in enumerate(v): + v[i] = await resolve_attr_path_recursively(elem) + elif isinstance(v, dict): + for k, elem in v.items(): + v[k] = await resolve_attr_path_recursively(elem) + return v + + async def resolve_attr_path_in_promise(p: Promise) -> Promise: """ resolve_attr_path_in_promise resolves the attribute path in a promise and returns a new promise with the resolved value diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 8154f9994e..a8a6e3444d 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -21,6 +21,7 @@ ) from flytekit.core.type_engine import TypeEngine from flytekit.exceptions.user import FlyteAssertion, FlytePromiseAttributeResolveException +from flytekit.models import literals as literal_models from flytekit.models.types import LiteralType, SimpleType, TypeStructure from flytekit.types.pickle.pickle import BatchSize @@ -147,6 +148,134 @@ def t1(a: typing.Union[float, MyDataclass, Annotated[typing.List[typing.Any], Ba translate_inputs_to_literals(ctx, {"a": input}, t1.interface.inputs, t1.python_interface.inputs) +def test_translate_inputs_to_literals_with_list(): + ctx = context_manager.FlyteContext.current_context() + + @dataclass_json + @dataclass + class Foo: + b: str + + @dataclass_json + @dataclass + class Bar: + a: List[Foo] + b: str + + src = {"a": [Bar(a=[Foo(b="foo")], b="bar")]} + src_lit = TypeEngine.to_literal( + ctx, + src, + Dict[str, List[Bar]], + TypeEngine.to_literal_type(Dict[str, List[Bar]]), + ) + src_promise = Promise("val1", src_lit) + + i0 = "foo" + i1 = Promise("n1", TypeEngine.to_literal(ctx, "bar", str, TypeEngine.to_literal_type(str))) + + @task + def t1(a: List[str]): + print(a) + + + lits = translate_inputs_to_literals(ctx, {"a": [ + i0, i1, src_promise["a"][0].a[0]["b"], src_promise["a"][0]["b"] + ]}, t1.interface.inputs, t1.python_interface.inputs) + literal_map = literal_models.LiteralMap(literals=lits) + python_values = TypeEngine.literal_map_to_kwargs(ctx, literal_map, t1.python_interface.inputs)["a"] + assert python_values == ["foo", "bar", "foo", "bar"] + + +def test_translate_inputs_to_literals_with_dict(): + ctx = context_manager.FlyteContext.current_context() + + @dataclass_json + @dataclass + class Foo: + b: str + + @dataclass_json + @dataclass + class Bar: + a: List[Foo] + b: str + + src = {"a": [Bar(a=[Foo(b="foo")], b="bar")]} + src_lit = TypeEngine.to_literal( + ctx, + src, + Dict[str, List[Bar]], + TypeEngine.to_literal_type(Dict[str, List[Bar]]), + ) + src_promise = Promise("val1", src_lit) + + i0 = "foo" + i1 = Promise("n1", TypeEngine.to_literal(ctx, "bar", str, TypeEngine.to_literal_type(str))) + + @task + def t1(a: Dict[str, str]): + print(a) + + + lits = translate_inputs_to_literals(ctx, {"a": { + "k0": i0, + "k1": i1, + "k2": src_promise["a"][0].a[0]["b"], + "k3": src_promise["a"][0]["b"] + }}, t1.interface.inputs, t1.python_interface.inputs) + literal_map = literal_models.LiteralMap(literals=lits) + python_values = TypeEngine.literal_map_to_kwargs(ctx, literal_map, t1.python_interface.inputs)["a"] + assert python_values == { + "k0": "foo", + "k1": "bar", + "k2": "foo", + "k3": "bar", + } + + +def test_translate_inputs_to_literals_with_nested_list_and_dict(): + ctx = context_manager.FlyteContext.current_context() + + @dataclass_json + @dataclass + class Foo: + b: str + + @dataclass_json + @dataclass + class Bar: + a: List[Foo] + b: str + + src = {"a": [Bar(a=[Foo(b="foo")], b="bar")]} + src_lit = TypeEngine.to_literal( + ctx, + src, + Dict[str, List[Bar]], + TypeEngine.to_literal_type(Dict[str, List[Bar]]), + ) + src_promise = Promise("val1", src_lit) + + i0 = "foo" + i1 = Promise("n1", TypeEngine.to_literal(ctx, "bar", str, TypeEngine.to_literal_type(str))) + + @task + def t1(a: Dict[str, List[Dict[str, str]]]): + print(a) + + lits = translate_inputs_to_literals(ctx, {"a": { + "k0": [{"k00": i0, "k01": i1}], + "k1": [{"k10": src_promise["a"][0].a[0]["b"], "k11": src_promise["a"][0]["b"]}] + }}, t1.interface.inputs, t1.python_interface.inputs) + literal_map = literal_models.LiteralMap(literals=lits) + python_values = TypeEngine.literal_map_to_kwargs(ctx, literal_map, t1.python_interface.inputs)["a"] + assert python_values == { + "k0": [{"k00": "foo", "k01": "bar"}], + "k1": [{"k10": "foo", "k11": "bar"}] + } + + def test_translate_inputs_to_literals_with_wrong_types(): ctx = context_manager.FlyteContext.current_context() with pytest.raises(TypeError, match="Cannot convert"):