Skip to content

Commit

Permalink
Honor primaryContainerName in default pod spec (#300)
Browse files Browse the repository at this point in the history
## Overview
Fixes a bug where we weren't propagating primary container name when using the default pod spec.

## Test Plan
Updated unittests to cover this case

## Rollout Plan (if applicable)
Just merge for now. We can take out eventually. No one is depending on this.

## Upstream Changes
Should this change be upstreamed to OSS (flyteorg/flyte)? If so, please check this box for auditing. Note, this is the responsibility of each developer. See [this guide](https://unionai.atlassian.net/wiki/spaces/ENG/pages/447610883/Flyte+-+Union+Cloud+Development+Runbook/#When-are-versions-updated%3F).
- [ ] To be upstreamed
  • Loading branch information
andrewwdye authored May 28, 2024
1 parent 63b46a2 commit de87729
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 18 deletions.
3 changes: 1 addition & 2 deletions fasttask/plugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,10 @@ func (p *Plugin) getExecutionEnv(ctx context.Context, tCtx core.TaskExecutionCon
}

fastTaskEnvironmentSpec.PodTemplateSpec = podTemplateSpecBytes
fastTaskEnvironmentSpec.PrimaryContainerName = primaryContainerName
if err := utils.MarshalStruct(fastTaskEnvironmentSpec, environmentSpec); err != nil {
return nil, fmt.Errorf("unable to marshal EnvironmentSpec [%v], Err: [%v]", fastTaskEnvironmentSpec, err.Error())
}

fastTaskEnvironmentSpec.PrimaryContainerName = primaryContainerName
}

environment, err := executionEnvClient.Create(ctx, executionEnv.GetId(), environmentSpec)
Expand Down
63 changes: 47 additions & 16 deletions fasttask/plugin/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@ package plugin

import (
"context"
"encoding/json"
"testing"
"time"

"github.com/golang/protobuf/proto"
_struct "github.com/golang/protobuf/ptypes/struct"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/protobuf/types/known/structpb"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"

idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
coremocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s"
iomocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils"

Expand Down Expand Up @@ -82,7 +85,7 @@ func TestFinalize(t *testing.T) {
// initialize static execution context attributes
taskMetadata := &coremocks.TaskExecutionMetadata{}
taskExecutionID := &coremocks.TaskExecutionID{}
taskExecutionID.OnGetGeneratedNameWithMatch(mock.Anything, mock.Anything).Return("task_id", nil)
taskExecutionID.OnGetGeneratedNameWithMatch(mock.Anything, mock.Anything).Return("task-id", nil)
taskMetadata.OnGetTaskExecutionID().Return(taskExecutionID)

// create TaskExecutionContext
Expand All @@ -106,7 +109,7 @@ func TestFinalize(t *testing.T) {

// create FastTaskService mock
fastTaskService := &mocks.FastTaskService{}
fastTaskService.OnCleanup(ctx, "task_id", "foo", "w0").Return(nil)
fastTaskService.OnCleanup(ctx, "task-id", "foo", "w0").Return(nil)

// initialize plugin
plugin := &Plugin{
Expand All @@ -120,6 +123,8 @@ func TestFinalize(t *testing.T) {

func TestGetExecutionEnv(t *testing.T) {
ctx := context.TODO()
tCtx := &coremocks.TaskExecutionContext{}
tCtx.OnTaskReader().Return(&coremocks.TaskReader{})

expectedExtant := &pb.FastTaskEnvironment{
QueueId: "foo",
Expand All @@ -128,11 +133,19 @@ func TestGetExecutionEnv(t *testing.T) {
err := utils.MarshalStruct(expectedExtant, expectedExtantStruct)
assert.Nil(t, err)

toFastTaskSpec := func(spec *pb.FastTaskEnvironmentSpec) *structpb.Struct {
specStruct := &_struct.Struct{}
err := utils.MarshalStruct(spec, specStruct)
assert.Nil(t, err)
return specStruct
}

tests := []struct {
name string
fastTaskExtant *pb.FastTaskEnvironment
fastTaskSpec *pb.FastTaskEnvironmentSpec
clientGetExists bool
name string
fastTaskExtant *pb.FastTaskEnvironment
fastTaskSpec *pb.FastTaskEnvironmentSpec
clientGetExists bool
createExectionEnvMatcher interface{} // func (environmentSpec *structpb.Struct) bool
}{
{
name: "ExecutionExtant",
Expand All @@ -141,21 +154,36 @@ func TestGetExecutionEnv(t *testing.T) {
},
},
{
name: "ExecutionSpecExists",
fastTaskSpec: &pb.FastTaskEnvironmentSpec{},
clientGetExists: true,
name: "ExecutionSpecExists",
fastTaskSpec: &pb.FastTaskEnvironmentSpec{},
clientGetExists: true,
createExectionEnvMatcher: expectedExtantStruct,
},
{
name: "ExecutionSpecCreate",
fastTaskSpec: &pb.FastTaskEnvironmentSpec{
PodTemplateSpec: []byte("bar"),
},
clientGetExists: false,
createExectionEnvMatcher: toFastTaskSpec(
&pb.FastTaskEnvironmentSpec{
PodTemplateSpec: []byte("bar"),
},
),
},
{
name: "ExecutionSpecInjectPodTemplateAndCreate",
fastTaskSpec: &pb.FastTaskEnvironmentSpec{},
clientGetExists: false,
createExectionEnvMatcher: mock.MatchedBy(func(environmentSpec *structpb.Struct) bool {
spec := &pb.FastTaskEnvironmentSpec{}
err := utils.UnmarshalStruct(environmentSpec, spec)
assert.Nil(t, err)
var podTemplateSpec v1.PodTemplateSpec
err = json.Unmarshal(spec.GetPodTemplateSpec(), &podTemplateSpec)
assert.Nil(t, err)
return podTemplateSpec.Namespace == "test-namespace" && spec.GetPrimaryContainerName() == "task-id"
}),
},
}

Expand Down Expand Up @@ -189,9 +217,9 @@ func TestGetExecutionEnv(t *testing.T) {
},
},
})
taskExecutionID.OnGetGeneratedNameMatch().Return("task_id")
taskExecutionID.OnGetGeneratedNameMatch().Return("task-id")
taskMetadata.OnGetTaskExecutionID().Return(taskExecutionID)
taskExecutionID.OnGetGeneratedNameWithMatch(mock.Anything, mock.Anything).Return("task_id", nil)
taskExecutionID.OnGetGeneratedNameWithMatch(mock.Anything, mock.Anything).Return("task-id", nil)
taskMetadata.OnGetTaskExecutionID().Return(taskExecutionID)

taskOverrides := &coremocks.TaskOverrides{}
Expand All @@ -211,6 +239,9 @@ func TestGetExecutionEnv(t *testing.T) {
Args: []string{},
},
},
Config: map[string]string{
flytek8s.PrimaryContainerKey: "primary",
},
}

// create ExecutionEnvClient mock
Expand All @@ -220,7 +251,7 @@ func TestGetExecutionEnv(t *testing.T) {
} else {
executionEnvClient.OnGetMatch(ctx, mock.Anything).Return(nil)
}
executionEnvClient.OnCreateMatch(ctx, "foo", mock.Anything).Return(expectedExtantStruct, nil)
executionEnvClient.OnCreateMatch(ctx, "foo", test.createExectionEnvMatcher).Return(expectedExtantStruct, nil)

// create TaskExecutionContext
tCtx := &coremocks.TaskExecutionContext{}
Expand Down Expand Up @@ -303,7 +334,7 @@ func TestHandleNotYetStarted(t *testing.T) {
Name: "execution_id",
})
taskExecutionID := &coremocks.TaskExecutionID{}
taskExecutionID.OnGetGeneratedNameWithMatch(mock.Anything, mock.Anything).Return("task_id", nil)
taskExecutionID.OnGetGeneratedNameWithMatch(mock.Anything, mock.Anything).Return("task-id", nil)
taskMetadata.OnGetTaskExecutionID().Return(taskExecutionID)

for _, test := range tests {
Expand Down Expand Up @@ -341,7 +372,7 @@ func TestHandleNotYetStarted(t *testing.T) {

// create FastTaskService mock
fastTaskService := &mocks.FastTaskService{}
fastTaskService.OnOfferOnQueue(ctx, "foo", "task_id", "namespace", "execution_id", []string{}).Return(test.workerID, nil)
fastTaskService.OnOfferOnQueue(ctx, "foo", "task-id", "namespace", "execution_id", []string{}).Return(test.workerID, nil)

// initialize plugin
plugin := &Plugin{
Expand Down Expand Up @@ -426,7 +457,7 @@ func TestHandleRunning(t *testing.T) {
Name: "execution_id",
})
taskExecutionID := &coremocks.TaskExecutionID{}
taskExecutionID.OnGetGeneratedNameWithMatch(mock.Anything, mock.Anything).Return("task_id", nil)
taskExecutionID.OnGetGeneratedNameWithMatch(mock.Anything, mock.Anything).Return("task-id", nil)
taskMetadata.OnGetTaskExecutionID().Return(taskExecutionID)

for _, test := range tests {
Expand Down Expand Up @@ -463,7 +494,7 @@ func TestHandleRunning(t *testing.T) {

// create FastTaskService mock
fastTaskService := &mocks.FastTaskService{}
fastTaskService.OnCheckStatusMatch(ctx, "task_id", "foo", "w0").Return(test.taskStatusPhase, "", test.checkStatusError)
fastTaskService.OnCheckStatusMatch(ctx, "task-id", "foo", "w0").Return(test.taskStatusPhase, "", test.checkStatusError)

// initialize plugin
plugin := &Plugin{
Expand Down

0 comments on commit de87729

Please sign in to comment.