Skip to content

Commit

Permalink
Merge 29c0688 into 10d0775
Browse files Browse the repository at this point in the history
  • Loading branch information
jeevb authored Oct 29, 2023
2 parents 10d0775 + 29c0688 commit 6f713c2
Show file tree
Hide file tree
Showing 3 changed files with 337 additions and 47 deletions.
3 changes: 3 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/ray/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package ray
import (
"context"

v1 "k8s.io/api/core/v1"

pluginsConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/config"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/logs"
pluginmachinery "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s"
Expand Down Expand Up @@ -78,6 +80,7 @@ type Config struct {
// Remote Ray Cluster Config
RemoteClusterConfig pluginmachinery.ClusterConfig `json:"remoteClusterConfig" pflag:"Configuration of remote K8s cluster for ray jobs"`
Logs logs.LogConfig `json:"logs" pflag:"-,Log configuration for ray jobs"`
LogsSidecar *v1.Container `json:"logsSidecar" pflag:"-,Sidecar to inject into head pods for capturing ray job logs"`
Defaults DefaultConfig `json:"defaults" pflag:"-,Default configuration for ray jobs"`
EnableUsageStats bool `json:"enableUsageStats" pflag:",Enable usage stats for ray jobs. These stats are submitted to usage-stats.ray.io per https://docs.ray.io/en/latest/cluster/usage-stats.html"`
}
Expand Down
113 changes: 90 additions & 23 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
)

const (
rayStateMountPath = "/tmp/ray"
defaultRayStateVolName = "system-ray-state"
rayTaskType = "ray"
KindRayJob = "RayJob"
IncludeDashboard = "include-dashboard"
Expand Down Expand Up @@ -61,17 +63,18 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}

var container v1.Container
found := false
for _, c := range podSpec.Containers {
var primaryContainer *v1.Container
var primaryContainerIdx int
for idx, c := range podSpec.Containers {
if c.Name == primaryContainerName {
container = c
found = true
c := c
primaryContainer = &c
primaryContainerIdx = idx
break
}
}

if !found {
if primaryContainer == nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to get primary container from the pod: [%v]", err.Error())
}

Expand Down Expand Up @@ -101,9 +104,15 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
}

