Skip to content

Commit

Permalink
Abort AWS Batch Job on Finalize (flyteorg#58)
Browse files Browse the repository at this point in the history
AWS Batch Plugin can determine that a task has failed if it detects it can no longer satisfy the failure ratio requirement. When this happens the array task node will be terminated then retried (if there are retry count available). This should trigger a cancellation to the AWS Job as well. 

It was originally written this way to match the old version of this code that, upon discussion, has been deemed a bad design. 

- [X] Need to ensure the cancellation API is idempotent (or add error check/ignore if it's not)
         Validated by running a Workflow and waiting for it to kick off an AWS Job then aborting it. That ended up calling Abort() then Finalize() and the second call was no-op
  • Loading branch information
EngHabu authored Feb 20, 2020
1 parent 1a10fb5 commit a931a40
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 21 deletions.
21 changes: 2 additions & 19 deletions go/tasks/plugins/array/awsbatch/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,28 +118,11 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c
}

func (e Executor) Abort(ctx context.Context, tCtx core.TaskExecutionContext) error {
pluginState := &State{}
if _, err := tCtx.PluginStateReader().Get(pluginState); err != nil {
return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read unmarshal custom state")
}

if pluginState.State == nil {
pluginState.State = &arrayCore.State{}
}

p, _ := pluginState.GetPhase()
logger.Infof(ctx, "Abort is called with phase [%v]", p)

switch p {
case arrayCore.PhaseCheckingSubTaskExecutions:
return TerminateSubTasks(ctx, e.jobStore.Client, *pluginState.GetExternalJobID())
}

return nil
return TerminateSubTasks(ctx, tCtx, e.jobStore.Client, "Aborted")
}

func (e Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error {
return nil
return TerminateSubTasks(ctx, tCtx, e.jobStore.Client, "Finalized")
}

func NewExecutor(ctx context.Context, awsClient aws.Client, cfg *batchConfig.Config,
Expand Down
28 changes: 26 additions & 2 deletions go/tasks/plugins/array/awsbatch/launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"fmt"

"github.com/lyft/flyteplugins/go/tasks/errors"

"github.com/lyft/flytestdlib/logger"

arrayCore "github.com/lyft/flyteplugins/go/tasks/plugins/array/core"
Expand Down Expand Up @@ -67,6 +69,28 @@ func LaunchSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, batchCl
return nextState, nil
}

func TerminateSubTasks(ctx context.Context, batchClient Client, jobID string) error {
return batchClient.TerminateJob(ctx, jobID, "aborted")
// Attempts to terminate the AWS Job if one is recorded in the pluginState. This API is idempotent and should be safe
// to call multiple times on the same job. It'll result in multiple calls to AWS Batch in that case, however.
func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, batchClient Client, reason string) error {
pluginState := &State{}
if _, err := tCtx.PluginStateReader().Get(pluginState); err != nil {
return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to unmarshal custom state")
}

// This only makes sense if the task has "just" been kicked off. Assigning state here is meant to make subsequent
// code simpler.
if pluginState.State == nil {
pluginState.State = &arrayCore.State{}
}

p, _ := pluginState.GetPhase()
logger.Infof(ctx, "TerminateSubTasks is called with phase [%v] and reason [%v]", p, reason)

if pluginState.GetExternalJobID() != nil {
jobID := *pluginState.GetExternalJobID()
logger.Infof(ctx, "Cancelling AWS Job [%v] because [%v].", jobID, reason)
return batchClient.TerminateJob(ctx, jobID, reason)
}

return nil
}
25 changes: 25 additions & 0 deletions go/tasks/plugins/array/awsbatch/launcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package awsbatch
import (
"testing"

"github.com/stretchr/testify/mock"

"k8s.io/apimachinery/pkg/api/resource"

core3 "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core"
Expand Down Expand Up @@ -126,6 +128,29 @@ func TestLaunchSubTasks(t *testing.T) {
})
}

func TestTerminateSubTasks(t *testing.T) {
ctx := context.Background()
pStateReader := &mocks.PluginStateReader{}
pStateReader.OnGetMatch(mock.Anything).Return(0, nil).Run(func(args mock.Arguments) {
s := args.Get(0).(*State)
s.ExternalJobID = refStr("abc-123")
})

tCtx := &mocks.TaskExecutionContext{}
tCtx.OnPluginStateReader().Return(pStateReader)

batchClient := &mocks2.Client{}
batchClient.OnTerminateJob(ctx, "abc-123", "Test terminate").Return(nil).Once()

t.Run("Simple", func(t *testing.T) {
assert.NoError(t, TerminateSubTasks(ctx, tCtx, batchClient, "Test terminate"))
})

batchClient.AssertExpectations(t)
tCtx.AssertExpectations(t)
pStateReader.AssertExpectations(t)
}

func assertEqual(t testing.TB, a, b interface{}) {
if diff := deep.Equal(a, b); diff != nil {
t.Error(diff)
Expand Down

0 comments on commit a931a40

Please sign in to comment.