Skip to content

Commit

Permalink
Rename and cleanup TestBuildResourceContainer
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Dye <[email protected]>
  • Loading branch information
andrewwdye committed Oct 10, 2023
1 parent 94cbc22 commit b9f01c0
Showing 1 changed file with 113 additions and 98 deletions.
211 changes: 113 additions & 98 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,26 @@ func dummySparkCustomObj(sparkConf map[string]string) *plugins.SparkJob {

func dummyPodSpec() *corev1.PodSpec {
return &corev1.PodSpec{
InitContainers: []corev1.Container{
{
Name: "init",
Image: testImage,
Args: testArgs,
},
},
Containers: []corev1.Container{
{
Name: "primary",
Image: testImage,
Args: testArgs,
Env: flytek8s.ToK8sEnvVar(dummyEnvVars),
},
{
Name: "secondary",
Image: testImage,
Args: testArgs,
Env: flytek8s.ToK8sEnvVar(dummyEnvVars),
},
},
}
}
Expand Down Expand Up @@ -379,26 +392,7 @@ func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool)
return taskCtx
}

func TestBuildResourceSpark(t *testing.T) {
sparkResourceHandler := sparkResourceHandler{}

// Case1: Valid Spark Task-Template
taskTemplate := dummySparkTaskTemplateContainer("blah-1", dummySparkConf)

// Set spark custom feature config.
assert.NoError(t, setSparkConfig(&Config{
Features: []Feature{
{
Name: "feature1",
SparkConfig: map[string]string{"spark.hadoop.feature1": "true"},
},
{
Name: "feature2",
SparkConfig: map[string]string{"spark.hadoop.feature2": "true"},
},
},
}))

func defaultPluginConfig() *config.K8sPluginConfig {
// Set Interruptible Config
runAsUser := int64(1000)
dnsOptVal1 := "1"
Expand Down Expand Up @@ -448,7 +442,7 @@ func TestBuildResourceSpark(t *testing.T) {
},
}

// interruptible/non-interruptible nodeselector requirement
// Interruptible/non-interruptible nodeselector requirement
interruptibleNodeSelectorRequirement := &corev1.NodeSelectorRequirement{
Key: "x/interruptible",
Operator: corev1.NodeSelectorOpIn,
Expand All @@ -461,9 +455,7 @@ func TestBuildResourceSpark(t *testing.T) {
Values: []string{"true"},
}

// NonInterruptibleNodeSelectorRequirement

assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{
config := &config.K8sPluginConfig{
DefaultAffinity: defaultAffinity,
DefaultPodSecurityContext: &corev1.PodSecurityContext{
RunAsUser: &runAsUser,
Expand Down Expand Up @@ -513,8 +505,32 @@ func TestBuildResourceSpark(t *testing.T) {
EnableHostNetworkingPod: &defaultPodHostNetwork,
DefaultEnvVars: defaultEnvVars,
DefaultEnvVarsFromEnv: defaultEnvVarsFromEnv,
}),
)
}
return config
}

func TestBuildResourceContainer(t *testing.T) {
sparkResourceHandler := sparkResourceHandler{}

// Case1: Valid Spark Task-Template
taskTemplate := dummySparkTaskTemplateContainer("blah-1", dummySparkConf)

// Set spark custom feature config.
assert.NoError(t, setSparkConfig(&Config{
Features: []Feature{
{
Name: "feature1",
SparkConfig: map[string]string{"spark.hadoop.feature1": "true"},
},
{
Name: "feature2",
SparkConfig: map[string]string{"spark.hadoop.feature2": "true"},
},
},
}))

defaultConfig := defaultPluginConfig()
assert.NoError(t, config.SetK8sPluginConfig(defaultConfig))
resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, true))
assert.Nil(t, err)

