diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 0eb13f6a8..c0020b642 100755 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -23,6 +23,8 @@ const PodKind = "pod" const OOMKilled = "OOMKilled" const Interrupted = "Interrupted" const SIGKILL = 137 +const defaultContainerTemplateName = "default" +const primaryContainerTemplateName = "primary" // ApplyInterruptibleNodeAffinity configures the node-affinity for the pod using the configuration specified. func ApplyInterruptibleNodeAffinity(interruptible bool, podSpec *v1.PodSpec) { @@ -135,7 +137,7 @@ func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (* return pod, nil } -func BuildPodWithSpec(podTemplate *v1.PodTemplate, podSpec *v1.PodSpec) (*v1.Pod, error) { +func BuildPodWithSpec(podTemplate *v1.PodTemplate, podSpec *v1.PodSpec, primaryContainerName string) (*v1.Pod, error) { pod := v1.Pod{ TypeMeta: v12.TypeMeta{ Kind: PodKind, @@ -144,14 +146,59 @@ func BuildPodWithSpec(podTemplate *v1.PodTemplate, podSpec *v1.PodSpec) (*v1.Pod } if podTemplate != nil { + // merge template PodSpec basePodSpec := podTemplate.Template.Spec.DeepCopy() err := mergo.Merge(basePodSpec, podSpec, mergo.WithOverride, mergo.WithAppendSlice) if err != nil { return nil, err } - basePodSpec.Containers = podSpec.Containers + // merge template Containers + var mergedContainers []v1.Container + var defaultContainerTemplate, primaryContainerTemplate *v1.Container + for i := 0; i < len(podTemplate.Template.Spec.Containers); i++ { + if podTemplate.Template.Spec.Containers[i].Name == defaultContainerTemplateName { + defaultContainerTemplate = &podTemplate.Template.Spec.Containers[i] + } else if podTemplate.Template.Spec.Containers[i].Name == primaryContainerTemplateName { + primaryContainerTemplate = &podTemplate.Template.Spec.Containers[i] + } + } + + for _, container := range podSpec.Containers { + // if applicable start with defaultContainerTemplate + var mergedContainer *v1.Container + if defaultContainerTemplate != nil { + mergedContainer = defaultContainerTemplate.DeepCopy() + } + + // if applicable merge with primaryContainerTemplate + if container.Name == primaryContainerName && primaryContainerTemplate != nil { + if mergedContainer == nil { + mergedContainer = primaryContainerTemplate.DeepCopy() + } else { + err := mergo.Merge(mergedContainer, primaryContainerTemplate, mergo.WithOverride, mergo.WithAppendSlice) + if err != nil { + return nil, err + } + } + } + + // if applicable merge with existing container + if mergedContainer == nil { + mergedContainers = append(mergedContainers, container) + } else { + err := mergo.Merge(mergedContainer, container, mergo.WithOverride, mergo.WithAppendSlice) + if err != nil { + return nil, err + } + + mergedContainers = append(mergedContainers, *mergedContainer) + } + + } + // update Pod fields + basePodSpec.Containers = mergedContainers pod.ObjectMeta = podTemplate.Template.ObjectMeta pod.Spec = *basePodSpec } else { diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index a3b5380c1..3af65a413 100755 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -1016,6 +1016,14 @@ func TestDeterminePrimaryContainerPhase(t *testing.T) { func TestBuildPodWithSpec(t *testing.T) { var priority int32 = 1 podSpec := v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Name: "foo", + }, + v1.Container{ + Name: "bar", + }, + }, NodeSelector: map[string]string{ "baz": "bar", }, @@ -1031,13 +1039,27 @@ func TestBuildPodWithSpec(t *testing.T) { }, } - pod, err := BuildPodWithSpec(nil, &podSpec) + pod, err := BuildPodWithSpec(nil, &podSpec, "foo") assert.Nil(t, err) assert.True(t, reflect.DeepEqual(pod.Spec, podSpec)) + defaultContainerTemplate := v1.Container{ + Name: defaultContainerTemplateName, + TerminationMessagePath: "/dev/default-termination-log", + } + + primaryContainerTemplate := v1.Container{ + Name: primaryContainerTemplateName, + TerminationMessagePath: "/dev/primary-termination-log", + } + podTemplate := v1.PodTemplate{ Template: v1.PodTemplateSpec{ Spec: v1.PodSpec{ + Containers: []v1.Container{ + defaultContainerTemplate, + primaryContainerTemplate, + }, HostNetwork: true, NodeSelector: map[string]string{ "foo": "bar", @@ -1052,7 +1074,7 @@ func TestBuildPodWithSpec(t *testing.T) { }, } - pod, err = BuildPodWithSpec(&podTemplate, &podSpec) + pod, err = BuildPodWithSpec(&podTemplate, &podSpec, "foo") assert.Nil(t, err) // validate a PodTemplate-only field @@ -1065,4 +1087,14 @@ func TestBuildPodWithSpec(t *testing.T) { assert.Equal(t, len(podTemplate.Template.Spec.NodeSelector)+len(podSpec.NodeSelector), len(pod.Spec.NodeSelector)) // validate an appended array assert.Equal(t, len(podTemplate.Template.Spec.Tolerations)+len(podSpec.Tolerations), len(pod.Spec.Tolerations)) + + // validate primary container + primaryContainer := pod.Spec.Containers[0] + assert.Equal(t, podSpec.Containers[0].Name, primaryContainer.Name) + assert.Equal(t, primaryContainerTemplate.TerminationMessagePath, primaryContainer.TerminationMessagePath) + + // validate default container + defaultContainer := pod.Spec.Containers[1] + assert.Equal(t, podSpec.Containers[1].Name, defaultContainer.Name) + assert.Equal(t, defaultContainerTemplate.TerminationMessagePath, defaultContainer.TerminationMessagePath) } diff --git a/flyteplugins/go/tasks/plugins/k8s/pod/container.go b/flyteplugins/go/tasks/plugins/k8s/pod/container.go index 0ec8240ec..c457c8bc7 100644 --- a/flyteplugins/go/tasks/plugins/k8s/pod/container.go +++ b/flyteplugins/go/tasks/plugins/k8s/pod/container.go @@ -5,6 +5,7 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteplugins/go/tasks/errors" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" @@ -27,6 +28,14 @@ func (containerPodBuilder) buildPodSpec(ctx context.Context, task *core.TaskTemp return podSpec, nil } +func (containerPodBuilder) getPrimaryContainerName(task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) (string, error) { + primaryContainerName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + if primaryContainerName == "" { + return "", errors.Errorf(errors.BadTaskSpecification, "invalid TaskSpecification, missing generated name") + } + return primaryContainerName, nil +} + func (containerPodBuilder) updatePodMetadata(ctx context.Context, pod *v1.Pod, task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) error { return nil } diff --git a/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go b/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go index 9e0aea1cb..3a399641a 100644 --- a/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go +++ b/flyteplugins/go/tasks/plugins/k8s/pod/plugin.go @@ -34,6 +34,7 @@ var ( type podBuilder interface { buildPodSpec(ctx context.Context, task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, error) + getPrimaryContainerName(task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) (string, error) updatePodMetadata(ctx context.Context, pod *v1.Pod, task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) error } @@ -61,7 +62,6 @@ func (p plugin) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecu builder = p.defaultPodBuilder } - // build pod podSpec, err := builder.buildPodSpec(ctx, task, taskCtx) if err != nil { return nil, err @@ -70,7 +70,12 @@ func (p plugin) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecu podSpec.ServiceAccountName = flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) - pod, err := flytek8s.BuildPodWithSpec(podTemplate, podSpec) + primaryContainerName, err := builder.getPrimaryContainerName(task, taskCtx) + if err != nil { + return nil, err + } + + pod, err := flytek8s.BuildPodWithSpec(podTemplate, podSpec, primaryContainerName) if err != nil { return nil, err } diff --git a/flyteplugins/go/tasks/plugins/k8s/pod/sidecar.go b/flyteplugins/go/tasks/plugins/k8s/pod/sidecar.go index 7a8495258..1c4cf0276 100644 --- a/flyteplugins/go/tasks/plugins/k8s/pod/sidecar.go +++ b/flyteplugins/go/tasks/plugins/k8s/pod/sidecar.go @@ -77,19 +77,31 @@ func (sidecarPodBuilder) buildPodSpec(ctx context.Context, task *core.TaskTempla return &podSpec, nil } -func getPrimaryContainerNameFromConfig(task *core.TaskTemplate) (string, error) { - if len(task.GetConfig()) == 0 { - return "", errors.Errorf(errors.BadTaskSpecification, - "invalid TaskSpecification, config needs to be non-empty and include missing [%s] key", PrimaryContainerKey) - } +func (sidecarPodBuilder) getPrimaryContainerName(task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) (string, error) { + switch task.TaskTypeVersion { + case 0: + // Handles pod tasks when they are defined as Sidecar tasks and marshal the podspec using k8s proto. + sidecarJob := sidecarJob{} + err := utils.UnmarshalStructToObj(task.GetCustom(), &sidecarJob) + if err != nil { + return "", errors.Errorf(errors.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error()) + } - primaryContainerName, ok := task.GetConfig()[PrimaryContainerKey] - if !ok { - return "", errors.Errorf(errors.BadTaskSpecification, - "invalid TaskSpecification, config missing [%s] key in [%v]", PrimaryContainerKey, task.GetConfig()) - } + return sidecarJob.PrimaryContainerName, nil + default: + if len(task.GetConfig()) == 0 { + return "", errors.Errorf(errors.BadTaskSpecification, + "invalid TaskSpecification, config needs to be non-empty and include missing [%s] key", PrimaryContainerKey) + } - return primaryContainerName, nil + primaryContainerName, ok := task.GetConfig()[PrimaryContainerKey] + if !ok { + return "", errors.Errorf(errors.BadTaskSpecification, + "invalid TaskSpecification, config missing [%s] key in [%v]", PrimaryContainerKey, task.GetConfig()) + } + + return primaryContainerName, nil + } } func mergeMapInto(src map[string]string, dst map[string]string) { @@ -98,11 +110,10 @@ func mergeMapInto(src map[string]string, dst map[string]string) { } } -func (sidecarPodBuilder) updatePodMetadata(ctx context.Context, pod *v1.Pod, task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) error { +func (s sidecarPodBuilder) updatePodMetadata(ctx context.Context, pod *v1.Pod, task *core.TaskTemplate, taskCtx pluginsCore.TaskExecutionContext) error { pod.Annotations = make(map[string]string) pod.Labels = make(map[string]string) - var primaryContainerName string switch task.TaskTypeVersion { case 0: // Handles pod tasks when they are defined as Sidecar tasks and marshal the podspec using k8s proto. @@ -114,32 +125,20 @@ func (sidecarPodBuilder) updatePodMetadata(ctx context.Context, pod *v1.Pod, tas mergeMapInto(sidecarJob.Annotations, pod.Annotations) mergeMapInto(sidecarJob.Labels, pod.Labels) - - primaryContainerName = sidecarJob.PrimaryContainerName - case 1: - // Handles pod tasks that marshal the pod spec to the task custom. - containerName, err := getPrimaryContainerNameFromConfig(task) - if err != nil { - return err - } - - primaryContainerName = containerName default: // Handles pod tasks that marshal the pod spec to the k8s_pod task target. - if task.GetK8SPod() == nil || task.GetK8SPod().Metadata != nil { + if task.GetK8SPod() != nil && task.GetK8SPod().Metadata != nil { mergeMapInto(task.GetK8SPod().Metadata.Annotations, pod.Annotations) mergeMapInto(task.GetK8SPod().Metadata.Labels, pod.Labels) } - - containerName, err := getPrimaryContainerNameFromConfig(task) - if err != nil { - return err - } - - primaryContainerName = containerName } // validate pod and update resource requirements + primaryContainerName, err := s.getPrimaryContainerName(task, taskCtx) + if err != nil { + return err + } + if err := validateAndFinalizePodSpec(ctx, taskCtx, primaryContainerName, &pod.Spec); err != nil { return err }