Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotations branch fix #255

Merged
merged 10 commits into from
Nov 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions flytekit/annotated/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,34 @@ def else_(self) -> Case:
}


def create_branch_node_promise_var(node_id: str, var: str) -> str:
"""
Generates a globally (wf-level) unique id for a variable.

When building bindings for the branch node, the inputs to the conditions (e.g. (x==5)) need to have variable names
(e.g. x). Because it's currently infeasible to get the name (e.g. x), we resolve to using the referenced node's
output name (e.g. out_0, my_out,... etc.). In order to avoid naming collisions (in cases when, for example, the
conditions reference two outputs of two different nodes named the same), we build a variable name composed of the
referenced node name + '.' + the referenced output name. Ideally we use something like
(https://github.com/pwwang/python-varname) to retrieve the assigned variable name (e.g. x). However, because of
https://github.com/pwwang/python-varname/issues/28, this is not currently supported for all AST nodes types.

:param str node_id: the original node_id that produced the variable.
:param str var: the output variable name from the original node.
:return: The generated unique id of the variable.
"""
return f"{node_id}.{var}"
EngHabu marked this conversation as resolved.
Show resolved Hide resolved


def merge_promises(*args: Promise) -> typing.List[Promise]:
node_vars: typing.Set[typing.Tuple[str, str]] = set()
merged_promises: typing.List[Promise] = []
for p in args:
if p is not None and p.ref:
node_var = (p.ref.node_id, p.ref.var)
if node_var not in node_vars:
merged_promises.append(p)
new_p = p.with_var(create_branch_node_promise_var(p.ref.node_id, p.ref.var))
merged_promises.append(new_p)
node_vars.add(node_var)
return merged_promises

Expand All @@ -292,7 +312,7 @@ def transform_to_conj_expr(expr: ConjunctionExpression) -> (_core_cond.Conjuncti

def transform_to_operand(v: Union[Promise, Literal]) -> (_core_cond.Operand, Optional[Promise]):
if isinstance(v, Promise):
return _core_cond.Operand(var=v.var), v
return _core_cond.Operand(var=create_branch_node_promise_var(v.ref.node_id, v.var)), v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so indeed this was the only change needed on flytekit side?

return _core_cond.Operand(primitive=v.scalar.primitive), None


Expand Down
104 changes: 104 additions & 0 deletions tests/flytekit/unit/annotated/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,110 @@ def my_wf(a: int, b: str) -> (int, str):
assert wf.nodes[2].branch_node is not None


def test_serialization_branch_sub_wf():
@task
def t1(a: int) -> int:
return a + 2

@workflow
def my_sub_wf(a: int) -> int:
return t1(a=a)

@workflow
def my_wf(a: int) -> int:
d = conditional("test1").if_(a > 3).then(t1(a=a)).else_().then(my_sub_wf(a=a))
return d

ctx = FlyteContext.current_context()
default_img = Image(name="default", fqn="test", tag="tag")
registration_settings = context_manager.RegistrationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)
with ctx.current_context().new_registration_settings(registration_settings=registration_settings):
wf = my_wf.get_registerable_entity()
assert wf is not None
assert len(wf.nodes[0].inputs) == 1
assert wf.nodes[0].inputs[0].var == ".a"
assert wf.nodes[0] is not None


def test_serialization_branch_compound_conditions():
@task
def t1(a: int) -> int:
return a + 2

@workflow
def my_wf(a: int) -> int:
d = (
conditional("test1")
.if_((a == 4) | (a == 3))
.then(t1(a=a))
.elif_(a < 6)
.then(t1(a=a))
.else_()
.fail("Unable to choose branch")
)
return d

ctx = FlyteContext.current_context()
default_img = Image(name="default", fqn="test", tag="tag")
registration_settings = context_manager.RegistrationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)
with ctx.current_context().new_registration_settings(registration_settings=registration_settings):
wf = my_wf.get_registerable_entity()
assert wf is not None
assert len(wf.nodes[0].inputs) == 1
assert wf.nodes[0].inputs[0].var == ".a"


def test_serialization_branch_complex_2():
@task
def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
return a + 2, "world"

@task
def t2(a: str) -> str:
return a

@workflow
def my_wf(a: int, b: str) -> (int, str):
x, y = t1(a=a)
d = (
conditional("test1")
.if_(x == 4)
EngHabu marked this conversation as resolved.
Show resolved Hide resolved
.then(t2(a=b))
.elif_(x >= 5)
.then(t2(a=y))
.else_()
.fail("Unable to choose branch")
)
f = conditional("test2").if_(d == "hello ").then(t2(a="It is hello")).else_().then(t2(a="Not Hello!"))
return x, f

ctx = FlyteContext.current_context()
default_img = Image(name="default", fqn="test", tag="tag")
registration_settings = context_manager.RegistrationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)
with ctx.current_context().new_registration_settings(registration_settings=registration_settings):
wf = my_wf.get_registerable_entity()
assert wf is not None
assert wf.nodes[1].inputs[0].var == "node-0.t1_int_output"


def test_serialization_branch():
@task
def mimic(a: int) -> typing.NamedTuple("OutputsBC", c=int):
Expand Down