Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Enable pod template and Use copy to construct head/worker in ray plugin #349

Merged
merged 4 commits into from
May 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 61 additions & 33 deletions go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/logs"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
Expand All @@ -22,6 +21,7 @@ import (

v1 "k8s.io/api/core/v1"

flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors"
ByronHsu marked this conversation as resolved.
Show resolved Hide resolved
"sigs.k8s.io/controller-runtime/pkg/client"
)

Expand All @@ -44,20 +44,35 @@ func (rayJobResourceHandler) GetProperties() k8s.PluginProperties {
func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
if err != nil {
return nil, errors.Errorf(errors.BadTaskSpecification, "unable to fetch task specification [%v]", err.Error())
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "unable to fetch task specification [%v]", err.Error())
} else if taskTemplate == nil {
return nil, errors.Errorf(errors.BadTaskSpecification, "nil task specification")
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification")
}

rayJob := plugins.RayJob{}
err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &rayJob)
if err != nil {
return nil, errors.Errorf(errors.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error())
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error())
}

container, err := flytek8s.ToK8sContainer(ctx, taskCtx)
podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)

if err != nil {
return nil, errors.Errorf(errors.BadTaskSpecification, "Unable to create container spec: [%v]", err.Error())
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}

var container v1.Container
found := false
for _, c := range podSpec.Containers {
if c.Name == primaryContainerName {
container = c
found = true
break
}
}

if !found {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to get primary container from the pod: [%v]", err.Error())
}

headReplicas := int32(1)
Expand All @@ -78,7 +93,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
enableIngress := true
rayClusterSpec := rayv1alpha1.RayClusterSpec{
HeadGroupSpec: rayv1alpha1.HeadGroupSpec{
Template: buildHeadPodTemplate(container, taskCtx),
Template: buildHeadPodTemplate(&container, podSpec, objectMeta, taskCtx),
ServiceType: v1.ServiceType(GetConfig().ServiceType),
Replicas: &headReplicas,
EnableIngress: &enableIngress,
Expand All @@ -88,7 +103,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
}

for _, spec := range rayJob.RayCluster.WorkerGroupSpec {
workerPodTemplate := buildWorkerPodTemplate(container, taskCtx)
workerPodTemplate := buildWorkerPodTemplate(&container, podSpec, objectMeta, taskCtx)

minReplicas := spec.Replicas
maxReplicas := spec.Replicas
Expand Down Expand Up @@ -139,18 +154,20 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
Kind: KindRayJob,
APIVersion: rayv1alpha1.SchemeGroupVersion.String(),
},
Spec: jobSpec,
Spec: jobSpec,
ObjectMeta: *objectMeta,
}

return &rayJobObject, nil
}

func buildHeadPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec {
func buildHeadPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec {
// Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L97
// They should always be the same, so we could hard code here.
primaryContainer := &v1.Container{Name: "ray-head", Image: container.Image}
primaryContainer.Resources = container.Resources
primaryContainer.Env = []v1.EnvVar{
primaryContainer := container.DeepCopy()
primaryContainer.Name = "ray-head"

envs := []v1.EnvVar{
{
Name: "MY_POD_IP",
ValueFrom: &v1.EnvVarSource{
Expand All @@ -160,8 +177,12 @@ func buildHeadPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExecu
},
},
}
primaryContainer.Env = append(primaryContainer.Env, container.Env...)
primaryContainer.Ports = []v1.ContainerPort{

primaryContainer.Args = []string{}

primaryContainer.Env = append(primaryContainer.Env, envs...)

ports := []v1.ContainerPort{
{
Name: "redis",
ContainerPort: 6379,
Expand All @@ -175,20 +196,23 @@ func buildHeadPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExecu
ContainerPort: 8265,
},
}
pod := &v1.PodSpec{
Containers: []v1.Container{*primaryContainer},
}
flytek8s.UpdatePod(taskCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{primaryContainer.Resources}, pod)

primaryContainer.Ports = append(primaryContainer.Ports, ports...)

headPodSpec := podSpec.DeepCopy()

headPodSpec.Containers = []v1.Container{*primaryContainer}

podTemplateSpec := v1.PodTemplateSpec{
Spec: *pod,
Spec: *headPodSpec,
ObjectMeta: *objectMeta,
}
podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())))
podTemplateSpec.SetAnnotations(utils.UnionMaps(podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())))
return podTemplateSpec
}

func buildWorkerPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec {
func buildWorkerPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMetadata *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec {
// Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L185
// They should always be the same, so we could hard code here.
initContainers := []v1.Container{
Expand All @@ -203,10 +227,12 @@ func buildWorkerPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExe
Resources: container.Resources,
},
}
primaryContainer := container.DeepCopy()
primaryContainer.Name = "ray-worker"

primaryContainer := &v1.Container{Name: "ray-worker", Image: container.Image}
primaryContainer.Resources = container.Resources
primaryContainer.Env = []v1.EnvVar{
primaryContainer.Args = []string{}

envs := []v1.EnvVar{
{
Name: "RAY_DISABLE_DOCKER_CPU_WARNING",
Value: "1",
Expand Down Expand Up @@ -268,7 +294,9 @@ func buildWorkerPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExe
},
},
}
primaryContainer.Env = append(primaryContainer.Env, container.Env...)

primaryContainer.Env = append(primaryContainer.Env, envs...)

primaryContainer.Lifecycle = &v1.Lifecycle{
PreStop: &v1.LifecycleHandler{
Exec: &v1.ExecAction{
Expand All @@ -279,7 +307,7 @@ func buildWorkerPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExe
},
}

primaryContainer.Ports = []v1.ContainerPort{
ports := []v1.ContainerPort{
{
Name: "redis",
ContainerPort: 6379,
Expand All @@ -293,15 +321,15 @@ func buildWorkerPodTemplate(container *v1.Container, taskCtx pluginsCore.TaskExe
ContainerPort: 8265,
},
}
primaryContainer.Ports = append(primaryContainer.Ports, ports...)

pod := &v1.PodSpec{
Containers: []v1.Container{*primaryContainer},
InitContainers: initContainers,
}
flytek8s.UpdatePod(taskCtx.TaskExecutionMetadata(), []v1.ResourceRequirements{primaryContainer.Resources}, pod)
workerPodSpec := podSpec.DeepCopy()
workerPodSpec.Containers = []v1.Container{*primaryContainer}
workerPodSpec.InitContainers = initContainers

podTemplateSpec := v1.PodTemplateSpec{
Spec: *pod,
Spec: *workerPodSpec,
ObjectMeta: *objectMetadata,
}
podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())))
podTemplateSpec.SetAnnotations(utils.UnionMaps(podTemplateSpec.GetAnnotations(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())))
Expand Down Expand Up @@ -350,7 +378,7 @@ func (rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s
return pluginsCore.PhaseInfoNotReady(time.Now(), pluginsCore.DefaultPhaseVersion, "job is pending"), nil
case rayv1alpha1.JobStatusFailed:
reason := fmt.Sprintf("Failed to create Ray job: %s", rayJob.Name)
return pluginsCore.PhaseInfoFailure(errors.TaskFailedWithError, reason, info), nil
return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil
case rayv1alpha1.JobStatusSucceeded:
return pluginsCore.PhaseInfoSuccess(info), nil
case rayv1alpha1.JobStatusRunning:
Expand Down