From 991d724e1c44f753b05021cc980acc4ad81d2766 Mon Sep 17 00:00:00 2001 From: akhurana001 <34587798+akhurana001@users.noreply.github.com> Date: Fri, 6 Nov 2020 14:31:57 -0800 Subject: [PATCH] Update Spark Plugin (#138) --- .../go/tasks/plugins/k8s/spark/spark.go | 32 +++++++++++++++++++ .../go/tasks/plugins/k8s/spark/spark_test.go | 18 +++++++++-- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index 985026a964..51f69a0c5a 100755 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -169,6 +169,24 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo } sparkConfig["spark.kubernetes.executor.podNamePrefix"] = taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + // Add driver/executor defaults to CRD Driver/Executor Spec as well. + cores, err := strconv.Atoi(sparkConfig["spark.driver.cores"]) + if err == nil { + driverSpec.Cores = intPtr(int32(cores)) + } + driverSpec.Memory = strPtr(sparkConfig["spark.driver.memory"]) + + execCores, err := strconv.Atoi(sparkConfig["spark.executor.cores"]) + if err == nil { + executorSpec.Cores = intPtr(int32(execCores)) + } + + execCount, err := strconv.Atoi(sparkConfig["spark.executor.instances"]) + if err == nil { + executorSpec.Instances = intPtr(int32(execCount)) + } + executorSpec.Memory = strPtr(sparkConfig["spark.executor.memory"]) + j := &sparkOp.SparkApplication{ TypeMeta: metav1.TypeMeta{ Kind: KindSparkApplication, @@ -392,3 +410,17 @@ func init() { DefaultForTaskTypes: []pluginsCore.TaskType{sparkTaskType}, }) } + +func strPtr(str string) *string { + if str == "" { + return nil + } + return &str +} + +func intPtr(val int32) *int32 { + if val == 0 { + return nil + } + return &val +} diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index a970d80bb4..f538b8ed6e 100755 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -3,6 +3,7 @@ package spark import ( "context" "fmt" + "strconv" "testing" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" @@ -37,9 +38,9 @@ const sparkUIAddress = "spark-ui.flyte" var ( dummySparkConf = map[string]string{ - "spark.driver.memory": "500M", + "spark.driver.memory": "200M", "spark.driver.cores": "1", - "spark.executor.cores": "1", + "spark.executor.cores": "2", "spark.executor.instances": "3", "spark.executor.memory": "500M", "spark.flyte.feature1.enabled": "true", @@ -316,6 +317,19 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, sj.PythonApplicationType, sparkApp.Spec.Type) assert.Equal(t, testArgs, sparkApp.Spec.Arguments) assert.Equal(t, testImage, *sparkApp.Spec.Image) + + //Validate Driver/Executor Spec. + + driverCores, _ := strconv.Atoi(dummySparkConf["spark.driver.cores"]) + execCores, _ := strconv.Atoi(dummySparkConf["spark.executor.cores"]) + execInstances, _ := strconv.Atoi(dummySparkConf["spark.executor.instances"]) + + assert.Equal(t, int32(driverCores), *sparkApp.Spec.Driver.Cores) + assert.Equal(t, int32(execCores), *sparkApp.Spec.Executor.Cores) + assert.Equal(t, int32(execInstances), *sparkApp.Spec.Executor.Instances) + assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory) + assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory) + // Validate Interruptible Toleration and NodeSelector set for Executor but not Driver. assert.Equal(t, 0, len(sparkApp.Spec.Driver.Tolerations)) assert.Equal(t, 0, len(sparkApp.Spec.Driver.NodeSelector))