Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Change kubeflow plugins to allow settings specs for different replica (
Browse files Browse the repository at this point in the history
…#345)

* change pytorch plugin to accept new pytorch task idl

Signed-off-by: Yubo Wang <[email protected]>

* merge elastic config in

Signed-off-by: Yubo Wang <[email protected]>

* add unit tests for pytorch

Signed-off-by: Yubo Wang <[email protected]>

* add tfjob

Signed-off-by: Yubo Wang <[email protected]>

* add mpi job

Signed-off-by: Yubo Wang <[email protected]>

* add test to commone operator

Signed-off-by: Yubo Wang <[email protected]>

* update flyteidl

Signed-off-by: Yubo Wang <[email protected]>

* add function header comments

Signed-off-by: Yubo Wang <[email protected]>

* fix lint

Signed-off-by: Yubo Wang <[email protected]>

---------

Signed-off-by: Yubo Wang <[email protected]>
Co-authored-by: Yubo Wang <[email protected]>
  • Loading branch information
yubofredwang and Yubo Wang authored May 9, 2023
1 parent 22ee8e9 commit 9cbd406
Show file tree
Hide file tree
Showing 10 changed files with 975 additions and 121 deletions.
2 changes: 1 addition & 1 deletion flyteplugins/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/athena v1.0.0
github.com/bstadlbauer/dask-k8s-operator-go-client v0.1.0
github.com/coocood/freecache v1.1.1
github.com/flyteorg/flyteidl v1.3.19
github.com/flyteorg/flyteidl v1.5.2
github.com/flyteorg/flytestdlib v1.0.15
github.com/go-test/deep v1.0.7
github.com/golang/protobuf v1.5.2
Expand Down
4 changes: 2 additions & 2 deletions flyteplugins/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQL
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/flyteorg/flyteidl v1.3.19 h1:i79Dh7UoP8Z4LEJ2ox6jlfZVJtFZ+r4g84CJj1gh22Y=
github.com/flyteorg/flyteidl v1.3.19/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM=
github.com/flyteorg/flyteidl v1.5.2 h1:DZPzYkTg92qA4e17fd0ZW1M+gh1gJKh/VOK+F4bYgM8=
github.com/flyteorg/flyteidl v1.5.2/go.mod h1:ckLjB51moX4L0oQml+WTCrPK50zrJf6IZJ6LPC0RB4I=
github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0=
github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s=
github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"sort"
"time"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"
flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/logs"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
Expand All @@ -21,6 +23,12 @@ const (
PytorchTaskType = "pytorch"
)

type ReplicaEntry struct {
PodSpec *v1.PodSpec
ReplicaNum int32
RestartPolicy commonOp.RestartPolicy
}

// ExtractMPICurrentCondition will return the first job condition for MPI
func ExtractMPICurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) {
if jobConditions != nil {
Expand Down Expand Up @@ -180,3 +188,72 @@ func OverridePrimaryContainerName(podSpec *v1.PodSpec, primaryContainerName stri
}
}
}

// ParseRunPolicy converts a kubeflow plugin RunPolicy object to a k8s RunPolicy object.
func ParseRunPolicy(flyteRunPolicy kfplugins.RunPolicy) commonOp.RunPolicy {
runPolicy := commonOp.RunPolicy{}
if flyteRunPolicy.GetBackoffLimit() != 0 {
var backoffLimit = flyteRunPolicy.GetBackoffLimit()
runPolicy.BackoffLimit = &backoffLimit
}
var cleanPodPolicy = ParseCleanPodPolicy(flyteRunPolicy.GetCleanPodPolicy())
runPolicy.CleanPodPolicy = &cleanPodPolicy
if flyteRunPolicy.GetActiveDeadlineSeconds() != 0 {
var ddlSeconds = int64(flyteRunPolicy.GetActiveDeadlineSeconds())
runPolicy.ActiveDeadlineSeconds = &ddlSeconds
}
if flyteRunPolicy.GetTtlSecondsAfterFinished() != 0 {
var ttl = flyteRunPolicy.GetTtlSecondsAfterFinished()
runPolicy.TTLSecondsAfterFinished = &ttl
}

return runPolicy
}

// Get k8s clean pod policy from flyte kubeflow plugins clean pod policy.
func ParseCleanPodPolicy(flyteCleanPodPolicy kfplugins.CleanPodPolicy) commonOp.CleanPodPolicy {
cleanPodPolicyMap := map[kfplugins.CleanPodPolicy]commonOp.CleanPodPolicy{
kfplugins.CleanPodPolicy_CLEANPOD_POLICY_NONE: commonOp.CleanPodPolicyNone,
kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL: commonOp.CleanPodPolicyAll,
kfplugins.CleanPodPolicy_CLEANPOD_POLICY_RUNNING: commonOp.CleanPodPolicyRunning,
}
return cleanPodPolicyMap[flyteCleanPodPolicy]
}

