From d25aade50cffcd9153b41d3e04ef623be30a6acb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabio=20Gr=C3=A4tz?= Date: Thu, 24 Nov 2022 19:41:33 +0100 Subject: [PATCH] Apply pod template to pytorch job pod spec --- go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go | 11 +++++++++++ .../plugins/k8s/kfoperators/pytorch/pytorch_test.go | 6 +++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 0ad038f64..19a678e27 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -66,6 +66,17 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } + podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) + + if podTemplate != nil { + basePodSpec := podTemplate.Template.Spec.DeepCopy() + mergedPodSpec, err := flytek8s.MergePodSpecs(basePodSpec, podSpec, kubeflowv1.PytorchJobDefaultContainerName) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) + } + podSpec = mergedPodSpec + } + common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.PytorchJobDefaultContainerName) workers := pytorchTaskExtraArgs.GetWorkers() diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 5185d150d..0c4dcb3b1 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -69,7 +69,7 @@ func dummyPytorchCustomObj(workers int32) *plugins.DistributedPyTorchTrainingTas } } -func dummySparkTaskTemplate(id string, pytorchCustomObj *plugins.DistributedPyTorchTrainingTask) *core.TaskTemplate { +func dummyPytorchTaskTemplate(id string, pytorchCustomObj *plugins.DistributedPyTorchTrainingTask) *core.TaskTemplate { ptObjJSON, err := utils.MarshalToString(pytorchCustomObj) if err != nil { @@ -260,7 +260,7 @@ func dummyPytorchJobResource(pytorchResourceHandler pytorchOperatorResourceHandl } ptObj := dummyPytorchCustomObj(workers) - taskTemplate := dummySparkTaskTemplate("the job", ptObj) + taskTemplate := dummyPytorchTaskTemplate("the job", ptObj) resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) if err != nil { panic(err) @@ -286,7 +286,7 @@ func TestBuildResourcePytorch(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} ptObj := dummyPytorchCustomObj(100) - taskTemplate := dummySparkTaskTemplate("the job", ptObj) + taskTemplate := dummyPytorchTaskTemplate("the job", ptObj) resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) assert.NoError(t, err)