Skip to content

Commit

Permalink
sending correct external resources for k8s-array plugin (flyteorg#300)
Browse files Browse the repository at this point in the history
* sending correct external resources for k8s-array plugin

Signed-off-by: Dan Rammer <[email protected]>

* bumping phase versions

Signed-off-by: Dan Rammer <[email protected]>

* using externalResources in aws-batch plugin

Signed-off-by: Dan Rammer <[email protected]>

* lint issues

Signed-off-by: Dan Rammer <[email protected]>

* fixed unit tests

Signed-off-by: Dan Rammer <[email protected]>

* removed unnecessary comments

Signed-off-by: Dan Rammer <[email protected]>

* reverted permanent failure computation

Signed-off-by: Dan Rammer <[email protected]>

---------

Signed-off-by: Dan Rammer <[email protected]>
  • Loading branch information
hamersaw authored Feb 8, 2023
1 parent 07bc0a9 commit e79acfb
Show file tree
Hide file tree
Showing 11 changed files with 83 additions and 60 deletions.
21 changes: 12 additions & 9 deletions go/tasks/plugins/array/awsbatch/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c
}

var err error
var externalResources []*core.ExternalResource

p, version := pluginState.GetPhase()
logger.Infof(ctx, "Entering handle with phase [%v]", p)
Expand All @@ -71,28 +72,28 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c
pluginState.State, err = array.DetermineDiscoverability(ctx, tCtx, pluginConfig.MaxArrayJobSize, pluginState.State)

case arrayCore.PhasePreLaunch:
pluginState, err = EnsureJobDefinition(ctx, tCtx, pluginConfig, e.jobStore.Client, e.jobDefinitionCache, pluginState)
pluginState, err = EnsureJobDefinition(ctx, tCtx, pluginConfig, e.jobStore.Client, e.jobDefinitionCache, pluginState, version+1)

case arrayCore.PhaseWaitingForResources:
fallthrough

case arrayCore.PhaseLaunch:
pluginState, err = LaunchSubTasks(ctx, tCtx, e.jobStore, pluginConfig, pluginState, e.metrics)
pluginState, err = LaunchSubTasks(ctx, tCtx, e.jobStore, pluginConfig, pluginState, e.metrics, version+1)

case arrayCore.PhaseCheckingSubTaskExecutions:
pluginState, err = CheckSubTasksState(ctx, tCtx, e.jobStore, pluginConfig, pluginState, e.metrics)

case arrayCore.PhaseAssembleFinalOutput:
pluginState.State, err = array.AssembleFinalOutputs(ctx, e.outputAssembler, tCtx, arrayCore.PhaseSuccess, version, pluginState.State)
pluginState.State, err = array.AssembleFinalOutputs(ctx, e.outputAssembler, tCtx, arrayCore.PhaseSuccess, version+1, pluginState.State)

case arrayCore.PhaseWriteToDiscoveryThenFail:
pluginState.State, err = array.WriteToDiscovery(ctx, tCtx, pluginState.State, arrayCore.PhaseAssembleFinalError, version)
pluginState.State, externalResources, err = array.WriteToDiscovery(ctx, tCtx, pluginState.State, arrayCore.PhaseAssembleFinalError, version+1)

case arrayCore.PhaseWriteToDiscovery:
pluginState.State, err = array.WriteToDiscovery(ctx, tCtx, pluginState.State, arrayCore.PhaseAssembleFinalOutput, version)
pluginState.State, externalResources, err = array.WriteToDiscovery(ctx, tCtx, pluginState.State, arrayCore.PhaseAssembleFinalOutput, version+1)

case arrayCore.PhaseAssembleFinalError:
pluginState.State, err = array.AssembleFinalOutputs(ctx, e.errorAssembler, tCtx, arrayCore.PhaseRetryableFailure, version, pluginState.State)
pluginState.State, err = array.AssembleFinalOutputs(ctx, e.errorAssembler, tCtx, arrayCore.PhaseRetryableFailure, version+1, pluginState.State)
}

