diff --git a/go/tasks/plugins/k8s/spark/spark.go b/go/tasks/plugins/k8s/spark/spark.go index 48d72c33c..41fe12cda 100755 --- a/go/tasks/plugins/k8s/spark/spark.go +++ b/go/tasks/plugins/k8s/spark/spark.go @@ -175,7 +175,8 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo if sparkJob.MainApplicationFile != "" { j.Spec.MainApplicationFile = &sparkJob.MainApplicationFile - } else if sparkJob.MainClass != "" { + } + if sparkJob.MainClass != "" { j.Spec.MainClass = &sparkJob.MainClass } diff --git a/go/tasks/plugins/k8s/spark/spark_test.go b/go/tasks/plugins/k8s/spark/spark_test.go index 9cc60894c..c74cc3cc9 100755 --- a/go/tasks/plugins/k8s/spark/spark_test.go +++ b/go/tasks/plugins/k8s/spark/spark_test.go @@ -27,6 +27,7 @@ import ( v1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) +const sparkMainClass = "MainClass" const sparkApplicationFile = "local:///spark_app.py" const testImage = "image://" const sparkUIAddress = "spark-ui.flyte" @@ -176,6 +177,7 @@ func dummySparkApplication(state sj.ApplicationStateType) *sj.SparkApplication { func dummySparkCustomObj() *plugins.SparkJob { sparkJob := plugins.SparkJob{} + sparkJob.MainClass = sparkMainClass sparkJob.MainApplicationFile = sparkApplicationFile sparkJob.SparkConf = dummySparkConf sparkJob.ApplicationType = plugins.SparkApplication_PYTHON @@ -266,6 +268,7 @@ func TestBuildResourceSpark(t *testing.T) { assert.NotNil(t, resource) sparkApp, ok := resource.(*sj.SparkApplication) assert.True(t, ok) + assert.Equal(t, sparkMainClass, *sparkApp.Spec.MainClass) assert.Equal(t, sparkApplicationFile, *sparkApp.Spec.MainApplicationFile) assert.Equal(t, sj.PythonApplicationType, sparkApp.Spec.Type) assert.Equal(t, testArgs, sparkApp.Spec.Arguments)