From 76171c0c03b787824a0eac7db80a1ec0cff35acb Mon Sep 17 00:00:00 2001 From: Honnix Date: Wed, 2 Aug 2023 23:46:08 +0200 Subject: [PATCH] Use agent as name where it fits (#381) Signed-off-by: Hongxin Liang --- .../tasks/plugins/webapi/agent/config_test.go | 4 +- .../go/tasks/plugins/webapi/agent/plugin.go | 62 +++++++++---------- .../tasks/plugins/webapi/agent/plugin_test.go | 20 +++--- 3 files changed, 43 insertions(+), 43 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/config_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/config_test.go index a328051591..1b36d03c86 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/config_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/config_test.go @@ -28,13 +28,13 @@ func TestGetAndSetConfig(t *testing.T) { } cfg.DefaultAgent.DefaultTimeout = config.Duration{Duration: 10 * time.Second} cfg.Agents = map[string]*Agent{ - "endpoint_1": { + "agent_1": { Insecure: cfg.DefaultAgent.Insecure, DefaultServiceConfig: cfg.DefaultAgent.DefaultServiceConfig, Timeouts: cfg.DefaultAgent.Timeouts, }, } - cfg.AgentForTaskTypes = map[string]string{"task_type_1": "endpoint_1"} + cfg.AgentForTaskTypes = map[string]string{"task_type_1": "agent_1"} err := SetConfig(&cfg) assert.NoError(t, err) assert.Equal(t, &cfg, GetConfig()) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index f8665aa607..fd497e11cb 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -25,7 +25,7 @@ import ( "google.golang.org/grpc" ) -type GetClientFunc func(ctx context.Context, endpoint *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) +type GetClientFunc func(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) type Plugin struct { metricScope promutils.Scope @@ -70,16 +70,16 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String() - endpoint, err := getFinalEndpoint(taskTemplate.Type, p.cfg) + agent, err := getFinalAgent(taskTemplate.Type, p.cfg) if err != nil { - return nil, nil, fmt.Errorf("failed to find agent endpoint with error: %v", err) + return nil, nil, fmt.Errorf("failed to find agent agent with error: %v", err) } - client, err := p.getClient(ctx, endpoint, p.connectionCache) + client, err := p.getClient(ctx, agent, p.connectionCache) if err != nil { return nil, nil, fmt.Errorf("failed to connect to agent with error: %v", err) } - finalCtx, cancel := getFinalContext(ctx, "CreateTask", endpoint) + finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent) defer cancel() taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) @@ -99,16 +99,16 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { metadata := taskCtx.ResourceMeta().(*ResourceMetaWrapper) - endpoint, err := getFinalEndpoint(metadata.TaskType, p.cfg) + agent, err := getFinalAgent(metadata.TaskType, p.cfg) if err != nil { - return nil, fmt.Errorf("failed to find agent endpoint with error: %v", err) + return nil, fmt.Errorf("failed to find agent with error: %v", err) } - client, err := p.getClient(ctx, endpoint, p.connectionCache) + client, err := p.getClient(ctx, agent, p.connectionCache) if err != nil { return nil, fmt.Errorf("failed to connect to agent with error: %v", err) } - finalCtx, cancel := getFinalContext(ctx, "GetTask", endpoint) + finalCtx, cancel := getFinalContext(ctx, "GetTask", agent) defer cancel() res, err := client.GetTask(finalCtx, &admin.GetTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta}) @@ -128,16 +128,16 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error } metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - endpoint, err := getFinalEndpoint(metadata.TaskType, p.cfg) + agent, err := getFinalAgent(metadata.TaskType, p.cfg) if err != nil { - return fmt.Errorf("failed to find agent endpoint with error: %v", err) + return fmt.Errorf("failed to find agent agent with error: %v", err) } - client, err := p.getClient(ctx, endpoint, p.connectionCache) + client, err := p.getClient(ctx, agent, p.connectionCache) if err != nil { return fmt.Errorf("failed to connect to agent with error: %v", err) } - finalCtx, cancel := getFinalContext(ctx, "DeleteTask", endpoint) + finalCtx, cancel := getFinalContext(ctx, "DeleteTask", agent) defer cancel() _, err = client.DeleteTask(finalCtx, &admin.DeleteTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta}) @@ -167,26 +167,26 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.State) } -func getFinalEndpoint(taskType string, cfg *Config) (*Agent, error) { +func getFinalAgent(taskType string, cfg *Config) (*Agent, error) { if id, exists := cfg.AgentForTaskTypes[taskType]; exists { - if endpoint, exists := cfg.Agents[id]; exists { - return endpoint, nil + if agent, exists := cfg.Agents[id]; exists { + return agent, nil } - return nil, fmt.Errorf("no endpoint definition found for ID %s that matches task type %s", id, taskType) + return nil, fmt.Errorf("no agent definition found for ID %s that matches task type %s", id, taskType) } return &cfg.DefaultAgent, nil } -func getClientFunc(ctx context.Context, endpoint *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { - conn, ok := connectionCache[endpoint] +func getClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { + conn, ok := connectionCache[agent] if ok { return service.NewAsyncAgentServiceClient(conn), nil } var opts []grpc.DialOption - if endpoint.Insecure { + if agent.Insecure { opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) } else { pool, err := x509.SystemCertPool() @@ -198,27 +198,27 @@ func getClientFunc(ctx context.Context, endpoint *Agent, connectionCache map[*Ag opts = append(opts, grpc.WithTransportCredentials(creds)) } - if len(endpoint.DefaultServiceConfig) != 0 { - opts = append(opts, grpc.WithDefaultServiceConfig(endpoint.DefaultServiceConfig)) + if len(agent.DefaultServiceConfig) != 0 { + opts = append(opts, grpc.WithDefaultServiceConfig(agent.DefaultServiceConfig)) } var err error - conn, err = grpc.Dial(endpoint.Endpoint, opts...) + conn, err = grpc.Dial(agent.Endpoint, opts...) if err != nil { return nil, err } - connectionCache[endpoint] = conn + connectionCache[agent] = conn defer func() { if err != nil { if cerr := conn.Close(); cerr != nil { - grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + grpclog.Infof("Failed to close conn to %s: %v", agent, cerr) } return } go func() { <-ctx.Done() if cerr := conn.Close(); cerr != nil { - grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + grpclog.Infof("Failed to close conn to %s: %v", agent, cerr) } }() }() @@ -237,16 +237,16 @@ func buildTaskExecutionMetadata(taskExecutionMetadata pluginsCore.TaskExecutionM } } -func getFinalTimeout(operation string, endpoint *Agent) config.Duration { - if t, exists := endpoint.Timeouts[operation]; exists { +func getFinalTimeout(operation string, agent *Agent) config.Duration { + if t, exists := agent.Timeouts[operation]; exists { return t } - return endpoint.DefaultTimeout + return agent.DefaultTimeout } -func getFinalContext(ctx context.Context, operation string, endpoint *Agent) (context.Context, context.CancelFunc) { - timeout := getFinalTimeout(operation, endpoint).Duration +func getFinalContext(ctx context.Context, operation string, agent *Agent) (context.Context, context.CancelFunc) { + timeout := getFinalTimeout(operation, agent).Duration if timeout == 0 { return ctx, func() {} } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 31c8fe034d..180a0d6e6f 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -49,12 +49,12 @@ func TestPlugin(t *testing.T) { assert.NotNil(t, p.PluginLoader) }) - t.Run("test getFinalEndpoint", func(t *testing.T) { - endpoint, _ := getFinalEndpoint("spark", &cfg) - assert.Equal(t, cfg.Agents["spark_agent"].Endpoint, endpoint.Endpoint) - endpoint, _ = getFinalEndpoint("foo", &cfg) - assert.Equal(t, cfg.DefaultAgent.Endpoint, endpoint.Endpoint) - _, err := getFinalEndpoint("bar", &cfg) + t.Run("test getFinalAgent", func(t *testing.T) { + agent, _ := getFinalAgent("spark", &cfg) + assert.Equal(t, cfg.Agents["spark_agent"].Endpoint, agent.Endpoint) + agent, _ = getFinalAgent("foo", &cfg) + assert.Equal(t, cfg.DefaultAgent.Endpoint, agent.Endpoint) + _, err := getFinalAgent("bar", &cfg) assert.NotNil(t, err) }) @@ -72,14 +72,14 @@ func TestPlugin(t *testing.T) { t.Run("test getClientFunc cache hit", func(t *testing.T) { connectionCache := make(map[*Agent]*grpc.ClientConn) - endpoint := &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"} + agent := &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"} - client, err := getClientFunc(context.Background(), endpoint, connectionCache) + client, err := getClientFunc(context.Background(), agent, connectionCache) assert.NoError(t, err) assert.NotNil(t, client) - assert.NotNil(t, client, connectionCache[endpoint]) + assert.NotNil(t, client, connectionCache[agent]) - cachedClient, err := getClientFunc(context.Background(), endpoint, connectionCache) + cachedClient, err := getClientFunc(context.Background(), agent, connectionCache) assert.NoError(t, err) assert.NotNil(t, cachedClient) assert.Equal(t, client, cachedClient)