Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix failure cases in which abort fails for a task #138

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions pkg/controller/nodes/task/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,11 +500,6 @@ func (t Handler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, r
currentPhase := nCtx.NodeStateReader().GetTaskNodeState().PluginPhase
logger.Debugf(ctx, "Abort invoked with phase [%v]", currentPhase)

if currentPhase.IsTerminal() {
logger.Debugf(ctx, "Returning immediately from Abort since task is already in terminal phase.", currentPhase)
return nil
}

ttype := nCtx.TaskReader().GetTaskType()
p, err := t.ResolvePlugin(ctx, ttype)
if err != nil {
Expand Down Expand Up @@ -535,6 +530,14 @@ func (t Handler) Abort(ctx context.Context, nCtx handler.NodeExecutionContext, r
logger.Errorf(ctx, "Abort failed when calling plugin abort.")
return err
}

// We should not try and send an event if we are already in terminal case, as we probably have already sent the event.
// Only if we are non terminal - lets send a failure event
if currentPhase.IsTerminal() {
logger.Debugf(ctx, "Returning immediately from Abort since task is already in terminal phase.", currentPhase)
return nil
}

taskExecID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID()
evRecorder := nCtx.EventsRecorder()
if err := evRecorder.RecordTaskEvent(ctx, &event.TaskExecutionEvent{
Expand Down
40 changes: 30 additions & 10 deletions pkg/controller/nodes/task/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ func Test_task_Handle_Barrier(t *testing.T) {
}

func Test_task_Abort(t *testing.T) {
createNodeCtx := func(ev *fakeBufferedTaskEventRecorder) *nodeMocks.NodeExecutionContext {
createNodeCtx := func(ev *fakeBufferedTaskEventRecorder, p pluginCore.Phase) *nodeMocks.NodeExecutionContext {
wfExecID := &core.WorkflowExecutionIdentifier{
Project: "project",
Domain: "domain",
Expand Down Expand Up @@ -1201,6 +1201,7 @@ func Test_task_Abort(t *testing.T) {
assert.NoError(t, cod.Encode(test{A: a}, st))
nr := &nodeMocks.NodeStateReader{}
nr.OnGetTaskNodeState().Return(handler.TaskNodeState{
PluginPhase: p,
PluginState: st.Bytes(),
})
nCtx.OnNodeStateReader().Return(nr)
Expand All @@ -1210,6 +1211,7 @@ func Test_task_Abort(t *testing.T) {
noopRm := CreateNoopResourceManager(context.TODO(), promutils.NewTestScope())

type fields struct {
p pluginCore.Phase
defaultPluginCallback func() pluginCore.Plugin
}
type args struct {
Expand All @@ -1222,22 +1224,40 @@ func Test_task_Abort(t *testing.T) {
wantErr bool
abortCalled bool
}{
{"no-plugin", fields{defaultPluginCallback: func() pluginCore.Plugin {
{"no-plugin", fields{p: pluginCore.PhaseInitializing, defaultPluginCallback: func() pluginCore.Plugin {
return nil
}}, args{nil}, true, false},

{"abort-fails", fields{defaultPluginCallback: func() pluginCore.Plugin {
{"abort-fails", fields{p: pluginCore.PhaseQueued, defaultPluginCallback: func() pluginCore.Plugin {
p := &pluginCoreMocks.Plugin{}
p.On("GetID").Return("id")
p.On("Abort", mock.Anything, mock.Anything).Return(fmt.Errorf("error"))
p.OnGetID().Return("id")
p.OnAbortMatch(mock.Anything, mock.Anything).Return(fmt.Errorf("error"))
return p
}}, args{nil}, true, true},
{"abort-success", fields{defaultPluginCallback: func() pluginCore.Plugin {
{"abort-success", fields{p: pluginCore.PhaseWaitingForResources, defaultPluginCallback: func() pluginCore.Plugin {
p := &pluginCoreMocks.Plugin{}
p.On("GetID").Return("id")
p.On("Abort", mock.Anything, mock.Anything).Return(nil)
p.OnGetID().Return("id")
p.OnAbortMatch(mock.Anything, mock.Anything).Return(nil)
return p
}}, args{ev: &fakeBufferedTaskEventRecorder{}}, false, true},
{"abort-terminal-event", fields{p: pluginCore.PhaseRetryableFailure, defaultPluginCallback: func() pluginCore.Plugin {
p := &pluginCoreMocks.Plugin{}
p.OnGetID().Return("id")
p.OnAbortMatch(mock.Anything, mock.Anything).Return(nil)
return p
}}, args{ev: nil}, false, true},
{"abort-terminal-event2", fields{p: pluginCore.PhaseSuccess, defaultPluginCallback: func() pluginCore.Plugin {
p := &pluginCoreMocks.Plugin{}
p.OnGetID().Return("id")
p.OnAbortMatch(mock.Anything, mock.Anything).Return(nil)
return p
}}, args{ev: nil}, false, true},
{"abort-terminal-event3", fields{p: pluginCore.PhasePermanentFailure, defaultPluginCallback: func() pluginCore.Plugin {
p := &pluginCoreMocks.Plugin{}
p.OnGetID().Return("id")
p.OnAbortMatch(mock.Anything, mock.Anything).Return(nil)
return p
}}, args{ev: nil}, false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -1246,14 +1266,14 @@ func Test_task_Abort(t *testing.T) {
defaultPlugin: m,
resourceManager: noopRm,
}
nCtx := createNodeCtx(tt.args.ev)
nCtx := createNodeCtx(tt.args.ev, tt.fields.p)
if err := tk.Abort(context.TODO(), nCtx, "reason"); (err != nil) != tt.wantErr {
t.Errorf("Handler.Abort() error = %v, wantErr %v", err, tt.wantErr)
}
c := 0
if tt.abortCalled {
c = 1
if !tt.wantErr {
if !tt.wantErr && tt.args.ev != nil {
assert.Len(t, tt.args.ev.evs, 1)
}
}
Expand Down