From f9525acde150454e659afce39193271e21bf6be4 Mon Sep 17 00:00:00 2001 From: Ketan Umare <16888709+kumare3@users.noreply.github.com> Date: Wed, 2 Feb 2022 17:47:10 -0800 Subject: [PATCH] Fixing out of order for conditional outputs (#843) --- flytekit/core/condition.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/flytekit/core/condition.py b/flytekit/core/condition.py index a38d0e2ab12..9b2670a6839 100644 --- a/flytekit/core/condition.py +++ b/flytekit/core/condition.py @@ -114,29 +114,39 @@ def end_branch(self) -> Optional[Union[Condition, Promise, Tuple[Promise], VoidP def if_(self, expr: bool) -> Case: return self._condition._if(expr) - def compute_output_set(self) -> typing.Optional[typing.Set[str]]: + def compute_output_vars(self) -> typing.Optional[typing.List[str]]: """ Computes and returns the minimum set of outputs for this conditional block, based on all the cases that have been registered """ - output_var_sets: typing.List[typing.Set[str]] = [] + output_vars: typing.List[str] = [] + output_vars_set = set() for c in self._cases: if c.output_promise is None and c.err is None: # One node returns a void output and no error, we will default to None return return None if c.output_promise is not None: + var = [] if isinstance(c.output_promise, tuple): - output_var_sets.append(set([i.var for i in c.output_promise])) + var = [i.var for i in c.output_promise] else: - output_var_sets.append({c.output_promise.var}) - curr = output_var_sets[0] - if len(output_var_sets) > 1: - for x in output_var_sets[1:]: - curr = curr.intersection(x) - return curr + var = [c.output_promise.var] + curr_set = set(var) + if not output_vars: + output_vars = var + output_vars_set = curr_set + else: + output_vars_set = output_vars_set.intersection(curr_set) + new_output_var = [] + for v in output_vars: + if v in output_vars_set: + new_output_var.append(v) + output_vars = new_output_var + + return output_vars def _compute_outputs(self, n: Node) -> Optional[Union[Promise, Tuple[Promise], VoidPromise]]: - curr = self.compute_output_set() + curr = self.compute_output_vars() if curr is None: return VoidPromise(n.id) promises = [Promise(var=x, val=NodeOutput(node=n, var=x)) for x in curr] @@ -197,7 +207,7 @@ def _compute_outputs(self, selected_output_promise) -> Optional[Union[Tuple[Prom """ For the local execution case only returns the least common set of outputs """ - curr = self.compute_output_set() + curr = self.compute_output_vars() if curr is None: return VoidPromise(self.name) if not isinstance(selected_output_promise, tuple): @@ -221,7 +231,7 @@ def end_branch(self) -> Optional[Union[Condition, Tuple[Promise], Promise, VoidP """ if self._last_case: FlyteContextManager.pop_context() - curr = self.compute_output_set() + curr = self.compute_output_vars() if curr is None: return VoidPromise(self.name) promises = [Promise(var=x, val=None) for x in curr]