From e84585e5ad4522854625b21d84f18c77b520920b Mon Sep 17 00:00:00 2001 From: HUAN-PING SU Date: Sat, 13 Feb 2021 08:20:40 +0800 Subject: [PATCH] tensorflow plugin implementation (#103) * Tensorflow plugin implementaion * Fix checkstyle * Update go/tasks/plugins/k8s/kfoperators/common/common_operator.go Co-authored-by: Haytham Abuelfutuh * Address comments * Fix lint error * Address comment * Fix lint * Update to running Co-authored-by: Haytham Abuelfutuh --- copilot/go.sum | 14 + go.sum | 2 + .../k8s/kfoperators/common/common_operator.go | 138 +++++++ .../common/common_operator_test.go | 67 ++++ .../k8s/kfoperators/pytorch/pytorch.go | 108 +----- .../k8s/kfoperators/pytorch/pytorch_test.go | 7 +- .../k8s/kfoperators/tensorflow/tensorflow.go | 154 ++++++++ .../kfoperators/tensorflow/tensorflow_test.go | 366 ++++++++++++++++++ 8 files changed, 754 insertions(+), 102 deletions(-) create mode 100644 go/tasks/plugins/k8s/kfoperators/common/common_operator.go create mode 100644 go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go create mode 100644 go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go create mode 100644 go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go diff --git a/copilot/go.sum b/copilot/go.sum index efedde154e..d210f7a610 100644 --- a/copilot/go.sum +++ b/copilot/go.sum @@ -68,6 +68,15 @@ github.com/aws/aws-sdk-go v1.28.9/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN github.com/aws/aws-sdk-go v1.29.23 h1:wtiGLOzxAP755OfuVTDIy/NbUIYEDxbIbBEDfNhUpeU= github.com/aws/aws-sdk-go v1.29.23/go.mod h1:1KvfttTE3SPKMpo8g2c6jL3ZKfXtFvKscTgahTma5Xg= github.com/aws/aws-sdk-go-v2 v0.20.0/go.mod h1:2LhT7UgHOXK3UXONKI5OMgIyoQL6zTAw/jwIeX6yqzw= +github.com/aws/aws-sdk-go-v2 v1.0.0/go.mod h1:smfAbmpW+tcRVuNUjo3MOArSZmW72t62rkCzc2i0TWM= +github.com/aws/aws-sdk-go-v2/config v1.0.0/go.mod h1:WysE/OpUgE37tjtmtJd8GXgT8s1euilE5XtUkRNUQ1w= +github.com/aws/aws-sdk-go-v2/credentials v1.0.0/go.mod h1:/SvsiqBf509hG4Bddigr3NB12MIpfHhZapyBurJe8aY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.0.0/go.mod h1:wpMHDCXvOXZxGCRSidyepa8uJHY4vaBGfY2/+oKU/Bc= +github.com/aws/aws-sdk-go-v2/service/athena v1.0.0/go.mod h1:qY8QFbemf2ceqweXcS6hQqiiIe1z42WqTvHsK2Lb0rE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.0.0/go.mod h1:3jExOmpbjgPnz2FJaMOfbSk1heTkZ66aD3yNtVhnjvI= +github.com/aws/aws-sdk-go-v2/service/sagemaker v1.0.0/go.mod h1:8/T2od4WQj1qKPr2ppDgjCnMFR6hfYJM4hzjH1D+HWg= +github.com/aws/aws-sdk-go-v2/service/sts v1.0.0/go.mod h1:5f+cELGATgill5Pu3/vK3Ebuigstc+qYEHW5MvGWZO4= +github.com/aws/smithy-go v1.0.0/go.mod h1:EzMw8dbp/YJL4A5/sbhGddag+NPT7q084agLbB9LgIw= github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1/go.mod h1:jvdWlw8vowVGnZqSDC7yhPd7AifQeQbRDkZcQXV2nRg= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= @@ -233,6 +242,8 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= github.com/google/gofuzz v0.0.0-20170612174753-24818f796faf/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -287,6 +298,8 @@ github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5i github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc= github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= @@ -320,6 +333,7 @@ github.com/lyft/api v0.0.0-20191031200350-b49a72c274e0 h1:NGL46+1RYcCXb3sShp0nQq github.com/lyft/api v0.0.0-20191031200350-b49a72c274e0/go.mod h1:/L5qH+AD540e7Cetbui1tuJeXdmNhO8jM6VkXeDdDhQ= github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f h1:PGuAMDzAen0AulUfaEhNQMYmUpa41pAVo3zHI+GJsCM= github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f/go.mod h1:llRdnznGEAqC3DcNm6yEj472xaFVfLM7hnYofMb12tQ= +github.com/lyft/flyteidl v0.18.9/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= github.com/lyft/flyteidl v0.18.11 h1:24NaFYWxANhRbwKfvkgu8axGTWUcl1tgZBqNJutKNJ8= github.com/lyft/flyteidl v0.18.11/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= github.com/lyft/flytestdlib v0.3.0/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU= diff --git a/go.sum b/go.sum index f139c7350f..365aa3889f 100644 --- a/go.sum +++ b/go.sum @@ -154,6 +154,7 @@ github.com/aws/aws-sdk-go-v2 v0.20.0/go.mod h1:2LhT7UgHOXK3UXONKI5OMgIyoQL6zTAw/ github.com/aws/aws-sdk-go-v2 v0.24.0/go.mod h1:2LhT7UgHOXK3UXONKI5OMgIyoQL6zTAw/jwIeX6yqzw= github.com/aws/aws-sdk-go-v2 v1.0.0 h1:ncEVPoHArsG+HjoDe/3ex/TG1CbLwMQ4eaWj0UGdyTo= github.com/aws/aws-sdk-go-v2 v1.0.0/go.mod h1:smfAbmpW+tcRVuNUjo3MOArSZmW72t62rkCzc2i0TWM= +github.com/aws/aws-sdk-go-v2 v1.1.0 h1:sKP6QWxdN1oRYjl+k6S3bpgBI+XUx/0mqVOLIw4lR/Q= github.com/aws/aws-sdk-go-v2/config v1.0.0 h1:x6vSFAwqAvhYPeSu60f0ZUlGHo3PKKmwDOTL8aMXtv4= github.com/aws/aws-sdk-go-v2/config v1.0.0/go.mod h1:WysE/OpUgE37tjtmtJd8GXgT8s1euilE5XtUkRNUQ1w= github.com/aws/aws-sdk-go-v2/credentials v1.0.0 h1:0M7netgZ8gCV4v7z1km+Fbl7j6KQYyZL7SS0/l5Jn/4= @@ -168,6 +169,7 @@ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.0.0 h1:IAutMPSryn github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.0.0/go.mod h1:3jExOmpbjgPnz2FJaMOfbSk1heTkZ66aD3yNtVhnjvI= github.com/aws/aws-sdk-go-v2/service/sagemaker v1.0.0 h1:WAKXnA5HISN6P8sbXsJ9486ThbRPnoBAtMyDSG7+jNM= github.com/aws/aws-sdk-go-v2/service/sagemaker v1.0.0/go.mod h1:8/T2od4WQj1qKPr2ppDgjCnMFR6hfYJM4hzjH1D+HWg= +github.com/aws/aws-sdk-go-v2/service/sagemaker v1.1.0 h1:qsaGAmYqUzym7g4uaBzx5uOYoEJW0wIHhgObLqZc1mo= github.com/aws/aws-sdk-go-v2/service/sts v1.0.0 h1:6XCgxNfE4L/Fnq+InhVNd16DKc6Ue1f3dJl3IwwJRUQ= github.com/aws/aws-sdk-go-v2/service/sts v1.0.0/go.mod h1:5f+cELGATgill5Pu3/vK3Ebuigstc+qYEHW5MvGWZO4= github.com/aws/smithy-go v1.0.0 h1:hkhcRKG9rJ4Fn+RbfXY7Tz7b3ITLDyolBnLLBhwbg/c= diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go new file mode 100644 index 0000000000..b33967e152 --- /dev/null +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -0,0 +1,138 @@ +package common + +import ( + "fmt" + "sort" + "time" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/tasklog" + + commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + flyteerr "github.com/lyft/flyteplugins/go/tasks/errors" + "github.com/lyft/flyteplugins/go/tasks/logs" + pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + v1 "k8s.io/api/core/v1" +) + +const ( + TensorflowTaskType = "tensorflow" + PytorchTaskType = "pytorch" +) + +func ExtractCurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) { + if jobConditions != nil { + sort.Slice(jobConditions, func(i, j int) bool { + return jobConditions[i].LastTransitionTime.Time.After(jobConditions[j].LastTransitionTime.Time) + }) + + for _, jc := range jobConditions { + if jc.Status == v1.ConditionTrue { + return jc, nil + } + } + } + + return commonOp.JobCondition{}, fmt.Errorf("found no current condition. Conditions: %+v", jobConditions) +} + +func GetPhaseInfo(currentCondition commonOp.JobCondition, occurredAt time.Time, + taskPhaseInfo pluginsCore.TaskInfo) (pluginsCore.PhaseInfo, error) { + switch currentCondition.Type { + case commonOp.JobCreated: + return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "JobCreated"), nil + case commonOp.JobRunning: + return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil + case commonOp.JobSucceeded: + return pluginsCore.PhaseInfoSuccess(&taskPhaseInfo), nil + case commonOp.JobFailed: + details := fmt.Sprintf("Job failed:\n\t%v - %v", currentCondition.Reason, currentCondition.Message) + return pluginsCore.PhaseInfoRetryableFailure(flyteerr.DownstreamSystemError, details, &taskPhaseInfo), nil + case commonOp.JobRestarting: + return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil + } + + return pluginsCore.PhaseInfoUndefined, nil +} + +func GetLogs(taskType string, name string, namespace string, + workersCount int32, psReplicasCount int32, chiefReplicasCount int32) ([]*core.TaskLog, error) { + taskLogs := make([]*core.TaskLog, 0, 10) + + logPlugin, err := logs.InitializeLogPlugins(logs.GetLogConfig()) + + if err != nil { + return nil, err + } + + if logPlugin == nil { + return nil, nil + } + + if taskType == PytorchTaskType { + masterTaskLog, masterErr := logPlugin.GetTaskLogs( + tasklog.Input{ + PodName: name + "-master-0", + Namespace: namespace, + LogName: "master", + }, + ) + if masterErr != nil { + return nil, masterErr + } + taskLogs = append(taskLogs, masterTaskLog.TaskLogs...) + } + + // get all workers log + for workerIndex := int32(0); workerIndex < workersCount; workerIndex++ { + workerLog, err := logPlugin.GetTaskLogs(tasklog.Input{ + PodName: name + fmt.Sprintf("-worker-%d", workerIndex), + Namespace: namespace, + }) + if err != nil { + return nil, err + } + taskLogs = append(taskLogs, workerLog.TaskLogs...) + } + // get all parameter servers logs + for psReplicaIndex := int32(0); psReplicaIndex < psReplicasCount; psReplicaIndex++ { + psReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{ + PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex), + Namespace: namespace, + }) + if err != nil { + return nil, err + } + taskLogs = append(taskLogs, psReplicaLog.TaskLogs...) + } + // get chief worker log, and the max number of chief worker is 1 + if chiefReplicasCount != 0 { + chiefReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{ + PodName: name + fmt.Sprintf("-chiefReplica-%d", 0), + Namespace: namespace, + }) + if err != nil { + return nil, err + } + taskLogs = append(taskLogs, chiefReplicaLog.TaskLogs...) + } + + return taskLogs, nil +} + +func OverrideDefaultContainerName(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, + defaultContainerName string) { + // Pytorch operator forces pod to have container named 'pytorch' + // https://github.com/kubeflow/pytorch-operator/blob/037cd1b18eb77f657f2a4bc8a8334f2a06324b57/pkg/apis/pytorch/validation/validation.go#L54-L62 + // Tensorflow operator forces pod to have container named 'tensorflow' + // https://github.com/kubeflow/tf-operator/blob/984adc287e6fe82841e4ca282dc9a2cbb71e2d4a/pkg/apis/tensorflow/validation/validation.go#L55-L63 + // hence we have to override the name set here + // https://github.com/lyft/flyteplugins/blob/209c52d002b4e6a39be5d175bc1046b7e631c153/go/tasks/pluginmachinery/flytek8s/container_helper.go#L116 + flyteDefaultContainerName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + for idx, c := range podSpec.Containers { + if c.Name == flyteDefaultContainerName { + podSpec.Containers[idx].Name = defaultContainerName + return + } + } +} diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go new file mode 100644 index 0000000000..579d58ba41 --- /dev/null +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go @@ -0,0 +1,67 @@ +package common + +import ( + "testing" + "time" + + commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1" + pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" +) + +func TestExtractCurrentCondition(t *testing.T) { + jobCreated := commonOp.JobCondition{ + Type: commonOp.JobCreated, + Status: corev1.ConditionTrue, + } + jobRunningActive := commonOp.JobCondition{ + Type: commonOp.JobRunning, + Status: corev1.ConditionFalse, + } + jobConditions := []commonOp.JobCondition{ + jobCreated, + jobRunningActive, + } + currentCondition, err := ExtractCurrentCondition(jobConditions) + assert.NoError(t, err) + assert.Equal(t, currentCondition, jobCreated) +} + +func TestGetPhaseInfo(t *testing.T) { + jobCreated := commonOp.JobCondition{ + Type: commonOp.JobCreated, + } + taskPhase, err := GetPhaseInfo(jobCreated, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + jobSucceeded := commonOp.JobCondition{ + Type: commonOp.JobSucceeded, + } + taskPhase, err = GetPhaseInfo(jobSucceeded, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseSuccess, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + jobFailed := commonOp.JobCondition{ + Type: commonOp.JobFailed, + } + taskPhase, err = GetPhaseInfo(jobFailed, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRetryableFailure, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + jobRestarting := commonOp.JobCondition{ + Type: commonOp.JobRestarting, + } + taskPhase, err = GetPhaseInfo(jobRestarting, time.Now(), pluginsCore.TaskInfo{}) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) +} diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 898f87da6a..1ea5fbff6d 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -2,12 +2,9 @@ package pytorch import ( "context" - "fmt" - - "sort" "time" - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/tasklog" + "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" flyteerr "github.com/lyft/flyteplugins/go/tasks/errors" @@ -20,19 +17,12 @@ import ( "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" - "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" - "github.com/lyft/flyteplugins/go/tasks/logs" - //commonOp "github.com/kubeflow/common/pkg/apis/common/v1" // switch to real 'common' once https://github.com/kubeflow/pytorch-operator/issues/263 resolved ptOp "github.com/kubeflow/pytorch-operator/pkg/apis/pytorch/v1" commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) -const ( - pytorchTaskType = "pytorch" -) - type pytorchOperatorResourceHandler struct { } @@ -71,7 +61,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } - overrideDefaultContainerName(taskCtx, podSpec) + common.OverrideDefaultContainerName(taskCtx, podSpec, ptOp.DefaultContainerName) workers := pytorchTaskExtraArgs.GetWorkers() @@ -113,12 +103,12 @@ func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginCont workersCount := app.Spec.PyTorchReplicaSpecs[ptOp.PyTorchReplicaTypeWorker].Replicas - taskLogs, err := getLogs(app, *workersCount) + taskLogs, err := common.GetLogs(common.PytorchTaskType, app.Name, app.Namespace, *workersCount, 0, 0) if err != nil { return pluginsCore.PhaseInfoUndefined, err } - currentCondition, err := extractCurrentCondition(app.Status.Conditions) + currentCondition, err := common.ExtractCurrentCondition(app.Status.Conditions) if err != nil { return pluginsCore.PhaseInfoUndefined, err } @@ -131,89 +121,7 @@ func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginCont CustomInfo: statusDetails, } - switch currentCondition.Type { - case commonOp.JobCreated: - return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "JobCreated"), nil - case commonOp.JobRunning: - return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil - case commonOp.JobSucceeded: - return pluginsCore.PhaseInfoSuccess(&taskPhaseInfo), nil - case commonOp.JobFailed: - details := fmt.Sprintf("Job failed:\n\t%v - %v", currentCondition.Reason, currentCondition.Message) - return pluginsCore.PhaseInfoRetryableFailure(flyteerr.DownstreamSystemError, details, &taskPhaseInfo), nil - case commonOp.JobRestarting: - details := fmt.Sprintf("Job failed:\n\t%v - %v", currentCondition.Reason, currentCondition.Message) - return pluginsCore.PhaseInfoRetryableFailure(flyteerr.RuntimeFailure, details, &taskPhaseInfo), nil - } - - return pluginsCore.PhaseInfoUndefined, nil -} - -func getLogs(app *ptOp.PyTorchJob, workersCount int32) ([]*core.TaskLog, error) { - p, err := logs.InitializeLogPlugins(logs.GetLogConfig()) - if err != nil { - return nil, err - } - - if p == nil { - return nil, nil - } - - o, err := p.GetTaskLogs(tasklog.Input{ - PodName: app.Name + "-master-0", - Namespace: app.Namespace, - LogName: "master", - }) - - if err != nil { - return nil, err - } - - taskLogs := make([]*core.TaskLog, 0, 10) - taskLogs = append(taskLogs, o.TaskLogs...) - - for workerIndex := int32(0); workerIndex < workersCount; workerIndex++ { - workerLog, err := p.GetTaskLogs(tasklog.Input{ - PodName: app.Name + fmt.Sprintf("-worker-%d", workerIndex), - Namespace: app.Namespace, - }) - - if err != nil { - return nil, err - } - - taskLogs = append(taskLogs, workerLog.TaskLogs...) - } - - return taskLogs, nil -} - -func extractCurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) { - sort.Slice(jobConditions[:], func(i, j int) bool { - return jobConditions[i].LastTransitionTime.Time.After(jobConditions[j].LastTransitionTime.Time) - }) - - for _, jc := range jobConditions { - if jc.Status == v1.ConditionTrue { - return jc, nil - } - } - - return commonOp.JobCondition{}, fmt.Errorf("found no current condition. Conditions: %+v", jobConditions) -} - -func overrideDefaultContainerName(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec) { - // Pytorch operator forces pod to have container named 'pytorch' - // https://github.com/kubeflow/pytorch-operator/blob/037cd1b18eb77f657f2a4bc8a8334f2a06324b57/pkg/apis/pytorch/validation/validation.go#L54-L62 - // hence we have to override the name set here - // https://github.com/lyft/flyteplugins/blob/209c52d002b4e6a39be5d175bc1046b7e631c153/go/tasks/pluginmachinery/flytek8s/container_helper.go#L116 - flyteDefaultContainerName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() - for idx, c := range podSpec.Containers { - if c.Name == flyteDefaultContainerName { - podSpec.Containers[idx].Name = ptOp.DefaultContainerName - return - } - } + return common.GetPhaseInfo(currentCondition, occurredAt, taskPhaseInfo) } func init() { @@ -223,11 +131,11 @@ func init() { pluginmachinery.PluginRegistry().RegisterK8sPlugin( k8s.PluginEntry{ - ID: pytorchTaskType, - RegisteredTaskTypes: []pluginsCore.TaskType{pytorchTaskType}, + ID: common.PytorchTaskType, + RegisteredTaskTypes: []pluginsCore.TaskType{common.PytorchTaskType}, ResourceToWatch: &ptOp.PyTorchJob{}, Plugin: pytorchOperatorResourceHandler{}, IsDefault: false, - DefaultForTaskTypes: []pluginsCore.TaskType{pytorchTaskType}, + DefaultForTaskTypes: []pluginsCore.TaskType{common.PytorchTaskType}, }) } diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 67d4fb4aa8..8c0f0587c8 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -6,6 +6,8 @@ import ( "testing" "time" + "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" + commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1" "github.com/lyft/flyteplugins/go/tasks/logs" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s" @@ -331,7 +333,7 @@ func TestGetTaskPhase(t *testing.T) { taskPhase, err = pytorchResourceHandler.GetTaskPhase(ctx, nil, dummyPytorchJobResourceCreator(commonOp.JobRestarting)) assert.NoError(t, err) - assert.Equal(t, pluginsCore.PhaseRetryableFailure, taskPhase.Phase()) + assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase()) assert.NotNil(t, taskPhase.Info()) assert.Nil(t, err) } @@ -345,7 +347,8 @@ func TestGetLogs(t *testing.T) { workers := int32(2) pytorchResourceHandler := pytorchOperatorResourceHandler{} - jobLogs, err := getLogs(dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning), workers) + pytorchJob := dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning) + jobLogs, err := common.GetLogs(common.PytorchTaskType, pytorchJob.Name, pytorchJob.Namespace, workers, 0, 0) assert.NoError(t, err) assert.Equal(t, 3, len(jobLogs)) assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].Uri) diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go new file mode 100644 index 0000000000..304aace398 --- /dev/null +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -0,0 +1,154 @@ +package tensorflow + +import ( + "context" + "time" + + "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + flyteerr "github.com/lyft/flyteplugins/go/tasks/errors" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s" + v1 "k8s.io/api/core/v1" + "k8s.io/client-go/kubernetes/scheme" + + pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + + //commonOp "github.com/kubeflow/common/pkg/apis/common/v1" // switch to real 'common' once https://github.com/kubeflow/pytorch-operator/issues/263 resolved + commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1" + tfOp "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type tensorflowOperatorResourceHandler struct { +} + +// Sanity test that the plugin implements method of k8s.Plugin +var _ k8s.Plugin = tensorflowOperatorResourceHandler{} + +// Defines a func to create a query object (typically just object and type meta portions) that's used to query k8s +// resources. +func (tensorflowOperatorResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionMetadata) (k8s.Resource, error) { + return &tfOp.TFJob{ + TypeMeta: metav1.TypeMeta{ + Kind: tfOp.Kind, + APIVersion: tfOp.SchemeGroupVersion.String(), + }, + }, nil +} + +// Defines a func to create the full resource object that will be posted to k8s. +func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (k8s.Resource, error) { + taskTemplate, err := taskCtx.TaskReader().Read(ctx) + + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "unable to fetch task specification [%v]", err.Error()) + } else if taskTemplate == nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification") + } + + tensorflowTaskExtraArgs := plugins.DistributedTensorflowTrainingTask{} + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &tensorflowTaskExtraArgs) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } + + podSpec, err := flytek8s.ToK8sPodSpec(ctx, taskCtx.TaskExecutionMetadata(), taskCtx.TaskReader(), taskCtx.InputReader(), taskCtx.OutputWriter()) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) + } + + common.OverrideDefaultContainerName(taskCtx, podSpec, tfOp.DefaultContainerName) + + workers := tensorflowTaskExtraArgs.GetWorkers() + psReplicas := tensorflowTaskExtraArgs.GetPsReplicas() + chiefReplicas := tensorflowTaskExtraArgs.GetChiefReplicas() + + jobSpec := tfOp.TFJobSpec{ + TTLSecondsAfterFinished: nil, + TFReplicaSpecs: map[tfOp.TFReplicaType]*commonOp.ReplicaSpec{ + tfOp.TFReplicaTypePS: { + Replicas: &psReplicas, + Template: v1.PodTemplateSpec{ + Spec: *podSpec, + }, + RestartPolicy: commonOp.RestartPolicyNever, + }, + tfOp.TFReplicaTypeChief: { + Replicas: &chiefReplicas, + Template: v1.PodTemplateSpec{ + Spec: *podSpec, + }, + RestartPolicy: commonOp.RestartPolicyNever, + }, + tfOp.TFReplicaTypeWorker: { + Replicas: &workers, + Template: v1.PodTemplateSpec{ + Spec: *podSpec, + }, + RestartPolicy: commonOp.RestartPolicyNever, + }, + }, + } + + job := &tfOp.TFJob{ + TypeMeta: metav1.TypeMeta{ + Kind: tfOp.Kind, + APIVersion: tfOp.SchemeGroupVersion.String(), + }, + Spec: jobSpec, + } + + return job, nil +} + +// Analyses the k8s resource and reports the status as TaskPhase. This call is expected to be relatively fast, +// any operations that might take a long time (limits are configured system-wide) should be offloaded to the +// background. +func (tensorflowOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource k8s.Resource) (pluginsCore.PhaseInfo, error) { + app := resource.(*tfOp.TFJob) + + workersCount := app.Spec.TFReplicaSpecs[tfOp.TFReplicaTypeWorker].Replicas + psReplicasCount := app.Spec.TFReplicaSpecs[tfOp.TFReplicaTypePS].Replicas + chiefCount := app.Spec.TFReplicaSpecs[tfOp.TFReplicaTypeChief].Replicas + + taskLogs, err := common.GetLogs(common.TensorflowTaskType, app.Name, app.Namespace, + *workersCount, *psReplicasCount, *chiefCount) + if err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + + currentCondition, err := common.ExtractCurrentCondition(app.Status.Conditions) + if err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + + occurredAt := time.Now() + statusDetails, _ := utils.MarshalObjToStruct(app.Status) + taskPhaseInfo := pluginsCore.TaskInfo{ + Logs: taskLogs, + OccurredAt: &occurredAt, + CustomInfo: statusDetails, + } + + return common.GetPhaseInfo(currentCondition, occurredAt, taskPhaseInfo) +} + +func init() { + if err := tfOp.AddToScheme(scheme.Scheme); err != nil { + panic(err) + } + + pluginmachinery.PluginRegistry().RegisterK8sPlugin( + k8s.PluginEntry{ + ID: common.TensorflowTaskType, + RegisteredTaskTypes: []pluginsCore.TaskType{common.TensorflowTaskType}, + ResourceToWatch: &tfOp.TFJob{}, + Plugin: tensorflowOperatorResourceHandler{}, + IsDefault: false, + DefaultForTaskTypes: []pluginsCore.TaskType{common.TensorflowTaskType}, + }) +} diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go new file mode 100644 index 0000000000..8d9d7d6f65 --- /dev/null +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -0,0 +1,366 @@ +package tensorflow + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/lyft/flyteplugins/go/tasks/plugins/k8s/kfoperators/common" + + commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1" + "github.com/lyft/flyteplugins/go/tasks/logs" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + + "github.com/stretchr/testify/mock" + + "github.com/lyft/flytestdlib/storage" + + pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" + + pluginIOMocks "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io/mocks" + + "github.com/golang/protobuf/jsonpb" + structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/stretchr/testify/assert" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + tfOp "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1" +) + +const testImage = "image://" +const serviceAccount = "tensorflow_sa" + +var ( + dummyEnvVars = []*core.KeyValuePair{ + {Key: "Env_Var", Value: "Env_Val"}, + } + + testArgs = []string{ + "test-args", + } + + resourceRequirements = &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1000m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("100m"), + corev1.ResourceMemory: resource.MustParse("512Mi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + } + + jobName = "the-job" + jobNamespace = "tensorflow-namespace" +) + +func dummyTensorFlowCustomObj(workers int32, psReplicas int32, chiefReplicas int32) *plugins.DistributedTensorflowTrainingTask { + return &plugins.DistributedTensorflowTrainingTask{ + Workers: workers, + PsReplicas: psReplicas, + ChiefReplicas: chiefReplicas, + } +} + +func dummySparkTaskTemplate(id string, tensorflowCustomObj *plugins.DistributedTensorflowTrainingTask) *core.TaskTemplate { + + tfObjJSON, err := utils.MarshalToString(tensorflowCustomObj) + if err != nil { + panic(err) + } + + structObj := structpb.Struct{} + + err = jsonpb.UnmarshalString(tfObjJSON, &structObj) + if err != nil { + panic(err) + } + + return &core.TaskTemplate{ + Id: &core.Identifier{Name: id}, + Type: "container", + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: testImage, + Args: testArgs, + Env: dummyEnvVars, + }, + }, + Custom: &structObj, + } +} + +func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskExecutionContext { + taskCtx := &mocks.TaskExecutionContext{} + inputReader := &pluginIOMocks.InputReader{} + inputReader.OnGetInputPrefixPath().Return(storage.DataReference("/input/prefix")) + inputReader.OnGetInputPath().Return(storage.DataReference("/input")) + inputReader.OnGetMatch(mock.Anything).Return(&core.LiteralMap{}, nil) + taskCtx.OnInputReader().Return(inputReader) + + outputReader := &pluginIOMocks.OutputWriter{} + outputReader.OnGetOutputPath().Return(storage.DataReference("/data/outputs.pb")) + outputReader.OnGetOutputPrefixPath().Return(storage.DataReference("/data/")) + outputReader.OnGetRawOutputPrefix().Return(storage.DataReference("")) + taskCtx.OnOutputWriter().Return(outputReader) + + taskReader := &mocks.TaskReader{} + taskReader.OnReadMatch(mock.Anything).Return(taskTemplate, nil) + taskCtx.OnTaskReader().Return(taskReader) + + tID := &mocks.TaskExecutionID{} + tID.OnGetID().Return(core.TaskExecutionIdentifier{ + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "my_name", + Project: "my_project", + Domain: "my_domain", + }, + }, + }) + tID.OnGetGeneratedName().Return("some-acceptable-name") + + resources := &mocks.TaskOverrides{} + resources.OnGetResources().Return(resourceRequirements) + + taskExecutionMetadata := &mocks.TaskExecutionMetadata{} + taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) + taskExecutionMetadata.OnGetNamespace().Return("test-namespace") + taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"annotation-1": "val1"}) + taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"}) + taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{ + Kind: "node", + Name: "blah", + }) + taskExecutionMetadata.OnIsInterruptible().Return(true) + taskExecutionMetadata.OnGetOverrides().Return(resources) + taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount) + taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) + return taskCtx +} + +func dummyTensorFlowJobResource(tensorflowResourceHandler tensorflowOperatorResourceHandler, + workers int32, psReplicas int32, chiefReplicas int32, conditionType commonOp.JobConditionType) *tfOp.TFJob { + var jobConditions []commonOp.JobCondition + + now := time.Now() + + jobCreated := commonOp.JobCondition{ + Type: commonOp.JobCreated, + Status: corev1.ConditionTrue, + Reason: "TensorFlowJobCreated", + Message: "TensorFlowJob the-job is created.", + LastUpdateTime: v1.Time{ + Time: now, + }, + LastTransitionTime: v1.Time{ + Time: now, + }, + } + jobRunningActive := commonOp.JobCondition{ + Type: commonOp.JobRunning, + Status: corev1.ConditionTrue, + Reason: "TensorFlowJobRunning", + Message: "TensorFlowJob the-job is running.", + LastUpdateTime: v1.Time{ + Time: now.Add(time.Minute), + }, + LastTransitionTime: v1.Time{ + Time: now.Add(time.Minute), + }, + } + jobRunningInactive := *jobRunningActive.DeepCopy() + jobRunningInactive.Status = corev1.ConditionFalse + jobSucceeded := commonOp.JobCondition{ + Type: commonOp.JobSucceeded, + Status: corev1.ConditionTrue, + Reason: "TensorFlowJobSucceeded", + Message: "TensorFlowJob the-job is successfully completed.", + LastUpdateTime: v1.Time{ + Time: now.Add(2 * time.Minute), + }, + LastTransitionTime: v1.Time{ + Time: now.Add(2 * time.Minute), + }, + } + jobFailed := commonOp.JobCondition{ + Type: commonOp.JobFailed, + Status: corev1.ConditionTrue, + Reason: "TensorFlowJobFailed", + Message: "TensorFlowJob the-job is failed.", + LastUpdateTime: v1.Time{ + Time: now.Add(2 * time.Minute), + }, + LastTransitionTime: v1.Time{ + Time: now.Add(2 * time.Minute), + }, + } + jobRestarting := commonOp.JobCondition{ + Type: commonOp.JobRestarting, + Status: corev1.ConditionTrue, + Reason: "TensorFlowJobRestarting", + Message: "TensorFlowJob the-job is restarting because some replica(s) failed.", + LastUpdateTime: v1.Time{ + Time: now.Add(3 * time.Minute), + }, + LastTransitionTime: v1.Time{ + Time: now.Add(3 * time.Minute), + }, + } + + switch conditionType { + case commonOp.JobCreated: + jobConditions = []commonOp.JobCondition{ + jobCreated, + } + case commonOp.JobRunning: + jobConditions = []commonOp.JobCondition{ + jobCreated, + jobRunningActive, + } + case commonOp.JobSucceeded: + jobConditions = []commonOp.JobCondition{ + jobCreated, + jobRunningInactive, + jobSucceeded, + } + case commonOp.JobFailed: + jobConditions = []commonOp.JobCondition{ + jobCreated, + jobRunningInactive, + jobFailed, + } + case commonOp.JobRestarting: + jobConditions = []commonOp.JobCondition{ + jobCreated, + jobRunningInactive, + jobFailed, + jobRestarting, + } + } + + tfObj := dummyTensorFlowCustomObj(workers, psReplicas, chiefReplicas) + taskTemplate := dummySparkTaskTemplate("the job", tfObj) + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate)) + if err != nil { + panic(err) + } + + return &tfOp.TFJob{ + ObjectMeta: v1.ObjectMeta{ + Name: jobName, + Namespace: jobNamespace, + }, + Spec: resource.(*tfOp.TFJob).Spec, + Status: commonOp.JobStatus{ + Conditions: jobConditions, + ReplicaStatuses: nil, + StartTime: nil, + CompletionTime: nil, + LastReconcileTime: nil, + }, + } +} + +func TestBuildResourceTensorFlow(t *testing.T) { + tensorflowResourceHandler := tensorflowOperatorResourceHandler{} + + tfObj := dummyTensorFlowCustomObj(100, 50, 1) + taskTemplate := dummySparkTaskTemplate("the job", tfObj) + + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, resource) + + tensorflowJob, ok := resource.(*tfOp.TFJob) + assert.True(t, ok) + assert.Equal(t, int32(100), *tensorflowJob.Spec.TFReplicaSpecs[tfOp.TFReplicaTypeWorker].Replicas) + assert.Equal(t, int32(50), *tensorflowJob.Spec.TFReplicaSpecs[tfOp.TFReplicaTypePS].Replicas) + assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[tfOp.TFReplicaTypeChief].Replicas) + + for _, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs { + var hasContainerWithDefaultTensorFlowName = false + + for _, container := range replicaSpec.Template.Spec.Containers { + if container.Name == tfOp.DefaultContainerName { + hasContainerWithDefaultTensorFlowName = true + } + + assert.Equal(t, resourceRequirements.Requests, container.Resources.Requests) + assert.Equal(t, resourceRequirements.Limits, container.Resources.Limits) + } + + assert.True(t, hasContainerWithDefaultTensorFlowName) + } +} + +func TestGetTaskPhase(t *testing.T) { + tensorflowResourceHandler := tensorflowOperatorResourceHandler{} + ctx := context.TODO() + + dummyTensorFlowJobResourceCreator := func(conditionType commonOp.JobConditionType) *tfOp.TFJob { + return dummyTensorFlowJobResource(tensorflowResourceHandler, 2, 1, 1, conditionType) + } + + taskPhase, err := tensorflowResourceHandler.GetTaskPhase(ctx, nil, dummyTensorFlowJobResourceCreator(commonOp.JobCreated)) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + taskPhase, err = tensorflowResourceHandler.GetTaskPhase(ctx, nil, dummyTensorFlowJobResourceCreator(commonOp.JobRunning)) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + taskPhase, err = tensorflowResourceHandler.GetTaskPhase(ctx, nil, dummyTensorFlowJobResourceCreator(commonOp.JobSucceeded)) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseSuccess, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + taskPhase, err = tensorflowResourceHandler.GetTaskPhase(ctx, nil, dummyTensorFlowJobResourceCreator(commonOp.JobFailed)) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRetryableFailure, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + taskPhase, err = tensorflowResourceHandler.GetTaskPhase(ctx, nil, dummyTensorFlowJobResourceCreator(commonOp.JobRestarting)) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) +} + +func TestGetLogs(t *testing.T) { + assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{ + IsKubernetesEnabled: true, + KubernetesURL: "k8s.com", + })) + + workers := int32(2) + psReplicas := int32(1) + chiefReplicas := int32(1) + + tensorflowResourceHandler := tensorflowOperatorResourceHandler{} + tensorFlowJob := dummyTensorFlowJobResource(tensorflowResourceHandler, workers, psReplicas, chiefReplicas, commonOp.JobRunning) + jobLogs, err := common.GetLogs(common.TensorflowTaskType, tensorFlowJob.Name, tensorFlowJob.Namespace, + workers, psReplicas, chiefReplicas) + assert.NoError(t, err) + assert.Equal(t, 4, len(jobLogs)) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[0].Uri) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-1/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[1].Uri) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-psReplica-0/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[2].Uri) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-chiefReplica-0/pod?namespace=tensorflow-namespace", jobNamespace, jobName), jobLogs[3].Uri) +}