Skip to content

Commit

Permalink
Use monotonically increasing phase version on map tasks (flyteorg#254)
Browse files Browse the repository at this point in the history
* updating phase version state

Signed-off-by: Daniel Rammer <[email protected]>

* added HashCode function comment

Signed-off-by: Daniel Rammer <[email protected]>

* using arraystatus hash codes in aws_batch plugin

Signed-off-by: Daniel Rammer <[email protected]>

* updated map to running phase comment

Signed-off-by: Daniel Rammer <[email protected]>

* added ArrayStatus.HashCode unit tests

Signed-off-by: Daniel Rammer <[email protected]>

* removed unecessary debugging logs

Signed-off-by: Daniel Rammer <[email protected]>

* fixed unit tests

Signed-off-by: Daniel Rammer <[email protected]>

* fixed lint issues

Signed-off-by: Daniel Rammer <[email protected]>
  • Loading branch information
hamersaw authored Apr 8, 2022
1 parent 8d33e2a commit cb9fdb5
Show file tree
Hide file tree
Showing 11 changed files with 166 additions and 66 deletions.
19 changes: 19 additions & 0 deletions flyteplugins/go/tasks/plugins/array/arraystatus/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
package arraystatus

import (
"encoding/binary"
"hash/fnv"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flytestdlib/bitarray"
)
Expand All @@ -20,6 +23,22 @@ type ArrayStatus struct {
Detailed bitarray.CompactArray `json:"details"`
}

// HashCode computes a hash of the phase indicies stored in the Detailed array to uniquely represent
// a collection of subtask phases.
func (a ArrayStatus) HashCode() (uint64, error) {
hash := fnv.New64()
bytes := make([]byte, 8)
for _, phaseIndex := range a.Detailed.GetItems() {
binary.LittleEndian.PutUint64(bytes, phaseIndex)
_, err := hash.Write(bytes)
if err != nil {
return 0, err
}
}

return hash.Sum64(), nil
}

// This is a status object that is returned after we make Catalog calls to see if subtasks are Cached
type ArrayCachedStatus struct {
CachedJobs *bitarray.BitSet `json:"cachedJobs"`
Expand Down
82 changes: 82 additions & 0 deletions flyteplugins/go/tasks/plugins/array/arraystatus/status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,91 @@ import (
"testing"

types "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"

"github.com/flyteorg/flytestdlib/bitarray"

"github.com/stretchr/testify/assert"
)

func TestArrayStatus_HashCode(t *testing.T) {
size := uint(10)

t.Run("Empty Equal", func(t *testing.T) {
expected := ArrayStatus{}
expectedHashCode, err := expected.HashCode()
assert.Nil(t, err)

actual := ArrayStatus{}
actualHashCode, err := actual.HashCode()
assert.Nil(t, err)

assert.Equal(t, expectedHashCode, actualHashCode)
})

t.Run("Populated Equal", func(t *testing.T) {
expectedDetailed, err := bitarray.NewCompactArray(size, bitarray.Item(len(types.Phases)-1))
assert.Nil(t, err)
expected := ArrayStatus{
Detailed: expectedDetailed,
}
expectedHashCode, err := expected.HashCode()
assert.Nil(t, err)

actualDetailed, err := bitarray.NewCompactArray(size, bitarray.Item(len(types.Phases)-1))
assert.Nil(t, err)
actual := ArrayStatus{
Detailed: actualDetailed,
}
actualHashCode, err := actual.HashCode()
assert.Nil(t, err)

assert.Equal(t, expectedHashCode, actualHashCode)
})

t.Run("Updated Not Equal", func(t *testing.T) {
expectedDetailed, err := bitarray.NewCompactArray(size, bitarray.Item(len(types.Phases)-1))
assert.Nil(t, err)
expectedDetailed.SetItem(0, uint64(1))
expected := ArrayStatus{
Detailed: expectedDetailed,
}
expectedHashCode, err := expected.HashCode()
assert.Nil(t, err)

actualDetailed, err := bitarray.NewCompactArray(size, bitarray.Item(len(types.Phases)-1))
assert.Nil(t, err)
actual := ArrayStatus{
Detailed: actualDetailed,
}
actualHashCode, err := actual.HashCode()
assert.Nil(t, err)

assert.NotEqual(t, expectedHashCode, actualHashCode)
})

t.Run("Updated Equal", func(t *testing.T) {
expectedDetailed, err := bitarray.NewCompactArray(size, bitarray.Item(len(types.Phases)-1))
assert.Nil(t, err)
expectedDetailed.SetItem(0, uint64(1))
expected := ArrayStatus{
Detailed: expectedDetailed,
}
expectedHashCode, err := expected.HashCode()
assert.Nil(t, err)

actualDetailed, err := bitarray.NewCompactArray(size, bitarray.Item(len(types.Phases)-1))
actualDetailed.SetItem(0, uint64(1))
assert.Nil(t, err)
actual := ArrayStatus{
Detailed: actualDetailed,
}
actualHashCode, err := actual.HashCode()
assert.Nil(t, err)

assert.Equal(t, expectedHashCode, actualHashCode)
})
}

func TestArraySummary_MergeFrom(t *testing.T) {
t.Run("Update when not equal", func(t *testing.T) {
expected := ArraySummary{
Expand Down
10 changes: 5 additions & 5 deletions flyteplugins/go/tasks/plugins/array/awsbatch/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c

var err error

p, _ := pluginState.GetPhase()
p, version := pluginState.GetPhase()
logger.Infof(ctx, "Entering handle with phase [%v]", p)

switch p {
Expand All @@ -85,16 +85,16 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c
e.jobStore, tCtx.DataStore(), pluginConfig, pluginState, e.metrics)

case arrayCore.PhaseAssembleFinalOutput:
pluginState.State, err = array.AssembleFinalOutputs(ctx, e.outputAssembler, tCtx, arrayCore.PhaseSuccess, pluginState.State)
pluginState.State, err = array.AssembleFinalOutputs(ctx, e.outputAssembler, tCtx, arrayCore.PhaseSuccess, version, pluginState.State)

case arrayCore.PhaseWriteToDiscoveryThenFail:
pluginState.State, err = array.WriteToDiscovery(ctx, tCtx, pluginState.State, arrayCore.PhaseAssembleFinalError)
pluginState.State, err = array.WriteToDiscovery(ctx, tCtx, pluginState.State, arrayCore.PhaseAssembleFinalError, version)

case arrayCore.PhaseWriteToDiscovery:
pluginState.State, err = array.WriteToDiscovery(ctx, tCtx, pluginState.State, arrayCore.PhaseAssembleFinalOutput)
pluginState.State, err = array.WriteToDiscovery(ctx, tCtx, pluginState.State, arrayCore.PhaseAssembleFinalOutput, version)

case arrayCore.PhaseAssembleFinalError:
pluginState.State, err = array.AssembleFinalOutputs(ctx, e.errorAssembler, tCtx, arrayCore.PhaseRetryableFailure, pluginState.State)
pluginState.State, err = array.AssembleFinalOutputs(ctx, e.errorAssembler, tCtx, arrayCore.PhaseRetryableFailure, version, pluginState.State)
}

if err != nil {
Expand Down
21 changes: 15 additions & 6 deletions flyteplugins/go/tasks/plugins/array/awsbatch/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata
Detailed: arrayCore.NewPhasesCompactArray(uint(currentState.GetExecutionArraySize())),
}

currentSubTaskPhaseHash, err := currentState.GetArrayStatus().HashCode()
if err != nil {
return currentState, err
}

queued := 0
for childIdx, subJob := range job.SubJobs {
actualPhase := subJob.Status.Phase
Expand Down Expand Up @@ -132,16 +137,20 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata
errorMsg := msg.Summary(cfg.MaxErrorStringLength)
parentState = parentState.SetReason(errorMsg)
}
_, version := currentState.GetPhase()
if phase == arrayCore.PhaseCheckingSubTaskExecutions {
newPhaseVersion := uint32(0)
// For now, the only changes to PhaseVersion and PreviousSummary occur for running array jobs.
for phase, count := range parentState.GetArrayStatus().Summary {
newPhaseVersion += uint32(phase) * uint32(count)
newSubTaskPhaseHash, err := parentState.GetArrayStatus().HashCode()
if err != nil {
return currentState, err
}

if newSubTaskPhaseHash != currentSubTaskPhaseHash {
version++
}

parentState = parentState.SetPhase(phase, newPhaseVersion).SetReason("Task is still running.")
parentState = parentState.SetPhase(phase, version).SetReason("Task is still running")
} else {
parentState = parentState.SetPhase(phase, core.DefaultPhaseVersion)
parentState = parentState.SetPhase(phase, version)
}

p, v := parentState.GetPhase()
Expand Down
8 changes: 4 additions & 4 deletions flyteplugins/go/tasks/plugins/array/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ func DetermineDiscoverability(ctx context.Context, tCtx core.TaskExecutionContex
return state, nil
}

func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state *arrayCore.State, phaseOnSuccess arrayCore.Phase) (*arrayCore.State, error) {
func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state *arrayCore.State, phaseOnSuccess arrayCore.Phase, versionOnSuccess uint32) (*arrayCore.State, error) {

// Check that the taskTemplate is valid
taskTemplate, err := tCtx.TaskReader().Read(ctx)
Expand All @@ -205,7 +205,7 @@ func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state

if tMeta := taskTemplate.Metadata; tMeta == nil || !tMeta.Discoverable {
logger.Debugf(ctx, "Task is not marked as discoverable. Moving to [%v] phase.", phaseOnSuccess)
return state.SetPhase(phaseOnSuccess, core.DefaultPhaseVersion).SetReason("Task is not discoverable."), nil
return state.SetPhase(phaseOnSuccess, versionOnSuccess).SetReason("Task is not discoverable."), nil
}

var inputReaders []io.InputReader
Expand Down Expand Up @@ -263,7 +263,7 @@ func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state
}

if len(catalogWriterItems) == 0 {
state.SetPhase(phaseOnSuccess, core.DefaultPhaseVersion).SetReason("No outputs need to be cached.")
state.SetPhase(phaseOnSuccess, versionOnSuccess).SetReason("No outputs need to be cached.")
return state, nil
}

Expand All @@ -273,7 +273,7 @@ func WriteToDiscovery(ctx context.Context, tCtx core.TaskExecutionContext, state
}

if allWritten {
state.SetPhase(phaseOnSuccess, core.DefaultPhaseVersion).SetReason("Finished writing catalog cache.")
state.SetPhase(phaseOnSuccess, versionOnSuccess).SetReason("Finished writing catalog cache.")
}

return state, nil
Expand Down
33 changes: 10 additions & 23 deletions flyteplugins/go/tasks/plugins/array/core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,6 @@ func ToArrayJob(structObj *structpb.Struct, taskTypeVersion int32) (*idlPlugins.
return arrayJob, err
}

func GetPhaseVersionOffset(currentPhase Phase, length int64) uint32 {
// NB: Make sure this is the last/highest value of the Phase!
return uint32(length * (int64(core.PhasePermanentFailure) + 1) * int64(currentPhase))
}

// Any state of the plugin needs to map to a core.PhaseInfo (which in turn will map to Admin events) so that the rest
// of the Flyte platform can understand what's happening. That is, each possible state that our plugin state
// machine returns should map to a unique (core.Phase, core.PhaseInfo.version).
Expand All @@ -189,20 +184,16 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl
case PhaseStart:
phaseInfo = core.PhaseInfoInitializing(t, core.DefaultPhaseVersion, state.GetReason(), nowTaskInfo)

case PhaseWaitingForResources:
phaseInfo = core.PhaseInfoWaitingForResourcesInfo(t, version, state.GetReason(), nowTaskInfo)

case PhasePreLaunch:
version := GetPhaseVersionOffset(p, 1) + version
phaseInfo = core.PhaseInfoRunning(version, nowTaskInfo)
fallthrough

case PhaseLaunch:
// The first time we return a Running core.Phase, we can just use the version inside the state object itself.
phaseInfo = core.PhaseInfoRunning(version, nowTaskInfo)

case PhaseWaitingForResources:
phaseInfo = core.PhaseInfoWaitingForResourcesInfo(t, version, state.GetReason(), nowTaskInfo)
fallthrough

case PhaseCheckingSubTaskExecutions:
// For future Running core.Phases, we have to make sure we don't use an earlier Admin version number,
// which means we need to offset things.
fallthrough

case PhaseAssembleFinalOutput:
Expand All @@ -215,15 +206,11 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl
fallthrough

case PhaseWriteToDiscovery:
// If the array task has 0 inputs we need to ensure the phaseVersion changes so that the
// task can progess. Therefore we default to task length 1 to ensure phase updates.
length := int64(1)
if state.GetOriginalArraySize() != 0 {
length = state.GetOriginalArraySize()
}

version := GetPhaseVersionOffset(p, length) + version
phaseInfo = core.PhaseInfoRunning(version, nowTaskInfo)
// The state version is only incremented in PhaseCheckingSubTaskExecutions when subtask
// phases are updated. Therefore by adding the phase to the state version we ensure that
// (1) all phase changes will have a new phase version and (2) all subtask phase updates
// result in monotonically increasing phase version.
phaseInfo = core.PhaseInfoRunning(version+uint32(p), nowTaskInfo)

case PhaseSuccess:
phaseInfo = core.PhaseInfoSuccess(nowTaskInfo)
Expand Down
12 changes: 2 additions & 10 deletions flyteplugins/go/tasks/plugins/array/core/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,6 @@ import (
"github.com/stretchr/testify/assert"
)

func TestGetPhaseVersionOffset(t *testing.T) {
length := int64(100)
checkSubTasksOffset := GetPhaseVersionOffset(PhaseAssembleFinalOutput, length)
discoverWriteOffset := GetPhaseVersionOffset(PhaseWriteToDiscovery, length)
// There are 9 possible core.Phases, from PhaseUndefined to PhasePermanentFailure
assert.Equal(t, uint32(length*9), discoverWriteOffset-checkSubTasksOffset)
}

func TestInvertBitSet(t *testing.T) {
input := bitarray.NewBitSet(4)
input.Set(0)
Expand Down Expand Up @@ -105,7 +97,7 @@ func TestMapArrayStateToPluginPhase(t *testing.T) {
phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, nil)
assert.NoError(t, err)
assert.Equal(t, core.PhaseRunning, phaseInfo.Phase())
assert.Equal(t, uint32(368), phaseInfo.Version())
assert.Equal(t, uint32(12), phaseInfo.Version())
})

t.Run("write to discovery", func(t *testing.T) {
Expand All @@ -124,7 +116,7 @@ func TestMapArrayStateToPluginPhase(t *testing.T) {
phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, nil)
assert.NoError(t, err)
assert.Equal(t, core.PhaseRunning, phaseInfo.Phase())
assert.Equal(t, uint32(548), phaseInfo.Version())
assert.Equal(t, uint32(14), phaseInfo.Version())
})

t.Run("success", func(t *testing.T) {
Expand Down
8 changes: 4 additions & 4 deletions flyteplugins/go/tasks/plugins/array/k8s/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,16 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c
tCtx.DataStore(), tCtx.OutputWriter().GetOutputPrefixPath(), tCtx.OutputWriter().GetRawOutputPrefix(), pluginState)

case arrayCore.PhaseAssembleFinalOutput:
nextState, err = array.AssembleFinalOutputs(ctx, e.outputsAssembler, tCtx, arrayCore.PhaseSuccess, pluginState)
nextState, err = array.AssembleFinalOutputs(ctx, e.outputsAssembler, tCtx, arrayCore.PhaseSuccess, version, pluginState)

case arrayCore.PhaseWriteToDiscoveryThenFail:
nextState, err = array.WriteToDiscovery(ctx, tCtx, pluginState, arrayCore.PhaseAssembleFinalError)
nextState, err = array.WriteToDiscovery(ctx, tCtx, pluginState, arrayCore.PhaseAssembleFinalError, version)

case arrayCore.PhaseWriteToDiscovery:
nextState, err = array.WriteToDiscovery(ctx, tCtx, pluginState, arrayCore.PhaseAssembleFinalOutput)
nextState, err = array.WriteToDiscovery(ctx, tCtx, pluginState, arrayCore.PhaseAssembleFinalOutput, version)

case arrayCore.PhaseAssembleFinalError:
nextState, err = array.AssembleFinalOutputs(ctx, e.errorAssembler, tCtx, arrayCore.PhasePermanentFailure, pluginState)
nextState, err = array.AssembleFinalOutputs(ctx, e.errorAssembler, tCtx, arrayCore.PhasePermanentFailure, version, pluginState)

default:
nextState = pluginState
Expand Down
20 changes: 14 additions & 6 deletions flyteplugins/go/tasks/plugins/array/k8s/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon
currentParallelism := 0
maxParallelism := int(arrayJob.Parallelism)

currentSubTaskPhaseHash, err := currentState.GetArrayStatus().HashCode()
if err != nil {
return currentState, externalResources, err
}

for childIdx, existingPhaseIdx := range currentState.GetArrayStatus().Detailed.GetItems() {
existingPhase := core.Phases[existingPhaseIdx]
retryAttempt := currentState.RetryAttempts.GetItem(childIdx)
Expand Down Expand Up @@ -255,17 +260,20 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon
newState = newState.SetReason(errorMsg)
}

_, version := currentState.GetPhase()
if phase == arrayCore.PhaseCheckingSubTaskExecutions {
newPhaseVersion := uint32(0)
newSubTaskPhaseHash, err := newState.GetArrayStatus().HashCode()
if err != nil {
return currentState, externalResources, err
}

// For now, the only changes to PhaseVersion and PreviousSummary occur for running array jobs.
for phase, count := range newState.GetArrayStatus().Summary {
newPhaseVersion += uint32(phase) * uint32(count)
if newSubTaskPhaseHash != currentSubTaskPhaseHash {
version++
}

newState = newState.SetPhase(phase, newPhaseVersion).SetReason("Task is still running.")
newState = newState.SetPhase(phase, version).SetReason("Task is still running")
} else {
newState = newState.SetPhase(phase, core.DefaultPhaseVersion)
newState = newState.SetPhase(phase, version)
}

return newState, externalResources, nil
Expand Down
Loading

0 comments on commit cb9fdb5

Please sign in to comment.