diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index c6ce78c1e..66600709a 100755 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -236,6 +236,11 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo j.Spec.MainClass = &sparkJob.MainClass } + // Spark driver pods should always run as non-interruptible. As such, we hardcode + // `interruptible=false` to explicitly add non-interruptible node selector + // requirements to the driver pods + flytek8s.ApplyInterruptibleNodeSelectorRequirement(false, j.Spec.Driver.Affinity) + // Add Interruptible Tolerations/NodeSelector to only Executor pods. // The Interruptible NodeSelector takes precedence over the DefaultNodeSelector if taskCtx.TaskExecutionMetadata().IsInterruptible() { diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index 5ac764910..6b6ed18d8 100755 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -585,8 +585,17 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, sparkApp.Spec.Executor.EnvVars["foo"], defaultEnvVars["foo"]) assert.Equal(t, sparkApp.Spec.Driver.EnvVars["fooEnv"], targetValueFromEnv) assert.Equal(t, sparkApp.Spec.Executor.EnvVars["fooEnv"], targetValueFromEnv) - assert.Equal(t, sparkApp.Spec.Driver.Affinity, defaultAffinity) + assert.Equal( + t, + sparkApp.Spec.Driver.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + ) + assert.Equal( + t, + sparkApp.Spec.Driver.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1], + *nonInterruptibleNodeSelectorRequirement, + ) assert.Equal( t, sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], @@ -639,7 +648,16 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, sparkApp.Spec.Driver.Tolerations[0].Value, "default") // Validate correct affinity and nodeselector requirements are set for both Driver and Executors. - assert.Equal(t, sparkApp.Spec.Driver.Affinity, defaultAffinity) + assert.Equal( + t, + sparkApp.Spec.Driver.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + ) + assert.Equal( + t, + sparkApp.Spec.Driver.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1], + *nonInterruptibleNodeSelectorRequirement, + ) assert.Equal( t, sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],