Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Support gRPC config for agent-service plugin
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix committed Jul 3, 2023
1 parent f7b37aa commit 8b27579
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 24 deletions.
27 changes: 23 additions & 4 deletions go/tasks/plugins/webapi/agent/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
)

var (
defaultTimeout = config.Duration{Duration: 10 * time.Second}

defaultConfig = Config{
WebAPI: webapi.PluginConfig{
ResourceQuotas: map[core.ResourceNamespace]int{
Expand Down Expand Up @@ -39,8 +41,11 @@ var (
Value: 50,
},
},
DefaultGrpcEndpoint: "dns:///flyte-agent.flyte.svc.cluster.local:80",
SupportedTaskTypes: []string{"task_type_1", "task_type_2"},
DefaultGrpcEndpoint: GrpcEndpoint{
Endpoint: "dns:///flyte-agent.flyte.svc.cluster.local:80",
Insecure: true,
},
SupportedTaskTypes: []string{"task_type_1", "task_type_2"},
}

configSection = pluginsConfig.MustRegisterSubSection("agent-service", &defaultConfig)
Expand All @@ -54,15 +59,29 @@ type Config struct {
// ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time
ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."`

DefaultGrpcEndpoint string `json:"defaultGrpcEndpoint" pflag:",The default grpc endpoint of agent service."`
DefaultGrpcEndpoint GrpcEndpoint `json:"defaultGrpcEndpoint" pflag:",The default grpc endpoint of agent service."`

// Maps endpoint to their plugin handler. {TaskType: Endpoint}
EndpointForTaskTypes map[string]string `json:"endpointForTaskTypes" pflag:"-,"`
EndpointForTaskTypes map[string]GrpcEndpoint `json:"endpointForTaskTypes" pflag:"-,"`

// SupportedTaskTypes is a list of task types that are supported by this plugin.
SupportedTaskTypes []string `json:"supportedTaskTypes" pflag:"-,Defines a list of task types that are supported by this plugin."`
}

type GrpcEndpoint struct {
// Endpoint points to a gRPC service
Endpoint string `json:"endpoint"`

// Insecure indicates whether the communication with the gRPC service is insecure
Insecure bool `json:"insecure"`

// DefaultServiceConfig sets default gRPC service config; check https://github.com/grpc/grpc/blob/master/doc/service_config.md for more details
DefaultServiceConfig string `json:"defaultServiceConfig"`

// Timeouts defines various RPC timeout values for different plugin operations: CreateTask, GetTask, DeleteTask; if not configured, defaults to 10s
Timeouts map[string]config.Duration `json:"timeouts"`
}

func GetConfig() *Config {
return configSection.GetConfig().(*Config)
}
Expand Down
15 changes: 15 additions & 0 deletions go/tasks/plugins/webapi/agent/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,28 @@ import (
"testing"
"time"

"github.com/flyteorg/flytestdlib/config"

"github.com/stretchr/testify/assert"
)

func TestGetAndSetConfig(t *testing.T) {
cfg := defaultConfig
cfg.WebAPI.Caching.Workers = 1
cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second
cfg.DefaultGrpcEndpoint.Insecure = false
cfg.DefaultGrpcEndpoint.DefaultServiceConfig = "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"
cfg.DefaultGrpcEndpoint.Timeouts = map[string]config.Duration{
"CreateTask": {
Duration: 1 * time.Millisecond,
},
"GetTask": {
Duration: 2 * time.Millisecond,
},
"DeleteTask": {
Duration: 3 * time.Millisecond,
},
}
err := SetConfig(&cfg)
assert.NoError(t, err)
assert.Equal(t, &cfg, GetConfig())
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/webapi/agent/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ func (m *MockClient) DeleteTask(_ context.Context, _ *admin.DeleteTaskRequest, _
return &admin.DeleteTaskResponse{}, nil
}

func mockGetClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
func mockGetClientFunc(_ context.Context, _ GrpcEndpoint, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
return &MockClient{}, nil
}

func mockGetBadClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
func mockGetBadClientFunc(_ context.Context, _ GrpcEndpoint, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
return nil, fmt.Errorf("error")
}

Expand Down
59 changes: 48 additions & 11 deletions go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@ package agent

import (
"context"
"crypto/x509"
"encoding/gob"
"fmt"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/flyteorg/flytestdlib/config"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"

"google.golang.org/grpc/grpclog"

Expand All @@ -21,7 +25,7 @@ import (
"google.golang.org/grpc"
)

type GetClientFunc func(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error)
type GetClientFunc func(ctx context.Context, endpoint GrpcEndpoint, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error)

type Plugin struct {
metricScope promutils.Scope
Expand Down Expand Up @@ -72,7 +76,10 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
return nil, nil, fmt.Errorf("failed to connect to agent with error: %v", err)
}

res, err := client.CreateTask(ctx, &admin.CreateTaskRequest{Inputs: inputs, Template: taskTemplate, OutputPrefix: outputPrefix})
newCtx, cancel := context.WithTimeout(ctx, getFinalTimeout("CreateTask", endpoint.Timeouts).Duration)
defer cancel()

res, err := client.CreateTask(newCtx, &admin.CreateTaskRequest{Inputs: inputs, Template: taskTemplate, OutputPrefix: outputPrefix})
if err != nil {
return nil, nil, err
}
Expand All @@ -94,7 +101,10 @@ func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest weba
return nil, fmt.Errorf("failed to connect to agent with error: %v", err)
}

res, err := client.GetTask(ctx, &admin.GetTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta})
newCtx, cancel := context.WithTimeout(ctx, getFinalTimeout("GetTask", endpoint.Timeouts).Duration)
defer cancel()

res, err := client.GetTask(newCtx, &admin.GetTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta})
if err != nil {
return nil, err
}
Expand All @@ -117,7 +127,10 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error
return fmt.Errorf("failed to connect to agent with error: %v", err)
}

