Skip to content

Commit

Permalink
tensorflow plugin implementation (flyteorg#103)
Browse files Browse the repository at this point in the history
* Tensorflow plugin implementaion

* Fix checkstyle

* Update go/tasks/plugins/k8s/kfoperators/common/common_operator.go

Co-authored-by: Haytham Abuelfutuh <[email protected]>

* Address comments

* Fix lint error

* Address comment

* Fix lint

* Update to running

Co-authored-by: Haytham Abuelfutuh <[email protected]>
  • Loading branch information
pingsutw and EngHabu authored Feb 13, 2021
1 parent 275483f commit e84585e
Show file tree
Hide file tree
Showing 8 changed files with 754 additions and 102 deletions.
14 changes: 14 additions & 0 deletions copilot/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ github.com/aws/aws-sdk-go v1.28.9/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN
github.com/aws/aws-sdk-go v1.29.23 h1:wtiGLOzxAP755OfuVTDIy/NbUIYEDxbIbBEDfNhUpeU=
github.com/aws/aws-sdk-go v1.29.23/go.mod h1:1KvfttTE3SPKMpo8g2c6jL3ZKfXtFvKscTgahTma5Xg=
github.com/aws/aws-sdk-go-v2 v0.20.0/go.mod h1:2LhT7UgHOXK3UXONKI5OMgIyoQL6zTAw/jwIeX6yqzw=
github.com/aws/aws-sdk-go-v2 v1.0.0/go.mod h1:smfAbmpW+tcRVuNUjo3MOArSZmW72t62rkCzc2i0TWM=
github.com/aws/aws-sdk-go-v2/config v1.0.0/go.mod h1:WysE/OpUgE37tjtmtJd8GXgT8s1euilE5XtUkRNUQ1w=
github.com/aws/aws-sdk-go-v2/credentials v1.0.0/go.mod h1:/SvsiqBf509hG4Bddigr3NB12MIpfHhZapyBurJe8aY=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.0.0/go.mod h1:wpMHDCXvOXZxGCRSidyepa8uJHY4vaBGfY2/+oKU/Bc=
github.com/aws/aws-sdk-go-v2/service/athena v1.0.0/go.mod h1:qY8QFbemf2ceqweXcS6hQqiiIe1z42WqTvHsK2Lb0rE=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.0.0/go.mod h1:3jExOmpbjgPnz2FJaMOfbSk1heTkZ66aD3yNtVhnjvI=
github.com/aws/aws-sdk-go-v2/service/sagemaker v1.0.0/go.mod h1:8/T2od4WQj1qKPr2ppDgjCnMFR6hfYJM4hzjH1D+HWg=
github.com/aws/aws-sdk-go-v2/service/sts v1.0.0/go.mod h1:5f+cELGATgill5Pu3/vK3Ebuigstc+qYEHW5MvGWZO4=
github.com/aws/smithy-go v1.0.0/go.mod h1:EzMw8dbp/YJL4A5/sbhGddag+NPT7q084agLbB9LgIw=
github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1/go.mod h1:jvdWlw8vowVGnZqSDC7yhPd7AifQeQbRDkZcQXV2nRg=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
Expand Down Expand Up @@ -233,6 +242,8 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI=
github.com/google/gofuzz v0.0.0-20170612174753-24818f796faf/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
Expand Down Expand Up @@ -287,6 +298,8 @@ github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5i
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc=
github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.5/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
Expand Down Expand Up @@ -320,6 +333,7 @@ github.com/lyft/api v0.0.0-20191031200350-b49a72c274e0 h1:NGL46+1RYcCXb3sShp0nQq
github.com/lyft/api v0.0.0-20191031200350-b49a72c274e0/go.mod h1:/L5qH+AD540e7Cetbui1tuJeXdmNhO8jM6VkXeDdDhQ=
github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f h1:PGuAMDzAen0AulUfaEhNQMYmUpa41pAVo3zHI+GJsCM=
github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f/go.mod h1:llRdnznGEAqC3DcNm6yEj472xaFVfLM7hnYofMb12tQ=
github.com/lyft/flyteidl v0.18.9/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20=
github.com/lyft/flyteidl v0.18.11 h1:24NaFYWxANhRbwKfvkgu8axGTWUcl1tgZBqNJutKNJ8=
github.com/lyft/flyteidl v0.18.11/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20=
github.com/lyft/flytestdlib v0.3.0/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU=
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ github.com/aws/aws-sdk-go-v2 v0.20.0/go.mod h1:2LhT7UgHOXK3UXONKI5OMgIyoQL6zTAw/
github.com/aws/aws-sdk-go-v2 v0.24.0/go.mod h1:2LhT7UgHOXK3UXONKI5OMgIyoQL6zTAw/jwIeX6yqzw=
github.com/aws/aws-sdk-go-v2 v1.0.0 h1:ncEVPoHArsG+HjoDe/3ex/TG1CbLwMQ4eaWj0UGdyTo=
github.com/aws/aws-sdk-go-v2 v1.0.0/go.mod h1:smfAbmpW+tcRVuNUjo3MOArSZmW72t62rkCzc2i0TWM=
github.com/aws/aws-sdk-go-v2 v1.1.0 h1:sKP6QWxdN1oRYjl+k6S3bpgBI+XUx/0mqVOLIw4lR/Q=
github.com/aws/aws-sdk-go-v2/config v1.0.0 h1:x6vSFAwqAvhYPeSu60f0ZUlGHo3PKKmwDOTL8aMXtv4=
github.com/aws/aws-sdk-go-v2/config v1.0.0/go.mod h1:WysE/OpUgE37tjtmtJd8GXgT8s1euilE5XtUkRNUQ1w=
github.com/aws/aws-sdk-go-v2/credentials v1.0.0 h1:0M7netgZ8gCV4v7z1km+Fbl7j6KQYyZL7SS0/l5Jn/4=
Expand All @@ -168,6 +169,7 @@ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.0.0 h1:IAutMPSryn
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.0.0/go.mod h1:3jExOmpbjgPnz2FJaMOfbSk1heTkZ66aD3yNtVhnjvI=
github.com/aws/aws-sdk-go-v2/service/sagemaker v1.0.0 h1:WAKXnA5HISN6P8sbXsJ9486ThbRPnoBAtMyDSG7+jNM=
github.com/aws/aws-sdk-go-v2/service/sagemaker v1.0.0/go.mod h1:8/T2od4WQj1qKPr2ppDgjCnMFR6hfYJM4hzjH1D+HWg=
github.com/aws/aws-sdk-go-v2/service/sagemaker v1.1.0 h1:qsaGAmYqUzym7g4uaBzx5uOYoEJW0wIHhgObLqZc1mo=
github.com/aws/aws-sdk-go-v2/service/sts v1.0.0 h1:6XCgxNfE4L/Fnq+InhVNd16DKc6Ue1f3dJl3IwwJRUQ=
github.com/aws/aws-sdk-go-v2/service/sts v1.0.0/go.mod h1:5f+cELGATgill5Pu3/vK3Ebuigstc+qYEHW5MvGWZO4=
github.com/aws/smithy-go v1.0.0 h1:hkhcRKG9rJ4Fn+RbfXY7Tz7b3ITLDyolBnLLBhwbg/c=
Expand Down
138 changes: 138 additions & 0 deletions go/tasks/plugins/k8s/kfoperators/common/common_operator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package common

import (
"fmt"
"sort"
"time"

"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/tasklog"

commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1"
"github.com/lyft/flyteidl/gen/pb-go/flyteidl/core"
flyteerr "github.com/lyft/flyteplugins/go/tasks/errors"
"github.com/lyft/flyteplugins/go/tasks/logs"
pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core"
v1 "k8s.io/api/core/v1"
)

const (
TensorflowTaskType = "tensorflow"
PytorchTaskType = "pytorch"
)

func ExtractCurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) {
if jobConditions != nil {
sort.Slice(jobConditions, func(i, j int) bool {
return jobConditions[i].LastTransitionTime.Time.After(jobConditions[j].LastTransitionTime.Time)
})

for _, jc := range jobConditions {
if jc.Status == v1.ConditionTrue {
return jc, nil
}
}
}

return commonOp.JobCondition{}, fmt.Errorf("found no current condition. Conditions: %+v", jobConditions)
}

