Skip to content

Commit

Permalink
Support K8s Driver/Executor Cores Request Conf (flyteorg#148)
Browse files Browse the repository at this point in the history
* Support K8s Driver/Executor Cores Request Conf

* go-import

* PR comments
  • Loading branch information
akhurana001 authored Feb 3, 2021
1 parent 31b3ce3 commit 9963622
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
20 changes: 16 additions & 4 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,24 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
}

// Set pod limits.
if sparkConfig["spark.kubernetes.driver.limit.cores"] == "" && sparkConfig["spark.driver.cores"] != "" {
sparkConfig["spark.kubernetes.driver.limit.cores"] = sparkConfig["spark.driver.cores"]
if len(sparkConfig["spark.kubernetes.driver.limit.cores"]) == 0 {
// spark.kubernetes.driver.request.cores takes precedence over spark.driver.cores
if len(sparkConfig["spark.kubernetes.driver.request.cores"]) != 0 {
sparkConfig["spark.kubernetes.driver.limit.cores"] = sparkConfig["spark.kubernetes.driver.request.cores"]
} else if len(sparkConfig["spark.driver.cores"]) != 0 {
sparkConfig["spark.kubernetes.driver.limit.cores"] = sparkConfig["spark.driver.cores"]
}
}
if sparkConfig["spark.kubernetes.executor.limit.cores"] == "" && sparkConfig["spark.executor.cores"] != "" {
sparkConfig["spark.kubernetes.executor.limit.cores"] = sparkConfig["spark.executor.cores"]

if len(sparkConfig["spark.kubernetes.executor.limit.cores"]) == 0 {
// spark.kubernetes.executor.request.cores takes precedence over spark.executor.cores
if len(sparkConfig["spark.kubernetes.executor.request.cores"]) != 0 {
sparkConfig["spark.kubernetes.executor.limit.cores"] = sparkConfig["spark.kubernetes.executor.request.cores"]
} else if len(sparkConfig["spark.executor.cores"]) != 0 {
sparkConfig["spark.kubernetes.executor.limit.cores"] = sparkConfig["spark.executor.cores"]
}
}

sparkConfig["spark.kubernetes.executor.podNamePrefix"] = taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
sparkConfig["spark.kubernetes.driverEnv.FLYTE_START_TIME"] = strconv.FormatInt(time.Now().UnixNano()/1000000, 10)

Expand Down
35 changes: 27 additions & 8 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,19 +232,19 @@ func dummySparkApplication(state sj.ApplicationStateType) *sj.SparkApplication {
}
}

func dummySparkCustomObj() *plugins.SparkJob {
func dummySparkCustomObj(sparkConf map[string]string) *plugins.SparkJob {
sparkJob := plugins.SparkJob{}

sparkJob.MainClass = sparkMainClass
sparkJob.MainApplicationFile = sparkApplicationFile
sparkJob.SparkConf = dummySparkConf
sparkJob.SparkConf = sparkConf
sparkJob.ApplicationType = plugins.SparkApplication_PYTHON
return &sparkJob
}

func dummySparkTaskTemplate(id string) *core.TaskTemplate {
func dummySparkTaskTemplate(id string, sparkConf map[string]string) *core.TaskTemplate {

sparkJob := dummySparkCustomObj()
sparkJob := dummySparkCustomObj(sparkConf)
sparkJobJSON, err := utils.MarshalToString(sparkJob)
if err != nil {
panic(err)
Expand Down Expand Up @@ -321,7 +321,7 @@ func TestBuildResourceSpark(t *testing.T) {
sparkResourceHandler := sparkResourceHandler{}

// Case1: Valid Spark Task-Template
taskTemplate := dummySparkTaskTemplate("blah-1")
taskTemplate := dummySparkTaskTemplate("blah-1", dummySparkConf)

// Set spark custom feature config.
assert.NoError(t, setSparkConfig(&Config{
Expand Down Expand Up @@ -364,7 +364,6 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Equal(t, testImage, *sparkApp.Spec.Image)

//Validate Driver/Executor Spec.

driverCores, _ := strconv.Atoi(dummySparkConf["spark.driver.cores"])
execCores, _ := strconv.Atoi(dummySparkConf["spark.executor.cores"])
execInstances, _ := strconv.Atoi(dummySparkConf["spark.executor.instances"])
Expand Down Expand Up @@ -421,7 +420,27 @@ func TestBuildResourceSpark(t *testing.T) {

assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1)

// Case 2: Interruptible False
// Case 2: Driver/Executor request cores set.
dummyConfWithRequest := make(map[string]string)

for k, v := range dummySparkConf {
dummyConfWithRequest[k] = v
}

dummyConfWithRequest["spark.kubernetes.driver.request.cores"] = "3"
dummyConfWithRequest["spark.kubernetes.executor.request.cores"] = "4"

taskTemplate = dummySparkTaskTemplate("blah-1", dummyConfWithRequest)
resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false))
assert.Nil(t, err)
assert.NotNil(t, resource)
sparkApp, ok = resource.(*sj.SparkApplication)
assert.True(t, ok)

assert.Equal(t, dummyConfWithRequest["spark.kubernetes.driver.request.cores"], sparkApp.Spec.SparkConf["spark.kubernetes.driver.limit.cores"])
assert.Equal(t, dummyConfWithRequest["spark.kubernetes.executor.request.cores"], sparkApp.Spec.SparkConf["spark.kubernetes.executor.limit.cores"])

// Case 3: Interruptible False
resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false))
assert.Nil(t, err)
assert.NotNil(t, resource)
Expand All @@ -434,7 +453,7 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Equal(t, 0, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, 0, len(sparkApp.Spec.Executor.NodeSelector))

// Case2: Invalid Spark Task-Template
// Case 4: Invalid Spark Task-Template
taskTemplate.Custom = nil
resource, err = sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, false))
assert.NotNil(t, err)
Expand Down

0 comments on commit 9963622

Please sign in to comment.