Skip to content

Commit

Permalink
Use agent as name where it fits (flyteorg#381)
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix authored Aug 2, 2023
1 parent c17fb58 commit 76171c0
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 43 deletions.
4 changes: 2 additions & 2 deletions flyteplugins/go/tasks/plugins/webapi/agent/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ func TestGetAndSetConfig(t *testing.T) {
}
cfg.DefaultAgent.DefaultTimeout = config.Duration{Duration: 10 * time.Second}
cfg.Agents = map[string]*Agent{
"endpoint_1": {
"agent_1": {
Insecure: cfg.DefaultAgent.Insecure,
DefaultServiceConfig: cfg.DefaultAgent.DefaultServiceConfig,
Timeouts: cfg.DefaultAgent.Timeouts,
},
}
cfg.AgentForTaskTypes = map[string]string{"task_type_1": "endpoint_1"}
cfg.AgentForTaskTypes = map[string]string{"task_type_1": "agent_1"}
err := SetConfig(&cfg)
assert.NoError(t, err)
assert.Equal(t, &cfg, GetConfig())
Expand Down
62 changes: 31 additions & 31 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
"google.golang.org/grpc"
)

type GetClientFunc func(ctx context.Context, endpoint *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error)
type GetClientFunc func(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error)

type Plugin struct {
metricScope promutils.Scope
Expand Down Expand Up @@ -70,16 +70,16 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR

outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String()

endpoint, err := getFinalEndpoint(taskTemplate.Type, p.cfg)
agent, err := getFinalAgent(taskTemplate.Type, p.cfg)
if err != nil {
return nil, nil, fmt.Errorf("failed to find agent endpoint with error: %v", err)
return nil, nil, fmt.Errorf("failed to find agent agent with error: %v", err)
}
client, err := p.getClient(ctx, endpoint, p.connectionCache)
client, err := p.getClient(ctx, agent, p.connectionCache)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect to agent with error: %v", err)
}

finalCtx, cancel := getFinalContext(ctx, "CreateTask", endpoint)
finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent)
defer cancel()

taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())
Expand All @@ -99,16 +99,16 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) {
metadata := taskCtx.ResourceMeta().(*ResourceMetaWrapper)

endpoint, err := getFinalEndpoint(metadata.TaskType, p.cfg)
agent, err := getFinalAgent(metadata.TaskType, p.cfg)
if err != nil {
return nil, fmt.Errorf("failed to find agent endpoint with error: %v", err)
return nil, fmt.Errorf("failed to find agent with error: %v", err)
}
client, err := p.getClient(ctx, endpoint, p.connectionCache)
client, err := p.getClient(ctx, agent, p.connectionCache)
if err != nil {
return nil, fmt.Errorf("failed to connect to agent with error: %v", err)
}

finalCtx, cancel := getFinalContext(ctx, "GetTask", endpoint)
finalCtx, cancel := getFinalContext(ctx, "GetTask", agent)
defer cancel()

res, err := client.GetTask(finalCtx, &admin.GetTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta})
Expand All @@ -128,16 +128,16 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error
}
metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper)

endpoint, err := getFinalEndpoint(metadata.TaskType, p.cfg)
agent, err := getFinalAgent(metadata.TaskType, p.cfg)
if err != nil {
return fmt.Errorf("failed to find agent endpoint with error: %v", err)
return fmt.Errorf("failed to find agent agent with error: %v", err)
}
client, err := p.getClient(ctx, endpoint, p.connectionCache)
client, err := p.getClient(ctx, agent, p.connectionCache)
if err != nil {
return fmt.Errorf("failed to connect to agent with error: %v", err)
}

finalCtx, cancel := getFinalContext(ctx, "DeleteTask", endpoint)
finalCtx, cancel := getFinalContext(ctx, "DeleteTask", agent)
defer cancel()

_, err = client.DeleteTask(finalCtx, &admin.DeleteTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta})
Expand Down Expand Up @@ -167,26 +167,26 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase
return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.State)
}