if err != nil {
Expand All @@ -105,17 +106,19 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c

// Always attempt to augment phase with task logs.
var logLinks []*idlCore.TaskLog
var externalResources []*core.ExternalResource

if p == arrayCore.PhasePreLaunch {
nextPhase, _ := pluginState.GetPhase()
if p == arrayCore.PhaseStart && nextPhase != arrayCore.PhaseStart {
// if transitioning from PhaseStart to another phase then cache lookups have completed
externalResources, err = arrayCore.InitializeExternalResources(ctx, tCtx, pluginState.State,
func(tCtx core.TaskExecutionContext, childIndex int) string {
// subTaskIDs for the the aws_batch are generated based on the job ID, therefore
// to initialize we default to an empty string which will be updated later.
return ""
},
)
} else if p != arrayCore.PhaseStart {
} else if p != arrayCore.PhaseStart && p != arrayCore.PhaseWriteToDiscovery && p != arrayCore.PhaseWriteToDiscoveryThenFail {
// if externalResources is not otherwise being populated then attempt to get task log links
logLinks, externalResources, err = GetTaskLinks(ctx, tCtx.TaskExecutionMetadata(), e.jobStore, pluginState)
}

Expand Down
6 changes: 3 additions & 3 deletions go/tasks/plugins/array/awsbatch/job_definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func containerImageRepository(containerImage string) string {
}

func EnsureJobDefinition(ctx context.Context, tCtx pluginCore.TaskExecutionContext, cfg *config.Config, client Client,
definitionCache definition.Cache, currentState *State) (nextState *State, err error) {
definitionCache definition.Cache, currentState *State, terminalVersion uint32) (nextState *State, err error) {

taskTemplate, err := tCtx.TaskReader().Read(ctx)
if err != nil {
Expand All @@ -65,7 +65,7 @@ func EnsureJobDefinition(ctx context.Context, tCtx pluginCore.TaskExecutionConte
containerImage, role, platformCapabilities, existingArn)

nextState = currentState.SetJobDefinitionArn(existingArn)
nextState.State = nextState.SetPhase(arrayCore.PhaseLaunch, 0).SetReason("AWS job definition already exist.")
nextState.State = nextState.SetPhase(arrayCore.PhaseLaunch, terminalVersion).SetReason("AWS job definition already exist.")
return nextState, nil
}

Expand All @@ -83,7 +83,7 @@ func EnsureJobDefinition(ctx context.Context, tCtx pluginCore.TaskExecutionConte
}

nextState = currentState.SetJobDefinitionArn(arn)
nextState.State = nextState.SetPhase(arrayCore.PhaseLaunch, 0).SetReason("Created AWS job definition")
nextState.State = nextState.SetPhase(arrayCore.PhaseLaunch, terminalVersion).SetReason("Created AWS job definition")

return nextState, nil
}
8 changes: 4 additions & 4 deletions go/tasks/plugins/array/awsbatch/job_definition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestEnsureJobDefinition(t *testing.T) {

nextState, err := EnsureJobDefinition(ctx, tCtx, cfg, batchClient, dCache, &State{
State: &arrayCore.State{},
})
}, 0)

assert.NoError(t, err)
assert.NotNil(t, nextState)
Expand All @@ -95,7 +95,7 @@ func TestEnsureJobDefinition(t *testing.T) {

nextState, err := EnsureJobDefinition(ctx, tCtx, cfg, batchClient, dCache, &State{
State: &arrayCore.State{},
})
}, 0)
assert.NoError(t, err)
assert.NotNil(t, nextState)
assert.Equal(t, "their-arn", nextState.JobDefinitionArn)
Expand Down Expand Up @@ -152,7 +152,7 @@ func TestEnsureJobDefinitionWithSecurityContext(t *testing.T) {

nextState, err := EnsureJobDefinition(ctx, tCtx, cfg, batchClient, dCache, &State{
State: &arrayCore.State{},
})
}, 0)

assert.NoError(t, err)
assert.NotNil(t, nextState)
Expand All @@ -168,7 +168,7 @@ func TestEnsureJobDefinitionWithSecurityContext(t *testing.T) {

nextState, err := EnsureJobDefinition(ctx, tCtx, cfg, batchClient, dCache, &State{
State: &arrayCore.State{},
})
}, 0)
assert.NoError(t, err)
assert.NotNil(t, nextState)
assert.Equal(t, "their-arn", nextState.JobDefinitionArn)
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/array/awsbatch/launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
)

func LaunchSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, batchClient Client, pluginConfig *config.Config,
currentState *State, metrics ExecutorMetrics) (nextState *State, err error) {
currentState *State, metrics ExecutorMetrics, terminalVersion uint32) (nextState *State, err error) {
size := currentState.GetExecutionArraySize()

jobDefinition := currentState.GetJobDefinitionArn()
Expand Down Expand Up @@ -56,7 +56,7 @@ func LaunchSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, batchCl
}

parentState := currentState.
SetPhase(arrayCore.PhaseCheckingSubTaskExecutions, 0).
SetPhase(arrayCore.PhaseCheckingSubTaskExecutions, terminalVersion).
SetArrayStatus(arraystatus.ArrayStatus{
Summary: arraystatus.ArraySummary{
core.PhaseQueued: int64(size),
Expand Down
2 changes: 1 addition & 1 deletion go/tasks/plugins/array/awsbatch/launcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func TestLaunchSubTasks(t *testing.T) {
JobDefinitionArn: "arn",
}

newState, err := LaunchSubTasks(context.TODO(), tCtx, batchClient, &config.Config{MaxArrayJobSize: 10}, currentState, getAwsBatchExecutorMetrics(promutils.NewTestScope()))
newState, err := LaunchSubTasks(context.TODO(), tCtx, batchClient, &config.Config{MaxArrayJobSize: 10}, currentState, getAwsBatchExecutorMetrics(promutils.NewTestScope()), 0)
assert.NoError(t, err)
assertEqual(t, expectedState, newState)
})
Expand Down
2 changes: 1 addition & 1 deletion go/tasks/plugins/array/awsbatch/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func CheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionContext, job

parentState = parentState.SetPhase(phase, version).SetReason("Task is still running")
} else {
parentState = parentState.SetPhase(phase, version)
parentState = parentState.SetPhase(phase, version+1)
}

p, v := parentState.GetPhase()
Expand Down
49 changes: 35 additions & 14 deletions go/tasks/plugins/array/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"math"
"strconv"

idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
idlPlugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"

arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core"
Expand All @@ -19,8 +20,6 @@ import (
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils"

idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
)

const AwsBatchTaskType = "aws-batch"
Expand Down Expand Up @@ -212,19 +211,20 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
return state, nil
}

func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state *arrayCore.State, phaseOnSuccess arrayCore.Phase, versionOnSuccess uint32) (*arrayCore.State, error) {
func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state *arrayCore.State, phaseOnSuccess arrayCore.Phase, versionOnSuccess uint32) (*arrayCore.State, []*core.ExternalResource, error) {
var externalResources []*core.ExternalResource

// Check that the taskTemplate is valid
taskTemplate, err := tCtx.TaskReader().Read(ctx)
if err != nil {
return state, err
return state, externalResources, err
} else if taskTemplate == nil {
return state, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil")
return state, externalResources, errors.Errorf(errors.BadTaskSpecification, "Required value not set, taskTemplate is nil")
}

if tMeta := taskTemplate.Metadata; tMeta == nil || !tMeta.Discoverable {
logger.Debugf(ctx, "Task is not marked as discoverable. Moving to [%v] phase.", phaseOnSuccess)
return state.SetPhase(phaseOnSuccess, versionOnSuccess).SetReason("Task is not discoverable."), nil
return state.SetPhase(phaseOnSuccess, versionOnSuccess).SetReason("Task is not discoverable."), externalResources, nil
}

var inputReaders []io.InputReader
Expand All @@ -233,12 +233,12 @@ func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state
// input readers
inputReaders, err = ConstructRemoteFileInputReaders(ctx, tCtx.DataStore(), tCtx.InputReader().GetInputPrefixPath(), arrayJobSize)
if err != nil {
return nil, err
return nil, externalResources, err
}
} else {
inputs, err := tCtx.InputReader().Get(ctx)
if err != nil {
return state, errors.Errorf(errors.MetadataAccessFailed, "Could not read inputs and therefore failed to determine array job size")
return state, externalResources, errors.Errorf(errors.MetadataAccessFailed, "Could not read inputs and therefore failed to determine array job size")
}

var literalCollection *idlCore.LiteralCollection
Expand All @@ -257,18 +257,21 @@ func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state
// output reader
outputReaders, err := ConstructOutputReaders(ctx, tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), arrayJobSize)
if err != nil {
return nil, err
return nil, externalResources, err
}

iface := *taskTemplate.Interface
iface.Outputs = makeSingularTaskInterface(iface.Outputs)