_, err = client.DeleteTask(ctx, &admin.DeleteTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta})
newCtx, cancel := context.WithTimeout(ctx, getFinalTimeout("DeleteTask", endpoint.Timeouts).Duration)
defer cancel()

_, err = client.DeleteTask(newCtx, &admin.DeleteTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta})
return err
}

Expand All @@ -144,28 +157,44 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase
return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.State)
}

func getFinalEndpoint(taskType, defaultEndpoint string, endpointForTaskTypes map[string]string) string {
func getFinalEndpoint(taskType string, defaultEndpoint GrpcEndpoint, endpointForTaskTypes map[string]GrpcEndpoint) GrpcEndpoint {
if t, exists := endpointForTaskTypes[taskType]; exists {
return t
}

return defaultEndpoint
}

func getClientFunc(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
conn, ok := connectionCache[endpoint]
func getClientFunc(ctx context.Context, endpoint GrpcEndpoint, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
conn, ok := connectionCache[endpoint.Endpoint]
if ok {
return service.NewAsyncAgentServiceClient(conn), nil
}

var opts []grpc.DialOption
var err error

opts = append(opts, grpc.WithInsecure())
conn, err = grpc.Dial(endpoint, opts...)
if endpoint.Insecure {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
} else {
pool, err := x509.SystemCertPool()
if err != nil {
return nil, err
}

creds := credentials.NewClientTLSFromCert(pool, "")
opts = append(opts, grpc.WithTransportCredentials(creds))
}

if endpoint.DefaultServiceConfig != "" {
opts = append(opts, grpc.WithDefaultServiceConfig(endpoint.DefaultServiceConfig))
}

var err error
conn, err = grpc.Dial(endpoint.Endpoint, opts...)
if err != nil {
return nil, err
}
connectionCache[endpoint] = conn
connectionCache[endpoint.Endpoint] = conn
defer func() {
if err != nil {
if cerr := conn.Close(); cerr != nil {
Expand All @@ -183,6 +212,14 @@ func getClientFunc(ctx context.Context, endpoint string, connectionCache map[str
return service.NewAsyncAgentServiceClient(conn), nil
}

func getFinalTimeout(operation string, timeouts map[string]config.Duration) config.Duration {
if t, exists := timeouts[operation]; exists {
return t
}

return defaultTimeout
}

func newAgentPlugin() webapi.PluginEntry {
supportedTaskTypes := GetConfig().SupportedTaskTypes

Expand Down
30 changes: 23 additions & 7 deletions go/tasks/plugins/webapi/agent/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"testing"
"time"

"github.com/flyteorg/flytestdlib/config"

"google.golang.org/grpc"

pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
Expand All @@ -25,8 +27,8 @@ func TestPlugin(t *testing.T) {
cfg := defaultConfig
cfg.WebAPI.Caching.Workers = 1
cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second
cfg.DefaultGrpcEndpoint = "test-agent.flyte.svc.cluster.local:80"
cfg.EndpointForTaskTypes = map[string]string{"spark": "localhost:80"}
cfg.DefaultGrpcEndpoint = GrpcEndpoint{Endpoint: "test-agent.flyte.svc.cluster.local:80"}
cfg.EndpointForTaskTypes = map[string]GrpcEndpoint{"spark": {Endpoint: "localhost:80"}}
err := SetConfig(&cfg)
assert.NoError(t, err)
assert.Equal(t, cfg.WebAPI, plugin.GetConfig())
Expand All @@ -46,15 +48,29 @@ func TestPlugin(t *testing.T) {
})

t.Run("test getFinalEndpoint", func(t *testing.T) {
endpoint := getFinalEndpoint("spark", "localhost:8080", map[string]string{"spark": "localhost:80"})
assert.Equal(t, endpoint, "localhost:80")
endpoint = getFinalEndpoint("spark", "localhost:8080", map[string]string{})
assert.Equal(t, endpoint, "localhost:8080")
defaultGrpcEndpoint := GrpcEndpoint{Endpoint: "localhost:8080"}
endpoint := getFinalEndpoint("spark", defaultGrpcEndpoint, map[string]GrpcEndpoint{"spark": {Endpoint: "localhost:80"}})
assert.Equal(t, endpoint.Endpoint, "localhost:80")
endpoint = getFinalEndpoint("spark", defaultGrpcEndpoint, map[string]GrpcEndpoint{})
assert.Equal(t, endpoint.Endpoint, "localhost:8080")
})

t.Run("test getClientFunc", func(t *testing.T) {
client, err := getClientFunc(context.Background(), "localhost:80", map[string]*grpc.ClientConn{})
client, err := getClientFunc(context.Background(), GrpcEndpoint{Endpoint: "localhost:80"}, map[string]*grpc.ClientConn{})
assert.NoError(t, err)
assert.NotNil(t, client)
})

t.Run("test getClientFunc more config", func(t *testing.T) {
client, err := getClientFunc(context.Background(), GrpcEndpoint{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}, map[string]*grpc.ClientConn{})
assert.NoError(t, err)
assert.NotNil(t, client)
})

t.Run("test getFinalTimeout", func(t *testing.T) {
timeout := getFinalTimeout("CreateTask", map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}})
assert.Equal(t, timeout.Duration, 1*time.Millisecond)
timeout = getFinalTimeout("DeleteTask", map[string]config.Duration{})
assert.Equal(t, timeout.Duration, 10*time.Second)
})
}

0 comments on commit 8b27579

Please sign in to comment.