Skip to content

Commit

Permalink
pytorch plugin implementation (flyteorg#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
igorvalko authored May 28, 2020
1 parent 502fafe commit 74715c3
Show file tree
Hide file tree
Showing 4 changed files with 593 additions and 1 deletion.
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ require (
github.com/golang/protobuf v1.3.3
github.com/googleapis/gnostic v0.4.1 // indirect
github.com/hashicorp/golang-lru v0.5.4
github.com/lyft/flyteidl v0.17.9
github.com/kubeflow/pytorch-operator v0.6.0
github.com/kubeflow/tf-operator v0.5.3
github.com/lyft/flyteidl v0.17.32
github.com/lyft/flytestdlib v0.3.3
github.com/magiconair/properties v1.8.1
github.com/mitchellh/mapstructure v1.1.2
Expand Down
13 changes: 13 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/PuerkitoBio/purell v1.0.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0=
github.com/PuerkitoBio/purell v1.1.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0=
github.com/PuerkitoBio/purell v1.1.1 h1:WEQqlqaGbrPkxLJWfBwQmfEAE1Z7ONdDLqrN38tNFfI=
github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0=
github.com/PuerkitoBio/urlesc v0.0.0-20160726150825-5bd2802263f2/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE=
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV/sSk/8dngufqelfh6jnri85riMAaF/M=
github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
Expand Down Expand Up @@ -114,6 +116,7 @@ github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZ
github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE=
github.com/elazarl/goproxy v0.0.0-20170405201442-c4fc26588b6e/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc=
github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs=
github.com/emicklei/go-restful v2.9.5+incompatible h1:spTtZBk5DYEvbxMVutUuTyh1Ao2r4iyvLdACqsl/Ljk=
github.com/emicklei/go-restful v2.9.5+incompatible/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
Expand Down Expand Up @@ -153,10 +156,12 @@ github.com/go-openapi/errors v0.19.2/go.mod h1:qX0BLWsyaKfvhluLejVpVNwNRdXZhEbTA
github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0=
github.com/go-openapi/jsonpointer v0.17.0/go.mod h1:cOnomiV+CVVwFLk0A/MExoFMjwdsUdVpsRhURCKh+3M=
github.com/go-openapi/jsonpointer v0.18.0/go.mod h1:cOnomiV+CVVwFLk0A/MExoFMjwdsUdVpsRhURCKh+3M=
github.com/go-openapi/jsonpointer v0.19.2 h1:A9+F4Dc/MCNB5jibxf6rRvOvR/iFgQdyNx9eIhnGqq0=
github.com/go-openapi/jsonpointer v0.19.2/go.mod h1:3akKfEdA7DF1sugOqz1dVQHBcuDBPKZGEoHC/NkiQRg=
github.com/go-openapi/jsonreference v0.0.0-20160704190145-13c6e3589ad9/go.mod h1:W3Z9FmVs9qj+KR4zFKmDPGiLdk1D9Rlm7cyMvf57TTg=
github.com/go-openapi/jsonreference v0.17.0/go.mod h1:g4xxGn04lDIRh0GJb5QlpE3HfopLOL6uZrK/VgnsK9I=
github.com/go-openapi/jsonreference v0.18.0/go.mod h1:g4xxGn04lDIRh0GJb5QlpE3HfopLOL6uZrK/VgnsK9I=
github.com/go-openapi/jsonreference v0.19.2 h1:o20suLFB4Ri0tuzpWtyHlh7E7HnkqTNLq6aR6WVNS1w=
github.com/go-openapi/jsonreference v0.19.2/go.mod h1:jMjeRr2HHw6nAVajTXJ4eiUwohSTlpa0o73RUL1owJc=
github.com/go-openapi/loads v0.17.0/go.mod h1:72tmFy5wsWx89uEVddd0RjRWPZm92WRLhf7AC+0+OOU=
github.com/go-openapi/loads v0.18.0/go.mod h1:72tmFy5wsWx89uEVddd0RjRWPZm92WRLhf7AC+0+OOU=
Expand All @@ -167,13 +172,15 @@ github.com/go-openapi/runtime v0.19.0/go.mod h1:OwNfisksmmaZse4+gpV3Ne9AyMOlP1lt
github.com/go-openapi/spec v0.0.0-20160808142527-6aced65f8501/go.mod h1:J8+jY1nAiCcj+friV/PDoE1/3eeccG9LYBs0tYvLOWc=
github.com/go-openapi/spec v0.17.0/go.mod h1:XkF/MOi14NmjsfZ8VtAKf8pIlbZzyoTvZsdfssdxcBI=
github.com/go-openapi/spec v0.18.0/go.mod h1:XkF/MOi14NmjsfZ8VtAKf8pIlbZzyoTvZsdfssdxcBI=
github.com/go-openapi/spec v0.19.2 h1:SStNd1jRcYtfKCN7R0laGNs80WYYvn5CbBjM2sOmCrE=
github.com/go-openapi/spec v0.19.2/go.mod h1:sCxk3jxKgioEJikev4fgkNmwS+3kuYdJtcsZsD5zxMY=
github.com/go-openapi/strfmt v0.17.0/go.mod h1:P82hnJI0CXkErkXi8IKjPbNBM6lV6+5pLP5l494TcyU=
github.com/go-openapi/strfmt v0.18.0/go.mod h1:P82hnJI0CXkErkXi8IKjPbNBM6lV6+5pLP5l494TcyU=
github.com/go-openapi/strfmt v0.19.0/go.mod h1:+uW+93UVvGGq2qGaZxdDeJqSAqBqBdl+ZPMF/cC8nDY=
github.com/go-openapi/swag v0.0.0-20160704191624-1d0bd113de87/go.mod h1:DXUve3Dpr1UfpPtxFw+EFuQ41HhCWZfha5jSVRG7C7I=
github.com/go-openapi/swag v0.17.0/go.mod h1:AByQ+nYG6gQg71GINrmuDXCPWdL640yX49/kXLo40Tg=
github.com/go-openapi/swag v0.18.0/go.mod h1:AByQ+nYG6gQg71GINrmuDXCPWdL640yX49/kXLo40Tg=
github.com/go-openapi/swag v0.19.2 h1:jvO6bCMBEilGwMfHhrd61zIID4oIFdwb76V17SM88dE=
github.com/go-openapi/swag v0.19.2/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
github.com/go-openapi/validate v0.18.0/go.mod h1:Uh4HdOzKt19xGIGm1qHf/ofbX1YQ4Y+MYsct2VUrAJ4=
github.com/go-openapi/validate v0.19.2/go.mod h1:1tRCw7m3jtI8eNWEEliiAqUIcBztB2KDnRCRMUi7GTA=
Expand Down Expand Up @@ -292,12 +299,17 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kubeflow/pytorch-operator v0.6.0 h1:y9Vzk7Jd5H/s610Y+ucURypCHgJugB25UL8GEz4DRL4=
github.com/kubeflow/pytorch-operator v0.6.0/go.mod h1:zHblV+yTwVG4PCgKTU2wPfOmQ6TJdfT87lDfHrP1a1Y=
github.com/kubeflow/tf-operator v0.5.3 h1:Ejn5vEAwHBKHU2sJTlUIRpezqIX3WeqXZ2dZx6zn6vY=
github.com/kubeflow/tf-operator v0.5.3/go.mod h1:EBtz5LQoKaHUl/5fV5vD1qXVNVNyn3TrFaH6eVoQ8SY=
github.com/lyft/api v0.0.0-20191031200350-b49a72c274e0 h1:NGL46+1RYcCXb3sShp0nQq4W38fcgnpCD4+X02eeLL0=
github.com/lyft/api v0.0.0-20191031200350-b49a72c274e0/go.mod h1:/L5qH+AD540e7Cetbui1tuJeXdmNhO8jM6VkXeDdDhQ=
github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f h1:PGuAMDzAen0AulUfaEhNQMYmUpa41pAVo3zHI+GJsCM=
github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f/go.mod h1:llRdnznGEAqC3DcNm6yEj472xaFVfLM7hnYofMb12tQ=
github.com/lyft/flyteidl v0.17.9 h1:JXT9PovHqS9V3YN74x9zWT0kvIEL48c2uNoujF1KMes=
github.com/lyft/flyteidl v0.17.9/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20=
github.com/lyft/flyteidl v0.17.29/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20=
github.com/lyft/flytestdlib v0.3.0 h1:nIkX4MlyYdcLLzaF35RI2P5BhARt+qMgHoFto8eVNzU=
github.com/lyft/flytestdlib v0.3.0/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU=
github.com/lyft/flytestdlib v0.3.2 h1:bY6Y+Fg6Jdc7zY4GAYuR7t2hjWwynIdmRvtLcRNaGnw=
Expand All @@ -312,6 +324,7 @@ github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czP
github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/mailru/easyjson v0.0.0-20180823135443-60711f1a8329/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63 h1:nTT4s92Dgz2HlrB2NaMgvlfqHH39OgMhA7z3PK7PGD4=
github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA=
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
Expand Down
224 changes: 224 additions & 0 deletions go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
package pytorch

import (
"context"
"fmt"
"sort"
"time"

"github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins"
flyteerr "github.com/lyft/flyteplugins/go/tasks/errors"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s"
v1 "k8s.io/api/core/v1"
"k8s.io/client-go/kubernetes/scheme"

pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils"

logUtils "github.com/lyft/flyteidl/clients/go/coreutils/logs"
"github.com/lyft/flyteidl/gen/pb-go/flyteidl/core"
"github.com/lyft/flyteplugins/go/tasks/logs"

//commonOp "github.com/kubeflow/common/pkg/apis/common/v1" // switch to real 'common' once https://github.com/kubeflow/pytorch-operator/issues/263 resolved
ptOp "github.com/kubeflow/pytorch-operator/pkg/apis/pytorch/v1"
commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

const (
pytorchTaskType = "pytorch"
)

type pytorchOperatorResourceHandler struct {
}

// Sanity test that the plugin implements method of k8s.Plugin
var _ k8s.Plugin = pytorchOperatorResourceHandler{}

// Defines a func to create a query object (typically just object and type meta portions) that's used to query k8s
// resources.
func (pytorchOperatorResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionMetadata) (k8s.Resource, error) {
return &ptOp.PyTorchJob{
TypeMeta: metav1.TypeMeta{
Kind: ptOp.Kind,
APIVersion: ptOp.SchemeGroupVersion.String(),
},
}, nil
}

// Defines a func to create the full resource object that will be posted to k8s.
func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (k8s.Resource, error) {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)

if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "unable to fetch task specification [%v]", err.Error())
} else if taskTemplate == nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification")
}

