Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Discard field when replica number equals zero to avoid api client error
Browse files Browse the repository at this point in the history
Signed-off-by: byhsu <[email protected]>
  • Loading branch information
ByronHsu committed Jan 29, 2023
1 parent 4634a81 commit b57c227
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 23 deletions.
36 changes: 16 additions & 20 deletions go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,31 +86,27 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task
chiefReplicas := tensorflowTaskExtraArgs.GetChiefReplicas()

jobSpec := kubeflowv1.TFJobSpec{
TFReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{
kubeflowv1.TFJobReplicaTypePS: {
Replicas: &psReplicas,
Template: v1.PodTemplateSpec{
ObjectMeta: objectMeta,
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
},
kubeflowv1.TFJobReplicaTypeChief: {
Replicas: &chiefReplicas,
TFReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{},
}

for _, t := range []struct {
replicaNum *int32
replicaType commonOp.ReplicaType
}{
{&workers, kubeflowv1.TFJobReplicaTypeWorker},
{&psReplicas, kubeflowv1.TFJobReplicaTypePS},
{&chiefReplicas, kubeflowv1.TFJobReplicaTypeChief},
} {
if *t.replicaNum > 0 {
jobSpec.TFReplicaSpecs[t.replicaType] = &commonOp.ReplicaSpec{
Replicas: t.replicaNum,
Template: v1.PodTemplateSpec{
ObjectMeta: objectMeta,
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
},
kubeflowv1.TFJobReplicaTypeWorker: {
Replicas: &workers,
Template: v1.PodTemplateSpec{
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
},
},
}
}
}

job := &kubeflowv1.TFJob{
Expand Down
25 changes: 22 additions & 3 deletions go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func dummyTensorFlowCustomObj(workers int32, psReplicas int32, chiefReplicas int
}
}

func dummySparkTaskTemplate(id string, tensorflowCustomObj *plugins.DistributedTensorflowTrainingTask) *core.TaskTemplate {
func dummyTensorFlowTaskTemplate(id string, tensorflowCustomObj *plugins.DistributedTensorflowTrainingTask) *core.TaskTemplate {

tfObjJSON, err := utils.MarshalToString(tensorflowCustomObj)
if err != nil {
Expand Down Expand Up @@ -251,7 +251,7 @@ func dummyTensorFlowJobResource(tensorflowResourceHandler tensorflowOperatorReso
}

tfObj := dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas)
taskTemplate := dummySparkTaskTemplate("the job", tfObj)
taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj)
resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate))
if err != nil {
panic(err)
Expand All @@ -277,7 +277,7 @@ func TestBuildResourceTensorFlow(t *testing.T) {
tensorflowResourceHandler := tensorflowOperatorResourceHandler{}

tfObj := dummyTensorFlowCustomObj(100, 50, 1)
taskTemplate := dummySparkTaskTemplate("the job", tfObj)
taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj)

resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate))
assert.NoError(t, err)
Expand Down Expand Up @@ -371,3 +371,22 @@ func TestGetProperties(t *testing.T) {
expected := k8s.PluginProperties{}
assert.Equal(t, expected, tensorflowResourceHandler.GetProperties())
}

func TestZeroReplicas(t *testing.T) {
// if the number of replicas is zero, the field should not be created or the client might complain.

tensorflowResourceHandler := tensorflowOperatorResourceHandler{}

tfObj := dummyTensorFlowCustomObj(10, 0, 0)
taskTemplate := dummyTensorFlowTaskTemplate("the job", tfObj)

resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate))
assert.NoError(t, err)
assert.NotNil(t, resource)

tensorflowJob, ok := resource.(*kubeflowv1.TFJob)
assert.True(t, ok)

assert.NotContains(t, tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypePS)
assert.NotContains(t, tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeChief)
}

0 comments on commit b57c227

Please sign in to comment.