diff --git a/Makefile b/Makefile index 84938afdd..b2f078723 100644 --- a/Makefile +++ b/Makefile @@ -12,18 +12,21 @@ update_boilerplate: .PHONY: linux_compile linux_compile: GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts/flytepropeller ./cmd/controller/main.go + GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts/flytepropeller-manager ./cmd/manager/main.go GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o /artifacts/kubectl-flyte ./cmd/kubectl-flyte/main.go .PHONY: compile compile: mkdir -p ./bin go build -o bin/flytepropeller ./cmd/controller/main.go + go build -o bin/flytepropeller-manager ./cmd/manager/main.go go build -o bin/kubectl-flyte ./cmd/kubectl-flyte/main.go && cp bin/kubectl-flyte ${GOPATH}/bin cross_compile: @glide install @mkdir -p ./bin/cross GOOS=linux GOARCH=amd64 go build -o bin/cross/flytepropeller ./cmd/controller/main.go + GOOS=linux GOARCH=amd64 go build -o bin/cross/flytepropeller-manager ./cmd/manager/main.go GOOS=linux GOARCH=amd64 go build -o bin/cross/kubectl-flyte ./cmd/kubectl-flyte/main.go op_code_generate: @@ -38,6 +41,11 @@ benchmark: server: @go run ./cmd/controller/main.go --alsologtostderr --propeller.kube-config=$(HOME)/.kube/config +# manager starts the manager service in development mode +.PHONY: manager +manager: + @go run ./cmd/manager/main.go --alsologtostderr --propeller.kube-config=$(HOME)/.kube/config + clean: rm -rf bin diff --git a/cmd/controller/cmd/init_certs.go b/cmd/controller/cmd/init_certs.go index 101181580..9e2167729 100644 --- a/cmd/controller/cmd/init_certs.go +++ b/cmd/controller/cmd/init_certs.go @@ -11,6 +11,7 @@ import ( kubeErrors "k8s.io/apimachinery/pkg/api/errors" "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/utils" corev1 "k8s.io/api/core/v1" v12 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -84,7 +85,7 @@ func runCertsCmd(ctx context.Context, propellerCfg *config.Config, cfg *webhookC return err } - kubeClient, _, err := getKubeConfig(ctx, propellerCfg) + kubeClient, _, err := utils.GetKubeConfig(ctx, propellerCfg) if err != nil { return err } diff --git a/cmd/controller/cmd/root.go b/cmd/controller/cmd/root.go index 1236d68dc..497d4680e 100644 --- a/cmd/controller/cmd/root.go +++ b/cmd/controller/cmd/root.go @@ -11,6 +11,7 @@ import ( "github.com/flyteorg/flytestdlib/contextutils" + transformers "github.com/flyteorg/flytepropeller/pkg/compiler/transformers/k8s" "github.com/flyteorg/flytepropeller/pkg/controller/executors" "k8s.io/klog" @@ -27,20 +28,15 @@ import ( "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/profutils" "github.com/flyteorg/flytestdlib/promutils" - "github.com/pkg/errors" "github.com/spf13/pflag" "github.com/spf13/cobra" - "k8s.io/client-go/kubernetes" - "k8s.io/client-go/tools/clientcmd" - - restclient "k8s.io/client-go/rest" - clientset "github.com/flyteorg/flytepropeller/pkg/client/clientset/versioned" informers "github.com/flyteorg/flytepropeller/pkg/client/informers/externalversions" "github.com/flyteorg/flytepropeller/pkg/controller" "github.com/flyteorg/flytepropeller/pkg/signals" + "github.com/flyteorg/flytepropeller/pkg/utils" ) const ( @@ -116,39 +112,39 @@ func logAndExit(err error) { os.Exit(-1) } -func getKubeConfig(_ context.Context, cfg *config2.Config) (*kubernetes.Clientset, *restclient.Config, error) { - var kubecfg *restclient.Config - var err error - if cfg.KubeConfigPath != "" { - kubeConfigPath := os.ExpandEnv(cfg.KubeConfigPath) - kubecfg, err = clientcmd.BuildConfigFromFlags(cfg.MasterURL, kubeConfigPath) - if err != nil { - return nil, nil, errors.Wrapf(err, "Error building kubeconfig") - } - } else { - kubecfg, err = restclient.InClusterConfig() - if err != nil { - return nil, nil, errors.Wrapf(err, "Cannot get InCluster kubeconfig") - } +func sharedInformerOptions(cfg *config2.Config) []informers.SharedInformerOption { + selectors := []struct { + label string + operation v1.LabelSelectorOperator + values []string + }{ + {transformers.ShardKeyLabel, v1.LabelSelectorOpIn, cfg.IncludeShardKeyLabel}, + {transformers.ShardKeyLabel, v1.LabelSelectorOpNotIn, cfg.ExcludeShardKeyLabel}, + {transformers.ProjectLabel, v1.LabelSelectorOpIn, cfg.IncludeProjectLabel}, + {transformers.ProjectLabel, v1.LabelSelectorOpNotIn, cfg.ExcludeProjectLabel}, + {transformers.DomainLabel, v1.LabelSelectorOpIn, cfg.IncludeDomainLabel}, + {transformers.DomainLabel, v1.LabelSelectorOpNotIn, cfg.ExcludeDomainLabel}, } - kubecfg.QPS = cfg.KubeConfig.QPS - kubecfg.Burst = cfg.KubeConfig.Burst - kubecfg.Timeout = cfg.KubeConfig.Timeout.Duration + labelSelector := controller.IgnoreCompletedWorkflowsLabelSelector() + for _, selector := range selectors { + if len(selector.values) > 0 { + labelSelectorRequirement := v1.LabelSelectorRequirement{ + Key: selector.label, + Operator: selector.operation, + Values: selector.values, + } - kubeClient, err := kubernetes.NewForConfig(kubecfg) - if err != nil { - return nil, nil, errors.Wrapf(err, "Error building kubernetes clientset") + labelSelector.MatchExpressions = append(labelSelector.MatchExpressions, labelSelectorRequirement) + } } - return kubeClient, kubecfg, err -} -func sharedInformerOptions(cfg *config2.Config) []informers.SharedInformerOption { opts := []informers.SharedInformerOption{ informers.WithTweakListOptions(func(options *v1.ListOptions) { - options.LabelSelector = v1.FormatLabelSelector(controller.IgnoreCompletedWorkflowsLabelSelector()) + options.LabelSelector = v1.FormatLabelSelector(labelSelector) }), } + if cfg.LimitNamespace != defaultNamespace { opts = append(opts, informers.WithNamespace(cfg.LimitNamespace)) } @@ -166,7 +162,7 @@ func executeRootCmd(cfg *config2.Config) { // set up signals so we handle the first shutdown signal gracefully ctx := signals.SetupSignalHandler(baseCtx) - kubeClient, kubecfg, err := getKubeConfig(ctx, cfg) + kubeClient, kubecfg, err := utils.GetKubeConfig(ctx, cfg) if err != nil { logger.Fatalf(ctx, "Error building kubernetes clientset: %s", err.Error()) } diff --git a/cmd/controller/cmd/webhook.go b/cmd/controller/cmd/webhook.go index 3af087e7b..40c039e05 100644 --- a/cmd/controller/cmd/webhook.go +++ b/cmd/controller/cmd/webhook.go @@ -17,6 +17,7 @@ import ( "github.com/flyteorg/flytepropeller/pkg/controller/executors" "github.com/flyteorg/flytepropeller/pkg/signals" + "github.com/flyteorg/flytepropeller/pkg/utils" "github.com/flyteorg/flytepropeller/pkg/webhook" "github.com/flyteorg/flytestdlib/logger" "github.com/flyteorg/flytestdlib/profutils" @@ -105,7 +106,7 @@ func runWebhook(origContext context.Context, propellerCfg *config.Config, cfg *w fmt.Println(string(raw)) - kubeClient, kubecfg, err := getKubeConfig(ctx, propellerCfg) + kubeClient, kubecfg, err := utils.GetKubeConfig(ctx, propellerCfg) if err != nil { return err } diff --git a/cmd/manager/cmd/root.go b/cmd/manager/cmd/root.go new file mode 100644 index 000000000..fc3da4af6 --- /dev/null +++ b/cmd/manager/cmd/root.go @@ -0,0 +1,202 @@ +// Commands for FlytePropeller manager. +package cmd + +import ( + "context" + "flag" + "os" + "runtime" + + "github.com/flyteorg/flytestdlib/config" + "github.com/flyteorg/flytestdlib/config/viper" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/profutils" + "github.com/flyteorg/flytestdlib/promutils" + "github.com/flyteorg/flytestdlib/version" + + "github.com/flyteorg/flytepropeller/manager" + managerConfig "github.com/flyteorg/flytepropeller/manager/config" + propellerConfig "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytepropeller/pkg/signals" + "github.com/flyteorg/flytepropeller/pkg/utils" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/klog" +) + +const ( + appName = "flytepropeller-manager" + podDefaultNamespace = "flyte" + podNameEnvVar = "POD_NAME" + podNamespaceEnvVar = "POD_NAMESPACE" +) + +var ( + cfgFile string + configAccessor = viper.NewAccessor(config.Options{StrictMode: true}) +) + +// rootCmd represents the base command when called without any subcommands +var rootCmd = &cobra.Command{ + Use: appName, + Short: "Runs FlytePropeller Manager to scale out FlytePropeller by executing multiple instances configured according to the defined sharding scheme.", + Long: ` +FlytePropeller Manager is used to effectively scale out FlyteWorkflow processing among a collection of FlytePropeller instances. Users configure a sharding mechanism (ex. 'hash', 'project', or 'domain') to define the sharding environment. + +The FlytePropeller Manager uses a kubernetes PodTemplate to construct the base FlytePropeller PodSpec. This means, apart from the configured sharding scheme, all managed FlytePropeller instances will be identical. + +The Manager ensures liveness and correctness by periodically scanning kubernets pods and recovering state (ie. starting missing pods, etc). Live configuration updates are currently unsupported, meaning configuration changes require an application restart. + +Sample configuration, illustrating 3 separate sharding techniques, is provided below: + + manager: + pod-application: "flytepropeller" + pod-namespace: "flyte" + pod-template-name: "flytepropeller-template" + pod-template-namespace: "flyte" + scan-interval: 10s + shard: + # distribute FlyteWorkflow processing over 3 machines evenly + type: hash + pod-count: 3 + + # process the specified projects on defined replicas and all uncovered projects on another + type: project + enableUncoveredReplica: true + replicas: + - entities: + - flytesnacks + - entities: + - flyteexamples + - flytelab + + # process the 'production' domain on a single instace and all other domains on another + type: domain + enableUncoveredReplica: true + replicas: + - entities: + - production + `, + PersistentPreRunE: initConfig, + Run: func(cmd *cobra.Command, args []string) { + executeRootCmd(propellerConfig.GetConfig(), managerConfig.GetConfig()) + }, +} + +// Execute adds all child commands to the root command and sets flags appropriately. +// This is called by main.main(). It only needs to happen once to the rootCmd. +func Execute() { + version.LogBuildInformation(appName) + logger.Infof(context.TODO(), "detected %d CPU's\n", runtime.NumCPU()) + if err := rootCmd.Execute(); err != nil { + logger.Error(context.TODO(), err) + os.Exit(1) + } +} + +func init() { + // allows `$ flytepropeller-manager --logtostderr` to work + klog.InitFlags(flag.CommandLine) + pflag.CommandLine.AddGoFlagSet(flag.CommandLine) + err := flag.CommandLine.Parse([]string{}) + if err != nil { + logAndExit(err) + } + + // Here you will define your flags and configuration settings. Cobra supports persistent flags, which, if defined + // here, will be global for your application. + rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", + "config file (default is $HOME/config.yaml)") + + configAccessor.InitializePflags(rootCmd.PersistentFlags()) + + rootCmd.AddCommand(viper.GetConfigCommand()) +} + +func initConfig(cmd *cobra.Command, _ []string) error { + configAccessor = viper.NewAccessor(config.Options{ + StrictMode: false, + SearchPaths: []string{cfgFile}, + }) + + configAccessor.InitializePflags(cmd.PersistentFlags()) + + err := configAccessor.UpdateConfig(context.TODO()) + if err != nil { + return err + } + + return nil +} + +func logAndExit(err error) { + logger.Error(context.Background(), err) + os.Exit(-1) +} + +func executeRootCmd(propellerCfg *propellerConfig.Config, cfg *managerConfig.Config) { + baseCtx := context.Background() + + // set up signals so we handle the first shutdown signal gracefully + ctx := signals.SetupSignalHandler(baseCtx) + + // lookup owner reference + kubeClient, _, err := utils.GetKubeConfig(ctx, propellerCfg) + if err != nil { + logger.Fatalf(ctx, "error building kubernetes clientset [%v]", err) + } + + ownerReferences := make([]metav1.OwnerReference, 0) + lookupOwnerReferences := true + podName, found := os.LookupEnv(podNameEnvVar) + if !found { + lookupOwnerReferences = false + } + + podNamespace, found := os.LookupEnv(podNamespaceEnvVar) + if !found { + lookupOwnerReferences = false + podNamespace = podDefaultNamespace + } + + if lookupOwnerReferences { + p, err := kubeClient.CoreV1().Pods(podNamespace).Get(ctx, podName, metav1.GetOptions{}) + if err != nil { + logger.Fatalf(ctx, "failed to get pod '%v' in namespace '%v' [%v]", podName, podNamespace, err) + } + + for _, ownerReference := range p.OwnerReferences { + // must set owner reference controller to false because k8s does not allow setting pod + // owner references to a controller that does not acknowledge ownership. in this case + // the owner is technically the FlytePropeller Manager pod and not that pods owner. + *ownerReference.BlockOwnerDeletion = false + *ownerReference.Controller = false + + ownerReferences = append(ownerReferences, ownerReference) + } + } + + // Add the propeller_manager subscope because the MetricsPrefix only has "flyte:" to get uniform collection of metrics. + scope := promutils.NewScope(propellerCfg.MetricsPrefix).NewSubScope("propeller_manager") + + go func() { + err := profutils.StartProfilingServerWithDefaultHandlers(ctx, propellerCfg.ProfilerPort.Port, nil) + if err != nil { + logger.Panicf(ctx, "failed to start profiling and metrics server [%v]", err) + } + }() + + m, err := manager.New(ctx, propellerCfg, cfg, podNamespace, ownerReferences, kubeClient, scope) + if err != nil { + logger.Fatalf(ctx, "failed to start manager [%v]", err) + } else if m == nil { + logger.Fatalf(ctx, "failed to start manager, nil manager received") + } + + if err = m.Run(ctx); err != nil { + logger.Fatalf(ctx, "error running manager [%v]", err) + } +} diff --git a/cmd/manager/main.go b/cmd/manager/main.go new file mode 100644 index 000000000..9ced29741 --- /dev/null +++ b/cmd/manager/main.go @@ -0,0 +1,9 @@ +package main + +import ( + "github.com/flyteorg/flytepropeller/cmd/manager/cmd" +) + +func main() { + cmd.Execute() +} diff --git a/manager/config/config.go b/manager/config/config.go new file mode 100644 index 000000000..d6bc21ac1 --- /dev/null +++ b/manager/config/config.go @@ -0,0 +1,62 @@ +package config + +import ( + "time" + + "github.com/flyteorg/flytestdlib/config" +) + +//go:generate pflags Config --default-var=DefaultConfig +//go:generate enumer --type=ShardType --trimprefix=ShardType -json -yaml + +var ( + DefaultConfig = &Config{ + PodApplication: "flytepropeller", + PodTemplateContainerName: "flytepropeller", + PodTemplateName: "flytepropeller-template", + PodTemplateNamespace: "flyte", + ScanInterval: config.Duration{ + Duration: 10 * time.Second, + }, + ShardConfig: ShardConfig{ + Type: ShardTypeHash, + ShardCount: 3, + }, + } + + configSection = config.MustRegisterSection("manager", DefaultConfig) +) + +type ShardType int + +const ( + ShardTypeDomain ShardType = iota + ShardTypeProject + ShardTypeHash +) + +// Configuration for defining shard replicas when using project or domain shard types +type PerShardMappingsConfig struct { + IDs []string `json:"ids" pflag:",The list of ids to be managed"` +} + +// Configuration for the FlytePropeller sharding strategy +type ShardConfig struct { + Type ShardType `json:"type" pflag:",Shard implementation to use"` + PerShardMappings []PerShardMappingsConfig `json:"per-shard-mapping" pflag:"-"` + ShardCount int `json:"shard-count" pflag:",The number of shards to manage for a 'hash' shard type"` +} + +// Configuration for the FlytePropeller Manager instance +type Config struct { + PodApplication string `json:"pod-application" pflag:",Application name for managed pods"` + PodTemplateContainerName string `json:"pod-template-container-name" pflag:",The container name within the K8s PodTemplate name used to set FlyteWorkflow CRD labels selectors"` + PodTemplateName string `json:"pod-template-name" pflag:",K8s PodTemplate name to use for starting FlytePropeller pods"` + PodTemplateNamespace string `json:"pod-template-namespace" pflag:",Namespace where the k8s PodTemplate is located"` + ScanInterval config.Duration `json:"scan-interval" pflag:",Frequency to scan FlytePropeller pods and start / restart if necessary"` + ShardConfig ShardConfig `json:"shard" pflag:",Configure the shard strategy for this manager"` +} + +func GetConfig() *Config { + return configSection.GetConfig().(*Config) +} diff --git a/manager/config/config_flags.go b/manager/config/config_flags.go new file mode 100755 index 000000000..0e143f881 --- /dev/null +++ b/manager/config/config_flags.go @@ -0,0 +1,61 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package config + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (Config) mustJsonMarshal(v interface{}) string { + raw, err := json.Marshal(v) + if err != nil { + panic(err) + } + + return string(raw) +} + +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "pod-application"), DefaultConfig.PodApplication, "Application name for managed pods") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "pod-template-container-name"), DefaultConfig.PodTemplateContainerName, "The container name within the K8s PodTemplate name used to set FlyteWorkflow CRD labels selectors") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "pod-template-name"), DefaultConfig.PodTemplateName, "K8s PodTemplate name to use for starting FlytePropeller pods") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "pod-template-namespace"), DefaultConfig.PodTemplateNamespace, "Namespace where the k8s PodTemplate is located") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "scan-interval"), DefaultConfig.ScanInterval.String(), "Frequency to scan FlytePropeller pods and start / restart if necessary") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "shard.type"), DefaultConfig.ShardConfig.Type.String(), "Shard implementation to use") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "shard.shard-count"), DefaultConfig.ShardConfig.ShardCount, "The number of shards to manage for a 'hash' shard type") + return cmdFlags +} diff --git a/manager/config/config_flags_test.go b/manager/config/config_flags_test.go new file mode 100755 index 000000000..887452276 --- /dev/null +++ b/manager/config/config_flags_test.go @@ -0,0 +1,200 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package config + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeRaw_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_pod-application", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("pod-application", testValue) + if vString, err := cmdFlags.GetString("pod-application"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.PodApplication) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_pod-template-container-name", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("pod-template-container-name", testValue) + if vString, err := cmdFlags.GetString("pod-template-container-name"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.PodTemplateContainerName) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_pod-template-name", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("pod-template-name", testValue) + if vString, err := cmdFlags.GetString("pod-template-name"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.PodTemplateName) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_pod-template-namespace", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("pod-template-namespace", testValue) + if vString, err := cmdFlags.GetString("pod-template-namespace"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.PodTemplateNamespace) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_scan-interval", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := DefaultConfig.ScanInterval.String() + + cmdFlags.Set("scan-interval", testValue) + if vString, err := cmdFlags.GetString("scan-interval"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.ScanInterval) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_shard.type", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("shard.type", testValue) + if vString, err := cmdFlags.GetString("shard.type"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.ShardConfig.Type) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_shard.shard-count", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("shard.shard-count", testValue) + if vInt, err := cmdFlags.GetInt("shard.shard-count"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.ShardConfig.ShardCount) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/manager/config/doc.go b/manager/config/doc.go new file mode 100644 index 000000000..2930e72a9 --- /dev/null +++ b/manager/config/doc.go @@ -0,0 +1,4 @@ +/* +Package config details configuration data structures for the FlytePropeller Manager implementation. +*/ +package config diff --git a/manager/config/shardtype_enumer.go b/manager/config/shardtype_enumer.go new file mode 100644 index 000000000..78ae145d8 --- /dev/null +++ b/manager/config/shardtype_enumer.go @@ -0,0 +1,86 @@ +// Code generated by "enumer --type=ShardType --trimprefix=ShardType -json -yaml"; DO NOT EDIT. + +// +package config + +import ( + "encoding/json" + "fmt" +) + +const _ShardTypeName = "DomainProjectHash" + +var _ShardTypeIndex = [...]uint8{0, 6, 13, 17} + +func (i ShardType) String() string { + if i < 0 || i >= ShardType(len(_ShardTypeIndex)-1) { + return fmt.Sprintf("ShardType(%d)", i) + } + return _ShardTypeName[_ShardTypeIndex[i]:_ShardTypeIndex[i+1]] +} + +var _ShardTypeValues = []ShardType{0, 1, 2} + +var _ShardTypeNameToValueMap = map[string]ShardType{ + _ShardTypeName[0:6]: 0, + _ShardTypeName[6:13]: 1, + _ShardTypeName[13:17]: 2, +} + +// ShardTypeString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func ShardTypeString(s string) (ShardType, error) { + if val, ok := _ShardTypeNameToValueMap[s]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to ShardType values", s) +} + +// ShardTypeValues returns all values of the enum +func ShardTypeValues() []ShardType { + return _ShardTypeValues +} + +// IsAShardType returns "true" if the value is listed in the enum definition. "false" otherwise +func (i ShardType) IsAShardType() bool { + for _, v := range _ShardTypeValues { + if i == v { + return true + } + } + return false +} + +// MarshalJSON implements the json.Marshaler interface for ShardType +func (i ShardType) MarshalJSON() ([]byte, error) { + return json.Marshal(i.String()) +} + +// UnmarshalJSON implements the json.Unmarshaler interface for ShardType +func (i *ShardType) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return fmt.Errorf("ShardType should be a string, got %s", data) + } + + var err error + *i, err = ShardTypeString(s) + return err +} + +// MarshalYAML implements a YAML Marshaler for ShardType +func (i ShardType) MarshalYAML() (interface{}, error) { + return i.String(), nil +} + +// UnmarshalYAML implements a YAML Unmarshaler for ShardType +func (i *ShardType) UnmarshalYAML(unmarshal func(interface{}) error) error { + var s string + if err := unmarshal(&s); err != nil { + return err + } + + var err error + *i, err = ShardTypeString(s) + return err +} diff --git a/manager/doc.go b/manager/doc.go new file mode 100644 index 000000000..2afdcb148 --- /dev/null +++ b/manager/doc.go @@ -0,0 +1,63 @@ +/* +Package manager introduces a FlytePropeller Manager implementation that enables horizontal scaling of FlytePropeller by sharding FlyteWorkflows. + +The FlytePropeller Manager manages a collection of FlytePropeller instances to effectively distribute load. Each managed FlytePropller instance is created as a k8s pod using a configurable k8s PodTemplate resource. The FlytePropeller Manager use a control loop to periodically check the status of managed FlytePropeller instances and creates, updates, or deletes pods as required. It is important to note that if the FlytePropeller Manager fails, managed instances are left running. This is in effort to ensure progress continues in evaluating FlyteWorkflow CRDs. + +FlytePropeller Manager is configured at the root of the FlytePropeller configurtion. Below is an example of the variety of configuration options along with succinct associated descriptions for each field: + + manager: + pod-application: "flytepropeller" # application name for managed pods + pod-template-container-name: "flytepropeller" # the container name within the K8s PodTemplate name used to set FlyteWorkflow CRD labels selectors + pod-template-name: "flytepropeller-template" # k8s PodTemplate name to use for starting FlytePropeller pods + pod-template-namespace: "flyte" # namespace where the k8s PodTemplate is located + scan-interval: 10s # frequency to scan FlytePropeller pods and start / restart if necessary + shard: # configure sharding strategy + # shard configuration redacted + +FlytePropeller Manager handles dynamic updates to both the k8s PodTemplate and shard configuration. The k8s PodTemplate resource has an associated resource version which uniquely identifies changes. Additionally, shard configuration modifications may be tracked using a simple hash. Flyte stores these values as annotations on managed FlytePropeller instances. Therefore, if either of there values change the FlytePropeller Manager instance will detect it and perform the necessary deployment updates. + +Shard Strategies + +Flyte defines a variety of Shard Strategies for configuring how FlyteWorkflows are sharded. These options may include the shard type (ex. hash, project, or domain) along with the number of shards or the distribution of project / domain IDs over shards. + +Internally, FlyteWorkflow CRDs are initialized with k8s labels for project, domain, and a shard-key. The project and domain label values are associated with the environment of the registered workflow. The shard-key value is a range-bounded hash over various components of the FlyteWorkflow metadata, currently the keyspace range is defined as [0,32). A sharded Flyte deployment ensures deterministic FlyteWorkflow evalutions by setting disjoint k8s label selectors, based on the aforementioned labels, on each managed FlytePropeller instance. This ensures that only a single FlytePropeller instance is responsible for processing each FlyteWorkflow. + +The Hash Shard Strategy, denoted by "type: hash" in the configuration below, uses consistent hashing to evenly distribute FlyteWorkflows over managed FlytePropeller instances. This is achieved by partitioning the keyspace (i.e. [0,32)) into a collection of disjoint ranges and using label selectors to assign those ranges to managed FlytePropeller instances. For example, with "shard-count: 4" the first instance is responsible for FlyteWorkflows with "shard-keys" in the range [0,8), the second [8,16), the third [16,24), and the fourth [24,32). It may be useful to note that the default shard type is "hash", so it will be implicitly defined if otherwise left out of the configuration. An example configuration for the Hash Shard Strategy is provided below: + + # a configuration example using the "hash" shard type + manager: + # pod and scanning configuration redacted + shard: + type: hash # use the "hash" shard strategy + shard-count: 4 # the total number of shards + +The Project and Domain Shard Strategies, denoted by "type: project" and "type: domain" respectively, use the FlyteWorkflow project and domain metadata to distributed FlyteWorkflows over managed FlytePropeller instances. These Shard Strategies are configured using a "per-shard-mapping" option, which is a list of ID lists. Each element in the "per-shard-mapping" list defines a new shard and the ID list assigns responsibility for the specified IDs to that shard. The assignment is performed using k8s label selectors, where each managed FlytePropeller instance includes FlyteWorkflows with the specified project or domain labels. + +A shard configured as a single wildcard ID (i.e. "*") is responsible for all IDs that are not covered by other shards. Only a single shard may be configured with a wildcard ID and on that shard their must be only one ID, namely the wildcard. In this case, the managed FlytePropeller instance uses k8s label selectors to exclude FlyteWorkflows with project or domain IDs from other shards. + + # a configuration example using the "project" shard type + manager: + # pod and scanning configuration redacted + shard: + type: project # use the "project" shard strategy + per-shard-mapping: # a list of per shard mappings - one shard is created for each element + - ids: # the list of ids to be managed by the first shard + - flytesnacks + - ids: # the list of ids to be managed by the second shard + - flyteexamples + - flytelabs + - ids: # the list of ids to be managed by the third shard + - "*" # use the wildcard to manage all ids not managed by other shards + + # a configuration example using the "domain" shard type + manager: + # pod and scanning configuration redacted + shard: + type: domain # use the "domain" shard strategy + per-shard-mapping: # a list of per shard mappings - one shard is created for each element + - ids: # the list of ids to be managed by the first shard + - production + - ids: # the list of ids to be managed by the second shard + - "*" # use the wildcard to manage all ids not managed by other shards +*/ +package manager diff --git a/manager/manager.go b/manager/manager.go new file mode 100644 index 000000000..fd3cc6e9b --- /dev/null +++ b/manager/manager.go @@ -0,0 +1,295 @@ +package manager + +import ( + "context" + "fmt" + "time" + + managerConfig "github.com/flyteorg/flytepropeller/manager/config" + "github.com/flyteorg/flytepropeller/manager/shardstrategy" + propellerConfig "github.com/flyteorg/flytepropeller/pkg/controller/config" + leader "github.com/flyteorg/flytepropeller/pkg/leaderelection" + "github.com/flyteorg/flytepropeller/pkg/utils" + + stderrors "github.com/flyteorg/flytestdlib/errors" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" + + "github.com/prometheus/client_golang/prometheus" + + v1 "k8s.io/api/core/v1" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/util/wait" + + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/tools/leaderelection" +) + +const ( + podTemplateResourceVersion = "podTemplateResourceVersion" + shardConfigHash = "shardConfigHash" +) + +type metrics struct { + Scope promutils.Scope + RoundTime promutils.StopWatch + PodsCreated prometheus.Counter + PodsDeleted prometheus.Counter + PodsRunning prometheus.Gauge +} + +func newManagerMetrics(scope promutils.Scope) *metrics { + return &metrics{ + Scope: scope, + RoundTime: scope.MustNewStopWatch("round_time", "Time to perform one round of validating managed pod status'", time.Millisecond), + PodsCreated: scope.MustNewCounter("pods_created_count", "Total number of pods created"), + PodsDeleted: scope.MustNewCounter("pods_deleted_count", "Total number of pods deleted"), + PodsRunning: scope.MustNewGauge("pods_running_count", "Number of managed pods currently running"), + } +} + +// Manager periodically scans k8s to ensure liveness of multiple FlytePropeller controller instances +// and rectifies state based on the configured sharding strategy. +type Manager struct { + kubeClient kubernetes.Interface + leaderElector *leaderelection.LeaderElector + metrics *metrics + ownerReferences []metav1.OwnerReference + podApplication string + podNamespace string + podTemplateContainerName string + podTemplateName string + podTemplateNamespace string + scanInterval time.Duration + shardStrategy shardstrategy.ShardStrategy +} + +func (m *Manager) createPods(ctx context.Context) error { + t := m.metrics.RoundTime.Start() + defer t.Stop() + + // retrieve pod metadata + podTemplate, err := m.kubeClient.CoreV1().PodTemplates(m.podTemplateNamespace).Get(ctx, m.podTemplateName, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("failed to retrieve pod template '%s' from namespace '%s' [%v]", m.podTemplateName, m.podTemplateNamespace, err) + } + + shardConfigHash, err := m.shardStrategy.HashCode() + if err != nil { + return err + } + + podAnnotations := map[string]string{ + "podTemplateResourceVersion": podTemplate.ObjectMeta.ResourceVersion, + "shardConfigHash": fmt.Sprintf("%d", shardConfigHash), + } + podNames := m.getPodNames() + podLabels := map[string]string{ + "app": m.podApplication, + } + + // disable leader election on all managed pods + container, err := utils.GetContainer(&podTemplate.Template.Spec, m.podTemplateContainerName) + if err != nil { + return fmt.Errorf("failed to retrieve flytepropeller container from pod template [%v]", err) + } + + container.Args = append(container.Args, "--propeller.leader-election.enabled=false") + + // retrieve existing pods + listOptions := metav1.ListOptions{ + LabelSelector: labels.SelectorFromSet(podLabels).String(), + } + + pods, err := m.kubeClient.CoreV1().Pods(m.podNamespace).List(ctx, listOptions) + if err != nil { + return err + } + + // note: we are unable to short-circuit if 'len(pods) == len(m.podNames)' because there may be + // unmanaged flytepropeller pods - which is invalid configuration but will be detected later + + // determine missing managed pods + podExists := make(map[string]bool) + for _, podName := range podNames { + podExists[podName] = false + } + + podsRunning := 0 + for _, pod := range pods.Items { + podName := pod.ObjectMeta.Name + + // validate existing pod annotations + deletePod := false + for key, value := range podAnnotations { + if pod.ObjectMeta.Annotations[key] != value { + logger.Infof(ctx, "detected pod '%s' with stale configuration", podName) + deletePod = true + break + } + } + + if pod.Status.Phase == v1.PodFailed { + logger.Warnf(ctx, "detected pod '%s' in 'failed' state", podName) + deletePod = true + } + + if deletePod { + err := m.kubeClient.CoreV1().Pods(m.podNamespace).Delete(ctx, podName, metav1.DeleteOptions{}) + if err != nil { + return err + } + + m.metrics.PodsDeleted.Inc() + logger.Infof(ctx, "deleted pod '%s'", podName) + continue + } + + // update podExists to track existing pods + if _, ok := podExists[podName]; ok { + podExists[podName] = true + + if pod.Status.Phase == v1.PodRunning { + podsRunning++ + } + } + } + + m.metrics.PodsRunning.Set(float64(podsRunning)) + + // create non-existent pods + errs := stderrors.ErrorCollection{} + for i, podName := range podNames { + if exists := podExists[podName]; !exists { + pod := &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: podAnnotations, + Name: podName, + Namespace: m.podNamespace, + Labels: podLabels, + OwnerReferences: m.ownerReferences, + }, + Spec: *podTemplate.Template.Spec.DeepCopy(), + } + + err := m.shardStrategy.UpdatePodSpec(&pod.Spec, m.podTemplateContainerName, i) + if err != nil { + errs.Append(fmt.Errorf("failed to update pod spec for '%s' [%v]", podName, err)) + continue + } + + _, err = m.kubeClient.CoreV1().Pods(m.podNamespace).Create(ctx, pod, metav1.CreateOptions{}) + if err != nil { + errs.Append(fmt.Errorf("failed to create pod '%s' [%v]", podName, err)) + continue + } + + m.metrics.PodsCreated.Inc() + logger.Infof(ctx, "created pod '%s'", podName) + } + } + + return errs.ErrorOrDefault() +} + +func (m *Manager) getPodNames() []string { + podCount := m.shardStrategy.GetPodCount() + var podNames []string + for i := 0; i < podCount; i++ { + podNames = append(podNames, fmt.Sprintf("%s-%d", m.podApplication, i)) + } + + return podNames +} + +// Run starts the manager instance as either a k8s leader, if configured, or as a standalone process. +func (m *Manager) Run(ctx context.Context) error { + if m.leaderElector != nil { + logger.Infof(ctx, "running with leader election") + m.leaderElector.Run(ctx) + } else { + logger.Infof(ctx, "running without leader election") + if err := m.run(ctx); err != nil { + return err + } + } + + return nil +} + +func (m *Manager) run(ctx context.Context) error { + logger.Infof(ctx, "started manager") + wait.UntilWithContext(ctx, + func(ctx context.Context) { + logger.Debugf(ctx, "validating managed pod(s) state") + err := m.createPods(ctx) + if err != nil { + logger.Errorf(ctx, "failed to create pod(s) [%v]", err) + } + }, + m.scanInterval, + ) + + logger.Infof(ctx, "shutting down manager") + return nil +} + +// New creates a new FlytePropeller Manager instance. +func New(ctx context.Context, propellerCfg *propellerConfig.Config, cfg *managerConfig.Config, podNamespace string, ownerReferences []metav1.OwnerReference, kubeClient kubernetes.Interface, scope promutils.Scope) (*Manager, error) { + shardStrategy, err := shardstrategy.NewShardStrategy(ctx, cfg.ShardConfig) + if err != nil { + return nil, fmt.Errorf("failed to initialize shard strategy [%v]", err) + } + + manager := &Manager{ + kubeClient: kubeClient, + metrics: newManagerMetrics(scope), + ownerReferences: ownerReferences, + podApplication: cfg.PodApplication, + podNamespace: podNamespace, + podTemplateContainerName: cfg.PodTemplateContainerName, + podTemplateName: cfg.PodTemplateName, + podTemplateNamespace: cfg.PodTemplateNamespace, + scanInterval: cfg.ScanInterval.Duration, + shardStrategy: shardStrategy, + } + + // configure leader elector + eventRecorder, err := utils.NewK8sEventRecorder(ctx, kubeClient, "flytepropeller-manager", propellerCfg.PublishK8sEvents) + if err != nil { + return nil, fmt.Errorf("failed to initialize k8s event recorder [%v]", err) + } + + lock, err := leader.NewResourceLock(kubeClient.CoreV1(), kubeClient.CoordinationV1(), eventRecorder, propellerCfg.LeaderElection) + if err != nil { + return nil, fmt.Errorf("failed to initialize resource lock [%v]", err) + } + + if lock != nil { + logger.Infof(ctx, "creating leader elector for the controller") + manager.leaderElector, err = leader.NewLeaderElector( + lock, + propellerCfg.LeaderElection, + func(ctx context.Context) { + logger.Infof(ctx, "started leading") + if err := manager.run(ctx); err != nil { + logger.Error(ctx, err) + } + }, + func() { + // need to check if this elector obtained leadership until k8s client-go api is fixed. currently the + // OnStoppingLeader func is called as a defer on every elector run, regardless of election status. + if manager.leaderElector.IsLeader() { + logger.Info(ctx, "stopped leading") + } + }) + + if err != nil { + return nil, fmt.Errorf("failed to initialize leader elector [%v]", err) + } + } + + return manager, nil +} diff --git a/manager/manager_test.go b/manager/manager_test.go new file mode 100644 index 000000000..9eb831d8b --- /dev/null +++ b/manager/manager_test.go @@ -0,0 +1,184 @@ +package manager + +import ( + "context" + "fmt" + "testing" + + "github.com/flyteorg/flytestdlib/promutils" + + "github.com/flyteorg/flytepropeller/manager/shardstrategy" + "github.com/flyteorg/flytepropeller/manager/shardstrategy/mocks" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/kubernetes/fake" +) + +var ( + podTemplate = &v1.PodTemplate{ + ObjectMeta: metav1.ObjectMeta{ + ResourceVersion: "0", + }, + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Command: []string{"flytepropeller"}, + Args: []string{"--config", "/etc/flyte/config/*.yaml"}, + }, + }, + }, + }, + } +) + +func createShardStrategy(podCount int) shardstrategy.ShardStrategy { + shardStrategy := mocks.ShardStrategy{} + shardStrategy.OnGetPodCount().Return(podCount) + shardStrategy.OnHashCode().Return(0, nil) + shardStrategy.OnUpdatePodSpecMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + + return &shardStrategy +} + +func TestCreatePods(t *testing.T) { + t.Parallel() + tests := []struct { + name string + shardStrategy shardstrategy.ShardStrategy + }{ + {"2", createShardStrategy(2)}, + {"3", createShardStrategy(3)}, + {"4", createShardStrategy(4)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.TODO() + scope := promutils.NewScope(fmt.Sprintf("create_%s", tt.name)) + kubeClient := fake.NewSimpleClientset(podTemplate) + + manager := Manager{ + kubeClient: kubeClient, + metrics: newManagerMetrics(scope), + podApplication: "flytepropeller", + shardStrategy: tt.shardStrategy, + } + + // ensure no pods are "running" + kubePodsClient := kubeClient.CoreV1().Pods("") + pods, err := kubePodsClient.List(ctx, metav1.ListOptions{}) + assert.NoError(t, err) + assert.Equal(t, 0, len(pods.Items)) + + // create all pods and validate state + err = manager.createPods(ctx) + assert.NoError(t, err) + + pods, err = kubePodsClient.List(ctx, metav1.ListOptions{}) + assert.NoError(t, err) + assert.Equal(t, tt.shardStrategy.GetPodCount(), len(pods.Items)) + + // execute again to ensure no new pods are created + err = manager.createPods(ctx) + assert.NoError(t, err) + + pods, err = kubePodsClient.List(ctx, metav1.ListOptions{}) + assert.NoError(t, err) + assert.Equal(t, tt.shardStrategy.GetPodCount(), len(pods.Items)) + }) + } +} + +func TestUpdatePods(t *testing.T) { + t.Parallel() + tests := []struct { + name string + shardStrategy shardstrategy.ShardStrategy + }{ + {"2", createShardStrategy(2)}, + {"3", createShardStrategy(3)}, + {"4", createShardStrategy(4)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.TODO() + scope := promutils.NewScope(fmt.Sprintf("update_%s", tt.name)) + + initObjects := []runtime.Object{podTemplate} + for i := 0; i < tt.shardStrategy.GetPodCount(); i++ { + initObjects = append(initObjects, &v1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + podTemplateResourceVersion: "1", + shardConfigHash: "1", + }, + Labels: map[string]string{ + "app": "flytepropeller", + }, + Name: fmt.Sprintf("flytepropeller-%d", i), + }, + }) + } + + kubeClient := fake.NewSimpleClientset(initObjects...) + + manager := Manager{ + kubeClient: kubeClient, + metrics: newManagerMetrics(scope), + podApplication: "flytepropeller", + shardStrategy: tt.shardStrategy, + } + + // ensure all pods are "running" + kubePodsClient := kubeClient.CoreV1().Pods("") + pods, err := kubePodsClient.List(ctx, metav1.ListOptions{}) + assert.NoError(t, err) + assert.Equal(t, tt.shardStrategy.GetPodCount(), len(pods.Items)) + for _, pod := range pods.Items { + assert.Equal(t, "1", pod.ObjectMeta.Annotations[podTemplateResourceVersion]) + } + + // create all pods and validate state + err = manager.createPods(ctx) + assert.NoError(t, err) + + pods, err = kubePodsClient.List(ctx, metav1.ListOptions{}) + assert.NoError(t, err) + assert.Equal(t, tt.shardStrategy.GetPodCount(), len(pods.Items)) + for _, pod := range pods.Items { + assert.Equal(t, podTemplate.ObjectMeta.ResourceVersion, pod.ObjectMeta.Annotations[podTemplateResourceVersion]) + } + }) + } +} + +func TestGetPodNames(t *testing.T) { + t.Parallel() + tests := []struct { + name string + shardStrategy shardstrategy.ShardStrategy + podCount int + }{ + {"2", createShardStrategy(2), 2}, + {"3", createShardStrategy(3), 3}, + {"4", createShardStrategy(4), 4}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := Manager{ + podApplication: "flytepropeller", + shardStrategy: tt.shardStrategy, + } + + assert.Equal(t, tt.podCount, len(manager.getPodNames())) + }) + } +} diff --git a/manager/shardstrategy/doc.go b/manager/shardstrategy/doc.go new file mode 100644 index 000000000..096315e0b --- /dev/null +++ b/manager/shardstrategy/doc.go @@ -0,0 +1,4 @@ +/* +Package shardstrategy defines a variety of sharding stratgies to distribute FlyteWorkflows over managed FlytePropeller instances. +*/ +package shardstrategy diff --git a/manager/shardstrategy/environment.go b/manager/shardstrategy/environment.go new file mode 100644 index 000000000..e6b819cbd --- /dev/null +++ b/manager/shardstrategy/environment.go @@ -0,0 +1,62 @@ +package shardstrategy + +import ( + "fmt" + + "github.com/flyteorg/flytepropeller/pkg/utils" + + v1 "k8s.io/api/core/v1" +) + +// EnvironmentShardStrategy assigns either project or domain identifers to individual +// FlytePropeller instances to determine FlyteWorkflow processing responsibility. +type EnvironmentShardStrategy struct { + EnvType environmentType + PerShardIDs [][]string +} + +type environmentType int + +const ( + Project environmentType = iota + Domain +) + +func (e environmentType) String() string { + return [...]string{"project", "domain"}[e] +} + +func (e *EnvironmentShardStrategy) GetPodCount() int { + return len(e.PerShardIDs) +} + +func (e *EnvironmentShardStrategy) HashCode() (uint32, error) { + return computeHashCode(e) +} + +func (e *EnvironmentShardStrategy) UpdatePodSpec(pod *v1.PodSpec, containerName string, podIndex int) error { + container, err := utils.GetContainer(pod, containerName) + if err != nil { + return err + } + + if podIndex < 0 || podIndex >= e.GetPodCount() { + return fmt.Errorf("invalid podIndex '%d' out of range [0,%d)", podIndex, e.GetPodCount()) + } + + if len(e.PerShardIDs[podIndex]) == 1 && e.PerShardIDs[podIndex][0] == "*" { + for i, shardIDs := range e.PerShardIDs { + if i != podIndex { + for _, id := range shardIDs { + container.Args = append(container.Args, fmt.Sprintf("--propeller.exclude-%s-label", e.EnvType), id) + } + } + } + } else { + for _, id := range e.PerShardIDs[podIndex] { + container.Args = append(container.Args, fmt.Sprintf("--propeller.include-%s-label", e.EnvType), id) + } + } + + return nil +} diff --git a/manager/shardstrategy/hash.go b/manager/shardstrategy/hash.go new file mode 100644 index 000000000..7de7e69f3 --- /dev/null +++ b/manager/shardstrategy/hash.go @@ -0,0 +1,72 @@ +package shardstrategy + +import ( + "fmt" + + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + "github.com/flyteorg/flytepropeller/pkg/utils" + + v1 "k8s.io/api/core/v1" +) + +// HashShardStrategy evenly assigns disjoint keyspace responsibilities over a collection of pods. +// All FlyteWorkflows are assigned a shard-key using a hash of their executionID and are then +// processed by the FlytePropeller instance responsible for that keyspace range. +type HashShardStrategy struct { + ShardCount int +} + +func (h *HashShardStrategy) GetPodCount() int { + return h.ShardCount +} + +func (h *HashShardStrategy) HashCode() (uint32, error) { + return computeHashCode(h) +} + +func (h *HashShardStrategy) UpdatePodSpec(pod *v1.PodSpec, containerName string, podIndex int) error { + container, err := utils.GetContainer(pod, containerName) + if err != nil { + return err + } + + if podIndex < 0 || podIndex >= h.GetPodCount() { + return fmt.Errorf("invalid podIndex '%d' out of range [0,%d)", podIndex, h.GetPodCount()) + } + + startKey, endKey := ComputeKeyRange(v1alpha1.ShardKeyspaceSize, h.GetPodCount(), podIndex) + for i := startKey; i < endKey; i++ { + container.Args = append(container.Args, "--propeller.include-shard-key-label", fmt.Sprintf("%d", i)) + } + + return nil +} + +// ComputeKeyRange computes a [startKey, endKey) pair denoting the key responsibilities for the +// provided pod index given the keyspaceSize and podCount parameters. +func ComputeKeyRange(keyspaceSize, podCount, podIndex int) (int, int) { + keysPerPod := keyspaceSize / podCount + keyRemainder := keyspaceSize - (podCount * keysPerPod) + + return computeStartKey(keysPerPod, keyRemainder, podIndex), computeStartKey(keysPerPod, keyRemainder, podIndex+1) +} + +func computeStartKey(keysPerPod, keysRemainder, podIndex int) int { + return (intMin(podIndex, keysRemainder) * (keysPerPod + 1)) + (intMax(0, podIndex-keysRemainder) * keysPerPod) +} + +func intMin(a, b int) int { + if a < b { + return a + } + + return b +} + +func intMax(a, b int) int { + if a > b { + return a + } + + return b +} diff --git a/manager/shardstrategy/hash_test.go b/manager/shardstrategy/hash_test.go new file mode 100644 index 000000000..6685fbd45 --- /dev/null +++ b/manager/shardstrategy/hash_test.go @@ -0,0 +1,25 @@ +package shardstrategy + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestComputeKeyRange(t *testing.T) { + keyspaceSize := 32 + for podCount := 1; podCount < keyspaceSize; podCount++ { + keysCovered := 0 + minKeyRangeSize := keyspaceSize / podCount + for podIndex := 0; podIndex < podCount; podIndex++ { + startIndex, endIndex := ComputeKeyRange(keyspaceSize, podCount, podIndex) + + rangeSize := endIndex - startIndex + keysCovered += rangeSize + assert.True(t, rangeSize-minKeyRangeSize >= 0) + assert.True(t, rangeSize-minKeyRangeSize <= 1) + } + + assert.Equal(t, keyspaceSize, keysCovered) + } +} diff --git a/manager/shardstrategy/mocks/shard_strategy.go b/manager/shardstrategy/mocks/shard_strategy.go new file mode 100644 index 000000000..5f5925974 --- /dev/null +++ b/manager/shardstrategy/mocks/shard_strategy.go @@ -0,0 +1,117 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + + v1 "k8s.io/api/core/v1" +) + +// ShardStrategy is an autogenerated mock type for the ShardStrategy type +type ShardStrategy struct { + mock.Mock +} + +type ShardStrategy_GetPodCount struct { + *mock.Call +} + +func (_m ShardStrategy_GetPodCount) Return(_a0 int) *ShardStrategy_GetPodCount { + return &ShardStrategy_GetPodCount{Call: _m.Call.Return(_a0)} +} + +func (_m *ShardStrategy) OnGetPodCount() *ShardStrategy_GetPodCount { + c := _m.On("GetPodCount") + return &ShardStrategy_GetPodCount{Call: c} +} + +func (_m *ShardStrategy) OnGetPodCountMatch(matchers ...interface{}) *ShardStrategy_GetPodCount { + c := _m.On("GetPodCount", matchers...) + return &ShardStrategy_GetPodCount{Call: c} +} + +// GetPodCount provides a mock function with given fields: +func (_m *ShardStrategy) GetPodCount() int { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + return r0 +} + +type ShardStrategy_HashCode struct { + *mock.Call +} + +func (_m ShardStrategy_HashCode) Return(_a0 uint32, _a1 error) *ShardStrategy_HashCode { + return &ShardStrategy_HashCode{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ShardStrategy) OnHashCode() *ShardStrategy_HashCode { + c := _m.On("HashCode") + return &ShardStrategy_HashCode{Call: c} +} + +func (_m *ShardStrategy) OnHashCodeMatch(matchers ...interface{}) *ShardStrategy_HashCode { + c := _m.On("HashCode", matchers...) + return &ShardStrategy_HashCode{Call: c} +} + +// HashCode provides a mock function with given fields: +func (_m *ShardStrategy) HashCode() (uint32, error) { + ret := _m.Called() + + var r0 uint32 + if rf, ok := ret.Get(0).(func() uint32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint32) + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type ShardStrategy_UpdatePodSpec struct { + *mock.Call +} + +func (_m ShardStrategy_UpdatePodSpec) Return(_a0 error) *ShardStrategy_UpdatePodSpec { + return &ShardStrategy_UpdatePodSpec{Call: _m.Call.Return(_a0)} +} + +func (_m *ShardStrategy) OnUpdatePodSpec(pod *v1.PodSpec, containerName string, podIndex int) *ShardStrategy_UpdatePodSpec { + c := _m.On("UpdatePodSpec", pod, containerName, podIndex) + return &ShardStrategy_UpdatePodSpec{Call: c} +} + +func (_m *ShardStrategy) OnUpdatePodSpecMatch(matchers ...interface{}) *ShardStrategy_UpdatePodSpec { + c := _m.On("UpdatePodSpec", matchers...) + return &ShardStrategy_UpdatePodSpec{Call: c} +} + +// UpdatePodSpec provides a mock function with given fields: pod, containerName, podIndex +func (_m *ShardStrategy) UpdatePodSpec(pod *v1.PodSpec, containerName string, podIndex int) error { + ret := _m.Called(pod, containerName, podIndex) + + var r0 error + if rf, ok := ret.Get(0).(func(*v1.PodSpec, string, int) error); ok { + r0 = rf(pod, containerName, podIndex) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/manager/shardstrategy/shard_strategy.go b/manager/shardstrategy/shard_strategy.go new file mode 100644 index 000000000..217539295 --- /dev/null +++ b/manager/shardstrategy/shard_strategy.go @@ -0,0 +1,98 @@ +package shardstrategy + +import ( + "bytes" + "context" + "encoding/gob" + "fmt" + "hash/fnv" + + "github.com/flyteorg/flytepropeller/manager/config" + "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + + v1 "k8s.io/api/core/v1" +) + +//go:generate mockery -name ShardStrategy -case=underscore + +// ShardStrategy defines necessary functionality for a sharding strategy. +type ShardStrategy interface { + // GetPodCount returns the total number of pods for the sharding strategy. + GetPodCount() int + // HashCode generates a unique hash code to identify shard strategy updates. + HashCode() (uint32, error) + // UpdatePodSpec amends the PodSpec for the specified index to include label selectors. + UpdatePodSpec(pod *v1.PodSpec, containerName string, podIndex int) error +} + +// NewShardStrategy creates and validates a new ShardStrategy defined by the configuration. +func NewShardStrategy(ctx context.Context, shardConfig config.ShardConfig) (ShardStrategy, error) { + switch shardConfig.Type { + case config.ShardTypeHash: + if shardConfig.ShardCount <= 0 { + return nil, fmt.Errorf("configured ShardCount (%d) must be greater than zero", shardConfig.ShardCount) + } else if shardConfig.ShardCount > v1alpha1.ShardKeyspaceSize { + return nil, fmt.Errorf("configured ShardCount (%d) is larger than available keyspace size (%d)", shardConfig.ShardCount, v1alpha1.ShardKeyspaceSize) + } + + return &HashShardStrategy{ + ShardCount: shardConfig.ShardCount, + }, nil + case config.ShardTypeProject, config.ShardTypeDomain: + perShardIDs := make([][]string, 0) + wildcardIDFound := false + for _, perShardMapping := range shardConfig.PerShardMappings { + if len(perShardMapping.IDs) == 0 { + return nil, fmt.Errorf("unable to create shard with 0 configured ids") + } + + // validate wildcard ID + for _, id := range perShardMapping.IDs { + if id == "*" { + if len(perShardMapping.IDs) != 1 { + return nil, fmt.Errorf("shards responsible for the wildcard id (ie. '*') may only contain one id") + } + + if wildcardIDFound { + return nil, fmt.Errorf("may only define one shard responsible for the wildcard id (ie. '*')") + } + + wildcardIDFound = true + } + } + + perShardIDs = append(perShardIDs, perShardMapping.IDs) + } + + var envType environmentType + switch shardConfig.Type { + case config.ShardTypeProject: + envType = Project + case config.ShardTypeDomain: + envType = Domain + } + + return &EnvironmentShardStrategy{ + EnvType: envType, + PerShardIDs: perShardIDs, + }, nil + } + + return nil, fmt.Errorf("shard strategy '%s' does not exist", shardConfig.Type) +} + +func computeHashCode(data interface{}) (uint32, error) { + hash := fnv.New32a() + + buffer := new(bytes.Buffer) + encoder := gob.NewEncoder(buffer) + if err := encoder.Encode(data); err != nil { + return 0, err + } + + if _, err := hash.Write(buffer.Bytes()); err != nil { + return 0, err + } + + return hash.Sum32(), nil +} diff --git a/manager/shardstrategy/shard_strategy_test.go b/manager/shardstrategy/shard_strategy_test.go new file mode 100644 index 000000000..5be9cfddc --- /dev/null +++ b/manager/shardstrategy/shard_strategy_test.go @@ -0,0 +1,161 @@ +package shardstrategy + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + v1 "k8s.io/api/core/v1" +) + +var ( + hashShardStrategy = &HashShardStrategy{ + ShardCount: 3, + } + + projectShardStrategy = &EnvironmentShardStrategy{ + EnvType: Project, + PerShardIDs: [][]string{ + []string{"flytesnacks"}, + []string{"flytefoo", "flytebar"}, + }, + } + + projectShardStrategyWildcard = &EnvironmentShardStrategy{ + EnvType: Project, + PerShardIDs: [][]string{ + []string{"flytesnacks"}, + []string{"flytefoo", "flytebar"}, + []string{"*"}, + }, + } + + domainShardStrategy = &EnvironmentShardStrategy{ + EnvType: Domain, + PerShardIDs: [][]string{ + []string{"production"}, + []string{"foo", "bar"}, + }, + } + + domainShardStrategyWildcard = &EnvironmentShardStrategy{ + EnvType: Domain, + PerShardIDs: [][]string{ + []string{"production"}, + []string{"foo", "bar"}, + []string{"*"}, + }, + } +) + +func TestGetPodCount(t *testing.T) { + tests := []struct { + name string + shardStrategy ShardStrategy + podCount int + }{ + {"hash", hashShardStrategy, 3}, + {"project", projectShardStrategy, 2}, + {"project_wildcard", projectShardStrategyWildcard, 3}, + {"domain", domainShardStrategy, 2}, + {"domain_wildcard", domainShardStrategyWildcard, 3}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.podCount, tt.shardStrategy.GetPodCount()) + }) + } +} + +func TestUpdatePodSpec(t *testing.T) { + t.Parallel() + tests := []struct { + name string + shardStrategy ShardStrategy + }{ + {"hash", hashShardStrategy}, + {"project", projectShardStrategy}, + {"project_wildcard", projectShardStrategyWildcard}, + {"domain", domainShardStrategy}, + {"domain_wildcard", domainShardStrategyWildcard}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for podIndex := 0; podIndex < tt.shardStrategy.GetPodCount(); podIndex++ { + podSpec := v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "flytepropeller", + }, + }, + } + + err := tt.shardStrategy.UpdatePodSpec(&podSpec, "flytepropeller", podIndex) + assert.NoError(t, err) + } + }) + } +} + +func TestUpdatePodSpecInvalidPodIndex(t *testing.T) { + t.Parallel() + tests := []struct { + name string + shardStrategy ShardStrategy + }{ + {"hash", hashShardStrategy}, + {"project", projectShardStrategy}, + {"project_wildcard", projectShardStrategyWildcard}, + {"domain", domainShardStrategy}, + {"domain_wildcard", domainShardStrategyWildcard}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + podSpec := v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "flytepropeller", + }, + }, + } + + lowerErr := tt.shardStrategy.UpdatePodSpec(&podSpec, "flytepropeller", -1) + assert.Error(t, lowerErr) + + upperErr := tt.shardStrategy.UpdatePodSpec(&podSpec, "flytepropeller", tt.shardStrategy.GetPodCount()) + assert.Error(t, upperErr) + }) + } +} + +func TestUpdatePodSpecInvalidPodSpec(t *testing.T) { + t.Parallel() + tests := []struct { + name string + shardStrategy ShardStrategy + }{ + {"hash", hashShardStrategy}, + {"project", projectShardStrategy}, + {"project_wildcard", projectShardStrategyWildcard}, + {"domain", domainShardStrategy}, + {"domain_wildcard", domainShardStrategyWildcard}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + podSpec := v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "flytefoo", + }, + }, + } + + err := tt.shardStrategy.UpdatePodSpec(&podSpec, "flytepropeller", 0) + assert.Error(t, err) + }) + } +} diff --git a/pkg/controller/config/config.go b/pkg/controller/config/config.go index 156a74786..dce67eecb 100644 --- a/pkg/controller/config/config.go +++ b/pkg/controller/config/config.go @@ -137,6 +137,12 @@ type Config struct { NodeConfig NodeConfig `json:"node-config,omitempty" pflag:",config for a workflow node"` MaxStreakLength int `json:"max-streak-length" pflag:",Maximum number of consecutive rounds that one propeller worker can use for one workflow - >1 => turbo-mode is enabled."` EventConfig EventConfig `json:"event-config,omitempty" pflag:",Configures execution event behavior."` + IncludeShardKeyLabel []string `json:"include-shard-key-label" pflag:",Include the specified shard key label in the k8s FlyteWorkflow CRD label selector"` + ExcludeShardKeyLabel []string `json:"exclude-shard-key-label" pflag:",Exclude the specified shard key label from the k8s FlyteWorkflow CRD label selector"` + IncludeProjectLabel []string `json:"include-project-label" pflag:",Include the specified project label in the k8s FlyteWorkflow CRD label selector"` + ExcludeProjectLabel []string `json:"exclude-project-label" pflag:",Exclude the specified project label from the k8s FlyteWorkflow CRD label selector"` + IncludeDomainLabel []string `json:"include-domain-label" pflag:",Include the specified domain label in the k8s FlyteWorkflow CRD label selector"` + ExcludeDomainLabel []string `json:"exclude-domain-label" pflag:",Exclude the specified domain label from the k8s FlyteWorkflow CRD label selector"` } // KubeClientConfig contains the configuration used by flytepropeller to configure its internal Kubernetes Client. diff --git a/pkg/controller/config/config_flags.go b/pkg/controller/config/config_flags.go index d0590612b..039161f9a 100755 --- a/pkg/controller/config/config_flags.go +++ b/pkg/controller/config/config_flags.go @@ -95,5 +95,11 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "max-streak-length"), defaultConfig.MaxStreakLength, "Maximum number of consecutive rounds that one propeller worker can use for one workflow - >1 => turbo-mode is enabled.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "event-config.raw-output-policy"), defaultConfig.EventConfig.RawOutputPolicy, "How output data should be passed along in execution events.") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "event-config.fallback-to-output-reference"), defaultConfig.EventConfig.FallbackToOutputReference, "Whether output data should be sent by reference when it is too large to be sent inline in execution events.") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "include-shard-key-label"), []string{}, "Include the specified shard key label in the k8s FlyteWorkflow CRD label selector") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "exclude-shard-key-label"), []string{}, "Exclude the specified shard key label from the k8s FlyteWorkflow CRD label selector") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "include-project-label"), []string{}, "Include the specified project label in the k8s FlyteWorkflow CRD label selector") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "exclude-project-label"), []string{}, "Exclude the specified project label from the k8s FlyteWorkflow CRD label selector") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "include-domain-label"), []string{}, "Include the specified domain label in the k8s FlyteWorkflow CRD label selector") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "exclude-domain-label"), []string{}, "Exclude the specified domain label from the k8s FlyteWorkflow CRD label selector") return cmdFlags } diff --git a/pkg/controller/config/config_flags_test.go b/pkg/controller/config/config_flags_test.go index 7b1ca36d9..4b9ed3afe 100755 --- a/pkg/controller/config/config_flags_test.go +++ b/pkg/controller/config/config_flags_test.go @@ -729,4 +729,88 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_include-shard-key-label", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := join_Config("1,1", ",") + + cmdFlags.Set("include-shard-key-label", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("include-shard-key-label"); err == nil { + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.IncludeShardKeyLabel) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_exclude-shard-key-label", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := join_Config("1,1", ",") + + cmdFlags.Set("exclude-shard-key-label", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("exclude-shard-key-label"); err == nil { + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.ExcludeShardKeyLabel) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_include-project-label", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := join_Config("1,1", ",") + + cmdFlags.Set("include-project-label", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("include-project-label"); err == nil { + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.IncludeProjectLabel) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_exclude-project-label", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := join_Config("1,1", ",") + + cmdFlags.Set("exclude-project-label", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("exclude-project-label"); err == nil { + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.ExcludeProjectLabel) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_include-domain-label", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := join_Config("1,1", ",") + + cmdFlags.Set("include-domain-label", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("include-domain-label"); err == nil { + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.IncludeDomainLabel) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_exclude-domain-label", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := join_Config("1,1", ",") + + cmdFlags.Set("exclude-domain-label", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("exclude-domain-label"); err == nil { + testDecodeRaw_Config(t, join_Config(vStringSlice, ","), &actual.ExcludeDomainLabel) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index 771e2dbf1..ef99e0c78 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -31,24 +31,22 @@ import ( "github.com/flyteorg/flytestdlib/storage" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" - corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/clock" "k8s.io/client-go/kubernetes" - "k8s.io/client-go/kubernetes/scheme" - typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1" "k8s.io/client-go/tools/cache" "k8s.io/client-go/tools/leaderelection" "k8s.io/client-go/tools/record" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" clientset "github.com/flyteorg/flytepropeller/pkg/client/clientset/versioned" - flyteScheme "github.com/flyteorg/flytepropeller/pkg/client/clientset/versioned/scheme" informers "github.com/flyteorg/flytepropeller/pkg/client/informers/externalversions" lister "github.com/flyteorg/flytepropeller/pkg/client/listers/flyteworkflow/v1alpha1" "github.com/flyteorg/flytepropeller/pkg/controller/nodes" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/recovery" "github.com/flyteorg/flytepropeller/pkg/controller/nodes/subworkflow/launchplan" "github.com/flyteorg/flytepropeller/pkg/controller/workflow" + leader "github.com/flyteorg/flytepropeller/pkg/leaderelection" + "github.com/flyteorg/flytepropeller/pkg/utils" ) const resourceLevelMonitorCycleDuration = 5 * time.Second @@ -287,23 +285,6 @@ func newControllerMetrics(scope promutils.Scope) *metrics { } } -func newK8sEventRecorder(ctx context.Context, kubeclientset kubernetes.Interface, publishK8sEvents bool) (record.EventRecorder, error) { - // Create event broadcaster - // Add FlyteWorkflow controller types to the default Kubernetes Scheme so Events can be - // logged for FlyteWorkflow Controller types. - err := flyteScheme.AddToScheme(scheme.Scheme) - if err != nil { - return nil, err - } - logger.Info(ctx, "Creating event broadcaster") - eventBroadcaster := record.NewBroadcaster() - eventBroadcaster.StartLogging(logger.InfofNoCtx) - if publishK8sEvents { - eventBroadcaster.StartRecordingToSink(&typedcorev1.EventSinkImpl{Interface: kubeclientset.CoreV1().Events("")}) - } - return eventBroadcaster.NewRecorder(scheme.Scheme, corev1.EventSource{Component: controllerAgentName}), nil -} - func getAdminClient(ctx context.Context) (client service.AdminServiceClient, err error) { cfg := admin.GetConfig(ctx) clients, err := admin.NewClientsetBuilder().WithConfig(cfg).Build(ctx) @@ -351,7 +332,7 @@ func New(ctx context.Context, cfg *config.Config, kubeclientset kubernetes.Inter return nil, errors.Wrapf(err, "failed to initialize WF GC") } - eventRecorder, err := newK8sEventRecorder(ctx, kubeclientset, cfg.PublishK8sEvents) + eventRecorder, err := utils.NewK8sEventRecorder(ctx, kubeclientset, controllerAgentName, cfg.PublishK8sEvents) if err != nil { logger.Errorf(ctx, "failed to event recorder %v", err) return nil, errors.Wrapf(err, "failed to initialize resource lock.") @@ -363,7 +344,7 @@ func New(ctx context.Context, cfg *config.Config, kubeclientset kubernetes.Inter numWorkers: cfg.Workers, } - lock, err := newResourceLock(kubeclientset.CoreV1(), kubeclientset.CoordinationV1(), eventRecorder, cfg.LeaderElection) + lock, err := leader.NewResourceLock(kubeclientset.CoreV1(), kubeclientset.CoordinationV1(), eventRecorder, cfg.LeaderElection) if err != nil { logger.Errorf(ctx, "failed to initialize resource lock.") return nil, errors.Wrapf(err, "failed to initialize resource lock.") @@ -371,7 +352,7 @@ func New(ctx context.Context, cfg *config.Config, kubeclientset kubernetes.Inter if lock != nil { logger.Infof(ctx, "Creating leader elector for the controller.") - controller.leaderElector, err = newLeaderElector(lock, cfg.LeaderElection, controller.onStartedLeading, func() { + controller.leaderElector, err = leader.NewLeaderElector(lock, cfg.LeaderElection, controller.onStartedLeading, func() { logger.Fatal(ctx, "Lost leader state. Shutting down.") }) diff --git a/pkg/controller/leaderelection.go b/pkg/leaderelection/leader_election.go similarity index 94% rename from pkg/controller/leaderelection.go rename to pkg/leaderelection/leader_election.go index 251409e26..acbe5dc80 100644 --- a/pkg/controller/leaderelection.go +++ b/pkg/leaderelection/leader_election.go @@ -28,7 +28,7 @@ const ( ) // NewResourceLock creates a new config map resource lock for use in a leader election loop -func newResourceLock(corev1 v1.CoreV1Interface, coordinationV1 v12.CoordinationV1Interface, eventRecorder record.EventRecorder, options config.LeaderElectionConfig) ( +func NewResourceLock(corev1 v1.CoreV1Interface, coordinationV1 v12.CoordinationV1Interface, eventRecorder record.EventRecorder, options config.LeaderElectionConfig) ( resourcelock.Interface, error) { if !options.Enabled { @@ -66,7 +66,7 @@ func getUniqueLeaderID() string { return fmt.Sprintf("%v_%v", id, rand.String(10)) } -func newLeaderElector(lock resourcelock.Interface, cfg config.LeaderElectionConfig, +func NewLeaderElector(lock resourcelock.Interface, cfg config.LeaderElectionConfig, leaderFn func(ctx context.Context), leaderStoppedFn func()) (*leaderelection.LeaderElector, error) { return leaderelection.NewLeaderElector(leaderelection.LeaderElectionConfig{ Lock: lock, diff --git a/pkg/utils/k8s.go b/pkg/utils/k8s.go index 3d1da706f..b1ce78c2f 100644 --- a/pkg/utils/k8s.go +++ b/pkg/utils/k8s.go @@ -1,11 +1,17 @@ package utils import ( + "context" + "fmt" + "os" "regexp" "strings" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" + flyteScheme "github.com/flyteorg/flytepropeller/pkg/client/clientset/versioned/scheme" + "github.com/flyteorg/flytepropeller/pkg/controller/config" + "github.com/flyteorg/flytestdlib/logger" "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/timestamp" "github.com/pkg/errors" @@ -13,6 +19,12 @@ import ( "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/validation" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/kubernetes/scheme" + typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1" + restclient "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" + "k8s.io/client-go/tools/record" ) var NotTheOwnerError = errors.Errorf("FlytePropeller is not the owner") @@ -84,6 +96,44 @@ func ToK8sResourceRequirements(resources *core.Resources) (*v1.ResourceRequireme return res, nil } +// GetContainer searches the provided pod spec for a container with the specified name +func GetContainer(pod *v1.PodSpec, containerName string) (*v1.Container, error) { + for i := 0; i < len(pod.Containers); i++ { + if pod.Containers[i].Name == containerName { + return &pod.Containers[i], nil + } + } + + return nil, fmt.Errorf("container '%s' not found in podtemplate, ", containerName) +} + +func GetKubeConfig(_ context.Context, cfg *config.Config) (*kubernetes.Clientset, *restclient.Config, error) { + var kubecfg *restclient.Config + var err error + if cfg.KubeConfigPath != "" { + kubeConfigPath := os.ExpandEnv(cfg.KubeConfigPath) + kubecfg, err = clientcmd.BuildConfigFromFlags(cfg.MasterURL, kubeConfigPath) + if err != nil { + return nil, nil, errors.Wrapf(err, "Error building kubeconfig") + } + } else { + kubecfg, err = restclient.InClusterConfig() + if err != nil { + return nil, nil, errors.Wrapf(err, "Cannot get InCluster kubeconfig") + } + } + + kubecfg.QPS = cfg.KubeConfig.QPS + kubecfg.Burst = cfg.KubeConfig.Burst + kubecfg.Timeout = cfg.KubeConfig.Timeout.Duration + + kubeClient, err := kubernetes.NewForConfig(kubecfg) + if err != nil { + return nil, nil, errors.Wrapf(err, "Error building kubernetes clientset") + } + return kubeClient, kubecfg, err +} + func GetWorkflowIDFromOwner(reference *metav1.OwnerReference, namespace string) (v1alpha1.WorkflowID, error) { if reference == nil { return "", NotTheOwnerError @@ -113,3 +163,20 @@ func SanitizeLabelValue(name string) string { } return strings.Trim(name, "-") } + +func NewK8sEventRecorder(ctx context.Context, kubeclientset kubernetes.Interface, controllerAgentName string, publishK8sEvents bool) (record.EventRecorder, error) { + // Create event broadcaster + // Add FlyteWorkflow controller types to the default Kubernetes Scheme so Events can be + // logged for FlyteWorkflow Controller types. + err := flyteScheme.AddToScheme(scheme.Scheme) + if err != nil { + return nil, err + } + logger.Info(ctx, "Creating event broadcaster") + eventBroadcaster := record.NewBroadcaster() + eventBroadcaster.StartLogging(logger.InfofNoCtx) + if publishK8sEvents { + eventBroadcaster.StartRecordingToSink(&typedcorev1.EventSinkImpl{Interface: kubeclientset.CoreV1().Events("")}) + } + return eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: controllerAgentName}), nil +}