Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add test for pytorch elastic task
Browse files Browse the repository at this point in the history
  • Loading branch information
Fabio Grätz committed Apr 22, 2023
1 parent 53c7a27 commit e1a5a7e
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ func dummyPytorchCustomObj(workers int32) *plugins.DistributedPyTorchTrainingTas
}
}

func dummyElasticPytorchCustomObj(workers int32, elasticConfig plugins.ElasticConfig) *plugins.DistributedPyTorchTrainingTask {
return &plugins.DistributedPyTorchTrainingTask{
Workers: workers,
ElasticConfig: &elasticConfig,
}
}

func dummyPytorchTaskTemplate(id string, pytorchCustomObj *plugins.DistributedPyTorchTrainingTask) *core.TaskTemplate {

ptObjJSON, err := utils.MarshalToString(pytorchCustomObj)
Expand Down Expand Up @@ -282,6 +289,39 @@ func dummyPytorchJobResource(pytorchResourceHandler pytorchOperatorResourceHandl
}
}

func TestBuildResourcePytorchElastic(t *testing.T) {
pytorchResourceHandler := pytorchOperatorResourceHandler{}

ptObj := dummyElasticPytorchCustomObj(2, plugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"})
taskTemplate := dummyPytorchTaskTemplate("the job", ptObj)

resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate))
assert.NoError(t, err)
assert.NotNil(t, resource)

pytorchJob, ok := resource.(*kubeflowv1.PyTorchJob)
assert.True(t, ok)
assert.Equal(t, int32(2), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas)
assert.NotNil(t, pytorchJob.Spec.ElasticPolicy)
assert.Equal(t, int32(1), *pytorchJob.Spec.ElasticPolicy.MinReplicas)
assert.Equal(t, int32(2), *pytorchJob.Spec.ElasticPolicy.MaxReplicas)
assert.Equal(t, int32(4), *pytorchJob.Spec.ElasticPolicy.NProcPerNode)
assert.Equal(t, kubeflowv1.RDZVBackend("c10d"), *pytorchJob.Spec.ElasticPolicy.RDZVBackend)

assert.Equal(t, 1, len(pytorchJob.Spec.PyTorchReplicaSpecs))
assert.Contains(t, pytorchJob.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker)

var hasContainerWithDefaultPytorchName = false

for _, container := range pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers {
if container.Name == kubeflowv1.PytorchJobDefaultContainerName {
hasContainerWithDefaultPytorchName = true
}
}

assert.True(t, hasContainerWithDefaultPytorchName)
}

func TestBuildResourcePytorch(t *testing.T) {
pytorchResourceHandler := pytorchOperatorResourceHandler{}

Expand All @@ -295,6 +335,7 @@ func TestBuildResourcePytorch(t *testing.T) {
pytorchJob, ok := resource.(*kubeflowv1.PyTorchJob)
assert.True(t, ok)
assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas)
assert.Nil(t, pytorchJob.Spec.ElasticPolicy)

for _, replicaSpec := range pytorchJob.Spec.PyTorchReplicaSpecs {
var hasContainerWithDefaultPytorchName = false
Expand Down

0 comments on commit e1a5a7e

Please sign in to comment.