// Get k8s restart policy from flyte kubeflow plugins restart policy.
func ParseRestartPolicy(flyteRestartPolicy kfplugins.RestartPolicy) commonOp.RestartPolicy {
restartPolicyMap := map[kfplugins.RestartPolicy]commonOp.RestartPolicy{
kfplugins.RestartPolicy_RESTART_POLICY_NEVER: commonOp.RestartPolicyNever,
kfplugins.RestartPolicy_RESTART_POLICY_ON_FAILURE: commonOp.RestartPolicyOnFailure,
kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS: commonOp.RestartPolicyAlways,
}
return restartPolicyMap[flyteRestartPolicy]
}

// OverrideContainerSpec overrides the specified container's properties in the given podSpec. The function
// updates the image, resources and command arguments of the container that matches the given containerName.
func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, resources *core.Resources, args []string) error {
for idx, c := range podSpec.Containers {
if c.Name == containerName {
if image != "" {
podSpec.Containers[idx].Image = image
}
if resources != nil {
// if resources requests and limits both not set, we will not override the resources
if len(resources.Requests) >= 1 || len(resources.Limits) >= 1 {
resources, err := flytek8s.ToK8sResourceRequirements(resources)
if err != nil {
return flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecificat ion on Resources [%v], Err: [%v]", resources, err.Error())
}
podSpec.Containers[idx].Resources = *resources
}
} else {
podSpec.Containers[idx].Resources = v1.ResourceRequirements{}
}
if len(args) != 0 {
podSpec.Containers[idx].Args = args
}
}
}
return nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ import (
"testing"
"time"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/logs"

pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
commonOp "github.com/kubeflow/common/pkg/apis/common/v1"
"github.com/stretchr/testify/assert"
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
)

func TestExtractMPICurrentCondition(t *testing.T) {
Expand Down Expand Up @@ -183,3 +186,101 @@ func TestGetLogs(t *testing.T) {
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-chiefReplica-0/pod?namespace=tensorflow-namespace", "tensorflow-namespace", "test"), jobLogs[2].Uri)

}

func dummyPodSpec() v1.PodSpec {
return v1.PodSpec{
Containers: []v1.Container{
{
Name: "primary container",
Args: []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"},
Resources: v1.ResourceRequirements{
Limits: v1.ResourceList{
"cpu": resource.MustParse("2"),
"memory": resource.MustParse("200Mi"),
"gpu": resource.MustParse("1"),
},
Requests: v1.ResourceList{
"cpu": resource.MustParse("1"),
"memory": resource.MustParse("100Mi"),
"gpu": resource.MustParse("1"),
},
},
VolumeMounts: []v1.VolumeMount{
{
Name: "volume mount",
},
},
},
{
Name: "secondary container",
Resources: v1.ResourceRequirements{
Limits: v1.ResourceList{
"gpu": resource.MustParse("2"),
},
Requests: v1.ResourceList{
"gpu": resource.MustParse("2"),
},
},
},
},
Volumes: []v1.Volume{
{
Name: "dshm",
},
},
Tolerations: []v1.Toleration{
{
Key: "my toleration key",
Value: "my toleration value",
},
},
}
}

func TestOverrideContainerSpec(t *testing.T) {
podSpec := dummyPodSpec()
err := OverrideContainerSpec(
&podSpec, "primary container", "testing-image",
&core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "250m"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "500m"},
},
},
[]string{"python", "-m", "run.py"},
)
assert.NoError(t, err)
assert.Equal(t, 2, len(podSpec.Containers))
assert.Equal(t, "testing-image", podSpec.Containers[0].Image)
assert.NotNil(t, podSpec.Containers[0].Resources.Limits)
assert.NotNil(t, podSpec.Containers[0].Resources.Requests)
// verify resources not overridden if empty resources
assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("250m")))
assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("500m")))
assert.Equal(t, []string{"python", "-m", "run.py"}, podSpec.Containers[0].Args)
}

func TestOverrideContainerSpecEmptyFields(t *testing.T) {
podSpec := dummyPodSpec()
err := OverrideContainerSpec(&podSpec, "primary container", "", &core.Resources{}, []string{})
assert.NoError(t, err)
assert.Equal(t, 2, len(podSpec.Containers))
assert.NotNil(t, podSpec.Containers[0].Resources.Limits)
assert.NotNil(t, podSpec.Containers[0].Resources.Requests)
// verify resources not overridden if empty resources
assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("1")))
assert.True(t, podSpec.Containers[0].Resources.Requests.Memory().Equal(resource.MustParse("100Mi")))
assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("2")))
assert.True(t, podSpec.Containers[0].Resources.Limits.Memory().Equal(resource.MustParse("200Mi")))
}

func TestOverrideContainerNilResources(t *testing.T) {
podSpec := dummyPodSpec()
err := OverrideContainerSpec(&podSpec, "primary container", "", nil, []string{})
assert.NoError(t, err)
assert.Equal(t, 2, len(podSpec.Containers))
assert.Nil(t, podSpec.Containers[0].Resources.Limits)
assert.Nil(t, podSpec.Containers[0].Resources.Requests)
}
Loading

0 comments on commit 9cbd406

Please sign in to comment.