diff --git a/flytepropeller/pkg/controller/nodes/executor.go b/flytepropeller/pkg/controller/nodes/executor.go index 7863d7ba54..f26b7d2cf8 100644 --- a/flytepropeller/pkg/controller/nodes/executor.go +++ b/flytepropeller/pkg/controller/nodes/executor.go @@ -155,7 +155,17 @@ func (c *nodeExecutor) IdempotentRecordEvent(ctx context.Context, nodeEvent *eve } func (c *nodeExecutor) attemptRecovery(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.PhaseInfo, error) { - recovered, err := c.recoveryClient.RecoverNodeExecution(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + fullyQualifiedNodeID := nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId + if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 { + // compute fully qualified node id (prefixed with parent id and retry attempt) to ensure uniqueness + var err error + fullyQualifiedNodeID, err = common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId) + if err != nil { + return handler.PhaseInfoUndefined, err + } + } + + recovered, err := c.recoveryClient.RecoverNodeExecution(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, fullyQualifiedNodeID) if err != nil { st, ok := status.FromError(err) if !ok || st.Code() != codes.NotFound { @@ -184,7 +194,7 @@ func (c *nodeExecutor) attemptRecovery(ctx context.Context, nCtx handler.NodeExe return handler.PhaseInfoUndefined, nil } - recoveredData, err := c.recoveryClient.RecoverNodeExecutionData(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + recoveredData, err := c.recoveryClient.RecoverNodeExecutionData(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, fullyQualifiedNodeID) if err != nil { st, ok := status.FromError(err) if !ok || st.Code() != codes.NotFound { diff --git a/flytepropeller/pkg/controller/nodes/executor_test.go b/flytepropeller/pkg/controller/nodes/executor_test.go index 547a51da90..19e5c3a562 100644 --- a/flytepropeller/pkg/controller/nodes/executor_test.go +++ b/flytepropeller/pkg/controller/nodes/executor_test.go @@ -49,7 +49,7 @@ import ( var fakeKubeClient = mocks4.NewFakeKubeClient() var catalogClient = catalog.NOOPCatalog{} -var recoveryClient = &recoveryMocks.RecoveryClient{} +var recoveryClient = &recoveryMocks.Client{} const taskID = "tID" const inputsPath = "inputs.pb" @@ -2028,10 +2028,6 @@ func TestRecover(t *testing.T) { Name: "name", } nodeID := "recovering" - nodeExecID := &core.NodeExecutionIdentifier{ - ExecutionId: wfExecID, - NodeId: nodeID, - } fullInputs := &core.LiteralMap{ Literals: map[string]*core.Literal{ @@ -2074,6 +2070,7 @@ func TestRecover(t *testing.T) { WorkflowExecutionIdentifier: recoveryID, }, }) + execContext.OnGetEventVersion().Return(v1alpha1.EventVersion0) nm := &nodeHandlerMocks.NodeExecutionMetadata{} nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{ @@ -2094,8 +2091,8 @@ func TestRecover(t *testing.T) { nCtx.OnNodeStatus().Return(ns) t.Run("recover task node successfully", func(t *testing.T) { - recoveryClient := &recoveryMocks.RecoveryClient{} - recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return( + recoveryClient := &recoveryMocks.Client{} + recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeID).Return( &admin.NodeExecution{ Closure: &admin.NodeExecutionClosure{ Phase: core.NodeExecution_SUCCEEDED, @@ -2105,7 +2102,7 @@ func TestRecover(t *testing.T) { }, }, nil) - recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeExecID).Return( + recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeID).Return( &admin.NodeExecutionGetDataResponse{ FullInputs: fullInputs, FullOutputs: fullOutputs, @@ -2139,8 +2136,8 @@ func TestRecover(t *testing.T) { assert.Equal(t, phaseInfo.GetPhase(), handler.EPhaseRecovered) }) t.Run("recover cached, dynamic task node successfully", func(t *testing.T) { - recoveryClient := &recoveryMocks.RecoveryClient{} - recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return( + recoveryClient := &recoveryMocks.Client{} + recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeID).Return( &admin.NodeExecution{ Closure: &admin.NodeExecutionClosure{ Phase: core.NodeExecution_SUCCEEDED, @@ -2178,7 +2175,7 @@ func TestRecover(t *testing.T) { }, }, } - recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeExecID).Return( + recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeID).Return( &admin.NodeExecutionGetDataResponse{ FullInputs: fullInputs, FullOutputs: fullOutputs, @@ -2223,8 +2220,8 @@ func TestRecover(t *testing.T) { }, phaseInfo.GetInfo().TaskNodeInfo.TaskNodeMetadata)) }) t.Run("recover workflow node successfully", func(t *testing.T) { - recoveryClient := &recoveryMocks.RecoveryClient{} - recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return( + recoveryClient := &recoveryMocks.Client{} + recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeID).Return( &admin.NodeExecution{ Closure: &admin.NodeExecutionClosure{ Phase: core.NodeExecution_SUCCEEDED, @@ -2243,7 +2240,7 @@ func TestRecover(t *testing.T) { }, }, nil) - recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeExecID).Return( + recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeID).Return( &admin.NodeExecutionGetDataResponse{ FullInputs: fullInputs, FullOutputs: fullOutputs, @@ -2280,8 +2277,8 @@ func TestRecover(t *testing.T) { }) t.Run("nothing to recover", func(t *testing.T) { - recoveryClient := &recoveryMocks.RecoveryClient{} - recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return( + recoveryClient := &recoveryMocks.Client{} + recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeID).Return( &admin.NodeExecution{ Closure: &admin.NodeExecutionClosure{ Phase: core.NodeExecution_FAILED, @@ -2298,8 +2295,8 @@ func TestRecover(t *testing.T) { }) t.Run("Fetch inputs", func(t *testing.T) { - recoveryClient := &recoveryMocks.RecoveryClient{} - recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return( + recoveryClient := &recoveryMocks.Client{} + recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeID).Return( &admin.NodeExecution{ InputUri: "inputuri", Closure: &admin.NodeExecutionClosure{ @@ -2310,7 +2307,7 @@ func TestRecover(t *testing.T) { }, }, nil) - recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeExecID).Return( + recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeID).Return( &admin.NodeExecutionGetDataResponse{ FullOutputs: fullOutputs, }, nil) @@ -2344,8 +2341,8 @@ func TestRecover(t *testing.T) { mockPBStore.AssertNumberOfCalls(t, "ReadProtobuf", 1) }) t.Run("Fetch outputs", func(t *testing.T) { - recoveryClient := &recoveryMocks.RecoveryClient{} - recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return( + recoveryClient := &recoveryMocks.Client{} + recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeID).Return( &admin.NodeExecution{ Closure: &admin.NodeExecutionClosure{ Phase: core.NodeExecution_SUCCEEDED, @@ -2355,7 +2352,7 @@ func TestRecover(t *testing.T) { }, }, nil) - recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeExecID).Return( + recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeID).Return( &admin.NodeExecutionGetDataResponse{ FullInputs: fullInputs, }, nil) diff --git a/flytepropeller/pkg/controller/nodes/recovery/client.go b/flytepropeller/pkg/controller/nodes/recovery/client.go index d4a1954395..f1cdd52ca1 100644 --- a/flytepropeller/pkg/controller/nodes/recovery/client.go +++ b/flytepropeller/pkg/controller/nodes/recovery/client.go @@ -11,28 +11,28 @@ import ( //go:generate mockery -name Client -output=mocks -case=underscore type Client interface { - RecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) (*admin.NodeExecution, error) - RecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) (*admin.NodeExecutionGetDataResponse, error) + RecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID string) (*admin.NodeExecution, error) + RecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID string) (*admin.NodeExecutionGetDataResponse, error) } type recoveryClient struct { adminClient service.AdminServiceClient } -func (c *recoveryClient) RecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID *core.NodeExecutionIdentifier) (*admin.NodeExecution, error) { +func (c *recoveryClient) RecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID string) (*admin.NodeExecution, error) { origNodeID := &core.NodeExecutionIdentifier{ ExecutionId: execID, - NodeId: nodeID.NodeId, + NodeId: nodeID, } return c.adminClient.GetNodeExecution(ctx, &admin.NodeExecutionGetRequest{ Id: origNodeID, }) } -func (c *recoveryClient) RecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID *core.NodeExecutionIdentifier) (*admin.NodeExecutionGetDataResponse, error) { +func (c *recoveryClient) RecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID string) (*admin.NodeExecutionGetDataResponse, error) { origNodeID := &core.NodeExecutionIdentifier{ ExecutionId: execID, - NodeId: nodeID.NodeId, + NodeId: nodeID, } return c.adminClient.GetNodeExecutionData(ctx, &admin.NodeExecutionGetDataRequest{ Id: origNodeID, diff --git a/flytepropeller/pkg/controller/nodes/recovery/mocks/client.go b/flytepropeller/pkg/controller/nodes/recovery/mocks/client.go index e55dc7add0..23a4770140 100644 --- a/flytepropeller/pkg/controller/nodes/recovery/mocks/client.go +++ b/flytepropeller/pkg/controller/nodes/recovery/mocks/client.go @@ -25,8 +25,8 @@ func (_m Client_RecoverNodeExecution) Return(_a0 *admin.NodeExecution, _a1 error return &Client_RecoverNodeExecution{Call: _m.Call.Return(_a0, _a1)} } -func (_m *Client) OnRecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) *Client_RecoverNodeExecution { - c_call := _m.On("RecoverNodeExecution", ctx, execID, id) +func (_m *Client) OnRecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID string) *Client_RecoverNodeExecution { + c_call := _m.On("RecoverNodeExecution", ctx, execID, nodeID) return &Client_RecoverNodeExecution{Call: c_call} } @@ -35,13 +35,13 @@ func (_m *Client) OnRecoverNodeExecutionMatch(matchers ...interface{}) *Client_R return &Client_RecoverNodeExecution{Call: c_call} } -// RecoverNodeExecution provides a mock function with given fields: ctx, execID, id -func (_m *Client) RecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) (*admin.NodeExecution, error) { - ret := _m.Called(ctx, execID, id) +// RecoverNodeExecution provides a mock function with given fields: ctx, execID, nodeID +func (_m *Client) RecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID string) (*admin.NodeExecution, error) { + ret := _m.Called(ctx, execID, nodeID) var r0 *admin.NodeExecution - if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier, *core.NodeExecutionIdentifier) *admin.NodeExecution); ok { - r0 = rf(ctx, execID, id) + if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier, string) *admin.NodeExecution); ok { + r0 = rf(ctx, execID, nodeID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*admin.NodeExecution) @@ -49,8 +49,8 @@ func (_m *Client) RecoverNodeExecution(ctx context.Context, execID *core.Workflo } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *core.WorkflowExecutionIdentifier, *core.NodeExecutionIdentifier) error); ok { - r1 = rf(ctx, execID, id) + if rf, ok := ret.Get(1).(func(context.Context, *core.WorkflowExecutionIdentifier, string) error); ok { + r1 = rf(ctx, execID, nodeID) } else { r1 = ret.Error(1) } @@ -66,8 +66,8 @@ func (_m Client_RecoverNodeExecutionData) Return(_a0 *admin.NodeExecutionGetData return &Client_RecoverNodeExecutionData{Call: _m.Call.Return(_a0, _a1)} } -func (_m *Client) OnRecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) *Client_RecoverNodeExecutionData { - c_call := _m.On("RecoverNodeExecutionData", ctx, execID, id) +func (_m *Client) OnRecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID string) *Client_RecoverNodeExecutionData { + c_call := _m.On("RecoverNodeExecutionData", ctx, execID, nodeID) return &Client_RecoverNodeExecutionData{Call: c_call} } @@ -76,13 +76,13 @@ func (_m *Client) OnRecoverNodeExecutionDataMatch(matchers ...interface{}) *Clie return &Client_RecoverNodeExecutionData{Call: c_call} } -// RecoverNodeExecutionData provides a mock function with given fields: ctx, execID, id -func (_m *Client) RecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) (*admin.NodeExecutionGetDataResponse, error) { - ret := _m.Called(ctx, execID, id) +// RecoverNodeExecutionData provides a mock function with given fields: ctx, execID, nodeID +func (_m *Client) RecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID string) (*admin.NodeExecutionGetDataResponse, error) { + ret := _m.Called(ctx, execID, nodeID) var r0 *admin.NodeExecutionGetDataResponse - if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier, *core.NodeExecutionIdentifier) *admin.NodeExecutionGetDataResponse); ok { - r0 = rf(ctx, execID, id) + if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier, string) *admin.NodeExecutionGetDataResponse); ok { + r0 = rf(ctx, execID, nodeID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*admin.NodeExecutionGetDataResponse) @@ -90,8 +90,8 @@ func (_m *Client) RecoverNodeExecutionData(ctx context.Context, execID *core.Wor } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *core.WorkflowExecutionIdentifier, *core.NodeExecutionIdentifier) error); ok { - r1 = rf(ctx, execID, id) + if rf, ok := ret.Get(1).(func(context.Context, *core.WorkflowExecutionIdentifier, string) error); ok { + r1 = rf(ctx, execID, nodeID) } else { r1 = ret.Error(1) } diff --git a/flytepropeller/pkg/controller/nodes/recovery/mocks/recovery_client.go b/flytepropeller/pkg/controller/nodes/recovery/mocks/recovery_client.go deleted file mode 100644 index f52b65474e..0000000000 --- a/flytepropeller/pkg/controller/nodes/recovery/mocks/recovery_client.go +++ /dev/null @@ -1,100 +0,0 @@ -// Code generated by mockery v1.0.1. DO NOT EDIT. - -package mocks - -import ( - context "context" - - admin "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" - - core "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - - mock "github.com/stretchr/testify/mock" -) - -// RecoveryClient is an autogenerated mock type for the RecoveryClient type -type RecoveryClient struct { - mock.Mock -} - -type RecoveryClient_RecoverNodeExecution struct { - *mock.Call -} - -func (_m RecoveryClient_RecoverNodeExecution) Return(_a0 *admin.NodeExecution, _a1 error) *RecoveryClient_RecoverNodeExecution { - return &RecoveryClient_RecoverNodeExecution{Call: _m.Call.Return(_a0, _a1)} -} - -func (_m *RecoveryClient) OnRecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) *RecoveryClient_RecoverNodeExecution { - c := _m.On("RecoverNodeExecution", ctx, execID, id) - return &RecoveryClient_RecoverNodeExecution{Call: c} -} - -func (_m *RecoveryClient) OnRecoverNodeExecutionMatch(matchers ...interface{}) *RecoveryClient_RecoverNodeExecution { - c := _m.On("RecoverNodeExecution", matchers...) - return &RecoveryClient_RecoverNodeExecution{Call: c} -} - -// RecoverNodeExecution provides a mock function with given fields: ctx, execID, id -func (_m *RecoveryClient) RecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) (*admin.NodeExecution, error) { - ret := _m.Called(ctx, execID, id) - - var r0 *admin.NodeExecution - if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier, *core.NodeExecutionIdentifier) *admin.NodeExecution); ok { - r0 = rf(ctx, execID, id) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*admin.NodeExecution) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *core.WorkflowExecutionIdentifier, *core.NodeExecutionIdentifier) error); ok { - r1 = rf(ctx, execID, id) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -type RecoveryClient_RecoverNodeExecutionData struct { - *mock.Call -} - -func (_m RecoveryClient_RecoverNodeExecutionData) Return(_a0 *admin.NodeExecutionGetDataResponse, _a1 error) *RecoveryClient_RecoverNodeExecutionData { - return &RecoveryClient_RecoverNodeExecutionData{Call: _m.Call.Return(_a0, _a1)} -} - -func (_m *RecoveryClient) OnRecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) *RecoveryClient_RecoverNodeExecutionData { - c := _m.On("RecoverNodeExecutionData", ctx, execID, id) - return &RecoveryClient_RecoverNodeExecutionData{Call: c} -} - -func (_m *RecoveryClient) OnRecoverNodeExecutionDataMatch(matchers ...interface{}) *RecoveryClient_RecoverNodeExecutionData { - c := _m.On("RecoverNodeExecutionData", matchers...) - return &RecoveryClient_RecoverNodeExecutionData{Call: c} -} - -// RecoverNodeExecutionData provides a mock function with given fields: ctx, execID, id -func (_m *RecoveryClient) RecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) (*admin.NodeExecutionGetDataResponse, error) { - ret := _m.Called(ctx, execID, id) - - var r0 *admin.NodeExecutionGetDataResponse - if rf, ok := ret.Get(0).(func(context.Context, *core.WorkflowExecutionIdentifier, *core.NodeExecutionIdentifier) *admin.NodeExecutionGetDataResponse); ok { - r0 = rf(ctx, execID, id) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*admin.NodeExecutionGetDataResponse) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *core.WorkflowExecutionIdentifier, *core.NodeExecutionIdentifier) error); ok { - r1 = rf(ctx, execID, id) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go index 7af36b89b4..13a997a7ce 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/handler_test.go @@ -153,7 +153,7 @@ func TestWorkflowNodeHandler_StartNode_Launchplan(t *testing.T) { mockNodeStatus.OnGetAttempts().Return(attempts) wfStatus := &mocks2.MutableWorkflowNodeStatus{} mockNodeStatus.OnGetOrCreateWorkflowStatus().Return(wfStatus) - recoveryClient := &mocks5.RecoveryClient{} + recoveryClient := &mocks5.Client{} t.Run("happy v0", func(t *testing.T) { @@ -232,7 +232,7 @@ func TestWorkflowNodeHandler_CheckNodeStatus(t *testing.T) { mockNodeStatus := &mocks2.ExecutableNodeStatus{} mockNodeStatus.OnGetAttempts().Return(attempts) mockNodeStatus.OnGetDataDir().Return(dataDir) - recoveryClient := &mocks5.RecoveryClient{} + recoveryClient := &mocks5.Client{} t.Run("stillRunning V0", func(t *testing.T) { @@ -304,7 +304,7 @@ func TestWorkflowNodeHandler_AbortNode(t *testing.T) { mockNodeStatus := &mocks2.ExecutableNodeStatus{} mockNodeStatus.OnGetAttempts().Return(attempts) mockNodeStatus.OnGetDataDir().Return(dataDir) - recoveryClient := &mocks5.RecoveryClient{} + recoveryClient := &mocks5.Client{} t.Run("abort v0", func(t *testing.T) { diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan.go index 8d14968d0f..ffcc5b9e18 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan.go @@ -74,7 +74,17 @@ func (l *launchPlanHandler) StartLaunchPlan(ctx context.Context, nCtx handler.No } if nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier != nil { - recovered, err := l.recoveryClient.RecoverNodeExecution(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, nCtx.NodeExecutionMetadata().GetNodeExecutionID()) + fullyQualifiedNodeID := nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId + if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 { + // compute fully qualified node id (prefixed with parent id and retry attempt) to ensure uniqueness + var err error + fullyQualifiedNodeID, err = common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId) + if err != nil { + return handler.UnknownTransition, err + } + } + + recovered, err := l.recoveryClient.RecoverNodeExecution(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, fullyQualifiedNodeID) if err != nil { st, ok := status.FromError(err) if !ok || st.Code() != codes.NotFound { diff --git a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan_test.go b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan_test.go index cfabdf515e..6022ccf2e3 100644 --- a/flytepropeller/pkg/controller/nodes/subworkflow/launchplan_test.go +++ b/flytepropeller/pkg/controller/nodes/subworkflow/launchplan_test.go @@ -192,7 +192,7 @@ func TestSubWorkflowHandler_StartLaunchPlan(t *testing.T) { RecoveryExecution: recoveredExecID, }, mock.Anything, mock.Anything, mock.Anything).Return(nil) - recoveryClient := recoveryMocks.RecoveryClient{} + recoveryClient := recoveryMocks.Client{} recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveredExecID, mock.Anything).Return(&admin.NodeExecution{ Closure: &admin.NodeExecutionClosure{ Phase: core.NodeExecution_SUCCEEDED, diff --git a/flytepropeller/pkg/controller/workflow/executor_test.go b/flytepropeller/pkg/controller/workflow/executor_test.go index 02e4a9d0b0..46c8be36b4 100644 --- a/flytepropeller/pkg/controller/workflow/executor_test.go +++ b/flytepropeller/pkg/controller/workflow/executor_test.go @@ -240,7 +240,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Error(t *testing.T) { eventSink := eventMocks.NewMockEventSink() catalogClient, err := catalog.NewCatalogClient(ctx, nil) assert.NoError(t, err) - recoveryClient := &recoveryMocks.RecoveryClient{} + recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, @@ -320,7 +320,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow(t *testing.T) { eventSink := eventMocks.NewMockEventSink() catalogClient, err := catalog.NewCatalogClient(ctx, nil) assert.NoError(t, err) - recoveryClient := &recoveryMocks.RecoveryClient{} + recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, @@ -384,7 +384,7 @@ func BenchmarkWorkflowExecutor(b *testing.B) { eventSink := eventMocks.NewMockEventSink() catalogClient, err := catalog.NewCatalogClient(ctx, nil) assert.NoError(b, err) - recoveryClient := &recoveryMocks.RecoveryClient{} + recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, scope) @@ -485,7 +485,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Failing(t *testing.T) { } catalogClient, err := catalog.NewCatalogClient(ctx, nil) assert.NoError(t, err) - recoveryClient := &recoveryMocks.RecoveryClient{} + recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) @@ -582,7 +582,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_Events(t *testing.T) { catalogClient, err := catalog.NewCatalogClient(ctx, nil) assert.NoError(t, err) adminClient := launchplan.NewFailFastLaunchPlanExecutor() - recoveryClient := &recoveryMocks.RecoveryClient{} + recoveryClient := &recoveryMocks.Client{} nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, eventSink, adminClient, adminClient, maxOutputSize, "s3://bucket", fakeKubeClient, catalogClient, recoveryClient, eventConfig, testClusterID, promutils.NewTestScope()) assert.NoError(t, err) @@ -638,7 +638,7 @@ func TestWorkflowExecutor_HandleFlyteWorkflow_EventFailure(t *testing.T) { nodeEventSink := eventMocks.NewMockEventSink() catalogClient, err := catalog.NewCatalogClient(ctx, nil) assert.NoError(t, err) - recoveryClient := &recoveryMocks.RecoveryClient{} + recoveryClient := &recoveryMocks.Client{} adminClient := launchplan.NewFailFastLaunchPlanExecutor() nodeExec, err := nodes.NewExecutor(ctx, config.GetConfig().NodeConfig, store, enqueueWorkflow, nodeEventSink, adminClient,