Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Watch agent metadata service #5017

Merged
merged 22 commits into from
Jun 8, 2024
71 changes: 41 additions & 30 deletions flyteplugins/go/tasks/plugins/webapi/agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import (
"context"
"crypto/x509"
"fmt"

"golang.org/x/exp/maps"
"google.golang.org/grpc"
Expand Down Expand Up @@ -98,8 +97,7 @@
return context.WithTimeout(ctx, timeout)
}

func initializeAgentRegistry(cs *ClientSet) (Registry, error) {
logger.Infof(context.Background(), "Initializing agent registry")
func updateAgentRegistry(ctx context.Context, cs *ClientSet) {
agentRegistry := make(Registry)
cfg := GetConfig()
var agentDeployments []*Deployment
Expand All @@ -115,25 +113,31 @@
}
agentDeployments = append(agentDeployments, maps.Values(cfg.AgentDeployments)...)
for _, agentDeployment := range agentDeployments {
client := cs.agentMetadataClients[agentDeployment.Endpoint]
client, ok := cs.agentMetadataClients[agentDeployment.Endpoint]
if !ok {
logger.Warningf(ctx, "Agent client not found in the clientSet for the endpoint: %v", agentDeployment.Endpoint)
continue

Check warning on line 119 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L118-L119

Added lines #L118 - L119 were not covered by tests
}

finalCtx, cancel := getFinalContext(context.Background(), "ListAgents", agentDeployment)
finalCtx, cancel := getFinalContext(ctx, "ListAgents", agentDeployment)
defer cancel()

res, err := client.ListAgents(finalCtx, &admin.ListAgentsRequest{})
if err != nil {
grpcStatus, ok := status.FromError(err)
if grpcStatus.Code() == codes.Unimplemented {
// we should not panic here, as we want to continue to support old agent settings
logger.Infof(context.Background(), "list agent method not implemented for agent: [%v]", agentDeployment)
logger.Warningf(finalCtx, "list agent method not implemented for agent: [%v]", agentDeployment.Endpoint)

Check warning on line 130 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L130

Added line #L130 was not covered by tests
continue
}

if !ok {
return nil, fmt.Errorf("failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment, err)
logger.Errorf(finalCtx, "failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment.Endpoint, err)
continue

Check warning on line 136 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L135-L136

Added lines #L135 - L136 were not covered by tests
}

return nil, fmt.Errorf("failed to list agent: [%v] with error: [%v]", agentDeployment, err)
logger.Errorf(finalCtx, "failed to list agent: [%v] with error: [%v]", agentDeployment.Endpoint, err)
continue

Check warning on line 140 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L139-L140

Added lines #L139 - L140 were not covered by tests
}

for _, agent := range res.GetAgents() {
Expand All @@ -148,20 +152,27 @@
agent := &Agent{AgentDeployment: agentDeployment, IsSync: agent.IsSync}
agentRegistry[supportedCategory.GetName()] = map[int32]*Agent{supportedCategory.GetVersion(): agent}
}
logger.Infof(context.Background(), "[%v] is a sync agent: [%v]", agent.Name, agent.IsSync)
logger.Infof(context.Background(), "[%v] supports task category: [%v]", agent.Name, supportedTaskCategories)
}
// If the agent doesn't implement the metadata service, we construct the registry based on the configuration
for taskType, agentDeploymentID := range cfg.AgentForTaskTypes {
if agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID]; ok {
if _, ok := agentRegistry[taskType]; !ok {
agent := &Agent{AgentDeployment: agentDeployment, IsSync: false}
agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}

Check warning on line 161 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L159-L161

Added lines #L159 - L161 were not covered by tests
}
}
}
}

return agentRegistry, nil
logger.Debugf(ctx, "AgentDeployment service supports task types: %v", maps.Keys(agentRegistry))
setAgentRegistry(agentRegistry)
}

