diff --git a/Makefile b/Makefile index 0e653fa93..be00efaaf 100644 --- a/Makefile +++ b/Makefile @@ -3,17 +3,18 @@ include boilerplate/flyte/docker_build/Makefile include boilerplate/flyte/golang_test_targets/Makefile include boilerplate/flyte/end2end/Makefile - .PHONY: update_boilerplate update_boilerplate: @curl https://raw.githubusercontent.com/flyteorg/boilerplate/master/boilerplate/update.sh -o boilerplate/update.sh @boilerplate/update.sh .PHONY: linux_compile +linux_compile: export CGO_ENABLED ?= 0 +linux_compile: export GOOS ?= linux linux_compile: - GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts/flytepropeller ./cmd/controller/main.go - GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts/flytepropeller-manager ./cmd/manager/main.go - GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts/kubectl-flyte ./cmd/kubectl-flyte/main.go + go build -o /artifacts/flytepropeller ./cmd/controller/main.go + go build -o /artifacts/flytepropeller-manager ./cmd/manager/main.go + go build -o /artifacts/kubectl-flyte ./cmd/kubectl-flyte/main.go .PHONY: compile compile: @@ -25,9 +26,9 @@ compile: cross_compile: @glide install @mkdir -p ./bin/cross - GOOS=linux GOARCH=amd64 go build -o bin/cross/flytepropeller ./cmd/controller/main.go - GOOS=linux GOARCH=amd64 go build -o bin/cross/flytepropeller-manager ./cmd/manager/main.go - GOOS=linux GOARCH=amd64 go build -o bin/cross/kubectl-flyte ./cmd/kubectl-flyte/main.go + go build -o bin/cross/flytepropeller ./cmd/controller/main.go + go build -o bin/cross/flytepropeller-manager ./cmd/manager/main.go + go build -o bin/cross/kubectl-flyte ./cmd/kubectl-flyte/main.go op_code_generate: @RESOURCE_NAME=flyteworkflow OPERATOR_PKG=github.com/flyteorg/flytepropeller ./hack/update-codegen.sh @@ -53,4 +54,3 @@ clean: golden: go test ./cmd/kubectl-flyte/cmd -update go test ./pkg/compiler/test -update - diff --git a/cmd/controller/cmd/webhook.go b/cmd/controller/cmd/webhook.go index d22da0f28..87e19448d 100644 --- a/cmd/controller/cmd/webhook.go +++ b/cmd/controller/cmd/webhook.go @@ -130,6 +130,14 @@ func runWebhook(origContext context.Context, propellerCfg *config.Config, cfg *w return err }) + g.Go(func() error { + err := controller.StartControllerManager(childCtx, mgr) + if err != nil { + logger.Fatalf(childCtx, "Failed to start controller manager. Error: %v", err) + } + return err + }) + g.Go(func() error { err := webhook.Run(childCtx, propellerCfg, cfg, defaultNamespace, &webhookScope, mgr) if err != nil { diff --git a/go.mod b/go.mod index fbfeb47c7..7a41b456a 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/fatih/color v1.13.0 github.com/flyteorg/flyteidl v1.3.14 - github.com/flyteorg/flyteplugins v1.0.40 + github.com/flyteorg/flyteplugins v1.0.49 github.com/flyteorg/flytestdlib v1.0.15 github.com/ghodss/yaml v1.0.0 github.com/go-redis/redis v6.15.7+incompatible diff --git a/pkg/compiler/transformers/k8s/node.go b/pkg/compiler/transformers/k8s/node.go index f8d5947c9..d9ac41dd4 100644 --- a/pkg/compiler/transformers/k8s/node.go +++ b/pkg/compiler/transformers/k8s/node.go @@ -45,8 +45,6 @@ func buildNodeSpec(n *core.Node, tasks []*core.CompiledTask, errs errors.Compile if n.GetTaskNode().Overrides != nil && n.GetTaskNode().Overrides.Resources != nil { resources = n.GetTaskNode().Overrides.Resources - } else { - resources = getResources(task) } } diff --git a/pkg/compiler/transformers/k8s/node_test.go b/pkg/compiler/transformers/k8s/node_test.go index 010eb0c3d..a9732d9d7 100644 --- a/pkg/compiler/transformers/k8s/node_test.go +++ b/pkg/compiler/transformers/k8s/node_test.go @@ -93,22 +93,6 @@ func TestBuildNodeSpec(t *testing.T) { mustBuild(t, n, 1, errs.NewScope()) }) - t.Run("Task with resources", func(t *testing.T) { - expectedCPU := resource.MustParse("10Mi") - n.Node.Target = &core.Node_TaskNode{ - TaskNode: &core.TaskNode{ - Reference: &core.TaskNode_ReferenceId{ - ReferenceId: &core.Identifier{Name: "ref_2"}, - }, - }, - } - - spec := mustBuild(t, n, 1, errs.NewScope()) - assert.NotNil(t, spec.Resources) - assert.NotNil(t, spec.Resources.Requests.Cpu()) - assert.Equal(t, expectedCPU.Value(), spec.Resources.Requests.Cpu().Value()) - }) - t.Run("node with resource overrides", func(t *testing.T) { expectedCPU := resource.MustParse("20Mi") n.Node.Target = &core.Node_TaskNode{ diff --git a/pkg/compiler/transformers/k8s/utils.go b/pkg/compiler/transformers/k8s/utils.go index 1cf9a521c..5f8a0f85a 100644 --- a/pkg/compiler/transformers/k8s/utils.go +++ b/pkg/compiler/transformers/k8s/utils.go @@ -48,18 +48,6 @@ func computeDeadline(n *core.Node) (*v1.Duration, error) { return deadline, nil } -func getResources(task *core.TaskTemplate) *core.Resources { - if task == nil { - return nil - } - - if task.GetContainer() == nil { - return nil - } - - return task.GetContainer().Resources -} - func toAliasValueArray(aliases []*core.Alias) []v1alpha1.Alias { if aliases == nil { return nil diff --git a/pkg/controller/executors/mocks/node_lookup.go b/pkg/controller/executors/mocks/node_lookup.go index 036a0400d..eac909a11 100644 --- a/pkg/controller/executors/mocks/node_lookup.go +++ b/pkg/controller/executors/mocks/node_lookup.go @@ -15,6 +15,47 @@ type NodeLookup struct { mock.Mock } +type NodeLookup_FromNode struct { + *mock.Call +} + +func (_m NodeLookup_FromNode) Return(_a0 []string, _a1 error) *NodeLookup_FromNode { + return &NodeLookup_FromNode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *NodeLookup) OnFromNode(id string) *NodeLookup_FromNode { + c_call := _m.On("FromNode", id) + return &NodeLookup_FromNode{Call: c_call} +} + +func (_m *NodeLookup) OnFromNodeMatch(matchers ...interface{}) *NodeLookup_FromNode { + c_call := _m.On("FromNode", matchers...) + return &NodeLookup_FromNode{Call: c_call} +} + +// FromNode provides a mock function with given fields: id +func (_m *NodeLookup) FromNode(id string) ([]string, error) { + ret := _m.Called(id) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + type NodeLookup_GetNode struct { *mock.Call } @@ -89,3 +130,44 @@ func (_m *NodeLookup) GetNodeExecutionStatus(ctx context.Context, id string) v1a return r0 } + +type NodeLookup_ToNode struct { + *mock.Call +} + +func (_m NodeLookup_ToNode) Return(_a0 []string, _a1 error) *NodeLookup_ToNode { + return &NodeLookup_ToNode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *NodeLookup) OnToNode(id string) *NodeLookup_ToNode { + c_call := _m.On("ToNode", id) + return &NodeLookup_ToNode{Call: c_call} +} + +func (_m *NodeLookup) OnToNodeMatch(matchers ...interface{}) *NodeLookup_ToNode { + c_call := _m.On("ToNode", matchers...) + return &NodeLookup_ToNode{Call: c_call} +} + +// ToNode provides a mock function with given fields: id +func (_m *NodeLookup) ToNode(id string) ([]string, error) { + ret := _m.Called(id) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/controller/executors/node_lookup.go b/pkg/controller/executors/node_lookup.go index 9b49dc4ff..381b832c0 100644 --- a/pkg/controller/executors/node_lookup.go +++ b/pkg/controller/executors/node_lookup.go @@ -12,21 +12,27 @@ import ( type NodeLookup interface { GetNode(nodeID v1alpha1.NodeID) (v1alpha1.ExecutableNode, bool) GetNodeExecutionStatus(ctx context.Context, id v1alpha1.NodeID) v1alpha1.ExecutableNodeStatus + // Lookup for upstream edges, find all node ids from which this node can be reached. + ToNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) + // Lookup for downstream edges, find all node ids that can be reached from the given node id. + FromNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) } // Implements a contextual NodeLookup that can be composed of a disparate NodeGetter and a NodeStatusGetter type contextualNodeLookup struct { v1alpha1.NodeGetter v1alpha1.NodeStatusGetter + DAGStructure } // Returns a Contextual NodeLookup using the given NodeGetter and a separate NodeStatusGetter. // Very useful in Subworkflows where the Subworkflow is the reservoir of the nodes, but the status for these nodes // maybe stored int he Top-level workflow node itself. -func NewNodeLookup(n v1alpha1.NodeGetter, s v1alpha1.NodeStatusGetter) NodeLookup { +func NewNodeLookup(n v1alpha1.NodeGetter, s v1alpha1.NodeStatusGetter, d DAGStructure) NodeLookup { return contextualNodeLookup{ NodeGetter: n, NodeStatusGetter: s, + DAGStructure: d, } } @@ -45,6 +51,14 @@ func (s staticNodeLookup) GetNodeExecutionStatus(_ context.Context, id v1alpha1. return s.status[id] } +func (s staticNodeLookup) ToNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) { + return nil, nil +} + +func (s staticNodeLookup) FromNode(id v1alpha1.NodeID) ([]v1alpha1.NodeID, error) { + return nil, nil +} + // Returns a new NodeLookup useful in Testing. Not recommended to be used in production func NewTestNodeLookup(nodes map[v1alpha1.NodeID]v1alpha1.ExecutableNode, status map[v1alpha1.NodeID]v1alpha1.ExecutableNodeStatus) NodeLookup { return staticNodeLookup{ diff --git a/pkg/controller/executors/node_lookup_test.go b/pkg/controller/executors/node_lookup_test.go index a86b00b08..4bce76138 100644 --- a/pkg/controller/executors/node_lookup_test.go +++ b/pkg/controller/executors/node_lookup_test.go @@ -18,14 +18,20 @@ type nsg struct { v1alpha1.NodeStatusGetter } +type dag struct { + DAGStructure +} + func TestNewNodeLookup(t *testing.T) { n := ng{} ns := nsg{} - nl := NewNodeLookup(n, ns) + d := dag{} + nl := NewNodeLookup(n, ns, d) assert.NotNil(t, nl) typed := nl.(contextualNodeLookup) assert.Equal(t, n, typed.NodeGetter) assert.Equal(t, ns, typed.NodeStatusGetter) + assert.Equal(t, d, typed.DAGStructure) } func TestNewTestNodeLookup(t *testing.T) { diff --git a/pkg/controller/nodes/branch/handler.go b/pkg/controller/nodes/branch/handler.go index 9b0cd7f59..109290b90 100644 --- a/pkg/controller/nodes/branch/handler.go +++ b/pkg/controller/nodes/branch/handler.go @@ -136,7 +136,11 @@ func (b *branchHandler) recurseDownstream(ctx context.Context, nCtx handler.Node childNodeStatus := nl.GetNodeExecutionStatus(ctx, branchTakenNode.GetID()) childNodeStatus.SetDataDir(nodeStatus.GetDataDir()) childNodeStatus.SetOutputDir(nodeStatus.GetOutputDir()) - dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), nCtx.NodeID()) + upstreamNodeIds, err := nCtx.ContextualNodeLookup().ToNode(branchTakenNode.GetID()) + if err != nil { + return handler.UnknownTransition, err + } + dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), append(upstreamNodeIds, nCtx.NodeID())...) execContext, err := b.getExecutionContextForDownstream(nCtx) if err != nil { return handler.UnknownTransition, err @@ -196,7 +200,11 @@ func (b *branchHandler) Abort(ctx context.Context, nCtx handler.NodeExecutionCon // TODO we should replace the call to RecursiveNodeHandler with a call to SingleNode Handler. The inputs are also already known ahead of time // There is no DAGStructure for the branch nodes, the branch taken node is the leaf node. The node itself may be arbitrarily complex, but in that case the node should reference a subworkflow etc // The parent of the BranchTaken Node is the actual Branch Node and all the data is just forwarded from the Branch to the executed node. - dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), nCtx.NodeID()) + upstreamNodeIds, err := nCtx.ContextualNodeLookup().ToNode(branchTakenNode.GetID()) + if err != nil { + return err + } + dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), append(upstreamNodeIds, nCtx.NodeID())...) execContext, err := b.getExecutionContextForDownstream(nCtx) if err != nil { return err @@ -236,7 +244,11 @@ func (b *branchHandler) Finalize(ctx context.Context, nCtx handler.NodeExecution // TODO we should replace the call to RecursiveNodeHandler with a call to SingleNode Handler. The inputs are also already known ahead of time // There is no DAGStructure for the branch nodes, the branch taken node is the leaf node. The node itself may be arbitrarily complex, but in that case the node should reference a subworkflow etc // The parent of the BranchTaken Node is the actual Branch Node and all the data is just forwarded from the Branch to the executed node. - dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), nCtx.NodeID()) + upstreamNodeIds, err := nCtx.ContextualNodeLookup().ToNode(branchTakenNode.GetID()) + if err != nil { + return err + } + dag := executors.NewLeafNodeDAGStructure(branchTakenNode.GetID(), append(upstreamNodeIds, nCtx.NodeID())...) execContext, err := b.getExecutionContextForDownstream(nCtx) if err != nil { return err diff --git a/pkg/controller/nodes/branch/handler_test.go b/pkg/controller/nodes/branch/handler_test.go index bc39c1b24..5711de5d4 100644 --- a/pkg/controller/nodes/branch/handler_test.go +++ b/pkg/controller/nodes/branch/handler_test.go @@ -158,24 +158,34 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { isErr bool expectedPhase handler.EPhase childPhase v1alpha1.NodePhase - nl *execMocks.NodeLookup + upstreamNodeID string }{ + {"upstreamNodeExists", executors.NodeStatusPending, nil, + &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, "n2"}, {"childNodeError", executors.NodeStatusUndefined, fmt.Errorf("err"), - &mocks2.ExecutableNodeStatus{}, bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, &execMocks.NodeLookup{}}, + &mocks2.ExecutableNodeStatus{}, bn, true, handler.EPhaseUndefined, v1alpha1.NodePhaseFailed, ""}, {"childPending", executors.NodeStatusPending, nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, &execMocks.NodeLookup{}}, + &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseQueued, ""}, {"childStillRunning", executors.NodeStatusRunning, nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, &execMocks.NodeLookup{}}, + &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseRunning, v1alpha1.NodePhaseRunning, ""}, {"childFailure", executors.NodeStatusFailed(expectedError), nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, &execMocks.NodeLookup{}}, + &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseFailed, v1alpha1.NodePhaseFailed, ""}, {"childComplete", executors.NodeStatusComplete, nil, - &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, &execMocks.NodeLookup{}}, + &mocks2.ExecutableNodeStatus{}, bn, false, handler.EPhaseSuccess, v1alpha1.NodePhaseSucceeded, ""}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { eCtx := &execMocks.ExecutionContext{} eCtx.OnGetParentInfo().Return(parentInfo{}) - nCtx, _ := createNodeContext(v1alpha1.BranchNodeNotYetEvaluated, &childNodeID, n, nil, test.nl, eCtx) + + mockNodeLookup := &execMocks.NodeLookup{} + if len(test.upstreamNodeID) > 0 { + mockNodeLookup.OnToNodeMatch(childNodeID).Return([]string{test.upstreamNodeID}, nil) + } else { + mockNodeLookup.OnToNodeMatch(childNodeID).Return(nil, nil) + } + + nCtx, _ := createNodeContext(v1alpha1.BranchNodeNotYetEvaluated, &childNodeID, n, nil, mockNodeLookup, eCtx) newParentInfo, _ := common.CreateParentInfo(parentInfo{}, nCtx.NodeID(), nCtx.CurrentAttempt()) expectedExecContext := executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo) mockNodeExecutor := &execMocks.Node{} @@ -187,23 +197,27 @@ func TestBranchHandler_RecurseDownstream(t *testing.T) { fList, err1 := d.FromNode("x") dList, err2 := d.ToNode(childNodeID) b := assert.NoError(t, err1) - b = b && assert.Equal(t, fList, []v1alpha1.NodeID{}) + b = b && assert.Equal(t, []v1alpha1.NodeID{}, fList) b = b && assert.NoError(t, err2) - b = b && assert.Equal(t, dList, []v1alpha1.NodeID{nodeID}) + dListExpected := []v1alpha1.NodeID{nodeID} + if len(test.upstreamNodeID) > 0 { + dListExpected = append([]string{test.upstreamNodeID}, dListExpected...) + } + b = b && assert.Equal(t, dListExpected, dList) return b } return false }), - mock.MatchedBy(func(lookup executors.NodeLookup) bool { return assert.Equal(t, lookup, test.nl) }), + mock.MatchedBy(func(lookup executors.NodeLookup) bool { return assert.Equal(t, lookup, mockNodeLookup) }), mock.MatchedBy(func(n v1alpha1.ExecutableNode) bool { return assert.Equal(t, n.GetID(), childNodeID) }), ).Return(test.ns, test.err) childNodeStatus := &mocks2.ExecutableNodeStatus{} - if test.nl != nil { + if mockNodeLookup != nil { childNodeStatus.OnGetOutputDir().Return("parent-output-dir") test.nodeStatus.OnGetDataDir().Return("parent-data-dir") test.nodeStatus.OnGetOutputDir().Return("parent-output-dir") - test.nl.OnGetNodeExecutionStatus(ctx, childNodeID).Return(childNodeStatus) + mockNodeLookup.OnGetNodeExecutionStatus(ctx, childNodeID).Return(childNodeStatus) childNodeStatus.On("SetDataDir", storage.DataReference("parent-data-dir")).Once() childNodeStatus.On("SetOutputDir", storage.DataReference("parent-output-dir")).Once() } @@ -295,17 +309,18 @@ func TestBranchHandler_AbortNode(t *testing.T) { t.Run("BranchNodeSuccess", func(t *testing.T) { mockNodeExecutor := &execMocks.Node{} - nl := &execMocks.NodeLookup{} + mockNodeLookup := &execMocks.NodeLookup{} + mockNodeLookup.OnToNodeMatch(mock.Anything).Return(nil, nil) eCtx := &execMocks.ExecutionContext{} eCtx.OnGetParentInfo().Return(parentInfo{}) - nCtx, s := createNodeContext(v1alpha1.BranchNodeSuccess, &n1, n, nil, nl, eCtx) + nCtx, s := createNodeContext(v1alpha1.BranchNodeSuccess, &n1, n, nil, mockNodeLookup, eCtx) newParentInfo, _ := common.CreateParentInfo(parentInfo{}, nCtx.NodeID(), nCtx.CurrentAttempt()) expectedExecContext := executors.NewExecutionContextWithParentInfo(nCtx.ExecutionContext(), newParentInfo) mockNodeExecutor.OnAbortHandlerMatch(mock.Anything, mock.MatchedBy(func(e executors.ExecutionContext) bool { return assert.Equal(t, e, expectedExecContext) }), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - nl.OnGetNode(*s.s.FinalizedNodeID).Return(n, true) + mockNodeLookup.OnGetNode(*s.s.FinalizedNodeID).Return(n, true) branch := New(mockNodeExecutor, eventConfig, promutils.NewTestScope()) err := branch.Abort(ctx, nCtx, "") assert.NoError(t, err) diff --git a/pkg/controller/nodes/dynamic/dynamic_workflow.go b/pkg/controller/nodes/dynamic/dynamic_workflow.go index ef1fc51c3..eb891aa27 100644 --- a/pkg/controller/nodes/dynamic/dynamic_workflow.go +++ b/pkg/controller/nodes/dynamic/dynamic_workflow.go @@ -183,7 +183,7 @@ func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.C subWorkflow: compiledWf, subWorkflowClosure: workflowCacheContents.CompiledWorkflow, execContext: executors.NewExecutionContext(nCtx.ExecutionContext(), compiledWf, compiledWf, newParentInfo, nCtx.ExecutionContext()), - nodeLookup: executors.NewNodeLookup(compiledWf, dynamicNodeStatus), + nodeLookup: executors.NewNodeLookup(compiledWf, dynamicNodeStatus, compiledWf), dynamicJobSpecURI: string(f.GetLoc()), }, nil } @@ -216,7 +216,7 @@ func (d dynamicNodeTaskNodeHandler) buildContextualDynamicWorkflow(ctx context.C subWorkflow: dynamicWf, subWorkflowClosure: closure, execContext: executors.NewExecutionContext(nCtx.ExecutionContext(), dynamicWf, dynamicWf, newParentInfo, nCtx.ExecutionContext()), - nodeLookup: executors.NewNodeLookup(dynamicWf, dynamicNodeStatus), + nodeLookup: executors.NewNodeLookup(dynamicWf, dynamicNodeStatus, dynamicWf), dynamicJobSpecURI: string(f.GetLoc()), }, nil } diff --git a/pkg/controller/nodes/executor.go b/pkg/controller/nodes/executor.go index 7e42e0ee3..c447b779c 100644 --- a/pkg/controller/nodes/executor.go +++ b/pkg/controller/nodes/executor.go @@ -216,7 +216,7 @@ func (c *nodeExecutor) attemptRecovery(ctx context.Context, nCtx handler.NodeExe // A recoverable node execution should always be in a terminal phase switch recovered.Closure.Phase { case core.NodeExecution_SKIPPED: - return handler.PhaseInfoSkip(nil, "node execution recovery indicated original node was skipped"), nil + return handler.PhaseInfoUndefined, nil case core.NodeExecution_SUCCEEDED: fallthrough case core.NodeExecution_RECOVERED: @@ -524,6 +524,7 @@ func (c *nodeExecutor) finalize(ctx context.Context, h handler.Node, nCtx handle func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executors.DAGStructure, nCtx *nodeExecContext, _ handler.Node) (executors.NodeStatus, error) { logger.Debugf(ctx, "Node not yet started, running pre-execute") defer logger.Debugf(ctx, "Node pre-execute completed") + occurredAt := time.Now() p, err := c.preExecute(ctx, dag, nCtx) if err != nil { logger.Errorf(ctx, "failed preExecute for node. Error: %s", err.Error()) @@ -547,6 +548,7 @@ func (c *nodeExecutor) handleNotYetStartedNode(ctx context.Context, dag executor if np != nodeStatus.GetPhase() { // assert np == Queued! logger.Infof(ctx, "Change in node state detected from [%s] -> [%s]", nodeStatus.GetPhase().String(), np.String()) + p = p.WithOccuredAt(occurredAt) nev, err := ToNodeExecutionEvent(nCtx.NodeExecutionMetadata().GetNodeExecutionID(), p, nCtx.InputReader().GetInputPath().String(), nodeStatus, nCtx.ExecutionContext().GetEventVersion(), @@ -691,6 +693,7 @@ func (c *nodeExecutor) handleQueuedOrRunningNode(ctx context.Context, nCtx *node Message: err.Error(), }, }, + ReportedAt: ptypes.TimestampNow(), }) if err != nil { @@ -1152,6 +1155,7 @@ func (c *nodeExecutor) AbortHandler(ctx context.Context, execContext executors.E }, }, ProducerId: c.clusterID, + ReportedAt: ptypes.TimestampNow(), }) if err != nil && !eventsErr.IsNotFound(err) && !eventsErr.IsEventIncompatibleClusterError(err) { if errors2.IsCausedBy(err, errors.IllegalStateError) { diff --git a/pkg/controller/nodes/handler/state.go b/pkg/controller/nodes/handler/state.go index c34901eab..2ca4fb015 100644 --- a/pkg/controller/nodes/handler/state.go +++ b/pkg/controller/nodes/handler/state.go @@ -18,7 +18,6 @@ type TaskNodeState struct { PluginPhaseVersion uint32 PluginState []byte PluginStateVersion uint32 - BarrierClockTick uint32 LastPhaseUpdatedAt time.Time PreviousNodeExecutionCheckpointURI storage.DataReference CleanupOnFailure bool diff --git a/pkg/controller/nodes/handler/transition.go b/pkg/controller/nodes/handler/transition.go index 335076b47..8d145102d 100644 --- a/pkg/controller/nodes/handler/transition.go +++ b/pkg/controller/nodes/handler/transition.go @@ -4,6 +4,7 @@ type TransitionType int const ( TransitionTypeEphemeral TransitionType = iota + // @deprecated support for Barrier type transitions has been deprecated TransitionTypeBarrier ) diff --git a/pkg/controller/nodes/handler/transition_info.go b/pkg/controller/nodes/handler/transition_info.go index 400b00935..5d302f4fa 100644 --- a/pkg/controller/nodes/handler/transition_info.go +++ b/pkg/controller/nodes/handler/transition_info.go @@ -105,6 +105,16 @@ func (p PhaseInfo) WithInfo(i *ExecutionInfo) PhaseInfo { } } +func (p PhaseInfo) WithOccuredAt(t time.Time) PhaseInfo { + return PhaseInfo{ + p: p.p, + occurredAt: t, + err: p.err, + info: p.info, + reason: p.reason, + } +} + var PhaseInfoUndefined = PhaseInfo{p: EPhaseUndefined} func phaseInfo(p EPhase, err *core.ExecutionError, info *ExecutionInfo, reason string) PhaseInfo { diff --git a/pkg/controller/nodes/node_state_manager.go b/pkg/controller/nodes/node_state_manager.go index cd51c73bc..89347f79b 100644 --- a/pkg/controller/nodes/node_state_manager.go +++ b/pkg/controller/nodes/node_state_manager.go @@ -51,7 +51,6 @@ func (n nodeStateManager) GetTaskNodeState() handler.TaskNodeState { PluginPhaseVersion: tn.GetPhaseVersion(), PluginStateVersion: tn.GetPluginStateVersion(), PluginState: tn.GetPluginState(), - BarrierClockTick: tn.GetBarrierClockTick(), LastPhaseUpdatedAt: tn.GetLastPhaseUpdatedAt(), PreviousNodeExecutionCheckpointURI: tn.GetPreviousNodeExecutionCheckpointPath(), CleanupOnFailure: tn.GetCleanupOnFailure(), diff --git a/pkg/controller/nodes/subworkflow/launchplan/admin.go b/pkg/controller/nodes/subworkflow/launchplan/admin.go index 251bd207d..0329f3aef 100644 --- a/pkg/controller/nodes/subworkflow/launchplan/admin.go +++ b/pkg/controller/nodes/subworkflow/launchplan/admin.go @@ -2,14 +2,17 @@ package launchplan import ( "context" + "errors" "fmt" "time" + evtErr "github.com/flyteorg/flytepropeller/events/errors" + "github.com/flyteorg/flytestdlib/cache" "golang.org/x/time/rate" "k8s.io/client-go/util/workqueue" - "github.com/flyteorg/flytestdlib/errors" + stdErr "github.com/flyteorg/flytestdlib/errors" "github.com/flyteorg/flytestdlib/logger" @@ -61,11 +64,11 @@ func (a *adminLaunchPlanExecutor) handleLaunchError(ctx context.Context, isRecov logger.Errorf(ctx, "Failed to add ExecID [%v] to auto refresh cache", executionID) } - return errors.Wrapf(RemoteErrorAlreadyExists, err, "ExecID %s already exists", executionID.Name) + return stdErr.Wrapf(RemoteErrorAlreadyExists, err, "ExecID %s already exists", executionID.Name) case codes.DataLoss, codes.DeadlineExceeded, codes.Internal, codes.Unknown, codes.Canceled: - return errors.Wrapf(RemoteErrorSystem, err, "failed to launch workflow [%s], system error", launchPlanRef.Name) + return stdErr.Wrapf(RemoteErrorSystem, err, "failed to launch workflow [%s], system error", launchPlanRef.Name) default: - return errors.Wrapf(RemoteErrorUser, err, "failed to launch workflow") + return stdErr.Wrapf(RemoteErrorUser, err, "failed to launch workflow") } } @@ -154,9 +157,9 @@ func (a *adminLaunchPlanExecutor) GetLaunchPlan(ctx context.Context, launchPlanR lp, err := a.adminClient.GetLaunchPlan(ctx, &getObjectRequest) if err != nil { if status.Code(err) == codes.NotFound { - return nil, errors.Wrapf(RemoteErrorNotFound, err, "No launch plan retrieved from Admin") + return nil, stdErr.Wrapf(RemoteErrorNotFound, err, "No launch plan retrieved from Admin") } - return nil, errors.Wrapf(RemoteErrorSystem, err, "Could not fetch launch plan definition from Admin") + return nil, stdErr.Wrapf(RemoteErrorSystem, err, "Could not fetch launch plan definition from Admin") } return lp, nil @@ -172,7 +175,16 @@ func (a *adminLaunchPlanExecutor) Kill(ctx context.Context, executionID *core.Wo if status.Code(err) == codes.NotFound { return nil } - return errors.Wrapf(RemoteErrorSystem, err, "system error") + + err = evtErr.WrapError(err) + eventErr := &evtErr.EventError{} + if errors.As(err, eventErr) { + if eventErr.Code == evtErr.EventAlreadyInTerminalStateError { + return nil + } + } + + return stdErr.Wrapf(RemoteErrorSystem, err, "system error") } return nil } @@ -207,12 +219,12 @@ func (a *adminLaunchPlanExecutor) syncItem(ctx context.Context, batch cache.Batc res, err := a.adminClient.GetExecution(ctx, req) if err != nil { - // TODO: Define which error codes are system errors (and return the error) vs user errors. + // TODO: Define which error codes are system errors (and return the error) vs user stdErr. if status.Code(err) == codes.NotFound { - err = errors.Wrapf(RemoteErrorNotFound, err, "execID [%s] not found on remote", exec.WorkflowExecutionIdentifier.Name) + err = stdErr.Wrapf(RemoteErrorNotFound, err, "execID [%s] not found on remote", exec.WorkflowExecutionIdentifier.Name) } else { - err = errors.Wrapf(RemoteErrorSystem, err, "system error") + err = stdErr.Wrapf(RemoteErrorSystem, err, "system error") } resp = append(resp, cache.ItemSyncResponse{ diff --git a/pkg/controller/nodes/subworkflow/subworkflow.go b/pkg/controller/nodes/subworkflow/subworkflow.go index 24d74473f..74beeaf79 100644 --- a/pkg/controller/nodes/subworkflow/subworkflow.go +++ b/pkg/controller/nodes/subworkflow/subworkflow.go @@ -209,7 +209,7 @@ func (s *subworkflowHandler) HandleFailingSubWorkflow(ctx context.Context, nCtx } status := nCtx.NodeStatus() - nodeLookup := executors.NewNodeLookup(subWorkflow, status) + nodeLookup := executors.NewNodeLookup(subWorkflow, status, subWorkflow) return s.HandleFailureNodeOfSubWorkflow(ctx, nCtx, subWorkflow, nodeLookup) } @@ -220,7 +220,7 @@ func (s *subworkflowHandler) StartSubWorkflow(ctx context.Context, nCtx handler. } status := nCtx.NodeStatus() - nodeLookup := executors.NewNodeLookup(subWorkflow, status) + nodeLookup := executors.NewNodeLookup(subWorkflow, status, subWorkflow) // assert startStatus.IsComplete() == true return s.startAndHandleSubWorkflow(ctx, nCtx, subWorkflow, nodeLookup) @@ -233,7 +233,7 @@ func (s *subworkflowHandler) CheckSubWorkflowStatus(ctx context.Context, nCtx ha } status := nCtx.NodeStatus() - nodeLookup := executors.NewNodeLookup(subWorkflow, status) + nodeLookup := executors.NewNodeLookup(subWorkflow, status, subWorkflow) return s.handleSubWorkflow(ctx, nCtx, subWorkflow, nodeLookup) } @@ -243,7 +243,7 @@ func (s *subworkflowHandler) HandleAbort(ctx context.Context, nCtx handler.NodeE return err } status := nCtx.NodeStatus() - nodeLookup := executors.NewNodeLookup(subWorkflow, status) + nodeLookup := executors.NewNodeLookup(subWorkflow, status, subWorkflow) execContext, err := s.getExecutionContextForDownstream(nCtx) if err != nil { return err diff --git a/pkg/controller/nodes/task/barrier.go b/pkg/controller/nodes/task/barrier.go deleted file mode 100644 index 0b0f84b6e..000000000 --- a/pkg/controller/nodes/task/barrier.go +++ /dev/null @@ -1,61 +0,0 @@ -package task - -import ( - "context" - "time" - - "github.com/flyteorg/flytestdlib/logger" - "k8s.io/apimachinery/pkg/util/cache" - - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/task/config" -) - -type BarrierKey = string - -type PluginCallLog struct { - PluginTransition *pluginRequestedTransition -} - -type BarrierTransition struct { - BarrierClockTick uint32 - CallLog PluginCallLog -} - -var NoBarrierTransition = BarrierTransition{BarrierClockTick: 0} - -type barrier struct { - barrierCacheExpiration time.Duration - barrierTransitions *cache.LRUExpireCache - barrierEnabled bool -} - -func (b *barrier) RecordBarrierTransition(ctx context.Context, k BarrierKey, bt BarrierTransition) { - if b.barrierEnabled { - b.barrierTransitions.Add(k, bt, b.barrierCacheExpiration) - } -} - -func (b *barrier) GetPreviousBarrierTransition(ctx context.Context, k BarrierKey) BarrierTransition { - if b.barrierEnabled { - if v, ok := b.barrierTransitions.Get(k); ok { - f, casted := v.(BarrierTransition) - if !casted { - logger.Errorf(ctx, "Failed to cast recorded value to BarrierTransition") - return NoBarrierTransition - } - return f - } - } - return NoBarrierTransition -} - -func newLRUBarrier(_ context.Context, cfg config.BarrierConfig) *barrier { - b := &barrier{ - barrierEnabled: cfg.Enabled, - } - if cfg.Enabled { - b.barrierCacheExpiration = cfg.CacheTTL.Duration - b.barrierTransitions = cache.NewLRUExpireCache(cfg.CacheSize) - } - return b -} diff --git a/pkg/controller/nodes/task/config/config.go b/pkg/controller/nodes/task/config/config.go index 4bc2937c5..020795675 100644 --- a/pkg/controller/nodes/task/config/config.go +++ b/pkg/controller/nodes/task/config/config.go @@ -20,11 +20,6 @@ var ( defaultConfig = &Config{ TaskPlugins: TaskPluginConfig{EnabledPlugins: []string{}, DefaultForTaskTypes: map[string]string{}}, MaxPluginPhaseVersions: 100000, - BarrierConfig: BarrierConfig{ - Enabled: true, - CacheSize: 10000, - CacheTTL: config.Duration{Duration: time.Minute * 30}, - }, BackOffConfig: BackOffConfig{ BaseSecond: 2, MaxDuration: config.Duration{Duration: time.Second * 20}, @@ -37,17 +32,10 @@ var ( type Config struct { TaskPlugins TaskPluginConfig `json:"task-plugins" pflag:",Task plugin configuration"` MaxPluginPhaseVersions int32 `json:"max-plugin-phase-versions" pflag:",Maximum number of plugin phase versions allowed for one phase."` - BarrierConfig BarrierConfig `json:"barrier" pflag:",Config for Barrier implementation"` BackOffConfig BackOffConfig `json:"backoff" pflag:",Config for Exponential BackOff implementation"` MaxErrorMessageLength int `json:"maxLogMessageLength" pflag:",Deprecated!!! Max length of error message."` } -type BarrierConfig struct { - Enabled bool `json:"enabled" pflag:",Enable Barrier transitions using inmemory context"` - CacheSize int `json:"cache-size" pflag:",Max number of barrier to preserve in memory"` - CacheTTL config.Duration `json:"cache-ttl" pflag:", Max duration that a barrier would be respected if the process is not restarted. This should account for time required to store the record into persistent storage (across multiple rounds."` -} - type TaskPluginConfig struct { EnabledPlugins []string `json:"enabled-plugins" pflag:",Plugins enabled currently"` // Maps task types to their plugin handler (by ID). diff --git a/pkg/controller/nodes/task/config/config_flags.go b/pkg/controller/nodes/task/config/config_flags.go index a77a6f58e..540d0214d 100755 --- a/pkg/controller/nodes/task/config/config_flags.go +++ b/pkg/controller/nodes/task/config/config_flags.go @@ -52,9 +52,6 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "task-plugins.enabled-plugins"), defaultConfig.TaskPlugins.EnabledPlugins, "Plugins enabled currently") cmdFlags.Int32(fmt.Sprintf("%v%v", prefix, "max-plugin-phase-versions"), defaultConfig.MaxPluginPhaseVersions, "Maximum number of plugin phase versions allowed for one phase.") - cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "barrier.enabled"), defaultConfig.BarrierConfig.Enabled, "Enable Barrier transitions using inmemory context") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "barrier.cache-size"), defaultConfig.BarrierConfig.CacheSize, "Max number of barrier to preserve in memory") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "barrier.cache-ttl"), defaultConfig.BarrierConfig.CacheTTL.String(), " Max duration that a barrier would be respected if the process is not restarted. This should account for time required to store the record into persistent storage (across multiple rounds.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "backoff.base-second"), defaultConfig.BackOffConfig.BaseSecond, "The number of seconds representing the base duration of the exponential backoff") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "backoff.max-duration"), defaultConfig.BackOffConfig.MaxDuration.String(), "The cap of the backoff duration") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "maxLogMessageLength"), defaultConfig.MaxErrorMessageLength, "Deprecated!!! Max length of error message.") diff --git a/pkg/controller/nodes/task/config/config_flags_test.go b/pkg/controller/nodes/task/config/config_flags_test.go index ef4f327d6..cc2f02534 100755 --- a/pkg/controller/nodes/task/config/config_flags_test.go +++ b/pkg/controller/nodes/task/config/config_flags_test.go @@ -127,48 +127,6 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_barrier.enabled", func(t *testing.T) { - - t.Run("Override", func(t *testing.T) { - testValue := "1" - - cmdFlags.Set("barrier.enabled", testValue) - if vBool, err := cmdFlags.GetBool("barrier.enabled"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.BarrierConfig.Enabled) - - } else { - assert.FailNow(t, err.Error()) - } - }) - }) - t.Run("Test_barrier.cache-size", func(t *testing.T) { - - t.Run("Override", func(t *testing.T) { - testValue := "1" - - cmdFlags.Set("barrier.cache-size", testValue) - if vInt, err := cmdFlags.GetInt("barrier.cache-size"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.BarrierConfig.CacheSize) - - } else { - assert.FailNow(t, err.Error()) - } - }) - }) - t.Run("Test_barrier.cache-ttl", func(t *testing.T) { - - t.Run("Override", func(t *testing.T) { - testValue := defaultConfig.BarrierConfig.CacheTTL.String() - - cmdFlags.Set("barrier.cache-ttl", testValue) - if vString, err := cmdFlags.GetString("barrier.cache-ttl"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.BarrierConfig.CacheTTL) - - } else { - assert.FailNow(t, err.Error()) - } - }) - }) t.Run("Test_backoff.base-second", func(t *testing.T) { t.Run("Override", func(t *testing.T) { diff --git a/pkg/controller/nodes/task/handler.go b/pkg/controller/nodes/task/handler.go index b95cdcb9b..d7300818b 100644 --- a/pkg/controller/nodes/task/handler.go +++ b/pkg/controller/nodes/task/handler.go @@ -207,7 +207,6 @@ type Handler struct { kubeClient pluginCore.KubeClient secretManager pluginCore.SecretManager resourceManager resourcemanager.BaseResourceManager - barrierCache *barrier cfg *config.Config pluginScope promutils.Scope eventConfig *controllerConfig.EventConfig @@ -658,47 +657,19 @@ func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) } } - barrierTick := uint32(0) + occurredAt := time.Now() // STEP 2: If no cache-hit and not transitioning to PhaseWaitingForCache, then lets invoke the plugin and wait for a transition out of undefined if pluginTrns.execInfo.TaskNodeInfo == nil || (pluginTrns.pInfo.Phase() != pluginCore.PhaseWaitingForCache && pluginTrns.execInfo.TaskNodeInfo.TaskNodeMetadata.CacheStatus != core.CatalogCacheStatus_CACHE_HIT) { - prevBarrier := t.barrierCache.GetPreviousBarrierTransition(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) - // Lets start with the current barrierTick (the value to be stored) same as the barrierTick in the cache - barrierTick = prevBarrier.BarrierClockTick - // Lets check if this value in cache is less than or equal to one in the store - if barrierTick <= ts.BarrierClockTick { - var err error - pluginTrns, err = t.invokePlugin(ctx, p, tCtx, ts) - if err != nil { - return handler.UnknownTransition, errors.Wrapf(errors.RuntimeExecutionError, nCtx.NodeID(), err, "failed during plugin execution") - } - if pluginTrns.IsPreviouslyObserved() { - logger.Debugf(ctx, "No state change for Task, previously observed same transition. Short circuiting.") - return pluginTrns.FinalTransition(ctx) - } - // Now no matter what we should update the barrierTick (stored in state) - // This is because the state is ahead of the inmemory representation - // This can happen in the case where the process restarted or the barrier cache got reset - barrierTick = ts.BarrierClockTick - // Now if the transition is of type barrier, lets tick the clock by one from the prev known value - // store that in the cache - if pluginTrns.ttype == handler.TransitionTypeBarrier { - logger.Infof(ctx, "Barrier transition observed for Plugin [%s], TaskExecID [%s]. recording: [%s]", p.GetID(), tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), pluginTrns.pInfo.String()) - barrierTick = barrierTick + 1 - t.barrierCache.RecordBarrierTransition(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), BarrierTransition{ - BarrierClockTick: barrierTick, - CallLog: PluginCallLog{ - PluginTransition: pluginTrns, - }, - }) - } - } else { - // Barrier tick will remain to be the one in cache. - // Now it may happen that the cache may get reset before we store the barrier tick - // this will cause us to lose that information and potentially replaying. - logger.Infof(ctx, "Replaying Barrier transition for cache tick [%d] < stored tick [%d], Plugin [%s], TaskExecID [%s]. recording: [%s]", barrierTick, ts.BarrierClockTick, p.GetID(), tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), prevBarrier.CallLog.PluginTransition.pInfo.String()) - pluginTrns = prevBarrier.CallLog.PluginTransition + var err error + pluginTrns, err = t.invokePlugin(ctx, p, tCtx, ts) + if err != nil { + return handler.UnknownTransition, errors.Wrapf(errors.RuntimeExecutionError, nCtx.NodeID(), err, "failed during plugin execution") + } + if pluginTrns.IsPreviouslyObserved() { + logger.Debugf(ctx, "No state change for Task, previously observed same transition. Short circuiting.") + return pluginTrns.FinalTransition(ctx) } } @@ -724,6 +695,7 @@ func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) PluginID: p.GetID(), ResourcePoolInfo: tCtx.rm.GetResourcePoolInfo(), ClusterID: t.clusterID, + OccurredAt: occurredAt, }) if err != nil { return handler.UnknownTransition, err @@ -750,6 +722,7 @@ func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) PluginID: p.GetID(), ResourcePoolInfo: tCtx.rm.GetResourcePoolInfo(), ClusterID: t.clusterID, + OccurredAt: occurredAt, }) if err != nil { logger.Errorf(ctx, "failed to convert plugin transition to TaskExecutionEvent. Error: %s", err.Error()) @@ -772,7 +745,6 @@ func (t Handler) Handle(ctx context.Context, nCtx handler.NodeExecutionContext) PluginStateVersion: pluginTrns.pluginStateVersion, PluginPhase: pluginTrns.pInfo.Phase(), PluginPhaseVersion: pluginTrns.pInfo.Version(), - BarrierClockTick: barrierTick, LastPhaseUpdatedAt: time.Now(), PreviousNodeExecutionCheckpointURI: ts.PreviousNodeExecutionCheckpointURI, CleanupOnFailure: ts.CleanupOnFailure || pluginTrns.pInfo.CleanupOnFailure(), @@ -931,7 +903,6 @@ func New(ctx context.Context, kubeClient executors.Client, client catalog.Client asyncCatalog: async, resourceManager: nil, secretManager: secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig()), - barrierCache: newLRUBarrier(ctx, cfg.BarrierConfig), cfg: cfg, eventConfig: eventConfig, clusterID: clusterID, diff --git a/pkg/controller/nodes/task/handler_test.go b/pkg/controller/nodes/task/handler_test.go index c1aabf0f9..6da4762cb 100644 --- a/pkg/controller/nodes/task/handler_test.go +++ b/pkg/controller/nodes/task/handler_test.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "testing" - "time" "github.com/flyteorg/flyteidl/clients/go/coreutils" "github.com/golang/protobuf/proto" @@ -699,11 +698,8 @@ func Test_task_Handle_NoCatalog(t *testing.T) { defaultPlugins: map[pluginCore.TaskType]pluginCore.Plugin{ "test": fakeplugins.NewPhaseBasedPlugin(), }, - pluginScope: promutils.NewTestScope(), - catalog: c, - barrierCache: newLRUBarrier(context.TODO(), config.BarrierConfig{ - Enabled: false, - }), + pluginScope: promutils.NewTestScope(), + catalog: c, resourceManager: noopRm, taskMetricsMap: make(map[MetricKey]*taskMetrics), eventConfig: eventConfig, @@ -1272,310 +1268,6 @@ func Test_task_Handle_Reservation(t *testing.T) { } } -func Test_task_Handle_Barrier(t *testing.T) { - // NOTE: Caching is disabled for this test - - createNodeContext := func(recorder events.TaskEventRecorder, ttype string, s *taskNodeStateHolder, prevBarrierClockTick uint32) *nodeMocks.NodeExecutionContext { - wfExecID := &core.WorkflowExecutionIdentifier{ - Project: "project", - Domain: "domain", - Name: "name", - } - - nodeID := "n1" - - nm := &nodeMocks.NodeExecutionMetadata{} - nm.OnGetAnnotations().Return(map[string]string{}) - nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ - NodeId: nodeID, - ExecutionId: wfExecID, - }) - nm.OnGetK8sServiceAccount().Return("service-account") - nm.OnGetLabels().Return(map[string]string{}) - nm.OnGetNamespace().Return("namespace") - nm.OnGetOwnerID().Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) - nm.OnGetOwnerReference().Return(v12.OwnerReference{ - Kind: "sample", - Name: "name", - }) - nm.OnIsInterruptible().Return(true) - - taskID := &core.Identifier{} - tk := &core.TaskTemplate{ - Id: taskID, - Type: "test", - Metadata: &core.TaskMetadata{ - Discoverable: false, - }, - Interface: &core.TypedInterface{ - Outputs: &core.VariableMap{ - Variables: map[string]*core.Variable{ - "x": { - Type: &core.LiteralType{ - Type: &core.LiteralType_Simple{ - Simple: core.SimpleType_BOOLEAN, - }, - }, - }, - }, - }, - }, - } - tr := &nodeMocks.TaskReader{} - tr.OnGetTaskID().Return(taskID) - tr.OnGetTaskType().Return(ttype) - tr.OnReadMatch(mock.Anything).Return(tk, nil) - - ns := &flyteMocks.ExecutableNodeStatus{} - ns.OnGetDataDir().Return(storage.DataReference("data-dir")) - ns.OnGetOutputDir().Return(storage.DataReference("output-dir")) - - res := &v1.ResourceRequirements{} - n := &flyteMocks.ExecutableNode{} - ma := 5 - n.OnGetRetryStrategy().Return(&v1alpha1.RetryStrategy{MinAttempts: &ma}) - n.OnGetResources().Return(res) - - ir := &ioMocks.InputReader{} - ir.OnGetInputPath().Return(storage.DataReference("input")) - ir.OnGetMatch(mock.Anything).Return(&core.LiteralMap{}, nil) - nCtx := &nodeMocks.NodeExecutionContext{} - nCtx.OnNodeExecutionMetadata().Return(nm) - nCtx.OnNode().Return(n) - nCtx.OnInputReader().Return(ir) - ds, err := storage.NewDataStore( - &storage.Config{ - Type: storage.TypeMemory, - }, - promutils.NewTestScope(), - ) - assert.NoError(t, err) - nCtx.OnDataStore().Return(ds) - nCtx.OnCurrentAttempt().Return(uint32(1)) - nCtx.OnTaskReader().Return(tr) - nCtx.OnMaxDatasetSizeBytes().Return(int64(1)) - nCtx.OnNodeStatus().Return(ns) - nCtx.OnNodeID().Return("n1") - nCtx.OnEventsRecorder().Return(recorder) - nCtx.OnEnqueueOwnerFunc().Return(nil) - - executionContext := &mocks.ExecutionContext{} - executionContext.OnGetExecutionConfig().Return(v1alpha1.ExecutionConfig{}) - executionContext.OnGetEventVersion().Return(v1alpha1.EventVersion0) - executionContext.OnGetParentInfo().Return(nil) - executionContext.OnIncrementParallelism().Return(1) - nCtx.OnExecutionContext().Return(executionContext) - - nCtx.OnRawOutputPrefix().Return("s3://sandbox/") - nCtx.OnOutputShardSelector().Return(ioutils.NewConstantShardSelector([]string{"x"})) - - st := bytes.NewBuffer([]byte{}) - cod := codex.GobStateCodec{} - assert.NoError(t, cod.Encode(&fakeplugins.NextPhaseState{ - Phase: pluginCore.PhaseSuccess, - OutputExists: true, - }, st)) - nr := &nodeMocks.NodeStateReader{} - nr.OnGetTaskNodeState().Return(handler.TaskNodeState{ - PluginState: st.Bytes(), - BarrierClockTick: prevBarrierClockTick, - }) - nCtx.OnNodeStateReader().Return(nr) - nCtx.OnNodeStateWriter().Return(s) - return nCtx - } - - noopRm := CreateNoopResourceManager(context.TODO(), promutils.NewTestScope()) - - trns := pluginCore.DoTransitionType(pluginCore.TransitionTypeBarrier, pluginCore.PhaseInfoQueued(time.Now(), 1, "z")) - type args struct { - prevTick uint32 - btrnsTick uint32 - bTrns *pluginCore.Transition - res []fakeplugins.HandleResponse - } - type wantBarrier struct { - hit bool - tick uint32 - } - type want struct { - wantBarrer wantBarrier - handlerPhase handler.EPhase - wantErr bool - eventPhase core.TaskExecution_Phase - pluginPhase pluginCore.Phase - pluginVer uint32 - } - tests := []struct { - name string - args args - want want - }{ - { - "ephemeral-trns", - args{ - res: []fakeplugins.HandleResponse{ - {T: pluginCore.DoTransitionType(pluginCore.TransitionTypeEphemeral, pluginCore.PhaseInfoRunning(1, &pluginCore.TaskInfo{}))}, - }, - }, - want{ - handlerPhase: handler.EPhaseRunning, - eventPhase: core.TaskExecution_RUNNING, - pluginPhase: pluginCore.PhaseRunning, - pluginVer: 1, - }, - }, - { - "first-barrier-trns", - args{ - res: []fakeplugins.HandleResponse{ - {T: pluginCore.DoTransitionType(pluginCore.TransitionTypeBarrier, pluginCore.PhaseInfoRunning(1, &pluginCore.TaskInfo{}))}, - }, - }, - want{ - wantBarrer: wantBarrier{ - hit: true, - tick: 1, - }, - handlerPhase: handler.EPhaseRunning, - eventPhase: core.TaskExecution_RUNNING, - pluginPhase: pluginCore.PhaseRunning, - pluginVer: 1, - }, - }, - { - "barrier-trns-replay", - args{ - prevTick: 0, - btrnsTick: 1, - bTrns: &trns, - }, - want{ - wantBarrer: wantBarrier{ - hit: true, - tick: 1, - }, - handlerPhase: handler.EPhaseRunning, - eventPhase: core.TaskExecution_QUEUED, - pluginPhase: pluginCore.PhaseQueued, - pluginVer: 1, - }, - }, - { - "barrier-trns-next", - args{ - prevTick: 1, - btrnsTick: 1, - bTrns: &trns, - res: []fakeplugins.HandleResponse{ - {T: pluginCore.DoTransitionType(pluginCore.TransitionTypeBarrier, pluginCore.PhaseInfoRunning(1, &pluginCore.TaskInfo{}))}, - }, - }, - want{ - wantBarrer: wantBarrier{ - hit: true, - tick: 2, - }, - handlerPhase: handler.EPhaseRunning, - eventPhase: core.TaskExecution_RUNNING, - pluginPhase: pluginCore.PhaseRunning, - pluginVer: 1, - }, - }, - { - "barrier-trns-restart-case", - args{ - prevTick: 2, - res: []fakeplugins.HandleResponse{ - {T: pluginCore.DoTransitionType(pluginCore.TransitionTypeBarrier, pluginCore.PhaseInfoRunning(1, &pluginCore.TaskInfo{}))}, - }, - }, - want{ - wantBarrer: wantBarrier{ - hit: true, - tick: 3, - }, - handlerPhase: handler.EPhaseRunning, - eventPhase: core.TaskExecution_RUNNING, - pluginPhase: pluginCore.PhaseRunning, - pluginVer: 1, - }, - }, - { - "barrier-trns-restart-case-ephemeral", - args{ - prevTick: 2, - res: []fakeplugins.HandleResponse{ - {T: pluginCore.DoTransitionType(pluginCore.TransitionTypeEphemeral, pluginCore.PhaseInfoRunning(1, &pluginCore.TaskInfo{}))}, - }, - }, - want{ - wantBarrer: wantBarrier{ - hit: false, - }, - handlerPhase: handler.EPhaseRunning, - eventPhase: core.TaskExecution_RUNNING, - pluginPhase: pluginCore.PhaseRunning, - pluginVer: 1, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - state := &taskNodeStateHolder{} - ev := &fakeBufferedTaskEventRecorder{} - nCtx := createNodeContext(ev, "test", state, tt.args.prevTick) - c := &pluginCatalogMocks.Client{} - - tk, err := New(context.TODO(), mocks.NewFakeKubeClient(), c, eventConfig, testClusterID, promutils.NewTestScope()) - assert.NoError(t, err) - tk.resourceManager = noopRm - - p := &pluginCoreMocks.Plugin{} - p.On("GetID").Return("plugin1") - p.OnGetProperties().Return(pluginCore.PluginProperties{}) - tctx, err := tk.newTaskExecutionContext(context.TODO(), nCtx, p) - assert.NoError(t, err) - id := tctx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() - - if tt.args.bTrns != nil { - x := &pluginRequestedTransition{} - x.ObservedTransitionAndState(*tt.args.bTrns, 0, nil) - tk.barrierCache.RecordBarrierTransition(context.TODO(), id, BarrierTransition{tt.args.btrnsTick, PluginCallLog{x}}) - } - - tk.defaultPlugins = map[pluginCore.TaskType]pluginCore.Plugin{ - "test": fakeplugins.NewReplayer("test", pluginCore.PluginProperties{}, - tt.args.res, nil, nil), - } - - got, err := tk.Handle(context.TODO(), nCtx) - if (err != nil) != tt.want.wantErr { - t.Errorf("Handler.Handle() error = %v, wantErr %v", err, tt.want.wantErr) - return - } - if err == nil { - assert.Equal(t, tt.want.handlerPhase.String(), got.Info().GetPhase().String()) - if assert.Equal(t, 1, len(ev.evs)) { - e := ev.evs[0] - assert.Equal(t, tt.want.eventPhase.String(), e.Phase.String()) - } - assert.Equal(t, tt.want.pluginPhase.String(), state.s.PluginPhase.String()) - assert.Equal(t, tt.want.pluginVer, state.s.PluginPhaseVersion) - if tt.want.wantBarrer.hit { - assert.Len(t, tk.barrierCache.barrierTransitions.Keys(), 1) - bt := tk.barrierCache.GetPreviousBarrierTransition(context.TODO(), id) - assert.Equal(t, bt.BarrierClockTick, tt.want.wantBarrer.tick) - assert.Equal(t, tt.want.wantBarrer.tick, state.s.BarrierClockTick) - } else { - assert.Len(t, tk.barrierCache.barrierTransitions.Keys(), 0) - assert.Equal(t, tt.args.prevTick, state.s.BarrierClockTick) - } - } - }) - } -} - func Test_task_Abort(t *testing.T) { createNodeCtx := func(ev events.TaskEventRecorder) *nodeMocks.NodeExecutionContext { wfExecID := &core.WorkflowExecutionIdentifier{ diff --git a/pkg/controller/nodes/task/k8s/plugin_context.go b/pkg/controller/nodes/task/k8s/plugin_context.go index cb90edfb3..aed5bc468 100644 --- a/pkg/controller/nodes/task/k8s/plugin_context.go +++ b/pkg/controller/nodes/task/k8s/plugin_context.go @@ -2,6 +2,7 @@ package k8s import ( "context" + "fmt" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" @@ -15,7 +16,8 @@ var _ k8s.PluginContext = &pluginContext{} type pluginContext struct { pluginsCore.TaskExecutionContext // Lazily creates a buffered outputWriter, overriding the input outputWriter. - ow *ioutils.BufferedOutputWriter + ow *ioutils.BufferedOutputWriter + k8sPluginState *k8s.PluginState } // Provides an output sync of type io.OutputWriter @@ -26,9 +28,38 @@ func (p *pluginContext) OutputWriter() io.OutputWriter { return buf } -func newPluginContext(tCtx pluginsCore.TaskExecutionContext) *pluginContext { +// pluginStateReader overrides the default PluginStateReader to return a pre-assigned PluginState. This allows us to +// encapsulate plugin state persistence in the existing k8s PluginManager and only expose the ability to read the +// previous Phase, PhaseVersion, and Reason for all k8s plugins. +type pluginStateReader struct { + k8sPluginState *k8s.PluginState +} + +func (p pluginStateReader) GetStateVersion() uint8 { + return 0 +} + +func (p pluginStateReader) Get(t interface{}) (stateVersion uint8, err error) { + if pointer, ok := t.(*k8s.PluginState); ok { + *pointer = *p.k8sPluginState + } else { + return 0, fmt.Errorf("unexpected type when reading plugin state") + } + + return 0, nil +} + +// PluginStateReader overrides the default behavior to return our k8s plugin specific reader. +func (p *pluginContext) PluginStateReader() pluginsCore.PluginStateReader { + return pluginStateReader{ + k8sPluginState: p.k8sPluginState, + } +} + +func newPluginContext(tCtx pluginsCore.TaskExecutionContext, k8sPluginState *k8s.PluginState) *pluginContext { return &pluginContext{ TaskExecutionContext: tCtx, ow: nil, + k8sPluginState: k8sPluginState, } } diff --git a/pkg/controller/nodes/task/k8s/plugin_manager.go b/pkg/controller/nodes/task/k8s/plugin_manager.go index e0d786858..67b0356a3 100644 --- a/pkg/controller/nodes/task/k8s/plugin_manager.go +++ b/pkg/controller/nodes/task/k8s/plugin_manager.go @@ -59,7 +59,8 @@ const ( ) type PluginState struct { - Phase PluginPhase + Phase PluginPhase + K8sPluginState k8s.PluginState } type PluginMetrics struct { @@ -247,7 +248,7 @@ func (e *PluginManager) LaunchResource(ctx context.Context, tCtx pluginsCore.Tas return pluginsCore.DoTransition(pluginsCore.PhaseInfoQueued(time.Now(), pluginsCore.DefaultPhaseVersion, "task submitted to K8s")), nil } -func (e *PluginManager) CheckResourcePhase(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (pluginsCore.Transition, error) { +func (e *PluginManager) CheckResourcePhase(ctx context.Context, tCtx pluginsCore.TaskExecutionContext, k8sPluginState *k8s.PluginState) (pluginsCore.Transition, error) { o, err := e.plugin.BuildIdentityResource(ctx, tCtx.TaskExecutionMetadata()) if err != nil { @@ -274,7 +275,7 @@ func (e *PluginManager) CheckResourcePhase(ctx context.Context, tCtx pluginsCore e.metrics.ResourceDeleted.Inc(ctx) } - pCtx := newPluginContext(tCtx) + pCtx := newPluginContext(tCtx, k8sPluginState) p, err := e.plugin.GetTaskPhase(ctx, pCtx, o) if err != nil { logger.Warnf(ctx, "failed to check status of resource in plugin [%s], with error: %s", e.GetID(), err.Error()) @@ -311,6 +312,7 @@ func (e *PluginManager) CheckResourcePhase(ctx context.Context, tCtx pluginsCore } func (e PluginManager) Handle(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (pluginsCore.Transition, error) { + // read phase state ps := PluginState{} if v, err := tCtx.PluginStateReader().Get(&ps); err != nil { if v != pluginStateVersion { @@ -318,16 +320,44 @@ func (e PluginManager) Handle(ctx context.Context, tCtx pluginsCore.TaskExecutio } return pluginsCore.UnknownTransition, errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state") } + + // evaluate plugin + var err error + var transition pluginsCore.Transition + pluginPhase := ps.Phase if ps.Phase == PluginPhaseNotStarted { - t, err := e.LaunchResource(ctx, tCtx) - if err == nil && t.Info().Phase() == pluginsCore.PhaseQueued { - if err := tCtx.PluginStateWriter().Put(pluginStateVersion, &PluginState{Phase: PluginPhaseStarted}); err != nil { - return pluginsCore.UnknownTransition, err - } + transition, err = e.LaunchResource(ctx, tCtx) + if err == nil && transition.Info().Phase() == pluginsCore.PhaseQueued { + pluginPhase = PluginPhaseStarted } - return t, err + } else { + transition, err = e.CheckResourcePhase(ctx, tCtx, &ps.K8sPluginState) + } + + if err != nil { + return transition, err } - return e.CheckResourcePhase(ctx, tCtx) + + // persist any changes in phase state + k8sPluginState := ps.K8sPluginState + if ps.Phase != pluginPhase || k8sPluginState.Phase != transition.Info().Phase() || + k8sPluginState.PhaseVersion != transition.Info().Version() || k8sPluginState.Reason != transition.Info().Reason() { + + newPluginState := PluginState{ + Phase: pluginPhase, + K8sPluginState: k8s.PluginState{ + Phase: transition.Info().Phase(), + PhaseVersion: transition.Info().Version(), + Reason: transition.Info().Reason(), + }, + } + + if err := tCtx.PluginStateWriter().Put(pluginStateVersion, &newPluginState); err != nil { + return pluginsCore.UnknownTransition, err + } + } + + return transition, err } func (e PluginManager) Abort(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) error { diff --git a/pkg/controller/nodes/task/k8s/plugin_manager_test.go b/pkg/controller/nodes/task/k8s/plugin_manager_test.go index 160dc335f..94b6b5524 100644 --- a/pkg/controller/nodes/task/k8s/plugin_manager_test.go +++ b/pkg/controller/nodes/task/k8s/plugin_manager_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "reflect" "testing" "k8s.io/client-go/kubernetes/scheme" @@ -715,6 +716,157 @@ func TestPluginManager_Handle_CheckResourceStatus(t *testing.T) { } } +func TestPluginManager_Handle_PluginState(t *testing.T) { + ctx := context.TODO() + tm := getMockTaskExecutionMetadata() + res := &v1.Pod{ + ObjectMeta: v12.ObjectMeta{ + Name: tm.GetTaskExecutionID().GetGeneratedName(), + Namespace: tm.GetNamespace(), + }, + } + + pluginStateQueued := PluginState{ + Phase: PluginPhaseStarted, + K8sPluginState: k8s.PluginState{ + Phase: pluginsCore.PhaseQueued, + PhaseVersion: 0, + Reason: "foo", + }, + } + pluginStateQueuedVersion1 := PluginState{ + Phase: PluginPhaseStarted, + K8sPluginState: k8s.PluginState{ + Phase: pluginsCore.PhaseQueued, + PhaseVersion: 1, + Reason: "foo", + }, + } + pluginStateQueuedReasonBar := PluginState{ + Phase: PluginPhaseStarted, + K8sPluginState: k8s.PluginState{ + Phase: pluginsCore.PhaseQueued, + PhaseVersion: 0, + Reason: "bar", + }, + } + pluginStateRunning := PluginState{ + Phase: PluginPhaseStarted, + K8sPluginState: k8s.PluginState{ + Phase: pluginsCore.PhaseRunning, + PhaseVersion: 0, + Reason: "", + }, + } + + phaseInfoQueued := pluginsCore.PhaseInfoQueuedWithTaskInfo(pluginStateQueued.K8sPluginState.PhaseVersion, pluginStateQueued.K8sPluginState.Reason, nil) + phaseInfoQueuedVersion1 := pluginsCore.PhaseInfoQueuedWithTaskInfo( + pluginStateQueuedVersion1.K8sPluginState.PhaseVersion, + pluginStateQueuedVersion1.K8sPluginState.Reason, + nil, + ) + phaseInfoQueuedReasonBar := pluginsCore.PhaseInfoQueuedWithTaskInfo( + pluginStateQueuedReasonBar.K8sPluginState.PhaseVersion, + pluginStateQueuedReasonBar.K8sPluginState.Reason, + nil, + ) + phaseInfoRunning := pluginsCore.PhaseInfoRunning(0, nil) + + tests := []struct { + name string + startPluginState PluginState + reportedPhaseInfo pluginsCore.PhaseInfo + expectedPluginState PluginState + }{ + { + "NoChange", + pluginStateQueued, + phaseInfoQueued, + pluginStateQueued, + }, + { + "K8sPhaseChange", + pluginStateQueued, + phaseInfoRunning, + pluginStateRunning, + }, + { + "PhaseVersionChange", + pluginStateQueued, + phaseInfoQueuedVersion1, + pluginStateQueuedVersion1, + }, + { + "ReasonChange", + pluginStateQueued, + phaseInfoQueuedReasonBar, + pluginStateQueuedReasonBar, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // mock TaskExecutionContext + tCtx := &pluginsCoreMock.TaskExecutionContext{} + tCtx.OnTaskExecutionMetadata().Return(getMockTaskExecutionMetadata()) + + tReader := &pluginsCoreMock.TaskReader{} + tReader.OnReadMatch(mock.Anything).Return(&core.TaskTemplate{}, nil) + tCtx.OnTaskReader().Return(tReader) + + // mock state reader / writer to use local pluginState variable + pluginState := &tt.startPluginState + customStateReader := &pluginsCoreMock.PluginStateReader{} + customStateReader.OnGetMatch(mock.MatchedBy(func(i interface{}) bool { + ps, ok := i.(*PluginState) + if ok { + *ps = *pluginState + return true + } + return false + })).Return(uint8(0), nil) + tCtx.OnPluginStateReader().Return(customStateReader) + + customStateWriter := &pluginsCoreMock.PluginStateWriter{} + customStateWriter.OnPutMatch(mock.Anything, mock.MatchedBy(func(i interface{}) bool { + ps, ok := i.(*PluginState) + if ok { + *pluginState = *ps + } + return ok + })).Return(nil) + tCtx.OnPluginStateWriter().Return(customStateWriter) + tCtx.OnOutputWriter().Return(&dummyOutputWriter{}) + + fc := extendedFakeClient{Client: fake.NewFakeClient(res)} + + mockResourceHandler := &pluginsk8sMock.Plugin{} + mockResourceHandler.OnGetProperties().Return(k8s.PluginProperties{}) + mockResourceHandler.On("BuildIdentityResource", mock.Anything, tCtx.TaskExecutionMetadata()).Return(&v1.Pod{}, nil) + mockResourceHandler.On("GetTaskPhase", mock.Anything, mock.Anything, mock.Anything).Return(tt.reportedPhaseInfo, nil) + + // create new PluginManager + pluginManager, err := NewPluginManager(ctx, dummySetupContext(fc), k8s.PluginEntry{ + ID: "x", + ResourceToWatch: &v1.Pod{}, + Plugin: mockResourceHandler, + }, NewResourceMonitorIndex()) + assert.NoError(t, err) + + // handle plugin + _, err = pluginManager.Handle(ctx, tCtx) + assert.NoError(t, err) + + // verify expected PluginState + newPluginState := PluginState{} + _, err = tCtx.PluginStateReader().Get(&newPluginState) + assert.NoError(t, err) + + assert.True(t, reflect.DeepEqual(newPluginState, tt.expectedPluginState)) + }) + } +} + func TestPluginManager_CustomKubeClient(t *testing.T) { ctx := context.TODO() tctx := getMockTaskContext(PluginPhaseNotStarted, PluginPhaseStarted) diff --git a/pkg/controller/nodes/task/transformer.go b/pkg/controller/nodes/task/transformer.go index db4e8a5a9..6faa93f70 100644 --- a/pkg/controller/nodes/task/transformer.go +++ b/pkg/controller/nodes/task/transformer.go @@ -1,6 +1,8 @@ package task import ( + "time" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" pluginCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" @@ -9,9 +11,10 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/config" "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/common" - "github.com/golang/protobuf/ptypes" - "github.com/flyteorg/flytepropeller/pkg/controller/nodes/handler" + + "github.com/golang/protobuf/ptypes" + timestamppb "github.com/golang/protobuf/ptypes/timestamp" ) // This is used by flyteadmin to indicate that map tasks now report subtask metadata individually. @@ -78,15 +81,27 @@ type ToTaskExecutionEventInputs struct { PluginID string ResourcePoolInfo []*event.ResourcePoolInfo ClusterID string + OccurredAt time.Time } func ToTaskExecutionEvent(input ToTaskExecutionEventInputs) (*event.TaskExecutionEvent, error) { // Transitions to a new phase - tm := ptypes.TimestampNow() var err error + var occurredAt *timestamppb.Timestamp if i := input.Info.Info(); i != nil && i.OccurredAt != nil { - tm, err = ptypes.TimestampProto(*i.OccurredAt) + occurredAt, err = ptypes.TimestampProto(*i.OccurredAt) + } else { + occurredAt, err = ptypes.TimestampProto(input.OccurredAt) + } + + if err != nil { + return nil, err + } + + reportedAt := ptypes.TimestampNow() + if i := input.Info.Info(); i != nil && i.ReportedAt != nil { + occurredAt, err = ptypes.TimestampProto(*i.ReportedAt) if err != nil { return nil, err } @@ -127,11 +142,12 @@ func ToTaskExecutionEvent(input ToTaskExecutionEventInputs) (*event.TaskExecutio Phase: ToTaskEventPhase(input.Info.Phase()), PhaseVersion: input.Info.Version(), ProducerId: input.ClusterID, - OccurredAt: tm, + OccurredAt: occurredAt, TaskType: input.TaskType, Reason: input.Info.Reason(), Metadata: metadata, EventVersion: taskExecutionEventVersion, + ReportedAt: reportedAt, } if input.Info.Phase().IsSuccess() && input.OutputWriter != nil { diff --git a/pkg/controller/nodes/transformers.go b/pkg/controller/nodes/transformers.go index ed3af8638..ae615a9f3 100644 --- a/pkg/controller/nodes/transformers.go +++ b/pkg/controller/nodes/transformers.go @@ -114,6 +114,7 @@ func ToNodeExecutionEvent(nodeExecID *core.NodeExecutionIdentifier, OccurredAt: occurredTime, ProducerId: clusterID, EventVersion: nodeExecutionEventVersion, + ReportedAt: ptypes.TimestampNow(), } } else { nev = &event.NodeExecutionEvent{ @@ -122,6 +123,7 @@ func ToNodeExecutionEvent(nodeExecID *core.NodeExecutionIdentifier, OccurredAt: occurredTime, ProducerId: clusterID, EventVersion: nodeExecutionEventVersion, + ReportedAt: ptypes.TimestampNow(), } } @@ -237,7 +239,6 @@ func UpdateNodeStatus(np v1alpha1.NodePhase, p handler.PhaseInfo, n *nodeStateMa t.SetLastPhaseUpdatedAt(n.t.LastPhaseUpdatedAt) t.SetPluginState(n.t.PluginState) t.SetPluginStateVersion(n.t.PluginStateVersion) - t.SetBarrierClockTick(n.t.BarrierClockTick) t.SetPreviousNodeExecutionCheckpointPath(n.t.PreviousNodeExecutionCheckpointURI) t.SetCleanupOnFailure(n.t.CleanupOnFailure) } diff --git a/pkg/controller/workflow/executor.go b/pkg/controller/workflow/executor.go index 92a870c1c..e3eac1e37 100644 --- a/pkg/controller/workflow/executor.go +++ b/pkg/controller/workflow/executor.go @@ -124,7 +124,7 @@ func (c *workflowExecutor) handleReadyWorkflow(ctx context.Context, w *v1alpha1. nodeStatus.SetDataDir(dataDir) nodeStatus.SetOutputDir(outputDir) execcontext := executors.NewExecutionContext(w, w, w, nil, executors.InitializeControlFlow()) - s, err := c.nodeExecutor.SetInputsForStartNode(ctx, execcontext, w, executors.NewNodeLookup(w, w.GetExecutionStatus()), inputs) + s, err := c.nodeExecutor.SetInputsForStartNode(ctx, execcontext, w, executors.NewNodeLookup(w, w.GetExecutionStatus(), w), inputs) if err != nil { return StatusReady, err } diff --git a/pkg/webhook/config/config.go b/pkg/webhook/config/config.go index 61c598c1d..c3cf9f2d8 100644 --- a/pkg/webhook/config/config.go +++ b/pkg/webhook/config/config.go @@ -33,6 +33,19 @@ var ( }, }, }, + GCPSecretManagerConfig: GCPSecretManagerConfig{ + SidecarImage: "gcr.io/google.com/cloudsdktool/cloud-sdk:alpine", + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceMemory: resource.MustParse("500Mi"), + corev1.ResourceCPU: resource.MustParse("200m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceMemory: resource.MustParse("500Mi"), + corev1.ResourceCPU: resource.MustParse("200m"), + }, + }, + }, VaultSecretManagerConfig: VaultSecretManagerConfig{ Role: "flyte", KVVersion: KVVersion2, @@ -57,6 +70,10 @@ const ( // Manager and mount them to a local file system (in memory) and share that mount with other containers in the pod. SecretManagerTypeAWS + // SecretManagerTypeGCP defines a secret manager webhook that injects a side car to pull secrets from GCP Secret + // Manager and mount them to a local file system (in memory) and share that mount with other containers in the pod. + SecretManagerTypeGCP + // SecretManagerTypeVault defines a secret manager webhook that pulls secrets from Hashicorp Vault. SecretManagerTypeVault ) @@ -81,6 +98,7 @@ type Config struct { SecretName string `json:"secretName" pflag:",Secret name to write generated certs to."` SecretManagerType SecretManagerType `json:"secretManagerType" pflag:"-,Secret manager type to use if secrets are not found in global secrets."` AWSSecretManagerConfig AWSSecretManagerConfig `json:"awsSecretManager" pflag:",AWS Secret Manager config."` + GCPSecretManagerConfig GCPSecretManagerConfig `json:"gcpSecretManager" pflag:",GCP Secret Manager config."` VaultSecretManagerConfig VaultSecretManagerConfig `json:"vaultSecretManager" pflag:",Vault Secret Manager config."` } @@ -89,6 +107,11 @@ type AWSSecretManagerConfig struct { Resources corev1.ResourceRequirements `json:"resources" pflag:"-,Specifies resource requirements for the init container."` } +type GCPSecretManagerConfig struct { + SidecarImage string `json:"sidecarImage" pflag:",Specifies the sidecar docker image to use"` + Resources corev1.ResourceRequirements `json:"resources" pflag:"-,Specifies resource requirements for the init container."` +} + type VaultSecretManagerConfig struct { Role string `json:"role" pflag:",Specifies the vault role to use"` KVVersion KVVersion `json:"kvVersion" pflag:"-,The KV Engine Version. Defaults to 2. Use 1 for unversioned secrets. Refer to - https://www.vaultproject.io/docs/secrets/kv#kv-secrets-engine."` diff --git a/pkg/webhook/config/config_flags.go b/pkg/webhook/config/config_flags.go index 7ef9575d7..089bc0064 100755 --- a/pkg/webhook/config/config_flags.go +++ b/pkg/webhook/config/config_flags.go @@ -58,6 +58,7 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Int32(fmt.Sprintf("%v%v", prefix, "servicePort"), DefaultConfig.ServicePort, "The port on the service that hosting webhook.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "secretName"), DefaultConfig.SecretName, "Secret name to write generated certs to.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "awsSecretManager.sidecarImage"), DefaultConfig.AWSSecretManagerConfig.SidecarImage, "Specifies the sidecar docker image to use") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "gcpSecretManager.sidecarImage"), DefaultConfig.GCPSecretManagerConfig.SidecarImage, "Specifies the sidecar docker image to use") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "vaultSecretManager.role"), DefaultConfig.VaultSecretManagerConfig.Role, "Specifies the vault role to use") return cmdFlags } diff --git a/pkg/webhook/config/config_flags_test.go b/pkg/webhook/config/config_flags_test.go index e68b5af13..613a0f6a3 100755 --- a/pkg/webhook/config/config_flags_test.go +++ b/pkg/webhook/config/config_flags_test.go @@ -211,6 +211,20 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_gcpSecretManager.sidecarImage", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("gcpSecretManager.sidecarImage", testValue) + if vString, err := cmdFlags.GetString("gcpSecretManager.sidecarImage"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.GCPSecretManagerConfig.SidecarImage) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_vaultSecretManager.role", func(t *testing.T) { t.Run("Override", func(t *testing.T) { diff --git a/pkg/webhook/config/secretmanagertype_enumer.go b/pkg/webhook/config/secretmanagertype_enumer.go index ce33f910f..986b8b135 100644 --- a/pkg/webhook/config/secretmanagertype_enumer.go +++ b/pkg/webhook/config/secretmanagertype_enumer.go @@ -7,9 +7,9 @@ import ( "fmt" ) -const _SecretManagerTypeName = "GlobalK8sAWSVault" +const _SecretManagerTypeName = "GlobalK8sAWSGCPVault" -var _SecretManagerTypeIndex = [...]uint8{0, 6, 9, 12, 17} +var _SecretManagerTypeIndex = [...]uint8{0, 6, 9, 12, 15, 20} func (i SecretManagerType) String() string { if i < 0 || i >= SecretManagerType(len(_SecretManagerTypeIndex)-1) { @@ -18,13 +18,14 @@ func (i SecretManagerType) String() string { return _SecretManagerTypeName[_SecretManagerTypeIndex[i]:_SecretManagerTypeIndex[i+1]] } -var _SecretManagerTypeValues = []SecretManagerType{0, 1, 2, 3} +var _SecretManagerTypeValues = []SecretManagerType{0, 1, 2, 3, 4} var _SecretManagerTypeNameToValueMap = map[string]SecretManagerType{ _SecretManagerTypeName[0:6]: 0, _SecretManagerTypeName[6:9]: 1, _SecretManagerTypeName[9:12]: 2, - _SecretManagerTypeName[12:17]: 3, + _SecretManagerTypeName[12:15]: 3, + _SecretManagerTypeName[15:20]: 4, } // SecretManagerTypeString retrieves an enum value from the enum constants string name. diff --git a/pkg/webhook/entrypoint.go b/pkg/webhook/entrypoint.go index 4f5a06825..50c4099c2 100644 --- a/pkg/webhook/entrypoint.go +++ b/pkg/webhook/entrypoint.go @@ -52,8 +52,10 @@ func Run(ctx context.Context, propellerCfg *config.Config, cfg *config2.Config, logger.Fatalf(ctx, "Failed to register webhook with manager. Error: %v", err) } - logger.Infof(ctx, "Starting controller-runtime manager") - return (*mgr).Start(ctx) + logger.Infof(ctx, "Started propeller webhook") + <-ctx.Done() + + return nil } func createMutationConfig(ctx context.Context, kubeClient *kubernetes.Clientset, webhookObj *PodMutator, defaultNamespace string) error { diff --git a/pkg/webhook/gcp_secret_manager.go b/pkg/webhook/gcp_secret_manager.go new file mode 100644 index 000000000..f17c1509a --- /dev/null +++ b/pkg/webhook/gcp_secret_manager.go @@ -0,0 +1,156 @@ +package webhook + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flytepropeller/pkg/webhook/config" + "github.com/flyteorg/flytestdlib/logger" + corev1 "k8s.io/api/core/v1" +) + +const ( + // GCPSecretsVolumeName defines the static name of the volume used for mounting/sharing secrets between init-container + // sidecar and the rest of the containers in the pod. + GCPSecretsVolumeName = "gcp-secret-vol" // #nosec +) + +var ( + // GCPSecretMountPath defines the default mount path for secrets + GCPSecretMountPath = filepath.Join(string(os.PathSeparator), "etc", "flyte", "secrets") +) + +// GCPSecretManagerInjector allows injecting of secrets from GCP Secret Manager as files. It uses a Google Cloud +// SDK SideCar as an init-container to download the secret and save it to a local volume shared with all other +// containers in the pod. It supports multiple secrets to be mounted but that will result into adding an init +// container for each secret. The Google serviceaccount (GSA) associated with the Pod, either via Workload +// Identity (recommended) or the underlying node's serviceacccount, must have permissions to pull the secret +// from GCP Secret Manager. Currently, the secret must also exist in the same project. Otherwise, the Pod will +// fail with an init-error. +// Files will be mounted on /etc/flyte/secrets// +type GCPSecretManagerInjector struct { + cfg config.GCPSecretManagerConfig +} + +func formatGCPSecretAccessCommand(secret *core.Secret) []string { + // `gcloud` writes this file with permission 0600. + // This will cause permission issues in the main container when using non-root + // users, so we fix the file permissions with `chmod`. + secretDir := strings.ToLower(filepath.Join(GCPSecretMountPath, secret.Group)) + secretPath := strings.ToLower(filepath.Join(secretDir, secret.GroupVersion)) + args := []string{ + "gcloud", + "secrets", + "versions", + "access", + secret.GroupVersion, + fmt.Sprintf("--secret=%s", secret.Group), + fmt.Sprintf( + "--out-file=%s", + secretPath, + ), + "&&", + "chmod", + "+rX", + secretDir, + secretPath, + } + return []string{"sh", "-c", strings.Join(args, " ")} +} + +func formatGCPInitContainerName(index int) string { + return fmt.Sprintf("gcp-pull-secret-%v", index) +} + +func (i GCPSecretManagerInjector) Type() config.SecretManagerType { + return config.SecretManagerTypeGCP +} + +func (i GCPSecretManagerInjector) Inject(ctx context.Context, secret *core.Secret, p *corev1.Pod) (newP *corev1.Pod, injected bool, err error) { + if len(secret.Group) == 0 || len(secret.GroupVersion) == 0 { + return nil, false, fmt.Errorf("GCP Secrets Webhook require both group and group version to be set. "+ + "Secret: [%v]", secret) + } + + switch secret.MountRequirement { + case core.Secret_ANY: + fallthrough + case core.Secret_FILE: + // A Volume with a static name so that if we try to inject multiple secrets, we won't mount multiple volumes. + // We use Memory as the storage medium for volume source to avoid + vol := corev1.Volume{ + Name: GCPSecretsVolumeName, + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{ + Medium: corev1.StorageMediumMemory, + }, + }, + } + + p.Spec.Volumes = appendVolumeIfNotExists(p.Spec.Volumes, vol) + p.Spec.InitContainers = append(p.Spec.InitContainers, createGCPSidecarContainer(i.cfg, p, secret)) + + secretVolumeMount := corev1.VolumeMount{ + Name: GCPSecretsVolumeName, + ReadOnly: true, + MountPath: GCPSecretMountPath, + } + + p.Spec.Containers = AppendVolumeMounts(p.Spec.Containers, secretVolumeMount) + p.Spec.InitContainers = AppendVolumeMounts(p.Spec.InitContainers, secretVolumeMount) + + // Inject GCP secret-inject webhook annotations to mount the secret in a predictable location. + envVars := []corev1.EnvVar{ + // Set environment variable to let the container know where to find the mounted files. + { + Name: SecretPathDefaultDirEnvVar, + Value: GCPSecretMountPath, + }, + // Sets an empty prefix to let the containers know the file names will match the secret keys as-is. + { + Name: SecretPathFilePrefixEnvVar, + Value: "", + }, + } + + for _, envVar := range envVars { + p.Spec.InitContainers = AppendEnvVars(p.Spec.InitContainers, envVar) + p.Spec.Containers = AppendEnvVars(p.Spec.Containers, envVar) + } + case core.Secret_ENV_VAR: + fallthrough + default: + err := fmt.Errorf("unrecognized mount requirement [%v] for secret [%v]", secret.MountRequirement.String(), secret.Key) + logger.Error(ctx, err) + return p, false, err + } + + return p, true, nil +} + +func createGCPSidecarContainer(cfg config.GCPSecretManagerConfig, p *corev1.Pod, secret *core.Secret) corev1.Container { + return corev1.Container{ + Image: cfg.SidecarImage, + // Create a unique name to allow multiple secrets to be mounted. + Name: formatGCPInitContainerName(len(p.Spec.InitContainers)), + Command: formatGCPSecretAccessCommand(secret), + VolumeMounts: []corev1.VolumeMount{ + { + Name: GCPSecretsVolumeName, + MountPath: GCPSecretMountPath, + }, + }, + Resources: cfg.Resources, + } +} + +// NewGCPSecretManagerInjector creates a SecretInjector that's able to mount secrets from GCP Secret Manager. +func NewGCPSecretManagerInjector(cfg config.GCPSecretManagerConfig) GCPSecretManagerInjector { + return GCPSecretManagerInjector{ + cfg: cfg, + } +} diff --git a/pkg/webhook/gcp_secret_manager_test.go b/pkg/webhook/gcp_secret_manager_test.go new file mode 100644 index 000000000..26805eafc --- /dev/null +++ b/pkg/webhook/gcp_secret_manager_test.go @@ -0,0 +1,79 @@ +package webhook + +import ( + "context" + "testing" + + "github.com/flyteorg/flytepropeller/pkg/webhook/config" + + "github.com/go-test/deep" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" +) + +func TestGCPSecretManagerInjector_Inject(t *testing.T) { + injector := NewGCPSecretManagerInjector(config.DefaultConfig.GCPSecretManagerConfig) + inputSecret := &core.Secret{ + Group: "TestSecret", + GroupVersion: "2", + } + + expected := &corev1.Pod{ + Spec: corev1.PodSpec{ + Volumes: []corev1.Volume{ + { + Name: "gcp-secret-vol", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{ + Medium: corev1.StorageMediumMemory, + }, + }, + }, + }, + + InitContainers: []corev1.Container{ + { + Name: "gcp-pull-secret-0", + Image: "gcr.io/google.com/cloudsdktool/cloud-sdk:alpine", + Command: []string{ + "sh", + "-c", + "gcloud secrets versions access 2 --secret=TestSecret --out-file=/etc/flyte/secrets/testsecret/2 && chmod +rX /etc/flyte/secrets/testsecret /etc/flyte/secrets/testsecret/2", + }, + Env: []corev1.EnvVar{ + { + Name: "FLYTE_SECRETS_DEFAULT_DIR", + Value: "/etc/flyte/secrets", + }, + { + Name: "FLYTE_SECRETS_FILE_PREFIX", + Value: "", + }, + }, + VolumeMounts: []corev1.VolumeMount{ + { + Name: "gcp-secret-vol", + MountPath: "/etc/flyte/secrets", + }, + }, + Resources: config.DefaultConfig.GCPSecretManagerConfig.Resources, + }, + }, + Containers: []corev1.Container{}, + }, + } + + p := &corev1.Pod{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{}, + }, + } + actualP, injected, err := injector.Inject(context.Background(), inputSecret, p) + assert.NoError(t, err) + assert.True(t, injected) + if diff := deep.Equal(actualP, expected); diff != nil { + assert.Fail(t, "actual != expected", "Diff: %v", diff) + } +} diff --git a/pkg/webhook/secrets.go b/pkg/webhook/secrets.go index ffffc53cd..eae878cef 100644 --- a/pkg/webhook/secrets.go +++ b/pkg/webhook/secrets.go @@ -74,6 +74,7 @@ func NewSecretsMutator(cfg *config.Config, _ promutils.Scope) *SecretsMutator { NewGlobalSecrets(secretmanager.NewFileEnvSecretManager(secretmanager.GetConfig())), NewK8sSecretsInjector(), NewAWSSecretManagerInjector(cfg.AWSSecretManagerConfig), + NewGCPSecretManagerInjector(cfg.GCPSecretManagerConfig), NewVaultSecretManagerInjector(cfg.VaultSecretManagerConfig), }, }