From 1dd1495790930d1174f2b214531f1431ff27734d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Thu, 24 Nov 2022 14:36:28 +0100 Subject: [PATCH 01/13] Merge pod template spec with pod spec in separate func MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- .../pluginmachinery/flytek8s/pod_helper.go | 119 ++++++++++-------- 1 file changed, 67 insertions(+), 52 deletions(-) diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 82f482285..54e983088 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -144,70 +144,82 @@ func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (* return pod, nil } -func BuildPodWithSpec(podTemplate *v1.PodTemplate, podSpec *v1.PodSpec, primaryContainerName string) (*v1.Pod, error) { - pod := v1.Pod{ - TypeMeta: v12.TypeMeta{ - Kind: PodKind, - APIVersion: v1.SchemeGroupVersion.String(), - }, +func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string) error { + err := mergo.Merge(podTemplatePodSpec, podSpec, mergo.WithOverride, mergo.WithAppendSlice) + if err != nil { + return err } - if podTemplate != nil { - // merge template PodSpec - basePodSpec := podTemplate.Template.Spec.DeepCopy() - err := mergo.Merge(basePodSpec, podSpec, mergo.WithOverride, mergo.WithAppendSlice) - if err != nil { - return nil, err + // merge template Containers + var mergedContainers []v1.Container + var defaultContainerTemplate, primaryContainerTemplate *v1.Container + for i := 0; i < len(podTemplatePodSpec.Containers); i++ { + if podTemplatePodSpec.Containers[i].Name == defaultContainerTemplateName { + defaultContainerTemplate = &podTemplatePodSpec.Containers[i] + } else if podTemplatePodSpec.Containers[i].Name == primaryContainerTemplateName { + primaryContainerTemplate = &podTemplatePodSpec.Containers[i] } + } - // merge template Containers - var mergedContainers []v1.Container - var defaultContainerTemplate, primaryContainerTemplate *v1.Container - for i := 0; i < len(podTemplate.Template.Spec.Containers); i++ { - if podTemplate.Template.Spec.Containers[i].Name == defaultContainerTemplateName { - defaultContainerTemplate = &podTemplate.Template.Spec.Containers[i] - } else if podTemplate.Template.Spec.Containers[i].Name == primaryContainerTemplateName { - primaryContainerTemplate = &podTemplate.Template.Spec.Containers[i] - } + for _, container := range podSpec.Containers { + // if applicable start with defaultContainerTemplate + var mergedContainer *v1.Container + if defaultContainerTemplate != nil { + mergedContainer = defaultContainerTemplate.DeepCopy() } - for _, container := range podSpec.Containers { - // if applicable start with defaultContainerTemplate - var mergedContainer *v1.Container - if defaultContainerTemplate != nil { - mergedContainer = defaultContainerTemplate.DeepCopy() - } - - // if applicable merge with primaryContainerTemplate - if container.Name == primaryContainerName && primaryContainerTemplate != nil { - if mergedContainer == nil { - mergedContainer = primaryContainerTemplate.DeepCopy() - } else { - err := mergo.Merge(mergedContainer, primaryContainerTemplate, mergo.WithOverride, mergo.WithAppendSlice) - if err != nil { - return nil, err - } - } - } - - // if applicable merge with existing container + // if applicable merge with primaryContainerTemplate + if container.Name == primaryContainerName && primaryContainerTemplate != nil { if mergedContainer == nil { - mergedContainers = append(mergedContainers, container) + mergedContainer = primaryContainerTemplate.DeepCopy() } else { - err := mergo.Merge(mergedContainer, container, mergo.WithOverride, mergo.WithAppendSlice) + err := mergo.Merge(mergedContainer, primaryContainerTemplate, mergo.WithOverride, mergo.WithAppendSlice) if err != nil { - return nil, err + return err } + } + } - mergedContainers = append(mergedContainers, *mergedContainer) + // if applicable merge with existing container # TODO test + if mergedContainer == nil { + mergedContainers = append(mergedContainers, container) + + } else { + err := mergo.Merge(mergedContainer, container, mergo.WithOverride, mergo.WithAppendSlice) + if err != nil { + return err } + mergedContainers = append(mergedContainers, *mergedContainer) + } + + } + + // update Pod fields + podTemplatePodSpec.Containers = mergedContainers + return nil +} + +func BuildPodWithSpec(podTemplate *v1.PodTemplate, podSpec *v1.PodSpec, primaryContainerName string) (*v1.Pod, error) { + pod := v1.Pod{ + TypeMeta: v12.TypeMeta{ + Kind: PodKind, + APIVersion: v1.SchemeGroupVersion.String(), + }, + } + + if podTemplate != nil { + // merge template PodSpec + basePodSpec := podTemplate.Template.Spec.DeepCopy() + + err := MergePodSpecs(basePodSpec, podSpec, primaryContainerName) + if err != nil { + return nil, err } - // update Pod fields - basePodSpec.Containers = mergedContainers pod.ObjectMeta = podTemplate.Template.ObjectMeta pod.Spec = *basePodSpec + } else { pod.Spec = *podSpec } @@ -231,12 +243,15 @@ func BuildIdentityPod() *v1.Pod { // Important considerations. // Pending Status in Pod could be for various reasons and sometimes could signal a problem // Case I: Pending because the Image pull is failing and it is backing off -// This could be transient. So we can actually rely on the failure reason. -// The failure transitions from ErrImagePull -> ImagePullBackoff +// +// This could be transient. So we can actually rely on the failure reason. +// The failure transitions from ErrImagePull -> ImagePullBackoff +// // Case II: Not enough resources are available. This is tricky. It could be that the total number of -// resources requested is beyond the capability of the system. for this we will rely on configuration -// and hence input gates. We should not allow bad requests that Request for large number of resource through. -// In the case it makes through, we will fail after timeout +// +// resources requested is beyond the capability of the system. for this we will rely on configuration +// and hence input gates. We should not allow bad requests that Request for large number of resource through. +// In the case it makes through, we will fail after timeout func DemystifyPending(status v1.PodStatus) (pluginsCore.PhaseInfo, error) { // Search over the difference conditions in the status object. Note that the 'Pending' this function is // demystifying is the 'phase' of the pod status. This is different than the PodReady condition type also used below From 01b3885edb0e3d838065801505041a67a48787f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Thu, 24 Nov 2022 19:41:14 +0100 Subject: [PATCH 02/13] Merge pod template spec with pod spec in separate func MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- go/tasks/pluginmachinery/flytek8s/pod_helper.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 54e983088..2b3df9edb 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -144,10 +144,10 @@ func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (* return pod, nil } -func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string) error { +func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string) (*v1.PodSpec, error) { err := mergo.Merge(podTemplatePodSpec, podSpec, mergo.WithOverride, mergo.WithAppendSlice) if err != nil { - return err + return nil, err } // merge template Containers @@ -175,7 +175,7 @@ func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryC } else { err := mergo.Merge(mergedContainer, primaryContainerTemplate, mergo.WithOverride, mergo.WithAppendSlice) if err != nil { - return err + return nil, err } } } @@ -187,7 +187,7 @@ func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryC } else { err := mergo.Merge(mergedContainer, container, mergo.WithOverride, mergo.WithAppendSlice) if err != nil { - return err + return nil, err } mergedContainers = append(mergedContainers, *mergedContainer) @@ -197,7 +197,8 @@ func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryC // update Pod fields podTemplatePodSpec.Containers = mergedContainers - return nil + + return podTemplatePodSpec, nil } func BuildPodWithSpec(podTemplate *v1.PodTemplate, podSpec *v1.PodSpec, primaryContainerName string) (*v1.Pod, error) { @@ -212,13 +213,13 @@ func BuildPodWithSpec(podTemplate *v1.PodTemplate, podSpec *v1.PodSpec, primaryC // merge template PodSpec basePodSpec := podTemplate.Template.Spec.DeepCopy() - err := MergePodSpecs(basePodSpec, podSpec, primaryContainerName) + mergedPodSpec, err := MergePodSpecs(basePodSpec, podSpec, primaryContainerName) if err != nil { return nil, err } pod.ObjectMeta = podTemplate.Template.ObjectMeta - pod.Spec = *basePodSpec + pod.Spec = *mergedPodSpec } else { pod.Spec = *podSpec From 22f00bc9bb3ef303f9a3f97ffc88240db4543108 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Thu, 24 Nov 2022 19:41:33 +0100 Subject: [PATCH 03/13] Apply pod template to pytorch job pod spec MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go | 11 +++++++++++ .../plugins/k8s/kfoperators/pytorch/pytorch_test.go | 6 +++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 0ad038f64..19a678e27 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -66,6 +66,17 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } + podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) + + if podTemplate != nil { + basePodSpec := podTemplate.Template.Spec.DeepCopy() + mergedPodSpec, err := flytek8s.MergePodSpecs(basePodSpec, podSpec, kubeflowv1.PytorchJobDefaultContainerName) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) + } + podSpec = mergedPodSpec + } + common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.PytorchJobDefaultContainerName) workers := pytorchTaskExtraArgs.GetWorkers() diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 5185d150d..0c4dcb3b1 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -69,7 +69,7 @@ func dummyPytorchCustomObj(workers int32) *plugins.DistributedPyTorchTrainingTas } } -func dummySparkTaskTemplate(id string, pytorchCustomObj *plugins.DistributedPyTorchTrainingTask) *core.TaskTemplate { +func dummyPytorchTaskTemplate(id string, pytorchCustomObj *plugins.DistributedPyTorchTrainingTask) *core.TaskTemplate { ptObjJSON, err := utils.MarshalToString(pytorchCustomObj) if err != nil { @@ -260,7 +260,7 @@ func dummyPytorchJobResource(pytorchResourceHandler pytorchOperatorResourceHandl } ptObj := dummyPytorchCustomObj(workers) - taskTemplate := dummySparkTaskTemplate("the job", ptObj) + taskTemplate := dummyPytorchTaskTemplate("the job", ptObj) resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) if err != nil { panic(err) @@ -286,7 +286,7 @@ func TestBuildResourcePytorch(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} ptObj := dummyPytorchCustomObj(100) - taskTemplate := dummySparkTaskTemplate("the job", ptObj) + taskTemplate := dummyPytorchTaskTemplate("the job", ptObj) resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) assert.NoError(t, err) From 03afd51b2833190ad83786decc4156df554dbdef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Wed, 30 Nov 2022 16:22:32 +0100 Subject: [PATCH 04/13] Handle nil podspecs before merging MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- go/tasks/pluginmachinery/flytek8s/pod_helper.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 2b3df9edb..17824a51c 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -2,6 +2,7 @@ package flytek8s import ( "context" + "errors" "fmt" "strings" "time" @@ -145,6 +146,10 @@ func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (* } func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string) (*v1.PodSpec, error) { + if podTemplatePodSpec == nil || podSpec == nil { + return nil, errors.New("podTemplatePodSpec and podSpec cannot be nil.") + } + err := mergo.Merge(podTemplatePodSpec, podSpec, mergo.WithOverride, mergo.WithAppendSlice) if err != nil { return nil, err From 4d92d8a33784bc1d0d93294de9ed471f1701866a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Wed, 30 Nov 2022 19:58:06 +0100 Subject: [PATCH 05/13] Pass both default and primare container name to MergePodSpecs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- go/tasks/pluginmachinery/flytek8s/pod_helper.go | 4 ++-- go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 17824a51c..77b27ef9c 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -145,7 +145,7 @@ func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (* return pod, nil } -func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string) (*v1.PodSpec, error) { +func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string, defaultContainerName string) (*v1.PodSpec, error) { if podTemplatePodSpec == nil || podSpec == nil { return nil, errors.New("podTemplatePodSpec and podSpec cannot be nil.") } @@ -218,7 +218,7 @@ func BuildPodWithSpec(podTemplate *v1.PodTemplate, podSpec *v1.PodSpec, primaryC // merge template PodSpec basePodSpec := podTemplate.Template.Spec.DeepCopy() - mergedPodSpec, err := MergePodSpecs(basePodSpec, podSpec, primaryContainerName) + mergedPodSpec, err := MergePodSpecs(basePodSpec, podSpec, primaryContainerName, defaultContainerTemplateName) if err != nil { return nil, err } diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 19a678e27..4741741c5 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -24,6 +24,9 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) +const defaultContainerTemplateName = kubeflowv1.PytorchJobDefaultContainerName +const primaryContainerTemplateName = "primary" + type pytorchOperatorResourceHandler struct { } @@ -70,7 +73,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx if podTemplate != nil { basePodSpec := podTemplate.Template.Spec.DeepCopy() - mergedPodSpec, err := flytek8s.MergePodSpecs(basePodSpec, podSpec, kubeflowv1.PytorchJobDefaultContainerName) + mergedPodSpec, err := flytek8s.MergePodSpecs(basePodSpec, podSpec, primaryContainerTemplateName, defaultContainerTemplateName) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) } From 2f6a24b2cd6100fe2ef3d993614421c10bf8f768 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Wed, 30 Nov 2022 20:04:28 +0100 Subject: [PATCH 06/13] Move podSpec.DeepCopy into MergePodSpecs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- .../pluginmachinery/flytek8s/pod_helper.go | 23 ++++++++++--------- .../k8s/kfoperators/pytorch/pytorch.go | 3 +-- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 77b27ef9c..7e7859e08 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -150,7 +150,10 @@ func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryC return nil, errors.New("podTemplatePodSpec and podSpec cannot be nil.") } - err := mergo.Merge(podTemplatePodSpec, podSpec, mergo.WithOverride, mergo.WithAppendSlice) + var podTemplatePodSpecCopy *v1.PodSpec + podTemplatePodSpecCopy = podTemplatePodSpec.DeepCopy() + + err := mergo.Merge(podTemplatePodSpecCopy, podSpec, mergo.WithOverride, mergo.WithAppendSlice) if err != nil { return nil, err } @@ -158,11 +161,11 @@ func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryC // merge template Containers var mergedContainers []v1.Container var defaultContainerTemplate, primaryContainerTemplate *v1.Container - for i := 0; i < len(podTemplatePodSpec.Containers); i++ { - if podTemplatePodSpec.Containers[i].Name == defaultContainerTemplateName { - defaultContainerTemplate = &podTemplatePodSpec.Containers[i] - } else if podTemplatePodSpec.Containers[i].Name == primaryContainerTemplateName { - primaryContainerTemplate = &podTemplatePodSpec.Containers[i] + for i := 0; i < len(podTemplatePodSpecCopy.Containers); i++ { + if podTemplatePodSpecCopy.Containers[i].Name == defaultContainerTemplateName { + defaultContainerTemplate = &podTemplatePodSpecCopy.Containers[i] + } else if podTemplatePodSpecCopy.Containers[i].Name == primaryContainerTemplateName { + primaryContainerTemplate = &podTemplatePodSpecCopy.Containers[i] } } @@ -201,9 +204,9 @@ func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryC } // update Pod fields - podTemplatePodSpec.Containers = mergedContainers + podTemplatePodSpecCopy.Containers = mergedContainers - return podTemplatePodSpec, nil + return podTemplatePodSpecCopy, nil } func BuildPodWithSpec(podTemplate *v1.PodTemplate, podSpec *v1.PodSpec, primaryContainerName string) (*v1.Pod, error) { @@ -216,9 +219,7 @@ func BuildPodWithSpec(podTemplate *v1.PodTemplate, podSpec *v1.PodSpec, primaryC if podTemplate != nil { // merge template PodSpec - basePodSpec := podTemplate.Template.Spec.DeepCopy() - - mergedPodSpec, err := MergePodSpecs(basePodSpec, podSpec, primaryContainerName, defaultContainerTemplateName) + mergedPodSpec, err := MergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerName, defaultContainerTemplateName) if err != nil { return nil, err } diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 4741741c5..06b009bf1 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -72,8 +72,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) if podTemplate != nil { - basePodSpec := podTemplate.Template.Spec.DeepCopy() - mergedPodSpec, err := flytek8s.MergePodSpecs(basePodSpec, podSpec, primaryContainerTemplateName, defaultContainerTemplateName) + mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerTemplateName, defaultContainerTemplateName) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) } From 5efe69b0315468579c40f3fb6f3c3b2696da10ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Wed, 30 Nov 2022 20:05:39 +0100 Subject: [PATCH 07/13] Add tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- .../flytek8s/pod_helper_test.go | 139 ++++++++++++++++++ 1 file changed, 139 insertions(+) diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index 3af65a413..9f80aa1f9 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -23,6 +23,7 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" pluginsCoreMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" @@ -1013,6 +1014,144 @@ func TestDeterminePrimaryContainerPhase(t *testing.T) { }) } +func TestMergePodSpecs(t *testing.T) { + var priority int32 = 1 + + podSpec1, _ := MergePodSpecs(nil, nil, "defaultname", "primaryname") + assert.Nil(t, podSpec1) + + podSpec2, _ := MergePodSpecs(&v1.PodSpec{}, nil, "defaultname", "primaryname") + assert.Nil(t, podSpec2) + + podSpec3, _ := MergePodSpecs(nil, &v1.PodSpec{}, "defaultname", "primaryname") + assert.Nil(t, podSpec3) + + podSpec := v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "foo", + }, + v1.Container{ + Name: "bar", + }, + }, + NodeSelector: map[string]string{ + "baz": "bar", + }, + Priority: &priority, + SchedulerName: "overrideScheduler", + Tolerations: []v1.Toleration{ + v1.Toleration{ + Key: "bar", + }, + v1.Toleration{ + Key: "baz", + }, + }, + } + + defaultContainerTemplate := v1.Container{ + Name: defaultContainerTemplateName, + TerminationMessagePath: "/dev/default-termination-log", + } + + primaryContainerTemplate := v1.Container{ + Name: primaryContainerTemplateName, + TerminationMessagePath: "/dev/primary-termination-log", + } + + podTemplateSpec := v1.PodSpec{ + Containers: []v1.Container{ + defaultContainerTemplate, + primaryContainerTemplate, + }, + HostNetwork: true, + NodeSelector: map[string]string{ + "foo": "bar", + }, + SchedulerName: "defaultScheduler", + Tolerations: []v1.Toleration{ + v1.Toleration{ + Key: "foo", + }, + }, + } + + mergedPodSpec, err := MergePodSpecs(&podTemplateSpec, &podSpec, "foo", "default") + assert.Nil(t, err) + + // validate a PodTemplate-only field + assert.Equal(t, podTemplateSpec.HostNetwork, mergedPodSpec.HostNetwork) + // validate a PodSpec-only field + assert.Equal(t, podSpec.Priority, mergedPodSpec.Priority) + // validate an overwritten PodTemplate field + assert.Equal(t, podSpec.SchedulerName, mergedPodSpec.SchedulerName) + // validate a merged map + assert.Equal(t, len(podTemplateSpec.NodeSelector)+len(podSpec.NodeSelector), len(mergedPodSpec.NodeSelector)) + // validate an appended array + assert.Equal(t, len(podTemplateSpec.Tolerations)+len(podSpec.Tolerations), len(mergedPodSpec.Tolerations)) + + // validate primary container + primaryContainer := mergedPodSpec.Containers[0] + assert.Equal(t, podSpec.Containers[0].Name, primaryContainer.Name) + assert.Equal(t, primaryContainerTemplate.TerminationMessagePath, primaryContainer.TerminationMessagePath) + + // validate default container + defaultContainer := mergedPodSpec.Containers[1] + assert.Equal(t, podSpec.Containers[1].Name, defaultContainer.Name) + assert.Equal(t, defaultContainerTemplate.TerminationMessagePath, defaultContainer.TerminationMessagePath) + +} + +func TestBuildPodWithSpec2(t *testing.T) { + podSpec := v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "foo", + }, + v1.Container{ + Name: "bar", + }, + }, + } + + pod, err := BuildPodWithSpec(nil, &podSpec, "foo") + assert.Nil(t, err) + assert.True(t, reflect.DeepEqual(pod.Spec, podSpec)) + + primaryContainerTemplate := v1.Container{ + Name: primaryContainerTemplateName, + TerminationMessagePath: "/dev/primary-termination-log", + } + + podTemplate := v1.PodTemplate{ + Template: v1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + "fooKey": "barVal", + }, + }, + Spec: v1.PodSpec{ + Containers: []v1.Container{ + primaryContainerTemplate, + }, + }, + }, + } + + pod, err = BuildPodWithSpec(&podTemplate, &podSpec, "foo") + assert.Nil(t, err) + + // Test that template podSpec is merged + primaryContainer := pod.Spec.Containers[0] + assert.Equal(t, podSpec.Containers[0].Name, primaryContainer.Name) + assert.Equal(t, primaryContainerTemplate.TerminationMessagePath, primaryContainer.TerminationMessagePath) + + // Test that template object metadata is copied + assert.Contains(t, pod.ObjectMeta.Labels, "fooKey") + assert.Equal(t, pod.ObjectMeta.Labels["fooKey"], "barVal") +} + func TestBuildPodWithSpec(t *testing.T) { var priority int32 = 1 podSpec := v1.PodSpec{ From 1980e8a50cb790e3037f8077f38c4d88991f2815 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Wed, 30 Nov 2022 20:06:36 +0100 Subject: [PATCH 08/13] Merge pod template into tfjob and mpijob MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- go/tasks/plugins/k8s/kfoperators/mpi/mpi.go | 13 +++++++++++++ .../k8s/kfoperators/tensorflow/tensorflow.go | 13 +++++++++++++ 2 files changed, 26 insertions(+) diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 197132519..4c51abd8f 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -21,6 +21,9 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" ) +const defaultContainerTemplateName = kubeflowv1.MPIJobDefaultContainerName +const primaryContainerTemplateName = "primary" + type mpiOperatorResourceHandler struct { } @@ -67,6 +70,16 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } + podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) + + if podTemplate != nil { + mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerTemplateName, defaultContainerTemplateName) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) + } + podSpec = mergedPodSpec + } + // workersPodSpec is deepCopy of podSpec submitted by flyte // WorkerPodSpec doesn't need any Argument & command. It will be trigger from launcher pod workersPodSpec := podSpec.DeepCopy() diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index d533d35a0..b5f2c9901 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -24,6 +24,9 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) +const defaultContainerTemplateName = kubeflowv1.TFJobDefaultContainerName +const primaryContainerTemplateName = "primary" + type tensorflowOperatorResourceHandler struct { } @@ -66,6 +69,16 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } + podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) + + if podTemplate != nil { + mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerTemplateName, defaultContainerTemplateName) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) + } + podSpec = mergedPodSpec + } + common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.TFJobDefaultContainerName) workers := tensorflowTaskExtraArgs.GetWorkers() From d6f3a74bbcc30acb091904e38e0d475bdc783040 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Wed, 30 Nov 2022 20:20:37 +0100 Subject: [PATCH 09/13] Lint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- go/tasks/pluginmachinery/flytek8s/pod_helper.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 7e7859e08..4ee466627 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -147,11 +147,10 @@ func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (* func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string, defaultContainerName string) (*v1.PodSpec, error) { if podTemplatePodSpec == nil || podSpec == nil { - return nil, errors.New("podTemplatePodSpec and podSpec cannot be nil.") + return nil, errors.New("podTemplatePodSpec and podSpec cannot be nil") } - var podTemplatePodSpecCopy *v1.PodSpec - podTemplatePodSpecCopy = podTemplatePodSpec.DeepCopy() + var podTemplatePodSpecCopy *v1.PodSpec = podTemplatePodSpec.DeepCopy() err := mergo.Merge(podTemplatePodSpecCopy, podSpec, mergo.WithOverride, mergo.WithAppendSlice) if err != nil { From f71fbc83a99686bfca877414783818ec1fa59860 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Sat, 10 Dec 2022 13:08:07 +0100 Subject: [PATCH 10/13] Correct usage of default and primate container (template) name MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- go/tasks/pluginmachinery/flytek8s/pod_helper.go | 4 ++-- go/tasks/pluginmachinery/flytek8s/pod_helper_test.go | 8 ++++---- go/tasks/plugins/k8s/kfoperators/mpi/mpi.go | 5 +---- go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go | 9 +++------ .../plugins/k8s/kfoperators/tensorflow/tensorflow.go | 9 +++------ 5 files changed, 13 insertions(+), 22 deletions(-) diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 4ee466627..54f9a1f01 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -145,7 +145,7 @@ func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (* return pod, nil } -func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string, defaultContainerName string) (*v1.PodSpec, error) { +func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string) (*v1.PodSpec, error) { if podTemplatePodSpec == nil || podSpec == nil { return nil, errors.New("podTemplatePodSpec and podSpec cannot be nil") } @@ -218,7 +218,7 @@ func BuildPodWithSpec(podTemplate *v1.PodTemplate, podSpec *v1.PodSpec, primaryC if podTemplate != nil { // merge template PodSpec - mergedPodSpec, err := MergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerName, defaultContainerTemplateName) + mergedPodSpec, err := MergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerName) if err != nil { return nil, err } diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index 9f80aa1f9..bad87d5f8 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -1017,13 +1017,13 @@ func TestDeterminePrimaryContainerPhase(t *testing.T) { func TestMergePodSpecs(t *testing.T) { var priority int32 = 1 - podSpec1, _ := MergePodSpecs(nil, nil, "defaultname", "primaryname") + podSpec1, _ := MergePodSpecs(nil, nil, "foo") assert.Nil(t, podSpec1) - podSpec2, _ := MergePodSpecs(&v1.PodSpec{}, nil, "defaultname", "primaryname") + podSpec2, _ := MergePodSpecs(&v1.PodSpec{}, nil, "foo") assert.Nil(t, podSpec2) - podSpec3, _ := MergePodSpecs(nil, &v1.PodSpec{}, "defaultname", "primaryname") + podSpec3, _ := MergePodSpecs(nil, &v1.PodSpec{}, "foo") assert.Nil(t, podSpec3) podSpec := v1.PodSpec{ @@ -1077,7 +1077,7 @@ func TestMergePodSpecs(t *testing.T) { }, } - mergedPodSpec, err := MergePodSpecs(&podTemplateSpec, &podSpec, "foo", "default") + mergedPodSpec, err := MergePodSpecs(&podTemplateSpec, &podSpec, "foo") assert.Nil(t, err) // validate a PodTemplate-only field diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 4c51abd8f..f1c928f77 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -21,9 +21,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" ) -const defaultContainerTemplateName = kubeflowv1.MPIJobDefaultContainerName -const primaryContainerTemplateName = "primary" - type mpiOperatorResourceHandler struct { } @@ -73,7 +70,7 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) if podTemplate != nil { - mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerTemplateName, defaultContainerTemplateName) + mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, kubeflowv1.MPIJobDefaultContainerName) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) } diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 06b009bf1..cf1a761ff 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -24,9 +24,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) -const defaultContainerTemplateName = kubeflowv1.PytorchJobDefaultContainerName -const primaryContainerTemplateName = "primary" - type pytorchOperatorResourceHandler struct { } @@ -69,18 +66,18 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } + common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.PytorchJobDefaultContainerName) + podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) if podTemplate != nil { - mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerTemplateName, defaultContainerTemplateName) + mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, kubeflowv1.PytorchJobDefaultContainerName) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) } podSpec = mergedPodSpec } - common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.PytorchJobDefaultContainerName) - workers := pytorchTaskExtraArgs.GetWorkers() jobSpec := kubeflowv1.PyTorchJobSpec{ diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index b5f2c9901..93780a52e 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -24,9 +24,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) -const defaultContainerTemplateName = kubeflowv1.TFJobDefaultContainerName -const primaryContainerTemplateName = "primary" - type tensorflowOperatorResourceHandler struct { } @@ -69,18 +66,18 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } + common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.TFJobDefaultContainerName) + podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) if podTemplate != nil { - mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerTemplateName, defaultContainerTemplateName) + mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, kubeflowv1.TFJobDefaultContainerName) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) } podSpec = mergedPodSpec } - common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.TFJobDefaultContainerName) - workers := tensorflowTaskExtraArgs.GetWorkers() psReplicas := tensorflowTaskExtraArgs.GetPsReplicas() chiefReplicas := tensorflowTaskExtraArgs.GetChiefReplicas() From 1c8f4de329d513007aaf468afae30e396984cc0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Sat, 10 Dec 2022 13:10:35 +0100 Subject: [PATCH 11/13] Override mpi default container name MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- go/tasks/plugins/k8s/kfoperators/mpi/mpi.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index f1c928f77..528737bf3 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -67,6 +67,8 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } + common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.MPIJobDefaultContainerName) + podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) if podTemplate != nil { From ffbe158ff0ae611dac281663b1ad44cd8e408a6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Sat, 10 Dec 2022 13:24:52 +0100 Subject: [PATCH 12/13] Carry over ObjectMeta from pod template MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- go/tasks/plugins/k8s/kfoperators/mpi/mpi.go | 9 +++++++-- go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go | 9 +++++++-- .../plugins/k8s/kfoperators/tensorflow/tensorflow.go | 9 +++++++-- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 528737bf3..ec42029b4 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -71,12 +71,15 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) + objectMeta := metav1.ObjectMeta{} + if podTemplate != nil { mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, kubeflowv1.MPIJobDefaultContainerName) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) } podSpec = mergedPodSpec + objectMeta = podTemplate.Template.ObjectMeta } // workersPodSpec is deepCopy of podSpec submitted by flyte @@ -101,14 +104,16 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu kubeflowv1.MPIJobReplicaTypeLauncher: { Replicas: &launcherReplicas, Template: v1.PodTemplateSpec{ - Spec: *podSpec, + ObjectMeta: objectMeta, + Spec: *podSpec, }, RestartPolicy: commonKf.RestartPolicyNever, }, kubeflowv1.MPIJobReplicaTypeWorker: { Replicas: &workers, Template: v1.PodTemplateSpec{ - Spec: *workersPodSpec, + ObjectMeta: objectMeta, + Spec: *workersPodSpec, }, RestartPolicy: commonKf.RestartPolicyNever, }, diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index cf1a761ff..ce36fd2c8 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -70,12 +70,15 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) + objectMeta := metav1.ObjectMeta{} + if podTemplate != nil { mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, kubeflowv1.PytorchJobDefaultContainerName) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) } podSpec = mergedPodSpec + objectMeta = podTemplate.Template.ObjectMeta } workers := pytorchTaskExtraArgs.GetWorkers() @@ -84,14 +87,16 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ kubeflowv1.PyTorchJobReplicaTypeMaster: { Template: v1.PodTemplateSpec{ - Spec: *podSpec, + ObjectMeta: objectMeta, + Spec: *podSpec, }, RestartPolicy: commonOp.RestartPolicyNever, }, kubeflowv1.PyTorchJobReplicaTypeWorker: { Replicas: &workers, Template: v1.PodTemplateSpec{ - Spec: *podSpec, + ObjectMeta: objectMeta, + Spec: *podSpec, }, RestartPolicy: commonOp.RestartPolicyNever, }, diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index 93780a52e..d2370cf94 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -70,12 +70,15 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) + objectMeta := metav1.ObjectMeta{} + if podTemplate != nil { mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, kubeflowv1.TFJobDefaultContainerName) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) } podSpec = mergedPodSpec + objectMeta = podTemplate.Template.ObjectMeta } workers := tensorflowTaskExtraArgs.GetWorkers() @@ -87,14 +90,16 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task kubeflowv1.TFJobReplicaTypePS: { Replicas: &psReplicas, Template: v1.PodTemplateSpec{ - Spec: *podSpec, + ObjectMeta: objectMeta, + Spec: *podSpec, }, RestartPolicy: commonOp.RestartPolicyNever, }, kubeflowv1.TFJobReplicaTypeChief: { Replicas: &chiefReplicas, Template: v1.PodTemplateSpec{ - Spec: *podSpec, + ObjectMeta: objectMeta, + Spec: *podSpec, }, RestartPolicy: commonOp.RestartPolicyNever, }, From b5ed78a188edca3a5be6d5ff247b87f93acfe878 Mon Sep 17 00:00:00 2001 From: "Fabio M. Graetz, Ph.D" Date: Mon, 12 Dec 2022 19:15:32 +0100 Subject: [PATCH 13/13] Remove old `TestBuildPodWithSpec` test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- .../flytek8s/pod_helper_test.go | 88 +------------------ 1 file changed, 1 insertion(+), 87 deletions(-) diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index bad87d5f8..01c16051e 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -1103,7 +1103,7 @@ func TestMergePodSpecs(t *testing.T) { } -func TestBuildPodWithSpec2(t *testing.T) { +func TestBuildPodWithSpec(t *testing.T) { podSpec := v1.PodSpec{ Containers: []v1.Container{ v1.Container{ @@ -1151,89 +1151,3 @@ func TestBuildPodWithSpec2(t *testing.T) { assert.Contains(t, pod.ObjectMeta.Labels, "fooKey") assert.Equal(t, pod.ObjectMeta.Labels["fooKey"], "barVal") } - -func TestBuildPodWithSpec(t *testing.T) { - var priority int32 = 1 - podSpec := v1.PodSpec{ - Containers: []v1.Container{ - v1.Container{ - Name: "foo", - }, - v1.Container{ - Name: "bar", - }, - }, - NodeSelector: map[string]string{ - "baz": "bar", - }, - Priority: &priority, - SchedulerName: "overrideScheduler", - Tolerations: []v1.Toleration{ - v1.Toleration{ - Key: "bar", - }, - v1.Toleration{ - Key: "baz", - }, - }, - } - - pod, err := BuildPodWithSpec(nil, &podSpec, "foo") - assert.Nil(t, err) - assert.True(t, reflect.DeepEqual(pod.Spec, podSpec)) - - defaultContainerTemplate := v1.Container{ - Name: defaultContainerTemplateName, - TerminationMessagePath: "/dev/default-termination-log", - } - - primaryContainerTemplate := v1.Container{ - Name: primaryContainerTemplateName, - TerminationMessagePath: "/dev/primary-termination-log", - } - - podTemplate := v1.PodTemplate{ - Template: v1.PodTemplateSpec{ - Spec: v1.PodSpec{ - Containers: []v1.Container{ - defaultContainerTemplate, - primaryContainerTemplate, - }, - HostNetwork: true, - NodeSelector: map[string]string{ - "foo": "bar", - }, - SchedulerName: "defaultScheduler", - Tolerations: []v1.Toleration{ - v1.Toleration{ - Key: "foo", - }, - }, - }, - }, - } - - pod, err = BuildPodWithSpec(&podTemplate, &podSpec, "foo") - assert.Nil(t, err) - - // validate a PodTemplate-only field - assert.Equal(t, podTemplate.Template.Spec.HostNetwork, pod.Spec.HostNetwork) - // validate a PodSpec-only field - assert.Equal(t, podSpec.Priority, pod.Spec.Priority) - // validate an overwritten PodTemplate field - assert.Equal(t, podSpec.SchedulerName, pod.Spec.SchedulerName) - // validate a merged map - assert.Equal(t, len(podTemplate.Template.Spec.NodeSelector)+len(podSpec.NodeSelector), len(pod.Spec.NodeSelector)) - // validate an appended array - assert.Equal(t, len(podTemplate.Template.Spec.Tolerations)+len(podSpec.Tolerations), len(pod.Spec.Tolerations)) - - // validate primary container - primaryContainer := pod.Spec.Containers[0] - assert.Equal(t, podSpec.Containers[0].Name, primaryContainer.Name) - assert.Equal(t, primaryContainerTemplate.TerminationMessagePath, primaryContainer.TerminationMessagePath) - - // validate default container - defaultContainer := pod.Spec.Containers[1] - assert.Equal(t, podSpec.Containers[1].Name, defaultContainer.Name) - assert.Equal(t, defaultContainerTemplate.TerminationMessagePath, defaultContainer.TerminationMessagePath) -}