func initializeClients(ctx context.Context) (*ClientSet, error) {
logger.Infof(ctx, "Initializing agent clients")

asyncAgentClients := make(map[string]service.AsyncAgentServiceClient)
syncAgentClients := make(map[string]service.SyncAgentServiceClient)
agentMetadataClients := make(map[string]service.AgentMetadataServiceClient)
func getAgentClientSets(ctx context.Context) *ClientSet {
clientSet := &ClientSet{
asyncAgentClients: make(map[string]service.AsyncAgentServiceClient),
syncAgentClients: make(map[string]service.SyncAgentServiceClient),
agentMetadataClients: make(map[string]service.AgentMetadataServiceClient),
}

var agentDeployments []*Deployment
cfg := GetConfig()
Expand All @@ -170,19 +181,19 @@
agentDeployments = append(agentDeployments, &cfg.DefaultAgent)
}
agentDeployments = append(agentDeployments, maps.Values(cfg.AgentDeployments)...)
for _, agentService := range agentDeployments {
conn, err := getGrpcConnection(ctx, agentService)
for _, agentDeployment := range agentDeployments {
if _, ok := clientSet.agentMetadataClients[agentDeployment.Endpoint]; ok {
logger.Infof(ctx, "Agent client already initialized for [%v]", agentDeployment.Endpoint)
continue

Check warning on line 187 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L186-L187

Added lines #L186 - L187 were not covered by tests
}
conn, err := getGrpcConnection(ctx, agentDeployment)
if err != nil {
return nil, err
logger.Errorf(ctx, "failed to create connection to agent: [%v] with error: [%v]", agentDeployment, err)
continue

Check warning on line 192 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L191-L192

Added lines #L191 - L192 were not covered by tests
}
syncAgentClients[agentService.Endpoint] = service.NewSyncAgentServiceClient(conn)
asyncAgentClients[agentService.Endpoint] = service.NewAsyncAgentServiceClient(conn)
agentMetadataClients[agentService.Endpoint] = service.NewAgentMetadataServiceClient(conn)
clientSet.syncAgentClients[agentDeployment.Endpoint] = service.NewSyncAgentServiceClient(conn)
clientSet.asyncAgentClients[agentDeployment.Endpoint] = service.NewAsyncAgentServiceClient(conn)
clientSet.agentMetadataClients[agentDeployment.Endpoint] = service.NewAgentMetadataServiceClient(conn)
}

return &ClientSet{
syncAgentClients: syncAgentClients,
asyncAgentClients: asyncAgentClients,
agentMetadataClients: agentMetadataClients,
}, nil
return clientSet
}
4 changes: 1 addition & 3 deletions flyteplugins/go/tasks/plugins/webapi/agent/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ func TestInitializeClients(t *testing.T) {
ctx := context.Background()
err := SetConfig(&cfg)
assert.NoError(t, err)
cs, err := initializeClients(ctx)
assert.NoError(t, err)
assert.NotNil(t, cs)
cs := getAgentClientSets(ctx)
_, ok := cs.syncAgentClients["y"]
assert.True(t, ok)
_, ok = cs.asyncAgentClients["x"]
Expand Down
4 changes: 4 additions & 0 deletions flyteplugins/go/tasks/plugins/webapi/agent/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ var (
// AsyncPlugin should be registered to at least one task type.
// Reference: https://github.com/flyteorg/flyte/blob/master/flyteplugins/go/tasks/pluginmachinery/registry.go#L27
SupportedTaskTypes: []string{"task_type_1", "task_type_2"},
PollInterval: config.Duration{Duration: 10 * time.Second},
}

configSection = pluginsConfig.MustRegisterSubSection("agent-service", &defaultConfig)
Expand All @@ -71,6 +72,9 @@ type Config struct {

// SupportedTaskTypes is a list of task types that are supported by this plugin.
SupportedTaskTypes []string `json:"supportedTaskTypes" pflag:"-,Defines a list of task types that are supported by this plugin."`

// PollInterval is the interval at which the plugin should poll the agent for metadata updates
PollInterval config.Duration `json:"pollInterval" pflag:",The interval at which the plugin should poll the agent for metadata updates."`
}

type Deployment struct {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ import (
)

func TestEndToEnd(t *testing.T) {
agentRegistry = Registry{
"openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}},
"spark": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: false}},
}
iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error {
return nil
}
Expand Down Expand Up @@ -118,6 +122,7 @@ func TestEndToEnd(t *testing.T) {
cfg: GetConfig(),
cs: &ClientSet{
asyncAgentClients: map[string]service.AsyncAgentServiceClient{},
syncAgentClients: map[string]service.SyncAgentServiceClient{},
agentMetadataClients: map[string]service.AgentMetadataServiceClient{},
},
}, nil
Expand Down Expand Up @@ -326,7 +331,6 @@ func newMockSyncAgentPlugin() webapi.PluginEntry {
defaultAgentEndpoint: syncAgentClient,
},
},
agentRegistry: Registry{"openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}}},
}, nil
},
}
Expand Down
77 changes: 47 additions & 30 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
"context"
"encoding/gob"
"fmt"
"sync"
"time"

