From ffbe158ff0ae611dac281663b1ad44cd8e408a6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Sat, 10 Dec 2022 13:24:52 +0100 Subject: [PATCH] Carry over ObjectMeta from pod template MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Fabio Grätz --- go/tasks/plugins/k8s/kfoperators/mpi/mpi.go | 9 +++++++-- go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go | 9 +++++++-- .../plugins/k8s/kfoperators/tensorflow/tensorflow.go | 9 +++++++-- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 528737bf3..ec42029b4 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -71,12 +71,15 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) + objectMeta := metav1.ObjectMeta{} + if podTemplate != nil { mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, kubeflowv1.MPIJobDefaultContainerName) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) } podSpec = mergedPodSpec + objectMeta = podTemplate.Template.ObjectMeta } // workersPodSpec is deepCopy of podSpec submitted by flyte @@ -101,14 +104,16 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu kubeflowv1.MPIJobReplicaTypeLauncher: { Replicas: &launcherReplicas, Template: v1.PodTemplateSpec{ - Spec: *podSpec, + ObjectMeta: objectMeta, + Spec: *podSpec, }, RestartPolicy: commonKf.RestartPolicyNever, }, kubeflowv1.MPIJobReplicaTypeWorker: { Replicas: &workers, Template: v1.PodTemplateSpec{ - Spec: *workersPodSpec, + ObjectMeta: objectMeta, + Spec: *workersPodSpec, }, RestartPolicy: commonKf.RestartPolicyNever, }, diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index cf1a761ff..ce36fd2c8 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -70,12 +70,15 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) + objectMeta := metav1.ObjectMeta{} + if podTemplate != nil { mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, kubeflowv1.PytorchJobDefaultContainerName) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) } podSpec = mergedPodSpec + objectMeta = podTemplate.Template.ObjectMeta } workers := pytorchTaskExtraArgs.GetWorkers() @@ -84,14 +87,16 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ kubeflowv1.PyTorchJobReplicaTypeMaster: { Template: v1.PodTemplateSpec{ - Spec: *podSpec, + ObjectMeta: objectMeta, + Spec: *podSpec, }, RestartPolicy: commonOp.RestartPolicyNever, }, kubeflowv1.PyTorchJobReplicaTypeWorker: { Replicas: &workers, Template: v1.PodTemplateSpec{ - Spec: *podSpec, + ObjectMeta: objectMeta, + Spec: *podSpec, }, RestartPolicy: commonOp.RestartPolicyNever, }, diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index 93780a52e..d2370cf94 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -70,12 +70,15 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) + objectMeta := metav1.ObjectMeta{} + if podTemplate != nil { mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, kubeflowv1.TFJobDefaultContainerName) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) } podSpec = mergedPodSpec + objectMeta = podTemplate.Template.ObjectMeta } workers := tensorflowTaskExtraArgs.GetWorkers() @@ -87,14 +90,16 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task kubeflowv1.TFJobReplicaTypePS: { Replicas: &psReplicas, Template: v1.PodTemplateSpec{ - Spec: *podSpec, + ObjectMeta: objectMeta, + Spec: *podSpec, }, RestartPolicy: commonOp.RestartPolicyNever, }, kubeflowv1.TFJobReplicaTypeChief: { Replicas: &chiefReplicas, Template: v1.PodTemplateSpec{ - Spec: *podSpec, + ObjectMeta: objectMeta, + Spec: *podSpec, }, RestartPolicy: commonOp.RestartPolicyNever, },