Skip to content

Commit

Permalink
Introduce default plugin for task type (flyteorg#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan authored Nov 2, 2020
1 parent a2b0dd1 commit 7e4ed05
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 87 deletions.
3 changes: 3 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ tasks:
- container
- K8S-ARRAY
- qubole-hive-executor
default-for-task-type:
- container-array: k8s-array
- presto: my-presto
# Uncomment to enable sagemaker plugin
# - sagemaker_training
# - sagemaker_hyperparameter_tuning
Expand Down
5 changes: 1 addition & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ require (
github.com/Azure/go-autorest/autorest v0.10.0 // indirect
github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295
github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1
github.com/coreos/etcd v3.3.15+incompatible // indirect
github.com/coreos/go-oidc v2.2.1+incompatible // indirect
github.com/fatih/color v1.9.0
github.com/ghodss/yaml v1.0.0
Expand All @@ -23,11 +22,10 @@ require (
github.com/imdario/mergo v0.3.8 // indirect
github.com/lyft/datacatalog v0.2.1
github.com/lyft/flyteidl v0.18.9
github.com/lyft/flyteplugins v0.5.12
github.com/lyft/flyteplugins v0.5.14
github.com/lyft/flytestdlib v0.3.9
github.com/magiconair/properties v1.8.1
github.com/mattn/go-colorable v0.1.6 // indirect
github.com/mitchellh/go-ps v1.0.0 // indirect
github.com/mitchellh/mapstructure v1.1.2
github.com/ncw/swift v1.0.50 // indirect
github.com/pkg/errors v0.9.1
Expand All @@ -49,7 +47,6 @@ require (
k8s.io/kube-openapi v0.0.0-20200204173128-addea2498afe // indirect
k8s.io/utils v0.0.0-20200229041039-0a110f9eb7ab // indirect
sigs.k8s.io/controller-runtime v0.5.1
sigs.k8s.io/testing_frameworks v0.1.2 // indirect
sigs.k8s.io/yaml v1.2.0 // indirect
)

Expand Down
45 changes: 4 additions & 41 deletions go.sum

Large diffs are not rendered by default.

73 changes: 63 additions & 10 deletions pkg/controller/nodes/task/config/config.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package config

import (
"context"
"fmt"
"strings"
"time"

"github.com/lyft/flytestdlib/config"
"github.com/lyft/flytestdlib/logger"
"k8s.io/apimachinery/pkg/util/sets"

"github.com/lyft/flytestdlib/config"
)

//go:generate pflags Config --default-var defaultConfig
Expand All @@ -14,7 +18,7 @@ const SectionKey = "tasks"

var (
defaultConfig = &Config{
TaskPlugins: TaskPluginConfig{EnabledPlugins: []string{}},
TaskPlugins: TaskPluginConfig{EnabledPlugins: []string{}, DefaultForTaskTypes: map[string]string{}},
MaxPluginPhaseVersions: 100000,
BarrierConfig: BarrierConfig{
Enabled: true,
Expand Down Expand Up @@ -46,22 +50,71 @@ type BarrierConfig struct {
}

type TaskPluginConfig struct {
EnabledPlugins []string `json:"enabled-plugins" pflag:",Plugins enabled currently"`
EnabledPlugins []string `json:"enabled-plugins" pflag:",deprecated"`
// Maps task types to their plugin handler (by ID).
DefaultForTaskTypes map[string]string `json:"default-for-task-types" pflag:"-,"`
}

type BackOffConfig struct {
BaseSecond int `json:"base-second" pflag:",The number of seconds representing the base duration of the exponential backoff"`
MaxDuration config.Duration `json:"max-duration" pflag:",The cap of the backoff duration"`
}

func (p TaskPluginConfig) GetEnabledPluginsSet() sets.String {
s := sets.NewString()
for _, e := range p.EnabledPlugins {
cleanedPluginName := strings.Trim(e, " ")
cleanedPluginName = strings.ToLower(cleanedPluginName)
s.Insert(cleanedPluginName)
type PluginID = string
type TaskType = string

// Contains the set of enabled plugins for this flytepropeller deployment along with default plugin handlers
// for specific task types.
type PluginsConfigMeta struct {
EnabledPlugins sets.String
AllDefaultForTaskTypes map[PluginID][]TaskType
}

func cleanString(source string) string {
cleaned := strings.Trim(source, " ")
cleaned = strings.ToLower(cleaned)
return cleaned
}

func (p TaskPluginConfig) GetEnabledPlugins() (PluginsConfigMeta, error) {
enabledPluginsNames := sets.NewString()
for _, pluginName := range p.EnabledPlugins {
cleanedPluginName := cleanString(pluginName)
enabledPluginsNames.Insert(cleanedPluginName)
}

pluginDefaultForTaskType := make(map[PluginID][]TaskType)
// Reverse the DefaultForTaskTypes map. Having the config use task type as a key guarantees only one default plugin can be specified per
// task type but now we need to sort for which tasks a plugin needs to be the default.
for taskName, pluginName := range p.DefaultForTaskTypes {
existing, found := pluginDefaultForTaskType[pluginName]
if !found {
existing = make([]string, 0, 1)
}
pluginDefaultForTaskType[cleanString(pluginName)] = append(existing, cleanString(taskName))
}

// All plugins are enabled, nothing further to validate here.
if enabledPluginsNames.Len() == 0 {
return PluginsConfigMeta{
EnabledPlugins: enabledPluginsNames,
AllDefaultForTaskTypes: pluginDefaultForTaskType,
}, nil
}

// Finally, validate that default plugins for task types only reference enabled plugins
for pluginName, taskTypes := range pluginDefaultForTaskType {
if !enabledPluginsNames.Has(pluginName) {
logger.Errorf(context.TODO(), "Cannot set default plugin [%s] for task types [%+v] when it is not "+
"configured to be an enabled plugin. Please double check the flytepropeller config.", pluginName, taskTypes)
return PluginsConfigMeta{}, fmt.Errorf("cannot set default plugin [%s] for task types [%+v] when it is not "+
"configured to be an enabled plugin", pluginName, taskTypes)
}
}
return s
return PluginsConfigMeta{
EnabledPlugins: enabledPluginsNames,
AllDefaultForTaskTypes: pluginDefaultForTaskType,
}, nil
}

func GetConfig() *Config {
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller/nodes/task/config/config_flags.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 40 additions & 4 deletions pkg/controller/nodes/task/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ type PluginRegistryIface interface {
GetK8sPlugins() []pluginK8s.PluginEntry
}

type taskType = string
type pluginID = string

type Handler struct {
Expand Down Expand Up @@ -204,28 +205,49 @@ func (t *Handler) Setup(ctx context.Context, sCtx handler.SetupContext) error {
return err
}

// Not every task type will have a default plugin specified in the flytepropeller config.
// That's fine, we resort to using the plugins' static RegisteredTaskTypes as a fallback further below.
fallbackTaskHandlerMap := make(map[taskType]map[pluginID]pluginCore.Plugin)

for _, p := range enabledPlugins {
// create a new resource registrar proxy for each plugin, and pass it into the plugin's LoadPlugin() via a setup context
pluginResourceNamespacePrefix := pluginCore.ResourceNamespace(newResourceManagerBuilder.GetID()).CreateSubNamespace(pluginCore.ResourceNamespace(p.ID))
sCtxFinal := newNameSpacedSetupCtx(
tSCtx, newResourceManagerBuilder.GetResourceRegistrar(pluginResourceNamespacePrefix))
logger.Infof(ctx, "Loading Plugin [%s] ENABLED", p.ID)
// cp, err := p.LoadPlugin(ctx, tSCtx)
cp, err := p.LoadPlugin(ctx, sCtxFinal)
if err != nil {
return regErrors.Wrapf(err, "failed to load plugin - %s", p.ID)
}
// For every default plugin for a task type specified in flytepropeller config we validate that the plugin's
// static definition includes that task type as something it is registered to handle.
for _, tt := range p.RegisteredTaskTypes {
logger.Infof(ctx, "Plugin [%s] registered for TaskType [%s]", cp.GetID(), tt)
// TODO(katrogan): Make the default task plugin assignment more explicit (https://github.com/lyft/flyte/issues/516)
t.defaultPlugins[tt] = cp
for _, defaultTaskType := range p.DefaultForTaskTypes {
if defaultTaskType == tt {
if existingHandler, alreadyDefaulted := t.defaultPlugins[tt]; alreadyDefaulted && existingHandler.GetID() != cp.GetID() {
logger.Errorf(ctx, "TaskType [%s] has multiple default handlers specified: [%s] and [%s]",
tt, existingHandler.GetID(), cp.GetID())
return regErrors.New(fmt.Sprintf("TaskType [%s] has multiple default handlers specified: [%s] and [%s]",
tt, existingHandler.GetID(), cp.GetID()))
}
logger.Infof(ctx, "Plugin [%s] registered for TaskType [%s]", cp.GetID(), tt)
t.defaultPlugins[tt] = cp
}
}

pluginsForTaskType, ok := t.pluginsForType[tt]
if !ok {
pluginsForTaskType = make(map[pluginID]pluginCore.Plugin)
}
pluginsForTaskType[cp.GetID()] = cp
t.pluginsForType[tt] = pluginsForTaskType

fallbackMap, ok := fallbackTaskHandlerMap[tt]
if !ok {
fallbackMap = make(map[pluginID]pluginCore.Plugin)
}
fallbackMap[cp.GetID()] = cp
fallbackTaskHandlerMap[tt] = fallbackMap
}
if p.IsDefault {
if err := t.setDefault(ctx, cp); err != nil {
Expand All @@ -234,6 +256,20 @@ func (t *Handler) Setup(ctx context.Context, sCtx handler.SetupContext) error {
}
}

// Read from the fallback task handler map for any remaining tasks without a defaultPlugins registered handler.
for taskType, registeredPlugins := range fallbackTaskHandlerMap {
if _, ok := t.defaultPlugins[taskType]; ok {
break
}
if len(registeredPlugins) != 1 {
logger.Errorf(ctx, "Multiple plugins registered to handle task type: %s. ([%+v])", taskType, registeredPlugins)
return regErrors.New(fmt.Sprintf("Multiple plugins registered to handle task type: %s. ([%+v]). Use default-for-task-type config option to choose the desired plugin.", taskType, registeredPlugins))
}
for _, plugin := range registeredPlugins {
t.defaultPlugins[taskType] = plugin
}
}

rm, err := newResourceManagerBuilder.BuildResourceManager(ctx)
if err != nil {
logger.Errorf(ctx, "Failed to build a resource manager")
Expand Down
52 changes: 37 additions & 15 deletions pkg/controller/nodes/task/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"testing"
"time"

pluginK8sMocks "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s/mocks"

"github.com/lyft/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/lyft/flytepropeller/pkg/apis/flyteworkflow/v1alpha1"

Expand All @@ -28,7 +30,6 @@ import (
"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io"
ioMocks "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io/mocks"
pluginK8s "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s"
pluginK8sMocks "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/k8s/mocks"
"github.com/lyft/flytestdlib/promutils"
"github.com/lyft/flytestdlib/storage"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -150,41 +151,60 @@ func Test_task_Setup(t *testing.T) {
defaultPluginID string
}
tests := []struct {
name string
registry PluginRegistryIface
fields wantFields
wantErr bool
name string
registry PluginRegistryIface
enabledPlugins []string
defaultForTaskTypes map[string]string
fields wantFields
wantErr bool
}{
{"no-plugins", testPluginRegistry{}, wantFields{}, false},
{"no-plugins", testPluginRegistry{}, []string{}, map[string]string{}, wantFields{}, false},
{"no-default-only-core", testPluginRegistry{
core: []pluginCore.PluginEntry{corePluginEntry}, k8s: []pluginK8s.PluginEntry{},
}, wantFields{
pluginIDs: map[pluginCore.TaskType]string{corePluginType: corePluginType},
}, false},
}, []string{corePluginType}, map[string]string{
corePluginType: corePluginType},
wantFields{
pluginIDs: map[pluginCore.TaskType]string{corePluginType: corePluginType},
}, false},
{"no-default-only-k8s", testPluginRegistry{
core: []pluginCore.PluginEntry{}, k8s: []pluginK8s.PluginEntry{k8sPluginEntry},
}, []string{k8sPluginType}, map[string]string{
k8sPluginType: k8sPluginType},
wantFields{
pluginIDs: map[pluginCore.TaskType]string{k8sPluginType: k8sPluginType},
}, false},
{"no-default", testPluginRegistry{}, []string{corePluginType, k8sPluginType}, map[string]string{
corePluginType: corePluginType,
k8sPluginType: k8sPluginType,
}, wantFields{
pluginIDs: map[pluginCore.TaskType]string{k8sPluginType: k8sPluginType},
}, false},
{"no-default", testPluginRegistry{
core: []pluginCore.PluginEntry{corePluginEntry}, k8s: []pluginK8s.PluginEntry{k8sPluginEntry},
}, wantFields{
pluginIDs: map[pluginCore.TaskType]string{corePluginType: corePluginType, k8sPluginType: k8sPluginType},
pluginIDs: map[pluginCore.TaskType]string{},
}, false},
{"only-default-core", testPluginRegistry{
core: []pluginCore.PluginEntry{corePluginEntry, corePluginEntryDefault}, k8s: []pluginK8s.PluginEntry{k8sPluginEntry},
}, []string{corePluginType, corePluginDefaultType, k8sPluginType}, map[string]string{
corePluginType: corePluginType,
corePluginDefaultType: corePluginDefaultType,
k8sPluginType: k8sPluginType,
}, wantFields{
pluginIDs: map[pluginCore.TaskType]string{corePluginType: corePluginType, corePluginDefaultType: corePluginDefaultType, k8sPluginType: k8sPluginType},
defaultPluginID: corePluginDefaultType,
}, false},
{"only-default-k8s", testPluginRegistry{
core: []pluginCore.PluginEntry{corePluginEntry}, k8s: []pluginK8s.PluginEntry{k8sPluginEntryDefault},
}, []string{corePluginType, k8sPluginDefaultType}, map[string]string{
corePluginType: corePluginType,
k8sPluginDefaultType: k8sPluginDefaultType,
}, wantFields{
pluginIDs: map[pluginCore.TaskType]string{corePluginType: corePluginType, k8sPluginDefaultType: k8sPluginDefaultType},
defaultPluginID: k8sPluginDefaultType,
}, false},
{"default-both", testPluginRegistry{
core: []pluginCore.PluginEntry{corePluginEntry, corePluginEntryDefault}, k8s: []pluginK8s.PluginEntry{k8sPluginEntry, k8sPluginEntryDefault},
}, []string{corePluginType, corePluginDefaultType, k8sPluginType, k8sPluginDefaultType}, map[string]string{
corePluginType: corePluginType,
corePluginDefaultType: corePluginDefaultType,
k8sPluginType: k8sPluginType,
k8sPluginDefaultType: k8sPluginDefaultType,
}, wantFields{
pluginIDs: map[pluginCore.TaskType]string{corePluginType: corePluginType, corePluginDefaultType: corePluginDefaultType, k8sPluginType: k8sPluginType, k8sPluginDefaultType: k8sPluginDefaultType},
defaultPluginID: corePluginDefaultType,
Expand All @@ -200,6 +220,8 @@ func Test_task_Setup(t *testing.T) {
sCtx.On("MetricsScope").Return(promutils.NewTestScope())

tk, err := New(context.TODO(), mocks.NewFakeKubeClient(), &pluginCatalogMocks.Client{}, promutils.NewTestScope())
tk.cfg.TaskPlugins.EnabledPlugins = tt.enabledPlugins
tk.cfg.TaskPlugins.DefaultForTaskTypes = tt.defaultForTaskTypes
assert.NoError(t, err)
tk.pluginRegistry = tt.registry
if err := tk.Setup(context.TODO(), sCtx); err != nil {
Expand Down
Loading

0 comments on commit 7e4ed05

Please sign in to comment.