From 7ba3e891728228327e3ca4f63ef2baf12f80ae8f Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Thu, 16 Apr 2020 10:25:02 -0700 Subject: [PATCH] Refactor to remove Workflow from Node executor (#102) --- cmd/kubectl-flyte/cmd/get.go | 2 + cmd/kubectl-flyte/cmd/printers/node.go | 5 +- cmd/kubectl-flyte/cmd/printers/workflow.go | 25 +- pkg/apis/flyteworkflow/v1alpha1/iface.go | 52 +- .../v1alpha1/mocks/BaseWorkflow.go | 41 ++ .../v1alpha1/mocks/BaseWorkflowWithStatus.go | 41 ++ .../v1alpha1/mocks/ExecutableSubWorkflow.go | 41 ++ .../v1alpha1/mocks/ExecutableWorkflow.go | 41 ++ .../v1alpha1/mocks/ExecutionTimeInfo.go | 115 +++++ pkg/apis/flyteworkflow/v1alpha1/mocks/Meta.go | 341 +++++++++++++ .../v1alpha1/mocks/MetaExtended.go | 450 ++++++++++++++++++ .../v1alpha1/mocks/NodeGetter.go | 54 +++ .../v1alpha1/mocks/SubWorkflowGetter.go | 47 ++ .../v1alpha1/mocks/WorkflowMeta.go | 2 +- .../v1alpha1/mocks/WorkflowMetaExtended.go | 2 +- .../flyteworkflow/v1alpha1/node_status.go | 1 + pkg/apis/flyteworkflow/v1alpha1/workflow.go | 8 + pkg/controller/executors/contextual.go | 30 -- pkg/controller/executors/dag_structure.go | 48 ++ pkg/controller/executors/execution_context.go | 86 ++++ .../executors/mocks/dag_structure.go | 92 ++++ .../mocks/dag_structure_with_start_node.go | 129 +++++ .../executors/mocks/execution_context.go | 448 +++++++++++++++++ .../mocks/immutable_execution_context.go | 373 +++++++++++++++ pkg/controller/executors/mocks/node.go | 64 +-- pkg/controller/executors/mocks/node_lookup.go | 91 ++++ .../executors/mocks/sub_workflow_getter.go | 47 ++ .../executors/mocks/task_details_getter.go | 54 +++ pkg/controller/executors/node.go | 9 +- pkg/controller/executors/node_lookup.go | 72 +++ pkg/controller/nodes/branch/evaluator.go | 9 +- pkg/controller/nodes/branch/handler.go | 92 +++- pkg/controller/nodes/branch/handler_test.go | 217 ++++----- pkg/controller/nodes/dynamic/handler.go | 65 +-- pkg/controller/nodes/dynamic/handler_test.go | 284 +++++------ pkg/controller/nodes/dynamic/subworkflow.go | 91 ---- .../nodes/dynamic/subworkflow_test.go | 52 -- pkg/controller/nodes/errors/codes.go | 1 + pkg/controller/nodes/executor.go | 344 ++++++------- pkg/controller/nodes/executor_test.go | 232 ++++++--- .../handler/mocks/node_execution_context.go | 107 +++-- .../handler/mocks/node_execution_metadata.go | 71 +-- .../nodes/handler/node_exec_context.go | 8 +- pkg/controller/nodes/node_exec_context.go | 105 ++-- .../nodes/node_exec_context_test.go | 2 +- pkg/controller/nodes/output_resolver.go | 7 +- pkg/controller/nodes/predicate.go | 29 +- pkg/controller/nodes/predicate_test.go | 106 ++--- pkg/controller/nodes/resolve.go | 21 +- pkg/controller/nodes/resolve_test.go | 5 + pkg/controller/nodes/subworkflow/handler.go | 9 +- .../nodes/subworkflow/handler_test.go | 181 +++---- .../nodes/subworkflow/launchplan.go | 35 +- .../nodes/subworkflow/launchplan_test.go | 128 ++--- .../nodes/subworkflow/subworkflow.go | 158 +++--- .../nodes/subworkflow/subworkflow_test.go | 150 ++++-- pkg/controller/nodes/subworkflow/util.go | 10 +- pkg/controller/nodes/subworkflow/util_test.go | 14 +- pkg/controller/nodes/task/handler_test.go | 327 +++++++------ .../nodes/task/taskexec_context_test.go | 64 +-- pkg/controller/nodes/task/transformer.go | 9 +- pkg/controller/workflow/executor.go | 25 +- pkg/visualize/sort.go | 4 +- 63 files changed, 4264 insertions(+), 1509 deletions(-) create mode 100644 pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutionTimeInfo.go create mode 100644 pkg/apis/flyteworkflow/v1alpha1/mocks/Meta.go create mode 100644 pkg/apis/flyteworkflow/v1alpha1/mocks/MetaExtended.go create mode 100644 pkg/apis/flyteworkflow/v1alpha1/mocks/NodeGetter.go create mode 100644 pkg/apis/flyteworkflow/v1alpha1/mocks/SubWorkflowGetter.go delete mode 100644 pkg/controller/executors/contextual.go create mode 100644 pkg/controller/executors/dag_structure.go create mode 100644 pkg/controller/executors/execution_context.go create mode 100644 pkg/controller/executors/mocks/dag_structure.go create mode 100644 pkg/controller/executors/mocks/dag_structure_with_start_node.go create mode 100644 pkg/controller/executors/mocks/execution_context.go create mode 100644 pkg/controller/executors/mocks/immutable_execution_context.go create mode 100644 pkg/controller/executors/mocks/node_lookup.go create mode 100644 pkg/controller/executors/mocks/sub_workflow_getter.go create mode 100644 pkg/controller/executors/mocks/task_details_getter.go create mode 100644 pkg/controller/executors/node_lookup.go delete mode 100644 pkg/controller/nodes/dynamic/subworkflow.go delete mode 100644 pkg/controller/nodes/dynamic/subworkflow_test.go diff --git a/cmd/kubectl-flyte/cmd/get.go b/cmd/kubectl-flyte/cmd/get.go index 4973fe89b7..ff0afdd0d3 100644 --- a/cmd/kubectl-flyte/cmd/get.go +++ b/cmd/kubectl-flyte/cmd/get.go @@ -7,6 +7,7 @@ import ( "strings" gotree "github.com/DiSiqueira/GoTree" + "github.com/lyft/flytestdlib/storage" "github.com/spf13/cobra" v12 "k8s.io/api/core/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -64,6 +65,7 @@ func (g *GetOpts) getWorkflow(ctx context.Context, name string) error { } wp := printers.WorkflowPrinter{} tree := gotree.New("Workflow") + w.DataReferenceConstructor = storage.URLPathConstructor{} if err := wp.Print(ctx, tree, w); err != nil { return err } diff --git a/cmd/kubectl-flyte/cmd/printers/node.go b/cmd/kubectl-flyte/cmd/printers/node.go index d7f6b220f7..08e7e04245 100644 --- a/cmd/kubectl-flyte/cmd/printers/node.go +++ b/cmd/kubectl-flyte/cmd/printers/node.go @@ -11,8 +11,8 @@ import ( gotree "github.com/DiSiqueira/GoTree" "github.com/fatih/color" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/lyft/flytepropeller/pkg/controller/executors" "github.com/lyft/flytepropeller/pkg/controller/nodes/task" "github.com/lyft/flytepropeller/pkg/utils" ) @@ -113,8 +113,7 @@ func (p NodePrinter) traverseNode(ctx context.Context, tree gotree.Tree, w v1alp if node.GetWorkflowNode().GetSubWorkflowRef() != nil { s := w.FindSubWorkflow(*node.GetWorkflowNode().GetSubWorkflowRef()) wp := WorkflowPrinter{} - cw := executors.NewSubContextualWorkflow(w, s, nodeStatus) - return wp.Print(ctx, tree, cw) + return wp.PrintSubWorkflow(ctx, tree, w, s, nodeStatus) } case v1alpha1.NodeKindTask: sub := tree.Add(strings.Join(p.NodeInfo(w.GetName(), node, nodeStatus), " | ")) diff --git a/cmd/kubectl-flyte/cmd/printers/workflow.go b/cmd/kubectl-flyte/cmd/printers/workflow.go index 8dcc8efb78..bd807b150b 100644 --- a/cmd/kubectl-flyte/cmd/printers/workflow.go +++ b/cmd/kubectl-flyte/cmd/printers/workflow.go @@ -7,6 +7,7 @@ import ( gotree "github.com/DiSiqueira/GoTree" "github.com/fatih/color" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/lyft/flytepropeller/pkg/visualize" ) @@ -25,7 +26,7 @@ func ColorizeWorkflowPhase(p v1alpha1.WorkflowPhase) string { return color.CyanString("%s", p.String()) } -func CalculateWorkflowRuntime(s v1alpha1.ExecutableWorkflowStatus) string { +func CalculateWorkflowRuntime(s v1alpha1.ExecutionTimeInfo) string { if s.GetStartedAt() != nil { if s.GetStoppedAt() != nil { return s.GetStoppedAt().Sub(s.GetStartedAt().Time).String() @@ -35,6 +36,12 @@ func CalculateWorkflowRuntime(s v1alpha1.ExecutableWorkflowStatus) string { return "na" } +type ContextualWorkflow struct { + v1alpha1.MetaExtended + v1alpha1.ExecutableSubWorkflow + v1alpha1.NodeStatusGetter +} + type WorkflowPrinter struct { } @@ -53,6 +60,22 @@ func (p WorkflowPrinter) Print(ctx context.Context, tree gotree.Tree, w v1alpha1 return np.PrintList(ctx, newTree, w, sortedNodes) } +func (p WorkflowPrinter) PrintSubWorkflow(ctx context.Context, tree gotree.Tree, w v1alpha1.ExecutableWorkflow, swf v1alpha1.ExecutableSubWorkflow, ns v1alpha1.ExecutableNodeStatus) error { + sortedNodes, err := visualize.TopologicalSort(swf) + if err != nil { + return err + } + newTree := gotree.New(fmt.Sprintf("SubWorkflow [%s] (%s %s %s)", + swf.GetID(), CalculateWorkflowRuntime(ns), + ColorizeNodePhase(ns.GetPhase()), ns.GetMessage())) + if tree != nil { + tree.AddTree(newTree) + } + np := NodePrinter{} + + return np.PrintList(ctx, newTree, &ContextualWorkflow{MetaExtended: w, ExecutableSubWorkflow: swf, NodeStatusGetter: ns}, sortedNodes) +} + func (p WorkflowPrinter) PrintShort(tree gotree.Tree, w v1alpha1.ExecutableWorkflow) error { if tree == nil { return fmt.Errorf("bad state in printer") diff --git a/pkg/apis/flyteworkflow/v1alpha1/iface.go b/pkg/apis/flyteworkflow/v1alpha1/iface.go index 5e7e270045..9a11dfe957 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -2,7 +2,6 @@ package v1alpha1 import ( "context" - "time" v1 "k8s.io/api/core/v1" @@ -238,16 +237,20 @@ type MutableNodeStatus interface { ClearSubNodeStatus() } +type ExecutionTimeInfo interface { + GetStoppedAt() *metav1.Time + GetStartedAt() *metav1.Time + GetLastUpdatedAt() *metav1.Time +} + // Interface for a Node p. This provides a mutable API. type ExecutableNodeStatus interface { NodeStatusGetter MutableNodeStatus NodeStatusVisitor + ExecutionTimeInfo GetPhase() NodePhase GetQueuedAt() *metav1.Time - GetStoppedAt() *metav1.Time - GetStartedAt() *metav1.Time - GetLastUpdatedAt() *metav1.Time GetLastAttemptStartedAt() *metav1.Time GetParentNodeID() *NodeID GetParentTaskID() *core.TaskExecutionIdentifier @@ -324,11 +327,9 @@ type ExecutableNode interface { // Interface for the Workflow p. This is the mutable portion for a Workflow type ExecutableWorkflowStatus interface { NodeStatusGetter + ExecutionTimeInfo UpdatePhase(p WorkflowPhase, msg string) GetPhase() WorkflowPhase - GetStoppedAt() *metav1.Time - GetStartedAt() *metav1.Time - GetLastUpdatedAt() *metav1.Time IsTerminated() bool GetMessage() string SetDataDir(DataReference) @@ -340,13 +341,18 @@ type ExecutableWorkflowStatus interface { ConstructNodeDataDir(ctx context.Context, name NodeID) (storage.DataReference, error) } +type NodeGetter interface { + GetNode(nodeID NodeID) (ExecutableNode, bool) +} + type BaseWorkflow interface { + NodeGetter StartNode() ExecutableNode GetID() WorkflowID // From returns all nodes that can be reached directly // from the node with the given unique name. FromNode(name NodeID) ([]NodeID, error) - GetNode(nodeID NodeID) (ExecutableNode, bool) + ToNode(name NodeID) ([]NodeID, error) } type BaseWorkflowWithStatus interface { @@ -365,9 +371,9 @@ type ExecutableSubWorkflow interface { GetOutputs() *OutputVarMap } -// WorkflowMeta provides an interface to retrieve labels, annotations and other concepts that are declared only once +// Meta provides an interface to retrieve labels, annotations and other concepts that are declared only once // for the top level workflow -type WorkflowMeta interface { +type Meta interface { GetExecutionID() ExecutionID GetK8sWorkflowID() types.NamespacedName GetOwnerReference() metav1.OwnerReference @@ -384,17 +390,21 @@ type TaskDetailsGetter interface { GetTask(id TaskID) (ExecutableTask, error) } -type WorkflowMetaExtended interface { - WorkflowMeta - TaskDetailsGetter +type SubWorkflowGetter interface { FindSubWorkflow(subID WorkflowID) ExecutableSubWorkflow +} + +type MetaExtended interface { + Meta + TaskDetailsGetter + SubWorkflowGetter GetExecutionStatus() ExecutableWorkflowStatus } -// A Top level Workflow is a combination of WorkflowMeta and an ExecutableSubWorkflow +// A Top level Workflow is a combination of Meta and an ExecutableSubWorkflow type ExecutableWorkflow interface { ExecutableSubWorkflow - WorkflowMetaExtended + MetaExtended NodeStatusGetter } @@ -420,15 +430,3 @@ func GetOutputsFile(outputDir DataReference) DataReference { func GetInputsFile(inputDir DataReference) DataReference { return inputDir + "/inputs.pb" } - -func GetOutputErrorFile(inputDir DataReference) DataReference { - return inputDir + "/error.pb" -} - -func GetFutureFile() string { - return "futures.pb" -} - -func GetCompiledFutureFile() string { - return "futures_compiled.pb" -} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflow.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflow.go index 884c38459d..648b8ee537 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflow.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflow.go @@ -159,3 +159,44 @@ func (_m *BaseWorkflow) StartNode() v1alpha1.ExecutableNode { return r0 } + +type BaseWorkflow_ToNode struct { + *mock.Call +} + +func (_m BaseWorkflow_ToNode) Return(_a0 []string, _a1 error) *BaseWorkflow_ToNode { + return &BaseWorkflow_ToNode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *BaseWorkflow) OnToNode(name string) *BaseWorkflow_ToNode { + c := _m.On("ToNode", name) + return &BaseWorkflow_ToNode{Call: c} +} + +func (_m *BaseWorkflow) OnToNodeMatch(matchers ...interface{}) *BaseWorkflow_ToNode { + c := _m.On("ToNode", matchers...) + return &BaseWorkflow_ToNode{Call: c} +} + +// ToNode provides a mock function with given fields: name +func (_m *BaseWorkflow) ToNode(name string) ([]string, error) { + ret := _m.Called(name) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflowWithStatus.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflowWithStatus.go index dc80fea299..af867ac3b4 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflowWithStatus.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/BaseWorkflowWithStatus.go @@ -195,3 +195,44 @@ func (_m *BaseWorkflowWithStatus) StartNode() v1alpha1.ExecutableNode { return r0 } + +type BaseWorkflowWithStatus_ToNode struct { + *mock.Call +} + +func (_m BaseWorkflowWithStatus_ToNode) Return(_a0 []string, _a1 error) *BaseWorkflowWithStatus_ToNode { + return &BaseWorkflowWithStatus_ToNode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *BaseWorkflowWithStatus) OnToNode(name string) *BaseWorkflowWithStatus_ToNode { + c := _m.On("ToNode", name) + return &BaseWorkflowWithStatus_ToNode{Call: c} +} + +func (_m *BaseWorkflowWithStatus) OnToNodeMatch(matchers ...interface{}) *BaseWorkflowWithStatus_ToNode { + c := _m.On("ToNode", matchers...) + return &BaseWorkflowWithStatus_ToNode{Call: c} +} + +// ToNode provides a mock function with given fields: name +func (_m *BaseWorkflowWithStatus) ToNode(name string) ([]string, error) { + ret := _m.Called(name) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflow.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflow.go index bb46e82d75..1e7e05a443 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflow.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableSubWorkflow.go @@ -329,3 +329,44 @@ func (_m *ExecutableSubWorkflow) StartNode() v1alpha1.ExecutableNode { return r0 } + +type ExecutableSubWorkflow_ToNode struct { + *mock.Call +} + +func (_m ExecutableSubWorkflow_ToNode) Return(_a0 []string, _a1 error) *ExecutableSubWorkflow_ToNode { + return &ExecutableSubWorkflow_ToNode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ExecutableSubWorkflow) OnToNode(name string) *ExecutableSubWorkflow_ToNode { + c := _m.On("ToNode", name) + return &ExecutableSubWorkflow_ToNode{Call: c} +} + +func (_m *ExecutableSubWorkflow) OnToNodeMatch(matchers ...interface{}) *ExecutableSubWorkflow_ToNode { + c := _m.On("ToNode", matchers...) + return &ExecutableSubWorkflow_ToNode{Call: c} +} + +// ToNode provides a mock function with given fields: name +func (_m *ExecutableSubWorkflow) ToNode(name string) ([]string, error) { + ret := _m.Called(name) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go index 4526db431f..e2102afd73 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableWorkflow.go @@ -802,3 +802,44 @@ func (_m *ExecutableWorkflow) StartNode() v1alpha1.ExecutableNode { return r0 } + +type ExecutableWorkflow_ToNode struct { + *mock.Call +} + +func (_m ExecutableWorkflow_ToNode) Return(_a0 []string, _a1 error) *ExecutableWorkflow_ToNode { + return &ExecutableWorkflow_ToNode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ExecutableWorkflow) OnToNode(name string) *ExecutableWorkflow_ToNode { + c := _m.On("ToNode", name) + return &ExecutableWorkflow_ToNode{Call: c} +} + +func (_m *ExecutableWorkflow) OnToNodeMatch(matchers ...interface{}) *ExecutableWorkflow_ToNode { + c := _m.On("ToNode", matchers...) + return &ExecutableWorkflow_ToNode{Call: c} +} + +// ToNode provides a mock function with given fields: name +func (_m *ExecutableWorkflow) ToNode(name string) ([]string, error) { + ret := _m.Called(name) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(name) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutionTimeInfo.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutionTimeInfo.go new file mode 100644 index 0000000000..4199378cbb --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutionTimeInfo.go @@ -0,0 +1,115 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// ExecutionTimeInfo is an autogenerated mock type for the ExecutionTimeInfo type +type ExecutionTimeInfo struct { + mock.Mock +} + +type ExecutionTimeInfo_GetLastUpdatedAt struct { + *mock.Call +} + +func (_m ExecutionTimeInfo_GetLastUpdatedAt) Return(_a0 *v1.Time) *ExecutionTimeInfo_GetLastUpdatedAt { + return &ExecutionTimeInfo_GetLastUpdatedAt{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionTimeInfo) OnGetLastUpdatedAt() *ExecutionTimeInfo_GetLastUpdatedAt { + c := _m.On("GetLastUpdatedAt") + return &ExecutionTimeInfo_GetLastUpdatedAt{Call: c} +} + +func (_m *ExecutionTimeInfo) OnGetLastUpdatedAtMatch(matchers ...interface{}) *ExecutionTimeInfo_GetLastUpdatedAt { + c := _m.On("GetLastUpdatedAt", matchers...) + return &ExecutionTimeInfo_GetLastUpdatedAt{Call: c} +} + +// GetLastUpdatedAt provides a mock function with given fields: +func (_m *ExecutionTimeInfo) GetLastUpdatedAt() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} + +type ExecutionTimeInfo_GetStartedAt struct { + *mock.Call +} + +func (_m ExecutionTimeInfo_GetStartedAt) Return(_a0 *v1.Time) *ExecutionTimeInfo_GetStartedAt { + return &ExecutionTimeInfo_GetStartedAt{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionTimeInfo) OnGetStartedAt() *ExecutionTimeInfo_GetStartedAt { + c := _m.On("GetStartedAt") + return &ExecutionTimeInfo_GetStartedAt{Call: c} +} + +func (_m *ExecutionTimeInfo) OnGetStartedAtMatch(matchers ...interface{}) *ExecutionTimeInfo_GetStartedAt { + c := _m.On("GetStartedAt", matchers...) + return &ExecutionTimeInfo_GetStartedAt{Call: c} +} + +// GetStartedAt provides a mock function with given fields: +func (_m *ExecutionTimeInfo) GetStartedAt() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} + +type ExecutionTimeInfo_GetStoppedAt struct { + *mock.Call +} + +func (_m ExecutionTimeInfo_GetStoppedAt) Return(_a0 *v1.Time) *ExecutionTimeInfo_GetStoppedAt { + return &ExecutionTimeInfo_GetStoppedAt{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionTimeInfo) OnGetStoppedAt() *ExecutionTimeInfo_GetStoppedAt { + c := _m.On("GetStoppedAt") + return &ExecutionTimeInfo_GetStoppedAt{Call: c} +} + +func (_m *ExecutionTimeInfo) OnGetStoppedAtMatch(matchers ...interface{}) *ExecutionTimeInfo_GetStoppedAt { + c := _m.On("GetStoppedAt", matchers...) + return &ExecutionTimeInfo_GetStoppedAt{Call: c} +} + +// GetStoppedAt provides a mock function with given fields: +func (_m *ExecutionTimeInfo) GetStoppedAt() *v1.Time { + ret := _m.Called() + + var r0 *v1.Time + if rf, ok := ret.Get(0).(func() *v1.Time); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.Time) + } + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/Meta.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/Meta.go new file mode 100644 index 0000000000..7692bd130c --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/Meta.go @@ -0,0 +1,341 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + types "k8s.io/apimachinery/pkg/types" + + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +// Meta is an autogenerated mock type for the Meta type +type Meta struct { + mock.Mock +} + +type Meta_GetAnnotations struct { + *mock.Call +} + +func (_m Meta_GetAnnotations) Return(_a0 map[string]string) *Meta_GetAnnotations { + return &Meta_GetAnnotations{Call: _m.Call.Return(_a0)} +} + +func (_m *Meta) OnGetAnnotations() *Meta_GetAnnotations { + c := _m.On("GetAnnotations") + return &Meta_GetAnnotations{Call: c} +} + +func (_m *Meta) OnGetAnnotationsMatch(matchers ...interface{}) *Meta_GetAnnotations { + c := _m.On("GetAnnotations", matchers...) + return &Meta_GetAnnotations{Call: c} +} + +// GetAnnotations provides a mock function with given fields: +func (_m *Meta) GetAnnotations() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +type Meta_GetCreationTimestamp struct { + *mock.Call +} + +func (_m Meta_GetCreationTimestamp) Return(_a0 v1.Time) *Meta_GetCreationTimestamp { + return &Meta_GetCreationTimestamp{Call: _m.Call.Return(_a0)} +} + +func (_m *Meta) OnGetCreationTimestamp() *Meta_GetCreationTimestamp { + c := _m.On("GetCreationTimestamp") + return &Meta_GetCreationTimestamp{Call: c} +} + +func (_m *Meta) OnGetCreationTimestampMatch(matchers ...interface{}) *Meta_GetCreationTimestamp { + c := _m.On("GetCreationTimestamp", matchers...) + return &Meta_GetCreationTimestamp{Call: c} +} + +// GetCreationTimestamp provides a mock function with given fields: +func (_m *Meta) GetCreationTimestamp() v1.Time { + ret := _m.Called() + + var r0 v1.Time + if rf, ok := ret.Get(0).(func() v1.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.Time) + } + + return r0 +} + +type Meta_GetExecutionID struct { + *mock.Call +} + +func (_m Meta_GetExecutionID) Return(_a0 v1alpha1.WorkflowExecutionIdentifier) *Meta_GetExecutionID { + return &Meta_GetExecutionID{Call: _m.Call.Return(_a0)} +} + +func (_m *Meta) OnGetExecutionID() *Meta_GetExecutionID { + c := _m.On("GetExecutionID") + return &Meta_GetExecutionID{Call: c} +} + +func (_m *Meta) OnGetExecutionIDMatch(matchers ...interface{}) *Meta_GetExecutionID { + c := _m.On("GetExecutionID", matchers...) + return &Meta_GetExecutionID{Call: c} +} + +// GetExecutionID provides a mock function with given fields: +func (_m *Meta) GetExecutionID() v1alpha1.WorkflowExecutionIdentifier { + ret := _m.Called() + + var r0 v1alpha1.WorkflowExecutionIdentifier + if rf, ok := ret.Get(0).(func() v1alpha1.WorkflowExecutionIdentifier); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.WorkflowExecutionIdentifier) + } + + return r0 +} + +type Meta_GetK8sWorkflowID struct { + *mock.Call +} + +func (_m Meta_GetK8sWorkflowID) Return(_a0 types.NamespacedName) *Meta_GetK8sWorkflowID { + return &Meta_GetK8sWorkflowID{Call: _m.Call.Return(_a0)} +} + +func (_m *Meta) OnGetK8sWorkflowID() *Meta_GetK8sWorkflowID { + c := _m.On("GetK8sWorkflowID") + return &Meta_GetK8sWorkflowID{Call: c} +} + +func (_m *Meta) OnGetK8sWorkflowIDMatch(matchers ...interface{}) *Meta_GetK8sWorkflowID { + c := _m.On("GetK8sWorkflowID", matchers...) + return &Meta_GetK8sWorkflowID{Call: c} +} + +// GetK8sWorkflowID provides a mock function with given fields: +func (_m *Meta) GetK8sWorkflowID() types.NamespacedName { + ret := _m.Called() + + var r0 types.NamespacedName + if rf, ok := ret.Get(0).(func() types.NamespacedName); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.NamespacedName) + } + + return r0 +} + +type Meta_GetLabels struct { + *mock.Call +} + +func (_m Meta_GetLabels) Return(_a0 map[string]string) *Meta_GetLabels { + return &Meta_GetLabels{Call: _m.Call.Return(_a0)} +} + +func (_m *Meta) OnGetLabels() *Meta_GetLabels { + c := _m.On("GetLabels") + return &Meta_GetLabels{Call: c} +} + +func (_m *Meta) OnGetLabelsMatch(matchers ...interface{}) *Meta_GetLabels { + c := _m.On("GetLabels", matchers...) + return &Meta_GetLabels{Call: c} +} + +// GetLabels provides a mock function with given fields: +func (_m *Meta) GetLabels() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +type Meta_GetName struct { + *mock.Call +} + +func (_m Meta_GetName) Return(_a0 string) *Meta_GetName { + return &Meta_GetName{Call: _m.Call.Return(_a0)} +} + +func (_m *Meta) OnGetName() *Meta_GetName { + c := _m.On("GetName") + return &Meta_GetName{Call: c} +} + +func (_m *Meta) OnGetNameMatch(matchers ...interface{}) *Meta_GetName { + c := _m.On("GetName", matchers...) + return &Meta_GetName{Call: c} +} + +// GetName provides a mock function with given fields: +func (_m *Meta) GetName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +type Meta_GetNamespace struct { + *mock.Call +} + +func (_m Meta_GetNamespace) Return(_a0 string) *Meta_GetNamespace { + return &Meta_GetNamespace{Call: _m.Call.Return(_a0)} +} + +func (_m *Meta) OnGetNamespace() *Meta_GetNamespace { + c := _m.On("GetNamespace") + return &Meta_GetNamespace{Call: c} +} + +func (_m *Meta) OnGetNamespaceMatch(matchers ...interface{}) *Meta_GetNamespace { + c := _m.On("GetNamespace", matchers...) + return &Meta_GetNamespace{Call: c} +} + +// GetNamespace provides a mock function with given fields: +func (_m *Meta) GetNamespace() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +type Meta_GetOwnerReference struct { + *mock.Call +} + +func (_m Meta_GetOwnerReference) Return(_a0 v1.OwnerReference) *Meta_GetOwnerReference { + return &Meta_GetOwnerReference{Call: _m.Call.Return(_a0)} +} + +func (_m *Meta) OnGetOwnerReference() *Meta_GetOwnerReference { + c := _m.On("GetOwnerReference") + return &Meta_GetOwnerReference{Call: c} +} + +func (_m *Meta) OnGetOwnerReferenceMatch(matchers ...interface{}) *Meta_GetOwnerReference { + c := _m.On("GetOwnerReference", matchers...) + return &Meta_GetOwnerReference{Call: c} +} + +// GetOwnerReference provides a mock function with given fields: +func (_m *Meta) GetOwnerReference() v1.OwnerReference { + ret := _m.Called() + + var r0 v1.OwnerReference + if rf, ok := ret.Get(0).(func() v1.OwnerReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.OwnerReference) + } + + return r0 +} + +type Meta_GetServiceAccountName struct { + *mock.Call +} + +func (_m Meta_GetServiceAccountName) Return(_a0 string) *Meta_GetServiceAccountName { + return &Meta_GetServiceAccountName{Call: _m.Call.Return(_a0)} +} + +func (_m *Meta) OnGetServiceAccountName() *Meta_GetServiceAccountName { + c := _m.On("GetServiceAccountName") + return &Meta_GetServiceAccountName{Call: c} +} + +func (_m *Meta) OnGetServiceAccountNameMatch(matchers ...interface{}) *Meta_GetServiceAccountName { + c := _m.On("GetServiceAccountName", matchers...) + return &Meta_GetServiceAccountName{Call: c} +} + +// GetServiceAccountName provides a mock function with given fields: +func (_m *Meta) GetServiceAccountName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +type Meta_IsInterruptible struct { + *mock.Call +} + +func (_m Meta_IsInterruptible) Return(_a0 bool) *Meta_IsInterruptible { + return &Meta_IsInterruptible{Call: _m.Call.Return(_a0)} +} + +func (_m *Meta) OnIsInterruptible() *Meta_IsInterruptible { + c := _m.On("IsInterruptible") + return &Meta_IsInterruptible{Call: c} +} + +func (_m *Meta) OnIsInterruptibleMatch(matchers ...interface{}) *Meta_IsInterruptible { + c := _m.On("IsInterruptible", matchers...) + return &Meta_IsInterruptible{Call: c} +} + +// IsInterruptible provides a mock function with given fields: +func (_m *Meta) IsInterruptible() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/MetaExtended.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/MetaExtended.go new file mode 100644 index 0000000000..282a782564 --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/MetaExtended.go @@ -0,0 +1,450 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + types "k8s.io/apimachinery/pkg/types" + + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +// MetaExtended is an autogenerated mock type for the MetaExtended type +type MetaExtended struct { + mock.Mock +} + +type MetaExtended_FindSubWorkflow struct { + *mock.Call +} + +func (_m MetaExtended_FindSubWorkflow) Return(_a0 v1alpha1.ExecutableSubWorkflow) *MetaExtended_FindSubWorkflow { + return &MetaExtended_FindSubWorkflow{Call: _m.Call.Return(_a0)} +} + +func (_m *MetaExtended) OnFindSubWorkflow(subID string) *MetaExtended_FindSubWorkflow { + c := _m.On("FindSubWorkflow", subID) + return &MetaExtended_FindSubWorkflow{Call: c} +} + +func (_m *MetaExtended) OnFindSubWorkflowMatch(matchers ...interface{}) *MetaExtended_FindSubWorkflow { + c := _m.On("FindSubWorkflow", matchers...) + return &MetaExtended_FindSubWorkflow{Call: c} +} + +// FindSubWorkflow provides a mock function with given fields: subID +func (_m *MetaExtended) FindSubWorkflow(subID string) v1alpha1.ExecutableSubWorkflow { + ret := _m.Called(subID) + + var r0 v1alpha1.ExecutableSubWorkflow + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableSubWorkflow); ok { + r0 = rf(subID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableSubWorkflow) + } + } + + return r0 +} + +type MetaExtended_GetAnnotations struct { + *mock.Call +} + +func (_m MetaExtended_GetAnnotations) Return(_a0 map[string]string) *MetaExtended_GetAnnotations { + return &MetaExtended_GetAnnotations{Call: _m.Call.Return(_a0)} +} + +func (_m *MetaExtended) OnGetAnnotations() *MetaExtended_GetAnnotations { + c := _m.On("GetAnnotations") + return &MetaExtended_GetAnnotations{Call: c} +} + +func (_m *MetaExtended) OnGetAnnotationsMatch(matchers ...interface{}) *MetaExtended_GetAnnotations { + c := _m.On("GetAnnotations", matchers...) + return &MetaExtended_GetAnnotations{Call: c} +} + +// GetAnnotations provides a mock function with given fields: +func (_m *MetaExtended) GetAnnotations() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +type MetaExtended_GetCreationTimestamp struct { + *mock.Call +} + +func (_m MetaExtended_GetCreationTimestamp) Return(_a0 v1.Time) *MetaExtended_GetCreationTimestamp { + return &MetaExtended_GetCreationTimestamp{Call: _m.Call.Return(_a0)} +} + +func (_m *MetaExtended) OnGetCreationTimestamp() *MetaExtended_GetCreationTimestamp { + c := _m.On("GetCreationTimestamp") + return &MetaExtended_GetCreationTimestamp{Call: c} +} + +func (_m *MetaExtended) OnGetCreationTimestampMatch(matchers ...interface{}) *MetaExtended_GetCreationTimestamp { + c := _m.On("GetCreationTimestamp", matchers...) + return &MetaExtended_GetCreationTimestamp{Call: c} +} + +// GetCreationTimestamp provides a mock function with given fields: +func (_m *MetaExtended) GetCreationTimestamp() v1.Time { + ret := _m.Called() + + var r0 v1.Time + if rf, ok := ret.Get(0).(func() v1.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.Time) + } + + return r0 +} + +type MetaExtended_GetExecutionID struct { + *mock.Call +} + +func (_m MetaExtended_GetExecutionID) Return(_a0 v1alpha1.WorkflowExecutionIdentifier) *MetaExtended_GetExecutionID { + return &MetaExtended_GetExecutionID{Call: _m.Call.Return(_a0)} +} + +func (_m *MetaExtended) OnGetExecutionID() *MetaExtended_GetExecutionID { + c := _m.On("GetExecutionID") + return &MetaExtended_GetExecutionID{Call: c} +} + +func (_m *MetaExtended) OnGetExecutionIDMatch(matchers ...interface{}) *MetaExtended_GetExecutionID { + c := _m.On("GetExecutionID", matchers...) + return &MetaExtended_GetExecutionID{Call: c} +} + +// GetExecutionID provides a mock function with given fields: +func (_m *MetaExtended) GetExecutionID() v1alpha1.WorkflowExecutionIdentifier { + ret := _m.Called() + + var r0 v1alpha1.WorkflowExecutionIdentifier + if rf, ok := ret.Get(0).(func() v1alpha1.WorkflowExecutionIdentifier); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.WorkflowExecutionIdentifier) + } + + return r0 +} + +type MetaExtended_GetExecutionStatus struct { + *mock.Call +} + +func (_m MetaExtended_GetExecutionStatus) Return(_a0 v1alpha1.ExecutableWorkflowStatus) *MetaExtended_GetExecutionStatus { + return &MetaExtended_GetExecutionStatus{Call: _m.Call.Return(_a0)} +} + +func (_m *MetaExtended) OnGetExecutionStatus() *MetaExtended_GetExecutionStatus { + c := _m.On("GetExecutionStatus") + return &MetaExtended_GetExecutionStatus{Call: c} +} + +func (_m *MetaExtended) OnGetExecutionStatusMatch(matchers ...interface{}) *MetaExtended_GetExecutionStatus { + c := _m.On("GetExecutionStatus", matchers...) + return &MetaExtended_GetExecutionStatus{Call: c} +} + +// GetExecutionStatus provides a mock function with given fields: +func (_m *MetaExtended) GetExecutionStatus() v1alpha1.ExecutableWorkflowStatus { + ret := _m.Called() + + var r0 v1alpha1.ExecutableWorkflowStatus + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableWorkflowStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableWorkflowStatus) + } + } + + return r0 +} + +type MetaExtended_GetK8sWorkflowID struct { + *mock.Call +} + +func (_m MetaExtended_GetK8sWorkflowID) Return(_a0 types.NamespacedName) *MetaExtended_GetK8sWorkflowID { + return &MetaExtended_GetK8sWorkflowID{Call: _m.Call.Return(_a0)} +} + +func (_m *MetaExtended) OnGetK8sWorkflowID() *MetaExtended_GetK8sWorkflowID { + c := _m.On("GetK8sWorkflowID") + return &MetaExtended_GetK8sWorkflowID{Call: c} +} + +func (_m *MetaExtended) OnGetK8sWorkflowIDMatch(matchers ...interface{}) *MetaExtended_GetK8sWorkflowID { + c := _m.On("GetK8sWorkflowID", matchers...) + return &MetaExtended_GetK8sWorkflowID{Call: c} +} + +// GetK8sWorkflowID provides a mock function with given fields: +func (_m *MetaExtended) GetK8sWorkflowID() types.NamespacedName { + ret := _m.Called() + + var r0 types.NamespacedName + if rf, ok := ret.Get(0).(func() types.NamespacedName); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.NamespacedName) + } + + return r0 +} + +type MetaExtended_GetLabels struct { + *mock.Call +} + +func (_m MetaExtended_GetLabels) Return(_a0 map[string]string) *MetaExtended_GetLabels { + return &MetaExtended_GetLabels{Call: _m.Call.Return(_a0)} +} + +func (_m *MetaExtended) OnGetLabels() *MetaExtended_GetLabels { + c := _m.On("GetLabels") + return &MetaExtended_GetLabels{Call: c} +} + +func (_m *MetaExtended) OnGetLabelsMatch(matchers ...interface{}) *MetaExtended_GetLabels { + c := _m.On("GetLabels", matchers...) + return &MetaExtended_GetLabels{Call: c} +} + +// GetLabels provides a mock function with given fields: +func (_m *MetaExtended) GetLabels() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +type MetaExtended_GetName struct { + *mock.Call +} + +func (_m MetaExtended_GetName) Return(_a0 string) *MetaExtended_GetName { + return &MetaExtended_GetName{Call: _m.Call.Return(_a0)} +} + +func (_m *MetaExtended) OnGetName() *MetaExtended_GetName { + c := _m.On("GetName") + return &MetaExtended_GetName{Call: c} +} + +func (_m *MetaExtended) OnGetNameMatch(matchers ...interface{}) *MetaExtended_GetName { + c := _m.On("GetName", matchers...) + return &MetaExtended_GetName{Call: c} +} + +// GetName provides a mock function with given fields: +func (_m *MetaExtended) GetName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +type MetaExtended_GetNamespace struct { + *mock.Call +} + +func (_m MetaExtended_GetNamespace) Return(_a0 string) *MetaExtended_GetNamespace { + return &MetaExtended_GetNamespace{Call: _m.Call.Return(_a0)} +} + +func (_m *MetaExtended) OnGetNamespace() *MetaExtended_GetNamespace { + c := _m.On("GetNamespace") + return &MetaExtended_GetNamespace{Call: c} +} + +func (_m *MetaExtended) OnGetNamespaceMatch(matchers ...interface{}) *MetaExtended_GetNamespace { + c := _m.On("GetNamespace", matchers...) + return &MetaExtended_GetNamespace{Call: c} +} + +// GetNamespace provides a mock function with given fields: +func (_m *MetaExtended) GetNamespace() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +type MetaExtended_GetOwnerReference struct { + *mock.Call +} + +func (_m MetaExtended_GetOwnerReference) Return(_a0 v1.OwnerReference) *MetaExtended_GetOwnerReference { + return &MetaExtended_GetOwnerReference{Call: _m.Call.Return(_a0)} +} + +func (_m *MetaExtended) OnGetOwnerReference() *MetaExtended_GetOwnerReference { + c := _m.On("GetOwnerReference") + return &MetaExtended_GetOwnerReference{Call: c} +} + +func (_m *MetaExtended) OnGetOwnerReferenceMatch(matchers ...interface{}) *MetaExtended_GetOwnerReference { + c := _m.On("GetOwnerReference", matchers...) + return &MetaExtended_GetOwnerReference{Call: c} +} + +// GetOwnerReference provides a mock function with given fields: +func (_m *MetaExtended) GetOwnerReference() v1.OwnerReference { + ret := _m.Called() + + var r0 v1.OwnerReference + if rf, ok := ret.Get(0).(func() v1.OwnerReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.OwnerReference) + } + + return r0 +} + +type MetaExtended_GetServiceAccountName struct { + *mock.Call +} + +func (_m MetaExtended_GetServiceAccountName) Return(_a0 string) *MetaExtended_GetServiceAccountName { + return &MetaExtended_GetServiceAccountName{Call: _m.Call.Return(_a0)} +} + +func (_m *MetaExtended) OnGetServiceAccountName() *MetaExtended_GetServiceAccountName { + c := _m.On("GetServiceAccountName") + return &MetaExtended_GetServiceAccountName{Call: c} +} + +func (_m *MetaExtended) OnGetServiceAccountNameMatch(matchers ...interface{}) *MetaExtended_GetServiceAccountName { + c := _m.On("GetServiceAccountName", matchers...) + return &MetaExtended_GetServiceAccountName{Call: c} +} + +// GetServiceAccountName provides a mock function with given fields: +func (_m *MetaExtended) GetServiceAccountName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +type MetaExtended_GetTask struct { + *mock.Call +} + +func (_m MetaExtended_GetTask) Return(_a0 v1alpha1.ExecutableTask, _a1 error) *MetaExtended_GetTask { + return &MetaExtended_GetTask{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *MetaExtended) OnGetTask(id string) *MetaExtended_GetTask { + c := _m.On("GetTask", id) + return &MetaExtended_GetTask{Call: c} +} + +func (_m *MetaExtended) OnGetTaskMatch(matchers ...interface{}) *MetaExtended_GetTask { + c := _m.On("GetTask", matchers...) + return &MetaExtended_GetTask{Call: c} +} + +// GetTask provides a mock function with given fields: id +func (_m *MetaExtended) GetTask(id string) (v1alpha1.ExecutableTask, error) { + ret := _m.Called(id) + + var r0 v1alpha1.ExecutableTask + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableTask); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableTask) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type MetaExtended_IsInterruptible struct { + *mock.Call +} + +func (_m MetaExtended_IsInterruptible) Return(_a0 bool) *MetaExtended_IsInterruptible { + return &MetaExtended_IsInterruptible{Call: _m.Call.Return(_a0)} +} + +func (_m *MetaExtended) OnIsInterruptible() *MetaExtended_IsInterruptible { + c := _m.On("IsInterruptible") + return &MetaExtended_IsInterruptible{Call: c} +} + +func (_m *MetaExtended) OnIsInterruptibleMatch(matchers ...interface{}) *MetaExtended_IsInterruptible { + c := _m.On("IsInterruptible", matchers...) + return &MetaExtended_IsInterruptible{Call: c} +} + +// IsInterruptible provides a mock function with given fields: +func (_m *MetaExtended) IsInterruptible() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeGetter.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeGetter.go new file mode 100644 index 0000000000..38535fb76c --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/NodeGetter.go @@ -0,0 +1,54 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + mock "github.com/stretchr/testify/mock" +) + +// NodeGetter is an autogenerated mock type for the NodeGetter type +type NodeGetter struct { + mock.Mock +} + +type NodeGetter_GetNode struct { + *mock.Call +} + +func (_m NodeGetter_GetNode) Return(_a0 v1alpha1.ExecutableNode, _a1 bool) *NodeGetter_GetNode { + return &NodeGetter_GetNode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *NodeGetter) OnGetNode(nodeID string) *NodeGetter_GetNode { + c := _m.On("GetNode", nodeID) + return &NodeGetter_GetNode{Call: c} +} + +func (_m *NodeGetter) OnGetNodeMatch(matchers ...interface{}) *NodeGetter_GetNode { + c := _m.On("GetNode", matchers...) + return &NodeGetter_GetNode{Call: c} +} + +// GetNode provides a mock function with given fields: nodeID +func (_m *NodeGetter) GetNode(nodeID string) (v1alpha1.ExecutableNode, bool) { + ret := _m.Called(nodeID) + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNode); ok { + r0 = rf(nodeID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(string) bool); ok { + r1 = rf(nodeID) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/SubWorkflowGetter.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/SubWorkflowGetter.go new file mode 100644 index 0000000000..5a6367b45f --- /dev/null +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/SubWorkflowGetter.go @@ -0,0 +1,47 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + mock "github.com/stretchr/testify/mock" +) + +// SubWorkflowGetter is an autogenerated mock type for the SubWorkflowGetter type +type SubWorkflowGetter struct { + mock.Mock +} + +type SubWorkflowGetter_FindSubWorkflow struct { + *mock.Call +} + +func (_m SubWorkflowGetter_FindSubWorkflow) Return(_a0 v1alpha1.ExecutableSubWorkflow) *SubWorkflowGetter_FindSubWorkflow { + return &SubWorkflowGetter_FindSubWorkflow{Call: _m.Call.Return(_a0)} +} + +func (_m *SubWorkflowGetter) OnFindSubWorkflow(subID string) *SubWorkflowGetter_FindSubWorkflow { + c := _m.On("FindSubWorkflow", subID) + return &SubWorkflowGetter_FindSubWorkflow{Call: c} +} + +func (_m *SubWorkflowGetter) OnFindSubWorkflowMatch(matchers ...interface{}) *SubWorkflowGetter_FindSubWorkflow { + c := _m.On("FindSubWorkflow", matchers...) + return &SubWorkflowGetter_FindSubWorkflow{Call: c} +} + +// FindSubWorkflow provides a mock function with given fields: subID +func (_m *SubWorkflowGetter) FindSubWorkflow(subID string) v1alpha1.ExecutableSubWorkflow { + ret := _m.Called(subID) + + var r0 v1alpha1.ExecutableSubWorkflow + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableSubWorkflow); ok { + r0 = rf(subID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableSubWorkflow) + } + } + + return r0 +} diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMeta.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMeta.go index 9934b3efd6..25475c2ec4 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMeta.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMeta.go @@ -11,7 +11,7 @@ import ( v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" ) -// WorkflowMeta is an autogenerated mock type for the WorkflowMeta type +// Meta is an autogenerated mock type for the Meta type type WorkflowMeta struct { mock.Mock } diff --git a/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMetaExtended.go b/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMetaExtended.go index 6865c07f42..6fe8e6357e 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMetaExtended.go +++ b/pkg/apis/flyteworkflow/v1alpha1/mocks/WorkflowMetaExtended.go @@ -11,7 +11,7 @@ import ( v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" ) -// WorkflowMetaExtended is an autogenerated mock type for the WorkflowMetaExtended type +// MetaExtended is an autogenerated mock type for the MetaExtended type type WorkflowMetaExtended struct { mock.Mock } diff --git a/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/pkg/apis/flyteworkflow/v1alpha1/node_status.go index a47c7719da..9a9b7f6c06 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/node_status.go +++ b/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -134,6 +134,7 @@ type WorkflowNodePhase int const ( WorkflowNodePhaseUndefined WorkflowNodePhase = iota WorkflowNodePhaseExecuting + WorkflowNodePhaseFailing ) type WorkflowNodeStatus struct { diff --git a/pkg/apis/flyteworkflow/v1alpha1/workflow.go b/pkg/apis/flyteworkflow/v1alpha1/workflow.go index 7a824abbc3..1176b6469b 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/workflow.go +++ b/pkg/apis/flyteworkflow/v1alpha1/workflow.go @@ -205,6 +205,14 @@ func (in *WorkflowSpec) GetID() WorkflowID { return in.ID } +func (in *WorkflowSpec) ToNode(name NodeID) ([]NodeID, error) { + if _, ok := in.Nodes[name]; !ok { + return nil, errors.Errorf("Bad Node [%v], is not defined in the Workflow [%v]", name, in.ID) + } + upstreamNodes := in.Connections.UpstreamEdges[name] + return upstreamNodes, nil +} + func (in *WorkflowSpec) FromNode(name NodeID) ([]NodeID, error) { if _, ok := in.Nodes[name]; !ok { return nil, errors.Errorf("Bad Node [%v], is not defined in the Workflow [%v]", name, in.ID) diff --git a/pkg/controller/executors/contextual.go b/pkg/controller/executors/contextual.go deleted file mode 100644 index 7d02a6f9a7..0000000000 --- a/pkg/controller/executors/contextual.go +++ /dev/null @@ -1,30 +0,0 @@ -package executors - -import ( - "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" -) - -type ContextualWorkflow struct { - v1alpha1.WorkflowMetaExtended - v1alpha1.ExecutableSubWorkflow - v1alpha1.NodeStatusGetter -} - -func NewBaseContextualWorkflow(baseWorkflow v1alpha1.ExecutableWorkflow) v1alpha1.ExecutableWorkflow { - return &ContextualWorkflow{ - ExecutableSubWorkflow: baseWorkflow, - WorkflowMetaExtended: baseWorkflow, - NodeStatusGetter: baseWorkflow.GetExecutionStatus(), - } -} - -// Creates a contextual workflow using the provided interface implementations. -func NewSubContextualWorkflow(baseWorkflow v1alpha1.ExecutableWorkflow, subWF v1alpha1.ExecutableSubWorkflow, - nodeStatus v1alpha1.ExecutableNodeStatus) v1alpha1.ExecutableWorkflow { - - return &ContextualWorkflow{ - ExecutableSubWorkflow: subWF, - WorkflowMetaExtended: baseWorkflow, - NodeStatusGetter: nodeStatus, - } -} diff --git a/pkg/controller/executors/dag_structure.go b/pkg/controller/executors/dag_structure.go new file mode 100644 index 0000000000..aa094fc0f6 --- /dev/null +++ b/pkg/controller/executors/dag_structure.go @@ -0,0 +1,48 @@ +package executors + +import ( + "fmt" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +// An interface that captures the Directed Acyclic Graph structure in which the nodes are connected. +// If NodeLookup and DAGStructure are used together a traversal can be implemented. +type DAGStructure interface { + // Lookup for upstream edges, find all node ids from which this node can be reached. + ToNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) + // Lookup for downstream edges, find all node ids that can be reached from the given node id. + FromNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) +} + +type DAGStructureWithStartNode interface { + DAGStructure + // The Starting node for the DAG + StartNode() v1alpha1.ExecutableNode +} + +type leafNodeDAGStructure struct { + parentNodes []v1alpha1.NodeID + currentNode v1alpha1.NodeID +} + +func (l leafNodeDAGStructure) StartNode() v1alpha1.ExecutableNode { + return nil +} + +func (l leafNodeDAGStructure) ToNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) { + if id == l.currentNode { + return l.parentNodes, nil + } + return nil, fmt.Errorf("unknown Node ID [%s]", id) +} + +func (l leafNodeDAGStructure) FromNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) { + return []v1alpha1.NodeID{}, nil +} + +// Returns a new DAGStructure for a leafNode. i.e., there are only incoming edges and no outgoing edges. +// Also there is no StartNode for this Structure +func NewLeafNodeDAGStructure(leafNode v1alpha1.NodeID, parentNodes ...v1alpha1.NodeID) DAGStructure { + return leafNodeDAGStructure{currentNode: leafNode, parentNodes: parentNodes} +} diff --git a/pkg/controller/executors/execution_context.go b/pkg/controller/executors/execution_context.go new file mode 100644 index 0000000000..9e7d4d3338 --- /dev/null +++ b/pkg/controller/executors/execution_context.go @@ -0,0 +1,86 @@ +package executors + +import ( + "fmt" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +type TaskDetailsGetter interface { + GetTask(id v1alpha1.TaskID) (v1alpha1.ExecutableTask, error) +} + +// Retrieves the Task details from a static inmemory HashMap +type staticTaskDetailsGetter struct { + // As this is an additional taskmap created, we can use a name to identify the taskmap. Every error message will have this name + name string + tasks map[v1alpha1.TaskID]v1alpha1.ExecutableTask +} + +func (t *staticTaskDetailsGetter) GetTask(id v1alpha1.TaskID) (v1alpha1.ExecutableTask, error) { + if task, ok := t.tasks[id]; ok { + return task, nil + } + return nil, fmt.Errorf("unable to find task with id [%s] in task set [%s]", id, t.name) +} + +// As this is an additional taskmap created, we can use a name to identify the taskmap. Every error message will have this name +func NewStaticTaskDetailsGetter(name string, tasks map[v1alpha1.TaskID]v1alpha1.ExecutableTask) TaskDetailsGetter { + return &staticTaskDetailsGetter{name: name, tasks: tasks} +} + +type SubWorkflowGetter interface { + FindSubWorkflow(subID v1alpha1.WorkflowID) v1alpha1.ExecutableSubWorkflow +} + +// Retrieves the Task details from a static inmemory HashMap +type staticSubWorkflowGetter struct { + // As this is an additional subWorkflow Map created, we can use a name to identify the SubWorkflow Set. Every error message will have this name + name string + subWorkflows map[v1alpha1.WorkflowID]v1alpha1.ExecutableSubWorkflow +} + +func (t *staticSubWorkflowGetter) FindSubWorkflow(subID v1alpha1.WorkflowID) v1alpha1.ExecutableSubWorkflow { + if swf, ok := t.subWorkflows[subID]; ok { + return swf + } + return nil +} + +// As this is an additional taskmap created, we can use a name to identify the taskmap. Every error message will have this name +func NewStaticSubWorkflowsGetter(name string, subworkflows map[v1alpha1.WorkflowID]v1alpha1.ExecutableSubWorkflow) SubWorkflowGetter { + return &staticSubWorkflowGetter{name: name, subWorkflows: subworkflows} +} + +type ImmutableExecutionContext interface { + v1alpha1.Meta + GetID() v1alpha1.WorkflowID +} + +type ExecutionContext interface { + ImmutableExecutionContext + TaskDetailsGetter + SubWorkflowGetter +} + +type execContext struct { + ImmutableExecutionContext + TaskDetailsGetter + SubWorkflowGetter +} + +func NewExecutionContextWithTasksGetter(prevExecContext ExecutionContext, taskGetter TaskDetailsGetter) ExecutionContext { + return NewExecutionContext(prevExecContext, taskGetter, prevExecContext) +} + +func NewExecutionContextWithWorkflowGetter(prevExecContext ExecutionContext, getter SubWorkflowGetter) ExecutionContext { + return NewExecutionContext(prevExecContext, prevExecContext, getter) +} + +func NewExecutionContext(immExecContext ImmutableExecutionContext, tasksGetter TaskDetailsGetter, workflowGetter SubWorkflowGetter) ExecutionContext { + return execContext{ + ImmutableExecutionContext: immExecContext, + TaskDetailsGetter: tasksGetter, + SubWorkflowGetter: workflowGetter, + } +} diff --git a/pkg/controller/executors/mocks/dag_structure.go b/pkg/controller/executors/mocks/dag_structure.go new file mode 100644 index 0000000000..dbe939bf3c --- /dev/null +++ b/pkg/controller/executors/mocks/dag_structure.go @@ -0,0 +1,92 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// DAGStructure is an autogenerated mock type for the DAGStructure type +type DAGStructure struct { + mock.Mock +} + +type DAGStructure_FromNode struct { + *mock.Call +} + +func (_m DAGStructure_FromNode) Return(_a0 []string, _a1 error) *DAGStructure_FromNode { + return &DAGStructure_FromNode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *DAGStructure) OnFromNode(id string) *DAGStructure_FromNode { + c := _m.On("FromNode", id) + return &DAGStructure_FromNode{Call: c} +} + +func (_m *DAGStructure) OnFromNodeMatch(matchers ...interface{}) *DAGStructure_FromNode { + c := _m.On("FromNode", matchers...) + return &DAGStructure_FromNode{Call: c} +} + +// FromNode provides a mock function with given fields: id +func (_m *DAGStructure) FromNode(id string) ([]string, error) { + ret := _m.Called(id) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type DAGStructure_ToNode struct { + *mock.Call +} + +func (_m DAGStructure_ToNode) Return(_a0 []string, _a1 error) *DAGStructure_ToNode { + return &DAGStructure_ToNode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *DAGStructure) OnToNode(id string) *DAGStructure_ToNode { + c := _m.On("ToNode", id) + return &DAGStructure_ToNode{Call: c} +} + +func (_m *DAGStructure) OnToNodeMatch(matchers ...interface{}) *DAGStructure_ToNode { + c := _m.On("ToNode", matchers...) + return &DAGStructure_ToNode{Call: c} +} + +// ToNode provides a mock function with given fields: id +func (_m *DAGStructure) ToNode(id string) ([]string, error) { + ret := _m.Called(id) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/controller/executors/mocks/dag_structure_with_start_node.go b/pkg/controller/executors/mocks/dag_structure_with_start_node.go new file mode 100644 index 0000000000..6e74e5e910 --- /dev/null +++ b/pkg/controller/executors/mocks/dag_structure_with_start_node.go @@ -0,0 +1,129 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + mock "github.com/stretchr/testify/mock" +) + +// DAGStructureWithStartNode is an autogenerated mock type for the DAGStructureWithStartNode type +type DAGStructureWithStartNode struct { + mock.Mock +} + +type DAGStructureWithStartNode_FromNode struct { + *mock.Call +} + +func (_m DAGStructureWithStartNode_FromNode) Return(_a0 []string, _a1 error) *DAGStructureWithStartNode_FromNode { + return &DAGStructureWithStartNode_FromNode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *DAGStructureWithStartNode) OnFromNode(id string) *DAGStructureWithStartNode_FromNode { + c := _m.On("FromNode", id) + return &DAGStructureWithStartNode_FromNode{Call: c} +} + +func (_m *DAGStructureWithStartNode) OnFromNodeMatch(matchers ...interface{}) *DAGStructureWithStartNode_FromNode { + c := _m.On("FromNode", matchers...) + return &DAGStructureWithStartNode_FromNode{Call: c} +} + +// FromNode provides a mock function with given fields: id +func (_m *DAGStructureWithStartNode) FromNode(id string) ([]string, error) { + ret := _m.Called(id) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type DAGStructureWithStartNode_StartNode struct { + *mock.Call +} + +func (_m DAGStructureWithStartNode_StartNode) Return(_a0 v1alpha1.ExecutableNode) *DAGStructureWithStartNode_StartNode { + return &DAGStructureWithStartNode_StartNode{Call: _m.Call.Return(_a0)} +} + +func (_m *DAGStructureWithStartNode) OnStartNode() *DAGStructureWithStartNode_StartNode { + c := _m.On("StartNode") + return &DAGStructureWithStartNode_StartNode{Call: c} +} + +func (_m *DAGStructureWithStartNode) OnStartNodeMatch(matchers ...interface{}) *DAGStructureWithStartNode_StartNode { + c := _m.On("StartNode", matchers...) + return &DAGStructureWithStartNode_StartNode{Call: c} +} + +// StartNode provides a mock function with given fields: +func (_m *DAGStructureWithStartNode) StartNode() v1alpha1.ExecutableNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + return r0 +} + +type DAGStructureWithStartNode_ToNode struct { + *mock.Call +} + +func (_m DAGStructureWithStartNode_ToNode) Return(_a0 []string, _a1 error) *DAGStructureWithStartNode_ToNode { + return &DAGStructureWithStartNode_ToNode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *DAGStructureWithStartNode) OnToNode(id string) *DAGStructureWithStartNode_ToNode { + c := _m.On("ToNode", id) + return &DAGStructureWithStartNode_ToNode{Call: c} +} + +func (_m *DAGStructureWithStartNode) OnToNodeMatch(matchers ...interface{}) *DAGStructureWithStartNode_ToNode { + c := _m.On("ToNode", matchers...) + return &DAGStructureWithStartNode_ToNode{Call: c} +} + +// ToNode provides a mock function with given fields: id +func (_m *DAGStructureWithStartNode) ToNode(id string) ([]string, error) { + ret := _m.Called(id) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/controller/executors/mocks/execution_context.go b/pkg/controller/executors/mocks/execution_context.go new file mode 100644 index 0000000000..18093be24f --- /dev/null +++ b/pkg/controller/executors/mocks/execution_context.go @@ -0,0 +1,448 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + types "k8s.io/apimachinery/pkg/types" + + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +// ExecutionContext is an autogenerated mock type for the ExecutionContext type +type ExecutionContext struct { + mock.Mock +} + +type ExecutionContext_FindSubWorkflow struct { + *mock.Call +} + +func (_m ExecutionContext_FindSubWorkflow) Return(_a0 v1alpha1.ExecutableSubWorkflow) *ExecutionContext_FindSubWorkflow { + return &ExecutionContext_FindSubWorkflow{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionContext) OnFindSubWorkflow(subID string) *ExecutionContext_FindSubWorkflow { + c := _m.On("FindSubWorkflow", subID) + return &ExecutionContext_FindSubWorkflow{Call: c} +} + +func (_m *ExecutionContext) OnFindSubWorkflowMatch(matchers ...interface{}) *ExecutionContext_FindSubWorkflow { + c := _m.On("FindSubWorkflow", matchers...) + return &ExecutionContext_FindSubWorkflow{Call: c} +} + +// FindSubWorkflow provides a mock function with given fields: subID +func (_m *ExecutionContext) FindSubWorkflow(subID string) v1alpha1.ExecutableSubWorkflow { + ret := _m.Called(subID) + + var r0 v1alpha1.ExecutableSubWorkflow + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableSubWorkflow); ok { + r0 = rf(subID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableSubWorkflow) + } + } + + return r0 +} + +type ExecutionContext_GetAnnotations struct { + *mock.Call +} + +func (_m ExecutionContext_GetAnnotations) Return(_a0 map[string]string) *ExecutionContext_GetAnnotations { + return &ExecutionContext_GetAnnotations{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionContext) OnGetAnnotations() *ExecutionContext_GetAnnotations { + c := _m.On("GetAnnotations") + return &ExecutionContext_GetAnnotations{Call: c} +} + +func (_m *ExecutionContext) OnGetAnnotationsMatch(matchers ...interface{}) *ExecutionContext_GetAnnotations { + c := _m.On("GetAnnotations", matchers...) + return &ExecutionContext_GetAnnotations{Call: c} +} + +// GetAnnotations provides a mock function with given fields: +func (_m *ExecutionContext) GetAnnotations() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +type ExecutionContext_GetCreationTimestamp struct { + *mock.Call +} + +func (_m ExecutionContext_GetCreationTimestamp) Return(_a0 v1.Time) *ExecutionContext_GetCreationTimestamp { + return &ExecutionContext_GetCreationTimestamp{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionContext) OnGetCreationTimestamp() *ExecutionContext_GetCreationTimestamp { + c := _m.On("GetCreationTimestamp") + return &ExecutionContext_GetCreationTimestamp{Call: c} +} + +func (_m *ExecutionContext) OnGetCreationTimestampMatch(matchers ...interface{}) *ExecutionContext_GetCreationTimestamp { + c := _m.On("GetCreationTimestamp", matchers...) + return &ExecutionContext_GetCreationTimestamp{Call: c} +} + +// GetCreationTimestamp provides a mock function with given fields: +func (_m *ExecutionContext) GetCreationTimestamp() v1.Time { + ret := _m.Called() + + var r0 v1.Time + if rf, ok := ret.Get(0).(func() v1.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.Time) + } + + return r0 +} + +type ExecutionContext_GetExecutionID struct { + *mock.Call +} + +func (_m ExecutionContext_GetExecutionID) Return(_a0 v1alpha1.WorkflowExecutionIdentifier) *ExecutionContext_GetExecutionID { + return &ExecutionContext_GetExecutionID{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionContext) OnGetExecutionID() *ExecutionContext_GetExecutionID { + c := _m.On("GetExecutionID") + return &ExecutionContext_GetExecutionID{Call: c} +} + +func (_m *ExecutionContext) OnGetExecutionIDMatch(matchers ...interface{}) *ExecutionContext_GetExecutionID { + c := _m.On("GetExecutionID", matchers...) + return &ExecutionContext_GetExecutionID{Call: c} +} + +// GetExecutionID provides a mock function with given fields: +func (_m *ExecutionContext) GetExecutionID() v1alpha1.WorkflowExecutionIdentifier { + ret := _m.Called() + + var r0 v1alpha1.WorkflowExecutionIdentifier + if rf, ok := ret.Get(0).(func() v1alpha1.WorkflowExecutionIdentifier); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.WorkflowExecutionIdentifier) + } + + return r0 +} + +type ExecutionContext_GetID struct { + *mock.Call +} + +func (_m ExecutionContext_GetID) Return(_a0 string) *ExecutionContext_GetID { + return &ExecutionContext_GetID{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionContext) OnGetID() *ExecutionContext_GetID { + c := _m.On("GetID") + return &ExecutionContext_GetID{Call: c} +} + +func (_m *ExecutionContext) OnGetIDMatch(matchers ...interface{}) *ExecutionContext_GetID { + c := _m.On("GetID", matchers...) + return &ExecutionContext_GetID{Call: c} +} + +// GetID provides a mock function with given fields: +func (_m *ExecutionContext) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +type ExecutionContext_GetK8sWorkflowID struct { + *mock.Call +} + +func (_m ExecutionContext_GetK8sWorkflowID) Return(_a0 types.NamespacedName) *ExecutionContext_GetK8sWorkflowID { + return &ExecutionContext_GetK8sWorkflowID{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionContext) OnGetK8sWorkflowID() *ExecutionContext_GetK8sWorkflowID { + c := _m.On("GetK8sWorkflowID") + return &ExecutionContext_GetK8sWorkflowID{Call: c} +} + +func (_m *ExecutionContext) OnGetK8sWorkflowIDMatch(matchers ...interface{}) *ExecutionContext_GetK8sWorkflowID { + c := _m.On("GetK8sWorkflowID", matchers...) + return &ExecutionContext_GetK8sWorkflowID{Call: c} +} + +// GetK8sWorkflowID provides a mock function with given fields: +func (_m *ExecutionContext) GetK8sWorkflowID() types.NamespacedName { + ret := _m.Called() + + var r0 types.NamespacedName + if rf, ok := ret.Get(0).(func() types.NamespacedName); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.NamespacedName) + } + + return r0 +} + +type ExecutionContext_GetLabels struct { + *mock.Call +} + +func (_m ExecutionContext_GetLabels) Return(_a0 map[string]string) *ExecutionContext_GetLabels { + return &ExecutionContext_GetLabels{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionContext) OnGetLabels() *ExecutionContext_GetLabels { + c := _m.On("GetLabels") + return &ExecutionContext_GetLabels{Call: c} +} + +func (_m *ExecutionContext) OnGetLabelsMatch(matchers ...interface{}) *ExecutionContext_GetLabels { + c := _m.On("GetLabels", matchers...) + return &ExecutionContext_GetLabels{Call: c} +} + +// GetLabels provides a mock function with given fields: +func (_m *ExecutionContext) GetLabels() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +type ExecutionContext_GetName struct { + *mock.Call +} + +func (_m ExecutionContext_GetName) Return(_a0 string) *ExecutionContext_GetName { + return &ExecutionContext_GetName{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionContext) OnGetName() *ExecutionContext_GetName { + c := _m.On("GetName") + return &ExecutionContext_GetName{Call: c} +} + +func (_m *ExecutionContext) OnGetNameMatch(matchers ...interface{}) *ExecutionContext_GetName { + c := _m.On("GetName", matchers...) + return &ExecutionContext_GetName{Call: c} +} + +// GetName provides a mock function with given fields: +func (_m *ExecutionContext) GetName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +type ExecutionContext_GetNamespace struct { + *mock.Call +} + +func (_m ExecutionContext_GetNamespace) Return(_a0 string) *ExecutionContext_GetNamespace { + return &ExecutionContext_GetNamespace{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionContext) OnGetNamespace() *ExecutionContext_GetNamespace { + c := _m.On("GetNamespace") + return &ExecutionContext_GetNamespace{Call: c} +} + +func (_m *ExecutionContext) OnGetNamespaceMatch(matchers ...interface{}) *ExecutionContext_GetNamespace { + c := _m.On("GetNamespace", matchers...) + return &ExecutionContext_GetNamespace{Call: c} +} + +// GetNamespace provides a mock function with given fields: +func (_m *ExecutionContext) GetNamespace() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +type ExecutionContext_GetOwnerReference struct { + *mock.Call +} + +func (_m ExecutionContext_GetOwnerReference) Return(_a0 v1.OwnerReference) *ExecutionContext_GetOwnerReference { + return &ExecutionContext_GetOwnerReference{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionContext) OnGetOwnerReference() *ExecutionContext_GetOwnerReference { + c := _m.On("GetOwnerReference") + return &ExecutionContext_GetOwnerReference{Call: c} +} + +func (_m *ExecutionContext) OnGetOwnerReferenceMatch(matchers ...interface{}) *ExecutionContext_GetOwnerReference { + c := _m.On("GetOwnerReference", matchers...) + return &ExecutionContext_GetOwnerReference{Call: c} +} + +// GetOwnerReference provides a mock function with given fields: +func (_m *ExecutionContext) GetOwnerReference() v1.OwnerReference { + ret := _m.Called() + + var r0 v1.OwnerReference + if rf, ok := ret.Get(0).(func() v1.OwnerReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.OwnerReference) + } + + return r0 +} + +type ExecutionContext_GetServiceAccountName struct { + *mock.Call +} + +func (_m ExecutionContext_GetServiceAccountName) Return(_a0 string) *ExecutionContext_GetServiceAccountName { + return &ExecutionContext_GetServiceAccountName{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionContext) OnGetServiceAccountName() *ExecutionContext_GetServiceAccountName { + c := _m.On("GetServiceAccountName") + return &ExecutionContext_GetServiceAccountName{Call: c} +} + +func (_m *ExecutionContext) OnGetServiceAccountNameMatch(matchers ...interface{}) *ExecutionContext_GetServiceAccountName { + c := _m.On("GetServiceAccountName", matchers...) + return &ExecutionContext_GetServiceAccountName{Call: c} +} + +// GetServiceAccountName provides a mock function with given fields: +func (_m *ExecutionContext) GetServiceAccountName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +type ExecutionContext_GetTask struct { + *mock.Call +} + +func (_m ExecutionContext_GetTask) Return(_a0 v1alpha1.ExecutableTask, _a1 error) *ExecutionContext_GetTask { + return &ExecutionContext_GetTask{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ExecutionContext) OnGetTask(id string) *ExecutionContext_GetTask { + c := _m.On("GetTask", id) + return &ExecutionContext_GetTask{Call: c} +} + +func (_m *ExecutionContext) OnGetTaskMatch(matchers ...interface{}) *ExecutionContext_GetTask { + c := _m.On("GetTask", matchers...) + return &ExecutionContext_GetTask{Call: c} +} + +// GetTask provides a mock function with given fields: id +func (_m *ExecutionContext) GetTask(id string) (v1alpha1.ExecutableTask, error) { + ret := _m.Called(id) + + var r0 v1alpha1.ExecutableTask + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableTask); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableTask) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type ExecutionContext_IsInterruptible struct { + *mock.Call +} + +func (_m ExecutionContext_IsInterruptible) Return(_a0 bool) *ExecutionContext_IsInterruptible { + return &ExecutionContext_IsInterruptible{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutionContext) OnIsInterruptible() *ExecutionContext_IsInterruptible { + c := _m.On("IsInterruptible") + return &ExecutionContext_IsInterruptible{Call: c} +} + +func (_m *ExecutionContext) OnIsInterruptibleMatch(matchers ...interface{}) *ExecutionContext_IsInterruptible { + c := _m.On("IsInterruptible", matchers...) + return &ExecutionContext_IsInterruptible{Call: c} +} + +// IsInterruptible provides a mock function with given fields: +func (_m *ExecutionContext) IsInterruptible() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} diff --git a/pkg/controller/executors/mocks/immutable_execution_context.go b/pkg/controller/executors/mocks/immutable_execution_context.go new file mode 100644 index 0000000000..d5d23da77d --- /dev/null +++ b/pkg/controller/executors/mocks/immutable_execution_context.go @@ -0,0 +1,373 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + types "k8s.io/apimachinery/pkg/types" + + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +// ImmutableExecutionContext is an autogenerated mock type for the ImmutableExecutionContext type +type ImmutableExecutionContext struct { + mock.Mock +} + +type ImmutableExecutionContext_GetAnnotations struct { + *mock.Call +} + +func (_m ImmutableExecutionContext_GetAnnotations) Return(_a0 map[string]string) *ImmutableExecutionContext_GetAnnotations { + return &ImmutableExecutionContext_GetAnnotations{Call: _m.Call.Return(_a0)} +} + +func (_m *ImmutableExecutionContext) OnGetAnnotations() *ImmutableExecutionContext_GetAnnotations { + c := _m.On("GetAnnotations") + return &ImmutableExecutionContext_GetAnnotations{Call: c} +} + +func (_m *ImmutableExecutionContext) OnGetAnnotationsMatch(matchers ...interface{}) *ImmutableExecutionContext_GetAnnotations { + c := _m.On("GetAnnotations", matchers...) + return &ImmutableExecutionContext_GetAnnotations{Call: c} +} + +// GetAnnotations provides a mock function with given fields: +func (_m *ImmutableExecutionContext) GetAnnotations() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +type ImmutableExecutionContext_GetCreationTimestamp struct { + *mock.Call +} + +func (_m ImmutableExecutionContext_GetCreationTimestamp) Return(_a0 v1.Time) *ImmutableExecutionContext_GetCreationTimestamp { + return &ImmutableExecutionContext_GetCreationTimestamp{Call: _m.Call.Return(_a0)} +} + +func (_m *ImmutableExecutionContext) OnGetCreationTimestamp() *ImmutableExecutionContext_GetCreationTimestamp { + c := _m.On("GetCreationTimestamp") + return &ImmutableExecutionContext_GetCreationTimestamp{Call: c} +} + +func (_m *ImmutableExecutionContext) OnGetCreationTimestampMatch(matchers ...interface{}) *ImmutableExecutionContext_GetCreationTimestamp { + c := _m.On("GetCreationTimestamp", matchers...) + return &ImmutableExecutionContext_GetCreationTimestamp{Call: c} +} + +// GetCreationTimestamp provides a mock function with given fields: +func (_m *ImmutableExecutionContext) GetCreationTimestamp() v1.Time { + ret := _m.Called() + + var r0 v1.Time + if rf, ok := ret.Get(0).(func() v1.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.Time) + } + + return r0 +} + +type ImmutableExecutionContext_GetExecutionID struct { + *mock.Call +} + +func (_m ImmutableExecutionContext_GetExecutionID) Return(_a0 v1alpha1.WorkflowExecutionIdentifier) *ImmutableExecutionContext_GetExecutionID { + return &ImmutableExecutionContext_GetExecutionID{Call: _m.Call.Return(_a0)} +} + +func (_m *ImmutableExecutionContext) OnGetExecutionID() *ImmutableExecutionContext_GetExecutionID { + c := _m.On("GetExecutionID") + return &ImmutableExecutionContext_GetExecutionID{Call: c} +} + +func (_m *ImmutableExecutionContext) OnGetExecutionIDMatch(matchers ...interface{}) *ImmutableExecutionContext_GetExecutionID { + c := _m.On("GetExecutionID", matchers...) + return &ImmutableExecutionContext_GetExecutionID{Call: c} +} + +// GetExecutionID provides a mock function with given fields: +func (_m *ImmutableExecutionContext) GetExecutionID() v1alpha1.WorkflowExecutionIdentifier { + ret := _m.Called() + + var r0 v1alpha1.WorkflowExecutionIdentifier + if rf, ok := ret.Get(0).(func() v1alpha1.WorkflowExecutionIdentifier); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.WorkflowExecutionIdentifier) + } + + return r0 +} + +type ImmutableExecutionContext_GetID struct { + *mock.Call +} + +func (_m ImmutableExecutionContext_GetID) Return(_a0 string) *ImmutableExecutionContext_GetID { + return &ImmutableExecutionContext_GetID{Call: _m.Call.Return(_a0)} +} + +func (_m *ImmutableExecutionContext) OnGetID() *ImmutableExecutionContext_GetID { + c := _m.On("GetID") + return &ImmutableExecutionContext_GetID{Call: c} +} + +func (_m *ImmutableExecutionContext) OnGetIDMatch(matchers ...interface{}) *ImmutableExecutionContext_GetID { + c := _m.On("GetID", matchers...) + return &ImmutableExecutionContext_GetID{Call: c} +} + +// GetID provides a mock function with given fields: +func (_m *ImmutableExecutionContext) GetID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +type ImmutableExecutionContext_GetK8sWorkflowID struct { + *mock.Call +} + +func (_m ImmutableExecutionContext_GetK8sWorkflowID) Return(_a0 types.NamespacedName) *ImmutableExecutionContext_GetK8sWorkflowID { + return &ImmutableExecutionContext_GetK8sWorkflowID{Call: _m.Call.Return(_a0)} +} + +func (_m *ImmutableExecutionContext) OnGetK8sWorkflowID() *ImmutableExecutionContext_GetK8sWorkflowID { + c := _m.On("GetK8sWorkflowID") + return &ImmutableExecutionContext_GetK8sWorkflowID{Call: c} +} + +func (_m *ImmutableExecutionContext) OnGetK8sWorkflowIDMatch(matchers ...interface{}) *ImmutableExecutionContext_GetK8sWorkflowID { + c := _m.On("GetK8sWorkflowID", matchers...) + return &ImmutableExecutionContext_GetK8sWorkflowID{Call: c} +} + +// GetK8sWorkflowID provides a mock function with given fields: +func (_m *ImmutableExecutionContext) GetK8sWorkflowID() types.NamespacedName { + ret := _m.Called() + + var r0 types.NamespacedName + if rf, ok := ret.Get(0).(func() types.NamespacedName); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.NamespacedName) + } + + return r0 +} + +type ImmutableExecutionContext_GetLabels struct { + *mock.Call +} + +func (_m ImmutableExecutionContext_GetLabels) Return(_a0 map[string]string) *ImmutableExecutionContext_GetLabels { + return &ImmutableExecutionContext_GetLabels{Call: _m.Call.Return(_a0)} +} + +func (_m *ImmutableExecutionContext) OnGetLabels() *ImmutableExecutionContext_GetLabels { + c := _m.On("GetLabels") + return &ImmutableExecutionContext_GetLabels{Call: c} +} + +func (_m *ImmutableExecutionContext) OnGetLabelsMatch(matchers ...interface{}) *ImmutableExecutionContext_GetLabels { + c := _m.On("GetLabels", matchers...) + return &ImmutableExecutionContext_GetLabels{Call: c} +} + +// GetLabels provides a mock function with given fields: +func (_m *ImmutableExecutionContext) GetLabels() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + +type ImmutableExecutionContext_GetName struct { + *mock.Call +} + +func (_m ImmutableExecutionContext_GetName) Return(_a0 string) *ImmutableExecutionContext_GetName { + return &ImmutableExecutionContext_GetName{Call: _m.Call.Return(_a0)} +} + +func (_m *ImmutableExecutionContext) OnGetName() *ImmutableExecutionContext_GetName { + c := _m.On("GetName") + return &ImmutableExecutionContext_GetName{Call: c} +} + +func (_m *ImmutableExecutionContext) OnGetNameMatch(matchers ...interface{}) *ImmutableExecutionContext_GetName { + c := _m.On("GetName", matchers...) + return &ImmutableExecutionContext_GetName{Call: c} +} + +// GetName provides a mock function with given fields: +func (_m *ImmutableExecutionContext) GetName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +type ImmutableExecutionContext_GetNamespace struct { + *mock.Call +} + +func (_m ImmutableExecutionContext_GetNamespace) Return(_a0 string) *ImmutableExecutionContext_GetNamespace { + return &ImmutableExecutionContext_GetNamespace{Call: _m.Call.Return(_a0)} +} + +func (_m *ImmutableExecutionContext) OnGetNamespace() *ImmutableExecutionContext_GetNamespace { + c := _m.On("GetNamespace") + return &ImmutableExecutionContext_GetNamespace{Call: c} +} + +func (_m *ImmutableExecutionContext) OnGetNamespaceMatch(matchers ...interface{}) *ImmutableExecutionContext_GetNamespace { + c := _m.On("GetNamespace", matchers...) + return &ImmutableExecutionContext_GetNamespace{Call: c} +} + +// GetNamespace provides a mock function with given fields: +func (_m *ImmutableExecutionContext) GetNamespace() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +type ImmutableExecutionContext_GetOwnerReference struct { + *mock.Call +} + +func (_m ImmutableExecutionContext_GetOwnerReference) Return(_a0 v1.OwnerReference) *ImmutableExecutionContext_GetOwnerReference { + return &ImmutableExecutionContext_GetOwnerReference{Call: _m.Call.Return(_a0)} +} + +func (_m *ImmutableExecutionContext) OnGetOwnerReference() *ImmutableExecutionContext_GetOwnerReference { + c := _m.On("GetOwnerReference") + return &ImmutableExecutionContext_GetOwnerReference{Call: c} +} + +func (_m *ImmutableExecutionContext) OnGetOwnerReferenceMatch(matchers ...interface{}) *ImmutableExecutionContext_GetOwnerReference { + c := _m.On("GetOwnerReference", matchers...) + return &ImmutableExecutionContext_GetOwnerReference{Call: c} +} + +// GetOwnerReference provides a mock function with given fields: +func (_m *ImmutableExecutionContext) GetOwnerReference() v1.OwnerReference { + ret := _m.Called() + + var r0 v1.OwnerReference + if rf, ok := ret.Get(0).(func() v1.OwnerReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1.OwnerReference) + } + + return r0 +} + +type ImmutableExecutionContext_GetServiceAccountName struct { + *mock.Call +} + +func (_m ImmutableExecutionContext_GetServiceAccountName) Return(_a0 string) *ImmutableExecutionContext_GetServiceAccountName { + return &ImmutableExecutionContext_GetServiceAccountName{Call: _m.Call.Return(_a0)} +} + +func (_m *ImmutableExecutionContext) OnGetServiceAccountName() *ImmutableExecutionContext_GetServiceAccountName { + c := _m.On("GetServiceAccountName") + return &ImmutableExecutionContext_GetServiceAccountName{Call: c} +} + +func (_m *ImmutableExecutionContext) OnGetServiceAccountNameMatch(matchers ...interface{}) *ImmutableExecutionContext_GetServiceAccountName { + c := _m.On("GetServiceAccountName", matchers...) + return &ImmutableExecutionContext_GetServiceAccountName{Call: c} +} + +// GetServiceAccountName provides a mock function with given fields: +func (_m *ImmutableExecutionContext) GetServiceAccountName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +type ImmutableExecutionContext_IsInterruptible struct { + *mock.Call +} + +func (_m ImmutableExecutionContext_IsInterruptible) Return(_a0 bool) *ImmutableExecutionContext_IsInterruptible { + return &ImmutableExecutionContext_IsInterruptible{Call: _m.Call.Return(_a0)} +} + +func (_m *ImmutableExecutionContext) OnIsInterruptible() *ImmutableExecutionContext_IsInterruptible { + c := _m.On("IsInterruptible") + return &ImmutableExecutionContext_IsInterruptible{Call: c} +} + +func (_m *ImmutableExecutionContext) OnIsInterruptibleMatch(matchers ...interface{}) *ImmutableExecutionContext_IsInterruptible { + c := _m.On("IsInterruptible", matchers...) + return &ImmutableExecutionContext_IsInterruptible{Call: c} +} + +// IsInterruptible provides a mock function with given fields: +func (_m *ImmutableExecutionContext) IsInterruptible() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} diff --git a/pkg/controller/executors/mocks/node.go b/pkg/controller/executors/mocks/node.go index 73d154e1d5..68804620cf 100644 --- a/pkg/controller/executors/mocks/node.go +++ b/pkg/controller/executors/mocks/node.go @@ -26,8 +26,8 @@ func (_m Node_AbortHandler) Return(_a0 error) *Node_AbortHandler { return &Node_AbortHandler{Call: _m.Call.Return(_a0)} } -func (_m *Node) OnAbortHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode, reason string) *Node_AbortHandler { - c := _m.On("AbortHandler", ctx, w, currentNode, reason) +func (_m *Node) OnAbortHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode, reason string) *Node_AbortHandler { + c := _m.On("AbortHandler", ctx, execContext, dag, nl, currentNode, reason) return &Node_AbortHandler{Call: c} } @@ -36,13 +36,13 @@ func (_m *Node) OnAbortHandlerMatch(matchers ...interface{}) *Node_AbortHandler return &Node_AbortHandler{Call: c} } -// AbortHandler provides a mock function with given fields: ctx, w, currentNode, reason -func (_m *Node) AbortHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode, reason string) error { - ret := _m.Called(ctx, w, currentNode, reason) +// AbortHandler provides a mock function with given fields: ctx, execContext, dag, nl, currentNode, reason +func (_m *Node) AbortHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode, reason string) error { + ret := _m.Called(ctx, execContext, dag, nl, currentNode, reason) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode, string) error); ok { - r0 = rf(ctx, w, currentNode, reason) + if rf, ok := ret.Get(0).(func(context.Context, executors.ExecutionContext, executors.DAGStructure, executors.NodeLookup, v1alpha1.ExecutableNode, string) error); ok { + r0 = rf(ctx, execContext, dag, nl, currentNode, reason) } else { r0 = ret.Error(0) } @@ -58,8 +58,8 @@ func (_m Node_FinalizeHandler) Return(_a0 error) *Node_FinalizeHandler { return &Node_FinalizeHandler{Call: _m.Call.Return(_a0)} } -func (_m *Node) OnFinalizeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) *Node_FinalizeHandler { - c := _m.On("FinalizeHandler", ctx, w, currentNode) +func (_m *Node) OnFinalizeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) *Node_FinalizeHandler { + c := _m.On("FinalizeHandler", ctx, execContext, dag, nl, currentNode) return &Node_FinalizeHandler{Call: c} } @@ -68,13 +68,13 @@ func (_m *Node) OnFinalizeHandlerMatch(matchers ...interface{}) *Node_FinalizeHa return &Node_FinalizeHandler{Call: c} } -// FinalizeHandler provides a mock function with given fields: ctx, w, currentNode -func (_m *Node) FinalizeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) error { - ret := _m.Called(ctx, w, currentNode) +// FinalizeHandler provides a mock function with given fields: ctx, execContext, dag, nl, currentNode +func (_m *Node) FinalizeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) error { + ret := _m.Called(ctx, execContext, dag, nl, currentNode) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode) error); ok { - r0 = rf(ctx, w, currentNode) + if rf, ok := ret.Get(0).(func(context.Context, executors.ExecutionContext, executors.DAGStructure, executors.NodeLookup, v1alpha1.ExecutableNode) error); ok { + r0 = rf(ctx, execContext, dag, nl, currentNode) } else { r0 = ret.Error(0) } @@ -122,8 +122,8 @@ func (_m Node_RecursiveNodeHandler) Return(_a0 executors.NodeStatus, _a1 error) return &Node_RecursiveNodeHandler{Call: _m.Call.Return(_a0, _a1)} } -func (_m *Node) OnRecursiveNodeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) *Node_RecursiveNodeHandler { - c := _m.On("RecursiveNodeHandler", ctx, w, currentNode) +func (_m *Node) OnRecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) *Node_RecursiveNodeHandler { + c := _m.On("RecursiveNodeHandler", ctx, execContext, dag, nl, currentNode) return &Node_RecursiveNodeHandler{Call: c} } @@ -132,20 +132,20 @@ func (_m *Node) OnRecursiveNodeHandlerMatch(matchers ...interface{}) *Node_Recur return &Node_RecursiveNodeHandler{Call: c} } -// RecursiveNodeHandler provides a mock function with given fields: ctx, w, currentNode -func (_m *Node) RecursiveNodeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { - ret := _m.Called(ctx, w, currentNode) +// RecursiveNodeHandler provides a mock function with given fields: ctx, execContext, dag, nl, currentNode +func (_m *Node) RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { + ret := _m.Called(ctx, execContext, dag, nl, currentNode) var r0 executors.NodeStatus - if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode) executors.NodeStatus); ok { - r0 = rf(ctx, w, currentNode) + if rf, ok := ret.Get(0).(func(context.Context, executors.ExecutionContext, executors.DAGStructure, executors.NodeLookup, v1alpha1.ExecutableNode) executors.NodeStatus); ok { + r0 = rf(ctx, execContext, dag, nl, currentNode) } else { r0 = ret.Get(0).(executors.NodeStatus) } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, v1alpha1.ExecutableWorkflow, v1alpha1.ExecutableNode) error); ok { - r1 = rf(ctx, w, currentNode) + if rf, ok := ret.Get(1).(func(context.Context, executors.ExecutionContext, executors.DAGStructure, executors.NodeLookup, v1alpha1.ExecutableNode) error); ok { + r1 = rf(ctx, execContext, dag, nl, currentNode) } else { r1 = ret.Error(1) } @@ -161,8 +161,8 @@ func (_m Node_SetInputsForStartNode) Return(_a0 executors.NodeStatus, _a1 error) return &Node_SetInputsForStartNode{Call: _m.Call.Return(_a0, _a1)} } -func (_m *Node) OnSetInputsForStartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, inputs *core.LiteralMap) *Node_SetInputsForStartNode { - c := _m.On("SetInputsForStartNode", ctx, w, inputs) +func (_m *Node) OnSetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, nl executors.NodeLookup, inputs *core.LiteralMap) *Node_SetInputsForStartNode { + c := _m.On("SetInputsForStartNode", ctx, execContext, dag, nl, inputs) return &Node_SetInputsForStartNode{Call: c} } @@ -171,20 +171,20 @@ func (_m *Node) OnSetInputsForStartNodeMatch(matchers ...interface{}) *Node_SetI return &Node_SetInputsForStartNode{Call: c} } -// SetInputsForStartNode provides a mock function with given fields: ctx, w, inputs -func (_m *Node) SetInputsForStartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, inputs *core.LiteralMap) (executors.NodeStatus, error) { - ret := _m.Called(ctx, w, inputs) +// SetInputsForStartNode provides a mock function with given fields: ctx, execContext, dag, nl, inputs +func (_m *Node) SetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, nl executors.NodeLookup, inputs *core.LiteralMap) (executors.NodeStatus, error) { + ret := _m.Called(ctx, execContext, dag, nl, inputs) var r0 executors.NodeStatus - if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.ExecutableWorkflow, *core.LiteralMap) executors.NodeStatus); ok { - r0 = rf(ctx, w, inputs) + if rf, ok := ret.Get(0).(func(context.Context, executors.ExecutionContext, executors.DAGStructureWithStartNode, executors.NodeLookup, *core.LiteralMap) executors.NodeStatus); ok { + r0 = rf(ctx, execContext, dag, nl, inputs) } else { r0 = ret.Get(0).(executors.NodeStatus) } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, v1alpha1.ExecutableWorkflow, *core.LiteralMap) error); ok { - r1 = rf(ctx, w, inputs) + if rf, ok := ret.Get(1).(func(context.Context, executors.ExecutionContext, executors.DAGStructureWithStartNode, executors.NodeLookup, *core.LiteralMap) error); ok { + r1 = rf(ctx, execContext, dag, nl, inputs) } else { r1 = ret.Error(1) } diff --git a/pkg/controller/executors/mocks/node_lookup.go b/pkg/controller/executors/mocks/node_lookup.go new file mode 100644 index 0000000000..9cc6bf94f3 --- /dev/null +++ b/pkg/controller/executors/mocks/node_lookup.go @@ -0,0 +1,91 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +// NodeLookup is an autogenerated mock type for the NodeLookup type +type NodeLookup struct { + mock.Mock +} + +type NodeLookup_GetNode struct { + *mock.Call +} + +func (_m NodeLookup_GetNode) Return(_a0 v1alpha1.ExecutableNode, _a1 bool) *NodeLookup_GetNode { + return &NodeLookup_GetNode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *NodeLookup) OnGetNode(nodeID string) *NodeLookup_GetNode { + c := _m.On("GetNode", nodeID) + return &NodeLookup_GetNode{Call: c} +} + +func (_m *NodeLookup) OnGetNodeMatch(matchers ...interface{}) *NodeLookup_GetNode { + c := _m.On("GetNode", matchers...) + return &NodeLookup_GetNode{Call: c} +} + +// GetNode provides a mock function with given fields: nodeID +func (_m *NodeLookup) GetNode(nodeID string) (v1alpha1.ExecutableNode, bool) { + ret := _m.Called(nodeID) + + var r0 v1alpha1.ExecutableNode + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableNode); ok { + r0 = rf(nodeID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNode) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(string) bool); ok { + r1 = rf(nodeID) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} + +type NodeLookup_GetNodeExecutionStatus struct { + *mock.Call +} + +func (_m NodeLookup_GetNodeExecutionStatus) Return(_a0 v1alpha1.ExecutableNodeStatus) *NodeLookup_GetNodeExecutionStatus { + return &NodeLookup_GetNodeExecutionStatus{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeLookup) OnGetNodeExecutionStatus(ctx context.Context, id string) *NodeLookup_GetNodeExecutionStatus { + c := _m.On("GetNodeExecutionStatus", ctx, id) + return &NodeLookup_GetNodeExecutionStatus{Call: c} +} + +func (_m *NodeLookup) OnGetNodeExecutionStatusMatch(matchers ...interface{}) *NodeLookup_GetNodeExecutionStatus { + c := _m.On("GetNodeExecutionStatus", matchers...) + return &NodeLookup_GetNodeExecutionStatus{Call: c} +} + +// GetNodeExecutionStatus provides a mock function with given fields: ctx, id +func (_m *NodeLookup) GetNodeExecutionStatus(ctx context.Context, id string) v1alpha1.ExecutableNodeStatus { + ret := _m.Called(ctx, id) + + var r0 v1alpha1.ExecutableNodeStatus + if rf, ok := ret.Get(0).(func(context.Context, string) v1alpha1.ExecutableNodeStatus); ok { + r0 = rf(ctx, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableNodeStatus) + } + } + + return r0 +} diff --git a/pkg/controller/executors/mocks/sub_workflow_getter.go b/pkg/controller/executors/mocks/sub_workflow_getter.go new file mode 100644 index 0000000000..5a6367b45f --- /dev/null +++ b/pkg/controller/executors/mocks/sub_workflow_getter.go @@ -0,0 +1,47 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + mock "github.com/stretchr/testify/mock" +) + +// SubWorkflowGetter is an autogenerated mock type for the SubWorkflowGetter type +type SubWorkflowGetter struct { + mock.Mock +} + +type SubWorkflowGetter_FindSubWorkflow struct { + *mock.Call +} + +func (_m SubWorkflowGetter_FindSubWorkflow) Return(_a0 v1alpha1.ExecutableSubWorkflow) *SubWorkflowGetter_FindSubWorkflow { + return &SubWorkflowGetter_FindSubWorkflow{Call: _m.Call.Return(_a0)} +} + +func (_m *SubWorkflowGetter) OnFindSubWorkflow(subID string) *SubWorkflowGetter_FindSubWorkflow { + c := _m.On("FindSubWorkflow", subID) + return &SubWorkflowGetter_FindSubWorkflow{Call: c} +} + +func (_m *SubWorkflowGetter) OnFindSubWorkflowMatch(matchers ...interface{}) *SubWorkflowGetter_FindSubWorkflow { + c := _m.On("FindSubWorkflow", matchers...) + return &SubWorkflowGetter_FindSubWorkflow{Call: c} +} + +// FindSubWorkflow provides a mock function with given fields: subID +func (_m *SubWorkflowGetter) FindSubWorkflow(subID string) v1alpha1.ExecutableSubWorkflow { + ret := _m.Called(subID) + + var r0 v1alpha1.ExecutableSubWorkflow + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableSubWorkflow); ok { + r0 = rf(subID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableSubWorkflow) + } + } + + return r0 +} diff --git a/pkg/controller/executors/mocks/task_details_getter.go b/pkg/controller/executors/mocks/task_details_getter.go new file mode 100644 index 0000000000..81fdc38cd2 --- /dev/null +++ b/pkg/controller/executors/mocks/task_details_getter.go @@ -0,0 +1,54 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + mock "github.com/stretchr/testify/mock" +) + +// TaskDetailsGetter is an autogenerated mock type for the TaskDetailsGetter type +type TaskDetailsGetter struct { + mock.Mock +} + +type TaskDetailsGetter_GetTask struct { + *mock.Call +} + +func (_m TaskDetailsGetter_GetTask) Return(_a0 v1alpha1.ExecutableTask, _a1 error) *TaskDetailsGetter_GetTask { + return &TaskDetailsGetter_GetTask{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *TaskDetailsGetter) OnGetTask(id string) *TaskDetailsGetter_GetTask { + c := _m.On("GetTask", id) + return &TaskDetailsGetter_GetTask{Call: c} +} + +func (_m *TaskDetailsGetter) OnGetTaskMatch(matchers ...interface{}) *TaskDetailsGetter_GetTask { + c := _m.On("GetTask", matchers...) + return &TaskDetailsGetter_GetTask{Call: c} +} + +// GetTask provides a mock function with given fields: id +func (_m *TaskDetailsGetter) GetTask(id string) (v1alpha1.ExecutableTask, error) { + ret := _m.Called(id) + + var r0 v1alpha1.ExecutableTask + if rf, ok := ret.Get(0).(func(string) v1alpha1.ExecutableTask); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableTask) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/controller/executors/node.go b/pkg/controller/executors/node.go index 9cd38eeb64..f8177ae195 100644 --- a/pkg/controller/executors/node.go +++ b/pkg/controller/executors/node.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" ) @@ -62,19 +63,19 @@ func (p NodePhase) String() string { type Node interface { // This method is used specifically to set inputs for start node. This is because start node does not retrieve inputs // from predecessors, but the inputs are inputs to the workflow or inputs to the parent container (workflow) node. - SetInputsForStartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, inputs *core.LiteralMap) (NodeStatus, error) + SetInputsForStartNode(ctx context.Context, execContext ExecutionContext, dag DAGStructureWithStartNode, nl NodeLookup, inputs *core.LiteralMap) (NodeStatus, error) // This is the main entrypoint to execute a node. It recursively depth-first goes through all ready nodes and starts their execution // This returns either // - 1. It finds a blocking node (not ready, or running) // - 2. A node fails and hence the workflow will fail // - 3. The final/end node has completed and the workflow should be stopped - RecursiveNodeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (NodeStatus, error) + RecursiveNodeHandler(ctx context.Context, execContext ExecutionContext, dag DAGStructure, nl NodeLookup, currentNode v1alpha1.ExecutableNode) (NodeStatus, error) // This aborts the given node. If the given node is complete then it recursively finds the running nodes and aborts them - AbortHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode, reason string) error + AbortHandler(ctx context.Context, execContext ExecutionContext, dag DAGStructure, nl NodeLookup, currentNode v1alpha1.ExecutableNode, reason string) error - FinalizeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) error + FinalizeHandler(ctx context.Context, execContext ExecutionContext, dag DAGStructure, nl NodeLookup, currentNode v1alpha1.ExecutableNode) error // This method should be used to initialize Node executor Initialize(ctx context.Context) error diff --git a/pkg/controller/executors/node_lookup.go b/pkg/controller/executors/node_lookup.go new file mode 100644 index 0000000000..e5adee8c7b --- /dev/null +++ b/pkg/controller/executors/node_lookup.go @@ -0,0 +1,72 @@ +package executors + +import ( + "context" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +// NodeLookup provides a structure that enables looking up all nodes within the current execution hierarchy/context. +// NOTE: execution hierarchy may change the nodes available, this is because when a SubWorkflow is being executed, only +// the nodes within the subworkflow are visible +type NodeLookup interface { + GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) + GetNodeExecutionStatus(ctx context.Context, id v1alpha1.NodeID) v1alpha1.ExecutableNodeStatus +} + +// Implements a de-generate case of NodeLookup, where only one Node is always looked up +type singleNodeLookup struct { + n v1alpha1.ExecutableNode + v1alpha1.NodeStatusGetter +} + +func (s singleNodeLookup) GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) { + if nodeID != s.n.GetID() { + return nil, false + } + return s.n, true +} + +// Returns a De-generate NodeLookup that always returns one node and the status of that node +func NewSingleNodeLookup(n v1alpha1.ExecutableNode, s v1alpha1.NodeStatusGetter) NodeLookup { + return singleNodeLookup{NodeStatusGetter: s, n: n} +} + +// Implements a contextual NodeLookup that can be composed of a disparate NodeGetter and a NodeStatusGetter +type contextualNodeLookup struct { + v1alpha1.NodeGetter + v1alpha1.NodeStatusGetter +} + +// Returns a Contextual NodeLookup using the given NodeGetter and a separate NodeStatusGetter. +// Very useful in Subworkflows where the Subworkflow is the reservoir of the nodes, but the status for these nodes +// maybe stored int he Top-level workflow node itself. +func NewNodeLookup(n v1alpha1.NodeGetter, s v1alpha1.NodeStatusGetter) NodeLookup { + return contextualNodeLookup{ + NodeGetter: n, + NodeStatusGetter: s, + } +} + +// Implements a nodeLookup using Maps, very useful in Testing +type staticNodeLookup struct { + nodes map[v1alpha1.NodeID]v1alpha1.ExecutableNode + status map[v1alpha1.NodeID]v1alpha1.ExecutableNodeStatus +} + +func (s staticNodeLookup) GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) { + n, ok := s.nodes[nodeID] + return n, ok +} + +func (s staticNodeLookup) GetNodeExecutionStatus(_ context.Context, id v1alpha1.NodeID) v1alpha1.ExecutableNodeStatus { + return s.status[id] +} + +// Returns a new NodeLookup useful in Testing. Not recommended to be used in production +func NewTestNodeLookup(nodes map[v1alpha1.NodeID]v1alpha1.ExecutableNode, status map[v1alpha1.NodeID]v1alpha1.ExecutableNodeStatus) NodeLookup { + return staticNodeLookup{ + nodes: nodes, + status: status, + } +} diff --git a/pkg/controller/nodes/branch/evaluator.go b/pkg/controller/nodes/branch/evaluator.go index 89183ab394..618730476e 100644 --- a/pkg/controller/nodes/branch/evaluator.go +++ b/pkg/controller/nodes/branch/evaluator.go @@ -8,6 +8,7 @@ import ( v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/executors" "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" regErrors "github.com/pkg/errors" @@ -88,9 +89,9 @@ func EvaluateIfBlock(block v1alpha1.ExecutableIfBlock, nodeInputs *core.LiteralM } // Decides the branch to be taken, returns the nodeId of the selected node or an error -// The branchnode is marked as success. This is used by downstream node to determine if it can be executed +// The branchNode is marked as success. This is used by downstream node to determine if it can be executed // All downstream nodes are marked as skipped -func DecideBranch(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, nodeID v1alpha1.NodeID, node v1alpha1.ExecutableBranchNode, nodeInputs *core.LiteralMap) (*v1alpha1.NodeID, error) { +func DecideBranch(ctx context.Context, nl executors.NodeLookup, nodeID v1alpha1.NodeID, node v1alpha1.ExecutableBranchNode, nodeInputs *core.LiteralMap) (*v1alpha1.NodeID, error) { var selectedNodeID *v1alpha1.NodeID var skippedNodeIds []*v1alpha1.NodeID var err error @@ -119,11 +120,11 @@ func DecideBranch(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, nodeID } for _, nodeIDPtr := range skippedNodeIds { skippedNodeID := *nodeIDPtr - n, ok := w.GetNode(skippedNodeID) + n, ok := nl.GetNode(skippedNodeID) if !ok { return nil, errors.Errorf(errors.DownstreamNodeNotFoundError, nodeID, "Downstream node [%v] not found", skippedNodeID) } - nStatus := w.GetNodeExecutionStatus(ctx, n.GetID()) + nStatus := nl.GetNodeExecutionStatus(ctx, n.GetID()) logger.Infof(ctx, "Branch Setting Node[%v] status to Skipped!", skippedNodeID) nStatus.UpdatePhase(v1alpha1.NodePhaseSkipped, v1.Now(), "Branch evaluated to false") } diff --git a/pkg/controller/nodes/branch/handler.go b/pkg/controller/nodes/branch/handler.go index 31816d576f..e4b409f522 100644 --- a/pkg/controller/nodes/branch/handler.go +++ b/pkg/controller/nodes/branch/handler.go @@ -31,21 +31,14 @@ func (b *branchHandler) Setup(ctx context.Context, setupContext handler.SetupCon return nil } -func (b *branchHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { - logger.Debug(ctx, "Starting Branch Node") - branchNode := nCtx.Node().GetBranchNode() - if branchNode == nil { - return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.IllegalStateError, "Invoked branch handler, for a non branch node.", nil)), nil - } - +func (b *branchHandler) HandleBranchNode(ctx context.Context, branchNode v1alpha1.ExecutableBranchNode, nCtx handler.NodeExecutionContext, nl executors.NodeLookup) (handler.Transition, error) { if nCtx.NodeStateReader().GetBranchNode().FinalizedNodeID == nil { nodeInputs, err := nCtx.InputReader().Get(ctx) if err != nil { errMsg := fmt.Sprintf("Failed to read input. Error [%s]", err) return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.RuntimeExecutionError, errMsg, nil)), nil } - w := nCtx.Workflow() - finalNodeID, err := DecideBranch(ctx, w, nCtx.NodeID(), branchNode, nodeInputs) + finalNodeID, err := DecideBranch(ctx, nl, nCtx.NodeID(), branchNode, nodeInputs) if err != nil { errMsg := fmt.Sprintf("Branch evaluation failed. Error [%s]", err) return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.IllegalStateError, errMsg, nil)), nil @@ -59,18 +52,18 @@ func (b *branchHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionCo } var ok bool - finalNode, ok := w.GetNode(*finalNodeID) + finalNode, ok := nl.GetNode(*finalNodeID) if !ok { errMsg := fmt.Sprintf("Branch downstream finalized node not found. FinalizedNode [%s]", *finalNodeID) logger.Debugf(ctx, errMsg) return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.DownstreamNodeNotFoundError, errMsg, nil)), nil } i := nCtx.NodeID() - childNodeStatus := w.GetNodeExecutionStatus(ctx, finalNode.GetID()) + childNodeStatus := nl.GetNodeExecutionStatus(ctx, finalNode.GetID()) childNodeStatus.SetParentNodeID(&i) - logger.Debugf(ctx, "Recursing down branchNodestatus node") - nodeStatus := w.GetNodeExecutionStatus(ctx, nCtx.NodeID()) + logger.Debugf(ctx, "Recursively executing branchNode's chosen path") + nodeStatus := nl.GetNodeExecutionStatus(ctx, nCtx.NodeID()) return b.recurseDownstream(ctx, nCtx, nodeStatus, finalNode) } @@ -88,28 +81,42 @@ func (b *branchHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionCo return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.IllegalStateError, errMsg, nil)), nil } - w := nCtx.Workflow() - branchTakenNode, ok := w.GetNode(*finalNodeID) + branchTakenNode, ok := nl.GetNode(*finalNodeID) if !ok { errMsg := fmt.Sprintf("Downstream node [%v] not found", *finalNodeID) return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.DownstreamNodeNotFoundError, errMsg, nil)), nil } // Recurse downstream - nodeStatus := w.GetNodeExecutionStatus(ctx, nCtx.NodeID()) + nodeStatus := nl.GetNodeExecutionStatus(ctx, nCtx.NodeID()) return b.recurseDownstream(ctx, nCtx, nodeStatus, branchTakenNode) } +func (b *branchHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { + logger.Debug(ctx, "Starting Branch Node") + branchNode := nCtx.Node().GetBranchNode() + if branchNode == nil { + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.IllegalStateError, "Invoked branch handler, for a non branch node.", nil)), nil + } + + nl := nCtx.ContextualNodeLookup() + + return b.HandleBranchNode(ctx, branchNode, nCtx, nl) +} + func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx handler.NodeExecutionContext, nodeStatus v1alpha1.ExecutableNodeStatus, branchTakenNode v1alpha1.ExecutableNode) (handler.Transition, error) { - w := nCtx.Workflow() - downstreamStatus, err := b.nodeExecutor.RecursiveNodeHandler(ctx, w, branchTakenNode) + // TODO we should replace the call to RecursiveNodeHandler with a call to SingleNode Handler. The inputs are also already known ahead of time + // There is no DAGStructure for the branch nodes, the branch taken node is the leaf node. The node itself may be arbitrarily complex, but in that case the node should reference a subworkflow etc + // The parent of the BranchTaken Node is the actual Branch Node and all the data is just forwarded from the Branch to the executed node. + dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), nCtx.NodeID()) + downstreamStatus, err := b.nodeExecutor.RecursiveNodeHandler(ctx, nCtx.ExecutionContext(), dag, nCtx.ContextualNodeLookup(), branchTakenNode) if err != nil { return handler.UnknownTransition, err } if downstreamStatus.IsComplete() { // For branch node we set the output node to be the same as the child nodes output - childNodeStatus := w.GetNodeExecutionStatus(ctx, branchTakenNode.GetID()) + childNodeStatus := nCtx.ContextualNodeLookup().GetNodeExecutionStatus(ctx, branchTakenNode.GetID()) nodeStatus.SetDataDir(childNodeStatus.GetDataDir()) nodeStatus.SetOutputDir(childNodeStatus.GetOutputDir()) phase := handler.PhaseInfoSuccess(&handler.ExecutionInfo{ @@ -134,9 +141,8 @@ func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx handler.Node func (b *branchHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { branch := nCtx.Node().GetBranchNode() - w := nCtx.Workflow() if branch == nil { - return errors.Errorf(errors.IllegalStateError, w.GetID(), nCtx.NodeID(), "Invoked branch handler, for a non branch node.") + return errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "Invoked branch handler, for a non branch node.") } // If the branch was already evaluated i.e, Node is in Running status @@ -154,19 +160,53 @@ func (b *branchHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionCon } finalNodeID := branchNodeState.FinalizedNodeID - branchTakenNode, ok := w.GetNode(*finalNodeID) + branchTakenNode, ok := nCtx.ContextualNodeLookup().GetNode(*finalNodeID) if !ok { logger.Errorf(ctx, "Downstream node [%v] not found", *finalNodeID) return nil } // Recurse downstream - return b.nodeExecutor.AbortHandler(ctx, w, branchTakenNode, reason) + // TODO we should replace the call to RecursiveNodeHandler with a call to SingleNode Handler. The inputs are also already known ahead of time + // There is no DAGStructure for the branch nodes, the branch taken node is the leaf node. The node itself may be arbitrarily complex, but in that case the node should reference a subworkflow etc + // The parent of the BranchTaken Node is the actual Branch Node and all the data is just forwarded from the Branch to the executed node. + dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), nCtx.NodeID()) + return b.nodeExecutor.AbortHandler(ctx, nCtx.ExecutionContext(), dag, nCtx.ContextualNodeLookup(), branchTakenNode, reason) } -func (b *branchHandler) Finalize(ctx context.Context, executionContext handler.NodeExecutionContext) error { - logger.Debugf(ctx, "BranchNode::Finalizer: nothing to do") - return nil +func (b *branchHandler) Finalize(ctx context.Context, nCtx handler.NodeExecutionContext) error { + branch := nCtx.Node().GetBranchNode() + if branch == nil { + return errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "Invoked branch handler, for a non branch node.") + } + + // If the branch was already evaluated i.e, Node is in Running status + branchNodeState := nCtx.NodeStateReader().GetBranchNode() + if branchNodeState.Phase == v1alpha1.BranchNodeNotYetEvaluated { + logger.Errorf(ctx, "No node finalized through previous branch evaluation.") + return nil + } else if branchNodeState.Phase == v1alpha1.BranchNodeError { + // We should never reach here, but for safety and completeness + errMsg := "branch evaluation failed" + if branch.GetElseFail() != nil { + errMsg = branch.GetElseFail().Message + } + return errors.Errorf(errors.UserProvidedError, nCtx.NodeID(), errMsg) + } + + finalNodeID := branchNodeState.FinalizedNodeID + branchTakenNode, ok := nCtx.ContextualNodeLookup().GetNode(*finalNodeID) + if !ok { + logger.Errorf(ctx, "Downstream node [%v] not found", *finalNodeID) + return nil + } + + // Recurse downstream + // TODO we should replace the call to RecursiveNodeHandler with a call to SingleNode Handler. The inputs are also already known ahead of time + // There is no DAGStructure for the branch nodes, the branch taken node is the leaf node. The node itself may be arbitrarily complex, but in that case the node should reference a subworkflow etc + // The parent of the BranchTaken Node is the actual Branch Node and all the data is just forwarded from the Branch to the executed node. + dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), nCtx.NodeID()) + return b.nodeExecutor.FinalizeHandler(ctx, nCtx.ExecutionContext(), dag, nCtx.ContextualNodeLookup(), branchTakenNode) } func New(executor executors.Node, scope promutils.Scope) handler.Node { diff --git a/pkg/controller/nodes/branch/handler_test.go b/pkg/controller/nodes/branch/handler_test.go index 8a35d35934..e1975939b9 100644 --- a/pkg/controller/nodes/branch/handler_test.go +++ b/pkg/controller/nodes/branch/handler_test.go @@ -20,20 +20,12 @@ import ( "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" mocks2 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" "github.com/lyft/flytepropeller/pkg/controller/executors" + execMocks "github.com/lyft/flytepropeller/pkg/controller/executors/mocks" "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" "github.com/lyft/flytepropeller/pkg/controller/nodes/handler/mocks" ) -type recursiveNodeHandlerFn func(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) -type abortNodeHandlerCbFn func(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) error - -type mockNodeExecutor struct { - executors.Node - RecursiveNodeHandlerCB recursiveNodeHandlerFn - AbortNodeHandlerCB abortNodeHandlerCbFn -} - type branchNodeStateHolder struct { s handler.BranchNodeState } @@ -55,18 +47,9 @@ func (t branchNodeStateHolder) PutDynamicNodeState(s handler.DynamicNodeState) e panic("not implemented") } -func (m *mockNodeExecutor) RecursiveNodeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { - return m.RecursiveNodeHandlerCB(ctx, w, currentNode) -} - -func (m *mockNodeExecutor) AbortHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode, reason string) error { - return m.AbortNodeHandlerCB(ctx, w, currentNode) -} - -func createNodeContext(phase v1alpha1.BranchNodePhase, childNodeID *v1alpha1.NodeID, w v1alpha1.ExecutableWorkflow, n v1alpha1.ExecutableNode, inputs *core.LiteralMap) *mocks.NodeExecutionContext { - nodeID := "nodeID" +func createNodeContext(phase v1alpha1.BranchNodePhase, childNodeID *v1alpha1.NodeID, n v1alpha1.ExecutableNode, inputs *core.LiteralMap, nl executors.NodeLookup) (*mocks.NodeExecutionContext, *branchNodeStateHolder) { branchNodeState := handler.BranchNodeState{ - FinalizedNodeID: &nodeID, + FinalizedNodeID: childNodeID, Phase: phase, } s := &branchNodeStateHolder{s: branchNodeState} @@ -78,106 +61,124 @@ func createNodeContext(phase v1alpha1.BranchNodePhase, childNodeID *v1alpha1.Nod } nm := &mocks.NodeExecutionMetadata{} - nm.On("GetAnnotations").Return(map[string]string{}) - nm.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: wfExecID, + nm.OnGetAnnotations().Return(map[string]string{}) + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ + ExecutionId: wfExecID, + NodeId: n.GetID(), }) - nm.On("GetK8sServiceAccount").Return("service-account") - nm.On("GetLabels").Return(map[string]string{}) - nm.On("GetNamespace").Return("namespace") - nm.On("GetOwnerID").Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) - nm.On("GetOwnerReference").Return(v1.OwnerReference{ + nm.OnGetK8sServiceAccount().Return("service-account") + nm.OnGetLabels().Return(map[string]string{}) + nm.OnGetNamespace().Return("namespace") + nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.OnGetOwnerReference().Return(v1.OwnerReference{ Kind: "sample", Name: "name", }) ns := &mocks2.ExecutableNodeStatus{} - ns.On("GetDataDir").Return(storage.DataReference("data-dir")) - ns.On("GetPhase").Return(v1alpha1.NodePhaseNotYetStarted) + ns.OnGetDataDir().Return(storage.DataReference("data-dir")) + ns.OnGetPhase().Return(v1alpha1.NodePhaseNotYetStarted) ir := &mocks3.InputReader{} - ir.On("Get", mock.Anything).Return(inputs, nil) + ir.OnGetMatch(mock.Anything).Return(inputs, nil) nCtx := &mocks.NodeExecutionContext{} - nCtx.On("NodeExecutionMetadata").Return(nm) - nCtx.On("Node").Return(n) - nCtx.On("InputReader").Return(ir) - nCtx.On("DataStore").Return(storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())) - nCtx.On("CurrentAttempt").Return(uint32(1)) - nCtx.On("MaxDatasetSizeBytes").Return(int64(1)) - nCtx.On("NodeStatus").Return(ns) - - nCtx.On("NodeID").Return("n1") - nCtx.On("EnqueueOwner").Return(nil) - nCtx.On("Workflow").Return(w) + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnNode().Return(n) + nCtx.OnInputReader().Return(ir) + tmpDataStore, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + nCtx.OnDataStore().Return(tmpDataStore) + nCtx.OnCurrentAttempt().Return(uint32(1)) + nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) + nCtx.OnNodeStatus().Return(ns) + + nCtx.OnNodeID().Return("n1") + nCtx.OnEnqueueOwnerFunc().Return(nil) nr := &mocks.NodeStateReader{} - nr.On("GetBranchNode").Return(handler.BranchNodeState{ + nr.OnGetBranchNode().Return(handler.BranchNodeState{ FinalizedNodeID: childNodeID, Phase: phase, }) - nCtx.On("NodeStateReader").Return(nr) - nCtx.On("NodeStateWriter").Return(s) - return nCtx + nCtx.OnNodeStateReader().Return(nr) + nCtx.OnNodeStateWriter().Return(s) + + eCtx := &execMocks.ExecutionContext{} + nCtx.OnExecutionContext().Return(eCtx) + + nCtx.OnContextualNodeLookup().Return(nl) + return nCtx, s } func TestBranchHandler_RecurseDownstream(t *testing.T) { ctx := context.TODO() - m := &mockNodeExecutor{} - branch := New(m, promutils.NewTestScope()).(*branchHandler) childNodeID := "child" - childDatadir := v1alpha1.DataReference("test") - - dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) - assert.NoError(t, err) - w := &v1alpha1.FlyteWorkflow{ - Status: v1alpha1.WorkflowStatus{ - NodeStatus: map[v1alpha1.NodeID]*v1alpha1.NodeStatus{ - childNodeID: { - DataDir: childDatadir, - }, - }, - }, - DataReferenceConstructor: dataStore, - } + nodeID := "n1" res := &v12.ResourceRequirements{} n := &mocks2.ExecutableNode{} - n.On("GetResources").Return(res) + n.OnGetResources().Return(res) + n.OnGetID().Return(nodeID) expectedError := fmt.Errorf("error") - recursiveNodeHandlerFnArchetype := func(status executors.NodeStatus, err error) recursiveNodeHandlerFn { - return func(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { - return status, err - } - } + bn := &mocks2.ExecutableNode{} + bn.OnGetID().Return(childNodeID) tests := []struct { - name string - recursiveNodeHandlerFn recursiveNodeHandlerFn - nodeStatus v1alpha1.ExecutableNodeStatus - branchTakenNode v1alpha1.ExecutableNode - isErr bool - expectedPhase handler.EPhase - childPhase v1alpha1.NodePhase + name string + ns executors.NodeStatus + err error + nodeStatus *mocks2.ExecutableNodeStatus + branchTakenNode v1alpha1.ExecutableNode + isErr bool + expectedPhase handler.EPhase + childPhase v1alpha1.NodePhase + nl *execMocks.NodeLookup }{ - {"childNodeError", recursiveNodeHandlerFnArchetype(executors.NodeStatusUndefined, expectedError), - nil, &v1alpha1.NodeSpec{}, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed}, - {"childPending", recursiveNodeHandlerFnArchetype(executors.NodeStatusPending, nil), - nil, &v1alpha1.NodeSpec{}, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued}, - {"childStillRunning", recursiveNodeHandlerFnArchetype(executors.NodeStatusRunning, nil), - nil, &v1alpha1.NodeSpec{}, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning}, - {"childFailure", recursiveNodeHandlerFnArchetype(executors.NodeStatusFailed(expectedError), nil), - nil, &v1alpha1.NodeSpec{}, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed}, - {"childComplete", recursiveNodeHandlerFnArchetype(executors.NodeStatusComplete, nil), - &v1alpha1.NodeStatus{}, &v1alpha1.NodeSpec{ID: childNodeID}, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded}, + {"childNodeError", executors.NodeStatusUndefined, expectedError, + nil, bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, nil}, + {"childPending", executors.NodeStatusPending, nil, + nil, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, nil}, + {"childStillRunning", executors.NodeStatusRunning, nil, + nil, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, nil}, + {"childFailure", executors.NodeStatusFailed(expectedError), nil, + nil, bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, nil}, + {"childComplete", executors.NodeStatusComplete, nil, + &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, &execMocks.NodeLookup{}}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - m.RecursiveNodeHandlerCB = test.recursiveNodeHandlerFn - - nCtx := createNodeContext(v1alpha1.BranchNodeNotYetEvaluated, &childNodeID, w, n, nil) + nCtx, _ := createNodeContext(v1alpha1.BranchNodeNotYetEvaluated, &childNodeID, n, nil, test.nl) + mockNodeExecutor := &execMocks.Node{} + mockNodeExecutor.OnRecursiveNodeHandlerMatch( + mock.Anything, // ctx + mock.MatchedBy(func(e executors.ExecutionContext) bool { return assert.Equal(t, e, nCtx.ExecutionContext()) }), + mock.MatchedBy(func(d executors.DAGStructure) bool { + if assert.NotNil(t, d) { + fList, err1 := d.FromNode("x") + dList, err2 := d.ToNode(childNodeID) + b := assert.NoError(t, err1) + b = b && assert.Equal(t, fList, []v1alpha1.NodeID{}) + b = b && assert.NoError(t, err2) + b = b && assert.Equal(t, dList, []v1alpha1.NodeID{nodeID}) + return b + } + return false + }), + mock.MatchedBy(func(lookup executors.NodeLookup) bool { return assert.Equal(t, lookup, test.nl) }), + mock.MatchedBy(func(n v1alpha1.ExecutableNode) bool { return assert.Equal(t, n.GetID(), childNodeID) }), + ).Return(test.ns, test.err) + + childNodeStatus := &mocks2.ExecutableNodeStatus{} + if test.nl != nil { + childNodeStatus.OnGetDataDir().Return("child-data-dir") + childNodeStatus.OnGetOutputDir().Return("child-output-dir") + test.nl.OnGetNodeExecutionStatus(ctx, childNodeID).Return(childNodeStatus) + test.nodeStatus.On("SetDataDir", storage.DataReference("child-data-dir")).Once() + test.nodeStatus.On("SetOutputDir", storage.DataReference("child-output-dir")).Once() + } + branch := New(mockNodeExecutor, promutils.NewTestScope()).(*branchHandler) h, err := branch.recurseDownstream(ctx, nCtx, test.nodeStatus, test.branchTakenNode) if test.isErr { assert.Error(t, err) @@ -185,17 +186,12 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { assert.NoError(t, err) } assert.Equal(t, test.expectedPhase, h.Info().GetPhase()) - if test.nodeStatus != nil { - assert.Equal(t, w.GetNodeExecutionStatus(ctx, test.branchTakenNode.GetID()).GetDataDir(), test.nodeStatus.GetDataDir()) - } }) } } func TestBranchHandler_AbortNode(t *testing.T) { ctx := context.TODO() - m := &mockNodeExecutor{} - branch := New(m, promutils.NewTestScope()) b1 := "b1" n1 := "n1" n2 := "n2" @@ -253,21 +249,25 @@ func TestBranchHandler_AbortNode(t *testing.T) { }, }, } + assert.NotNil(t, w) t.Run("NoBranchNode", func(t *testing.T) { - nCtx := createNodeContext(v1alpha1.BranchNodeError, nil, w, n, nil) + mockNodeExecutor := &execMocks.Node{} + mockNodeExecutor.OnAbortHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("err")) + nCtx, _ := createNodeContext(v1alpha1.BranchNodeError, nil, n, nil, nil) + branch := New(mockNodeExecutor, promutils.NewTestScope()) err := branch.Abort(ctx, nCtx, "") assert.Error(t, err) assert.True(t, errors.Matches(err, errors.UserProvidedError)) }) t.Run("BranchNodeSuccess", func(t *testing.T) { - - nCtx := createNodeContext(v1alpha1.BranchNodeSuccess, &n1, w, n, nil) - m.AbortNodeHandlerCB = func(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) error { - assert.Equal(t, n1, currentNode.GetID()) - return nil - } + mockNodeExecutor := &execMocks.Node{} + mockNodeExecutor.OnAbortHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + nl := &execMocks.NodeLookup{} + nCtx, s := createNodeContext(v1alpha1.BranchNodeSuccess, &n1, n, nil, nl) + nl.OnGetNode(*s.s.FinalizedNodeID).Return(n, true) + branch := New(mockNodeExecutor, promutils.NewTestScope()) err := branch.Abort(ctx, nCtx, "") assert.NoError(t, err) }) @@ -275,16 +275,16 @@ func TestBranchHandler_AbortNode(t *testing.T) { func TestBranchHandler_Initialize(t *testing.T) { ctx := context.TODO() - m := &mockNodeExecutor{} - branch := New(m, promutils.NewTestScope()) + mockNodeExecutor := &execMocks.Node{} + branch := New(mockNodeExecutor, promutils.NewTestScope()) assert.NoError(t, branch.Setup(ctx, nil)) } // TODO incomplete test suite, add more func TestBranchHandler_HandleNode(t *testing.T) { ctx := context.TODO() - m := &mockNodeExecutor{} - branch := New(m, promutils.NewTestScope()) + mockNodeExecutor := &execMocks.Node{} + branch := New(mockNodeExecutor, promutils.NewTestScope()) childNodeID := "child" childDatadir := v1alpha1.DataReference("test") w := &v1alpha1.FlyteWorkflow{ @@ -299,6 +299,8 @@ func TestBranchHandler_HandleNode(t *testing.T) { }, }, } + assert.NotNil(t, w) + _, inputs := getComparisonExpression(1, core.ComparisonExpression_NEQ, 1) tests := []struct { @@ -313,9 +315,10 @@ func TestBranchHandler_HandleNode(t *testing.T) { t.Run(test.name, func(t *testing.T) { res := &v12.ResourceRequirements{} n := &mocks2.ExecutableNode{} - n.On("GetResources").Return(res) - n.On("GetBranchNode").Return(nil) - nCtx := createNodeContext(v1alpha1.BranchNodeSuccess, &childNodeID, w, n, inputs) + n.OnGetResources().Return(res) + n.OnGetBranchNode().Return(nil) + n.OnGetID().Return("n1") + nCtx, _ := createNodeContext(v1alpha1.BranchNodeSuccess, &childNodeID, n, inputs, nil) s, err := branch.Handle(ctx, nCtx) if test.isErr { diff --git a/pkg/controller/nodes/dynamic/handler.go b/pkg/controller/nodes/dynamic/handler.go index 9ac9e40595..1c4ceae81c 100644 --- a/pkg/controller/nodes/dynamic/handler.go +++ b/pkg/controller/nodes/dynamic/handler.go @@ -96,13 +96,13 @@ func (d dynamicNodeTaskNodeHandler) handleParentNode(ctx context.Context, prevSt } func (d dynamicNodeTaskNodeHandler) handleDynamicSubNodes(ctx context.Context, nCtx handler.NodeExecutionContext, prevState handler.DynamicNodeState) (handler.Transition, handler.DynamicNodeState, error) { - dynamicWF, _, err := d.buildContextualDynamicWorkflow(ctx, nCtx) + execContext, dynamicWF, nl, _, err := d.buildContextualDynamicWorkflow(ctx, nCtx) if err != nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure( "DynamicWorkflowBuildFailed", err.Error(), nil)), handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: err.Error()}, nil } - trns, newState, err := d.progressDynamicWorkflow(ctx, dynamicWF, nCtx, prevState) + trns, newState, err := d.progressDynamicWorkflow(ctx, execContext, dynamicWF, nl, nCtx, prevState) if err != nil { return handler.UnknownTransition, prevState, err } @@ -188,7 +188,7 @@ func (d dynamicNodeTaskNodeHandler) Abort(ctx context.Context, nCtx handler.Node fallthrough case v1alpha1.DynamicNodePhaseExecuting: logger.Infof(ctx, "Aborting dynamic workflow at RetryAttempt [%d]", nCtx.CurrentAttempt()) - dynamicWF, isDynamic, err := d.buildContextualDynamicWorkflow(ctx, nCtx) + execContext, dynamicWF, nl, isDynamic, err := d.buildContextualDynamicWorkflow(ctx, nCtx) if err != nil { return err } @@ -197,7 +197,7 @@ func (d dynamicNodeTaskNodeHandler) Abort(ctx context.Context, nCtx handler.Node return nil } - return d.nodeExecutor.AbortHandler(ctx, dynamicWF, dynamicWF.StartNode(), reason) + return d.nodeExecutor.AbortHandler(ctx, execContext, dynamicWF, nl, dynamicWF.StartNode(), reason) default: logger.Infof(ctx, "Aborting regular node RetryAttempt [%d]", nCtx.CurrentAttempt()) // The parent node has not yet completed, so we will abort the parent node @@ -221,12 +221,12 @@ func (d dynamicNodeTaskNodeHandler) Finalize(ctx context.Context, nCtx handler.N ds := nCtx.NodeStateReader().GetDynamicNodeState() if ds.Phase == v1alpha1.DynamicNodePhaseFailing || ds.Phase == v1alpha1.DynamicNodePhaseExecuting { logger.Infof(ctx, "Finalizing dynamic workflow RetryAttempt [%d]", nCtx.CurrentAttempt()) - dynamicWF, isDynamic, err := d.buildContextualDynamicWorkflow(ctx, nCtx) + execContext, dynamicWF, nl, isDynamic, err := d.buildContextualDynamicWorkflow(ctx, nCtx) if err != nil { errs = append(errs, err) } else { if isDynamic { - if err := d.nodeExecutor.FinalizeHandler(ctx, dynamicWF, dynamicWF.StartNode()); err != nil { + if err := d.nodeExecutor.FinalizeHandler(ctx, execContext, dynamicWF, nl, dynamicWF.StartNode()); err != nil { logger.Errorf(ctx, "failed to finalize dynamic workflow, err: %s", err) errs = append(errs, err) } @@ -316,8 +316,8 @@ func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflowTemplate(ctx context.Con return &core.WorkflowTemplate{ Id: &core.Identifier{ - Project: nCtx.NodeExecutionMetadata().GetExecutionID().Project, - Domain: nCtx.NodeExecutionMetadata().GetExecutionID().Domain, + Project: nCtx.NodeExecutionMetadata().GetNodeExecutionID().GetExecutionId().Project, + Domain: nCtx.NodeExecutionMetadata().GetNodeExecutionID().GetExecutionId().Domain, Version: rand.String(10), Name: rand.String(10), ResourceType: core.ResourceType_WORKFLOW, @@ -328,22 +328,21 @@ func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflowTemplate(ctx context.Con }, nil } -func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext) (dynamicWf v1alpha1.ExecutableWorkflow, isDynamic bool, err error) { +func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext) (execContext executors.ExecutionContext, subwf v1alpha1.ExecutableWorkflow, nodeLookup executors.NodeLookup, isDynamic bool, err error) { t := d.metrics.buildDynamicWorkflow.Start(ctx) defer t.Stop() - f, err := task.NewRemoteFutureFileReader(ctx, nCtx.NodeStatus().GetOutputDir(), nCtx.DataStore()) if err != nil { - return nil, false, err + return } // TODO: This is a hack to set parent task execution id, we should move to node-node relationship. execID := task.GetTaskExecutionIdentifier(nCtx) - nStatus := nCtx.NodeStatus().GetNodeExecutionStatus(ctx, dynamicNodeID) - nStatus.SetDataDir(nCtx.NodeStatus().GetDataDir()) - nStatus.SetOutputDir(nCtx.NodeStatus().GetOutputDir()) - nStatus.SetParentTaskID(execID) + dynamicNodeStatus := nCtx.NodeStatus().GetNodeExecutionStatus(ctx, dynamicNodeID) + dynamicNodeStatus.SetDataDir(nCtx.NodeStatus().GetDataDir()) + dynamicNodeStatus.SetOutputDir(nCtx.NodeStatus().GetOutputDir()) + dynamicNodeStatus.SetParentTaskID(execID) // cacheHitStopWatch := d.metrics.CacheHit.Start(ctx) // Check if we have compiled the workflow before: @@ -359,37 +358,39 @@ func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.C // d.metrics.CacheError.Inc(ctx) // } else { // cacheHitStopWatch.Stop() - // return newContextualWorkflow(nCtx.Workflow(), compiledWf, nStatus, compiledWf.Tasks, compiledWf.SubWorkflows), true, nil + // return newContextualWorkflow(nCtx.Workflow(), compiledWf, dynamicNodeStatus, compiledWf.Tasks, compiledWf.SubWorkflows), true, nil // } // } // We know for sure that futures file was generated. Lets read it djSpec, err := f.Read(ctx) if err != nil { - return nil, false, errors.Wrapf(errors.RuntimeExecutionError, nCtx.NodeID(), err, "unable to read futures file, maybe corrupted") + err = errors.Wrapf(errors.RuntimeExecutionError, nCtx.NodeID(), err, "unable to read futures file, maybe corrupted") + return } var closure *core.CompiledWorkflowClosure - wf, err := d.buildDynamicWorkflowTemplate(ctx, djSpec, nCtx, nStatus) + wf, err := d.buildDynamicWorkflowTemplate(ctx, djSpec, nCtx, dynamicNodeStatus) + isDynamic = true if err != nil { - return nil, true, err + return } compiledTasks, err := compileTasks(ctx, djSpec.Tasks) if err != nil { - return nil, true, err + return } // Get the requirements, that is, a list of all the task IDs and the launch plan IDs that will be called as part of this dynamic task. // The definition of these will need to be fetched from Admin (in order to get the interface). requirements, err := compiler.GetRequirements(wf, djSpec.Subworkflows) if err != nil { - return nil, true, err + return } launchPlanInterfaces, err := d.getLaunchPlanInterfaces(ctx, requirements.GetRequiredLaunchPlanIds()) if err != nil { - return nil, true, err + return } // TODO: In addition to querying Admin for launch plans, we also need to get all the tasks that are missing from the dynamic job spec. @@ -398,19 +399,22 @@ func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.C closure, err = compiler.CompileWorkflow(wf, djSpec.Subworkflows, compiledTasks, launchPlanInterfaces) if err != nil { - return nil, true, err + return } - subwf, err := k8s.BuildFlyteWorkflow(closure, &core.LiteralMap{}, nil, "") + dynamicWf, err := k8s.BuildFlyteWorkflow(closure, &core.LiteralMap{}, nil, "") if err != nil { - return nil, true, err + return } - if err := f.Cache(ctx, subwf); err != nil { + if err := f.Cache(ctx, dynamicWf); err != nil { logger.Errorf(ctx, "Failed to cache Dynamic workflow [%s]", err.Error()) } - return newContextualWorkflow(nCtx.Workflow(), subwf, nStatus, subwf.Tasks, subwf.SubWorkflows, nCtx.DataStore()), true, nil + subwf = dynamicWf + execContext = executors.NewExecutionContext(nCtx.ExecutionContext(), subwf, subwf) + nodeLookup = executors.NewNodeLookup(subwf, dynamicNodeStatus) + return execContext, subwf, nodeLookup, true, nil } func (d dynamicNodeTaskNodeHandler) getLaunchPlanInterfaces(ctx context.Context, launchPlanIDs []compiler.LaunchPlanRefIdentifier) ( @@ -429,10 +433,10 @@ func (d dynamicNodeTaskNodeHandler) getLaunchPlanInterfaces(ctx context.Context, return launchPlanInterfaces, nil } -func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, dynamicWorkflow v1alpha1.ExecutableWorkflow, +func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, execContext executors.ExecutionContext, dynamicWorkflow v1alpha1.ExecutableWorkflow, nl executors.NodeLookup, nCtx handler.NodeExecutionContext, prevState handler.DynamicNodeState) (handler.Transition, handler.DynamicNodeState, error) { - state, err := d.nodeExecutor.RecursiveNodeHandler(ctx, dynamicWorkflow, dynamicWorkflow.StartNode()) + state, err := d.nodeExecutor.RecursiveNodeHandler(ctx, execContext, dynamicWorkflow, nl, dynamicWorkflow.StartNode()) if err != nil { return handler.UnknownTransition, prevState, err } @@ -461,7 +465,8 @@ func (d dynamicNodeTaskNodeHandler) progressDynamicWorkflow(ctx context.Context, var o *handler.OutputInfo // If the WF interface has outputs, validate that the outputs file was written. if outputBindings := dynamicWorkflow.GetOutputBindings(); len(outputBindings) > 0 { - endNodeStatus := dynamicWorkflow.GetNodeExecutionStatus(ctx, v1alpha1.EndNodeID) + dynamicNodeStatus := nCtx.NodeStatus().GetNodeExecutionStatus(ctx, dynamicNodeID) + endNodeStatus := dynamicNodeStatus.GetNodeExecutionStatus(ctx, v1alpha1.EndNodeID) if endNodeStatus == nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure("MalformedDynamicWorkflow", "no end-node found in dynamic workflow", nil)), handler.DynamicNodeState{Phase: v1alpha1.DynamicNodePhaseFailing, Reason: "no end-node found in dynamic workflow"}, diff --git a/pkg/controller/nodes/dynamic/handler_test.go b/pkg/controller/nodes/dynamic/handler_test.go index 28a8bbb00b..b055bdc7a1 100644 --- a/pkg/controller/nodes/dynamic/handler_test.go +++ b/pkg/controller/nodes/dynamic/handler_test.go @@ -21,6 +21,7 @@ import ( "k8s.io/apimachinery/pkg/types" ioMocks "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io/mocks" + lpMocks "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks" "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" @@ -61,16 +62,22 @@ func Test_dynamicNodeHandler_Handle_Parent(t *testing.T) { Name: "name", } + res := &v12.ResourceRequirements{} + n := &flyteMocks.ExecutableNode{} + n.OnGetResources().Return(res) + n.OnGetID().Return("n1") + nm := &nodeMocks.NodeExecutionMetadata{} - nm.On("GetAnnotations").Return(map[string]string{}) - nm.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: wfExecID, + nm.OnGetAnnotations().Return(map[string]string{}) + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ + ExecutionId: wfExecID, + NodeId: n.GetID(), }) - nm.On("GetK8sServiceAccount").Return("service-account") - nm.On("GetLabels").Return(map[string]string{}) - nm.On("GetNamespace").Return("namespace") - nm.On("GetOwnerID").Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) - nm.On("GetOwnerReference").Return(v1.OwnerReference{ + nm.OnGetK8sServiceAccount().Return("service-account") + nm.OnGetLabels().Return(map[string]string{}) + nm.OnGetNamespace().Return("namespace") + nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.OnGetOwnerReference().Return(v1.OwnerReference{ Kind: "sample", Name: "name", }) @@ -97,38 +104,33 @@ func Test_dynamicNodeHandler_Handle_Parent(t *testing.T) { }, } tr := &nodeMocks.TaskReader{} - tr.On("GetTaskID").Return(taskID) - tr.On("GetTaskType").Return(ttype) - tr.On("Read", mock.Anything).Return(tk, nil) + tr.OnGetTaskID().Return(taskID) + tr.OnGetTaskType().Return(ttype) + tr.OnReadMatch(mock.Anything).Return(tk, nil) ns := &flyteMocks.ExecutableNodeStatus{} - ns.On("GetDataDir").Return(storage.DataReference("data-dir")) - ns.On("GetOutputDir").Return(storage.DataReference("data-dir")) - - res := &v12.ResourceRequirements{} - n := &flyteMocks.ExecutableNode{} - n.On("GetResources").Return(res) + ns.OnGetDataDir().Return(storage.DataReference("data-dir")) + ns.OnGetOutputDir().Return(storage.DataReference("data-dir")) dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) assert.NoError(t, err) ir := &ioMocks.InputReader{} nCtx := &nodeMocks.NodeExecutionContext{} - nCtx.On("NodeExecutionMetadata").Return(nm) - nCtx.On("Node").Return(n) - nCtx.On("InputReader").Return(ir) - nCtx.On("DataReferenceConstructor").Return(storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())) - nCtx.On("CurrentAttempt").Return(uint32(1)) - nCtx.On("TaskReader").Return(tr) - nCtx.On("MaxDatasetSizeBytes").Return(int64(1)) - nCtx.On("NodeStatus").Return(ns) - nCtx.On("NodeID").Return("n1") - nCtx.On("EnqueueOwner").Return(nil) + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnNode().Return(n) + nCtx.OnInputReader().Return(ir) + nCtx.OnCurrentAttempt().Return(uint32(1)) + nCtx.OnTaskReader().Return(tr) + nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) + nCtx.OnNodeStatus().Return(ns) + nCtx.OnNodeID().Return("n1") + nCtx.OnEnqueueOwnerFunc().Return(nil) nCtx.OnDataStore().Return(dataStore) r := &nodeMocks.NodeStateReader{} - r.On("GetDynamicNodeState").Return(handler.DynamicNodeState{}) - nCtx.On("NodeStateReader").Return(r) + r.OnGetDynamicNodeState().Return(handler.DynamicNodeState{}) + nCtx.OnNodeStateReader().Return(r) return nCtx } @@ -391,7 +393,7 @@ func createDynamicJobSpec() *core.DynamicJobSpec { func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { createNodeContext := func(ttype string, finalOutput storage.DataReference) *nodeMocks.NodeExecutionContext { ctx := context.TODO() - + nodeID := "n1" wfExecID := &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", @@ -399,15 +401,16 @@ func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { } nm := &nodeMocks.NodeExecutionMetadata{} - nm.On("GetAnnotations").Return(map[string]string{}) - nm.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: wfExecID, + nm.OnGetAnnotations().Return(map[string]string{}) + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ + ExecutionId: wfExecID, + NodeId: nodeID, }) - nm.On("GetK8sServiceAccount").Return("service-account") - nm.On("GetLabels").Return(map[string]string{}) - nm.On("GetNamespace").Return("namespace") - nm.On("GetOwnerID").Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) - nm.On("GetOwnerReference").Return(v1.OwnerReference{ + nm.OnGetK8sServiceAccount().Return("service-account") + nm.OnGetLabels().Return(map[string]string{}) + nm.OnGetNamespace().Return("namespace") + nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.OnGetOwnerReference().Return(v1.OwnerReference{ Kind: "sample", Name: "name", }) @@ -434,39 +437,40 @@ func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { }, } tr := &nodeMocks.TaskReader{} - tr.On("GetTaskID").Return(taskID) - tr.On("GetTaskType").Return(ttype) - tr.On("Read", mock.Anything).Return(tk, nil) + tr.OnGetTaskID().Return(taskID) + tr.OnGetTaskType().Return(ttype) + tr.OnRead(ctx).Return(tk, nil) n := &flyteMocks.ExecutableNode{} tID := "task-1" - n.On("GetTaskID").Return(&tID) + n.OnGetTaskID().Return(&tID) dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) assert.NoError(t, err) ir := &ioMocks.InputReader{} nCtx := &nodeMocks.NodeExecutionContext{} - nCtx.On("NodeExecutionMetadata").Return(nm) - nCtx.On("Node").Return(n) - nCtx.On("InputReader").Return(ir) - nCtx.On("DataReferenceConstructor").Return(storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())) - nCtx.On("CurrentAttempt").Return(uint32(1)) - nCtx.On("TaskReader").Return(tr) - nCtx.On("MaxDatasetSizeBytes").Return(int64(1)) - nCtx.On("NodeID").Return("n1") - nCtx.On("EnqueueOwnerFunc").Return(func() error { return nil }) + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnNode().Return(n) + nCtx.OnInputReader().Return(ir) + nCtx.OnCurrentAttempt().Return(uint32(1)) + nCtx.OnTaskReader().Return(tr) + nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) + nCtx.OnNodeID().Return(nodeID) + nCtx.OnEnqueueOwnerFunc().Return(func() error { return nil }) nCtx.OnDataStore().Return(dataStore) + execContext := &executorMocks.ExecutionContext{} + nCtx.OnExecutionContext().Return(execContext) endNodeStatus := &flyteMocks.ExecutableNodeStatus{} - endNodeStatus.On("GetDataDir").Return(storage.DataReference("end-node")) - endNodeStatus.On("GetOutputDir").Return(storage.DataReference("end-node")) + endNodeStatus.OnGetDataDir().Return("end-node") + endNodeStatus.OnGetOutputDir().Return("end-node") subNs := &flyteMocks.ExecutableNodeStatus{} subNs.On("SetDataDir", mock.Anything).Return() subNs.On("SetOutputDir", mock.Anything).Return() subNs.On("ResetDirty").Return() - subNs.On("GetOutputDir").Return(finalOutput) + subNs.OnGetOutputDir().Return(finalOutput) subNs.On("SetParentTaskID", mock.Anything).Return() subNs.OnGetAttempts().Return(0) @@ -480,23 +484,21 @@ func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { dynamicNS.OnGetNodeExecutionStatus(ctx, v1alpha1.EndNodeID).Return(endNodeStatus) ns := &flyteMocks.ExecutableNodeStatus{} - ns.On("GetDataDir").Return(storage.DataReference("data-dir")) - ns.On("GetOutputDir").Return(storage.DataReference("output-dir")) - ns.On("GetNodeExecutionStatus", dynamicNodeID).Return(dynamicNS) + ns.OnGetDataDir().Return("data-dir") + ns.OnGetOutputDir().Return("output-dir") ns.OnGetNodeExecutionStatus(ctx, dynamicNodeID).Return(dynamicNS) - nCtx.On("NodeStatus").Return(ns) + nCtx.OnNodeStatus().Return(ns) w := &flyteMocks.ExecutableWorkflow{} ws := &flyteMocks.ExecutableWorkflowStatus{} - ws.OnGetNodeExecutionStatus(ctx, "n1").Return(ns) - w.On("GetExecutionStatus").Return(ws) - nCtx.On("Workflow").Return(w) + ws.OnGetNodeExecutionStatus(ctx, nodeID).Return(ns) + w.OnGetExecutionStatus().Return(ws) r := &nodeMocks.NodeStateReader{} - r.On("GetDynamicNodeState").Return(handler.DynamicNodeState{ + r.OnGetDynamicNodeState().Return(handler.DynamicNodeState{ Phase: v1alpha1.DynamicNodePhaseExecuting, }) - nCtx.On("NodeStateReader").Return(r) + nCtx.OnNodeStateReader().Return(r) return nCtx } @@ -533,7 +535,7 @@ func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { finalOutput := storage.DataReference("/subnode") nCtx := createNodeContext("test", finalOutput) s := &dynamicNodeStateHolder{} - nCtx.On("NodeStateWriter").Return(s) + nCtx.OnNodeStateWriter().Return(s) f, err := nCtx.DataStore().ConstructReference(context.TODO(), nCtx.NodeStatus().GetOutputDir(), "futures.pb") assert.NoError(t, err) if tt.args.dj != nil { @@ -548,9 +550,9 @@ func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { } n := &executorMocks.Node{} if tt.args.isErr { - n.On("RecursiveNodeHandler", mock.Anything, mock.Anything, mock.Anything).Return(executors.NodeStatusUndefined, fmt.Errorf("error")) + n.OnRecursiveNodeHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(executors.NodeStatusUndefined, fmt.Errorf("error")) } else { - n.On("RecursiveNodeHandler", mock.Anything, mock.Anything, mock.Anything).Return(tt.args.s, nil) + n.OnRecursiveNodeHandlerMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tt.args.s, nil) } if tt.args.generateOutputs { endF := v1alpha1.GetOutputsFile("end-node") @@ -618,16 +620,15 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t Name: "name", } + nodeID := "n1" nm := &nodeMocks.NodeExecutionMetadata{} - nm.On("GetAnnotations").Return(map[string]string{}) - nm.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: wfExecID, - }) - nm.On("GetK8sServiceAccount").Return("service-account") - nm.On("GetLabels").Return(map[string]string{}) - nm.On("GetNamespace").Return("namespace") - nm.On("GetOwnerID").Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) - nm.On("GetOwnerReference").Return(v1.OwnerReference{ + nm.OnGetAnnotations().Return(map[string]string{}) + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ExecutionId: wfExecID, NodeId: nodeID}) + nm.OnGetK8sServiceAccount().Return("service-account") + nm.OnGetLabels().Return(map[string]string{}) + nm.OnGetNamespace().Return("namespace") + nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.OnGetOwnerReference().Return(v1.OwnerReference{ Kind: "sample", Name: "name", }) @@ -654,39 +655,38 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t }, } tr := &nodeMocks.TaskReader{} - tr.On("GetTaskID").Return(taskID) - tr.On("GetTaskType").Return(ttype) - tr.On("Read", mock.Anything).Return(tk, nil) + tr.OnGetTaskID().Return(taskID) + tr.OnGetTaskType().Return(ttype) + tr.OnReadMatch(mock.Anything).Return(tk, nil) n := &flyteMocks.ExecutableNode{} tID := "dyn-task-1" - n.On("GetTaskID").Return(&tID) + n.OnGetTaskID().Return(&tID) dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) assert.NoError(t, err) ir := &ioMocks.InputReader{} nCtx := &nodeMocks.NodeExecutionContext{} - nCtx.On("NodeExecutionMetadata").Return(nm) - nCtx.On("Node").Return(n) - nCtx.On("InputReader").Return(ir) - nCtx.On("DataReferenceConstructor").Return(storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())) - nCtx.On("CurrentAttempt").Return(uint32(1)) - nCtx.On("TaskReader").Return(tr) - nCtx.On("MaxDatasetSizeBytes").Return(int64(1)) - nCtx.On("NodeID").Return("n1") - nCtx.On("EnqueueOwnerFunc").Return(func() error { return nil }) + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnNode().Return(n) + nCtx.OnInputReader().Return(ir) + nCtx.OnCurrentAttempt().Return(uint32(1)) + nCtx.OnTaskReader().Return(tr) + nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) + nCtx.OnNodeID().Return("n1") + nCtx.OnEnqueueOwnerFunc().Return(func() error { return nil }) nCtx.OnDataStore().Return(dataStore) endNodeStatus := &flyteMocks.ExecutableNodeStatus{} - endNodeStatus.On("GetDataDir").Return(storage.DataReference("end-node")) - endNodeStatus.On("GetOutputDir").Return(storage.DataReference("end-node")) + endNodeStatus.OnGetDataDir().Return(storage.DataReference("end-node")) + endNodeStatus.OnGetOutputDir().Return(storage.DataReference("end-node")) subNs := &flyteMocks.ExecutableNodeStatus{} subNs.On("SetDataDir", mock.Anything).Return() subNs.On("SetOutputDir", mock.Anything).Return() subNs.On("ResetDirty").Return() - subNs.On("GetOutputDir").Return(finalOutput) + subNs.OnGetOutputDir().Return(finalOutput) subNs.On("SetParentTaskID", mock.Anything).Return() subNs.OnGetAttempts().Return(0) @@ -698,23 +698,24 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t dynamicNS.OnGetNodeExecutionStatus(ctx, v1alpha1.EndNodeID).Return(endNodeStatus) ns := &flyteMocks.ExecutableNodeStatus{} - ns.On("GetDataDir").Return(storage.DataReference("data-dir")) - ns.On("GetOutputDir").Return(storage.DataReference("output-dir")) - ns.On("GetNodeExecutionStatus", dynamicNodeID).Return(dynamicNS) + ns.OnGetDataDir().Return(storage.DataReference("data-dir")) + ns.OnGetOutputDir().Return(storage.DataReference("output-dir")) ns.OnGetNodeExecutionStatus(ctx, dynamicNodeID).Return(dynamicNS) - nCtx.On("NodeStatus").Return(ns) + ns.OnGetNodeExecutionStatus(ctx, dynamicNodeID).Return(dynamicNS) + nCtx.OnNodeStatus().Return(ns) w := &flyteMocks.ExecutableWorkflow{} ws := &flyteMocks.ExecutableWorkflowStatus{} ws.OnGetNodeExecutionStatus(ctx, "n1").Return(ns) - w.On("GetExecutionStatus").Return(ws) - nCtx.On("Workflow").Return(w) + w.OnGetExecutionStatus().Return(ws) r := &nodeMocks.NodeStateReader{} - r.On("GetDynamicNodeState").Return(handler.DynamicNodeState{ + r.OnGetDynamicNodeState().Return(handler.DynamicNodeState{ Phase: v1alpha1.DynamicNodePhaseExecuting, }) - nCtx.On("NodeStateReader").Return(r) + nCtx.OnNodeStateReader().Return(r) + execContext := &executorMocks.ExecutionContext{} + nCtx.OnExecutionContext().Return(execContext) return nCtx } @@ -766,11 +767,13 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t lpReader: mockLPLauncher, metrics: newMetrics(promutils.NewTestScope()), } - executableWorkflow, isDynamic, err := d.buildContextualDynamicWorkflow(ctx, nCtx) + execCtx, executableWorkflow, nl, isDynamic, err := d.buildContextualDynamicWorkflow(ctx, nCtx) assert.True(t, callsAdmin) assert.True(t, isDynamic) assert.NoError(t, err) assert.NotNil(t, executableWorkflow) + assert.NotNil(t, execCtx) + assert.NotNil(t, nl) }) t.Run("launch plan interfaces do not parent task interface", func(t *testing.T) { @@ -785,7 +788,7 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t finalOutput := storage.DataReference("/subnode") nCtx := createNodeContext("test", finalOutput) s := &dynamicNodeStateHolder{} - nCtx.On("NodeStateWriter").Return(s) + nCtx.OnNodeStateWriter().Return(s) f, err := nCtx.DataStore().ConstructReference(ctx, nCtx.NodeStatus().GetOutputDir(), "futures.pb") assert.NoError(t, err) assert.NoError(t, nCtx.DataStore().WriteProtobuf(context.TODO(), f, storage.Options{}, djSpec)) @@ -821,11 +824,13 @@ func Test_dynamicNodeHandler_buildContextualDynamicWorkflow_withLaunchPlans(t *t lpReader: mockLPLauncher, metrics: newMetrics(promutils.NewTestScope()), } - executableWorkflow, isDynamic, err := d.buildContextualDynamicWorkflow(ctx, nCtx) + execCtx, executableWorkflow, nl, isDynamic, err := d.buildContextualDynamicWorkflow(ctx, nCtx) assert.True(t, callsAdmin) assert.True(t, isDynamic) assert.Error(t, err) assert.Nil(t, executableWorkflow) + assert.Nil(t, execCtx) + assert.Nil(t, nl) }) } @@ -855,7 +860,7 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { createNodeContext := func(ttype string, finalOutput storage.DataReference) *nodeMocks.NodeExecutionContext { ctx := context.TODO() - + nodeID := "n1" wfExecID := &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", @@ -863,15 +868,16 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { } nm := &nodeMocks.NodeExecutionMetadata{} - nm.On("GetAnnotations").Return(map[string]string{}) - nm.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: wfExecID, + nm.OnGetAnnotations().Return(map[string]string{}) + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ + ExecutionId: wfExecID, + NodeId: nodeID, }) - nm.On("GetK8sServiceAccount").Return("service-account") - nm.On("GetLabels").Return(map[string]string{}) - nm.On("GetNamespace").Return("namespace") - nm.On("GetOwnerID").Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) - nm.On("GetOwnerReference").Return(v1.OwnerReference{ + nm.OnGetK8sServiceAccount().Return("service-account") + nm.OnGetLabels().Return(map[string]string{}) + nm.OnGetNamespace().Return("namespace") + nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.OnGetOwnerReference().Return(v1.OwnerReference{ Kind: "sample", Name: "name", }) @@ -898,39 +904,40 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { }, } tr := &nodeMocks.TaskReader{} - tr.On("GetTaskID").Return(taskID) - tr.On("GetTaskType").Return(ttype) - tr.On("Read", mock.Anything).Return(tk, nil) + tr.OnGetTaskID().Return(taskID) + tr.OnGetTaskType().Return(ttype) + tr.OnRead(ctx).Return(tk, nil) n := &flyteMocks.ExecutableNode{} tID := "task-1" - n.On("GetTaskID").Return(&tID) + n.OnGetTaskID().Return(&tID) dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) assert.NoError(t, err) ir := &ioMocks.InputReader{} nCtx := &nodeMocks.NodeExecutionContext{} - nCtx.On("NodeExecutionMetadata").Return(nm) - nCtx.On("Node").Return(n) - nCtx.On("InputReader").Return(ir) - nCtx.On("DataReferenceConstructor").Return(storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())) - nCtx.On("CurrentAttempt").Return(uint32(1)) - nCtx.On("TaskReader").Return(tr) - nCtx.On("MaxDatasetSizeBytes").Return(int64(1)) - nCtx.On("NodeID").Return("n1") - nCtx.On("EnqueueOwnerFunc").Return(func() error { return nil }) + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnNode().Return(n) + nCtx.OnInputReader().Return(ir) + nCtx.OnCurrentAttempt().Return(uint32(1)) + nCtx.OnTaskReader().Return(tr) + nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) + nCtx.OnNodeID().Return(nodeID) + nCtx.OnEnqueueOwnerFunc().Return(func() error { return nil }) nCtx.OnDataStore().Return(dataStore) + execContext := &executorMocks.ExecutionContext{} + nCtx.OnExecutionContext().Return(execContext) endNodeStatus := &flyteMocks.ExecutableNodeStatus{} - endNodeStatus.On("GetDataDir").Return(storage.DataReference("end-node")) - endNodeStatus.On("GetOutputDir").Return(storage.DataReference("end-node")) + endNodeStatus.OnGetDataDir().Return("end-node") + endNodeStatus.OnGetOutputDir().Return("end-node") subNs := &flyteMocks.ExecutableNodeStatus{} subNs.On("SetDataDir", mock.Anything).Return() subNs.On("SetOutputDir", mock.Anything).Return() subNs.On("ResetDirty").Return() - subNs.On("GetOutputDir").Return(finalOutput) + subNs.OnGetOutputDir().Return(finalOutput) subNs.On("SetParentTaskID", mock.Anything).Return() subNs.OnGetAttempts().Return(0) @@ -944,26 +951,23 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { dynamicNS.OnGetNodeExecutionStatus(ctx, v1alpha1.EndNodeID).Return(endNodeStatus) ns := &flyteMocks.ExecutableNodeStatus{} - ns.On("GetDataDir").Return(storage.DataReference("data-dir")) - ns.On("GetOutputDir").Return(storage.DataReference("output-dir")) - ns.On("GetNodeExecutionStatus", dynamicNodeID).Return(dynamicNS) + ns.OnGetDataDir().Return("data-dir") + ns.OnGetOutputDir().Return("output-dir") ns.OnGetNodeExecutionStatus(ctx, dynamicNodeID).Return(dynamicNS) - nCtx.On("NodeStatus").Return(ns) + nCtx.OnNodeStatus().Return(ns) w := &flyteMocks.ExecutableWorkflow{} ws := &flyteMocks.ExecutableWorkflowStatus{} - ws.OnGetNodeExecutionStatus(ctx, "n1").Return(ns) - w.On("GetExecutionStatus").Return(ws) - nCtx.On("Workflow").Return(w) + ws.OnGetNodeExecutionStatus(ctx, nodeID).Return(ns) + w.OnGetExecutionStatus().Return(ws) r := &nodeMocks.NodeStateReader{} - r.On("GetDynamicNodeState").Return(handler.DynamicNodeState{ + r.OnGetDynamicNodeState().Return(handler.DynamicNodeState{ Phase: v1alpha1.DynamicNodePhaseExecuting, }) - nCtx.On("NodeStateReader").Return(r) + nCtx.OnNodeStateReader().Return(r) return nCtx } - t.Run("dynamicnodephase-executing", func(t *testing.T) { nCtx := createNodeContext("test", "x") @@ -976,7 +980,7 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { h := &mocks.TaskNodeHandler{} h.OnFinalize(ctx, nCtx).Return(nil) n := &executorMocks.Node{} - n.OnFinalizeHandlerMatch(ctx, mock.Anything, mock.Anything).Return(nil) + n.OnFinalizeHandlerMatch(ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) d := New(h, n, mockLPLauncher, promutils.NewTestScope()) assert.NoError(t, d.Finalize(ctx, nCtx)) assert.NotZero(t, len(h.ExpectedCalls)) @@ -997,7 +1001,7 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { h := &mocks.TaskNodeHandler{} h.OnFinalize(ctx, nCtx).Return(fmt.Errorf("err")) n := &executorMocks.Node{} - n.OnFinalizeHandlerMatch(ctx, mock.Anything, mock.Anything).Return(nil) + n.OnFinalizeHandlerMatch(ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) d := New(h, n, mockLPLauncher, promutils.NewTestScope()) assert.Error(t, d.Finalize(ctx, nCtx)) assert.NotZero(t, len(h.ExpectedCalls)) @@ -1018,7 +1022,7 @@ func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { h := &mocks.TaskNodeHandler{} h.OnFinalize(ctx, nCtx).Return(nil) n := &executorMocks.Node{} - n.OnFinalizeHandlerMatch(ctx, mock.Anything, mock.Anything).Return(fmt.Errorf("err")) + n.OnFinalizeHandlerMatch(ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("err")) d := New(h, n, mockLPLauncher, promutils.NewTestScope()) assert.Error(t, d.Finalize(ctx, nCtx)) assert.NotZero(t, len(h.ExpectedCalls)) diff --git a/pkg/controller/nodes/dynamic/subworkflow.go b/pkg/controller/nodes/dynamic/subworkflow.go deleted file mode 100644 index 9d2350755d..0000000000 --- a/pkg/controller/nodes/dynamic/subworkflow.go +++ /dev/null @@ -1,91 +0,0 @@ -package dynamic - -import ( - "context" - - "github.com/lyft/flytepropeller/pkg/controller/executors" - - "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/lyft/flytestdlib/storage" -) - -// Defines a sub-contextual workflow that is built in-memory to represent a dynamic job execution plan. -type contextualWorkflow struct { - v1alpha1.ExecutableWorkflow - - extraTasks map[v1alpha1.TaskID]*v1alpha1.TaskSpec - extraWorkflows map[v1alpha1.WorkflowID]*v1alpha1.WorkflowSpec - status *ContextualWorkflowStatus -} - -func newContextualWorkflow(baseWorkflow v1alpha1.ExecutableWorkflow, - subwf v1alpha1.ExecutableSubWorkflow, - status v1alpha1.ExecutableNodeStatus, - tasks map[v1alpha1.TaskID]*v1alpha1.TaskSpec, - workflows map[v1alpha1.WorkflowID]*v1alpha1.WorkflowSpec, - refConstructor storage.ReferenceConstructor) v1alpha1.ExecutableWorkflow { - - return &contextualWorkflow{ - ExecutableWorkflow: executors.NewSubContextualWorkflow(baseWorkflow, subwf, status), - extraTasks: tasks, - extraWorkflows: workflows, - status: newContextualWorkflowStatus(baseWorkflow.GetExecutionStatus(), status, refConstructor), - } -} - -func (w contextualWorkflow) GetExecutionStatus() v1alpha1.ExecutableWorkflowStatus { - return w.status -} - -func (w contextualWorkflow) GetTask(id v1alpha1.TaskID) (v1alpha1.ExecutableTask, error) { - if task, found := w.extraTasks[id]; found { - return task, nil - } - - return w.ExecutableWorkflow.GetTask(id) -} - -func (w contextualWorkflow) FindSubWorkflow(id v1alpha1.WorkflowID) v1alpha1.ExecutableSubWorkflow { - if wf, found := w.extraWorkflows[id]; found { - return wf - } - - return w.ExecutableWorkflow.FindSubWorkflow(id) -} - -// A contextual workflow status to override some of the implementations. -type ContextualWorkflowStatus struct { - v1alpha1.ExecutableWorkflowStatus - baseStatus v1alpha1.ExecutableNodeStatus - referenceConstructor storage.ReferenceConstructor -} - -func (w ContextualWorkflowStatus) GetDataDir() v1alpha1.DataReference { - return w.baseStatus.GetDataDir() -} - -// Overrides default node data dir to work around the contractual assumption between Propeller and Futures to write all -// sub-node inputs into current node data directory. -// E.g. -// if current node data dir is /wf_exec/node-1/data/ -// and the task ran and yielded 2 nodes, the structure will look like this: -// /wf_exec/node-1/data/ -// |_ inputs.pb -// |_ futures.pb -// |_ sub-node1/inputs.pb -// |_ sub-node2/inputs.pb -// TODO: This is just a stop-gap until we transition the DynamicJobSpec to be a full-fledged workflow spec. -// TODO: this will allow us to have proper data bindings between nodes then we can stop making assumptions about data refs. -func (w ContextualWorkflowStatus) ConstructNodeDataDir(ctx context.Context, name v1alpha1.NodeID) (storage.DataReference, error) { - return w.referenceConstructor.ConstructReference(ctx, w.GetDataDir(), name) -} - -func newContextualWorkflowStatus(baseWfStatus v1alpha1.ExecutableWorkflowStatus, - baseStatus v1alpha1.ExecutableNodeStatus, constructor storage.ReferenceConstructor) *ContextualWorkflowStatus { - - return &ContextualWorkflowStatus{ - ExecutableWorkflowStatus: baseWfStatus, - baseStatus: baseStatus, - referenceConstructor: constructor, - } -} diff --git a/pkg/controller/nodes/dynamic/subworkflow_test.go b/pkg/controller/nodes/dynamic/subworkflow_test.go deleted file mode 100644 index f55016ee75..0000000000 --- a/pkg/controller/nodes/dynamic/subworkflow_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package dynamic - -import ( - "context" - "testing" - - "github.com/lyft/flytestdlib/promutils" - "github.com/lyft/flytestdlib/storage" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - - "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" -) - -func TestNewContextualWorkflow(t *testing.T) { - wf := &mocks.ExecutableWorkflow{} - calledBase := false - wf.On("GetAnnotations").Return(map[string]string{}).Run(func(_ mock.Arguments) { - calledBase = true - }) - - wf.On("GetExecutionStatus").Return(&mocks.ExecutableWorkflowStatus{}) - - subwf := &mocks.ExecutableSubWorkflow{} - cWF := newContextualWorkflow(wf, subwf, nil, nil, nil, nil) - cWF.GetAnnotations() - - assert.True(t, calledBase) -} - -func TestConstructNodeDataDir(t *testing.T) { - wf := &mocks.ExecutableWorkflow{} - wf.On("GetExecutionStatus").Return(&mocks.ExecutableWorkflowStatus{}) - - wfStatus := &mocks.ExecutableWorkflowStatus{} - wfStatus.On("GetDataDir").Return(storage.DataReference("fk://wrong/")).Run(func(_ mock.Arguments) { - assert.FailNow(t, "Should call the override") - }) - - nodeStatus := &mocks.ExecutableNodeStatus{} - nodeStatus.On("GetDataDir").Return(storage.DataReference("fk://right/")) - - ds, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) - assert.NoError(t, err) - cWF := newContextualWorkflowStatus(wfStatus, nodeStatus, ds) - - dataDir, err := cWF.ConstructNodeDataDir(context.TODO(), "my_node") - assert.NoError(t, err) - assert.NotNil(t, dataDir) - assert.Equal(t, "fk://right/my_node", dataDir.String()) -} diff --git a/pkg/controller/nodes/errors/codes.go b/pkg/controller/nodes/errors/codes.go index d1304f49a4..19bb41cf87 100644 --- a/pkg/controller/nodes/errors/codes.go +++ b/pkg/controller/nodes/errors/codes.go @@ -17,6 +17,7 @@ const ( CausedByError ErrorCode = "CausedByError" RuntimeExecutionError ErrorCode = "RuntimeExecutionError" SubWorkflowExecutionFailed ErrorCode = "SubWorkflowExecutionFailed" + SubWorkflowExecutionFailing ErrorCode = "SubWorkflowExecutionFailing" RemoteChildWorkflowExecutionFailed ErrorCode = "RemoteChildWorkflowExecutionFailed" NoBranchTakenError ErrorCode = "NoBranchTakenError" OutputsNotFoundError ErrorCode = "OutputsNotFoundError" diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index 317381ad1d..a02d51a72b 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -72,10 +72,10 @@ type nodeExecutor struct { shardSelector ioutils.ShardSelector } -func (c *nodeExecutor) RecordTransitionLatency(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) { +func (c *nodeExecutor) RecordTransitionLatency(ctx context.Context, dag executors.DAGStructure, nl executors.NodeLookup, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) { if nodeStatus.GetPhase() == v1alpha1.NodePhaseNotYetStarted || nodeStatus.GetPhase() == v1alpha1.NodePhaseQueued { // Log transition latency (The most recently finished parent node endAt time to this node's queuedAt time -now-) - t, err := GetParentNodeMaxEndTime(ctx, w, node) + t, err := GetParentNodeMaxEndTime(ctx, dag, nl, node) if err != nil { logger.Warnf(ctx, "Failed to record transition latency for node. Error: %s", err.Error()) return @@ -119,10 +119,10 @@ func (c *nodeExecutor) IdempotentRecordEvent(ctx context.Context, nodeEvent *eve // In this method we check if the queue is ready to be processed and if so, we prime it in Admin as queued // Before we start the node execution, we need to transition this Node status to Queued. // This is because a node execution has to exist before task/wf executions can start. -func (c *nodeExecutor) preExecute(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.PhaseInfo, error) { +func (c *nodeExecutor) preExecute(ctx context.Context, dag executors.DAGStructure, nCtx handler.NodeExecutionContext) (handler.PhaseInfo, error) { logger.Debugf(ctx, "Node not yet started") // Query the nodes information to figure out if it can be executed. - predicatePhase, err := CanExecute(ctx, w, node) + predicatePhase, err := CanExecute(ctx, dag, nCtx.ContextualNodeLookup(), nCtx.Node()) if err != nil { logger.Debugf(ctx, "Node failed in CanExecute. Error [%s]", err) return handler.PhaseInfoUndefined, err @@ -131,6 +131,8 @@ func (c *nodeExecutor) preExecute(ctx context.Context, w v1alpha1.ExecutableWork if predicatePhase == PredicatePhaseReady { // TODO: Performance problem, we maybe in a retry loop and do not need to resolve the inputs again. // For now we will do this. + node := nCtx.Node() + nodeStatus := nCtx.NodeStatus() dataDir := nodeStatus.GetDataDir() var nodeInputs *core.LiteralMap if !node.IsStartNode() { @@ -138,7 +140,7 @@ func (c *nodeExecutor) preExecute(ctx context.Context, w v1alpha1.ExecutableWork defer t.Stop() // Can execute var err error - nodeInputs, err = Resolve(ctx, c.outputResolver, w, node.GetID(), node.GetInputBindings()) + nodeInputs, err = Resolve(ctx, c.outputResolver, nCtx.ContextualNodeLookup(), node.GetID(), node.GetInputBindings()) // TODO we need to handle retryable, network errors here!! if err != nil { c.metrics.ResolutionFailure.Inc(ctx) @@ -183,7 +185,7 @@ func (c *nodeExecutor) isTimeoutExpired(queuedAt *metav1.Time, timeout time.Dura return false } -func (c *nodeExecutor) isEligibleForRetry(nCtx *execContext, nodeStatus v1alpha1.ExecutableNodeStatus, err *core.ExecutionError) (currentAttempt, maxAttempts uint32, isEligible bool) { +func (c *nodeExecutor) isEligibleForRetry(nCtx *nodeExecContext, nodeStatus v1alpha1.ExecutableNodeStatus, err *core.ExecutionError) (currentAttempt, maxAttempts uint32, isEligible bool) { if err.Kind == core.ExecutionError_SYSTEM { currentAttempt = nodeStatus.GetSystemFailures() maxAttempts = c.maxNodeRetriesForSystemFailures @@ -199,7 +201,7 @@ func (c *nodeExecutor) isEligibleForRetry(nCtx *execContext, nodeStatus v1alpha1 return } -func (c *nodeExecutor) execute(ctx context.Context, h handler.Node, nCtx *execContext, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.PhaseInfo, error) { +func (c *nodeExecutor) execute(ctx context.Context, h handler.Node, nCtx *nodeExecContext, nodeStatus v1alpha1.ExecutableNodeStatus) (handler.PhaseInfo, error) { logger.Debugf(ctx, "Executing node") defer logger.Debugf(ctx, "Node execution round complete") @@ -265,149 +267,66 @@ func (c *nodeExecutor) finalize(ctx context.Context, h handler.Node, nCtx handle return h.Finalize(ctx, nCtx) } -func (c *nodeExecutor) handleNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (executors.NodeStatus, error) { - logger.Debugf(ctx, "Handling Node [%s]", node.GetID()) - defer logger.Debugf(ctx, "Completed node [%s]", node.GetID()) - - nodeExecID := &core.NodeExecutionIdentifier{ - NodeId: node.GetID(), - ExecutionId: w.GetExecutionID().WorkflowExecutionIdentifier, +func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executors.DAGStructure, nCtx *nodeExecContext, _ handler.Node) (executors.NodeStatus, error) { + logger.Debugf(ctx, "Node not yet started, running pre-execute") + defer logger.Debugf(ctx, "Node pre-execute completed") + p, err := c.preExecute(ctx, dag, nCtx) + if err != nil { + logger.Errorf(ctx, "failed preExecute for node. Error: %s", err.Error()) + return executors.NodeStatusUndefined, err } - nodeStatus := w.GetNodeExecutionStatus(ctx, node.GetID()) - - if nodeStatus.IsDirty() { - return executors.NodeStatusRunning, nil + if p.GetPhase() == handler.EPhaseUndefined { + return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "received undefined phase.") } - // Now depending on the node type decide - h, err := c.nodeHandlerFactory.GetHandler(node.GetKind()) - if err != nil { - return executors.NodeStatusUndefined, err + if p.GetPhase() == handler.EPhaseNotReady { + return executors.NodeStatusPending, nil } - nCtx, err := c.newNodeExecContextDefault(ctx, w, node, nodeStatus) + np, err := ToNodePhase(p.GetPhase()) if err != nil { - return executors.NodeStatusUndefined, err + return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "failed to move from queued") } - currentPhase := nodeStatus.GetPhase() - - // Optimization! - // If it is start node we directly move it to Queued without needing to run preExecute - if currentPhase == v1alpha1.NodePhaseNotYetStarted && !node.IsStartNode() { - logger.Debugf(ctx, "Node not yet started, running pre-execute") - defer logger.Debugf(ctx, "Node pre-execute completed") - p, err := c.preExecute(ctx, w, node, nodeStatus) + nodeStatus := nCtx.NodeStatus() + if np != nodeStatus.GetPhase() { + // assert np == Queued! + logger.Infof(ctx, "Change in node state detected from [%s] -> [%s]", nodeStatus.GetPhase().String(), np.String()) + nev, err := ToNodeExecutionEvent(nCtx.NodeExecutionMetadata().GetNodeExecutionID(), p, nCtx.InputReader(), nodeStatus) if err != nil { - logger.Errorf(ctx, "failed preExecute for node. Error: %s", err.Error()) - return executors.NodeStatusUndefined, err + return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") } - - if p.GetPhase() == handler.EPhaseUndefined { - return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, node.GetID(), "received undefined phase.") - } - - if p.GetPhase() == handler.EPhaseNotReady { - return executors.NodeStatusPending, nil - } - - np, err := ToNodePhase(p.GetPhase()) + err = c.IdempotentRecordEvent(ctx, nev) if err != nil { - return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, node.GetID(), err, "failed to move from queued") - } - - if np != nodeStatus.GetPhase() { - // assert np == Queued! - logger.Infof(ctx, "Change in node state detected from [%s] -> [%s]", nodeStatus.GetPhase().String(), np.String()) - nev, err := ToNodeExecutionEvent(nodeExecID, p, nCtx.InputReader(), nCtx.NodeStatus()) - if err != nil { - return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, node.GetID(), err, "could not convert phase info to event") - } - err = c.IdempotentRecordEvent(ctx, nev) - if err != nil { - logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) - return executors.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, node.GetID(), err, "failed to record node event") - } - UpdateNodeStatus(np, p, nCtx.nsm, nodeStatus) - c.RecordTransitionLatency(ctx, w, node, nodeStatus) - } - - if np == v1alpha1.NodePhaseQueued { - return executors.NodeStatusQueued, nil - } else if np == v1alpha1.NodePhaseSkipped { - return executors.NodeStatusSuccess, nil - } - - return executors.NodeStatusPending, nil - } - - if currentPhase == v1alpha1.NodePhaseFailing { - logger.Debugf(ctx, "node failing") - if err := c.finalize(ctx, h, nCtx); err != nil { - return executors.NodeStatusUndefined, err - } - - nodeStatus.UpdatePhase(v1alpha1.NodePhaseFailed, v1.Now(), nodeStatus.GetMessage()) - c.metrics.FailureDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetStoppedAt().Time) - // TODO we need to have a way to find the error message from failing to failed! - return executors.NodeStatusFailed(fmt.Errorf(nodeStatus.GetMessage())), nil - } - - if currentPhase == v1alpha1.NodePhaseTimingOut { - logger.Debugf(ctx, "node timing out") - if err := c.abort(ctx, h, nCtx, "node timed out"); err != nil { - return executors.NodeStatusUndefined, err + logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) + return executors.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") } - - nodeStatus.ClearSubNodeStatus() - nodeStatus.UpdatePhase(v1alpha1.NodePhaseTimedOut, v1.Now(), nodeStatus.GetMessage()) - c.metrics.TimedOutFailure.Inc(ctx) - return executors.NodeStatusTimedOut, nil + UpdateNodeStatus(np, p, nCtx.nsm, nodeStatus) + c.RecordTransitionLatency(ctx, dag, nCtx.ContextualNodeLookup(), nCtx.Node(), nodeStatus) } - if currentPhase == v1alpha1.NodePhaseSucceeding { - logger.Debugf(ctx, "node succeeding") - if err := c.finalize(ctx, h, nCtx); err != nil { - return executors.NodeStatusUndefined, err - } - - nodeStatus.ClearSubNodeStatus() - nodeStatus.UpdatePhase(v1alpha1.NodePhaseSucceeded, v1.Now(), "completed successfully") - c.metrics.SuccessDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetStoppedAt().Time) + if np == v1alpha1.NodePhaseQueued { + return executors.NodeStatusQueued, nil + } else if np == v1alpha1.NodePhaseSkipped { return executors.NodeStatusSuccess, nil } - if currentPhase == v1alpha1.NodePhaseRetryableFailure { - logger.Debugf(ctx, "node failed with retryable failure, aborting and finalizing, message: %s", nodeStatus.GetMessage()) - if err := c.abort(ctx, h, nCtx, nodeStatus.GetMessage()); err != nil { - return executors.NodeStatusUndefined, err - } + return executors.NodeStatusPending, nil +} - // NOTE: It is important to increment attempts only after abort has been called. Increment attempt mutates the state - // Attempt is used throughout the system to determine the idempotent resource version. - nodeStatus.IncrementAttempts() - nodeStatus.UpdatePhase(v1alpha1.NodePhaseRunning, v1.Now(), "retrying") - // We are going to retry in the next round, so we should clear all current state - nodeStatus.ClearSubNodeStatus() - nodeStatus.ClearTaskStatus() - nodeStatus.ClearWorkflowStatus() - nodeStatus.ClearDynamicNodeStatus() - return executors.NodeStatusPending, nil - } +func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx *nodeExecContext, h handler.Node) (executors.NodeStatus, error) { + nodeStatus := nCtx.NodeStatus() + currentPhase := nodeStatus.GetPhase() - if currentPhase == v1alpha1.NodePhaseFailed { - // This should never happen - return executors.NodeStatusFailed(fmt.Errorf(nodeStatus.GetMessage())), nil - } + // case v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning: + logger.Debugf(ctx, "node executing, current phase [%s]", currentPhase) + defer logger.Debugf(ctx, "node execution completed") // Since we reset node status inside execute for retryable failure, we use lastAttemptStartTime to carry that information // across execute which is used to emit metrics lastAttemptStartTime := nodeStatus.GetLastAttemptStartedAt() - // case v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning: - logger.Debugf(ctx, "node executing, current phase [%s]", currentPhase) - defer logger.Debugf(ctx, "node execution completed") p, err := c.execute(ctx, h, nCtx, nodeStatus) if err != nil { logger.Errorf(ctx, "failed Execute for node. Error: %s", err.Error()) @@ -436,12 +355,12 @@ func (c *nodeExecutor) handleNode(ctx context.Context, w v1alpha1.ExecutableWork } if p.GetPhase() == handler.EPhaseUndefined { - return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, node.GetID(), "received undefined phase.") + return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, nCtx.NodeID(), "received undefined phase.") } np, err := ToNodePhase(p.GetPhase()) if err != nil { - return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, node.GetID(), err, "failed to move from queued") + return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "failed to move from queued") } finalStatus := executors.NodeStatusRunning @@ -467,15 +386,15 @@ func (c *nodeExecutor) handleNode(ctx context.Context, w v1alpha1.ExecutableWork if np != nodeStatus.GetPhase() && np != v1alpha1.NodePhaseRetryableFailure { // assert np == skipped, succeeding or failing logger.Infof(ctx, "Change in node state detected from [%s] -> [%s], (handler phase [%s])", nodeStatus.GetPhase().String(), np.String(), p.GetPhase().String()) - nev, err := ToNodeExecutionEvent(nodeExecID, p, nCtx.InputReader(), nCtx.NodeStatus()) + nev, err := ToNodeExecutionEvent(nCtx.NodeExecutionMetadata().GetNodeExecutionID(), p, nCtx.InputReader(), nCtx.NodeStatus()) if err != nil { - return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, node.GetID(), err, "could not convert phase info to event") + return executors.NodeStatusUndefined, errors.Wrapf(errors.IllegalStateError, nCtx.NodeID(), err, "could not convert phase info to event") } err = c.IdempotentRecordEvent(ctx, nev) if err != nil { logger.Warningf(ctx, "Failed to record nodeEvent, error [%s]", err.Error()) - return executors.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, node.GetID(), err, "failed to record node event") + return executors.NodeStatusUndefined, errors.Wrapf(errors.EventRecordingFailed, nCtx.NodeID(), err, "failed to record node event") } // We reach here only when transitioning from Queued to Running. In this case, the startedAt is not set. @@ -490,15 +409,93 @@ func (c *nodeExecutor) handleNode(ctx context.Context, w v1alpha1.ExecutableWork return finalStatus, nil } +func (c *nodeExecutor) handleRetryableFailure(ctx context.Context, nCtx *nodeExecContext, h handler.Node) (executors.NodeStatus, error) { + nodeStatus := nCtx.NodeStatus() + logger.Debugf(ctx, "node failed with retryable failure, aborting and finalizing, message: %s", nodeStatus.GetMessage()) + if err := c.abort(ctx, h, nCtx, nodeStatus.GetMessage()); err != nil { + return executors.NodeStatusUndefined, err + } + + // NOTE: It is important to increment attempts only after abort has been called. Increment attempt mutates the state + // Attempt is used throughout the system to determine the idempotent resource version. + nodeStatus.IncrementAttempts() + nodeStatus.UpdatePhase(v1alpha1.NodePhaseRunning, v1.Now(), "retrying") + // We are going to retry in the next round, so we should clear all current state + nodeStatus.ClearSubNodeStatus() + nodeStatus.ClearTaskStatus() + nodeStatus.ClearWorkflowStatus() + nodeStatus.ClearDynamicNodeStatus() + return executors.NodeStatusPending, nil +} + +func (c *nodeExecutor) handleNode(ctx context.Context, dag executors.DAGStructure, nCtx *nodeExecContext, h handler.Node) (executors.NodeStatus, error) { + logger.Debugf(ctx, "Handling Node [%s]", nCtx.NodeID()) + defer logger.Debugf(ctx, "Completed node [%s]", nCtx.NodeID()) + + nodeStatus := nCtx.NodeStatus() + currentPhase := nodeStatus.GetPhase() + + // Optimization! + // If it is start node we directly move it to Queued without needing to run preExecute + if currentPhase == v1alpha1.NodePhaseNotYetStarted && !nCtx.Node().IsStartNode() { + return c.handleNotYetStartedNode(ctx, dag, nCtx, h) + } + + if currentPhase == v1alpha1.NodePhaseFailing { + logger.Debugf(ctx, "node failing") + if err := c.finalize(ctx, h, nCtx); err != nil { + return executors.NodeStatusUndefined, err + } + nodeStatus.UpdatePhase(v1alpha1.NodePhaseFailed, v1.Now(), nodeStatus.GetMessage()) + c.metrics.FailureDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetStoppedAt().Time) + return executors.NodeStatusFailed(fmt.Errorf(nodeStatus.GetMessage())), nil + } + + if currentPhase == v1alpha1.NodePhaseTimingOut { + logger.Debugf(ctx, "node timing out") + if err := c.abort(ctx, h, nCtx, "node timed out"); err != nil { + return executors.NodeStatusUndefined, err + } + + nodeStatus.ClearSubNodeStatus() + nodeStatus.UpdatePhase(v1alpha1.NodePhaseTimedOut, v1.Now(), nodeStatus.GetMessage()) + c.metrics.TimedOutFailure.Inc(ctx) + return executors.NodeStatusTimedOut, nil + } + + if currentPhase == v1alpha1.NodePhaseSucceeding { + logger.Debugf(ctx, "node succeeding") + if err := c.finalize(ctx, h, nCtx); err != nil { + return executors.NodeStatusUndefined, err + } + + nodeStatus.ClearSubNodeStatus() + nodeStatus.UpdatePhase(v1alpha1.NodePhaseSucceeded, v1.Now(), "completed successfully") + c.metrics.SuccessDuration.Observe(ctx, nodeStatus.GetStartedAt().Time, nodeStatus.GetStoppedAt().Time) + return executors.NodeStatusSuccess, nil + } + + if currentPhase == v1alpha1.NodePhaseRetryableFailure { + return c.handleRetryableFailure(ctx, nCtx, h) + } + + if currentPhase == v1alpha1.NodePhaseFailed { + // This should never happen + return executors.NodeStatusFailed(fmt.Errorf(nodeStatus.GetMessage())), nil + } + + return c.handleQueuedOrRunningNode(ctx, nCtx, h) +} + // The space search for the next node to execute is implemented like a DFS algorithm. handleDownstream visits all the nodes downstream from // the currentNode. Visit a node is the RecursiveNodeHandler. A visit may be partial, complete or may result in a failure. -func (c *nodeExecutor) handleDownstream(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { +func (c *nodeExecutor) handleDownstream(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { logger.Debugf(ctx, "Handling downstream Nodes") // This node is success. Handle all downstream nodes - downstreamNodes, err := w.FromNode(currentNode.GetID()) + downstreamNodes, err := dag.FromNode(currentNode.GetID()) if err != nil { - logger.Debugf(ctx, "Error when retrieving downstream nodes. Error [%v]", err) - return executors.NodeStatusFailed(err), nil + logger.Debugf(ctx, "Error when retrieving downstream nodes, [%s]", err) + return executors.NodeStatusFailed(errors.Wrapf(errors.BadSpecificationError, currentNode.GetID(), err, "failed to retrieve downstream nodes")), nil } if len(downstreamNodes) == 0 { logger.Debugf(ctx, "No downstream nodes found. Complete.") @@ -510,11 +507,11 @@ func (c *nodeExecutor) handleDownstream(ctx context.Context, w v1alpha1.Executab allCompleted := true partialNodeCompletion := false for _, downstreamNodeName := range downstreamNodes { - downstreamNode, ok := w.GetNode(downstreamNodeName) + downstreamNode, ok := nl.GetNode(downstreamNodeName) if !ok { return executors.NodeStatusFailed(errors.Errorf(errors.BadSpecificationError, currentNode.GetID(), "Unable to find Downstream Node [%v]", downstreamNodeName)), nil } - state, err := c.RecursiveNodeHandler(ctx, w, downstreamNode) + state, err := c.RecursiveNodeHandler(ctx, execContext, dag, nl, downstreamNode) if err != nil { return executors.NodeStatusUndefined, err } @@ -546,12 +543,8 @@ func (c *nodeExecutor) handleDownstream(ctx context.Context, w v1alpha1.Executab return executors.NodeStatusPending, nil } -func (c *nodeExecutor) SetInputsForStartNode(ctx context.Context, w v1alpha1.ExecutableWorkflow, inputs *core.LiteralMap) (executors.NodeStatus, error) { - startNode := w.StartNode() - if startNode == nil { - return executors.NodeStatusFailed(errors.Errorf(errors.BadSpecificationError, v1alpha1.StartNodeID, "Start node not found")), nil - } - +func (c *nodeExecutor) SetInputsForStartNode(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructureWithStartNode, nl executors.NodeLookup, inputs *core.LiteralMap) (executors.NodeStatus, error) { + startNode := dag.StartNode() ctx = contextutils.WithNodeID(ctx, startNode.GetID()) if inputs == nil { logger.Infof(ctx, "No inputs for the workflow. Skipping storing inputs") @@ -559,7 +552,7 @@ func (c *nodeExecutor) SetInputsForStartNode(ctx context.Context, w v1alpha1.Exe } // StartNode is special. It does not have any processing step. It just takes the workflow (or subworkflow) inputs and converts to its own outputs - nodeStatus := w.GetNodeExecutionStatus(ctx, startNode.GetID()) + nodeStatus := nl.GetNodeExecutionStatus(ctx, startNode.GetID()) if len(nodeStatus.GetDataDir()) == 0 { return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, startNode.GetID(), "no data-dir set, cannot store inputs") @@ -575,23 +568,50 @@ func (c *nodeExecutor) SetInputsForStartNode(ctx context.Context, w v1alpha1.Exe return executors.NodeStatusComplete, nil } -func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { +func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) (executors.NodeStatus, error) { currentNodeCtx := contextutils.WithNodeID(ctx, currentNode.GetID()) - nodeStatus := w.GetNodeExecutionStatus(ctx, currentNode.GetID()) + nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) switch nodeStatus.GetPhase() { case v1alpha1.NodePhaseNotYetStarted, v1alpha1.NodePhaseQueued, v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseTimingOut, v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseSucceeding: + // TODO Follow up Pull Request, + // 1. Rename this method to DAGTraversalHandleNode (accepts a DAGStructure along-with) the remaining arguments + // 2. Create a new method called HandleNode (part of the interface) (remaining all args as the previous method, but no DAGStructure + // 3. Additional both methods will receive inputs reader + // 4. The Downstream nodes handler will Resolve the Inputs + // 5. the method will delegate all other node handling to HandleNode. + // 6. Thus we can get rid of SetInputs for StartNode as well logger.Debugf(currentNodeCtx, "Handling node Status [%v]", nodeStatus.GetPhase().String()) t := c.metrics.NodeExecutionTime.Start(ctx) defer t.Stop() - return c.handleNode(currentNodeCtx, w, currentNode) + + // This is an optimization to avoid creating the nodeContext object in case the node has already been looked at. + // If the overhead was zero, we would just do the isDirtyCheck after the nodeContext is created + nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) + if nodeStatus.IsDirty() { + return executors.NodeStatusRunning, nil + } + + nCtx, err := c.newNodeExecContextDefault(ctx, currentNode.GetID(), execContext, nl) + if err != nil { + // NodeExecution creation failure is a permanent fail / system error. + // Should a system failure always return an err? + return executors.NodeStatusFailed(err), nil + } + + // Now depending on the node type decide + h, err := c.nodeHandlerFactory.GetHandler(nCtx.Node().GetKind()) + if err != nil { + return executors.NodeStatusUndefined, err + } + return c.handleNode(currentNodeCtx, dag, nCtx, h) // TODO we can optimize skip state handling by iterating down the graph and marking all as skipped // Currently we treat either Skip or Success the same way. In this approach only one node will be skipped // at a time. As we iterate down, further nodes will be skipped case v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseSkipped: - return c.handleDownstream(ctx, w, currentNode) + return c.handleDownstream(ctx, execContext, dag, nl, currentNode) case v1alpha1.NodePhaseFailed: logger.Debugf(currentNodeCtx, "Node Failed") return executors.NodeStatusFailed(errors.Errorf(errors.RuntimeExecutionError, currentNode.GetID(), "Node Failed.")), nil @@ -602,8 +622,8 @@ func (c *nodeExecutor) RecursiveNodeHandler(ctx context.Context, w v1alpha1.Exec return executors.NodeStatusUndefined, errors.Errorf(errors.IllegalStateError, currentNode.GetID(), "Should never reach here") } -func (c *nodeExecutor) FinalizeHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode) error { - nodeStatus := w.GetNodeExecutionStatus(ctx, currentNode.GetID()) +func (c *nodeExecutor) FinalizeHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode) error { + nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) switch nodeStatus.GetPhase() { case v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseSucceeding, v1alpha1.NodePhaseRetryableFailure: @@ -615,7 +635,7 @@ func (c *nodeExecutor) FinalizeHandler(ctx context.Context, w v1alpha1.Executabl return err } - nCtx, err := c.newNodeExecContextDefault(ctx, w, currentNode, nodeStatus) + nCtx, err := c.newNodeExecContextDefault(ctx, currentNode.GetID(), execContext, nl) if err != nil { return err } @@ -626,7 +646,7 @@ func (c *nodeExecutor) FinalizeHandler(ctx context.Context, w v1alpha1.Executabl } default: // Abort downstream nodes - downstreamNodes, err := w.FromNode(currentNode.GetID()) + downstreamNodes, err := dag.FromNode(currentNode.GetID()) if err != nil { logger.Debugf(ctx, "Error when retrieving downstream nodes. Error [%v]", err) return nil @@ -634,12 +654,12 @@ func (c *nodeExecutor) FinalizeHandler(ctx context.Context, w v1alpha1.Executabl errs := make([]error, 0, len(downstreamNodes)) for _, d := range downstreamNodes { - downstreamNode, ok := w.GetNode(d) + downstreamNode, ok := nl.GetNode(d) if !ok { return errors.Errorf(errors.BadSpecificationError, currentNode.GetID(), "Unable to find Downstream Node [%v]", d) } - if err := c.FinalizeHandler(ctx, w, downstreamNode); err != nil { + if err := c.FinalizeHandler(ctx, execContext, dag, nl, downstreamNode); err != nil { logger.Infof(ctx, "Failed to abort node [%v]. Error: %v", d, err) errs = append(errs, err) } @@ -655,8 +675,8 @@ func (c *nodeExecutor) FinalizeHandler(ctx context.Context, w v1alpha1.Executabl return nil } -func (c *nodeExecutor) AbortHandler(ctx context.Context, w v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode, reason string) error { - nodeStatus := w.GetNodeExecutionStatus(ctx, currentNode.GetID()) +func (c *nodeExecutor) AbortHandler(ctx context.Context, execContext executors.ExecutionContext, dag executors.DAGStructure, nl executors.NodeLookup, currentNode v1alpha1.ExecutableNode, reason string) error { + nodeStatus := nl.GetNodeExecutionStatus(ctx, currentNode.GetID()) switch nodeStatus.GetPhase() { case v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseFailing, v1alpha1.NodePhaseSucceeding, v1alpha1.NodePhaseRetryableFailure, v1alpha1.NodePhaseQueued: @@ -668,7 +688,7 @@ func (c *nodeExecutor) AbortHandler(ctx context.Context, w v1alpha1.ExecutableWo return err } - nCtx, err := c.newNodeExecContextDefault(ctx, w, currentNode, nodeStatus) + nCtx, err := c.newNodeExecContextDefault(ctx, currentNode.GetID(), execContext, nl) if err != nil { return err } @@ -677,12 +697,8 @@ func (c *nodeExecutor) AbortHandler(ctx context.Context, w v1alpha1.ExecutableWo if err != nil { return err } - nodeExecID := &core.NodeExecutionIdentifier{ - NodeId: nCtx.NodeID(), - ExecutionId: w.GetExecutionID().WorkflowExecutionIdentifier, - } err = c.IdempotentRecordEvent(ctx, &event.NodeExecutionEvent{ - Id: nodeExecID, + Id: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), Phase: core.NodeExecution_ABORTED, OccurredAt: ptypes.TimestampNow(), OutputResult: &event.NodeExecutionEvent_Error{ @@ -702,7 +718,7 @@ func (c *nodeExecutor) AbortHandler(ctx context.Context, w v1alpha1.ExecutableWo } case v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseSkipped: // Abort downstream nodes - downstreamNodes, err := w.FromNode(currentNode.GetID()) + downstreamNodes, err := dag.FromNode(currentNode.GetID()) if err != nil { logger.Debugf(ctx, "Error when retrieving downstream nodes. Error [%v]", err) return nil @@ -710,12 +726,12 @@ func (c *nodeExecutor) AbortHandler(ctx context.Context, w v1alpha1.ExecutableWo errs := make([]error, 0, len(downstreamNodes)) for _, d := range downstreamNodes { - downstreamNode, ok := w.GetNode(d) + downstreamNode, ok := nl.GetNode(d) if !ok { return errors.Errorf(errors.BadSpecificationError, currentNode.GetID(), "Unable to find Downstream Node [%v]", d) } - if err := c.AbortHandler(ctx, w, downstreamNode, reason); err != nil { + if err := c.AbortHandler(ctx, execContext, dag, nl, downstreamNode, reason); err != nil { logger.Infof(ctx, "Failed to abort node [%v]. Error: %v", d, err) errs = append(errs, err) } diff --git a/pkg/controller/nodes/executor_test.go b/pkg/controller/nodes/executor_test.go index 574c958f66..ee202baea3 100644 --- a/pkg/controller/nodes/executor_test.go +++ b/pkg/controller/nodes/executor_test.go @@ -61,7 +61,7 @@ func TestSetInputsForStartNode(t *testing.T) { w.DummyStartNode = &v1alpha1.NodeSpec{ ID: v1alpha1.StartNodeID, } - s, err := exec.SetInputsForStartNode(ctx, w, nil) + s, err := exec.SetInputsForStartNode(ctx, w, w, w, nil) assert.NoError(t, err) assert.Equal(t, executors.NodeStatusComplete, s) }) @@ -73,7 +73,7 @@ func TestSetInputsForStartNode(t *testing.T) { w.DummyStartNode = &v1alpha1.NodeSpec{ ID: v1alpha1.StartNodeID, } - s, err := exec.SetInputsForStartNode(ctx, w, inputs) + s, err := exec.SetInputsForStartNode(ctx, w, w, w, inputs) assert.NoError(t, err) assert.Equal(t, executors.NodeStatusComplete, s) actual := &core.LiteralMap{} @@ -87,7 +87,7 @@ func TestSetInputsForStartNode(t *testing.T) { w.DummyStartNode = &v1alpha1.NodeSpec{ ID: v1alpha1.StartNodeID, } - s, err := exec.SetInputsForStartNode(ctx, w, inputs) + s, err := exec.SetInputsForStartNode(ctx, w, w, w, inputs) assert.Error(t, err) assert.Equal(t, executors.NodeStatusUndefined, s) }) @@ -102,7 +102,7 @@ func TestSetInputsForStartNode(t *testing.T) { w.DummyStartNode = &v1alpha1.NodeSpec{ ID: v1alpha1.StartNodeID, } - s, err := execFail.SetInputsForStartNode(ctx, w, inputs) + s, err := execFail.SetInputsForStartNode(ctx, w, w, w, inputs) assert.Error(t, err) assert.Equal(t, executors.NodeStatusUndefined, s) }) @@ -233,7 +233,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseStartNodes(t *testing.T) { hf.On("GetHandler", v1alpha1.NodeKindStart).Return(h, nil) mockWf, startNode, startNodeStatus := createStartNodeWf(test.currentNodePhase, 0) - s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockWf, mockWf, startNode) if test.expectedError { assert.Error(t, err) } else { @@ -318,7 +318,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { hf.On("GetHandler", v1alpha1.NodeKindEnd).Return(h, nil) mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.parentNodePhase, 0) - s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockWf, mockWf, mockNode) if test.expectedError { assert.Error(t, err) } else { @@ -423,7 +423,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { startNode := mockWf.StartNode() startStatus := mockWf.GetNodeExecutionStatus(ctx, startNode.GetID()) assert.Equal(t, v1alpha1.NodePhaseSucceeded, startStatus.GetPhase()) - s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockWf, mockWf, startNode) if test.expectedError { assert.Error(t, err) } else { @@ -517,70 +517,71 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { // Setup mockN2Status := &mocks.ExecutableNodeStatus{} // No parent node - mockN2Status.On("GetParentNodeID").Return(nil) - mockN2Status.On("GetParentTaskID").Return(nil) - mockN2Status.On("GetPhase").Return(n2Phase) + mockN2Status.OnGetParentNodeID().Return(nil) + mockN2Status.OnGetParentTaskID().Return(nil) + mockN2Status.OnGetPhase().Return(n2Phase) mockN2Status.On("SetDataDir", mock.AnythingOfType(reflect.TypeOf(storage.DataReference("x")).String())) - mockN2Status.On("GetDataDir").Return(storage.DataReference("blah")) + mockN2Status.OnGetDataDir().Return(storage.DataReference("blah")) mockN2Status.On("SetOutputDir", mock.AnythingOfType(reflect.TypeOf(storage.DataReference("x")).String())) - mockN2Status.On("GetOutputDir").Return(storage.DataReference("blah")) - mockN2Status.On("GetWorkflowNodeStatus").Return(nil) + mockN2Status.OnGetOutputDir().Return(storage.DataReference("blah")) + mockN2Status.OnGetWorkflowNodeStatus().Return(nil) - mockN2Status.On("GetStoppedAt").Return(nil) + mockN2Status.OnGetStoppedAt().Return(nil) mockN2Status.On("UpdatePhase", expectedN2Phase, mock.Anything, mock.AnythingOfType("string")) - mockN2Status.On("IsDirty").Return(false) - mockN2Status.On("GetTaskNodeStatus").Return(nil) + mockN2Status.OnIsDirty().Return(false) + mockN2Status.OnGetTaskNodeStatus().Return(nil) mockN2Status.On("ClearDynamicNodeStatus").Return(nil) - mockN2Status.On("GetAttempts").Return(uint32(0)) + mockN2Status.OnGetAttempts().Return(uint32(0)) mockNode := &mocks.ExecutableNode{} - mockNode.On("GetID").Return(nodeN2) - mockNode.On("GetBranchNode").Return(nil) - mockNode.On("GetKind").Return(v1alpha1.NodeKindTask) - mockNode.On("IsStartNode").Return(false) - mockNode.On("IsEndNode").Return(false) - mockNode.On("GetTaskID").Return(&taskID) - mockNode.On("GetInputBindings").Return([]*v1alpha1.Binding{}) - mockNode.On("IsInterruptible").Return(nil) + mockNode.OnGetID().Return(nodeN2) + mockNode.OnGetBranchNode().Return(nil) + mockNode.OnGetKind().Return(v1alpha1.NodeKindTask) + mockNode.OnIsStartNode().Return(false) + mockNode.OnIsEndNode().Return(false) + mockNode.OnGetTaskID().Return(&taskID) + mockNode.OnGetInputBindings().Return([]*v1alpha1.Binding{}) + mockNode.OnIsInterruptible().Return(nil) mockNodeN0 := &mocks.ExecutableNode{} - mockNodeN0.On("GetID").Return(nodeN0) - mockNodeN0.On("GetBranchNode").Return(nil) - mockNodeN0.On("GetKind").Return(v1alpha1.NodeKindTask) - mockNodeN0.On("IsStartNode").Return(false) - mockNodeN0.On("IsEndNode").Return(false) - mockNodeN0.On("GetTaskID").Return(&taskID0) - mockNodeN0.On("IsInterruptible").Return(nil) + mockNodeN0.OnGetID().Return(nodeN0) + mockNodeN0.OnGetBranchNode().Return(nil) + mockNodeN0.OnGetKind().Return(v1alpha1.NodeKindTask) + mockNodeN0.OnIsStartNode().Return(false) + mockNodeN0.OnIsEndNode().Return(false) + mockNodeN0.OnGetTaskID().Return(&taskID0) + mockNodeN0.OnIsInterruptible().Return(nil) mockN0Status := &mocks.ExecutableNodeStatus{} - mockN0Status.On("GetPhase").Return(n0Phase) - mockN0Status.On("GetAttempts").Return(uint32(0)) + mockN0Status.OnGetPhase().Return(n0Phase) + mockN0Status.OnGetAttempts().Return(uint32(0)) - mockN0Status.On("IsDirty").Return(false) - mockN0Status.On("GetParentTaskID").Return(nil) + mockN0Status.OnIsDirty().Return(false) + mockN0Status.OnGetParentTaskID().Return(nil) n := v1.Now() - mockN0Status.On("GetStoppedAt").Return(&n) + mockN0Status.OnGetStoppedAt().Return(&n) tk := &mocks.ExecutableTask{} - tk.On("CoreTask").Return(&core.TaskTemplate{}) + tk.OnCoreTask().Return(&core.TaskTemplate{}) mockWfStatus := &mocks.ExecutableWorkflowStatus{} mockWf := &mocks.ExecutableWorkflow{} - mockWf.On("StartNode").Return(mockNodeN0) - mockWf.On("GetNode", nodeN2).Return(mockNode, true) + mockWf.OnStartNode().Return(mockNodeN0) + mockWf.OnGetNode(nodeN2).Return(mockNode, true) mockWf.OnGetNodeExecutionStatusMatch(mock.Anything, nodeN0).Return(mockN0Status) mockWf.OnGetNodeExecutionStatusMatch(mock.Anything, nodeN2).Return(mockN2Status) - mockWf.On("GetConnections").Return(connections) - mockWf.On("GetID").Return("w1") - mockWf.On("FromNode", nodeN0).Return([]string{nodeN2}, nil) - mockWf.On("FromNode", nodeN2).Return([]string{}, fmt.Errorf("did not expect")) - mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{}) - mockWf.On("GetExecutionStatus").Return(mockWfStatus) - mockWf.On("GetTask", taskID0).Return(tk, nil) - mockWf.On("GetTask", taskID).Return(tk, nil) - mockWf.On("GetLabels").Return(make(map[string]string)) - mockWf.On("IsInterruptible").Return(false) - mockWfStatus.On("GetDataDir").Return(storage.DataReference("x")) - mockWfStatus.On("ConstructNodeDataDir", mock.Anything, mock.Anything, mock.Anything).Return(storage.DataReference("x"), nil) + mockWf.OnGetConnections().Return(connections) + mockWf.OnGetID().Return("w1") + mockWf.OnToNode(nodeN2).Return([]string{nodeN0}, nil) + mockWf.OnFromNode(nodeN0).Return([]string{nodeN2}, nil) + mockWf.OnFromNode(nodeN2).Return([]string{}, fmt.Errorf("did not expect")) + mockWf.OnGetExecutionID().Return(v1alpha1.WorkflowExecutionIdentifier{}) + mockWf.OnGetExecutionStatus().Return(mockWfStatus) + mockWf.OnGetTask(taskID0).Return(tk, nil) + mockWf.OnGetTask(taskID).Return(tk, nil) + mockWf.OnGetLabels().Return(make(map[string]string)) + mockWf.OnIsInterruptible().Return(false) + mockWfStatus.OnGetDataDir().Return(storage.DataReference("x")) + mockWfStatus.OnConstructNodeDataDirMatch(mock.Anything, mock.Anything, mock.Anything).Return("x", nil) return mockWf, mockN2Status } @@ -602,7 +603,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { hf := &mocks2.HandlerFactory{} h := &nodeHandlerMocks.Node{} - h.On("Handle", + h.OnHandleMatch( mock.MatchedBy(func(ctx context.Context) bool { return true }), mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), ).Return(handler.UnknownTransition, fmt.Errorf("should not be called")) @@ -620,7 +621,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { exec := execIface.(*nodeExecutor) exec.nodeHandlerFactory = hf - s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockWf, mockWf, startNode) if test.expectedError { assert.Error(t, err) } else { @@ -722,7 +723,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { startNode := mockWf.StartNode() startStatus := mockWf.GetNodeExecutionStatus(ctx, startNode.GetID()) assert.Equal(t, v1alpha1.NodePhaseSucceeded, startStatus.GetPhase()) - s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockWf, mockWf, startNode) if test.expectedError { assert.Error(t, err) } else { @@ -821,7 +822,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { mockWf, _, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 1) startNode := mockWf.StartNode() - s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockWf, mockWf, startNode) if test.expectedError { assert.Error(t, err) } else { @@ -857,7 +858,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { mockWf, _, mockNodeStatus := createSingleNodeWf(v1alpha1.NodePhaseRunning, 0) startNode := mockWf.StartNode() - s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockWf, mockWf, startNode) assert.NoError(t, err) assert.Equal(t, executors.NodePhasePending.String(), s.NodePhase.String()) assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) @@ -886,7 +887,7 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { mockWf, _, mockNodeStatus := createSingleNodeWf(v1alpha1.NodePhaseRunning, 1) startNode := mockWf.StartNode() - s, err := exec.RecursiveNodeHandler(ctx, mockWf, startNode) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockWf, mockWf, startNode) assert.NoError(t, err) assert.Equal(t, executors.NodePhasePending.String(), s.NodePhase.String()) assert.Equal(t, uint32(0), mockNodeStatus.GetAttempts()) @@ -990,7 +991,7 @@ func TestNodeExecutor_RecursiveNodeHandler_NoDownstream(t *testing.T) { hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.currentNodePhase, 1) - s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockWf, mockWf, mockNode) if test.expectedError { assert.Error(t, err) } else { @@ -1094,7 +1095,7 @@ func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { hf.On("GetHandler", v1alpha1.NodeKindTask).Return(h, nil) mockWf, mockNode, mockNodeStatus := createSingleNodeWf(test.parentNodePhase, 0) - s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockNode) + s, err := exec.RecursiveNodeHandler(ctx, mockWf, mockWf, mockWf, mockNode) if test.expectedError { assert.Error(t, err) } else { @@ -1108,6 +1109,109 @@ func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { } } +func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) { + ctx := context.TODO() + enQWf := func(workflowID v1alpha1.WorkflowID) { + } + mockEventSink := events.NewMockEventSink().(*events.MockEventSink) + + store := createInmemoryDataStore(t, promutils.NewTestScope()) + + adminClient := launchplan.NewFailFastLaunchPlanExecutor() + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, + 10, "s3://bucket", fakeKubeClient, catalogClient, promutils.NewTestScope()) + assert.NoError(t, err) + exec := execIface.(*nodeExecutor) + // Node not yet started + { + tests := []struct { + name string + parentNodePhase v1alpha1.BranchNodePhase + currentNodePhase v1alpha1.NodePhase + phaseUpdateExpected bool + expectedPhase executors.NodePhase + expectedError bool + }{ + {"branchSuccess", v1alpha1.BranchNodeSuccess, v1alpha1.NodePhaseNotYetStarted, true, executors.NodePhaseQueued, false}, + {"branchNotYetDone", v1alpha1.BranchNodeNotYetEvaluated, v1alpha1.NodePhaseNotYetStarted, false, executors.NodePhaseUndefined, true}, + {"branchError", v1alpha1.BranchNodeError, v1alpha1.NodePhaseNotYetStarted, false, executors.NodePhaseUndefined, true}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + hf := &mocks2.HandlerFactory{} + exec.nodeHandlerFactory = hf + h := &nodeHandlerMocks.Node{} + h.OnHandleMatch( + mock.MatchedBy(func(ctx context.Context) bool { return true }), + mock.MatchedBy(func(o handler.NodeExecutionContext) bool { return true }), + ).Return(handler.UnknownTransition, fmt.Errorf("should not be called")) + h.OnFinalizeRequired().Return(true) + h.OnFinalizeMatch(mock.Anything, mock.Anything).Return(fmt.Errorf("error")) + + hf.OnGetHandlerMatch(v1alpha1.NodeKindTask).Return(h, nil) + + parentBranchNodeID := "branchNode" + parentBranchNode := &mocks.ExecutableNode{} + parentBranchNode.OnGetID().Return(parentBranchNodeID) + parentBranchNode.OnGetBranchNode().Return(&mocks.ExecutableBranchNode{}) + parentBranchNodeStatus := &mocks.ExecutableNodeStatus{} + parentBranchNodeStatus.OnGetPhase().Return(v1alpha1.NodePhaseRunning) + parentBranchNodeStatus.OnIsDirty().Return(false) + bns := &mocks.MutableBranchNodeStatus{} + parentBranchNodeStatus.OnGetBranchStatus().Return(bns) + bns.OnGetPhase().Return(test.parentNodePhase) + + tk := &mocks.ExecutableTask{} + tk.OnCoreTask().Return(&core.TaskTemplate{}) + + tid := "tid" + eCtx := &mocks4.ExecutionContext{} + eCtx.OnGetTask(tid).Return(tk, nil) + eCtx.OnIsInterruptible().Return(true) + eCtx.OnGetExecutionID().Return(v1alpha1.WorkflowExecutionIdentifier{WorkflowExecutionIdentifier: &core.WorkflowExecutionIdentifier{}}) + eCtx.OnGetLabels().Return(nil) + + branchTakenNodeID := "branchTakenNode" + branchTakenNode := &mocks.ExecutableNode{} + branchTakenNode.OnGetID().Return(branchTakenNodeID) + branchTakenNode.OnGetKind().Return(v1alpha1.NodeKindTask) + branchTakenNode.OnGetTaskID().Return(&tid) + branchTakenNode.OnIsInterruptible().Return(nil) + branchTakenNode.OnIsStartNode().Return(false) + branchTakenNode.OnIsEndNode().Return(false) + branchTakenNode.OnGetInputBindings().Return(nil) + branchTakeNodeStatus := &mocks.ExecutableNodeStatus{} + branchTakeNodeStatus.OnGetPhase().Return(test.currentNodePhase) + branchTakeNodeStatus.OnIsDirty().Return(false) + branchTakeNodeStatus.OnGetSystemFailures().Return(1) + branchTakeNodeStatus.OnGetDataDir().Return("data") + branchTakeNodeStatus.OnGetParentNodeID().Return(&parentBranchNodeID) + branchTakeNodeStatus.OnGetParentTaskID().Return(nil) + + if test.phaseUpdateExpected { + branchTakeNodeStatus.On("UpdatePhase", v1alpha1.NodePhaseQueued, mock.Anything, mock.Anything).Return() + } + + leafDag := executors.NewLeafNodeDAGStructure(branchTakenNodeID, parentBranchNodeID) + + nl := executors.NewTestNodeLookup( + map[v1alpha1.NodeID]v1alpha1.ExecutableNode{branchTakenNodeID: branchTakenNode, parentBranchNodeID: parentBranchNode}, + map[v1alpha1.NodeID]v1alpha1.ExecutableNodeStatus{branchTakenNodeID: branchTakeNodeStatus, parentBranchNodeID: parentBranchNodeStatus}, + ) + + s, err := exec.RecursiveNodeHandler(ctx, eCtx, leafDag, nl, branchTakenNode) + if test.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedPhase, s.NodePhase, "expected: %s, received %s", test.expectedPhase.String(), s.NodePhase.String()) + }) + } + } +} + func Test_nodeExecutor_RecordTransitionLatency(t *testing.T) { testScope := promutils.NewTestScope() type fields struct { @@ -1162,7 +1266,7 @@ func Test_nodeExecutor_RecordTransitionLatency(t *testing.T) { nodeRecorder: tt.fields.nodeRecorder, metrics: tt.fields.metrics, } - c.RecordTransitionLatency(context.TODO(), tt.args.w, tt.args.node, tt.args.nodeStatus) + c.RecordTransitionLatency(context.TODO(), tt.args.w, tt.args.w, tt.args.node, tt.args.nodeStatus) ch := make(chan prometheus.Metric, 2) tt.fields.metrics.TransitionLatency.Collect(ch) @@ -1271,7 +1375,7 @@ func Test_nodeExecutor_timeout(t *testing.T) { mockNode.On("GetExecutionDeadline").Return(&tt.executionDeadline) mockNode.OnGetRetryStrategy().Return(&v1alpha1.RetryStrategy{MinAttempts: &tt.retries}) - nCtx := &execContext{node: mockNode, nsm: &nodeStateManager{nodeStatus: ns}} + nCtx := &nodeExecContext{node: mockNode, nsm: &nodeStateManager{nodeStatus: ns}} phaseInfo, err := c.execute(context.TODO(), h, nCtx, ns) if tt.err != nil { @@ -1318,7 +1422,7 @@ func Test_nodeExecutor_system_error(t *testing.T) { retries := 2 mockNode.OnGetRetryStrategy().Return(&v1alpha1.RetryStrategy{MinAttempts: &retries}) - nCtx := &execContext{node: mockNode, nsm: &nodeStateManager{nodeStatus: ns}} + nCtx := &nodeExecContext{node: mockNode, nsm: &nodeStateManager{nodeStatus: ns}} phaseInfo, err := c.execute(context.TODO(), h, nCtx, ns) assert.Equal(t, handler.EPhaseRetryableFailure, phaseInfo.GetPhase()) assert.NoError(t, err) @@ -1328,7 +1432,7 @@ func Test_nodeExecutor_system_error(t *testing.T) { func Test_nodeExecutor_abort(t *testing.T) { ctx := context.Background() exec := nodeExecutor{} - nCtx := &execContext{} + nCtx := &nodeExecContext{} t.Run("abort error calls finalize", func(t *testing.T) { h := &nodeHandlerMocks.Node{} diff --git a/pkg/controller/nodes/handler/mocks/node_execution_context.go b/pkg/controller/nodes/handler/mocks/node_execution_context.go index 59456d50eb..04aa13a6a4 100644 --- a/pkg/controller/nodes/handler/mocks/node_execution_context.go +++ b/pkg/controller/nodes/handler/mocks/node_execution_context.go @@ -4,9 +4,12 @@ package mocks import ( events "github.com/lyft/flyteidl/clients/go/events" - io "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io" + executors "github.com/lyft/flytepropeller/pkg/controller/executors" + handler "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + io "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io" + ioutils "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils" mock "github.com/stretchr/testify/mock" @@ -21,6 +24,40 @@ type NodeExecutionContext struct { mock.Mock } +type NodeExecutionContext_ContextualNodeLookup struct { + *mock.Call +} + +func (_m NodeExecutionContext_ContextualNodeLookup) Return(_a0 executors.NodeLookup) *NodeExecutionContext_ContextualNodeLookup { + return &NodeExecutionContext_ContextualNodeLookup{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeExecutionContext) OnContextualNodeLookup() *NodeExecutionContext_ContextualNodeLookup { + c := _m.On("ContextualNodeLookup") + return &NodeExecutionContext_ContextualNodeLookup{Call: c} +} + +func (_m *NodeExecutionContext) OnContextualNodeLookupMatch(matchers ...interface{}) *NodeExecutionContext_ContextualNodeLookup { + c := _m.On("ContextualNodeLookup", matchers...) + return &NodeExecutionContext_ContextualNodeLookup{Call: c} +} + +// ContextualNodeLookup provides a mock function with given fields: +func (_m *NodeExecutionContext) ContextualNodeLookup() executors.NodeLookup { + ret := _m.Called() + + var r0 executors.NodeLookup + if rf, ok := ret.Get(0).(func() executors.NodeLookup); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(executors.NodeLookup) + } + } + + return r0 +} + type NodeExecutionContext_CurrentAttempt struct { *mock.Call } @@ -155,6 +192,40 @@ func (_m *NodeExecutionContext) EventsRecorder() events.TaskEventRecorder { return r0 } +type NodeExecutionContext_ExecutionContext struct { + *mock.Call +} + +func (_m NodeExecutionContext_ExecutionContext) Return(_a0 executors.ExecutionContext) *NodeExecutionContext_ExecutionContext { + return &NodeExecutionContext_ExecutionContext{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeExecutionContext) OnExecutionContext() *NodeExecutionContext_ExecutionContext { + c := _m.On("ExecutionContext") + return &NodeExecutionContext_ExecutionContext{Call: c} +} + +func (_m *NodeExecutionContext) OnExecutionContextMatch(matchers ...interface{}) *NodeExecutionContext_ExecutionContext { + c := _m.On("ExecutionContext", matchers...) + return &NodeExecutionContext_ExecutionContext{Call: c} +} + +// ExecutionContext provides a mock function with given fields: +func (_m *NodeExecutionContext) ExecutionContext() executors.ExecutionContext { + ret := _m.Called() + + var r0 executors.ExecutionContext + if rf, ok := ret.Get(0).(func() executors.ExecutionContext); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(executors.ExecutionContext) + } + } + + return r0 +} + type NodeExecutionContext_InputReader struct { *mock.Call } @@ -522,37 +593,3 @@ func (_m *NodeExecutionContext) TaskReader() handler.TaskReader { return r0 } - -type NodeExecutionContext_Workflow struct { - *mock.Call -} - -func (_m NodeExecutionContext_Workflow) Return(_a0 v1alpha1.ExecutableWorkflow) *NodeExecutionContext_Workflow { - return &NodeExecutionContext_Workflow{Call: _m.Call.Return(_a0)} -} - -func (_m *NodeExecutionContext) OnWorkflow() *NodeExecutionContext_Workflow { - c := _m.On("Workflow") - return &NodeExecutionContext_Workflow{Call: c} -} - -func (_m *NodeExecutionContext) OnWorkflowMatch(matchers ...interface{}) *NodeExecutionContext_Workflow { - c := _m.On("Workflow", matchers...) - return &NodeExecutionContext_Workflow{Call: c} -} - -// Workflow provides a mock function with given fields: -func (_m *NodeExecutionContext) Workflow() v1alpha1.ExecutableWorkflow { - ret := _m.Called() - - var r0 v1alpha1.ExecutableWorkflow - if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableWorkflow); ok { - r0 = rf() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(v1alpha1.ExecutableWorkflow) - } - } - - return r0 -} diff --git a/pkg/controller/nodes/handler/mocks/node_execution_metadata.go b/pkg/controller/nodes/handler/mocks/node_execution_metadata.go index f3bc9c3b22..7e3b0be889 100644 --- a/pkg/controller/nodes/handler/mocks/node_execution_metadata.go +++ b/pkg/controller/nodes/handler/mocks/node_execution_metadata.go @@ -3,12 +3,13 @@ package mocks import ( + core "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + mock "github.com/stretchr/testify/mock" + types "k8s.io/apimachinery/pkg/types" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - v1alpha1 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" ) // NodeExecutionMetadata is an autogenerated mock type for the NodeExecutionMetadata type @@ -50,38 +51,6 @@ func (_m *NodeExecutionMetadata) GetAnnotations() map[string]string { return r0 } -type NodeExecutionMetadata_GetExecutionID struct { - *mock.Call -} - -func (_m NodeExecutionMetadata_GetExecutionID) Return(_a0 v1alpha1.WorkflowExecutionIdentifier) *NodeExecutionMetadata_GetExecutionID { - return &NodeExecutionMetadata_GetExecutionID{Call: _m.Call.Return(_a0)} -} - -func (_m *NodeExecutionMetadata) OnGetExecutionID() *NodeExecutionMetadata_GetExecutionID { - c := _m.On("GetExecutionID") - return &NodeExecutionMetadata_GetExecutionID{Call: c} -} - -func (_m *NodeExecutionMetadata) OnGetExecutionIDMatch(matchers ...interface{}) *NodeExecutionMetadata_GetExecutionID { - c := _m.On("GetExecutionID", matchers...) - return &NodeExecutionMetadata_GetExecutionID{Call: c} -} - -// GetExecutionID provides a mock function with given fields: -func (_m *NodeExecutionMetadata) GetExecutionID() v1alpha1.WorkflowExecutionIdentifier { - ret := _m.Called() - - var r0 v1alpha1.WorkflowExecutionIdentifier - if rf, ok := ret.Get(0).(func() v1alpha1.WorkflowExecutionIdentifier); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(v1alpha1.WorkflowExecutionIdentifier) - } - - return r0 -} - type NodeExecutionMetadata_GetK8sServiceAccount struct { *mock.Call } @@ -180,6 +149,40 @@ func (_m *NodeExecutionMetadata) GetNamespace() string { return r0 } +type NodeExecutionMetadata_GetNodeExecutionID struct { + *mock.Call +} + +func (_m NodeExecutionMetadata_GetNodeExecutionID) Return(_a0 *core.NodeExecutionIdentifier) *NodeExecutionMetadata_GetNodeExecutionID { + return &NodeExecutionMetadata_GetNodeExecutionID{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeExecutionMetadata) OnGetNodeExecutionID() *NodeExecutionMetadata_GetNodeExecutionID { + c := _m.On("GetNodeExecutionID") + return &NodeExecutionMetadata_GetNodeExecutionID{Call: c} +} + +func (_m *NodeExecutionMetadata) OnGetNodeExecutionIDMatch(matchers ...interface{}) *NodeExecutionMetadata_GetNodeExecutionID { + c := _m.On("GetNodeExecutionID", matchers...) + return &NodeExecutionMetadata_GetNodeExecutionID{Call: c} +} + +// GetNodeExecutionID provides a mock function with given fields: +func (_m *NodeExecutionMetadata) GetNodeExecutionID() *core.NodeExecutionIdentifier { + ret := _m.Called() + + var r0 *core.NodeExecutionIdentifier + if rf, ok := ret.Get(0).(func() *core.NodeExecutionIdentifier); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.NodeExecutionIdentifier) + } + } + + return r0 +} + type NodeExecutionMetadata_GetOwnerID struct { *mock.Call } diff --git a/pkg/controller/nodes/handler/node_exec_context.go b/pkg/controller/nodes/handler/node_exec_context.go index dd9224b8c9..187d8bb854 100644 --- a/pkg/controller/nodes/handler/node_exec_context.go +++ b/pkg/controller/nodes/handler/node_exec_context.go @@ -14,6 +14,7 @@ import ( "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io" "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/executors" ) type TaskReader interface { @@ -30,8 +31,7 @@ type SetupContext interface { type NodeExecutionMetadata interface { GetOwnerID() types.NamespacedName - // TODO we should covert this to a generic execution identifier instead of a workflow identifier - GetExecutionID() v1alpha1.WorkflowExecutionIdentifier + GetNodeExecutionID() *core.NodeExecutionIdentifier GetNamespace() string GetOwnerReference() v1.OwnerReference GetLabels() map[string]string @@ -66,8 +66,8 @@ type NodeExecutionContext interface { EnqueueOwnerFunc() func() error - // Deprecated - Workflow() v1alpha1.ExecutableWorkflow + ContextualNodeLookup() executors.NodeLookup + ExecutionContext() executors.ExecutionContext // TODO We should not need to pass NodeStatus, we probably only need it for DataDir, which should actually be sent using an OutputWriter interface // Deprecated NodeStatus() v1alpha1.ExecutableNodeStatus diff --git a/pkg/controller/nodes/node_exec_context.go b/pkg/controller/nodes/node_exec_context.go index 6099573f80..225d712fbd 100644 --- a/pkg/controller/nodes/node_exec_context.go +++ b/pkg/controller/nodes/node_exec_context.go @@ -6,6 +6,7 @@ import ( "strconv" "github.com/lyft/flyteidl/clients/go/events" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" "github.com/lyft/flytestdlib/storage" "k8s.io/apimachinery/pkg/types" @@ -13,6 +14,7 @@ import ( "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/executors" "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" "github.com/lyft/flytepropeller/pkg/utils" ) @@ -21,29 +23,34 @@ const NodeIDLabel = "node-id" const TaskNameLabel = "task-name" const NodeInterruptibleLabel = "interruptible" -type execMetadata struct { - v1alpha1.WorkflowMeta +type nodeExecMetadata struct { + v1alpha1.Meta + nodeExecID *core.NodeExecutionIdentifier interrutptible bool nodeLabels map[string]string } -func (e execMetadata) GetK8sServiceAccount() string { - return e.WorkflowMeta.GetServiceAccountName() +func (e nodeExecMetadata) GetNodeExecutionID() *core.NodeExecutionIdentifier { + return e.nodeExecID } -func (e execMetadata) GetOwnerID() types.NamespacedName { +func (e nodeExecMetadata) GetK8sServiceAccount() string { + return e.Meta.GetServiceAccountName() +} + +func (e nodeExecMetadata) GetOwnerID() types.NamespacedName { return types.NamespacedName{Name: e.GetName(), Namespace: e.GetNamespace()} } -func (e execMetadata) IsInterruptible() bool { +func (e nodeExecMetadata) IsInterruptible() bool { return e.interrutptible } -func (e execMetadata) GetLabels() map[string]string { +func (e nodeExecMetadata) GetLabels() map[string]string { return e.nodeLabels } -type execContext struct { +type nodeExecContext struct { store *storage.DataStore tr handler.TaskReader md handler.NodeExecutionMetadata @@ -54,81 +61,93 @@ type execContext struct { maxDatasetSizeBytes int64 nsm *nodeStateManager enqueueOwner func() error - w v1alpha1.ExecutableWorkflow rawOutputPrefix storage.DataReference shardSelector ioutils.ShardSelector + nl executors.NodeLookup + ic executors.ExecutionContext +} + +func (e nodeExecContext) ExecutionContext() executors.ExecutionContext { + return e.ic +} + +func (e nodeExecContext) ContextualNodeLookup() executors.NodeLookup { + return e.nl } -func (e execContext) OutputShardSelector() ioutils.ShardSelector { +func (e nodeExecContext) OutputShardSelector() ioutils.ShardSelector { return e.shardSelector } -func (e execContext) RawOutputPrefix() storage.DataReference { +func (e nodeExecContext) RawOutputPrefix() storage.DataReference { return e.rawOutputPrefix } -func (e execContext) EnqueueOwnerFunc() func() error { +func (e nodeExecContext) EnqueueOwnerFunc() func() error { return e.enqueueOwner } -func (e execContext) Workflow() v1alpha1.ExecutableWorkflow { - return e.w -} - -func (e execContext) TaskReader() handler.TaskReader { +func (e nodeExecContext) TaskReader() handler.TaskReader { return e.tr } -func (e execContext) NodeStateReader() handler.NodeStateReader { +func (e nodeExecContext) NodeStateReader() handler.NodeStateReader { return e.nsm } -func (e execContext) NodeStateWriter() handler.NodeStateWriter { +func (e nodeExecContext) NodeStateWriter() handler.NodeStateWriter { return e.nsm } -func (e execContext) DataStore() *storage.DataStore { +func (e nodeExecContext) DataStore() *storage.DataStore { return e.store } -func (e execContext) InputReader() io.InputReader { +func (e nodeExecContext) InputReader() io.InputReader { return e.inputs } -func (e execContext) EventsRecorder() events.TaskEventRecorder { +func (e nodeExecContext) EventsRecorder() events.TaskEventRecorder { return e.er } -func (e execContext) NodeID() v1alpha1.NodeID { +func (e nodeExecContext) NodeID() v1alpha1.NodeID { return e.node.GetID() } -func (e execContext) Node() v1alpha1.ExecutableNode { +func (e nodeExecContext) Node() v1alpha1.ExecutableNode { return e.node } -func (e execContext) CurrentAttempt() uint32 { +func (e nodeExecContext) CurrentAttempt() uint32 { return e.nodeStatus.GetAttempts() } -func (e execContext) NodeStatus() v1alpha1.ExecutableNodeStatus { +func (e nodeExecContext) NodeStatus() v1alpha1.ExecutableNodeStatus { return e.nodeStatus } -func (e execContext) NodeExecutionMetadata() handler.NodeExecutionMetadata { +func (e nodeExecContext) NodeExecutionMetadata() handler.NodeExecutionMetadata { return e.md } -func (e execContext) MaxDatasetSizeBytes() int64 { +func (e nodeExecContext) MaxDatasetSizeBytes() int64 { return e.maxDatasetSizeBytes } -func newNodeExecContext(_ context.Context, store *storage.DataStore, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus, inputs io.InputReader, interruptible bool, maxDatasetSize int64, er events.TaskEventRecorder, tr handler.TaskReader, nsm *nodeStateManager, enqueueOwner func() error, rawOutputPrefix storage.DataReference, outputShardSelector ioutils.ShardSelector) *execContext { - md := execMetadata{WorkflowMeta: w, interrutptible: interruptible} +func newNodeExecContext(_ context.Context, store *storage.DataStore, execContext executors.ExecutionContext, nl executors.NodeLookup, node v1alpha1.ExecutableNode, nodeStatus v1alpha1.ExecutableNodeStatus, inputs io.InputReader, interruptible bool, maxDatasetSize int64, er events.TaskEventRecorder, tr handler.TaskReader, nsm *nodeStateManager, enqueueOwner func() error, rawOutputPrefix storage.DataReference, outputShardSelector ioutils.ShardSelector) *nodeExecContext { + md := nodeExecMetadata{ + Meta: execContext, + nodeExecID: &core.NodeExecutionIdentifier{ + NodeId: node.GetID(), + ExecutionId: execContext.GetExecutionID().WorkflowExecutionIdentifier, + }, + interrutptible: interruptible, + } // Copy the wf labels before adding node specific labels. nodeLabels := make(map[string]string) - for k, v := range w.GetLabels() { + for k, v := range execContext.GetLabels() { nodeLabels[k] = v } nodeLabels[NodeIDLabel] = utils.SanitizeLabelValue(node.GetID()) @@ -138,7 +157,7 @@ func newNodeExecContext(_ context.Context, store *storage.DataStore, w v1alpha1. nodeLabels[NodeInterruptibleLabel] = strconv.FormatBool(interruptible) md.nodeLabels = nodeLabels - return &execContext{ + return &nodeExecContext{ md: md, store: store, node: node, @@ -149,42 +168,50 @@ func newNodeExecContext(_ context.Context, store *storage.DataStore, w v1alpha1. tr: tr, nsm: nsm, enqueueOwner: enqueueOwner, - w: w, rawOutputPrefix: rawOutputPrefix, shardSelector: outputShardSelector, + nl: nl, + ic: execContext, } } -func (c *nodeExecutor) newNodeExecContextDefault(ctx context.Context, w v1alpha1.ExecutableWorkflow, n v1alpha1.ExecutableNode, s v1alpha1.ExecutableNodeStatus) (*execContext, error) { +func (c *nodeExecutor) newNodeExecContextDefault(ctx context.Context, currentNodeID v1alpha1.NodeID, executionContext executors.ExecutionContext, nl executors.NodeLookup) (*nodeExecContext, error) { + n, ok := nl.GetNode(currentNodeID) + if !ok { + return nil, fmt.Errorf("failed to find node with ID [%s] in execution [%s]", currentNodeID, executionContext.GetID()) + } + var tr handler.TaskReader if n.GetKind() == v1alpha1.NodeKindTask { if n.GetTaskID() == nil { return nil, fmt.Errorf("bad state, no task-id defined for node [%s]", n.GetID()) } - tk, err := w.GetTask(*n.GetTaskID()) + tk, err := executionContext.GetTask(*n.GetTaskID()) if err != nil { return nil, err } - tr = &taskReader{TaskTemplate: tk.CoreTask()} + tr = taskReader{TaskTemplate: tk.CoreTask()} } workflowEnqueuer := func() error { - c.enqueueWorkflow(w.GetID()) + c.enqueueWorkflow(executionContext.GetID()) return nil } - interrutible := w.IsInterruptible() + interrutible := executionContext.IsInterruptible() if n.IsInterruptible() != nil { interrutible = *n.IsInterruptible() } + s := nl.GetNodeExecutionStatus(ctx, currentNodeID) + // a node is not considered interruptible if the system failures have exceeded the configured threshold if interrutible && s.GetSystemFailures() >= c.interruptibleFailureThreshold { interrutible = false c.metrics.InterruptedThresholdHit.Inc(ctx) } - return newNodeExecContext(ctx, c.store, w, n, s, + return newNodeExecContext(ctx, c.store, executionContext, nl, n, s, ioutils.NewCachedInputReader( ctx, ioutils.NewRemoteFileInputReader( diff --git a/pkg/controller/nodes/node_exec_context_test.go b/pkg/controller/nodes/node_exec_context_test.go index dc42eeedb2..9b8483c2ae 100644 --- a/pkg/controller/nodes/node_exec_context_test.go +++ b/pkg/controller/nodes/node_exec_context_test.go @@ -47,7 +47,7 @@ func Test_NodeContext(t *testing.T) { Kind: v1alpha1.NodeKindTask, } s, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) - nCtx := newNodeExecContext(context.TODO(), s, w1, n, nil, nil, false, 0, nil, TaskReader{}, nil, nil, "s3://bucket", ioutils.NewConstantShardSelector([]string{"x"})) + nCtx := newNodeExecContext(context.TODO(), s, w1, w1, n, nil, nil, false, 0, nil, TaskReader{}, nil, nil, "s3://bucket", ioutils.NewConstantShardSelector([]string{"x"})) assert.Equal(t, "id", nCtx.NodeExecutionMetadata().GetLabels()["node-id"]) assert.Equal(t, "false", nCtx.NodeExecutionMetadata().GetLabels()["interruptible"]) assert.Equal(t, "task-name", nCtx.NodeExecutionMetadata().GetLabels()["task-name"]) diff --git a/pkg/controller/nodes/output_resolver.go b/pkg/controller/nodes/output_resolver.go index 6bd586da95..cc84838fb1 100644 --- a/pkg/controller/nodes/output_resolver.go +++ b/pkg/controller/nodes/output_resolver.go @@ -7,6 +7,7 @@ import ( "github.com/lyft/flytestdlib/logger" "github.com/lyft/flytestdlib/storage" + "github.com/lyft/flytepropeller/pkg/controller/executors" "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" @@ -20,7 +21,7 @@ type VarName = string type OutputResolver interface { // Extracts a subset of node outputs to literals. - ExtractOutput(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, n v1alpha1.ExecutableNode, + ExtractOutput(ctx context.Context, nl executors.NodeLookup, n v1alpha1.ExecutableNode, bindToVar VarName) (values *core.Literal, err error) } @@ -37,9 +38,9 @@ type remoteFileOutputResolver struct { store *storage.DataStore } -func (r remoteFileOutputResolver) ExtractOutput(ctx context.Context, w v1alpha1.BaseWorkflowWithStatus, n v1alpha1.ExecutableNode, +func (r remoteFileOutputResolver) ExtractOutput(ctx context.Context, nl executors.NodeLookup, n v1alpha1.ExecutableNode, bindToVar VarName) (values *core.Literal, err error) { - nodeStatus := w.GetNodeExecutionStatus(ctx, n.GetID()) + nodeStatus := nl.GetNodeExecutionStatus(ctx, n.GetID()) outputsFileRef := v1alpha1.GetOutputsFile(nodeStatus.GetOutputDir()) index, actualVar, err := ParseVarName(bindToVar) diff --git a/pkg/controller/nodes/predicate.go b/pkg/controller/nodes/predicate.go index 499df73963..63b52fc0b8 100644 --- a/pkg/controller/nodes/predicate.go +++ b/pkg/controller/nodes/predicate.go @@ -8,6 +8,7 @@ import ( v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/executors" "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" ) @@ -37,36 +38,36 @@ func (p PredicatePhase) String() string { return "undefined" } -func CanExecute(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.BaseNode) (PredicatePhase, error) { +func CanExecute(ctx context.Context, dag executors.DAGStructure, nl executors.NodeLookup, node v1alpha1.BaseNode) (PredicatePhase, error) { nodeID := node.GetID() if nodeID == v1alpha1.StartNodeID { logger.Debugf(ctx, "Start Node id is assumed to be ready.") return PredicatePhaseReady, nil } - nodeStatus := w.GetNodeExecutionStatus(ctx, nodeID) + nodeStatus := nl.GetNodeExecutionStatus(ctx, nodeID) parentNodeID := nodeStatus.GetParentNodeID() - upstreamNodes, ok := w.GetConnections().UpstreamEdges[nodeID] - if !ok { + upstreamNodes, err := dag.ToNode(nodeID) + if err != nil { return PredicatePhaseUndefined, errors.Errorf(errors.BadSpecificationError, nodeID, "Unable to find upstream nodes for Node") } skipped := false for _, upstreamNodeID := range upstreamNodes { - upstreamNodeStatus := w.GetNodeExecutionStatus(ctx, upstreamNodeID) + upstreamNodeStatus := nl.GetNodeExecutionStatus(ctx, upstreamNodeID) if upstreamNodeStatus.IsDirty() { return PredicatePhaseNotReady, nil } if parentNodeID != nil && *parentNodeID == upstreamNodeID { - upstreamNode, ok := w.GetNode(upstreamNodeID) + upstreamNode, ok := nl.GetNode(upstreamNodeID) if !ok { return PredicatePhaseUndefined, errors.Errorf(errors.BadSpecificationError, nodeID, "Upstream node [%v] of node [%v] not defined", upstreamNodeID, nodeID) } // This only happens if current node is the child node of a branch node - if upstreamNode.GetBranchNode() == nil || upstreamNodeStatus.GetOrCreateBranchStatus().GetPhase() != v1alpha1.BranchNodeSuccess { + if upstreamNode.GetBranchNode() == nil || upstreamNodeStatus.GetBranchStatus().GetPhase() != v1alpha1.BranchNodeSuccess { logger.Debugf(ctx, "Branch sub node is expected to have parent branch node in succeeded state") return PredicatePhaseUndefined, errors.Errorf(errors.IllegalStateError, nodeID, "Upstream node [%v] is set as parent, but is not a branch node of [%v] or in illegal state.", upstreamNodeID, nodeID) } @@ -88,7 +89,7 @@ func CanExecute(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha return PredicatePhaseReady, nil } -func GetParentNodeMaxEndTime(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.BaseNode) (t v1.Time, err error) { +func GetParentNodeMaxEndTime(ctx context.Context, dag executors.DAGStructure, nl executors.NodeLookup, node v1alpha1.BaseNode) (t v1.Time, err error) { zeroTime := v1.NewTime(time.Time{}) nodeID := node.GetID() if nodeID == v1alpha1.StartNodeID { @@ -96,24 +97,24 @@ func GetParentNodeMaxEndTime(ctx context.Context, w v1alpha1.ExecutableWorkflow, return zeroTime, nil } - nodeStatus := w.GetNodeExecutionStatus(ctx, node.GetID()) + nodeStatus := nl.GetNodeExecutionStatus(ctx, node.GetID()) parentNodeID := nodeStatus.GetParentNodeID() - upstreamNodes, ok := w.GetConnections().UpstreamEdges[nodeID] - if !ok { + upstreamNodes, err := dag.ToNode(nodeID) + if err != nil { return zeroTime, errors.Errorf(errors.BadSpecificationError, nodeID, "Unable to find upstream nodes for Node") } var latest v1.Time for _, upstreamNodeID := range upstreamNodes { - upstreamNodeStatus := w.GetNodeExecutionStatus(ctx, upstreamNodeID) + upstreamNodeStatus := nl.GetNodeExecutionStatus(ctx, upstreamNodeID) if parentNodeID != nil && *parentNodeID == upstreamNodeID { - upstreamNode, ok := w.GetNode(upstreamNodeID) + upstreamNode, ok := nl.GetNode(upstreamNodeID) if !ok { return zeroTime, errors.Errorf(errors.BadSpecificationError, nodeID, "Upstream node [%v] of node [%v] not defined", upstreamNodeID, nodeID) } // This only happens if current node is the child node of a branch node - if upstreamNode.GetBranchNode() == nil || upstreamNodeStatus.GetOrCreateBranchStatus().GetPhase() != v1alpha1.BranchNodeSuccess { + if upstreamNode.GetBranchNode() == nil || upstreamNodeStatus.GetBranchStatus().GetPhase() != v1alpha1.BranchNodeSuccess { logger.Debugf(ctx, "Branch sub node is expected to have parent branch node in succeeded state") return zeroTime, errors.Errorf(errors.IllegalStateError, nodeID, "Upstream node [%v] is set as parent, but is not a branch node of [%v] or in illegal state.", upstreamNodeID, nodeID) } diff --git a/pkg/controller/nodes/predicate_test.go b/pkg/controller/nodes/predicate_test.go index 3519936062..a23abe3e22 100644 --- a/pkg/controller/nodes/predicate_test.go +++ b/pkg/controller/nodes/predicate_test.go @@ -2,6 +2,7 @@ package nodes import ( "context" + "fmt" "testing" "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" @@ -14,18 +15,14 @@ func TestCanExecute(t *testing.T) { nodeN1 := "n1" nodeN2 := "n2" ctx := context.Background() - connections := &v1alpha1.Connections{ - UpstreamEdges: map[v1alpha1.NodeID][]v1alpha1.NodeID{ - nodeN2: {nodeN0, nodeN1}, - }, - } + upstreamN2 := []v1alpha1.NodeID{nodeN0, nodeN1} // Table tests are not really helpful here, so we decided against it t.Run("startNode", func(t *testing.T) { mockNode := &mocks.BaseNode{} mockNode.On("GetID").Return(v1alpha1.StartNodeID) - p, err := CanExecute(ctx, nil, mockNode) + p, err := CanExecute(ctx, nil, nil, mockNode) assert.NoError(t, err) assert.Equal(t, PredicatePhaseReady, p) }) @@ -34,15 +31,15 @@ func TestCanExecute(t *testing.T) { // Setup mockNodeStatus := &mocks.ExecutableNodeStatus{} // No parent node - mockNodeStatus.On("GetParentNodeID").Return(nil) + mockNodeStatus.OnGetParentNodeID().Return(nil) mockNode := &mocks.BaseNode{} - mockNode.On("GetID").Return(nodeN2) + mockNode.OnGetID().Return(nodeN2) mockWf := &mocks.ExecutableWorkflow{} mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockNodeStatus) - mockWf.On("GetConnections").Return(&v1alpha1.Connections{}) - mockWf.On("GetID").Return("w1") + mockWf.OnGetID().Return("w1") + mockWf.OnToNode("n2").Return(nil, fmt.Errorf("not found")) - p, err := CanExecute(ctx, mockWf, mockNode) + p, err := CanExecute(ctx, mockWf, mockWf, mockNode) assert.Error(t, err) assert.Equal(t, PredicatePhaseUndefined, p) }) @@ -51,27 +48,27 @@ func TestCanExecute(t *testing.T) { // Setup mockN2Status := &mocks.ExecutableNodeStatus{} // No parent node - mockN2Status.On("GetParentNodeID").Return(nil) - mockN2Status.On("IsDirty").Return(false) + mockN2Status.OnGetParentNodeID().Return(nil) + mockN2Status.OnIsDirty().Return(false) mockNode := &mocks.BaseNode{} - mockNode.On("GetID").Return(nodeN2) + mockNode.OnGetID().Return(nodeN2) mockN0Status := &mocks.ExecutableNodeStatus{} - mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) - mockN0Status.On("IsDirty").Return(false) + mockN0Status.OnGetPhase().Return(v1alpha1.NodePhaseRunning) + mockN0Status.OnIsDirty().Return(false) mockN1Status := &mocks.ExecutableNodeStatus{} - mockN1Status.On("GetPhase").Return(v1alpha1.NodePhaseRunning) - mockN1Status.On("IsDirty").Return(false) + mockN1Status.OnGetPhase().Return(v1alpha1.NodePhaseRunning) + mockN1Status.OnIsDirty().Return(false) mockWf := &mocks.ExecutableWorkflow{} mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) - mockWf.On("GetConnections").Return(connections) - mockWf.On("GetID").Return("w1") + mockWf.OnToNode(nodeN2).Return(upstreamN2, nil) + mockWf.OnGetID().Return("w1") - p, err := CanExecute(ctx, mockWf, mockNode) + p, err := CanExecute(ctx, mockWf, mockWf, mockNode) assert.NoError(t, err) assert.Equal(t, PredicatePhaseNotReady, p) }) @@ -98,10 +95,10 @@ func TestCanExecute(t *testing.T) { mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) - mockWf.On("GetConnections").Return(connections) + mockWf.OnToNode(nodeN2).Return(upstreamN2, nil) mockWf.On("GetID").Return("w1") - p, err := CanExecute(ctx, mockWf, mockNode) + p, err := CanExecute(ctx, mockWf, mockWf, mockNode) assert.NoError(t, err) assert.Equal(t, PredicatePhaseNotReady, p) }) @@ -128,10 +125,11 @@ func TestCanExecute(t *testing.T) { mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) - mockWf.On("GetConnections").Return(connections) + mockWf.OnToNode(nodeN2).Return(upstreamN2, nil) + mockWf.On("GetID").Return("w1") - p, err := CanExecute(ctx, mockWf, mockNode) + p, err := CanExecute(ctx, mockWf, mockWf, mockNode) assert.NoError(t, err) assert.Equal(t, PredicatePhaseReady, p) }) @@ -158,10 +156,10 @@ func TestCanExecute(t *testing.T) { mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) - mockWf.On("GetConnections").Return(connections) + mockWf.OnToNode(nodeN2).Return(upstreamN2, nil) mockWf.On("GetID").Return("w1") - p, err := CanExecute(ctx, mockWf, mockNode) + p, err := CanExecute(ctx, mockWf, mockWf, mockNode) assert.NoError(t, err) assert.Equal(t, PredicatePhaseNotReady, p) }) @@ -188,10 +186,10 @@ func TestCanExecute(t *testing.T) { mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) - mockWf.On("GetConnections").Return(connections) + mockWf.OnToNode(nodeN2).Return(upstreamN2, nil) mockWf.On("GetID").Return("w1") - p, err := CanExecute(ctx, mockWf, mockNode) + p, err := CanExecute(ctx, mockWf, mockWf, mockNode) assert.NoError(t, err) assert.Equal(t, PredicatePhaseNotReady, p) }) @@ -218,10 +216,10 @@ func TestCanExecute(t *testing.T) { mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) - mockWf.On("GetConnections").Return(connections) + mockWf.OnToNode(nodeN2).Return(upstreamN2, nil) mockWf.On("GetID").Return("w1") - p, err := CanExecute(ctx, mockWf, mockNode) + p, err := CanExecute(ctx, mockWf, mockWf, mockNode) assert.NoError(t, err) assert.Equal(t, PredicatePhaseSkip, p) }) @@ -248,10 +246,10 @@ func TestCanExecute(t *testing.T) { mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) - mockWf.On("GetConnections").Return(connections) + mockWf.OnToNode(nodeN2).Return(upstreamN2, nil) mockWf.On("GetID").Return("w1") - p, err := CanExecute(ctx, mockWf, mockNode) + p, err := CanExecute(ctx, mockWf, mockWf, mockNode) assert.NoError(t, err) assert.Equal(t, PredicatePhaseSkip, p) }) @@ -279,10 +277,10 @@ func TestCanExecute(t *testing.T) { mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) - mockWf.On("GetConnections").Return(connections) + mockWf.OnToNode(nodeN2).Return(upstreamN2, nil) mockWf.On("GetID").Return("w1") - p, err := CanExecute(ctx, mockWf, mockNode) + p, err := CanExecute(ctx, mockWf, mockWf, mockNode) assert.NoError(t, err) assert.Equal(t, PredicatePhaseNotReady, p) }) @@ -312,11 +310,11 @@ func TestCanExecute(t *testing.T) { mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) - mockWf.On("GetConnections").Return(connections) + mockWf.OnToNode(nodeN2).Return(upstreamN2, nil) mockWf.On("GetNode", nodeN0).Return(nil, false) mockWf.On("GetID").Return("w1") - p, err := CanExecute(ctx, mockWf, mockNode) + p, err := CanExecute(ctx, mockWf, mockWf, mockNode) assert.Error(t, err) assert.Equal(t, PredicatePhaseUndefined, p) }) @@ -346,11 +344,11 @@ func TestCanExecute(t *testing.T) { mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) - mockWf.On("GetConnections").Return(connections) + mockWf.OnToNode(nodeN2).Return(upstreamN2, nil) mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) mockWf.On("GetID").Return("w1") - p, err := CanExecute(ctx, mockWf, mockNode) + p, err := CanExecute(ctx, mockWf, mockWf, mockNode) assert.Error(t, err) assert.Equal(t, PredicatePhaseUndefined, p) }) @@ -373,7 +371,7 @@ func TestCanExecute(t *testing.T) { mockN0Node.On("GetBranchNode").Return(mockN0BranchNode) mockN0Status := &mocks.ExecutableNodeStatus{} mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) - mockN0Status.On("GetOrCreateBranchStatus").Return(mockN0BranchStatus) + mockN0Status.On("GetBranchStatus").Return(mockN0BranchStatus) mockN0Status.On("IsDirty").Return(false) mockN1Status := &mocks.ExecutableNodeStatus{} @@ -384,11 +382,11 @@ func TestCanExecute(t *testing.T) { mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) - mockWf.On("GetConnections").Return(connections) + mockWf.OnToNode(nodeN2).Return(upstreamN2, nil) mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) mockWf.On("GetID").Return("w1") - p, err := CanExecute(ctx, mockWf, mockNode) + p, err := CanExecute(ctx, mockWf, mockWf, mockNode) assert.Error(t, err) assert.Equal(t, PredicatePhaseUndefined, p) }) @@ -411,7 +409,7 @@ func TestCanExecute(t *testing.T) { mockN0Node.On("GetBranchNode").Return(mockN0BranchNode) mockN0Status := &mocks.ExecutableNodeStatus{} mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) - mockN0Status.On("GetOrCreateBranchStatus").Return(mockN0BranchStatus) + mockN0Status.On("GetBranchStatus").Return(mockN0BranchStatus) mockN0Status.On("IsDirty").Return(false) mockN1Status := &mocks.ExecutableNodeStatus{} @@ -422,11 +420,11 @@ func TestCanExecute(t *testing.T) { mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) - mockWf.On("GetConnections").Return(connections) + mockWf.OnToNode(nodeN2).Return(upstreamN2, nil) mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) mockWf.On("GetID").Return("w1") - p, err := CanExecute(ctx, mockWf, mockNode) + p, err := CanExecute(ctx, mockWf, mockWf, mockNode) assert.Error(t, err) assert.Equal(t, PredicatePhaseUndefined, p) }) @@ -450,7 +448,7 @@ func TestCanExecute(t *testing.T) { mockN0Node.On("GetBranchNode").Return(mockN0BranchNode) mockN0Status := &mocks.ExecutableNodeStatus{} mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) - mockN0Status.On("GetOrCreateBranchStatus").Return(mockN0BranchStatus) + mockN0Status.On("GetBranchStatus").Return(mockN0BranchStatus) mockN0Status.On("IsDirty").Return(false) mockN1Status := &mocks.ExecutableNodeStatus{} @@ -461,11 +459,11 @@ func TestCanExecute(t *testing.T) { mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) - mockWf.On("GetConnections").Return(connections) + mockWf.OnToNode(nodeN2).Return(upstreamN2, nil) mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) mockWf.On("GetID").Return("w1") - p, err := CanExecute(ctx, mockWf, mockNode) + p, err := CanExecute(ctx, mockWf, mockWf, mockNode) assert.NoError(t, err) assert.Equal(t, PredicatePhaseReady, p) }) @@ -489,7 +487,7 @@ func TestCanExecute(t *testing.T) { mockN0Node.On("GetBranchNode").Return(mockN0BranchNode) mockN0Status := &mocks.ExecutableNodeStatus{} mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) - mockN0Status.On("GetOrCreateBranchStatus").Return(mockN0BranchStatus) + mockN0Status.On("GetBranchStatus").Return(mockN0BranchStatus) mockN0Status.On("IsDirty").Return(false) mockN1Status := &mocks.ExecutableNodeStatus{} @@ -500,11 +498,11 @@ func TestCanExecute(t *testing.T) { mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) - mockWf.On("GetConnections").Return(connections) + mockWf.OnToNode(nodeN2).Return(upstreamN2, nil) mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) mockWf.On("GetID").Return("w1") - p, err := CanExecute(ctx, mockWf, mockNode) + p, err := CanExecute(ctx, mockWf, mockWf, mockNode) assert.NoError(t, err) assert.Equal(t, PredicatePhaseSkip, p) }) @@ -528,7 +526,7 @@ func TestCanExecute(t *testing.T) { mockN0Node.On("GetBranchNode").Return(mockN0BranchNode) mockN0Status := &mocks.ExecutableNodeStatus{} mockN0Status.On("GetPhase").Return(v1alpha1.NodePhaseSucceeded) - mockN0Status.On("GetOrCreateBranchStatus").Return(mockN0BranchStatus) + mockN0Status.On("GetBranchStatus").Return(mockN0BranchStatus) mockN0Status.On("IsDirty").Return(false) mockN1Status := &mocks.ExecutableNodeStatus{} @@ -539,11 +537,11 @@ func TestCanExecute(t *testing.T) { mockWf.OnGetNodeExecutionStatus(ctx, nodeN0).Return(mockN0Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN1).Return(mockN1Status) mockWf.OnGetNodeExecutionStatus(ctx, nodeN2).Return(mockN2Status) - mockWf.On("GetConnections").Return(connections) + mockWf.OnToNode(nodeN2).Return(upstreamN2, nil) mockWf.On("GetNode", nodeN0).Return(mockN0Node, true) mockWf.On("GetID").Return("w1") - p, err := CanExecute(ctx, mockWf, mockNode) + p, err := CanExecute(ctx, mockWf, mockWf, mockNode) assert.NoError(t, err) assert.Equal(t, PredicatePhaseNotReady, p) }) diff --git a/pkg/controller/nodes/resolve.go b/pkg/controller/nodes/resolve.go index 26db35feba..b95fc7f5af 100644 --- a/pkg/controller/nodes/resolve.go +++ b/pkg/controller/nodes/resolve.go @@ -5,11 +5,12 @@ import ( "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/controller/executors" "github.com/lyft/flytepropeller/pkg/controller/nodes/errors" "github.com/lyft/flytestdlib/logger" ) -func ResolveBindingData(ctx context.Context, outputResolver OutputResolver, w v1alpha1.BaseWorkflowWithStatus, bindingData *core.BindingData) (*core.Literal, error) { +func ResolveBindingData(ctx context.Context, outputResolver OutputResolver, nl executors.NodeLookup, bindingData *core.BindingData) (*core.Literal, error) { logger.Debugf(ctx, "Resolving binding data") literal := &core.Literal{} @@ -23,7 +24,7 @@ func ResolveBindingData(ctx context.Context, outputResolver OutputResolver, w v1 logger.Debugf(ctx, "bindingData.GetValue() [%v] is of type Collection", bindingData.GetValue()) literalCollection := make([]*core.Literal, 0, len(bindingData.GetCollection().GetBindings())) for _, b := range bindingData.GetCollection().GetBindings() { - l, err := ResolveBindingData(ctx, outputResolver, w, b) + l, err := ResolveBindingData(ctx, outputResolver, nl, b) if err != nil { logger.Debugf(ctx, "Failed to resolve binding data. Error: [%v]", err) return nil, err @@ -41,7 +42,7 @@ func ResolveBindingData(ctx context.Context, outputResolver OutputResolver, w v1 logger.Debugf(ctx, "bindingData.GetValue() [%v] is of type Map", bindingData.GetValue()) literalMap := make(map[string]*core.Literal, len(bindingData.GetMap().GetBindings())) for k, v := range bindingData.GetMap().GetBindings() { - l, err := ResolveBindingData(ctx, outputResolver, w, v) + l, err := ResolveBindingData(ctx, outputResolver, nl, v) if err != nil { logger.Debugf(ctx, "Failed to resolve binding data. Error: [%v]", err) return nil, err @@ -61,7 +62,7 @@ func ResolveBindingData(ctx context.Context, outputResolver OutputResolver, w v1 upstreamNodeID := bindingData.GetPromise().GetNodeId() bindToVar := bindingData.GetPromise().GetVar() - if w == nil { + if nl == nil { return nil, errors.Errorf(errors.IllegalStateError, upstreamNodeID, "Trying to resolve output from previous node, without providing the workflow for variable [%s]", bindToVar) @@ -72,13 +73,13 @@ func ResolveBindingData(ctx context.Context, outputResolver OutputResolver, w v1 "No nodeId (missing) specified for binding in Workflow.") } - n, ok := w.GetNode(upstreamNodeID) + n, ok := nl.GetNode(upstreamNodeID) if !ok { - return nil, errors.Errorf(errors.IllegalStateError, w.GetID(), upstreamNodeID, + return nil, errors.Errorf(errors.IllegalStateError, "id", upstreamNodeID, "Undefined node in Workflow") } - return outputResolver.ExtractOutput(ctx, w, n, bindToVar) + return outputResolver.ExtractOutput(ctx, nl, n, bindToVar) case *core.BindingData_Scalar: logger.Debugf(ctx, "bindingData.GetValue() [%v] is of type Scalar", bindingData.GetValue()) literal.Value = &core.Literal_Scalar{Scalar: bindingData.GetScalar()} @@ -86,15 +87,15 @@ func ResolveBindingData(ctx context.Context, outputResolver OutputResolver, w v1 return literal, nil } -func Resolve(ctx context.Context, outputResolver OutputResolver, w v1alpha1.BaseWorkflowWithStatus, nodeID v1alpha1.NodeID, bindings []*v1alpha1.Binding) (*core.LiteralMap, error) { +func Resolve(ctx context.Context, outputResolver OutputResolver, nl executors.NodeLookup, nodeID v1alpha1.NodeID, bindings []*v1alpha1.Binding) (*core.LiteralMap, error) { logger.Debugf(ctx, "bindings: [%v]", bindings) literalMap := make(map[string]*core.Literal, len(bindings)) for _, binding := range bindings { logger.Debugf(ctx, "Resolving binding: [%v]", binding) varName := binding.GetVar() - l, err := ResolveBindingData(ctx, outputResolver, w, binding.GetBinding()) + l, err := ResolveBindingData(ctx, outputResolver, nl, binding.GetBinding()) if err != nil { - return nil, errors.Wrapf(errors.BindingResolutionError, nodeID, err, "Error binding Var [%v].[%v]", w.GetID(), binding.GetVar()) + return nil, errors.Wrapf(errors.BindingResolutionError, nodeID, err, "Error binding Var [%v].[%v]", "wf", binding.GetVar()) } literalMap[varName] = l diff --git a/pkg/controller/nodes/resolve_test.go b/pkg/controller/nodes/resolve_test.go index 817817456a..46649e86aa 100644 --- a/pkg/controller/nodes/resolve_test.go +++ b/pkg/controller/nodes/resolve_test.go @@ -22,6 +22,7 @@ var testScope = promutils.NewScope("test") type dummyBaseWorkflow struct { DummyStartNode v1alpha1.ExecutableNode ID v1alpha1.WorkflowID + ToNodeCb func(name v1alpha1.NodeID) ([]v1alpha1.NodeID, error) FromNodeCb func(name v1alpha1.NodeID) ([]v1alpha1.NodeID, error) GetNodeCb func(nodeId v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) Status map[v1alpha1.NodeID]*v1alpha1.NodeStatus @@ -29,6 +30,10 @@ type dummyBaseWorkflow struct { Interruptible bool } +func (d *dummyBaseWorkflow) ToNode(name v1alpha1.NodeID) ([]v1alpha1.NodeID, error) { + return d.ToNodeCb(name) +} + func (d *dummyBaseWorkflow) GetOutputBindings() []*v1alpha1.Binding { return []*v1alpha1.Binding{} } diff --git a/pkg/controller/nodes/subworkflow/handler.go b/pkg/controller/nodes/subworkflow/handler.go index 2b5e3852a8..e0d2b8aa93 100644 --- a/pkg/controller/nodes/subworkflow/handler.go +++ b/pkg/controller/nodes/subworkflow/handler.go @@ -80,9 +80,7 @@ func (w *workflowNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecu } if wfNode.GetSubWorkflowRef() != nil { - wf := nCtx.Workflow() - status := wf.GetNodeExecutionStatus(ctx, nCtx.NodeID()) - return w.subWfHandler.CheckSubWorkflowStatus(ctx, nCtx, wf, status) + return w.subWfHandler.CheckSubWorkflowStatus(ctx, nCtx) } else if wfNode.GetLaunchPlanRefID() != nil { return w.lpHandler.CheckLaunchPlanStatus(ctx, nCtx) } @@ -91,14 +89,13 @@ func (w *workflowNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecu } func (w *workflowNodeHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { - wf := nCtx.Workflow() wfNode := nCtx.Node().GetWorkflowNode() if wfNode.GetSubWorkflowRef() != nil { - return w.subWfHandler.HandleAbort(ctx, nCtx, wf, *wfNode.GetSubWorkflowRef(), reason) + return w.subWfHandler.HandleAbort(ctx, nCtx, reason) } if wfNode.GetLaunchPlanRefID() != nil { - return w.lpHandler.HandleAbort(ctx, wf, nCtx.Node(), reason) + return w.lpHandler.HandleAbort(ctx, nCtx, reason) } return nil } diff --git a/pkg/controller/nodes/subworkflow/handler_test.go b/pkg/controller/nodes/subworkflow/handler_test.go index 029fa3f7b0..7a1c871147 100644 --- a/pkg/controller/nodes/subworkflow/handler_test.go +++ b/pkg/controller/nodes/subworkflow/handler_test.go @@ -9,12 +9,6 @@ import ( "github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" mocks4 "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io/mocks" - "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - mocks2 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" - "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" - mocks3 "github.com/lyft/flytepropeller/pkg/controller/nodes/handler/mocks" - "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" - "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks" "github.com/lyft/flytestdlib/contextutils" "github.com/lyft/flytestdlib/promutils" "github.com/lyft/flytestdlib/promutils/labeled" @@ -23,6 +17,14 @@ import ( "github.com/stretchr/testify/mock" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + + "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + mocks2 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + execMocks "github.com/lyft/flytepropeller/pkg/controller/executors/mocks" + "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" + mocks3 "github.com/lyft/flytepropeller/pkg/controller/nodes/handler/mocks" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" + "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks" ) type workflowNodeStateHolder struct { @@ -46,63 +48,58 @@ func (t workflowNodeStateHolder) PutDynamicNodeState(s handler.DynamicNodeState) panic("not implemented") } -func createNodeContext(phase v1alpha1.WorkflowNodePhase, w v1alpha1.ExecutableWorkflow, n v1alpha1.ExecutableNode) *mocks3.NodeExecutionContext { +var wfExecID = &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", +} - wfNodeState := handler.WorkflowNodeState{} - s := &workflowNodeStateHolder{s: wfNodeState} +func createNodeContext(phase v1alpha1.WorkflowNodePhase, n v1alpha1.ExecutableNode, s v1alpha1.ExecutableNodeStatus) *mocks3.NodeExecutionContext { - wfExecID := &core.WorkflowExecutionIdentifier{ - Project: "project", - Domain: "domain", - Name: "name", - } + wfNodeState := handler.WorkflowNodeState{} + state := &workflowNodeStateHolder{s: wfNodeState} nm := &mocks3.NodeExecutionMetadata{} - nm.On("GetAnnotations").Return(map[string]string{}) - nm.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: wfExecID, + nm.OnGetAnnotations().Return(map[string]string{}) + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ + ExecutionId: wfExecID, + NodeId: n.GetID(), }) - nm.On("GetK8sServiceAccount").Return("service-account") - nm.On("GetLabels").Return(map[string]string{}) - nm.On("GetNamespace").Return("namespace") - nm.On("GetOwnerID").Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) - nm.On("GetOwnerReference").Return(v1.OwnerReference{ + nm.OnGetK8sServiceAccount().Return("service-account") + nm.OnGetLabels().Return(map[string]string{}) + nm.OnGetNamespace().Return("namespace") + nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.OnGetOwnerReference().Return(v1.OwnerReference{ Kind: "sample", Name: "name", }) - ns := &mocks2.ExecutableNodeStatus{} - ns.On("GetDataDir").Return(storage.DataReference("data-dir")) - ns.On("GetPhase").Return(v1alpha1.NodePhaseNotYetStarted) - ir := &mocks4.InputReader{} inputs := &core.LiteralMap{} - ir.On("Get", mock.Anything).Return(inputs, nil) + ir.OnGetMatch(mock.Anything).Return(inputs, nil) nCtx := &mocks3.NodeExecutionContext{} - nCtx.On("Node").Return(n) - nCtx.On("NodeExecutionMetadata").Return(nm) - nCtx.On("InputReader").Return(ir) - nCtx.On("CurrentAttempt").Return(uint32(1)) - nCtx.On("MaxDatasetSizeBytes").Return(int64(1)) - nCtx.On("NodeStatus").Return(ns) - nCtx.On("NodeID").Return("n1") - nCtx.On("EnqueueOwner").Return(nil) - nCtx.On("Workflow").Return(w) + nCtx.OnNode().Return(n) + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnInputReader().Return(ir) + nCtx.OnCurrentAttempt().Return(uint32(1)) + nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) + nCtx.OnNodeID().Return(n.GetID()) + nCtx.OnEnqueueOwnerFunc().Return(nil) + nCtx.OnNodeStatus().Return(s) nr := &mocks3.NodeStateReader{} - nr.On("GetWorkflowNodeState").Return(handler.WorkflowNodeState{ + nr.OnGetWorkflowNodeState().Return(handler.WorkflowNodeState{ Phase: phase, }) - nCtx.On("NodeStateReader").Return(nr) - nCtx.On("NodeStateWriter").Return(s) + nCtx.OnNodeStateReader().Return(nr) + nCtx.OnNodeStateWriter().Return(state) return nCtx } func TestWorkflowNodeHandler_StartNode_Launchplan(t *testing.T) { ctx := context.TODO() - nodeID := "n1" attempts := uint32(1) lpID := &core.Identifier{ @@ -113,48 +110,38 @@ func TestWorkflowNodeHandler_StartNode_Launchplan(t *testing.T) { ResourceType: core.ResourceType_LAUNCH_PLAN, } mockWfNode := &mocks2.ExecutableWorkflowNode{} - mockWfNode.On("GetLaunchPlanRefID").Return(&v1alpha1.Identifier{ + mockWfNode.OnGetLaunchPlanRefID().Return(&v1alpha1.Identifier{ Identifier: lpID, }) - mockWfNode.On("GetSubWorkflowRef").Return(nil) + mockWfNode.OnGetSubWorkflowRef().Return(nil) mockNode := &mocks2.ExecutableNode{} - mockNode.On("GetID").Return("n1") - mockNode.On("GetWorkflowNode").Return(mockWfNode) + mockNode.OnGetID().Return("n1") + mockNode.OnGetWorkflowNode().Return(mockWfNode) mockNodeStatus := &mocks2.ExecutableNodeStatus{} - mockNodeStatus.On("GetAttempts").Return(attempts) + mockNodeStatus.OnGetAttempts().Return(attempts) wfStatus := &mocks2.MutableWorkflowNodeStatus{} - mockNodeStatus.On("GetOrCreateWorkflowStatus").Return(wfStatus) - parentID := &core.WorkflowExecutionIdentifier{ - Name: "x", - Domain: "y", - Project: "z", - } - mockWf := &mocks2.ExecutableWorkflow{} - mockWf.OnGetNodeExecutionStatus(ctx, nodeID).Return(mockNodeStatus) - mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: parentID, - }) + mockNodeStatus.OnGetOrCreateWorkflowStatus().Return(wfStatus) t.Run("happy", func(t *testing.T) { mockLPExec := &mocks.Executor{} h := New(nil, mockLPExec, promutils.NewTestScope()) - mockLPExec.On("Launch", + mockLPExec.OnLaunchMatch( ctx, mock.MatchedBy(func(o launchplan.LaunchContext) bool { return o.ParentNodeExecution.NodeId == mockNode.GetID() && - o.ParentNodeExecution.ExecutionId == parentID + o.ParentNodeExecution.ExecutionId == wfExecID }), mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), mock.MatchedBy(func(o *core.Identifier) bool { return lpID == o }), mock.MatchedBy(func(o *core.LiteralMap) bool { return o.Literals == nil }), ).Return(nil) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseUndefined, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseUndefined, mockNode, mockNodeStatus) s, err := h.Handle(ctx, nCtx) assert.NoError(t, err) assert.Equal(t, handler.EPhaseRunning, s.Info().GetPhase()) @@ -164,7 +151,6 @@ func TestWorkflowNodeHandler_StartNode_Launchplan(t *testing.T) { func TestWorkflowNodeHandler_CheckNodeStatus(t *testing.T) { ctx := context.TODO() - nodeID := "n1" attempts := uint32(1) dataDir := storage.DataReference("data") @@ -176,45 +162,34 @@ func TestWorkflowNodeHandler_CheckNodeStatus(t *testing.T) { ResourceType: core.ResourceType_LAUNCH_PLAN, } mockWfNode := &mocks2.ExecutableWorkflowNode{} - mockWfNode.On("GetLaunchPlanRefID").Return(&v1alpha1.Identifier{ + mockWfNode.OnGetLaunchPlanRefID().Return(&v1alpha1.Identifier{ Identifier: lpID, }) - mockWfNode.On("GetSubWorkflowRef").Return(nil) + mockWfNode.OnGetSubWorkflowRef().Return(nil) mockNode := &mocks2.ExecutableNode{} - mockNode.On("GetID").Return("n1") - mockNode.On("GetWorkflowNode").Return(mockWfNode) + mockNode.OnGetID().Return("n1") + mockNode.OnGetWorkflowNode().Return(mockWfNode) mockNodeStatus := &mocks2.ExecutableNodeStatus{} - mockNodeStatus.On("GetAttempts").Return(attempts) - mockNodeStatus.On("GetDataDir").Return(dataDir) - - parentID := &core.WorkflowExecutionIdentifier{ - Name: "x", - Domain: "y", - Project: "z", - } - mockWf := &mocks2.ExecutableWorkflow{} - mockWf.OnGetNodeExecutionStatus(ctx, nodeID).Return(mockNodeStatus) - mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: parentID, - }) + mockNodeStatus.OnGetAttempts().Return(attempts) + mockNodeStatus.OnGetDataDir().Return(dataDir) t.Run("stillRunning", func(t *testing.T) { mockLPExec := &mocks.Executor{} h := New(nil, mockLPExec, promutils.NewTestScope()) - mockLPExec.On("GetStatus", + mockLPExec.OnGetStatusMatch( ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_RUNNING, }, nil) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) s, err := h.Handle(ctx, nCtx) assert.NoError(t, err) assert.Equal(t, handler.EPhaseRunning, s.Info().GetPhase()) @@ -224,7 +199,6 @@ func TestWorkflowNodeHandler_CheckNodeStatus(t *testing.T) { func TestWorkflowNodeHandler_AbortNode(t *testing.T) { ctx := context.TODO() - nodeID := "n1" attempts := uint32(1) dataDir := storage.DataReference("data") @@ -236,45 +210,36 @@ func TestWorkflowNodeHandler_AbortNode(t *testing.T) { ResourceType: core.ResourceType_LAUNCH_PLAN, } mockWfNode := &mocks2.ExecutableWorkflowNode{} - mockWfNode.On("GetLaunchPlanRefID").Return(&v1alpha1.Identifier{ + mockWfNode.OnGetLaunchPlanRefID().Return(&v1alpha1.Identifier{ Identifier: lpID, }) - mockWfNode.On("GetSubWorkflowRef").Return(nil) + mockWfNode.OnGetSubWorkflowRef().Return(nil) mockNode := &mocks2.ExecutableNode{} - mockNode.On("GetID").Return("n1") - mockNode.On("GetWorkflowNode").Return(mockWfNode) + mockNode.OnGetID().Return("n1") + mockNode.OnGetWorkflowNode().Return(mockWfNode) mockNodeStatus := &mocks2.ExecutableNodeStatus{} - mockNodeStatus.On("GetAttempts").Return(attempts) - mockNodeStatus.On("GetDataDir").Return(dataDir) - - parentID := &core.WorkflowExecutionIdentifier{ - Name: "x", - Domain: "y", - Project: "z", - } - mockWf := &mocks2.ExecutableWorkflow{} - mockWf.OnGetNodeExecutionStatus(ctx, nodeID).Return(mockNodeStatus) - mockWf.On("GetName").Return("test") - mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: parentID, - }) + mockNodeStatus.OnGetAttempts().Return(attempts) + mockNodeStatus.OnGetDataDir().Return(dataDir) t.Run("abort", func(t *testing.T) { mockLPExec := &mocks.Executor{} - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) h := New(nil, mockLPExec, promutils.NewTestScope()) - mockLPExec.On("Kill", + mockLPExec.OnKillMatch( ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), mock.AnythingOfType(reflect.String.String()), ).Return(nil) + eCtx := &execMocks.ExecutionContext{} + nCtx.OnExecutionContext().Return(eCtx) + eCtx.OnGetName().Return("test") err := h.Abort(ctx, nCtx, "test") assert.NoError(t, err) }) @@ -284,15 +249,19 @@ func TestWorkflowNodeHandler_AbortNode(t *testing.T) { mockLPExec := &mocks.Executor{} expectedErr := fmt.Errorf("fail") h := New(nil, mockLPExec, promutils.NewTestScope()) - mockLPExec.On("Kill", + mockLPExec.OnKillMatch( ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), mock.AnythingOfType(reflect.String.String()), ).Return(expectedErr) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) + eCtx := &execMocks.ExecutionContext{} + nCtx.OnExecutionContext().Return(eCtx) + eCtx.OnGetName().Return("test") + err := h.Abort(ctx, nCtx, "test") assert.Error(t, err) assert.Equal(t, err, expectedErr) diff --git a/pkg/controller/nodes/subworkflow/launchplan.go b/pkg/controller/nodes/subworkflow/launchplan.go index 90c1a98195..172757f2aa 100644 --- a/pkg/controller/nodes/subworkflow/launchplan.go +++ b/pkg/controller/nodes/subworkflow/launchplan.go @@ -25,12 +25,9 @@ func (l *launchPlanHandler) StartLaunchPlan(ctx context.Context, nCtx handler.No return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.RuntimeExecutionError, errMsg, nil)), nil } - w := nCtx.Workflow() - nodeStatus := w.GetNodeExecutionStatus(ctx, nCtx.NodeID()) childID, err := GetChildWorkflowExecutionID( - w.GetExecutionID().WorkflowExecutionIdentifier, - nCtx.NodeID(), - nodeStatus.GetAttempts(), + nCtx.NodeExecutionMetadata().GetNodeExecutionID(), + nCtx.CurrentAttempt(), ) if err != nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.RuntimeExecutionError, "failed to create unique ID", nil)), nil @@ -38,12 +35,9 @@ func (l *launchPlanHandler) StartLaunchPlan(ctx context.Context, nCtx handler.No launchCtx := launchplan.LaunchContext{ // TODO we need to add principal and nestinglevel as annotations or labels? - Principal: "unknown", - NestingLevel: 0, - ParentNodeExecution: &core.NodeExecutionIdentifier{ - NodeId: nCtx.NodeID(), - ExecutionId: w.GetExecutionID().WorkflowExecutionIdentifier, - }, + Principal: "unknown", + NestingLevel: 0, + ParentNodeExecution: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), } err = l.launchPlan.Launch(ctx, launchCtx, childID, nCtx.Node().GetWorkflowNode().GetLaunchPlanRefID().Identifier, nodeInputs) if err != nil { @@ -68,12 +62,9 @@ func (l *launchPlanHandler) StartLaunchPlan(ctx context.Context, nCtx handler.No func (l *launchPlanHandler) CheckLaunchPlanStatus(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { // Handle launch plan - w := nCtx.Workflow() - nodeStatus := w.GetNodeExecutionStatus(ctx, nCtx.NodeID()) childID, err := GetChildWorkflowExecutionID( - w.GetExecutionID().WorkflowExecutionIdentifier, - nCtx.NodeID(), - nodeStatus.GetAttempts(), + nCtx.NodeExecutionMetadata().GetNodeExecutionID(), + nCtx.CurrentAttempt(), ) if err != nil { @@ -123,7 +114,7 @@ func (l *launchPlanHandler) CheckLaunchPlanStatus(ctx context.Context, nCtx hand // nCtx.Node().GetOutputAlias() var oInfo *handler.OutputInfo if wfStatusClosure.GetOutputs() != nil { - outputFile := v1alpha1.GetOutputsFile(nodeStatus.GetOutputDir()) + outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) if wfStatusClosure.GetOutputs().GetUri() != "" { uri := wfStatusClosure.GetOutputs().GetUri() store := nCtx.DataStore() @@ -149,16 +140,14 @@ func (l *launchPlanHandler) CheckLaunchPlanStatus(ctx context.Context, nCtx hand return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil } -func (l *launchPlanHandler) HandleAbort(ctx context.Context, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode, reason string) error { - nodeStatus := w.GetNodeExecutionStatus(ctx, node.GetID()) +func (l *launchPlanHandler) HandleAbort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { childID, err := GetChildWorkflowExecutionID( - w.GetExecutionID().WorkflowExecutionIdentifier, - node.GetID(), - nodeStatus.GetAttempts(), + nCtx.NodeExecutionMetadata().GetNodeExecutionID(), + nCtx.CurrentAttempt(), ) if err != nil { // THIS SHOULD NEVER HAPPEN return err } - return l.launchPlan.Kill(ctx, childID, fmt.Sprintf("parent execution id [%s] aborted, reason [%s]", w.GetName(), reason)) + return l.launchPlan.Kill(ctx, childID, fmt.Sprintf("cascading abort as parent execution id [%s] aborted, reason [%s]", nCtx.ExecutionContext().GetName(), reason)) } diff --git a/pkg/controller/nodes/subworkflow/launchplan_test.go b/pkg/controller/nodes/subworkflow/launchplan_test.go index e5ddefd0ca..861518de3f 100644 --- a/pkg/controller/nodes/subworkflow/launchplan_test.go +++ b/pkg/controller/nodes/subworkflow/launchplan_test.go @@ -17,6 +17,7 @@ import ( "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" mocks2 "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + execMocks "github.com/lyft/flytepropeller/pkg/controller/executors/mocks" "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/lyft/flytepropeller/pkg/controller/nodes/subworkflow/launchplan/mocks" @@ -37,7 +38,6 @@ func createInmemoryStore(t testing.TB) *storage.DataStore { func TestSubWorkflowHandler_StartLaunchPlan(t *testing.T) { ctx := context.TODO() - nodeID := "n1" attempts := uint32(1) lpID := &core.Identifier{ @@ -59,17 +59,6 @@ func TestSubWorkflowHandler_StartLaunchPlan(t *testing.T) { mockNodeStatus := &mocks2.ExecutableNodeStatus{} mockNodeStatus.On("GetAttempts").Return(attempts) - parentID := &core.WorkflowExecutionIdentifier{ - Name: "x", - Domain: "y", - Project: "z", - } - mockWf := &mocks2.ExecutableWorkflow{} - mockWf.OnGetNodeExecutionStatus(ctx, nodeID).Return(mockNodeStatus) - mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: parentID, - }) - t.Run("happy", func(t *testing.T) { mockLPExec := &mocks.Executor{} @@ -81,10 +70,10 @@ func TestSubWorkflowHandler_StartLaunchPlan(t *testing.T) { ctx, mock.MatchedBy(func(o launchplan.LaunchContext) bool { return o.ParentNodeExecution.NodeId == mockNode.GetID() && - o.ParentNodeExecution.ExecutionId == parentID + o.ParentNodeExecution.ExecutionId == wfExecID }), mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), mock.MatchedBy(func(o *core.Identifier) bool { return lpID == o }), mock.MatchedBy(func(o *core.LiteralMap) bool { return o.Literals == nil }), @@ -93,7 +82,7 @@ func TestSubWorkflowHandler_StartLaunchPlan(t *testing.T) { wfStatus := &mocks2.MutableWorkflowNodeStatus{} mockNodeStatus.On("GetOrCreateWorkflowStatus").Return(wfStatus) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseUndefined, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseUndefined, mockNode, mockNodeStatus) s, err := h.StartLaunchPlan(ctx, nCtx) assert.NoError(t, err) assert.Equal(t, s.Info().GetPhase(), handler.EPhaseRunning) @@ -110,16 +99,16 @@ func TestSubWorkflowHandler_StartLaunchPlan(t *testing.T) { ctx, mock.MatchedBy(func(o launchplan.LaunchContext) bool { return o.ParentNodeExecution.NodeId == mockNode.GetID() && - o.ParentNodeExecution.ExecutionId == parentID + o.ParentNodeExecution.ExecutionId == wfExecID }), mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), mock.MatchedBy(func(o *core.Identifier) bool { return lpID == o }), mock.MatchedBy(func(o *core.LiteralMap) bool { return o.Literals == nil }), ).Return(errors.Wrapf(launchplan.RemoteErrorAlreadyExists, fmt.Errorf("blah"), "failed")) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseUndefined, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseUndefined, mockNode, mockNodeStatus) s, err := h.StartLaunchPlan(ctx, nCtx) assert.NoError(t, err) assert.Equal(t, s.Info().GetPhase(), handler.EPhaseRunning) @@ -136,16 +125,16 @@ func TestSubWorkflowHandler_StartLaunchPlan(t *testing.T) { ctx, mock.MatchedBy(func(o launchplan.LaunchContext) bool { return o.ParentNodeExecution.NodeId == mockNode.GetID() && - o.ParentNodeExecution.ExecutionId == parentID + o.ParentNodeExecution.ExecutionId == wfExecID }), mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), mock.MatchedBy(func(o *core.Identifier) bool { return lpID == o }), mock.MatchedBy(func(o *core.LiteralMap) bool { return o.Literals == nil }), ).Return(errors.Wrapf(launchplan.RemoteErrorSystem, fmt.Errorf("blah"), "failed")) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) s, err := h.StartLaunchPlan(ctx, nCtx) assert.Error(t, err) assert.Equal(t, handler.EPhaseUndefined, s.Info().GetPhase()) @@ -162,16 +151,16 @@ func TestSubWorkflowHandler_StartLaunchPlan(t *testing.T) { ctx, mock.MatchedBy(func(o launchplan.LaunchContext) bool { return o.ParentNodeExecution.NodeId == mockNode.GetID() && - o.ParentNodeExecution.ExecutionId == parentID + o.ParentNodeExecution.ExecutionId == wfExecID }), mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), mock.MatchedBy(func(o *core.Identifier) bool { return lpID == o }), mock.MatchedBy(func(o *core.LiteralMap) bool { return o.Literals == nil }), ).Return(errors.Wrapf(launchplan.RemoteErrorUser, fmt.Errorf("blah"), "failed")) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) s, err := h.StartLaunchPlan(ctx, nCtx) assert.NoError(t, err) assert.Equal(t, handler.EPhaseFailed, s.Info().GetPhase()) @@ -182,7 +171,6 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { ctx := context.TODO() - nodeID := "n1" attempts := uint32(1) dataDir := storage.DataReference("data") @@ -207,17 +195,6 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mockNodeStatus.On("GetDataDir").Return(dataDir) mockNodeStatus.On("GetOutputDir").Return(dataDir) - parentID := &core.WorkflowExecutionIdentifier{ - Name: "x", - Domain: "y", - Project: "z", - } - mockWf := &mocks2.ExecutableWorkflow{} - mockWf.OnGetNodeExecutionStatus(ctx, nodeID).Return(mockNodeStatus) - mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: parentID, - }) - t.Run("stillRunning", func(t *testing.T) { mockLPExec := &mocks.Executor{} @@ -228,13 +205,13 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mockLPExec.On("GetStatus", ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_RUNNING, }, nil) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) s, err := h.CheckLaunchPlanStatus(ctx, nCtx) assert.NoError(t, err) @@ -251,13 +228,13 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mockLPExec.On("GetStatus", ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_SUCCEEDED, }, nil) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) s, err := h.CheckLaunchPlanStatus(ctx, nCtx) assert.NoError(t, err) assert.Equal(t, handler.EPhaseSuccess, s.Info().GetPhase()) @@ -285,7 +262,7 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mockLPExec.On("GetStatus", ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_SUCCEEDED, @@ -298,14 +275,13 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { }, }, nil) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) nCtx.OnDataStore().Return(mockStore) s, err := h.CheckLaunchPlanStatus(ctx, nCtx) - mockNodeStatus.AssertCalled(t, "GetOutputDir") assert.NoError(t, err) assert.Equal(t, handler.EPhaseSuccess, s.Info().GetPhase()) final := &core.LiteralMap{} - assert.NoError(t, mockStore.ReadProtobuf(ctx, v1alpha1.GetOutputsFile(dataDir), final)) + assert.NoError(t, mockStore.ReadProtobuf(ctx, v1alpha1.GetOutputsFile(dataDir), final), mockStore) v, ok := final.GetLiterals()["x"] assert.True(t, ok) assert.Equal(t, int64(1), v.GetScalar().GetPrimitive().GetInteger()) @@ -327,7 +303,7 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mockLPExec.On("GetStatus", ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_SUCCEEDED, @@ -340,7 +316,7 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { }, }, nil) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) nCtx.OnDataStore().Return(mockStore) s, err := h.CheckLaunchPlanStatus(ctx, nCtx) assert.NoError(t, err) @@ -363,7 +339,7 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mockLPExec.On("GetStatus", ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_FAILED, @@ -375,7 +351,7 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { }, }, nil) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) s, err := h.CheckLaunchPlanStatus(ctx, nCtx) assert.NoError(t, err) assert.Equal(t, s.Info().GetPhase(), handler.EPhaseFailed) @@ -392,13 +368,13 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mockLPExec.On("GetStatus", ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_FAILED, }, nil) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) s, err := h.CheckLaunchPlanStatus(ctx, nCtx) assert.NoError(t, err) assert.Equal(t, s.Info().GetPhase(), handler.EPhaseFailed) @@ -415,13 +391,13 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mockLPExec.On("GetStatus", ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_ABORTED, }, nil) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) s, err := h.CheckLaunchPlanStatus(ctx, nCtx) assert.NoError(t, err) assert.Equal(t, s.Info().GetPhase(), handler.EPhaseFailed) @@ -438,11 +414,11 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mockLPExec.On("GetStatus", ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), ).Return(nil, errors.Wrapf(launchplan.RemoteErrorNotFound, fmt.Errorf("some error"), "not found")) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) s, err := h.CheckLaunchPlanStatus(ctx, nCtx) assert.NoError(t, err) assert.Equal(t, s.Info().GetPhase(), handler.EPhaseFailed) @@ -459,11 +435,11 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mockLPExec.On("GetStatus", ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), ).Return(nil, errors.Wrapf(launchplan.RemoteErrorSystem, fmt.Errorf("some error"), "not found")) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) s, err := h.CheckLaunchPlanStatus(ctx, nCtx) assert.Error(t, err) assert.Equal(t, s.Info().GetPhase(), handler.EPhaseUndefined) @@ -486,7 +462,7 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mockLPExec.On("GetStatus", ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_SUCCEEDED, @@ -499,7 +475,7 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { }, }, nil) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) nCtx.OnDataStore().Return(mockStore) s, err := h.CheckLaunchPlanStatus(ctx, nCtx) assert.Error(t, err) @@ -519,7 +495,7 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mockLPExec.On("GetStatus", ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_SUCCEEDED, @@ -532,7 +508,7 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { }, }, nil) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) nCtx.OnDataStore().Return(mockStore) s, err := h.CheckLaunchPlanStatus(ctx, nCtx) assert.NotNil(t, err) @@ -552,7 +528,7 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { mockLPExec.On("GetStatus", ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), ).Return(&admin.ExecutionClosure{ Phase: core.WorkflowExecution_SUCCEEDED, @@ -565,7 +541,7 @@ func TestSubWorkflowHandler_CheckLaunchPlanStatus(t *testing.T) { }, }, nil) - nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockWf, mockNode) + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) nCtx.OnDataStore().Return(mockStore) s, err := h.CheckLaunchPlanStatus(ctx, nCtx) assert.Error(t, err) @@ -601,24 +577,12 @@ func TestLaunchPlanHandler_HandleAbort(t *testing.T) { mockNodeStatus.On("GetAttempts").Return(attempts) mockNodeStatus.On("GetDataDir").Return(dataDir) - parentID := &core.WorkflowExecutionIdentifier{ - Name: "x", - Domain: "y", - Project: "z", - } - mockWf := &mocks2.ExecutableWorkflow{} - mockWf.On("GetName").Return("test") - mockWf.OnGetNodeExecutionStatus(ctx, nodeID).Return(mockNodeStatus) - mockWf.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: parentID, - }) - t.Run("abort-success", func(t *testing.T) { mockLPExec := &mocks.Executor{} mockLPExec.On("Kill", ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), mock.AnythingOfType(reflect.String.String()), ).Return(nil) @@ -626,7 +590,11 @@ func TestLaunchPlanHandler_HandleAbort(t *testing.T) { h := launchPlanHandler{ launchPlan: mockLPExec, } - err := h.HandleAbort(ctx, mockWf, mockNode, "some reason") + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) + eCtx := &execMocks.ExecutionContext{} + eCtx.OnGetName().Return("name") + nCtx.OnExecutionContext().Return(eCtx) + err := h.HandleAbort(ctx, nCtx, "some reason") assert.NoError(t, err) }) @@ -636,7 +604,7 @@ func TestLaunchPlanHandler_HandleAbort(t *testing.T) { mockLPExec.On("Kill", ctx, mock.MatchedBy(func(o *core.WorkflowExecutionIdentifier) bool { - return o.Project == parentID.Project && o.Domain == parentID.Domain + return assert.Equal(t, wfExecID.Project, o.Project) && assert.Equal(t, wfExecID.Domain, o.Domain) }), mock.AnythingOfType(reflect.String.String()), ).Return(expectedErr) @@ -644,7 +612,11 @@ func TestLaunchPlanHandler_HandleAbort(t *testing.T) { h := launchPlanHandler{ launchPlan: mockLPExec, } - err := h.HandleAbort(ctx, mockWf, mockNode, "reason") + nCtx := createNodeContext(v1alpha1.WorkflowNodePhaseExecuting, mockNode, mockNodeStatus) + eCtx := &execMocks.ExecutionContext{} + eCtx.OnGetName().Return("name") + nCtx.OnExecutionContext().Return(eCtx) + err := h.HandleAbort(ctx, nCtx, "reason") assert.Error(t, err) assert.Equal(t, err, expectedErr) }) diff --git a/pkg/controller/nodes/subworkflow/subworkflow.go b/pkg/controller/nodes/subworkflow/subworkflow.go index c11635b9cc..082425d919 100644 --- a/pkg/controller/nodes/subworkflow/subworkflow.go +++ b/pkg/controller/nodes/subworkflow/subworkflow.go @@ -12,25 +12,57 @@ import ( "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" ) -// TODO Add unit tests for subworkflow handler - -// Subworkflow handler handles inline subworkflows +// Subworkflow handler handles inline subWorkflows type subworkflowHandler struct { nodeExecutor executors.Node } -func (s *subworkflowHandler) DoInlineSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext, w v1alpha1.ExecutableWorkflow, - parentNodeStatus v1alpha1.ExecutableNodeStatus, startNode v1alpha1.ExecutableNode) (handler.Transition, error) { +// Helper method that extracts the SubWorkflow from the ExecutionContext +func GetSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext) (v1alpha1.ExecutableSubWorkflow, error) { + node := nCtx.Node() + subID := *node.GetWorkflowNode().GetSubWorkflowRef() + subWorkflow := nCtx.ExecutionContext().FindSubWorkflow(subID) + if subWorkflow == nil { + return nil, fmt.Errorf("failed to find sub workflow with ID [%s]", subID) + } + return subWorkflow, nil +} + +// Performs an additional step of passing in and setting the inputs, before handling the execution of a SubWorkflow. +func (s *subworkflowHandler) startAndHandleSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext, subWorkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) { + // Before starting the subworkflow, lets set the inputs for the Workflow. The inputs for a SubWorkflow are essentially + // Copy of the inputs to the Node + nodeInputs, err := nCtx.InputReader().Get(ctx) + if err != nil { + errMsg := fmt.Sprintf("Failed to read input. Error [%s]", err) + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.RuntimeExecutionError, errMsg, nil)), nil + } + + startStatus, err := s.nodeExecutor.SetInputsForStartNode(ctx, nCtx.ExecutionContext(), subWorkflow, nl, nodeInputs) + if err != nil { + // NOTE: We are implicitly considering an error when setting the inputs as a system error and hence automatically retryable! + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err + } + + if startStatus.HasFailed() { + errorCode, _ := errors.GetErrorCode(startStatus.Err) + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errorCode, startStatus.Err.Error(), nil)), nil + } + return s.handleSubWorkflow(ctx, nCtx, subWorkflow, nl) +} + +// Calls the recursive node executor to handle the SubWorkflow and translates the results after the success +func (s *subworkflowHandler) handleSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext, subworkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) { - // TODO we need to handle failing and success nodes - state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, w, startNode) + state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, nCtx.ExecutionContext(), subworkflow, nl, subworkflow.StartNode()) if err != nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err } if state.HasFailed() { - if w.GetOnFailureNode() != nil { - // TODO ssingh: this is supposed to be failing + if subworkflow.GetOnFailureNode() != nil { + // TODO Handle Failure node for subworkflows. We need to add new state to the executor so that, we can continue returning Running, but in the next round start executing DoInFailureHandling - NOTE1 + // https://github.com/lyft/flyte/issues/265 return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.SubWorkflowExecutionFailed, state.Err.Error(), nil)), err } @@ -40,8 +72,8 @@ func (s *subworkflowHandler) DoInlineSubWorkflow(ctx context.Context, nCtx handl if state.IsComplete() { // If the WF interface has outputs, validate that the outputs file was written. var oInfo *handler.OutputInfo - if outputBindings := w.GetOutputBindings(); len(outputBindings) > 0 { - endNodeStatus := w.GetNodeExecutionStatus(ctx, v1alpha1.EndNodeID) + if outputBindings := subworkflow.GetOutputBindings(); len(outputBindings) > 0 { + endNodeStatus := nl.GetNodeExecutionStatus(ctx, v1alpha1.EndNodeID) store := nCtx.DataStore() if endNodeStatus == nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.SubWorkflowExecutionFailed, "No end node found in subworkflow.", nil)), err @@ -58,7 +90,7 @@ func (s *subworkflowHandler) DoInlineSubWorkflow(ctx context.Context, nCtx handl } // TODO optimization, we could just point the outputInfo to the path of the subworkflows output - destinationPath := v1alpha1.GetOutputsFile(parentNodeStatus.GetOutputDir()) + destinationPath := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) if err := store.CopyRaw(ctx, sourcePath, destinationPath, storage.Options{}); err != nil { errMsg := fmt.Sprintf("Failed to copy subworkflow outputs from [%v] to [%v]", sourcePath, destinationPath) return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.SubWorkflowExecutionFailed, errMsg, nil)), nil @@ -80,9 +112,11 @@ func (s *subworkflowHandler) DoInlineSubWorkflow(ctx context.Context, nCtx handl return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(nil)), nil } -func (s *subworkflowHandler) DoInFailureHandling(ctx context.Context, nCtx handler.NodeExecutionContext, w v1alpha1.ExecutableWorkflow) (handler.Transition, error) { - if w.GetOnFailureNode() != nil { - state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, w, w.GetOnFailureNode()) +// TODO related to NOTE1, this is not used currently, but should be used. For this we will need to clean up the state machine in the main handle function +// https://github.com/lyft/flyte/issues/265 +func (s *subworkflowHandler) HandleFailureNodeOfSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext, subworkflow v1alpha1.ExecutableSubWorkflow, nl executors.NodeLookup) (handler.Transition, error) { + if subworkflow.GetOnFailureNode() != nil { + state, err := s.nodeExecutor.RecursiveNodeHandler(ctx, nCtx.ExecutionContext(), subworkflow, nl, subworkflow.GetOnFailureNode()) if err != nil { return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err } @@ -103,93 +137,37 @@ func (s *subworkflowHandler) DoInFailureHandling(ctx context.Context, nCtx handl } func (s *subworkflowHandler) StartSubWorkflow(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { - node := nCtx.Node() - subID := *node.GetWorkflowNode().GetSubWorkflowRef() - subWorkflow := nCtx.Workflow().FindSubWorkflow(subID) - if subWorkflow == nil { - errMsg := fmt.Sprintf("No subWorkflow [%s], workflow.", subID) - return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.SubWorkflowExecutionFailed, errMsg, nil)), nil - } - - w := nCtx.Workflow() - status := w.GetNodeExecutionStatus(ctx, node.GetID()) - contextualSubWorkflow := executors.NewSubContextualWorkflow(w, subWorkflow, status) - startNode := contextualSubWorkflow.StartNode() - if startNode == nil { - errMsg := "No start node found in subworkflow." - return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.SubWorkflowExecutionFailed, errMsg, nil)), nil - } - - // Before starting the subworkflow, lets set the inputs for the Workflow. The inputs for a SubWorkflow are essentially - // Copy of the inputs to the Node - nodeInputs, err := nCtx.InputReader().Get(ctx) + subWorkflow, err := GetSubWorkflow(ctx, nCtx) if err != nil { - errMsg := fmt.Sprintf("Failed to read input. Error [%s]", err) - return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.RuntimeExecutionError, errMsg, nil)), nil + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.SubWorkflowExecutionFailed, err.Error(), nil)), nil } - startStatus, err := s.nodeExecutor.SetInputsForStartNode(ctx, contextualSubWorkflow, nodeInputs) - if err != nil { - // TODO we are considering an error when setting inputs are retryable - return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoUndefined), err - } - - if startStatus.HasFailed() { - errorCode, _ := errors.GetErrorCode(startStatus.Err) - return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errorCode, startStatus.Err.Error(), nil)), nil - } + status := nCtx.NodeStatus() + nodeLookup := executors.NewNodeLookup(subWorkflow, status) // assert startStatus.IsComplete() == true - return s.DoInlineSubWorkflow(ctx, nCtx, contextualSubWorkflow, status, startNode) + return s.startAndHandleSubWorkflow(ctx, nCtx, subWorkflow, nodeLookup) } -func (s *subworkflowHandler) CheckSubWorkflowStatus(ctx context.Context, nCtx handler.NodeExecutionContext, w v1alpha1.ExecutableWorkflow, status v1alpha1.ExecutableNodeStatus) (handler.Transition, error) { - // Handle subworkflow - subID := *nCtx.Node().GetWorkflowNode().GetSubWorkflowRef() - subWorkflow := w.FindSubWorkflow(subID) - if subWorkflow == nil { - errMsg := fmt.Sprintf("No subWorkflow [%s], workflow.", subID) - return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.SubWorkflowExecutionFailed, errMsg, nil)), nil - } - - contextualSubWorkflow := executors.NewSubContextualWorkflow(w, subWorkflow, status) - startNode := w.StartNode() - if startNode == nil { - errMsg := "No start node found in subworkflow" - return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.SubWorkflowExecutionFailed, errMsg, nil)), nil +func (s *subworkflowHandler) CheckSubWorkflowStatus(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { + subWorkflow, err := GetSubWorkflow(ctx, nCtx) + if err != nil { + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.SubWorkflowExecutionFailed, err.Error(), nil)), nil } - parentNodeStatus := w.GetNodeExecutionStatus(ctx, nCtx.NodeID()) - return s.DoInlineSubWorkflow(ctx, nCtx, contextualSubWorkflow, parentNodeStatus, startNode) + status := nCtx.NodeStatus() + nodeLookup := executors.NewNodeLookup(subWorkflow, status) + return s.startAndHandleSubWorkflow(ctx, nCtx, subWorkflow, nodeLookup) } -func (s *subworkflowHandler) HandleSubWorkflowFailingNode(ctx context.Context, nCtx handler.NodeExecutionContext, w v1alpha1.ExecutableWorkflow, node v1alpha1.ExecutableNode) (handler.Transition, error) { - status := w.GetNodeExecutionStatus(ctx, node.GetID()) - subID := *node.GetWorkflowNode().GetSubWorkflowRef() - subWorkflow := w.FindSubWorkflow(subID) - if subWorkflow == nil { - errMsg := fmt.Sprintf("No subWorkflow [%s], workflow.", subID) - return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(errors.SubWorkflowExecutionFailed, errMsg, nil)), nil - } - contextualSubWorkflow := executors.NewSubContextualWorkflow(w, subWorkflow, status) - return s.DoInFailureHandling(ctx, nCtx, contextualSubWorkflow) -} - -func (s *subworkflowHandler) HandleAbort(ctx context.Context, nCtx handler.NodeExecutionContext, w v1alpha1.ExecutableWorkflow, workflowID v1alpha1.WorkflowID, reason string) error { - subWorkflow := w.FindSubWorkflow(workflowID) - if subWorkflow == nil { - return fmt.Errorf("no sub workflow [%s] found in node [%s]", workflowID, nCtx.NodeID()) - } - - nodeStatus := w.GetNodeExecutionStatus(ctx, nCtx.NodeID()) - contextualSubWorkflow := executors.NewSubContextualWorkflow(w, subWorkflow, nodeStatus) - - startNode := contextualSubWorkflow.StartNode() - if startNode == nil { - return fmt.Errorf("no sub workflow [%s] found in node [%s]", workflowID, nCtx.NodeID()) +func (s *subworkflowHandler) HandleAbort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { + subWorkflow, err := GetSubWorkflow(ctx, nCtx) + if err != nil { + return err } - - return s.nodeExecutor.AbortHandler(ctx, contextualSubWorkflow, startNode, reason) + status := nCtx.NodeStatus() + nodeLookup := executors.NewNodeLookup(subWorkflow, status) + return s.nodeExecutor.AbortHandler(ctx, nCtx.ExecutionContext(), subWorkflow, nodeLookup, subWorkflow.StartNode(), reason) } func newSubworkflowHandler(nodeExecutor executors.Node) subworkflowHandler { diff --git a/pkg/controller/nodes/subworkflow/subworkflow_test.go b/pkg/controller/nodes/subworkflow/subworkflow_test.go index aaab2c511e..3fe7c57d98 100644 --- a/pkg/controller/nodes/subworkflow/subworkflow_test.go +++ b/pkg/controller/nodes/subworkflow/subworkflow_test.go @@ -8,74 +8,148 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" coreMocks "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" execMocks "github.com/lyft/flytepropeller/pkg/controller/executors/mocks" "github.com/lyft/flytepropeller/pkg/controller/nodes/handler/mocks" ) -func Test_subworkflowHandler_HandleAbort(t *testing.T) { +func TestGetSubWorkflow(t *testing.T) { ctx := context.TODO() + t.Run("subworkflow", func(t *testing.T) { + + wfNode := &coreMocks.ExecutableWorkflowNode{} + x := "x" + wfNode.OnGetSubWorkflowRef().Return(&x) + + node := &coreMocks.ExecutableNode{} + node.OnGetWorkflowNode().Return(wfNode) + + ectx := &execMocks.ExecutionContext{} + + swf := &coreMocks.ExecutableSubWorkflow{} + ectx.OnFindSubWorkflow("x").Return(swf) + + nCtx := &mocks.NodeExecutionContext{} + nCtx.OnNode().Return(node) + nCtx.OnExecutionContext().Return(ectx) + + w, err := GetSubWorkflow(ctx, nCtx) + assert.NoError(t, err) + assert.Equal(t, swf, w) + }) + t.Run("missing-subworkflow", func(t *testing.T) { + + wfNode := &coreMocks.ExecutableWorkflowNode{} + x := "x" + wfNode.OnGetSubWorkflowRef().Return(&x) + + node := &coreMocks.ExecutableNode{} + node.OnGetWorkflowNode().Return(wfNode) + + ectx := &execMocks.ExecutionContext{} + + ectx.OnFindSubWorkflow("x").Return(nil) + nCtx := &mocks.NodeExecutionContext{} - nodeExec := &execMocks.Node{} - s := newSubworkflowHandler(nodeExec) - wf := &coreMocks.ExecutableWorkflow{} - wf.OnFindSubWorkflow("x").Return(nil) - nCtx.OnNodeID().Return("n1") - assert.Error(t, s.HandleAbort(ctx, nCtx, wf, "x", "reason")) + nCtx.OnNode().Return(node) + nCtx.OnExecutionContext().Return(ectx) + + _, err := GetSubWorkflow(ctx, nCtx) + assert.Error(t, err) }) +} + +func Test_subworkflowHandler_HandleAbort(t *testing.T) { + ctx := context.TODO() t.Run("missing-startNode", func(t *testing.T) { + + wfNode := &coreMocks.ExecutableWorkflowNode{} + x := "x" + wfNode.OnGetSubWorkflowRef().Return(&x) + + node := &coreMocks.ExecutableNode{} + node.OnGetWorkflowNode().Return(wfNode) + + swf := &coreMocks.ExecutableSubWorkflow{} + ectx := &execMocks.ExecutionContext{} + ectx.OnFindSubWorkflow("x").Return(swf) + + ns := &coreMocks.ExecutableNodeStatus{} nCtx := &mocks.NodeExecutionContext{} + nCtx.OnNode().Return(node) + nCtx.OnExecutionContext().Return(ectx) + nCtx.OnNodeStatus().Return(ns) + nCtx.OnNodeID().Return("n1") + nodeExec := &execMocks.Node{} s := newSubworkflowHandler(nodeExec) - wf := &coreMocks.ExecutableWorkflow{} - st := &coreMocks.ExecutableNodeStatus{} - swf := &coreMocks.ExecutableSubWorkflow{} - wf.OnFindSubWorkflow("x").Return(swf) - wf.OnGetNodeExecutionStatus(ctx, "n1").Return(st) - nCtx.OnNodeID().Return("n1") - swf.OnStartNode().Return(nil) - assert.Error(t, s.HandleAbort(ctx, nCtx, wf, "x", "reason")) + n := &coreMocks.ExecutableNode{} + swf.OnGetID().Return("swf") + nodeExec.OnAbortHandlerMatch(mock.Anything, ectx, swf, mock.Anything, n, "reason").Return(nil) + assert.Panics(t, func() { + _ = s.HandleAbort(ctx, nCtx, "reason") + }) }) t.Run("abort-error", func(t *testing.T) { + wfNode := &coreMocks.ExecutableWorkflowNode{} + x := "x" + wfNode.OnGetSubWorkflowRef().Return(&x) + + node := &coreMocks.ExecutableNode{} + node.OnGetWorkflowNode().Return(wfNode) + + swf := &coreMocks.ExecutableSubWorkflow{} + swf.OnStartNode().Return(&coreMocks.ExecutableNode{}) + + ectx := &execMocks.ExecutionContext{} + ectx.OnFindSubWorkflow("x").Return(swf) + + ns := &coreMocks.ExecutableNodeStatus{} nCtx := &mocks.NodeExecutionContext{} + nCtx.OnNode().Return(node) + nCtx.OnExecutionContext().Return(ectx) + nCtx.OnNodeStatus().Return(ns) + nCtx.OnNodeID().Return("n1") + nodeExec := &execMocks.Node{} s := newSubworkflowHandler(nodeExec) - wf := &coreMocks.ExecutableWorkflow{} - st := &coreMocks.ExecutableNodeStatus{} - swf := &coreMocks.ExecutableSubWorkflow{} - wf.OnFindSubWorkflow("x").Return(swf) - wf.OnGetNodeExecutionStatus(ctx, "n1").Return(st) - nCtx.OnNodeID().Return("n1") n := &coreMocks.ExecutableNode{} - swf.OnStartNode().Return(n) swf.OnGetID().Return("swf") - nodeExec.OnAbortHandlerMatch(mock.Anything, mock.MatchedBy(func(wf v1alpha1.ExecutableWorkflow) bool { - return wf.GetID() == swf.GetID() - }), n, mock.Anything).Return(fmt.Errorf("err")) - assert.Error(t, s.HandleAbort(ctx, nCtx, wf, "x", "reason")) + nodeExec.OnAbortHandlerMatch(mock.Anything, ectx, swf, mock.Anything, n, "reason").Return(fmt.Errorf("err")) + assert.Error(t, s.HandleAbort(ctx, nCtx, "reason")) }) t.Run("abort-success", func(t *testing.T) { + + wfNode := &coreMocks.ExecutableWorkflowNode{} + x := "x" + wfNode.OnGetSubWorkflowRef().Return(&x) + + node := &coreMocks.ExecutableNode{} + node.OnGetWorkflowNode().Return(wfNode) + + swf := &coreMocks.ExecutableSubWorkflow{} + swf.OnStartNode().Return(&coreMocks.ExecutableNode{}) + + ectx := &execMocks.ExecutionContext{} + ectx.OnFindSubWorkflow("x").Return(swf) + + ns := &coreMocks.ExecutableNodeStatus{} nCtx := &mocks.NodeExecutionContext{} + nCtx.OnNode().Return(node) + nCtx.OnExecutionContext().Return(ectx) + nCtx.OnNodeStatus().Return(ns) + nCtx.OnNodeID().Return("n1") + nodeExec := &execMocks.Node{} s := newSubworkflowHandler(nodeExec) - wf := &coreMocks.ExecutableWorkflow{} - st := &coreMocks.ExecutableNodeStatus{} - swf := &coreMocks.ExecutableSubWorkflow{} - wf.OnFindSubWorkflow("x").Return(swf) - wf.OnGetNodeExecutionStatus(ctx, "n1").Return(st) - nCtx.OnNodeID().Return("n1") n := &coreMocks.ExecutableNode{} - swf.OnStartNode().Return(n) swf.OnGetID().Return("swf") - nodeExec.OnAbortHandlerMatch(mock.Anything, mock.MatchedBy(func(wf v1alpha1.ExecutableWorkflow) bool { - return wf.GetID() == swf.GetID() - }), n, mock.Anything).Return(nil) - assert.NoError(t, s.HandleAbort(ctx, nCtx, wf, "x", "reason")) + nodeExec.OnAbortHandlerMatch(mock.Anything, ectx, swf, mock.Anything, n, "reason").Return(nil) + assert.NoError(t, s.HandleAbort(ctx, nCtx, "reason")) }) } diff --git a/pkg/controller/nodes/subworkflow/util.go b/pkg/controller/nodes/subworkflow/util.go index 973a2e0b93..39e33a867c 100644 --- a/pkg/controller/nodes/subworkflow/util.go +++ b/pkg/controller/nodes/subworkflow/util.go @@ -4,21 +4,21 @@ import ( "strconv" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" - "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/lyft/flytepropeller/pkg/utils" ) const maxLengthForSubWorkflow = 20 -func GetChildWorkflowExecutionID(parentID *core.WorkflowExecutionIdentifier, id v1alpha1.NodeID, attempt uint32) (*core.WorkflowExecutionIdentifier, error) { - name, err := utils.FixedLengthUniqueIDForParts(maxLengthForSubWorkflow, parentID.Name, id, strconv.Itoa(int(attempt))) +func GetChildWorkflowExecutionID(nodeExecID *core.NodeExecutionIdentifier, attempt uint32) (*core.WorkflowExecutionIdentifier, error) { + name, err := utils.FixedLengthUniqueIDForParts(maxLengthForSubWorkflow, nodeExecID.ExecutionId.Name, nodeExecID.NodeId, strconv.Itoa(int(attempt))) if err != nil { return nil, err } // Restriction on name is 20 chars return &core.WorkflowExecutionIdentifier{ - Project: parentID.Project, - Domain: parentID.Domain, + Project: nodeExecID.ExecutionId.Project, + Domain: nodeExecID.ExecutionId.Domain, Name: name, }, nil } diff --git a/pkg/controller/nodes/subworkflow/util_test.go b/pkg/controller/nodes/subworkflow/util_test.go index a3e126f94b..f65ddb4725 100644 --- a/pkg/controller/nodes/subworkflow/util_test.go +++ b/pkg/controller/nodes/subworkflow/util_test.go @@ -9,11 +9,15 @@ import ( func TestGetChildWorkflowExecutionID(t *testing.T) { id, err := GetChildWorkflowExecutionID( - &core.WorkflowExecutionIdentifier{ - Project: "project", - Domain: "domain", - Name: "first-name-is-pretty-large", - }, "hello-world", 1) + &core.NodeExecutionIdentifier{ + NodeId: "hello-world", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "first-name-is-pretty-large", + }, + }, + 1) assert.Equal(t, id.Name, "fav2uxxi") assert.NoError(t, err) } diff --git a/pkg/controller/nodes/task/handler_test.go b/pkg/controller/nodes/task/handler_test.go index 4fe39d4c0f..60740edfc3 100644 --- a/pkg/controller/nodes/task/handler_test.go +++ b/pkg/controller/nodes/task/handler_test.go @@ -33,7 +33,6 @@ import ( v12 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" - "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" flyteMocks "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" "github.com/lyft/flytepropeller/pkg/controller/executors/mocks" "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" @@ -320,16 +319,19 @@ func Test_task_Handle_NoCatalog(t *testing.T) { Name: "name", } + nodeID := "n1" + nm := &nodeMocks.NodeExecutionMetadata{} - nm.On("GetAnnotations").Return(map[string]string{}) - nm.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: wfExecID, + nm.OnGetAnnotations().Return(map[string]string{}) + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ + NodeId: nodeID, + ExecutionId: wfExecID, }) - nm.On("GetK8sServiceAccount").Return("service-account") - nm.On("GetLabels").Return(map[string]string{}) - nm.On("GetNamespace").Return("namespace") - nm.On("GetOwnerID").Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) - nm.On("GetOwnerReference").Return(v12.OwnerReference{ + nm.OnGetK8sServiceAccount().Return("service-account") + nm.OnGetLabels().Return(map[string]string{}) + nm.OnGetNamespace().Return("namespace") + nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.OnGetOwnerReference().Return(v12.OwnerReference{ Kind: "sample", Name: "name", }) @@ -356,32 +358,39 @@ func Test_task_Handle_NoCatalog(t *testing.T) { } taskID := &core.Identifier{} tr := &nodeMocks.TaskReader{} - tr.On("GetTaskID").Return(taskID) - tr.On("GetTaskType").Return(ttype) - tr.On("Read", mock.Anything).Return(tk, nil) + tr.OnGetTaskID().Return(taskID) + tr.OnGetTaskType().Return(ttype) + tr.OnReadMatch(mock.Anything).Return(tk, nil) ns := &flyteMocks.ExecutableNodeStatus{} - ns.On("GetDataDir").Return(storage.DataReference("data-dir")) - ns.On("GetOutputDir").Return(storage.DataReference("data-dir")) + ns.OnGetDataDir().Return("data-dir") + ns.OnGetOutputDir().Return("data-dir") res := &v1.ResourceRequirements{} n := &flyteMocks.ExecutableNode{} - n.On("GetResources").Return(res) + n.OnGetResources().Return(res) ir := &ioMocks.InputReader{} - ir.On("GetInputPath").Return(storage.DataReference("input")) + ir.OnGetInputPath().Return("input") nCtx := &nodeMocks.NodeExecutionContext{} - nCtx.On("NodeExecutionMetadata").Return(nm) - nCtx.On("Node").Return(n) - nCtx.On("InputReader").Return(ir) - nCtx.On("DataStore").Return(storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())) - nCtx.On("CurrentAttempt").Return(uint32(1)) - nCtx.On("TaskReader").Return(tr) - nCtx.On("MaxDatasetSizeBytes").Return(int64(1)) - nCtx.On("NodeStatus").Return(ns) - nCtx.On("NodeID").Return("n1") - nCtx.On("EventsRecorder").Return(recorder) - nCtx.On("EnqueueOwner").Return(nil) + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnNode().Return(n) + nCtx.OnInputReader().Return(ir) + ds, err := storage.NewDataStore( + &storage.Config{ + Type: storage.TypeMemory, + }, + promutils.NewTestScope(), + ) + assert.NoError(t, err) + nCtx.OnDataStore().Return(ds) + nCtx.OnCurrentAttempt().Return(uint32(1)) + nCtx.OnTaskReader().Return(tr) + nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) + nCtx.OnNodeStatus().Return(ns) + nCtx.OnNodeID().Return(nodeID) + nCtx.OnEventsRecorder().Return(recorder) + nCtx.OnEnqueueOwnerFunc().Return(nil) nCtx.OnRawOutputPrefix().Return("s3://sandbox/") nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"})) @@ -390,13 +399,13 @@ func Test_task_Handle_NoCatalog(t *testing.T) { cod := codex.GobStateCodec{} assert.NoError(t, cod.Encode(pluginResp, st)) nr := &nodeMocks.NodeStateReader{} - nr.On("GetTaskNodeState").Return(handler.TaskNodeState{ + nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ PluginState: st.Bytes(), PluginPhase: pluginPhase, PluginPhaseVersion: pluginVer, }) - nCtx.On("NodeStateReader").Return(nr) - nCtx.On("NodeStateWriter").Return(s) + nCtx.OnNodeStateReader().Return(nr) + nCtx.OnNodeStateWriter().Return(s) return nCtx } @@ -624,16 +633,19 @@ func Test_task_Handle_Catalog(t *testing.T) { Name: "name", } + nodeID := "n1" + nm := &nodeMocks.NodeExecutionMetadata{} - nm.On("GetAnnotations").Return(map[string]string{}) - nm.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: wfExecID, + nm.OnGetAnnotations().Return(map[string]string{}) + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ + NodeId: nodeID, + ExecutionId: wfExecID, }) - nm.On("GetK8sServiceAccount").Return("service-account") - nm.On("GetLabels").Return(map[string]string{}) - nm.On("GetNamespace").Return("namespace") - nm.On("GetOwnerID").Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) - nm.On("GetOwnerReference").Return(v12.OwnerReference{ + nm.OnGetK8sServiceAccount().Return("service-account") + nm.OnGetLabels().Return(map[string]string{}) + nm.OnGetNamespace().Return("namespace") + nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.OnGetOwnerReference().Return(v12.OwnerReference{ Kind: "sample", Name: "name", }) @@ -660,32 +672,40 @@ func Test_task_Handle_Catalog(t *testing.T) { }, } tr := &nodeMocks.TaskReader{} - tr.On("GetTaskID").Return(taskID) - tr.On("GetTaskType").Return(ttype) - tr.On("Read", mock.Anything).Return(tk, nil) + tr.OnGetTaskID().Return(taskID) + tr.OnGetTaskType().Return(ttype) + tr.OnReadMatch(mock.Anything).Return(tk, nil) ns := &flyteMocks.ExecutableNodeStatus{} - ns.On("GetDataDir").Return(storage.DataReference("data-dir")) - ns.On("GetOutputDir").Return(storage.DataReference("output-dir")) + ns.OnGetDataDir().Return(storage.DataReference("data-dir")) + ns.OnGetOutputDir().Return(storage.DataReference("output-dir")) res := &v1.ResourceRequirements{} n := &flyteMocks.ExecutableNode{} - n.On("GetResources").Return(res) + n.OnGetResources().Return(res) ir := &ioMocks.InputReader{} - ir.On("GetInputPath").Return(storage.DataReference("input")) + ir.OnGetInputPath().Return(storage.DataReference("input")) nCtx := &nodeMocks.NodeExecutionContext{} - nCtx.On("NodeExecutionMetadata").Return(nm) - nCtx.On("Node").Return(n) - nCtx.On("InputReader").Return(ir) - nCtx.On("DataStore").Return(storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())) - nCtx.On("CurrentAttempt").Return(uint32(1)) - nCtx.On("TaskReader").Return(tr) - nCtx.On("MaxDatasetSizeBytes").Return(int64(1)) - nCtx.On("NodeStatus").Return(ns) - nCtx.On("NodeID").Return("n1") - nCtx.On("EventsRecorder").Return(recorder) - nCtx.On("EnqueueOwner").Return(nil) + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnNode().Return(n) + nCtx.OnInputReader().Return(ir) + nCtx.OnInputReader().Return(ir) + ds, err := storage.NewDataStore( + &storage.Config{ + Type: storage.TypeMemory, + }, + promutils.NewTestScope(), + ) + assert.NoError(t, err) + nCtx.OnDataStore().Return(ds) + nCtx.OnCurrentAttempt().Return(uint32(1)) + nCtx.OnTaskReader().Return(tr) + nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) + nCtx.OnNodeStatus().Return(ns) + nCtx.OnNodeID().Return(nodeID) + nCtx.OnEventsRecorder().Return(recorder) + nCtx.OnEnqueueOwnerFunc().Return(nil) nCtx.OnRawOutputPrefix().Return("s3://sandbox/") nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"})) @@ -697,11 +717,11 @@ func Test_task_Handle_Catalog(t *testing.T) { OutputExists: true, }, st)) nr := &nodeMocks.NodeStateReader{} - nr.On("GetTaskNodeState").Return(handler.TaskNodeState{ + nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ PluginState: st.Bytes(), }) - nCtx.On("NodeStateReader").Return(nr) - nCtx.On("NodeStateWriter").Return(s) + nCtx.OnNodeStateReader().Return(nr) + nCtx.OnNodeStateWriter().Return(s) return nCtx } @@ -821,6 +841,7 @@ func Test_task_Handle_Catalog(t *testing.T) { } func Test_task_Handle_Barrier(t *testing.T) { + // NOTE: Caching is disabled for this test createNodeContext := func(recorder events.TaskEventRecorder, ttype string, s *taskNodeStateHolder, prevBarrierClockTick uint32) *nodeMocks.NodeExecutionContext { wfExecID := &core.WorkflowExecutionIdentifier{ @@ -829,16 +850,19 @@ func Test_task_Handle_Barrier(t *testing.T) { Name: "name", } + nodeID := "n1" + nm := &nodeMocks.NodeExecutionMetadata{} - nm.On("GetAnnotations").Return(map[string]string{}) - nm.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: wfExecID, + nm.OnGetAnnotations().Return(map[string]string{}) + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ + NodeId: nodeID, + ExecutionId: wfExecID, }) - nm.On("GetK8sServiceAccount").Return("service-account") - nm.On("GetLabels").Return(map[string]string{}) - nm.On("GetNamespace").Return("namespace") - nm.On("GetOwnerID").Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) - nm.On("GetOwnerReference").Return(v12.OwnerReference{ + nm.OnGetK8sServiceAccount().Return("service-account") + nm.OnGetLabels().Return(map[string]string{}) + nm.OnGetNamespace().Return("namespace") + nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.OnGetOwnerReference().Return(v12.OwnerReference{ Kind: "sample", Name: "name", }) @@ -865,32 +889,39 @@ func Test_task_Handle_Barrier(t *testing.T) { }, } tr := &nodeMocks.TaskReader{} - tr.On("GetTaskID").Return(taskID) - tr.On("GetTaskType").Return(ttype) - tr.On("Read", mock.Anything).Return(tk, nil) + tr.OnGetTaskID().Return(taskID) + tr.OnGetTaskType().Return(ttype) + tr.OnReadMatch(mock.Anything).Return(tk, nil) ns := &flyteMocks.ExecutableNodeStatus{} - ns.On("GetDataDir").Return(storage.DataReference("data-dir")) - ns.On("GetOutputDir").Return(storage.DataReference("output-dir")) + ns.OnGetDataDir().Return(storage.DataReference("data-dir")) + ns.OnGetOutputDir().Return(storage.DataReference("output-dir")) res := &v1.ResourceRequirements{} n := &flyteMocks.ExecutableNode{} - n.On("GetResources").Return(res) + n.OnGetResources().Return(res) ir := &ioMocks.InputReader{} - ir.On("GetInputPath").Return(storage.DataReference("input")) + ir.OnGetInputPath().Return(storage.DataReference("input")) nCtx := &nodeMocks.NodeExecutionContext{} - nCtx.On("NodeExecutionMetadata").Return(nm) - nCtx.On("Node").Return(n) - nCtx.On("InputReader").Return(ir) - nCtx.On("DataStore").Return(storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())) - nCtx.On("CurrentAttempt").Return(uint32(1)) - nCtx.On("TaskReader").Return(tr) - nCtx.On("MaxDatasetSizeBytes").Return(int64(1)) - nCtx.On("NodeStatus").Return(ns) - nCtx.On("NodeID").Return("n1") - nCtx.On("EventsRecorder").Return(recorder) - nCtx.On("EnqueueOwner").Return(nil) + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnNode().Return(n) + nCtx.OnInputReader().Return(ir) + ds, err := storage.NewDataStore( + &storage.Config{ + Type: storage.TypeMemory, + }, + promutils.NewTestScope(), + ) + assert.NoError(t, err) + nCtx.OnDataStore().Return(ds) + nCtx.OnCurrentAttempt().Return(uint32(1)) + nCtx.OnTaskReader().Return(tr) + nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) + nCtx.OnNodeStatus().Return(ns) + nCtx.OnNodeID().Return("n1") + nCtx.OnEventsRecorder().Return(recorder) + nCtx.OnEnqueueOwnerFunc().Return(nil) nCtx.OnRawOutputPrefix().Return("s3://sandbox/") nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"})) @@ -902,12 +933,12 @@ func Test_task_Handle_Barrier(t *testing.T) { OutputExists: true, }, st)) nr := &nodeMocks.NodeStateReader{} - nr.On("GetTaskNodeState").Return(handler.TaskNodeState{ + nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ PluginState: st.Bytes(), BarrierClockTick: prevBarrierClockTick, }) - nCtx.On("NodeStateReader").Return(nr) - nCtx.On("NodeStateWriter").Return(s) + nCtx.OnNodeStateReader().Return(nr) + nCtx.OnNodeStateWriter().Return(s) return nCtx } @@ -1107,46 +1138,56 @@ func Test_task_Abort(t *testing.T) { Name: "name", } + nodeID := "n1" + nm := &nodeMocks.NodeExecutionMetadata{} - nm.On("GetAnnotations").Return(map[string]string{}) - nm.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: wfExecID, + nm.OnGetAnnotations().Return(map[string]string{}) + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ + NodeId: nodeID, + ExecutionId: wfExecID, }) - nm.On("GetK8sServiceAccount").Return("service-account") - nm.On("GetLabels").Return(map[string]string{}) - nm.On("GetNamespace").Return("namespace") - nm.On("GetOwnerID").Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) - nm.On("GetOwnerReference").Return(v12.OwnerReference{ + nm.OnGetK8sServiceAccount().Return("service-account") + nm.OnGetLabels().Return(map[string]string{}) + nm.OnGetNamespace().Return("namespace") + nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.OnGetOwnerReference().Return(v12.OwnerReference{ Kind: "sample", Name: "name", }) taskID := &core.Identifier{} tr := &nodeMocks.TaskReader{} - tr.On("GetTaskID").Return(taskID) - tr.On("GetTaskType").Return("x") + tr.OnGetTaskID().Return(taskID) + tr.OnGetTaskType().Return("x") ns := &flyteMocks.ExecutableNodeStatus{} - ns.On("GetDataDir").Return(storage.DataReference("data-dir")) - ns.On("GetOutputDir").Return(storage.DataReference("output-dir")) + ns.OnGetDataDir().Return(storage.DataReference("data-dir")) + ns.OnGetOutputDir().Return(storage.DataReference("output-dir")) res := &v1.ResourceRequirements{} n := &flyteMocks.ExecutableNode{} - n.On("GetResources").Return(res) + n.OnGetResources().Return(res) ir := &ioMocks.InputReader{} nCtx := &nodeMocks.NodeExecutionContext{} - nCtx.On("NodeExecutionMetadata").Return(nm) - nCtx.On("Node").Return(n) - nCtx.On("InputReader").Return(ir) - nCtx.On("DataStore").Return(storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())) - nCtx.On("CurrentAttempt").Return(uint32(1)) - nCtx.On("TaskReader").Return(tr) - nCtx.On("MaxDatasetSizeBytes").Return(int64(1)) - nCtx.On("NodeStatus").Return(ns) - nCtx.On("NodeID").Return("n1") - nCtx.On("EnqueueOwner").Return(nil) - nCtx.On("EventsRecorder").Return(ev) + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnNode().Return(n) + nCtx.OnInputReader().Return(ir) + ds, err := storage.NewDataStore( + &storage.Config{ + Type: storage.TypeMemory, + }, + promutils.NewTestScope(), + ) + assert.NoError(t, err) + nCtx.OnDataStore().Return(ds) + nCtx.OnCurrentAttempt().Return(uint32(1)) + nCtx.OnTaskReader().Return(tr) + nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) + nCtx.OnNodeStatus().Return(ns) + nCtx.OnNodeID().Return("n1") + nCtx.OnEnqueueOwnerFunc().Return(nil) + nCtx.OnEventsRecorder().Return(ev) nCtx.OnRawOutputPrefix().Return("s3://sandbox/") nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"})) @@ -1159,10 +1200,10 @@ func Test_task_Abort(t *testing.T) { cod := codex.GobStateCodec{} assert.NoError(t, cod.Encode(test{A: a}, st)) nr := &nodeMocks.NodeStateReader{} - nr.On("GetTaskNodeState").Return(handler.TaskNodeState{ + nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ PluginState: st.Bytes(), }) - nCtx.On("NodeStateReader").Return(nr) + nCtx.OnNodeStateReader().Return(nr) return nCtx } @@ -1231,46 +1272,56 @@ func Test_task_Finalize(t *testing.T) { Name: "name", } + nodeID := "n1" + nm := &nodeMocks.NodeExecutionMetadata{} - nm.On("GetAnnotations").Return(map[string]string{}) - nm.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: wfExecID, + nm.OnGetAnnotations().Return(map[string]string{}) + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ + NodeId: nodeID, + ExecutionId: wfExecID, }) - nm.On("GetK8sServiceAccount").Return("service-account") - nm.On("GetLabels").Return(map[string]string{}) - nm.On("GetNamespace").Return("namespace") - nm.On("GetOwnerID").Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) - nm.On("GetOwnerReference").Return(v12.OwnerReference{ + nm.OnGetK8sServiceAccount().Return("service-account") + nm.OnGetLabels().Return(map[string]string{}) + nm.OnGetNamespace().Return("namespace") + nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.OnGetOwnerReference().Return(v12.OwnerReference{ Kind: "sample", Name: "name", }) taskID := &core.Identifier{} tr := &nodeMocks.TaskReader{} - tr.On("GetTaskID").Return(taskID) - tr.On("GetTaskType").Return("x") + tr.OnGetTaskID().Return(taskID) + tr.OnGetTaskType().Return("x") ns := &flyteMocks.ExecutableNodeStatus{} - ns.On("GetDataDir").Return(storage.DataReference("data-dir")) - ns.On("GetOutputDir").Return(storage.DataReference("output-dir")) + ns.OnGetDataDir().Return(storage.DataReference("data-dir")) + ns.OnGetOutputDir().Return(storage.DataReference("output-dir")) res := &v1.ResourceRequirements{} n := &flyteMocks.ExecutableNode{} - n.On("GetResources").Return(res) + n.OnGetResources().Return(res) ir := &ioMocks.InputReader{} nCtx := &nodeMocks.NodeExecutionContext{} - nCtx.On("NodeExecutionMetadata").Return(nm) - nCtx.On("Node").Return(n) - nCtx.On("InputReader").Return(ir) - nCtx.On("DataStore").Return(storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())) - nCtx.On("CurrentAttempt").Return(uint32(1)) - nCtx.On("TaskReader").Return(tr) - nCtx.On("MaxDatasetSizeBytes").Return(int64(1)) - nCtx.On("NodeStatus").Return(ns) - nCtx.On("NodeID").Return("n1") - nCtx.On("EventsRecorder").Return(nil) - nCtx.On("EnqueueOwner").Return(nil) + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnNode().Return(n) + nCtx.OnInputReader().Return(ir) + ds, err := storage.NewDataStore( + &storage.Config{ + Type: storage.TypeMemory, + }, + promutils.NewTestScope(), + ) + assert.NoError(t, err) + nCtx.OnDataStore().Return(ds) + nCtx.OnCurrentAttempt().Return(uint32(1)) + nCtx.OnTaskReader().Return(tr) + nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) + nCtx.OnNodeStatus().Return(ns) + nCtx.OnNodeID().Return("n1") + nCtx.OnEventsRecorder().Return(nil) + nCtx.OnEnqueueOwnerFunc().Return(nil) nCtx.OnRawOutputPrefix().Return("s3://sandbox/") nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"})) diff --git a/pkg/controller/nodes/task/taskexec_context_test.go b/pkg/controller/nodes/task/taskexec_context_test.go index 356bdfe134..1e03305b19 100644 --- a/pkg/controller/nodes/task/taskexec_context_test.go +++ b/pkg/controller/nodes/task/taskexec_context_test.go @@ -16,7 +16,6 @@ import ( v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" - "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" flyteMocks "github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" "github.com/lyft/flytepropeller/pkg/controller/nodes/handler" nodeMocks "github.com/lyft/flytepropeller/pkg/controller/nodes/handler/mocks" @@ -31,45 +30,56 @@ func TestHandler_newTaskExecutionContext(t *testing.T) { Name: "name", } + nodeID := "n1" + nm := &nodeMocks.NodeExecutionMetadata{} - nm.On("GetAnnotations").Return(map[string]string{}) - nm.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ - WorkflowExecutionIdentifier: wfExecID, + nm.OnGetAnnotations().Return(map[string]string{}) + nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ + NodeId: nodeID, + ExecutionId: wfExecID, }) - nm.On("GetK8sServiceAccount").Return("service-account") - nm.On("GetLabels").Return(map[string]string{}) - nm.On("GetNamespace").Return("namespace") - nm.On("GetOwnerID").Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) - nm.On("GetOwnerReference").Return(v1.OwnerReference{ + nm.OnGetK8sServiceAccount().Return("service-account") + nm.OnGetLabels().Return(map[string]string{}) + nm.OnGetNamespace().Return("namespace") + nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.OnGetOwnerReference().Return(v1.OwnerReference{ Kind: "sample", Name: "name", }) taskID := &core.Identifier{} tr := &nodeMocks.TaskReader{} - tr.On("GetTaskID").Return(taskID) + tr.OnGetTaskID().Return(taskID) ns := &flyteMocks.ExecutableNodeStatus{} - ns.On("GetDataDir").Return(storage.DataReference("data-dir")) - ns.On("GetOutputDir").Return(storage.DataReference("output-dir")) + ns.OnGetDataDir().Return(storage.DataReference("data-dir")) + ns.OnGetOutputDir().Return(storage.DataReference("output-dir")) res := &v12.ResourceRequirements{} n := &flyteMocks.ExecutableNode{} - n.On("GetResources").Return(res) + n.OnGetResources().Return(res) ir := &ioMocks.InputReader{} nCtx := &nodeMocks.NodeExecutionContext{} - nCtx.On("NodeExecutionMetadata").Return(nm) - nCtx.On("Node").Return(n) - nCtx.On("InputReader").Return(ir) - nCtx.On("DataStore").Return(storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())) - nCtx.On("CurrentAttempt").Return(uint32(1)) - nCtx.On("TaskReader").Return(tr) - nCtx.On("MaxDatasetSizeBytes").Return(int64(1)) - nCtx.On("NodeStatus").Return(ns) - nCtx.On("NodeID").Return("n1") - nCtx.On("EventsRecorder").Return(nil) - nCtx.On("EnqueueOwner").Return(nil) + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnNode().Return(n) + nCtx.OnInputReader().Return(ir) + nCtx.OnCurrentAttempt().Return(uint32(1)) + nCtx.OnTaskReader().Return(tr) + nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) + nCtx.OnNodeStatus().Return(ns) + nCtx.OnNodeID().Return(nodeID) + nCtx.OnEventsRecorder().Return(nil) + nCtx.OnEnqueueOwnerFunc().Return(nil) + + ds, err := storage.NewDataStore( + &storage.Config{ + Type: storage.TypeMemory, + }, + promutils.NewTestScope(), + ) + assert.NoError(t, err) + nCtx.OnDataStore().Return(ds) st := bytes.NewBuffer([]byte{}) a := 45 @@ -79,10 +89,10 @@ func TestHandler_newTaskExecutionContext(t *testing.T) { codex := codex.GobStateCodec{} assert.NoError(t, codex.Encode(test{A: a}, st)) nr := &nodeMocks.NodeStateReader{} - nr.On("GetTaskNodeState").Return(handler.TaskNodeState{ + nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ PluginState: st.Bytes(), }) - nCtx.On("NodeStateReader").Return(nr) + nCtx.OnNodeStateReader().Return(nr) nCtx.OnRawOutputPrefix().Return("s3://sandbox/") nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"})) @@ -124,7 +134,7 @@ func TestHandler_newTaskExecutionContext(t *testing.T) { assert.Equal(t, got.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), "name-n1-1") assert.Equal(t, got.TaskExecutionMetadata().GetTaskExecutionID().GetID().TaskId, taskID) assert.Equal(t, got.TaskExecutionMetadata().GetTaskExecutionID().GetID().RetryAttempt, uint32(1)) - assert.Equal(t, got.TaskExecutionMetadata().GetTaskExecutionID().GetID().NodeExecutionId.GetNodeId(), "n1") + assert.Equal(t, got.TaskExecutionMetadata().GetTaskExecutionID().GetID().NodeExecutionId.GetNodeId(), nodeID) assert.Equal(t, got.TaskExecutionMetadata().GetTaskExecutionID().GetID().NodeExecutionId.GetExecutionId(), wfExecID) // TODO @kumare fix this test diff --git a/pkg/controller/nodes/task/transformer.go b/pkg/controller/nodes/task/transformer.go index def576d0e3..4dcccb45bd 100644 --- a/pkg/controller/nodes/task/transformer.go +++ b/pkg/controller/nodes/task/transformer.go @@ -87,11 +87,8 @@ func ToTaskExecutionEvent(taskExecID *core.TaskExecutionIdentifier, in io.InputF func GetTaskExecutionIdentifier(nCtx handler.NodeExecutionContext) *core.TaskExecutionIdentifier { return &core.TaskExecutionIdentifier{ - TaskId: nCtx.TaskReader().GetTaskID(), - RetryAttempt: nCtx.CurrentAttempt(), - NodeExecutionId: &core.NodeExecutionIdentifier{ - NodeId: nCtx.NodeID(), - ExecutionId: nCtx.NodeExecutionMetadata().GetExecutionID().WorkflowExecutionIdentifier, - }, + TaskId: nCtx.TaskReader().GetTaskID(), + RetryAttempt: nCtx.CurrentAttempt(), + NodeExecutionId: nCtx.NodeExecutionMetadata().GetNodeExecutionID(), } } diff --git a/pkg/controller/workflow/executor.go b/pkg/controller/workflow/executor.go index 6be11f20c9..3c82d68635 100644 --- a/pkg/controller/workflow/executor.go +++ b/pkg/controller/workflow/executor.go @@ -103,7 +103,7 @@ func (c *workflowExecutor) handleReadyWorkflow(ctx context.Context, w *v1alpha1. logger.Infof(ctx, "Setting the MetadataDir for StartNode [%v]", dataDir) nodeStatus.SetDataDir(dataDir) nodeStatus.SetOutputDir(outputDir) - s, err := c.nodeExecutor.SetInputsForStartNode(ctx, w, inputs) + s, err := c.nodeExecutor.SetInputsForStartNode(ctx, w, w, executors.NewNodeLookup(w, w.GetExecutionStatus()), inputs) if err != nil { return StatusReady, err } @@ -115,12 +115,11 @@ func (c *workflowExecutor) handleReadyWorkflow(ctx context.Context, w *v1alpha1. } func (c *workflowExecutor) handleRunningWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) (Status, error) { - contextualWf := executors.NewBaseContextualWorkflow(w) - startNode := contextualWf.StartNode() + startNode := w.StartNode() if startNode == nil { return StatusFailed(errors.Errorf(errors.IllegalStateError, w.GetID(), "StartNode not found in running workflow?")), nil } - state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, contextualWf, startNode) + state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, w, w, w, startNode) if err != nil { return StatusRunning, err } @@ -136,21 +135,20 @@ func (c *workflowExecutor) handleRunningWorkflow(ctx context.Context, w *v1alpha return StatusSucceeding, nil } if state.PartiallyComplete() { - c.enqueueWorkflow(contextualWf.GetK8sWorkflowID().String()) + c.enqueueWorkflow(w.GetK8sWorkflowID().String()) } return StatusRunning, nil } func (c *workflowExecutor) handleFailingWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) (Status, error) { - contextualWf := executors.NewBaseContextualWorkflow(w) // Best effort clean-up. - if err := c.cleanupRunningNodes(ctx, contextualWf, "Some node execution failed, auto-abort."); err != nil { + if err := c.cleanupRunningNodes(ctx, w, "Some node execution failed, auto-abort."); err != nil { logger.Errorf(ctx, "Failed to propagate Abort for workflow:%v. Error: %v", w.ExecutionID.WorkflowExecutionIdentifier, err) } - errorNode := contextualWf.GetOnFailureNode() + errorNode := w.GetOnFailureNode() if errorNode != nil { - state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, contextualWf, errorNode) + state, err := c.nodeExecutor.RecursiveNodeHandler(ctx, w, w, w, errorNode) if err != nil { return StatusFailing(nil), err } @@ -163,12 +161,12 @@ func (c *workflowExecutor) handleFailingWorkflow(ctx context.Context, w *v1alpha } if state.PartiallyComplete() { // Re-enqueue the workflow - c.enqueueWorkflow(contextualWf.GetK8sWorkflowID().String()) + c.enqueueWorkflow(w.GetK8sWorkflowID().String()) return StatusFailing(nil), nil } // Fallthrough to handle state is complete } - return StatusFailed(errors.Errorf(errors.CausedByError, w.ID, contextualWf.GetExecutionStatus().GetMessage())), nil + return StatusFailed(errors.Errorf(errors.CausedByError, w.ID, w.GetExecutionStatus().GetMessage())), nil } func (c *workflowExecutor) handleSucceedingWorkflow(ctx context.Context, w *v1alpha1.FlyteWorkflow) Status { @@ -371,8 +369,7 @@ func (c *workflowExecutor) HandleAbortedWorkflow(ctx context.Context, w *v1alpha } // Best effort clean-up. - contextualWf := executors.NewBaseContextualWorkflow(w) - if err2 := c.cleanupRunningNodes(ctx, contextualWf, reason); err2 != nil { + if err2 := c.cleanupRunningNodes(ctx, w, reason); err2 != nil { logger.Errorf(ctx, "Failed to propagate Abort for workflow:%v. Error: %v", w.ExecutionID.WorkflowExecutionIdentifier, err2) } @@ -400,7 +397,7 @@ func (c *workflowExecutor) cleanupRunningNodes(ctx context.Context, w v1alpha1.E return errors.Errorf(errors.IllegalStateError, w.GetID(), "StartNode not found in running workflow?") } - if err := c.nodeExecutor.AbortHandler(ctx, w, startNode, reason); err != nil { + if err := c.nodeExecutor.AbortHandler(ctx, w, w, w, startNode, reason); err != nil { return errors.Errorf(errors.CausedByError, w.GetID(), "Failed to propagate Abort for workflow. Error: %v", err) } diff --git a/pkg/visualize/sort.go b/pkg/visualize/sort.go index 7b4f62559b..b8aded43dc 100644 --- a/pkg/visualize/sort.go +++ b/pkg/visualize/sort.go @@ -23,7 +23,7 @@ func NewNodeVisitor(nodes []v1alpha1.NodeID) NodeVisitor { return v } -func tsortHelper(g v1alpha1.ExecutableWorkflow, currentNode v1alpha1.ExecutableNode, visited NodeVisitor, reverseSortedNodes *[]v1alpha1.ExecutableNode) error { +func tsortHelper(g v1alpha1.ExecutableSubWorkflow, currentNode v1alpha1.ExecutableNode, visited NodeVisitor, reverseSortedNodes *[]v1alpha1.ExecutableNode) error { if visited[currentNode.GetID()] == NotVisited { visited[currentNode.GetID()] = Visited defer func() { @@ -62,7 +62,7 @@ func reverseSlice(sl []v1alpha1.ExecutableNode) []v1alpha1.ExecutableNode { return sl } -func TopologicalSort(g v1alpha1.ExecutableWorkflow) ([]v1alpha1.ExecutableNode, error) { +func TopologicalSort(g v1alpha1.ExecutableSubWorkflow) ([]v1alpha1.ExecutableNode, error) { reverseSortedNodes := make([]v1alpha1.ExecutableNode, 0, 25) visited := NewNodeVisitor(g.GetNodes()) if err := tsortHelper(g, g.StartNode(), visited, &reverseSortedNodes); err != nil {