func getFinalEndpoint(taskType string, cfg *Config) (*Agent, error) {
func getFinalAgent(taskType string, cfg *Config) (*Agent, error) {
if id, exists := cfg.AgentForTaskTypes[taskType]; exists {
if endpoint, exists := cfg.Agents[id]; exists {
return endpoint, nil
if agent, exists := cfg.Agents[id]; exists {
return agent, nil
}
return nil, fmt.Errorf("no endpoint definition found for ID %s that matches task type %s", id, taskType)
return nil, fmt.Errorf("no agent definition found for ID %s that matches task type %s", id, taskType)
}

return &cfg.DefaultAgent, nil
}

func getClientFunc(ctx context.Context, endpoint *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
conn, ok := connectionCache[endpoint]
func getClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
conn, ok := connectionCache[agent]
if ok {
return service.NewAsyncAgentServiceClient(conn), nil
}

var opts []grpc.DialOption

if endpoint.Insecure {
if agent.Insecure {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
} else {
pool, err := x509.SystemCertPool()
Expand All @@ -198,27 +198,27 @@ func getClientFunc(ctx context.Context, endpoint *Agent, connectionCache map[*Ag
opts = append(opts, grpc.WithTransportCredentials(creds))
}

if len(endpoint.DefaultServiceConfig) != 0 {
opts = append(opts, grpc.WithDefaultServiceConfig(endpoint.DefaultServiceConfig))
if len(agent.DefaultServiceConfig) != 0 {
opts = append(opts, grpc.WithDefaultServiceConfig(agent.DefaultServiceConfig))
}

var err error
conn, err = grpc.Dial(endpoint.Endpoint, opts...)
conn, err = grpc.Dial(agent.Endpoint, opts...)
if err != nil {
return nil, err
}
connectionCache[endpoint] = conn
connectionCache[agent] = conn
defer func() {
if err != nil {
if cerr := conn.Close(); cerr != nil {
grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr)
grpclog.Infof("Failed to close conn to %s: %v", agent, cerr)
}
return
}
go func() {
<-ctx.Done()
if cerr := conn.Close(); cerr != nil {
grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr)
grpclog.Infof("Failed to close conn to %s: %v", agent, cerr)
}
}()
}()
Expand All @@ -237,16 +237,16 @@ func buildTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecutionM
}
}

func getFinalTimeout(operation string, endpoint *Agent) config.Duration {
if t, exists := endpoint.Timeouts[operation]; exists {
func getFinalTimeout(operation string, agent *Agent) config.Duration {
if t, exists := agent.Timeouts[operation]; exists {
return t
}

return endpoint.DefaultTimeout
return agent.DefaultTimeout
}

func getFinalContext(ctx context.Context, operation string, endpoint *Agent) (context.Context, context.CancelFunc) {
timeout := getFinalTimeout(operation, endpoint).Duration
func getFinalContext(ctx context.Context, operation string, agent *Agent) (context.Context, context.CancelFunc) {
timeout := getFinalTimeout(operation, agent).Duration
if timeout == 0 {
return ctx, func() {}
}
Expand Down
20 changes: 10 additions & 10 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ func TestPlugin(t *testing.T) {
assert.NotNil(t, p.PluginLoader)
})

t.Run("test getFinalEndpoint", func(t *testing.T) {
endpoint, _ := getFinalEndpoint("spark", &cfg)
assert.Equal(t, cfg.Agents["spark_agent"].Endpoint, endpoint.Endpoint)
endpoint, _ = getFinalEndpoint("foo", &cfg)
assert.Equal(t, cfg.DefaultAgent.Endpoint, endpoint.Endpoint)
_, err := getFinalEndpoint("bar", &cfg)
t.Run("test getFinalAgent", func(t *testing.T) {
agent, _ := getFinalAgent("spark", &cfg)
assert.Equal(t, cfg.Agents["spark_agent"].Endpoint, agent.Endpoint)
agent, _ = getFinalAgent("foo", &cfg)
assert.Equal(t, cfg.DefaultAgent.Endpoint, agent.Endpoint)
_, err := getFinalAgent("bar", &cfg)
assert.NotNil(t, err)
})

Expand All @@ -72,14 +72,14 @@ func TestPlugin(t *testing.T) {

t.Run("test getClientFunc cache hit", func(t *testing.T) {
connectionCache := make(map[*Agent]*grpc.ClientConn)
endpoint := &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}
agent := &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}

client, err := getClientFunc(context.Background(), endpoint, connectionCache)
client, err := getClientFunc(context.Background(), agent, connectionCache)
assert.NoError(t, err)
assert.NotNil(t, client)
assert.NotNil(t, client, connectionCache[endpoint])
assert.NotNil(t, client, connectionCache[agent])

cachedClient, err := getClientFunc(context.Background(), endpoint, connectionCache)
cachedClient, err := getClientFunc(context.Background(), agent, connectionCache)
assert.NoError(t, err)
assert.NotNil(t, cachedClient)
assert.Equal(t, client, cachedClient)
Expand Down

0 comments on commit 76171c0

Please sign in to comment.