From 0637c345b54ec62277fbd6e47d5b04cbc7ee5929 Mon Sep 17 00:00:00 2001 From: Dan Rammer Date: Mon, 24 Jan 2022 14:35:43 -0600 Subject: [PATCH] Track RetryAttempt and Phase of ExternalResources (#231) * replaced Metadata proto in TaskInfo with ExternalResource array Signed-off-by: Daniel Rammer * added ExternalResource documentation comments Signed-off-by: Daniel Rammer * setting retry attempt on external resources Signed-off-by: Daniel Rammer * fixed unit tests and lint issues Signed-off-by: Daniel Rammer * tracking RetryAttempt for k8s array plugin using a CompactArray Signed-off-by: Daniel Rammer * added a few comments Signed-off-by: Daniel Rammer * setting RetryAttempts for awsbatch subtasks Signed-off-by: Daniel Rammer * fixed unit tests and lint issues Signed-off-by: Daniel Rammer * populating external resource index with original index Signed-off-by: Daniel Rammer * updated comments Signed-off-by: Daniel Rammer * updated flyteidl version Signed-off-by: Daniel Rammer --- go/tasks/pluginmachinery/core/phase.go | 18 +++- .../pluginmachinery/webapi/example/plugin.go | 10 +-- go/tasks/plugins/array/awsbatch/launcher.go | 11 ++- .../plugins/array/awsbatch/launcher_test.go | 5 ++ go/tasks/plugins/array/awsbatch/monitor.go | 1 + .../plugins/array/awsbatch/monitor_test.go | 8 ++ go/tasks/plugins/array/core/state.go | 34 +++++--- go/tasks/plugins/array/core/state_test.go | 83 +++++++++++++++---- go/tasks/plugins/array/k8s/monitor.go | 24 ++++++ go/tasks/plugins/array/k8s/monitor_test.go | 13 +++ go/tasks/plugins/hive/execution_state.go | 16 ++-- go/tasks/plugins/hive/execution_state_test.go | 12 +-- .../k8s/sagemaker/builtin_training_test.go | 12 +-- .../k8s/sagemaker/custom_training_test.go | 12 +-- .../sagemaker/hyperparameter_tuning_test.go | 12 +-- go/tasks/plugins/k8s/sagemaker/utils.go | 9 +- go/tasks/plugins/presto/execution_state.go | 10 +-- .../plugins/presto/execution_state_test.go | 4 +- go/tasks/plugins/webapi/athena/plugin.go | 10 +-- go/tasks/plugins/webapi/athena/plugin_test.go | 11 +-- tests/end_to_end.go | 1 + 21 files changed, 193 insertions(+), 123 deletions(-) diff --git a/go/tasks/pluginmachinery/core/phase.go b/go/tasks/pluginmachinery/core/phase.go index fe4aed060f..4e0c15dc53 100644 --- a/go/tasks/pluginmachinery/core/phase.go +++ b/go/tasks/pluginmachinery/core/phase.go @@ -4,8 +4,6 @@ import ( "fmt" "time" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" structpb "github.com/golang/protobuf/ptypes/struct" ) @@ -69,6 +67,18 @@ func (p Phase) IsWaitingForResources() bool { return p == PhaseWaitingForResources } +type ExternalResource struct { + // A unique identifier for the external resource + ExternalID string + // A unique index for the external resource. Although the ID may change, this will remain the same + // throughout task event reports and retries. + Index uint32 + // The nubmer of times this external resource has been attempted + RetryAttempt uint32 + // Phase (if exists) associated with the external resource + Phase Phase +} + type TaskInfo struct { // log information for the task execution Logs []*core.TaskLog @@ -77,8 +87,8 @@ type TaskInfo struct { OccurredAt *time.Time // Custom Event information that the plugin would like to expose to the front-end CustomInfo *structpb.Struct - // Metadata around how a task was executed - Metadata *event.TaskExecutionMetadata + // A collection of information about external resources launched by this task + ExternalResources []*ExternalResource } func (t *TaskInfo) String() string { diff --git a/go/tasks/pluginmachinery/webapi/example/plugin.go b/go/tasks/pluginmachinery/webapi/example/plugin.go index 401cb90643..2c300e92f3 100644 --- a/go/tasks/pluginmachinery/webapi/example/plugin.go +++ b/go/tasks/pluginmachinery/webapi/example/plugin.go @@ -4,8 +4,6 @@ import ( "context" "time" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/errors" @@ -96,11 +94,9 @@ func (p Plugin) Status(ctx context.Context, tCtx webapi.StatusContext) (phase co }, }, OccurredAt: &tNow, - Metadata: &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: "abc", - }, + ExternalResources: []*core.ExternalResource{ + { + ExternalID: "abc", }, }, }), nil diff --git a/go/tasks/plugins/array/awsbatch/launcher.go b/go/tasks/plugins/array/awsbatch/launcher.go index 32f2c5d1d6..8959af2671 100644 --- a/go/tasks/plugins/array/awsbatch/launcher.go +++ b/go/tasks/plugins/array/awsbatch/launcher.go @@ -6,6 +6,7 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/errors" + "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flytestdlib/logger" arrayCore "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core" @@ -53,6 +54,13 @@ func LaunchSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, batchCl } metrics.SubTasksSubmitted.Add(ctx, float64(size)) + + retryAttemptsArray, err := bitarray.NewCompactArray(uint(size), bitarray.Item(pluginConfig.MaxRetries)) + if err != nil { + logger.Errorf(context.Background(), "Failed to create attempts compact array with [count: %v, maxValue: %v]", size, pluginConfig.MaxRetries) + return nil, err + } + parentState := currentState. SetPhase(arrayCore.PhaseCheckingSubTaskExecutions, 0). SetArrayStatus(arraystatus.ArrayStatus{ @@ -61,7 +69,8 @@ func LaunchSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, batchCl }, Detailed: arrayCore.NewPhasesCompactArray(uint(size)), }). - SetReason("Successfully launched subtasks.") + SetReason("Successfully launched subtasks."). + SetRetryAttempts(retryAttemptsArray) nextState = currentState.SetExternalJobID(j) nextState.State = parentState diff --git a/go/tasks/plugins/array/awsbatch/launcher_test.go b/go/tasks/plugins/array/awsbatch/launcher_test.go index 5520a9b3df..d135500b16 100644 --- a/go/tasks/plugins/array/awsbatch/launcher_test.go +++ b/go/tasks/plugins/array/awsbatch/launcher_test.go @@ -3,6 +3,7 @@ package awsbatch import ( "testing" + "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flytestdlib/promutils" "github.com/stretchr/testify/mock" @@ -110,6 +111,9 @@ func TestLaunchSubTasks(t *testing.T) { JobDefinitionArn: "arn", } + retryAttemptsArray, err := bitarray.NewCompactArray(5, bitarray.Item(0)) + assert.NoError(t, err) + expectedState := &State{ State: &core2.State{ CurrentPhase: core2.PhaseCheckingSubTaskExecutions, @@ -123,6 +127,7 @@ func TestLaunchSubTasks(t *testing.T) { }, Detailed: arrayCore.NewPhasesCompactArray(5), }, + RetryAttempts: retryAttemptsArray, }, ExternalJobID: refStr("qpxyarq"), diff --git a/go/tasks/plugins/array/awsbatch/monitor.go b/go/tasks/plugins/array/awsbatch/monitor.go index 87ef4ade73..a7d033aa15 100644 --- a/go/tasks/plugins/array/awsbatch/monitor.go +++ b/go/tasks/plugins/array/awsbatch/monitor.go @@ -112,6 +112,7 @@ func CheckSubTasksState(ctx context.Context, taskMeta core.TaskExecutionMetadata newArrayStatus.Detailed.SetItem(childIdx, bitarray.Item(actualPhase)) newArrayStatus.Summary.Inc(actualPhase) + parentState.RetryAttempts.SetItem(childIdx, bitarray.Item(len(subJob.Attempts))) } if queued > 0 { diff --git a/go/tasks/plugins/array/awsbatch/monitor_test.go b/go/tasks/plugins/array/awsbatch/monitor_test.go index b54af44659..621512b4a0 100644 --- a/go/tasks/plugins/array/awsbatch/monitor_test.go +++ b/go/tasks/plugins/array/awsbatch/monitor_test.go @@ -136,6 +136,9 @@ func TestCheckSubTasksState(t *testing.T) { inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) assert.NoError(t, err) + retryAttemptsArray, err := bitarray.NewCompactArray(1, bitarray.Item(1)) + assert.NoError(t, err) + newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, @@ -146,6 +149,7 @@ func TestCheckSubTasksState(t *testing.T) { Detailed: arrayCore.NewPhasesCompactArray(1), }, IndexesToCache: bitarray.NewBitSet(1), + RetryAttempts: retryAttemptsArray, }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", @@ -180,6 +184,9 @@ func TestCheckSubTasksState(t *testing.T) { inMemDatastore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) assert.NoError(t, err) + retryAttemptsArray, err := bitarray.NewCompactArray(2, bitarray.Item(1)) + assert.NoError(t, err) + newState, err := CheckSubTasksState(ctx, tMeta, "", "", jobStore, inMemDatastore, &config.Config{}, &State{ State: &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, @@ -190,6 +197,7 @@ func TestCheckSubTasksState(t *testing.T) { Detailed: arrayCore.NewPhasesCompactArray(2), }, IndexesToCache: bitarray.NewBitSet(2), + RetryAttempts: retryAttemptsArray, }, ExternalJobID: refStr("job-id"), JobDefinitionArn: "", diff --git a/go/tasks/plugins/array/core/state.go b/go/tasks/plugins/array/core/state.go index 3703d6480b..a8908fec04 100644 --- a/go/tasks/plugins/array/core/state.go +++ b/go/tasks/plugins/array/core/state.go @@ -5,8 +5,6 @@ import ( "fmt" "time" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flytestdlib/errors" "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" @@ -52,6 +50,9 @@ type State struct { // Which sub-tasks to cache, (using the original index, that is, the length is ArrayJob.size) IndexesToCache *bitarray.BitSet `json:"indexesToCache"` + + // Tracks the number of subtask retries using the execution index + RetryAttempts bitarray.CompactArray `json:"retryAttempts"` } func (s State) GetReason() string { @@ -111,6 +112,11 @@ func (s *State) SetReason(reason string) *State { return s } +func (s *State) SetRetryAttempts(retryAttempts bitarray.CompactArray) *State { + s.RetryAttempts = retryAttempts + return s +} + func (s *State) SetExecutionArraySize(size int) *State { s.ExecutionArraySize = size return s @@ -171,20 +177,24 @@ func GetPhaseVersionOffset(currentPhase Phase, length int64) uint32 { // handling as we don't have to keep an ever growing list of log links (our batch jobs can be 5000 sub-tasks, keeping // all the log links takes up a lot of space). func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idlCore.TaskLog, subTaskIDs []*string) (core.PhaseInfo, error) { - phaseInfo := core.PhaseInfoUndefined t := time.Now() + nowTaskInfo := &core.TaskInfo{ - OccurredAt: &t, - Logs: logLinks, - } - if nowTaskInfo.Metadata == nil { - nowTaskInfo.Metadata = &event.TaskExecutionMetadata{} + OccurredAt: &t, + Logs: logLinks, + ExternalResources: make([]*core.ExternalResource, len(subTaskIDs)), } - for _, subTaskID := range subTaskIDs { - nowTaskInfo.Metadata.ExternalResources = append(nowTaskInfo.Metadata.ExternalResources, &event.ExternalResourceInfo{ - ExternalId: *subTaskID, - }) + + for childIndex, subTaskID := range subTaskIDs { + originalIndex := CalculateOriginalIndex(childIndex, state.GetIndexesToCache()) + + nowTaskInfo.ExternalResources[childIndex] = &core.ExternalResource{ + ExternalID: *subTaskID, + Index: uint32(originalIndex), + RetryAttempt: uint32(state.RetryAttempts.GetItem(childIndex)), + Phase: core.Phases[state.ArrayStatus.Detailed.GetItem(childIndex)], + } } switch p, version := state.GetPhase(); p { diff --git a/go/tasks/plugins/array/core/state_test.go b/go/tasks/plugins/array/core/state_test.go index 8766126239..1f477a874f 100644 --- a/go/tasks/plugins/array/core/state_test.go +++ b/go/tasks/plugins/array/core/state_test.go @@ -5,14 +5,13 @@ import ( "fmt" "testing" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/golang/protobuf/proto" "github.com/flyteorg/flytestdlib/bitarray" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/plugins/array/arraystatus" "github.com/stretchr/testify/assert" ) @@ -51,30 +50,43 @@ func assertBitSetsEqual(t testing.TB, b1, b2 *bitarray.BitSet, len int) { } } -func assertTaskExecutionMetadata(t *testing.T, subTaskIDs []*string, metadata *event.TaskExecutionMetadata) { - assert.NotNil(t, metadata) - var externalResources = make([]*event.ExternalResourceInfo, len(subTaskIDs)) +func assertTaskExternalResources(t *testing.T, subTaskIDs []*string, retryAttemptsArray *bitarray.CompactArray, detailedArray *bitarray.CompactArray, externalResources []*core.ExternalResource) { + assert.NotNil(t, externalResources) for i, subTaskID := range subTaskIDs { - externalResources[i] = &event.ExternalResourceInfo{ - ExternalId: *subTaskID, - } + externalResource := externalResources[i] + assert.Equal(t, *subTaskID, externalResource.ExternalID) + assert.Equal(t, retryAttemptsArray.GetItem(i), bitarray.Item(externalResource.RetryAttempt)) + assert.Equal(t, core.Phases[detailedArray.GetItem(i)], externalResource.Phase) } - assert.True(t, proto.Equal(&event.TaskExecutionMetadata{ - ExternalResources: externalResources, - }, metadata)) } func TestMapArrayStateToPluginPhase(t *testing.T) { ctx := context.Background() - var subTaskIDs = make([]*string, 3) - for i := 0; i < 3; i++ { + + subTaskCount := 3 + + var subTaskIDs = make([]*string, subTaskCount) + detailedArray := NewPhasesCompactArray(uint(subTaskCount)) + indexesToCache := InvertBitSet(bitarray.NewBitSet(uint(subTaskCount)), uint(subTaskCount)) + retryAttemptsArray, err := bitarray.NewCompactArray(uint(subTaskCount), bitarray.Item(1)) + assert.NoError(t, err) + + for i := 0; i < subTaskCount; i++ { subTaskID := fmt.Sprintf("sub_task_%d", i) subTaskIDs[i] = &subTaskID + + detailedArray.SetItem(i, bitarray.Item(core.PhaseRunning)) + retryAttemptsArray.SetItem(i, bitarray.Item(1)) } t.Run("start", func(t *testing.T) { s := State{ CurrentPhase: PhaseStart, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailedArray, + }, + IndexesToCache: indexesToCache, + RetryAttempts: retryAttemptsArray, } phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) @@ -85,6 +97,11 @@ func TestMapArrayStateToPluginPhase(t *testing.T) { s := State{ CurrentPhase: PhaseLaunch, PhaseVersion: 0, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailedArray, + }, + IndexesToCache: indexesToCache, + RetryAttempts: retryAttemptsArray, } phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) @@ -98,13 +115,18 @@ func TestMapArrayStateToPluginPhase(t *testing.T) { PhaseVersion: 8, OriginalArraySize: 10, ExecutionArraySize: 5, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailedArray, + }, + IndexesToCache: indexesToCache, + RetryAttempts: retryAttemptsArray, } phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.Equal(t, core.PhaseRunning, phaseInfo.Phase()) assert.Equal(t, uint32(368), phaseInfo.Version()) - assertTaskExecutionMetadata(t, subTaskIDs, phaseInfo.Info().Metadata) + assertTaskExternalResources(t, subTaskIDs, &retryAttemptsArray, &detailedArray, phaseInfo.Info().ExternalResources) }) t.Run("write to discovery", func(t *testing.T) { @@ -113,55 +135,80 @@ func TestMapArrayStateToPluginPhase(t *testing.T) { PhaseVersion: 8, OriginalArraySize: 10, ExecutionArraySize: 5, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailedArray, + }, + IndexesToCache: indexesToCache, + RetryAttempts: retryAttemptsArray, } phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.Equal(t, core.PhaseRunning, phaseInfo.Phase()) assert.Equal(t, uint32(548), phaseInfo.Version()) - assertTaskExecutionMetadata(t, subTaskIDs, phaseInfo.Info().Metadata) + assertTaskExternalResources(t, subTaskIDs, &retryAttemptsArray, &detailedArray, phaseInfo.Info().ExternalResources) }) t.Run("success", func(t *testing.T) { s := State{ CurrentPhase: PhaseSuccess, PhaseVersion: 0, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailedArray, + }, + IndexesToCache: indexesToCache, + RetryAttempts: retryAttemptsArray, } phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.Equal(t, core.PhaseSuccess, phaseInfo.Phase()) - assertTaskExecutionMetadata(t, subTaskIDs, phaseInfo.Info().Metadata) + assertTaskExternalResources(t, subTaskIDs, &retryAttemptsArray, &detailedArray, phaseInfo.Info().ExternalResources) }) t.Run("retryable failure", func(t *testing.T) { s := State{ CurrentPhase: PhaseRetryableFailure, PhaseVersion: 0, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailedArray, + }, + IndexesToCache: indexesToCache, + RetryAttempts: retryAttemptsArray, } phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.Equal(t, core.PhaseRetryableFailure, phaseInfo.Phase()) - assertTaskExecutionMetadata(t, subTaskIDs, phaseInfo.Info().Metadata) + assertTaskExternalResources(t, subTaskIDs, &retryAttemptsArray, &detailedArray, phaseInfo.Info().ExternalResources) }) t.Run("permanent failure", func(t *testing.T) { s := State{ CurrentPhase: PhasePermanentFailure, PhaseVersion: 0, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailedArray, + }, + IndexesToCache: indexesToCache, + RetryAttempts: retryAttemptsArray, } phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) assert.NoError(t, err) assert.Equal(t, core.PhasePermanentFailure, phaseInfo.Phase()) - assertTaskExecutionMetadata(t, subTaskIDs, phaseInfo.Info().Metadata) + assertTaskExternalResources(t, subTaskIDs, &retryAttemptsArray, &detailedArray, phaseInfo.Info().ExternalResources) }) t.Run("All phases", func(t *testing.T) { for _, p := range PhaseValues() { s := State{ CurrentPhase: p, + ArrayStatus: arraystatus.ArrayStatus{ + Detailed: detailedArray, + }, + IndexesToCache: indexesToCache, + RetryAttempts: retryAttemptsArray, } phaseInfo, err := MapArrayStateToPluginPhase(ctx, &s, nil, subTaskIDs) diff --git a/go/tasks/plugins/array/k8s/monitor.go b/go/tasks/plugins/array/k8s/monitor.go index cea1351815..f7d1bfcfa5 100644 --- a/go/tasks/plugins/array/k8s/monitor.go +++ b/go/tasks/plugins/array/k8s/monitor.go @@ -61,6 +61,30 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon currentState.ArrayStatus = *newArrayStatus } + // If the current State is newly minted then we must initialize RetryAttempts to track how many + // times each subtask is executed. + if len(currentState.RetryAttempts.GetItems()) == 0 { + count := uint(currentState.GetExecutionArraySize()) + maxValue := bitarray.Item(tCtx.TaskExecutionMetadata().GetMaxAttempts()) + + retryAttemptsArray, err := bitarray.NewCompactArray(count, maxValue) + if err != nil { + logger.Errorf(context.Background(), "Failed to create attempts compact array with [count: %v, maxValue: %v]", count, maxValue) + return currentState, logLinks, subTaskIDs, nil + } + + // Currently if any subtask fails then all subtasks are retried up to MaxAttempts. Therefore, all + // subtasks have an identical RetryAttempt, namely that of the map task execution metadata. Once + // retries over individual subtasks are implemented we should revisit this logic and instead + // increment the RetryAttempt for each subtask everytime a new pod is created. + retryAttempt := bitarray.Item(tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID().RetryAttempt) + for i := 0; i < currentState.GetExecutionArraySize(); i++ { + retryAttemptsArray.SetItem(i, retryAttempt) + } + + currentState.RetryAttempts = retryAttemptsArray + } + logPlugin, err := logs.InitializeLogPlugins(&config.LogConfig.Config) if err != nil { logger.Errorf(ctx, "Error initializing LogPlugins: [%s]", err) diff --git a/go/tasks/plugins/array/k8s/monitor_test.go b/go/tasks/plugins/array/k8s/monitor_test.go index 8f7c3414c1..fd7dc7d0e1 100644 --- a/go/tasks/plugins/array/k8s/monitor_test.go +++ b/go/tasks/plugins/array/k8s/monitor_test.go @@ -75,6 +75,7 @@ func getMockTaskExecutionContext(ctx context.Context) *mocks.TaskExecutionContex tMeta.OnIsInterruptible().Return(false) tMeta.OnGetK8sServiceAccount().Return("s") + tMeta.OnGetMaxAttempts().Return(2) tMeta.OnGetNamespace().Return("n") tMeta.OnGetLabels().Return(nil) tMeta.OnGetAnnotations().Return(nil) @@ -194,6 +195,9 @@ func TestCheckSubTasksState(t *testing.T) { }, } + retryAttemptsArray, err := bitarray.NewCompactArray(5, bitarray.Item(0)) + assert.NoError(t, err) + newState, _, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 5, @@ -203,6 +207,7 @@ func TestCheckSubTasksState(t *testing.T) { Detailed: arrayCore.NewPhasesCompactArray(uint(5)), }, IndexesToCache: bitarray.NewBitSet(5), + RetryAttempts: retryAttemptsArray, }) assert.Nil(t, err) @@ -236,6 +241,9 @@ func TestCheckSubTasksStateResourceGranted(t *testing.T) { }, } + retryAttemptsArray, err := bitarray.NewCompactArray(5, bitarray.Item(0)) + assert.NoError(t, err) + cacheIndexes := bitarray.NewBitSet(5) newState, _, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, @@ -246,6 +254,7 @@ func TestCheckSubTasksStateResourceGranted(t *testing.T) { ArrayStatus: arraystatus.ArrayStatus{ Detailed: arrayCore.NewPhasesCompactArray(uint(5)), }, + RetryAttempts: retryAttemptsArray, }) assert.Nil(t, err) @@ -273,6 +282,9 @@ func TestCheckSubTasksStateResourceGranted(t *testing.T) { } cacheIndexes := bitarray.NewBitSet(5) + retryAttemptsArray, err := bitarray.NewCompactArray(5, bitarray.Item(0)) + assert.NoError(t, err) + newState, _, subTaskIDs, err := LaunchAndCheckSubTasksState(ctx, tCtx, &kubeClient, &config, nil, "/prefix/", "/prefix-sand/", &arrayCore.State{ CurrentPhase: arrayCore.PhaseCheckingSubTaskExecutions, ExecutionArraySize: 5, @@ -280,6 +292,7 @@ func TestCheckSubTasksStateResourceGranted(t *testing.T) { OriginalMinSuccesses: 5, ArrayStatus: *arrayStatus, IndexesToCache: cacheIndexes, + RetryAttempts: retryAttemptsArray, }) assert.Nil(t, err) diff --git a/go/tasks/plugins/hive/execution_state.go b/go/tasks/plugins/hive/execution_state.go index cbc45cc06d..ed6f17cfbb 100644 --- a/go/tasks/plugins/hive/execution_state.go +++ b/go/tasks/plugins/hive/execution_state.go @@ -6,8 +6,6 @@ import ( "strconv" "time" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" @@ -149,20 +147,18 @@ func ConstructTaskInfo(e ExecutionState) *core.TaskInfo { logs := make([]*idlCore.TaskLog, 0, 1) t := time.Now() - metadata := &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: e.CommandID, - }, + externalResources := []*core.ExternalResource{ + { + ExternalID: e.CommandID, }, } if e.CommandID != "" { logs = append(logs, ConstructTaskLog(e)) return &core.TaskInfo{ - Logs: logs, - OccurredAt: &t, - Metadata: metadata, + Logs: logs, + OccurredAt: &t, + ExternalResources: externalResources, } } diff --git a/go/tasks/plugins/hive/execution_state_test.go b/go/tasks/plugins/hive/execution_state_test.go index cd5cd868e6..749e23b46b 100644 --- a/go/tasks/plugins/hive/execution_state_test.go +++ b/go/tasks/plugins/hive/execution_state_test.go @@ -7,9 +7,6 @@ import ( "testing" "time" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/golang/protobuf/proto" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io" ioMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks" @@ -128,13 +125,8 @@ func TestConstructTaskInfo(t *testing.T) { taskInfo := ConstructTaskInfo(e) assert.Equal(t, "https://wellness.qubole.com/v2/analyze?command_id=123", taskInfo.Logs[0].Uri) - assert.True(t, proto.Equal(taskInfo.Metadata, &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: "123", - }, - }, - })) + assert.Len(t, taskInfo.ExternalResources, 1) + assert.Equal(t, taskInfo.ExternalResources[0].ExternalID, "123") } func TestMapExecutionStateToPhaseInfo(t *testing.T) { diff --git a/go/tasks/plugins/k8s/sagemaker/builtin_training_test.go b/go/tasks/plugins/k8s/sagemaker/builtin_training_test.go index df02c97a34..1634ff7b76 100644 --- a/go/tasks/plugins/k8s/sagemaker/builtin_training_test.go +++ b/go/tasks/plugins/k8s/sagemaker/builtin_training_test.go @@ -5,9 +5,6 @@ import ( "fmt" "testing" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/golang/protobuf/proto" - "github.com/go-test/deep" flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" @@ -273,12 +270,7 @@ func Test_awsSagemakerPlugin_getEventInfoForTrainingJob(t *testing.T) { if diff := deep.Equal(expectedCustomInfo, taskInfo.CustomInfo); diff != nil { assert.FailNow(t, "Should be equal.", "Diff: %v", diff) } - assert.True(t, proto.Equal(taskInfo.Metadata, &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: "some-acceptable-name", - }, - }, - })) + assert.Len(t, taskInfo.ExternalResources, 1) + assert.Equal(t, taskInfo.ExternalResources[0].ExternalID, "some-acceptable-name") }) } diff --git a/go/tasks/plugins/k8s/sagemaker/custom_training_test.go b/go/tasks/plugins/k8s/sagemaker/custom_training_test.go index 91709b5071..748fa7d35b 100644 --- a/go/tasks/plugins/k8s/sagemaker/custom_training_test.go +++ b/go/tasks/plugins/k8s/sagemaker/custom_training_test.go @@ -6,9 +6,6 @@ import ( "strconv" "testing" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/golang/protobuf/proto" - "github.com/go-test/deep" flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" @@ -297,12 +294,7 @@ func Test_awsSagemakerPlugin_getEventInfoForCustomTrainingJob(t *testing.T) { if diff := deep.Equal(expectedCustomInfo, taskInfo.CustomInfo); diff != nil { assert.FailNow(t, "Should be equal.", "Diff: %v", diff) } - assert.True(t, proto.Equal(taskInfo.Metadata, &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: "some-acceptable-name", - }, - }, - })) + assert.Len(t, taskInfo.ExternalResources, 1) + assert.Equal(t, taskInfo.ExternalResources[0].ExternalID, "some-acceptable-name") }) } diff --git a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go index 796949d869..e710dc564c 100644 --- a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go +++ b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning_test.go @@ -5,9 +5,6 @@ import ( "fmt" "testing" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/golang/protobuf/proto" - "github.com/go-test/deep" flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" @@ -129,12 +126,7 @@ func Test_awsSagemakerPlugin_getEventInfoForHyperparameterTuningJob(t *testing.T if diff := deep.Equal(expectedCustomInfo, taskInfo.CustomInfo); diff != nil { assert.FailNow(t, "Should be equal.", "Diff: %v", diff) } - assert.True(t, proto.Equal(taskInfo.Metadata, &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: "some-acceptable-name", - }, - }, - })) + assert.Len(t, taskInfo.ExternalResources, 1) + assert.Equal(t, taskInfo.ExternalResources[0].ExternalID, "some-acceptable-name") }) } diff --git a/go/tasks/plugins/k8s/sagemaker/utils.go b/go/tasks/plugins/k8s/sagemaker/utils.go index 7dc26ac8b5..de1354ad91 100644 --- a/go/tasks/plugins/k8s/sagemaker/utils.go +++ b/go/tasks/plugins/k8s/sagemaker/utils.go @@ -6,7 +6,6 @@ import ( "sort" "strings" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors" @@ -400,11 +399,9 @@ func createTaskInfo(_ context.Context, jobRegion string, jobName string, jobType return &pluginsCore.TaskInfo{ Logs: taskLogs, CustomInfo: customInfo, - Metadata: &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: jobName, - }, + ExternalResources: []*pluginsCore.ExternalResource{ + { + ExternalID: jobName, }, }, }, nil diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 3370b0b951..7d241cb09a 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -3,8 +3,6 @@ package presto import ( "context" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" @@ -506,11 +504,9 @@ func ConstructTaskInfo(e ExecutionState) *core.TaskInfo { return &core.TaskInfo{ Logs: logs, OccurredAt: &t, - Metadata: &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: e.CommandID, - }, + ExternalResources: []*core.ExternalResource{ + { + ExternalID: e.CommandID, }, }, } diff --git a/go/tasks/plugins/presto/execution_state_test.go b/go/tasks/plugins/presto/execution_state_test.go index d6caad0b8b..9a2c4c15bc 100644 --- a/go/tasks/plugins/presto/execution_state_test.go +++ b/go/tasks/plugins/presto/execution_state_test.go @@ -106,8 +106,8 @@ func TestConstructTaskInfo(t *testing.T) { taskInfo := ConstructTaskInfo(e) assert.Equal(t, "https://prestoproxy-internal.flyteorg.net:443", taskInfo.Logs[0].Uri) - assert.Len(t, taskInfo.Metadata.ExternalResources, 1) - assert.Equal(t, taskInfo.Metadata.ExternalResources[0].ExternalId, "123") + assert.Len(t, taskInfo.ExternalResources, 1) + assert.Equal(t, taskInfo.ExternalResources[0].ExternalID, "123") } func TestMapExecutionStateToPhaseInfo(t *testing.T) { diff --git a/go/tasks/plugins/webapi/athena/plugin.go b/go/tasks/plugins/webapi/athena/plugin.go index f2cbba4367..c39d9104a8 100644 --- a/go/tasks/plugins/webapi/athena/plugin.go +++ b/go/tasks/plugins/webapi/athena/plugin.go @@ -5,8 +5,6 @@ import ( "fmt" "time" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - errors2 "github.com/flyteorg/flyteplugins/go/tasks/errors" awsSdk "github.com/aws/aws-sdk-go-v2/aws" @@ -186,11 +184,9 @@ func createTaskInfo(queryID string, cfg awsSdk.Config) *core.TaskInfo { Name: "Athena Query Console", }, }, - Metadata: &event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: queryID, - }, + ExternalResources: []*core.ExternalResource{ + { + ExternalID: queryID, }, }, } diff --git a/go/tasks/plugins/webapi/athena/plugin_test.go b/go/tasks/plugins/webapi/athena/plugin_test.go index d85b425734..5f821bb679 100644 --- a/go/tasks/plugins/webapi/athena/plugin_test.go +++ b/go/tasks/plugins/webapi/athena/plugin_test.go @@ -5,8 +5,6 @@ import ( awsSdk "github.com/aws/aws-sdk-go-v2/aws" idlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/event" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" ) @@ -20,11 +18,6 @@ func TestCreateTaskInfo(t *testing.T) { Name: "Athena Query Console", }, }, taskInfo.Logs) - assert.True(t, proto.Equal(&event.TaskExecutionMetadata{ - ExternalResources: []*event.ExternalResourceInfo{ - { - ExternalId: "query_id", - }, - }, - }, taskInfo.Metadata)) + assert.Len(t, taskInfo.ExternalResources, 1) + assert.Equal(t, taskInfo.ExternalResources[0].ExternalID, "query_id") } diff --git a/tests/end_to_end.go b/tests/end_to_end.go index b4ff1a19bd..df8d9e6537 100644 --- a/tests/end_to_end.go +++ b/tests/end_to_end.go @@ -161,6 +161,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i tMeta.OnGetOverrides().Return(overrides) tMeta.OnGetK8sServiceAccount().Return("s") tMeta.OnGetNamespace().Return("fake-development") + tMeta.OnGetMaxAttempts().Return(2) tMeta.OnGetSecurityContext().Return(idlCore.SecurityContext{ RunAs: &idlCore.Identity{ K8SServiceAccount: "s",