From d1a254c495a1ea4b7feff40da4d9d592a73f3199 Mon Sep 17 00:00:00 2001 From: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> Date: Fri, 14 Jul 2023 15:35:30 +0200 Subject: [PATCH 1/8] Add failing test Signed-off-by: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> --- go/tasks/plugins/k8s/dask/dask_test.go | 62 ++++++++++++++++++++------ 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/go/tasks/plugins/k8s/dask/dask_test.go b/go/tasks/plugins/k8s/dask/dask_test.go index 31115d06f..80eac321d 100644 --- a/go/tasks/plugins/k8s/dask/dask_test.go +++ b/go/tasks/plugins/k8s/dask/dask_test.go @@ -27,9 +27,11 @@ import ( ) const ( - defaultTestImage = "image://" - testNWorkers = 10 - testTaskID = "some-acceptable-name" + defaultTestImage = "image://" + testNWorkers = 10 + testTaskID = "some-acceptable-name" + podTemplateName = "dask-dummy-pod-template-name" + serviceAccountName = "dummy-service-account" ) var ( @@ -53,6 +55,16 @@ var ( Requests: testPlatformResources.Requests, Limits: testPlatformResources.Requests, } + podTemplate = &v1.PodTemplate{ + ObjectMeta: metav1.ObjectMeta{ + Name: podTemplateName, + }, + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + ServiceAccountName: serviceAccountName, + }, + }, + } ) func dummyDaskJob(status daskAPI.JobStatus) *daskAPI.DaskJob { @@ -90,7 +102,7 @@ func dummpyDaskCustomObj(customImage string, resources *core.Resources) *plugins return &daskJob } -func dummyDaskTaskTemplate(customImage string, resources *core.Resources) *core.TaskTemplate { +func dummyDaskTaskTemplate(customImage string, resources *core.Resources, podTemplateName string) *core.TaskTemplate { // In a real usecase, resources will always be filled, but might be empty if resources == nil { resources = &core.Resources{ @@ -114,9 +126,13 @@ func dummyDaskTaskTemplate(customImage string, resources *core.Resources) *core. for _, envVar := range testEnvVars { envVars = append(envVars, &core.KeyValuePair{Key: envVar.Name, Value: envVar.Value}) } + metadata := &core.TaskMetadata{ + PodTemplateName: podTemplateName, + } return &core.TaskTemplate{ - Id: &core.Identifier{Name: "test-build-resource"}, - Type: daskTaskType, + Id: &core.Identifier{Name: "test-build-resource"}, + Type: daskTaskType, + Metadata: metadata, Target: &core.TaskTemplate_Container{ Container: &core.Container{ Image: defaultTestImage, @@ -179,7 +195,7 @@ func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Resourc func TestBuildResourceDaskHappyPath(t *testing.T) { daskResourceHandler := daskResourceHandler{} - taskTemplate := dummyDaskTaskTemplate("", nil) + taskTemplate := dummyDaskTaskTemplate("", nil, "") taskContext := dummyDaskTaskContext(taskTemplate, nil, false) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -287,7 +303,7 @@ func TestBuildResourceDaskCustomImages(t *testing.T) { customImage := "customImage" daskResourceHandler := daskResourceHandler{} - taskTemplate := dummyDaskTaskTemplate(customImage, nil) + taskTemplate := dummyDaskTaskTemplate(customImage, nil, "") taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, false) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -320,7 +336,7 @@ func TestBuildResourceDaskDefaultResoureRequirements(t *testing.T) { } daskResourceHandler := daskResourceHandler{} - taskTemplate := dummyDaskTaskTemplate("", nil) + taskTemplate := dummyDaskTaskTemplate("", nil, "") taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, false) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -377,7 +393,7 @@ func TestBuildResourcesDaskCustomResoureRequirements(t *testing.T) { } daskResourceHandler := daskResourceHandler{} - taskTemplate := dummyDaskTaskTemplate("", &protobufResources) + taskTemplate := dummyDaskTaskTemplate("", &protobufResources, "") taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, false) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -432,7 +448,7 @@ func TestBuildResourceDaskInterruptible(t *testing.T) { daskResourceHandler := daskResourceHandler{} - taskTemplate := dummyDaskTaskTemplate("", nil) + taskTemplate := dummyDaskTaskTemplate("", nil, "") taskContext := dummyDaskTaskContext(taskTemplate, nil, true) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -446,7 +462,7 @@ func TestBuildResourceDaskInterruptible(t *testing.T) { assert.Equal(t, defaultNodeSelector, jobSpec.NodeSelector) assert.Equal(t, defaultAffinity, jobSpec.Affinity) - // Scheduler - should not bt interruptible + // Scheduler - should not be interruptible schedulerSpec := daskJob.Spec.Cluster.Spec.Scheduler.Spec assert.Equal(t, defaultTolerations, schedulerSpec.Tolerations) assert.Equal(t, defaultNodeSelector, schedulerSpec.NodeSelector) @@ -463,6 +479,26 @@ func TestBuildResourceDaskInterruptible(t *testing.T) { ) } +func TestBuildResouceDaskUsePodTemplate(t *testing.T) { + flytek8s.DefaultPodTemplateStore.Store(podTemplate) + daskResourceHandler := daskResourceHandler{} + taskTemplate := dummyDaskTaskTemplate("", nil, podTemplateName) + taskContext := dummyDaskTaskContext(taskTemplate, nil, false) + r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) + assert.Nil(t, err) + assert.NotNil(t, r) + daskJob, ok := r.(*daskAPI.DaskJob) + assert.True(t, ok) + + // The job template has a custom service account set. This should be passed on to all three components + assert.Equal(t, serviceAccountName, daskJob.Spec.Job.Spec.ServiceAccountName) + assert.Equal(t, serviceAccountName, daskJob.Spec.Cluster.Spec.Scheduler.Spec.ServiceAccountName) + assert.Equal(t, serviceAccountName, daskJob.Spec.Cluster.Spec.Worker.Spec.ServiceAccountName) + + // Cleanup + flytek8s.DefaultPodTemplateStore.Delete(podTemplate) +} + func TestGetPropertiesDask(t *testing.T) { daskResourceHandler := daskResourceHandler{} expected := k8s.PluginProperties{} @@ -478,7 +514,7 @@ func TestBuildIdentityResourceDask(t *testing.T) { }, } - taskTemplate := dummyDaskTaskTemplate("", nil) + taskTemplate := dummyDaskTaskTemplate("", nil, "") taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, false) identityResources, err := daskResourceHandler.BuildIdentityResource(context.TODO(), taskContext.TaskExecutionMetadata()) if err != nil { From 7de57e9c87d1d73594e2bd3421068c27ae37d888 Mon Sep 17 00:00:00 2001 From: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> Date: Fri, 14 Jul 2023 16:22:10 +0200 Subject: [PATCH 2/8] WIP Signed-off-by: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> --- go/tasks/plugins/k8s/dask/dask.go | 222 ++++++++++-------------------- 1 file changed, 70 insertions(+), 152 deletions(-) diff --git a/go/tasks/plugins/k8s/dask/dask.go b/go/tasks/plugins/k8s/dask/dask.go index 5db33aedf..60d885e84 100755 --- a/go/tasks/plugins/k8s/dask/dask.go +++ b/go/tasks/plugins/k8s/dask/dask.go @@ -6,15 +6,12 @@ import ( "time" daskAPI "github.com/dask/dask-kubernetes/v2023/dask_kubernetes/operator/go_client/pkg/apis/kubernetes.dask.org/v1" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/logs" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" @@ -30,68 +27,13 @@ const ( KindDaskJob = "DaskJob" ) -type defaults struct { - Image string - JobRunnerContainer v1.Container - Resources *v1.ResourceRequirements - Env []v1.EnvVar - Annotations map[string]string - IsInterruptible bool -} - -func getDefaults(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, taskTemplate core.TaskTemplate) (*defaults, error) { - executionMetadata := taskCtx.TaskExecutionMetadata() - - defaultContainerSpec := taskTemplate.GetContainer() - if defaultContainerSpec == nil { - return nil, errors.Errorf(errors.BadTaskSpecification, "task is missing a default container") - } - - defaultImage := defaultContainerSpec.GetImage() - if defaultImage == "" { - return nil, errors.Errorf(errors.BadTaskSpecification, "task is missing a default image") - } - - var defaultEnvVars []v1.EnvVar - if taskTemplate.GetContainer().GetEnv() != nil { - for _, keyValuePair := range taskTemplate.GetContainer().GetEnv() { - defaultEnvVars = append(defaultEnvVars, v1.EnvVar{Name: keyValuePair.Key, Value: keyValuePair.Value}) +func getPrimaryContainer(spec *v1.PodSpec, primaryContainerName string) *v1.Container { + for _, container := range spec.Containers { + if container.Name == primaryContainerName { + return &container } } - - containerResources, err := flytek8s.ToK8sResourceRequirements(defaultContainerSpec.GetResources()) - if err != nil { - return nil, err - } - - jobRunnerContainer := v1.Container{ - Name: "job-runner", - Image: defaultImage, - Args: defaultContainerSpec.GetArgs(), - Env: defaultEnvVars, - Resources: *containerResources, - } - - templateParameters := template.Parameters{ - TaskExecMetadata: taskCtx.TaskExecutionMetadata(), - Inputs: taskCtx.InputReader(), - OutputPath: taskCtx.OutputWriter(), - Task: taskCtx.TaskReader(), - } - if err = flytek8s.AddFlyteCustomizationsToContainer(ctx, templateParameters, - flytek8s.ResourceCustomizationModeMergeExistingResources, &jobRunnerContainer); err != nil { - - return nil, err - } - - return &defaults{ - Image: defaultImage, - JobRunnerContainer: jobRunnerContainer, - Resources: &jobRunnerContainer.Resources, - Env: defaultEnvVars, - Annotations: executionMetadata.GetAnnotations(), - IsInterruptible: executionMetadata.IsInterruptible(), - }, nil + return nil } type daskResourceHandler struct { @@ -114,11 +56,6 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC } else if taskTemplate == nil { return nil, errors.Errorf(errors.BadTaskSpecification, "nil task specification") } - defaults, err := getDefaults(ctx, taskCtx, *taskTemplate) - if err != nil { - return nil, err - } - clusterName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() daskJob := plugins.DaskJob{} err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &daskJob) @@ -126,38 +63,45 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC return nil, errors.Wrapf(errors.BadTaskSpecification, err, "invalid TaskSpecification [%v], failed to unmarshal", taskTemplate.GetCustom()) } - workerSpec, err := createWorkerSpec(*daskJob.Workers, *defaults) + podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) + + workerSpec, err := createWorkerSpec(*daskJob.Workers, podSpec, primaryContainerName) if err != nil { return nil, err } - schedulerSpec, err := createSchedulerSpec(*daskJob.Scheduler, clusterName, *defaults) + + clusterName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + schedulerSpec, err := createSchedulerSpec(*daskJob.Scheduler, clusterName, podSpec, primaryContainerName) if err != nil { return nil, err } - jobSpec := createJobSpec(*workerSpec, *schedulerSpec, *defaults) + + jobSpec := createJobSpec(*workerSpec, *schedulerSpec, podSpec, objectMeta) job := &daskAPI.DaskJob{ TypeMeta: metav1.TypeMeta{ Kind: KindDaskJob, APIVersion: daskAPI.SchemeGroupVersion.String(), }, - ObjectMeta: metav1.ObjectMeta{ - Name: "will-be-overridden", // Will be overridden by Flyte to `clusterName` - Annotations: defaults.Annotations, - }, - Spec: *jobSpec, + ObjectMeta: *objectMeta, + Spec: *jobSpec, } return job, nil } -func createWorkerSpec(cluster plugins.DaskWorkerGroup, defaults defaults) (*daskAPI.WorkerSpec, error) { - image := defaults.Image +func createWorkerSpec(cluster plugins.DaskWorkerGroup, podSpec *v1.PodSpec, primaryContainerName string) (*daskAPI.WorkerSpec, error) { + workerPodSpec := podSpec.DeepCopy() + primaryContainer := getPrimaryContainer(workerPodSpec, primaryContainerName) + primaryContainer.Name = "dask-worker" + + // Set custom image if present if cluster.GetImage() != "" { - image = cluster.GetImage() + primaryContainer.Image = cluster.GetImage() } + // Set custom resources var err error - resources := defaults.Resources + resources := &primaryContainer.Resources clusterResources := cluster.GetResources() if len(clusterResources.Requests) >= 1 || len(clusterResources.Limits) >= 1 { resources, err = flytek8s.ToK8sResourceRequirements(cluster.GetResources()) @@ -165,16 +109,14 @@ func createWorkerSpec(cluster plugins.DaskWorkerGroup, defaults defaults) (*dask return nil, err } } - if resources == nil { - resources = &v1.ResourceRequirements{} - } + primaryContainer.Resources = *resources + // Set custom args workerArgs := []string{ "dask-worker", "--name", "$(DASK_WORKER_NAME)", } - // If limits are set, append `--nthreads` and `--memory-limit` as per these docs: // https://kubernetes.dask.org/en/latest/kubecluster.html?#best-practices if resources != nil && resources.Limits != nil { @@ -188,83 +130,63 @@ func createWorkerSpec(cluster plugins.DaskWorkerGroup, defaults defaults) (*dask workerArgs = append(workerArgs, "--memory-limit", memory) } } - - wokerSpec := v1.PodSpec{ - Affinity: &v1.Affinity{}, - Containers: []v1.Container{ - { - Name: "dask-worker", - Image: image, - ImagePullPolicy: v1.PullIfNotPresent, - Args: workerArgs, - Resources: *resources, - Env: defaults.Env, - }, - }, - } - - if defaults.IsInterruptible { - wokerSpec.Tolerations = append(wokerSpec.Tolerations, config.GetK8sPluginConfig().InterruptibleTolerations...) - wokerSpec.NodeSelector = config.GetK8sPluginConfig().InterruptibleNodeSelector - } - flytek8s.ApplyInterruptibleNodeSelectorRequirement(defaults.IsInterruptible, wokerSpec.Affinity) + primaryContainer.Args = workerArgs return &daskAPI.WorkerSpec{ Replicas: int(cluster.GetNumberOfWorkers()), - Spec: wokerSpec, + Spec: *workerPodSpec, }, nil } -func createSchedulerSpec(cluster plugins.DaskScheduler, clusterName string, defaults defaults) (*daskAPI.SchedulerSpec, error) { - schedulerImage := defaults.Image - if cluster.GetImage() != "" { - schedulerImage = cluster.GetImage() +func createSchedulerSpec(scheduler plugins.DaskScheduler, clusterName string, podSpec *v1.PodSpec, primaryContainerName string) (*daskAPI.SchedulerSpec, error) { + schedulerPodSpec := podSpec.DeepCopy() + primaryContainer := getPrimaryContainer(schedulerPodSpec, primaryContainerName) + primaryContainer.Name = "scheduler" + + // Override image if applicable + if scheduler.GetImage() != "" { + primaryContainer.Image = scheduler.GetImage() } + // Override resources if applicable var err error - resources := defaults.Resources - - clusterResources := cluster.GetResources() - if len(clusterResources.Requests) >= 1 || len(clusterResources.Limits) >= 1 { - resources, err = flytek8s.ToK8sResourceRequirements(cluster.GetResources()) + resources := &primaryContainer.Resources + schedulerResources := scheduler.GetResources() + if len(schedulerResources.Requests) >= 1 || len(schedulerResources.Limits) >= 1 { + resources, err = flytek8s.ToK8sResourceRequirements(scheduler.GetResources()) if err != nil { return nil, err } } - if resources == nil { - resources = &v1.ResourceRequirements{} + primaryContainer.Resources = *resources + + // Override args + primaryContainer.Args = []string{"dask-scheduler"} + + // Add ports + primaryContainer.Ports = []v1.ContainerPort{ + { + Name: "tcp-comm", + ContainerPort: 8786, + Protocol: "TCP", + }, + { + Name: "dashboard", + ContainerPort: 8787, + Protocol: "TCP", + }, } + // Set restart policy + schedulerPodSpec.RestartPolicy = v1.RestartPolicyAlways + return &daskAPI.SchedulerSpec{ - Spec: v1.PodSpec{ - RestartPolicy: v1.RestartPolicyAlways, - Containers: []v1.Container{ - { - Name: "scheduler", - Image: schedulerImage, - Args: []string{"dask-scheduler"}, - Resources: *resources, - Env: defaults.Env, - Ports: []v1.ContainerPort{ - { - Name: "tcp-comm", - ContainerPort: 8786, - Protocol: "TCP", - }, - { - Name: "dashboard", - ContainerPort: 8787, - Protocol: "TCP", - }, - }, - }, - }, - }, + Spec: *schedulerPodSpec, Service: v1.ServiceSpec{ Type: v1.ServiceTypeNodePort, Selector: map[string]string{ - "dask.org/cluster-name": clusterName, - "dask.org/component": "scheduler", + "dask.org/scheduler-name": clusterName, + "dask.org/component": "scheduler", }, Ports: []v1.ServicePort{ { @@ -284,20 +206,16 @@ func createSchedulerSpec(cluster plugins.DaskScheduler, clusterName string, defa }, nil } -func createJobSpec(workerSpec daskAPI.WorkerSpec, schedulerSpec daskAPI.SchedulerSpec, defaults defaults) *daskAPI.DaskJobSpec { +func createJobSpec(workerSpec daskAPI.WorkerSpec, schedulerSpec daskAPI.SchedulerSpec, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta) *daskAPI.DaskJobSpec { + jobPodSpec := podSpec.DeepCopy() + jobPodSpec.RestartPolicy = v1.RestartPolicyNever + return &daskAPI.DaskJobSpec{ Job: daskAPI.JobSpec{ - Spec: v1.PodSpec{ - RestartPolicy: v1.RestartPolicyNever, - Containers: []v1.Container{ - defaults.JobRunnerContainer, - }, - }, + Spec: *jobPodSpec, }, Cluster: daskAPI.DaskCluster{ - ObjectMeta: metav1.ObjectMeta{ - Annotations: defaults.Annotations, - }, + ObjectMeta: *objectMeta, Spec: daskAPI.DaskClusterSpec{ Worker: workerSpec, Scheduler: schedulerSpec, From 19f073624d2c1ca9b787984a6d9f8946d3c3a32c Mon Sep 17 00:00:00 2001 From: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> Date: Sat, 15 Jul 2023 15:21:31 +0200 Subject: [PATCH 3/8] Improve test Signed-off-by: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> --- go/tasks/plugins/k8s/dask/dask_test.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/go/tasks/plugins/k8s/dask/dask_test.go b/go/tasks/plugins/k8s/dask/dask_test.go index 80eac321d..224cb4e5b 100644 --- a/go/tasks/plugins/k8s/dask/dask_test.go +++ b/go/tasks/plugins/k8s/dask/dask_test.go @@ -27,11 +27,11 @@ import ( ) const ( - defaultTestImage = "image://" - testNWorkers = 10 - testTaskID = "some-acceptable-name" - podTemplateName = "dask-dummy-pod-template-name" - serviceAccountName = "dummy-service-account" + defaultTestImage = "image://" + testNWorkers = 10 + testTaskID = "some-acceptable-name" + podTemplateName = "dask-dummy-pod-template-name" + templateServiceAccountName = "template-service-account" ) var ( @@ -61,7 +61,7 @@ var ( }, Template: v1.PodTemplateSpec{ Spec: v1.PodSpec{ - ServiceAccountName: serviceAccountName, + ServiceAccountName: templateServiceAccountName, }, }, } @@ -491,9 +491,9 @@ func TestBuildResouceDaskUsePodTemplate(t *testing.T) { assert.True(t, ok) // The job template has a custom service account set. This should be passed on to all three components - assert.Equal(t, serviceAccountName, daskJob.Spec.Job.Spec.ServiceAccountName) - assert.Equal(t, serviceAccountName, daskJob.Spec.Cluster.Spec.Scheduler.Spec.ServiceAccountName) - assert.Equal(t, serviceAccountName, daskJob.Spec.Cluster.Spec.Worker.Spec.ServiceAccountName) + assert.Equal(t, templateServiceAccountName, daskJob.Spec.Job.Spec.ServiceAccountName) + assert.Equal(t, templateServiceAccountName, daskJob.Spec.Cluster.Spec.Scheduler.Spec.ServiceAccountName) + assert.Equal(t, templateServiceAccountName, daskJob.Spec.Cluster.Spec.Worker.Spec.ServiceAccountName) // Cleanup flytek8s.DefaultPodTemplateStore.Delete(podTemplate) From 1bad96f3e33654c5b952aedf9cf2ce7278688fa5 Mon Sep 17 00:00:00 2001 From: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> Date: Sun, 16 Jul 2023 12:14:52 +0200 Subject: [PATCH 4/8] Refactor to use `ToK8sPodSpec` Signed-off-by: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> --- go.mod | 1 + go.sum | 4 +- .../pluginmachinery/flytek8s/pod_helper.go | 2 +- go/tasks/plugins/k8s/dask/dask.go | 126 ++++++++++++++++-- go/tasks/plugins/k8s/dask/dask_test.go | 67 ++++++---- 5 files changed, 157 insertions(+), 43 deletions(-) diff --git a/go.mod b/go.mod index 80b06c9df..20a65870a 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/ray-project/kuberay/ray-operator v0.0.0-20220728052838-eaa75fa6707c github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.1 + golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 golang.org/x/net v0.8.0 golang.org/x/oauth2 v0.0.0-20220411215720-9780585627b5 google.golang.org/api v0.76.0 diff --git a/go.sum b/go.sum index fd47f0ad0..0e165e063 100644 --- a/go.sum +++ b/go.sum @@ -735,6 +735,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -760,7 +762,7 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= +golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= golang.org/x/net v0.0.0-20170114055629-f2499483f923/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 661527c1d..c1db4f38b 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -61,7 +61,7 @@ func ApplyInterruptibleNodeSelectorRequirement(interruptible bool, affinity *v1. nst.MatchExpressions = append(nst.MatchExpressions, nodeSelectorRequirement) } } else { - affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms = []v1.NodeSelectorTerm{v1.NodeSelectorTerm{MatchExpressions: []v1.NodeSelectorRequirement{nodeSelectorRequirement}}} + affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms = []v1.NodeSelectorTerm{{MatchExpressions: []v1.NodeSelectorRequirement{nodeSelectorRequirement}}} } } diff --git a/go/tasks/plugins/k8s/dask/dask.go b/go/tasks/plugins/k8s/dask/dask.go index 60d885e84..459fbbbfd 100755 --- a/go/tasks/plugins/k8s/dask/dask.go +++ b/go/tasks/plugins/k8s/dask/dask.go @@ -5,6 +5,8 @@ import ( "fmt" "time" + "golang.org/x/exp/slices" + daskAPI "github.com/dask/dask-kubernetes/v2023/dask_kubernetes/operator/go_client/pkg/apis/kubernetes.dask.org/v1" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/flyteorg/flyteplugins/go/tasks/errors" @@ -12,6 +14,7 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" @@ -27,13 +30,72 @@ const ( KindDaskJob = "DaskJob" ) -func getPrimaryContainer(spec *v1.PodSpec, primaryContainerName string) *v1.Container { +func mergeMapInto(src map[string]string, dst map[string]string) { + for key, value := range src { + dst[key] = value + } +} + +func getPrimaryContainer(spec *v1.PodSpec, primaryContainerName string) (*v1.Container, error) { for _, container := range spec.Containers { if container.Name == primaryContainerName { - return &container + return &container, nil + } + } + return nil, errors.Errorf(errors.BadTaskSpecification, "primary container [%v] not found in pod spec", primaryContainerName) +} + +func replacePrimaryContainer(spec *v1.PodSpec, primaryContainerName string, container v1.Container) error { + for i, c := range spec.Containers { + if c.Name == primaryContainerName { + spec.Containers[i] = container + return nil + } + } + return errors.Errorf(errors.BadTaskSpecification, "primary container [%v] not found in pod spec", primaryContainerName) +} + +func removeInterruptibleConfig(spec *v1.PodSpec, taskCtx pluginsCore.TaskExecutionContext) { + if !taskCtx.TaskExecutionMetadata().IsInterruptible() { + return + } + + // Tolerations + interruptlibleTolerations := config.GetK8sPluginConfig().InterruptibleTolerations + newTolerations := []v1.Toleration{} + for _, toleration := range spec.Tolerations { + if !slices.Contains(interruptlibleTolerations, toleration) { + newTolerations = append(newTolerations, toleration) + } + } + spec.Tolerations = newTolerations + + // Node selectors + interruptibleNodeSelector := config.GetK8sPluginConfig().InterruptibleNodeSelector + for key := range spec.NodeSelector { + if _, ok := interruptibleNodeSelector[key]; ok { + delete(spec.NodeSelector, key) + } + } + + // Node selector requirements + interruptibleNodeSelectorRequirements := config.GetK8sPluginConfig().InterruptibleNodeSelectorRequirement + nodeSelectorTerms := spec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms + for i := range nodeSelectorTerms { + nst := &nodeSelectorTerms[i] + matchExpressions := nst.MatchExpressions + newMatchExpressions := []v1.NodeSelectorRequirement{} + for _, matchExpression := range matchExpressions { + if !nodeSelectorRequirementsAreEqual(matchExpression, *interruptibleNodeSelectorRequirements) { + newMatchExpressions = append(newMatchExpressions, matchExpression) + } } + nst.MatchExpressions = newMatchExpressions } - return nil +} + +func nodeSelectorRequirementsAreEqual(a v1.NodeSelectorRequirement, b v1.NodeSelectorRequirement) bool { + return a.Key == b.Key && a.Operator == b.Operator && slices.Equal(a.Values, b.Values) } type daskResourceHandler struct { @@ -64,6 +126,15 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC } podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) + if err != nil { + return nil, err + } + nonInterruptiblePodSpec := podSpec.DeepCopy() + removeInterruptibleConfig(nonInterruptiblePodSpec, taskCtx) + + // Add labels and annotations to objectMeta as they're not added by ToK8sPodSpec + mergeMapInto(taskCtx.TaskExecutionMetadata().GetAnnotations(), objectMeta.Annotations) + mergeMapInto(taskCtx.TaskExecutionMetadata().GetLabels(), objectMeta.Labels) workerSpec, err := createWorkerSpec(*daskJob.Workers, podSpec, primaryContainerName) if err != nil { @@ -71,12 +142,15 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC } clusterName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() - schedulerSpec, err := createSchedulerSpec(*daskJob.Scheduler, clusterName, podSpec, primaryContainerName) + schedulerSpec, err := createSchedulerSpec(*daskJob.Scheduler, clusterName, nonInterruptiblePodSpec, primaryContainerName) if err != nil { return nil, err } - jobSpec := createJobSpec(*workerSpec, *schedulerSpec, podSpec, objectMeta) + jobSpec, err := createJobSpec(*workerSpec, *schedulerSpec, nonInterruptiblePodSpec, primaryContainerName, objectMeta) + if err != nil { + return nil, err + } job := &daskAPI.DaskJob{ TypeMeta: metav1.TypeMeta{ @@ -91,7 +165,10 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC func createWorkerSpec(cluster plugins.DaskWorkerGroup, podSpec *v1.PodSpec, primaryContainerName string) (*daskAPI.WorkerSpec, error) { workerPodSpec := podSpec.DeepCopy() - primaryContainer := getPrimaryContainer(workerPodSpec, primaryContainerName) + primaryContainer, err := getPrimaryContainer(workerPodSpec, primaryContainerName) + if err != nil { + return nil, err + } primaryContainer.Name = "dask-worker" // Set custom image if present @@ -100,7 +177,6 @@ func createWorkerSpec(cluster plugins.DaskWorkerGroup, podSpec *v1.PodSpec, prim } // Set custom resources - var err error resources := &primaryContainer.Resources clusterResources := cluster.GetResources() if len(clusterResources.Requests) >= 1 || len(clusterResources.Limits) >= 1 { @@ -132,6 +208,10 @@ func createWorkerSpec(cluster plugins.DaskWorkerGroup, podSpec *v1.PodSpec, prim } primaryContainer.Args = workerArgs + err = replacePrimaryContainer(workerPodSpec, primaryContainerName, *primaryContainer) + if err != nil { + return nil, err + } return &daskAPI.WorkerSpec{ Replicas: int(cluster.GetNumberOfWorkers()), Spec: *workerPodSpec, @@ -140,7 +220,10 @@ func createWorkerSpec(cluster plugins.DaskWorkerGroup, podSpec *v1.PodSpec, prim func createSchedulerSpec(scheduler plugins.DaskScheduler, clusterName string, podSpec *v1.PodSpec, primaryContainerName string) (*daskAPI.SchedulerSpec, error) { schedulerPodSpec := podSpec.DeepCopy() - primaryContainer := getPrimaryContainer(schedulerPodSpec, primaryContainerName) + primaryContainer, err := getPrimaryContainer(schedulerPodSpec, primaryContainerName) + if err != nil { + return nil, err + } primaryContainer.Name = "scheduler" // Override image if applicable @@ -149,7 +232,6 @@ func createSchedulerSpec(scheduler plugins.DaskScheduler, clusterName string, po } // Override resources if applicable - var err error resources := &primaryContainer.Resources schedulerResources := scheduler.GetResources() if len(schedulerResources.Requests) >= 1 || len(schedulerResources.Limits) >= 1 { @@ -177,16 +259,21 @@ func createSchedulerSpec(scheduler plugins.DaskScheduler, clusterName string, po }, } - // Set restart policy schedulerPodSpec.RestartPolicy = v1.RestartPolicyAlways + // Set primary container + err = replacePrimaryContainer(schedulerPodSpec, primaryContainerName, *primaryContainer) + if err != nil { + return nil, err + } + return &daskAPI.SchedulerSpec{ Spec: *schedulerPodSpec, Service: v1.ServiceSpec{ Type: v1.ServiceTypeNodePort, Selector: map[string]string{ - "dask.org/scheduler-name": clusterName, - "dask.org/component": "scheduler", + "dask.org/cluster-name": clusterName, + "dask.org/component": "scheduler", }, Ports: []v1.ServicePort{ { @@ -206,10 +293,21 @@ func createSchedulerSpec(scheduler plugins.DaskScheduler, clusterName string, po }, nil } -func createJobSpec(workerSpec daskAPI.WorkerSpec, schedulerSpec daskAPI.SchedulerSpec, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta) *daskAPI.DaskJobSpec { +func createJobSpec(workerSpec daskAPI.WorkerSpec, schedulerSpec daskAPI.SchedulerSpec, podSpec *v1.PodSpec, primaryContainerName string, objectMeta *metav1.ObjectMeta) (*daskAPI.DaskJobSpec, error) { jobPodSpec := podSpec.DeepCopy() jobPodSpec.RestartPolicy = v1.RestartPolicyNever + primaryContainer, err := getPrimaryContainer(jobPodSpec, primaryContainerName) + if err != nil { + return nil, err + } + primaryContainer.Name = "job-runner" + + err = replacePrimaryContainer(jobPodSpec, primaryContainerName, *primaryContainer) + if err != nil { + return nil, err + } + return &daskAPI.DaskJobSpec{ Job: daskAPI.JobSpec{ Spec: *jobPodSpec, @@ -221,7 +319,7 @@ func createJobSpec(workerSpec daskAPI.WorkerSpec, schedulerSpec daskAPI.Schedule Scheduler: schedulerSpec, }, }, - } + }, nil } func (p daskResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, r client.Object) (pluginsCore.PhaseInfo, error) { diff --git a/go/tasks/plugins/k8s/dask/dask_test.go b/go/tasks/plugins/k8s/dask/dask_test.go index 224cb4e5b..f9d22c3d3 100644 --- a/go/tasks/plugins/k8s/dask/dask_test.go +++ b/go/tasks/plugins/k8s/dask/dask_test.go @@ -27,11 +27,13 @@ import ( ) const ( - defaultTestImage = "image://" - testNWorkers = 10 - testTaskID = "some-acceptable-name" - podTemplateName = "dask-dummy-pod-template-name" - templateServiceAccountName = "template-service-account" + defaultTestImage = "image://" + testNWorkers = 10 + testTaskID = "some-acceptable-name" + podTemplateName = "dask-dummy-pod-template-name" + defaultServiceAccountName = "default-service-account" + defaultNamespace = "default-namespace" + podTempaltePriorityClassName = "pod-template-priority-class-name" ) var ( @@ -61,7 +63,7 @@ var ( }, Template: v1.PodTemplateSpec{ Spec: v1.PodSpec{ - ServiceAccountName: templateServiceAccountName, + PriorityClassName: podTempaltePriorityClassName, }, }, } @@ -71,7 +73,7 @@ func dummyDaskJob(status daskAPI.JobStatus) *daskAPI.DaskJob { return &daskAPI.DaskJob{ ObjectMeta: metav1.ObjectMeta{ Name: "dask-job-name", - Namespace: "dask-namespace", + Namespace: defaultNamespace, }, Status: daskAPI.DaskJobStatus{ ClusterName: "dask-cluster-name", @@ -122,7 +124,7 @@ func dummyDaskTaskTemplate(customImage string, resources *core.Resources, podTem if err != nil { panic(err) } - envVars := []*core.KeyValuePair{} + var envVars []*core.KeyValuePair for _, envVar := range testEnvVars { envVars = append(envVars, &core.KeyValuePair{Key: envVar.Name, Value: envVar.Value}) } @@ -185,6 +187,8 @@ func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Resourc taskExecutionMetadata.OnGetMaxAttempts().Return(uint32(1)) taskExecutionMetadata.OnIsInterruptible().Return(isInterruptible) taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) + taskExecutionMetadata.OnGetK8sServiceAccount().Return(defaultServiceAccountName) + taskExecutionMetadata.OnGetNamespace().Return(defaultNamespace) overrides := &mocks.TaskOverrides{} overrides.OnGetResources().Return(resources) taskExecutionMetadata.OnGetOverrides().Return(overrides) @@ -196,7 +200,7 @@ func TestBuildResourceDaskHappyPath(t *testing.T) { daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", nil, "") - taskContext := dummyDaskTaskContext(taskTemplate, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, false) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -204,9 +208,8 @@ func TestBuildResourceDaskHappyPath(t *testing.T) { assert.True(t, ok) var defaultTolerations []v1.Toleration - var defaultNodeSelector map[string]string - var defaultAffinity *v1.Affinity - defaultWorkerAffinity := v1.Affinity{ + defaultNodeSelector := map[string]string{} + defaultAffinity := &v1.Affinity{ NodeAffinity: nil, PodAffinity: nil, PodAntiAffinity: nil, @@ -249,7 +252,8 @@ func TestBuildResourceDaskHappyPath(t *testing.T) { assert.Equal(t, defaultResources, schedulerSpec.Containers[0].Resources) assert.Equal(t, []string{"dask-scheduler"}, schedulerSpec.Containers[0].Args) assert.Equal(t, expectedPorts, schedulerSpec.Containers[0].Ports) - assert.Equal(t, testEnvVars, schedulerSpec.Containers[0].Env) + // Flyte adds more environment variables to the scheduler + assert.Contains(t, schedulerSpec.Containers[0].Env, testEnvVars[0]) assert.Equal(t, defaultTolerations, schedulerSpec.Tolerations) assert.Equal(t, defaultNodeSelector, schedulerSpec.NodeSelector) assert.Equal(t, defaultAffinity, schedulerSpec.Affinity) @@ -281,13 +285,13 @@ func TestBuildResourceDaskHappyPath(t *testing.T) { workerSpec := daskJob.Spec.Cluster.Spec.Worker.Spec assert.Equal(t, testNWorkers, daskJob.Spec.Cluster.Spec.Worker.Replicas) assert.Equal(t, "dask-worker", workerSpec.Containers[0].Name) - assert.Equal(t, v1.PullIfNotPresent, workerSpec.Containers[0].ImagePullPolicy) assert.Equal(t, defaultTestImage, workerSpec.Containers[0].Image) assert.Equal(t, defaultResources, workerSpec.Containers[0].Resources) - assert.Equal(t, testEnvVars, workerSpec.Containers[0].Env) + // Flyte adds more environment variables to the worker + assert.Contains(t, workerSpec.Containers[0].Env, testEnvVars[0]) assert.Equal(t, defaultTolerations, workerSpec.Tolerations) assert.Equal(t, defaultNodeSelector, workerSpec.NodeSelector) - assert.Equal(t, &defaultWorkerAffinity, workerSpec.Affinity) + assert.Equal(t, defaultAffinity, workerSpec.Affinity) assert.Equal(t, []string{ "dask-worker", "--name", @@ -419,9 +423,19 @@ func TestBuildResourcesDaskCustomResoureRequirements(t *testing.T) { } func TestBuildResourceDaskInterruptible(t *testing.T) { - var defaultNodeSelector map[string]string - var defaultAffinity *v1.Affinity - var defaultTolerations []v1.Toleration + defaultNodeSelector := map[string]string{} + defaultAffinity := v1.Affinity{ + NodeAffinity: &v1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &v1.NodeSelector{ + NodeSelectorTerms: []v1.NodeSelectorTerm{ + { + MatchExpressions: []v1.NodeSelectorRequirement{}, + }, + }, + }, + }, + } + defaultTolerations := []v1.Toleration{} interruptibleNodeSelector := map[string]string{ "x/interruptible": "true", @@ -449,7 +463,7 @@ func TestBuildResourceDaskInterruptible(t *testing.T) { daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", nil, "") - taskContext := dummyDaskTaskContext(taskTemplate, nil, true) + taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, true) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -460,13 +474,13 @@ func TestBuildResourceDaskInterruptible(t *testing.T) { jobSpec := daskJob.Spec.Job.Spec assert.Equal(t, defaultTolerations, jobSpec.Tolerations) assert.Equal(t, defaultNodeSelector, jobSpec.NodeSelector) - assert.Equal(t, defaultAffinity, jobSpec.Affinity) + assert.Equal(t, &defaultAffinity, jobSpec.Affinity) // Scheduler - should not be interruptible schedulerSpec := daskJob.Spec.Cluster.Spec.Scheduler.Spec assert.Equal(t, defaultTolerations, schedulerSpec.Tolerations) assert.Equal(t, defaultNodeSelector, schedulerSpec.NodeSelector) - assert.Equal(t, defaultAffinity, schedulerSpec.Affinity) + assert.Equal(t, &defaultAffinity, schedulerSpec.Affinity) // Default Workers - Should be interruptible workerSpec := daskJob.Spec.Cluster.Spec.Worker.Spec @@ -483,17 +497,16 @@ func TestBuildResouceDaskUsePodTemplate(t *testing.T) { flytek8s.DefaultPodTemplateStore.Store(podTemplate) daskResourceHandler := daskResourceHandler{} taskTemplate := dummyDaskTaskTemplate("", nil, podTemplateName) - taskContext := dummyDaskTaskContext(taskTemplate, nil, false) + taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, false) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) daskJob, ok := r.(*daskAPI.DaskJob) assert.True(t, ok) - // The job template has a custom service account set. This should be passed on to all three components - assert.Equal(t, templateServiceAccountName, daskJob.Spec.Job.Spec.ServiceAccountName) - assert.Equal(t, templateServiceAccountName, daskJob.Spec.Cluster.Spec.Scheduler.Spec.ServiceAccountName) - assert.Equal(t, templateServiceAccountName, daskJob.Spec.Cluster.Spec.Worker.Spec.ServiceAccountName) + assert.Equal(t, podTempaltePriorityClassName, daskJob.Spec.Job.Spec.PriorityClassName) + assert.Equal(t, podTempaltePriorityClassName, daskJob.Spec.Cluster.Spec.Scheduler.Spec.PriorityClassName) + assert.Equal(t, podTempaltePriorityClassName, daskJob.Spec.Cluster.Spec.Worker.Spec.PriorityClassName) // Cleanup flytek8s.DefaultPodTemplateStore.Delete(podTemplate) From b9787427b7ec3ab845ca1fb3e27134376aae978c Mon Sep 17 00:00:00 2001 From: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> Date: Mon, 17 Jul 2023 11:40:03 +0200 Subject: [PATCH 5/8] Fix linting Signed-off-by: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> --- go/tasks/plugins/k8s/dask/dask.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/go/tasks/plugins/k8s/dask/dask.go b/go/tasks/plugins/k8s/dask/dask.go index 459fbbbfd..c3c31b99e 100755 --- a/go/tasks/plugins/k8s/dask/dask.go +++ b/go/tasks/plugins/k8s/dask/dask.go @@ -185,6 +185,9 @@ func createWorkerSpec(cluster plugins.DaskWorkerGroup, podSpec *v1.PodSpec, prim return nil, err } } + if resources == nil { + resources = &v1.ResourceRequirements{} + } primaryContainer.Resources = *resources // Set custom args @@ -195,7 +198,7 @@ func createWorkerSpec(cluster plugins.DaskWorkerGroup, podSpec *v1.PodSpec, prim } // If limits are set, append `--nthreads` and `--memory-limit` as per these docs: // https://kubernetes.dask.org/en/latest/kubecluster.html?#best-practices - if resources != nil && resources.Limits != nil { + if resources.Limits != nil { limits := resources.Limits if limits.Cpu() != nil { cpuCount := fmt.Sprintf("%v", limits.Cpu().Value()) From 25b183ba3fe4b0ef52ea7d5436038cc923e25838 Mon Sep 17 00:00:00 2001 From: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> Date: Sat, 16 Sep 2023 21:42:52 +0200 Subject: [PATCH 6/8] Use `Always` restart policy for workers Signed-off-by: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> --- go/tasks/plugins/k8s/dask/dask.go | 4 ++++ go/tasks/plugins/k8s/dask/dask_test.go | 1 + 2 files changed, 5 insertions(+) diff --git a/go/tasks/plugins/k8s/dask/dask.go b/go/tasks/plugins/k8s/dask/dask.go index c3c31b99e..87fd298a3 100755 --- a/go/tasks/plugins/k8s/dask/dask.go +++ b/go/tasks/plugins/k8s/dask/dask.go @@ -215,6 +215,10 @@ func createWorkerSpec(cluster plugins.DaskWorkerGroup, podSpec *v1.PodSpec, prim if err != nil { return nil, err } + + // All workers are created as k8s deployment and must have a restart policy of Always + workerPodSpec.RestartPolicy = v1.RestartPolicyAlways + return &daskAPI.WorkerSpec{ Replicas: int(cluster.GetNumberOfWorkers()), Spec: *workerPodSpec, diff --git a/go/tasks/plugins/k8s/dask/dask_test.go b/go/tasks/plugins/k8s/dask/dask_test.go index f9d22c3d3..0aa89cf1b 100644 --- a/go/tasks/plugins/k8s/dask/dask_test.go +++ b/go/tasks/plugins/k8s/dask/dask_test.go @@ -301,6 +301,7 @@ func TestBuildResourceDaskHappyPath(t *testing.T) { "--memory-limit", "1Gi", }, workerSpec.Containers[0].Args) + assert.Equal(t, workerSpec.RestartPolicy, v1.RestartPolicyAlways) } func TestBuildResourceDaskCustomImages(t *testing.T) { From 7acfe0f3c6e7d656019502925ab1cd8b43211962 Mon Sep 17 00:00:00 2001 From: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> Date: Sun, 17 Sep 2023 11:25:36 +0200 Subject: [PATCH 7/8] Add test which checks whether labels are propagated Signed-off-by: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> --- go/tasks/plugins/k8s/dask/dask_test.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/go/tasks/plugins/k8s/dask/dask_test.go b/go/tasks/plugins/k8s/dask/dask_test.go index 0aa89cf1b..011c78523 100644 --- a/go/tasks/plugins/k8s/dask/dask_test.go +++ b/go/tasks/plugins/k8s/dask/dask_test.go @@ -44,6 +44,7 @@ var ( "execute-dask-task", } testAnnotations = map[string]string{"annotation-1": "val1"} + testLabels = map[string]string{"label-1": "val1"} testPlatformResources = v1.ResourceRequirements{ Requests: v1.ResourceList{ v1.ResourceCPU: resource.MustParse("4"), @@ -182,7 +183,7 @@ func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Resourc taskExecutionMetadata := &mocks.TaskExecutionMetadata{} taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) taskExecutionMetadata.OnGetAnnotations().Return(testAnnotations) - taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"}) + taskExecutionMetadata.OnGetLabels().Return(testLabels) taskExecutionMetadata.OnGetPlatformResources().Return(&testPlatformResources) taskExecutionMetadata.OnGetMaxAttempts().Return(uint32(1)) taskExecutionMetadata.OnIsInterruptible().Return(isInterruptible) @@ -218,6 +219,7 @@ func TestBuildResourceDaskHappyPath(t *testing.T) { // Job jobSpec := daskJob.Spec.Job.Spec assert.Equal(t, testAnnotations, daskJob.ObjectMeta.GetAnnotations()) + assert.Equal(t, testLabels, daskJob.ObjectMeta.GetLabels()) assert.Equal(t, v1.RestartPolicyNever, jobSpec.RestartPolicy) assert.Equal(t, "job-runner", jobSpec.Containers[0].Name) assert.Equal(t, defaultTestImage, jobSpec.Containers[0].Image) @@ -227,11 +229,12 @@ func TestBuildResourceDaskHappyPath(t *testing.T) { assert.Equal(t, defaultNodeSelector, jobSpec.NodeSelector) assert.Equal(t, defaultAffinity, jobSpec.Affinity) - // Flyte adds more environment variables to the driver + // Flyte adds more environment variables to the runner assert.Contains(t, jobSpec.Containers[0].Env, testEnvVars[0]) // Cluster assert.Equal(t, testAnnotations, daskJob.Spec.Cluster.ObjectMeta.GetAnnotations()) + assert.Equal(t, testLabels, daskJob.Spec.Cluster.ObjectMeta.GetLabels()) // Scheduler schedulerSpec := daskJob.Spec.Cluster.Spec.Scheduler.Spec From fba2d0b4d829c9822bbe0f6015d9920a36aa7f28 Mon Sep 17 00:00:00 2001 From: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> Date: Sun, 17 Sep 2023 11:58:16 +0200 Subject: [PATCH 8/8] Replace `removeInterruptibleConfig` with `TaskExectuionMetadata` wrapper Signed-off-by: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> --- go/tasks/plugins/k8s/dask/dask.go | 75 ++++++++++---------------- go/tasks/plugins/k8s/dask/dask_test.go | 16 ++---- 2 files changed, 30 insertions(+), 61 deletions(-) diff --git a/go/tasks/plugins/k8s/dask/dask.go b/go/tasks/plugins/k8s/dask/dask.go index 669cd2799..cb0d9ec93 100755 --- a/go/tasks/plugins/k8s/dask/dask.go +++ b/go/tasks/plugins/k8s/dask/dask.go @@ -5,8 +5,6 @@ import ( "fmt" "time" - "golang.org/x/exp/slices" - daskAPI "github.com/dask/dask-kubernetes/v2023/dask_kubernetes/operator/go_client/pkg/apis/kubernetes.dask.org/v1" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/flyteorg/flyteplugins/go/tasks/errors" @@ -14,7 +12,6 @@ import ( "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" @@ -30,6 +27,27 @@ const ( KindDaskJob = "DaskJob" ) +// Wraps a regular TaskExecutionMetadata and overrides the IsInterruptible method to always return false +// This is useful as the runner and the scheduler pods should never be interruptable +type nonInterruptibleTaskExecutionMetadata struct { + pluginsCore.TaskExecutionMetadata +} + +func (n nonInterruptibleTaskExecutionMetadata) IsInterruptible() bool { + return false +} + +// A wrapper around a regular TaskExecutionContext allowing to inject a custom TaskExecutionMetadata which is +// non-interruptible +type nonInterruptibleTaskExecutionContext struct { + pluginsCore.TaskExecutionContext + metadata nonInterruptibleTaskExecutionMetadata +} + +func (n nonInterruptibleTaskExecutionContext) TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata { + return n.metadata +} + func mergeMapInto(src map[string]string, dst map[string]string) { for key, value := range src { dst[key] = value @@ -55,49 +73,6 @@ func replacePrimaryContainer(spec *v1.PodSpec, primaryContainerName string, cont return errors.Errorf(errors.BadTaskSpecification, "primary container [%v] not found in pod spec", primaryContainerName) } -func removeInterruptibleConfig(spec *v1.PodSpec, taskCtx pluginsCore.TaskExecutionContext) { - if !taskCtx.TaskExecutionMetadata().IsInterruptible() { - return - } - - // Tolerations - interruptlibleTolerations := config.GetK8sPluginConfig().InterruptibleTolerations - newTolerations := []v1.Toleration{} - for _, toleration := range spec.Tolerations { - if !slices.Contains(interruptlibleTolerations, toleration) { - newTolerations = append(newTolerations, toleration) - } - } - spec.Tolerations = newTolerations - - // Node selectors - interruptibleNodeSelector := config.GetK8sPluginConfig().InterruptibleNodeSelector - for key := range spec.NodeSelector { - if _, ok := interruptibleNodeSelector[key]; ok { - delete(spec.NodeSelector, key) - } - } - - // Node selector requirements - interruptibleNodeSelectorRequirements := config.GetK8sPluginConfig().InterruptibleNodeSelectorRequirement - nodeSelectorTerms := spec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms - for i := range nodeSelectorTerms { - nst := &nodeSelectorTerms[i] - matchExpressions := nst.MatchExpressions - newMatchExpressions := []v1.NodeSelectorRequirement{} - for _, matchExpression := range matchExpressions { - if !nodeSelectorRequirementsAreEqual(matchExpression, *interruptibleNodeSelectorRequirements) { - newMatchExpressions = append(newMatchExpressions, matchExpression) - } - } - nst.MatchExpressions = newMatchExpressions - } -} - -func nodeSelectorRequirementsAreEqual(a v1.NodeSelectorRequirement, b v1.NodeSelectorRequirement) bool { - return a.Key == b.Key && a.Operator == b.Operator && slices.Equal(a.Values, b.Values) -} - type daskResourceHandler struct { } @@ -129,8 +104,12 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC if err != nil { return nil, err } - nonInterruptiblePodSpec := podSpec.DeepCopy() - removeInterruptibleConfig(nonInterruptiblePodSpec, taskCtx) + nonInterruptibleTaskMetadata := nonInterruptibleTaskExecutionMetadata{taskCtx.TaskExecutionMetadata()} + nonInterruptibleTaskCtx := nonInterruptibleTaskExecutionContext{taskCtx, nonInterruptibleTaskMetadata} + nonInterruptiblePodSpec, _, _, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx) + if err != nil { + return nil, err + } // Add labels and annotations to objectMeta as they're not added by ToK8sPodSpec mergeMapInto(taskCtx.TaskExecutionMetadata().GetAnnotations(), objectMeta.Annotations) diff --git a/go/tasks/plugins/k8s/dask/dask_test.go b/go/tasks/plugins/k8s/dask/dask_test.go index 73ea60857..966d2293d 100644 --- a/go/tasks/plugins/k8s/dask/dask_test.go +++ b/go/tasks/plugins/k8s/dask/dask_test.go @@ -428,18 +428,8 @@ func TestBuildResourcesDaskCustomResoureRequirements(t *testing.T) { func TestBuildResourceDaskInterruptible(t *testing.T) { defaultNodeSelector := map[string]string{} - defaultAffinity := v1.Affinity{ - NodeAffinity: &v1.NodeAffinity{ - RequiredDuringSchedulingIgnoredDuringExecution: &v1.NodeSelector{ - NodeSelectorTerms: []v1.NodeSelectorTerm{ - { - MatchExpressions: []v1.NodeSelectorRequirement{}, - }, - }, - }, - }, - } - defaultTolerations := []v1.Toleration{} + var defaultAffinity v1.Affinity + var defaultTolerations []v1.Toleration interruptibleNodeSelector := map[string]string{ "x/interruptible": "true", @@ -544,7 +534,7 @@ func TestGetTaskPhaseDask(t *testing.T) { daskResourceHandler := daskResourceHandler{} ctx := context.TODO() - taskTemplate := dummyDaskTaskTemplate("", nil) + taskTemplate := dummyDaskTaskTemplate("", nil, "") taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, false) taskPhase, err := daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(""))