pytorchTaskExtraArgs := plugins.DistributedPyTorchTrainingTask{}
err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &pytorchTaskExtraArgs)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error())
}

podSpec, err := flytek8s.ToK8sPodSpec(ctx, taskCtx.TaskExecutionMetadata(), taskCtx.TaskReader(), taskCtx.InputReader(), taskCtx.OutputWriter())
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}

overrideDefaultContainerName(taskCtx, podSpec)

workers := pytorchTaskExtraArgs.GetWorkers()

jobSpec := ptOp.PyTorchJobSpec{
TTLSecondsAfterFinished: nil,
PyTorchReplicaSpecs: map[ptOp.PyTorchReplicaType]*commonOp.ReplicaSpec{
ptOp.PyTorchReplicaTypeMaster: {
Template: v1.PodTemplateSpec{
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
},
ptOp.PyTorchReplicaTypeWorker: {
Replicas: &workers,
Template: v1.PodTemplateSpec{
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
},
},
}

job := &ptOp.PyTorchJob{
TypeMeta: metav1.TypeMeta{
Kind: ptOp.Kind,
APIVersion: ptOp.SchemeGroupVersion.String(),
},
Spec: jobSpec,
}

return job, nil
}

// Analyses the k8s resource and reports the status as TaskPhase. This call is expected to be relatively fast,
// any operations that might take a long time (limits are configured system-wide) should be offloaded to the
// background.
func (pytorchOperatorResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource k8s.Resource) (pluginsCore.PhaseInfo, error) {
app := resource.(*ptOp.PyTorchJob)

workersCount := app.Spec.PyTorchReplicaSpecs[ptOp.PyTorchReplicaTypeWorker].Replicas

taskLogs, err := getLogs(app, *workersCount)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}

currentCondition, err := extractCurrentCondition(app.Status.Conditions)
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}

