Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Override primary container name instead of flyte generated name (#340)
Browse files Browse the repository at this point in the history
Signed-off-by: byhsu <[email protected]>
Co-authored-by: byhsu <[email protected]>
  • Loading branch information
ByronHsu and ByronHsu authored Apr 17, 2023
1 parent 435436b commit f5f4182
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 23 deletions.
8 changes: 4 additions & 4 deletions go/tasks/pluginmachinery/flytek8s/pod_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,20 +255,20 @@ func ApplyFlytePodConfiguration(ctx context.Context, tCtx pluginsCore.TaskExecut

// ToK8sPodSpec builds a PodSpec and ObjectMeta based on the definition passed by the TaskExecutionContext. This
// involves parsing the raw PodSpec definition and applying all Flyte configuration options.
func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, *metav1.ObjectMeta, error) {
func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, *metav1.ObjectMeta, string, error) {
// build raw PodSpec and ObjectMeta
podSpec, objectMeta, primaryContainerName, err := BuildRawPod(ctx, tCtx)
if err != nil {
return nil, nil, err
return nil, nil, "", err
}

// add flyte configuration
podSpec, objectMeta, err = ApplyFlytePodConfiguration(ctx, tCtx, podSpec, objectMeta, primaryContainerName)
if err != nil {
return nil, nil, err
return nil, nil, "", err
}

return podSpec, objectMeta, nil
return podSpec, objectMeta, primaryContainerName, nil
}

// getBasePodTemplate attempts to retrieve the PodTemplate to use as the base for k8s Pod configuration. This value can
Expand Down
18 changes: 9 additions & 9 deletions go/tasks/pluginmachinery/flytek8s/pod_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ func toK8sPodInterruptible(t *testing.T) {
},
})

p, _, err := ToK8sPodSpec(ctx, x)
p, _, _, err := ToK8sPodSpec(ctx, x)
assert.NoError(t, err)
assert.Len(t, p.Tolerations, 2)
assert.Equal(t, "x/flyte", p.Tolerations[1].Key)
Expand Down Expand Up @@ -391,7 +391,7 @@ func TestToK8sPod(t *testing.T) {
},
})

p, _, err := ToK8sPodSpec(ctx, x)
p, _, _, err := ToK8sPodSpec(ctx, x)
assert.NoError(t, err)
assert.Equal(t, len(p.Tolerations), 1)
})
Expand All @@ -408,7 +408,7 @@ func TestToK8sPod(t *testing.T) {
},
})

p, _, err := ToK8sPodSpec(ctx, x)
p, _, _, err := ToK8sPodSpec(ctx, x)
assert.NoError(t, err)
assert.Equal(t, len(p.Tolerations), 0)
assert.Equal(t, "some-acceptable-name", p.Containers[0].Name)
Expand All @@ -435,7 +435,7 @@ func TestToK8sPod(t *testing.T) {
DefaultMemoryRequest: resource.MustParse("1024Mi"),
}))

p, _, err := ToK8sPodSpec(ctx, x)
p, _, _, err := ToK8sPodSpec(ctx, x)
assert.NoError(t, err)
assert.Equal(t, 1, len(p.NodeSelector))
assert.Equal(t, "myScheduler", p.SchedulerName)
Expand All @@ -452,7 +452,7 @@ func TestToK8sPod(t *testing.T) {
}))

x := dummyExecContext(&v1.ResourceRequirements{})
p, _, err := ToK8sPodSpec(ctx, x)
p, _, _, err := ToK8sPodSpec(ctx, x)
assert.NoError(t, err)
assert.NotNil(t, p.SecurityContext)
assert.Equal(t, *p.SecurityContext.RunAsGroup, v)
Expand All @@ -464,7 +464,7 @@ func TestToK8sPod(t *testing.T) {
EnableHostNetworkingPod: &enabled,
}))
x := dummyExecContext(&v1.ResourceRequirements{})
p, _, err := ToK8sPodSpec(ctx, x)
p, _, _, err := ToK8sPodSpec(ctx, x)
assert.NoError(t, err)
assert.True(t, p.HostNetwork)
})
Expand All @@ -475,15 +475,15 @@ func TestToK8sPod(t *testing.T) {
EnableHostNetworkingPod: &enabled,
}))
x := dummyExecContext(&v1.ResourceRequirements{})
p, _, err := ToK8sPodSpec(ctx, x)
p, _, _, err := ToK8sPodSpec(ctx, x)
assert.NoError(t, err)
assert.False(t, p.HostNetwork)
})

