From a8f3ccc9609b9ba1a5cd73ea3c95afbab2e94fda Mon Sep 17 00:00:00 2001 From: Mecoli1219 Date: Thu, 17 Oct 2024 02:05:05 -0700 Subject: [PATCH 1/4] enable attr path to list Signed-off-by: Mecoli1219 --- flytekit/core/promise.py | 4 +++ tests/flytekit/unit/core/test_promise.py | 40 ++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index e95835b53f..44a50d4c6f 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -99,6 +99,10 @@ def my_wf(in1: int, in2: int) -> int: try: if type(v) is Promise: v = await resolve_attr_path_in_promise(v) + elif type(v) is list: + for i, elem in enumerate(v): + if type(elem) is Promise: + v[i] = await resolve_attr_path_in_promise(elem) result[k] = await TypeEngine.async_to_literal(ctx, v, t, var.type) except TypeTransformerFailedError as exc: exc.args = (f"Failed argument '{k}': {exc.args[0]}",) diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 8154f9994e..1b57c03416 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -21,6 +21,7 @@ ) from flytekit.core.type_engine import TypeEngine from flytekit.exceptions.user import FlyteAssertion, FlytePromiseAttributeResolveException +from flytekit.models import literals as literal_models from flytekit.models.types import LiteralType, SimpleType, TypeStructure from flytekit.types.pickle.pickle import BatchSize @@ -147,6 +148,45 @@ def t1(a: typing.Union[float, MyDataclass, Annotated[typing.List[typing.Any], Ba translate_inputs_to_literals(ctx, {"a": input}, t1.interface.inputs, t1.python_interface.inputs) +def test_translate_inputs_to_literals_with_list(): + ctx = context_manager.FlyteContext.current_context() + + @dataclass_json + @dataclass + class Foo: + b: str + + @dataclass_json + @dataclass + class Bar: + a: List[Foo] + b: str + + src = {"a": [Bar(a=[Foo(b="foo")], b="bar")]} + src_lit = TypeEngine.to_literal( + ctx, + src, + Dict[str, List[Bar]], + TypeEngine.to_literal_type(Dict[str, List[Bar]]), + ) + src_promise = Promise("val1", src_lit) + + i0 = "foo" + i1 = Promise("n1", TypeEngine.to_literal(ctx, "bar", str, TypeEngine.to_literal_type(str))) + + @task + def t1(a: List[str]): + print(a) + + + lits = translate_inputs_to_literals(ctx, {"a": [ + i0, i1, src_promise["a"][0]["a"][0]["b"], src_promise["a"][0]["b"] + ]}, t1.interface.inputs, t1.python_interface.inputs) + literal_map = literal_models.LiteralMap(literals=lits) + python_values = TypeEngine.literal_map_to_kwargs(ctx, literal_map, t1.python_interface.inputs)["a"] + assert python_values == ["foo", "bar", "foo", "bar"] + + def test_translate_inputs_to_literals_with_wrong_types(): ctx = context_manager.FlyteContext.current_context() with pytest.raises(TypeError, match="Cannot convert"): From 38ba736eae875abb8d2830b0ca60f5b7290cf49e Mon Sep 17 00:00:00 2001 From: Mecoli1219 Date: Thu, 17 Oct 2024 11:30:49 -0700 Subject: [PATCH 2/4] enable nested list Signed-off-by: Mecoli1219 --- flytekit/core/promise.py | 13 ++++++-- tests/flytekit/unit/core/test_promise.py | 39 ++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 44a50d4c6f..bbb421c852 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -100,9 +100,7 @@ def my_wf(in1: int, in2: int) -> int: if type(v) is Promise: v = await resolve_attr_path_in_promise(v) elif type(v) is list: - for i, elem in enumerate(v): - if type(elem) is Promise: - v[i] = await resolve_attr_path_in_promise(elem) + v = await resolve_attr_path_in_list(v) result[k] = await TypeEngine.async_to_literal(ctx, v, t, var.type) except TypeTransformerFailedError as exc: exc.args = (f"Failed argument '{k}': {exc.args[0]}",) @@ -175,6 +173,15 @@ async def resolve_attr_path_in_promise(p: Promise) -> Promise: return p +async def resolve_attr_path_in_list(l: List[Any]) -> List[Any]: + for i, elem in enumerate(l): + if type(elem) is Promise: + l[i] = await resolve_attr_path_in_promise(elem) + elif type(elem) is list: + l[i] = await resolve_attr_path_in_list(elem) + return l + + def resolve_attr_path_in_dict(d: dict, attr_path: List[Union[str, int]]) -> Any: curr_val = d for attr in attr_path: diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 1b57c03416..9a821071b2 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -187,6 +187,45 @@ def t1(a: List[str]): assert python_values == ["foo", "bar", "foo", "bar"] +def test_translate_inputs_to_literals_with_nested_list(): + ctx = context_manager.FlyteContext.current_context() + + @dataclass_json + @dataclass + class Foo: + b: str + + @dataclass_json + @dataclass + class Bar: + a: List[Foo] + b: str + + src = {"a": [Bar(a=[Foo(b="foo")], b="bar")]} + src_lit = TypeEngine.to_literal( + ctx, + src, + Dict[str, List[Bar]], + TypeEngine.to_literal_type(Dict[str, List[Bar]]), + ) + src_promise = Promise("val1", src_lit) + + i0 = "foo" + i1 = Promise("n1", TypeEngine.to_literal(ctx, "bar", str, TypeEngine.to_literal_type(str))) + + @task + def t1(a: List[List[str]]): + print(a) + + + lits = translate_inputs_to_literals(ctx, {"a": [[ + i0, i1, src_promise["a"][0]["a"][0]["b"], src_promise["a"][0]["b"] + ]]}, t1.interface.inputs, t1.python_interface.inputs) + literal_map = literal_models.LiteralMap(literals=lits) + python_values = TypeEngine.literal_map_to_kwargs(ctx, literal_map, t1.python_interface.inputs)["a"] + assert python_values == [["foo", "bar", "foo", "bar"]] + + def test_translate_inputs_to_literals_with_wrong_types(): ctx = context_manager.FlyteContext.current_context() with pytest.raises(TypeError, match="Cannot convert"): From 6b676870f20cdcff49464e7f8a824c43980ace20 Mon Sep 17 00:00:00 2001 From: Mecoli1219 Date: Thu, 17 Oct 2024 16:42:35 -0700 Subject: [PATCH 3/4] nit Signed-off-by: Mecoli1219 --- flytekit/core/promise.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index bbb421c852..6a73d2231d 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -97,10 +97,7 @@ def my_wf(in1: int, in2: int) -> int: var = flyte_interface_types[k] t = native_types[k] try: - if type(v) is Promise: - v = await resolve_attr_path_in_promise(v) - elif type(v) is list: - v = await resolve_attr_path_in_list(v) + v = await resolve_attr_path_recursively(v) result[k] = await TypeEngine.async_to_literal(ctx, v, t, var.type) except TypeTransformerFailedError as exc: exc.args = (f"Failed argument '{k}': {exc.args[0]}",) @@ -173,13 +170,16 @@ async def resolve_attr_path_in_promise(p: Promise) -> Promise: return p -async def resolve_attr_path_in_list(l: List[Any]) -> List[Any]: - for i, elem in enumerate(l): - if type(elem) is Promise: - l[i] = await resolve_attr_path_in_promise(elem) - elif type(elem) is list: - l[i] = await resolve_attr_path_in_list(elem) - return l +async def resolve_attr_path_recursively(v: Any) -> Any: + """ + This function resolves the attribute path in a nested structure recursively. + """ + if isinstance(v, Promise): + v = await resolve_attr_path_in_promise(v) + elif isinstance(v, list): + for i, elem in enumerate(v): + v[i] = await resolve_attr_path_recursively(elem) + return v def resolve_attr_path_in_dict(d: dict, attr_path: List[Union[str, int]]) -> Any: From fee59052382e8ac62f2b5a5e8d2bbb5c8662734d Mon Sep 17 00:00:00 2001 From: Mecoli1219 Date: Thu, 17 Oct 2024 16:46:06 -0700 Subject: [PATCH 4/4] Enable dict support Signed-off-by: Mecoli1219 --- flytekit/core/promise.py | 27 +++++----- tests/flytekit/unit/core/test_promise.py | 64 +++++++++++++++++++++--- 2 files changed, 72 insertions(+), 19 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 6a73d2231d..2394a9b1af 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -109,6 +109,21 @@ def my_wf(in1: int, in2: int) -> int: translate_inputs_to_literals = loop_manager.synced(_translate_inputs_to_literals) +async def resolve_attr_path_recursively(v: Any) -> Any: + """ + This function resolves the attribute path in a nested structure recursively. + """ + if isinstance(v, Promise): + v = await resolve_attr_path_in_promise(v) + elif isinstance(v, list): + for i, elem in enumerate(v): + v[i] = await resolve_attr_path_recursively(elem) + elif isinstance(v, dict): + for k, elem in v.items(): + v[k] = await resolve_attr_path_recursively(elem) + return v + + async def resolve_attr_path_in_promise(p: Promise) -> Promise: """ resolve_attr_path_in_promise resolves the attribute path in a promise and returns a new promise with the resolved value @@ -170,18 +185,6 @@ async def resolve_attr_path_in_promise(p: Promise) -> Promise: return p -async def resolve_attr_path_recursively(v: Any) -> Any: - """ - This function resolves the attribute path in a nested structure recursively. - """ - if isinstance(v, Promise): - v = await resolve_attr_path_in_promise(v) - elif isinstance(v, list): - for i, elem in enumerate(v): - v[i] = await resolve_attr_path_recursively(elem) - return v - - def resolve_attr_path_in_dict(d: dict, attr_path: List[Union[str, int]]) -> Any: curr_val = d for attr in attr_path: diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 9a821071b2..a8a6e3444d 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -180,14 +180,14 @@ def t1(a: List[str]): lits = translate_inputs_to_literals(ctx, {"a": [ - i0, i1, src_promise["a"][0]["a"][0]["b"], src_promise["a"][0]["b"] + i0, i1, src_promise["a"][0].a[0]["b"], src_promise["a"][0]["b"] ]}, t1.interface.inputs, t1.python_interface.inputs) literal_map = literal_models.LiteralMap(literals=lits) python_values = TypeEngine.literal_map_to_kwargs(ctx, literal_map, t1.python_interface.inputs)["a"] assert python_values == ["foo", "bar", "foo", "bar"] -def test_translate_inputs_to_literals_with_nested_list(): +def test_translate_inputs_to_literals_with_dict(): ctx = context_manager.FlyteContext.current_context() @dataclass_json @@ -214,16 +214,66 @@ class Bar: i1 = Promise("n1", TypeEngine.to_literal(ctx, "bar", str, TypeEngine.to_literal_type(str))) @task - def t1(a: List[List[str]]): + def t1(a: Dict[str, str]): print(a) - lits = translate_inputs_to_literals(ctx, {"a": [[ - i0, i1, src_promise["a"][0]["a"][0]["b"], src_promise["a"][0]["b"] - ]]}, t1.interface.inputs, t1.python_interface.inputs) + lits = translate_inputs_to_literals(ctx, {"a": { + "k0": i0, + "k1": i1, + "k2": src_promise["a"][0].a[0]["b"], + "k3": src_promise["a"][0]["b"] + }}, t1.interface.inputs, t1.python_interface.inputs) literal_map = literal_models.LiteralMap(literals=lits) python_values = TypeEngine.literal_map_to_kwargs(ctx, literal_map, t1.python_interface.inputs)["a"] - assert python_values == [["foo", "bar", "foo", "bar"]] + assert python_values == { + "k0": "foo", + "k1": "bar", + "k2": "foo", + "k3": "bar", + } + + +def test_translate_inputs_to_literals_with_nested_list_and_dict(): + ctx = context_manager.FlyteContext.current_context() + + @dataclass_json + @dataclass + class Foo: + b: str + + @dataclass_json + @dataclass + class Bar: + a: List[Foo] + b: str + + src = {"a": [Bar(a=[Foo(b="foo")], b="bar")]} + src_lit = TypeEngine.to_literal( + ctx, + src, + Dict[str, List[Bar]], + TypeEngine.to_literal_type(Dict[str, List[Bar]]), + ) + src_promise = Promise("val1", src_lit) + + i0 = "foo" + i1 = Promise("n1", TypeEngine.to_literal(ctx, "bar", str, TypeEngine.to_literal_type(str))) + + @task + def t1(a: Dict[str, List[Dict[str, str]]]): + print(a) + + lits = translate_inputs_to_literals(ctx, {"a": { + "k0": [{"k00": i0, "k01": i1}], + "k1": [{"k10": src_promise["a"][0].a[0]["b"], "k11": src_promise["a"][0]["b"]}] + }}, t1.interface.inputs, t1.python_interface.inputs) + literal_map = literal_models.LiteralMap(literals=lits) + python_values = TypeEngine.literal_map_to_kwargs(ctx, literal_map, t1.python_interface.inputs)["a"] + assert python_values == { + "k0": [{"k00": "foo", "k01": "bar"}], + "k1": [{"k10": "foo", "k11": "bar"}] + } def test_translate_inputs_to_literals_with_wrong_types():