From 93ed2085dd28bd8c58e56ea1c6c6f179cf1b4e2d Mon Sep 17 00:00:00 2001 From: Anand Swaminathan Date: Wed, 17 Mar 2021 14:49:15 -0700 Subject: [PATCH] Inject and Use values from Security Context (#153) * Inject and Use values from Security Context Signed-off-by: Anand Swaminathan --- .../pluginmachinery/core/exec_metadata.go | 1 + .../core/mocks/task_execution_metadata.go | 34 +++++++++ .../tasks/pluginmachinery/flytek8s/utils.go | 16 +++++ .../pluginmachinery/flytek8s/utils_test.go | 26 +++++++ .../plugins/array/awsbatch/config/config.go | 7 +- .../plugins/array/awsbatch/job_definition.go | 2 +- .../array/awsbatch/job_definition_test.go | 69 ++++++++++++++++++- .../go/tasks/plugins/awsutils/awsutils.go | 24 +++++-- .../tasks/plugins/k8s/container/container.go | 3 +- .../plugins/k8s/container/container_test.go | 3 + .../plugins/k8s/sagemaker/builtin_training.go | 5 +- .../k8s/sagemaker/hyperparameter_tuning.go | 5 +- .../k8s/sagemaker/plugin_test_utils.go | 9 +++ .../go/tasks/plugins/k8s/sidecar/sidecar.go | 3 +- .../tasks/plugins/k8s/sidecar/sidecar_test.go | 3 + .../go/tasks/plugins/k8s/spark/spark.go | 9 ++- .../go/tasks/plugins/k8s/spark/spark_test.go | 4 ++ 17 files changed, 204 insertions(+), 19 deletions(-) diff --git a/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go b/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go index 64ec7ce5d8..ccde9c2c78 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/exec_metadata.go @@ -33,5 +33,6 @@ type TaskExecutionMetadata interface { GetMaxAttempts() uint32 GetAnnotations() map[string]string GetK8sServiceAccount() string + GetSecurityContext() core.SecurityContext IsInterruptible() bool } diff --git a/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go b/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go index 24b692c471..7ec516337a 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go @@ -3,7 +3,9 @@ package mocks import ( + flyteidlcore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" core "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + mock "github.com/stretchr/testify/mock" types "k8s.io/apimachinery/pkg/types" @@ -278,6 +280,38 @@ func (_m *TaskExecutionMetadata) GetOwnerReference() v1.OwnerReference { return r0 } +type TaskExecutionMetadata_GetSecurityContext struct { + *mock.Call +} + +func (_m TaskExecutionMetadata_GetSecurityContext) Return(_a0 flyteidlcore.SecurityContext) *TaskExecutionMetadata_GetSecurityContext { + return &TaskExecutionMetadata_GetSecurityContext{Call: _m.Call.Return(_a0)} +} + +func (_m *TaskExecutionMetadata) OnGetSecurityContext() *TaskExecutionMetadata_GetSecurityContext { + c := _m.On("GetSecurityContext") + return &TaskExecutionMetadata_GetSecurityContext{Call: c} +} + +func (_m *TaskExecutionMetadata) OnGetSecurityContextMatch(matchers ...interface{}) *TaskExecutionMetadata_GetSecurityContext { + c := _m.On("GetSecurityContext", matchers...) + return &TaskExecutionMetadata_GetSecurityContext{Call: c} +} + +// GetSecurityContext provides a mock function with given fields: +func (_m *TaskExecutionMetadata) GetSecurityContext() flyteidlcore.SecurityContext { + ret := _m.Called() + + var r0 flyteidlcore.SecurityContext + if rf, ok := ret.Get(0).(func() flyteidlcore.SecurityContext); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(flyteidlcore.SecurityContext) + } + + return r0 +} + type TaskExecutionMetadata_GetTaskExecutionID struct { *mock.Call } diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/utils.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/utils.go index f6fcb61a6c..70e2d7d579 100755 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/utils.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/utils.go @@ -2,6 +2,7 @@ package flytek8s import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + pluginmachinery_core "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" v1 "k8s.io/api/core/v1" ) @@ -12,3 +13,18 @@ func ToK8sEnvVar(env []*core.KeyValuePair) []v1.EnvVar { } return envVars } + +func GetServiceAccountNameFromTaskExecutionMetadata(taskExecutionMetadata pluginmachinery_core.TaskExecutionMetadata) string { + var serviceAccount string + securityContext := taskExecutionMetadata.GetSecurityContext() + if securityContext.GetRunAs() != nil { + serviceAccount = securityContext.GetRunAs().GetK8SServiceAccount() + } + + // TO BE DEPRECATED + if len(serviceAccount) == 0 { + serviceAccount = taskExecutionMetadata.GetK8sServiceAccount() + } + + return serviceAccount +} diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/utils_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/utils_test.go index 6ac2bcac9d..52473ff890 100755 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/utils_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/utils_test.go @@ -1 +1,27 @@ package flytek8s + +import ( + "testing" + + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" + + "github.com/stretchr/testify/assert" +) + +func TestGetServiceAccountNameFromTaskExecutionMetadata(t *testing.T) { + mockTaskExecMetadata := mocks.TaskExecutionMetadata{} + mockTaskExecMetadata.OnGetSecurityContext().Return(core.SecurityContext{ + RunAs: &core.Identity{K8SServiceAccount: "service-account"}, + }) + result := GetServiceAccountNameFromTaskExecutionMetadata(&mockTaskExecMetadata) + assert.Equal(t, "service-account", result) +} + +func TestGetServiceAccountNameFromServiceAccount(t *testing.T) { + mockTaskExecMetadata := mocks.TaskExecutionMetadata{} + mockTaskExecMetadata.OnGetSecurityContext().Return(core.SecurityContext{}) + mockTaskExecMetadata.OnGetK8sServiceAccount().Return("service-account") + result := GetServiceAccountNameFromTaskExecutionMetadata(&mockTaskExecMetadata) + assert.Equal(t, "service-account", result) +} diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/config/config.go b/flyteplugins/go/tasks/plugins/array/awsbatch/config/config.go index 557789a75a..6fd84d6ffe 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/config/config.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/config/config.go @@ -22,9 +22,10 @@ type Config struct { // Provide additional environment variable pairs that plugin authors will provide to containers DefaultEnvVars map[string]string `json:"defaultEnvVars" pflag:"-,Additional environment variable that should be injected into every resource"` MaxErrorStringLength int `json:"maxErrLength" pflag:",Determines the maximum length of the error string returned for the array."` - RoleAnnotationKey string `json:"roleAnnotationKey" pflag:",Map key to use to lookup role from task annotations."` - OutputAssembler workqueue.Config `json:"outputAssembler"` - ErrorAssembler workqueue.Config `json:"errorAssembler"` + // This can be deprecated. Just having it for backward compatibility + RoleAnnotationKey string `json:"roleAnnotationKey" pflag:",Map key to use to lookup role from task annotations."` + OutputAssembler workqueue.Config `json:"outputAssembler"` + ErrorAssembler workqueue.Config `json:"errorAssembler"` } type JobStoreConfig struct { diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/job_definition.go b/flyteplugins/go/tasks/plugins/array/awsbatch/job_definition.go index 27b36bdbec..9809073370 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/job_definition.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/job_definition.go @@ -50,7 +50,7 @@ func EnsureJobDefinition(ctx context.Context, tCtx pluginCore.TaskExecutionConte return nil, errors.Errorf(pluginErrors.BadTaskSpecification, "Tasktemplate does not contain a container image.") } - role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, tCtx.TaskExecutionMetadata().GetAnnotations()) + role := awsUtils.GetRoleFromSecurityContext(cfg.RoleAnnotationKey, tCtx.TaskExecutionMetadata()) cacheKey := definition.NewCacheKey(role, containerImage) if existingArn, found := definitionCache.Get(cacheKey); found { diff --git a/flyteplugins/go/tasks/plugins/array/awsbatch/job_definition_test.go b/flyteplugins/go/tasks/plugins/array/awsbatch/job_definition_test.go index 3fb9ee306a..a7dbd97754 100644 --- a/flyteplugins/go/tasks/plugins/array/awsbatch/job_definition_test.go +++ b/flyteplugins/go/tasks/plugins/array/awsbatch/job_definition_test.go @@ -64,7 +64,7 @@ func TestEnsureJobDefinition(t *testing.T) { tMeta.OnGetTaskExecutionID().Return(tID) tMeta.OnGetOverrides().Return(overrides) tMeta.OnGetAnnotations().Return(map[string]string{}) - + tMeta.OnGetSecurityContext().Return(core.SecurityContext{}) tCtx := &mocks.TaskExecutionContext{} tCtx.OnTaskReader().Return(tReader) tCtx.OnTaskExecutionMetadata().Return(tMeta) @@ -101,3 +101,70 @@ func TestEnsureJobDefinition(t *testing.T) { assert.Equal(t, "their-arn", nextState.JobDefinitionArn) }) } + +func TestEnsureJobDefinitionWithSecurityContext(t *testing.T) { + ctx := context.Background() + + tReader := &mocks.TaskReader{} + tReader.OnReadMatch(mock.Anything).Return(&core.TaskTemplate{ + Interface: &core.TypedInterface{ + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{"var1": {Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}}}, + }, + }, + Target: &core.TaskTemplate_Container{ + Container: createSampleContainerTask(), + }, + }, nil) + + overrides := &mocks.TaskOverrides{} + overrides.OnGetConfig().Return(&v1.ConfigMap{Data: map[string]string{ + DynamicTaskQueueKey: "queue1", + }}) + + tID := &mocks.TaskExecutionID{} + tID.OnGetGeneratedName().Return("found") + + tMeta := &mocks.TaskExecutionMetadata{} + tMeta.OnGetTaskExecutionID().Return(tID) + tMeta.OnGetOverrides().Return(overrides) + tMeta.OnGetAnnotations().Return(map[string]string{}) + tMeta.OnGetSecurityContext().Return(core.SecurityContext{ + RunAs: &core.Identity{IamRole: "new-role"}, + }) + tCtx := &mocks.TaskExecutionContext{} + tCtx.OnTaskReader().Return(tReader) + tCtx.OnTaskExecutionMetadata().Return(tMeta) + + cfg := &config.Config{} + batchClient := NewCustomBatchClient(batchMocks.NewMockAwsBatchClient(), "", "", + utils.NewRateLimiter("", 10, 20), + utils.NewRateLimiter("", 10, 20)) + + t.Run("Not Found", func(t *testing.T) { + dCache := definition.NewCache(10) + + nextState, err := EnsureJobDefinition(ctx, tCtx, cfg, batchClient, dCache, &State{ + State: &arrayCore.State{}, + }) + + assert.NoError(t, err) + assert.NotNil(t, nextState) + assert.Equal(t, "my-arn", nextState.JobDefinitionArn) + p, v := nextState.GetPhase() + assert.Equal(t, arrayCore.PhaseLaunch, p) + assert.Zero(t, v) + }) + + t.Run("Found", func(t *testing.T) { + dCache := definition.NewCache(10) + assert.NoError(t, dCache.Put(definition.NewCacheKey("new-role", "img1"), "their-arn")) + + nextState, err := EnsureJobDefinition(ctx, tCtx, cfg, batchClient, dCache, &State{ + State: &arrayCore.State{}, + }) + assert.NoError(t, err) + assert.NotNil(t, nextState) + assert.Equal(t, "their-arn", nextState.JobDefinitionArn) + }) +} diff --git a/flyteplugins/go/tasks/plugins/awsutils/awsutils.go b/flyteplugins/go/tasks/plugins/awsutils/awsutils.go index 2c5b8a496b..641b3e4280 100644 --- a/flyteplugins/go/tasks/plugins/awsutils/awsutils.go +++ b/flyteplugins/go/tasks/plugins/awsutils/awsutils.go @@ -1,10 +1,26 @@ package awsutils -import "context" +import ( + core2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" +) -func GetRole(_ context.Context, roleAnnotationKey string, annotations map[string]string) string { - if len(roleAnnotationKey) > 0 { - return annotations[roleAnnotationKey] +func GetRoleFromSecurityContext(roleKey string, taskExecutionMetadata core2.TaskExecutionMetadata) string { + var role string + securityContext := taskExecutionMetadata.GetSecurityContext() + if securityContext.GetRunAs() != nil { + role = securityContext.GetRunAs().GetIamRole() + } + + // Continue this for backward compatibility + if len(role) == 0 { + role = getRole(roleKey, taskExecutionMetadata.GetAnnotations()) + } + return role +} + +func getRole(roleKey string, keyValueMap map[string]string) string { + if len(roleKey) > 0 { + return keyValueMap[roleKey] } return "" diff --git a/flyteplugins/go/tasks/plugins/k8s/container/container.go b/flyteplugins/go/tasks/plugins/k8s/container/container.go index 918258b053..10667d5c23 100755 --- a/flyteplugins/go/tasks/plugins/k8s/container/container.go +++ b/flyteplugins/go/tasks/plugins/k8s/container/container.go @@ -64,8 +64,7 @@ func (Plugin) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecuti pod := flytek8s.BuildPodWithSpec(podSpec) - // We want to Also update the serviceAccount to the serviceaccount of the workflow - pod.Spec.ServiceAccountName = taskCtx.TaskExecutionMetadata().GetK8sServiceAccount() + pod.Spec.ServiceAccountName = flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) return pod, nil } diff --git a/flyteplugins/go/tasks/plugins/k8s/container/container_test.go b/flyteplugins/go/tasks/plugins/k8s/container/container_test.go index 7374b88897..c3f905ff97 100755 --- a/flyteplugins/go/tasks/plugins/k8s/container/container_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/container/container_test.go @@ -39,6 +39,9 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements) pluginsCore. Name: "blah", }) taskMetadata.On("GetK8sServiceAccount").Return("service-account") + taskMetadata.On("GetSecurityContext").Return(core.SecurityContext{ + RunAs: &core.Identity{K8SServiceAccount: "service-account"}, + }) taskMetadata.On("GetOwnerID").Return(types.NamespacedName{ Namespace: "test-namespace", Name: "test-owner-name", diff --git a/flyteplugins/go/tasks/plugins/k8s/sagemaker/builtin_training.go b/flyteplugins/go/tasks/plugins/k8s/sagemaker/builtin_training.go index 5cb1539ead..e27e7994bf 100644 --- a/flyteplugins/go/tasks/plugins/k8s/sagemaker/builtin_training.go +++ b/flyteplugins/go/tasks/plugins/k8s/sagemaker/builtin_training.go @@ -116,8 +116,9 @@ func (m awsSagemakerPlugin) buildResourceForTrainingJob( inputModeString := strings.Title(strings.ToLower(sagemakerTrainingJob.GetAlgorithmSpecification().GetInputMode().String())) - role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata().GetAnnotations()) - if role == "" { + role := awsUtils.GetRoleFromSecurityContext(cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata()) + + if len(role) == 0 { role = cfg.RoleArn } diff --git a/flyteplugins/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go b/flyteplugins/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go index 18e897d07f..b1fba80988 100644 --- a/flyteplugins/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go +++ b/flyteplugins/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go @@ -137,8 +137,9 @@ func (m awsSagemakerPlugin) buildResourceForHyperparameterTuningJob( tuningObjectiveTypeString := strings.Title(strings.ToLower(hpoJobConfig.GetTuningObjective().GetObjectiveType().String())) trainingJobEarlyStoppingTypeString := strings.Title(strings.ToLower(hpoJobConfig.TrainingJobEarlyStoppingType.String())) - role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata().GetAnnotations()) - if role == "" { + role := awsUtils.GetRoleFromSecurityContext(cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata()) + + if len(role) == 0 { role = cfg.RoleArn } diff --git a/flyteplugins/go/tasks/plugins/k8s/sagemaker/plugin_test_utils.go b/flyteplugins/go/tasks/plugins/k8s/sagemaker/plugin_test_utils.go index 6f6fac3813..206eac0ead 100644 --- a/flyteplugins/go/tasks/plugins/k8s/sagemaker/plugin_test_utils.go +++ b/flyteplugins/go/tasks/plugins/k8s/sagemaker/plugin_test_utils.go @@ -195,6 +195,10 @@ func generateMockCustomTrainingJobTaskContext(taskTemplate *flyteIdlCore.TaskTem taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) taskExecutionMetadata.OnGetNamespace().Return("test-namespace") taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"iam.amazonaws.com/role": "metadata_role"}) + taskExecutionMetadata.OnGetSecurityContext().Return(flyteIdlCore.SecurityContext{ + RunAs: &flyteIdlCore.Identity{IamRole: "new-role"}, + }) + taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"}) taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{ Kind: "node", @@ -270,6 +274,7 @@ func generateMockTrainingJobTaskContext(taskTemplate *flyteIdlCore.TaskTemplate, taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) taskExecutionMetadata.OnGetNamespace().Return("test-namespace") taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"iam.amazonaws.com/role": "metadata_role"}) + taskExecutionMetadata.OnGetSecurityContext().Return(flyteIdlCore.SecurityContext{}) taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"}) taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{ Kind: "node", @@ -353,6 +358,7 @@ func generateMockHyperparameterTuningJobTaskContext(taskTemplate *flyteIdlCore.T outputReader.OnGetOutputPath().Return(storage.DataReference("/data/outputs.pb")) outputReader.OnGetOutputPrefixPath().Return(storage.DataReference("/data/")) outputReader.OnGetRawOutputPrefix().Return(storage.DataReference("/raw/")) + taskCtx.OnOutputWriter().Return(outputReader) taskReader := &mocks.TaskReader{} @@ -384,6 +390,9 @@ func genMockTaskExecutionMetadata() *mocks.TaskExecutionMetadata { taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) taskExecutionMetadata.OnGetNamespace().Return("test-namespace") taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"iam.amazonaws.com/role": "metadata_role"}) + taskExecutionMetadata.OnGetSecurityContext().Return(flyteIdlCore.SecurityContext{ + RunAs: &flyteIdlCore.Identity{IamRole: "default_role"}, + }) taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"}) taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{ Kind: "node", diff --git a/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go b/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go index fecec11697..905dbd12b5 100755 --- a/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go +++ b/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go @@ -118,8 +118,7 @@ func (sidecarResourceHandler) BuildResource(ctx context.Context, taskCtx plugins // CrashLoopBackoff after the initial job completion. pod.Spec.RestartPolicy = k8sv1.RestartPolicyNever - // We want to also update the serviceAccount to the serviceaccount of the workflow - pod.Spec.ServiceAccountName = taskCtx.TaskExecutionMetadata().GetK8sServiceAccount() + pod.Spec.ServiceAccountName = flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) pod, err = validateAndFinalizePod(ctx, taskCtx, primaryContainerName, *pod) if err != nil { diff --git a/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar_test.go b/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar_test.go index 00ca6c9e83..da67a3613b 100755 --- a/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar_test.go @@ -56,12 +56,14 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements) pluginsCore. taskMetadata := &pluginsCoreMock.TaskExecutionMetadata{} taskMetadata.On("GetNamespace").Return("test-namespace") taskMetadata.On("GetAnnotations").Return(map[string]string{"annotation-1": "val1"}) + taskMetadata.On("GetLabels").Return(map[string]string{"label-1": "val1"}) taskMetadata.On("GetOwnerReference").Return(metav1.OwnerReference{ Kind: "node", Name: "blah", }) taskMetadata.On("IsInterruptible").Return(true) + taskMetadata.On("GetSecurityContext").Return(core.SecurityContext{}) taskMetadata.On("GetK8sServiceAccount").Return("service-account") taskMetadata.On("GetOwnerID").Return(types.NamespacedName{ Namespace: "test-namespace", @@ -319,6 +321,7 @@ func TestBuildSidecarResource(t *testing.T) { assert.Len(t, res.(*v1.Pod).Spec.Containers[0].VolumeMounts, 1) assert.Equal(t, "volume mount", res.(*v1.Pod).Spec.Containers[0].VolumeMounts[0].Name) + assert.Equal(t, "service-account", res.(*v1.Pod).Spec.ServiceAccountName) // Assert user-specified tolerations don't get overridden assert.Len(t, res.(*v1.Pod).Spec.Tolerations, 1) for _, tol := range res.(*v1.Pod).Spec.Tolerations { diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index 3e5608c4b4..6f9b62683b 100755 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -92,6 +92,11 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo } sparkEnvVars["FLYTE_MAX_ATTEMPTS"] = strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts())) + serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) + + if len(serviceAccountName) == 0 { + serviceAccountName = sparkTaskType + } driverSpec := sparkOp.DriverSpec{ SparkPodSpec: sparkOp.SparkPodSpec{ Annotations: annotations, @@ -99,7 +104,7 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo EnvVars: sparkEnvVars, Image: &container.Image, }, - ServiceAccount: &sparkTaskType, + ServiceAccount: &serviceAccountName, } executorSpec := sparkOp.ExecutorSpec{ @@ -184,7 +189,7 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo APIVersion: sparkOp.SchemeGroupVersion.String(), }, Spec: sparkOp.SparkApplicationSpec{ - ServiceAccount: &sparkTaskType, + ServiceAccount: &serviceAccountName, Type: getApplicationType(sparkJob.GetApplicationType()), Image: &container.Image, Arguments: modifiedArgs, diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index bb7ad420a5..8070f5b2f0 100755 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -317,6 +317,9 @@ func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool) Kind: "node", Name: "blah", }) + taskExecutionMetadata.On("GetSecurityContext").Return(core.SecurityContext{ + RunAs: &core.Identity{K8SServiceAccount: "new-val"}, + }) taskExecutionMetadata.On("IsInterruptible").Return(interruptible) taskExecutionMetadata.On("GetMaxAttempts").Return(uint32(1)) taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata) @@ -374,6 +377,7 @@ func TestBuildResourceSpark(t *testing.T) { execCores, _ := strconv.ParseInt(dummySparkConf["spark.executor.cores"], 10, 32) execInstances, _ := strconv.ParseInt(dummySparkConf["spark.executor.instances"], 10, 32) + assert.Equal(t, "new-val", *sparkApp.Spec.ServiceAccount) assert.Equal(t, int32(driverCores), *sparkApp.Spec.Driver.Cores) assert.Equal(t, int32(execCores), *sparkApp.Spec.Executor.Cores) assert.Equal(t, int32(execInstances), *sparkApp.Spec.Executor.Instances)