"golang.org/x/exp/maps"
"k8s.io/apimachinery/pkg/util/wait"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
flyteIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
Expand All @@ -24,11 +26,27 @@

type Registry map[string]map[int32]*Agent // map[taskTypeName][taskTypeVersion] => Agent

type Plugin struct {
metricScope promutils.Scope
cfg *Config
cs *ClientSet
var (
agentRegistry Registry
mu sync.RWMutex
)

func getAgentRegistry() Registry {
mu.Lock()
defer mu.Unlock()
return agentRegistry
}

func setAgentRegistry(r Registry) {
mu.Lock()
defer mu.Unlock()
agentRegistry = r
}

type Plugin struct {
metricScope promutils.Scope
cfg *Config
cs *ClientSet
}

type ResourceWrapper struct {
Expand Down Expand Up @@ -95,7 +113,7 @@
outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String()

taskCategory := admin.TaskCategory{Name: taskTemplate.Type, Version: taskTemplate.TaskTypeVersion}
agent, isSync := getFinalAgent(&taskCategory, p.cfg, p.agentRegistry)
agent, isSync := getFinalAgent(&taskCategory, p.cfg)

taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())

Expand Down Expand Up @@ -193,7 +211,7 @@

func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) {
metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper)
agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg, p.agentRegistry)
agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg)

client, err := p.getAsyncAgentClient(ctx, agent)
if err != nil {
Expand Down Expand Up @@ -226,7 +244,7 @@
return nil
}
metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper)
agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg, p.agentRegistry)
agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg)

Check warning on line 247 in flyteplugins/go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/plugin.go#L247

Added line #L247 was not covered by tests

client, err := p.getAsyncAgentClient(ctx, agent)
if err != nil {
Expand Down Expand Up @@ -322,6 +340,13 @@
return client, nil
}

func (p Plugin) watchAgents(ctx context.Context) {
go wait.Until(func() {
clientSet := getAgentClientSets(ctx)
updateAgentRegistry(ctx, clientSet)
}, p.cfg.PollInterval.Duration, ctx.Done())

Check warning on line 347 in flyteplugins/go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/plugin.go#L343-L347

Added lines #L343 - L347 were not covered by tests
}

func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *flyteIdl.LiteralMap) error {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
if err != nil {
Expand All @@ -344,11 +369,11 @@
return taskCtx.OutputWriter().Put(ctx, opReader)
}

func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config, agentRegistry Registry) (*Deployment, bool) {
if agent, exists := agentRegistry[taskCategory.Name][taskCategory.Version]; exists {
func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config) (*Deployment, bool) {
r := getAgentRegistry()
if agent, exists := r[taskCategory.Name][taskCategory.Version]; exists {
return agent.AgentDeployment, agent.IsSync
}

return &cfg.DefaultAgent, false
}

Expand All @@ -367,38 +392,30 @@
}

