Skip to content

Commit

Permalink
Refactor container task, pod task and respective array task plugin im…
Browse files Browse the repository at this point in the history
…pls (flyteorg#195)
  • Loading branch information
Katrina Rogan authored Aug 13, 2021
1 parent 81416b9 commit 7d44700
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 93 deletions.
77 changes: 49 additions & 28 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/container_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,53 +105,74 @@ func ApplyResourceOverrides(ctx context.Context, resources v1.ResourceRequiremen
return &resources
}

// Returns a K8s Container for the execution
// Transforms a task template target of type core.Container into a bare-bones kubernetes container, which can be further
// modified with flyte-specific customizations specified by various static and run-time attributes.
func ToK8sContainer(ctx context.Context, taskContainer *core.Container, iFace *core.TypedInterface, parameters template.Parameters) (*v1.Container, error) {
modifiedCommand, err := template.Render(ctx, taskContainer.GetCommand(), parameters)
if err != nil {
return nil, err
}

modifiedArgs, err := template.Render(ctx, taskContainer.GetArgs(), parameters)
if err != nil {
return nil, err
}

envVars := DecorateEnvVars(ctx, ToK8sEnvVar(taskContainer.GetEnv()), parameters.TaskExecMetadata.GetTaskExecutionID())

// Perform preliminary validations
if parameters.TaskExecMetadata.GetOverrides() == nil {
return nil, errors.Errorf(errors.BadTaskSpecification, "platform/compiler error, overrides not set for task")
}
if parameters.TaskExecMetadata.GetOverrides() == nil || parameters.TaskExecMetadata.GetOverrides().GetResources() == nil {
return nil, errors.Errorf(errors.BadTaskSpecification, "resource requirements not found for container task, required!")
}

res := parameters.TaskExecMetadata.GetOverrides().GetResources()
if res != nil {
res = ApplyResourceOverrides(ctx, *res)
}

// Make the container name the same as the pod name, unless it violates K8s naming conventions
// Container names are subject to the DNS-1123 standard
containerName := parameters.TaskExecMetadata.GetTaskExecutionID().GetGeneratedName()
if errs := validation.IsDNS1123Label(containerName); len(errs) > 0 {
containerName = rand.String(4)
}
c := &v1.Container{
container := &v1.Container{
Name: containerName,
Image: taskContainer.GetImage(),
Args: modifiedArgs,
Command: modifiedCommand,
Env: envVars,
Args: taskContainer.GetArgs(),
Command: taskContainer.GetCommand(),
Env: ToK8sEnvVar(taskContainer.GetEnv()),
TerminationMessagePolicy: v1.TerminationMessageFallbackToLogsOnError,
}
if err := AddCoPilotToContainer(ctx, config.GetK8sPluginConfig().CoPilot, container, iFace, taskContainer.DataConfig); err != nil {
return nil, err
}
return container, nil
}

type ResourceCustomizationMode int

if res != nil {
c.Resources = *res
const (
AssignResources ResourceCustomizationMode = iota
MergeExistingResources
LeaveResourcesUnmodified
)

// Takes a container definition which specifies how to run a Flyte task and fills in templated command and argument
// values, updates resources and decorates environment variables with platform and task-specific customizations.
func AddFlyteCustomizationsToContainer(ctx context.Context, parameters template.Parameters,
mode ResourceCustomizationMode, container *v1.Container) error {
modifiedCommand, err := template.Render(ctx, container.Command, parameters)
if err != nil {
return err
}
container.Command = modifiedCommand

if err := AddCoPilotToContainer(ctx, config.GetK8sPluginConfig().CoPilot, c, iFace, taskContainer.DataConfig); err != nil {
return nil, err
modifiedArgs, err := template.Render(ctx, container.Args, parameters)
if err != nil {
return err
}
container.Args = modifiedArgs

container.Env = DecorateEnvVars(ctx, container.Env, parameters.TaskExecMetadata.GetTaskExecutionID())

if parameters.TaskExecMetadata.GetOverrides() != nil && parameters.TaskExecMetadata.GetOverrides().GetResources() != nil {
res := parameters.TaskExecMetadata.GetOverrides().GetResources()
switch mode {
case AssignResources:
if res = ApplyResourceOverrides(ctx, *res); res != nil {
container.Resources = *res
}
case MergeExistingResources:
MergeResources(*res, &container.Resources)
container.Resources = *ApplyResourceOverrides(ctx, container.Resources)
case LeaveResourcesUnmodified:
}
}
return c, nil
return nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ import (
"context"
"testing"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template"
mocks2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flytestdlib/storage"
"github.com/stretchr/testify/mock"
"k8s.io/apimachinery/pkg/util/validation"

"github.com/stretchr/testify/assert"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
Expand Down Expand Up @@ -229,3 +237,131 @@ func TestMergeResources_PartialResourceKeys(t *testing.T) {
Limits: expectedResourceList,
})
}

func TestToK8sContainer(t *testing.T) {
taskContainer := &core.Container{
Image: "myimage",
Args: []string{
"arg1",
"arg2",
"arg3",
},
Command: []string{
"com1",
"com2",
"com3",
},
Env: []*core.KeyValuePair{
{
Key: "k",
Value: "v",
},
},
}

mockTaskExecMetadata := mocks.TaskExecutionMetadata{}
mockTaskOverrides := mocks.TaskOverrides{}
mockTaskOverrides.OnGetResources().Return(&v1.ResourceRequirements{
Limits: v1.ResourceList{
v1.ResourceEphemeralStorage: resource.MustParse("1024Mi"),
},
})
mockTaskExecMetadata.OnGetOverrides().Return(&mockTaskOverrides)
mockTaskExecutionID := mocks.TaskExecutionID{}
mockTaskExecutionID.OnGetGeneratedName().Return("gen_name")
mockTaskExecMetadata.OnGetTaskExecutionID().Return(&mockTaskExecutionID)

templateParameters := template.Parameters{
TaskExecMetadata: &mockTaskExecMetadata,
}

container, err := ToK8sContainer(context.TODO(), taskContainer, nil, templateParameters)
assert.NoError(t, err)
assert.Equal(t, container.Image, "myimage")
assert.EqualValues(t, []string{
"arg1",
"arg2",
"arg3",
}, container.Args)
assert.EqualValues(t, []string{
"com1",
"com2",
"com3",
}, container.Command)
assert.EqualValues(t, []v1.EnvVar{
{
Name: "k",
Value: "v",
},
}, container.Env)
errs := validation.IsDNS1123Label(container.Name)
assert.Nil(t, errs)
}

func TestAddFlyteCustomizationsToContainer(t *testing.T) {
mockTaskExecMetadata := mocks.TaskExecutionMetadata{}
mockTaskExecutionID := mocks.TaskExecutionID{}
mockTaskExecutionID.OnGetGeneratedName().Return("gen_name")
mockTaskExecutionID.OnGetID().Return(core.TaskExecutionIdentifier{
TaskId: &core.Identifier{
ResourceType: core.ResourceType_TASK,
Project: "p1",
Domain: "d1",
Name: "task_name",
Version: "v1",
},
NodeExecutionId: &core.NodeExecutionIdentifier{
NodeId: "node_id",
ExecutionId: &core.WorkflowExecutionIdentifier{
Project: "p2",
Domain: "d2",
Name: "n2",
},
},
RetryAttempt: 1,
})
mockTaskExecMetadata.OnGetTaskExecutionID().Return(&mockTaskExecutionID)

mockOverrides := mocks.TaskOverrides{}
mockOverrides.OnGetResources().Return(&v1.ResourceRequirements{
Requests: v1.ResourceList{
v1.ResourceEphemeralStorage: resource.MustParse("1024Mi"),
},
Limits: v1.ResourceList{
v1.ResourceEphemeralStorage: resource.MustParse("2048Mi"),
},
})
mockTaskExecMetadata.OnGetOverrides().Return(&mockOverrides)

mockInputReader := mocks2.InputReader{}
mockInputPath := storage.DataReference("s3://input/path")
mockInputReader.OnGetInputPath().Return(mockInputPath)
mockInputReader.OnGetInputPrefixPath().Return(mockInputPath)
mockInputReader.On("Get", mock.Anything).Return(nil, nil)

mockOutputPath := mocks2.OutputFilePaths{}
mockOutputPathPrefix := storage.DataReference("s3://output/path")
mockOutputPath.OnGetRawOutputPrefix().Return(mockOutputPathPrefix)
mockOutputPath.OnGetOutputPrefixPath().Return(mockOutputPathPrefix)

templateParameters := template.Parameters{
TaskExecMetadata: &mockTaskExecMetadata,
Inputs: &mockInputReader,
OutputPath: &mockOutputPath,
}
container := &v1.Container{
Command: []string{
"{{ .Input }}",
},
Args: []string{
"{{ .OutputPrefix }}",
},
}
err := AddFlyteCustomizationsToContainer(context.TODO(), templateParameters, AssignResources, container)
assert.NoError(t, err)
assert.EqualValues(t, container.Args, []string{"s3://output/path"})
assert.EqualValues(t, container.Command, []string{"s3://input/path"})
assert.Len(t, container.Resources.Limits, 3)
assert.Len(t, container.Resources.Requests, 3)
assert.Len(t, container.Env, 12)
}
9 changes: 7 additions & 2 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,17 @@ func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*
logger.Errorf(ctx, "Default Pod creation logic works for default container in the task template only.")
return nil, fmt.Errorf("container not specified in task template")
}
c, err := ToK8sContainer(ctx, task.GetContainer(), task.Interface, template.Parameters{
templateParameters := template.Parameters{
Task: tCtx.TaskReader(),
Inputs: tCtx.InputReader(),
OutputPath: tCtx.OutputWriter(),
TaskExecMetadata: tCtx.TaskExecutionMetadata(),
})
}
c, err := ToK8sContainer(ctx, task.GetContainer(), task.Interface, templateParameters)
if err != nil {
return nil, err
}
err = AddFlyteCustomizationsToContainer(ctx, templateParameters, AssignResources, c)
if err != nil {
return nil, err
}
Expand Down
43 changes: 8 additions & 35 deletions flyteplugins/go/tasks/plugins/array/k8s/transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,40 +94,6 @@ func buildPodMapTask(task *idlCore.TaskTemplate, metadata core.TaskExecutionMeta
return pod, nil
}

// Here we customize the k8sPod primary container by templatizing args.
// The call to ToK8sPodSpec for the task container target
// case already handles this but we must explicitly do so for K8sPod task targets.
func modifyMapPodTaskPrimaryContainer(ctx context.Context, tCtx core.TaskExecutionContext, arrTCtx *arrayTaskContext, container *v1.Container) error {
var err error
container.Args, err = template.Render(ctx, container.Args,
template.Parameters{
TaskExecMetadata: tCtx.TaskExecutionMetadata(),
Inputs: arrTCtx.arrayInputReader,
OutputPath: tCtx.OutputWriter(),
Task: tCtx.TaskReader(),
})
if err != nil {
return err
}
container.Command, err = template.Render(ctx, container.Command,
template.Parameters{
TaskExecMetadata: tCtx.TaskExecutionMetadata(),
Inputs: arrTCtx.arrayInputReader,
OutputPath: tCtx.OutputWriter(),
Task: tCtx.TaskReader(),
})
if err != nil {
return err
}
resources := flytek8s.ApplyResourceOverrides(ctx, container.Resources)
if resources != nil {
container.Resources = *resources
}

container.Env = flytek8s.DecorateEnvVars(ctx, container.Env, tCtx.TaskExecutionMetadata().GetTaskExecutionID())
return nil
}

// Note that Name is not set on the result object.
// It's up to the caller to set the Name before creating the object in K8s.
func FlyteArrayJobToK8sPodTemplate(ctx context.Context, tCtx core.TaskExecutionContext, namespaceTemplate string) (
Expand Down Expand Up @@ -193,7 +159,14 @@ func FlyteArrayJobToK8sPodTemplate(ctx context.Context, tCtx core.TaskExecutionC
if err != nil {
return v1.Pod{}, nil, err
}
err = modifyMapPodTaskPrimaryContainer(ctx, tCtx, arrTCtx, &pod.Spec.Containers[containerIndex])
templateParameters := template.Parameters{
TaskExecMetadata: tCtx.TaskExecutionMetadata(),
Inputs: arrTCtx.arrayInputReader,
OutputPath: tCtx.OutputWriter(),
Task: tCtx.TaskReader(),
}
err = flytek8s.AddFlyteCustomizationsToContainer(
ctx, templateParameters, flytek8s.MergeExistingResources, &pod.Spec.Containers[containerIndex])
if err != nil {
return v1.Pod{}, nil, err
}
Expand Down
23 changes: 18 additions & 5 deletions flyteplugins/go/tasks/plugins/array/k8s/transformer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package k8s
import (
"context"
"encoding/json"
"fmt"
"testing"

"k8s.io/apimachinery/pkg/api/resource"
Expand Down Expand Up @@ -168,6 +169,16 @@ func TestFlyteArrayJobToK8sPodTemplate(t *testing.T) {
tMeta.OnGetOwnerReference().Return(v12.OwnerReference{})
tMeta.OnGetSecurityContext().Return(core.SecurityContext{})
tMeta.OnGetK8sServiceAccount().Return("sa")
mockResourceOverrides := mocks.TaskOverrides{}
mockResourceOverrides.OnGetResources().Return(&v1.ResourceRequirements{
Requests: v1.ResourceList{
"ephemeral-storage": resource.MustParse("1024Mi"),
},
Limits: v1.ResourceList{
"ephemeral-storage": resource.MustParse("2048Mi"),
},
})
tMeta.OnGetOverrides().Return(&mockResourceOverrides)
tID := &mocks.TaskExecutionID{}
tID.OnGetID().Return(core.TaskExecutionIdentifier{
NodeExecutionId: &core.NodeExecutionIdentifier{
Expand Down Expand Up @@ -214,14 +225,16 @@ func TestFlyteArrayJobToK8sPodTemplate(t *testing.T) {
defaultMemoryFromConfig := resource.MustParse("1024Mi")
assert.EqualValues(t, v1.ResourceRequirements{
Requests: v1.ResourceList{
v1.ResourceCPU: resource.MustParse("1"),
v1.ResourceMemory: defaultMemoryFromConfig,
v1.ResourceCPU: resource.MustParse("1"),
v1.ResourceMemory: defaultMemoryFromConfig,
v1.ResourceEphemeralStorage: resource.MustParse("1024Mi"),
},
Limits: v1.ResourceList{
v1.ResourceCPU: resource.MustParse("1"),
v1.ResourceMemory: defaultMemoryFromConfig,
v1.ResourceCPU: resource.MustParse("1"),
v1.ResourceMemory: defaultMemoryFromConfig,
v1.ResourceEphemeralStorage: resource.MustParse("2048Mi"),
},
}, pod.Spec.Containers[0].Resources)
}, pod.Spec.Containers[0].Resources, fmt.Sprintf("%+v", pod.Spec.Containers[0].Resources))
assert.EqualValues(t, []v1.EnvVar{
{
Name: "FLYTE_INTERNAL_EXECUTION_ID",
Expand Down
Loading

0 comments on commit 7d44700

Please sign in to comment.