diff --git a/flytepropeller/pkg/controller/nodes/node_exec_context.go b/flytepropeller/pkg/controller/nodes/node_exec_context.go index 736540e9632..1321903cfd7 100644 --- a/flytepropeller/pkg/controller/nodes/node_exec_context.go +++ b/flytepropeller/pkg/controller/nodes/node_exec_context.go @@ -3,9 +3,8 @@ package nodes import ( "context" "fmt" - "strconv" - "slices" + "strconv" _struct "github.com/golang/protobuf/ptypes/struct" "github.com/pkg/errors" diff --git a/flytepropeller/pkg/controller/nodes/task/k8s/event_watcher_test.go b/flytepropeller/pkg/controller/nodes/task/k8s/event_watcher_test.go index 53932eef01b..37e4ba11ffa 100644 --- a/flytepropeller/pkg/controller/nodes/task/k8s/event_watcher_test.go +++ b/flytepropeller/pkg/controller/nodes/task/k8s/event_watcher_test.go @@ -1,7 +1,6 @@ package k8s import ( - "k8s.io/client-go/tools/cache" "testing" "time" @@ -10,6 +9,7 @@ import ( eventsv1 "k8s.io/api/events/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/cache" ) func TestEventWatcher_OnAdd(t *testing.T) { diff --git a/flytepropeller/pkg/webhook/config/config.go b/flytepropeller/pkg/webhook/config/config.go index 66904c536e9..840d66d06d1 100644 --- a/flytepropeller/pkg/webhook/config/config.go +++ b/flytepropeller/pkg/webhook/config/config.go @@ -5,6 +5,7 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/flyteorg/flyte/flytestdlib/config" ) @@ -111,6 +112,9 @@ type Config struct { GCPSecretManagerConfig GCPSecretManagerConfig `json:"gcpSecretManager" pflag:",GCP Secret Manager config."` VaultSecretManagerConfig VaultSecretManagerConfig `json:"vaultSecretManager" pflag:",Vault Secret Manager config."` EmbeddedSecretManagerConfig EmbeddedSecretManagerConfig `json:"embeddedSecretManagerConfig" pflag:",Embedded Secret Manager config without sidecar and which calls into the supported providers directly."` + + // Ignore PFlag for Image Builder + ImageBuilderConfig *ImageBuilderConfig `json:"imageBuilderConfig,omitempty" pflag:"-,"` } //go:generate enumer --type=EmbeddedSecretManagerType -json -yaml -trimprefix=EmbeddedSecretManagerType @@ -155,6 +159,17 @@ type VaultSecretManagerConfig struct { Annotations map[string]string `json:"annotations" pflag:"-,Annotation to be added to user task pod. The annotation can also be used to override default annotations added by Flyte. Useful to customize Vault integration (https://developer.hashicorp.com/vault/docs/platform/k8s/injector/annotations)"` } +type HostnameReplacement struct { + Existing string `json:"existing" pflag:",The existing hostname to replace"` + Replacement string `json:"replacement" pflag:",The replacement hostname"` + DisableVerification bool `json:"disableVerification" pflag:",Allow disabling URI verification for development environments"` +} + +type ImageBuilderConfig struct { + HostnameReplacement HostnameReplacement `json:"hostnameReplacement"` + LabelSelector metav1.LabelSelector `json:"labelSelector"` +} + func GetConfig() *Config { return configSection.GetConfig().(*Config) } diff --git a/flytepropeller/pkg/webhook/config/config_test.go b/flytepropeller/pkg/webhook/config/config_test.go new file mode 100644 index 00000000000..c09cb55175f --- /dev/null +++ b/flytepropeller/pkg/webhook/config/config_test.go @@ -0,0 +1,103 @@ +package config + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Ensure HostnameReplacements resolves to non-nil empty list. +func TestConfig_DefaultImageBuilder(t *testing.T) { + assert.Nil(t, DefaultConfig.ImageBuilderConfig) +} + +func TestConfig_LoadSimpleJSON(t *testing.T) { + expectedJSON := `{ + "metrics-prefix": "test-prefix", + "certDir": "/test/cert/dir", + "localCert": true, + "listenPort": 8080, + "serviceName": "test-service", + "servicePort": 8081, + "secretName": "test-secret", + "secretManagerType": "K8s" + }` + + var config Config + err := json.Unmarshal([]byte(expectedJSON), &config) + assert.Nil(t, err) + + assert.Nil(t, config.ImageBuilderConfig) +} + +func TestConfig_ImageBuilderConfig(t *testing.T) { + + t.Run("With verification enabled", func(t *testing.T) { + expectedJSON := `{ + "metrics-prefix": "test-prefix", + "certDir": "/test/cert/dir", + "localCert": true, + "listenPort": 8080, + "serviceName": "test-service", + "servicePort": 8081, + "secretName": "test-secret", + "secretManagerType": "K8s", + "imageBuilderConfig": { + "hostnameReplacement": { + "existing": "test.existing.hostname", + "replacement": "test.replacement.hostname" + }, + "labelSelector": { + "matchLabels": { + "test-key": "test-value" + } + } + } + }` + + var config Config + err := json.Unmarshal([]byte(expectedJSON), &config) + assert.Nil(t, err) + + assert.Equal(t, "test.existing.hostname", config.ImageBuilderConfig.HostnameReplacement.Existing) + assert.Equal(t, "test.replacement.hostname", config.ImageBuilderConfig.HostnameReplacement.Replacement) + assert.Equal(t, false, config.ImageBuilderConfig.HostnameReplacement.DisableVerification) + assert.Equal(t, "test-value", config.ImageBuilderConfig.LabelSelector.MatchLabels["test-key"]) + }) + + t.Run("With verification disabled", func(t *testing.T) { + expectedJSON := `{ + "metrics-prefix": "test-prefix", + "certDir": "/test/cert/dir", + "localCert": true, + "listenPort": 8080, + "serviceName": "test-service", + "servicePort": 8081, + "secretName": "test-secret", + "secretManagerType": "K8s", + "imageBuilderConfig": { + "hostnameReplacement": { + "existing": "test.existing.hostname", + "replacement": "test.replacement.hostname", + "disableVerification": true + }, + "labelSelector": { + "matchLabels": { + "test-key": "test-value" + } + } + } + }` + + var config Config + err := json.Unmarshal([]byte(expectedJSON), &config) + assert.Nil(t, err) + + assert.Equal(t, "test.existing.hostname", config.ImageBuilderConfig.HostnameReplacement.Existing) + assert.Equal(t, "test.replacement.hostname", config.ImageBuilderConfig.HostnameReplacement.Replacement) + assert.Equal(t, true, config.ImageBuilderConfig.HostnameReplacement.DisableVerification) + assert.Equal(t, "test-value", config.ImageBuilderConfig.LabelSelector.MatchLabels["test-key"]) + }) + +} diff --git a/flytepropeller/pkg/webhook/entrypoint.go b/flytepropeller/pkg/webhook/entrypoint.go index 62ebdb5fa88..5085883344b 100644 --- a/flytepropeller/pkg/webhook/entrypoint.go +++ b/flytepropeller/pkg/webhook/entrypoint.go @@ -54,7 +54,7 @@ func RunWebhook(ctx context.Context, propellerCfg *config.Config, cfg *config2.C webhookScope := (*scope).NewSubScope("webhook") - secretsWebhook, err := NewPodMutator(ctx, cfg, mgr.GetScheme(), webhookScope) + secretsWebhook, err := NewPodCreationWebhookConfig(ctx, cfg, mgr.GetScheme(), webhookScope) if err != nil { return err } @@ -65,7 +65,7 @@ func RunWebhook(ctx context.Context, propellerCfg *config.Config, cfg *config2.C return err } - err = secretsWebhook.Register(ctx, mgr) + err = secretsWebhook.Register(ctx, K8sRuntimeHTTPHookRegisterer{mgr: mgr}) if err != nil { logger.Fatalf(ctx, "Failed to register webhook with manager. Error: %v", err) } @@ -76,7 +76,7 @@ func RunWebhook(ctx context.Context, propellerCfg *config.Config, cfg *config2.C return nil } -func createMutationConfig(ctx context.Context, kubeClient *kubernetes.Clientset, webhookObj *PodMutator, defaultNamespace string) error { +func createMutationConfig(ctx context.Context, kubeClient *kubernetes.Clientset, webhookObj *PodCreationWebhookConfig, defaultNamespace string) error { shouldAddOwnerRef := true podName, found := os.LookupEnv(PodNameEnvVar) if !found { diff --git a/flytepropeller/pkg/webhook/http_hook_registerer.go b/flytepropeller/pkg/webhook/http_hook_registerer.go new file mode 100644 index 00000000000..2fc70a69dc2 --- /dev/null +++ b/flytepropeller/pkg/webhook/http_hook_registerer.go @@ -0,0 +1,21 @@ +package webhook + +import ( + "net/http" + + "sigs.k8s.io/controller-runtime/pkg/manager" +) + +//go:generate mockery --output=./mocks --case=underscore -name=HTTPHookRegistererIface + +type HTTPHookRegistererIface interface { + Register(path string, hook http.Handler) +} + +type K8sRuntimeHTTPHookRegisterer struct { + mgr manager.Manager +} + +func (k K8sRuntimeHTTPHookRegisterer) Register(path string, hook http.Handler) { + k.mgr.GetWebhookServer().Register(path, hook) +} diff --git a/flytepropeller/pkg/webhook/image_builder_mutator_v1.go b/flytepropeller/pkg/webhook/image_builder_mutator_v1.go new file mode 100644 index 00000000000..fb7c1a730fe --- /dev/null +++ b/flytepropeller/pkg/webhook/image_builder_mutator_v1.go @@ -0,0 +1,218 @@ +package webhook + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + "github.com/prometheus/client_golang/prometheus" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + "github.com/flyteorg/flyte/flytepropeller/pkg/webhook/config" + "github.com/flyteorg/flyte/flytestdlib/logger" + "github.com/flyteorg/flyte/flytestdlib/promutils" +) + +const ( + // Convention established with image builder that all orgs images are places under orgs namespace + orgsNamespace = "orgs" + // Convention where cr.union.ai/unionai/image-name map to /union/image-name + unionaiOrgPlaceholder = "unionai" + unionPathReplacement = "union" + // Version 1 URI format part that is expected in URI + version1URIPart = "v1" + // Currently excluding prefix to support versionless URIs + v1CloudTaskPart = "cloud/task" // Build-image container task + v1OrgsPart = "orgs/%s/" // Any of the org's images, orgs/ + v1UnionPublicPart = unionPathReplacement + "/" // Any publicly accessible image + + ImageBuilderV1ID = "image-builder" +) + +type metrics struct { + Scope promutils.Scope + RoundTime promutils.StopWatch + Attempts prometheus.Counter + Failures prometheus.Counter + V1ContainerValidationAttempts prometheus.Counter + V1ContainerValidationFailures prometheus.Counter +} + +type ImageBuilderMutatorV1 struct { + hostnameReplacement config.HostnameReplacement + metrics metrics + labelSelector metav1.LabelSelector + + // Prefixes used as replacement sources + initialUnionPublicPrefixes []string // For union public images + initialOrgPrefixes []string // For Org based images + + // Prefixes to be used for replacing hostnames + targetPublicPrefix string + targetOrgPrefix string + + // Acceptable prefixes is a list of known prefixes + // used for post replacement validation. + acceptablePrefixes []string +} + +func (i ImageBuilderMutatorV1) ID() string { + return ImageBuilderV1ID +} + +func (i *ImageBuilderMutatorV1) Mutate(ctx context.Context, pod *corev1.Pod) (newP *corev1.Pod, podChanged bool, errResponse *admission.Response) { + t := i.metrics.RoundTime.Start() + defer t.Stop() + hr := i.hostnameReplacement + + logger.Debugf(ctx, "Replacing hostname [%v] with [%v] for Pod [%v/%v]", hr.Existing, hr.Replacement, pod.Namespace, pod.Name) + i.metrics.Attempts.Inc() + newContainers, changed, err := i.replaceHostnames(ctx, pod.Name, pod.Namespace, pod.Spec.Containers, hr) + if err != nil { // Failed to replace or validate + logger.Warnf(ctx, "Failed to replace container image names for pod [%v/%v] due to error: %w", pod.Namespace, pod.Name, err) + i.metrics.Failures.Inc() + admissionResponse := admission.Errored(http.StatusForbidden, err) + return nil, false, &admissionResponse + } + pod.Spec.Containers = *newContainers + logger.Debugf(ctx, "Finished replacing hostname [%v] with [%v] for relevant Pod [%v/%v] containers", + hr.Existing, hr.Replacement, pod.Namespace, pod.Name) + return pod, changed, nil +} + +func (i *ImageBuilderMutatorV1) LabelSelector() *metav1.LabelSelector { + return &i.labelSelector +} + +func newMetrics(scope promutils.Scope) metrics { + return metrics{ + Scope: scope, + RoundTime: scope.MustNewStopWatch("round_time", "Time taken to complete a round of image builder mutator", time.Millisecond), + Attempts: scope.MustNewCounter("attempts", "Number of Image Builder webhook mutation attempts"), + Failures: scope.MustNewCounter("failures", "Number of Image Builder webhook mutation failures"), + V1ContainerValidationAttempts: scope.MustNewCounter("v1_container_validation_attempts", "Number of attempts to validate image URI format for a specific container"), + V1ContainerValidationFailures: scope.MustNewCounter("v1_container_validation_failures", "Number of failures to validate image URI format for a specific container"), + } +} + +func replaceByPrefix(c *corev1.Container, prefix string, replacementPrefix string) bool { + originalImage := c.Image + if strings.HasPrefix(originalImage, prefix) { + c.Image = replacementPrefix + strings.TrimPrefix(originalImage, prefix) + return true + } + return false +} + +// Replaces hostnames in the container image names +// Adheres to version 1 URI format, assumes versionless URIs are version 1 and valid. +func (i ImageBuilderMutatorV1) replaceV1Hostname(c *corev1.Container) bool { + for _, prefix := range i.initialUnionPublicPrefixes { + containerChanged := replaceByPrefix(c, prefix, i.targetPublicPrefix) + if containerChanged { + return true + } + } + + // If the image is not a publicly accessible image, assume it is an org based image + for _, prefix := range i.initialOrgPrefixes { + containerChanged := replaceByPrefix(c, prefix, i.targetOrgPrefix) + if containerChanged { + return true + } + } + + return false // Container image name did not match any of the expected prefixes +} + +// Validates hostname replacement to validate against v1 URI format. +// Validates non versioned URI formats as version 1 for backwards compatibility +// Returns true if the image is authorized, false otherwise +func (i ImageBuilderMutatorV1) verifyV1Prefix(ctx context.Context, containerImage string, targetHostname string, podNamespace string) bool { + // Note, theoeretically, Prefix Tree is a better datastructure to use here, but the number of prefixes is small + // Consider using if the number of prefixes grows significantly. + orgPart := fmt.Sprintf(v1OrgsPart, podNamespace) + validPrefixes := []string{ + // Support versionless URIs for V1. This is for backward compatibility with pre-versioned URI formats. + fmt.Sprintf("%s/%s", targetHostname, orgPart), + fmt.Sprintf("%s/%s/%s", targetHostname, version1URIPart, orgPart), + } + validPrefixes = append(validPrefixes, i.acceptablePrefixes...) + + i.metrics.V1ContainerValidationAttempts.Inc() + // Must match one of the valid prefixes + for _, prefix := range validPrefixes { + if strings.HasPrefix(containerImage, prefix) { + return true + } + } + logger.Warnf(ctx, "Container Image name %s forbidden", containerImage) + i.metrics.V1ContainerValidationFailures.Inc() + return false +} + +func (i ImageBuilderMutatorV1) replaceHostnames(ctx context.Context, podName string, podNameSpace string, containers []corev1.Container, hr config.HostnameReplacement) (newContainers *[]corev1.Container, anyContainerChanged bool, err error) { + anyContainerChanged = false + res := make([]corev1.Container, 0, len(containers)) + + for ic := range containers { + container := containers[ic] + label := fmt.Sprintf("Pod [%v/%v] container [%v]", podNameSpace, podName, container.Name) + logger.Debugf(ctx, "%v - Replacing hostname [%v] with [%v]", label, hr.Existing, hr.Replacement) + originalImage := container.Image + + containerChanged := i.replaceV1Hostname(&container) + if !i.hostnameReplacement.DisableVerification { + canUse := i.verifyV1Prefix(ctx, container.Image, hr.Replacement, podNameSpace) + if !canUse { + logger.Warnf(ctx, "%v - Original image name [%v] resulted in unauthorized image name [%v]", label, originalImage, container.Image) + return nil, false, fmt.Errorf("access to %s is not authorized", originalImage) + } + } + + if containerChanged { + logger.Debugf(ctx, "%v - Replaced hostname [%v] with [%v]", label, hr.Existing, hr.Replacement) + } else { + logger.Debugf(ctx, "%v - Not replacing hostname in [%v] with [%v]", label, hr.Replacement, hr.Existing) + } + anyContainerChanged = anyContainerChanged || containerChanged + res = append(res, container) + } + return &res, anyContainerChanged, nil +} + +func NewImageBuilderMutator(hostnameReplacement config.HostnameReplacement, labelSelector metav1.LabelSelector, scope promutils.Scope) *ImageBuilderMutatorV1 { + parts := []string{v1CloudTaskPart, v1UnionPublicPart} + validPrefixes := make([]string, len(parts)*2) + for ip, part := range parts { + // Support pre-versioned URI formats for backward compatibility + validPrefixes[ip] = fmt.Sprintf("%s/%s", hostnameReplacement.Replacement, part) + // Support version 1 URI formats + validPrefixes[ip+len(parts)] = fmt.Sprintf("%s/%s/%s", hostnameReplacement.Replacement, version1URIPart, part) + } + + return &ImageBuilderMutatorV1{ + hostnameReplacement: hostnameReplacement, + metrics: newMetrics(scope), + labelSelector: labelSelector, + initialUnionPublicPrefixes: []string{ + // Versioned URI formats first + fmt.Sprintf("%s/%s/%s/", hostnameReplacement.Existing, version1URIPart, unionaiOrgPlaceholder), + // Legacy non-versioned URI formats + fmt.Sprintf("%s/%s/", hostnameReplacement.Existing, unionaiOrgPlaceholder), + }, + initialOrgPrefixes: []string{ + // Versioned URI formats first + fmt.Sprintf("%s/%s", hostnameReplacement.Existing, version1URIPart), + // Legacy non-versioned URI formats + hostnameReplacement.Existing, + }, + targetPublicPrefix: fmt.Sprintf("%s/%s/", hostnameReplacement.Replacement, unionPathReplacement), + targetOrgPrefix: fmt.Sprintf("%s/%s", hostnameReplacement.Replacement, orgsNamespace), + acceptablePrefixes: validPrefixes, + } +} diff --git a/flytepropeller/pkg/webhook/image_builder_mutator_v1_test.go b/flytepropeller/pkg/webhook/image_builder_mutator_v1_test.go new file mode 100644 index 00000000000..ed9a9f0cc53 --- /dev/null +++ b/flytepropeller/pkg/webhook/image_builder_mutator_v1_test.go @@ -0,0 +1,416 @@ +package webhook + +import ( + "context" + "fmt" + "net/http" + "testing" + + promtestutil "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/flyteorg/flyte/flytepropeller/pkg/webhook/config" + "github.com/flyteorg/flyte/flytestdlib/promutils" +) + +const ( + org = "test-org" + otherOrg = "test-other-org" + existingHostname = "test.original.hostname" + replacementHostname = "test.replacement.hostname" +) + +var ( + differentImage = fmt.Sprintf("%s/orgs/%s/other-image", replacementHostname, org) + unusedLabelSelector = metav1.LabelSelector{} + invalidImageNames = []string{ + // Orgs that share similar prefixes and suffixes + fmt.Sprintf("%s/%s-suffix/image", existingHostname, org), + fmt.Sprintf("%s/prefix-%s/image", existingHostname, org), + fmt.Sprintf("%s/%s/image", existingHostname, otherOrg), + } + replacedInvalidImageNames = []string{ + fmt.Sprintf("%s/orgs/%s-suffix/image", replacementHostname, org), + fmt.Sprintf("%s/orgs/prefix-%s/image", replacementHostname, org), + fmt.Sprintf("%s/orgs/%s/image", replacementHostname, otherOrg), + } +) + +func defaultTestImageBuilderMutator() *ImageBuilderMutatorV1 { + return NewImageBuilderMutator(config.HostnameReplacement{ + Existing: existingHostname, + Replacement: replacementHostname, + }, unusedLabelSelector, promutils.NewTestScope()) +} + +func TestImageBuilderWebhook_Mutate(t *testing.T) { + + t.Run("Valid hostnames", func(t *testing.T) { + + validImageNames := []string{ + // URL Paths without version. Backwards compatible with older + // versions of unionai SDK + // Build-image container task + fmt.Sprintf("%s/cloud/task", replacementHostname), + // Any image in users org + fmt.Sprintf("%s/orgs/%s/image", replacementHostname, org), + fmt.Sprintf("%s/orgs/%s/other-image", replacementHostname, org), + // Any image in union publicly accessible namespace + fmt.Sprintf("%s/union/image", replacementHostname), + fmt.Sprintf("%s/union/other-image", replacementHostname), + + // Version 1 URI paths + fmt.Sprintf("%s/%s/cloud/task", replacementHostname, version1URIPart), + // Any image in users org + fmt.Sprintf("%s/%s/orgs/%s/image", replacementHostname, version1URIPart, org), + fmt.Sprintf("%s/%s/orgs/%s/other-image", replacementHostname, version1URIPart, org), + // Any image in union publicly accessible namespace + fmt.Sprintf("%s/%s/union/image", replacementHostname, version1URIPart), + fmt.Sprintf("%s/%s/union/other-image", replacementHostname, version1URIPart), + } + for _, validImageName := range validImageNames { + m := defaultTestImageBuilderMutator() + pod := corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: org, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "testcontainer1", + Image: validImageName, + }, + }, + }, + } + returnedPod, changed, err := m.Mutate(context.Background(), &pod) + assert.Nil(t, err) + assert.False(t, changed, fmt.Sprintf("Expected no change for %s", validImageName)) + assert.Equal(t, &pod, returnedPod, fmt.Sprintf("Expected no Pod differences for %s", validImageName)) + assert.Equal(t, validImageName, returnedPod.Spec.Containers[0].Image, fmt.Sprintf("Expected image name to be left at %s", validImageName)) + assert.Equal(t, 1, int(promtestutil.ToFloat64(m.metrics.Attempts))) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.Failures))) + assert.Equal(t, 1, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationAttempts))) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationFailures))) + } + }) + + t.Run("Replaces hostname for pre-versioned union public images", func(t *testing.T) { + m := defaultTestImageBuilderMutator() + existingImage := fmt.Sprintf("%s/%s/image", existingHostname, unionaiOrgPlaceholder) + expectedImage := fmt.Sprintf("%s/%s/image", replacementHostname, unionPathReplacement) + otherImage := differentImage + pod := corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: org, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "testcontainer1", + Image: existingImage, + }, + { + Name: "testcontainer2", + Image: otherImage, + }, + }, + }, + } + returnedPod, changed, err := m.Mutate(context.Background(), &pod) + assert.Nil(t, err) + assert.True(t, changed) + assert.Equal(t, &pod, returnedPod) + assert.Equal(t, expectedImage, returnedPod.Spec.Containers[0].Image) + assert.Equal(t, otherImage, returnedPod.Spec.Containers[1].Image) + assert.Equal(t, 1, int(promtestutil.ToFloat64(m.metrics.Attempts))) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.Failures))) + assert.Equal(t, 2, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationAttempts))) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationFailures))) + }) + + t.Run("Replaces hostname for versioned union public images", func(t *testing.T) { + m := defaultTestImageBuilderMutator() + existingImage := fmt.Sprintf("%s/%s/%s/image", existingHostname, version1URIPart, unionaiOrgPlaceholder) + expectedImage := fmt.Sprintf("%s/%s/image", replacementHostname, unionPathReplacement) + otherImage := differentImage + pod := corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: org, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "testcontainer1", + Image: existingImage, + }, + { + Name: "testcontainer2", + Image: otherImage, + }, + }, + }, + } + returnedPod, changed, err := m.Mutate(context.Background(), &pod) + assert.Nil(t, err) + assert.True(t, changed) + assert.Equal(t, &pod, returnedPod) + assert.Equal(t, expectedImage, returnedPod.Spec.Containers[0].Image) + assert.Equal(t, otherImage, returnedPod.Spec.Containers[1].Image) + assert.Equal(t, 1, int(promtestutil.ToFloat64(m.metrics.Attempts))) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.Failures))) + assert.Equal(t, 2, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationAttempts))) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationFailures))) + }) + + t.Run("Replaces hostname for pre-versioned org specific image", func(t *testing.T) { + m := defaultTestImageBuilderMutator() + existingImage := fmt.Sprintf("%s/%s/image", existingHostname, org) + expectedImage := fmt.Sprintf("%s/orgs/%s/image", replacementHostname, org) + otherImage := differentImage + pod := corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: org, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "testcontainer1", + Image: existingImage, + }, + { + Name: "testcontainer2", + Image: otherImage, + }, + }, + }, + } + returnedPod, changed, err := m.Mutate(context.Background(), &pod) + assert.Nil(t, err) + assert.True(t, changed) + assert.Equal(t, &pod, returnedPod) + assert.Equal(t, expectedImage, returnedPod.Spec.Containers[0].Image) + assert.Equal(t, otherImage, returnedPod.Spec.Containers[1].Image) + assert.Equal(t, 1, int(promtestutil.ToFloat64(m.metrics.Attempts))) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.Failures))) + assert.Equal(t, 2, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationAttempts))) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationFailures))) + }) + + t.Run("Replaces hostname for versioned org specific image", func(t *testing.T) { + m := defaultTestImageBuilderMutator() + existingImage := fmt.Sprintf("%s/%s/%s/image", existingHostname, version1URIPart, org) + expectedImage := fmt.Sprintf("%s/orgs/%s/image", replacementHostname, org) + otherImage := differentImage + pod := corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: org, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "testcontainer1", + Image: existingImage, + }, + { + Name: "testcontainer2", + Image: otherImage, + }, + }, + }, + } + returnedPod, changed, err := m.Mutate(context.Background(), &pod) + assert.Nil(t, err) + assert.True(t, changed) + assert.Equal(t, &pod, returnedPod) + assert.Equal(t, expectedImage, returnedPod.Spec.Containers[0].Image) + assert.Equal(t, otherImage, returnedPod.Spec.Containers[1].Image) + assert.Equal(t, 1, int(promtestutil.ToFloat64(m.metrics.Attempts))) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.Failures))) + assert.Equal(t, 2, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationAttempts))) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationFailures))) + }) + + t.Run("Replaces multiple hostname match occurrence", func(t *testing.T) { + m := defaultTestImageBuilderMutator() + originalImages := make([]string, 3) + expectedImages := make([]string, 3) + for i := 0; i < 3; i++ { + originalImages[i] = fmt.Sprintf("%s/%s/image-%d", existingHostname, org, i) + expectedImages[i] = fmt.Sprintf("%s/orgs/%s/image-%d", replacementHostname, org, i) + } + + otherImage := differentImage + pod := corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: org, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "testcontainer1", + Image: originalImages[0], + }, + { + Name: "testcontainer2", + Image: originalImages[1], + }, + { + Name: "testcontainer3", + Image: originalImages[2], + }, + { + Name: "testunchangedcontainer", + Image: otherImage, + }, + }, + }, + } + returnedPod, changed, err := m.Mutate(context.Background(), &pod) + assert.Nil(t, err) + assert.True(t, changed) + assert.Equal(t, &pod, returnedPod) + for i := 0; i < 3; i++ { + assert.Equal(t, expectedImages[i], returnedPod.Spec.Containers[i].Image) + } + assert.Equal(t, otherImage, returnedPod.Spec.Containers[3].Image) + assert.Equal(t, 1, int(promtestutil.ToFloat64(m.metrics.Attempts))) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.Failures))) + assert.Equal(t, 4, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationAttempts))) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationFailures))) + }) + + t.Run("Rejects different organization", func(t *testing.T) { + otherOrg := "other-org" + m := defaultTestImageBuilderMutator() + originalImage := fmt.Sprintf("%s/%s/image", otherOrg, existingHostname) + pod := corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: org, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "testcontainer1", + Image: originalImage, + }, + }, + }, + } + returnedPod, changed, err := m.Mutate(context.Background(), &pod) + assert.NotNil(t, err) + assert.Equal(t, http.StatusForbidden, int(err.Result.Code)) + assert.Equal(t, fmt.Sprintf("access to %s is not authorized", originalImage), err.Result.Message) + assert.False(t, changed) + assert.Nil(t, returnedPod) + assert.Equal(t, 1, int(promtestutil.ToFloat64(m.metrics.Attempts))) + assert.Equal(t, 1, int(promtestutil.ToFloat64(m.metrics.Failures))) + assert.Equal(t, 1, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationAttempts))) + assert.Equal(t, 1, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationFailures))) + }) + + t.Run("Rejects unauthorized paths", func(t *testing.T) { + for _, invalidImageName := range invalidImageNames { + m := defaultTestImageBuilderMutator() + pod := corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: org, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "testcontainer1", + Image: invalidImageName, + }, + }, + }, + } + returnedPod, changed, err := m.Mutate(context.Background(), &pod) + assert.NotNil(t, err, fmt.Sprintf("Expected error for %s", invalidImageName)) + assert.Equal(t, http.StatusForbidden, int(err.Result.Code), fmt.Sprintf("Expected forbidden error code for %s", invalidImageName)) + assert.Equal(t, fmt.Sprintf("access to %s is not authorized", invalidImageName), err.Result.Message) + assert.False(t, changed, fmt.Sprintf("Expected no change for %s", invalidImageName)) + assert.Nil(t, returnedPod, fmt.Sprintf("Expected nil returned pod pointer for %s", invalidImageName)) + assert.Equal(t, 1, int(promtestutil.ToFloat64(m.metrics.Attempts))) + assert.Equal(t, 1, int(promtestutil.ToFloat64(m.metrics.Failures))) + assert.Equal(t, 1, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationAttempts))) + assert.Equal(t, 1, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationFailures))) + } + }) + + t.Run("Skips verification for invalid URI paths", func(t *testing.T) { + for i, unVerifiedImageName := range invalidImageNames { + m := NewImageBuilderMutator(config.HostnameReplacement{ + Existing: existingHostname, + Replacement: replacementHostname, + DisableVerification: true, + }, unusedLabelSelector, promutils.NewTestScope()) + pod := corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: org, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "testcontainer1", + Image: unVerifiedImageName, + }, + }, + }, + } + returnedPod, changed, err := m.Mutate(context.Background(), &pod) + assert.Nil(t, err, fmt.Sprintf("Expected no error for %s", unVerifiedImageName)) + assert.Equal(t, &pod, returnedPod, fmt.Sprintf("Expected no Pod address change for %s", unVerifiedImageName)) + assert.Equal(t, replacedInvalidImageNames[i], returnedPod.Spec.Containers[0].Image) + assert.True(t, changed, fmt.Sprintf("Expected image still changed for %s", unVerifiedImageName)) + assert.Equal(t, 1, int(promtestutil.ToFloat64(m.metrics.Attempts)), fmt.Sprintf("Expected 1 attempt for %s", unVerifiedImageName)) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.Failures)), fmt.Sprintf("Expected 0 failures for %s", unVerifiedImageName)) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationAttempts)), fmt.Sprintf("Expected 1 attempt for %s", unVerifiedImageName)) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationFailures)), fmt.Sprintf("Expected 0 failures for %s", unVerifiedImageName)) + } + }) + + t.Run("Skips verification for different host", func(t *testing.T) { + // Pod uses a hostname different from hostnameReplacement.Existing + otherHostname := "test.other.hostname" + otherHostImageNames := []string{ + fmt.Sprintf("%s/cloud/task", otherHostname), + fmt.Sprintf("%s/orgs/%s/image", otherHostname, org), + fmt.Sprintf("%s/union/image", otherHostname), + fmt.Sprintf("%s/%s/cloud/task", otherHostname, version1URIPart), + fmt.Sprintf("%s/%s/orgs/%s/image", otherHostname, version1URIPart, org), + fmt.Sprintf("%s/%s/union/image", otherHostname, version1URIPart), + } + for _, unVerifiedImageName := range otherHostImageNames { + m := NewImageBuilderMutator(config.HostnameReplacement{ + Existing: existingHostname, + Replacement: replacementHostname, + DisableVerification: true, + }, unusedLabelSelector, promutils.NewTestScope()) + pod := corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: org, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "testcontainer1", + Image: unVerifiedImageName, + }, + }, + }, + } + returnedPod, changed, err := m.Mutate(context.Background(), &pod) + assert.Nil(t, err, fmt.Sprintf("Expected no error for %s", unVerifiedImageName)) + assert.Equal(t, &pod, returnedPod, fmt.Sprintf("Expected no Pod address change for %s", unVerifiedImageName)) + assert.Equal(t, unVerifiedImageName, returnedPod.Spec.Containers[0].Image) + assert.False(t, changed, fmt.Sprintf("Expected no change for %s", unVerifiedImageName)) + assert.Equal(t, 1, int(promtestutil.ToFloat64(m.metrics.Attempts)), fmt.Sprintf("Expected 1 attempt for %s", unVerifiedImageName)) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.Failures)), fmt.Sprintf("Expected 0 failures for %s", unVerifiedImageName)) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationAttempts)), fmt.Sprintf("Expected 1 attempt for %s", unVerifiedImageName)) + assert.Equal(t, 0, int(promtestutil.ToFloat64(m.metrics.V1ContainerValidationFailures)), fmt.Sprintf("Expected 0 failures for %s", unVerifiedImageName)) + } + }) +} diff --git a/flytepropeller/pkg/webhook/mocks/http_hook_registerer_iface.go b/flytepropeller/pkg/webhook/mocks/http_hook_registerer_iface.go new file mode 100644 index 00000000000..6e4db208caa --- /dev/null +++ b/flytepropeller/pkg/webhook/mocks/http_hook_registerer_iface.go @@ -0,0 +1,19 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + http "net/http" + + mock "github.com/stretchr/testify/mock" +) + +// HTTPHookRegistererIface is an autogenerated mock type for the HTTPHookRegistererIface type +type HTTPHookRegistererIface struct { + mock.Mock +} + +// Register provides a mock function with given fields: path, hook +func (_m *HTTPHookRegistererIface) Register(path string, hook http.Handler) { + _m.Called(path, hook) +} diff --git a/flytepropeller/pkg/webhook/mocks/mutator.go b/flytepropeller/pkg/webhook/mocks/mutator.go deleted file mode 100644 index 5376621f46c..00000000000 --- a/flytepropeller/pkg/webhook/mocks/mutator.go +++ /dev/null @@ -1,95 +0,0 @@ -// Code generated by mockery v1.0.1. DO NOT EDIT. - -package mocks - -import ( - context "context" - - mock "github.com/stretchr/testify/mock" - v1 "k8s.io/api/core/v1" -) - -// Mutator is an autogenerated mock type for the Mutator type -type Mutator struct { - mock.Mock -} - -type Mutator_ID struct { - *mock.Call -} - -func (_m Mutator_ID) Return(_a0 string) *Mutator_ID { - return &Mutator_ID{Call: _m.Call.Return(_a0)} -} - -func (_m *Mutator) OnID() *Mutator_ID { - c_call := _m.On("ID") - return &Mutator_ID{Call: c_call} -} - -func (_m *Mutator) OnIDMatch(matchers ...interface{}) *Mutator_ID { - c_call := _m.On("ID", matchers...) - return &Mutator_ID{Call: c_call} -} - -// ID provides a mock function with given fields: -func (_m *Mutator) ID() string { - ret := _m.Called() - - var r0 string - if rf, ok := ret.Get(0).(func() string); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(string) - } - - return r0 -} - -type Mutator_Mutate struct { - *mock.Call -} - -func (_m Mutator_Mutate) Return(newP *v1.Pod, changed bool, err error) *Mutator_Mutate { - return &Mutator_Mutate{Call: _m.Call.Return(newP, changed, err)} -} - -func (_m *Mutator) OnMutate(ctx context.Context, p *v1.Pod) *Mutator_Mutate { - c_call := _m.On("Mutate", ctx, p) - return &Mutator_Mutate{Call: c_call} -} - -func (_m *Mutator) OnMutateMatch(matchers ...interface{}) *Mutator_Mutate { - c_call := _m.On("Mutate", matchers...) - return &Mutator_Mutate{Call: c_call} -} - -// Mutate provides a mock function with given fields: ctx, p -func (_m *Mutator) Mutate(ctx context.Context, p *v1.Pod) (*v1.Pod, bool, error) { - ret := _m.Called(ctx, p) - - var r0 *v1.Pod - if rf, ok := ret.Get(0).(func(context.Context, *v1.Pod) *v1.Pod); ok { - r0 = rf(ctx, p) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*v1.Pod) - } - } - - var r1 bool - if rf, ok := ret.Get(1).(func(context.Context, *v1.Pod) bool); ok { - r1 = rf(ctx, p) - } else { - r1 = ret.Get(1).(bool) - } - - var r2 error - if rf, ok := ret.Get(2).(func(context.Context, *v1.Pod) error); ok { - r2 = rf(ctx, p) - } else { - r2 = ret.Error(2) - } - - return r0, r1, r2 -} diff --git a/flytepropeller/pkg/webhook/mocks/pod_mutator.go b/flytepropeller/pkg/webhook/mocks/pod_mutator.go new file mode 100644 index 00000000000..d898d015463 --- /dev/null +++ b/flytepropeller/pkg/webhook/mocks/pod_mutator.go @@ -0,0 +1,136 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + admission "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + + corev1 "k8s.io/api/core/v1" + + mock "github.com/stretchr/testify/mock" + + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// PodMutator is an autogenerated mock type for the PodMutator type +type PodMutator struct { + mock.Mock +} + +type PodMutator_ID struct { + *mock.Call +} + +func (_m PodMutator_ID) Return(_a0 string) *PodMutator_ID { + return &PodMutator_ID{Call: _m.Call.Return(_a0)} +} + +func (_m *PodMutator) OnID() *PodMutator_ID { + c_call := _m.On("ID") + return &PodMutator_ID{Call: c_call} +} + +func (_m *PodMutator) OnIDMatch(matchers ...interface{}) *PodMutator_ID { + c_call := _m.On("ID", matchers...) + return &PodMutator_ID{Call: c_call} +} + +// ID provides a mock function with given fields: +func (_m *PodMutator) ID() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +type PodMutator_LabelSelector struct { + *mock.Call +} + +func (_m PodMutator_LabelSelector) Return(_a0 *v1.LabelSelector) *PodMutator_LabelSelector { + return &PodMutator_LabelSelector{Call: _m.Call.Return(_a0)} +} + +func (_m *PodMutator) OnLabelSelector() *PodMutator_LabelSelector { + c_call := _m.On("LabelSelector") + return &PodMutator_LabelSelector{Call: c_call} +} + +func (_m *PodMutator) OnLabelSelectorMatch(matchers ...interface{}) *PodMutator_LabelSelector { + c_call := _m.On("LabelSelector", matchers...) + return &PodMutator_LabelSelector{Call: c_call} +} + +// LabelSelector provides a mock function with given fields: +func (_m *PodMutator) LabelSelector() *v1.LabelSelector { + ret := _m.Called() + + var r0 *v1.LabelSelector + if rf, ok := ret.Get(0).(func() *v1.LabelSelector); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.LabelSelector) + } + } + + return r0 +} + +type PodMutator_Mutate struct { + *mock.Call +} + +func (_m PodMutator_Mutate) Return(newP *corev1.Pod, changed bool, err *admission.Response) *PodMutator_Mutate { + return &PodMutator_Mutate{Call: _m.Call.Return(newP, changed, err)} +} + +func (_m *PodMutator) OnMutate(ctx context.Context, p *corev1.Pod) *PodMutator_Mutate { + c_call := _m.On("Mutate", ctx, p) + return &PodMutator_Mutate{Call: c_call} +} + +func (_m *PodMutator) OnMutateMatch(matchers ...interface{}) *PodMutator_Mutate { + c_call := _m.On("Mutate", matchers...) + return &PodMutator_Mutate{Call: c_call} +} + +// Mutate provides a mock function with given fields: ctx, p +func (_m *PodMutator) Mutate(ctx context.Context, p *corev1.Pod) (*corev1.Pod, bool, *admission.Response) { + ret := _m.Called(ctx, p) + + var r0 *corev1.Pod + if rf, ok := ret.Get(0).(func(context.Context, *corev1.Pod) *corev1.Pod); ok { + r0 = rf(ctx, p) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*corev1.Pod) + } + } + + var r1 bool + if rf, ok := ret.Get(1).(func(context.Context, *corev1.Pod) bool); ok { + r1 = rf(ctx, p) + } else { + r1 = ret.Get(1).(bool) + } + + var r2 *admission.Response + if rf, ok := ret.Get(2).(func(context.Context, *corev1.Pod) *admission.Response); ok { + r2 = rf(ctx, p) + } else { + if ret.Get(2) != nil { + r2 = ret.Get(2).(*admission.Response) + } + } + + return r0, r1, r2 +} diff --git a/flytepropeller/pkg/webhook/pod.go b/flytepropeller/pkg/webhook/pod.go index 52b48e02f1f..6a8f01458e3 100644 --- a/flytepropeller/pkg/webhook/pod.go +++ b/flytepropeller/pkg/webhook/pod.go @@ -41,46 +41,81 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" - "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils/secrets" "github.com/flyteorg/flyte/flytepropeller/pkg/webhook/config" "github.com/flyteorg/flyte/flytestdlib/logger" "github.com/flyteorg/flyte/flytestdlib/promutils" ) -const webhookName = "flyte-pod-webhook.flyte.org" +const ( + secretsWebhookName = "flyte-pod-webhook.flyte.org" // #nosec G101 + unionWebhookNameDomain = "union.ai" +) -// PodMutator implements controller-runtime WebHook interface. -type PodMutator struct { - decoder *admission.Decoder - cfg *config.Config - Mutators []MutatorConfig +var ( + admissionRegistrationRules = []admissionregistrationv1.RuleWithOperations{ + { + Operations: []admissionregistrationv1.OperationType{ + admissionregistrationv1.Create, + }, + Rule: admissionregistrationv1.Rule{ + APIGroups: []string{"*"}, + APIVersions: []string{"v1"}, + Resources: []string{"pods"}, + }, + }, + } + admissionRegistrationVersions = []string{ + "v1", + "v1beta1", + } + admissionRegistrationFailurePolicy = admissionregistrationv1.Fail + admissionRegistrationSideEffects = admissionregistrationv1.SideEffectClassNoneOnDryRun +) + +// PodCreationWebhookConfig maps one to one to Kubernetes MutatingWebhookConfiguration +// but specifically tagetting Pod creation. Kubernetes MutatingWebhookConfiguration supports +// multiple webhooks. This class is responsible for converting Union specific configuration into a +// Kubernetes MutatingWebhookConfiguration with potentially multiple webhooks. +type PodCreationWebhookConfig struct { + cfg *config.Config + httpHandlers []httpHandler + caBytes []byte } -type MutatorConfig struct { - Mutator Mutator - Required bool +// Internal struct type to consolidate handling of the HTTP layer. +// Every http handler shares the same decode, encoding logic but has a different Mutator. +type httpHandler struct { + decoder *admission.Decoder + mutator PodMutator + // The unique name of the Mutating Webhook + mutatingWebhookName string + // The complete URI Path to register webhook with. + path string } -type Mutator interface { +// PodMutator contains the business logic for a unique type of mutation or validation. +type PodMutator interface { ID() string - Mutate(ctx context.Context, p *corev1.Pod) (newP *corev1.Pod, changed bool, err error) + // Conducts the act of mutating the pod. + Mutate(ctx context.Context, p *corev1.Pod) (newP *corev1.Pod, changed bool, err *admission.Response) + // Defines how to select which Pods to apply the webhook to. + LabelSelector() *metav1.LabelSelector } -func (pm PodMutator) Handle(ctx context.Context, request admission.Request) admission.Response { +func (h httpHandler) Handle(ctx context.Context, request admission.Request) admission.Response { // Get the object in the request obj := &corev1.Pod{} - err := pm.decoder.Decode(request, obj) + err := h.decoder.Decode(request, obj) if err != nil { return admission.Errored(http.StatusBadRequest, err) } - newObj, changed, err := pm.Mutate(ctx, obj) - if err != nil { - return admission.Errored(http.StatusBadRequest, err) + newObj, changed, admissionError := h.mutator.Mutate(ctx, obj) + if admissionError != nil { + return *admissionError } if changed { @@ -96,59 +131,44 @@ func (pm PodMutator) Handle(ctx context.Context, request admission.Request) admi return admission.Allowed("No changes") } -func (pm PodMutator) Mutate(ctx context.Context, p *corev1.Pod) (newP *corev1.Pod, changed bool, err error) { - newP = p - for _, m := range pm.Mutators { - tempP := newP - tempChanged := false - tempP, tempChanged, err = m.Mutator.Mutate(ctx, tempP) - if err != nil { - if m.Required { - err = fmt.Errorf("failed to mutate using [%v]. Since it's a required mutator, failing early. Error: %v", m.Mutator.ID(), err) - logger.Info(ctx, err) - return p, false, err - } - - logger.Infof(ctx, "Failed to mutate using [%v]. Since it's not a required mutator, skipping. Error: %v", m.Mutator.ID(), err) - continue - } - - newP = tempP - if tempChanged { - changed = true - } +func (pm PodCreationWebhookConfig) Register(ctx context.Context, registerer HTTPHookRegistererIface) error { + for _, httpHandler := range pm.httpHandlers { + wh := &admission.Webhook{Handler: httpHandler} + logger.Infof(ctx, "Registering path [%v]", httpHandler.path) + registerer.Register(httpHandler.path, wh) } - - return newP, changed, nil -} - -func (pm PodMutator) Register(ctx context.Context, mgr manager.Manager) error { - wh := &admission.Webhook{ - Handler: pm, - } - - mutatePath := getPodMutatePath() - logger.Infof(ctx, "Registering path [%v]", mutatePath) - mgr.GetWebhookServer().Register(mutatePath, wh) return nil } -func (pm PodMutator) GetMutatePath() string { - return getPodMutatePath() -} - -func getPodMutatePath() string { +func getPodMutatePath(subpath string) string { pod := flytek8s.BuildIdentityPod() - return generateMutatePath(pod.GroupVersionKind()) + return generateMutatePath(pod.GroupVersionKind(), subpath) } -func generateMutatePath(gvk schema.GroupVersionKind) string { +func generateMutatePath(gvk schema.GroupVersionKind, subpath string) string { return "/mutate-" + strings.Replace(gvk.Group, ".", "-", -1) + "-" + - gvk.Version + "-" + strings.ToLower(gvk.Kind) + gvk.Version + "-" + strings.ToLower(gvk.Kind) + "/" + subpath +} + +func (pm PodCreationWebhookConfig) CreateMutationWebhookConfiguration(namespace string) (*admissionregistrationv1.MutatingWebhookConfiguration, error) { + webhooks := make([]admissionregistrationv1.MutatingWebhook, 0, len(pm.httpHandlers)) + for _, httpHandler := range pm.httpHandlers { + webhooks = append(webhooks, pm.getMutatingWebhook(namespace, httpHandler)) + } + + mutateConfig := &admissionregistrationv1.MutatingWebhookConfiguration{ + ObjectMeta: metav1.ObjectMeta{ + Name: pm.cfg.ServiceName, + Namespace: namespace, + }, + Webhooks: webhooks, + } + + return mutateConfig, nil } -func (pm PodMutator) CreateMutationWebhookConfiguration(namespace string) (*admissionregistrationv1.MutatingWebhookConfiguration, error) { - caBytes, err := os.ReadFile(filepath.Join(pm.cfg.ExpandCertDir(), "ca.crt")) +func NewPodCreationWebhookConfig(ctx context.Context, cfg *config.Config, scheme *runtime.Scheme, scope promutils.Scope) (*PodCreationWebhookConfig, error) { + caBytes, err := os.ReadFile(filepath.Join(cfg.ExpandCertDir(), "ca.crt")) if err != nil { // ca.crt is optional. If not provided, API Server will assume the webhook is serving SSL using a certificate // issued by a known Cert Authority. @@ -159,71 +179,97 @@ func (pm PodMutator) CreateMutationWebhookConfiguration(namespace string) (*admi } } - path := pm.GetMutatePath() - fail := admissionregistrationv1.Fail - sideEffects := admissionregistrationv1.SideEffectClassNoneOnDryRun + secretsMutator, err := NewSecretsMutator(ctx, cfg, scope.NewSubScope("secrets")) + if err != nil { + return nil, err + } - mutateConfig := &admissionregistrationv1.MutatingWebhookConfiguration{ - ObjectMeta: metav1.ObjectMeta{ - Name: pm.cfg.ServiceName, - Namespace: namespace, - }, + decoder := admission.NewDecoder(scheme) - Webhooks: []admissionregistrationv1.MutatingWebhook{ - { - Name: webhookName, - ClientConfig: admissionregistrationv1.WebhookClientConfig{ - CABundle: caBytes, // CA bundle created earlier - Service: &admissionregistrationv1.ServiceReference{ - Name: pm.cfg.ServiceName, - Namespace: namespace, - Path: &path, - Port: &pm.cfg.ServicePort, - }, - }, - Rules: []admissionregistrationv1.RuleWithOperations{ - { - Operations: []admissionregistrationv1.OperationType{ - admissionregistrationv1.Create, - }, - Rule: admissionregistrationv1.Rule{ - APIGroups: []string{"*"}, - APIVersions: []string{"v1"}, - Resources: []string{"pods"}, - }, - }, - }, - FailurePolicy: &fail, - SideEffects: &sideEffects, - AdmissionReviewVersions: []string{ - "v1", - "v1beta1", - }, - ObjectSelector: &metav1.LabelSelector{ - MatchLabels: map[string]string{ - secrets.PodLabel: secrets.PodLabelValue, - }, - }, - }}, + httpHandlers := []httpHandler{ + { + decoder: decoder, + mutator: secretsMutator, + mutatingWebhookName: secretsWebhookName, + path: getPodMutatePath(secretsMutator.ID()), + }, } - return mutateConfig, nil -} + if cfg.ImageBuilderConfig != nil { + imageBuilderMutator := NewImageBuilderMutator(cfg.ImageBuilderConfig.HostnameReplacement, cfg.ImageBuilderConfig.LabelSelector, scope.NewSubScope("image-builder")) + httpHandlers = append(httpHandlers, httpHandler{ + decoder: decoder, + mutator: imageBuilderMutator, + mutatingWebhookName: getMutatingWebhookName(imageBuilderMutator.ID()), + path: getPodMutatePath(imageBuilderMutator.ID()), + }) + } -func NewPodMutator(ctx context.Context, cfg *config.Config, scheme *runtime.Scheme, scope promutils.Scope) (*PodMutator, error) { - secretsMutator, err := NewSecretsMutator(ctx, cfg, scope.NewSubScope("secrets")) + err = verifyHTTPHandlers(ctx, httpHandlers) if err != nil { return nil, err } - return &PodMutator{ - decoder: admission.NewDecoder(scheme), - cfg: cfg, - Mutators: []MutatorConfig{ - { - Mutator: secretsMutator, - Required: true, + return &PodCreationWebhookConfig{ + cfg: cfg, + httpHandlers: httpHandlers, + caBytes: caBytes, + }, nil +} + +func getMutatingWebhookName(id string) string { + return fmt.Sprintf("%s-webhook.%s", id, unionWebhookNameDomain) +} + +func (pm PodCreationWebhookConfig) getMutatingWebhook(namespace string, httpHandler httpHandler) admissionregistrationv1.MutatingWebhook { + return admissionregistrationv1.MutatingWebhook{ + Name: httpHandler.mutatingWebhookName, + ClientConfig: admissionregistrationv1.WebhookClientConfig{ + CABundle: pm.caBytes, // CA bundle created earlier + Service: &admissionregistrationv1.ServiceReference{ + Name: pm.cfg.ServiceName, + Namespace: namespace, + Path: &httpHandler.path, + Port: &pm.cfg.ServicePort, }, }, - }, nil + Rules: admissionRegistrationRules, + FailurePolicy: &admissionRegistrationFailurePolicy, + SideEffects: &admissionRegistrationSideEffects, + AdmissionReviewVersions: admissionRegistrationVersions, + ObjectSelector: httpHandler.mutator.LabelSelector(), + } +} + +// Verify that there aren't any duplicate webhook names or URI paths. +func verifyHTTPHandlers(ctx context.Context, httpHandlers []httpHandler) error { + webhookNameOccurrences := make(map[string]bool) + pathOccurrences := make(map[string]bool) + + duplicateWebhookNames := make([]string, 0) + duplicatePaths := make([]string, 0) + + for _, handler := range httpHandlers { + if webhookNameOccurrences[handler.mutatingWebhookName] { + duplicateWebhookNames = append(duplicateWebhookNames, handler.mutatingWebhookName) + } + webhookNameOccurrences[handler.mutatingWebhookName] = true + if pathOccurrences[handler.path] { + duplicatePaths = append(duplicatePaths, handler.path) + } + pathOccurrences[handler.path] = true + } + + e := "" + if len(duplicateWebhookNames) > 0 { + e += fmt.Sprintf("Duplicate webhook names found: [%v]. ", strings.Join(duplicateWebhookNames, ",")) + } + if len(duplicatePaths) > 0 { + e += fmt.Sprintf("Duplicate paths found: [%v]", strings.Join(duplicatePaths, ",")) + } + if len(e) > 0 { + logger.Errorf(ctx, "Invalid webhook configuration: %v", e) + return fmt.Errorf("Invalid webhook configuration: %v", e) + } + return nil } diff --git a/flytepropeller/pkg/webhook/pod_test.go b/flytepropeller/pkg/webhook/pod_test.go index c99fd522ab6..ae22dadeab0 100644 --- a/flytepropeller/pkg/webhook/pod_test.go +++ b/flytepropeller/pkg/webhook/pod_test.go @@ -3,21 +3,80 @@ package webhook import ( "context" "fmt" + "net/http" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" admissionv1 "k8s.io/api/admission/v1" corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/client-go/tools/clientcmd/api/latest" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils/secrets" "github.com/flyteorg/flyte/flytepropeller/pkg/webhook/config" "github.com/flyteorg/flyte/flytepropeller/pkg/webhook/mocks" "github.com/flyteorg/flyte/flytestdlib/promutils" ) +var ( + expectedSecretsLabelSelector = metav1.LabelSelector{ + MatchLabels: map[string]string{ + secrets.PodLabel: secrets.PodLabelValue, + }, + } + expectedImageBuilderLabelSelector = metav1.LabelSelector{ + MatchLabels: map[string]string{ + "test-arbitrary-label": "test-arbitrary-value", + }, + } + testImageBuilderConfig = config.ImageBuilderConfig{ + HostnameReplacement: config.HostnameReplacement{ + Existing: "test.existing.hostname", + Replacement: "test.hostname", + }, + LabelSelector: expectedImageBuilderLabelSelector, + } +) + +func TestNewPodCreationWebhookConfig_NewPodCreationWebhookConfig(t *testing.T) { + + t.Run("Defaults with Secrets", func(t *testing.T) { + ctx := context.Background() + + pm, err := NewPodCreationWebhookConfig(ctx, &config.Config{ + CertDir: "testdata", + ServiceName: "my-service", + }, latest.Scheme, promutils.NewTestScope()) + + assert.NoError(t, err) + assert.NotNil(t, pm) + assert.Equal(t, 1, len(pm.httpHandlers)) + secretsHTTPHandler := pm.httpHandlers[0] + assert.Equal(t, expectedSecretsLabelSelector, *secretsHTTPHandler.mutator.LabelSelector()) + }) + + t.Run("With additional Image Builder config", func(t *testing.T) { + ctx := context.Background() + + pm, err := NewPodCreationWebhookConfig(ctx, &config.Config{ + CertDir: "testdata", + ServiceName: "my-service", + ImageBuilderConfig: &testImageBuilderConfig, + }, latest.Scheme, promutils.NewTestScope()) + + assert.NoError(t, err) + assert.NotNil(t, pm) + assert.Equal(t, 2, len(pm.httpHandlers)) + secretsHTTPHandler := pm.httpHandlers[0] + assert.Equal(t, expectedSecretsLabelSelector, *secretsHTTPHandler.mutator.LabelSelector()) + imageBuilderHTTPHandler := pm.httpHandlers[1] + assert.Equal(t, expectedImageBuilderLabelSelector, *imageBuilderHTTPHandler.mutator.LabelSelector()) + }) +} + func TestPodMutator_Mutate(t *testing.T) { inputPod := &corev1.Pod{ Spec: corev1.PodSpec{ @@ -29,69 +88,52 @@ func TestPodMutator_Mutate(t *testing.T) { }, } - successMutator := &mocks.Mutator{} + successMutator := &mocks.PodMutator{} successMutator.OnID().Return("SucceedingMutator") successMutator.OnMutateMatch(mock.Anything, mock.Anything).Return(nil, false, nil) - failedMutator := &mocks.Mutator{} + failedMutator := &mocks.PodMutator{} failedMutator.OnID().Return("FailingMutator") - failedMutator.OnMutateMatch(mock.Anything, mock.Anything).Return(nil, false, fmt.Errorf("failing mock")) + admissionError := admission.Errored(http.StatusBadRequest, fmt.Errorf("failing mock")) + failedMutator.OnMutateMatch(mock.Anything, mock.Anything).Return(nil, false, &admissionError) t.Run("Required Mutator Succeeded", func(t *testing.T) { - pm := &PodMutator{ - Mutators: []MutatorConfig{ - { - Mutator: successMutator, - Required: true, - }, - }, - } ctx := context.Background() - _, changed, err := pm.Mutate(ctx, inputPod.DeepCopy()) - assert.NoError(t, err) + _, changed, err := successMutator.Mutate(ctx, inputPod.DeepCopy()) + assert.Nil(t, err) assert.False(t, changed) }) t.Run("Required Mutator Failed", func(t *testing.T) { - pm := &PodMutator{ - Mutators: []MutatorConfig{ - { - Mutator: failedMutator, - Required: true, - }, - }, - } - ctx := context.Background() - _, _, err := pm.Mutate(ctx, inputPod.DeepCopy()) - assert.Error(t, err) - }) - - t.Run("Non-required Mutator Failed", func(t *testing.T) { - pm := &PodMutator{ - Mutators: []MutatorConfig{ - { - Mutator: failedMutator, - Required: false, - }, - }, - } ctx := context.Background() - _, _, err := pm.Mutate(ctx, inputPod.DeepCopy()) - assert.NoError(t, err) + _, _, err := failedMutator.Mutate(ctx, inputPod.DeepCopy()) + assert.NotNil(t, err) }) } func Test_CreateMutationWebhookConfiguration(t *testing.T) { ctx := context.Background() - pm, err := NewPodMutator(ctx, &config.Config{ + serviceName := "test-service" + pm, err := NewPodCreationWebhookConfig(ctx, &config.Config{ CertDir: "testdata", - ServiceName: "my-service", + ServiceName: serviceName, }, latest.Scheme, promutils.NewTestScope()) assert.NoError(t, err) + t.Run("Empty namespace", func(t *testing.T) { - c, err := pm.CreateMutationWebhookConfiguration("") + namespace := "" + c, err := pm.CreateMutationWebhookConfiguration(namespace) assert.NoError(t, err) assert.NotNil(t, c) + + assert.Equal(t, 1, len(c.Webhooks)) + assert.Equal(t, "flyte-pod-webhook.flyte.org", c.Webhooks[0].Name) + assert.Equal(t, serviceName, c.Webhooks[0].ClientConfig.Service.Name) + assert.Equal(t, namespace, c.Webhooks[0].ClientConfig.Service.Namespace) + assert.Equal(t, getPodMutatePath(SecretsID), *c.Webhooks[0].ClientConfig.Service.Path) + assert.Equal(t, 1, len(c.Webhooks[0].ObjectSelector.MatchLabels)) + assert.Equal(t, 0, len(c.Webhooks[0].ObjectSelector.DeepCopy().MatchExpressions)) + assert.Equal(t, expectedSecretsLabelSelector, *c.Webhooks[0].ObjectSelector) }) t.Run("With namespace", func(t *testing.T) { @@ -99,56 +141,130 @@ func Test_CreateMutationWebhookConfiguration(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, c) }) + + t.Run("With image builder", func(t *testing.T) { + pm, err := NewPodCreationWebhookConfig(ctx, &config.Config{ + CertDir: "testdata", + ServiceName: serviceName, + ImageBuilderConfig: &testImageBuilderConfig, + }, latest.Scheme, promutils.NewTestScope()) + assert.NoError(t, err) + namespace := "test-namespace" + c, err := pm.CreateMutationWebhookConfiguration(namespace) + assert.NoError(t, err) + assert.NotNil(t, c) + + assert.Equal(t, 2, len(c.Webhooks)) + assert.Equal(t, secretsWebhookName, c.Webhooks[0].Name) + assert.Equal(t, serviceName, c.Webhooks[0].ClientConfig.Service.Name) + assert.Equal(t, namespace, c.Webhooks[0].ClientConfig.Service.Namespace) + assert.Equal(t, getPodMutatePath(SecretsID), *c.Webhooks[0].ClientConfig.Service.Path) + assert.Equal(t, 1, len(c.Webhooks[0].ObjectSelector.MatchLabels)) + assert.Equal(t, 0, len(c.Webhooks[0].ObjectSelector.DeepCopy().MatchExpressions)) + assert.Equal(t, expectedSecretsLabelSelector, *c.Webhooks[0].ObjectSelector) + + assert.Equal(t, getMutatingWebhookName(ImageBuilderV1ID), c.Webhooks[1].Name) + assert.Equal(t, serviceName, c.Webhooks[1].ClientConfig.Service.Name) + assert.Equal(t, namespace, c.Webhooks[1].ClientConfig.Service.Namespace) + assert.Equal(t, getPodMutatePath(ImageBuilderV1ID), *c.Webhooks[1].ClientConfig.Service.Path) + assert.Equal(t, expectedImageBuilderLabelSelector, *c.Webhooks[1].ObjectSelector) + }) +} + +func Test_GetMutatePath(t *testing.T) { + assert.Equal(t, "/mutate--v1-pod/secrets", getPodMutatePath(SecretsID)) + assert.Equal(t, "/mutate--v1-pod/image-builder", getPodMutatePath(ImageBuilderV1ID)) } -func Test_Handle(t *testing.T) { +func Test_Register(t *testing.T) { ctx := context.Background() - pm, err := NewPodMutator(ctx, &config.Config{ - CertDir: "testdata", - ServiceName: "my-service", + + t.Run("Defaults", func(t *testing.T) { + pm, err := NewPodCreationWebhookConfig(context.Background(), &config.Config{ + CertDir: "testdata", + ServiceName: "my-service", + }, latest.Scheme, promutils.NewTestScope()) + assert.NoError(t, err) + + mockRegister := &mocks.HTTPHookRegistererIface{} + wh := &admission.Webhook{Handler: pm.httpHandlers[0]} + mockRegister.On("Register", "/mutate--v1-pod/secrets", wh) + err = pm.Register(ctx, mockRegister) + assert.Nil(t, err) + }) + + t.Run("With Image Builder", func(t *testing.T) { + pm, err := NewPodCreationWebhookConfig(context.Background(), &config.Config{ + CertDir: "testdata", + ServiceName: "my-service", + ImageBuilderConfig: &testImageBuilderConfig, + }, latest.Scheme, promutils.NewTestScope()) + assert.NoError(t, err) + + mockRegister := &mocks.HTTPHookRegistererIface{} + secretWH := &admission.Webhook{Handler: pm.httpHandlers[0]} + mockRegister.On("Register", getPodMutatePath(SecretsID), secretWH) + imageBuilderWH := &admission.Webhook{Handler: pm.httpHandlers[1]} + mockRegister.On("Register", getPodMutatePath(ImageBuilderV1ID), imageBuilderWH) + + err = pm.Register(ctx, mockRegister) + assert.Nil(t, err) + }) +} + +func Test_MutatorConfigHandle(t *testing.T) { + ctx := context.Background() + pm, err := NewPodCreationWebhookConfig(ctx, &config.Config{ + CertDir: "testdata", + ServiceName: "my-service", + ImageBuilderConfig: &testImageBuilderConfig, }, latest.Scheme, promutils.NewTestScope()) assert.NoError(t, err) + req := admission.Request{ AdmissionRequest: admissionv1.AdmissionRequest{ Object: runtime.RawExtension{ Raw: []byte(`{ - "apiVersion": "v1", - "kind": "Pod", - "metadata": { - "name": "foo", - "namespace": "default" - }, - "spec": { - "containers": [ - { - "image": "bar:v2", - "name": "bar" - } - ] - } + "apiVersion": "v1", + "kind": "Pod", + "metadata": { + "name": "foo", + "namespace": "default" + }, + "spec": { + "containers": [ + { + "image": "test.hostname/orgs/default/bar:v2", + "name": "bar" + } + ] + } }`), }, OldObject: runtime.RawExtension{ Raw: []byte(`{ - "apiVersion": "v1", - "kind": "Pod", - "metadata": { - "name": "foo", - "namespace": "default" - }, - "spec": { - "containers": [ - { - "image": "bar:v1", - "name": "bar" - } - ] - } + "apiVersion": "v1", + "kind": "Pod", + "metadata": { + "name": "foo", + "namespace": "default" + }, + "spec": { + "containers": [ + { + "image": "test.hostname/orgs/default/bar:v1", + "name": "bar" + } + ] + } }`), }, }, } - resp := pm.Handle(context.Background(), req) - assert.True(t, resp.Allowed) + assert.Equal(t, 2, len(pm.httpHandlers)) + for _, mutator := range pm.httpHandlers { + resp := mutator.Handle(context.Background(), req) + assert.True(t, resp.Allowed) + } } diff --git a/flytepropeller/pkg/webhook/secrets.go b/flytepropeller/pkg/webhook/secrets.go index 8aae5dc70c1..8ebeb336667 100644 --- a/flytepropeller/pkg/webhook/secrets.go +++ b/flytepropeller/pkg/webhook/secrets.go @@ -4,8 +4,11 @@ import ( "context" "errors" "fmt" + "net/http" corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" secretUtils "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils/secrets" @@ -19,6 +22,7 @@ const ( SecretPathDefaultDirEnvVar = "FLYTE_SECRETS_DEFAULT_DIR" // #nosec SecretPathFilePrefixEnvVar = "FLYTE_SECRETS_FILE_PREFIX" // #nosec SecretEnvVarPrefix = "FLYTE_SECRETS_ENV_PREFIX" // #nosec + SecretsID = "secrets" ) type SecretsMutator struct { @@ -35,13 +39,14 @@ type SecretsInjector interface { } func (s SecretsMutator) ID() string { - return "secrets" + return SecretsID } -func (s *SecretsMutator) Mutate(ctx context.Context, pod *corev1.Pod) (*corev1.Pod, bool /* injected */, error) { +func (s *SecretsMutator) Mutate(ctx context.Context, pod *corev1.Pod) (newP *corev1.Pod, podChanged bool, errResponse *admission.Response) { secrets, err := secretUtils.UnmarshalStringMapToSecrets(pod.GetAnnotations()) if err != nil { - return pod, false, fmt.Errorf("failed to unmarshal secrets from pod annotations: %w", err) + admissionError := admission.Errored(http.StatusBadRequest, fmt.Errorf("failed to unmarshal secrets from pod annotations: %w", err)) + return pod, false, &admissionError } for _, secret := range secrets { @@ -52,7 +57,8 @@ func (s *SecretsMutator) Mutate(ctx context.Context, pod *corev1.Pod) (*corev1.P } else { err = fmt.Errorf("none of the secret managers injected secret [%v]: %w", secret, err) } - return pod, false, err + admissionError := admission.Errored(http.StatusBadRequest, err) + return pod, false, &admissionError } pod = mutatedPod @@ -61,6 +67,14 @@ func (s *SecretsMutator) Mutate(ctx context.Context, pod *corev1.Pod) (*corev1.P return pod, len(secrets) > 0, nil } +func (s *SecretsMutator) LabelSelector() *metav1.LabelSelector { + return &metav1.LabelSelector{ + MatchLabels: map[string]string{ + secretUtils.PodLabel: secretUtils.PodLabelValue, + }, + } +} + func (s *SecretsMutator) injectSecret(ctx context.Context, secret *core.Secret, pod *corev1.Pod) (*corev1.Pod, bool /*injected*/, error) { errs := make([]error, 0) diff --git a/flytepropeller/pkg/webhook/secrets_test.go b/flytepropeller/pkg/webhook/secrets_test.go index 5b1f4f4e5cd..991256f8332 100644 --- a/flytepropeller/pkg/webhook/secrets_test.go +++ b/flytepropeller/pkg/webhook/secrets_test.go @@ -8,8 +8,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + secretUtils "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils/secrets" "github.com/flyteorg/flyte/flytepropeller/pkg/webhook/config" "github.com/flyteorg/flyte/flytepropeller/pkg/webhook/mocks" ) @@ -18,7 +20,7 @@ func TestSecretsWebhook_Mutate(t *testing.T) { t.Run("No injectors", func(t *testing.T) { m := SecretsMutator{} _, changed, err := m.Mutate(context.Background(), &corev1.Pod{}) - assert.NoError(t, err) + assert.Nil(t, err) assert.False(t, changed) }) @@ -43,7 +45,7 @@ func TestSecretsWebhook_Mutate(t *testing.T) { } _, changed, err := m.Mutate(context.Background(), podWithAnnotations.DeepCopy()) - assert.Error(t, err) + assert.NotNil(t, err) assert.False(t, changed) }) @@ -60,7 +62,17 @@ func TestSecretsWebhook_Mutate(t *testing.T) { } _, changed, err := m.Mutate(context.Background(), podWithAnnotations.DeepCopy()) - assert.NoError(t, err) + assert.Nil(t, err) assert.True(t, changed) }) } + +func TestSecrets_LabelSelector(t *testing.T) { + m := SecretsMutator{} + expected := metav1.LabelSelector{ + MatchLabels: map[string]string{ + secretUtils.PodLabel: secretUtils.PodLabelValue, + }, + } + assert.Equal(t, expected, *m.LabelSelector()) +}