func newAgentPlugin() webapi.PluginEntry {
cs, err := initializeClients(context.Background())
if err != nil {
// We should wait for all agents to be up and running before starting the server
panic(fmt.Sprintf("failed to initialize clients with error: %v", err))
}

agentRegistry, err := initializeAgentRegistry(cs)
if err != nil {
panic(fmt.Sprintf("failed to initialize agent registry with error: %v", err))
}

ctx := context.Background()

Check warning on line 395 in flyteplugins/go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/plugin.go#L395

Added line #L395 was not covered by tests
cfg := GetConfig()
supportedTaskTypes := append(maps.Keys(agentRegistry), cfg.SupportedTaskTypes...)
logger.Infof(context.Background(), "AgentDeployment service supports task types: %v", supportedTaskTypes)

clientSet := getAgentClientSets(ctx)
updateAgentRegistry(ctx, clientSet)
supportedTaskTypes := append(maps.Keys(getAgentRegistry()), cfg.SupportedTaskTypes...)

Check warning on line 400 in flyteplugins/go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/plugin.go#L398-L400

Added lines #L398 - L400 were not covered by tests

return webapi.PluginEntry{
ID: "agent-service",
SupportedTaskTypes: supportedTaskTypes,
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: cfg,
cs: cs,
agentRegistry: agentRegistry,
}, nil
plugin := &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: cfg,
cs: clientSet,

Check warning on line 409 in flyteplugins/go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/plugin.go#L406-L409

Added lines #L406 - L409 were not covered by tests
}
plugin.watchAgents(ctx)
return plugin, nil

Check warning on line 412 in flyteplugins/go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/plugin.go#L411-L412

Added lines #L411 - L412 were not covered by tests
},
}
}

func RegisterAgentPlugin() {
gob.Register(ResourceMetaWrapper{})
gob.Register(ResourceWrapper{})

pluginmachinery.PluginRegistry().RegisterRemotePlugin(newAgentPlugin())
}
13 changes: 6 additions & 7 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ func TestPlugin(t *testing.T) {

t.Run("test getFinalAgent", func(t *testing.T) {
agent := &Agent{AgentDeployment: &Deployment{Endpoint: "localhost:80"}}
agentRegistry := Registry{"spark": {defaultTaskTypeVersion: agent}}
agentRegistry = Registry{"spark": {defaultTaskTypeVersion: agent}}
spark := &admin.TaskCategory{Name: "spark", Version: defaultTaskTypeVersion}
foo := &admin.TaskCategory{Name: "foo", Version: defaultTaskTypeVersion}
bar := &admin.TaskCategory{Name: "bar", Version: defaultTaskTypeVersion}
agentDeployment, _ := getFinalAgent(spark, &cfg, agentRegistry)
agentDeployment, _ := getFinalAgent(spark, &cfg)
assert.Equal(t, agentDeployment.Endpoint, "localhost:80")
agentDeployment, _ = getFinalAgent(foo, &cfg, agentRegistry)
agentDeployment, _ = getFinalAgent(foo, &cfg)
assert.Equal(t, agentDeployment.Endpoint, cfg.DefaultAgent.Endpoint)
agentDeployment, _ = getFinalAgent(bar, &cfg, agentRegistry)
agentDeployment, _ = getFinalAgent(bar, &cfg)
assert.Equal(t, agentDeployment.Endpoint, cfg.DefaultAgent.Endpoint)
})

Expand Down Expand Up @@ -318,11 +318,10 @@ func TestInitializeAgentRegistry(t *testing.T) {
cfg.AgentForTaskTypes = map[string]string{"task1": "agent-deployment-1", "task2": "agent-deployment-2"}
err := SetConfig(&cfg)
assert.NoError(t, err)
agentRegistry, err := initializeAgentRegistry(cs)
assert.NoError(t, err)
updateAgentRegistry(context.Background(), cs)

// In golang, the order of keys in a map is random. So, we sort the keys before asserting.
agentRegistryKeys := maps.Keys(agentRegistry)
agentRegistryKeys := maps.Keys(getAgentRegistry())
sort.Strings(agentRegistryKeys)

assert.Equal(t, agentRegistryKeys, []string{"task1", "task2", "task3"})
Expand Down
Loading