Skip to content

Commit

Permalink
Use flyte configuration defaults for sidecar pod spec (flyteorg#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan authored Oct 14, 2020
1 parent 639c589 commit 15016e2
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 42 deletions.
2 changes: 1 addition & 1 deletion flyteplugins/copilot/data/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (

type Downloader struct {
format core.DataLoadingConfig_LiteralMapFormat
store *storage.DataStore
store *storage.DataStore
// TODO support download mode
mode core.IOStrategy_DownloadMode
}
Expand Down
2 changes: 1 addition & 1 deletion flyteplugins/copilot/data/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (

"github.com/golang/protobuf/proto"
"github.com/lyft/flyteidl/gen/pb-go/flyteidl/core"
"github.com/lyft/flytestdlib/futures"
"github.com/lyft/flytestdlib/logger"
"github.com/lyft/flytestdlib/storage"
"github.com/lyft/flytestdlib/futures"
"github.com/pkg/errors"

"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils"
Expand Down
41 changes: 27 additions & 14 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,45 @@ import (
"strings"
"time"

"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils"

"github.com/lyft/flytestdlib/logger"
v1 "k8s.io/api/core/v1"
v12 "k8s.io/apimachinery/pkg/apis/meta/v1"

pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils"
)

const PodKind = "pod"
const OOMKilled = "OOMKilled"
const Interrupted = "Interrupted"
const SIGKILL = 137

// Updates the base pod spec used to execute tasks. This is configured with plugins and task metadata-specific options
func UpdatePod(taskExecutionMetadata pluginsCore.TaskExecutionMetadata,
resourceRequirements []v1.ResourceRequirements, podSpec *v1.PodSpec) {
if len(podSpec.RestartPolicy) == 0 {
podSpec.RestartPolicy = v1.RestartPolicyNever
}
podSpec.Tolerations = append(
GetPodTolerations(taskExecutionMetadata.IsInterruptible(), resourceRequirements...), podSpec.Tolerations...)
if len(podSpec.ServiceAccountName) == 0 {
podSpec.ServiceAccountName = taskExecutionMetadata.GetK8sServiceAccount()
}
if len(podSpec.SchedulerName) == 0 {
podSpec.SchedulerName = config.GetK8sPluginConfig().SchedulerName
}
podSpec.NodeSelector = utils.UnionMaps(podSpec.NodeSelector, config.GetK8sPluginConfig().DefaultNodeSelector)
if taskExecutionMetadata.IsInterruptible() {
podSpec.NodeSelector = utils.UnionMaps(podSpec.NodeSelector, config.GetK8sPluginConfig().InterruptibleNodeSelector)
}
if podSpec.Affinity == nil {
podSpec.Affinity = config.GetK8sPluginConfig().DefaultAffinity
}
}

func ToK8sPodSpec(ctx context.Context, taskExecutionMetadata pluginsCore.TaskExecutionMetadata, taskReader pluginsCore.TaskReader,
inputs io.InputReader, outputPaths io.OutputFilePaths) (*v1.PodSpec, error) {
task, err := taskReader.Read(ctx)
Expand All @@ -40,21 +64,10 @@ func ToK8sPodSpec(ctx context.Context, taskExecutionMetadata pluginsCore.TaskExe
containers := []v1.Container{
*c,
}

pod := &v1.PodSpec{
// We could specify Scheduler, Affinity, nodename etc
RestartPolicy: v1.RestartPolicyNever,
Containers: containers,
Tolerations: GetPodTolerations(taskExecutionMetadata.IsInterruptible(), c.Resources),
ServiceAccountName: taskExecutionMetadata.GetK8sServiceAccount(),
SchedulerName: config.GetK8sPluginConfig().SchedulerName,
NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector,
Affinity: config.GetK8sPluginConfig().DefaultAffinity,
}

if taskExecutionMetadata.IsInterruptible() {
pod.NodeSelector = utils.UnionMaps(pod.NodeSelector, config.GetK8sPluginConfig().InterruptibleNodeSelector)
Containers: containers,
}
UpdatePod(taskExecutionMetadata, []v1.ResourceRequirements{c.Resources}, pod)

if err := AddCoPilotToPod(ctx, config.GetK8sPluginConfig().CoPilot, pod, task.GetInterface(), taskExecutionMetadata, inputs, outputPaths, task.GetContainer().GetDataConfig()); err != nil {
return nil, err
Expand Down
57 changes: 55 additions & 2 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,68 @@ func dummyInputReader() io.InputReader {
return inputReader
}

func TestToK8sPodIterruptible(t *testing.T) {
ctx := context.TODO()
func TestPodSetup(t *testing.T) {
configAccessor := viper.NewAccessor(config1.Options{
StrictMode: true,
SearchPaths: []string{"testdata/config.yaml"},
})
err := configAccessor.UpdateConfig(context.TODO())
assert.NoError(t, err)

t.Run("UpdatePod", updatePod)
t.Run("ToK8sPodInterruptible", toK8sPodInterruptible)
}

func updatePod(t *testing.T) {
taskExecutionMetadata := dummyTaskExecutionMetadata(&v1.ResourceRequirements{
Limits: v1.ResourceList{
v1.ResourceCPU: resource.MustParse("1024m"),
v1.ResourceStorage: resource.MustParse("100M"),
},
Requests: v1.ResourceList{
v1.ResourceCPU: resource.MustParse("1024m"),
v1.ResourceStorage: resource.MustParse("100M"),
},
})

pod := &v1.Pod{
Spec: v1.PodSpec{
Tolerations: []v1.Toleration{
{
Key: "my toleration key",
Value: "my toleration value",
},
},
NodeSelector: map[string]string{
"user": "also configured",
},
},
}
UpdatePod(taskExecutionMetadata, []v1.ResourceRequirements{}, &pod.Spec)
assert.Equal(t, v1.RestartPolicyNever, pod.Spec.RestartPolicy)
for _, tol := range pod.Spec.Tolerations {
if tol.Key == "x/flyte" {
assert.Equal(t, tol.Value, "interruptible")
assert.Equal(t, tol.Operator, v1.TolerationOperator("Equal"))
assert.Equal(t, tol.Effect, v1.TaintEffect("NoSchedule"))
} else if tol.Key == "my toleration key" {
assert.Equal(t, tol.Value, "my toleration value")
} else {
t.Fatalf("unexpected toleration [%+v]", tol)
}
}
assert.Equal(t, "service-account", pod.Spec.ServiceAccountName)
assert.Equal(t, "flyte-scheduler", pod.Spec.SchedulerName)
assert.Len(t, pod.Spec.Tolerations, 2)
assert.EqualValues(t, map[string]string{
"x/interruptible": "true",
"user": "also configured",
}, pod.Spec.NodeSelector)
}

func toK8sPodInterruptible(t *testing.T) {
ctx := context.TODO()

op := &pluginsIOMock.OutputFilePaths{}
op.On("GetOutputPrefixPath").Return(storage.DataReference(""))
op.On("GetRawOutputPrefix").Return(storage.DataReference(""))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ plugins:
- container
# All k8s plugins default configuration
k8s:
scheduler-name: flyte-scheduler
default-annotations:
- annotationKey1: annotationValue1
- annotationKey2: annotationValue2
Expand Down
14 changes: 2 additions & 12 deletions flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ import (
"github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins"

"github.com/lyft/flyteplugins/go/tasks/pluginmachinery"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"

pluginsCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s"

"github.com/lyft/flyteplugins/go/tasks/errors"
Expand Down Expand Up @@ -61,15 +59,7 @@ func validateAndFinalizePod(

}
pod.Spec.Containers = finalizedContainers
if pod.Spec.Tolerations == nil {
pod.Spec.Tolerations = make([]k8sv1.Toleration, 0)
}
pod.Spec.Tolerations = append(
flytek8s.GetPodTolerations(taskCtx.TaskExecutionMetadata().IsInterruptible(), resReqs...), pod.Spec.Tolerations...)
if taskCtx.TaskExecutionMetadata().IsInterruptible() && len(config.GetK8sPluginConfig().InterruptibleNodeSelector) > 0 {
pod.Spec.NodeSelector = config.GetK8sPluginConfig().InterruptibleNodeSelector
}

flytek8s.UpdatePod(taskCtx.TaskExecutionMetadata(), resReqs, &pod.Spec)
return &pod, nil
}

Expand Down
22 changes: 10 additions & 12 deletions flyteplugins/go/tasks/plugins/k8s/sidecar/sidecar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,17 @@ func TestBuildSidecarResource(t *testing.T) {

// Assert user-specified tolerations don't get overridden
assert.Len(t, res.(*v1.Pod).Spec.Tolerations, 2)
expectedTolerations := []v1.Toleration{
{
Key: "flyte/gpu",
Operator: "Equal",
Value: "dedicated",
Effect: "NoSchedule",
},
{
Key: "my toleration key",
Value: "my toleration value",
},
for _, tol := range res.(*v1.Pod).Spec.Tolerations {
if tol.Key == "flyte/gpu" {
assert.Equal(t, tol.Value, "dedicated")
assert.Equal(t, tol.Operator, v1.TolerationOperator("Equal"))
assert.Equal(t, tol.Effect, v1.TaintEffect("NoSchedule"))
} else if tol.Key == "my toleration key" {
assert.Equal(t, tol.Value, "my toleration value")
} else {
t.Fatalf("unexpected toleration [%+v]", tol)
}
}
assert.EqualValues(t, expectedTolerations, res.(*v1.Pod).Spec.Tolerations)
}

func TestBuildSidecarResourceMissingPrimary(t *testing.T) {
Expand Down

0 comments on commit 15016e2

Please sign in to comment.