diff --git a/flyteplugins/go/tasks/pluginmachinery/google/config.go b/flyteplugins/go/tasks/pluginmachinery/google/config.go index 445cb9efdf..ecd154e0d5 100644 --- a/flyteplugins/go/tasks/pluginmachinery/google/config.go +++ b/flyteplugins/go/tasks/pluginmachinery/google/config.go @@ -3,13 +3,17 @@ package google type TokenSourceFactoryType = string const ( - TokenSourceTypeDefault = "default" + TokenSourceTypeDefault = "default" + TokenSourceTypeGkeTaskWorkloadIdentity = "gke-task-workload-identity" // #nosec ) type TokenSourceFactoryConfig struct { - // Type is type of TokenSourceFactory, possible values are 'default' or 'gke'. + // Type is type of TokenSourceFactory, possible values are 'default' or 'gke-task-workload-identity'. // - 'default' uses default credentials, see https://cloud.google.com/iam/docs/service-accounts#default - Type TokenSourceFactoryType `json:"type" pflag:",Defines type of TokenSourceFactory, possible values are 'default'"` + Type TokenSourceFactoryType `json:"type" pflag:",Defines type of TokenSourceFactory, possible values are 'default' and 'gke-task-workload-identity'"` + + // Configuration for GKE task workload identity token source factory + GkeTaskWorkloadIdentityTokenSourceFactoryConfig GkeTaskWorkloadIdentityTokenSourceFactoryConfig `json:"gke-task-workload-identity" pflag:"Extra configuration for GKE task workload identity token source factory"` } func GetDefaultConfig() TokenSourceFactoryConfig { diff --git a/flyteplugins/go/tasks/pluginmachinery/google/default_token_source_factory.go b/flyteplugins/go/tasks/pluginmachinery/google/default_token_source_factory.go index 430e208791..358202f605 100644 --- a/flyteplugins/go/tasks/pluginmachinery/google/default_token_source_factory.go +++ b/flyteplugins/go/tasks/pluginmachinery/google/default_token_source_factory.go @@ -9,7 +9,10 @@ import ( type defaultTokenSource struct{} -func (m *defaultTokenSource) GetTokenSource(ctx context.Context, identity Identity) (oauth2.TokenSource, error) { +func (m *defaultTokenSource) GetTokenSource( + ctx context.Context, + identity Identity, +) (oauth2.TokenSource, error) { return google.DefaultTokenSource(ctx) } diff --git a/flyteplugins/go/tasks/pluginmachinery/google/gke_task_workload_identity_token_source_factory.go b/flyteplugins/go/tasks/pluginmachinery/google/gke_task_workload_identity_token_source_factory.go new file mode 100644 index 0000000000..649401fe1b --- /dev/null +++ b/flyteplugins/go/tasks/pluginmachinery/google/gke_task_workload_identity_token_source_factory.go @@ -0,0 +1,111 @@ +package google + +import ( + "context" + + pluginmachinery "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/pkg/errors" + "golang.org/x/oauth2" + "google.golang.org/api/impersonate" + "google.golang.org/grpc/credentials/oauth" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" +) + +const ( + gcpServiceAccountAnnotationKey = "iam.gke.io/gcp-service-account" + workflowIdentityDocURL = "https://cloud.google.com/kubernetes-engine/docs/how-to/workload-identity" +) + +var impersonationScopes = []string{"https://www.googleapis.com/auth/bigquery"} + +type GkeTaskWorkloadIdentityTokenSourceFactoryConfig struct { + RemoteClusterConfig pluginmachinery.ClusterConfig `json:"remoteClusterConfig" pflag:"Configuration of remote GKE cluster"` +} + +type gkeTaskWorkloadIdentityTokenSourceFactory struct { + kubeClient kubernetes.Interface +} + +func (m *gkeTaskWorkloadIdentityTokenSourceFactory) getGcpServiceAccount( + ctx context.Context, + identity Identity, +) (string, error) { + if identity.K8sServiceAccount == "" { + identity.K8sServiceAccount = "default" + } + serviceAccount, err := m.kubeClient.CoreV1().ServiceAccounts(identity.K8sNamespace).Get( + ctx, + identity.K8sServiceAccount, + metav1.GetOptions{}, + ) + if err != nil { + return "", errors.Wrapf(err, "failed to retrieve task k8s service account") + } + + for key, value := range serviceAccount.Annotations { + if key == gcpServiceAccountAnnotationKey { + return value, nil + } + } + + return "", errors.Errorf( + "[%v] annotation doesn't exist on k8s service account [%v/%v], read more at %v", + gcpServiceAccountAnnotationKey, + identity.K8sNamespace, + identity.K8sServiceAccount, + workflowIdentityDocURL) +} + +func (m *gkeTaskWorkloadIdentityTokenSourceFactory) GetTokenSource( + ctx context.Context, + identity Identity, +) (oauth2.TokenSource, error) { + gcpServiceAccount, err := m.getGcpServiceAccount(ctx, identity) + if err != nil { + return oauth.TokenSource{}, err + } + + return impersonate.CredentialsTokenSource(ctx, impersonate.CredentialsConfig{ + TargetPrincipal: gcpServiceAccount, + Scopes: impersonationScopes, + }) +} + +func getKubeClient( + config *GkeTaskWorkloadIdentityTokenSourceFactoryConfig, +) (*kubernetes.Clientset, error) { + var kubeCfg *rest.Config + var err error + if config.RemoteClusterConfig.Enabled { + kubeCfg, err = pluginmachinery.KubeClientConfig( + config.RemoteClusterConfig.Endpoint, + config.RemoteClusterConfig.Auth, + ) + if err != nil { + return nil, errors.Wrapf(err, "Error building kubeconfig") + } + } else { + kubeCfg, err = rest.InClusterConfig() + if err != nil { + return nil, errors.Wrapf(err, "Cannot get InCluster kubeconfig") + } + } + + kubeClient, err := kubernetes.NewForConfig(kubeCfg) + if err != nil { + return nil, errors.Wrapf(err, "Error building kubernetes clientset") + } + return kubeClient, err +} + +func NewGkeTaskWorkloadIdentityTokenSourceFactory( + config *GkeTaskWorkloadIdentityTokenSourceFactoryConfig, +) (TokenSourceFactory, error) { + kubeClient, err := getKubeClient(config) + if err != nil { + return nil, err + } + return &gkeTaskWorkloadIdentityTokenSourceFactory{kubeClient: kubeClient}, nil +} diff --git a/flyteplugins/go/tasks/pluginmachinery/google/gke_task_workload_identity_token_source_factory_test.go b/flyteplugins/go/tasks/pluginmachinery/google/gke_task_workload_identity_token_source_factory_test.go new file mode 100644 index 0000000000..ae88eb5451 --- /dev/null +++ b/flyteplugins/go/tasks/pluginmachinery/google/gke_task_workload_identity_token_source_factory_test.go @@ -0,0 +1,64 @@ +package google + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" +) + +func TestGetGcpServiceAccount(t *testing.T) { + ctx := context.TODO() + + t.Run("get GCP service account", func(t *testing.T) { + kubeClient := fake.NewSimpleClientset(&corev1.ServiceAccount{ + ObjectMeta: v1.ObjectMeta{ + Name: "name", + Namespace: "namespace", + Annotations: map[string]string{ + "owner": "abc", + "iam.gke.io/gcp-service-account": "gcp-service-account", + }, + }}) + ts := gkeTaskWorkloadIdentityTokenSourceFactory{kubeClient: kubeClient} + gcpServiceAccount, err := ts.getGcpServiceAccount(ctx, Identity{ + K8sNamespace: "namespace", + K8sServiceAccount: "name", + }) + + assert.NoError(t, err) + assert.Equal(t, "gcp-service-account", gcpServiceAccount) + }) + + t.Run("no GCP service account", func(t *testing.T) { + kubeClient := fake.NewSimpleClientset() + ts := gkeTaskWorkloadIdentityTokenSourceFactory{kubeClient: kubeClient} + _, err := ts.getGcpServiceAccount(ctx, Identity{ + K8sNamespace: "namespace", + K8sServiceAccount: "name", + }) + + assert.ErrorContains(t, err, "failed to retrieve task k8s service account") + }) + + t.Run("no GCP service account annotation", func(t *testing.T) { + kubeClient := fake.NewSimpleClientset(&corev1.ServiceAccount{ + ObjectMeta: v1.ObjectMeta{ + Name: "name", + Namespace: "namespace", + Annotations: map[string]string{ + "owner": "abc", + }, + }}) + ts := gkeTaskWorkloadIdentityTokenSourceFactory{kubeClient: kubeClient} + _, err := ts.getGcpServiceAccount(ctx, Identity{ + K8sNamespace: "namespace", + K8sServiceAccount: "name", + }) + + assert.ErrorContains(t, err, "annotation doesn't exist on k8s service account") + }) +} diff --git a/flyteplugins/go/tasks/pluginmachinery/google/token_source_factory.go b/flyteplugins/go/tasks/pluginmachinery/google/token_source_factory.go index 05207e25c9..18cd1b0a7d 100644 --- a/flyteplugins/go/tasks/pluginmachinery/google/token_source_factory.go +++ b/flyteplugins/go/tasks/pluginmachinery/google/token_source_factory.go @@ -17,9 +17,17 @@ type TokenSourceFactory interface { } func NewTokenSourceFactory(config TokenSourceFactoryConfig) (TokenSourceFactory, error) { - if config.Type == TokenSourceTypeDefault { + switch config.Type { + case TokenSourceTypeDefault: return NewDefaultTokenSourceFactory() + case TokenSourceTypeGkeTaskWorkloadIdentity: + return NewGkeTaskWorkloadIdentityTokenSourceFactory( + &config.GkeTaskWorkloadIdentityTokenSourceFactoryConfig, + ) } - return nil, errors.Errorf("unknown token source type [%v], possible values are: 'default'", config.Type) + return nil, errors.Errorf( + "unknown token source type [%v], possible values are: 'default' and 'gke-task-workload-identity'", + config.Type, + ) }