Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Feat: Configure elastic training in pytorch plugin (#343)
Browse files Browse the repository at this point in the history
  • Loading branch information
fg91 authored Apr 24, 2023
1 parent 2ad8d08 commit 01f2126
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 26 deletions.
2 changes: 1 addition & 1 deletion flyteplugins/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/athena v1.0.0
github.com/bstadlbauer/dask-k8s-operator-go-client v0.1.0
github.com/coocood/freecache v1.1.1
github.com/flyteorg/flyteidl v1.3.16
github.com/flyteorg/flyteidl v1.3.19
github.com/flyteorg/flytestdlib v1.0.15
github.com/go-test/deep v1.0.7
github.com/golang/protobuf v1.5.2
Expand Down
4 changes: 2 additions & 2 deletions flyteplugins/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQL
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/flyteorg/flyteidl v1.3.16 h1:mRq1VeUl5LP12dezbGHLQcrLuAmO9kawK9X7arqCInM=
github.com/flyteorg/flyteidl v1.3.16/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM=
github.com/flyteorg/flyteidl v1.3.19 h1:i79Dh7UoP8Z4LEJ2ox6jlfZVJtFZ+r4g84CJj1gh22Y=
github.com/flyteorg/flyteidl v1.3.19/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM=
github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0=
github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s=
github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk=
Expand Down
63 changes: 48 additions & 15 deletions flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,26 +73,59 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
return nil, fmt.Errorf("number of worker should be more then 0")
}

