diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 88419b64c..f9c98cfd6 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -96,7 +96,7 @@ func GetMPIPhaseInfo(currentCondition commonOp.JobCondition, occurredAt time.Tim } // GetLogs will return the logs for kubeflow job -func GetLogs(taskType string, name string, namespace string, +func GetLogs(taskType string, name string, namespace string, hasMaster bool, workersCount int32, psReplicasCount int32, chiefReplicasCount int32) ([]*core.TaskLog, error) { taskLogs := make([]*core.TaskLog, 0, 10) @@ -110,7 +110,7 @@ func GetLogs(taskType string, name string, namespace string, return nil, nil } - if taskType == PytorchTaskType { + if taskType == PytorchTaskType && hasMaster == true { masterTaskLog, masterErr := logPlugin.GetTaskLogs( tasklog.Input{ PodName: name + "-master-0", diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go index 1fb26f128..9a1de47fe 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go @@ -164,18 +164,18 @@ func TestGetLogs(t *testing.T) { workers := int32(1) launcher := int32(1) - jobLogs, err := GetLogs(MPITaskType, "test", "mpi-namespace", workers, launcher, 0) + jobLogs, err := GetLogs(MPITaskType, "test", "mpi-namespace", false, workers, launcher, 0) assert.NoError(t, err) assert.Equal(t, 1, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", "mpi-namespace", "test"), jobLogs[0].Uri) - jobLogs, err = GetLogs(PytorchTaskType, "test", "pytorch-namespace", workers, launcher, 0) + jobLogs, err = GetLogs(PytorchTaskType, "test", "pytorch-namespace", true, workers, launcher, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", "pytorch-namespace", "test"), jobLogs[0].Uri) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=pytorch-namespace", "pytorch-namespace", "test"), jobLogs[1].Uri) - jobLogs, err = GetLogs(TensorflowTaskType, "test", "tensorflow-namespace", workers, launcher, 1) + jobLogs, err = GetLogs(TensorflowTaskType, "test", "tensorflow-namespace", false, workers, launcher, 1) assert.NoError(t, err) assert.Equal(t, 3, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=tensorflow-namespace", "tensorflow-namespace", "test"), jobLogs[0].Uri) diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 7837f5f4a..8f65b135a 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -145,7 +145,7 @@ func (mpiOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext numWorkers = app.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas numLauncherReplicas = app.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas - taskLogs, err := common.GetLogs(common.MPITaskType, app.Name, app.Namespace, + taskLogs, err := common.GetLogs(common.MPITaskType, app.Name, app.Namespace, false, *numWorkers, *numLauncherReplicas, 0) if err != nil { return pluginsCore.PhaseInfoUndefined, err diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 2e9e9283a..8a7bb7aa9 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -373,7 +373,7 @@ func TestGetLogs(t *testing.T) { mpiResourceHandler := mpiOperatorResourceHandler{} mpiJob := dummyMPIJobResource(mpiResourceHandler, workers, launcher, slots, mpiOp.JobRunning) - jobLogs, err := common.GetLogs(common.MPITaskType, mpiJob.Name, mpiJob.Namespace, workers, launcher, 0) + jobLogs, err := common.GetLogs(common.MPITaskType, mpiJob.Name, mpiJob.Namespace, false, workers, launcher, 0) assert.NoError(t, err) assert.Equal(t, 2, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=mpi-namespace", jobNamespace, jobName), jobLogs[0].Uri) diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 338f6cd56..516505bce 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -143,9 +143,15 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { app := resource.(*kubeflowv1.PyTorchJob) + // Elastic PytorchJobs don't use master replicas + hasMaster := false + if _, ok := app.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster]; ok { + hasMaster = true + } + workersCount := app.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas - taskLogs, err := common.GetLogs(common.PytorchTaskType, app.Name, app.Namespace, *workersCount, 0, 0) + taskLogs, err := common.GetLogs(common.PytorchTaskType, app.Name, app.Namespace, hasMaster, *workersCount, 0, 0) if err != nil { return pluginsCore.PhaseInfoUndefined, err } diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 74bd3fe92..6dee03e03 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -399,11 +399,12 @@ func TestGetLogs(t *testing.T) { KubernetesURL: "k8s.com", })) + hasMaster := true workers := int32(2) pytorchResourceHandler := pytorchOperatorResourceHandler{} pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) - jobLogs, err := common.GetLogs(common.PytorchTaskType, pytorchJob.Name, pytorchJob.Namespace, workers, 0, 0) + jobLogs, err := common.GetLogs(common.PytorchTaskType, pytorchJob.Name, pytorchJob.Namespace, hasMaster, workers, 0, 0) assert.NoError(t, err) assert.Equal(t, 3, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].Uri) @@ -411,6 +412,24 @@ func TestGetLogs(t *testing.T) { assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-1/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[2].Uri) } +func TestGetLogsElastic(t *testing.T) { + assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{ + IsKubernetesEnabled: true, + KubernetesURL: "k8s.com", + })) + + hasMaster := false + workers := int32(2) + + pytorchResourceHandler := pytorchOperatorResourceHandler{} + pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) + jobLogs, err := common.GetLogs(common.PytorchTaskType, pytorchJob.Name, pytorchJob.Namespace, hasMaster, workers, 0, 0) + assert.NoError(t, err) + assert.Equal(t, 2, len(jobLogs)) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].Uri) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-1/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[1].Uri) +} + func TestGetProperties(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} expected := k8s.PluginProperties{} diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index b5a5a675f..de3f3d65c 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -122,7 +122,7 @@ func (tensorflowOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginC psReplicasCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS].Replicas chiefCount := app.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Replicas - taskLogs, err := common.GetLogs(common.TensorflowTaskType, app.Name, app.Namespace, + taskLogs, err := common.GetLogs(common.TensorflowTaskType, app.Name, app.Namespace, false, *workersCount, *psReplicasCount, *chiefCount) if err != nil { return pluginsCore.PhaseInfoUndefined, err diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 37f22bf34..22d102a6b 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -357,7 +357,7 @@ func TestGetLogs(t *testing.T) { tensorflowResourceHandler := tensorflowOperatorResourceHandler{} tensorFlowJob := dummyTensorFlowJobResource(tensorflowResourceHandler, workers, psReplicas, chiefReplicas, commonOp.JobRunning) - jobLogs, err := common.GetLogs(common.TensorflowTaskType, tensorFlowJob.Name, tensorFlowJob.Namespace, + jobLogs, err := common.GetLogs(common.TensorflowTaskType, tensorFlowJob.Name, tensorFlowJob.Namespace, false, workers, psReplicas, chiefReplicas) assert.NoError(t, err) assert.Equal(t, 4, len(jobLogs))