Skip to content

Commit

Permalink
Update toleration and annotations in ray pod spec (flyteorg#302)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Dec 22, 2022
1 parent c2cfdbe commit 1505002
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 12 deletions.
32 changes: 21 additions & 11 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
logger.Errorf(ctx, "Default Pod creation logic works for default container in the task template only.")
return nil, fmt.Errorf("container not specified in task template")
}

templateParameters := template.Parameters{
Task: taskCtx.TaskReader(),
Inputs: taskCtx.InputReader(),
Expand Down Expand Up @@ -95,7 +96,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
enableIngress := true
rayClusterSpec := rayv1alpha1.RayClusterSpec{
HeadGroupSpec: rayv1alpha1.HeadGroupSpec{
Template: buildHeadPodTemplate(container),
Template: buildHeadPodTemplate(container, taskCtx),
ServiceType: v1.ServiceType(GetConfig().ServiceType),
Replicas: &headReplicas,
EnableIngress: &enableIngress,
Expand All @@ -105,7 +106,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
}

for _, spec := range rayJob.RayCluster.WorkerGroupSpec {
workerPodTemplate := buildWorkerPodTemplate(container)
workerPodTemplate := buildWorkerPodTemplate(container, taskCtx)

minReplicas := spec.Replicas
maxReplicas := spec.Replicas
Expand Down Expand Up @@ -162,7 +163,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
return &rayJobObject, nil
}

func buildHeadPodTemplate(container *v1.Container) v1.PodTemplateSpec {
func buildHeadPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec {
// Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L97
// They should always be the same, so we could hard code here.
primaryContainer := &v1.Container{Name: "ray-head", Image: container.Image}
Expand Down Expand Up @@ -192,16 +193,20 @@ func buildHeadPodTemplate(container *v1.Container) v1.PodTemplateSpec {
ContainerPort: 8265,
},
}
pod := &v1.PodSpec{
Containers: []v1.Container{*primaryContainer},
}
flytek8s.UpdatePod(taskCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{primaryContainer.Resources}, pod)

podTemplateSpec := v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{*primaryContainer},
},
Spec: *pod,
}
podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())))
podTemplateSpec.SetAnnotations(utils.UnionMaps(podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())))
return podTemplateSpec
}

func buildWorkerPodTemplate(container *v1.Container) v1.PodTemplateSpec {
func buildWorkerPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec {
// Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L185
// They should always be the same, so we could hard code here.
initContainers := []v1.Container{
Expand Down Expand Up @@ -307,12 +312,17 @@ func buildWorkerPodTemplate(container *v1.Container) v1.PodTemplateSpec {
},
}

pod := &v1.PodSpec{
Containers: []v1.Container{*primaryContainer},
InitContainers: initContainers,
}
flytek8s.UpdatePod(taskCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{primaryContainer.Resources}, pod)

podTemplateSpec := v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{*primaryContainer},
InitContainers: initContainers,
},
Spec: *pod,
}
podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())))
podTemplateSpec.SetAnnotations(utils.UnionMaps(podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())))
return podTemplateSpec
}

Expand Down
20 changes: 19 additions & 1 deletion flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"testing"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/golang/protobuf/jsonpb"
structpb "github.com/golang/protobuf/ptypes/struct"
Expand Down Expand Up @@ -136,14 +138,24 @@ func dummyRayTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskExecut
taskExecutionMetadata.OnGetOverrides().Return(resources)
taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount)
taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{})
taskExecutionMetadata.OnGetSecurityContext().Return(core.SecurityContext{RunAs: &core.Identity{K8SServiceAccount: serviceAccount}})
taskExecutionMetadata.OnGetSecurityContext().Return(core.SecurityContext{
RunAs: &core.Identity{K8SServiceAccount: serviceAccount},
})
taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata)
return taskCtx
}

func TestBuildResourceRay(t *testing.T) {
rayJobResourceHandler := rayJobResourceHandler{}
taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObj())
toleration := []corev1.Toleration{{
Key: "storage",
Value: "dedicated",
Operator: corev1.TolerationOpExists,
Effect: corev1.TaintEffectNoSchedule,
}}
err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration})
assert.Nil(t, err)

RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate))
assert.Nil(t, err)
Expand All @@ -157,6 +169,9 @@ func TestBuildResourceRay(t *testing.T) {
assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, serviceAccount)
assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams,
map[string]string{"dashboard-host": "0.0.0.0", "include-dashboard": "true", "node-ip-address": "$MY_POD_IP", "num-cpus": "1"})
assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations, map[string]string{"annotation-1": "val1"})
assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels, map[string]string{"label-1": "val1"})
assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations, toleration)

workerReplica := int32(3)
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas, &workerReplica)
Expand All @@ -165,6 +180,9 @@ func TestBuildResourceRay(t *testing.T) {
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName, workerGroupName)
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName, serviceAccount)
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams, map[string]string{"node-ip-address": "$MY_POD_IP"})
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations, map[string]string{"annotation-1": "val1"})
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels, map[string]string{"label-1": "val1"})
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration)
}

func TestGetPropertiesRay(t *testing.T) {
Expand Down

0 comments on commit 1505002

Please sign in to comment.