Skip to content

Commit

Permalink
[flytepropeller] Watch agent metadata service dynamically (#5460)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
3 people authored Jun 28, 2024
1 parent 4643e2a commit 5b0d787
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 87 deletions.
25 changes: 24 additions & 1 deletion flyteplugins/go/tasks/pluginmachinery/core/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package core
import (
"context"
"fmt"
"sync"

"k8s.io/utils/strings/slices"
)

//go:generate mockery -all -case=underscore
Expand Down Expand Up @@ -55,7 +58,27 @@ type Plugin interface {
Finalize(ctx context.Context, tCtx TaskExecutionContext) error
}

// Loads and validates a plugin.
type AgentService struct {
mu sync.RWMutex
supportedTaskTypes []TaskType
CorePlugin Plugin
}

// ContainTaskType check if agent supports this task type.
func (p *AgentService) ContainTaskType(taskType TaskType) bool {
p.mu.RLock()
defer p.mu.RUnlock()
return slices.Contains(p.supportedTaskTypes, taskType)
}

// SetSupportedTaskType set supportTaskType in the agent service.
func (p *AgentService) SetSupportedTaskType(taskTypes []TaskType) {
p.mu.Lock()
defer p.mu.Unlock()
p.supportedTaskTypes = taskTypes
}

// LoadPlugin Loads and validates a plugin.
func LoadPlugin(ctx context.Context, iCtx SetupContext, entry PluginEntry) (Plugin, error) {
plugin, err := entry.LoadPlugin(ctx, iCtx)
if err != nil {
Expand Down
14 changes: 14 additions & 0 deletions flyteplugins/go/tasks/pluginmachinery/core/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,17 @@ func TestLoadPlugin(t *testing.T) {
})

}

func TestAgentService(t *testing.T) {
agentService := core.AgentService{}
taskTypes := []core.TaskType{"sensor", "chatgpt"}

for _, taskType := range taskTypes {
assert.Equal(t, false, agentService.ContainTaskType(taskType))
}

agentService.SetSupportedTaskType(taskTypes)
for _, taskType := range taskTypes {
assert.Equal(t, true, agentService.ContainTaskType(taskType))
}
}
41 changes: 22 additions & 19 deletions flyteplugins/go/tasks/plugins/webapi/agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,11 @@ func getFinalContext(ctx context.Context, operation string, agent *Deployment) (
return context.WithTimeout(ctx, timeout)
}

func updateAgentRegistry(ctx context.Context, cs *ClientSet) {
agentRegistry := make(Registry)
func getAgentRegistry(ctx context.Context, cs *ClientSet) Registry {
newAgentRegistry := make(Registry)
cfg := GetConfig()
var agentDeployments []*Deployment

// Ensure that the old configuration is backward compatible
for taskType, agentDeploymentID := range cfg.AgentForTaskTypes {
agent := Agent{AgentDeployment: cfg.AgentDeployments[agentDeploymentID], IsSync: false}
agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: &agent}
}

if len(cfg.DefaultAgent.Endpoint) != 0 {
agentDeployments = append(agentDeployments, &cfg.DefaultAgent)
}
Expand Down Expand Up @@ -137,27 +131,36 @@ func updateAgentRegistry(ctx context.Context, cs *ClientSet) {
deprecatedSupportedTaskTypes := agent.SupportedTaskTypes
for _, supportedTaskType := range deprecatedSupportedTaskTypes {
agent := &Agent{AgentDeployment: agentDeployment, IsSync: agent.IsSync}
agentRegistry[supportedTaskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
newAgentRegistry[supportedTaskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
}

supportedTaskCategories := agent.SupportedTaskCategories
for _, supportedCategory := range supportedTaskCategories {
agent := &Agent{AgentDeployment: agentDeployment, IsSync: agent.IsSync}
agentRegistry[supportedCategory.GetName()] = map[int32]*Agent{supportedCategory.GetVersion(): agent}
newAgentRegistry[supportedCategory.GetName()] = map[int32]*Agent{supportedCategory.GetVersion(): agent}
}
}
// If the agent doesn't implement the metadata service, we construct the registry based on the configuration
for taskType, agentDeploymentID := range cfg.AgentForTaskTypes {
if agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID]; ok {
if _, ok := agentRegistry[taskType]; !ok {
agent := &Agent{AgentDeployment: agentDeployment, IsSync: false}
agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
}
}

// If the agent doesn't implement the metadata service, we construct the registry based on the configuration
for taskType, agentDeploymentID := range cfg.AgentForTaskTypes {
if agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID]; ok {
if _, ok := newAgentRegistry[taskType]; !ok {
agent := &Agent{AgentDeployment: agentDeployment, IsSync: false}
newAgentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
}
}
}
logger.Debugf(ctx, "AgentDeployment service supports task types: %v", maps.Keys(agentRegistry))
setAgentRegistry(agentRegistry)

// Ensure that the old configuration is backward compatible
for _, taskType := range cfg.SupportedTaskTypes {
if _, ok := newAgentRegistry[taskType]; !ok {
agent := &Agent{AgentDeployment: &cfg.DefaultAgent, IsSync: false}
newAgentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
}
}

return newAgentRegistry
}

func getAgentClientSets(ctx context.Context) *ClientSet {
Expand Down
19 changes: 12 additions & 7 deletions flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ import (
)

func TestEndToEnd(t *testing.T) {
agentRegistry = Registry{
"openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}},
"spark": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: false}},
}
iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error {
return nil
}
Expand Down Expand Up @@ -117,7 +113,7 @@ func TestEndToEnd(t *testing.T) {
t.Run("failed to create a job", func(t *testing.T) {
agentPlugin := newMockAsyncAgentPlugin()
agentPlugin.PluginLoader = func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
return Plugin{
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
cs: &ClientSet{
Expand Down Expand Up @@ -259,6 +255,9 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext {

func newMockAsyncAgentPlugin() webapi.PluginEntry {
asyncAgentClient := new(agentMocks.AsyncAgentServiceClient)
agentRegistry := Registry{
"spark": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: false}},
}

mockCreateRequestMatcher := mock.MatchedBy(func(request *admin.CreateTaskRequest) bool {
expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "/tmp/123"}
Expand All @@ -283,20 +282,25 @@ func newMockAsyncAgentPlugin() webapi.PluginEntry {
ID: "agent-service",
SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark"},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
return Plugin{
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: &cfg,
cs: &ClientSet{
asyncAgentClients: map[string]service.AsyncAgentServiceClient{
defaultAgentEndpoint: asyncAgentClient,
},
},
registry: agentRegistry,
}, nil
},
}
}

func newMockSyncAgentPlugin() webapi.PluginEntry {
agentRegistry := Registry{
"openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}},
}

syncAgentClient := new(agentMocks.SyncAgentServiceClient)
output, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1})
resource := &admin.Resource{Phase: flyteIdlCore.TaskExecution_SUCCEEDED, Outputs: output}
Expand All @@ -323,14 +327,15 @@ func newMockSyncAgentPlugin() webapi.PluginEntry {
ID: "agent-service",
SupportedTaskTypes: []core.TaskType{"openai"},
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
return Plugin{
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: &cfg,
cs: &ClientSet{
syncAgentClients: map[string]service.SyncAgentServiceClient{
defaultAgentEndpoint: syncAgentClient,
},
},
registry: agentRegistry,
}, nil
},
}
Expand Down
Loading

0 comments on commit 5b0d787

Please sign in to comment.