Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Per task type grpc endpoint config
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix committed Jul 27, 2023
1 parent fed2449 commit a95e504
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 33 deletions.
6 changes: 5 additions & 1 deletion go/tasks/plugins/webapi/agent/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,14 @@ type Config struct {
// ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time
ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."`

// The default grpc endpoint if there does not exist a more specific matching against task types
DefaultGrpcEndpoint GrpcEndpoint `json:"defaultGrpcEndpoint" pflag:",The default grpc endpoint of agent service."`

// The grpc endpoints of agent services, which are used to match against specific task types
GrpcEndpoints map[string]*GrpcEndpoint `json:"grpcEndpoints" pflag:",The grpc endpoints of agent services."`

// Maps endpoint to their plugin handler. {TaskType: Endpoint}
EndpointForTaskTypes map[string]GrpcEndpoint `json:"endpointForTaskTypes" pflag:"-,"`
EndpointForTaskTypes map[string]string `json:"endpointForTaskTypes" pflag:"-,"`

// SupportedTaskTypes is a list of task types that are supported by this plugin.
SupportedTaskTypes []string `json:"supportedTaskTypes" pflag:"-,Defines a list of task types that are supported by this plugin."`
Expand Down
8 changes: 8 additions & 0 deletions go/tasks/plugins/webapi/agent/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ func TestGetAndSetConfig(t *testing.T) {
},
}
cfg.DefaultGrpcEndpoint.DefaultTimeout = config.Duration{Duration: 10 * time.Second}
cfg.GrpcEndpoints = map[string]*GrpcEndpoint{
"endpoint_1": {
Insecure: cfg.DefaultGrpcEndpoint.Insecure,
DefaultServiceConfig: cfg.DefaultGrpcEndpoint.DefaultServiceConfig,
Timeouts: cfg.DefaultGrpcEndpoint.Timeouts,
},
}
cfg.EndpointForTaskTypes = map[string]string{"task_type_1": "endpoint_1"}
err := SetConfig(&cfg)
assert.NoError(t, err)
assert.Equal(t, &cfg, GetConfig())
Expand Down
4 changes: 2 additions & 2 deletions go/tasks/plugins/webapi/agent/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ func (m *MockClient) DeleteTask(_ context.Context, _ *admin.DeleteTaskRequest, _
return &admin.DeleteTaskResponse{}, nil
}