Expand All @@ -527,28 +543,16 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Equal(t, testArgs, sparkApp.Spec.Arguments)
assert.Equal(t, testImage, *sparkApp.Spec.Image)
assert.NotNil(t, sparkApp.Spec.Driver.SparkPodSpec.SecurityContenxt)
assert.Equal(t, *sparkApp.Spec.Driver.SparkPodSpec.SecurityContenxt.RunAsUser, runAsUser)
assert.Equal(t, *sparkApp.Spec.Driver.SparkPodSpec.SecurityContenxt.RunAsUser, *defaultConfig.DefaultPodSecurityContext.RunAsUser)
assert.NotNil(t, sparkApp.Spec.Driver.DNSConfig)
assert.Equal(t, []string{"8.8.8.8", "8.8.4.4"}, sparkApp.Spec.Driver.DNSConfig.Nameservers)
assert.Equal(t, "ndots", sparkApp.Spec.Driver.DNSConfig.Options[0].Name)
assert.Equal(t, dnsOptVal1, *sparkApp.Spec.Driver.DNSConfig.Options[0].Value)
assert.Equal(t, "single-request-reopen", sparkApp.Spec.Driver.DNSConfig.Options[1].Name)
assert.Equal(t, "timeout", sparkApp.Spec.Driver.DNSConfig.Options[2].Name)
assert.Equal(t, dnsOptVal2, *sparkApp.Spec.Driver.DNSConfig.Options[2].Value)
assert.Equal(t, "attempts", sparkApp.Spec.Driver.DNSConfig.Options[3].Name)
assert.Equal(t, dnsOptVal3, *sparkApp.Spec.Driver.DNSConfig.Options[3].Value)
assert.ElementsMatch(t, defaultConfig.DefaultPodDNSConfig.Options, sparkApp.Spec.Driver.DNSConfig.Options)
assert.Equal(t, []string{"ns1.svc.cluster-domain.example", "my.dns.search.suffix"}, sparkApp.Spec.Driver.DNSConfig.Searches)
assert.NotNil(t, sparkApp.Spec.Executor.SparkPodSpec.SecurityContenxt)
assert.Equal(t, *sparkApp.Spec.Executor.SparkPodSpec.SecurityContenxt.RunAsUser, runAsUser)
assert.Equal(t, *sparkApp.Spec.Executor.SparkPodSpec.SecurityContenxt.RunAsUser, *defaultConfig.DefaultPodSecurityContext.RunAsUser)
assert.NotNil(t, sparkApp.Spec.Executor.DNSConfig)
assert.NotNil(t, sparkApp.Spec.Executor.DNSConfig)
assert.Equal(t, "ndots", sparkApp.Spec.Executor.DNSConfig.Options[0].Name)
assert.Equal(t, dnsOptVal1, *sparkApp.Spec.Executor.DNSConfig.Options[0].Value)
assert.Equal(t, "single-request-reopen", sparkApp.Spec.Executor.DNSConfig.Options[1].Name)
assert.Equal(t, "timeout", sparkApp.Spec.Executor.DNSConfig.Options[2].Name)
assert.Equal(t, dnsOptVal2, *sparkApp.Spec.Executor.DNSConfig.Options[2].Value)
assert.Equal(t, "attempts", sparkApp.Spec.Executor.DNSConfig.Options[3].Name)
assert.Equal(t, dnsOptVal3, *sparkApp.Spec.Executor.DNSConfig.Options[3].Value)
assert.ElementsMatch(t, defaultConfig.DefaultPodDNSConfig.Options, sparkApp.Spec.Executor.DNSConfig.Options)
assert.Equal(t, []string{"ns1.svc.cluster-domain.example", "my.dns.search.suffix"}, sparkApp.Spec.Executor.DNSConfig.Searches)

//Validate Driver/Executor Spec.
Expand All @@ -563,19 +567,19 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory)
assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory)
assert.Equal(t, dummySparkConf["spark.batchScheduler"], *sparkApp.Spec.BatchScheduler)
assert.Equal(t, schedulerName, *sparkApp.Spec.Executor.SchedulerName)
assert.Equal(t, schedulerName, *sparkApp.Spec.Driver.SchedulerName)
assert.Equal(t, defaultPodHostNetwork, *sparkApp.Spec.Executor.HostNetwork)
assert.Equal(t, defaultPodHostNetwork, *sparkApp.Spec.Driver.HostNetwork)
assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Executor.SchedulerName)
assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Driver.SchedulerName)
assert.Equal(t, *defaultConfig.EnableHostNetworkingPod, *sparkApp.Spec.Executor.HostNetwork)
assert.Equal(t, *defaultConfig.EnableHostNetworkingPod, *sparkApp.Spec.Driver.HostNetwork)

// Validate
// * Interruptible Toleration and NodeSelector set for Executor but not Driver.
// * Validate Default NodeSelector set for Driver but overwritten with Interruptible NodeSelector for Executor.
// * Default Tolerations set for both Driver and Executor.
// * Interruptible/Non-Interruptible NodeSelectorRequirements set for Executor Affinity but not Driver Affinity.
// * Default tolerations set for both Driver and Executor.
// * Interruptible tolerations and node selector set for Executor but not Driver.
// * Default node selector set for both Driver and Executor.
// * Interruptible node selector requirements set for Executor Affinity, non-interruptiblefir Driver Affinity.
assert.Equal(t, 1, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Driver.NodeSelector))
assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Driver.NodeSelector)
assert.Equal(t, defaultConfig.DefaultNodeSelector, sparkApp.Spec.Driver.NodeSelector)
tolDriverDefault := sparkApp.Spec.Driver.Tolerations[0]
assert.Equal(t, tolDriverDefault.Key, "x/flyte")
assert.Equal(t, tolDriverDefault.Value, "default")
Expand Down Expand Up @@ -633,31 +637,36 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Equal(t, dummySparkConf["spark.flyteorg.feature3.enabled"], sparkApp.Spec.SparkConf["spark.flyteorg.feature3.enabled"])

assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1)
assert.Equal(t, sparkApp.Spec.Driver.EnvVars["foo"], defaultEnvVars["foo"])
assert.Equal(t, sparkApp.Spec.Executor.EnvVars["foo"], defaultEnvVars["foo"])
assert.Equal(t, sparkApp.Spec.Driver.EnvVars["fooEnv"], targetValueFromEnv)
assert.Equal(t, sparkApp.Spec.Executor.EnvVars["fooEnv"], targetValueFromEnv)

assert.Equal(
t,
sparkApp.Spec.Driver.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
)
assert.Equal(
t,
sparkApp.Spec.Driver.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1],
*nonInterruptibleNodeSelectorRequirement,
)
assert.Equal(
t,
sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
)
assert.Equal(
t,
sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1],
*interruptibleNodeSelectorRequirement,
)
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Driver.EnvVars["foo"])
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], sparkApp.Spec.Executor.EnvVars["foo"])
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Driver.EnvVars["fooEnv"])
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], sparkApp.Spec.Executor.EnvVars["fooEnv"])

assert.Equal(t, &corev1.NodeAffinity{
RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{
NodeSelectorTerms: []corev1.NodeSelectorTerm{
{
MatchExpressions: []corev1.NodeSelectorRequirement{
defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
*defaultConfig.NonInterruptibleNodeSelectorRequirement,
},
},
},
},
}, sparkApp.Spec.Driver.Affinity.NodeAffinity)

assert.Equal(t, &corev1.NodeAffinity{
RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{
NodeSelectorTerms: []corev1.NodeSelectorTerm{
{
MatchExpressions: []corev1.NodeSelectorRequirement{
defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
*defaultConfig.InterruptibleNodeSelectorRequirement,
},
},
},
},
}, sparkApp.Spec.Executor.Affinity.NodeAffinity)

// Case 2: Driver/Executor request cores set.
dummyConfWithRequest := make(map[string]string)
Expand Down Expand Up @@ -690,36 +699,41 @@ func TestBuildResourceSpark(t *testing.T) {
// Validate that the default Toleration and NodeSelector are set for both Driver and Executors.
assert.Equal(t, 1, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Driver.NodeSelector))
assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Driver.NodeSelector)
assert.Equal(t, defaultConfig.DefaultNodeSelector, sparkApp.Spec.Driver.NodeSelector)
assert.Equal(t, 1, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Executor.NodeSelector))
assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Executor.NodeSelector)
assert.Equal(t, defaultConfig.DefaultNodeSelector, sparkApp.Spec.Executor.NodeSelector)
assert.Equal(t, sparkApp.Spec.Executor.Tolerations[0].Key, "x/flyte")
assert.Equal(t, sparkApp.Spec.Executor.Tolerations[0].Value, "default")
assert.Equal(t, sparkApp.Spec.Driver.Tolerations[0].Key, "x/flyte")
assert.Equal(t, sparkApp.Spec.Driver.Tolerations[0].Value, "default")

// Validate correct affinity and nodeselector requirements are set for both Driver and Executors.
assert.Equal(
t,
sparkApp.Spec.Driver.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
)
assert.Equal(
t,
sparkApp.Spec.Driver.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1],
*nonInterruptibleNodeSelectorRequirement,
)
assert.Equal(
t,
sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
)
assert.Equal(
t,
sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1],
*nonInterruptibleNodeSelectorRequirement,
)
assert.Equal(t, &corev1.NodeAffinity{
RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{
NodeSelectorTerms: []corev1.NodeSelectorTerm{
{
MatchExpressions: []corev1.NodeSelectorRequirement{
defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
*defaultConfig.NonInterruptibleNodeSelectorRequirement,
},
},
},
},
}, sparkApp.Spec.Driver.Affinity.NodeAffinity)

assert.Equal(t, &corev1.NodeAffinity{
RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{
NodeSelectorTerms: []corev1.NodeSelectorTerm{
{
MatchExpressions: []corev1.NodeSelectorRequirement{
defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
*defaultConfig.NonInterruptibleNodeSelectorRequirement,
},
},
},
},
}, sparkApp.Spec.Executor.Affinity.NodeAffinity)

// Case 4: Invalid Spark Task-Template
taskTemplate.Custom = nil
Expand Down Expand Up @@ -748,6 +762,7 @@ func TestBuildResourcePodTemplate(t *testing.T) {
}
podSpec := dummyPodSpec()
podSpec.Tolerations = append(podSpec.Tolerations, extraToleration)
podSpec.NodeSelector["x/custom"] = "foo"
taskTemplate := dummySparkTaskTemplatePod("blah-1", dummySparkConf, podSpec)
taskTemplate.GetK8SPod()
sparkResourceHandler := sparkResourceHandler{}
Expand Down

0 comments on commit b9f01c0

Please sign in to comment.