occurredAt := time.Now()
statusDetails, _ := utils.MarshalObjToStruct(app.Status)
taskPhaseInfo := pluginsCore.TaskInfo{
Logs: taskLogs,
OccurredAt: &occurredAt,
CustomInfo: statusDetails,
}

switch currentCondition.Type {
case commonOp.JobCreated:
return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "JobCreated"), nil
case commonOp.JobRunning:
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil
case commonOp.JobSucceeded:
return pluginsCore.PhaseInfoSuccess(&taskPhaseInfo), nil
case commonOp.JobFailed:
details := fmt.Sprintf("Job failed:\n\t%v - %v", currentCondition.Reason, currentCondition.Message)
return pluginsCore.PhaseInfoRetryableFailure(flyteerr.DownstreamSystemError, details, &taskPhaseInfo), nil
case commonOp.JobRestarting:
details := fmt.Sprintf("Job failed:\n\t%v - %v", currentCondition.Reason, currentCondition.Message)
return pluginsCore.PhaseInfoRetryableFailure(flyteerr.RuntimeFailure, details, &taskPhaseInfo), nil
}

return pluginsCore.PhaseInfoUndefined, nil
}

func getLogs(app *ptOp.PyTorchJob, workersCount int32) ([]*core.TaskLog, error) {
// If kubeClient was available, it would be better to use
// https://github.com/lyft/flyteplugins/blob/209c52d002b4e6a39be5d175bc1046b7e631c153/go/tasks/logs/logging_utils.go#L12
makeTaskLog := func(appName, appNamespace, suffix, url string) (core.TaskLog, error) {
return logUtils.NewKubernetesLogPlugin(url).GetTaskLog(
appName+"-"+suffix,
appNamespace,
"",
"",
suffix+" logs (via Kubernetes)")
}

var taskLogs []*core.TaskLog

logConfig := logs.GetLogConfig()
if logConfig.IsKubernetesEnabled {
masterTaskLog, masterErr := makeTaskLog(app.Name, app.Namespace, "master-0", logConfig.KubernetesURL)
if masterErr != nil {
return nil, masterErr
}
taskLogs = append(taskLogs, &masterTaskLog)

for workerIndex := int32(0); workerIndex < workersCount; workerIndex++ {
workerLog, err := makeTaskLog(app.Name, app.Namespace, fmt.Sprintf("worker-%d", workerIndex), logConfig.KubernetesURL)
if err != nil {
return nil, err
}
taskLogs = append(taskLogs, &workerLog)
}
}
return taskLogs, nil
}

func extractCurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) {
sort.Slice(jobConditions[:], func(i, j int) bool {
return jobConditions[i].LastTransitionTime.Time.After(jobConditions[j].LastTransitionTime.Time)
})

for _, jc := range jobConditions {
if jc.Status == v1.ConditionTrue {
return jc, nil
}
}

return commonOp.JobCondition{}, fmt.Errorf("found no current condition. Conditions: %+v", jobConditions)
}

func overrideDefaultContainerName(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec) {
// Pytorch operator forces pod to have container named 'pytorch'
// https://github.com/kubeflow/pytorch-operator/blob/037cd1b18eb77f657f2a4bc8a8334f2a06324b57/pkg/apis/pytorch/validation/validation.go#L54-L62
// hence we have to override the name set here
// https://github.com/lyft/flyteplugins/blob/209c52d002b4e6a39be5d175bc1046b7e631c153/go/tasks/pluginmachinery/flytek8s/container_helper.go#L116
flyteDefaultContainerName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
for idx, c := range podSpec.Containers {
if c.Name == flyteDefaultContainerName {
podSpec.Containers[idx].Name = ptOp.DefaultContainerName
return
}
}
}

func init() {
if err := ptOp.AddToScheme(scheme.Scheme); err != nil {
panic(err)
}

pluginmachinery.PluginRegistry().RegisterK8sPlugin(
k8s.PluginEntry{
ID: pytorchTaskType,
RegisteredTaskTypes: []pluginsCore.TaskType{pytorchTaskType},
ResourceToWatch: &ptOp.PyTorchJob{},
Plugin: pytorchOperatorResourceHandler{},
IsDefault: false,
})
}
Loading

0 comments on commit 74715c3

Please sign in to comment.