diff --git a/flyteplugins/go.mod b/flyteplugins/go.mod index 610aba01e..8ce37c384 100644 --- a/flyteplugins/go.mod +++ b/flyteplugins/go.mod @@ -11,7 +11,9 @@ require ( github.com/golang/protobuf v1.3.3 github.com/googleapis/gnostic v0.4.1 // indirect github.com/hashicorp/golang-lru v0.5.4 - github.com/lyft/flyteidl v0.17.9 + github.com/kubeflow/pytorch-operator v0.6.0 + github.com/kubeflow/tf-operator v0.5.3 + github.com/lyft/flyteidl v0.17.32 github.com/lyft/flytestdlib v0.3.3 github.com/magiconair/properties v1.8.1 github.com/mitchellh/mapstructure v1.1.2 diff --git a/flyteplugins/go.sum b/flyteplugins/go.sum index ae9c3db9e..deaab3b17 100644 --- a/flyteplugins/go.sum +++ b/flyteplugins/go.sum @@ -50,8 +50,10 @@ github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/purell v1.0.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/purell v1.1.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= +github.com/PuerkitoBio/purell v1.1.1 h1:WEQqlqaGbrPkxLJWfBwQmfEAE1Z7ONdDLqrN38tNFfI= github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/urlesc v0.0.0-20160726150825-5bd2802263f2/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= +github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV/sSk/8dngufqelfh6jnri85riMAaF/M= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= @@ -114,6 +116,7 @@ github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZ github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= github.com/elazarl/goproxy v0.0.0-20170405201442-c4fc26588b6e/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= +github.com/emicklei/go-restful v2.9.5+incompatible h1:spTtZBk5DYEvbxMVutUuTyh1Ao2r4iyvLdACqsl/Ljk= github.com/emicklei/go-restful v2.9.5+incompatible/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= @@ -153,10 +156,12 @@ github.com/go-openapi/errors v0.19.2/go.mod h1:qX0BLWsyaKfvhluLejVpVNwNRdXZhEbTA github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0= github.com/go-openapi/jsonpointer v0.17.0/go.mod h1:cOnomiV+CVVwFLk0A/MExoFMjwdsUdVpsRhURCKh+3M= github.com/go-openapi/jsonpointer v0.18.0/go.mod h1:cOnomiV+CVVwFLk0A/MExoFMjwdsUdVpsRhURCKh+3M= +github.com/go-openapi/jsonpointer v0.19.2 h1:A9+F4Dc/MCNB5jibxf6rRvOvR/iFgQdyNx9eIhnGqq0= github.com/go-openapi/jsonpointer v0.19.2/go.mod h1:3akKfEdA7DF1sugOqz1dVQHBcuDBPKZGEoHC/NkiQRg= github.com/go-openapi/jsonreference v0.0.0-20160704190145-13c6e3589ad9/go.mod h1:W3Z9FmVs9qj+KR4zFKmDPGiLdk1D9Rlm7cyMvf57TTg= github.com/go-openapi/jsonreference v0.17.0/go.mod h1:g4xxGn04lDIRh0GJb5QlpE3HfopLOL6uZrK/VgnsK9I= github.com/go-openapi/jsonreference v0.18.0/go.mod h1:g4xxGn04lDIRh0GJb5QlpE3HfopLOL6uZrK/VgnsK9I= +github.com/go-openapi/jsonreference v0.19.2 h1:o20suLFB4Ri0tuzpWtyHlh7E7HnkqTNLq6aR6WVNS1w= github.com/go-openapi/jsonreference v0.19.2/go.mod h1:jMjeRr2HHw6nAVajTXJ4eiUwohSTlpa0o73RUL1owJc= github.com/go-openapi/loads v0.17.0/go.mod h1:72tmFy5wsWx89uEVddd0RjRWPZm92WRLhf7AC+0+OOU= github.com/go-openapi/loads v0.18.0/go.mod h1:72tmFy5wsWx89uEVddd0RjRWPZm92WRLhf7AC+0+OOU= @@ -167,6 +172,7 @@ github.com/go-openapi/runtime v0.19.0/go.mod h1:OwNfisksmmaZse4+gpV3Ne9AyMOlP1lt github.com/go-openapi/spec v0.0.0-20160808142527-6aced65f8501/go.mod h1:J8+jY1nAiCcj+friV/PDoE1/3eeccG9LYBs0tYvLOWc= github.com/go-openapi/spec v0.17.0/go.mod h1:XkF/MOi14NmjsfZ8VtAKf8pIlbZzyoTvZsdfssdxcBI= github.com/go-openapi/spec v0.18.0/go.mod h1:XkF/MOi14NmjsfZ8VtAKf8pIlbZzyoTvZsdfssdxcBI= +github.com/go-openapi/spec v0.19.2 h1:SStNd1jRcYtfKCN7R0laGNs80WYYvn5CbBjM2sOmCrE= github.com/go-openapi/spec v0.19.2/go.mod h1:sCxk3jxKgioEJikev4fgkNmwS+3kuYdJtcsZsD5zxMY= github.com/go-openapi/strfmt v0.17.0/go.mod h1:P82hnJI0CXkErkXi8IKjPbNBM6lV6+5pLP5l494TcyU= github.com/go-openapi/strfmt v0.18.0/go.mod h1:P82hnJI0CXkErkXi8IKjPbNBM6lV6+5pLP5l494TcyU= @@ -174,6 +180,7 @@ github.com/go-openapi/strfmt v0.19.0/go.mod h1:+uW+93UVvGGq2qGaZxdDeJqSAqBqBdl+Z github.com/go-openapi/swag v0.0.0-20160704191624-1d0bd113de87/go.mod h1:DXUve3Dpr1UfpPtxFw+EFuQ41HhCWZfha5jSVRG7C7I= github.com/go-openapi/swag v0.17.0/go.mod h1:AByQ+nYG6gQg71GINrmuDXCPWdL640yX49/kXLo40Tg= github.com/go-openapi/swag v0.18.0/go.mod h1:AByQ+nYG6gQg71GINrmuDXCPWdL640yX49/kXLo40Tg= +github.com/go-openapi/swag v0.19.2 h1:jvO6bCMBEilGwMfHhrd61zIID4oIFdwb76V17SM88dE= github.com/go-openapi/swag v0.19.2/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= github.com/go-openapi/validate v0.18.0/go.mod h1:Uh4HdOzKt19xGIGm1qHf/ofbX1YQ4Y+MYsct2VUrAJ4= github.com/go-openapi/validate v0.19.2/go.mod h1:1tRCw7m3jtI8eNWEEliiAqUIcBztB2KDnRCRMUi7GTA= @@ -292,12 +299,17 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kubeflow/pytorch-operator v0.6.0 h1:y9Vzk7Jd5H/s610Y+ucURypCHgJugB25UL8GEz4DRL4= +github.com/kubeflow/pytorch-operator v0.6.0/go.mod h1:zHblV+yTwVG4PCgKTU2wPfOmQ6TJdfT87lDfHrP1a1Y= +github.com/kubeflow/tf-operator v0.5.3 h1:Ejn5vEAwHBKHU2sJTlUIRpezqIX3WeqXZ2dZx6zn6vY= +github.com/kubeflow/tf-operator v0.5.3/go.mod h1:EBtz5LQoKaHUl/5fV5vD1qXVNVNyn3TrFaH6eVoQ8SY= github.com/lyft/api v0.0.0-20191031200350-b49a72c274e0 h1:NGL46+1RYcCXb3sShp0nQq4W38fcgnpCD4+X02eeLL0= github.com/lyft/api v0.0.0-20191031200350-b49a72c274e0/go.mod h1:/L5qH+AD540e7Cetbui1tuJeXdmNhO8jM6VkXeDdDhQ= github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f h1:PGuAMDzAen0AulUfaEhNQMYmUpa41pAVo3zHI+GJsCM= github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f/go.mod h1:llRdnznGEAqC3DcNm6yEj472xaFVfLM7hnYofMb12tQ= github.com/lyft/flyteidl v0.17.9 h1:JXT9PovHqS9V3YN74x9zWT0kvIEL48c2uNoujF1KMes= github.com/lyft/flyteidl v0.17.9/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= +github.com/lyft/flyteidl v0.17.29/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= github.com/lyft/flytestdlib v0.3.0 h1:nIkX4MlyYdcLLzaF35RI2P5BhARt+qMgHoFto8eVNzU= github.com/lyft/flytestdlib v0.3.0/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU= github.com/lyft/flytestdlib v0.3.2 h1:bY6Y+Fg6Jdc7zY4GAYuR7t2hjWwynIdmRvtLcRNaGnw= @@ -312,6 +324,7 @@ github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czP github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20180823135443-60711f1a8329/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63 h1:nTT4s92Dgz2HlrB2NaMgvlfqHH39OgMhA7z3PK7PGD4= github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go new file mode 100644 index 000000000..98f0b06f9 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -0,0 +1,224 @@ +package pytorch + +import ( + "context" + "fmt" + "sort" + "time" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + flyteerr "github.com/lyft/flyteplugins/go/tasks/errors" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s" + v1 "k8s.io/api/core/v1" + "k8s.io/client-go/kubernetes/scheme" + + pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + + logUtils "github.com/lyft/flyteidl/clients/go/coreutils/logs" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteplugins/go/tasks/logs" + + //commonOp "github.com/kubeflow/common/pkg/apis/common/v1" // switch to real 'common' once https://github.com/kubeflow/pytorch-operator/issues/263 resolved + ptOp "github.com/kubeflow/pytorch-operator/pkg/apis/pytorch/v1" + commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const ( + pytorchTaskType = "pytorch" +) + +type pytorchOperatorResourceHandler struct { +} + +// Sanity test that the plugin implements method of k8s.Plugin +var _ k8s.Plugin = pytorchOperatorResourceHandler{} + +// Defines a func to create a query object (typically just object and type meta portions) that's used to query k8s +// resources. +func (pytorchOperatorResourceHandler) BuildIdentityResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionMetadata) (k8s.Resource, error) { + return &ptOp.PyTorchJob{ + TypeMeta: metav1.TypeMeta{ + Kind: ptOp.Kind, + APIVersion: ptOp.SchemeGroupVersion.String(), + }, + }, nil +} + +// Defines a func to create the full resource object that will be posted to k8s. +func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (k8s.Resource, error) { + taskTemplate, err := taskCtx.TaskReader().Read(ctx) + + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "unable to fetch task specification [%v]", err.Error()) + } else if taskTemplate == nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification") + } + + pytorchTaskExtraArgs := plugins.DistributedPyTorchTrainingTask{} + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &pytorchTaskExtraArgs) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } + + podSpec, err := flytek8s.ToK8sPodSpec(ctx, taskCtx.TaskExecutionMetadata(), taskCtx.TaskReader(), taskCtx.InputReader(), taskCtx.OutputWriter()) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) + } + + overrideDefaultContainerName(taskCtx, podSpec) + + workers := pytorchTaskExtraArgs.GetWorkers() + + jobSpec := ptOp.PyTorchJobSpec{ + TTLSecondsAfterFinished: nil, + PyTorchReplicaSpecs: map[ptOp.PyTorchReplicaType]*commonOp.ReplicaSpec{ + ptOp.PyTorchReplicaTypeMaster: { + Template: v1.PodTemplateSpec{ + Spec: *podSpec, + }, + RestartPolicy: commonOp.RestartPolicyNever, + }, + ptOp.PyTorchReplicaTypeWorker: { + Replicas: &workers, + Template: v1.PodTemplateSpec{ + Spec: *podSpec, + }, + RestartPolicy: commonOp.RestartPolicyNever, + }, + }, + } + + job := &ptOp.PyTorchJob{ + TypeMeta: metav1.TypeMeta{ + Kind: ptOp.Kind, + APIVersion: ptOp.SchemeGroupVersion.String(), + }, + Spec: jobSpec, + } + + return job, nil +} + +// Analyses the k8s resource and reports the status as TaskPhase. This call is expected to be relatively fast, +// any operations that might take a long time (limits are configured system-wide) should be offloaded to the +// background. +func (pytorchOperatorResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource k8s.Resource) (pluginsCore.PhaseInfo, error) { + app := resource.(*ptOp.PyTorchJob) + + workersCount := app.Spec.PyTorchReplicaSpecs[ptOp.PyTorchReplicaTypeWorker].Replicas + + taskLogs, err := getLogs(app, *workersCount) + if err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + + currentCondition, err := extractCurrentCondition(app.Status.Conditions) + if err != nil { + return pluginsCore.PhaseInfoUndefined, err + } + + occurredAt := time.Now() + statusDetails, _ := utils.MarshalObjToStruct(app.Status) + taskPhaseInfo := pluginsCore.TaskInfo{ + Logs: taskLogs, + OccurredAt: &occurredAt, + CustomInfo: statusDetails, + } + + switch currentCondition.Type { + case commonOp.JobCreated: + return pluginsCore.PhaseInfoQueued(occurredAt, pluginsCore.DefaultPhaseVersion, "JobCreated"), nil + case commonOp.JobRunning: + return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &taskPhaseInfo), nil + case commonOp.JobSucceeded: + return pluginsCore.PhaseInfoSuccess(&taskPhaseInfo), nil + case commonOp.JobFailed: + details := fmt.Sprintf("Job failed:\n\t%v - %v", currentCondition.Reason, currentCondition.Message) + return pluginsCore.PhaseInfoRetryableFailure(flyteerr.DownstreamSystemError, details, &taskPhaseInfo), nil + case commonOp.JobRestarting: + details := fmt.Sprintf("Job failed:\n\t%v - %v", currentCondition.Reason, currentCondition.Message) + return pluginsCore.PhaseInfoRetryableFailure(flyteerr.RuntimeFailure, details, &taskPhaseInfo), nil + } + + return pluginsCore.PhaseInfoUndefined, nil +} + +func getLogs(app *ptOp.PyTorchJob, workersCount int32) ([]*core.TaskLog, error) { + // If kubeClient was available, it would be better to use + // https://github.com/lyft/flyteplugins/blob/209c52d002b4e6a39be5d175bc1046b7e631c153/go/tasks/logs/logging_utils.go#L12 + makeTaskLog := func(appName, appNamespace, suffix, url string) (core.TaskLog, error) { + return logUtils.NewKubernetesLogPlugin(url).GetTaskLog( + appName+"-"+suffix, + appNamespace, + "", + "", + suffix+" logs (via Kubernetes)") + } + + var taskLogs []*core.TaskLog + + logConfig := logs.GetLogConfig() + if logConfig.IsKubernetesEnabled { + masterTaskLog, masterErr := makeTaskLog(app.Name, app.Namespace, "master-0", logConfig.KubernetesURL) + if masterErr != nil { + return nil, masterErr + } + taskLogs = append(taskLogs, &masterTaskLog) + + for workerIndex := int32(0); workerIndex < workersCount; workerIndex++ { + workerLog, err := makeTaskLog(app.Name, app.Namespace, fmt.Sprintf("worker-%d", workerIndex), logConfig.KubernetesURL) + if err != nil { + return nil, err + } + taskLogs = append(taskLogs, &workerLog) + } + } + return taskLogs, nil +} + +func extractCurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) { + sort.Slice(jobConditions[:], func(i, j int) bool { + return jobConditions[i].LastTransitionTime.Time.After(jobConditions[j].LastTransitionTime.Time) + }) + + for _, jc := range jobConditions { + if jc.Status == v1.ConditionTrue { + return jc, nil + } + } + + return commonOp.JobCondition{}, fmt.Errorf("found no current condition. Conditions: %+v", jobConditions) +} + +func overrideDefaultContainerName(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec) { + // Pytorch operator forces pod to have container named 'pytorch' + // https://github.com/kubeflow/pytorch-operator/blob/037cd1b18eb77f657f2a4bc8a8334f2a06324b57/pkg/apis/pytorch/validation/validation.go#L54-L62 + // hence we have to override the name set here + // https://github.com/lyft/flyteplugins/blob/209c52d002b4e6a39be5d175bc1046b7e631c153/go/tasks/pluginmachinery/flytek8s/container_helper.go#L116 + flyteDefaultContainerName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + for idx, c := range podSpec.Containers { + if c.Name == flyteDefaultContainerName { + podSpec.Containers[idx].Name = ptOp.DefaultContainerName + return + } + } +} + +func init() { + if err := ptOp.AddToScheme(scheme.Scheme); err != nil { + panic(err) + } + + pluginmachinery.PluginRegistry().RegisterK8sPlugin( + k8s.PluginEntry{ + ID: pytorchTaskType, + RegisteredTaskTypes: []pluginsCore.TaskType{pytorchTaskType}, + ResourceToWatch: &ptOp.PyTorchJob{}, + Plugin: pytorchOperatorResourceHandler{}, + IsDefault: false, + }) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go new file mode 100644 index 000000000..2b325e841 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -0,0 +1,353 @@ +package pytorch + +import ( + "context" + "fmt" + "testing" + "time" + + commonOp "github.com/kubeflow/tf-operator/pkg/apis/common/v1" + "github.com/lyft/flyteplugins/go/tasks/logs" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + + "github.com/stretchr/testify/mock" + + "github.com/lyft/flytestdlib/storage" + + pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" + + pluginIOMocks "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io/mocks" + + "github.com/golang/protobuf/jsonpb" + structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/stretchr/testify/assert" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + ptOp "github.com/kubeflow/pytorch-operator/pkg/apis/pytorch/v1" +) + +const testImage = "image://" +const serviceAccount = "pytorch_sa" + +var ( + dummyEnvVars = []*core.KeyValuePair{ + {Key: "Env_Var", Value: "Env_Val"}, + } + + testArgs = []string{ + "test-args", + } + + resourceRequirements = &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1000m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("100m"), + corev1.ResourceMemory: resource.MustParse("512Mi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + } + + jobName = "the-job" + jobNamespace = "pytorch-namespace" +) + +func dummyPytorchCustomObj(workers int32) *plugins.DistributedPyTorchTrainingTask { + return &plugins.DistributedPyTorchTrainingTask{ + Workers: workers, + } +} + +func dummySparkTaskTemplate(id string, pytorchCustomObj *plugins.DistributedPyTorchTrainingTask) *core.TaskTemplate { + + ptObjJSON, err := utils.MarshalToString(pytorchCustomObj) + if err != nil { + panic(err) + } + + structObj := structpb.Struct{} + + err = jsonpb.UnmarshalString(ptObjJSON, &structObj) + if err != nil { + panic(err) + } + + return &core.TaskTemplate{ + Id: &core.Identifier{Name: id}, + Type: "container", + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: testImage, + Args: testArgs, + Env: dummyEnvVars, + }, + }, + Custom: &structObj, + } +} + +func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskExecutionContext { + taskCtx := &mocks.TaskExecutionContext{} + inputReader := &pluginIOMocks.InputReader{} + inputReader.OnGetInputPrefixPath().Return(storage.DataReference("/input/prefix")) + inputReader.OnGetInputPath().Return(storage.DataReference("/input")) + inputReader.OnGetMatch(mock.Anything).Return(&core.LiteralMap{}, nil) + taskCtx.OnInputReader().Return(inputReader) + + outputReader := &pluginIOMocks.OutputWriter{} + outputReader.OnGetOutputPath().Return(storage.DataReference("/data/outputs.pb")) + outputReader.OnGetOutputPrefixPath().Return(storage.DataReference("/data/")) + taskCtx.OnOutputWriter().Return(outputReader) + + taskReader := &mocks.TaskReader{} + taskReader.OnReadMatch(mock.Anything).Return(taskTemplate, nil) + taskCtx.OnTaskReader().Return(taskReader) + + tID := &mocks.TaskExecutionID{} + tID.OnGetID().Return(core.TaskExecutionIdentifier{ + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "my_name", + Project: "my_project", + Domain: "my_domain", + }, + }, + }) + tID.OnGetGeneratedName().Return("some-acceptable-name") + + resources := &mocks.TaskOverrides{} + resources.OnGetResources().Return(resourceRequirements) + + taskExecutionMetadata := &mocks.TaskExecutionMetadata{} + taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) + taskExecutionMetadata.OnGetNamespace().Return("test-namespace") + taskExecutionMetadata.OnGetAnnotations().Return(map[string]string{"annotation-1": "val1"}) + taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"}) + taskExecutionMetadata.OnGetOwnerReference().Return(v1.OwnerReference{ + Kind: "node", + Name: "blah", + }) + taskExecutionMetadata.OnIsInterruptible().Return(true) + taskExecutionMetadata.OnGetOverrides().Return(resources) + taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount) + taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) + return taskCtx +} + +func dummyPytorchJobResource(pytorchResourceHandler pytorchOperatorResourceHandler, workers int32, conditionType commonOp.JobConditionType) *ptOp.PyTorchJob { + var jobConditions []commonOp.JobCondition + + now := time.Now() + + jobCreated := commonOp.JobCondition{ + Type: commonOp.JobCreated, + Status: corev1.ConditionTrue, + Reason: "PyTorchJobCreated", + Message: "PyTorchJob the-job is created.", + LastUpdateTime: v1.Time{ + Time: now, + }, + LastTransitionTime: v1.Time{ + Time: now, + }, + } + jobRunningActive := commonOp.JobCondition{ + Type: commonOp.JobRunning, + Status: corev1.ConditionTrue, + Reason: "PyTorchJobRunning", + Message: "PyTorchJob the-job is running.", + LastUpdateTime: v1.Time{ + Time: now.Add(time.Minute), + }, + LastTransitionTime: v1.Time{ + Time: now.Add(time.Minute), + }, + } + jobRunningInactive := *jobRunningActive.DeepCopy() + jobRunningInactive.Status = corev1.ConditionFalse + jobSucceeded := commonOp.JobCondition{ + Type: commonOp.JobSucceeded, + Status: corev1.ConditionTrue, + Reason: "PyTorchJobSucceeded", + Message: "PyTorchJob the-job is successfully completed.", + LastUpdateTime: v1.Time{ + Time: now.Add(2 * time.Minute), + }, + LastTransitionTime: v1.Time{ + Time: now.Add(2 * time.Minute), + }, + } + jobFailed := commonOp.JobCondition{ + Type: commonOp.JobFailed, + Status: corev1.ConditionTrue, + Reason: "PyTorchJobFailed", + Message: "PyTorchJob the-job is failed.", + LastUpdateTime: v1.Time{ + Time: now.Add(2 * time.Minute), + }, + LastTransitionTime: v1.Time{ + Time: now.Add(2 * time.Minute), + }, + } + jobRestarting := commonOp.JobCondition{ + Type: commonOp.JobRestarting, + Status: corev1.ConditionTrue, + Reason: "PyTorchJobRestarting", + Message: "PyTorchJob the-job is restarting because some replica(s) failed.", + LastUpdateTime: v1.Time{ + Time: now.Add(3 * time.Minute), + }, + LastTransitionTime: v1.Time{ + Time: now.Add(3 * time.Minute), + }, + } + + switch conditionType { + case commonOp.JobCreated: + jobConditions = []commonOp.JobCondition{ + jobCreated, + } + case commonOp.JobRunning: + jobConditions = []commonOp.JobCondition{ + jobCreated, + jobRunningActive, + } + case commonOp.JobSucceeded: + jobConditions = []commonOp.JobCondition{ + jobCreated, + jobRunningInactive, + jobSucceeded, + } + case commonOp.JobFailed: + jobConditions = []commonOp.JobCondition{ + jobCreated, + jobRunningInactive, + jobFailed, + } + case commonOp.JobRestarting: + jobConditions = []commonOp.JobCondition{ + jobCreated, + jobRunningInactive, + jobFailed, + jobRestarting, + } + } + + ptObj := dummyPytorchCustomObj(workers) + taskTemplate := dummySparkTaskTemplate("the job", ptObj) + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) + if err != nil { + panic(err) + } + + return &ptOp.PyTorchJob{ + ObjectMeta: v1.ObjectMeta{ + Name: jobName, + Namespace: jobNamespace, + }, + Spec: resource.(*ptOp.PyTorchJob).Spec, + Status: commonOp.JobStatus{ + Conditions: jobConditions, + ReplicaStatuses: nil, + StartTime: nil, + CompletionTime: nil, + LastReconcileTime: nil, + }, + } +} + +func TestBuildResourcePytorch(t *testing.T) { + pytorchResourceHandler := pytorchOperatorResourceHandler{} + + ptObj := dummyPytorchCustomObj(100) + taskTemplate := dummySparkTaskTemplate("the job", ptObj) + + resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, resource) + + pytorchJob, ok := resource.(*ptOp.PyTorchJob) + assert.True(t, ok) + assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[ptOp.PyTorchReplicaTypeWorker].Replicas) + + for _, replicaSpec := range pytorchJob.Spec.PyTorchReplicaSpecs { + var hasContainerWithDefaultPytorchName = false + + for _, container := range replicaSpec.Template.Spec.Containers { + if container.Name == ptOp.DefaultContainerName { + hasContainerWithDefaultPytorchName = true + } + + assert.Equal(t, resourceRequirements.Requests, container.Resources.Requests) + assert.Equal(t, resourceRequirements.Limits, container.Resources.Limits) + } + + assert.True(t, hasContainerWithDefaultPytorchName) + } +} + +func TestGetTaskPhase(t *testing.T) { + pytorchResourceHandler := pytorchOperatorResourceHandler{} + ctx := context.TODO() + + dummyPytorchJobResourceCreator := func(conditionType commonOp.JobConditionType) *ptOp.PyTorchJob { + return dummyPytorchJobResource(pytorchResourceHandler, 2, conditionType) + } + + taskPhase, err := pytorchResourceHandler.GetTaskPhase(ctx, nil, dummyPytorchJobResourceCreator(commonOp.JobCreated)) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseQueued, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + taskPhase, err = pytorchResourceHandler.GetTaskPhase(ctx, nil, dummyPytorchJobResourceCreator(commonOp.JobRunning)) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRunning, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + taskPhase, err = pytorchResourceHandler.GetTaskPhase(ctx, nil, dummyPytorchJobResourceCreator(commonOp.JobSucceeded)) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseSuccess, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + taskPhase, err = pytorchResourceHandler.GetTaskPhase(ctx, nil, dummyPytorchJobResourceCreator(commonOp.JobFailed)) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRetryableFailure, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) + + taskPhase, err = pytorchResourceHandler.GetTaskPhase(ctx, nil, dummyPytorchJobResourceCreator(commonOp.JobRestarting)) + assert.NoError(t, err) + assert.Equal(t, pluginsCore.PhaseRetryableFailure, taskPhase.Phase()) + assert.NotNil(t, taskPhase.Info()) + assert.Nil(t, err) +} + +func TestGetLogs(t *testing.T) { + assert.NoError(t, logs.SetLogConfig(&logs.LogConfig{ + IsKubernetesEnabled: true, + KubernetesURL: "k8s.com", + })) + + workers := int32(2) + + pytorchResourceHandler := pytorchOperatorResourceHandler{} + jobLogs, err := getLogs(dummyPytorchJobResource(pytorchResourceHandler, workers, commonOp.JobRunning), workers) + assert.NoError(t, err) + assert.Equal(t, 3, len(jobLogs)) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-master-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[0].Uri) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-0/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[1].Uri) + assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-worker-1/pod?namespace=pytorch-namespace", jobNamespace, jobName), jobLogs[2].Uri) +}