t.Run("skipSettingHostNetwork", func(t *testing.T) {
assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{}))
x := dummyExecContext(&v1.ResourceRequirements{})
p, _, err := ToK8sPodSpec(ctx, x)
p, _, _, err := ToK8sPodSpec(ctx, x)
assert.NoError(t, err)
assert.False(t, p.HostNetwork)
})
Expand Down Expand Up @@ -517,7 +517,7 @@ func TestToK8sPod(t *testing.T) {
}))

x := dummyExecContext(&v1.ResourceRequirements{})
p, _, err := ToK8sPodSpec(ctx, x)
p, _, _, err := ToK8sPodSpec(ctx, x)
assert.NoError(t, err)
assert.NotNil(t, p.DNSConfig)
assert.Equal(t, []string{"8.8.8.8", "8.8.4.4"}, p.DNSConfig.Nameservers)
Expand Down
6 changes: 2 additions & 4 deletions go/tasks/plugins/k8s/kfoperators/common/common_operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,15 @@ func GetLogs(taskType string, name string, namespace string,
return taskLogs, nil
}

func OverrideDefaultContainerName(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec,
defaultContainerName string) {
func OverridePrimaryContainerName(podSpec *v1.PodSpec, primaryContainerName string, defaultContainerName string) {
// Pytorch operator forces pod to have container named 'pytorch'
// https://github.com/kubeflow/pytorch-operator/blob/037cd1b18eb77f657f2a4bc8a8334f2a06324b57/pkg/apis/pytorch/validation/validation.go#L54-L62
// Tensorflow operator forces pod to have container named 'tensorflow'
// https://github.com/kubeflow/tf-operator/blob/984adc287e6fe82841e4ca282dc9a2cbb71e2d4a/pkg/apis/tensorflow/validation/validation.go#L55-L63
// hence we have to override the name set here
// https://github.com/flyteorg/flyteplugins/blob/209c52d002b4e6a39be5d175bc1046b7e631c153/go/tasks/pluginmachinery/flytek8s/container_helper.go#L116
flyteDefaultContainerName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
for idx, c := range podSpec.Containers {
if c.Name == flyteDefaultContainerName {
if c.Name == primaryContainerName {
podSpec.Containers[idx].Name = defaultContainerName
return
}
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/k8s/kfoperators/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu
launcherReplicas := mpiTaskExtraArgs.GetNumLauncherReplicas()
slots := mpiTaskExtraArgs.GetSlots()

podSpec, objectMeta, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)
podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}
common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.MPIJobDefaultContainerName)
common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.MPIJobDefaultContainerName)

// workersPodSpec is deepCopy of podSpec submitted by flyte
// WorkerPodSpec doesn't need any Argument & command. It will be trigger from launcher pod
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error())
}

podSpec, objectMeta, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)
podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}
common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.PytorchJobDefaultContainerName)
common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.PytorchJobDefaultContainerName)

workers := pytorchTaskExtraArgs.GetWorkers()
if workers == 0 {
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error())
}

podSpec, objectMeta, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)
podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}
common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.TFJobDefaultContainerName)
common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.TFJobDefaultContainerName)

workers := tensorflowTaskExtraArgs.GetWorkers()
psReplicas := tensorflowTaskExtraArgs.GetPsReplicas()
Expand Down

0 comments on commit f5f4182

Please sign in to comment.