func mockGetClientFunc(_ context.Context, _ GrpcEndpoint, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
func mockGetClientFunc(_ context.Context, _ *GrpcEndpoint, _ map[*GrpcEndpoint]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
return &MockClient{}, nil
}

func mockGetBadClientFunc(_ context.Context, _ GrpcEndpoint, _ map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
func mockGetBadClientFunc(_ context.Context, _ *GrpcEndpoint, _ map[*GrpcEndpoint]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
return nil, fmt.Errorf("error")
}

Expand Down
40 changes: 26 additions & 14 deletions go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ import (
"google.golang.org/grpc"
)

type GetClientFunc func(ctx context.Context, endpoint GrpcEndpoint, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error)
type GetClientFunc func(ctx context.Context, endpoint *GrpcEndpoint, connectionCache map[*GrpcEndpoint]*grpc.ClientConn) (service.AsyncAgentServiceClient, error)

type Plugin struct {
metricScope promutils.Scope
cfg *Config
getClient GetClientFunc
connectionCache map[string]*grpc.ClientConn
connectionCache map[*GrpcEndpoint]*grpc.ClientConn
}

type ResourceWrapper struct {
Expand Down Expand Up @@ -70,7 +70,10 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR

outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String()

endpoint := getFinalEndpoint(taskTemplate.Type, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes)
endpoint, err := getFinalEndpoint(taskTemplate.Type, p.cfg)
if err != nil {
return nil, nil, fmt.Errorf("failed to find agent endpoint with error: %v", err)
}
client, err := p.getClient(ctx, endpoint, p.connectionCache)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect to agent with error: %v", err)
Expand All @@ -96,7 +99,10 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) {
metadata := taskCtx.ResourceMeta().(*ResourceMetaWrapper)

endpoint := getFinalEndpoint(metadata.TaskType, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes)
endpoint, err := getFinalEndpoint(metadata.TaskType, p.cfg)
if err != nil {
return nil, fmt.Errorf("failed to find agent endpoint with error: %v", err)
}
client, err := p.getClient(ctx, endpoint, p.connectionCache)
if err != nil {
return nil, fmt.Errorf("failed to connect to agent with error: %v", err)
Expand All @@ -122,7 +128,10 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error
}
metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper)

endpoint := getFinalEndpoint(metadata.TaskType, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes)
endpoint, err := getFinalEndpoint(metadata.TaskType, p.cfg)
if err != nil {
return fmt.Errorf("failed to find agent endpoint with error: %v", err)
}
client, err := p.getClient(ctx, endpoint, p.connectionCache)
if err != nil {
return fmt.Errorf("failed to connect to agent with error: %v", err)
Expand Down Expand Up @@ -158,16 +167,19 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase
return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.State)
}

func getFinalEndpoint(taskType string, defaultEndpoint GrpcEndpoint, endpointForTaskTypes map[string]GrpcEndpoint) GrpcEndpoint {
if t, exists := endpointForTaskTypes[taskType]; exists {
return t
func getFinalEndpoint(taskType string, cfg *Config) (*GrpcEndpoint, error) {
if id, exists := cfg.EndpointForTaskTypes[taskType]; exists {
if endpoint, exists := cfg.GrpcEndpoints[id]; exists {
return endpoint, nil
}
return nil, fmt.Errorf("no endpoint definition found for ID %s that matches task type %s", id, taskType)
}

return defaultEndpoint
return &cfg.DefaultGrpcEndpoint, nil
}

func getClientFunc(ctx context.Context, endpoint GrpcEndpoint, connectionCache map[string]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
conn, ok := connectionCache[endpoint.Endpoint]
func getClientFunc(ctx context.Context, endpoint *GrpcEndpoint, connectionCache map[*GrpcEndpoint]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) {
conn, ok := connectionCache[endpoint]
if ok {
return service.NewAsyncAgentServiceClient(conn), nil
}
Expand Down Expand Up @@ -195,7 +207,7 @@ func getClientFunc(ctx context.Context, endpoint GrpcEndpoint, connectionCache m
if err != nil {
return nil, err
}
connectionCache[endpoint.Endpoint] = conn
connectionCache[endpoint] = conn
defer func() {
if err != nil {
if cerr := conn.Close(); cerr != nil {
Expand Down Expand Up @@ -233,7 +245,7 @@ func getFinalTimeout(operation string, endpoint *GrpcEndpoint) config.Duration {
return endpoint.DefaultTimeout
}

func getFinalContext(ctx context.Context, operation string, endpoint GrpcEndpoint) (context.Context, context.CancelFunc) {
func getFinalContext(ctx context.Context, operation string, endpoint *GrpcEndpoint) (context.Context, context.CancelFunc) {
timeout := getFinalTimeout(operation, endpoint).Duration
if timeout == 0 {
return ctx, func() {}
Expand All @@ -252,7 +264,7 @@ func newAgentPlugin() webapi.PluginEntry {
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
getClient: getClientFunc,
connectionCache: make(map[string]*grpc.ClientConn),
connectionCache: make(map[*GrpcEndpoint]*grpc.ClientConn),
}, nil
},
}
Expand Down
50 changes: 34 additions & 16 deletions go/tasks/plugins/webapi/agent/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,18 @@ func TestPlugin(t *testing.T) {
fakeSetupContext := pluginCoreMocks.SetupContext{}
fakeSetupContext.OnMetricsScope().Return(promutils.NewScope("test"))

cfg := defaultConfig
cfg.WebAPI.Caching.Workers = 1
cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second
cfg.DefaultGrpcEndpoint = GrpcEndpoint{Endpoint: "test-agent.flyte.svc.cluster.local:80"}
cfg.GrpcEndpoints = map[string]*GrpcEndpoint{"spark_agent": {Endpoint: "localhost:80"}}
cfg.EndpointForTaskTypes = map[string]string{"spark": "spark_agent", "bar": "bar_agent"}

plugin := Plugin{
metricScope: fakeSetupContext.MetricsScope(),
cfg: GetConfig(),
}
t.Run("get config", func(t *testing.T) {
cfg := defaultConfig
cfg.WebAPI.Caching.Workers = 1
cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second
cfg.DefaultGrpcEndpoint = GrpcEndpoint{Endpoint: "test-agent.flyte.svc.cluster.local:80"}
cfg.EndpointForTaskTypes = map[string]GrpcEndpoint{"spark": {Endpoint: "localhost:80"}}
err := SetConfig(&cfg)
assert.NoError(t, err)
assert.Equal(t, cfg.WebAPI, plugin.GetConfig())
Expand All @@ -48,37 +50,53 @@ func TestPlugin(t *testing.T) {
})

t.Run("test getFinalEndpoint", func(t *testing.T) {
defaultGrpcEndpoint := GrpcEndpoint{Endpoint: "localhost:8080"}
endpoint := getFinalEndpoint("spark", defaultGrpcEndpoint, map[string]GrpcEndpoint{"spark": {Endpoint: "localhost:80"}})
assert.Equal(t, "localhost:80", endpoint.Endpoint)
endpoint = getFinalEndpoint("spark", defaultGrpcEndpoint, map[string]GrpcEndpoint{})
assert.Equal(t, "localhost:8080", endpoint.Endpoint)
endpoint, _ := getFinalEndpoint("spark", &cfg)
assert.Equal(t, cfg.GrpcEndpoints["spark_agent"].Endpoint, endpoint.Endpoint)
endpoint, _ = getFinalEndpoint("foo", &cfg)
assert.Equal(t, cfg.DefaultGrpcEndpoint.Endpoint, endpoint.Endpoint)
_, err := getFinalEndpoint("bar", &cfg)
assert.NotNil(t, err)
})

t.Run("test getClientFunc", func(t *testing.T) {
client, err := getClientFunc(context.Background(), GrpcEndpoint{Endpoint: "localhost:80"}, map[string]*grpc.ClientConn{})
client, err := getClientFunc(context.Background(), &GrpcEndpoint{Endpoint: "localhost:80"}, map[*GrpcEndpoint]*grpc.ClientConn{})
assert.NoError(t, err)
assert.NotNil(t, client)
})

t.Run("test getClientFunc more config", func(t *testing.T) {
client, err := getClientFunc(context.Background(), GrpcEndpoint{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}, map[string]*grpc.ClientConn{})
client, err := getClientFunc(context.Background(), &GrpcEndpoint{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}, map[*GrpcEndpoint]*grpc.ClientConn{})
assert.NoError(t, err)
assert.NotNil(t, client)
})

t.Run("test getClientFunc cache hit", func(t *testing.T) {
connectionCache := make(map[*GrpcEndpoint]*grpc.ClientConn)
endpoint := &GrpcEndpoint{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}

client, err := getClientFunc(context.Background(), endpoint, connectionCache)
assert.NoError(t, err)
assert.NotNil(t, client)
assert.NotNil(t, client, connectionCache[endpoint])

cachedClient, err := getClientFunc(context.Background(), endpoint, connectionCache)
assert.NoError(t, err)
assert.NotNil(t, cachedClient)
assert.Equal(t, client, cachedClient)
})

t.Run("test getFinalTimeout", func(t *testing.T) {
timeout := getFinalTimeout("CreateTask", GrpcEndpoint{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}})
timeout := getFinalTimeout("CreateTask", &GrpcEndpoint{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}})
assert.Equal(t, 1*time.Millisecond, timeout.Duration)
timeout = getFinalTimeout("DeleteTask", GrpcEndpoint{Endpoint: "localhost:8080", DefaultTimeout: config.Duration{Duration: 10 * time.Second}})
timeout = getFinalTimeout("DeleteTask", &GrpcEndpoint{Endpoint: "localhost:8080", DefaultTimeout: config.Duration{Duration: 10 * time.Second}})
assert.Equal(t, 10*time.Second, timeout.Duration)
})

t.Run("test getFinalContext", func(t *testing.T) {
ctx, _ := getFinalContext(context.TODO(), "DeleteTask", GrpcEndpoint{})
ctx, _ := getFinalContext(context.TODO(), "DeleteTask", &GrpcEndpoint{})
assert.Equal(t, context.TODO(), ctx)

ctx, _ = getFinalContext(context.TODO(), "CreateTask", GrpcEndpoint{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}})
ctx, _ = getFinalContext(context.TODO(), "CreateTask", &GrpcEndpoint{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}})
assert.NotEqual(t, context.TODO(), ctx)
})
}

0 comments on commit a95e504

Please sign in to comment.