diff --git a/go/tasks/plugins/k8s/spark/spark.go b/go/tasks/plugins/k8s/spark/spark.go index 27ac0eda8b..b95a689811 100755 --- a/go/tasks/plugins/k8s/spark/spark.go +++ b/go/tasks/plugins/k8s/spark/spark.go @@ -137,12 +137,24 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo } // Set pod limits. - if sparkConfig["spark.kubernetes.driver.limit.cores"] == "" && sparkConfig["spark.driver.cores"] != "" { - sparkConfig["spark.kubernetes.driver.limit.cores"] = sparkConfig["spark.driver.cores"] + if len(sparkConfig["spark.kubernetes.driver.limit.cores"]) == 0 { + // spark.kubernetes.driver.request.cores takes precedence over spark.driver.cores + if len(sparkConfig["spark.kubernetes.driver.request.cores"]) != 0 { + sparkConfig["spark.kubernetes.driver.limit.cores"] = sparkConfig["spark.kubernetes.driver.request.cores"] + } else if len(sparkConfig["spark.driver.cores"]) != 0 { + sparkConfig["spark.kubernetes.driver.limit.cores"] = sparkConfig["spark.driver.cores"] + } } - if sparkConfig["spark.kubernetes.executor.limit.cores"] == "" && sparkConfig["spark.executor.cores"] != "" { - sparkConfig["spark.kubernetes.executor.limit.cores"] = sparkConfig["spark.executor.cores"] + + if len(sparkConfig["spark.kubernetes.executor.limit.cores"]) == 0 { + // spark.kubernetes.executor.request.cores takes precedence over spark.executor.cores + if len(sparkConfig["spark.kubernetes.executor.request.cores"]) != 0 { + sparkConfig["spark.kubernetes.executor.limit.cores"] = sparkConfig["spark.kubernetes.executor.request.cores"] + } else if len(sparkConfig["spark.executor.cores"]) != 0 { + sparkConfig["spark.kubernetes.executor.limit.cores"] = sparkConfig["spark.executor.cores"] + } } + sparkConfig["spark.kubernetes.executor.podNamePrefix"] = taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() sparkConfig["spark.kubernetes.driverEnv.FLYTE_START_TIME"] = strconv.FormatInt(time.Now().UnixNano()/1000000, 10) diff --git a/go/tasks/plugins/k8s/spark/spark_test.go b/go/tasks/plugins/k8s/spark/spark_test.go index 3ad0dc8fee..fc3586c11b 100755 --- a/go/tasks/plugins/k8s/spark/spark_test.go +++ b/go/tasks/plugins/k8s/spark/spark_test.go @@ -232,19 +232,19 @@ func dummySparkApplication(state sj.ApplicationStateType) *sj.SparkApplication { } } -func dummySparkCustomObj() *plugins.SparkJob { +func dummySparkCustomObj(sparkConf map[string]string) *plugins.SparkJob { sparkJob := plugins.SparkJob{} sparkJob.MainClass = sparkMainClass sparkJob.MainApplicationFile = sparkApplicationFile - sparkJob.SparkConf = dummySparkConf + sparkJob.SparkConf = sparkConf sparkJob.ApplicationType = plugins.SparkApplication_PYTHON return &sparkJob } -func dummySparkTaskTemplate(id string) *core.TaskTemplate { +func dummySparkTaskTemplate(id string, sparkConf map[string]string) *core.TaskTemplate { - sparkJob := dummySparkCustomObj() + sparkJob := dummySparkCustomObj(sparkConf) sparkJobJSON, err := utils.MarshalToString(sparkJob) if err != nil { panic(err) @@ -321,7 +321,7 @@ func TestBuildResourceSpark(t *testing.T) { sparkResourceHandler := sparkResourceHandler{} // Case1: Valid Spark Task-Template - taskTemplate := dummySparkTaskTemplate("blah-1") + taskTemplate := dummySparkTaskTemplate("blah-1", dummySparkConf) // Set spark custom feature config. assert.NoError(t, setSparkConfig(&Config{ @@ -364,7 +364,6 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, testImage, *sparkApp.Spec.Image) //Validate Driver/Executor Spec. - driverCores, _ := strconv.Atoi(dummySparkConf["spark.driver.cores"]) execCores, _ := strconv.Atoi(dummySparkConf["spark.executor.cores"]) execInstances, _ := strconv.Atoi(dummySparkConf["spark.executor.instances"]) @@ -421,7 +420,27 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1) - // Case 2: Interruptible False + // Case 2: Driver/Executor request cores set. + dummyConfWithRequest := make(map[string]string) + + for k, v := range dummySparkConf { + dummyConfWithRequest[k] = v + } + + dummyConfWithRequest["spark.kubernetes.driver.request.cores"] = "3" + dummyConfWithRequest["spark.kubernetes.executor.request.cores"] = "4" + + taskTemplate = dummySparkTaskTemplate("blah-1", dummyConfWithRequest) + resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false)) + assert.Nil(t, err) + assert.NotNil(t, resource) + sparkApp, ok = resource.(*sj.SparkApplication) + assert.True(t, ok) + + assert.Equal(t, dummyConfWithRequest["spark.kubernetes.driver.request.cores"], sparkApp.Spec.SparkConf["spark.kubernetes.driver.limit.cores"]) + assert.Equal(t, dummyConfWithRequest["spark.kubernetes.executor.request.cores"], sparkApp.Spec.SparkConf["spark.kubernetes.executor.limit.cores"]) + + // Case 3: Interruptible False resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false)) assert.Nil(t, err) assert.NotNil(t, resource) @@ -434,7 +453,7 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, 0, len(sparkApp.Spec.Executor.Tolerations)) assert.Equal(t, 0, len(sparkApp.Spec.Executor.NodeSelector)) - // Case2: Invalid Spark Task-Template + // Case 4: Invalid Spark Task-Template taskTemplate.Custom = nil resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false)) assert.NotNil(t, err)