Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Apply default pod template to PytorchJob pods (#297)
Browse files Browse the repository at this point in the history
* Merge pod template spec with pod spec in separate func

Signed-off-by: Fabio Grätz <[email protected]>

* Merge pod template spec with pod spec in separate func

Signed-off-by: Fabio Grätz <[email protected]>

* Apply pod template to pytorch job pod spec

Signed-off-by: Fabio Grätz <[email protected]>

* Handle nil podspecs before merging

Signed-off-by: Fabio Grätz <[email protected]>

* Pass both default and primare container name to MergePodSpecs

Signed-off-by: Fabio Grätz <[email protected]>

* Move podSpec.DeepCopy into MergePodSpecs

Signed-off-by: Fabio Grätz <[email protected]>

* Add tests

Signed-off-by: Fabio Grätz <[email protected]>

* Merge pod template into tfjob and mpijob

Signed-off-by: Fabio Grätz <[email protected]>

* Lint

Signed-off-by: Fabio Grätz <[email protected]>

* Correct usage of default and primate container (template) name

Signed-off-by: Fabio Grätz <[email protected]>

* Override mpi default container name

Signed-off-by: Fabio Grätz <[email protected]>

* Carry over ObjectMeta from pod template

Signed-off-by: Fabio Grätz <[email protected]>

* Remove old `TestBuildPodWithSpec` test

Signed-off-by: Fabio Grätz <[email protected]>

Signed-off-by: Fabio Grätz <[email protected]>
Co-authored-by: Fabio Grätz <[email protected]>
  • Loading branch information
fg91 and Fabio Grätz authored Dec 16, 2022
1 parent 17b3010 commit 968edb3
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 91 deletions.
125 changes: 73 additions & 52 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package flytek8s

import (
"context"
"errors"
"fmt"
"strings"
"time"
Expand Down Expand Up @@ -144,70 +145,87 @@ func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*
return pod, nil
}

func BuildPodWithSpec(podTemplate *v1.PodTemplate, podSpec *v1.PodSpec, primaryContainerName string) (*v1.Pod, error) {
pod := v1.Pod{
TypeMeta: v12.TypeMeta{
Kind: PodKind,
APIVersion: v1.SchemeGroupVersion.String(),
},
func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string) (*v1.PodSpec, error) {
if podTemplatePodSpec == nil || podSpec == nil {
return nil, errors.New("podTemplatePodSpec and podSpec cannot be nil")
}

if podTemplate != nil {
// merge template PodSpec
basePodSpec := podTemplate.Template.Spec.DeepCopy()
err := mergo.Merge(basePodSpec, podSpec, mergo.WithOverride, mergo.WithAppendSlice)
if err != nil {
return nil, err
}
var podTemplatePodSpecCopy *v1.PodSpec = podTemplatePodSpec.DeepCopy()

// merge template Containers
var mergedContainers []v1.Container
var defaultContainerTemplate, primaryContainerTemplate *v1.Container
for i := 0; i < len(podTemplate.Template.Spec.Containers); i++ {
if podTemplate.Template.Spec.Containers[i].Name == defaultContainerTemplateName {
defaultContainerTemplate = &podTemplate.Template.Spec.Containers[i]
} else if podTemplate.Template.Spec.Containers[i].Name == primaryContainerTemplateName {
primaryContainerTemplate = &podTemplate.Template.Spec.Containers[i]
}
}
err := mergo.Merge(podTemplatePodSpecCopy, podSpec, mergo.WithOverride, mergo.WithAppendSlice)
if err != nil {
return nil, err
}

for _, container := range podSpec.Containers {
// if applicable start with defaultContainerTemplate
var mergedContainer *v1.Container
if defaultContainerTemplate != nil {
mergedContainer = defaultContainerTemplate.DeepCopy()
}
// merge template Containers
var mergedContainers []v1.Container
var defaultContainerTemplate, primaryContainerTemplate *v1.Container
for i := 0; i < len(podTemplatePodSpecCopy.Containers); i++ {
if podTemplatePodSpecCopy.Containers[i].Name == defaultContainerTemplateName {
defaultContainerTemplate = &podTemplatePodSpecCopy.Containers[i]
} else if podTemplatePodSpecCopy.Containers[i].Name == primaryContainerTemplateName {
primaryContainerTemplate = &podTemplatePodSpecCopy.Containers[i]
}
}

// if applicable merge with primaryContainerTemplate
if container.Name == primaryContainerName && primaryContainerTemplate != nil {
if mergedContainer == nil {
mergedContainer = primaryContainerTemplate.DeepCopy()
} else {
err := mergo.Merge(mergedContainer, primaryContainerTemplate, mergo.WithOverride, mergo.WithAppendSlice)
if err != nil {
return nil, err
}
}
}
for _, container := range podSpec.Containers {
// if applicable start with defaultContainerTemplate
var mergedContainer *v1.Container
if defaultContainerTemplate != nil {
mergedContainer = defaultContainerTemplate.DeepCopy()
}

// if applicable merge with existing container
// if applicable merge with primaryContainerTemplate
if container.Name == primaryContainerName && primaryContainerTemplate != nil {
if mergedContainer == nil {
mergedContainers = append(mergedContainers, container)
mergedContainer = primaryContainerTemplate.DeepCopy()
} else {
err := mergo.Merge(mergedContainer, container, mergo.WithOverride, mergo.WithAppendSlice)
err := mergo.Merge(mergedContainer, primaryContainerTemplate, mergo.WithOverride, mergo.WithAppendSlice)
if err != nil {
return nil, err
}
}
}

// if applicable merge with existing container # TODO test
if mergedContainer == nil {
mergedContainers = append(mergedContainers, container)

mergedContainers = append(mergedContainers, *mergedContainer)
} else {
err := mergo.Merge(mergedContainer, container, mergo.WithOverride, mergo.WithAppendSlice)
if err != nil {
return nil, err
}

mergedContainers = append(mergedContainers, *mergedContainer)
}

}

// update Pod fields
podTemplatePodSpecCopy.Containers = mergedContainers

return podTemplatePodSpecCopy, nil
}

func BuildPodWithSpec(podTemplate *v1.PodTemplate, podSpec *v1.PodSpec, primaryContainerName string) (*v1.Pod, error) {
pod := v1.Pod{
TypeMeta: v12.TypeMeta{
Kind: PodKind,
APIVersion: v1.SchemeGroupVersion.String(),
},
}

if podTemplate != nil {
// merge template PodSpec
mergedPodSpec, err := MergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerName)
if err != nil {
return nil, err
}

// update Pod fields
basePodSpec.Containers = mergedContainers
pod.ObjectMeta = podTemplate.Template.ObjectMeta
pod.Spec = *basePodSpec
pod.Spec = *mergedPodSpec

} else {
pod.Spec = *podSpec
}
Expand All @@ -231,12 +249,15 @@ func BuildIdentityPod() *v1.Pod {
// Important considerations.
// Pending Status in Pod could be for various reasons and sometimes could signal a problem
// Case I: Pending because the Image pull is failing and it is backing off
// This could be transient. So we can actually rely on the failure reason.
// The failure transitions from ErrImagePull -> ImagePullBackoff
//
// This could be transient. So we can actually rely on the failure reason.
// The failure transitions from ErrImagePull -> ImagePullBackoff
//
// Case II: Not enough resources are available. This is tricky. It could be that the total number of
// resources requested is beyond the capability of the system. for this we will rely on configuration
// and hence input gates. We should not allow bad requests that Request for large number of resource through.
// In the case it makes through, we will fail after timeout
//
// resources requested is beyond the capability of the system. for this we will rely on configuration
// and hence input gates. We should not allow bad requests that Request for large number of resource through.
// In the case it makes through, we will fail after timeout
func DemystifyPending(status v1.PodStatus) (pluginsCore.PhaseInfo, error) {
// Search over the difference conditions in the status object. Note that the 'Pending' this function is
// demystifying is the 'phase' of the pod status. This is different than the PodReady condition type also used below
Expand Down
113 changes: 83 additions & 30 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
pluginsCoreMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
Expand Down Expand Up @@ -1013,8 +1014,18 @@ func TestDeterminePrimaryContainerPhase(t *testing.T) {
})
}

func TestBuildPodWithSpec(t *testing.T) {
func TestMergePodSpecs(t *testing.T) {
var priority int32 = 1

podSpec1, _ := MergePodSpecs(nil, nil, "foo")
assert.Nil(t, podSpec1)

podSpec2, _ := MergePodSpecs(&v1.PodSpec{}, nil, "foo")
assert.Nil(t, podSpec2)

podSpec3, _ := MergePodSpecs(nil, &v1.PodSpec{}, "foo")
assert.Nil(t, podSpec3)

podSpec := v1.PodSpec{
Containers: []v1.Container{
v1.Container{
Expand All @@ -1039,10 +1050,6 @@ func TestBuildPodWithSpec(t *testing.T) {
},
}

pod, err := BuildPodWithSpec(nil, &podSpec, "foo")
assert.Nil(t, err)
assert.True(t, reflect.DeepEqual(pod.Spec, podSpec))

defaultContainerTemplate := v1.Container{
Name: defaultContainerTemplateName,
TerminationMessagePath: "/dev/default-termination-log",
Expand All @@ -1053,48 +1060,94 @@ func TestBuildPodWithSpec(t *testing.T) {
TerminationMessagePath: "/dev/primary-termination-log",
}

podTemplate := v1.PodTemplate{
Template: v1.PodTemplateSpec{
Spec: v1.PodSpec{
Containers: []v1.Container{
defaultContainerTemplate,
primaryContainerTemplate,
},
HostNetwork: true,
NodeSelector: map[string]string{
"foo": "bar",
},
SchedulerName: "defaultScheduler",
Tolerations: []v1.Toleration{
v1.Toleration{
Key: "foo",
},
},
podTemplateSpec := v1.PodSpec{
Containers: []v1.Container{
defaultContainerTemplate,
primaryContainerTemplate,
},
HostNetwork: true,
NodeSelector: map[string]string{
"foo": "bar",
},
SchedulerName: "defaultScheduler",
Tolerations: []v1.Toleration{
v1.Toleration{
Key: "foo",
},
},
}

pod, err = BuildPodWithSpec(&podTemplate, &podSpec, "foo")
mergedPodSpec, err := MergePodSpecs(&podTemplateSpec, &podSpec, "foo")
assert.Nil(t, err)

// validate a PodTemplate-only field
assert.Equal(t, podTemplate.Template.Spec.HostNetwork, pod.Spec.HostNetwork)
assert.Equal(t, podTemplateSpec.HostNetwork, mergedPodSpec.HostNetwork)
// validate a PodSpec-only field
assert.Equal(t, podSpec.Priority, pod.Spec.Priority)
assert.Equal(t, podSpec.Priority, mergedPodSpec.Priority)
// validate an overwritten PodTemplate field
assert.Equal(t, podSpec.SchedulerName, pod.Spec.SchedulerName)
assert.Equal(t, podSpec.SchedulerName, mergedPodSpec.SchedulerName)
// validate a merged map
assert.Equal(t, len(podTemplate.Template.Spec.NodeSelector)+len(podSpec.NodeSelector), len(pod.Spec.NodeSelector))
assert.Equal(t, len(podTemplateSpec.NodeSelector)+len(podSpec.NodeSelector), len(mergedPodSpec.NodeSelector))
// validate an appended array
assert.Equal(t, len(podTemplate.Template.Spec.Tolerations)+len(podSpec.Tolerations), len(pod.Spec.Tolerations))
assert.Equal(t, len(podTemplateSpec.Tolerations)+len(podSpec.Tolerations), len(mergedPodSpec.Tolerations))

// validate primary container
primaryContainer := pod.Spec.Containers[0]
primaryContainer := mergedPodSpec.Containers[0]
assert.Equal(t, podSpec.Containers[0].Name, primaryContainer.Name)
assert.Equal(t, primaryContainerTemplate.TerminationMessagePath, primaryContainer.TerminationMessagePath)

// validate default container
defaultContainer := pod.Spec.Containers[1]
defaultContainer := mergedPodSpec.Containers[1]
assert.Equal(t, podSpec.Containers[1].Name, defaultContainer.Name)
assert.Equal(t, defaultContainerTemplate.TerminationMessagePath, defaultContainer.TerminationMessagePath)

}

func TestBuildPodWithSpec(t *testing.T) {
podSpec := v1.PodSpec{
Containers: []v1.Container{
v1.Container{
Name: "foo",
},
v1.Container{
Name: "bar",
},
},
}

pod, err := BuildPodWithSpec(nil, &podSpec, "foo")
assert.Nil(t, err)
assert.True(t, reflect.DeepEqual(pod.Spec, podSpec))

primaryContainerTemplate := v1.Container{
Name: primaryContainerTemplateName,
TerminationMessagePath: "/dev/primary-termination-log",
}

podTemplate := v1.PodTemplate{
Template: v1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
"fooKey": "barVal",
},
},
Spec: v1.PodSpec{
Containers: []v1.Container{
primaryContainerTemplate,
},
},
},
}

pod, err = BuildPodWithSpec(&podTemplate, &podSpec, "foo")
assert.Nil(t, err)

// Test that template podSpec is merged
primaryContainer := pod.Spec.Containers[0]
assert.Equal(t, podSpec.Containers[0].Name, primaryContainer.Name)
assert.Equal(t, primaryContainerTemplate.TerminationMessagePath, primaryContainer.TerminationMessagePath)

// Test that template object metadata is copied
assert.Contains(t, pod.ObjectMeta.Labels, "fooKey")
assert.Equal(t, pod.ObjectMeta.Labels["fooKey"], "barVal")
}
21 changes: 19 additions & 2 deletions flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}

common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.MPIJobDefaultContainerName)

podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace())

objectMeta := metav1.ObjectMeta{}

if podTemplate != nil {
mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, kubeflowv1.MPIJobDefaultContainerName)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error())
}
podSpec = mergedPodSpec
objectMeta = podTemplate.Template.ObjectMeta
}

// workersPodSpec is deepCopy of podSpec submitted by flyte
// WorkerPodSpec doesn't need any Argument & command. It will be trigger from launcher pod
workersPodSpec := podSpec.DeepCopy()
Expand All @@ -89,14 +104,16 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu
kubeflowv1.MPIJobReplicaTypeLauncher: {
Replicas: &launcherReplicas,
Template: v1.PodTemplateSpec{
Spec: *podSpec,
ObjectMeta: objectMeta,
Spec: *podSpec,
},
RestartPolicy: commonKf.RestartPolicyNever,
},
kubeflowv1.MPIJobReplicaTypeWorker: {
Replicas: &workers,
Template: v1.PodTemplateSpec{
Spec: *workersPodSpec,
ObjectMeta: objectMeta,
Spec: *workersPodSpec,
},
RestartPolicy: commonKf.RestartPolicyNever,
},
Expand Down
Loading

0 comments on commit 968edb3

Please sign in to comment.