diff --git a/cmd/controller/cmd/init_certs.go b/cmd/controller/cmd/init_certs.go index 8848176d0e..f39edac1cd 100644 --- a/cmd/controller/cmd/init_certs.go +++ b/cmd/controller/cmd/init_certs.go @@ -1,42 +1,17 @@ package cmd import ( - "bytes" "context" - cryptorand "crypto/rand" - - "github.com/flyteorg/flytepropeller/pkg/webhook" - - webhookConfig "github.com/flyteorg/flytepropeller/pkg/webhook/config" - - "github.com/flyteorg/flytestdlib/logger" - kubeErrors "k8s.io/apimachinery/pkg/api/errors" "github.com/flyteorg/flytepropeller/pkg/controller/config" - "github.com/flyteorg/flytepropeller/pkg/utils" - corev1 "k8s.io/api/core/v1" - v12 "k8s.io/apimachinery/pkg/apis/meta/v1" - v1 "k8s.io/client-go/kubernetes/typed/core/v1" + webhookConfig "github.com/flyteorg/flytepropeller/pkg/webhook/config" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "fmt" - "math/big" - "os" - "time" + "github.com/flyteorg/flytepropeller/pkg/webhook" "github.com/spf13/cobra" ) -const ( - CaCertKey = "ca.crt" - ServerCertKey = "tls.crt" - ServerCertPrivateKey = "tls.key" -) - // initCertsCmd initializes x509 TLS Certificates and saves them to a secret. var initCertsCmd = &cobra.Command{ Use: "init-certs", @@ -58,207 +33,10 @@ for the Webhook command to mount and read correctly. `, Example: "flytepropeller webhook init-certs", RunE: func(cmd *cobra.Command, args []string) error { - return runCertsCmd(context.Background(), config.GetConfig(), webhookConfig.GetConfig()) + return webhook.InitCerts(context.Background(), config.GetConfig(), webhookConfig.GetConfig()) }, } -type webhookCerts struct { - // base64 Encoded CA Cert - CaPEM *bytes.Buffer - // base64 Encoded Server Cert - ServerPEM *bytes.Buffer - // base64 Encoded Server Cert Key - PrivateKeyPEM *bytes.Buffer -} - func init() { webhookCmd.AddCommand(initCertsCmd) } - -func runCertsCmd(ctx context.Context, propellerCfg *config.Config, cfg *webhookConfig.Config) error { - podNamespace, found := os.LookupEnv(webhook.PodNamespaceEnvVar) - if !found { - podNamespace = podDefaultNamespace - } - - logger.Infof(ctx, "Issuing certs") - certs, err := createCerts(podNamespace) - if err != nil { - return err - } - - kubeClient, _, err := utils.GetKubeConfig(ctx, propellerCfg) - if err != nil { - return err - } - - logger.Infof(ctx, "Creating secret [%v] in Namespace [%v]", cfg.SecretName, podNamespace) - err = createWebhookSecret(ctx, podNamespace, cfg, certs, kubeClient.CoreV1().Secrets(podNamespace)) - if err != nil { - return err - } - - return nil -} - -func createWebhookSecret(ctx context.Context, namespace string, cfg *webhookConfig.Config, certs webhookCerts, secretsClient v1.SecretInterface) error { - isImmutable := true - secretData := map[string][]byte{ - CaCertKey: certs.CaPEM.Bytes(), - ServerCertKey: certs.ServerPEM.Bytes(), - ServerCertPrivateKey: certs.PrivateKeyPEM.Bytes(), - } - - secret := &corev1.Secret{ - ObjectMeta: v12.ObjectMeta{ - Name: cfg.SecretName, - Namespace: namespace, - }, - Type: corev1.SecretTypeOpaque, - Data: secretData, - Immutable: &isImmutable, - } - - _, err := secretsClient.Create(ctx, secret, v12.CreateOptions{}) - if err == nil { - logger.Infof(ctx, "Created secret [%v]", cfg.SecretName) - return nil - } - - if kubeErrors.IsAlreadyExists(err) { - logger.Infof(ctx, "A secret already exists with the same name. Validating.") - s, err := secretsClient.Get(ctx, cfg.SecretName, v12.GetOptions{}) - if err != nil { - return err - } - - // If ServerCertKey or ServerCertPrivateKey are missing, update - requiresUpdate := false - for key := range secretData { - if key == CaCertKey { - continue - } - - if _, exists := s.Data[key]; !exists { - requiresUpdate = true - break - } - } - - if requiresUpdate { - logger.Infof(ctx, "The existing secret is missing one or more keys.") - secret.Annotations = map[string]string{ - "flyteLastUpdate": "system-updated", - "flyteUpdatedAt": time.Now().String(), - } - - _, err = secretsClient.Update(ctx, secret, v12.UpdateOptions{}) - if err != nil && kubeErrors.IsConflict(err) { - logger.Infof(ctx, "Another instance of flyteadmin has updated the same secret. Ignoring this update") - err = nil - } - - return err - } - - return nil - } - - return err -} - -func createCerts(serviceNamespace string) (certs webhookCerts, err error) { - // CA config - caRequest := &x509.Certificate{ - SerialNumber: big.NewInt(2021), - Subject: pkix.Name{ - Organization: []string{"flyte.org"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(1, 0, 0), - IsCA: true, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - BasicConstraintsValid: true, - } - - // CA private key - caPrivateKey, err := rsa.GenerateKey(cryptorand.Reader, 4096) - if err != nil { - return webhookCerts{}, err - } - - // Self signed CA certificate - caCert, err := x509.CreateCertificate(cryptorand.Reader, caRequest, caRequest, &caPrivateKey.PublicKey, caPrivateKey) - if err != nil { - return webhookCerts{}, err - } - - // PEM encode CA cert - caPEM := new(bytes.Buffer) - err = pem.Encode(caPEM, &pem.Block{ - Type: "CERTIFICATE", - Bytes: caCert, - }) - if err != nil { - return webhookCerts{}, err - } - - dnsNames := []string{"flyte-pod-webhook", - "flyte-pod-webhook." + serviceNamespace, "flyte-pod-webhook." + serviceNamespace + ".svc"} - commonName := "flyte-pod-webhook." + serviceNamespace + ".svc" - - // server cert config - certRequest := &x509.Certificate{ - DNSNames: dnsNames, - SerialNumber: big.NewInt(1658), - Subject: pkix.Name{ - CommonName: commonName, - Organization: []string{"flyte.org"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(1, 0, 0), - SubjectKeyId: []byte{1, 2, 3, 4, 6}, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature, - } - - // server private key - serverPrivateKey, err := rsa.GenerateKey(cryptorand.Reader, 4096) - if err != nil { - return webhookCerts{}, err - } - - // sign the server cert - cert, err := x509.CreateCertificate(cryptorand.Reader, certRequest, caRequest, &serverPrivateKey.PublicKey, caPrivateKey) - if err != nil { - return webhookCerts{}, err - } - - // PEM encode the server cert and key - serverCertPEM := new(bytes.Buffer) - err = pem.Encode(serverCertPEM, &pem.Block{ - Type: "CERTIFICATE", - Bytes: cert, - }) - - if err != nil { - return webhookCerts{}, fmt.Errorf("failed to Encode CertPEM. Error: %w", err) - } - - serverPrivKeyPEM := new(bytes.Buffer) - err = pem.Encode(serverPrivKeyPEM, &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(serverPrivateKey), - }) - - if err != nil { - return webhookCerts{}, fmt.Errorf("failed to Encode Cert Private Key. Error: %w", err) - } - - return webhookCerts{ - CaPEM: caPEM, - ServerPEM: serverCertPEM, - PrivateKeyPEM: serverPrivKeyPEM, - }, nil -} diff --git a/cmd/controller/cmd/root.go b/cmd/controller/cmd/root.go index cb83526d4b..85b21f83e0 100644 --- a/cmd/controller/cmd/root.go +++ b/cmd/controller/cmd/root.go @@ -8,6 +8,8 @@ import ( "runtime" "github.com/flyteorg/flytestdlib/profutils" + "github.com/flyteorg/flytestdlib/promutils" + "golang.org/x/sync/errgroup" "k8s.io/klog" config2 "github.com/flyteorg/flytepropeller/pkg/controller/config" @@ -103,12 +105,38 @@ func executeRootCmd(baseCtx context.Context, cfg *config2.Config) error { // set up signals so we handle the first shutdown signal gracefully ctx := signals.SetupSignalHandler(baseCtx) - go func() { - err := profutils.StartProfilingServerWithDefaultHandlers(ctx, cfg.ProfilerPort.Port, nil) + // Add the propeller subscope because the MetricsPrefix only has "flyte:" to get uniform collection of metrics. + propellerScope := promutils.NewScope(cfg.MetricsPrefix).NewSubScope("propeller").NewSubScope(cfg.LimitNamespace) + mgr, err := controller.CreateControllerManager(ctx, cfg, defaultNamespace, &propellerScope) + if err != nil { + logger.Fatalf(ctx, "Failed to create controller manager. Error: %v", err) + return err + } + + g, childCtx := errgroup.WithContext(ctx) + g.Go(func() error { + err := profutils.StartProfilingServerWithDefaultHandlers(childCtx, cfg.ProfilerPort.Port, nil) + if err != nil { + logger.Fatalf(childCtx, "Failed to Start profiling and metrics server. Error: %v", err) + } + return err + }) + + g.Go(func() error { + err := controller.StartControllerManager(childCtx, mgr) + if err != nil { + logger.Fatalf(childCtx, "Failed to start controller manager. Error: %v", err) + } + return err + }) + + g.Go(func() error { + err := controller.StartController(childCtx, cfg, defaultNamespace, mgr, &propellerScope) if err != nil { - logger.Fatalf(ctx, "Failed to Start profiling and metrics server. Error: %v", err) + logger.Fatalf(childCtx, "Failed to start controller. Error: %v", err) } - }() + return err + }) - return controller.StartController(ctx, cfg, defaultNamespace) + return g.Wait() } diff --git a/cmd/controller/cmd/webhook.go b/cmd/controller/cmd/webhook.go index 40d94cab99..ff8b81ea11 100644 --- a/cmd/controller/cmd/webhook.go +++ b/cmd/controller/cmd/webhook.go @@ -3,6 +3,10 @@ package cmd import ( "context" + "github.com/flyteorg/flytepropeller/pkg/controller" + "github.com/flyteorg/flytestdlib/promutils" + "golang.org/x/sync/errgroup" + webhookConfig "github.com/flyteorg/flytepropeller/pkg/webhook/config" "github.com/flyteorg/flytestdlib/profutils" @@ -13,10 +17,6 @@ import ( "github.com/spf13/cobra" ) -const ( - podDefaultNamespace = "default" -) - var webhookCmd = &cobra.Command{ Use: "webhook", Aliases: []string{"webhooks"}, @@ -82,11 +82,37 @@ func runWebhook(origContext context.Context, propellerCfg *config.Config, cfg *w // set up signals so we handle the first shutdown signal gracefully ctx := signals.SetupSignalHandler(origContext) - go func() { - err := profutils.StartProfilingServerWithDefaultHandlers(ctx, propellerCfg.ProfilerPort.Port, nil) + propellerScope := promutils.NewScope(cfg.MetricsPrefix).NewSubScope("propeller").NewSubScope(propellerCfg.LimitNamespace) + mgr, err := controller.CreateControllerManager(ctx, propellerCfg, defaultNamespace, &propellerScope) + if err != nil { + logger.Fatalf(ctx, "Failed to create controller manager. Error: %v", err) + return err + } + + g, childCtx := errgroup.WithContext(ctx) + g.Go(func() error { + err := profutils.StartProfilingServerWithDefaultHandlers(childCtx, propellerCfg.ProfilerPort.Port, nil) if err != nil { - logger.Panicf(ctx, "Failed to Start profiling and metrics server. Error: %v", err) + logger.Fatalf(childCtx, "Failed to Start profiling and metrics server. Error: %v", err) } - }() - return webhook.Run(ctx, propellerCfg, cfg, defaultNamespace) + return err + }) + + g.Go(func() error { + err := controller.StartControllerManager(childCtx, mgr) + if err != nil { + logger.Fatalf(childCtx, "Failed to start controller manager. Error: %v", err) + } + return err + }) + + g.Go(func() error { + err := webhook.Run(childCtx, propellerCfg, cfg, defaultNamespace, &propellerScope, mgr) + if err != nil { + logger.Fatalf(childCtx, "Failed to start webhook. Error: %v", err) + } + return err + }) + + return g.Wait() } diff --git a/go.mod b/go.mod index b1b120a76a..17332f0169 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/spf13/cobra v1.1.1 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.7.0 + golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba golang.org/x/tools v0.1.10 // indirect google.golang.org/grpc v1.36.0 diff --git a/go.sum b/go.sum index 36a82886b5..5401a8f887 100644 --- a/go.sum +++ b/go.sum @@ -937,6 +937,7 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go b/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go index b4ebb861fe..d4485535d5 100644 --- a/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go +++ b/pkg/apis/flyteworkflow/v1alpha1/zz_generated.deepcopy.go @@ -1,3 +1,4 @@ +//go:build !ignore_autogenerated // +build !ignore_autogenerated // Code generated by deepcopy-gen. DO NOT EDIT. diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index d75800f985..7e4400924b 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -493,8 +493,42 @@ func SharedInformerOptions(cfg *config.Config, defaultNamespace string) []inform return opts } +func CreateControllerManager(ctx context.Context, cfg *config.Config, + defaultNamespace string, scope *promutils.Scope) (*manager.Manager, error) { + + _, kubecfg, err := utils.GetKubeConfig(ctx, cfg) + if err != nil { + return nil, errors.Wrapf(err, "error building Kubernetes Clientset") + } + + limitNamespace := "" + if cfg.LimitNamespace != defaultNamespace { + limitNamespace = cfg.LimitNamespace + } + mgr, err := manager.New(kubecfg, manager.Options{ + Namespace: limitNamespace, + SyncPeriod: &cfg.DownstreamEval.Duration, + ClientBuilder: executors.NewFallbackClientBuilder((*scope).NewSubScope("kube")), + }) + if err != nil { + return nil, errors.Wrapf(err, "failed to initialize controller-runtime manager") + } + return &mgr, nil +} + +// StartControllerManager Start controller runtime manager to start listening to resource changes. +// K8sPluginManager uses controller runtime to create informers for the CRDs being monitored by plugins. The informer +// EventHandler enqueues the owner workflow for reevaluation. These informer events allow propeller to detect +// workflow changes faster than the default sync interval for workflow CRDs. +func StartControllerManager(ctx context.Context, mgr *manager.Manager) error { + ctx = contextutils.WithGoroutineLabel(ctx, "controller-runtime-manager") + pprof.SetGoroutineLabels(ctx) + logger.Infof(ctx, "Starting controller-runtime manager") + return (*mgr).Start(ctx) +} + // StartController creates a new FlytePropeller Controller and starts it -func StartController(ctx context.Context, cfg *config.Config, defaultNamespace string) error { +func StartController(ctx context.Context, cfg *config.Config, defaultNamespace string, mgr *manager.Manager, scope *promutils.Scope) error { // Setup cancel on the context ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -532,38 +566,7 @@ func StartController(ctx context.Context, cfg *config.Config, defaultNamespace s informerFactory := k8sInformers.NewSharedInformerFactoryWithOptions(kubeClient, flyteK8sConfig.GetK8sPluginConfig().DefaultPodTemplateResync.Duration) - // Add the propeller subscope because the MetricsPrefix only has "flyte:" to get uniform collection of metrics. - propellerScope := promutils.NewScope(cfg.MetricsPrefix).NewSubScope("propeller").NewSubScope(cfg.LimitNamespace) - - limitNamespace := "" - if cfg.LimitNamespace != defaultNamespace { - limitNamespace = cfg.LimitNamespace - } - - mgr, err := manager.New(kubecfg, manager.Options{ - Namespace: limitNamespace, - SyncPeriod: &cfg.DownstreamEval.Duration, - ClientBuilder: executors.NewFallbackClientBuilder(propellerScope.NewSubScope("kube")), - }) - if err != nil { - return errors.Wrapf(err, "failed to initialize controller-runtime manager") - } - - // Start controller runtime manager to start listening to resource changes. - // K8sPluginManager uses controller runtime to create informers for the CRDs being monitored by plugins. The informer - // EventHandler enqueues the owner workflow for reevaluation. These informer events allow propeller to detect - // workflow changes faster than the default sync interval for workflow CRDs. - go func(ctx context.Context) { - ctx = contextutils.WithGoroutineLabel(ctx, "controller-runtime-manager") - pprof.SetGoroutineLabels(ctx) - logger.Infof(ctx, "Starting controller-runtime manager") - err := mgr.Start(ctx) - if err != nil { - logger.Fatalf(ctx, "Failed to start manager. Error: %v", err) - } - }(ctx) - - c, err := New(ctx, cfg, kubeClient, flyteworkflowClient, flyteworkflowInformerFactory, informerFactory, mgr, propellerScope) + c, err := New(ctx, cfg, kubeClient, flyteworkflowClient, flyteworkflowInformerFactory, informerFactory, *mgr, *scope) if err != nil { return errors.Wrap(err, "failed to start FlytePropeller") } else if c == nil { diff --git a/pkg/webhook/config/config.go b/pkg/webhook/config/config.go index 7d78e99996..61c598c1d5 100644 --- a/pkg/webhook/config/config.go +++ b/pkg/webhook/config/config.go @@ -14,8 +14,10 @@ var ( DefaultConfig = &Config{ SecretName: "flyte-pod-webhook", ServiceName: "flyte-pod-webhook", + ServicePort: 443, MetricsPrefix: "flyte:", CertDir: "/etc/webhook/certs", + LocalCert: false, ListenPort: 9443, SecretManagerType: SecretManagerTypeK8s, AWSSecretManagerConfig: AWSSecretManagerConfig{ @@ -72,8 +74,10 @@ const ( type Config struct { MetricsPrefix string `json:"metrics-prefix" pflag:",An optional prefix for all published metrics."` CertDir string `json:"certDir" pflag:",Certificate directory to use to write generated certs. Defaults to /etc/webhook/certs/"` + LocalCert bool `json:"localCert" pflag:",write certs locally. Defaults to false"` ListenPort int `json:"listenPort" pflag:",The port to use to listen to webhook calls. Defaults to 9443"` ServiceName string `json:"serviceName" pflag:",The name of the webhook service."` + ServicePort int32 `json:"servicePort" pflag:",The port on the service that hosting webhook."` SecretName string `json:"secretName" pflag:",Secret name to write generated certs to."` SecretManagerType SecretManagerType `json:"secretManagerType" pflag:"-,Secret manager type to use if secrets are not found in global secrets."` AWSSecretManagerConfig AWSSecretManagerConfig `json:"awsSecretManager" pflag:",AWS Secret Manager config."` diff --git a/pkg/webhook/config/config_flags.go b/pkg/webhook/config/config_flags.go index 6dea588048..7ef9575d79 100755 --- a/pkg/webhook/config/config_flags.go +++ b/pkg/webhook/config/config_flags.go @@ -52,8 +52,10 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) cmdFlags.String(fmt.Sprintf("%v%v", prefix, "metrics-prefix"), DefaultConfig.MetricsPrefix, "An optional prefix for all published metrics.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "certDir"), DefaultConfig.CertDir, "Certificate directory to use to write generated certs. Defaults to /etc/webhook/certs/") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "localCert"), DefaultConfig.LocalCert, "write certs locally. Defaults to false") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "listenPort"), DefaultConfig.ListenPort, "The port to use to listen to webhook calls. Defaults to 9443") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "serviceName"), DefaultConfig.ServiceName, "The name of the webhook service.") + cmdFlags.Int32(fmt.Sprintf("%v%v", prefix, "servicePort"), DefaultConfig.ServicePort, "The port on the service that hosting webhook.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "secretName"), DefaultConfig.SecretName, "Secret name to write generated certs to.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "awsSecretManager.sidecarImage"), DefaultConfig.AWSSecretManagerConfig.SidecarImage, "Specifies the sidecar docker image to use") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "vaultSecretManager.role"), DefaultConfig.VaultSecretManagerConfig.Role, "Specifies the vault role to use") diff --git a/pkg/webhook/config/config_flags_test.go b/pkg/webhook/config/config_flags_test.go index 10b69e8455..e68b5af131 100755 --- a/pkg/webhook/config/config_flags_test.go +++ b/pkg/webhook/config/config_flags_test.go @@ -127,6 +127,20 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_localCert", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("localCert", testValue) + if vBool, err := cmdFlags.GetBool("localCert"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.LocalCert) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_listenPort", func(t *testing.T) { t.Run("Override", func(t *testing.T) { @@ -155,6 +169,20 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_servicePort", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("servicePort", testValue) + if vInt32, err := cmdFlags.GetInt32("servicePort"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt32), &actual.ServicePort) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_secretName", func(t *testing.T) { t.Run("Override", func(t *testing.T) { diff --git a/pkg/webhook/entrypoint.go b/pkg/webhook/entrypoint.go index da4555946c..4f5a068251 100644 --- a/pkg/webhook/entrypoint.go +++ b/pkg/webhook/entrypoint.go @@ -8,7 +8,6 @@ import ( "os" "github.com/flyteorg/flytepropeller/pkg/controller/config" - "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/utils" config2 "github.com/flyteorg/flytepropeller/pkg/webhook/config" "github.com/flyteorg/flytestdlib/logger" @@ -24,7 +23,8 @@ const ( PodNamespaceEnvVar = "POD_NAMESPACE" ) -func Run(ctx context.Context, propellerCfg *config.Config, cfg *config2.Config, defaultNamespace string) error { +func Run(ctx context.Context, propellerCfg *config.Config, cfg *config2.Config, + defaultNamespace string, scope *promutils.Scope, mgr *manager.Manager) error { raw, err := json.Marshal(cfg) if err != nil { return err @@ -32,19 +32,12 @@ func Run(ctx context.Context, propellerCfg *config.Config, cfg *config2.Config, fmt.Println(string(raw)) - kubeClient, kubecfg, err := utils.GetKubeConfig(ctx, propellerCfg) + kubeClient, _, err := utils.GetKubeConfig(ctx, propellerCfg) if err != nil { return err } - // Add the propeller subscope because the MetricsPrefix only has "flyte:" to get uniform collection of metrics. - propellerScope := promutils.NewScope(cfg.MetricsPrefix).NewSubScope("propeller").NewSubScope(propellerCfg.LimitNamespace) - webhookScope := propellerScope.NewSubScope("webhook") - - limitNamespace := "" - if propellerCfg.LimitNamespace != defaultNamespace { - limitNamespace = propellerCfg.LimitNamespace - } + webhookScope := (*scope).NewSubScope("webhook") secretsWebhook := NewPodMutator(cfg, webhookScope) @@ -54,25 +47,13 @@ func Run(ctx context.Context, propellerCfg *config.Config, cfg *config2.Config, return err } - mgr, err := manager.New(kubecfg, manager.Options{ - Port: cfg.ListenPort, - CertDir: cfg.CertDir, - Namespace: limitNamespace, - SyncPeriod: &propellerCfg.DownstreamEval.Duration, - ClientBuilder: executors.NewFallbackClientBuilder(webhookScope), - }) - - if err != nil { - logger.Fatalf(ctx, "Failed to initialize controller run-time manager. Error: %v", err) - } - - err = secretsWebhook.Register(ctx, mgr) + err = secretsWebhook.Register(ctx, *mgr) if err != nil { logger.Fatalf(ctx, "Failed to register webhook with manager. Error: %v", err) } logger.Infof(ctx, "Starting controller-runtime manager") - return mgr.Start(ctx) + return (*mgr).Start(ctx) } func createMutationConfig(ctx context.Context, kubeClient *kubernetes.Clientset, webhookObj *PodMutator, defaultNamespace string) error { diff --git a/pkg/webhook/init_cert.go b/pkg/webhook/init_cert.go new file mode 100644 index 0000000000..24bfd957b8 --- /dev/null +++ b/pkg/webhook/init_cert.go @@ -0,0 +1,250 @@ +package webhook + +import ( + "bytes" + "context" + cryptorand "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "os" + "path" + "time" + + "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/utils" + webhookConfig "github.com/flyteorg/flytepropeller/pkg/webhook/config" + "github.com/flyteorg/flytestdlib/logger" + corev1 "k8s.io/api/core/v1" + kubeErrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + v1 "k8s.io/client-go/kubernetes/typed/core/v1" +) + +type webhookCerts struct { + // base64 Encoded CA Cert + CaPEM *bytes.Buffer + // base64 Encoded Server Cert + ServerPEM *bytes.Buffer + // base64 Encoded Server Cert Key + PrivateKeyPEM *bytes.Buffer +} + +const ( + CaCertKey = "ca.crt" + ServerCertKey = "tls.crt" + ServerCertPrivateKey = "tls.key" + podDefaultNamespace = "flyte" + permission = 0644 +) + +func InitCerts(ctx context.Context, propellerCfg *config.Config, cfg *webhookConfig.Config) error { + podNamespace, found := os.LookupEnv(PodNamespaceEnvVar) + if !found { + podNamespace = podDefaultNamespace + } + + logger.Infof(ctx, "Issuing certs") + certs, err := createCerts(podNamespace) + if err != nil { + return err + } + + kubeClient, _, err := utils.GetKubeConfig(ctx, propellerCfg) + if err != nil { + return err + } + + logger.Infof(ctx, "Creating secret [%v] in Namespace [%v]", cfg.SecretName, podNamespace) + err = createWebhookSecret(ctx, podNamespace, cfg, certs, kubeClient.CoreV1().Secrets(podNamespace)) + if err != nil { + return err + } + + return nil +} + +func createWebhookSecret(ctx context.Context, namespace string, cfg *webhookConfig.Config, certs webhookCerts, secretsClient v1.SecretInterface) error { + isImmutable := true + secretData := map[string][]byte{ + CaCertKey: certs.CaPEM.Bytes(), + ServerCertKey: certs.ServerPEM.Bytes(), + ServerCertPrivateKey: certs.PrivateKeyPEM.Bytes(), + } + + if cfg.LocalCert { + if _, err := os.Stat(cfg.CertDir); os.IsNotExist(err) { + if err := os.Mkdir(cfg.CertDir, permission); err != nil { + return err + } + } + + if err := os.WriteFile(path.Join(cfg.CertDir, CaCertKey), certs.CaPEM.Bytes(), permission); err != nil { + return err + } + + if err := os.WriteFile(path.Join(cfg.CertDir, ServerCertKey), certs.ServerPEM.Bytes(), permission); err != nil { + return err + } + + if err := os.WriteFile(path.Join(cfg.CertDir, ServerCertPrivateKey), certs.PrivateKeyPEM.Bytes(), permission); err != nil { + return err + } + } + + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: cfg.SecretName, + Namespace: namespace, + }, + Type: corev1.SecretTypeOpaque, + Data: secretData, + Immutable: &isImmutable, + } + + _, err := secretsClient.Create(ctx, secret, metav1.CreateOptions{}) + if err == nil { + logger.Infof(ctx, "Created secret [%v]", cfg.SecretName) + return nil + } + + if kubeErrors.IsAlreadyExists(err) { + logger.Infof(ctx, "A secret already exists with the same name. Validating.") + s, err := secretsClient.Get(ctx, cfg.SecretName, metav1.GetOptions{}) + if err != nil { + return err + } + + // If ServerCertKey or ServerCertPrivateKey are missing, update + requiresUpdate := false + for key := range secretData { + if key == CaCertKey { + continue + } + + if _, exists := s.Data[key]; !exists { + requiresUpdate = true + break + } + } + + if requiresUpdate { + logger.Infof(ctx, "The existing secret is missing one or more keys.") + secret.Annotations = map[string]string{ + "flyteLastUpdate": "system-updated", + "flyteUpdatedAt": time.Now().String(), + } + + _, err = secretsClient.Update(ctx, secret, metav1.UpdateOptions{}) + if err != nil && kubeErrors.IsConflict(err) { + logger.Infof(ctx, "Another instance of flyteadmin has updated the same secret. Ignoring this update") + err = nil + } + + return err + } + + return nil + } + + return err +} + +func createCerts(serviceNamespace string) (certs webhookCerts, err error) { + // CA config + caRequest := &x509.Certificate{ + SerialNumber: big.NewInt(2021), + Subject: pkix.Name{ + Organization: []string{"flyte.org"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + // CA private key + caPrivateKey, err := rsa.GenerateKey(cryptorand.Reader, 4096) + if err != nil { + return webhookCerts{}, err + } + + // Self signed CA certificate + caCert, err := x509.CreateCertificate(cryptorand.Reader, caRequest, caRequest, &caPrivateKey.PublicKey, caPrivateKey) + if err != nil { + return webhookCerts{}, err + } + + // PEM encode CA cert + caPEM := new(bytes.Buffer) + err = pem.Encode(caPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: caCert, + }) + if err != nil { + return webhookCerts{}, err + } + + dnsNames := []string{"flyte-pod-webhook", + "flyte-pod-webhook." + serviceNamespace, "flyte-pod-webhook." + serviceNamespace + ".svc"} + commonName := "flyte-pod-webhook." + serviceNamespace + ".svc" + + // server cert config + certRequest := &x509.Certificate{ + DNSNames: dnsNames, + SerialNumber: big.NewInt(1658), + Subject: pkix.Name{ + CommonName: commonName, + Organization: []string{"flyte.org"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), + SubjectKeyId: []byte{1, 2, 3, 4, 6}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + + // server private key + serverPrivateKey, err := rsa.GenerateKey(cryptorand.Reader, 4096) + if err != nil { + return webhookCerts{}, err + } + + // sign the server cert + cert, err := x509.CreateCertificate(cryptorand.Reader, certRequest, caRequest, &serverPrivateKey.PublicKey, caPrivateKey) + if err != nil { + return webhookCerts{}, err + } + + // PEM encode the server cert and key + serverCertPEM := new(bytes.Buffer) + err = pem.Encode(serverCertPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: cert, + }) + + if err != nil { + return webhookCerts{}, fmt.Errorf("failed to Encode CertPEM. Error: %w", err) + } + + serverPrivKeyPEM := new(bytes.Buffer) + err = pem.Encode(serverPrivKeyPEM, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(serverPrivateKey), + }) + + if err != nil { + return webhookCerts{}, fmt.Errorf("failed to Encode Cert Private Key. Error: %w", err) + } + + return webhookCerts{ + CaPEM: caPEM, + ServerPEM: serverCertPEM, + PrivateKeyPEM: serverPrivKeyPEM, + }, nil +} diff --git a/pkg/webhook/pod.go b/pkg/webhook/pod.go index 49f6fd0eae..ec940a2799 100644 --- a/pkg/webhook/pod.go +++ b/pkg/webhook/pod.go @@ -181,6 +181,7 @@ func (pm PodMutator) CreateMutationWebhookConfiguration(namespace string) (*admi Name: pm.cfg.ServiceName, Namespace: namespace, }, + Webhooks: []admissionregistrationv1.MutatingWebhook{ { Name: webhookName, @@ -190,6 +191,7 @@ func (pm PodMutator) CreateMutationWebhookConfiguration(namespace string) (*admi Name: pm.cfg.ServiceName, Namespace: namespace, Path: &path, + Port: &pm.cfg.ServicePort, }, }, Rules: []admissionregistrationv1.RuleWithOperations{