func GetPhaseInfo(currentCondition commonOp.JobCondition, occurredAt time.Time,
taskPhaseInfo pluginsCore.TaskInfo) (pluginsCore.PhaseInfo, error) {
switch currentCondition.Type {
case commonOp.JobCreated:
return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "JobCreated"), nil
case commonOp.JobRunning:
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil
case commonOp.JobSucceeded:
return pluginsCore.PhaseInfoSuccess(&taskPhaseInfo), nil
case commonOp.JobFailed:
details := fmt.Sprintf("Job failed:\n\t%v - %v", currentCondition.Reason, currentCondition.Message)
return pluginsCore.PhaseInfoRetryableFailure(flyteerr.DownstreamSystemError, details, &taskPhaseInfo), nil
case commonOp.JobRestarting:
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil
}

return pluginsCore.PhaseInfoUndefined, nil
}

func GetLogs(taskType string, name string, namespace string,
workersCount int32, psReplicasCount int32, chiefReplicasCount int32) ([]*core.TaskLog, error) {
taskLogs := make([]*core.TaskLog, 0, 10)

logPlugin, err := logs.InitializeLogPlugins(logs.GetLogConfig())

if err != nil {
return nil, err
}

if logPlugin == nil {
return nil, nil
}

if taskType == PytorchTaskType {
masterTaskLog, masterErr := logPlugin.GetTaskLogs(
tasklog.Input{
PodName: name + "-master-0",
Namespace: namespace,
LogName: "master",
},
)
if masterErr != nil {
return nil, masterErr
}
taskLogs = append(taskLogs, masterTaskLog.TaskLogs...)
}

// get all workers log
for workerIndex := int32(0); workerIndex < workersCount; workerIndex++ {
workerLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-worker-%d", workerIndex),
Namespace: namespace,
})
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, workerLog.TaskLogs...)
}
// get all parameter servers logs
for psReplicaIndex := int32(0); psReplicaIndex < psReplicasCount; psReplicaIndex++ {
psReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-psReplica-%d", psReplicaIndex),
Namespace: namespace,
})
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, psReplicaLog.TaskLogs...)
}
// get chief worker log, and the max number of chief worker is 1
if chiefReplicasCount != 0 {
chiefReplicaLog, err := logPlugin.GetTaskLogs(tasklog.Input{
PodName: name + fmt.Sprintf("-chiefReplica-%d", 0),
Namespace: namespace,
})
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, chiefReplicaLog.TaskLogs...)
}

