Skip to content

Commit

Permalink
Enable Resolve Attr Path for List or Dict of Promise (#2828)
Browse files Browse the repository at this point in the history
* enable attr path to list

Signed-off-by: Mecoli1219 <[email protected]>

* enable nested list

Signed-off-by: Mecoli1219 <[email protected]>

* nit

Signed-off-by: Mecoli1219 <[email protected]>

* Enable dict support

Signed-off-by: Mecoli1219 <[email protected]>

---------

Signed-off-by: Mecoli1219 <[email protected]>
  • Loading branch information
Mecoli1219 authored Oct 18, 2024
1 parent c4674c3 commit 816015d
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 2 deletions.
18 changes: 16 additions & 2 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def my_wf(in1: int, in2: int) -> int:
var = flyte_interface_types[k]
t = native_types[k]
try:
if type(v) is Promise:
v = await resolve_attr_path_in_promise(v)
v = await resolve_attr_path_recursively(v)
result[k] = await TypeEngine.async_to_literal(ctx, v, t, var.type)
except TypeTransformerFailedError as exc:
exc.args = (f"Failed argument '{k}': {exc.args[0]}",)
Expand All @@ -110,6 +109,21 @@ def my_wf(in1: int, in2: int) -> int:
translate_inputs_to_literals = loop_manager.synced(_translate_inputs_to_literals)


async def resolve_attr_path_recursively(v: Any) -> Any:
"""
This function resolves the attribute path in a nested structure recursively.
"""
if isinstance(v, Promise):
v = await resolve_attr_path_in_promise(v)
elif isinstance(v, list):
for i, elem in enumerate(v):
v[i] = await resolve_attr_path_recursively(elem)
elif isinstance(v, dict):
for k, elem in v.items():
v[k] = await resolve_attr_path_recursively(elem)
return v


async def resolve_attr_path_in_promise(p: Promise) -> Promise:
"""
resolve_attr_path_in_promise resolves the attribute path in a promise and returns a new promise with the resolved value
Expand Down
129 changes: 129 additions & 0 deletions tests/flytekit/unit/core/test_promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions.user import FlyteAssertion, FlytePromiseAttributeResolveException
from flytekit.models import literals as literal_models
from flytekit.models.types import LiteralType, SimpleType, TypeStructure
from flytekit.types.pickle.pickle import BatchSize

Expand Down Expand Up @@ -147,6 +148,134 @@ def t1(a: typing.Union[float, MyDataclass, Annotated[typing.List[typing.Any], Ba
translate_inputs_to_literals(ctx, {"a": input}, t1.interface.inputs, t1.python_interface.inputs)


def test_translate_inputs_to_literals_with_list():
ctx = context_manager.FlyteContext.current_context()

@dataclass_json
@dataclass
class Foo:
b: str

@dataclass_json
@dataclass
class Bar:
a: List[Foo]
b: str

src = {"a": [Bar(a=[Foo(b="foo")], b="bar")]}
src_lit = TypeEngine.to_literal(
ctx,
src,
Dict[str, List[Bar]],
TypeEngine.to_literal_type(Dict[str, List[Bar]]),
)
src_promise = Promise("val1", src_lit)

i0 = "foo"
i1 = Promise("n1", TypeEngine.to_literal(ctx, "bar", str, TypeEngine.to_literal_type(str)))

@task
def t1(a: List[str]):
print(a)


lits = translate_inputs_to_literals(ctx, {"a": [
i0, i1, src_promise["a"][0].a[0]["b"], src_promise["a"][0]["b"]
]}, t1.interface.inputs, t1.python_interface.inputs)
literal_map = literal_models.LiteralMap(literals=lits)
python_values = TypeEngine.literal_map_to_kwargs(ctx, literal_map, t1.python_interface.inputs)["a"]
assert python_values == ["foo", "bar", "foo", "bar"]


def test_translate_inputs_to_literals_with_dict():
ctx = context_manager.FlyteContext.current_context()

@dataclass_json
@dataclass
class Foo:
b: str

@dataclass_json
@dataclass
class Bar:
a: List[Foo]
b: str

src = {"a": [Bar(a=[Foo(b="foo")], b="bar")]}
src_lit = TypeEngine.to_literal(
ctx,
src,
Dict[str, List[Bar]],
TypeEngine.to_literal_type(Dict[str, List[Bar]]),
)
src_promise = Promise("val1", src_lit)

i0 = "foo"
i1 = Promise("n1", TypeEngine.to_literal(ctx, "bar", str, TypeEngine.to_literal_type(str)))

@task
def t1(a: Dict[str, str]):
print(a)


lits = translate_inputs_to_literals(ctx, {"a": {
"k0": i0,
"k1": i1,
"k2": src_promise["a"][0].a[0]["b"],
"k3": src_promise["a"][0]["b"]
}}, t1.interface.inputs, t1.python_interface.inputs)
literal_map = literal_models.LiteralMap(literals=lits)
python_values = TypeEngine.literal_map_to_kwargs(ctx, literal_map, t1.python_interface.inputs)["a"]
assert python_values == {
"k0": "foo",
"k1": "bar",
"k2": "foo",
"k3": "bar",
}


def test_translate_inputs_to_literals_with_nested_list_and_dict():
ctx = context_manager.FlyteContext.current_context()

@dataclass_json
@dataclass
class Foo:
b: str

@dataclass_json
@dataclass
class Bar:
a: List[Foo]
b: str

src = {"a": [Bar(a=[Foo(b="foo")], b="bar")]}
src_lit = TypeEngine.to_literal(
ctx,
src,
Dict[str, List[Bar]],
TypeEngine.to_literal_type(Dict[str, List[Bar]]),
)
src_promise = Promise("val1", src_lit)

i0 = "foo"
i1 = Promise("n1", TypeEngine.to_literal(ctx, "bar", str, TypeEngine.to_literal_type(str)))

@task
def t1(a: Dict[str, List[Dict[str, str]]]):
print(a)

lits = translate_inputs_to_literals(ctx, {"a": {
"k0": [{"k00": i0, "k01": i1}],
"k1": [{"k10": src_promise["a"][0].a[0]["b"], "k11": src_promise["a"][0]["b"]}]
}}, t1.interface.inputs, t1.python_interface.inputs)
literal_map = literal_models.LiteralMap(literals=lits)
python_values = TypeEngine.literal_map_to_kwargs(ctx, literal_map, t1.python_interface.inputs)["a"]
assert python_values == {
"k0": [{"k00": "foo", "k01": "bar"}],
"k1": [{"k10": "foo", "k11": "bar"}]
}


def test_translate_inputs_to_literals_with_wrong_types():
ctx = context_manager.FlyteContext.current_context()
with pytest.raises(TypeError, match="Cannot convert"):
Expand Down

0 comments on commit 816015d

Please sign in to comment.