Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Retry map task subtasks #236

Merged
merged 9 commits into from
Feb 3, 2022
35 changes: 17 additions & 18 deletions go/tasks/plugins/array/core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,24 +256,23 @@ func MapArrayStateToPluginPhase(_ context.Context, state *State, logLinks []*idl
func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus.ArraySummary) Phase {
totalCount := int64(0)
totalSuccesses := int64(0)
totalFailures := int64(0)
totalPermanentFailures := int64(0)
totalRetryableFailures := int64(0)
totalRunning := int64(0)
totalWaitingForResources := int64(0)
for phase, count := range summary {
totalCount += count
if phase.IsTerminal() {
if phase.IsSuccess() {
totalSuccesses += count
} else {
// TODO: Split out retryable failures to be retried without doing the entire array task.
// TODO: Other option: array tasks are only retryable as a full set and to get single task retriability
// TODO: dynamic_task must be updated to not auto-combine to array tasks. For scale reasons, it is
// TODO: preferable to auto-combine to array tasks for now.
totalFailures += count
}
} else if phase.IsWaitingForResources() {

switch phase {
case core.PhaseSuccess:
totalSuccesses += count
case core.PhasePermanentFailure:
totalPermanentFailures += count
case core.PhaseRetryableFailure:
totalRetryableFailures += count
case core.PhaseWaitingForResources:
totalWaitingForResources += count
} else {
default:
totalRunning += count
}
}
Expand All @@ -284,9 +283,9 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus
}

// No chance to reach the required success numbers.
if totalRunning+totalSuccesses+totalWaitingForResources < minSuccesses {
logger.Infof(ctx, "Array failed early because total failures > minSuccesses[%v]. Snapshot totalRunning[%v] + totalSuccesses[%v] + totalWaitingForResource[%v]",
minSuccesses, totalRunning, totalSuccesses, totalWaitingForResources)
if totalRunning+totalSuccesses+totalWaitingForResources+totalRetryableFailures < minSuccesses {
logger.Infof(ctx, "Array failed early because total failures > minSuccesses[%v]. Snapshot totalRunning[%v] + totalSuccesses[%v] + totalWaitingForResource[%v] + totalRetryableFailures[%v]",
minSuccesses, totalRunning, totalSuccesses, totalWaitingForResources, totalRetryableFailures)
return PhaseWriteToDiscoveryThenFail
}

Expand All @@ -299,8 +298,8 @@ func SummaryToPhase(ctx context.Context, minSuccesses int64, summary arraystatus
return PhaseWriteToDiscovery
}

logger.Debugf(ctx, "Array is still running [Successes: %v, Failures: %v, Total: %v, MinSuccesses: %v]",
totalSuccesses, totalFailures, totalCount, minSuccesses)
logger.Debugf(ctx, "Array is still running [Successes: %v, PermanentFailures: %v, RetryableFailures: %v, Total: %v, MinSuccesses: %v]",
totalSuccesses, totalPermanentFailures, totalRetryableFailures, totalCount, minSuccesses)
return PhaseCheckingSubTaskExecutions
}

Expand Down
2 changes: 1 addition & 1 deletion go/tasks/plugins/array/k8s/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (e Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (c
nextState, err = array.WriteToDiscovery(ctx, tCtx, pluginState, arrayCore.PhaseAssembleFinalOutput)

case arrayCore.PhaseAssembleFinalError:
nextState, err = array.AssembleFinalOutputs(ctx, e.errorAssembler, tCtx, arrayCore.PhaseRetryableFailure, pluginState)
nextState, err = array.AssembleFinalOutputs(ctx, e.errorAssembler, tCtx, arrayCore.PhasePermanentFailure, pluginState)

default:
nextState = pluginState
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/array/k8s/launcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ var arrayJobEnvVars = []corev1.EnvVar{
},
}

func formatSubTaskName(_ context.Context, parentName, suffix string) (subTaskName string) {
return utils.ConvertToDNS1123SubdomainCompatibleString(fmt.Sprintf("%v-%v", parentName, suffix))
func formatSubTaskName(_ context.Context, parentName, indexStr, retryAttemptStr string) (subTaskName string) {
hamersaw marked this conversation as resolved.
Show resolved Hide resolved
return utils.ConvertToDNS1123SubdomainCompatibleString(fmt.Sprintf("%v-%v-%v", parentName, indexStr, retryAttemptStr))
}

func ApplyPodPolicies(_ context.Context, cfg *Config, pod *corev1.Pod) *corev1.Pod {
Expand Down
38 changes: 27 additions & 11 deletions go/tasks/plugins/array/k8s/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,8 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon
return currentState, logLinks, subTaskIDs, nil
}

// Currently if any subtask fails then all subtasks are retried up to MaxAttempts. Therefore, all
// subtasks have an identical RetryAttempt, namely that of the map task execution metadata. Once
// retries over individual subtasks are implemented we should revisit this logic and instead
// increment the RetryAttempt for each subtask everytime a new pod is created.
retryAttempt := bitarray.Item(tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID().RetryAttempt)
for i := 0; i < currentState.GetExecutionArraySize(); i++ {
retryAttemptsArray.SetItem(i, retryAttempt)
retryAttemptsArray.SetItem(i, 0)
}

currentState.RetryAttempts = retryAttemptsArray
Expand All @@ -93,20 +88,40 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon

for childIdx, existingPhaseIdx := range currentState.GetArrayStatus().Detailed.GetItems() {
existingPhase := core.Phases[existingPhaseIdx]
indexStr := strconv.Itoa(childIdx)
podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr)
originalIdx := arrayCore.CalculateOriginalIndex(childIdx, newState.GetIndexesToCache())

indexStr := strconv.Itoa(childIdx)
retryAttempt := currentState.RetryAttempts.GetItem(childIdx)
retryAttemptStr := strconv.FormatUint(retryAttempt, 10)
podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr, retryAttemptStr)

if existingPhase.IsTerminal() {
// If we get here it means we have already "processed" this terminal phase since we will only persist
// the phase after all processing is done (e.g. check outputs/errors file, record events... etc.).

// Since we know we have already "processed" this terminal phase we can safely deallocate resource
err = deallocateResource(ctx, tCtx, config, childIdx)
err = deallocateResource(ctx, tCtx, config, podName)
if err != nil {
logger.Errorf(ctx, "Error releasing allocation token [%s] in LaunchAndCheckSubTasks [%s]", podName, err)
return currentState, logLinks, subTaskIDs, errors2.Wrapf(ErrCheckPodStatus, err, "Error releasing allocation token.")
}

// If a subtask is marked as a retryable failure we check if the number of retries
// exceeds the maximum attempts. If so, transition the task to a permanent failure
// so that is not attempted again. If it can be retried, increment the retry attempts
// value and transition the task to "Undefined" so that it is reevaluated.
if existingPhase == core.PhaseRetryableFailure {
if uint32(retryAttempt+1) < tCtx.TaskExecutionMetadata().GetMaxAttempts() {
newState.RetryAttempts.SetItem(childIdx, retryAttempt+1)

newArrayStatus.Summary.Inc(core.PhaseUndefined)
newArrayStatus.Detailed.SetItem(childIdx, bitarray.Item(core.PhaseUndefined))
continue
} else {
existingPhase = core.PhasePermanentFailure
}
}

newArrayStatus.Summary.Inc(existingPhase)
newArrayStatus.Detailed.SetItem(childIdx, bitarray.Item(existingPhase))

Expand All @@ -117,6 +132,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon
},
originalIdx,
tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID().RetryAttempt,
retryAttempt,
logPlugin)

if err != nil {
Expand Down Expand Up @@ -209,7 +225,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon
return newState, logLinks, subTaskIDs, nil
}

func FetchPodStatusAndLogs(ctx context.Context, client core.KubeClient, name k8sTypes.NamespacedName, index int, retryAttempt uint32, logPlugin tasklog.Plugin) (
func FetchPodStatusAndLogs(ctx context.Context, client core.KubeClient, name k8sTypes.NamespacedName, index int, retryAttempt uint32, subtaskRetryAttempt uint64, logPlugin tasklog.Plugin) (
info core.PhaseInfo, err error) {

pod := &v1.Pod{
Expand Down Expand Up @@ -249,7 +265,7 @@ func FetchPodStatusAndLogs(ctx context.Context, client core.KubeClient, name k8s
o, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: pod.Name,
Namespace: pod.Namespace,
LogName: fmt.Sprintf(" #%d-%d", retryAttempt, index),
LogName: fmt.Sprintf(" #%d-%d-%d", retryAttempt, index, subtaskRetryAttempt),
PodUnixStartTime: pod.CreationTimestamp.Unix(),
})

Expand Down
10 changes: 5 additions & 5 deletions go/tasks/plugins/array/k8s/monitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func TestGetNamespaceForExecution(t *testing.T) {
func testSubTaskIDs(t *testing.T, actual []*string) {
var expected = make([]*string, 5)
for i := 0; i < len(expected); i++ {
subTaskID := fmt.Sprintf("notfound-%d", i)
subTaskID := fmt.Sprintf("notfound-%d-0", i)
expected[i] = &subTaskID
}
assert.EqualValues(t, expected, actual)
Expand Down Expand Up @@ -173,11 +173,11 @@ func TestCheckSubTasksState(t *testing.T) {
assert.NotEmpty(t, logLinks)
assert.Equal(t, 10, len(logLinks))
for i := 0; i < 10; i = i + 2 {
assert.Equal(t, fmt.Sprintf("Kubernetes Logs #0-%d (PhaseRunning)", i/2), logLinks[i].Name)
assert.Equal(t, fmt.Sprintf("k8s/log/a-n-b/notfound-%d/pod?namespace=a-n-b", i/2), logLinks[i].Uri)
assert.Equal(t, fmt.Sprintf("Kubernetes Logs #0-%d-0 (PhaseRunning)", i/2), logLinks[i].Name)
assert.Equal(t, fmt.Sprintf("k8s/log/a-n-b/notfound-%d-0/pod?namespace=a-n-b", i/2), logLinks[i].Uri)

assert.Equal(t, fmt.Sprintf("Cloudwatch Logs #0-%d (PhaseRunning)", i/2), logLinks[i+1].Name)
assert.Equal(t, fmt.Sprintf("https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/kubernetes/flyte;prefix=var.log.containers.notfound-%d;streamFilter=typeLogStreamPrefix", i/2), logLinks[i+1].Uri)
assert.Equal(t, fmt.Sprintf("Cloudwatch Logs #0-%d-0 (PhaseRunning)", i/2), logLinks[i+1].Name)
assert.Equal(t, fmt.Sprintf("https://console.aws.amazon.com/cloudwatch/home?region=us-east-1#logStream:group=/kubernetes/flyte;prefix=var.log.containers.notfound-%d-0;streamFilter=typeLogStreamPrefix", i/2), logLinks[i+1].Uri)
}

p, _ := newState.GetPhase()
Expand Down
20 changes: 12 additions & 8 deletions go/tasks/plugins/array/k8s/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeCl
}

indexStr := strconv.Itoa(t.ChildIdx)
podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr)
retryAttemptStr := strconv.FormatUint(t.State.RetryAttempts.GetItem(t.ChildIdx), 10)
podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr, retryAttemptStr)
allocationStatus, err := allocateResource(ctx, tCtx, t.Config, podName)
if err != nil {
return LaunchError, err
Expand Down Expand Up @@ -187,7 +188,9 @@ func (t Task) Launch(ctx context.Context, tCtx core.TaskExecutionContext, kubeCl
func (t *Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient, dataStore *storage.DataStore, outputPrefix, baseOutputDataSandbox storage.DataReference,
logPlugin tasklog.Plugin) (MonitorResult, []*idlCore.TaskLog, error) {
indexStr := strconv.Itoa(t.ChildIdx)
podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr)
retryAttempt := t.State.RetryAttempts.GetItem(t.ChildIdx)
retryAttemptStr := strconv.FormatUint(retryAttempt, 10)
podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr, retryAttemptStr)
t.SubTaskIDs = append(t.SubTaskIDs, &podName)
var loglinks []*idlCore.TaskLog

Expand All @@ -200,6 +203,7 @@ func (t *Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kube
},
originalIdx,
tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID().RetryAttempt,
retryAttempt,
logPlugin)
if err != nil {
return MonitorError, loglinks, errors2.Wrapf(ErrCheckPodStatus, err, "Failed to check pod status.")
Expand Down Expand Up @@ -229,7 +233,8 @@ func (t *Task) Monitor(ctx context.Context, tCtx core.TaskExecutionContext, kube

func (t Task) Abort(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient) error {
indexStr := strconv.Itoa(t.ChildIdx)
podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr)
retryAttemptStr := strconv.FormatUint(t.State.RetryAttempts.GetItem(t.ChildIdx), 10)
podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr, retryAttemptStr)
pod := &corev1.Pod{
TypeMeta: metav1.TypeMeta{
Kind: PodKind,
Expand All @@ -256,7 +261,8 @@ func (t Task) Abort(ctx context.Context, tCtx core.TaskExecutionContext, kubeCli

func (t Task) Finalize(ctx context.Context, tCtx core.TaskExecutionContext, kubeClient core.KubeClient) error {
indexStr := strconv.Itoa(t.ChildIdx)
podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr)
retryAttemptStr := strconv.FormatUint(t.State.RetryAttempts.GetItem(t.ChildIdx), 10)
podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr, retryAttemptStr)

pod := &v1.Pod{
TypeMeta: metaV1.TypeMeta{
Expand Down Expand Up @@ -285,7 +291,7 @@ func (t Task) Finalize(ctx context.Context, tCtx core.TaskExecutionContext, kube
}

// Deallocate Resource
err = deallocateResource(ctx, tCtx, t.Config, t.ChildIdx)
err = deallocateResource(ctx, tCtx, t.Config, podName)
if err != nil {
logger.Errorf(ctx, "Error releasing allocation token [%s] in Finalize [%s]", podName, err)
return err
Expand Down Expand Up @@ -317,12 +323,10 @@ func allocateResource(ctx context.Context, tCtx core.TaskExecutionContext, confi
return allocationStatus, nil
}

func deallocateResource(ctx context.Context, tCtx core.TaskExecutionContext, config *Config, childIdx int) error {
func deallocateResource(ctx context.Context, tCtx core.TaskExecutionContext, config *Config, podName string) error {
if !IsResourceConfigSet(config.ResourceConfig) {
return nil
}
indexStr := strconv.Itoa((childIdx))
podName := formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), indexStr)
resourceNamespace := core.ResourceNamespace(config.ResourceConfig.PrimaryLabel)

err := tCtx.ResourceManager().ReleaseResource(ctx, resourceNamespace, podName)
Expand Down
17 changes: 14 additions & 3 deletions go/tasks/plugins/array/k8s/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
"github.com/stretchr/testify/mock"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/plugins/array/core"

"github.com/flyteorg/flytestdlib/bitarray"

"github.com/stretchr/testify/assert"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand All @@ -24,8 +27,8 @@ func TestFinalize(t *testing.T) {
resourceManager := mocks.ResourceManager{}
podTemplate, _, _ := FlyteArrayJobToK8sPodTemplate(ctx, tCtx, "")
pod := addPodFinalizer(&podTemplate)
pod.Name = formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), "1")
assert.Equal(t, "notfound-1", pod.Name)
pod.Name = formatSubTaskName(ctx, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), "1", "0")
assert.Equal(t, "notfound-1-0", pod.Name)
assert.NoError(t, kubeClient.GetClient().Create(ctx, pod))

resourceManager.OnReleaseResourceMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil)
Expand All @@ -39,12 +42,20 @@ func TestFinalize(t *testing.T) {
},
}

retryAttemptsArray, err := bitarray.NewCompactArray(2, 1)
assert.NoError(t, err)

state := core.State{
RetryAttempts: retryAttemptsArray,
}

task := &Task{
State: &state,
Config: &config,
ChildIdx: 1,
}

err := task.Finalize(ctx, tCtx, &kubeClient)
err = task.Finalize(ctx, tCtx, &kubeClient)
assert.NoError(t, err)
}

Expand Down