return taskLogs, nil
}

func OverrideDefaultContainerName(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec,
defaultContainerName string) {
// Pytorch operator forces pod to have container named 'pytorch'
// https://github.com/kubeflow/pytorch-operator/blob/037cd1b18eb77f657f2a4bc8a8334f2a06324b57/pkg/apis/pytorch/validation/validation.go#L54-L62
// Tensorflow operator forces pod to have container named 'tensorflow'
// https://github.com/kubeflow/tf-operator/blob/984adc287e6fe82841e4ca282dc9a2cbb71e2d4a/pkg/apis/tensorflow/validation/validation.go#L55-L63
// hence we have to override the name set here
// https://github.com/lyft/flyteplugins/blob/209c52d002b4e6a39be5d175bc1046b7e631c153/go/tasks/pluginmachinery/flytek8s/container_helper.go#L116
flyteDefaultContainerName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
for idx, c := range podSpec.Containers {
if c.Name == flyteDefaultContainerName {
podSpec.Containers[idx].Name = defaultContainerName
return
}
}
}
67 changes: 67 additions & 0 deletions go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package common

import (
"testing"
"time"

commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1"
pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/stretchr/testify/assert"
corev1 "k8s.io/api/core/v1"
)

func TestExtractCurrentCondition(t *testing.T) {
jobCreated := commonOp.JobCondition{
Type: commonOp.JobCreated,
Status: corev1.ConditionTrue,
}
jobRunningActive := commonOp.JobCondition{
Type: commonOp.JobRunning,
Status: corev1.ConditionFalse,
}
jobConditions := []commonOp.JobCondition{
jobCreated,
jobRunningActive,
}
currentCondition, err := ExtractCurrentCondition(jobConditions)
assert.NoError(t, err)
assert.Equal(t, currentCondition, jobCreated)
}

func TestGetPhaseInfo(t *testing.T) {
jobCreated := commonOp.JobCondition{
Type: commonOp.JobCreated,
}
taskPhase, err := GetPhaseInfo(jobCreated, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)

jobSucceeded := commonOp.JobCondition{
Type: commonOp.JobSucceeded,
}
taskPhase, err = GetPhaseInfo(jobSucceeded, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseSuccess, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)

jobFailed := commonOp.JobCondition{
Type: commonOp.JobFailed,
}
taskPhase, err = GetPhaseInfo(jobFailed, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseRetryableFailure, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)

jobRestarting := commonOp.JobCondition{
Type: commonOp.JobRestarting,
}
taskPhase, err = GetPhaseInfo(jobRestarting, time.Now(), pluginsCore.TaskInfo{})
assert.NoError(t, err)
assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase())
assert.NotNil(t, taskPhase.Info())
assert.Nil(t, err)
}
Loading

0 comments on commit e84585e

Please sign in to comment.