Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
fix v1 pytorch job plugin with elastic policy (#359)
Browse files Browse the repository at this point in the history
* fix pytorch job plugin elastic policy

Signed-off-by: Yubo Wang <[email protected]>

* add ElasticConfig interface

Signed-off-by: Yubo Wang <[email protected]>

* add more testing

Signed-off-by: Yubo Wang <[email protected]>

---------

Signed-off-by: Yubo Wang <[email protected]>
Co-authored-by: Yubo Wang <[email protected]>
  • Loading branch information
yubofredwang and Yubo Wang authored Jun 14, 2023
1 parent 53e63bc commit 318fa6b
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 29 deletions.
63 changes: 40 additions & 23 deletions go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification")
}

pytorchTaskExtraArgs := plugins.DistributedPyTorchTrainingTask{}
err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &pytorchTaskExtraArgs)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error())
}

podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
Expand All @@ -80,6 +74,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
RestartPolicy: commonOp.RestartPolicyNever,
}
runPolicy := commonOp.RunPolicy{}
var elasticPolicy *kubeflowv1.ElasticPolicy

if taskTemplate.TaskTypeVersion == 0 {
pytorchTaskExtraArgs := plugins.DistributedPyTorchTrainingTask{}
Expand All @@ -90,6 +85,11 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
}

workerReplica.ReplicaNum = pytorchTaskExtraArgs.GetWorkers()
// Set elastic config
elasticConfig := pytorchTaskExtraArgs.GetElasticConfig()
if elasticConfig != nil {
elasticPolicy = ParseElasticConfig(elasticConfig)
}
} else if taskTemplate.TaskTypeVersion == 1 {
kfPytorchTaskExtraArgs := kfplugins.DistributedPyTorchTrainingTask{}

Expand Down Expand Up @@ -134,6 +134,11 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
if kfPytorchTaskExtraArgs.GetRunPolicy() != nil {
runPolicy = common.ParseRunPolicy(*kfPytorchTaskExtraArgs.GetRunPolicy())
}
// Set elastic config
elasticConfig := kfPytorchTaskExtraArgs.GetElasticConfig()
if elasticConfig != nil {
elasticPolicy = ParseElasticConfig(elasticConfig)
}
} else {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification,
"Invalid TaskSpecification, unsupported task template version [%v] key", taskTemplate.TaskTypeVersion)
Expand Down Expand Up @@ -164,23 +169,9 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
RunPolicy: runPolicy,
}

// Set elastic config
elasticConfig := pytorchTaskExtraArgs.GetElasticConfig()
if elasticConfig != nil {
minReplicas := elasticConfig.GetMinReplicas()
maxReplicas := elasticConfig.GetMaxReplicas()
nProcPerNode := elasticConfig.GetNprocPerNode()
maxRestarts := elasticConfig.GetMaxRestarts()
rdzvBackend := kubeflowv1.RDZVBackend(elasticConfig.GetRdzvBackend())
elasticPolicy := kubeflowv1.ElasticPolicy{
MinReplicas: &minReplicas,
MaxReplicas: &maxReplicas,
RDZVBackend: &rdzvBackend,
NProcPerNode: &nProcPerNode,
MaxRestarts: &maxRestarts,
}
jobSpec.ElasticPolicy = &elasticPolicy
// Remove master replica if elastic policy is set
if elasticPolicy != nil {
jobSpec.ElasticPolicy = elasticPolicy
// Remove master replica spec if elastic policy is set
delete(jobSpec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeMaster)
}

Expand All @@ -195,6 +186,32 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
return job, nil
}

// Interface for unified elastic config handling across plugin version v0 and v1. This interface should
// always be aligned with the ElasticConfig defined in flyteidl.
type ElasticConfig interface {
GetMinReplicas() int32
GetMaxReplicas() int32
GetNprocPerNode() int32
GetMaxRestarts() int32
GetRdzvBackend() string
}

// To support parsing elastic config from both v0 and v1 of kubeflow pytorch idl
func ParseElasticConfig(elasticConfig ElasticConfig) *kubeflowv1.ElasticPolicy {
minReplicas := elasticConfig.GetMinReplicas()
maxReplicas := elasticConfig.GetMaxReplicas()
nProcPerNode := elasticConfig.GetNprocPerNode()
maxRestarts := elasticConfig.GetMaxRestarts()
rdzvBackend := kubeflowv1.RDZVBackend(elasticConfig.GetRdzvBackend())
return &kubeflowv1.ElasticPolicy{
MinReplicas: &minReplicas,
MaxReplicas: &maxReplicas,
RDZVBackend: &rdzvBackend,
NProcPerNode: &nProcPerNode,
MaxRestarts: &maxRestarts,
}
}

// Analyses the k8s resource and reports the status as TaskPhase. This call is expected to be relatively fast,
// any operations that might take a long time (limits are configured system-wide) should be offloaded to the
// background.
Expand Down
99 changes: 93 additions & 6 deletions go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,6 @@ func TestBuildResourcePytorchV1(t *testing.T) {
},
},
},
RunPolicy: &kfplugins.RunPolicy{
CleanPodPolicy: kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL,
BackoffLimit: 100,
},
}