// Do not cache failed tasks. Retrieve the final phase from array status and unset the non-successful ones.

tasksToCache := state.GetIndexesToCache().DeepCopy()
for idx, phaseIdx := range state.ArrayStatus.Detailed.GetItems() {
phase := core.Phases[phaseIdx]
if !phase.IsSuccess() {
tasksToCache.Clear(uint(idx))
// tasksToCache is built on the originalArraySize and ArrayStatus.Detailed is the executionArraySize
originalIdx := arrayCore.CalculateOriginalIndex(idx, state.GetIndexesToCache())
tasksToCache.Clear(uint(originalIdx))
}
}

Expand All @@ -278,24 +281,42 @@ func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state
iface, &tasksToCache, inputReaders, outputReaders)

if err != nil {
return nil, err
return nil, externalResources, err
}

if len(catalogWriterItems) == 0 {
state.SetPhase(phaseOnSuccess, versionOnSuccess).SetReason("No outputs need to be cached.")
return state, nil
return state, externalResources, nil
}

allWritten, err := WriteToCatalog(ctx, tCtx.TaskRefreshIndicator(), tCtx.Catalog(), catalogWriterItems)
if err != nil {
return nil, err
return nil, externalResources, err
}

if allWritten {
state.SetPhase(phaseOnSuccess, versionOnSuccess).SetReason("Finished writing catalog cache.")

// set CACHE_POPULATED CacheStatus on all cached subtasks
externalResources = make([]*core.ExternalResource, 0)
for idx, phaseIdx := range state.ArrayStatus.Detailed.GetItems() {
originalIdx := arrayCore.CalculateOriginalIndex(idx, state.GetIndexesToCache())
if !tasksToCache.IsSet(uint(originalIdx)) {
continue
}

externalResources = append(externalResources,
&core.ExternalResource{
CacheStatus: idlCore.CatalogCacheStatus_CACHE_POPULATED,
Index: uint32(originalIdx),
RetryAttempt: uint32(state.RetryAttempts.GetItem(idx)),
Phase: core.Phases[phaseIdx],
},
)
}
}

return state, nil
return state, externalResources, nil
}

func WriteToCatalog(ctx context.Context, ownerSignal core.SignalAsync, catalogClient catalog.AsyncClient,
Expand Down
6 changes: 1 addition & 5 deletions go/tasks/plugins/array/core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,7 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl
fallthrough

case PhaseWriteToDiscovery:
// The state version is only incremented in PhaseCheckingSubTaskExecutions when subtask
// phases are updated. Therefore by adding the phase to the state version we ensure that
// (1) all phase changes will have a new phase version and (2) all subtask phase updates
// result in monotonically increasing phase version.
phaseInfo = core.PhaseInfoRunning(version+uint32(p), nowTaskInfo)
phaseInfo = core.PhaseInfoRunning(version, nowTaskInfo)

case PhaseSuccess:
phaseInfo = core.PhaseInfoSuccess(nowTaskInfo)
Expand Down
8 changes: 4 additions & 4 deletions go/tasks/plugins/array/core/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestMapArrayStateToPluginPhase(t *testing.T) {
phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, nil)
assert.NoError(t, err)
assert.Equal(t, core.PhaseRunning, phaseInfo.Phase())
assert.Equal(t, uint32(12), phaseInfo.Version())
assert.Equal(t, uint32(8), phaseInfo.Version())
})

t.Run("write to discovery", func(t *testing.T) {
Expand All @@ -116,7 +116,7 @@ func TestMapArrayStateToPluginPhase(t *testing.T) {
phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, nil)
assert.NoError(t, err)
assert.Equal(t, core.PhaseRunning, phaseInfo.Phase())
assert.Equal(t, uint32(14), phaseInfo.Version())
assert.Equal(t, uint32(8), phaseInfo.Version())
})

t.Run("success", func(t *testing.T) {
Expand Down Expand Up @@ -304,11 +304,11 @@ func TestSummaryToPhase(t *testing.T) {
},
},
{
"FailOnToManyPermanentFailures",
"FailOnTooManyPermanentFailures",
PhaseWriteToDiscoveryThenFail,
map[core.Phase]int64{
core.PhasePermanentFailure: 1,
core.PhaseUndefined: 9,
core.PhaseSuccess: 9,
},
},
{
Expand Down
Loading

0 comments on commit e79acfb

Please sign in to comment.