Skip to content

Commit

Permalink
Supporting interruptible for map tasks (flyteorg#253)
Browse files Browse the repository at this point in the history
* implemented IsInterruptible for SubTaskExecutionMetadata

Signed-off-by: Daniel Rammer <[email protected]>

* fixed possible race condition

Signed-off-by: Daniel Rammer <[email protected]>

* fixed unit tests

Signed-off-by: Daniel Rammer <[email protected]>

* fixed lint issue

Signed-off-by: Daniel Rammer <[email protected]>

* updated TODO documentation

Signed-off-by: Daniel Rammer <[email protected]>

* changed context on NewCompactArray error log

Signed-off-by: Daniel Rammer <[email protected]>

* fixing retry attempt calculation on abort

Signed-off-by: Daniel Rammer <[email protected]>
  • Loading branch information
hamersaw authored Apr 8, 2022
1 parent cb9fdb5 commit 3f53941
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ type TaskExecutionMetadata interface {
GetSecurityContext() core.SecurityContext
IsInterruptible() bool
GetPlatformResources() *v1.ResourceRequirements
GetInterruptibleFailureThreshold() uint32
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 4 additions & 16 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,11 @@ func ApplyInterruptibleNodeAffinity(interruptible bool, podSpec *v1.PodSpec) {
// UpdatePod updates the base pod spec used to execute tasks. This is configured with plugins and task metadata-specific options
func UpdatePod(taskExecutionMetadata pluginsCore.TaskExecutionMetadata,
resourceRequirements []v1.ResourceRequirements, podSpec *v1.PodSpec) {
UpdatePodWithInterruptibleFlag(taskExecutionMetadata, resourceRequirements, podSpec, false)
}

// UpdatePodWithInterruptibleFlag updates the base pod spec used to execute tasks. This is configured with plugins and task metadata-specific options
func UpdatePodWithInterruptibleFlag(taskExecutionMetadata pluginsCore.TaskExecutionMetadata,
resourceRequirements []v1.ResourceRequirements, podSpec *v1.PodSpec, omitInterruptible bool) {
isInterruptible := !omitInterruptible && taskExecutionMetadata.IsInterruptible()
if len(podSpec.RestartPolicy) == 0 {
podSpec.RestartPolicy = v1.RestartPolicyNever
}
podSpec.Tolerations = append(
GetPodTolerations(isInterruptible, resourceRequirements...), podSpec.Tolerations...)
GetPodTolerations(taskExecutionMetadata.IsInterruptible(), resourceRequirements...), podSpec.Tolerations...)

if len(podSpec.ServiceAccountName) == 0 {
podSpec.ServiceAccountName = taskExecutionMetadata.GetK8sServiceAccount()
Expand All @@ -83,7 +76,7 @@ func UpdatePodWithInterruptibleFlag(taskExecutionMetadata pluginsCore.TaskExecut
podSpec.SchedulerName = config.GetK8sPluginConfig().SchedulerName
}
podSpec.NodeSelector = utils.UnionMaps(podSpec.NodeSelector, config.GetK8sPluginConfig().DefaultNodeSelector)
if isInterruptible {
if taskExecutionMetadata.IsInterruptible() {
podSpec.NodeSelector = utils.UnionMaps(podSpec.NodeSelector, config.GetK8sPluginConfig().InterruptibleNodeSelector)
}
if podSpec.Affinity == nil && config.GetK8sPluginConfig().DefaultAffinity != nil {
Expand All @@ -98,16 +91,11 @@ func UpdatePodWithInterruptibleFlag(taskExecutionMetadata pluginsCore.TaskExecut
if podSpec.DNSConfig == nil && config.GetK8sPluginConfig().DefaultPodDNSConfig != nil {
podSpec.DNSConfig = config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy()
}
ApplyInterruptibleNodeAffinity(isInterruptible, podSpec)
ApplyInterruptibleNodeAffinity(taskExecutionMetadata.IsInterruptible(), podSpec)
}

// ToK8sPodSpec constructs a pod spec from the given TaskTemplate
func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, error) {
return ToK8sPodSpecWithInterruptible(ctx, tCtx, false)
}

// ToK8sPodSpecWithInterruptible constructs a pod spec from the gien TaskTemplate and optionally add (interruptible instance) support.
func ToK8sPodSpecWithInterruptible(ctx context.Context, tCtx pluginsCore.TaskExecutionContext, omitInterruptible bool) (*v1.PodSpec, error) {
task, err := tCtx.TaskReader().Read(ctx)
if err != nil {
logger.Warnf(ctx, "failed to read task information when trying to construct Pod, err: %s", err.Error())
Expand Down Expand Up @@ -138,7 +126,7 @@ func ToK8sPodSpecWithInterruptible(ctx context.Context, tCtx pluginsCore.TaskExe
pod := &v1.PodSpec{
Containers: containers,
}
UpdatePodWithInterruptibleFlag(tCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{c.Resources}, pod, omitInterruptible)
UpdatePod(tCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{c.Resources}, pod)

if err := AddCoPilotToPod(ctx, config.GetK8sPluginConfig().CoPilot, pod, task.GetInterface(), tCtx.TaskExecutionMetadata(), tCtx.InputReader(), tCtx.OutputWriter(), task.GetContainer().GetDataConfig()); err != nil {
return nil, err
Expand Down
38 changes: 0 additions & 38 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ func TestPodSetup(t *testing.T) {
t.Run("ApplyInterruptibleNodeAffinity", TestApplyInterruptibleNodeAffinity)
t.Run("UpdatePod", updatePod)
t.Run("ToK8sPodInterruptible", toK8sPodInterruptible)
t.Run("toK8sPodInterruptibleFalse", toK8sPodInterruptibleFalse)
}

func TestApplyInterruptibleNodeAffinity(t *testing.T) {
Expand Down Expand Up @@ -349,43 +348,6 @@ func toK8sPodInterruptible(t *testing.T) {
)
}

func toK8sPodInterruptibleFalse(t *testing.T) {
ctx := context.TODO()

x := dummyExecContext(&v1.ResourceRequirements{
Limits: v1.ResourceList{
v1.ResourceCPU: resource.MustParse("1024m"),
v1.ResourceStorage: resource.MustParse("100M"),
ResourceNvidiaGPU: resource.MustParse("1"),
},
Requests: v1.ResourceList{
v1.ResourceCPU: resource.MustParse("1024m"),
v1.ResourceStorage: resource.MustParse("100M"),
},
})

p, err := ToK8sPodSpecWithInterruptible(ctx, x, true)
assert.NoError(t, err)
assert.Len(t, p.Tolerations, 1)
assert.Equal(t, 0, len(p.NodeSelector))
assert.Equal(t, "", p.NodeSelector["x/interruptible"])
assert.NotEqualValues(
t,
[]v1.NodeSelectorTerm{
v1.NodeSelectorTerm{
MatchExpressions: []v1.NodeSelectorRequirement{
v1.NodeSelectorRequirement{
Key: "x/interruptible",
Operator: v1.NodeSelectorOpIn,
Values: []string{"true"},
},
},
},
},
p.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms,
)
}

func TestToK8sPod(t *testing.T) {
ctx := context.TODO()

Expand Down
3 changes: 3 additions & 0 deletions flyteplugins/go/tasks/plugins/array/core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ type State struct {

// Tracks the number of subtask retries using the execution index
RetryAttempts bitarray.CompactArray `json:"retryAttempts"`

// Tracks the number of system failures for each subtask using the execution index
SystemFailures bitarray.CompactArray `json:"systemFailures"`
}

func (s State) GetReason() string {
Expand Down
43 changes: 39 additions & 4 deletions flyteplugins/go/tasks/plugins/array/k8s/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon

retryAttemptsArray, err := bitarray.NewCompactArray(count, maxValue)
if err != nil {
logger.Errorf(context.Background(), "Failed to create attempts compact array with [count: %v, maxValue: %v]", count, maxValue)
logger.Errorf(ctx, "Failed to create attempts compact array with [count: %v, maxValue: %v]", count, maxValue)
return currentState, externalResources, nil
}

Expand All @@ -106,6 +106,26 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon
currentState.RetryAttempts = retryAttemptsArray
}

// If the current State is newly minted then we must initialize SystemFailures to track how many
// times the subtask failed due to system issues, this is necessary to correctly evaluate
// interruptible subtasks.
if len(currentState.SystemFailures.GetItems()) == 0 {
count := uint(currentState.GetExecutionArraySize())
maxValue := bitarray.Item(tCtx.TaskExecutionMetadata().GetInterruptibleFailureThreshold())

systemFailuresArray, err := bitarray.NewCompactArray(count, maxValue)
if err != nil {
logger.Errorf(ctx, "Failed to create system failures array with [count: %v, maxValue: %v]", count, maxValue)
return currentState, externalResources, err
}

for i := 0; i < currentState.GetExecutionArraySize(); i++ {
systemFailuresArray.SetItem(i, 0)
}

currentState.SystemFailures = systemFailuresArray
}

// initialize log plugin
logPlugin, err := logs.InitializeLogPlugins(&config.LogConfig.Config)
if err != nil {
Expand Down Expand Up @@ -146,7 +166,8 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon
}

originalIdx := arrayCore.CalculateOriginalIndex(childIdx, newState.GetIndexesToCache())
stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt)
systemFailures := currentState.SystemFailures.GetItem(childIdx)
stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, systemFailures)
if err != nil {
return currentState, externalResources, err
}
Expand Down Expand Up @@ -188,6 +209,16 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon
return currentState, externalResources, perr
}

if phaseInfo.Err() != nil {
messageCollector.Collect(childIdx, phaseInfo.Err().String())
}

if phaseInfo.Err() != nil && phaseInfo.Err().GetKind() == idlCore.ExecutionError_SYSTEM {
newState.SystemFailures.SetItem(childIdx, systemFailures+1)
} else {
newState.SystemFailures.SetItem(childIdx, systemFailures)
}

// process subtask phase
actualPhase := phaseInfo.Phase()
if actualPhase.IsSuccess() {
Expand Down Expand Up @@ -294,15 +325,19 @@ func TerminateSubTasks(ctx context.Context, tCtx core.TaskExecutionContext, kube
messageCollector := errorcollector.NewErrorMessageCollector()
for childIdx, existingPhaseIdx := range currentState.GetArrayStatus().Detailed.GetItems() {
existingPhase := core.Phases[existingPhaseIdx]
retryAttempt := currentState.RetryAttempts.GetItem(childIdx)
retryAttempt := uint64(0)
if childIdx < len(currentState.RetryAttempts.GetItems()) {
// we can use RetryAttempts if it has been initialized, otherwise stay with default 0
retryAttempt = currentState.RetryAttempts.GetItem(childIdx)
}

// return immediately if subtask has completed or not yet started
if existingPhase.IsTerminal() || existingPhase == core.PhaseUndefined {
continue
}

originalIdx := arrayCore.CalculateOriginalIndex(childIdx, currentState.GetIndexesToCache())
stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt)
stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, childIdx, originalIdx, retryAttempt, 0)
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions flyteplugins/go/tasks/plugins/array/k8s/management_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func getMockTaskExecutionContext(ctx context.Context, parallelism int) *mocks.Ta
tMeta.OnGetAnnotations().Return(nil)
tMeta.OnGetOwnerReference().Return(metav1.OwnerReference{})
tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{})
tMeta.OnGetInterruptibleFailureThreshold().Return(2)

ow := &mocks2.OutputWriter{}
ow.OnGetOutputPrefixPath().Return("/prefix/")
Expand Down
15 changes: 12 additions & 3 deletions flyteplugins/go/tasks/plugins/array/k8s/subtask_exec_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ func (s SubTaskExecutionContext) TaskReader() pluginsCore.TaskReader {

// NewSubtaskExecutionContext constructs a SubTaskExecutionContext using the provided parameters
func NewSubTaskExecutionContext(tCtx pluginsCore.TaskExecutionContext, taskTemplate *core.TaskTemplate,
executionIndex, originalIndex int, retryAttempt uint64) (SubTaskExecutionContext, error) {
executionIndex, originalIndex int, retryAttempt uint64, systemFailures uint64) (SubTaskExecutionContext, error) {

subTaskExecutionMetadata, err := NewSubTaskExecutionMetadata(tCtx.TaskExecutionMetadata(), taskTemplate, executionIndex, retryAttempt)
subTaskExecutionMetadata, err := NewSubTaskExecutionMetadata(tCtx.TaskExecutionMetadata(), taskTemplate, executionIndex, retryAttempt, systemFailures)
if err != nil {
return SubTaskExecutionContext{}, err
}
Expand Down Expand Up @@ -135,6 +135,7 @@ type SubTaskExecutionMetadata struct {
pluginsCore.TaskExecutionMetadata
annotations map[string]string
labels map[string]string
interruptible bool
subtaskExecutionID SubTaskExecutionID
}

Expand All @@ -153,8 +154,14 @@ func (s SubTaskExecutionMetadata) GetTaskExecutionID() pluginsCore.TaskExecution
return s.subtaskExecutionID
}

// IsInterruptbile overrides the base NodeExecutionMetadata to return a subtask specific identifier
func (s SubTaskExecutionMetadata) IsInterruptible() bool {
return s.interruptible
}

// NewSubtaskExecutionMetadata constructs a SubTaskExecutionMetadata using the provided parameters
func NewSubTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecutionMetadata, taskTemplate *core.TaskTemplate, executionIndex int, retryAttempt uint64) (SubTaskExecutionMetadata, error) {
func NewSubTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecutionMetadata, taskTemplate *core.TaskTemplate,
executionIndex int, retryAttempt uint64, systemFailures uint64) (SubTaskExecutionMetadata, error) {

var err error
secretsMap := make(map[string]string)
Expand All @@ -171,10 +178,12 @@ func NewSubTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecution
}

subTaskExecutionID := NewSubTaskExecutionID(taskExecutionMetadata.GetTaskExecutionID(), executionIndex, retryAttempt)
interruptible := taskExecutionMetadata.IsInterruptible() && uint32(systemFailures) < taskExecutionMetadata.GetInterruptibleFailureThreshold()
return SubTaskExecutionMetadata{
taskExecutionMetadata,
utils.UnionMaps(taskExecutionMetadata.GetAnnotations(), secretsMap),
utils.UnionMaps(taskExecutionMetadata.GetLabels(), injectSecretsLabel),
interruptible,
subTaskExecutionID,
}, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ func TestSubTaskExecutionContext(t *testing.T) {
executionIndex := 0
originalIndex := 5
retryAttempt := uint64(1)
systemFailures := uint64(0)

stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, executionIndex, originalIndex, retryAttempt)
stCtx, err := NewSubTaskExecutionContext(tCtx, taskTemplate, executionIndex, originalIndex, retryAttempt, systemFailures)
assert.Nil(t, err)

assert.Equal(t, fmt.Sprintf("notfound-%d-%d", executionIndex, retryAttempt), stCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName())
Expand Down
1 change: 1 addition & 0 deletions flyteplugins/tests/end_to_end.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i
Name: execID,
})
tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{})
tMeta.OnGetInterruptibleFailureThreshold().Return(2)

catClient := &catalogMocks.Client{}
catData := sync.Map{}
Expand Down

0 comments on commit 3f53941

Please sign in to comment.