Skip to content

Commit

Permalink
The status of the AWS batch job should become failed once the retry l…
Browse files Browse the repository at this point in the history
…imit exceeded (flyteorg#291)

* Turn PhaseRetryableFailure into PhaseRetryLimitExceededFailure

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* update

Signed-off-by: Kevin Su <[email protected]>

* update test

Signed-off-by: Kevin Su <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

* update

Signed-off-by: Kevin Su <[email protected]>

* update tests

Signed-off-by: Kevin Su <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

* wip

Signed-off-by: Kevin Su <[email protected]>

* udpate

Signed-off-by: Kevin Su <[email protected]>

* address comment

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* fix tests

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Dec 1, 2022
1 parent 675a167 commit 17b3010
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 19 deletions.
4 changes: 1 addition & 3 deletions flyteplugins/go/tasks/plugins/array/awsbatch/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c
pluginState, err = LaunchSubTasks(ctx, tCtx, e.jobStore, pluginConfig, pluginState, e.metrics)

case arrayCore.PhaseCheckingSubTaskExecutions:
pluginState, err = CheckSubTasksState(ctx, tCtx.TaskExecutionMetadata(),
tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(),
e.jobStore, tCtx.DataStore(), pluginConfig, pluginState, e.metrics)
pluginState, err = CheckSubTasksState(ctx, tCtx, e.jobStore, pluginConfig, pluginState, e.metrics)

case arrayCore.PhaseAssembleFinalOutput:
pluginState.State, err = array.AssembleFinalOutputs(ctx, e.outputAssembler, tCtx, arrayCore.PhaseSuccess, version, pluginState.State)
Expand Down
27 changes: 22 additions & 5 deletions flyteplugins/go/tasks/plugins/array/awsbatch/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"context"

core2 "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flytestdlib/storage"

"github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array"
Expand Down Expand Up @@ -34,19 +34,32 @@ func createSubJobList(count int) []*Job {
return res
}

func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata, outputPrefix, baseOutputSandbox storage.DataReference, jobStore *JobStore,
dataStore *storage.DataStore, cfg *config.Config, currentState *State, metrics ExecutorMetrics) (newState *State, err error) {
func CheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionContext, jobStore *JobStore,
cfg *config.Config, currentState *State, metrics ExecutorMetrics) (newState *State, err error) {
newState = currentState
parentState := currentState.State
jobName := taskMeta.GetTaskExecutionID().GetGeneratedName()
jobName := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
job := jobStore.Get(jobName)
outputPrefix := tCtx.OutputWriter().GetOutputPrefixPath()
baseOutputSandbox := tCtx.OutputWriter().GetRawOutputPrefix()
dataStore := tCtx.DataStore()
// Check that the taskTemplate is valid
var taskTemplate *core2.TaskTemplate
taskTemplate, err = tCtx.TaskReader().Read(ctx)
if err != nil {
return nil, errors.Wrapf(errors.CorruptedPluginState, err, "Failed to read task template")
} else if taskTemplate == nil {
return nil, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil")
}
retry := toRetryStrategy(ctx, toBackoffLimit(taskTemplate.Metadata), cfg.MinRetries, cfg.MaxRetries)

// If job isn't currently being monitored (recovering from a restart?), add it to the sync-cache and return
if job == nil {
logger.Info(ctx, "Job not found in cache, adding it. [%v]", jobName)

_, err = jobStore.GetOrCreate(jobName, &Job{
ID: *currentState.ExternalJobID,
OwnerReference: taskMeta.GetOwnerID(),
OwnerReference: tCtx.TaskExecutionMetadata().GetOwnerID(),
SubJobs: createSubJobList(currentState.GetExecutionArraySize()),
})

Expand Down Expand Up @@ -108,6 +121,10 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata
} else {
msg.Collect(childIdx, "Job failed")
}

if subJob.Status.Phase == core.PhaseRetryableFailure && *retry.Attempts == int64(len(subJob.Attempts)) {
actualPhase = core.PhasePermanentFailure
}
} else if subJob.Status.Phase.IsSuccess() {
actualPhase, err = array.CheckTaskOutput(ctx, dataStore, outputPrefix, baseOutputSandbox, childIdx, originalIdx)
if err != nil {
Expand Down
87 changes: 76 additions & 11 deletions flyteplugins/go/tasks/plugins/array/awsbatch/monitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package awsbatch
import (
"testing"

"github.com/stretchr/testify/mock"

"github.com/flyteorg/flytestdlib/contextutils"
"github.com/flyteorg/flytestdlib/promutils/labeled"

Expand All @@ -11,6 +13,7 @@ import (

"github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus"

flyteIdl "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"

"github.com/aws/aws-sdk-go/aws/request"
Expand All @@ -19,6 +22,7 @@ import (
arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
ioMocks "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/config"
batchMocks "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/awsbatch/mocks"
"github.com/flyteorg/flytestdlib/utils"
Expand All @@ -35,15 +39,39 @@ func init() {

func TestCheckSubTasksState(t *testing.T) {
ctx := context.Background()
tCtx := &mocks.TaskExecutionContext{}
tID := &mocks.TaskExecutionID{}
tID.OnGetGeneratedName().Return("generated-name")

tMeta := &mocks.TaskExecutionMetadata{}
tMeta.OnGetOwnerID().Return(types.NamespacedName{
Namespace: "domain",
Name: "name",
})
tMeta.OnGetTaskExecutionID().Return(tID)
inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)

outputWriter := &ioMocks.OutputWriter{}
outputWriter.OnGetOutputPrefixPath().Return("")
outputWriter.OnGetRawOutputPrefix().Return("")

taskReader := &mocks.TaskReader{}
task := &flyteIdl.TaskTemplate{
Type: "test",
Target: &flyteIdl.TaskTemplate_Container{
Container: &flyteIdl.Container{
Command: []string{"command"},
Args: []string{"{{.Input}}"},
},
},
Metadata: &flyteIdl.TaskMetadata{Retries: &flyteIdl.RetryStrategy{Retries: 3}},
}
taskReader.On("Read", mock.Anything).Return(task, nil)

tCtx.OnOutputWriter().Return(outputWriter)
tCtx.OnTaskReader().Return(taskReader)
tCtx.OnDataStore().Return(inMemDatastore)
tCtx.OnTaskExecutionMetadata().Return(tMeta)

t.Run("Not in cache", func(t *testing.T) {
mBatchClient := batchMocks.NewMockAwsBatchClient()
Expand All @@ -52,7 +80,7 @@ func TestCheckSubTasksState(t *testing.T) {
utils.NewRateLimiter("", 10, 20))

jobStore := newJobsStore(t, batchClient)
newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, nil, &config.Config{}, &State{
newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{
State: &arrayCore.State{
CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions,
ExecutionArraySize: 5,
Expand Down Expand Up @@ -98,7 +126,7 @@ func TestCheckSubTasksState(t *testing.T) {

assert.NoError(t, err)

newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, nil, &config.Config{}, &State{
newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{
State: &arrayCore.State{
CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions,
ExecutionArraySize: 5,
Expand Down Expand Up @@ -133,13 +161,10 @@ func TestCheckSubTasksState(t *testing.T) {

assert.NoError(t, err)

inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)

retryAttemptsArray, err := bitarray.NewCompactArray(1, bitarray.Item(1))
assert.NoError(t, err)

newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{
newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{
State: &arrayCore.State{
CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions,
ExecutionArraySize: 1,
Expand Down Expand Up @@ -181,13 +206,10 @@ func TestCheckSubTasksState(t *testing.T) {

assert.NoError(t, err)

inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
assert.NoError(t, err)

retryAttemptsArray, err := bitarray.NewCompactArray(2, bitarray.Item(1))
assert.NoError(t, err)

newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{
newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{
State: &arrayCore.State{
CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions,
ExecutionArraySize: 2,
Expand All @@ -206,6 +228,49 @@ func TestCheckSubTasksState(t *testing.T) {
assert.NoError(t, err)
p, _ := newState.GetPhase()
assert.Equal(t, arrayCore.PhaseCheckingSubTaskExecutions.String(), p.String())
})

t.Run("retry limit exceeded", func(t *testing.T) {
mBatchClient := batchMocks.NewMockAwsBatchClient()
batchClient := NewCustomBatchClient(mBatchClient, "", "",
utils.NewRateLimiter("", 10, 20),
utils.NewRateLimiter("", 10, 20))

jobStore := newJobsStore(t, batchClient)
_, err := jobStore.GetOrCreate(tID.GetGeneratedName(), &Job{
ID: "job-id",
Status: JobStatus{
Phase: core.PhaseRunning,
},
SubJobs: []*Job{
{Status: JobStatus{Phase: core.PhaseRetryableFailure}, Attempts: []Attempt{{LogStream: "failed"}}},
{Status: JobStatus{Phase: core.PhaseSuccess}},
},
})

assert.NoError(t, err)

retryAttemptsArray, err := bitarray.NewCompactArray(2, bitarray.Item(1))
assert.NoError(t, err)

newState, err := CheckSubTasksState(ctx, tCtx, jobStore, &config.Config{}, &State{
State: &arrayCore.State{
CurrentPhase: arrayCore.PhaseWriteToDiscoveryThenFail,
ExecutionArraySize: 2,
OriginalArraySize: 2,
OriginalMinSuccesses: 2,
ArrayStatus: arraystatus.ArrayStatus{
Detailed: arrayCore.NewPhasesCompactArray(2),
},
IndexesToCache: bitarray.NewBitSet(2),
RetryAttempts: retryAttemptsArray,
},
ExternalJobID: refStr("job-id"),
JobDefinitionArn: "",
}, getAwsBatchExecutorMetrics(promutils.NewTestScope()))

assert.NoError(t, err)
p, _ := newState.GetPhase()
assert.Equal(t, arrayCore.PhaseWriteToDiscoveryThenFail, p)
})
}
16 changes: 16 additions & 0 deletions flyteplugins/go/tasks/plugins/array/core/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,22 @@ func TestSummaryToPhase(t *testing.T) {
core.PhaseSuccess: 10,
},
},
{
"FailedToRetry",
PhaseWriteToDiscoveryThenFail,
map[core.Phase]int64{
core.PhaseSuccess: 5,
core.PhasePermanentFailure: 5,
},
},
{
"Retrying",
PhaseCheckingSubTaskExecutions,
map[core.Phase]int64{
core.PhaseSuccess: 5,
core.PhaseRetryableFailure: 5,
},
},
}

for _, tt := range tests {
Expand Down

0 comments on commit 17b3010

Please sign in to comment.