From bf00d9ff1e8f6a14722f1eafdedfb2841cd62496 Mon Sep 17 00:00:00 2001 From: Dan Rammer Date: Thu, 1 Dec 2022 12:31:29 -0600 Subject: [PATCH] Add support for GateNode with signal and sleep condition (#436) * Update flyteidl version Signed-off-by: Flyte-Bot * Update flyteidl version Signed-off-by: Flyte-Bot * Fix build break Signed-off-by: Haytham Abuelfutuh * Update flyteidl version Signed-off-by: Flyte-Bot * added GateNode to compiler Signed-off-by: Daniel Rammer * added gate node handler Signed-off-by: Daniel Rammer * enable reading and setting gate node state Signed-off-by: Daniel Rammer * gate nodes working Signed-off-by: Daniel Rammer * changed Conditional to Condition in proto naming Signed-off-by: Daniel Rammer * passing admin client to gate node handler Signed-off-by: Daniel Rammer * using signal service client to check for signal in admin and write output Signed-off-by: Daniel Rammer * updated comments Signed-off-by: Daniel Rammer * completed implementation Signed-off-by: Daniel Rammer * added unit tests for gate node Signed-off-by: Daniel Rammer * fixed unit tests with missing signal mocks Signed-off-by: Daniel Rammer * added docs on gate node handler Signed-off-by: Daniel Rammer * fixed lint issues Signed-off-by: Daniel Rammer * updating flyteidl dependency Signed-off-by: Daniel Rammer * fixed lint issue Signed-off-by: Daniel Rammer * added output variable name to signal condition Signed-off-by: Daniel Rammer * using last attempt started at timestamp on node context rather than tracking in gate node status Signed-off-by: Daniel Rammer * updated GateNodeStatus mocks Signed-off-by: Daniel Rammer * fixed lint issue Signed-off-by: Daniel Rammer * fixed unit tests Signed-off-by: Daniel Rammer * updated flyteidl deps Signed-off-by: Daniel Rammer * update flyteidl deps Signed-off-by: Daniel Rammer * added interface validation for approve condition Signed-off-by: Daniel Rammer * added approve condition unit tests Signed-off-by: Daniel Rammer * fixed missed merge conflict updating to slice of dial options Signed-off-by: Daniel Rammer * update generated mocks Signed-off-by: Dan Rammer Signed-off-by: Flyte-Bot Signed-off-by: Haytham Abuelfutuh Signed-off-by: Daniel Rammer Signed-off-by: Dan Rammer Co-authored-by: flyte-bot Co-authored-by: Haytham Abuelfutuh --- .../cmd/kubectl-flyte/cmd/create.go | 6 +- .../testdata/gate-node-approve.yaml.golden | 20 ++ .../cmd/testdata/gate-node-signal.yaml.golden | 16 + .../cmd/testdata/gate-node-sleep.yaml.golden | 13 + .../cmd/testdata/launchplan.yaml.golden | 10 + .../pkg/apis/flyteworkflow/v1alpha1/gate.go | 106 ++++++ .../pkg/apis/flyteworkflow/v1alpha1/iface.go | 23 ++ .../v1alpha1/mocks/ExecutableGateNode.go | 149 +++++++++ .../mocks/ExecutableGateNodeStatus.go | 45 +++ .../v1alpha1/mocks/ExecutableNode.go | 34 ++ .../v1alpha1/mocks/ExecutableNodeStatus.go | 73 +++++ .../v1alpha1/mocks/MutableGateNodeStatus.go | 82 +++++ .../v1alpha1/mocks/MutableNodeStatus.go | 73 +++++ .../flyteworkflow/v1alpha1/node_status.go | 47 +++ .../pkg/apis/flyteworkflow/v1alpha1/nodes.go | 8 + .../apis/flyteworkflow/v1alpha1/workflow.go | 2 +- .../pkg/compiler/common/mocks/node.go | 34 ++ .../pkg/compiler/common/mocks/node_builder.go | 34 ++ flytepropeller/pkg/compiler/common/reader.go | 1 + .../pkg/compiler/errors/compiler_errors.go | 11 + .../pkg/compiler/transformers/k8s/node.go | 26 ++ .../compiler/transformers/k8s/node_test.go | 57 +++- .../pkg/compiler/validators/interface.go | 30 ++ .../pkg/compiler/validators/interface_test.go | 97 ++++++ flytepropeller/pkg/controller/controller.go | 10 +- .../controller/nodes/branch/handler_test.go | 4 + .../controller/nodes/dynamic/handler_test.go | 4 + .../pkg/controller/nodes/executor.go | 5 +- .../pkg/controller/nodes/executor_test.go | 89 ++--- .../pkg/controller/nodes/gate/handler.go | 231 +++++++++++++ .../pkg/controller/nodes/gate/handler_test.go | 305 ++++++++++++++++++ .../nodes/gate/mocks/signal_service_client.go | 162 ++++++++++ .../nodes/handler/mocks/node_state_reader.go | 32 ++ .../nodes/handler/mocks/node_state_writer.go | 32 ++ .../pkg/controller/nodes/handler/state.go | 7 + .../nodes/handler/transition_info.go | 4 + .../pkg/controller/nodes/handler_factory.go | 6 +- .../controller/nodes/node_state_manager.go | 16 + .../nodes/subworkflow/handler_test.go | 4 + .../pkg/controller/nodes/task/handler_test.go | 4 + .../pkg/controller/nodes/transformers.go | 6 + .../pkg/controller/workflow/executor_test.go | 26 +- 42 files changed, 1876 insertions(+), 68 deletions(-) create mode 100755 flytepropeller/cmd/kubectl-flyte/cmd/testdata/gate-node-approve.yaml.golden create mode 100755 flytepropeller/cmd/kubectl-flyte/cmd/testdata/gate-node-signal.yaml.golden create mode 100755 flytepropeller/cmd/kubectl-flyte/cmd/testdata/gate-node-sleep.yaml.golden create mode 100755 flytepropeller/cmd/kubectl-flyte/cmd/testdata/launchplan.yaml.golden create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/gate.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableGateNode.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableGateNodeStatus.go create mode 100644 flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableGateNodeStatus.go create mode 100644 flytepropeller/pkg/controller/nodes/gate/handler.go create mode 100644 flytepropeller/pkg/controller/nodes/gate/handler_test.go create mode 100644 flytepropeller/pkg/controller/nodes/gate/mocks/signal_service_client.go diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/create.go b/flytepropeller/cmd/kubectl-flyte/cmd/create.go index 286596419d..91ea38255c 100644 --- a/flytepropeller/cmd/kubectl-flyte/cmd/create.go +++ b/flytepropeller/cmd/kubectl-flyte/cmd/create.go @@ -14,6 +14,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/compiler" "github.com/flyteorg/flytepropeller/pkg/compiler/common" compilerErrors "github.com/flyteorg/flytepropeller/pkg/compiler/errors" @@ -200,6 +201,9 @@ func (c *CreateOpts) createWorkflowFromProto() error { if err != nil { return err } + flyteWf.ExecutionID = v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: executionID, + } if flyteWf.Annotations == nil { flyteWf.Annotations = *c.annotations.value } else { @@ -209,7 +213,7 @@ func (c *CreateOpts) createWorkflowFromProto() error { } if c.dryRun { - fmt.Printf("Dry Run mode enabled. Printing the compiled workflow.") + fmt.Printf("Dry Run mode enabled. Printing the compiled workflow.\n") j, err := json.Marshal(flyteWf) if err != nil { return errors.Wrapf(err, "Failed to marshal final workflow to Propeller format.") diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/testdata/gate-node-approve.yaml.golden b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/gate-node-approve.yaml.golden new file mode 100755 index 0000000000..1f049bff84 --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/gate-node-approve.yaml.golden @@ -0,0 +1,20 @@ +workflow: + id: + name: workflow-id-123 + domain: development + project: flytesnacks + interface: + inputs: + variables: + x: + type: + simple: INTEGER + "y": + type: + collectionType: + simple: STRING + nodes: + - id: node-1 + gateNode: + approve: + signalId: foo diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/testdata/gate-node-signal.yaml.golden b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/gate-node-signal.yaml.golden new file mode 100755 index 0000000000..159626d346 --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/gate-node-signal.yaml.golden @@ -0,0 +1,16 @@ +workflow: + id: + name: workflow-id-123 + domain: development + project: flytesnacks + interface: + inputs: + variables: {} + nodes: + - id: node-1 + gateNode: + signal: + signalId: foo + type: + simple: BOOLEAN + outputVariableName: o0 diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/testdata/gate-node-sleep.yaml.golden b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/gate-node-sleep.yaml.golden new file mode 100755 index 0000000000..bf0724eb4f --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/gate-node-sleep.yaml.golden @@ -0,0 +1,13 @@ +workflow: + id: + name: workflow-id-123 + domain: development + project: flytesnacks + interface: + inputs: + variables: {} + nodes: + - id: node-1 + gateNode: + sleep: + duration: 10s diff --git a/flytepropeller/cmd/kubectl-flyte/cmd/testdata/launchplan.yaml.golden b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/launchplan.yaml.golden new file mode 100755 index 0000000000..b78eefcadd --- /dev/null +++ b/flytepropeller/cmd/kubectl-flyte/cmd/testdata/launchplan.yaml.golden @@ -0,0 +1,10 @@ +workflow: + id: + name: missing-launchplan + nodes: + - id: node-1 + workflowNode: + launchplanRef: + project: foo + domain: bar + name: baz diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/gate.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/gate.go new file mode 100644 index 0000000000..621cf94aed --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/gate.go @@ -0,0 +1,106 @@ +package v1alpha1 + +import ( + "bytes" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/golang/protobuf/jsonpb" +) + +type ConditionKind string + +func (n ConditionKind) String() string { + return string(n) +} + +const ( + ConditionKindApprove ConditionKind = "approve" + ConditionKindSignal ConditionKind = "signal" + ConditionKindSleep ConditionKind = "sleep" +) + +type ApproveCondition struct { + *core.ApproveCondition +} + +func (in ApproveCondition) MarshalJSON() ([]byte, error) { + if in.ApproveCondition == nil { + return nilJSON, nil + } + + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.ApproveCondition); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (in *ApproveCondition) UnmarshalJSON(b []byte) error { + in.ApproveCondition = &core.ApproveCondition{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.ApproveCondition) +} + +type SignalCondition struct { + *core.SignalCondition +} + +func (in SignalCondition) MarshalJSON() ([]byte, error) { + if in.SignalCondition == nil { + return nilJSON, nil + } + + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.SignalCondition); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (in *SignalCondition) UnmarshalJSON(b []byte) error { + in.SignalCondition = &core.SignalCondition{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.SignalCondition) +} + +type SleepCondition struct { + *core.SleepCondition +} + +func (in SleepCondition) MarshalJSON() ([]byte, error) { + if in.SleepCondition == nil { + return nilJSON, nil + } + + var buf bytes.Buffer + if err := marshaler.Marshal(&buf, in.SleepCondition); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (in *SleepCondition) UnmarshalJSON(b []byte) error { + in.SleepCondition = &core.SleepCondition{} + return jsonpb.Unmarshal(bytes.NewReader(b), in.SleepCondition) +} + +type GateNodeSpec struct { + Kind ConditionKind `json:"kind"` + Approve *ApproveCondition `json:"approve,omitempty"` + Signal *SignalCondition `json:"signal,omitempty"` + Sleep *SleepCondition `json:"sleep,omitempty"` +} + +func (g *GateNodeSpec) GetKind() ConditionKind { + return g.Kind +} + +func (g *GateNodeSpec) GetApprove() *core.ApproveCondition { + return g.Approve.ApproveCondition +} + +func (g *GateNodeSpec) GetSignal() *core.SignalCondition { + return g.Signal.SignalCondition +} + +func (g *GateNodeSpec) GetSleep() *core.SleepCondition { + return g.Sleep.SleepCondition +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go index 65d46bca2b..c52361b239 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/iface.go @@ -44,6 +44,7 @@ const ( NodeKindTask NodeKind = "task" NodeKindBranch NodeKind = "branch" // A Branch node with conditions NodeKindWorkflow NodeKind = "workflow" // Either an inline workflow or a remote workflow definition + NodeKindGate NodeKind = "gate" // A Gate node with a condition NodeKindStart NodeKind = "start" // Start node is a special node NodeKindEnd NodeKind = "end" ) @@ -245,6 +246,13 @@ type ExecutableBranchNode interface { GetElseFail() *core.Error } +type ExecutableGateNode interface { + GetKind() ConditionKind + GetApprove() *core.ApproveCondition + GetSignal() *core.SignalCondition + GetSleep() *core.SleepCondition +} + type ExecutableWorkflowNodeStatus interface { GetWorkflowNodePhase() WorkflowNodePhase GetExecutionError() *core.ExecutionError @@ -257,6 +265,16 @@ type MutableWorkflowNodeStatus interface { SetExecutionError(executionError *core.ExecutionError) } +type ExecutableGateNodeStatus interface { + GetGateNodePhase() GateNodePhase +} + +type MutableGateNodeStatus interface { + Mutable + ExecutableGateNodeStatus + SetGateNodePhase(phase GateNodePhase) +} + type Mutable interface { IsDirty() bool } @@ -288,6 +306,10 @@ type MutableNodeStatus interface { ClearDynamicNodeStatus() ClearLastAttemptStartedAt() ClearSubNodeStatus() + + GetGateNodeStatus() MutableGateNodeStatus + GetOrCreateGateNodeStatus() MutableGateNodeStatus + ClearGateNodeStatus() } type ExecutionTimeInfo interface { @@ -370,6 +392,7 @@ type ExecutableNode interface { GetTaskID() *TaskID GetBranchNode() ExecutableBranchNode GetWorkflowNode() ExecutableWorkflowNode + GetGateNode() ExecutableGateNode GetOutputAlias() []Alias GetInputBindings() []*Binding GetResources() *v1.ResourceRequirements diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableGateNode.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableGateNode.go new file mode 100644 index 0000000000..c7889632b1 --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableGateNode.go @@ -0,0 +1,149 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + core "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + mock "github.com/stretchr/testify/mock" + + v1alpha1 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" +) + +// ExecutableGateNode is an autogenerated mock type for the ExecutableGateNode type +type ExecutableGateNode struct { + mock.Mock +} + +type ExecutableGateNode_GetApprove struct { + *mock.Call +} + +func (_m ExecutableGateNode_GetApprove) Return(_a0 *core.ApproveCondition) *ExecutableGateNode_GetApprove { + return &ExecutableGateNode_GetApprove{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableGateNode) OnGetApprove() *ExecutableGateNode_GetApprove { + c_call := _m.On("GetApprove") + return &ExecutableGateNode_GetApprove{Call: c_call} +} + +func (_m *ExecutableGateNode) OnGetApproveMatch(matchers ...interface{}) *ExecutableGateNode_GetApprove { + c_call := _m.On("GetApprove", matchers...) + return &ExecutableGateNode_GetApprove{Call: c_call} +} + +// GetApprove provides a mock function with given fields: +func (_m *ExecutableGateNode) GetApprove() *core.ApproveCondition { + ret := _m.Called() + + var r0 *core.ApproveCondition + if rf, ok := ret.Get(0).(func() *core.ApproveCondition); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.ApproveCondition) + } + } + + return r0 +} + +type ExecutableGateNode_GetKind struct { + *mock.Call +} + +func (_m ExecutableGateNode_GetKind) Return(_a0 v1alpha1.ConditionKind) *ExecutableGateNode_GetKind { + return &ExecutableGateNode_GetKind{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableGateNode) OnGetKind() *ExecutableGateNode_GetKind { + c_call := _m.On("GetKind") + return &ExecutableGateNode_GetKind{Call: c_call} +} + +func (_m *ExecutableGateNode) OnGetKindMatch(matchers ...interface{}) *ExecutableGateNode_GetKind { + c_call := _m.On("GetKind", matchers...) + return &ExecutableGateNode_GetKind{Call: c_call} +} + +// GetKind provides a mock function with given fields: +func (_m *ExecutableGateNode) GetKind() v1alpha1.ConditionKind { + ret := _m.Called() + + var r0 v1alpha1.ConditionKind + if rf, ok := ret.Get(0).(func() v1alpha1.ConditionKind); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.ConditionKind) + } + + return r0 +} + +type ExecutableGateNode_GetSignal struct { + *mock.Call +} + +func (_m ExecutableGateNode_GetSignal) Return(_a0 *core.SignalCondition) *ExecutableGateNode_GetSignal { + return &ExecutableGateNode_GetSignal{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableGateNode) OnGetSignal() *ExecutableGateNode_GetSignal { + c_call := _m.On("GetSignal") + return &ExecutableGateNode_GetSignal{Call: c_call} +} + +func (_m *ExecutableGateNode) OnGetSignalMatch(matchers ...interface{}) *ExecutableGateNode_GetSignal { + c_call := _m.On("GetSignal", matchers...) + return &ExecutableGateNode_GetSignal{Call: c_call} +} + +// GetSignal provides a mock function with given fields: +func (_m *ExecutableGateNode) GetSignal() *core.SignalCondition { + ret := _m.Called() + + var r0 *core.SignalCondition + if rf, ok := ret.Get(0).(func() *core.SignalCondition); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.SignalCondition) + } + } + + return r0 +} + +type ExecutableGateNode_GetSleep struct { + *mock.Call +} + +func (_m ExecutableGateNode_GetSleep) Return(_a0 *core.SleepCondition) *ExecutableGateNode_GetSleep { + return &ExecutableGateNode_GetSleep{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableGateNode) OnGetSleep() *ExecutableGateNode_GetSleep { + c_call := _m.On("GetSleep") + return &ExecutableGateNode_GetSleep{Call: c_call} +} + +func (_m *ExecutableGateNode) OnGetSleepMatch(matchers ...interface{}) *ExecutableGateNode_GetSleep { + c_call := _m.On("GetSleep", matchers...) + return &ExecutableGateNode_GetSleep{Call: c_call} +} + +// GetSleep provides a mock function with given fields: +func (_m *ExecutableGateNode) GetSleep() *core.SleepCondition { + ret := _m.Called() + + var r0 *core.SleepCondition + if rf, ok := ret.Get(0).(func() *core.SleepCondition); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.SleepCondition) + } + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableGateNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableGateNodeStatus.go new file mode 100644 index 0000000000..ba9ac5e69f --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableGateNodeStatus.go @@ -0,0 +1,45 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + v1alpha1 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + mock "github.com/stretchr/testify/mock" +) + +// ExecutableGateNodeStatus is an autogenerated mock type for the ExecutableGateNodeStatus type +type ExecutableGateNodeStatus struct { + mock.Mock +} + +type ExecutableGateNodeStatus_GetGateNodePhase struct { + *mock.Call +} + +func (_m ExecutableGateNodeStatus_GetGateNodePhase) Return(_a0 v1alpha1.GateNodePhase) *ExecutableGateNodeStatus_GetGateNodePhase { + return &ExecutableGateNodeStatus_GetGateNodePhase{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableGateNodeStatus) OnGetGateNodePhase() *ExecutableGateNodeStatus_GetGateNodePhase { + c_call := _m.On("GetGateNodePhase") + return &ExecutableGateNodeStatus_GetGateNodePhase{Call: c_call} +} + +func (_m *ExecutableGateNodeStatus) OnGetGateNodePhaseMatch(matchers ...interface{}) *ExecutableGateNodeStatus_GetGateNodePhase { + c_call := _m.On("GetGateNodePhase", matchers...) + return &ExecutableGateNodeStatus_GetGateNodePhase{Call: c_call} +} + +// GetGateNodePhase provides a mock function with given fields: +func (_m *ExecutableGateNodeStatus) GetGateNodePhase() v1alpha1.GateNodePhase { + ret := _m.Called() + + var r0 v1alpha1.GateNodePhase + if rf, ok := ret.Get(0).(func() v1alpha1.GateNodePhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.GateNodePhase) + } + + return r0 +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go index 883e8eca66..5fbd946fae 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNode.go @@ -153,6 +153,40 @@ func (_m *ExecutableNode) GetExecutionDeadline() *time.Duration { return r0 } +type ExecutableNode_GetGateNode struct { + *mock.Call +} + +func (_m ExecutableNode_GetGateNode) Return(_a0 v1alpha1.ExecutableGateNode) *ExecutableNode_GetGateNode { + return &ExecutableNode_GetGateNode{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableNode) OnGetGateNode() *ExecutableNode_GetGateNode { + c_call := _m.On("GetGateNode") + return &ExecutableNode_GetGateNode{Call: c_call} +} + +func (_m *ExecutableNode) OnGetGateNodeMatch(matchers ...interface{}) *ExecutableNode_GetGateNode { + c_call := _m.On("GetGateNode", matchers...) + return &ExecutableNode_GetGateNode{Call: c_call} +} + +// GetGateNode provides a mock function with given fields: +func (_m *ExecutableNode) GetGateNode() v1alpha1.ExecutableGateNode { + ret := _m.Called() + + var r0 v1alpha1.ExecutableGateNode + if rf, ok := ret.Get(0).(func() v1alpha1.ExecutableGateNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.ExecutableGateNode) + } + } + + return r0 +} + type ExecutableNode_GetID struct { *mock.Call } diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go index a441a9cd4b..346680cfa9 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/ExecutableNodeStatus.go @@ -25,6 +25,11 @@ func (_m *ExecutableNodeStatus) ClearDynamicNodeStatus() { _m.Called() } +// ClearGateNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) ClearGateNodeStatus() { + _m.Called() +} + // ClearLastAttemptStartedAt provides a mock function with given fields: func (_m *ExecutableNodeStatus) ClearLastAttemptStartedAt() { _m.Called() @@ -211,6 +216,40 @@ func (_m *ExecutableNodeStatus) GetExecutionError() *core.ExecutionError { return r0 } +type ExecutableNodeStatus_GetGateNodeStatus struct { + *mock.Call +} + +func (_m ExecutableNodeStatus_GetGateNodeStatus) Return(_a0 v1alpha1.MutableGateNodeStatus) *ExecutableNodeStatus_GetGateNodeStatus { + return &ExecutableNodeStatus_GetGateNodeStatus{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableNodeStatus) OnGetGateNodeStatus() *ExecutableNodeStatus_GetGateNodeStatus { + c_call := _m.On("GetGateNodeStatus") + return &ExecutableNodeStatus_GetGateNodeStatus{Call: c_call} +} + +func (_m *ExecutableNodeStatus) OnGetGateNodeStatusMatch(matchers ...interface{}) *ExecutableNodeStatus_GetGateNodeStatus { + c_call := _m.On("GetGateNodeStatus", matchers...) + return &ExecutableNodeStatus_GetGateNodeStatus{Call: c_call} +} + +// GetGateNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetGateNodeStatus() v1alpha1.MutableGateNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableGateNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableGateNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableGateNodeStatus) + } + } + + return r0 +} + type ExecutableNodeStatus_GetLastAttemptStartedAt struct { *mock.Call } @@ -413,6 +452,40 @@ func (_m *ExecutableNodeStatus) GetOrCreateDynamicNodeStatus() v1alpha1.MutableD return r0 } +type ExecutableNodeStatus_GetOrCreateGateNodeStatus struct { + *mock.Call +} + +func (_m ExecutableNodeStatus_GetOrCreateGateNodeStatus) Return(_a0 v1alpha1.MutableGateNodeStatus) *ExecutableNodeStatus_GetOrCreateGateNodeStatus { + return &ExecutableNodeStatus_GetOrCreateGateNodeStatus{Call: _m.Call.Return(_a0)} +} + +func (_m *ExecutableNodeStatus) OnGetOrCreateGateNodeStatus() *ExecutableNodeStatus_GetOrCreateGateNodeStatus { + c_call := _m.On("GetOrCreateGateNodeStatus") + return &ExecutableNodeStatus_GetOrCreateGateNodeStatus{Call: c_call} +} + +func (_m *ExecutableNodeStatus) OnGetOrCreateGateNodeStatusMatch(matchers ...interface{}) *ExecutableNodeStatus_GetOrCreateGateNodeStatus { + c_call := _m.On("GetOrCreateGateNodeStatus", matchers...) + return &ExecutableNodeStatus_GetOrCreateGateNodeStatus{Call: c_call} +} + +// GetOrCreateGateNodeStatus provides a mock function with given fields: +func (_m *ExecutableNodeStatus) GetOrCreateGateNodeStatus() v1alpha1.MutableGateNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableGateNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableGateNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableGateNodeStatus) + } + } + + return r0 +} + type ExecutableNodeStatus_GetOrCreateTaskStatus struct { *mock.Call } diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableGateNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableGateNodeStatus.go new file mode 100644 index 0000000000..498940ffff --- /dev/null +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableGateNodeStatus.go @@ -0,0 +1,82 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + v1alpha1 "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + mock "github.com/stretchr/testify/mock" +) + +// MutableGateNodeStatus is an autogenerated mock type for the MutableGateNodeStatus type +type MutableGateNodeStatus struct { + mock.Mock +} + +type MutableGateNodeStatus_GetGateNodePhase struct { + *mock.Call +} + +func (_m MutableGateNodeStatus_GetGateNodePhase) Return(_a0 v1alpha1.GateNodePhase) *MutableGateNodeStatus_GetGateNodePhase { + return &MutableGateNodeStatus_GetGateNodePhase{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableGateNodeStatus) OnGetGateNodePhase() *MutableGateNodeStatus_GetGateNodePhase { + c_call := _m.On("GetGateNodePhase") + return &MutableGateNodeStatus_GetGateNodePhase{Call: c_call} +} + +func (_m *MutableGateNodeStatus) OnGetGateNodePhaseMatch(matchers ...interface{}) *MutableGateNodeStatus_GetGateNodePhase { + c_call := _m.On("GetGateNodePhase", matchers...) + return &MutableGateNodeStatus_GetGateNodePhase{Call: c_call} +} + +// GetGateNodePhase provides a mock function with given fields: +func (_m *MutableGateNodeStatus) GetGateNodePhase() v1alpha1.GateNodePhase { + ret := _m.Called() + + var r0 v1alpha1.GateNodePhase + if rf, ok := ret.Get(0).(func() v1alpha1.GateNodePhase); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(v1alpha1.GateNodePhase) + } + + return r0 +} + +type MutableGateNodeStatus_IsDirty struct { + *mock.Call +} + +func (_m MutableGateNodeStatus_IsDirty) Return(_a0 bool) *MutableGateNodeStatus_IsDirty { + return &MutableGateNodeStatus_IsDirty{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableGateNodeStatus) OnIsDirty() *MutableGateNodeStatus_IsDirty { + c_call := _m.On("IsDirty") + return &MutableGateNodeStatus_IsDirty{Call: c_call} +} + +func (_m *MutableGateNodeStatus) OnIsDirtyMatch(matchers ...interface{}) *MutableGateNodeStatus_IsDirty { + c_call := _m.On("IsDirty", matchers...) + return &MutableGateNodeStatus_IsDirty{Call: c_call} +} + +// IsDirty provides a mock function with given fields: +func (_m *MutableGateNodeStatus) IsDirty() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// SetGateNodePhase provides a mock function with given fields: phase +func (_m *MutableGateNodeStatus) SetGateNodePhase(phase v1alpha1.GateNodePhase) { + _m.Called(phase) +} diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go index 95001591c5..9bb0f59b2e 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks/MutableNodeStatus.go @@ -23,6 +23,11 @@ func (_m *MutableNodeStatus) ClearDynamicNodeStatus() { _m.Called() } +// ClearGateNodeStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) ClearGateNodeStatus() { + _m.Called() +} + // ClearLastAttemptStartedAt provides a mock function with given fields: func (_m *MutableNodeStatus) ClearLastAttemptStartedAt() { _m.Called() @@ -111,6 +116,40 @@ func (_m *MutableNodeStatus) GetDynamicNodeStatus() v1alpha1.MutableDynamicNodeS return r0 } +type MutableNodeStatus_GetGateNodeStatus struct { + *mock.Call +} + +func (_m MutableNodeStatus_GetGateNodeStatus) Return(_a0 v1alpha1.MutableGateNodeStatus) *MutableNodeStatus_GetGateNodeStatus { + return &MutableNodeStatus_GetGateNodeStatus{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableNodeStatus) OnGetGateNodeStatus() *MutableNodeStatus_GetGateNodeStatus { + c_call := _m.On("GetGateNodeStatus") + return &MutableNodeStatus_GetGateNodeStatus{Call: c_call} +} + +func (_m *MutableNodeStatus) OnGetGateNodeStatusMatch(matchers ...interface{}) *MutableNodeStatus_GetGateNodeStatus { + c_call := _m.On("GetGateNodeStatus", matchers...) + return &MutableNodeStatus_GetGateNodeStatus{Call: c_call} +} + +// GetGateNodeStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) GetGateNodeStatus() v1alpha1.MutableGateNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableGateNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableGateNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableGateNodeStatus) + } + } + + return r0 +} + type MutableNodeStatus_GetOrCreateBranchStatus struct { *mock.Call } @@ -179,6 +218,40 @@ func (_m *MutableNodeStatus) GetOrCreateDynamicNodeStatus() v1alpha1.MutableDyna return r0 } +type MutableNodeStatus_GetOrCreateGateNodeStatus struct { + *mock.Call +} + +func (_m MutableNodeStatus_GetOrCreateGateNodeStatus) Return(_a0 v1alpha1.MutableGateNodeStatus) *MutableNodeStatus_GetOrCreateGateNodeStatus { + return &MutableNodeStatus_GetOrCreateGateNodeStatus{Call: _m.Call.Return(_a0)} +} + +func (_m *MutableNodeStatus) OnGetOrCreateGateNodeStatus() *MutableNodeStatus_GetOrCreateGateNodeStatus { + c_call := _m.On("GetOrCreateGateNodeStatus") + return &MutableNodeStatus_GetOrCreateGateNodeStatus{Call: c_call} +} + +func (_m *MutableNodeStatus) OnGetOrCreateGateNodeStatusMatch(matchers ...interface{}) *MutableNodeStatus_GetOrCreateGateNodeStatus { + c_call := _m.On("GetOrCreateGateNodeStatus", matchers...) + return &MutableNodeStatus_GetOrCreateGateNodeStatus{Call: c_call} +} + +// GetOrCreateGateNodeStatus provides a mock function with given fields: +func (_m *MutableNodeStatus) GetOrCreateGateNodeStatus() v1alpha1.MutableGateNodeStatus { + ret := _m.Called() + + var r0 v1alpha1.MutableGateNodeStatus + if rf, ok := ret.Get(0).(func() v1alpha1.MutableGateNodeStatus); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(v1alpha1.MutableGateNodeStatus) + } + } + + return r0 +} + type MutableNodeStatus_GetOrCreateTaskStatus struct { *mock.Call } diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status.go index 7f2c7ead9c..7aea3f2b8e 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/node_status.go @@ -184,6 +184,29 @@ func (in *WorkflowNodeStatus) SetWorkflowNodePhase(phase WorkflowNodePhase) { } } +type GateNodePhase int + +const ( + GateNodePhaseUndefined GateNodePhase = iota + GateNodePhaseExecuting +) + +type GateNodeStatus struct { + MutableStruct + Phase GateNodePhase `json:"phase,omitempty"` +} + +func (in *GateNodeStatus) GetGateNodePhase() GateNodePhase { + return in.Phase +} + +func (in *GateNodeStatus) SetGateNodePhase(phase GateNodePhase) { + if in.Phase != phase { + in.SetDirty() + in.Phase = phase + } +} + type NodeStatus struct { MutableStruct Phase NodePhase `json:"phase,omitempty"` @@ -211,6 +234,7 @@ type NodeStatus struct { TaskNodeStatus *TaskNodeStatus `json:",omitempty"` DynamicNodeStatus *DynamicNodeStatus `json:"dynamicNodeStatus,omitempty"` + GateNodeStatus *GateNodeStatus `json:"gateNodeStatus,omitempty"` // In case of Failing/Failed Phase, an execution error can be optionally associated with the Node Error *ExecutionError `json:"error,omitempty"` @@ -284,6 +308,13 @@ func (in *NodeStatus) GetTaskStatus() MutableTaskNodeStatus { return in.TaskNodeStatus } +func (in *NodeStatus) GetGateNodeStatus() MutableGateNodeStatus { + if in.GateNodeStatus == nil { + return nil + } + return in.GateNodeStatus +} + func (in NodeStatus) VisitNodeStatuses(visitor NodeStatusVisitFn) { for n, s := range in.SubNodeStatus { visitor(n, s) @@ -317,6 +348,11 @@ func (in *NodeStatus) ClearSubNodeStatus() { in.SetDirty() } +func (in *NodeStatus) ClearGateNodeStatus() { + in.GateNodeStatus = nil + in.SetDirty() +} + func (in *NodeStatus) GetLastUpdatedAt() *metav1.Time { return in.LastUpdatedAt } @@ -412,6 +448,17 @@ func (in *NodeStatus) GetOrCreateTaskStatus() MutableTaskNodeStatus { return in.TaskNodeStatus } +func (in *NodeStatus) GetOrCreateGateNodeStatus() MutableGateNodeStatus { + if in.GateNodeStatus == nil { + in.SetDirty() + in.GateNodeStatus = &GateNodeStatus{ + MutableStruct: MutableStruct{}, + } + } + + return in.GateNodeStatus +} + func (in *NodeStatus) UpdatePhase(p NodePhase, occurredAt metav1.Time, reason string, err *core.ExecutionError) { if in.Phase == p { // We will not update the phase multiple times. This prevents the comparison from returning false positive diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go index ad114a11a3..682af365d8 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/nodes.go @@ -100,6 +100,7 @@ type NodeSpec struct { BranchNode *BranchNodeSpec `json:"branch,omitempty"` TaskRef *TaskID `json:"task,omitempty"` WorkflowNode *WorkflowNodeSpec `json:"workflow,omitempty"` + GateNode *GateNodeSpec `json:"gate,omitempty"` InputBindings []*Binding `json:"inputBindings,omitempty"` Config *typesv1.ConfigMap `json:"config,omitempty"` RetryStrategy *RetryStrategy `json:"retry,omitempty"` @@ -198,6 +199,13 @@ func (in *NodeSpec) GetBranchNode() ExecutableBranchNode { return in.BranchNode } +func (in *NodeSpec) GetGateNode() ExecutableGateNode { + if in.GateNode == nil { + return nil + } + return in.GateNode +} + func (in *NodeSpec) GetTaskID() *TaskID { return in.TaskRef } diff --git a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow.go b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow.go index 4aac6ef0fc..d9cc4ab586 100644 --- a/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow.go +++ b/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/workflow.go @@ -196,7 +196,7 @@ func (in *Inputs) UnmarshalJSON(b []byte) error { } func (in *Inputs) MarshalJSON() ([]byte, error) { - if in == nil { + if in == nil || in.LiteralMap == nil { return nilJSON, nil } diff --git a/flytepropeller/pkg/compiler/common/mocks/node.go b/flytepropeller/pkg/compiler/common/mocks/node.go index bf601c4ad2..364a1921dc 100644 --- a/flytepropeller/pkg/compiler/common/mocks/node.go +++ b/flytepropeller/pkg/compiler/common/mocks/node.go @@ -82,6 +82,40 @@ func (_m *Node) GetCoreNode() *core.Node { return r0 } +type Node_GetGateNode struct { + *mock.Call +} + +func (_m Node_GetGateNode) Return(_a0 *core.GateNode) *Node_GetGateNode { + return &Node_GetGateNode{Call: _m.Call.Return(_a0)} +} + +func (_m *Node) OnGetGateNode() *Node_GetGateNode { + c_call := _m.On("GetGateNode") + return &Node_GetGateNode{Call: c_call} +} + +func (_m *Node) OnGetGateNodeMatch(matchers ...interface{}) *Node_GetGateNode { + c_call := _m.On("GetGateNode", matchers...) + return &Node_GetGateNode{Call: c_call} +} + +// GetGateNode provides a mock function with given fields: +func (_m *Node) GetGateNode() *core.GateNode { + ret := _m.Called() + + var r0 *core.GateNode + if rf, ok := ret.Get(0).(func() *core.GateNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.GateNode) + } + } + + return r0 +} + type Node_GetId struct { *mock.Call } diff --git a/flytepropeller/pkg/compiler/common/mocks/node_builder.go b/flytepropeller/pkg/compiler/common/mocks/node_builder.go index a87ed756cc..44b320dc9e 100644 --- a/flytepropeller/pkg/compiler/common/mocks/node_builder.go +++ b/flytepropeller/pkg/compiler/common/mocks/node_builder.go @@ -82,6 +82,40 @@ func (_m *NodeBuilder) GetCoreNode() *core.Node { return r0 } +type NodeBuilder_GetGateNode struct { + *mock.Call +} + +func (_m NodeBuilder_GetGateNode) Return(_a0 *core.GateNode) *NodeBuilder_GetGateNode { + return &NodeBuilder_GetGateNode{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeBuilder) OnGetGateNode() *NodeBuilder_GetGateNode { + c_call := _m.On("GetGateNode") + return &NodeBuilder_GetGateNode{Call: c_call} +} + +func (_m *NodeBuilder) OnGetGateNodeMatch(matchers ...interface{}) *NodeBuilder_GetGateNode { + c_call := _m.On("GetGateNode", matchers...) + return &NodeBuilder_GetGateNode{Call: c_call} +} + +// GetGateNode provides a mock function with given fields: +func (_m *NodeBuilder) GetGateNode() *core.GateNode { + ret := _m.Called() + + var r0 *core.GateNode + if rf, ok := ret.Get(0).(func() *core.GateNode); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*core.GateNode) + } + } + + return r0 +} + type NodeBuilder_GetId struct { *mock.Call } diff --git a/flytepropeller/pkg/compiler/common/reader.go b/flytepropeller/pkg/compiler/common/reader.go index 8c8f24fe5e..d0ea361724 100644 --- a/flytepropeller/pkg/compiler/common/reader.go +++ b/flytepropeller/pkg/compiler/common/reader.go @@ -40,6 +40,7 @@ type Node interface { GetMetadata() *core.NodeMetadata GetTask() Task GetSubWorkflow() Workflow + GetGateNode() *core.GateNode } // An immutable task that represents the final output of the compiler. diff --git a/flytepropeller/pkg/compiler/errors/compiler_errors.go b/flytepropeller/pkg/compiler/errors/compiler_errors.go index eb8199dc91..b73b1927d6 100755 --- a/flytepropeller/pkg/compiler/errors/compiler_errors.go +++ b/flytepropeller/pkg/compiler/errors/compiler_errors.go @@ -90,6 +90,9 @@ const ( // Given value cannot be assigned to any union variant in a binding IncompatibleBindingUnionValue ErrorCode = "IncompatibleBindingUnionValue" + + // A gate node is missing a condition. + NoConditionFound ErrorCode = "NoConditionFound" ) func NewBranchNodeNotSpecified(branchNodeID string) *CompileError { @@ -316,6 +319,14 @@ func NewIncompatibleBindingUnionValue(nodeID, sinkParam, expectedType, binding s ) } +func NewNoConditionFound(nodeID string) *CompileError { + return newError( + NoConditionFound, + fmt.Sprintf("Can't find any condition in gate node [%v].", nodeID), + nodeID, + ) +} + func newError(code ErrorCode, description, nodeID string) (err *CompileError) { err = &CompileError{ code: code, diff --git a/flytepropeller/pkg/compiler/transformers/k8s/node.go b/flytepropeller/pkg/compiler/transformers/k8s/node.go index aece64a9aa..0b9d0e244c 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/node.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/node.go @@ -128,6 +128,32 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile // as the first element in the list. That way list[0] will always be the first node actualNode := []*v1alpha1.NodeSpec{nodeSpec} return append(actualNode, ns...), !errs.HasErrors() + case *core.Node_GateNode: + nodeSpec.Kind = v1alpha1.NodeKindGate + gateNode := n.GetGateNode() + switch gateNode.Condition.(type) { + case *core.GateNode_Approve: + nodeSpec.GateNode = &v1alpha1.GateNodeSpec{ + Kind: v1alpha1.ConditionKindApprove, + Approve: &v1alpha1.ApproveCondition{ + ApproveCondition: gateNode.GetApprove(), + }, + } + case *core.GateNode_Signal: + nodeSpec.GateNode = &v1alpha1.GateNodeSpec{ + Kind: v1alpha1.ConditionKindSignal, + Signal: &v1alpha1.SignalCondition{ + SignalCondition: gateNode.GetSignal(), + }, + } + case *core.GateNode_Sleep: + nodeSpec.GateNode = &v1alpha1.GateNodeSpec{ + Kind: v1alpha1.ConditionKindSleep, + Sleep: &v1alpha1.SleepCondition{ + SleepCondition: gateNode.GetSleep(), + }, + } + } default: if n.GetId() == v1alpha1.StartNodeID { nodeSpec.Kind = v1alpha1.NodeKindStart diff --git a/flytepropeller/pkg/compiler/transformers/k8s/node_test.go b/flytepropeller/pkg/compiler/transformers/k8s/node_test.go index a5b7ef275e..010eb0c3d0 100644 --- a/flytepropeller/pkg/compiler/transformers/k8s/node_test.go +++ b/flytepropeller/pkg/compiler/transformers/k8s/node_test.go @@ -2,14 +2,19 @@ package k8s import ( "testing" - - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "k8s.io/apimachinery/pkg/api/resource" + "time" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/compiler/common" "github.com/flyteorg/flytepropeller/pkg/compiler/errors" + "github.com/stretchr/testify/assert" + + "google.golang.org/protobuf/types/known/durationpb" + + "k8s.io/apimachinery/pkg/api/resource" ) func createNodeWithTask() *core.Node { @@ -208,6 +213,52 @@ func TestBuildNodeSpec(t *testing.T) { mustBuild(t, n, 2, errs.NewScope()) }) + t.Run("GateNodeApprove", func(t *testing.T) { + n.Node.Target = &core.Node_GateNode{ + GateNode: &core.GateNode{ + Condition: &core.GateNode_Approve{ + Approve: &core.ApproveCondition{ + SignalId: "foo", + }, + }, + }, + } + + mustBuild(t, n, 1, errs.NewScope()) + }) + + t.Run("GateNodeSignal", func(t *testing.T) { + n.Node.Target = &core.Node_GateNode{ + GateNode: &core.GateNode{ + Condition: &core.GateNode_Signal{ + Signal: &core.SignalCondition{ + SignalId: "foo", + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BOOLEAN, + }, + }, + }, + }, + }, + } + + mustBuild(t, n, 1, errs.NewScope()) + }) + + t.Run("GateNodeSleep", func(t *testing.T) { + n.Node.Target = &core.Node_GateNode{ + GateNode: &core.GateNode{ + Condition: &core.GateNode_Sleep{ + Sleep: &core.SleepCondition{ + Duration: durationpb.New(time.Minute), + }, + }, + }, + } + + mustBuild(t, n, 1, errs.NewScope()) + }) } func TestBuildTasks(t *testing.T) { diff --git a/flytepropeller/pkg/compiler/validators/interface.go b/flytepropeller/pkg/compiler/validators/interface.go index 3f05ecf5ac..1ab613241e 100644 --- a/flytepropeller/pkg/compiler/validators/interface.go +++ b/flytepropeller/pkg/compiler/validators/interface.go @@ -118,6 +118,36 @@ func ValidateUnderlyingInterface(w c.WorkflowBuilder, node c.NodeBuilder, errs e } case *core.Node_BranchNode: iface, _ = validateBranchInterface(w, node, errs.NewScope()) + case *core.Node_GateNode: + gateNode := node.GetGateNode() + if approve := gateNode.GetApprove(); approve != nil { + iface = &core.TypedInterface{ + Inputs: &core.VariableMap{Variables: map[string]*core.Variable{}}, + Outputs: &core.VariableMap{Variables: map[string]*core.Variable{}}, + } + } else if signal := gateNode.GetSignal(); signal != nil { + if signal.GetType() == nil { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "GateNode.Signal.Type")) + } else if len(signal.GetOutputVariableName()) == 0 { + errs.Collect(errors.NewValueRequiredErr(node.GetId(), "GateNode.Signal.OutputVariableName")) + } else { + iface = &core.TypedInterface{ + Inputs: &core.VariableMap{Variables: map[string]*core.Variable{}}, + Outputs: &core.VariableMap{Variables: map[string]*core.Variable{ + signal.GetOutputVariableName(): &core.Variable{ + Type: signal.GetType(), + }, + }}, + } + } + } else if sleep := gateNode.GetSleep(); sleep != nil { + iface = &core.TypedInterface{ + Inputs: &core.VariableMap{Variables: map[string]*core.Variable{}}, + Outputs: &core.VariableMap{Variables: map[string]*core.Variable{}}, + } + } else { + errs.Collect(errors.NewNoConditionFound(node.GetId())) + } default: errs.Collect(errors.NewValueRequiredErr(node.GetId(), "Target")) } diff --git a/flytepropeller/pkg/compiler/validators/interface_test.go b/flytepropeller/pkg/compiler/validators/interface_test.go index 2d4742f0fe..cac2737815 100644 --- a/flytepropeller/pkg/compiler/validators/interface_test.go +++ b/flytepropeller/pkg/compiler/validators/interface_test.go @@ -2,14 +2,19 @@ package validators import ( "testing" + "time" "github.com/flyteorg/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + c "github.com/flyteorg/flytepropeller/pkg/compiler/common" "github.com/flyteorg/flytepropeller/pkg/compiler/common/mocks" "github.com/flyteorg/flytepropeller/pkg/compiler/errors" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + + "google.golang.org/protobuf/types/known/durationpb" ) func TestValidateInterface(t *testing.T) { @@ -276,6 +281,98 @@ func TestValidateUnderlyingInterface(t *testing.T) { assertNonEmptyInterface(t, iface, ifaceOk, errs) }) }) + + t.Run("GateNode", func(t *testing.T) { + t.Run("Approve", func(t *testing.T) { + wfBuilder := mocks.WorkflowBuilder{} + + gateNode := &core.GateNode{ + Condition: &core.GateNode_Approve{ + Approve: &core.ApproveCondition{ + SignalId: "foo", + }, + }, + } + + nodeBuilder := mocks.NodeBuilder{} + nodeBuilder.On("GetCoreNode").Return(&core.Node{ + Target: &core.Node_GateNode{ + GateNode: gateNode, + }, + }) + nodeBuilder.OnGetInterface().Return(nil) + + nodeBuilder.On("GetGateNode").Return(gateNode) + nodeBuilder.On("GetId").Return("node_1") + nodeBuilder.On("SetInterface", mock.Anything).Return() + + errs := errors.NewCompileErrors() + iface, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assertNonEmptyInterface(t, iface, ifaceOk, errs) + }) + + t.Run("Signal", func(t *testing.T) { + wfBuilder := mocks.WorkflowBuilder{} + + gateNode := &core.GateNode{ + Condition: &core.GateNode_Signal{ + Signal: &core.SignalCondition{ + SignalId: "foo", + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BOOLEAN, + }, + }, + OutputVariableName: "foo", + }, + }, + } + + nodeBuilder := mocks.NodeBuilder{} + nodeBuilder.On("GetCoreNode").Return(&core.Node{ + Target: &core.Node_GateNode{ + GateNode: gateNode, + }, + }) + nodeBuilder.OnGetInterface().Return(nil) + + nodeBuilder.On("GetGateNode").Return(gateNode) + nodeBuilder.On("GetId").Return("node_1") + nodeBuilder.On("SetInterface", mock.Anything).Return() + + errs := errors.NewCompileErrors() + iface, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assertNonEmptyInterface(t, iface, ifaceOk, errs) + }) + + t.Run("Sleep", func(t *testing.T) { + wfBuilder := mocks.WorkflowBuilder{} + + gateNode := &core.GateNode{ + Condition: &core.GateNode_Sleep{ + Sleep: &core.SleepCondition{ + Duration: durationpb.New(time.Minute), + }, + }, + } + + nodeBuilder := mocks.NodeBuilder{} + nodeBuilder.On("GetCoreNode").Return(&core.Node{ + Target: &core.Node_GateNode{ + GateNode: gateNode, + }, + }) + nodeBuilder.OnGetInterface().Return(nil) + + nodeBuilder.On("GetGateNode").Return(gateNode) + nodeBuilder.On("GetId").Return("node_1") + nodeBuilder.On("SetInterface", mock.Anything).Return() + + errs := errors.NewCompileErrors() + iface, ifaceOk := ValidateUnderlyingInterface(&wfBuilder, &nodeBuilder, errs.NewScope()) + assertNonEmptyInterface(t, iface, ifaceOk, errs) + }) + }) } func matchIdentifier(id core.Identifier) interface{} { diff --git a/flytepropeller/pkg/controller/controller.go b/flytepropeller/pkg/controller/controller.go index c536e019af..a2a6a432a3 100644 --- a/flytepropeller/pkg/controller/controller.go +++ b/flytepropeller/pkg/controller/controller.go @@ -305,11 +305,11 @@ func newControllerMetrics(scope promutils.Scope) *metrics { } } -func getAdminClient(ctx context.Context) (client service.AdminServiceClient, opt []grpc.DialOption, err error) { +func getAdminClient(ctx context.Context) (client service.AdminServiceClient, signalClient service.SignalServiceClient, opt []grpc.DialOption, err error) { cfg := admin.GetConfig(ctx) clients, err := admin.NewClientsetBuilder().WithConfig(cfg).Build(ctx) if err != nil { - return nil, nil, fmt.Errorf("failed to initialize clientset. Error: %w", err) + return nil, nil, nil, fmt.Errorf("failed to initialize clientset. Error: %w", err) } credentialsFuture := admin.NewPerRPCCredentialsFuture() @@ -318,7 +318,7 @@ func getAdminClient(ctx context.Context) (client service.AdminServiceClient, opt grpc.WithPerRPCCredentials(credentialsFuture), } - return clients.AdminClient(), opts, nil + return clients.AdminClient(), clients.SignalServiceClient(), opts, nil } // New returns a new FlyteWorkflow controller @@ -326,7 +326,7 @@ func New(ctx context.Context, cfg *config.Config, kubeclientset kubernetes.Inter flyteworkflowInformerFactory informers.SharedInformerFactory, informerFactory k8sInformers.SharedInformerFactory, kubeClient executors.Client, scope promutils.Scope) (*Controller, error) { - adminClient, authOpts, err := getAdminClient(ctx) + adminClient, signalClient, authOpts, err := getAdminClient(ctx) if err != nil { logger.Errorf(ctx, "failed to initialize Admin client, err :%s", err.Error()) return nil, err @@ -439,7 +439,7 @@ func New(ctx context.Context, cfg *config.Config, kubeclientset kubernetes.Inter nodeExecutor, err := nodes.NewExecutor(ctx, cfg.NodeConfig, store, controller.enqueueWorkflowForNodeUpdates, eventSink, launchPlanActor, launchPlanActor, cfg.MaxDatasetSizeBytes, - storage.DataReference(cfg.DefaultRawOutputPrefix), kubeClient, catalogClient, recovery.NewClient(adminClient), &cfg.EventConfig, cfg.ClusterID, scope) + storage.DataReference(cfg.DefaultRawOutputPrefix), kubeClient, catalogClient, recovery.NewClient(adminClient), &cfg.EventConfig, cfg.ClusterID, signalClient, scope) if err != nil { return nil, errors.Wrapf(err, "Failed to create Controller.") } diff --git a/flytepropeller/pkg/controller/nodes/branch/handler_test.go b/flytepropeller/pkg/controller/nodes/branch/handler_test.go index b6f4ab0d66..bc39c1b243 100644 --- a/flytepropeller/pkg/controller/nodes/branch/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/branch/handler_test.go @@ -54,6 +54,10 @@ func (t branchNodeStateHolder) PutDynamicNodeState(s handler.DynamicNodeState) e panic("not implemented") } +func (t branchNodeStateHolder) PutGateNodeState(s handler.GateNodeState) error { + panic("not implemented") +} + type parentInfo struct { } diff --git a/flytepropeller/pkg/controller/nodes/dynamic/handler_test.go b/flytepropeller/pkg/controller/nodes/dynamic/handler_test.go index 6c048ec178..64c2eb336c 100644 --- a/flytepropeller/pkg/controller/nodes/dynamic/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/dynamic/handler_test.go @@ -55,6 +55,10 @@ func (t *dynamicNodeStateHolder) PutDynamicNodeState(s handler.DynamicNodeState) return nil } +func (t dynamicNodeStateHolder) PutGateNodeState(s handler.GateNodeState) error { + panic("not implemented") +} + var tID = "task-1" var eventConfig = &config.EventConfig{ diff --git a/flytepropeller/pkg/controller/nodes/executor.go b/flytepropeller/pkg/controller/nodes/executor.go index 678f7e446c..20a4bcdbd4 100644 --- a/flytepropeller/pkg/controller/nodes/executor.go +++ b/flytepropeller/pkg/controller/nodes/executor.go @@ -33,6 +33,7 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flytepropeller/events" eventsErr "github.com/flyteorg/flytepropeller/events/errors" "github.com/flyteorg/flytestdlib/contextutils" @@ -1148,7 +1149,7 @@ func (c *nodeExecutor) Initialize(ctx context.Context) error { func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *storage.DataStore, enQWorkflow v1alpha1.EnqueueWorkflow, eventSink events.EventSink, workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, maxDatasetSize int64, defaultRawOutputPrefix storage.DataReference, kubeClient executors.Client, - catalogClient catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, clusterID string, scope promutils.Scope) (executors.Node, error) { + catalogClient catalog.Client, recoveryClient recovery.Client, eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, scope promutils.Scope) (executors.Node, error) { // TODO we may want to make this configurable. shardSelector, err := ioutils.NewBase36PrefixShardSelector(ctx) @@ -1196,7 +1197,7 @@ func NewExecutor(ctx context.Context, nodeConfig config.NodeConfig, store *stora eventConfig: eventConfig, clusterID: clusterID, } - nodeHandlerFactory, err := NewHandlerFactory(ctx, exec, workflowLauncher, launchPlanReader, kubeClient, catalogClient, recoveryClient, eventConfig, clusterID, nodeScope) + nodeHandlerFactory, err := NewHandlerFactory(ctx, exec, workflowLauncher, launchPlanReader, kubeClient, catalogClient, recoveryClient, eventConfig, clusterID, signalClient, nodeScope) exec.nodeHandlerFactory = nodeHandlerFactory return exec, err } diff --git a/flytepropeller/pkg/controller/nodes/executor_test.go b/flytepropeller/pkg/controller/nodes/executor_test.go index e1e6db59de..2c3d62931c 100644 --- a/flytepropeller/pkg/controller/nodes/executor_test.go +++ b/flytepropeller/pkg/controller/nodes/executor_test.go @@ -8,48 +8,51 @@ import ( "testing" "time" - "github.com/flyteorg/flytestdlib/contextutils" - - "github.com/golang/protobuf/proto" - - mocks3 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" - storageMocks "github.com/flyteorg/flytestdlib/storage/mocks" - "github.com/flyteorg/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flytestdlib/promutils/labeled" - "github.com/flyteorg/flytestdlib/storage" - "github.com/prometheus/client_golang/prometheus" - "github.com/stretchr/testify/mock" - v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flytestdlib/contextutils" + + mocks3 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flytepropeller/events" eventsErr "github.com/flyteorg/flytepropeller/events/errors" eventMocks "github.com/flyteorg/flytepropeller/events/mocks" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/controller/executors" mocks4 "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" + gatemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" nodeHandlerMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" mocks2 "github.com/flyteorg/flytepropeller/pkg/controller/nodes/mocks" + recoveryMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/catalog" + flyteassert "github.com/flyteorg/flytepropeller/pkg/utils/assert" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/flyteorg/flytestdlib/storage" + storageMocks "github.com/flyteorg/flytestdlib/storage/mocks" + + "github.com/golang/protobuf/proto" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" - "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" - "github.com/flyteorg/flytepropeller/pkg/controller/config" - "github.com/flyteorg/flytepropeller/pkg/controller/executors" - recoveryMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery/mocks" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" - flyteassert "github.com/flyteorg/flytepropeller/pkg/utils/assert" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) var fakeKubeClient = mocks4.NewFakeKubeClient() var catalogClient = catalog.NOOPCatalog{} var recoveryClient = &recoveryMocks.Client{} +var signalClient = &gatemocks.SignalServiceClient{} const taskID = "tID" const inputsPath = "inputs.pb" @@ -68,7 +71,7 @@ func TestSetInputsForStartNode(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() exec, err := NewExecutor(ctx, config.GetConfig().NodeConfig, mockStorage, enQWf, eventMocks.NewMockEventSink(), adminClient, - adminClient, 10, "s3://bucket/", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + adminClient, 10, "s3://bucket/", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) inputs := &core.LiteralMap{ Literals: map[string]*core.Literal{ @@ -115,7 +118,7 @@ func TestSetInputsForStartNode(t *testing.T) { failStorage := createFailingDatastore(t, testScope.NewSubScope("failing")) execFail, err := NewExecutor(ctx, config.GetConfig().NodeConfig, failStorage, enQWf, eventMocks.NewMockEventSink(), adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) t.Run("StorageFailure", func(t *testing.T) { w := createDummyBaseWorkflow(mockStorage) @@ -140,8 +143,8 @@ func TestNodeExecutor_Initialize(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() t.Run("happy", func(t *testing.T) { - execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, adminClient, + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) @@ -154,8 +157,8 @@ func TestNodeExecutor_Initialize(t *testing.T) { }) t.Run("error", func(t *testing.T) { - execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, memStore, enQWf, mockEventSink, adminClient, adminClient, + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) @@ -178,7 +181,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseStartNodes(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) @@ -282,7 +285,7 @@ func TestNodeExecutor_RecursiveNodeHandler_RecurseEndNode(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) @@ -687,8 +690,8 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() - execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, - adminClient, adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) exec.nodeHandlerFactory = hf @@ -762,8 +765,8 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() - execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) exec.nodeHandlerFactory = hf @@ -874,8 +877,8 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { hf := &mocks2.HandlerFactory{} store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() - execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) exec.nodeHandlerFactory = hf @@ -938,8 +941,8 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { hf := &mocks2.HandlerFactory{} store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() - execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) exec.nodeHandlerFactory = hf @@ -969,8 +972,8 @@ func TestNodeExecutor_RecursiveNodeHandler_Recurse(t *testing.T) { hf := &mocks2.HandlerFactory{} store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() - execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) exec.nodeHandlerFactory = hf @@ -1003,8 +1006,8 @@ func TestNodeExecutor_RecursiveNodeHandler_NoDownstream(t *testing.T) { store := createInmemoryDataStore(t, promutils.NewTestScope()) adminClient := launchplan.NewFailFastLaunchPlanExecutor() - execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, - adminClient, 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) @@ -1115,7 +1118,7 @@ func TestNodeExecutor_RecursiveNodeHandler_UpstreamNotReady(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) @@ -1231,7 +1234,7 @@ func TestNodeExecutor_RecursiveNodeHandler_BranchNode(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) // Node not yet started @@ -1829,7 +1832,7 @@ func TestNodeExecutor_RecursiveNodeHandler_ParallelismLimit(t *testing.T) { adminClient := launchplan.NewFailFastLaunchPlanExecutor() execIface, err := NewExecutor(ctx, config.GetConfig().NodeConfig, store, enQWf, mockEventSink, adminClient, adminClient, - 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + 10, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) exec := execIface.(*nodeExecutor) diff --git a/flytepropeller/pkg/controller/nodes/gate/handler.go b/flytepropeller/pkg/controller/nodes/gate/handler.go new file mode 100644 index 0000000000..5c6d1a512f --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/gate/handler.go @@ -0,0 +1,231 @@ +package gate + +import ( + "context" + "fmt" + "time" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/errors" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/storage" +) + +//go:generate mockery -all -case=underscore + +// SignalServiceClient is a SignalServiceClient wrapper interface used specifically for generating +// mocks for testing +type SignalServiceClient interface { + service.SignalServiceClient +} + +// gateNodeHandler is a handle implementation for processing gate nodes +type gateNodeHandler struct { + signalClient SignalServiceClient + metrics metrics +} + +// metrics encapsulates the prometheus metrics for this handler +type metrics struct { + scope promutils.Scope +} + +// newMetrics initializes a new metrics struct +func newMetrics(scope promutils.Scope) metrics { + return metrics{ + scope: scope, + } +} + +// Abort stops the gate node defined in the NodeExecutionContext +func (g *gateNodeHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, reason string) error { + return nil +} + +// Finalize completes the gate node defined in the NodeExecutionContext +func (g *gateNodeHandler) Finalize(ctx context.Context, _ handler.NodeExecutionContext) error { + return nil +} + +// FinalizeRequired defines whether or not this handler requires finalize to be called on +// node completion +func (g *gateNodeHandler) FinalizeRequired() bool { + return false +} + +// Handle is responsible for transitioning and reporting node state to complete the node defined +// by the NodeExecutionContext +func (g *gateNodeHandler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.Transition, error) { + gateNode := nCtx.Node().GetGateNode() + gateNodeState := nCtx.NodeStateReader().GetGateNodeState() + + if gateNodeState.Phase == v1alpha1.GateNodePhaseUndefined { + gateNodeState.Phase = v1alpha1.GateNodePhaseExecuting + } + + switch gateNode.GetKind() { + case v1alpha1.ConditionKindApprove: + // retrieve approve condition + approveCondition := gateNode.GetApprove() + if approveCondition == nil { + errMsg := "gateNode approve condition is nil" + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, + errors.BadSpecificationError, errMsg, nil)), nil + } + + // use admin client to query for signal + request := &admin.SignalGetOrCreateRequest{ + Id: &core.SignalIdentifier{ + ExecutionId: nCtx.ExecutionContext().GetExecutionID().WorkflowExecutionIdentifier, + SignalId: approveCondition.SignalId, + }, + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BOOLEAN, + }, + }, + } + + signal, err := g.signalClient.GetOrCreateSignal(ctx, request) + if err != nil { + return handler.UnknownTransition, err + } + + // if signal has value then check for approval + if signal.Value != nil && signal.Value.Value != nil { + approved, ok := getBoolean(signal.Value) + if !ok { + errMsg := fmt.Sprintf("received a non-boolean approve signal value [%v]", signal.Value) + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_UNKNOWN, + errors.RuntimeExecutionError, errMsg, nil)), nil + } + + if !approved { + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_USER, + "ReceivedRejectSignal", "received a reject signal to disapprove the node input values", nil)), nil + } + + // copy input values to outputs + inputs, err := nCtx.InputReader().Get(ctx) + if err != nil { + errMsg := fmt.Sprintf("failed to read input with error [%s]", err) + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, errors.RuntimeExecutionError, errMsg, nil)), nil + } + + outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) + if err := nCtx.DataStore().WriteProtobuf(ctx, outputFile, storage.Options{}, inputs); err != nil { + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "WriteOutputsFailed", + fmt.Sprintf("failed to write signal value to [%v] with error [%s]", outputFile, err.Error()), nil)), nil + } + + o := &handler.OutputInfo{OutputURI: outputFile} + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{ + OutputInfo: o, + })), nil + } + case v1alpha1.ConditionKindSignal: + // retrieve signal condition + signalCondition := gateNode.GetSignal() + if signalCondition == nil { + errMsg := "gateNode signal condition is nil" + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, + errors.BadSpecificationError, errMsg, nil)), nil + } + + // use admin client to query for signal + request := &admin.SignalGetOrCreateRequest{ + Id: &core.SignalIdentifier{ + ExecutionId: nCtx.ExecutionContext().GetExecutionID().WorkflowExecutionIdentifier, + SignalId: signalCondition.SignalId, + }, + Type: signalCondition.Type, + } + + signal, err := g.signalClient.GetOrCreateSignal(ctx, request) + if err != nil { + return handler.UnknownTransition, err + } + + // if signal has value then write to output and transition to success + if signal.Value != nil && signal.Value.Value != nil { + outputs := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + signalCondition.OutputVariableName: signal.Value, + }, + } + + outputFile := v1alpha1.GetOutputsFile(nCtx.NodeStatus().GetOutputDir()) + if err := nCtx.DataStore().WriteProtobuf(ctx, outputFile, storage.Options{}, outputs); err != nil { + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, "WriteOutputsFailed", + fmt.Sprintf("failed to write signal value to [%v] with error: %s", outputFile, err.Error()), nil)), nil + } + + o := &handler.OutputInfo{OutputURI: outputFile} + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{ + OutputInfo: o, + })), nil + } + case v1alpha1.ConditionKindSleep: + // retrieve sleep duration + sleepCondition := gateNode.GetSleep() + if sleepCondition == nil { + errMsg := "gateNode sleep condition is nil" + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, + errors.BadSpecificationError, errMsg, nil)), nil + } + + sleepDuration := sleepCondition.GetDuration().AsDuration() + + // check duration of node sleep + now := time.Now() + if sleepDuration <= now.Sub(nCtx.NodeStatus().GetLastAttemptStartedAt().Time) { + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{})), nil + } + default: + errMsg := "gateNode does not have a supported condition reference" + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(core.ExecutionError_SYSTEM, + errors.BadSpecificationError, errMsg, nil)), nil + } + + // update gate node status + if err := nCtx.NodeStateWriter().PutGateNodeState(gateNodeState); err != nil { + logger.Errorf(ctx, "failed to store TaskNode state with err [%s]", err.Error()) + return handler.UnknownTransition, err + } + + return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), nil +} + +// Setup handles any initialization requirements for this handler +func (g *gateNodeHandler) Setup(_ context.Context, _ handler.SetupContext) error { + return nil +} + +// New initializes a new gateNodeHandler +func New(eventConfig *config.EventConfig, signalClient service.SignalServiceClient, scope promutils.Scope) handler.Node { + gateScope := scope.NewSubScope("gate") + return &gateNodeHandler{ + signalClient: signalClient, + metrics: newMetrics(gateScope), + } +} + +func getBoolean(literal *core.Literal) (bool, bool) { + if scalarValue, ok := literal.Value.(*core.Literal_Scalar); ok { + if primitiveValue, ok := scalarValue.Scalar.Value.(*core.Scalar_Primitive); ok { + if booleanValue, ok := primitiveValue.Primitive.Value.(*core.Primitive_Boolean); ok { + return booleanValue.Boolean, true + } + } + } + + return false, false +} diff --git a/flytepropeller/pkg/controller/nodes/gate/handler_test.go b/flytepropeller/pkg/controller/nodes/gate/handler_test.go new file mode 100644 index 0000000000..ca7e8e9dd5 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/gate/handler_test.go @@ -0,0 +1,305 @@ +package gate + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + ioMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + flyteMocks "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/config" + executormocks "github.com/flyteorg/flytepropeller/pkg/controller/executors/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate/mocks" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + nodeMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler/mocks" + + "github.com/flyteorg/flytestdlib/contextutils" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/promutils/labeled" + "github.com/flyteorg/flytestdlib/storage" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "google.golang.org/protobuf/types/known/durationpb" + + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +var ( + eventConfig = &config.EventConfig{ + RawOutputPolicy: config.RawOutputPolicyReference, + } + + approveGateNode = &v1alpha1.GateNodeSpec{ + Kind: v1alpha1.ConditionKindApprove, + Approve: &v1alpha1.ApproveCondition{ + ApproveCondition: &core.ApproveCondition{ + SignalId: "foo", + }, + }, + } + + signalGateNode = &v1alpha1.GateNodeSpec{ + Kind: v1alpha1.ConditionKindSignal, + Signal: &v1alpha1.SignalCondition{ + SignalCondition: &core.SignalCondition{ + SignalId: "foo", + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_BOOLEAN, + }, + }, + OutputVariableName: "bar", + }, + }, + } + + sleepMinuteGateNode = &v1alpha1.GateNodeSpec{ + Kind: v1alpha1.ConditionKindSleep, + Sleep: &v1alpha1.SleepCondition{ + SleepCondition: &core.SleepCondition{ + Duration: durationpb.New(time.Minute), + }, + }, + } + + sleepNowGateNode = &v1alpha1.GateNodeSpec{ + Kind: v1alpha1.ConditionKindSleep, + Sleep: &v1alpha1.SleepCondition{ + SleepCondition: &core.SleepCondition{ + Duration: durationpb.New(time.Minute * 0), + }, + }, + } +) + +func init() { + labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, + contextutils.TaskIDKey) +} + +func createNodeExecutionContext(gateNode *v1alpha1.GateNodeSpec) *nodeMocks.NodeExecutionContext { + wfExecID := v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + } + + n := &flyteMocks.ExecutableNode{} + n.OnGetGateNode().Return(gateNode) + + nm := &nodeMocks.NodeExecutionMetadata{} + + ns := &flyteMocks.ExecutableNodeStatus{} + ns.OnGetDataDir().Return(storage.DataReference("data-dir")) + ns.OnGetOutputDir().Return(storage.DataReference("data-dir")) + + t := v1.NewTime(time.Now()) + ns.OnGetLastAttemptStartedAt().Return(&t) + + inputReader := &ioMocks.InputReader{} + inputReader.OnGetMatch(mock.Anything).Return(&core.LiteralMap{}, nil) + dataStore, _ := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + + eCtx := &executormocks.ExecutionContext{} + eCtx.OnGetExecutionID().Return(wfExecID) + + nCtx := &nodeMocks.NodeExecutionContext{} + nCtx.OnNodeExecutionMetadata().Return(nm) + nCtx.OnNode().Return(n) + nCtx.OnNodeStatus().Return(ns) + nCtx.OnDataStore().Return(dataStore) + nCtx.OnExecutionContext().Return(eCtx) + nCtx.OnInputReader().Return(inputReader) + + r := &nodeMocks.NodeStateReader{} + r.OnGetGateNodeState().Return(handler.GateNodeState{}) + nCtx.OnNodeStateReader().Return(r) + + w := &nodeMocks.NodeStateWriter{} + w.OnPutGateNodeStateMatch(mock.Anything).Return(nil) + nCtx.OnNodeStateWriter().Return(w) + return nCtx +} + +func TestAbort(t *testing.T) { + ctx := context.TODO() + signalClient := mocks.SignalServiceClient{} + scope := promutils.NewTestScope() + + handler := New(eventConfig, &signalClient, scope) + + assert.NoError(t, handler.Abort(ctx, nil, "")) +} + +func TestFinalize(t *testing.T) { + ctx := context.TODO() + signalClient := mocks.SignalServiceClient{} + scope := promutils.NewTestScope() + + handler := New(eventConfig, &signalClient, scope) + + assert.NoError(t, handler.Finalize(ctx, nil)) +} + +func TestHandle(t *testing.T) { + ctx := context.TODO() + scope := promutils.NewTestScope() + + t.Run("ApproveCheck", func(t *testing.T) { + nCtx := createNodeExecutionContext(approveGateNode) + signalClient := mocks.SignalServiceClient{} + signalClient.OnGetOrCreateSignalMatch(mock.Anything, mock.Anything).Return(&admin.Signal{}, nil) + + gateNodeHandler := New(eventConfig, &signalClient, scope) + + transition, err := gateNodeHandler.Handle(ctx, nCtx) + assert.NoError(t, err) + assert.Equal(t, handler.EPhaseRunning, transition.Info().GetPhase()) + }) + + t.Run("ApproveComplete", func(t *testing.T) { + nCtx := createNodeExecutionContext(approveGateNode) + signalClient := mocks.SignalServiceClient{} + signalClient.OnGetOrCreateSignalMatch(mock.Anything, mock.Anything).Return(&admin.Signal{ + Value: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Boolean{ + Boolean: true, + }, + }, + }, + }, + }, + }, + }, nil) + + gateNodeHandler := New(eventConfig, &signalClient, scope) + + transition, err := gateNodeHandler.Handle(ctx, nCtx) + assert.NoError(t, err) + assert.Equal(t, handler.EPhaseSuccess, transition.Info().GetPhase()) + }) + + t.Run("ApproveRejected", func(t *testing.T) { + nCtx := createNodeExecutionContext(approveGateNode) + signalClient := mocks.SignalServiceClient{} + signalClient.OnGetOrCreateSignalMatch(mock.Anything, mock.Anything).Return(&admin.Signal{ + Value: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Boolean{ + Boolean: false, + }, + }, + }, + }, + }, + }, + }, nil) + + gateNodeHandler := New(eventConfig, &signalClient, scope) + + transition, err := gateNodeHandler.Handle(ctx, nCtx) + assert.NoError(t, err) + assert.Equal(t, handler.EPhaseFailed, transition.Info().GetPhase()) + }) + + t.Run("ApproveError", func(t *testing.T) { + nCtx := createNodeExecutionContext(approveGateNode) + signalClient := mocks.SignalServiceClient{} + signalClient.OnGetOrCreateSignalMatch(mock.Anything, mock.Anything).Return(&admin.Signal{}, errors.New("foo")) + + gateNodeHandler := New(eventConfig, &signalClient, scope) + + transition, err := gateNodeHandler.Handle(ctx, nCtx) + assert.Error(t, err) + assert.Equal(t, handler.EPhaseUndefined, transition.Info().GetPhase()) + }) + + t.Run("SignalCheck", func(t *testing.T) { + nCtx := createNodeExecutionContext(signalGateNode) + signalClient := mocks.SignalServiceClient{} + signalClient.OnGetOrCreateSignalMatch(mock.Anything, mock.Anything).Return(&admin.Signal{}, nil) + + gateNodeHandler := New(eventConfig, &signalClient, scope) + + transition, err := gateNodeHandler.Handle(ctx, nCtx) + assert.NoError(t, err) + assert.Equal(t, handler.EPhaseRunning, transition.Info().GetPhase()) + }) + + t.Run("SignalComplete", func(t *testing.T) { + nCtx := createNodeExecutionContext(signalGateNode) + signalClient := mocks.SignalServiceClient{} + signalClient.OnGetOrCreateSignalMatch(mock.Anything, mock.Anything).Return(&admin.Signal{ + Value: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Boolean{ + Boolean: false, + }, + }, + }, + }, + }, + }, + }, nil) + + gateNodeHandler := New(eventConfig, &signalClient, scope) + + transition, err := gateNodeHandler.Handle(ctx, nCtx) + assert.NoError(t, err) + assert.Equal(t, handler.EPhaseSuccess, transition.Info().GetPhase()) + }) + + t.Run("SignalError", func(t *testing.T) { + nCtx := createNodeExecutionContext(signalGateNode) + signalClient := mocks.SignalServiceClient{} + signalClient.OnGetOrCreateSignalMatch(mock.Anything, mock.Anything).Return(&admin.Signal{}, errors.New("foo")) + + gateNodeHandler := New(eventConfig, &signalClient, scope) + + transition, err := gateNodeHandler.Handle(ctx, nCtx) + assert.Error(t, err) + assert.Equal(t, handler.EPhaseUndefined, transition.Info().GetPhase()) + }) + + t.Run("SleepCheck", func(t *testing.T) { + nCtx := createNodeExecutionContext(sleepMinuteGateNode) + signalClient := mocks.SignalServiceClient{} + + gateNodeHandler := New(eventConfig, &signalClient, scope) + + transition, err := gateNodeHandler.Handle(ctx, nCtx) + assert.NoError(t, err) + assert.Equal(t, handler.EPhaseRunning, transition.Info().GetPhase()) + }) + + t.Run("SleepComplete", func(t *testing.T) { + nCtx := createNodeExecutionContext(sleepNowGateNode) + signalClient := mocks.SignalServiceClient{} + + gateNodeHandler := New(eventConfig, &signalClient, scope) + + transition, err := gateNodeHandler.Handle(ctx, nCtx) + assert.NoError(t, err) + assert.Equal(t, handler.EPhaseSuccess, transition.Info().GetPhase()) + }) +} diff --git a/flytepropeller/pkg/controller/nodes/gate/mocks/signal_service_client.go b/flytepropeller/pkg/controller/nodes/gate/mocks/signal_service_client.go new file mode 100644 index 0000000000..d0a819e2b0 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/gate/mocks/signal_service_client.go @@ -0,0 +1,162 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + admin "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + + grpc "google.golang.org/grpc" + + mock "github.com/stretchr/testify/mock" +) + +// SignalServiceClient is an autogenerated mock type for the SignalServiceClient type +type SignalServiceClient struct { + mock.Mock +} + +type SignalServiceClient_GetOrCreateSignal struct { + *mock.Call +} + +func (_m SignalServiceClient_GetOrCreateSignal) Return(_a0 *admin.Signal, _a1 error) *SignalServiceClient_GetOrCreateSignal { + return &SignalServiceClient_GetOrCreateSignal{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *SignalServiceClient) OnGetOrCreateSignal(ctx context.Context, in *admin.SignalGetOrCreateRequest, opts ...grpc.CallOption) *SignalServiceClient_GetOrCreateSignal { + c_call := _m.On("GetOrCreateSignal", ctx, in, opts) + return &SignalServiceClient_GetOrCreateSignal{Call: c_call} +} + +func (_m *SignalServiceClient) OnGetOrCreateSignalMatch(matchers ...interface{}) *SignalServiceClient_GetOrCreateSignal { + c_call := _m.On("GetOrCreateSignal", matchers...) + return &SignalServiceClient_GetOrCreateSignal{Call: c_call} +} + +// GetOrCreateSignal provides a mock function with given fields: ctx, in, opts +func (_m *SignalServiceClient) GetOrCreateSignal(ctx context.Context, in *admin.SignalGetOrCreateRequest, opts ...grpc.CallOption) (*admin.Signal, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *admin.Signal + if rf, ok := ret.Get(0).(func(context.Context, *admin.SignalGetOrCreateRequest, ...grpc.CallOption) *admin.Signal); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.Signal) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.SignalGetOrCreateRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type SignalServiceClient_ListSignals struct { + *mock.Call +} + +func (_m SignalServiceClient_ListSignals) Return(_a0 *admin.SignalList, _a1 error) *SignalServiceClient_ListSignals { + return &SignalServiceClient_ListSignals{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *SignalServiceClient) OnListSignals(ctx context.Context, in *admin.SignalListRequest, opts ...grpc.CallOption) *SignalServiceClient_ListSignals { + c_call := _m.On("ListSignals", ctx, in, opts) + return &SignalServiceClient_ListSignals{Call: c_call} +} + +func (_m *SignalServiceClient) OnListSignalsMatch(matchers ...interface{}) *SignalServiceClient_ListSignals { + c_call := _m.On("ListSignals", matchers...) + return &SignalServiceClient_ListSignals{Call: c_call} +} + +// ListSignals provides a mock function with given fields: ctx, in, opts +func (_m *SignalServiceClient) ListSignals(ctx context.Context, in *admin.SignalListRequest, opts ...grpc.CallOption) (*admin.SignalList, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *admin.SignalList + if rf, ok := ret.Get(0).(func(context.Context, *admin.SignalListRequest, ...grpc.CallOption) *admin.SignalList); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.SignalList) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.SignalListRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type SignalServiceClient_SetSignal struct { + *mock.Call +} + +func (_m SignalServiceClient_SetSignal) Return(_a0 *admin.SignalSetResponse, _a1 error) *SignalServiceClient_SetSignal { + return &SignalServiceClient_SetSignal{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *SignalServiceClient) OnSetSignal(ctx context.Context, in *admin.SignalSetRequest, opts ...grpc.CallOption) *SignalServiceClient_SetSignal { + c_call := _m.On("SetSignal", ctx, in, opts) + return &SignalServiceClient_SetSignal{Call: c_call} +} + +func (_m *SignalServiceClient) OnSetSignalMatch(matchers ...interface{}) *SignalServiceClient_SetSignal { + c_call := _m.On("SetSignal", matchers...) + return &SignalServiceClient_SetSignal{Call: c_call} +} + +// SetSignal provides a mock function with given fields: ctx, in, opts +func (_m *SignalServiceClient) SetSignal(ctx context.Context, in *admin.SignalSetRequest, opts ...grpc.CallOption) (*admin.SignalSetResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *admin.SignalSetResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.SignalSetRequest, ...grpc.CallOption) *admin.SignalSetResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.SignalSetResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.SignalSetRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/flytepropeller/pkg/controller/nodes/handler/mocks/node_state_reader.go b/flytepropeller/pkg/controller/nodes/handler/mocks/node_state_reader.go index f8ee782c15..ef86e64cc3 100644 --- a/flytepropeller/pkg/controller/nodes/handler/mocks/node_state_reader.go +++ b/flytepropeller/pkg/controller/nodes/handler/mocks/node_state_reader.go @@ -76,6 +76,38 @@ func (_m *NodeStateReader) GetDynamicNodeState() handler.DynamicNodeState { return r0 } +type NodeStateReader_GetGateNodeState struct { + *mock.Call +} + +func (_m NodeStateReader_GetGateNodeState) Return(_a0 handler.GateNodeState) *NodeStateReader_GetGateNodeState { + return &NodeStateReader_GetGateNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateReader) OnGetGateNodeState() *NodeStateReader_GetGateNodeState { + c_call := _m.On("GetGateNodeState") + return &NodeStateReader_GetGateNodeState{Call: c_call} +} + +func (_m *NodeStateReader) OnGetGateNodeStateMatch(matchers ...interface{}) *NodeStateReader_GetGateNodeState { + c_call := _m.On("GetGateNodeState", matchers...) + return &NodeStateReader_GetGateNodeState{Call: c_call} +} + +// GetGateNodeState provides a mock function with given fields: +func (_m *NodeStateReader) GetGateNodeState() handler.GateNodeState { + ret := _m.Called() + + var r0 handler.GateNodeState + if rf, ok := ret.Get(0).(func() handler.GateNodeState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(handler.GateNodeState) + } + + return r0 +} + type NodeStateReader_GetTaskNodeState struct { *mock.Call } diff --git a/flytepropeller/pkg/controller/nodes/handler/mocks/node_state_writer.go b/flytepropeller/pkg/controller/nodes/handler/mocks/node_state_writer.go index 6820f4504d..ec5359550a 100644 --- a/flytepropeller/pkg/controller/nodes/handler/mocks/node_state_writer.go +++ b/flytepropeller/pkg/controller/nodes/handler/mocks/node_state_writer.go @@ -76,6 +76,38 @@ func (_m *NodeStateWriter) PutDynamicNodeState(s handler.DynamicNodeState) error return r0 } +type NodeStateWriter_PutGateNodeState struct { + *mock.Call +} + +func (_m NodeStateWriter_PutGateNodeState) Return(_a0 error) *NodeStateWriter_PutGateNodeState { + return &NodeStateWriter_PutGateNodeState{Call: _m.Call.Return(_a0)} +} + +func (_m *NodeStateWriter) OnPutGateNodeState(s handler.GateNodeState) *NodeStateWriter_PutGateNodeState { + c_call := _m.On("PutGateNodeState", s) + return &NodeStateWriter_PutGateNodeState{Call: c_call} +} + +func (_m *NodeStateWriter) OnPutGateNodeStateMatch(matchers ...interface{}) *NodeStateWriter_PutGateNodeState { + c_call := _m.On("PutGateNodeState", matchers...) + return &NodeStateWriter_PutGateNodeState{Call: c_call} +} + +// PutGateNodeState provides a mock function with given fields: s +func (_m *NodeStateWriter) PutGateNodeState(s handler.GateNodeState) error { + ret := _m.Called(s) + + var r0 error + if rf, ok := ret.Get(0).(func(handler.GateNodeState) error); ok { + r0 = rf(s) + } else { + r0 = ret.Error(0) + } + + return r0 +} + type NodeStateWriter_PutTaskNodeState struct { *mock.Call } diff --git a/flytepropeller/pkg/controller/nodes/handler/state.go b/flytepropeller/pkg/controller/nodes/handler/state.go index b15b900d22..d404590998 100644 --- a/flytepropeller/pkg/controller/nodes/handler/state.go +++ b/flytepropeller/pkg/controller/nodes/handler/state.go @@ -41,11 +41,17 @@ type WorkflowNodeState struct { Error *core.ExecutionError } +type GateNodeState struct { + Phase v1alpha1.GateNodePhase + StartedAt time.Time +} + type NodeStateWriter interface { PutTaskNodeState(s TaskNodeState) error PutBranchNode(s BranchNodeState) error PutDynamicNodeState(s DynamicNodeState) error PutWorkflowNodeState(s WorkflowNodeState) error + PutGateNodeState(s GateNodeState) error } type NodeStateReader interface { @@ -53,4 +59,5 @@ type NodeStateReader interface { GetBranchNode() BranchNodeState GetDynamicNodeState() DynamicNodeState GetWorkflowNodeState() WorkflowNodeState + GetGateNodeState() GateNodeState } diff --git a/flytepropeller/pkg/controller/nodes/handler/transition_info.go b/flytepropeller/pkg/controller/nodes/handler/transition_info.go index d12ff9f8a6..3bac510e64 100644 --- a/flytepropeller/pkg/controller/nodes/handler/transition_info.go +++ b/flytepropeller/pkg/controller/nodes/handler/transition_info.go @@ -49,6 +49,9 @@ type TaskNodeInfo struct { TaskNodeMetadata *event.TaskNodeMetadata } +type GateNodeInfo struct { +} + type OutputInfo struct { OutputURI storage.DataReference DeckURI *storage.DataReference @@ -60,6 +63,7 @@ type ExecutionInfo struct { BranchNodeInfo *BranchNodeInfo OutputInfo *OutputInfo TaskNodeInfo *TaskNodeInfo + GateNodeInfo *GateNodeInfo } type PhaseInfo struct { diff --git a/flytepropeller/pkg/controller/nodes/handler_factory.go b/flytepropeller/pkg/controller/nodes/handler_factory.go index 72e361070a..e13143e6b2 100644 --- a/flytepropeller/pkg/controller/nodes/handler_factory.go +++ b/flytepropeller/pkg/controller/nodes/handler_factory.go @@ -3,6 +3,8 @@ package nodes import ( "context" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" @@ -19,6 +21,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/branch" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/end" + "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/start" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow" @@ -56,7 +59,7 @@ func (f handlerFactory) Setup(ctx context.Context, setup handler.SetupContext) e func NewHandlerFactory(ctx context.Context, executor executors.Node, workflowLauncher launchplan.Executor, launchPlanReader launchplan.Reader, kubeClient executors.Client, client catalog.Client, recoveryClient recovery.Client, - eventConfig *config.EventConfig, clusterID string, scope promutils.Scope) (HandlerFactory, error) { + eventConfig *config.EventConfig, clusterID string, signalClient service.SignalServiceClient, scope promutils.Scope) (HandlerFactory, error) { t, err := task.New(ctx, kubeClient, client, eventConfig, clusterID, scope) if err != nil { @@ -68,6 +71,7 @@ func NewHandlerFactory(ctx context.Context, executor executors.Node, workflowLau v1alpha1.NodeKindBranch: branch.New(executor, eventConfig, scope), v1alpha1.NodeKindTask: dynamic.New(t, executor, launchPlanReader, eventConfig, scope), v1alpha1.NodeKindWorkflow: subworkflow.New(executor, workflowLauncher, recoveryClient, eventConfig, scope), + v1alpha1.NodeKindGate: gate.New(eventConfig, signalClient, scope), v1alpha1.NodeKindStart: start.New(), v1alpha1.NodeKindEnd: end.New(), }, diff --git a/flytepropeller/pkg/controller/nodes/node_state_manager.go b/flytepropeller/pkg/controller/nodes/node_state_manager.go index 4906a6da29..73baf4ddae 100644 --- a/flytepropeller/pkg/controller/nodes/node_state_manager.go +++ b/flytepropeller/pkg/controller/nodes/node_state_manager.go @@ -15,6 +15,7 @@ type nodeStateManager struct { b *handler.BranchNodeState d *handler.DynamicNodeState w *handler.WorkflowNodeState + g *handler.GateNodeState } func (n *nodeStateManager) PutTaskNodeState(s handler.TaskNodeState) error { @@ -37,6 +38,11 @@ func (n *nodeStateManager) PutWorkflowNodeState(s handler.WorkflowNodeState) err return nil } +func (n *nodeStateManager) PutGateNodeState(s handler.GateNodeState) error { + n.g = &s + return nil +} + func (n nodeStateManager) GetTaskNodeState() handler.TaskNodeState { tn := n.nodeStatus.GetTaskNodeStatus() if tn != nil { @@ -85,11 +91,21 @@ func (n nodeStateManager) GetWorkflowNodeState() handler.WorkflowNodeState { return ws } +func (n nodeStateManager) GetGateNodeState() handler.GateNodeState { + gn := n.nodeStatus.GetGateNodeStatus() + gs := handler.GateNodeState{} + if gn != nil { + gs.Phase = gn.GetGateNodePhase() + } + return gs +} + func (n *nodeStateManager) clearNodeStatus() { n.t = nil n.b = nil n.d = nil n.w = nil + n.g = nil n.nodeStatus.ClearLastAttemptStartedAt() } diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go index 17d9ee2d71..0599953f52 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go @@ -56,6 +56,10 @@ func (t workflowNodeStateHolder) PutDynamicNodeState(s handler.DynamicNodeState) panic("not implemented") } +func (t workflowNodeStateHolder) PutGateNodeState(s handler.GateNodeState) error { + panic("not implemented") +} + var wfExecID = &core.WorkflowExecutionIdentifier{ Project: "project", Domain: "domain", diff --git a/flytepropeller/pkg/controller/nodes/task/handler_test.go b/flytepropeller/pkg/controller/nodes/task/handler_test.go index 0b5bf6da7b..aaaad1dac9 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/task/handler_test.go @@ -380,6 +380,10 @@ func (t taskNodeStateHolder) PutDynamicNodeState(s handler.DynamicNodeState) err panic("not implemented") } +func (t taskNodeStateHolder) PutGateNodeState(s handler.GateNodeState) error { + panic("not implemented") +} + func CreateNoopResourceManager(ctx context.Context, scope promutils.Scope) resourcemanager.BaseResourceManager { rmBuilder, _ := resourcemanager.GetResourceManagerBuilderByType(ctx, rmConfig.TypeNoop, scope) rm, _ := rmBuilder.BuildResourceManager(ctx) diff --git a/flytepropeller/pkg/controller/nodes/transformers.go b/flytepropeller/pkg/controller/nodes/transformers.go index 6d2d1a95e5..76d1c9331c 100644 --- a/flytepropeller/pkg/controller/nodes/transformers.go +++ b/flytepropeller/pkg/controller/nodes/transformers.go @@ -254,4 +254,10 @@ func UpdateNodeStatus(np v1alpha1.NodePhase, p handler.PhaseInfo, n *nodeStateMa t.SetWorkflowNodePhase(n.w.Phase) t.SetExecutionError(n.w.Error) } + + // Update gate node status + if n.g != nil { + t := s.GetOrCreateGateNodeStatus() + t.SetGateNodePhase(n.g.Phase) + } } diff --git a/flytepropeller/pkg/controller/workflow/executor_test.go b/flytepropeller/pkg/controller/workflow/executor_test.go index 46c8be36b4..47cf83ebec 100644 --- a/flytepropeller/pkg/controller/workflow/executor_test.go +++ b/flytepropeller/pkg/controller/workflow/executor_test.go @@ -49,6 +49,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/nodes" + gatemocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/gate/mocks" recoveryMocks "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery/mocks" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" ) @@ -56,6 +57,7 @@ import ( var ( testScope = promutils.NewScope("test_wfexec") fakeKubeClient = mocks2.NewFakeKubeClient() + signalClient = &gatemocks.SignalServiceClient{} ) const ( @@ -243,8 +245,8 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Error(t *testing.T) { recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() - nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, - adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, eventConfig, testClusterID, promutils.NewTestScope()) assert.NoError(t, err) @@ -323,8 +325,8 @@ func TestWorkflowExecutor_HandleFlyteWorkflow(t *testing.T) { recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() - nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, - adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, eventConfig, testClusterID, promutils.NewTestScope()) @@ -386,8 +388,8 @@ func BenchmarkWorkflowExecutor(b *testing.B) { assert.NoError(b, err) recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() - nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, - adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, scope) + nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, scope) assert.NoError(b, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, eventConfig, testClusterID, promutils.NewTestScope()) @@ -487,8 +489,8 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) { assert.NoError(t, err) recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() - nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, - adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "", nodeExec, eventConfig, testClusterID, promutils.NewTestScope()) assert.NoError(t, err) @@ -583,8 +585,8 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Events(t *testing.T) { assert.NoError(t, err) adminClient := launchplan.NewFailFastLaunchPlanExecutor() recoveryClient := &recoveryMocks.Client{} - nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, - adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) executor, err := NewExecutor(ctx, store, enqueueWorkflow, eventSink, recorder, "metadata", nodeExec, eventConfig, testClusterID, promutils.NewTestScope()) assert.NoError(t, err) @@ -641,8 +643,8 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_EventFailure(t *testing.T) { recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() - nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, nodeEventSink, adminClient, - adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) + nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, nodeEventSink, adminClient, adminClient, + maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, signalClient, promutils.NewTestScope()) assert.NoError(t, err) t.Run("EventAlreadyInTerminalStateError", func(t *testing.T) {