From de1d1d345f3bf487701b1f732b2c56f76e81e086 Mon Sep 17 00:00:00 2001 From: "Fabio M. Graetz, Ph.D" Date: Wed, 7 Jun 2023 20:15:55 +0200 Subject: [PATCH] Don't add master replica log link when doing elastic pytorch training (#356) * Don't add master log link when doing elastic pytorch training Signed-off-by: Fabio Graetz * Lint Signed-off-by: Fabio Graetz --------- Signed-off-by: Fabio Graetz --- .../k8s/kfoperators/common/common_operator.go | 4 ++-- .../common/common_operator_test.go | 6 +++--- go/tasks/plugins/k8s/kfoperators/mpi/mpi.go | 2 +- .../plugins/k8s/kfoperators/mpi/mpi_test.go | 2 +- .../k8s/kfoperators/pytorch/pytorch.go | 8 ++++++- .../k8s/kfoperators/pytorch/pytorch_test.go | 21 ++++++++++++++++++- .../k8s/kfoperators/tensorflow/tensorflow.go | 2 +- .../kfoperators/tensorflow/tensorflow_test.go | 2 +- 8 files changed, 36 insertions(+), 11 deletions(-) diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index d86ae42df..19e8a9724 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -104,7 +104,7 @@ func GetMPIPhaseInfo(currentCondition commonOp.JobCondition, occurredAt time.Tim } // GetLogs will return the logs for kubeflow job -func GetLogs(taskType string, name string, namespace string, +func GetLogs(taskType string, name string, namespace string, hasMaster bool, workersCount int32, psReplicasCount int32, chiefReplicasCount int32) ([]*core.TaskLog, error) { taskLogs := make([]*core.TaskLog, 0, 10) @@ -118,7 +118,7 @@ func GetLogs(taskType string, name string, namespace string, return nil, nil } - if taskType == PytorchTaskType { + if taskType == PytorchTaskType && hasMaster { masterTaskLog, masterErr := logPlugin.GetTaskLogs( tasklog.Input{ PodName: name + "-master-0", diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go index ee2dc5a94..96c4bcd87 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go @@ -167,18 +167,18 @@ func TestGetLogs(t *testing.T) { workers := int32(1) launcher := int32(1) - jobLogs, err := GetLogs(MPITaskType, "test", "mpi-namespace", workers, launcher, 0) + jobLogs, err := GetLogs(MPITaskType, "test", "mpi-namespace", false, workers, launcher, 0) assert.NoError(t, err) assert.Equal(t, 1, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", "mpi-namespace", "test"), jobLogs[0].Uri) - jobLogs, err = GetLogs(PytorchTaskType, "test", "pytorch-namespace", workers, launcher, 0) + jobLogs, err = GetLogs(PytorchTaskType, "test", "pytorch-namespace", true, workers, launcher, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", "pytorch-namespace", "test"), jobLogs[0].Uri) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=pytorch-namespace", "pytorch-namespace", "test"), jobLogs[1].Uri) - jobLogs, err = GetLogs(TensorflowTaskType, "test", "tensorflow-namespace", workers, launcher, 1) + jobLogs, err = GetLogs(TensorflowTaskType, "test", "tensorflow-namespace", false, workers, launcher, 1) assert.NoError(t, err) assert.Equal(t, 3, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=tensorflow-namespace", "tensorflow-namespace", "test"), jobLogs[0].Uri) diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index d4e35a25d..e9e1f6037 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -204,7 +204,7 @@ func (mpiOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext numWorkers = app.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas numLauncherReplicas = app.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas - taskLogs, err := common.GetLogs(common.MPITaskType, app.Name, app.Namespace, + taskLogs, err := common.GetLogs(common.MPITaskType, app.Name, app.Namespace, false, *numWorkers, *numLauncherReplicas, 0) if err != nil { return pluginsCore.PhaseInfoUndefined, err diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 778b20a08..9ec11f5da 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -389,7 +389,7 @@ func TestGetLogs(t *testing.T) { mpiResourceHandler := mpiOperatorResourceHandler{} mpiJob := dummyMPIJobResource(mpiResourceHandler, workers, launcher, slots, mpiOp.JobRunning) - jobLogs, err := common.GetLogs(common.MPITaskType, mpiJob.Name, mpiJob.Namespace, workers, launcher, 0) + jobLogs, err := common.GetLogs(common.MPITaskType, mpiJob.Name, mpiJob.Namespace, false, workers, launcher, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", jobNamespace, jobName), jobLogs[0].Uri) diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index d5cd747c6..1e9fe6115 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -201,9 +201,15 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { app := resource.(*kubeflowv1.PyTorchJob) + // Elastic PytorchJobs don't use master replicas + hasMaster := false + if _, ok := app.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster]; ok { + hasMaster = true + } + workersCount := app.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas - taskLogs, err := common.GetLogs(common.PytorchTaskType, app.Name, app.Namespace, *workersCount, 0, 0) + taskLogs, err := common.GetLogs(common.PytorchTaskType, app.Name, app.Namespace, hasMaster, *workersCount, 0, 0) if err != nil { return pluginsCore.PhaseInfoUndefined, err } diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index fea07505a..4a17b7490 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -416,11 +416,12 @@ func TestGetLogs(t *testing.T) { KubernetesURL: "k8s.com", })) + hasMaster := true workers := int32(2) pytorchResourceHandler := pytorchOperatorResourceHandler{} pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) - jobLogs, err := common.GetLogs(common.PytorchTaskType, pytorchJob.Name, pytorchJob.Namespace, workers, 0, 0) + jobLogs, err := common.GetLogs(common.PytorchTaskType, pytorchJob.Name, pytorchJob.Namespace, hasMaster, workers, 0, 0) assert.NoError(t, err) assert.Equal(t, 3, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].Uri) @@ -428,6 +429,24 @@ func TestGetLogs(t *testing.T) { assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-1/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[2].Uri) } +func TestGetLogsElastic(t *testing.T) { + assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{ + IsKubernetesEnabled: true, + KubernetesURL: "k8s.com", + })) + + hasMaster := false + workers := int32(2) + + pytorchResourceHandler := pytorchOperatorResourceHandler{} + pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) + jobLogs, err := common.GetLogs(common.PytorchTaskType, pytorchJob.Name, pytorchJob.Namespace, hasMaster, workers, 0, 0) + assert.NoError(t, err) + assert.Equal(t, 2, len(jobLogs)) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].Uri) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-1/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[1].Uri) +} + func TestGetProperties(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} expected := k8s.PluginProperties{} diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index 6ee3ce440..ea2930d6f 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -203,7 +203,7 @@ func (tensorflowOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginC psReplicasCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS].Replicas chiefCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Replicas - taskLogs, err := common.GetLogs(common.TensorflowTaskType, app.Name, app.Namespace, + taskLogs, err := common.GetLogs(common.TensorflowTaskType, app.Name, app.Namespace, false, *workersCount, *psReplicasCount, *chiefCount) if err != nil { return pluginsCore.PhaseInfoUndefined, err diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 8174258e1..4e2bc1388 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -373,7 +373,7 @@ func TestGetLogs(t *testing.T) { tensorflowResourceHandler := tensorflowOperatorResourceHandler{} tensorFlowJob := dummyTensorFlowJobResource(tensorflowResourceHandler, workers, psReplicas, chiefReplicas, commonOp.JobRunning) - jobLogs, err := common.GetLogs(common.TensorflowTaskType, tensorFlowJob.Name, tensorFlowJob.Namespace, + jobLogs, err := common.GetLogs(common.TensorflowTaskType, tensorFlowJob.Name, tensorFlowJob.Namespace, false, workers, psReplicas, chiefReplicas) assert.NoError(t, err) assert.Equal(t, 4, len(jobLogs))