Skip to content

Commit

Permalink
Add matchable plugin overrides for single task executions too (flyteo…
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan authored Nov 30, 2020
1 parent 5f77327 commit efb9474
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 20 deletions.
22 changes: 16 additions & 6 deletions pkg/manager/impl/execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func (m *ExecutionManager) addLabelsAndAnnotations(requestSpec *admin.ExecutionS
}

func (m *ExecutionManager) addPluginOverrides(ctx context.Context, executionID *core.WorkflowExecutionIdentifier,
workflowName, launchPlanName string, partiallyPopulatedInputs *workflowengineInterfaces.ExecuteWorkflowInput) error {
workflowName, launchPlanName string) ([]*admin.PluginOverride, error) {
override, err := m.resourceManager.GetResource(ctx, interfaces.ResourceRequest{
Project: executionID.Project,
Domain: executionID.Domain,
Expand All @@ -185,13 +185,13 @@ func (m *ExecutionManager) addPluginOverrides(ctx context.Context, executionID *
if err != nil {
ec, ok := err.(errors.FlyteAdminError)
if !ok || ec.Code() != codes.NotFound {
return err
return nil, err
}
}
if override != nil && override.Attributes != nil && override.Attributes.GetPluginOverrides() != nil {
partiallyPopulatedInputs.TaskPluginOverrides = override.Attributes.GetPluginOverrides().Overrides
return override.Attributes.GetPluginOverrides().Overrides, nil
}
return nil
return nil, nil
}

func (m *ExecutionManager) offloadInputs(ctx context.Context, literalMap *core.LiteralMap, identifier *core.WorkflowExecutionIdentifier, key string) (storage.DataReference, error) {
Expand Down Expand Up @@ -476,6 +476,14 @@ func (m *ExecutionManager) launchSingleTaskExecution(
executeTaskInputs.Annotations = request.Spec.Annotations.Values
}

overrides, err := m.addPluginOverrides(ctx, &workflowExecutionID, workflowExecutionID.Name, "")
if err != nil {
return nil, nil, err
}
if overrides != nil {
executeTaskInputs.TaskPluginOverrides = overrides
}

execInfo, err := m.workflowExecutor.ExecuteTask(ctx, executeTaskInputs)
if err != nil {
m.systemMetrics.PropellerFailures.Inc()
Expand Down Expand Up @@ -642,11 +650,13 @@ func (m *ExecutionManager) launchExecutionAndPrepareModel(
return nil, nil, err
}

err = m.addPluginOverrides(ctx, &workflowExecutionID, launchPlan.GetSpec().WorkflowId.Name, launchPlan.Id.Name,
&executeWorkflowInputs)
overrides, err := m.addPluginOverrides(ctx, &workflowExecutionID, launchPlan.GetSpec().WorkflowId.Name, launchPlan.Id.Name)
if err != nil {
return nil, nil, err
}
if overrides != nil {
executeWorkflowInputs.TaskPluginOverrides = overrides
}

execInfo, err := m.workflowExecutor.ExecuteWorkflow(ctx, executeWorkflowInputs)
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions pkg/manager/impl/execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2157,10 +2157,10 @@ func TestAddPluginOverrides(t *testing.T) {
db, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), workflowengineMocks.NewMockExecutor(),
mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil)

err := execManager.(*ExecutionManager).addPluginOverrides(
context.Background(), executionID, workflowName, launchPlanName, &partiallyPopulatedInputs)
taskPluginOverrides, err := execManager.(*ExecutionManager).addPluginOverrides(
context.Background(), executionID, workflowName, launchPlanName)
assert.NoError(t, err)
assert.Len(t, partiallyPopulatedInputs.TaskPluginOverrides, 2)
assert.Len(t, taskPluginOverrides, 2)
for _, override := range partiallyPopulatedInputs.TaskPluginOverrides {
if override.TaskType == "python" {
assert.EqualValues(t, []string{"plugin a"}, override.PluginId)
Expand Down Expand Up @@ -2190,8 +2190,8 @@ func TestPluginOverrides_ResourceGetFailure(t *testing.T) {
db, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), workflowengineMocks.NewMockExecutor(),
mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil)

err := execManager.(*ExecutionManager).addPluginOverrides(
context.Background(), executionID, workflowName, launchPlanName, &workflowengineInterfaces.ExecuteWorkflowInput{})
_, err := execManager.(*ExecutionManager).addPluginOverrides(
context.Background(), executionID, workflowName, launchPlanName)
assert.Error(t, err, "uh oh")
}

Expand Down
1 change: 1 addition & 0 deletions pkg/workflowengine/impl/propeller_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ func (c *FlytePropeller) ExecuteTask(ctx context.Context, input interfaces.Execu
flyteWf.Labels = labels
annotations := addMapValues(input.Annotations, flyteWf.Annotations)
flyteWf.Annotations = annotations
addExecutionOverrides(input.TaskPluginOverrides, flyteWf)

/*
TODO(katrogan): uncomment once propeller has updated the flyte workflow CRD.
Expand Down
19 changes: 10 additions & 9 deletions pkg/workflowengine/interfaces/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ type ExecuteWorkflowInput struct {
}

type ExecuteTaskInput struct {
ExecutionID *core.WorkflowExecutionIdentifier
WfClosure core.CompiledWorkflowClosure
Inputs *core.LiteralMap
ReferenceName string
Auth *admin.AuthRole
AcceptedAt time.Time
Labels map[string]string
Annotations map[string]string
QueueingBudget time.Duration
ExecutionID *core.WorkflowExecutionIdentifier
WfClosure core.CompiledWorkflowClosure
Inputs *core.LiteralMap
ReferenceName string
Auth *admin.AuthRole
AcceptedAt time.Time
Labels map[string]string
Annotations map[string]string
QueueingBudget time.Duration
TaskPluginOverrides []*admin.PluginOverride
}

type TerminateWorkflowInput struct {
Expand Down

0 comments on commit efb9474

Please sign in to comment.