From 9a2bbbaf2f3ac9e38222ba7755ac4bda1ac2fdb7 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Tue, 9 May 2023 13:40:42 -0700 Subject: [PATCH] Change kubeflow plugins to allow settings specs for different replica (#345) * change pytorch plugin to accept new pytorch task idl Signed-off-by: Yubo Wang * merge elastic config in Signed-off-by: Yubo Wang * add unit tests for pytorch Signed-off-by: Yubo Wang * add tfjob Signed-off-by: Yubo Wang * add mpi job Signed-off-by: Yubo Wang * add test to commone operator Signed-off-by: Yubo Wang * update flyteidl Signed-off-by: Yubo Wang * add function header comments Signed-off-by: Yubo Wang * fix lint Signed-off-by: Yubo Wang --------- Signed-off-by: Yubo Wang Co-authored-by: Yubo Wang --- go.mod | 2 +- go.sum | 4 +- .../k8s/kfoperators/common/common_operator.go | 77 +++++++ .../common/common_operator_test.go | 101 ++++++++++ go/tasks/plugins/k8s/kfoperators/mpi/mpi.go | 147 ++++++++++---- .../plugins/k8s/kfoperators/mpi/mpi_test.go | 136 ++++++++++++- .../k8s/kfoperators/pytorch/pytorch.go | 146 ++++++++++---- .../k8s/kfoperators/pytorch/pytorch_test.go | 166 +++++++++++++++- .../k8s/kfoperators/tensorflow/tensorflow.go | 129 +++++++++--- .../kfoperators/tensorflow/tensorflow_test.go | 188 +++++++++++++++++- 10 files changed, 975 insertions(+), 121 deletions(-) diff --git a/go.mod b/go.mod index 6a2331949..7d73674eb 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/athena v1.0.0 github.com/bstadlbauer/dask-k8s-operator-go-client v0.1.0 github.com/coocood/freecache v1.1.1 - github.com/flyteorg/flyteidl v1.3.19 + github.com/flyteorg/flyteidl v1.5.2 github.com/flyteorg/flytestdlib v1.0.15 github.com/go-test/deep v1.0.7 github.com/golang/protobuf v1.5.2 diff --git a/go.sum b/go.sum index 70bbe278f..1af063022 100644 --- a/go.sum +++ b/go.sum @@ -232,8 +232,8 @@ github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQL github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= -github.com/flyteorg/flyteidl v1.3.19 h1:i79Dh7UoP8Z4LEJ2ox6jlfZVJtFZ+r4g84CJj1gh22Y= -github.com/flyteorg/flyteidl v1.3.19/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM= +github.com/flyteorg/flyteidl v1.5.2 h1:DZPzYkTg92qA4e17fd0ZW1M+gh1gJKh/VOK+F4bYgM8= +github.com/flyteorg/flyteidl v1.5.2/go.mod h1:ckLjB51moX4L0oQml+WTCrPK50zrJf6IZJ6LPC0RB4I= github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0= github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s= github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk= diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 88419b64c..d86ae42df 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -5,9 +5,11 @@ import ( "sort" "time" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow" flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/logs" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" @@ -21,6 +23,12 @@ const ( PytorchTaskType = "pytorch" ) +type ReplicaEntry struct { + PodSpec *v1.PodSpec + ReplicaNum int32 + RestartPolicy commonOp.RestartPolicy +} + // ExtractMPICurrentCondition will return the first job condition for MPI func ExtractMPICurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) { if jobConditions != nil { @@ -180,3 +188,72 @@ func OverridePrimaryContainerName(podSpec *v1.PodSpec, primaryContainerName stri } } } + +// ParseRunPolicy converts a kubeflow plugin RunPolicy object to a k8s RunPolicy object. +func ParseRunPolicy(flyteRunPolicy kfplugins.RunPolicy) commonOp.RunPolicy { + runPolicy := commonOp.RunPolicy{} + if flyteRunPolicy.GetBackoffLimit() != 0 { + var backoffLimit = flyteRunPolicy.GetBackoffLimit() + runPolicy.BackoffLimit = &backoffLimit + } + var cleanPodPolicy = ParseCleanPodPolicy(flyteRunPolicy.GetCleanPodPolicy()) + runPolicy.CleanPodPolicy = &cleanPodPolicy + if flyteRunPolicy.GetActiveDeadlineSeconds() != 0 { + var ddlSeconds = int64(flyteRunPolicy.GetActiveDeadlineSeconds()) + runPolicy.ActiveDeadlineSeconds = &ddlSeconds + } + if flyteRunPolicy.GetTtlSecondsAfterFinished() != 0 { + var ttl = flyteRunPolicy.GetTtlSecondsAfterFinished() + runPolicy.TTLSecondsAfterFinished = &ttl + } + + return runPolicy +} + +// Get k8s clean pod policy from flyte kubeflow plugins clean pod policy. +func ParseCleanPodPolicy(flyteCleanPodPolicy kfplugins.CleanPodPolicy) commonOp.CleanPodPolicy { + cleanPodPolicyMap := map[kfplugins.CleanPodPolicy]commonOp.CleanPodPolicy{ + kfplugins.CleanPodPolicy_CLEANPOD_POLICY_NONE: commonOp.CleanPodPolicyNone, + kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL: commonOp.CleanPodPolicyAll, + kfplugins.CleanPodPolicy_CLEANPOD_POLICY_RUNNING: commonOp.CleanPodPolicyRunning, + } + return cleanPodPolicyMap[flyteCleanPodPolicy] +} + +// Get k8s restart policy from flyte kubeflow plugins restart policy. +func ParseRestartPolicy(flyteRestartPolicy kfplugins.RestartPolicy) commonOp.RestartPolicy { + restartPolicyMap := map[kfplugins.RestartPolicy]commonOp.RestartPolicy{ + kfplugins.RestartPolicy_RESTART_POLICY_NEVER: commonOp.RestartPolicyNever, + kfplugins.RestartPolicy_RESTART_POLICY_ON_FAILURE: commonOp.RestartPolicyOnFailure, + kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS: commonOp.RestartPolicyAlways, + } + return restartPolicyMap[flyteRestartPolicy] +} + +// OverrideContainerSpec overrides the specified container's properties in the given podSpec. The function +// updates the image, resources and command arguments of the container that matches the given containerName. +func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, resources *core.Resources, args []string) error { + for idx, c := range podSpec.Containers { + if c.Name == containerName { + if image != "" { + podSpec.Containers[idx].Image = image + } + if resources != nil { + // if resources requests and limits both not set, we will not override the resources + if len(resources.Requests) >= 1 || len(resources.Limits) >= 1 { + resources, err := flytek8s.ToK8sResourceRequirements(resources) + if err != nil { + return flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecificat ion on Resources [%v], Err: [%v]", resources, err.Error()) + } + podSpec.Containers[idx].Resources = *resources + } + } else { + podSpec.Containers[idx].Resources = v1.ResourceRequirements{} + } + if len(args) != 0 { + podSpec.Containers[idx].Args = args + } + } + } + return nil +} diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go index 1fb26f128..ee2dc5a94 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go @@ -5,12 +5,15 @@ import ( "testing" "time" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/logs" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" commonOp "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/stretchr/testify/assert" corev1 "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" ) func TestExtractMPICurrentCondition(t *testing.T) { @@ -183,3 +186,101 @@ func TestGetLogs(t *testing.T) { assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-chiefReplica-0/pod?namespace=tensorflow-namespace", "tensorflow-namespace", "test"), jobLogs[2].Uri) } + +func dummyPodSpec() v1.PodSpec { + return v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "primary container", + Args: []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"}, + Resources: v1.ResourceRequirements{ + Limits: v1.ResourceList{ + "cpu": resource.MustParse("2"), + "memory": resource.MustParse("200Mi"), + "gpu": resource.MustParse("1"), + }, + Requests: v1.ResourceList{ + "cpu": resource.MustParse("1"), + "memory": resource.MustParse("100Mi"), + "gpu": resource.MustParse("1"), + }, + }, + VolumeMounts: []v1.VolumeMount{ + { + Name: "volume mount", + }, + }, + }, + { + Name: "secondary container", + Resources: v1.ResourceRequirements{ + Limits: v1.ResourceList{ + "gpu": resource.MustParse("2"), + }, + Requests: v1.ResourceList{ + "gpu": resource.MustParse("2"), + }, + }, + }, + }, + Volumes: []v1.Volume{ + { + Name: "dshm", + }, + }, + Tolerations: []v1.Toleration{ + { + Key: "my toleration key", + Value: "my toleration value", + }, + }, + } +} + +func TestOverrideContainerSpec(t *testing.T) { + podSpec := dummyPodSpec() + err := OverrideContainerSpec( + &podSpec, "primary container", "testing-image", + &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + }, + }, + []string{"python", "-m", "run.py"}, + ) + assert.NoError(t, err) + assert.Equal(t, 2, len(podSpec.Containers)) + assert.Equal(t, "testing-image", podSpec.Containers[0].Image) + assert.NotNil(t, podSpec.Containers[0].Resources.Limits) + assert.NotNil(t, podSpec.Containers[0].Resources.Requests) + // verify resources not overridden if empty resources + assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("250m"))) + assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("500m"))) + assert.Equal(t, []string{"python", "-m", "run.py"}, podSpec.Containers[0].Args) +} + +func TestOverrideContainerSpecEmptyFields(t *testing.T) { + podSpec := dummyPodSpec() + err := OverrideContainerSpec(&podSpec, "primary container", "", &core.Resources{}, []string{}) + assert.NoError(t, err) + assert.Equal(t, 2, len(podSpec.Containers)) + assert.NotNil(t, podSpec.Containers[0].Resources.Limits) + assert.NotNil(t, podSpec.Containers[0].Resources.Requests) + // verify resources not overridden if empty resources + assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("1"))) + assert.True(t, podSpec.Containers[0].Resources.Requests.Memory().Equal(resource.MustParse("100Mi"))) + assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("2"))) + assert.True(t, podSpec.Containers[0].Resources.Limits.Memory().Equal(resource.MustParse("200Mi"))) +} + +func TestOverrideContainerNilResources(t *testing.T) { + podSpec := dummyPodSpec() + err := OverrideContainerSpec(&podSpec, "primary container", "", nil, []string{}) + assert.NoError(t, err) + assert.Equal(t, 2, len(podSpec.Containers)) + assert.Nil(t, podSpec.Containers[0].Resources.Limits) + assert.Nil(t, podSpec.Containers[0].Resources.Requests) +} diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 7837f5f4a..d4e35a25d 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -7,6 +7,8 @@ import ( "time" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow" + flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" @@ -48,7 +50,6 @@ func (mpiOperatorResourceHandler) BuildIdentityResource(ctx context.Context, tas // Defines a func to create the full resource object that will be posted to k8s. func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { taskTemplate, err := taskCtx.TaskReader().Read(ctx) - taskTemplateConfig := taskTemplate.GetConfig() if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "unable to fetch task specification [%v]", err.Error()) @@ -56,69 +57,127 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification") } - mpiTaskExtraArgs := plugins.DistributedMPITrainingTask{} - err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &mpiTaskExtraArgs) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) - } - - workers := mpiTaskExtraArgs.GetNumWorkers() - launcherReplicas := mpiTaskExtraArgs.GetNumLauncherReplicas() - slots := mpiTaskExtraArgs.GetSlots() - podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.MPIJobDefaultContainerName) - // workersPodSpec is deepCopy of podSpec submitted by flyte - workersPodSpec := podSpec.DeepCopy() + var launcherReplica = common.ReplicaEntry{ + ReplicaNum: int32(1), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, + } + var workerReplica = common.ReplicaEntry{ + ReplicaNum: int32(0), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, + } + slots := int32(1) + runPolicy := commonOp.RunPolicy{} + + if taskTemplate.TaskTypeVersion == 0 { + mpiTaskExtraArgs := plugins.DistributedMPITrainingTask{} + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &mpiTaskExtraArgs) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } + + workerReplica.ReplicaNum = mpiTaskExtraArgs.GetNumWorkers() + launcherReplica.ReplicaNum = mpiTaskExtraArgs.GetNumLauncherReplicas() + slots = mpiTaskExtraArgs.GetSlots() - // If users don't specify "worker_spec_command" in the task config, the command/args are empty. - // However, in some cases, the workers need command/args. - // For example, in horovod tasks, each worker runs a command launching ssh daemon. + // V1 requires passing worker command as template config parameter + taskTemplateConfig := taskTemplate.GetConfig() + workerSpecCommand := []string{} + if val, ok := taskTemplateConfig[workerSpecCommandKey]; ok { + workerSpecCommand = strings.Split(val, " ") + } - workerSpecCommand := []string{} - if val, ok := taskTemplateConfig[workerSpecCommandKey]; ok { - workerSpecCommand = strings.Split(val, " ") - } + for k := range workerReplica.PodSpec.Containers { + if workerReplica.PodSpec.Containers[k].Name == kubeflowv1.MPIJobDefaultContainerName { + workerReplica.PodSpec.Containers[k].Args = workerSpecCommand + workerReplica.PodSpec.Containers[k].Command = []string{} + } + } + + } else if taskTemplate.TaskTypeVersion == 1 { + kfMPITaskExtraArgs := kfplugins.DistributedMPITrainingTask{} + + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &kfMPITaskExtraArgs) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } - for k := range workersPodSpec.Containers { - workersPodSpec.Containers[k].Args = workerSpecCommand - workersPodSpec.Containers[k].Command = []string{} + launcherReplicaSpec := kfMPITaskExtraArgs.GetLauncherReplicas() + if launcherReplicaSpec != nil { + // flyte commands will be passed as args to the container + err = common.OverrideContainerSpec( + launcherReplica.PodSpec, + kubeflowv1.MPIJobDefaultContainerName, + launcherReplicaSpec.GetImage(), + launcherReplicaSpec.GetResources(), + launcherReplicaSpec.GetCommand(), + ) + if err != nil { + return nil, err + } + launcherReplica.RestartPolicy = common.ParseRestartPolicy(launcherReplicaSpec.GetRestartPolicy()) + } + + workerReplicaSpec := kfMPITaskExtraArgs.GetWorkerReplicas() + if workerReplicaSpec != nil { + err = common.OverrideContainerSpec( + workerReplica.PodSpec, + kubeflowv1.MPIJobDefaultContainerName, + workerReplicaSpec.GetImage(), + workerReplicaSpec.GetResources(), + workerReplicaSpec.GetCommand(), + ) + if err != nil { + return nil, err + } + workerReplica.RestartPolicy = common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy()) + workerReplica.ReplicaNum = workerReplicaSpec.GetReplicas() + } + + if kfMPITaskExtraArgs.GetRunPolicy() != nil { + runPolicy = common.ParseRunPolicy(*kfMPITaskExtraArgs.GetRunPolicy()) + } + + } else { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, + "Invalid TaskSpecification, unsupported task template version [%v] key", taskTemplate.TaskTypeVersion) } - if workers == 0 { + if workerReplica.ReplicaNum == 0 { return nil, fmt.Errorf("number of worker should be more then 0") } - if launcherReplicas == 0 { + if launcherReplica.ReplicaNum == 0 { return nil, fmt.Errorf("number of launch worker should be more then 0") } jobSpec := kubeflowv1.MPIJobSpec{ - SlotsPerWorker: &slots, - MPIReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{}, - } - - for _, t := range []struct { - podSpec v1.PodSpec - replicaNum *int32 - replicaType commonOp.ReplicaType - }{ - {*podSpec, &launcherReplicas, kubeflowv1.MPIJobReplicaTypeLauncher}, - {*workersPodSpec, &workers, kubeflowv1.MPIJobReplicaTypeWorker}, - } { - if *t.replicaNum > 0 { - jobSpec.MPIReplicaSpecs[t.replicaType] = &commonOp.ReplicaSpec{ - Replicas: t.replicaNum, + SlotsPerWorker: &slots, + RunPolicy: runPolicy, + MPIReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ + kubeflowv1.MPIJobReplicaTypeLauncher: { + Replicas: &launcherReplica.ReplicaNum, Template: v1.PodTemplateSpec{ ObjectMeta: *objectMeta, - Spec: t.podSpec, + Spec: *launcherReplica.PodSpec, }, - RestartPolicy: commonOp.RestartPolicyNever, - } - } + RestartPolicy: launcherReplica.RestartPolicy, + }, + kubeflowv1.MPIJobReplicaTypeWorker: { + Replicas: &workerReplica.ReplicaNum, + Template: v1.PodTemplateSpec{ + ObjectMeta: *objectMeta, + Spec: *workerReplica.PodSpec, + }, + RestartPolicy: workerReplica.RestartPolicy, + }, + }, } job := &kubeflowv1.MPIJob{ diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 2e9e9283a..778b20a08 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -8,6 +8,7 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow" "github.com/flyteorg/flyteplugins/go/tasks/logs" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" @@ -68,9 +69,24 @@ func dummyMPICustomObj(workers int32, launcher int32, slots int32) *plugins.Dist } } -func dummyMPITaskTemplate(id string, mpiCustomObj *plugins.DistributedMPITrainingTask) *core.TaskTemplate { +func dummyMPITaskTemplate(id string, args ...interface{}) *core.TaskTemplate { + + var mpiObjJSON string + var err error + + for _, arg := range args { + switch t := arg.(type) { + case *kfplugins.DistributedMPITrainingTask: + var mpiCustomObj = t + mpiObjJSON, err = utils.MarshalToString(mpiCustomObj) + case *plugins.DistributedMPITrainingTask: + var mpiCustomObj = t + mpiObjJSON, err = utils.MarshalToString(mpiCustomObj) + default: + err = fmt.Errorf("Unkonw input type %T", t) + } + } - mpiObjJSON, err := utils.MarshalToString(mpiCustomObj) if err != nil { panic(err) } @@ -428,3 +444,119 @@ func TestReplicaCounts(t *testing.T) { }) } } + +func TestBuildResourceMPIV1(t *testing.T) { + launcherCommand := []string{"python", "launcher.py"} + workerCommand := []string{"/usr/sbin/sshd", "/.sshd_config"} + taskConfig := &kfplugins.DistributedMPITrainingTask{ + LauncherReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Image: testImage, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + }, + }, + Command: launcherCommand, + }, + WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + }, + }, + Command: workerCommand, + }, + Slots: int32(1), + } + + launcherResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("250m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("500m"), + }, + } + + workerResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + }, + } + + mpiResourceHandler := mpiOperatorResourceHandler{} + + taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) + taskTemplate.TaskTypeVersion = 1 + + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, resource) + + mpiJob, ok := resource.(*kubeflowv1.MPIJob) + assert.True(t, ok) + assert.Equal(t, int32(1), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas) + assert.Equal(t, int32(100), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas) + assert.Equal(t, int32(1), *mpiJob.Spec.SlotsPerWorker) + assert.Equal(t, *launcherResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Resources) + assert.Equal(t, *workerResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) + assert.Equal(t, launcherCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Args) + assert.Equal(t, workerCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Args) +} + +func TestBuildResourceMPIV1WithOnlyWorkerReplica(t *testing.T) { + workerCommand := []string{"/usr/sbin/sshd", "/.sshd_config"} + + taskConfig := &kfplugins.DistributedMPITrainingTask{ + WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + }, + }, + Command: []string{"/usr/sbin/sshd", "/.sshd_config"}, + }, + Slots: int32(1), + } + + workerResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + }, + } + + mpiResourceHandler := mpiOperatorResourceHandler{} + + taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) + taskTemplate.TaskTypeVersion = 1 + + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, resource) + + mpiJob, ok := resource.(*kubeflowv1.MPIJob) + assert.True(t, ok) + assert.Equal(t, int32(1), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas) + assert.Equal(t, int32(100), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas) + assert.Equal(t, int32(1), *mpiJob.Spec.SlotsPerWorker) + assert.Equal(t, *workerResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) + assert.Equal(t, testArgs, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Args) + assert.Equal(t, workerCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Args) +} diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 338f6cd56..d5cd747c6 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -6,6 +6,7 @@ import ( "time" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow" flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" @@ -68,64 +69,121 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx } common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.PytorchJobDefaultContainerName) - workers := pytorchTaskExtraArgs.GetWorkers() - if workers == 0 { - return nil, fmt.Errorf("number of worker should be more then 0") + var masterReplica = common.ReplicaEntry{ + ReplicaNum: int32(1), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, + } + var workerReplica = common.ReplicaEntry{ + ReplicaNum: int32(0), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, } + runPolicy := commonOp.RunPolicy{} - var jobSpec kubeflowv1.PyTorchJobSpec + if taskTemplate.TaskTypeVersion == 0 { + pytorchTaskExtraArgs := plugins.DistributedPyTorchTrainingTask{} - elasticConfig := pytorchTaskExtraArgs.GetElasticConfig() + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &pytorchTaskExtraArgs) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } - if elasticConfig != nil { - minReplicas := elasticConfig.GetMinReplicas() - maxReplicas := elasticConfig.GetMaxReplicas() - nProcPerNode := elasticConfig.GetNprocPerNode() - maxRestarts := elasticConfig.GetMaxRestarts() - rdzvBackend := kubeflowv1.RDZVBackend(elasticConfig.GetRdzvBackend()) + workerReplica.ReplicaNum = pytorchTaskExtraArgs.GetWorkers() + } else if taskTemplate.TaskTypeVersion == 1 { + kfPytorchTaskExtraArgs := kfplugins.DistributedPyTorchTrainingTask{} - jobSpec = kubeflowv1.PyTorchJobSpec{ - ElasticPolicy: &kubeflowv1.ElasticPolicy{ - MinReplicas: &minReplicas, - MaxReplicas: &maxReplicas, - RDZVBackend: &rdzvBackend, - NProcPerNode: &nProcPerNode, - MaxRestarts: &maxRestarts, - }, - PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ - kubeflowv1.PyTorchJobReplicaTypeWorker: { - Replicas: &workers, - Template: v1.PodTemplateSpec{ - ObjectMeta: *objectMeta, - Spec: *podSpec, - }, - RestartPolicy: commonOp.RestartPolicyNever, - }, - }, + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &kfPytorchTaskExtraArgs) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } + + // Replace specs of master replica, master should always have 1 replica + masterReplicaSpec := kfPytorchTaskExtraArgs.GetMasterReplicas() + if masterReplicaSpec != nil { + err := common.OverrideContainerSpec( + masterReplica.PodSpec, + kubeflowv1.PytorchJobDefaultContainerName, + masterReplicaSpec.GetImage(), + masterReplicaSpec.GetResources(), + nil, + ) + if err != nil { + return nil, err + } + masterReplica.RestartPolicy = common.ParseRestartPolicy(masterReplicaSpec.GetRestartPolicy()) + } + + // Replace specs of worker replica + workerReplicaSpec := kfPytorchTaskExtraArgs.GetWorkerReplicas() + if workerReplicaSpec != nil { + err := common.OverrideContainerSpec( + workerReplica.PodSpec, + kubeflowv1.PytorchJobDefaultContainerName, + workerReplicaSpec.GetImage(), + workerReplicaSpec.GetResources(), + nil, + ) + if err != nil { + return nil, err + } + workerReplica.RestartPolicy = common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy()) + workerReplica.ReplicaNum = workerReplicaSpec.GetReplicas() } + if kfPytorchTaskExtraArgs.GetRunPolicy() != nil { + runPolicy = common.ParseRunPolicy(*kfPytorchTaskExtraArgs.GetRunPolicy()) + } } else { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, + "Invalid TaskSpecification, unsupported task template version [%v] key", taskTemplate.TaskTypeVersion) + } + + if workerReplica.ReplicaNum == 0 { + return nil, fmt.Errorf("number of worker should be more then 0") + } - jobSpec = kubeflowv1.PyTorchJobSpec{ - PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ - kubeflowv1.PyTorchJobReplicaTypeMaster: { - Template: v1.PodTemplateSpec{ - ObjectMeta: *objectMeta, - Spec: *podSpec, - }, - RestartPolicy: commonOp.RestartPolicyNever, + jobSpec := kubeflowv1.PyTorchJobSpec{ + PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ + kubeflowv1.PyTorchJobReplicaTypeMaster: { + Template: v1.PodTemplateSpec{ + ObjectMeta: *objectMeta, + Spec: *masterReplica.PodSpec, }, - kubeflowv1.PyTorchJobReplicaTypeWorker: { - Replicas: &workers, - Template: v1.PodTemplateSpec{ - ObjectMeta: *objectMeta, - Spec: *podSpec, - }, - RestartPolicy: commonOp.RestartPolicyNever, + RestartPolicy: masterReplica.RestartPolicy, + }, + kubeflowv1.PyTorchJobReplicaTypeWorker: { + Replicas: &workerReplica.ReplicaNum, + Template: v1.PodTemplateSpec{ + ObjectMeta: *objectMeta, + Spec: *workerReplica.PodSpec, }, + RestartPolicy: workerReplica.RestartPolicy, }, + }, + RunPolicy: runPolicy, + } + + // Set elastic config + elasticConfig := pytorchTaskExtraArgs.GetElasticConfig() + if elasticConfig != nil { + minReplicas := elasticConfig.GetMinReplicas() + maxReplicas := elasticConfig.GetMaxReplicas() + nProcPerNode := elasticConfig.GetNprocPerNode() + maxRestarts := elasticConfig.GetMaxRestarts() + rdzvBackend := kubeflowv1.RDZVBackend(elasticConfig.GetRdzvBackend()) + elasticPolicy := kubeflowv1.ElasticPolicy{ + MinReplicas: &minReplicas, + MaxReplicas: &maxReplicas, + RDZVBackend: &rdzvBackend, + NProcPerNode: &nProcPerNode, + MaxRestarts: &maxRestarts, } + jobSpec.ElasticPolicy = &elasticPolicy + // Remove master replica if elastic policy is set + delete(jobSpec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeMaster) } + job := &kubeflowv1.PyTorchJob{ TypeMeta: metav1.TypeMeta{ Kind: kubeflowv1.PytorchJobKind, diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 74bd3fe92..fea07505a 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -26,6 +26,7 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow" "github.com/golang/protobuf/jsonpb" structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" @@ -35,6 +36,7 @@ import ( ) const testImage = "image://" +const testImageMaster = "image://master" const serviceAccount = "pytorch_sa" var ( @@ -76,9 +78,24 @@ func dummyElasticPytorchCustomObj(workers int32, elasticConfig plugins.ElasticCo } } -func dummyPytorchTaskTemplate(id string, pytorchCustomObj *plugins.DistributedPyTorchTrainingTask) *core.TaskTemplate { +func dummyPytorchTaskTemplate(id string, args ...interface{}) *core.TaskTemplate { + + var ptObjJSON string + var err error + + for _, arg := range args { + switch t := arg.(type) { + case *kfplugins.DistributedPyTorchTrainingTask: + var pytorchCustomObj = t + ptObjJSON, err = utils.MarshalToString(pytorchCustomObj) + case *plugins.DistributedPyTorchTrainingTask: + var pytorchCustomObj = t + ptObjJSON, err = utils.MarshalToString(pytorchCustomObj) + default: + err = fmt.Errorf("Unkonw input type %T", t) + } + } - ptObjJSON, err := utils.MarshalToString(pytorchCustomObj) if err != nil { panic(err) } @@ -457,3 +474,148 @@ func TestReplicaCounts(t *testing.T) { }) } } + +func TestBuildResourcePytorchV1(t *testing.T) { + taskConfig := &kfplugins.DistributedPyTorchTrainingTask{ + MasterReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Image: testImageMaster, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + }, + }, + RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS, + }, + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + }, + }, + }, + RunPolicy: &kfplugins.RunPolicy{ + CleanPodPolicy: kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL, + BackoffLimit: 100, + }, + } + + masterResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("250m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("500m"), + }, + } + + workerResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + }, + } + + pytorchResourceHandler := pytorchOperatorResourceHandler{} + + taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) + taskTemplate.TaskTypeVersion = 1 + + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, res) + + pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) + assert.True(t, ok) + + assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) + assert.Nil(t, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) + + assert.Equal(t, testImageMaster, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Image) + assert.Equal(t, testImage, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Image) + + assert.Equal(t, *masterResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Resources) + assert.Equal(t, *workerResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) + + assert.Equal(t, commonOp.RestartPolicyAlways, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].RestartPolicy) + assert.Equal(t, commonOp.RestartPolicyNever, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].RestartPolicy) + + assert.Equal(t, commonOp.CleanPodPolicyAll, *pytorchJob.Spec.RunPolicy.CleanPodPolicy) + assert.Equal(t, int32(100), *pytorchJob.Spec.RunPolicy.BackoffLimit) + assert.Nil(t, pytorchJob.Spec.RunPolicy.TTLSecondsAfterFinished) + assert.Nil(t, pytorchJob.Spec.RunPolicy.ActiveDeadlineSeconds) + + assert.Nil(t, pytorchJob.Spec.ElasticPolicy) +} + +func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) { + taskConfig := &kfplugins.DistributedPyTorchTrainingTask{ + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + }, + }, + }, + } + // Master Replica should use resource from task override if not set + taskOverrideResourceRequirements := &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1000m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("100m"), + corev1.ResourceMemory: resource.MustParse("512Mi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + } + + workerResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + }, + } + + pytorchResourceHandler := pytorchOperatorResourceHandler{} + + taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) + taskTemplate.TaskTypeVersion = 1 + + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, res) + + pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) + assert.True(t, ok) + + assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) + assert.Nil(t, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) + + assert.Equal(t, testImage, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Image) + assert.Equal(t, testImage, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Image) + + assert.Equal(t, *taskOverrideResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Resources) + assert.Equal(t, *workerResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) + + assert.Equal(t, commonOp.RestartPolicyNever, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].RestartPolicy) + assert.Equal(t, commonOp.RestartPolicyNever, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].RestartPolicy) + + assert.Nil(t, pytorchJob.Spec.ElasticPolicy) +} diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index b5a5a675f..6ee3ce440 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -6,6 +6,7 @@ import ( "time" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow" flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" @@ -56,23 +57,109 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification") } - tensorflowTaskExtraArgs := plugins.DistributedTensorflowTrainingTask{} - err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &tensorflowTaskExtraArgs) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) - } - podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.TFJobDefaultContainerName) - workers := tensorflowTaskExtraArgs.GetWorkers() - psReplicas := tensorflowTaskExtraArgs.GetPsReplicas() - chiefReplicas := tensorflowTaskExtraArgs.GetChiefReplicas() + replicaSpecMap := map[commonOp.ReplicaType]*common.ReplicaEntry{ + kubeflowv1.TFJobReplicaTypeChief: { + ReplicaNum: int32(0), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, + }, + kubeflowv1.TFJobReplicaTypeWorker: { + ReplicaNum: int32(0), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, + }, + kubeflowv1.TFJobReplicaTypePS: { + ReplicaNum: int32(0), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, + }, + } + runPolicy := commonOp.RunPolicy{} + + if taskTemplate.TaskTypeVersion == 0 { + tensorflowTaskExtraArgs := plugins.DistributedTensorflowTrainingTask{} + + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &tensorflowTaskExtraArgs) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } + + replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].ReplicaNum = tensorflowTaskExtraArgs.GetChiefReplicas() + replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].ReplicaNum = tensorflowTaskExtraArgs.GetWorkers() + replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].ReplicaNum = tensorflowTaskExtraArgs.GetPsReplicas() + + } else if taskTemplate.TaskTypeVersion == 1 { + kfTensorflowTaskExtraArgs := kfplugins.DistributedTensorflowTrainingTask{} + + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &kfTensorflowTaskExtraArgs) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } + + chiefReplicaSpec := kfTensorflowTaskExtraArgs.GetChiefReplicas() + if chiefReplicaSpec != nil { + err := common.OverrideContainerSpec( + replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].PodSpec, + kubeflowv1.TFJobDefaultContainerName, + chiefReplicaSpec.GetImage(), + chiefReplicaSpec.GetResources(), + nil, + ) + if err != nil { + return nil, err + } + replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].RestartPolicy = common.ParseRestartPolicy(chiefReplicaSpec.GetRestartPolicy()) + replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].ReplicaNum = chiefReplicaSpec.GetReplicas() + } + + workerReplicaSpec := kfTensorflowTaskExtraArgs.GetWorkerReplicas() + if workerReplicaSpec != nil { + err := common.OverrideContainerSpec( + replicaSpecMap[kubeflowv1.MPIJobReplicaTypeWorker].PodSpec, + kubeflowv1.TFJobDefaultContainerName, + workerReplicaSpec.GetImage(), + workerReplicaSpec.GetResources(), + nil, + ) + if err != nil { + return nil, err + } + replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].RestartPolicy = common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy()) + replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].ReplicaNum = workerReplicaSpec.GetReplicas() + } + + psReplicaSpec := kfTensorflowTaskExtraArgs.GetPsReplicas() + if psReplicaSpec != nil { + err := common.OverrideContainerSpec( + replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].PodSpec, + kubeflowv1.TFJobDefaultContainerName, + psReplicaSpec.GetImage(), + psReplicaSpec.GetResources(), + nil, + ) + if err != nil { + return nil, err + } + replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].RestartPolicy = common.ParseRestartPolicy(psReplicaSpec.GetRestartPolicy()) + replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].ReplicaNum = psReplicaSpec.GetReplicas() + } + + if kfTensorflowTaskExtraArgs.GetRunPolicy() != nil { + runPolicy = common.ParseRunPolicy(*kfTensorflowTaskExtraArgs.GetRunPolicy()) + } - if workers == 0 { + } else { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, + "Invalid TaskSpecification, unsupported task template version [%v] key", taskTemplate.TaskTypeVersion) + } + + if replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].ReplicaNum == 0 { return nil, fmt.Errorf("number of worker should be more then 0") } @@ -80,27 +167,21 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task TFReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{}, } - for _, t := range []struct { - podSpec v1.PodSpec - replicaNum *int32 - replicaType commonOp.ReplicaType - }{ - {*podSpec, &workers, kubeflowv1.TFJobReplicaTypeWorker}, - {*podSpec, &psReplicas, kubeflowv1.TFJobReplicaTypePS}, - {*podSpec, &chiefReplicas, kubeflowv1.TFJobReplicaTypeChief}, - } { - if *t.replicaNum > 0 { - jobSpec.TFReplicaSpecs[t.replicaType] = &commonOp.ReplicaSpec{ - Replicas: t.replicaNum, + for replicaType, replicaEntry := range replicaSpecMap { + if replicaEntry.ReplicaNum > 0 { + jobSpec.TFReplicaSpecs[replicaType] = &commonOp.ReplicaSpec{ + Replicas: &replicaEntry.ReplicaNum, Template: v1.PodTemplateSpec{ ObjectMeta: *objectMeta, - Spec: t.podSpec, + Spec: *replicaEntry.PodSpec, }, - RestartPolicy: commonOp.RestartPolicyNever, + RestartPolicy: replicaEntry.RestartPolicy, } } } + jobSpec.RunPolicy = runPolicy + job := &kubeflowv1.TFJob{ TypeMeta: metav1.TypeMeta{ Kind: kubeflowv1.TFJobKind, diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 37f22bf34..8174258e1 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -26,6 +26,7 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow" "github.com/golang/protobuf/jsonpb" structpb "github.com/golang/protobuf/ptypes/struct" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" @@ -70,9 +71,24 @@ func dummyTensorFlowCustomObj(workers int32, psReplicas int32, chiefReplicas int } } -func dummyTensorFlowTaskTemplate(id string, tensorflowCustomObj *plugins.DistributedTensorflowTrainingTask) *core.TaskTemplate { +func dummyTensorFlowTaskTemplate(id string, args ...interface{}) *core.TaskTemplate { + + var tfObjJSON string + var err error + + for _, arg := range args { + switch t := arg.(type) { + case *kfplugins.DistributedTensorflowTrainingTask: + var tensorflowCustomObj = t + tfObjJSON, err = utils.MarshalToString(tensorflowCustomObj) + case *plugins.DistributedTensorflowTrainingTask: + var tensorflowCustomObj = t + tfObjJSON, err = utils.MarshalToString(tensorflowCustomObj) + default: + err = fmt.Errorf("Unkonw input type %T", t) + } + } - tfObjJSON, err := utils.MarshalToString(tensorflowCustomObj) if err != nil { panic(err) } @@ -420,3 +436,171 @@ func TestReplicaCounts(t *testing.T) { }) } } + +func TestBuildResourceTensorFlowV1(t *testing.T) { + taskConfig := &kfplugins.DistributedTensorflowTrainingTask{ + ChiefReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 1, + Image: testImage, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, + }, + RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS, + }, + WorkerReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + }, + }, + PsReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 50, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + }, + }, + }, + RunPolicy: &kfplugins.RunPolicy{ + CleanPodPolicy: kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL, + ActiveDeadlineSeconds: int32(100), + }, + } + + resourceRequirementsMap := map[commonOp.ReplicaType]*corev1.ResourceRequirements{ + kubeflowv1.TFJobReplicaTypeChief: { + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("250m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("500m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, + }, + kubeflowv1.TFJobReplicaTypeWorker: { + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + kubeflowv1.TFJobReplicaTypePS: { + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("250m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("500m"), + }, + }, + } + + tensorflowResourceHandler := tensorflowOperatorResourceHandler{} + + taskTemplate := dummyTensorFlowTaskTemplate("v1", taskConfig) + taskTemplate.TaskTypeVersion = 1 + + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, resource) + + tensorflowJob, ok := resource.(*kubeflowv1.TFJob) + assert.True(t, ok) + assert.Equal(t, int32(100), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Replicas) + assert.Equal(t, int32(50), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS].Replicas) + assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Replicas) + + for replicaType, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs { + var hasContainerWithDefaultTensorFlowName = false + + for _, container := range replicaSpec.Template.Spec.Containers { + if container.Name == kubeflowv1.TFJobDefaultContainerName { + hasContainerWithDefaultTensorFlowName = true + assert.Equal(t, *resourceRequirementsMap[replicaType], container.Resources) + } + } + + assert.True(t, hasContainerWithDefaultTensorFlowName) + } + assert.Equal(t, commonOp.CleanPodPolicyAll, *tensorflowJob.Spec.RunPolicy.CleanPodPolicy) + assert.Equal(t, int64(100), *tensorflowJob.Spec.RunPolicy.ActiveDeadlineSeconds) +} + +func TestBuildResourceTensorFlowV1WithOnlyWorker(t *testing.T) { + taskConfig := &kfplugins.DistributedTensorflowTrainingTask{ + WorkerReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + }, + }, + } + + resourceRequirementsMap := map[commonOp.ReplicaType]*corev1.ResourceRequirements{ + kubeflowv1.TFJobReplicaTypeWorker: { + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + } + + tensorflowResourceHandler := tensorflowOperatorResourceHandler{} + + taskTemplate := dummyTensorFlowTaskTemplate("v1 with only worker replica", taskConfig) + taskTemplate.TaskTypeVersion = 1 + + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, resource) + + tensorflowJob, ok := resource.(*kubeflowv1.TFJob) + assert.True(t, ok) + assert.Equal(t, int32(100), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Replicas) + assert.Nil(t, tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief]) + assert.Nil(t, tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS]) + + for replicaType, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs { + var hasContainerWithDefaultTensorFlowName = false + + for _, container := range replicaSpec.Template.Spec.Containers { + if container.Name == kubeflowv1.TFJobDefaultContainerName { + hasContainerWithDefaultTensorFlowName = true + assert.Equal(t, *resourceRequirementsMap[replicaType], container.Resources) + } + } + + assert.True(t, hasContainerWithDefaultTensorFlowName) + } +}