Skip to content

Commit

Permalink
Use correct node ID when recovery subworkflow nodes (flyteorg#481)
Browse files Browse the repository at this point in the history
* using fully qualified node id on recovery

Signed-off-by: Daniel Rammer <[email protected]>

* fixed unit tests

Signed-off-by: Daniel Rammer <[email protected]>

Signed-off-by: Daniel Rammer <[email protected]>
  • Loading branch information
hamersaw authored Sep 8, 2022
1 parent 98ee0f6 commit 31c6efa
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 159 deletions.
14 changes: 12 additions & 2 deletions flytepropeller/pkg/controller/nodes/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,17 @@ func (c *nodeExecutor) IdempotentRecordEvent(ctx context.Context, nodeEvent *eve
}

func (c *nodeExecutor) attemptRecovery(ctx context.Context, nCtx handler.NodeExecutionContext) (handler.PhaseInfo, error) {
recovered, err := c.recoveryClient.RecoverNodeExecution(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, nCtx.NodeExecutionMetadata().GetNodeExecutionID())
fullyQualifiedNodeID := nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId
if nCtx.ExecutionContext().GetEventVersion() != v1alpha1.EventVersion0 {
// compute fully qualified node id (prefixed with parent id and retry attempt) to ensure uniqueness
var err error
fullyQualifiedNodeID, err = common.GenerateUniqueID(nCtx.ExecutionContext().GetParentInfo(), nCtx.NodeExecutionMetadata().GetNodeExecutionID().NodeId)
if err != nil {
return handler.PhaseInfoUndefined, err
}
}

recovered, err := c.recoveryClient.RecoverNodeExecution(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, fullyQualifiedNodeID)
if err != nil {
st, ok := status.FromError(err)
if !ok || st.Code() != codes.NotFound {
Expand Down Expand Up @@ -184,7 +194,7 @@ func (c *nodeExecutor) attemptRecovery(ctx context.Context, nCtx handler.NodeExe
return handler.PhaseInfoUndefined, nil
}

recoveredData, err := c.recoveryClient.RecoverNodeExecutionData(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, nCtx.NodeExecutionMetadata().GetNodeExecutionID())
recoveredData, err := c.recoveryClient.RecoverNodeExecutionData(ctx, nCtx.ExecutionContext().GetExecutionConfig().RecoveryExecution.WorkflowExecutionIdentifier, fullyQualifiedNodeID)
if err != nil {
st, ok := status.FromError(err)
if !ok || st.Code() != codes.NotFound {
Expand Down
41 changes: 19 additions & 22 deletions flytepropeller/pkg/controller/nodes/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import (

var fakeKubeClient = mocks4.NewFakeKubeClient()
var catalogClient = catalog.NOOPCatalog{}
var recoveryClient = &recoveryMocks.RecoveryClient{}
var recoveryClient = &recoveryMocks.Client{}

const taskID = "tID"
const inputsPath = "inputs.pb"
Expand Down Expand Up @@ -2028,10 +2028,6 @@ func TestRecover(t *testing.T) {
Name: "name",
}
nodeID := "recovering"
nodeExecID := &core.NodeExecutionIdentifier{
ExecutionId: wfExecID,
NodeId: nodeID,
}

fullInputs := &core.LiteralMap{
Literals: map[string]*core.Literal{
Expand Down Expand Up @@ -2074,6 +2070,7 @@ func TestRecover(t *testing.T) {
WorkflowExecutionIdentifier: recoveryID,
},
})
execContext.OnGetEventVersion().Return(v1alpha1.EventVersion0)

nm := &nodeHandlerMocks.NodeExecutionMetadata{}
nm.OnGetNodeExecutionID().Return(&core.NodeExecutionIdentifier{
Expand All @@ -2094,8 +2091,8 @@ func TestRecover(t *testing.T) {
nCtx.OnNodeStatus().Return(ns)

t.Run("recover task node successfully", func(t *testing.T) {
recoveryClient := &recoveryMocks.RecoveryClient{}
recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return(
recoveryClient := &recoveryMocks.Client{}
recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeID).Return(
&admin.NodeExecution{
Closure: &admin.NodeExecutionClosure{
Phase: core.NodeExecution_SUCCEEDED,
Expand All @@ -2105,7 +2102,7 @@ func TestRecover(t *testing.T) {
},
}, nil)

recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeExecID).Return(
recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeID).Return(
&admin.NodeExecutionGetDataResponse{
FullInputs: fullInputs,
FullOutputs: fullOutputs,
Expand Down Expand Up @@ -2139,8 +2136,8 @@ func TestRecover(t *testing.T) {
assert.Equal(t, phaseInfo.GetPhase(), handler.EPhaseRecovered)
})
t.Run("recover cached, dynamic task node successfully", func(t *testing.T) {
recoveryClient := &recoveryMocks.RecoveryClient{}
recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return(
recoveryClient := &recoveryMocks.Client{}
recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeID).Return(
&admin.NodeExecution{
Closure: &admin.NodeExecutionClosure{
Phase: core.NodeExecution_SUCCEEDED,
Expand Down Expand Up @@ -2178,7 +2175,7 @@ func TestRecover(t *testing.T) {
},
},
}
recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeExecID).Return(
recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeID).Return(
&admin.NodeExecutionGetDataResponse{
FullInputs: fullInputs,
FullOutputs: fullOutputs,
Expand Down Expand Up @@ -2223,8 +2220,8 @@ func TestRecover(t *testing.T) {
}, phaseInfo.GetInfo().TaskNodeInfo.TaskNodeMetadata))
})
t.Run("recover workflow node successfully", func(t *testing.T) {
recoveryClient := &recoveryMocks.RecoveryClient{}
recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return(
recoveryClient := &recoveryMocks.Client{}
recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeID).Return(
&admin.NodeExecution{
Closure: &admin.NodeExecutionClosure{
Phase: core.NodeExecution_SUCCEEDED,
Expand All @@ -2243,7 +2240,7 @@ func TestRecover(t *testing.T) {
},
}, nil)

recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeExecID).Return(
recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeID).Return(
&admin.NodeExecutionGetDataResponse{
FullInputs: fullInputs,
FullOutputs: fullOutputs,
Expand Down Expand Up @@ -2280,8 +2277,8 @@ func TestRecover(t *testing.T) {
})

t.Run("nothing to recover", func(t *testing.T) {
recoveryClient := &recoveryMocks.RecoveryClient{}
recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return(
recoveryClient := &recoveryMocks.Client{}
recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeID).Return(
&admin.NodeExecution{
Closure: &admin.NodeExecutionClosure{
Phase: core.NodeExecution_FAILED,
Expand All @@ -2298,8 +2295,8 @@ func TestRecover(t *testing.T) {
})

t.Run("Fetch inputs", func(t *testing.T) {
recoveryClient := &recoveryMocks.RecoveryClient{}
recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return(
recoveryClient := &recoveryMocks.Client{}
recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeID).Return(
&admin.NodeExecution{
InputUri: "inputuri",
Closure: &admin.NodeExecutionClosure{
Expand All @@ -2310,7 +2307,7 @@ func TestRecover(t *testing.T) {
},
}, nil)

recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeExecID).Return(
recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeID).Return(
&admin.NodeExecutionGetDataResponse{
FullOutputs: fullOutputs,
}, nil)
Expand Down Expand Up @@ -2344,8 +2341,8 @@ func TestRecover(t *testing.T) {
mockPBStore.AssertNumberOfCalls(t, "ReadProtobuf", 1)
})
t.Run("Fetch outputs", func(t *testing.T) {
recoveryClient := &recoveryMocks.RecoveryClient{}
recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeExecID).Return(
recoveryClient := &recoveryMocks.Client{}
recoveryClient.On("RecoverNodeExecution", mock.Anything, recoveryID, nodeID).Return(
&admin.NodeExecution{
Closure: &admin.NodeExecutionClosure{
Phase: core.NodeExecution_SUCCEEDED,
Expand All @@ -2355,7 +2352,7 @@ func TestRecover(t *testing.T) {
},
}, nil)

recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeExecID).Return(
recoveryClient.On("RecoverNodeExecutionData", mock.Anything, recoveryID, nodeID).Return(
&admin.NodeExecutionGetDataResponse{
FullInputs: fullInputs,
}, nil)
Expand Down
12 changes: 6 additions & 6 deletions flytepropeller/pkg/controller/nodes/recovery/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,28 @@ import (
//go:generate mockery -name Client -output=mocks -case=underscore

type Client interface {
RecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) (*admin.NodeExecution, error)
RecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, id *core.NodeExecutionIdentifier) (*admin.NodeExecutionGetDataResponse, error)
RecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID string) (*admin.NodeExecution, error)
RecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID string) (*admin.NodeExecutionGetDataResponse, error)
}

type recoveryClient struct {
adminClient service.AdminServiceClient
}

func (c *recoveryClient) RecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID *core.NodeExecutionIdentifier) (*admin.NodeExecution, error) {
func (c *recoveryClient) RecoverNodeExecution(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID string) (*admin.NodeExecution, error) {
origNodeID := &core.NodeExecutionIdentifier{
ExecutionId: execID,
NodeId: nodeID.NodeId,
NodeId: nodeID,
}
return c.adminClient.GetNodeExecution(ctx, &admin.NodeExecutionGetRequest{
Id: origNodeID,
})
}

func (c *recoveryClient) RecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID *core.NodeExecutionIdentifier) (*admin.NodeExecutionGetDataResponse, error) {
func (c *recoveryClient) RecoverNodeExecutionData(ctx context.Context, execID *core.WorkflowExecutionIdentifier, nodeID string) (*admin.NodeExecutionGetDataResponse, error) {
origNodeID := &core.NodeExecutionIdentifier{
ExecutionId: execID,
NodeId: nodeID.NodeId,
NodeId: nodeID,
}
return c.adminClient.GetNodeExecutionData(ctx, &admin.NodeExecutionGetDataRequest{
Id: origNodeID,
Expand Down
36 changes: 18 additions & 18 deletions flytepropeller/pkg/controller/nodes/recovery/mocks/client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

100 changes: 0 additions & 100 deletions flytepropeller/pkg/controller/nodes/recovery/mocks/recovery_client.go

This file was deleted.

Loading

0 comments on commit 31c6efa

Please sign in to comment.