From 321ab21a02d76a6a30dd9f9a892b2e2c7d93addd Mon Sep 17 00:00:00 2001 From: Kamal <54046807+kamaleybov@users.noreply.github.com> Date: Sun, 23 Jun 2024 23:02:47 -0700 Subject: [PATCH] Mount embedded secrets as files (#293) ## Overview Adds support for mounting secrets as files. This change is part of series to enable mounting secrets as files: - flyte fork: https://github.com/unionai/flyte/pull/293 (this change) - cloud: https://github.com/unionai/cloud/pull/7975 - unionai: https://github.com/unionai/unionai/pull/237 ## Test Plan 1. Deploy to dogfood. 2. Create secret using `-f` flag and a file. 3. Run a workflow using the file. Verify that secret value was read correctly. ## Rollout Plan Run `managed-cluster--sync-all`. ## Upstream Changes Should this change be upstreamed to OSS (flyteorg/flyte)? If so, please check this box for auditing. Note, this is the responsibility of each developer. See [this guide](https://unionai.atlassian.net/wiki/spaces/ENG/pages/447610883/Flyte+-+Union+Cloud+Development+Runbook/#When-are-versions-updated%3F). - [ ] To be upstreamed ## Issue ref COR-811 ## Checklist * [ ] Added tests * [ ] Ran a deploy dry run and shared the terraform plan * [ ] Added logging and metrics * [ ] Updated [dashboards](https://unionai.grafana.net/dashboards) and [alerts](https://unionai.grafana.net/alerting/list) * [ ] Updated documentation --- Makefile | 1 + cmd/single/config_flags.go | 62 +++++ cmd/single/config_flags_test.go | 214 ++++++++++++++++ fasttask/plugin/go.mod | 16 ++ fasttask/plugin/go.sum | 30 +++ flyteadmin/go.mod | 16 ++ flyteadmin/go.sum | 30 +++ .../pluginmachinery/core/secret_manager.go | 40 ++- .../pkg/controller/nodes/task/handler.go | 3 +- .../pkg/secret/aws_secret_fetcher.go | 36 ++- .../pkg/secret/aws_secret_fetcher_test.go | 16 +- .../pkg/secret/aws_secret_manager.go | 11 +- flytepropeller/pkg/secret/config/config.go | 27 +- .../pkg/secret/config/config_flags.go | 1 + .../pkg/secret/config/config_flags_test.go | 14 + .../pkg/secret/embedded_secret_manager.go | 241 +++++++++++------- .../secret/embedded_secret_manager_test.go | 2 +- .../pkg/secret/gcp_secret_fetcher.go | 25 +- .../pkg/secret/gcp_secret_fetcher_test.go | 10 +- .../pkg/secret/gcp_secret_manager.go | 8 +- flytepropeller/pkg/secret/global_secrets.go | 3 +- ..._iface.go => aws_secret_manager_client.go} | 20 +- ..._iface.go => gcp_secret_manager_client.go} | 20 +- .../mocks/http_hook_registerer_iface.go | 19 -- .../pkg/secret/mocks/secret_fetcher.go | 53 ---- flytepropeller/pkg/secret/secret_fetcher.go | 42 +++ ...cret_iface.go => secret_manager_client.go} | 12 +- flytepropeller/pkg/secret/secrets_injector.go | 46 ++++ .../{secrets.go => secrets_pod_mutator.go} | 49 +--- flytepropeller/pkg/secret/secrets_test.go | 8 +- flytepropeller/pkg/secret/utils.go | 17 +- .../{secret => webhook}/mocks/pod_mutator.go | 0 flytepropeller/pkg/webhook/pod.go | 2 + flytepropeller/pkg/webhook/pod_test.go | 10 +- go.mod | 4 +- 35 files changed, 801 insertions(+), 307 deletions(-) create mode 100755 cmd/single/config_flags.go create mode 100755 cmd/single/config_flags_test.go rename flytepropeller/pkg/secret/mocks/{aws_secrets_iface.go => aws_secret_manager_client.go} (52%) rename flytepropeller/pkg/secret/mocks/{gcp_secrets_iface.go => gcp_secret_manager_client.go} (53%) delete mode 100644 flytepropeller/pkg/secret/mocks/http_hook_registerer_iface.go delete mode 100644 flytepropeller/pkg/secret/mocks/secret_fetcher.go create mode 100644 flytepropeller/pkg/secret/secret_fetcher.go rename flytepropeller/pkg/secret/{embedded_secret_iface.go => secret_manager_client.go} (62%) create mode 100644 flytepropeller/pkg/secret/secrets_injector.go rename flytepropeller/pkg/secret/{secrets.go => secrets_pod_mutator.go} (64%) rename flytepropeller/pkg/{secret => webhook}/mocks/pod_mutator.go (100%) diff --git a/Makefile b/Makefile index d718013ffb0..cd80cfdf7cd 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,5 @@ export REPOSITORY=flyte +export REPO_ROOT=. include boilerplate/flyte/end2end/Makefile include boilerplate/flyte/golang_test_targets/Makefile diff --git a/cmd/single/config_flags.go b/cmd/single/config_flags.go new file mode 100755 index 00000000000..a808b9af565 --- /dev/null +++ b/cmd/single/config_flags.go @@ -0,0 +1,62 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package single + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (Config) mustJsonMarshal(v interface{}) string { + raw, err := json.Marshal(v) + if err != nil { + panic(err) + } + + return string(raw) +} + +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "propeller.disabled"), DefaultConfig.Propeller.Disabled, "Disables flytepropeller in the single binary mode") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "propeller.disableWebhook"), DefaultConfig.Propeller.DisableWebhook, "Disables webhook only") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "admin.disabled"), DefaultConfig.Admin.Disabled, "Disables flyteadmin in the single binary mode") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "admin.disableScheduler"), DefaultConfig.Admin.DisableScheduler, "Disables Native scheduler in the single binary mode") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "admin.disableClusterResourceManager"), DefaultConfig.Admin.DisableClusterResourceManager, "Disables Cluster resource manager") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "admin.seedProjects"), DefaultConfig.Admin.SeedProjects, "flyte projects to create by default.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "dataCatalog.disabled"), DefaultConfig.DataCatalog.Disabled, "Disables datacatalog in the single binary mode") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "cacheService.disabled"), DefaultConfig.CacheService.Disabled, "Disables cacheservice in the single binary mode") + return cmdFlags +} diff --git a/cmd/single/config_flags_test.go b/cmd/single/config_flags_test.go new file mode 100755 index 00000000000..1c0a040d6c0 --- /dev/null +++ b/cmd/single/config_flags_test.go @@ -0,0 +1,214 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package single + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeRaw_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_propeller.disabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("propeller.disabled", testValue) + if vBool, err := cmdFlags.GetBool("propeller.disabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.Propeller.Disabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_propeller.disableWebhook", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("propeller.disableWebhook", testValue) + if vBool, err := cmdFlags.GetBool("propeller.disableWebhook"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.Propeller.DisableWebhook) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_admin.disabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("admin.disabled", testValue) + if vBool, err := cmdFlags.GetBool("admin.disabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.Admin.Disabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_admin.disableScheduler", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("admin.disableScheduler", testValue) + if vBool, err := cmdFlags.GetBool("admin.disableScheduler"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.Admin.DisableScheduler) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_admin.disableClusterResourceManager", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("admin.disableClusterResourceManager", testValue) + if vBool, err := cmdFlags.GetBool("admin.disableClusterResourceManager"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.Admin.DisableClusterResourceManager) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_admin.seedProjects", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := join_Config(DefaultConfig.Admin.SeedProjects, ",") + + cmdFlags.Set("admin.seedProjects", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("admin.seedProjects"); err == nil { + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.Admin.SeedProjects) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_dataCatalog.disabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("dataCatalog.disabled", testValue) + if vBool, err := cmdFlags.GetBool("dataCatalog.disabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.DataCatalog.Disabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_cacheService.disabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("cacheService.disabled", testValue) + if vBool, err := cmdFlags.GetBool("cacheService.disabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.CacheService.Disabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/fasttask/plugin/go.mod b/fasttask/plugin/go.mod index e65c8300095..1dd711116ee 100644 --- a/fasttask/plugin/go.mod +++ b/fasttask/plugin/go.mod @@ -25,6 +25,7 @@ require ( cloud.google.com/go/compute v1.23.3 // indirect cloud.google.com/go/compute/metadata v0.2.3 // indirect cloud.google.com/go/iam v1.1.5 // indirect + cloud.google.com/go/secretmanager v1.11.4 // indirect cloud.google.com/go/storage v1.36.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.2 // indirect github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.1 // indirect @@ -32,6 +33,20 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.1.0 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.2.0 // indirect github.com/aws/aws-sdk-go v1.44.2 // indirect + github.com/aws/aws-sdk-go-v2 v1.24.1 // indirect + github.com/aws/aws-sdk-go-v2/config v1.26.1 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.16.12 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.9 // indirect + github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.26.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.18.5 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.26.5 // indirect + github.com/aws/smithy-go v1.19.0 // indirect github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect @@ -108,6 +123,7 @@ require ( golang.org/x/term v0.19.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/time v0.5.0 // indirect + gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect google.golang.org/api v0.155.0 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 // indirect diff --git a/fasttask/plugin/go.sum b/fasttask/plugin/go.sum index 07cc54cdffa..d79e12032c9 100644 --- a/fasttask/plugin/go.sum +++ b/fasttask/plugin/go.sum @@ -7,6 +7,8 @@ cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGB cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= cloud.google.com/go/iam v1.1.5 h1:1jTsCu4bcsNsE4iiqNT5SHwrDRCfRmIaaaVFhRveTJI= cloud.google.com/go/iam v1.1.5/go.mod h1:rB6P/Ic3mykPbFio+vo7403drjlgvoWfYpJhMXEbzv8= +cloud.google.com/go/secretmanager v1.11.4 h1:krnX9qpG2kR2fJ+u+uNyNo+ACVhplIAS4Pu7u+4gd+k= +cloud.google.com/go/secretmanager v1.11.4/go.mod h1:wreJlbS9Zdq21lMzWmJ0XhWW2ZxgPeahsqeV/vZoJ3w= cloud.google.com/go/storage v1.36.0 h1:P0mOkAcaJxhCTvAkMhxMfrTKiNcub4YmmPBtlhAyTr8= cloud.google.com/go/storage v1.36.0/go.mod h1:M6M/3V/D3KpzMTJyPOR/HU6n2Si5QdaXYEsng2xgOs8= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.2 h1:t5+QXLCK9SVi0PPdaY0PrFvYUo24KwA0QwxnaHRSVd4= @@ -26,6 +28,34 @@ github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/aws/aws-sdk-go v1.44.2 h1:5VBk5r06bgxgRKVaUtm1/4NT/rtrnH2E4cnAYv5zgQc= github.com/aws/aws-sdk-go v1.44.2/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo= +github.com/aws/aws-sdk-go-v2 v1.24.1 h1:xAojnj+ktS95YZlDf0zxWBkbFtymPeDP+rvUQIH3uAU= +github.com/aws/aws-sdk-go-v2 v1.24.1/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= +github.com/aws/aws-sdk-go-v2/config v1.26.1 h1:z6DqMxclFGL3Zfo+4Q0rLnAZ6yVkzCRxhRMsiRQnD1o= +github.com/aws/aws-sdk-go-v2/config v1.26.1/go.mod h1:ZB+CuKHRbb5v5F0oJtGdhFTelmrxd4iWO1lf0rQwSAg= +github.com/aws/aws-sdk-go-v2/credentials v1.16.12 h1:v/WgB8NxprNvr5inKIiVVrXPuuTegM+K8nncFkr1usU= +github.com/aws/aws-sdk-go-v2/credentials v1.16.12/go.mod h1:X21k0FjEJe+/pauud82HYiQbEr9jRKY3kXEIQ4hXeTQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10 h1:w98BT5w+ao1/r5sUuiH6JkVzjowOKeOJRHERyy1vh58= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10/go.mod h1:K2WGI7vUvkIv1HoNbfBA1bvIZ+9kL3YVmWxeKuLQsiw= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 h1:vF+Zgd9s+H4vOXd5BMaPWykta2a6Ih0AKLq/X6NYKn4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10/go.mod h1:6BkRjejp/GR4411UGqkX8+wFMbFbqsUIimfK4XjOKR4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 h1:nYPe006ktcqUji8S2mqXf9c/7NdiKriOwMvWQHgYztw= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10/go.mod h1:6UV4SZkVvmODfXKql4LCbaZUpF7HO2BX38FgBf9ZOLw= +github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 h1:GrSw8s0Gs/5zZ0SX+gX4zQjRnRsMJDJ2sLur1gRBhEM= +github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2/go.mod h1:6fQQgfuGmw8Al/3M2IgIllycxV7ZW7WCdVSqfBeUiCY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 h1:/b31bi3YVNlkzkBrm9LfpaKoaYZUxIAj4sHfOTmLfqw= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4/go.mod h1:2aGXHFmbInwgP9ZfpmdIfOELL79zhdNYNmReK8qDfdQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.9 h1:Nf2sHxjMJR8CSImIVCONRi4g0Su3J+TSTbS7G0pUeMU= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.9/go.mod h1:idky4TER38YIjr2cADF1/ugFMKvZV7p//pVeV5LZbF0= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.26.1 h1:Sn3MAV9YeACCULaxNWWYFH1a6G4wYFwBn3/TA5MwE2Q= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.26.1/go.mod h1:qutL00aW8GSo2D0I6UEOqMvRS3ZyuBrOC1BLe5D2jPc= +github.com/aws/aws-sdk-go-v2/service/sso v1.18.5 h1:ldSFWz9tEHAwHNmjx2Cvy1MjP5/L9kNoR0skc6wyOOM= +github.com/aws/aws-sdk-go-v2/service/sso v1.18.5/go.mod h1:CaFfXLYL376jgbP7VKC96uFcU8Rlavak0UlAwk1Dlhc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5 h1:2k9KmFawS63euAkY4/ixVNsYYwrwnd5fIvgEKkfZFNM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5/go.mod h1:W+nd4wWDVkSUIox9bacmkBP5NMFQeTJ/xqNabpzSR38= +github.com/aws/aws-sdk-go-v2/service/sts v1.26.5 h1:5UYvv8JUvllZsRnfrcMQ+hJ9jNICmcgKPAO1CER25Wg= +github.com/aws/aws-sdk-go-v2/service/sts v1.26.5/go.mod h1:XX5gh4CB7wAs4KhcF46G6C8a2i7eupU19dcAAE+EydU= +github.com/aws/smithy-go v1.19.0 h1:KWFKQV80DpP3vJrrA9sVAHQ5gc2z8i4EzrLhLlWXcBM= +github.com/aws/smithy-go v1.19.0/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 h1:VRtJdDi2lqc3MFwmouppm2jlm6icF+7H3WYKpLENMTo= github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1/go.mod h1:jvdWlw8vowVGnZqSDC7yhPd7AifQeQbRDkZcQXV2nRg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= diff --git a/flyteadmin/go.mod b/flyteadmin/go.mod index a7b4eb41a89..ed9027ded31 100644 --- a/flyteadmin/go.mod +++ b/flyteadmin/go.mod @@ -78,12 +78,27 @@ require ( cloud.google.com/go/compute v1.23.3 // indirect cloud.google.com/go/compute/metadata v0.2.3 // indirect cloud.google.com/go/pubsub v1.34.0 // indirect + cloud.google.com/go/secretmanager v1.11.4 // indirect github.com/Azure/azure-sdk-for-go/sdk/azcore v1.8.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.1 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.2.0 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.2.0 // indirect github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535 // indirect + github.com/aws/aws-sdk-go-v2 v1.24.1 // indirect + github.com/aws/aws-sdk-go-v2/config v1.26.1 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.16.12 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.9 // indirect + github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.26.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.18.5 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.26.5 // indirect + github.com/aws/smithy-go v1.19.0 // indirect github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bradfitz/gomemcache v0.0.0-20180710155616-bc664df96737 // indirect @@ -208,6 +223,7 @@ require ( golang.org/x/term v0.19.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.20.0 // indirect + gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240123012728-ef4313101c80 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80 // indirect diff --git a/flyteadmin/go.sum b/flyteadmin/go.sum index 45841c54753..53707a45cf4 100644 --- a/flyteadmin/go.sum +++ b/flyteadmin/go.sum @@ -45,6 +45,8 @@ cloud.google.com/go/pubsub v1.2.0/go.mod h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIA cloud.google.com/go/pubsub v1.3.1/go.mod h1:i+ucay31+CNRpDW4Lu78I4xXG+O1r/MAHgjpRVR+TSU= cloud.google.com/go/pubsub v1.34.0 h1:ZtPbfwfi5rLaPeSvDC29fFoE20/tQvGrUS6kVJZJvkU= cloud.google.com/go/pubsub v1.34.0/go.mod h1:alj4l4rBg+N3YTFDDC+/YyFTs6JAjam2QfYsddcAW4c= +cloud.google.com/go/secretmanager v1.11.4 h1:krnX9qpG2kR2fJ+u+uNyNo+ACVhplIAS4Pu7u+4gd+k= +cloud.google.com/go/secretmanager v1.11.4/go.mod h1:wreJlbS9Zdq21lMzWmJ0XhWW2ZxgPeahsqeV/vZoJ3w= cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= cloud.google.com/go/storage v1.5.0/go.mod h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos= cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk= @@ -114,7 +116,35 @@ github.com/aws/aws-sdk-go v1.23.20/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpi github.com/aws/aws-sdk-go v1.31.3/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/aws/aws-sdk-go v1.46.2 h1:XZbOmjtN1VCfEtQq7QNFsbxIqO+bB+bRhiOBjp6AzWc= github.com/aws/aws-sdk-go v1.46.2/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= +github.com/aws/aws-sdk-go-v2 v1.24.1 h1:xAojnj+ktS95YZlDf0zxWBkbFtymPeDP+rvUQIH3uAU= +github.com/aws/aws-sdk-go-v2 v1.24.1/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= +github.com/aws/aws-sdk-go-v2/config v1.26.1 h1:z6DqMxclFGL3Zfo+4Q0rLnAZ6yVkzCRxhRMsiRQnD1o= +github.com/aws/aws-sdk-go-v2/config v1.26.1/go.mod h1:ZB+CuKHRbb5v5F0oJtGdhFTelmrxd4iWO1lf0rQwSAg= +github.com/aws/aws-sdk-go-v2/credentials v1.16.12 h1:v/WgB8NxprNvr5inKIiVVrXPuuTegM+K8nncFkr1usU= +github.com/aws/aws-sdk-go-v2/credentials v1.16.12/go.mod h1:X21k0FjEJe+/pauud82HYiQbEr9jRKY3kXEIQ4hXeTQ= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10 h1:w98BT5w+ao1/r5sUuiH6JkVzjowOKeOJRHERyy1vh58= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.10/go.mod h1:K2WGI7vUvkIv1HoNbfBA1bvIZ+9kL3YVmWxeKuLQsiw= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 h1:vF+Zgd9s+H4vOXd5BMaPWykta2a6Ih0AKLq/X6NYKn4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10/go.mod h1:6BkRjejp/GR4411UGqkX8+wFMbFbqsUIimfK4XjOKR4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 h1:nYPe006ktcqUji8S2mqXf9c/7NdiKriOwMvWQHgYztw= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10/go.mod h1:6UV4SZkVvmODfXKql4LCbaZUpF7HO2BX38FgBf9ZOLw= +github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 h1:GrSw8s0Gs/5zZ0SX+gX4zQjRnRsMJDJ2sLur1gRBhEM= +github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2/go.mod h1:6fQQgfuGmw8Al/3M2IgIllycxV7ZW7WCdVSqfBeUiCY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 h1:/b31bi3YVNlkzkBrm9LfpaKoaYZUxIAj4sHfOTmLfqw= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4/go.mod h1:2aGXHFmbInwgP9ZfpmdIfOELL79zhdNYNmReK8qDfdQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.9 h1:Nf2sHxjMJR8CSImIVCONRi4g0Su3J+TSTbS7G0pUeMU= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.9/go.mod h1:idky4TER38YIjr2cADF1/ugFMKvZV7p//pVeV5LZbF0= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.26.1 h1:Sn3MAV9YeACCULaxNWWYFH1a6G4wYFwBn3/TA5MwE2Q= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.26.1/go.mod h1:qutL00aW8GSo2D0I6UEOqMvRS3ZyuBrOC1BLe5D2jPc= +github.com/aws/aws-sdk-go-v2/service/sso v1.18.5 h1:ldSFWz9tEHAwHNmjx2Cvy1MjP5/L9kNoR0skc6wyOOM= +github.com/aws/aws-sdk-go-v2/service/sso v1.18.5/go.mod h1:CaFfXLYL376jgbP7VKC96uFcU8Rlavak0UlAwk1Dlhc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5 h1:2k9KmFawS63euAkY4/ixVNsYYwrwnd5fIvgEKkfZFNM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.5/go.mod h1:W+nd4wWDVkSUIox9bacmkBP5NMFQeTJ/xqNabpzSR38= +github.com/aws/aws-sdk-go-v2/service/sts v1.26.5 h1:5UYvv8JUvllZsRnfrcMQ+hJ9jNICmcgKPAO1CER25Wg= +github.com/aws/aws-sdk-go-v2/service/sts v1.26.5/go.mod h1:XX5gh4CB7wAs4KhcF46G6C8a2i7eupU19dcAAE+EydU= github.com/aws/aws-xray-sdk-go v0.9.4/go.mod h1:XtMKdBQfpVut+tJEwI7+dJFRxxRdxHDyVNp2tHXRq04= +github.com/aws/smithy-go v1.19.0 h1:KWFKQV80DpP3vJrrA9sVAHQ5gc2z8i4EzrLhLlWXcBM= +github.com/aws/smithy-go v1.19.0/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= diff --git a/flyteplugins/go/tasks/pluginmachinery/core/secret_manager.go b/flyteplugins/go/tasks/pluginmachinery/core/secret_manager.go index aff5397922d..092d982d497 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/secret_manager.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/secret_manager.go @@ -1,7 +1,45 @@ package core -import "context" +import ( + "context" + "fmt" + "unicode/utf8" + + "github.com/flyteorg/flyte/flytepropeller/pkg/secret" + "github.com/flyteorg/flyte/flytepropeller/pkg/secret/config" +) type SecretManager interface { Get(ctx context.Context, key string) (string, error) } + +type EmbeddedSecretManager struct { + secretFetcher secret.SecretFetcher +} + +func (e *EmbeddedSecretManager) Get(ctx context.Context, key string) (string, error) { + secretValue, err := e.secretFetcher.GetSecretValue(ctx, key) + if err != nil { + return "", err + } + + if secretValue.StringValue != "" { + return secretValue.StringValue, nil + } + + // GCP secrets store values as binary only. We could fail this path for AWS, but for + // consistent behaviour between AWS and GCP we will allow this path for AWS as well. + if !utf8.Valid(secretValue.BinaryValue) { + return "", fmt.Errorf("secret %q has a binary value that is not a valid UTF-8 string", key) + } + return string(secretValue.BinaryValue), nil +} + +func NewEmbeddedSecretManager(ctx context.Context, cfg config.EmbeddedSecretManagerConfig) (SecretManager, error) { + secretFetcher, err := secret.NewSecretFetcher(ctx, cfg) + if err != nil { + return nil, err + } + + return &EmbeddedSecretManager{secretFetcher}, nil +} diff --git a/flytepropeller/pkg/controller/nodes/task/handler.go b/flytepropeller/pkg/controller/nodes/task/handler.go index 8a80138eaad..0965ef7d72b 100644 --- a/flytepropeller/pkg/controller/nodes/task/handler.go +++ b/flytepropeller/pkg/controller/nodes/task/handler.go @@ -29,7 +29,6 @@ import ( "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/task/config" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/task/resourcemanager" rmConfig "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/task/resourcemanager/config" - "github.com/flyteorg/flyte/flytepropeller/pkg/secret" secretConfig "github.com/flyteorg/flyte/flytepropeller/pkg/secret/config" "github.com/flyteorg/flyte/flytepropeller/pkg/utils" "github.com/flyteorg/flyte/flytestdlib/contextutils" @@ -891,7 +890,7 @@ func New(ctx context.Context, kubeClient executors.Client, kubeClientset kuberne return nil, err } - secretManager, err := secret.NewSecretFetcherManager(ctx, secretConfig.GetConfig().EmbeddedSecretManagerConfig) + secretManager, err := pluginCore.NewEmbeddedSecretManager(ctx, secretConfig.GetConfig().EmbeddedSecretManagerConfig) if err != nil { logger.Errorf(ctx, "Failed to create secret manager with err %v", err) return nil, err diff --git a/flytepropeller/pkg/secret/aws_secret_fetcher.go b/flytepropeller/pkg/secret/aws_secret_fetcher.go index d2f637ce5e6..17d7e89943d 100644 --- a/flytepropeller/pkg/secret/aws_secret_fetcher.go +++ b/flytepropeller/pkg/secret/aws_secret_fetcher.go @@ -14,41 +14,51 @@ import ( "github.com/flyteorg/flyte/flytestdlib/logger" ) +const ( + AWSSecretLatestVersion = "AWSCURRENT" +) + type AWSSecretFetcher struct { - client AWSSecretsIface + client AWSSecretManagerClient cfg config.AWSConfig } -func (a AWSSecretFetcher) Get(ctx context.Context, key string) (string, error) { - return a.GetSecretValue(ctx, key) -} - -func (a AWSSecretFetcher) GetSecretValue(ctx context.Context, secretID string) (string, error) { +func (a AWSSecretFetcher) GetSecretValue(ctx context.Context, secretID string) (*SecretValue, error) { logger.Infof(ctx, "Got fetch secret Request for %v!", secretID) resp, err := a.client.GetSecretValue(ctx, &awssm.GetSecretValueInput{ SecretId: aws.String(secretID), - VersionStage: aws.String(AWSSecretLatesVersion), + VersionStage: aws.String(AWSSecretLatestVersion), }) + if err != nil { var notFound *types.ResourceNotFoundException if errors.As(err, ¬Found) { wrappedErr := stdlibErrors.Wrapf(ErrCodeSecretNotFound, err, fmt.Sprintf(SecretNotFoundErrorFormat, secretID)) logger.Warn(ctx, wrappedErr) - return "", wrappedErr + return nil, wrappedErr } wrappedErr := stdlibErrors.Wrapf(ErrCodeSecretReadFailure, err, fmt.Sprintf(SecretReadFailureErrorFormat, secretID)) logger.Error(ctx, wrappedErr) - return "", wrappedErr + return nil, wrappedErr } - if resp.SecretString == nil || *resp.SecretString == "" { + + if (resp.SecretString == nil || *resp.SecretString == "") && resp.SecretBinary == nil { wrappedErr := stdlibErrors.Wrapf(ErrCodeSecretNil, err, fmt.Sprintf(SecretNilErrorFormat, secretID)) logger.Error(ctx, wrappedErr) - return "", wrappedErr + return nil, wrappedErr + } + + secretValue := &SecretValue{} + if resp.SecretString != nil { + secretValue.StringValue = *resp.SecretString + } else { + secretValue.BinaryValue = resp.SecretBinary } - return *resp.SecretString, nil + + return secretValue, nil } // NewAWSSecretFetcher creates a secret value fetcher for AWS -func NewAWSSecretFetcher(cfg config.AWSConfig, client AWSSecretsIface) SecretFetcher { +func NewAWSSecretFetcher(cfg config.AWSConfig, client AWSSecretManagerClient) SecretFetcher { return AWSSecretFetcher{cfg: cfg, client: client} } diff --git a/flytepropeller/pkg/secret/aws_secret_fetcher_test.go b/flytepropeller/pkg/secret/aws_secret_fetcher_test.go index ca79f7ca6a5..af36f5e2ced 100644 --- a/flytepropeller/pkg/secret/aws_secret_fetcher_test.go +++ b/flytepropeller/pkg/secret/aws_secret_fetcher_test.go @@ -19,7 +19,7 @@ import ( var ( ctx context.Context scope promutils.Scope - awsClient *mocks.AWSSecretsIface + awsClient *mocks.AWSSecretManagerClient ) const secretID = "secretID" @@ -27,7 +27,7 @@ const secretID = "secretID" func SetupTest() { scope = promutils.NewTestScope() ctx = context.Background() - awsClient = &mocks.AWSSecretsIface{} + awsClient = &mocks.AWSSecretManagerClient{} } func TestGetSecretValueAWS(t *testing.T) { @@ -36,12 +36,12 @@ func TestGetSecretValueAWS(t *testing.T) { awsSecretsFetcher := NewAWSSecretFetcher(config.AWSConfig{}, awsClient) awsClient.OnGetSecretValueMatch(ctx, &secretsmanager.GetSecretValueInput{ SecretId: aws.String(secretID), - VersionStage: aws.String(AWSSecretLatesVersion), + VersionStage: aws.String(AWSSecretLatestVersion), }).Return(&secretsmanager.GetSecretValueOutput{ SecretString: aws.String("secretValue"), }, nil) - _, err := awsSecretsFetcher.Get(ctx, "secretID") + _, err := awsSecretsFetcher.GetSecretValue(ctx, "secretID") assert.NoError(t, err) }) @@ -51,10 +51,10 @@ func TestGetSecretValueAWS(t *testing.T) { cause := &types.ResourceNotFoundException{} awsClient.OnGetSecretValueMatch(ctx, &secretsmanager.GetSecretValueInput{ SecretId: aws.String(secretID), - VersionStage: aws.String(AWSSecretLatesVersion), + VersionStage: aws.String(AWSSecretLatestVersion), }).Return(nil, cause) - _, err := awsSecretsFetcher.Get(ctx, "secretID") + _, err := awsSecretsFetcher.GetSecretValue(ctx, "secretID") assert.Equal(t, stdlibErrors.Wrapf(ErrCodeSecretNotFound, cause, fmt.Sprintf(SecretNotFoundErrorFormat, secretID)), err) }) @@ -64,10 +64,10 @@ func TestGetSecretValueAWS(t *testing.T) { cause := fmt.Errorf("some error") awsClient.OnGetSecretValueMatch(ctx, &secretsmanager.GetSecretValueInput{ SecretId: aws.String(secretID), - VersionStage: aws.String(AWSSecretLatesVersion), + VersionStage: aws.String(AWSSecretLatestVersion), }).Return(nil, cause) - _, err := awsSecretsFetcher.Get(ctx, "secretID") + _, err := awsSecretsFetcher.GetSecretValue(ctx, "secretID") assert.Equal(t, stdlibErrors.Wrapf(ErrCodeSecretReadFailure, cause, fmt.Sprintf(SecretReadFailureErrorFormat, secretID)), err) }) } diff --git a/flytepropeller/pkg/secret/aws_secret_manager.go b/flytepropeller/pkg/secret/aws_secret_manager.go index f4bd64b82cf..963a244a7c1 100644 --- a/flytepropeller/pkg/secret/aws_secret_manager.go +++ b/flytepropeller/pkg/secret/aws_secret_manager.go @@ -3,7 +3,6 @@ package secret import ( "context" "fmt" - "os" "path/filepath" "strings" @@ -29,11 +28,9 @@ const ( // AWS SideCar Docker Container expects the mount to always be under /tmp AWSInitContainerMountPath = "/tmp" -) -var ( - // AWSSecretMountPathPrefix defines the default mount path for secrets - AWSSecretMountPathPrefix = []string{string(os.PathSeparator), "etc", "flyte", "secrets"} + // AWSSecretMountPath defines the default mount path for secrets + AWSSecretMountPath = "/etc/flyte/secrets" // #nosec G101 ) // AWSSecretManagerInjector allows injecting of secrets from AWS Secret Manager as files. It uses AWS-provided SideCar @@ -85,7 +82,7 @@ func (i AWSSecretManagerInjector) Inject(ctx context.Context, secret *core.Secre secretVolumeMount := corev1.VolumeMount{ Name: AWSSecretsVolumeName, ReadOnly: true, - MountPath: filepath.Join(AWSSecretMountPathPrefix...), + MountPath: AWSSecretMountPath, } p.Spec.Containers = AppendVolumeMounts(p.Spec.Containers, secretVolumeMount) @@ -96,7 +93,7 @@ func (i AWSSecretManagerInjector) Inject(ctx context.Context, secret *core.Secre // Set environment variable to let the container know where to find the mounted files. { Name: SecretPathDefaultDirEnvVar, - Value: filepath.Join(AWSSecretMountPathPrefix...), + Value: AWSSecretMountPath, }, // Sets an empty prefix to let the containers know the file names will match the secret keys as-is. { diff --git a/flytepropeller/pkg/secret/config/config.go b/flytepropeller/pkg/secret/config/config.go index 840d66d06d1..9ee5297e535 100644 --- a/flytepropeller/pkg/secret/config/config.go +++ b/flytepropeller/pkg/secret/config/config.go @@ -54,6 +54,21 @@ var ( Role: "flyte", KVVersion: KVVersion2, }, + EmbeddedSecretManagerConfig: EmbeddedSecretManagerConfig{ + FileMountInitContainer: FileMountInitContainerConfig{ + Image: "busybox:1.28", + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceMemory: resource.MustParse("100Mi"), + corev1.ResourceCPU: resource.MustParse("100m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceMemory: resource.MustParse("100Mi"), + corev1.ResourceCPU: resource.MustParse("100m"), + }, + }, + }, + }, } configSection = config.MustRegisterSection("webhook", DefaultConfig) @@ -126,9 +141,10 @@ const ( ) type EmbeddedSecretManagerConfig struct { - Type EmbeddedSecretManagerType `json:"type" pflags:"-,Type of embedded secret manager to initialize"` - AWSConfig AWSConfig `json:"awsConfig" pflag:",Config for AWS settings"` - GCPConfig GCPConfig `json:"gcpConfig" pflag:",Config for GCP settings"` + Type EmbeddedSecretManagerType `json:"type" pflags:"-,Type of embedded secret manager to initialize"` + AWSConfig AWSConfig `json:"awsConfig" pflag:",Config for AWS settings"` + GCPConfig GCPConfig `json:"gcpConfig" pflag:",Config for GCP settings"` + FileMountInitContainer FileMountInitContainerConfig `json:"fileMountInitContainer" pflag:",Init container configuration to use for mounting secrets as files."` } type AWSConfig struct { @@ -139,6 +155,11 @@ type GCPConfig struct { Project string `json:"project" pflag:",GCP project to be used for secret manager"` } +type FileMountInitContainerConfig struct { + Image string `json:"image" pflag:",Specifies init container image to use for mounting secrets as files."` + Resources corev1.ResourceRequirements `json:"resources" pflag:"-,Specifies resource requirements for the init container."` +} + func (c Config) ExpandCertDir() string { return os.ExpandEnv(c.CertDir) } diff --git a/flytepropeller/pkg/secret/config/config_flags.go b/flytepropeller/pkg/secret/config/config_flags.go index 3349fd97820..5f0e50a581e 100755 --- a/flytepropeller/pkg/secret/config/config_flags.go +++ b/flytepropeller/pkg/secret/config/config_flags.go @@ -63,5 +63,6 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "embeddedSecretManagerConfig.type"), DefaultConfig.EmbeddedSecretManagerConfig.Type.String(), "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "embeddedSecretManagerConfig.awsConfig.region"), DefaultConfig.EmbeddedSecretManagerConfig.AWSConfig.Region, "AWS region") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "embeddedSecretManagerConfig.gcpConfig.project"), DefaultConfig.EmbeddedSecretManagerConfig.GCPConfig.Project, "GCP project to be used for secret manager") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "embeddedSecretManagerConfig.fileMountInitContainer.image"), DefaultConfig.EmbeddedSecretManagerConfig.FileMountInitContainer.Image, "Specifies init container image to use for mounting secrets as files.") return cmdFlags } diff --git a/flytepropeller/pkg/secret/config/config_flags_test.go b/flytepropeller/pkg/secret/config/config_flags_test.go index da59a4346f8..5f3e767fa88 100755 --- a/flytepropeller/pkg/secret/config/config_flags_test.go +++ b/flytepropeller/pkg/secret/config/config_flags_test.go @@ -281,4 +281,18 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_embeddedSecretManagerConfig.fileMountInitContainer.image", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("embeddedSecretManagerConfig.fileMountInitContainer.image", testValue) + if vString, err := cmdFlags.GetString("embeddedSecretManagerConfig.fileMountInitContainer.image"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.EmbeddedSecretManagerConfig.FileMountInitContainer.Image) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } diff --git a/flytepropeller/pkg/secret/embedded_secret_manager.go b/flytepropeller/pkg/secret/embedded_secret_manager.go index 818c16b46ab..b69735dd666 100644 --- a/flytepropeller/pkg/secret/embedded_secret_manager.go +++ b/flytepropeller/pkg/secret/embedded_secret_manager.go @@ -2,12 +2,11 @@ package secret import ( "context" + "encoding/base64" "fmt" "strings" + "unicode/utf8" - gcpsm "cloud.google.com/go/secretmanager/apiv1" - awsConfig "github.com/aws/aws-sdk-go-v2/config" - awssm "github.com/aws/aws-sdk-go-v2/service/secretsmanager" corev1 "k8s.io/api/core/v1" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" @@ -18,25 +17,28 @@ import ( ) const ( - UnionSecretEnvVarPrefix = "_UNION_" - SecretFieldSeparator = "__" - ValueFormatter = "%s" - SecretsStorageUnionPrefix = "u" - SecretsStorageOrgPrefixFormat = SecretsStorageUnionPrefix + SecretFieldSeparator + "org" + SecretFieldSeparator + ValueFormatter - SecretsStorageDomainPrefixFormat = SecretsStorageOrgPrefixFormat + SecretFieldSeparator + "domain" + SecretFieldSeparator + ValueFormatter - SecretsStorageProjectPrefixFormat = SecretsStorageDomainPrefixFormat + SecretFieldSeparator + "project" + SecretFieldSeparator + ValueFormatter - SecretsStorageFormat = SecretsStorageProjectPrefixFormat + SecretFieldSeparator + "key" + SecretFieldSeparator + ValueFormatter - ProjectLabel = "project" - DomainLabel = "domain" - OrganizationLabel = "organization" - EmptySecretScope = "" - AWSSecretLatesVersion = "AWSCURRENT" - GCPSecretNameFormat = "projects/%s/secrets/%s/versions/latest" // #nosec G101 + UnionSecretEnvVarPrefix = "_UNION_" + // Static name of the volume used for mounting secrets with file mount requirement. + EmbeddedSecretsVolumeName = "embedded-secret-vol" // #nosec G101 + EmbeddedSecretsMountPath = "/etc/flyte/secrets" // #nosec G101 + + SecretFieldSeparator = "__" + ValueFormatter = "%s" + SecretsStorageUnionPrefix = "u" + SecretsStorageOrgPrefixFormat = SecretsStorageUnionPrefix + SecretFieldSeparator + "org" + SecretFieldSeparator + ValueFormatter + SecretsStorageDomainPrefixFormat = SecretsStorageOrgPrefixFormat + SecretFieldSeparator + "domain" + SecretFieldSeparator + ValueFormatter + SecretsStorageProjectPrefixFormat = SecretsStorageDomainPrefixFormat + SecretFieldSeparator + "project" + SecretFieldSeparator + ValueFormatter + SecretsStorageFormat = SecretsStorageProjectPrefixFormat + SecretFieldSeparator + "key" + SecretFieldSeparator + ValueFormatter + ProjectLabel = "project" + DomainLabel = "domain" + OrganizationLabel = "organization" + EmptySecretScope = "" + SecretNotFoundErrorFormat = "secret %v not found in the secret manager" // #nosec G101 SecretReadFailureErrorFormat = "secret %v failed to be read from secret manager" // #nosec G101 SecretNilErrorFormat = "secret %v read as empty from the secret manager" // #nosec G101 SecretRequirementsErrorFormat = "secret read requirements not met due to empty %v field in the pod labels" // #nosec G101 - SecretSecretNotFoundAcrossAllScopes = "secret not found across all scope" // #nosec G101 + SecretSecretNotFoundAcrossAllScopes = "secret not found across all scopes" // #nosec G101 ErrCodeSecretRequirementsError stdlibErrors.ErrorCode = "SecretRequirementsError" // #nosec G101 ErrCodeSecretNotFound stdlibErrors.ErrorCode = "SecretNotFound" // #nosec G101 ErrCodeSecretNotFoundAcrossAllScopes stdlibErrors.ErrorCode = "SecretNotFoundAcrossAllScopes" // #nosec G101 @@ -44,11 +46,6 @@ const ( ErrCodeSecretNil stdlibErrors.ErrorCode = "SecretNil" // #nosec G101 ) -//go:generate mockery --output=./mocks --case=underscore -name=SecretFetcher -type SecretFetcher interface { - Get(ctx context.Context, key string) (string, error) -} - // AWSSecretManagerInjector allows injecting of secrets from AWS Secret Manager as environment variable. It uses AWS-provided SideCar // as an init-container to download the secret and save it to a local volume shared with all other containers in the pod. // It supports multiple secrets to be mounted but that will result into adding an init container for each secret. @@ -94,90 +91,172 @@ func validateRequiredFieldsExist(labels map[string]string) error { return nil } -func (i EmbeddedSecretManagerInjector) lookUpSecret(ctx context.Context, secret *core.Secret, labels map[string]string) (string, error) { +func (i EmbeddedSecretManagerInjector) lookUpSecret(ctx context.Context, secret *core.Secret, labels map[string]string) (*SecretValue, error) { // Fetch the secret from configured secrets manager err := validateRequiredFieldsExist(labels) if err != nil { - return "", err + return nil, err } // Fetch project-domain scoped secret projectDomainScopedSecret := fmt.Sprintf(SecretsStorageFormat, labels[OrganizationLabel], labels[DomainLabel], labels[ProjectLabel], secret.Key) - secretValue, err := i.secretFetcher.Get(ctx, projectDomainScopedSecret) - if err != nil && !stdlibErrors.IsCausedBy(err, ErrCodeSecretNotFound) { - return "", err - } - if len(secretValue) > 0 { + secretValue, err := i.secretFetcher.GetSecretValue(ctx, projectDomainScopedSecret) + if err == nil { return secretValue, nil } + if !stdlibErrors.IsCausedBy(err, ErrCodeSecretNotFound) { + return nil, err + } // Fetch domain scoped secret domainScopedSecret := fmt.Sprintf(SecretsStorageFormat, labels[OrganizationLabel], labels[DomainLabel], EmptySecretScope, secret.Key) - secretValue, err = i.secretFetcher.Get(ctx, domainScopedSecret) - if err != nil && !stdlibErrors.IsCausedBy(err, ErrCodeSecretNotFound) { - return "", err - } - if len(secretValue) > 0 { + secretValue, err = i.secretFetcher.GetSecretValue(ctx, domainScopedSecret) + if err == nil { return secretValue, nil } + if !stdlibErrors.IsCausedBy(err, ErrCodeSecretNotFound) { + return nil, err + } // Fetch organization scoped secret orgScopedSecret := fmt.Sprintf(SecretsStorageFormat, labels[OrganizationLabel], EmptySecretScope, EmptySecretScope, secret.Key) - secretValue, err = i.secretFetcher.Get(ctx, orgScopedSecret) - if err != nil && !stdlibErrors.IsCausedBy(err, ErrCodeSecretNotFound) { - return "", err + secretValue, err = i.secretFetcher.GetSecretValue(ctx, orgScopedSecret) + if err != nil { + return secretValue, err } - if len(secretValue) > 0 { - return secretValue, nil + if !stdlibErrors.IsCausedBy(err, ErrCodeSecretNotFound) { + return nil, err } - return "", stdlibErrors.Errorf(ErrCodeSecretNotFoundAcrossAllScopes, SecretSecretNotFoundAcrossAllScopes) + return nil, stdlibErrors.Errorf(ErrCodeSecretNotFoundAcrossAllScopes, SecretSecretNotFoundAcrossAllScopes) } -func (i EmbeddedSecretManagerInjector) Inject(ctx context.Context, secret *core.Secret, p *corev1.Pod) (newP *corev1.Pod, injected bool, err error) { + +func (i EmbeddedSecretManagerInjector) Inject( + ctx context.Context, + secret *core.Secret, + pod *corev1.Pod, +) (*corev1.Pod, bool /*injected*/, error) { if len(secret.Key) == 0 { - return p, false, fmt.Errorf("EmbeddedSecretManager requires key to be set. "+ - "Secret: [%v]", secret) + return pod, false, fmt.Errorf("EmbeddedSecretManager requires key to be set. Secret: [%v]", secret) + } + + secretValue, err := i.lookUpSecret(ctx, secret, pod.Labels) + if err != nil { + return pod, false, err } switch secret.MountRequirement { case core.Secret_ANY: fallthrough case core.Secret_ENV_VAR: - // Fetch the secret from secrets manager - secretValue, err := i.lookUpSecret(ctx, secret, p.Labels) - if err != nil { - return p, false, err + var stringValue string + if secretValue.StringValue != "" { + stringValue = secretValue.StringValue + } else { + // GCP secrets store values as binary only. This means a secret could be + // defined as a file, but mounted as an environment variable. + // We could fail this path for AWS, but for consistent behaviour between + // AWS and GCP we will allow this path for AWS as well. + if !utf8.Valid(secretValue.BinaryValue) { + return pod, false, fmt.Errorf( + "secret %q is attempted to be mounted as an environment variable, "+ + "but has a binary value that is not a valid UTF-8 string; mount "+ + "as a file instead", secret.Key) + } + stringValue = string(secretValue.BinaryValue) + } + i.injectAsEnvVar(secret.Key, stringValue, pod) + case core.Secret_FILE: + if secretValue.BinaryValue == nil { + return pod, false, fmt.Errorf( + "secret %q is attempted to be mounted as a file, but has no binary "+ + "value; mount as an environment variable instead", secret.Key) } + i.injectAsFile(secret.Key, secretValue.BinaryValue, pod) + default: + err := fmt.Errorf("unrecognized mount requirement [%v] for secret [%v]", secret.MountRequirement.String(), secret.Key) + logger.Error(ctx, err) + return pod, false, err + } + + return pod, true, nil +} - prefixEnvVar := corev1.EnvVar{ +func (i EmbeddedSecretManagerInjector) injectAsEnvVar(secretKey string, secretValue string, pod *corev1.Pod) { + envVars := []corev1.EnvVar{ + { Name: SecretEnvVarPrefix, Value: UnionSecretEnvVarPrefix, - } - // Inject secret-inject webhook annotations to mount the secret in a predictable location. - envVars := []corev1.EnvVar{ - prefixEnvVar, - // Set environment variable to let the container know where to find the mounted files. - { - Name: UnionSecretEnvVarPrefix + strings.ToUpper(secret.Key), - Value: secretValue, + }, + { + Name: UnionSecretEnvVarPrefix + strings.ToUpper(secretKey), + Value: secretValue, + }, + } + pod.Spec.InitContainers = AppendEnvVars(pod.Spec.InitContainers, envVars...) + pod.Spec.Containers = AppendEnvVars(pod.Spec.Containers, envVars...) +} + +func (i EmbeddedSecretManagerInjector) injectAsFile(secretKey string, secretValue []byte, pod *corev1.Pod) { + // A volume with a static name so that if we try to inject multiple secrets, we won't mount multiple volumes. + volume := corev1.Volume{ + Name: EmbeddedSecretsVolumeName, + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{ + Medium: corev1.StorageMediumMemory, }, - } + }, + } + pod.Spec.Volumes = appendVolumeIfNotExists(pod.Spec.Volumes, volume) - for _, envVar := range envVars { - p.Spec.InitContainers = AppendEnvVars(p.Spec.InitContainers, envVar) - p.Spec.Containers = AppendEnvVars(p.Spec.Containers, envVar) - } + secretFilePath := EmbeddedSecretsMountPath + "/" + secretKey + secretInitContainer := corev1.Container{ + Name: "init-embedded-secret-" + secretKey, + Image: i.cfg.FileMountInitContainer.Image, + Env: []corev1.EnvVar{ + { + Name: "SECRET_VALUE", + Value: base64.StdEncoding.EncodeToString(secretValue), + }, + }, + Command: []string{ + "sh", + "-c", + fmt.Sprintf("printf \"%%s\" \"$SECRET_VALUE\" | base64 -d > \"%s\"", secretFilePath), + }, + Resources: i.cfg.FileMountInitContainer.Resources, + VolumeMounts: []corev1.VolumeMount{ + { + Name: EmbeddedSecretsVolumeName, + ReadOnly: false, + MountPath: EmbeddedSecretsMountPath, + }, + }, + } + pod.Spec.InitContainers = append(pod.Spec.InitContainers, secretInitContainer) - case core.Secret_FILE: - err := fmt.Errorf("secret [%v] requirement is not supported for secret [%v]", secret.MountRequirement.String(), secret.Key) - logger.Error(ctx, err) - return p, false, err - default: - err := fmt.Errorf("unrecognized mount requirement [%v] for secret [%v]", secret.MountRequirement.String(), secret.Key) - logger.Error(ctx, err) - return p, false, err + secretVolumeMount := corev1.VolumeMount{ + Name: EmbeddedSecretsVolumeName, + ReadOnly: true, + MountPath: EmbeddedSecretsMountPath, } + pod.Spec.InitContainers = AppendVolumeMounts(pod.Spec.InitContainers, secretVolumeMount) + pod.Spec.Containers = AppendVolumeMounts(pod.Spec.Containers, secretVolumeMount) - return p, true, nil + // Inject AWS secret-inject webhook annotations to mount the secret in a predictable location. + envVars := []corev1.EnvVar{ + // Set environment variable to let the containers know where to find the mounted files. + { + Name: SecretPathDefaultDirEnvVar, + Value: EmbeddedSecretsMountPath, + }, + // Sets an empty prefix to let the containers know the file names will match the secret keys as-is. + { + Name: SecretPathFilePrefixEnvVar, + Value: "", + }, + } + pod.Spec.InitContainers = AppendEnvVars(pod.Spec.InitContainers, envVars...) + pod.Spec.Containers = AppendEnvVars(pod.Spec.Containers, envVars...) } func NewEmbeddedSecretManagerInjector(cfg config.EmbeddedSecretManagerConfig, secretFetcher SecretFetcher) SecretsInjector { @@ -186,23 +265,3 @@ func NewEmbeddedSecretManagerInjector(cfg config.EmbeddedSecretManagerConfig, se secretFetcher: secretFetcher, } } - -func NewSecretFetcherManager(ctx context.Context, cfg config.EmbeddedSecretManagerConfig) (SecretFetcher, error) { - switch cfg.Type { - case config.EmbeddedSecretManagerTypeAWS: - awsCfg, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithRegion(cfg.AWSConfig.Region)) - if err != nil { - logger.Errorf(ctx, "failed to start secret manager service due to %v", err) - return nil, fmt.Errorf("failed to start secret manager service due to %v", err) - } - return NewAWSSecretFetcher(cfg.AWSConfig, awssm.NewFromConfig(awsCfg)), nil - case config.EmbeddedSecretManagerTypeGCP: - gcpSmClient, err := gcpsm.NewClient(ctx) - if err != nil { - logger.Errorf(ctx, "failed to start secret manager service due to %v", err) - return nil, fmt.Errorf("failed to start secret manager service due to %v", err) - } - return NewGCPSecretFetcher(cfg.GCPConfig, gcpSmClient), nil - } - return nil, fmt.Errorf("failed to start secret fetcher service due to unsupported type %v. Only supported for aws and gcp right now", cfg.Type) -} diff --git a/flytepropeller/pkg/secret/embedded_secret_manager_test.go b/flytepropeller/pkg/secret/embedded_secret_manager_test.go index f5b811b5705..b07fba80c2f 100644 --- a/flytepropeller/pkg/secret/embedded_secret_manager_test.go +++ b/flytepropeller/pkg/secret/embedded_secret_manager_test.go @@ -20,7 +20,7 @@ import ( func TestEmbeddedSecretManagerInjector_Inject(t *testing.T) { ctx = context.Background() - gcpClient = &mocks.GCPSecretsIface{} + gcpClient = &mocks.GCPSecretManagerClient{} gcpProject = "project" secretIDKey := "secretID" secretValue := "secretValue" diff --git a/flytepropeller/pkg/secret/gcp_secret_fetcher.go b/flytepropeller/pkg/secret/gcp_secret_fetcher.go index 610679d567c..2f0a7ff340d 100644 --- a/flytepropeller/pkg/secret/gcp_secret_fetcher.go +++ b/flytepropeller/pkg/secret/gcp_secret_fetcher.go @@ -13,16 +13,16 @@ import ( "github.com/flyteorg/flyte/flytestdlib/logger" ) +const ( + GCPSecretNameFormat = "projects/%s/secrets/%s/versions/latest" // #nosec G101 +) + type GCPSecretFetcher struct { - client GCPSecretsIface + client GCPSecretManagerClient cfg config.GCPConfig } -func (g GCPSecretFetcher) Get(ctx context.Context, key string) (string, error) { - return g.GetSecretValue(ctx, key) -} - -func (g GCPSecretFetcher) GetSecretValue(ctx context.Context, secretID string) (string, error) { +func (g GCPSecretFetcher) GetSecretValue(ctx context.Context, secretID string) (*SecretValue, error) { logger.Infof(ctx, "Got fetch secret Request for %v!", secretID) resp, err := g.client.AccessSecretVersion(ctx, &gcpsmpb.AccessSecretVersionRequest{ Name: fmt.Sprintf(GCPSecretNameFormat, g.cfg.Project, secretID), @@ -31,21 +31,24 @@ func (g GCPSecretFetcher) GetSecretValue(ctx context.Context, secretID string) ( if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { wrappedErr := stdlibErrors.Wrapf(ErrCodeSecretNotFound, err, fmt.Sprintf(SecretNotFoundErrorFormat, secretID)) logger.Warn(ctx, wrappedErr) - return "", wrappedErr + return nil, wrappedErr } wrappedErr := stdlibErrors.Wrapf(ErrCodeSecretReadFailure, err, fmt.Sprintf(SecretReadFailureErrorFormat, secretID)) logger.Error(ctx, wrappedErr) - return "", wrappedErr + return nil, wrappedErr } if resp.GetPayload() == nil { wrappedErr := stdlibErrors.Wrapf(ErrCodeSecretNil, err, fmt.Sprintf(SecretNilErrorFormat, secretID)) logger.Error(ctx, wrappedErr) - return "", wrappedErr + return nil, wrappedErr } - return string(resp.GetPayload().GetData()), nil + + return &SecretValue{ + BinaryValue: resp.GetPayload().GetData(), + }, nil } // NewGCPSecretFetcher creates a secret value fetcher for GCP -func NewGCPSecretFetcher(cfg config.GCPConfig, client GCPSecretsIface) SecretFetcher { +func NewGCPSecretFetcher(cfg config.GCPConfig, client GCPSecretManagerClient) SecretFetcher { return GCPSecretFetcher{cfg: cfg, client: client} } diff --git a/flytepropeller/pkg/secret/gcp_secret_fetcher_test.go b/flytepropeller/pkg/secret/gcp_secret_fetcher_test.go index 5bc0764154d..3aacb899664 100644 --- a/flytepropeller/pkg/secret/gcp_secret_fetcher_test.go +++ b/flytepropeller/pkg/secret/gcp_secret_fetcher_test.go @@ -17,14 +17,14 @@ import ( ) var ( - gcpClient *mocks.GCPSecretsIface + gcpClient *mocks.GCPSecretManagerClient gcpProject string ) func SetupGCPTest() { scope = promutils.NewTestScope() ctx = context.Background() - gcpClient = &mocks.GCPSecretsIface{} + gcpClient = &mocks.GCPSecretManagerClient{} gcpProject = "project" } @@ -42,7 +42,7 @@ func TestGetSecretValueGCP(t *testing.T) { }, }, nil) - _, err := gcpSecretsFetcher.Get(ctx, "secretID") + _, err := gcpSecretsFetcher.GetSecretValue(ctx, "secretID") assert.NoError(t, err) }) @@ -56,7 +56,7 @@ func TestGetSecretValueGCP(t *testing.T) { Name: fmt.Sprintf(GCPSecretNameFormat, gcpProject, secretID), }).Return(nil, cause) - _, err := gcpSecretsFetcher.Get(ctx, "secretID") + _, err := gcpSecretsFetcher.GetSecretValue(ctx, "secretID") assert.Equal(t, stdlibErrors.Wrapf(ErrCodeSecretNotFound, cause, fmt.Sprintf(SecretNotFoundErrorFormat, secretID)), err) }) @@ -70,7 +70,7 @@ func TestGetSecretValueGCP(t *testing.T) { Name: fmt.Sprintf(GCPSecretNameFormat, gcpProject, secretID), }).Return(nil, cause) - _, err := gcpSecretsFetcher.Get(ctx, "secretID") + _, err := gcpSecretsFetcher.GetSecretValue(ctx, "secretID") assert.Equal(t, stdlibErrors.Wrapf(ErrCodeSecretReadFailure, cause, fmt.Sprintf(SecretReadFailureErrorFormat, secretID)), err) }) } diff --git a/flytepropeller/pkg/secret/gcp_secret_manager.go b/flytepropeller/pkg/secret/gcp_secret_manager.go index bea22ba725c..f1a9ad66f72 100644 --- a/flytepropeller/pkg/secret/gcp_secret_manager.go +++ b/flytepropeller/pkg/secret/gcp_secret_manager.go @@ -3,7 +3,6 @@ package secret import ( "context" "fmt" - "os" "path/filepath" "strings" @@ -17,12 +16,9 @@ import ( const ( // GCPSecretsVolumeName defines the static name of the volume used for mounting/sharing secrets between init-container // sidecar and the rest of the containers in the pod. - GCPSecretsVolumeName = "gcp-secret-vol" // #nosec -) + GCPSecretsVolumeName = "gcp-secret-vol" // #nosec G101 -var ( - // GCPSecretMountPath defines the default mount path for secrets - GCPSecretMountPath = filepath.Join(string(os.PathSeparator), "etc", "flyte", "secrets") + GCPSecretMountPath = "/etc/flyte/secrets" // #nosec G101 ) // GCPSecretManagerInjector allows injecting of secrets from GCP Secret Manager as files. It uses a Google Cloud diff --git a/flytepropeller/pkg/secret/global_secrets.go b/flytepropeller/pkg/secret/global_secrets.go index be64f32b36e..d9314316e39 100644 --- a/flytepropeller/pkg/secret/global_secrets.go +++ b/flytepropeller/pkg/secret/global_secrets.go @@ -12,8 +12,7 @@ import ( "github.com/flyteorg/flyte/flytestdlib/logger" ) -//go:generate mockery -all -case=underscore - +//go:generate mockery --output=./mocks --case=underscore --name=GlobalSecretProvider type GlobalSecretProvider interface { GetForSecret(ctx context.Context, secret *coreIdl.Secret) (string, error) } diff --git a/flytepropeller/pkg/secret/mocks/aws_secrets_iface.go b/flytepropeller/pkg/secret/mocks/aws_secret_manager_client.go similarity index 52% rename from flytepropeller/pkg/secret/mocks/aws_secrets_iface.go rename to flytepropeller/pkg/secret/mocks/aws_secret_manager_client.go index 0365c610056..345623b599f 100644 --- a/flytepropeller/pkg/secret/mocks/aws_secrets_iface.go +++ b/flytepropeller/pkg/secret/mocks/aws_secret_manager_client.go @@ -10,31 +10,31 @@ import ( secretsmanager "github.com/aws/aws-sdk-go-v2/service/secretsmanager" ) -// AWSSecretsIface is an autogenerated mock type for the AWSSecretsIface type -type AWSSecretsIface struct { +// AWSSecretManagerClient is an autogenerated mock type for the AWSSecretManagerClient type +type AWSSecretManagerClient struct { mock.Mock } -type AWSSecretsIface_GetSecretValue struct { +type AWSSecretManagerClient_GetSecretValue struct { *mock.Call } -func (_m AWSSecretsIface_GetSecretValue) Return(_a0 *secretsmanager.GetSecretValueOutput, _a1 error) *AWSSecretsIface_GetSecretValue { - return &AWSSecretsIface_GetSecretValue{Call: _m.Call.Return(_a0, _a1)} +func (_m AWSSecretManagerClient_GetSecretValue) Return(_a0 *secretsmanager.GetSecretValueOutput, _a1 error) *AWSSecretManagerClient_GetSecretValue { + return &AWSSecretManagerClient_GetSecretValue{Call: _m.Call.Return(_a0, _a1)} } -func (_m *AWSSecretsIface) OnGetSecretValue(_a0 context.Context, _a1 *secretsmanager.GetSecretValueInput, _a2 ...func(*secretsmanager.Options)) *AWSSecretsIface_GetSecretValue { +func (_m *AWSSecretManagerClient) OnGetSecretValue(_a0 context.Context, _a1 *secretsmanager.GetSecretValueInput, _a2 ...func(*secretsmanager.Options)) *AWSSecretManagerClient_GetSecretValue { c_call := _m.On("GetSecretValue", _a0, _a1, _a2) - return &AWSSecretsIface_GetSecretValue{Call: c_call} + return &AWSSecretManagerClient_GetSecretValue{Call: c_call} } -func (_m *AWSSecretsIface) OnGetSecretValueMatch(matchers ...interface{}) *AWSSecretsIface_GetSecretValue { +func (_m *AWSSecretManagerClient) OnGetSecretValueMatch(matchers ...interface{}) *AWSSecretManagerClient_GetSecretValue { c_call := _m.On("GetSecretValue", matchers...) - return &AWSSecretsIface_GetSecretValue{Call: c_call} + return &AWSSecretManagerClient_GetSecretValue{Call: c_call} } // GetSecretValue provides a mock function with given fields: _a0, _a1, _a2 -func (_m *AWSSecretsIface) GetSecretValue(_a0 context.Context, _a1 *secretsmanager.GetSecretValueInput, _a2 ...func(*secretsmanager.Options)) (*secretsmanager.GetSecretValueOutput, error) { +func (_m *AWSSecretManagerClient) GetSecretValue(_a0 context.Context, _a1 *secretsmanager.GetSecretValueInput, _a2 ...func(*secretsmanager.Options)) (*secretsmanager.GetSecretValueOutput, error) { _va := make([]interface{}, len(_a2)) for _i := range _a2 { _va[_i] = _a2[_i] diff --git a/flytepropeller/pkg/secret/mocks/gcp_secrets_iface.go b/flytepropeller/pkg/secret/mocks/gcp_secret_manager_client.go similarity index 53% rename from flytepropeller/pkg/secret/mocks/gcp_secrets_iface.go rename to flytepropeller/pkg/secret/mocks/gcp_secret_manager_client.go index 14361afcb3d..c366d7976e1 100644 --- a/flytepropeller/pkg/secret/mocks/gcp_secrets_iface.go +++ b/flytepropeller/pkg/secret/mocks/gcp_secret_manager_client.go @@ -11,31 +11,31 @@ import ( secretmanagerpb "cloud.google.com/go/secretmanager/apiv1/secretmanagerpb" ) -// GCPSecretsIface is an autogenerated mock type for the GCPSecretsIface type -type GCPSecretsIface struct { +// GCPSecretManagerClient is an autogenerated mock type for the GCPSecretManagerClient type +type GCPSecretManagerClient struct { mock.Mock } -type GCPSecretsIface_AccessSecretVersion struct { +type GCPSecretManagerClient_AccessSecretVersion struct { *mock.Call } -func (_m GCPSecretsIface_AccessSecretVersion) Return(_a0 *secretmanagerpb.AccessSecretVersionResponse, _a1 error) *GCPSecretsIface_AccessSecretVersion { - return &GCPSecretsIface_AccessSecretVersion{Call: _m.Call.Return(_a0, _a1)} +func (_m GCPSecretManagerClient_AccessSecretVersion) Return(_a0 *secretmanagerpb.AccessSecretVersionResponse, _a1 error) *GCPSecretManagerClient_AccessSecretVersion { + return &GCPSecretManagerClient_AccessSecretVersion{Call: _m.Call.Return(_a0, _a1)} } -func (_m *GCPSecretsIface) OnAccessSecretVersion(ctx context.Context, req *secretmanagerpb.AccessSecretVersionRequest, opts ...gax.CallOption) *GCPSecretsIface_AccessSecretVersion { +func (_m *GCPSecretManagerClient) OnAccessSecretVersion(ctx context.Context, req *secretmanagerpb.AccessSecretVersionRequest, opts ...gax.CallOption) *GCPSecretManagerClient_AccessSecretVersion { c_call := _m.On("AccessSecretVersion", ctx, req, opts) - return &GCPSecretsIface_AccessSecretVersion{Call: c_call} + return &GCPSecretManagerClient_AccessSecretVersion{Call: c_call} } -func (_m *GCPSecretsIface) OnAccessSecretVersionMatch(matchers ...interface{}) *GCPSecretsIface_AccessSecretVersion { +func (_m *GCPSecretManagerClient) OnAccessSecretVersionMatch(matchers ...interface{}) *GCPSecretManagerClient_AccessSecretVersion { c_call := _m.On("AccessSecretVersion", matchers...) - return &GCPSecretsIface_AccessSecretVersion{Call: c_call} + return &GCPSecretManagerClient_AccessSecretVersion{Call: c_call} } // AccessSecretVersion provides a mock function with given fields: ctx, req, opts -func (_m *GCPSecretsIface) AccessSecretVersion(ctx context.Context, req *secretmanagerpb.AccessSecretVersionRequest, opts ...gax.CallOption) (*secretmanagerpb.AccessSecretVersionResponse, error) { +func (_m *GCPSecretManagerClient) AccessSecretVersion(ctx context.Context, req *secretmanagerpb.AccessSecretVersionRequest, opts ...gax.CallOption) (*secretmanagerpb.AccessSecretVersionResponse, error) { _va := make([]interface{}, len(opts)) for _i := range opts { _va[_i] = opts[_i] diff --git a/flytepropeller/pkg/secret/mocks/http_hook_registerer_iface.go b/flytepropeller/pkg/secret/mocks/http_hook_registerer_iface.go deleted file mode 100644 index 6e4db208caa..00000000000 --- a/flytepropeller/pkg/secret/mocks/http_hook_registerer_iface.go +++ /dev/null @@ -1,19 +0,0 @@ -// Code generated by mockery v1.0.1. DO NOT EDIT. - -package mocks - -import ( - http "net/http" - - mock "github.com/stretchr/testify/mock" -) - -// HTTPHookRegistererIface is an autogenerated mock type for the HTTPHookRegistererIface type -type HTTPHookRegistererIface struct { - mock.Mock -} - -// Register provides a mock function with given fields: path, hook -func (_m *HTTPHookRegistererIface) Register(path string, hook http.Handler) { - _m.Called(path, hook) -} diff --git a/flytepropeller/pkg/secret/mocks/secret_fetcher.go b/flytepropeller/pkg/secret/mocks/secret_fetcher.go deleted file mode 100644 index b8671748a72..00000000000 --- a/flytepropeller/pkg/secret/mocks/secret_fetcher.go +++ /dev/null @@ -1,53 +0,0 @@ -// Code generated by mockery v1.0.1. DO NOT EDIT. - -package mocks - -import ( - context "context" - - mock "github.com/stretchr/testify/mock" -) - -// SecretFetcher is an autogenerated mock type for the SecretFetcher type -type SecretFetcher struct { - mock.Mock -} - -type SecretFetcher_Get struct { - *mock.Call -} - -func (_m SecretFetcher_Get) Return(_a0 string, _a1 error) *SecretFetcher_Get { - return &SecretFetcher_Get{Call: _m.Call.Return(_a0, _a1)} -} - -func (_m *SecretFetcher) OnGet(ctx context.Context, key string) *SecretFetcher_Get { - c_call := _m.On("Get", ctx, key) - return &SecretFetcher_Get{Call: c_call} -} - -func (_m *SecretFetcher) OnGetMatch(matchers ...interface{}) *SecretFetcher_Get { - c_call := _m.On("Get", matchers...) - return &SecretFetcher_Get{Call: c_call} -} - -// Get provides a mock function with given fields: ctx, key -func (_m *SecretFetcher) Get(ctx context.Context, key string) (string, error) { - ret := _m.Called(ctx, key) - - var r0 string - if rf, ok := ret.Get(0).(func(context.Context, string) string); ok { - r0 = rf(ctx, key) - } else { - r0 = ret.Get(0).(string) - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, key) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} diff --git a/flytepropeller/pkg/secret/secret_fetcher.go b/flytepropeller/pkg/secret/secret_fetcher.go new file mode 100644 index 00000000000..80270522145 --- /dev/null +++ b/flytepropeller/pkg/secret/secret_fetcher.go @@ -0,0 +1,42 @@ +package secret + +import ( + "context" + "fmt" + + gcpsm "cloud.google.com/go/secretmanager/apiv1" + awsConfig "github.com/aws/aws-sdk-go-v2/config" + awssm "github.com/aws/aws-sdk-go-v2/service/secretsmanager" + + "github.com/flyteorg/flyte/flytepropeller/pkg/secret/config" + "github.com/flyteorg/flyte/flytestdlib/logger" +) + +type SecretFetcher interface { + GetSecretValue(ctx context.Context, secretID string) (*SecretValue, error) +} + +type SecretValue struct { + StringValue string + BinaryValue []byte +} + +func NewSecretFetcher(ctx context.Context, cfg config.EmbeddedSecretManagerConfig) (SecretFetcher, error) { + switch cfg.Type { + case config.EmbeddedSecretManagerTypeAWS: + awsCfg, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithRegion(cfg.AWSConfig.Region)) + if err != nil { + logger.Errorf(ctx, "failed to start secret manager service due to %v", err) + return nil, fmt.Errorf("failed to start secret manager service due to %v", err) + } + return NewAWSSecretFetcher(cfg.AWSConfig, awssm.NewFromConfig(awsCfg)), nil + case config.EmbeddedSecretManagerTypeGCP: + gcpSmClient, err := gcpsm.NewClient(ctx) + if err != nil { + logger.Errorf(ctx, "failed to start secret manager service due to %v", err) + return nil, fmt.Errorf("failed to start secret manager service due to %v", err) + } + return NewGCPSecretFetcher(cfg.GCPConfig, gcpSmClient), nil + } + return nil, fmt.Errorf("failed to start secret fetcher service due to unsupported type %v. Only supported for aws and gcp right now", cfg.Type) +} diff --git a/flytepropeller/pkg/secret/embedded_secret_iface.go b/flytepropeller/pkg/secret/secret_manager_client.go similarity index 62% rename from flytepropeller/pkg/secret/embedded_secret_iface.go rename to flytepropeller/pkg/secret/secret_manager_client.go index 4771c5d426d..84722609636 100644 --- a/flytepropeller/pkg/secret/embedded_secret_iface.go +++ b/flytepropeller/pkg/secret/secret_manager_client.go @@ -8,16 +8,16 @@ import ( "github.com/googleapis/gax-go/v2" ) -//go:generate mockery --output=./mocks --case=underscore -name=AWSSecretsIface +//go:generate mockery --output=./mocks --case=underscore -name=AWSSecretManagerClient -// AWSSecretsIface AWS Secret Manager API interface used in the webhook for looking up the secret to mount on the user pod. -type AWSSecretsIface interface { +// AWSSecretManagerClient AWS Secret Manager API interface used in the webhook for looking up the secret to mount on the user pod. +type AWSSecretManagerClient interface { GetSecretValue(context.Context, *secretsmanager.GetSecretValueInput, ...func(*secretsmanager.Options)) (*secretsmanager.GetSecretValueOutput, error) } -// GCPSecretsIface GCP Secret Manager API interface used in the webhook for looking up the secret to mount on the user pod. +// GCPSecretManagerClient GCP Secret Manager API interface used in the webhook for looking up the secret to mount on the user pod. // -//go:generate mockery --output=./mocks --case=underscore -name=GCPSecretsIface -type GCPSecretsIface interface { +//go:generate mockery --output=./mocks --case=underscore -name=GCPSecretManagerClient +type GCPSecretManagerClient interface { AccessSecretVersion(ctx context.Context, req *secretmanagerpb.AccessSecretVersionRequest, opts ...gax.CallOption) (*secretmanagerpb.AccessSecretVersionResponse, error) } diff --git a/flytepropeller/pkg/secret/secrets_injector.go b/flytepropeller/pkg/secret/secrets_injector.go new file mode 100644 index 00000000000..24d4d3daba4 --- /dev/null +++ b/flytepropeller/pkg/secret/secrets_injector.go @@ -0,0 +1,46 @@ +package secret + +import ( + "context" + "fmt" + + corev1 "k8s.io/api/core/v1" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/task/secretmanager" + "github.com/flyteorg/flyte/flytepropeller/pkg/secret/config" +) + +//go:generate mockery --output=./mocks --case=underscore --name=SecretsInjector +type SecretsInjector interface { + Type() config.SecretManagerType + Inject(ctx context.Context, secrets *core.Secret, p *corev1.Pod) (newP *corev1.Pod, injected bool, err error) +} + +func newSecretsInjector( + ctx context.Context, + secretManagerType config.SecretManagerType, + webhookConfig *config.Config, + globalSecretManagerConfig *secretmanager.Config, +) (SecretsInjector, error) { + switch secretManagerType { + case config.SecretManagerTypeGlobal: + return NewGlobalSecrets(secretmanager.NewFileEnvSecretManager(globalSecretManagerConfig)), nil + case config.SecretManagerTypeK8s: + return NewK8sSecretsInjector(), nil + case config.SecretManagerTypeAWS: + return NewAWSSecretManagerInjector(webhookConfig.AWSSecretManagerConfig), nil + case config.SecretManagerTypeGCP: + return NewGCPSecretManagerInjector(webhookConfig.GCPSecretManagerConfig), nil + case config.SecretManagerTypeVault: + return NewVaultSecretManagerInjector(webhookConfig.VaultSecretManagerConfig), nil + case config.SecretManagerTypeEmbedded: + secretFetcher, err := NewSecretFetcher(ctx, webhookConfig.EmbeddedSecretManagerConfig) + if err != nil { + return nil, err + } + return NewEmbeddedSecretManagerInjector(webhookConfig.EmbeddedSecretManagerConfig, secretFetcher), nil + default: + return nil, fmt.Errorf("unrecognized secret manager type [%v]", secretManagerType) + } +} diff --git a/flytepropeller/pkg/secret/secrets.go b/flytepropeller/pkg/secret/secrets_pod_mutator.go similarity index 64% rename from flytepropeller/pkg/secret/secrets.go rename to flytepropeller/pkg/secret/secrets_pod_mutator.go index 59310b727c6..d71c49db548 100644 --- a/flytepropeller/pkg/secret/secrets.go +++ b/flytepropeller/pkg/secret/secrets_pod_mutator.go @@ -25,7 +25,7 @@ const ( SecretsID = "secrets" ) -type SecretsMutator struct { +type SecretsPodMutator struct { // Secret manager types in order that they should be used. enabledSecretManagerTypes []config.SecretManagerType @@ -33,16 +33,11 @@ type SecretsMutator struct { injectors map[config.SecretManagerType]SecretsInjector } -type SecretsInjector interface { - Type() config.SecretManagerType - Inject(ctx context.Context, secrets *core.Secret, p *corev1.Pod) (newP *corev1.Pod, injected bool, err error) -} - -func (s SecretsMutator) ID() string { +func (s SecretsPodMutator) ID() string { return SecretsID } -func (s *SecretsMutator) Mutate(ctx context.Context, pod *corev1.Pod) (newP *corev1.Pod, podChanged bool, errResponse *admission.Response) { +func (s *SecretsPodMutator) Mutate(ctx context.Context, pod *corev1.Pod) (newP *corev1.Pod, podChanged bool, errResponse *admission.Response) { secrets, err := secretUtils.UnmarshalStringMapToSecrets(pod.GetAnnotations()) if err != nil { admissionError := admission.Errored(http.StatusBadRequest, fmt.Errorf("failed to unmarshal secrets from pod annotations: %w", err)) @@ -67,7 +62,7 @@ func (s *SecretsMutator) Mutate(ctx context.Context, pod *corev1.Pod) (newP *cor return pod, len(secrets) > 0, nil } -func (s *SecretsMutator) LabelSelector() *metav1.LabelSelector { +func (s *SecretsPodMutator) LabelSelector() *metav1.LabelSelector { return &metav1.LabelSelector{ MatchLabels: map[string]string{ secretUtils.PodLabel: secretUtils.PodLabelValue, @@ -75,7 +70,7 @@ func (s *SecretsMutator) LabelSelector() *metav1.LabelSelector { } } -func (s *SecretsMutator) injectSecret(ctx context.Context, secret *core.Secret, pod *corev1.Pod) (*corev1.Pod, bool /*injected*/, error) { +func (s *SecretsPodMutator) injectSecret(ctx context.Context, secret *core.Secret, pod *corev1.Pod) (*corev1.Pod, bool /*injected*/, error) { errs := make([]error, 0) logger.Debugf(ctx, "Injecting secret [%v].", secret) @@ -102,7 +97,7 @@ func (s *SecretsMutator) injectSecret(ctx context.Context, secret *core.Secret, } // NewSecretsMutator creates a new SecretsMutator with all available plugins. -func NewSecretsMutator(ctx context.Context, cfg *config.Config, _ promutils.Scope) (*SecretsMutator, error) { +func NewSecretsMutator(ctx context.Context, cfg *config.Config, _ promutils.Scope) (*SecretsPodMutator, error) { enabledSecretManagerTypes := []config.SecretManagerType{ config.SecretManagerTypeGlobal, } @@ -115,43 +110,15 @@ func NewSecretsMutator(ctx context.Context, cfg *config.Config, _ promutils.Scop injectors := make(map[config.SecretManagerType]SecretsInjector, len(enabledSecretManagerTypes)) globalSecretManagerConfig := secretmanager.GetConfig() for _, secretManagerType := range enabledSecretManagerTypes { - injector, err := newSecretManager(ctx, secretManagerType, cfg, globalSecretManagerConfig) + injector, err := newSecretsInjector(ctx, secretManagerType, cfg, globalSecretManagerConfig) if err != nil { return nil, err } injectors[secretManagerType] = injector } - return &SecretsMutator{ + return &SecretsPodMutator{ enabledSecretManagerTypes, injectors, }, nil } - -func newSecretManager( - ctx context.Context, - secretManagerType config.SecretManagerType, - webhookConfig *config.Config, - globalSecretManagerConfig *secretmanager.Config, -) (SecretsInjector, error) { - switch secretManagerType { - case config.SecretManagerTypeGlobal: - return NewGlobalSecrets(secretmanager.NewFileEnvSecretManager(globalSecretManagerConfig)), nil - case config.SecretManagerTypeK8s: - return NewK8sSecretsInjector(), nil - case config.SecretManagerTypeAWS: - return NewAWSSecretManagerInjector(webhookConfig.AWSSecretManagerConfig), nil - case config.SecretManagerTypeGCP: - return NewGCPSecretManagerInjector(webhookConfig.GCPSecretManagerConfig), nil - case config.SecretManagerTypeVault: - return NewVaultSecretManagerInjector(webhookConfig.VaultSecretManagerConfig), nil - case config.SecretManagerTypeEmbedded: - secretFetcher, err := NewSecretFetcherManager(ctx, webhookConfig.EmbeddedSecretManagerConfig) - if err != nil { - return nil, err - } - return NewEmbeddedSecretManagerInjector(webhookConfig.EmbeddedSecretManagerConfig, secretFetcher), nil - default: - return nil, fmt.Errorf("unrecognized secret manager type [%v]", secretManagerType) - } -} diff --git a/flytepropeller/pkg/secret/secrets_test.go b/flytepropeller/pkg/secret/secrets_test.go index 31311356f4c..a1ae4b77ab4 100644 --- a/flytepropeller/pkg/secret/secrets_test.go +++ b/flytepropeller/pkg/secret/secrets_test.go @@ -18,7 +18,7 @@ import ( func TestSecretsWebhook_Mutate(t *testing.T) { t.Run("No injectors", func(t *testing.T) { - m := SecretsMutator{} + m := SecretsPodMutator{} _, changed, err := m.Mutate(context.Background(), &corev1.Pod{}) assert.Nil(t, err) assert.False(t, changed) @@ -37,7 +37,7 @@ func TestSecretsWebhook_Mutate(t *testing.T) { mutator.OnInjectMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil, false, fmt.Errorf("failed")) mutator.OnType().Return(config.SecretManagerTypeGlobal) - m := SecretsMutator{ + m := SecretsPodMutator{ enabledSecretManagerTypes: []config.SecretManagerType{config.SecretManagerTypeGlobal}, injectors: map[config.SecretManagerType]SecretsInjector{ config.SecretManagerTypeGlobal: mutator, @@ -54,7 +54,7 @@ func TestSecretsWebhook_Mutate(t *testing.T) { mutator.OnInjectMatch(mock.Anything, mock.Anything, mock.Anything).Return(&corev1.Pod{}, true, nil) mutator.OnType().Return(config.SecretManagerTypeGlobal) - m := SecretsMutator{ + m := SecretsPodMutator{ enabledSecretManagerTypes: []config.SecretManagerType{config.SecretManagerTypeGlobal}, injectors: map[config.SecretManagerType]SecretsInjector{ config.SecretManagerTypeGlobal: mutator, @@ -68,7 +68,7 @@ func TestSecretsWebhook_Mutate(t *testing.T) { } func TestSecrets_LabelSelector(t *testing.T) { - m := SecretsMutator{} + m := SecretsPodMutator{} expected := metav1.LabelSelector{ MatchLabels: map[string]string{ secretUtils.PodLabel: secretUtils.PodLabelValue, diff --git a/flytepropeller/pkg/secret/utils.go b/flytepropeller/pkg/secret/utils.go index e3fb04105b7..36c054c6ebd 100644 --- a/flytepropeller/pkg/secret/utils.go +++ b/flytepropeller/pkg/secret/utils.go @@ -78,16 +78,19 @@ func AppendVolumeMounts(containers []corev1.Container, mount corev1.VolumeMount) return res } -func AppendEnvVars(containers []corev1.Container, envVar corev1.EnvVar) []corev1.Container { +func AppendEnvVars(containers []corev1.Container, envVars ...corev1.EnvVar) []corev1.Container { res := make([]corev1.Container, 0, len(containers)) for _, c := range containers { - if foundIndex := hasEnvVar(c.Env, envVar.Name); foundIndex >= 0 { - // This would be someone adding a duplicate key to what the webhook is trying to add.We should delete the existing one and then add the new at the beginning - c.Env = append(c.Env[:foundIndex], c.Env[foundIndex+1:]...) + for _, envVar := range envVars { + if foundIndex := hasEnvVar(c.Env, envVar.Name); foundIndex >= 0 { + // This would be someone adding a duplicate key to what the webhook is trying to add. + // We should delete the existing one and then add the new at the beginning + c.Env = append(c.Env[:foundIndex], c.Env[foundIndex+1:]...) + } + + // Append the passed in environment variable to the start of the list. + c.Env = append([]corev1.EnvVar{envVar}, c.Env...) } - // Append the passed in environment variable to the start of the list. - // With multiple calls to this function too, eg : in case of injecting multiple secrets, the same premise holds. - c.Env = append([]corev1.EnvVar{envVar}, c.Env...) res = append(res, c) } diff --git a/flytepropeller/pkg/secret/mocks/pod_mutator.go b/flytepropeller/pkg/webhook/mocks/pod_mutator.go similarity index 100% rename from flytepropeller/pkg/secret/mocks/pod_mutator.go rename to flytepropeller/pkg/webhook/mocks/pod_mutator.go diff --git a/flytepropeller/pkg/webhook/pod.go b/flytepropeller/pkg/webhook/pod.go index 1029761d372..f0527a466f2 100644 --- a/flytepropeller/pkg/webhook/pod.go +++ b/flytepropeller/pkg/webhook/pod.go @@ -97,6 +97,8 @@ type httpHandler struct { path string } +//go:generate mockery --output=./mocks --case=underscore --name=PodMutator + // PodMutator contains the business logic for a unique type of mutation or validation. type PodMutator interface { ID() string diff --git a/flytepropeller/pkg/webhook/pod_test.go b/flytepropeller/pkg/webhook/pod_test.go index c8fd5ca42d5..9d3ca5be476 100644 --- a/flytepropeller/pkg/webhook/pod_test.go +++ b/flytepropeller/pkg/webhook/pod_test.go @@ -18,7 +18,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils/secrets" "github.com/flyteorg/flyte/flytepropeller/pkg/secret" "github.com/flyteorg/flyte/flytepropeller/pkg/secret/config" - secretMocks "github.com/flyteorg/flyte/flytepropeller/pkg/secret/mocks" + "github.com/flyteorg/flyte/flytepropeller/pkg/webhook/mocks" "github.com/flyteorg/flyte/flytestdlib/promutils" ) @@ -89,11 +89,11 @@ func TestPodMutator_Mutate(t *testing.T) { }, } - successMutator := &secretMocks.PodMutator{} + successMutator := &mocks.PodMutator{} successMutator.OnID().Return("SucceedingMutator") successMutator.OnMutateMatch(mock.Anything, mock.Anything).Return(nil, false, nil) - failedMutator := &secretMocks.PodMutator{} + failedMutator := &mocks.PodMutator{} failedMutator.OnID().Return("FailingMutator") admissionError := admission.Errored(http.StatusBadRequest, fmt.Errorf("failing mock")) failedMutator.OnMutateMatch(mock.Anything, mock.Anything).Return(nil, false, &admissionError) @@ -187,7 +187,7 @@ func Test_Register(t *testing.T) { }, latest.Scheme, promutils.NewTestScope()) assert.NoError(t, err) - mockRegister := &secretMocks.HTTPHookRegistererIface{} + mockRegister := &mocks.HTTPHookRegistererIface{} wh := &admission.Webhook{Handler: pm.httpHandlers[0]} mockRegister.On("Register", "/mutate--v1-pod/secrets", wh) err = pm.Register(ctx, mockRegister) @@ -202,7 +202,7 @@ func Test_Register(t *testing.T) { }, latest.Scheme, promutils.NewTestScope()) assert.NoError(t, err) - mockRegister := &secretMocks.HTTPHookRegistererIface{} + mockRegister := &mocks.HTTPHookRegistererIface{} secretWH := &admission.Webhook{Handler: pm.httpHandlers[0]} mockRegister.On("Register", getPodMutatePath(secret.SecretsID), secretWH) imageBuilderWH := &admission.Webhook{Handler: pm.httpHandlers[1]} diff --git a/go.mod b/go.mod index e9fbbe02062..76bf6f58b80 100644 --- a/go.mod +++ b/go.mod @@ -9,9 +9,11 @@ require ( github.com/flyteorg/flyte/flytepropeller v0.0.0-00010101000000-000000000000 github.com/flyteorg/flyte/flytestdlib v0.0.0-00010101000000-000000000000 github.com/golang/glog v1.2.0 + github.com/mitchellh/mapstructure v1.5.0 github.com/prometheus/client_golang v1.17.0 github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 + github.com/stretchr/testify v1.9.0 golang.org/x/sync v0.7.0 gorm.io/driver/postgres v1.5.3 sigs.k8s.io/controller-runtime v0.16.3 @@ -156,7 +158,6 @@ require ( github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect github.com/mattn/goveralls v0.0.6 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect - github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect @@ -191,7 +192,6 @@ require ( github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/viper v1.11.0 // indirect github.com/stretchr/objx v0.5.2 // indirect - github.com/stretchr/testify v1.9.0 // indirect github.com/subosito/gotenv v1.2.0 // indirect github.com/tidwall/gjson v1.17.0 // indirect github.com/tidwall/match v1.1.1 // indirect