From 3f539418de1d8d64c44294defb750443db126d31 Mon Sep 17 00:00:00 2001 From: Dan Rammer Date: Fri, 8 Apr 2022 13:24:08 -0500 Subject: [PATCH] Supporting interruptible for map tasks (#253) * implemented IsInterruptible for SubTaskExecutionMetadata Signed-off-by: Daniel Rammer * fixed possible race condition Signed-off-by: Daniel Rammer * fixed unit tests Signed-off-by: Daniel Rammer * fixed lint issue Signed-off-by: Daniel Rammer * updated TODO documentation Signed-off-by: Daniel Rammer * changed context on NewCompactArray error log Signed-off-by: Daniel Rammer * fixing retry attempt calculation on abort Signed-off-by: Daniel Rammer --- .../pluginmachinery/core/exec_metadata.go | 1 + .../core/mocks/task_execution_metadata.go | 32 ++++++++++++++ .../pluginmachinery/flytek8s/pod_helper.go | 20 ++------- .../flytek8s/pod_helper_test.go | 38 ---------------- .../go/tasks/plugins/array/core/state.go | 3 ++ .../go/tasks/plugins/array/k8s/management.go | 43 +++++++++++++++++-- .../plugins/array/k8s/management_test.go | 1 + .../plugins/array/k8s/subtask_exec_context.go | 15 +++++-- .../array/k8s/subtask_exec_context_test.go | 3 +- flyteplugins/tests/end_to_end.go | 1 + 10 files changed, 95 insertions(+), 62 deletions(-) diff --git a/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go b/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go index 594dd6eef5..2e39dda140 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go @@ -44,4 +44,5 @@ type TaskExecutionMetadata interface { GetSecurityContext() core.SecurityContext IsInterruptible() bool GetPlatformResources() *v1.ResourceRequirements + GetInterruptibleFailureThreshold() uint32 } diff --git a/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go b/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go index 28c05c2748..e851cb7aa9 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go @@ -54,6 +54,38 @@ func (_m *TaskExecutionMetadata) GetAnnotations() map[string]string { return r0 } +type TaskExecutionMetadata_GetInterruptibleFailureThreshold struct { + *mock.Call +} + +func (_m TaskExecutionMetadata_GetInterruptibleFailureThreshold) Return(_a0 uint32) *TaskExecutionMetadata_GetInterruptibleFailureThreshold { + return &TaskExecutionMetadata_GetInterruptibleFailureThreshold{Call: _m.Call.Return(_a0)} +} + +func (_m *TaskExecutionMetadata) OnGetInterruptibleFailureThreshold() *TaskExecutionMetadata_GetInterruptibleFailureThreshold { + c := _m.On("GetInterruptibleFailureThreshold") + return &TaskExecutionMetadata_GetInterruptibleFailureThreshold{Call: c} +} + +func (_m *TaskExecutionMetadata) OnGetInterruptibleFailureThresholdMatch(matchers ...interface{}) *TaskExecutionMetadata_GetInterruptibleFailureThreshold { + c := _m.On("GetInterruptibleFailureThreshold", matchers...) + return &TaskExecutionMetadata_GetInterruptibleFailureThreshold{Call: c} +} + +// GetInterruptibleFailureThreshold provides a mock function with given fields: +func (_m *TaskExecutionMetadata) GetInterruptibleFailureThreshold() uint32 { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + return r0 +} + type TaskExecutionMetadata_GetK8sServiceAccount struct { *mock.Call } diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 6fdca19ff6..0eb13f6a8f 100755 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -63,18 +63,11 @@ func ApplyInterruptibleNodeAffinity(interruptible bool, podSpec *v1.PodSpec) { // UpdatePod updates the base pod spec used to execute tasks. This is configured with plugins and task metadata-specific options func UpdatePod(taskExecutionMetadata pluginsCore.TaskExecutionMetadata, resourceRequirements []v1.ResourceRequirements, podSpec *v1.PodSpec) { - UpdatePodWithInterruptibleFlag(taskExecutionMetadata, resourceRequirements, podSpec, false) -} - -// UpdatePodWithInterruptibleFlag updates the base pod spec used to execute tasks. This is configured with plugins and task metadata-specific options -func UpdatePodWithInterruptibleFlag(taskExecutionMetadata pluginsCore.TaskExecutionMetadata, - resourceRequirements []v1.ResourceRequirements, podSpec *v1.PodSpec, omitInterruptible bool) { - isInterruptible := !omitInterruptible && taskExecutionMetadata.IsInterruptible() if len(podSpec.RestartPolicy) == 0 { podSpec.RestartPolicy = v1.RestartPolicyNever } podSpec.Tolerations = append( - GetPodTolerations(isInterruptible, resourceRequirements...), podSpec.Tolerations...) + GetPodTolerations(taskExecutionMetadata.IsInterruptible(), resourceRequirements...), podSpec.Tolerations...) if len(podSpec.ServiceAccountName) == 0 { podSpec.ServiceAccountName = taskExecutionMetadata.GetK8sServiceAccount() @@ -83,7 +76,7 @@ func UpdatePodWithInterruptibleFlag(taskExecutionMetadata pluginsCore.TaskExecut podSpec.SchedulerName = config.GetK8sPluginConfig().SchedulerName } podSpec.NodeSelector = utils.UnionMaps(podSpec.NodeSelector, config.GetK8sPluginConfig().DefaultNodeSelector) - if isInterruptible { + if taskExecutionMetadata.IsInterruptible() { podSpec.NodeSelector = utils.UnionMaps(podSpec.NodeSelector, config.GetK8sPluginConfig().InterruptibleNodeSelector) } if podSpec.Affinity == nil && config.GetK8sPluginConfig().DefaultAffinity != nil { @@ -98,16 +91,11 @@ func UpdatePodWithInterruptibleFlag(taskExecutionMetadata pluginsCore.TaskExecut if podSpec.DNSConfig == nil && config.GetK8sPluginConfig().DefaultPodDNSConfig != nil { podSpec.DNSConfig = config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy() } - ApplyInterruptibleNodeAffinity(isInterruptible, podSpec) + ApplyInterruptibleNodeAffinity(taskExecutionMetadata.IsInterruptible(), podSpec) } // ToK8sPodSpec constructs a pod spec from the given TaskTemplate func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, error) { - return ToK8sPodSpecWithInterruptible(ctx, tCtx, false) -} - -// ToK8sPodSpecWithInterruptible constructs a pod spec from the gien TaskTemplate and optionally add (interruptible instance) support. -func ToK8sPodSpecWithInterruptible(ctx context.Context, tCtx pluginsCore.TaskExecutionContext, omitInterruptible bool) (*v1.PodSpec, error) { task, err := tCtx.TaskReader().Read(ctx) if err != nil { logger.Warnf(ctx, "failed to read task information when trying to construct Pod, err: %s", err.Error()) @@ -138,7 +126,7 @@ func ToK8sPodSpecWithInterruptible(ctx context.Context, tCtx pluginsCore.TaskExe pod := &v1.PodSpec{ Containers: containers, } - UpdatePodWithInterruptibleFlag(tCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{c.Resources}, pod, omitInterruptible) + UpdatePod(tCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{c.Resources}, pod) if err := AddCoPilotToPod(ctx, config.GetK8sPluginConfig().CoPilot, pod, task.GetInterface(), tCtx.TaskExecutionMetadata(), tCtx.InputReader(), tCtx.OutputWriter(), task.GetContainer().GetDataConfig()); err != nil { return nil, err diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index a6d47e949d..a3b5380c16 100755 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -109,7 +109,6 @@ func TestPodSetup(t *testing.T) { t.Run("ApplyInterruptibleNodeAffinity", TestApplyInterruptibleNodeAffinity) t.Run("UpdatePod", updatePod) t.Run("ToK8sPodInterruptible", toK8sPodInterruptible) - t.Run("toK8sPodInterruptibleFalse", toK8sPodInterruptibleFalse) } func TestApplyInterruptibleNodeAffinity(t *testing.T) { @@ -349,43 +348,6 @@ func toK8sPodInterruptible(t *testing.T) { ) } -func toK8sPodInterruptibleFalse(t *testing.T) { - ctx := context.TODO() - - x := dummyExecContext(&v1.ResourceRequirements{ - Limits: v1.ResourceList{ - v1.ResourceCPU: resource.MustParse("1024m"), - v1.ResourceStorage: resource.MustParse("100M"), - ResourceNvidiaGPU: resource.MustParse("1"), - }, - Requests: v1.ResourceList{ - v1.ResourceCPU: resource.MustParse("1024m"), - v1.ResourceStorage: resource.MustParse("100M"), - }, - }) - - p, err := ToK8sPodSpecWithInterruptible(ctx, x, true) - assert.NoError(t, err) - assert.Len(t, p.Tolerations, 1) - assert.Equal(t, 0, len(p.NodeSelector)) - assert.Equal(t, "", p.NodeSelector["x/interruptible"]) - assert.NotEqualValues( - t, - []v1.NodeSelectorTerm{ - v1.NodeSelectorTerm{ - MatchExpressions: []v1.NodeSelectorRequirement{ - v1.NodeSelectorRequirement{ - Key: "x/interruptible", - Operator: v1.NodeSelectorOpIn, - Values: []string{"true"}, - }, - }, - }, - }, - p.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms, - ) -} - func TestToK8sPod(t *testing.T) { ctx := context.TODO() diff --git a/flyteplugins/go/tasks/plugins/array/core/state.go b/flyteplugins/go/tasks/plugins/array/core/state.go index 66f4b7eba8..9a281c311c 100644 --- a/flyteplugins/go/tasks/plugins/array/core/state.go +++ b/flyteplugins/go/tasks/plugins/array/core/state.go @@ -53,6 +53,9 @@ type State struct { // Tracks the number of subtask retries using the execution index RetryAttempts bitarray.CompactArray `json:"retryAttempts"` + + // Tracks the number of system failures for each subtask using the execution index + SystemFailures bitarray.CompactArray `json:"systemFailures"` } func (s State) GetReason() string { diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management.go b/flyteplugins/go/tasks/plugins/array/k8s/management.go index 3e3eb042f8..1d9b98e506 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management.go @@ -93,7 +93,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon retryAttemptsArray, err := bitarray.NewCompactArray(count, maxValue) if err != nil { - logger.Errorf(context.Background(), "Failed to create attempts compact array with [count: %v, maxValue: %v]", count, maxValue) + logger.Errorf(ctx, "Failed to create attempts compact array with [count: %v, maxValue: %v]", count, maxValue) return currentState, externalResources, nil } @@ -106,6 +106,26 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon currentState.RetryAttempts = retryAttemptsArray } + // If the current State is newly minted then we must initialize SystemFailures to track how many + // times the subtask failed due to system issues, this is necessary to correctly evaluate + // interruptible subtasks. + if len(currentState.SystemFailures.GetItems()) == 0 { + count := uint(currentState.GetExecutionArraySize()) + maxValue := bitarray.Item(tCtx.TaskExecutionMetadata().GetInterruptibleFailureThreshold()) + + systemFailuresArray, err := bitarray.NewCompactArray(count, maxValue) + if err != nil { + logger.Errorf(ctx, "Failed to create system failures array with [count: %v, maxValue: %v]", count, maxValue) + return currentState, externalResources, err + } + + for i := 0; i < currentState.GetExecutionArraySize(); i++ { + systemFailuresArray.SetItem(i, 0) + } + + currentState.SystemFailures = systemFailuresArray + } + // initialize log plugin logPlugin, err := logs.InitializeLogPlugins(&config.LogConfig.Config) if err != nil { @@ -146,7 +166,8 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon } originalIdx := arrayCore.CalculateOriginalIndex(childIdx, newState.GetIndexesToCache()) - stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt) + systemFailures := currentState.SystemFailures.GetItem(childIdx) + stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, systemFailures) if err != nil { return currentState, externalResources, err } @@ -188,6 +209,16 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon return currentState, externalResources, perr } + if phaseInfo.Err() != nil { + messageCollector.Collect(childIdx, phaseInfo.Err().String()) + } + + if phaseInfo.Err() != nil && phaseInfo.Err().GetKind() == idlCore.ExecutionError_SYSTEM { + newState.SystemFailures.SetItem(childIdx, systemFailures+1) + } else { + newState.SystemFailures.SetItem(childIdx, systemFailures) + } + // process subtask phase actualPhase := phaseInfo.Phase() if actualPhase.IsSuccess() { @@ -294,7 +325,11 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube messageCollector := errorcollector.NewErrorMessageCollector() for childIdx, existingPhaseIdx := range currentState.GetArrayStatus().Detailed.GetItems() { existingPhase := core.Phases[existingPhaseIdx] - retryAttempt := currentState.RetryAttempts.GetItem(childIdx) + retryAttempt := uint64(0) + if childIdx < len(currentState.RetryAttempts.GetItems()) { + // we can use RetryAttempts if it has been initialized, otherwise stay with default 0 + retryAttempt = currentState.RetryAttempts.GetItem(childIdx) + } // return immediately if subtask has completed or not yet started if existingPhase.IsTerminal() || existingPhase == core.PhaseUndefined { @@ -302,7 +337,7 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube } originalIdx := arrayCore.CalculateOriginalIndex(childIdx, currentState.GetIndexesToCache()) - stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt) + stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, 0) if err != nil { return err } diff --git a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go index bc5be9e7c0..2aaf077fa6 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/management_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/management_test.go @@ -91,6 +91,7 @@ func getMockTaskExecutionContext(ctx context.Context, parallelism int) *mocks.Ta tMeta.OnGetAnnotations().Return(nil) tMeta.OnGetOwnerReference().Return(metav1.OwnerReference{}) tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) + tMeta.OnGetInterruptibleFailureThreshold().Return(2) ow := &mocks2.OutputWriter{} ow.OnGetOutputPrefixPath().Return("/prefix/") diff --git a/flyteplugins/go/tasks/plugins/array/k8s/subtask_exec_context.go b/flyteplugins/go/tasks/plugins/array/k8s/subtask_exec_context.go index fe49b1f59d..d5d4393101 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/subtask_exec_context.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/subtask_exec_context.go @@ -43,9 +43,9 @@ func (s SubTaskExecutionContext) TaskReader() pluginsCore.TaskReader { // NewSubtaskExecutionContext constructs a SubTaskExecutionContext using the provided parameters func NewSubTaskExecutionContext(tCtx pluginsCore.TaskExecutionContext, taskTemplate *core.TaskTemplate, - executionIndex, originalIndex int, retryAttempt uint64) (SubTaskExecutionContext, error) { + executionIndex, originalIndex int, retryAttempt uint64, systemFailures uint64) (SubTaskExecutionContext, error) { - subTaskExecutionMetadata, err := NewSubTaskExecutionMetadata(tCtx.TaskExecutionMetadata(), taskTemplate, executionIndex, retryAttempt) + subTaskExecutionMetadata, err := NewSubTaskExecutionMetadata(tCtx.TaskExecutionMetadata(), taskTemplate, executionIndex, retryAttempt, systemFailures) if err != nil { return SubTaskExecutionContext{}, err } @@ -135,6 +135,7 @@ type SubTaskExecutionMetadata struct { pluginsCore.TaskExecutionMetadata annotations map[string]string labels map[string]string + interruptible bool subtaskExecutionID SubTaskExecutionID } @@ -153,8 +154,14 @@ func (s SubTaskExecutionMetadata) GetTaskExecutionID() pluginsCore.TaskExecution return s.subtaskExecutionID } +// IsInterruptbile overrides the base NodeExecutionMetadata to return a subtask specific identifier +func (s SubTaskExecutionMetadata) IsInterruptible() bool { + return s.interruptible +} + // NewSubtaskExecutionMetadata constructs a SubTaskExecutionMetadata using the provided parameters -func NewSubTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecutionMetadata, taskTemplate *core.TaskTemplate, executionIndex int, retryAttempt uint64) (SubTaskExecutionMetadata, error) { +func NewSubTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecutionMetadata, taskTemplate *core.TaskTemplate, + executionIndex int, retryAttempt uint64, systemFailures uint64) (SubTaskExecutionMetadata, error) { var err error secretsMap := make(map[string]string) @@ -171,10 +178,12 @@ func NewSubTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecution } subTaskExecutionID := NewSubTaskExecutionID(taskExecutionMetadata.GetTaskExecutionID(), executionIndex, retryAttempt) + interruptible := taskExecutionMetadata.IsInterruptible() && uint32(systemFailures) < taskExecutionMetadata.GetInterruptibleFailureThreshold() return SubTaskExecutionMetadata{ taskExecutionMetadata, utils.UnionMaps(taskExecutionMetadata.GetAnnotations(), secretsMap), utils.UnionMaps(taskExecutionMetadata.GetLabels(), injectSecretsLabel), + interruptible, subTaskExecutionID, }, nil } diff --git a/flyteplugins/go/tasks/plugins/array/k8s/subtask_exec_context_test.go b/flyteplugins/go/tasks/plugins/array/k8s/subtask_exec_context_test.go index efe547bad0..079ab82915 100644 --- a/flyteplugins/go/tasks/plugins/array/k8s/subtask_exec_context_test.go +++ b/flyteplugins/go/tasks/plugins/array/k8s/subtask_exec_context_test.go @@ -20,8 +20,9 @@ func TestSubTaskExecutionContext(t *testing.T) { executionIndex := 0 originalIndex := 5 retryAttempt := uint64(1) + systemFailures := uint64(0) - stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, executionIndex, originalIndex, retryAttempt) + stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, executionIndex, originalIndex, retryAttempt, systemFailures) assert.Nil(t, err) assert.Equal(t, fmt.Sprintf("notfound-%d-%d", executionIndex, retryAttempt), stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()) diff --git a/flyteplugins/tests/end_to_end.go b/flyteplugins/tests/end_to_end.go index df8d9e6537..9eb1638c08 100644 --- a/flyteplugins/tests/end_to_end.go +++ b/flyteplugins/tests/end_to_end.go @@ -176,6 +176,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i Name: execID, }) tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) + tMeta.OnGetInterruptibleFailureThreshold().Return(2) catClient := &catalogMocks.Client{} catData := sync.Map{}