jobSpec := kubeflowv1.PyTorchJobSpec{
PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{
kubeflowv1.PyTorchJobReplicaTypeMaster: {
Template: v1.PodTemplateSpec{
ObjectMeta: *objectMeta,
Spec: *podSpec,
var jobSpec kubeflowv1.PyTorchJobSpec

elasticConfig := pytorchTaskExtraArgs.GetElasticConfig()

if elasticConfig != nil {
minReplicas := elasticConfig.GetMinReplicas()
maxReplicas := elasticConfig.GetMaxReplicas()
nProcPerNode := elasticConfig.GetNprocPerNode()
maxRestarts := elasticConfig.GetMaxRestarts()
rdzvBackend := kubeflowv1.RDZVBackend(elasticConfig.GetRdzvBackend())

jobSpec = kubeflowv1.PyTorchJobSpec{
ElasticPolicy: &kubeflowv1.ElasticPolicy{
MinReplicas: &minReplicas,
MaxReplicas: &maxReplicas,
RDZVBackend: &rdzvBackend,
NProcPerNode: &nProcPerNode,
MaxRestarts: &maxRestarts,
},
PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{
kubeflowv1.PyTorchJobReplicaTypeWorker: {
Replicas: &workers,
Template: v1.PodTemplateSpec{
ObjectMeta: *objectMeta,
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
},
RestartPolicy: commonOp.RestartPolicyNever,
},
kubeflowv1.PyTorchJobReplicaTypeWorker: {
Replicas: &workers,
Template: v1.PodTemplateSpec{
ObjectMeta: *objectMeta,
Spec: *podSpec,
}

} else {

jobSpec = kubeflowv1.PyTorchJobSpec{
PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{
kubeflowv1.PyTorchJobReplicaTypeMaster: {
Template: v1.PodTemplateSpec{
ObjectMeta: *objectMeta,
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
},
kubeflowv1.PyTorchJobReplicaTypeWorker: {
Replicas: &workers,
Template: v1.PodTemplateSpec{
ObjectMeta: *objectMeta,
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
},
RestartPolicy: commonOp.RestartPolicyNever,
},
},
}
}

job := &kubeflowv1.PyTorchJob{
TypeMeta: metav1.TypeMeta{
Kind: kubeflowv1.PytorchJobKind,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ func dummyPytorchCustomObj(workers int32) *plugins.DistributedPyTorchTrainingTas
}
}

func dummyElasticPytorchCustomObj(workers int32, elasticConfig plugins.ElasticConfig) *plugins.DistributedPyTorchTrainingTask {
return &plugins.DistributedPyTorchTrainingTask{
Workers: workers,
ElasticConfig: &elasticConfig,
}
}

func dummyPytorchTaskTemplate(id string, pytorchCustomObj *plugins.DistributedPyTorchTrainingTask) *core.TaskTemplate {

ptObjJSON, err := utils.MarshalToString(pytorchCustomObj)
Expand Down Expand Up @@ -260,7 +267,7 @@ func dummyPytorchJobResource(pytorchResourceHandler pytorchOperatorResourceHandl
}

ptObj := dummyPytorchCustomObj(workers)
taskTemplate := dummyPytorchTaskTemplate("the job", ptObj)
taskTemplate := dummyPytorchTaskTemplate("job1", ptObj)
resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate))
if err != nil {
panic(err)
Expand All @@ -282,19 +289,53 @@ func dummyPytorchJobResource(pytorchResourceHandler pytorchOperatorResourceHandl
}
}

func TestBuildResourcePytorch(t *testing.T) {
func TestBuildResourcePytorchElastic(t *testing.T) {
pytorchResourceHandler := pytorchOperatorResourceHandler{}

ptObj := dummyPytorchCustomObj(100)
taskTemplate := dummyPytorchTaskTemplate("the job", ptObj)
ptObj := dummyElasticPytorchCustomObj(2, plugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"})
taskTemplate := dummyPytorchTaskTemplate("job2", ptObj)

resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate))
assert.NoError(t, err)
assert.NotNil(t, resource)

pytorchJob, ok := resource.(*kubeflowv1.PyTorchJob)
assert.True(t, ok)
assert.Equal(t, int32(2), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas)
assert.NotNil(t, pytorchJob.Spec.ElasticPolicy)
assert.Equal(t, int32(1), *pytorchJob.Spec.ElasticPolicy.MinReplicas)
assert.Equal(t, int32(2), *pytorchJob.Spec.ElasticPolicy.MaxReplicas)
assert.Equal(t, int32(4), *pytorchJob.Spec.ElasticPolicy.NProcPerNode)
assert.Equal(t, kubeflowv1.RDZVBackend("c10d"), *pytorchJob.Spec.ElasticPolicy.RDZVBackend)

assert.Equal(t, 1, len(pytorchJob.Spec.PyTorchReplicaSpecs))
assert.Contains(t, pytorchJob.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker)

var hasContainerWithDefaultPytorchName = false

for _, container := range pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers {
if container.Name == kubeflowv1.PytorchJobDefaultContainerName {
hasContainerWithDefaultPytorchName = true
}
}

assert.True(t, hasContainerWithDefaultPytorchName)
}

func TestBuildResourcePytorch(t *testing.T) {
pytorchResourceHandler := pytorchOperatorResourceHandler{}

ptObj := dummyPytorchCustomObj(100)
taskTemplate := dummyPytorchTaskTemplate("job3", ptObj)

res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate))
assert.NoError(t, err)
assert.NotNil(t, res)

pytorchJob, ok := res.(*kubeflowv1.PyTorchJob)
assert.True(t, ok)
assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas)
assert.Nil(t, pytorchJob.Spec.ElasticPolicy)

for _, replicaSpec := range pytorchJob.Spec.PyTorchReplicaSpecs {
var hasContainerWithDefaultPytorchName = false
Expand Down Expand Up @@ -392,17 +433,17 @@ func TestReplicaCounts(t *testing.T) {
ptObj := dummyPytorchCustomObj(test.workerReplicaCount)
taskTemplate := dummyPytorchTaskTemplate("the job", ptObj)

resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate))
res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate))
if test.expectError {
assert.Error(t, err)
assert.Nil(t, resource)
assert.Nil(t, res)
return
}

assert.NoError(t, err)
assert.NotNil(t, resource)
assert.NotNil(t, res)

job, ok := resource.(*kubeflowv1.PyTorchJob)
job, ok := res.(*kubeflowv1.PyTorchJob)
assert.True(t, ok)

assert.Len(t, job.Spec.PyTorchReplicaSpecs, len(test.contains))
Expand Down

0 comments on commit 01f2126

Please sign in to comment.