Skip to content

Commit

Permalink
Disable Interruptible for K8s array tasks (flyteorg#214)
Browse files Browse the repository at this point in the history
* Disable Interruptible for K8s array tasks
  • Loading branch information
anandswaminathan authored Sep 29, 2021
1 parent 565b532 commit 2cbcc27
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 5 deletions.
18 changes: 14 additions & 4 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,28 +61,38 @@ func ApplyInterruptibleNodeAffinity(interruptible bool, podSpec *v1.PodSpec) {
// Updates the base pod spec used to execute tasks. This is configured with plugins and task metadata-specific options
func UpdatePod(taskExecutionMetadata pluginsCore.TaskExecutionMetadata,
resourceRequirements []v1.ResourceRequirements, podSpec *v1.PodSpec) {
UpdatePodWithInterruptibleFlag(taskExecutionMetadata, resourceRequirements, podSpec, false)
}

// Updates the base pod spec used to execute tasks. This is configured with plugins and task metadata-specific options
func UpdatePodWithInterruptibleFlag(taskExecutionMetadata pluginsCore.TaskExecutionMetadata,
resourceRequirements []v1.ResourceRequirements, podSpec *v1.PodSpec, omitInterruptible bool) {
isInterruptible := !omitInterruptible && taskExecutionMetadata.IsInterruptible()
if len(podSpec.RestartPolicy) == 0 {
podSpec.RestartPolicy = v1.RestartPolicyNever
}
podSpec.Tolerations = append(
GetPodTolerations(taskExecutionMetadata.IsInterruptible(), resourceRequirements...), podSpec.Tolerations...)
GetPodTolerations(isInterruptible, resourceRequirements...), podSpec.Tolerations...)
if len(podSpec.ServiceAccountName) == 0 {
podSpec.ServiceAccountName = taskExecutionMetadata.GetK8sServiceAccount()
}
if len(podSpec.SchedulerName) == 0 {
podSpec.SchedulerName = config.GetK8sPluginConfig().SchedulerName
}
podSpec.NodeSelector = utils.UnionMaps(podSpec.NodeSelector, config.GetK8sPluginConfig().DefaultNodeSelector)
if taskExecutionMetadata.IsInterruptible() {
if isInterruptible {
podSpec.NodeSelector = utils.UnionMaps(podSpec.NodeSelector, config.GetK8sPluginConfig().InterruptibleNodeSelector)
}
if podSpec.Affinity == nil && config.GetK8sPluginConfig().DefaultAffinity != nil {
podSpec.Affinity = config.GetK8sPluginConfig().DefaultAffinity.DeepCopy()
}
ApplyInterruptibleNodeAffinity(taskExecutionMetadata.IsInterruptible(), podSpec)
ApplyInterruptibleNodeAffinity(isInterruptible, podSpec)
}

func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, error) {
return ToK8sPodSpecWithInterruptible(ctx, tCtx, false)
}
func ToK8sPodSpecWithInterruptible(ctx context.Context, tCtx pluginsCore.TaskExecutionContext, omitInterruptible bool) (*v1.PodSpec, error) {
task, err := tCtx.TaskReader().Read(ctx)
if err != nil {
logger.Warnf(ctx, "failed to read task information when trying to construct Pod, err: %s", err.Error())
Expand Down Expand Up @@ -113,7 +123,7 @@ func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*
pod := &v1.PodSpec{
Containers: containers,
}
UpdatePod(tCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{c.Resources}, pod)
UpdatePodWithInterruptibleFlag(tCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{c.Resources}, pod, omitInterruptible)

if err := AddCoPilotToPod(ctx, config.GetK8sPluginConfig().CoPilot, pod, task.GetInterface(), tCtx.TaskExecutionMetadata(), tCtx.InputReader(), tCtx.OutputWriter(), task.GetContainer().GetDataConfig()); err != nil {
return nil, err
Expand Down
38 changes: 38 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ func TestPodSetup(t *testing.T) {
t.Run("ApplyInterruptibleNodeAffinity", TestApplyInterruptibleNodeAffinity)
t.Run("UpdatePod", updatePod)
t.Run("ToK8sPodInterruptible", toK8sPodInterruptible)
t.Run("toK8sPodInterruptibleFalse", toK8sPodInterruptibleFalse)
}

func TestApplyInterruptibleNodeAffinity(t *testing.T) {
Expand Down Expand Up @@ -344,6 +345,43 @@ func toK8sPodInterruptible(t *testing.T) {
)
}

func toK8sPodInterruptibleFalse(t *testing.T) {
ctx := context.TODO()

x := dummyExecContext(&v1.ResourceRequirements{
Limits: v1.ResourceList{
v1.ResourceCPU: resource.MustParse("1024m"),
v1.ResourceStorage: resource.MustParse("100M"),
ResourceNvidiaGPU: resource.MustParse("1"),
},
Requests: v1.ResourceList{
v1.ResourceCPU: resource.MustParse("1024m"),
v1.ResourceStorage: resource.MustParse("100M"),
},
})

p, err := ToK8sPodSpecWithInterruptible(ctx, x, true)
assert.NoError(t, err)
assert.Len(t, p.Tolerations, 1)
assert.Equal(t, 0, len(p.NodeSelector))
assert.Equal(t, "", p.NodeSelector["x/interruptible"])
assert.NotEqualValues(
t,
[]v1.NodeSelectorTerm{
v1.NodeSelectorTerm{
MatchExpressions: []v1.NodeSelectorRequirement{
v1.NodeSelectorRequirement{
Key: "x/interruptible",
Operator: v1.NodeSelectorOpIn,
Values: []string{"true"},
},
},
},
},
p.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms,
)
}

func TestToK8sPod(t *testing.T) {
ctx := context.TODO()

Expand Down
2 changes: 1 addition & 1 deletion flyteplugins/go/tasks/plugins/array/k8s/transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func FlyteArrayJobToK8sPodTemplate(ctx context.Context, tCtx core.TaskExecutionC
},
}
if taskTemplate.GetContainer() != nil {
podSpec, err := flytek8s.ToK8sPodSpec(ctx, arrTCtx)
podSpec, err := flytek8s.ToK8sPodSpecWithInterruptible(ctx, arrTCtx, true)
if err != nil {
return v1.Pod{}, nil, err
}
Expand Down

0 comments on commit 2cbcc27

Please sign in to comment.