diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 79c597da1d..02b3f060b5 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -62,6 +62,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC logger.Errorf(ctx, "Default Pod creation logic works for default container in the task template only.") return nil, fmt.Errorf("container not specified in task template") } + templateParameters := template.Parameters{ Task: taskCtx.TaskReader(), Inputs: taskCtx.InputReader(), @@ -95,7 +96,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC enableIngress := true rayClusterSpec := rayv1alpha1.RayClusterSpec{ HeadGroupSpec: rayv1alpha1.HeadGroupSpec{ - Template: buildHeadPodTemplate(container), + Template: buildHeadPodTemplate(container, taskCtx), ServiceType: v1.ServiceType(GetConfig().ServiceType), Replicas: &headReplicas, EnableIngress: &enableIngress, @@ -105,7 +106,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC } for _, spec := range rayJob.RayCluster.WorkerGroupSpec { - workerPodTemplate := buildWorkerPodTemplate(container) + workerPodTemplate := buildWorkerPodTemplate(container, taskCtx) minReplicas := spec.Replicas maxReplicas := spec.Replicas @@ -162,7 +163,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC return &rayJobObject, nil } -func buildHeadPodTemplate(container *v1.Container) v1.PodTemplateSpec { +func buildHeadPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec { // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L97 // They should always be the same, so we could hard code here. primaryContainer := &v1.Container{Name: "ray-head", Image: container.Image} @@ -192,16 +193,20 @@ func buildHeadPodTemplate(container *v1.Container) v1.PodTemplateSpec { ContainerPort: 8265, }, } + pod := &v1.PodSpec{ + Containers: []v1.Container{*primaryContainer}, + } + flytek8s.UpdatePod(taskCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{primaryContainer.Resources}, pod) podTemplateSpec := v1.PodTemplateSpec{ - Spec: v1.PodSpec{ - Containers: []v1.Container{*primaryContainer}, - }, + Spec: *pod, } + podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) + podTemplateSpec.SetAnnotations(utils.UnionMaps(podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) return podTemplateSpec } -func buildWorkerPodTemplate(container *v1.Container) v1.PodTemplateSpec { +func buildWorkerPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec { // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L185 // They should always be the same, so we could hard code here. initContainers := []v1.Container{ @@ -307,12 +312,17 @@ func buildWorkerPodTemplate(container *v1.Container) v1.PodTemplateSpec { }, } + pod := &v1.PodSpec{ + Containers: []v1.Container{*primaryContainer}, + InitContainers: initContainers, + } + flytek8s.UpdatePod(taskCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{primaryContainer.Resources}, pod) + podTemplateSpec := v1.PodTemplateSpec{ - Spec: v1.PodSpec{ - Containers: []v1.Container{*primaryContainer}, - InitContainers: initContainers, - }, + Spec: *pod, } + podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))) + podTemplateSpec.SetAnnotations(utils.UnionMaps(podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()))) return podTemplateSpec } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 2c13972d12..99da77b6e9 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" "github.com/golang/protobuf/jsonpb" structpb "github.com/golang/protobuf/ptypes/struct" @@ -136,7 +138,9 @@ func dummyRayTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskExecut taskExecutionMetadata.OnGetOverrides().Return(resources) taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount) taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{}) - taskExecutionMetadata.OnGetSecurityContext().Return(core.SecurityContext{RunAs: &core.Identity{K8SServiceAccount: serviceAccount}}) + taskExecutionMetadata.OnGetSecurityContext().Return(core.SecurityContext{ + RunAs: &core.Identity{K8SServiceAccount: serviceAccount}, + }) taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) return taskCtx } @@ -144,6 +148,14 @@ func dummyRayTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskExecut func TestBuildResourceRay(t *testing.T) { rayJobResourceHandler := rayJobResourceHandler{} taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj()) + toleration := []corev1.Toleration{{ + Key: "storage", + Value: "dedicated", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }} + err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) + assert.Nil(t, err) RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate)) assert.Nil(t, err) @@ -157,6 +169,9 @@ func TestBuildResourceRay(t *testing.T) { assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, serviceAccount) assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams, map[string]string{"dashboard-host": "0.0.0.0", "include-dashboard": "true", "node-ip-address": "$MY_POD_IP", "num-cpus": "1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations, map[string]string{"annotation-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels, map[string]string{"label-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations, toleration) workerReplica := int32(3) assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas, &workerReplica) @@ -165,6 +180,9 @@ func TestBuildResourceRay(t *testing.T) { assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName, workerGroupName) assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName, serviceAccount) assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams, map[string]string{"node-ip-address": "$MY_POD_IP"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations, map[string]string{"annotation-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels, map[string]string{"label-1": "val1"}) + assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) } func TestGetPropertiesRay(t *testing.T) {