masterResourceRequirements := &corev1.ResourceRequirements{
Expand Down Expand Up @@ -567,14 +563,45 @@ func TestBuildResourcePytorchV1(t *testing.T) {
assert.Equal(t, commonOp.RestartPolicyAlways, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].RestartPolicy)
assert.Equal(t, commonOp.RestartPolicyNever, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].RestartPolicy)

assert.Equal(t, commonOp.CleanPodPolicyAll, *pytorchJob.Spec.RunPolicy.CleanPodPolicy)
assert.Equal(t, int32(100), *pytorchJob.Spec.RunPolicy.BackoffLimit)
assert.Nil(t, pytorchJob.Spec.RunPolicy.CleanPodPolicy)
assert.Nil(t, pytorchJob.Spec.RunPolicy.BackoffLimit)
assert.Nil(t, pytorchJob.Spec.RunPolicy.TTLSecondsAfterFinished)
assert.Nil(t, pytorchJob.Spec.RunPolicy.ActiveDeadlineSeconds)

assert.Nil(t, pytorchJob.Spec.ElasticPolicy)
}

func TestBuildResourcePytorchV1WithRunPolicy(t *testing.T) {
taskConfig := &kfplugins.DistributedPyTorchTrainingTask{
WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{
Replicas: 100,
},
RunPolicy: &kfplugins.RunPolicy{
CleanPodPolicy: kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL,
BackoffLimit: 100,
ActiveDeadlineSeconds: 1000,
TtlSecondsAfterFinished: 10000,
},
}
pytorchResourceHandler := pytorchOperatorResourceHandler{}

taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig)
taskTemplate.TaskTypeVersion = 1

res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate))
assert.NoError(t, err)
assert.NotNil(t, res)

pytorchJob, ok := res.(*kubeflowv1.PyTorchJob)
assert.True(t, ok)
assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas)
assert.Nil(t, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas)
assert.Equal(t, commonOp.CleanPodPolicyAll, *pytorchJob.Spec.RunPolicy.CleanPodPolicy)
assert.Equal(t, int32(100), *pytorchJob.Spec.RunPolicy.BackoffLimit)
assert.Equal(t, int64(1000), *pytorchJob.Spec.RunPolicy.ActiveDeadlineSeconds)
assert.Equal(t, int32(10000), *pytorchJob.Spec.RunPolicy.TTLSecondsAfterFinished)
}

func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) {
taskConfig := &kfplugins.DistributedPyTorchTrainingTask{
WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{
Expand Down Expand Up @@ -638,3 +665,63 @@ func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) {

assert.Nil(t, pytorchJob.Spec.ElasticPolicy)
}

func TestBuildResourcePytorchV1WithElastic(t *testing.T) {
taskConfig := &kfplugins.DistributedPyTorchTrainingTask{
WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{
Replicas: 2,
},
ElasticConfig: &kfplugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"},
}
taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig)
taskTemplate.TaskTypeVersion = 1

pytorchResourceHandler := pytorchOperatorResourceHandler{}
resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate))
assert.NoError(t, err)
assert.NotNil(t, resource)

pytorchJob, ok := resource.(*kubeflowv1.PyTorchJob)
assert.True(t, ok)
assert.Equal(t, int32(2), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas)
assert.NotNil(t, pytorchJob.Spec.ElasticPolicy)
assert.Equal(t, int32(1), *pytorchJob.Spec.ElasticPolicy.MinReplicas)
assert.Equal(t, int32(2), *pytorchJob.Spec.ElasticPolicy.MaxReplicas)
assert.Equal(t, int32(4), *pytorchJob.Spec.ElasticPolicy.NProcPerNode)
assert.Equal(t, kubeflowv1.RDZVBackend("c10d"), *pytorchJob.Spec.ElasticPolicy.RDZVBackend)

assert.Equal(t, 1, len(pytorchJob.Spec.PyTorchReplicaSpecs))
assert.Contains(t, pytorchJob.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker)

var hasContainerWithDefaultPytorchName = false

for _, container := range pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers {
if container.Name == kubeflowv1.PytorchJobDefaultContainerName {
hasContainerWithDefaultPytorchName = true
}
}

assert.True(t, hasContainerWithDefaultPytorchName)
}

func TestBuildResourcePytorchV1WithZeroWorker(t *testing.T) {
taskConfig := &kfplugins.DistributedPyTorchTrainingTask{
WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{
Replicas: 0,
},
}
pytorchResourceHandler := pytorchOperatorResourceHandler{}
taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig)
taskTemplate.TaskTypeVersion = 1
_, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate))
assert.Error(t, err)
}

func TestParasElasticConfig(t *testing.T) {
elasticConfig := plugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"}
elasticPolicy := ParseElasticConfig(&elasticConfig)
assert.Equal(t, int32(1), *elasticPolicy.MinReplicas)
assert.Equal(t, int32(2), *elasticPolicy.MaxReplicas)
assert.Equal(t, int32(4), *elasticPolicy.NProcPerNode)
assert.Equal(t, kubeflowv1.RDZVBackend("c10d"), *elasticPolicy.RDZVBackend)
}

0 comments on commit 318fa6b

Please sign in to comment.