From 4f8556ede15df2352d809b10211c62d392f2aea6 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 12 May 2023 10:55:51 -0700 Subject: [PATCH] Add environment variables to execution spec (#556) * Add envs to execution spec Signed-off-by: Kevin Su * update Signed-off-by: Kevin Su * update Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su * update idl Signed-off-by: Kevin Su * lint Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su --------- Signed-off-by: Kevin Su --- flyteadmin/go.mod | 6 +- flyteadmin/go.sum | 12 +- .../manager/impl/execution_manager_test.go | 250 ++++++++++++++++++ flyteadmin/pkg/manager/impl/shared/iface.go | 2 + .../manager/impl/testutils/mock_requests.go | 1 + flyteadmin/pkg/manager/impl/util/shared.go | 5 + .../interfaces/application_configuration.go | 16 ++ .../workflowengine/impl/prepare_execution.go | 8 + .../impl/prepare_execution_test.go | 8 + 9 files changed, 299 insertions(+), 9 deletions(-) diff --git a/flyteadmin/go.mod b/flyteadmin/go.mod index 5674480a9..7b9c13c41 100644 --- a/flyteadmin/go.mod +++ b/flyteadmin/go.mod @@ -13,9 +13,9 @@ require ( github.com/cloudevents/sdk-go/v2 v2.8.0 github.com/coreos/go-oidc v2.2.1+incompatible github.com/evanphx/json-patch v4.12.0+incompatible - github.com/flyteorg/flyteidl v1.5.0 - github.com/flyteorg/flyteplugins v1.0.40 - github.com/flyteorg/flytepropeller v1.1.70 + github.com/flyteorg/flyteidl v1.5.3 + github.com/flyteorg/flyteplugins v1.0.56 + github.com/flyteorg/flytepropeller v1.1.87 github.com/flyteorg/flytestdlib v1.0.15 github.com/flyteorg/stow v0.3.6 github.com/ghodss/yaml v1.0.0 diff --git a/flyteadmin/go.sum b/flyteadmin/go.sum index 9ba7e3c50..1ef64e388 100644 --- a/flyteadmin/go.sum +++ b/flyteadmin/go.sum @@ -312,12 +312,12 @@ github.com/fatih/structs v1.0.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/flyteorg/flyteidl v1.5.0 h1:vdaA5Cg9eqi5NMuASSod/AE7RXlHvzdWjSL9abDyd/M= -github.com/flyteorg/flyteidl v1.5.0/go.mod h1:ckLjB51moX4L0oQml+WTCrPK50zrJf6IZJ6LPC0RB4I= -github.com/flyteorg/flyteplugins v1.0.40 h1:RTsYingqmqr13qBbi4CB2ArXDHNHUOkAF+HTLJQiQ/s= -github.com/flyteorg/flyteplugins v1.0.40/go.mod h1:qyUPqVspLcLGJpKxVwHDWf+kBpOGuItOxCaF6zAmDio= -github.com/flyteorg/flytepropeller v1.1.70 h1:/d1qqz13rdVADM85ST70eerAdBstJJz9UUB/mNSZi0w= -github.com/flyteorg/flytepropeller v1.1.70/go.mod h1:MezHUJmgPzm4Pu8nIy6LLiEkxNA6buTQ7hInSqCViTY= +github.com/flyteorg/flyteidl v1.5.3 h1:qHyU9kvcxGIkXoloi768ayx9FHrs961dZC3WYziGGZA= +github.com/flyteorg/flyteidl v1.5.3/go.mod h1:ckLjB51moX4L0oQml+WTCrPK50zrJf6IZJ6LPC0RB4I= +github.com/flyteorg/flyteplugins v1.0.56 h1:kBTDgTpdSi7wcptk4cMwz5vfh1MU82VaUMMboe1InXw= +github.com/flyteorg/flyteplugins v1.0.56/go.mod h1:aFCKSn8TPzxSAILIiogHtUnHlUCN9+y6Vf+r9R4KZDU= +github.com/flyteorg/flytepropeller v1.1.87 h1:Px7ASDjrWyeVrUb15qXmhw9QK7xPcFjL5Yetr2P6iGM= +github.com/flyteorg/flytepropeller v1.1.87/go.mod h1:rBTB2jJpSZL1SvbgyiVh5Cobh3Azi/FvawXMxqB/uvo= github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0= github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s= github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk= diff --git a/flyteadmin/pkg/manager/impl/execution_manager_test.go b/flyteadmin/pkg/manager/impl/execution_manager_test.go index a3c384312..7549e03f5 100644 --- a/flyteadmin/pkg/manager/impl/execution_manager_test.go +++ b/flyteadmin/pkg/manager/impl/execution_manager_test.go @@ -1083,6 +1083,90 @@ func TestCreateExecutionOverwriteCache(t *testing.T) { } } +func TestCreateExecutionWithEnvs(t *testing.T) { + tests := []struct { + name string + task bool + envs []*core.KeyValuePair + want []*core.KeyValuePair + }{ + { + name: "LaunchPlanDefault", + task: false, + envs: nil, + want: nil, + }, + { + name: "LaunchPlanEnable", + task: false, + envs: []*core.KeyValuePair{{Key: "foo", Value: "bar"}}, + want: []*core.KeyValuePair{{Key: "foo", Value: "bar"}}, + }, + { + name: "TaskDefault", + task: false, + envs: nil, + want: nil, + }, + { + name: "TaskEnable", + task: true, + envs: []*core.KeyValuePair{{Key: "foo", Value: "bar"}}, + want: []*core.KeyValuePair{{Key: "foo", Value: "bar"}}, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + request := testutils.GetExecutionRequest() + if tt.task { + request.Spec.LaunchPlan.ResourceType = core.ResourceType_TASK + } + request.Spec.Envs.Values = tt.envs + + repository := getMockRepositoryForExecTest() + setDefaultLpCallbackForExecTest(repository) + setDefaultTaskCallbackForExecTest(repository) + + exCreateFunc := func(ctx context.Context, input models.Execution) error { + var spec admin.ExecutionSpec + err := proto.Unmarshal(input.Spec, &spec) + assert.Nil(t, err) + + if tt.task { + assert.Equal(t, uint(0), input.LaunchPlanID) + assert.NotEqual(t, uint(0), input.TaskID) + } else { + assert.NotEqual(t, uint(0), input.LaunchPlanID) + assert.Equal(t, uint(0), input.TaskID) + } + if len(tt.envs) != 0 { + assert.Equal(t, tt.envs[0].Key, spec.GetEnvs().Values[0].Key) + assert.Equal(t, tt.envs[0].Value, spec.GetEnvs().Values[0].Value) + } else { + assert.Nil(t, spec.GetEnvs().GetValues()) + } + + return nil + } + + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCreateCallback(exCreateFunc) + mockExecutor := workflowengineMocks.WorkflowExecutor{} + mockExecutor.OnExecuteMatch(mock.Anything, mock.Anything, mock.Anything).Return(workflowengineInterfaces.ExecutionResponse{}, nil) + mockExecutor.OnID().Return("testMockExecutor") + r := plugins.NewRegistry() + r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &mockExecutor) + execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + + _, err := execManager.CreateExecution(context.Background(), request, requestedAt) + assert.Nil(t, err) + }) + } +} + func makeExecutionGetFunc( t *testing.T, closureBytes []byte, startTime *time.Time) repositoryMocks.GetExecutionFunc { return func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { @@ -1205,6 +1289,39 @@ func makeExecutionOverwriteCacheGetFunc( } } +func makeExecutionWithEnvs( + t *testing.T, closureBytes []byte, startTime *time.Time, envs []*core.KeyValuePair) repositoryMocks.GetExecutionFunc { + return func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { + assert.Equal(t, "project", input.Project) + assert.Equal(t, "domain", input.Domain) + assert.Equal(t, "name", input.Name) + + request := testutils.GetExecutionRequest() + request.Spec.Envs.Values = envs + + specBytes, err := proto.Marshal(request.Spec) + assert.Nil(t, err) + + return models.Execution{ + ExecutionKey: models.ExecutionKey{ + Project: "project", + Domain: "domain", + Name: "name", + }, + BaseModel: models.BaseModel{ + ID: uint(8), + }, + Spec: specBytes, + Phase: core.WorkflowExecution_QUEUED.String(), + Closure: closureBytes, + LaunchPlanID: uint(1), + WorkflowID: uint(2), + StartedAt: startTime, + Cluster: testCluster, + }, nil + } +} + func TestRelaunchExecution(t *testing.T) { // Set up mocks. repository := getMockRepositoryForExecTest() @@ -1518,6 +1635,58 @@ func TestRelaunchExecutionOverwriteCacheOverride(t *testing.T) { }) } +func TestRelaunchExecutionEnvsOverride(t *testing.T) { + // Set up mocks. + repository := getMockRepositoryForExecTest() + setDefaultLpCallbackForExecTest(repository) + mockExecutor := workflowengineMocks.WorkflowExecutor{} + mockExecutor.OnExecuteMatch(mock.Anything, mock.Anything, mock.Anything).Return(workflowengineInterfaces.ExecutionResponse{}, nil) + mockExecutor.OnID().Return("testMockExecutor") + r := plugins.NewRegistry() + r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &mockExecutor) + execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + startTime := time.Now() + startTimeProto, _ := ptypes.TimestampProto(startTime) + existingClosure := admin.ExecutionClosure{ + Phase: core.WorkflowExecution_RUNNING, + StartedAt: startTimeProto, + } + existingClosureBytes, _ := proto.Marshal(&existingClosure) + env := []*core.KeyValuePair{{Key: "foo", Value: "bar"}} + executionGetFunc := makeExecutionWithEnvs(t, existingClosureBytes, &startTime, env) + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(executionGetFunc) + + var createCalled bool + exCreateFunc := func(ctx context.Context, input models.Execution) error { + createCalled = true + assert.Equal(t, "relaunchy", input.Name) + assert.Equal(t, "domain", input.Domain) + assert.Equal(t, "project", input.Project) + assert.Equal(t, uint(8), input.SourceExecutionID) + var spec admin.ExecutionSpec + err := proto.Unmarshal(input.Spec, &spec) + assert.Nil(t, err) + assert.Equal(t, admin.ExecutionMetadata_RELAUNCH, spec.Metadata.Mode) + assert.Equal(t, int32(admin.ExecutionMetadata_RELAUNCH), input.Mode) + assert.NotNil(t, spec.GetEnvs()) + assert.Equal(t, spec.GetEnvs().Values[0].Key, env[0].Key) + assert.Equal(t, spec.GetEnvs().Values[0].Value, env[0].Value) + return nil + } + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCreateCallback(exCreateFunc) + + _, err := execManager.RelaunchExecution(context.Background(), admin.ExecutionRelaunchRequest{ + Id: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + Name: "relaunchy", + }, requestedAt) + assert.Nil(t, err) + assert.True(t, createCalled) +} + func TestRecoverExecution(t *testing.T) { // Set up mocks. repository := getMockRepositoryForExecTest() @@ -1879,6 +2048,67 @@ func TestRecoverExecutionOverwriteCacheOverride(t *testing.T) { assert.True(t, proto.Equal(expectedResponse, response)) } +func TestRecoverExecutionEnvsOverride(t *testing.T) { + // Set up mocks. + repository := getMockRepositoryForExecTest() + setDefaultLpCallbackForExecTest(repository) + mockExecutor := workflowengineMocks.WorkflowExecutor{} + mockExecutor.OnExecuteMatch(mock.Anything, mock.Anything, mock.Anything).Return(workflowengineInterfaces.ExecutionResponse{}, nil) + mockExecutor.OnID().Return("testMockExecutor") + r := plugins.NewRegistry() + r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &mockExecutor) + execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + startTime := time.Now() + startTimeProto, _ := ptypes.TimestampProto(startTime) + existingClosure := admin.ExecutionClosure{ + Phase: core.WorkflowExecution_SUCCEEDED, + StartedAt: startTimeProto, + } + existingClosureBytes, _ := proto.Marshal(&existingClosure) + env := []*core.KeyValuePair{{Key: "foo", Value: "bar"}} + executionGetFunc := makeExecutionWithEnvs(t, existingClosureBytes, &startTime, env) + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(executionGetFunc) + + exCreateFunc := func(ctx context.Context, input models.Execution) error { + assert.Equal(t, "recovered", input.Name) + assert.Equal(t, "domain", input.Domain) + assert.Equal(t, "project", input.Project) + assert.Equal(t, uint(8), input.SourceExecutionID) + var spec admin.ExecutionSpec + err := proto.Unmarshal(input.Spec, &spec) + assert.Nil(t, err) + assert.Equal(t, admin.ExecutionMetadata_RECOVERED, spec.Metadata.Mode) + assert.Equal(t, int32(admin.ExecutionMetadata_RECOVERED), input.Mode) + assert.NotNil(t, spec.GetEnvs()) + assert.Equal(t, spec.GetEnvs().GetValues()[0].Key, env[0].Key) + assert.Equal(t, spec.GetEnvs().GetValues()[0].Value, env[0].Value) + return nil + } + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCreateCallback(exCreateFunc) + + // Issue request. + response, err := execManager.RecoverExecution(context.Background(), admin.ExecutionRecoverRequest{ + Id: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + Name: "recovered", + }, requestedAt) + + // And verify response. + assert.Nil(t, err) + + expectedResponse := &admin.ExecutionCreateResponse{ + Id: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "recovered", + }, + } + assert.True(t, proto.Equal(expectedResponse, response)) +} + func TestCreateWorkflowEvent(t *testing.T) { repository := repositoryMocks.NewMockRepository() startTime := time.Now() @@ -4208,6 +4438,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { requestMaxParallelism := int32(10) requestInterruptible := false requestOverwriteCache := false + requestEnvironmentVariables := []*core.KeyValuePair{{Key: "hello", Value: "world"}} launchPlanLabels := map[string]string{"launchPlanLabelKey": "launchPlanLabelValue"} launchPlanAnnotations := map[string]string{"launchPlanAnnotationKey": "launchPlanAnnotationValue"} @@ -4217,6 +4448,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { launchPlanMaxParallelism := int32(50) launchPlanInterruptible := true launchPlanOverwriteCache := true + launchPlanEnvironmentVariables := []*core.KeyValuePair{{Key: "foo", Value: "bar"}} applicationConfig := runtime.NewConfigurationProvider() @@ -4305,6 +4537,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { MaxParallelism: requestMaxParallelism, Interruptible: &wrappers.BoolValue{Value: requestInterruptible}, OverwriteCache: requestOverwriteCache, + Envs: &admin.Envs{Values: requestEnvironmentVariables}, }, } execConfig, err := executionManager.getExecutionConfig(context.TODO(), request, nil) @@ -4316,6 +4549,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.Equal(t, requestOutputLocationPrefix, execConfig.RawOutputDataConfig.OutputLocationPrefix) assert.Equal(t, requestLabels, execConfig.GetLabels().Values) assert.Equal(t, requestAnnotations, execConfig.GetAnnotations().Values) + assert.Equal(t, requestEnvironmentVariables, execConfig.GetEnvs().Values) }) t.Run("request with partial config", func(t *testing.T) { request := &admin.ExecutionCreateRequest{ @@ -4343,6 +4577,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { MaxParallelism: launchPlanMaxParallelism, Interruptible: &wrappers.BoolValue{Value: launchPlanInterruptible}, OverwriteCache: launchPlanOverwriteCache, + Envs: &admin.Envs{Values: launchPlanEnvironmentVariables}, }, } execConfig, err := executionManager.getExecutionConfig(context.TODO(), request, launchPlan) @@ -4354,6 +4589,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.True(t, proto.Equal(launchPlan.Spec.Annotations, execConfig.Annotations)) assert.Equal(t, requestOutputLocationPrefix, execConfig.RawOutputDataConfig.OutputLocationPrefix) assert.Equal(t, requestLabels, execConfig.GetLabels().Values) + assert.Equal(t, launchPlanEnvironmentVariables, execConfig.GetEnvs().Values) }) t.Run("request with empty security context", func(t *testing.T) { request := &admin.ExecutionCreateRequest{ @@ -4381,6 +4617,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { MaxParallelism: launchPlanMaxParallelism, Interruptible: &wrappers.BoolValue{Value: launchPlanInterruptible}, OverwriteCache: launchPlanOverwriteCache, + Envs: &admin.Envs{Values: launchPlanEnvironmentVariables}, }, } execConfig, err := executionManager.getExecutionConfig(context.TODO(), request, launchPlan) @@ -4391,6 +4628,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.Equal(t, launchPlanK8sServiceAccount, execConfig.SecurityContext.RunAs.K8SServiceAccount) assert.Equal(t, launchPlanOutputLocationPrefix, execConfig.RawOutputDataConfig.OutputLocationPrefix) assert.Equal(t, launchPlanLabels, execConfig.GetLabels().Values) + assert.Equal(t, launchPlanEnvironmentVariables, execConfig.GetEnvs().Values) }) t.Run("request with no config", func(t *testing.T) { request := &admin.ExecutionCreateRequest{ @@ -4413,6 +4651,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { MaxParallelism: launchPlanMaxParallelism, Interruptible: &wrappers.BoolValue{Value: launchPlanInterruptible}, OverwriteCache: launchPlanOverwriteCache, + Envs: &admin.Envs{Values: launchPlanEnvironmentVariables}, }, } execConfig, err := executionManager.getExecutionConfig(context.TODO(), request, launchPlan) @@ -4424,6 +4663,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.Equal(t, launchPlanOutputLocationPrefix, execConfig.RawOutputDataConfig.OutputLocationPrefix) assert.Equal(t, launchPlanLabels, execConfig.GetLabels().Values) assert.Equal(t, launchPlanAnnotations, execConfig.GetAnnotations().Values) + assert.Equal(t, launchPlanEnvironmentVariables, execConfig.GetEnvs().Values) }) t.Run("launchplan with partial config", func(t *testing.T) { request := &admin.ExecutionCreateRequest{ @@ -4474,6 +4714,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.Equal(t, rmOutputLocationPrefix, execConfig.RawOutputDataConfig.OutputLocationPrefix) assert.Equal(t, rmLabels, execConfig.GetLabels().Values) assert.Equal(t, rmAnnotations, execConfig.GetAnnotations().Values) + assert.Nil(t, execConfig.GetEnvs()) }) t.Run("matchable resource partial config", func(t *testing.T) { resourceManager.GetResourceFunc = func(ctx context.Context, @@ -4520,6 +4761,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.Nil(t, execConfig.GetRawOutputDataConfig()) assert.Nil(t, execConfig.GetLabels()) assert.Equal(t, rmAnnotations, execConfig.GetAnnotations().Values) + assert.Nil(t, execConfig.GetEnvs()) }) t.Run("matchable resource with no config", func(t *testing.T) { resourceManager.GetResourceFunc = func(ctx context.Context, @@ -4557,6 +4799,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.Nil(t, execConfig.GetRawOutputDataConfig()) assert.Nil(t, execConfig.GetLabels()) assert.Nil(t, execConfig.GetAnnotations()) + assert.Nil(t, execConfig.GetEnvs()) }) t.Run("fetch security context from deprecated config", func(t *testing.T) { resourceManager.GetResourceFunc = func(ctx context.Context, @@ -4599,6 +4842,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.Nil(t, execConfig.GetRawOutputDataConfig()) assert.Nil(t, execConfig.GetLabels()) assert.Nil(t, execConfig.GetAnnotations()) + assert.Nil(t, execConfig.GetEnvs()) }) t.Run("matchable resource workflow resource", func(t *testing.T) { resourceManager.GetResourceFunc = func(ctx context.Context, @@ -4652,6 +4896,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.Nil(t, execConfig.GetRawOutputDataConfig()) assert.Nil(t, execConfig.GetLabels()) assert.Nil(t, execConfig.GetAnnotations()) + assert.Nil(t, execConfig.GetEnvs()) }) t.Run("matchable resource failure", func(t *testing.T) { resourceManager.GetResourceFunc = func(ctx context.Context, @@ -4682,6 +4927,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.Nil(t, execConfig.GetRawOutputDataConfig()) assert.Nil(t, execConfig.GetLabels()) assert.Nil(t, execConfig.GetAnnotations()) + assert.Nil(t, execConfig.GetEnvs()) }) t.Run("application configuration", func(t *testing.T) { resourceManager.GetResourceFunc = func(ctx context.Context, @@ -4790,6 +5036,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { launchPlan := &admin.LaunchPlan{ Spec: &admin.LaunchPlanSpec{ Interruptible: &wrappers.BoolValue{Value: true}, + Envs: &admin.Envs{Values: []*core.KeyValuePair{{Key: "foo", Value: "bar"}}}, }, } @@ -4801,6 +5048,9 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.Nil(t, execConfig.GetRawOutputDataConfig()) assert.Nil(t, execConfig.GetLabels()) assert.Nil(t, execConfig.GetAnnotations()) + assert.Equal(t, 1, len(execConfig.Envs.Values)) + assert.Equal(t, "foo", execConfig.Envs.Values[0].Key) + assert.Equal(t, "bar", execConfig.Envs.Values[0].Value) }) t.Run("launch plan with no interruptible override specified", func(t *testing.T) { request := &admin.ExecutionCreateRequest{ diff --git a/flyteadmin/pkg/manager/impl/shared/iface.go b/flyteadmin/pkg/manager/impl/shared/iface.go index 7baae65a1..0caba1d50 100644 --- a/flyteadmin/pkg/manager/impl/shared/iface.go +++ b/flyteadmin/pkg/manager/impl/shared/iface.go @@ -24,4 +24,6 @@ type WorkflowExecutionConfigInterface interface { GetInterruptible() *wrappers.BoolValue // GetOverwriteCache indicates a workflow should skip all its cached results and re-compute its output, overwriting any already stored data. GetOverwriteCache() bool + // GetEnvs defines environment variables to be set for the execution. + GetEnvs() *admin.Envs } diff --git a/flyteadmin/pkg/manager/impl/testutils/mock_requests.go b/flyteadmin/pkg/manager/impl/testutils/mock_requests.go index 2a8d47dd9..8b8473376 100644 --- a/flyteadmin/pkg/manager/impl/testutils/mock_requests.go +++ b/flyteadmin/pkg/manager/impl/testutils/mock_requests.go @@ -221,6 +221,7 @@ func GetExecutionRequest() admin.ExecutionCreateRequest { }, }, RawOutputDataConfig: &admin.RawOutputDataConfig{OutputLocationPrefix: "default_raw_output"}, + Envs: &admin.Envs{}, }, Inputs: &core.LiteralMap{ Literals: map[string]*core.Literal{ diff --git a/flyteadmin/pkg/manager/impl/util/shared.go b/flyteadmin/pkg/manager/impl/util/shared.go index bf9490473..31cf89608 100644 --- a/flyteadmin/pkg/manager/impl/util/shared.go +++ b/flyteadmin/pkg/manager/impl/util/shared.go @@ -325,5 +325,10 @@ func MergeIntoExecConfig(workflowExecConfig admin.WorkflowExecutionConfig, spec workflowExecConfig.OverwriteCache = spec.GetOverwriteCache() } + if (workflowExecConfig.GetEnvs() == nil || len(workflowExecConfig.GetEnvs().Values) == 0) && + (spec.GetEnvs() != nil && len(spec.GetEnvs().Values) > 0) { + workflowExecConfig.Envs = spec.GetEnvs() + } + return workflowExecConfig } diff --git a/flyteadmin/pkg/runtime/interfaces/application_configuration.go b/flyteadmin/pkg/runtime/interfaces/application_configuration.go index 16b1f921d..cf9bf2e9e 100644 --- a/flyteadmin/pkg/runtime/interfaces/application_configuration.go +++ b/flyteadmin/pkg/runtime/interfaces/application_configuration.go @@ -94,6 +94,9 @@ type ApplicationConfig struct { // Enabling will use Storage (s3/gcs/etc) to offload static parts of CRDs. UseOffloadedWorkflowClosure bool `json:"useOffloadedWorkflowClosure"` + + // Environment variables to be set for the execution. + Envs map[string]string `json:"envs,omitempty"` } func (a *ApplicationConfig) GetRoleNameKey() string { @@ -166,6 +169,19 @@ func (a *ApplicationConfig) GetOverwriteCache() bool { return a.OverwriteCache } +func (a *ApplicationConfig) GetEnvs() *admin.Envs { + var envs []*core.KeyValuePair + for k, v := range a.Envs { + envs = append(envs, &core.KeyValuePair{ + Key: k, + Value: v, + }) + } + return &admin.Envs{ + Values: envs, + } +} + // GetAsWorkflowExecutionConfig returns the WorkflowExecutionConfig as extracted from this object func (a *ApplicationConfig) GetAsWorkflowExecutionConfig() admin.WorkflowExecutionConfig { // These values should always be set as their fallback values equals to their zero value or nil, diff --git a/flyteadmin/pkg/workflowengine/impl/prepare_execution.go b/flyteadmin/pkg/workflowengine/impl/prepare_execution.go index f2a778e27..30b76fd9b 100644 --- a/flyteadmin/pkg/workflowengine/impl/prepare_execution.go +++ b/flyteadmin/pkg/workflowengine/impl/prepare_execution.go @@ -64,6 +64,14 @@ func addExecutionOverrides(taskPluginOverrides []*admin.PluginOverride, } executionConfig.OverwriteCache = workflowExecutionConfig.GetOverwriteCache() + + envs := make(map[string]string) + if workflowExecutionConfig.GetEnvs() != nil { + for _, v := range workflowExecutionConfig.GetEnvs().Values { + envs[v.Key] = v.Value + } + executionConfig.EnvironmentVariables = envs + } } if taskResources != nil { var requests = v1alpha1.TaskResourceSpec{} diff --git a/flyteadmin/pkg/workflowengine/impl/prepare_execution_test.go b/flyteadmin/pkg/workflowengine/impl/prepare_execution_test.go index 38e155636..60925c08b 100644 --- a/flyteadmin/pkg/workflowengine/impl/prepare_execution_test.go +++ b/flyteadmin/pkg/workflowengine/impl/prepare_execution_test.go @@ -166,6 +166,14 @@ func TestAddExecutionOverrides(t *testing.T) { addExecutionOverrides(nil, workflowExecutionConfig, nil, nil, workflow) assert.True(t, workflow.ExecutionConfig.OverwriteCache) }) + t.Run("Override environment variables", func(t *testing.T) { + workflowExecutionConfig := &admin.WorkflowExecutionConfig{ + Envs: &admin.Envs{Values: []*core.KeyValuePair{{Key: "key", Value: "value"}}}, + } + workflow := &v1alpha1.FlyteWorkflow{} + addExecutionOverrides(nil, workflowExecutionConfig, nil, nil, workflow) + assert.Equal(t, workflow.ExecutionConfig.EnvironmentVariables, map[string]string{"key": "value"}) + }) } func TestPrepareFlyteWorkflow(t *testing.T) {