Skip to content

Commit

Permalink
Fixing out of order for conditional outputs (flyteorg#843)
Browse files Browse the repository at this point in the history
  • Loading branch information
kumare3 authored and kennyworkman committed Feb 8, 2022
1 parent b431c0d commit f9525ac
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions flytekit/core/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,29 +114,39 @@ def end_branch(self) -> Optional[Union[Condition, Promise, Tuple[Promise], VoidP
def if_(self, expr: bool) -> Case:
return self._condition._if(expr)

def compute_output_set(self) -> typing.Optional[typing.Set[str]]:
def compute_output_vars(self) -> typing.Optional[typing.List[str]]:
"""
Computes and returns the minimum set of outputs for this conditional block, based on all the cases that have
been registered
"""
output_var_sets: typing.List[typing.Set[str]] = []
output_vars: typing.List[str] = []
output_vars_set = set()
for c in self._cases:
if c.output_promise is None and c.err is None:
# One node returns a void output and no error, we will default to None return
return None
if c.output_promise is not None:
var = []
if isinstance(c.output_promise, tuple):
output_var_sets.append(set([i.var for i in c.output_promise]))
var = [i.var for i in c.output_promise]
else:
output_var_sets.append({c.output_promise.var})
curr = output_var_sets[0]
if len(output_var_sets) > 1:
for x in output_var_sets[1:]:
curr = curr.intersection(x)
return curr
var = [c.output_promise.var]
curr_set = set(var)
if not output_vars:
output_vars = var
output_vars_set = curr_set
else:
output_vars_set = output_vars_set.intersection(curr_set)
new_output_var = []
for v in output_vars:
if v in output_vars_set:
new_output_var.append(v)
output_vars = new_output_var

return output_vars

def _compute_outputs(self, n: Node) -> Optional[Union[Promise, Tuple[Promise], VoidPromise]]:
curr = self.compute_output_set()
curr = self.compute_output_vars()
if curr is None:
return VoidPromise(n.id)
promises = [Promise(var=x, val=NodeOutput(node=n, var=x)) for x in curr]
Expand Down Expand Up @@ -197,7 +207,7 @@ def _compute_outputs(self, selected_output_promise) -> Optional[Union[Tuple[Prom
"""
For the local execution case only returns the least common set of outputs
"""
curr = self.compute_output_set()
curr = self.compute_output_vars()
if curr is None:
return VoidPromise(self.name)
if not isinstance(selected_output_promise, tuple):
Expand All @@ -221,7 +231,7 @@ def end_branch(self) -> Optional[Union[Condition, Tuple[Promise], Promise, VoidP
"""
if self._last_case:
FlyteContextManager.pop_context()
curr = self.compute_output_set()
curr = self.compute_output_vars()
if curr is None:
return VoidPromise(self.name)
promises = [Promise(var=x, val=None) for x in curr]
Expand Down

0 comments on commit f9525ac

Please sign in to comment.