enableIngress := true
headPodSpec := podSpec.DeepCopy()
rayClusterSpec := rayv1alpha1.RayClusterSpec{
HeadGroupSpec: rayv1alpha1.HeadGroupSpec{
Template: buildHeadPodTemplate(&container, podSpec, objectMeta, taskCtx),
Template: buildHeadPodTemplate(
&headPodSpec.Containers[primaryContainerIdx],
headPodSpec,
objectMeta,
taskCtx,
),
ServiceType: v1.ServiceType(cfg.ServiceType),
Replicas: &headReplicas,
EnableIngress: &enableIngress,
Expand All @@ -113,7 +122,13 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
}

for _, spec := range rayJob.RayCluster.WorkerGroupSpec {
workerPodTemplate := buildWorkerPodTemplate(&container, podSpec, objectMeta, taskCtx)
workerPodSpec := podSpec.DeepCopy()
workerPodTemplate := buildWorkerPodTemplate(
&workerPodSpec.Containers[primaryContainerIdx],
workerPodSpec,
objectMeta,
taskCtx,
)

minReplicas := spec.Replicas
maxReplicas := spec.Replicas
Expand Down Expand Up @@ -161,7 +176,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC

jobSpec := rayv1alpha1.RayJobSpec{
RayClusterSpec: rayClusterSpec,
Entrypoint: strings.Join(container.Args, " "),
Entrypoint: strings.Join(primaryContainer.Args, " "),
ShutdownAfterJobFinishes: cfg.ShutdownAfterJobFinishes,
TTLSecondsAfterFinished: &cfg.TTLSecondsAfterFinished,
RuntimeEnv: rayJob.RuntimeEnv,
Expand All @@ -179,10 +194,66 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
return &rayJobObject, nil
}

func buildHeadPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec {
func injectLogsSidecar(primaryContainer *v1.Container, podSpec *v1.PodSpec) {
cfg := GetConfig()
if cfg.LogsSidecar == nil {
return
}
sidecar := cfg.LogsSidecar.DeepCopy()

// Ray logs integration
var rayStateVolMount *v1.VolumeMount
// Look for an existing volume mount on the primary container, mounted at /tmp/ray
for _, vm := range primaryContainer.VolumeMounts {
if vm.MountPath == rayStateMountPath {
vm := vm
rayStateVolMount = &vm
break
}
}
// No existing volume mount exists at /tmp/ray. We create a new volume and volume
// mount and add it to the pod and container specs respectively
if rayStateVolMount == nil {
vol := v1.Volume{
Name: defaultRayStateVolName,
VolumeSource: v1.VolumeSource{
EmptyDir: &v1.EmptyDirVolumeSource{},
},
}
podSpec.Volumes = append(podSpec.Volumes, vol)
volMount := v1.VolumeMount{
Name: defaultRayStateVolName,
MountPath: rayStateMountPath,
}
primaryContainer.VolumeMounts = append(primaryContainer.VolumeMounts, volMount)
rayStateVolMount = &volMount
}
// We need to mirror the ray state volume mount into the sidecar as readonly,
// so that we can read the logs written by the head node.
readOnlyRayStateVolMount := *rayStateVolMount.DeepCopy()
readOnlyRayStateVolMount.ReadOnly = true

// Update volume mounts on sidecar
// If one already exists with the desired mount path, simply replace it. Otherwise,
// add it to sidecar's volume mounts.
foundExistingSidecarVolMount := false
for idx, vm := range sidecar.VolumeMounts {
if vm.MountPath == rayStateMountPath {
foundExistingSidecarVolMount = true
sidecar.VolumeMounts[idx] = readOnlyRayStateVolMount

Check warning on line 243 in flyteplugins/go/tasks/plugins/k8s/ray/ray.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/k8s/ray/ray.go#L241-L243

Added lines #L241 - L243 were not covered by tests
}
}
if !foundExistingSidecarVolMount {
sidecar.VolumeMounts = append(sidecar.VolumeMounts, readOnlyRayStateVolMount)
}

// Add sidecar to containers
podSpec.Containers = append(podSpec.Containers, *sidecar)
}

func buildHeadPodTemplate(primaryContainer *v1.Container, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec {
// Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L97
// They should always be the same, so we could hard code here.
primaryContainer := container.DeepCopy()
primaryContainer.Name = "ray-head"

envs := []v1.EnvVar{
Expand Down Expand Up @@ -217,12 +288,11 @@ func buildHeadPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMe

primaryContainer.Ports = append(primaryContainer.Ports, ports...)

headPodSpec := podSpec.DeepCopy()

headPodSpec.Containers = []v1.Container{*primaryContainer}
// Inject a sidecar for capturing and exposing Ray job logs
injectLogsSidecar(primaryContainer, podSpec)

podTemplateSpec := v1.PodTemplateSpec{
Spec: *headPodSpec,
Spec: *podSpec,
ObjectMeta: *objectMeta,
}
cfg := config.GetK8sPluginConfig()
Expand All @@ -231,7 +301,7 @@ func buildHeadPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMe
return podTemplateSpec
}

func buildWorkerPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMetadata *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec {
func buildWorkerPodTemplate(primaryContainer *v1.Container, podSpec *v1.PodSpec, objectMetadata *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec {
// Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L185
// They should always be the same, so we could hard code here.
initContainers := []v1.Container{
Expand All @@ -243,10 +313,11 @@ func buildWorkerPodTemplate(container *v1.Container, podSpec *v1.PodSpec, object
"-c",
"until nslookup $RAY_IP.$(cat /var/run/secrets/kubernetes.io/serviceaccount/namespace).svc.cluster.local; do echo waiting for myservice; sleep 2; done",
},
Resources: container.Resources,
Resources: primaryContainer.Resources,
},
}
primaryContainer := container.DeepCopy()
podSpec.InitContainers = append(podSpec.InitContainers, initContainers...)

primaryContainer.Name = "ray-worker"

primaryContainer.Args = []string{}
Expand Down Expand Up @@ -342,12 +413,8 @@ func buildWorkerPodTemplate(container *v1.Container, podSpec *v1.PodSpec, object
}
primaryContainer.Ports = append(primaryContainer.Ports, ports...)

workerPodSpec := podSpec.DeepCopy()
workerPodSpec.Containers = []v1.Container{*primaryContainer}
workerPodSpec.InitContainers = initContainers

podTemplateSpec := v1.PodTemplateSpec{
Spec: *workerPodSpec,
Spec: *podSpec,
ObjectMeta: *objectMetadata,
}
podTemplateSpec.SetLabels(utils.UnionMaps(podTemplateSpec.GetLabels(), utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())))
Expand Down
Loading

0 comments on commit 6f713c2

Please sign in to comment.