Skip to content

Commit

Permalink
feat: agent propagation (#3654)
Browse files Browse the repository at this point in the history
* add otelgrpc to agent and update grpc server to accept context

* make context propagation work for trigger request

* return context in test method

* test that trigger request/response share the same traceID

* feat: add trace propagation to the rest of the workers
  • Loading branch information
mathnogueira authored Feb 16, 2024
1 parent 0b1c545 commit 6cb3f45
Show file tree
Hide file tree
Showing 22 changed files with 288 additions and 110 deletions.
8 changes: 8 additions & 0 deletions agent/client/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"net"
"time"

"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.opentelemetry.io/otel/propagation"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -73,6 +75,12 @@ func (c *Client) connect(ctx context.Context) error {
grpc.WithTransportCredentials(transportCredentials),
grpc.WithDefaultServiceConfig(retryPolicy),
grpc.WithIdleTimeout(0), // disable grpc idle timeout
grpc.WithStatsHandler(otelgrpc.NewClientHandler()),
grpc.WithUnaryInterceptor(otelgrpc.UnaryClientInterceptor(
otelgrpc.WithPropagators(
propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}),
),
)),
)
if err != nil {
return fmt.Errorf("could not connect to server: %w", err)
Expand Down
118 changes: 79 additions & 39 deletions agent/client/mocks/grpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,41 @@ import (
"github.com/avast/retry-go"
"github.com/kubeshop/tracetest/agent/client"
"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/server/telemetry"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.opentelemetry.io/otel/propagation"
"google.golang.org/grpc"
)

type GrpcServerMock struct {
proto.UnimplementedOrchestratorServer
port int
triggerChannel chan *proto.TriggerRequest
pollingChannel chan *proto.PollingRequest
otlpConnectionTestChannel chan *proto.OTLPConnectionTestRequest
terminationChannel chan *proto.ShutdownRequest
dataStoreTestChannel chan *proto.DataStoreConnectionTestRequest
triggerChannel chan Message[*proto.TriggerRequest]
pollingChannel chan Message[*proto.PollingRequest]
otlpConnectionTestChannel chan Message[*proto.OTLPConnectionTestRequest]
terminationChannel chan Message[*proto.ShutdownRequest]
dataStoreTestChannel chan Message[*proto.DataStoreConnectionTestRequest]

lastTriggerResponse *proto.TriggerResponse
lastPollingResponse *proto.PollingResponse
lastOtlpConnectionResponse *proto.OTLPConnectionTestResponse
lastDataStoreConnectionResponse *proto.DataStoreConnectionTestResponse
lastTriggerResponse Message[*proto.TriggerResponse]
lastPollingResponse Message[*proto.PollingResponse]
lastOtlpConnectionResponse Message[*proto.OTLPConnectionTestResponse]
lastDataStoreConnectionResponse Message[*proto.DataStoreConnectionTestResponse]

server *grpc.Server
}

type Message[T any] struct {
Context context.Context
Data T
}

func NewGrpcServer() *GrpcServerMock {
server := &GrpcServerMock{
triggerChannel: make(chan *proto.TriggerRequest),
pollingChannel: make(chan *proto.PollingRequest),
terminationChannel: make(chan *proto.ShutdownRequest),
dataStoreTestChannel: make(chan *proto.DataStoreConnectionTestRequest),
otlpConnectionTestChannel: make(chan *proto.OTLPConnectionTestRequest),
triggerChannel: make(chan Message[*proto.TriggerRequest]),
pollingChannel: make(chan Message[*proto.PollingRequest]),
terminationChannel: make(chan Message[*proto.ShutdownRequest]),
dataStoreTestChannel: make(chan Message[*proto.DataStoreConnectionTestRequest]),
otlpConnectionTestChannel: make(chan Message[*proto.OTLPConnectionTestRequest]),
}
var wg sync.WaitGroup
wg.Add(1)
Expand Down Expand Up @@ -65,7 +73,13 @@ func (s *GrpcServerMock) start(wg *sync.WaitGroup, port int) error {

s.port = lis.Addr().(*net.TCPAddr).Port

server := grpc.NewServer()
server := grpc.NewServer(
grpc.UnaryInterceptor(otelgrpc.UnaryServerInterceptor(
otelgrpc.WithPropagators(
propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}),
),
)),
)
proto.RegisterOrchestratorServer(server, s)

s.server = server
Expand Down Expand Up @@ -107,7 +121,12 @@ func (s *GrpcServerMock) RegisterTriggerAgent(id *proto.AgentIdentification, str

for {
triggerRequest := <-s.triggerChannel
err := stream.Send(triggerRequest)
err := telemetry.InjectContextIntoStream(triggerRequest.Context, stream)
if err != nil {
log.Println(err.Error())
}

err = stream.Send(triggerRequest.Data)
if err != nil {
log.Println("could not send trigger request to agent: %w", err)
}
Expand All @@ -120,7 +139,7 @@ func (s *GrpcServerMock) SendTriggerResult(ctx context.Context, result *proto.Tr
return nil, fmt.Errorf("could not validate token")
}

s.lastTriggerResponse = result
s.lastTriggerResponse = Message[*proto.TriggerResponse]{Data: result, Context: ctx}
return &proto.Empty{}, nil
}

Expand All @@ -131,7 +150,12 @@ func (s *GrpcServerMock) RegisterPollerAgent(id *proto.AgentIdentification, stre

for {
pollerRequest := <-s.pollingChannel
err := stream.Send(pollerRequest)
err := telemetry.InjectContextIntoStream(pollerRequest.Context, stream)
if err != nil {
log.Println(err.Error())
}

err = stream.Send(pollerRequest.Data)
if err != nil {
log.Println("could not send polling request to agent: %w", err)
}
Expand All @@ -145,7 +169,12 @@ func (s *GrpcServerMock) RegisterDataStoreConnectionTestAgent(id *proto.AgentIde

for {
dsTestRequest := <-s.dataStoreTestChannel
err := stream.Send(dsTestRequest)
err := telemetry.InjectContextIntoStream(dsTestRequest.Context, stream)
if err != nil {
log.Println(err.Error())
}

err = stream.Send(dsTestRequest.Data)
if err != nil {
log.Println("could not send polling request to agent: %w", err)
}
Expand All @@ -159,7 +188,12 @@ func (s *GrpcServerMock) RegisterOTLPConnectionTestListener(id *proto.AgentIdent

for {
testRequest := <-s.otlpConnectionTestChannel
err := stream.Send(testRequest)
err := telemetry.InjectContextIntoStream(testRequest.Context, stream)
if err != nil {
log.Println(err.Error())
}

err = stream.Send(testRequest.Data)
if err != nil {
log.Println("could not send polling request to agent: %w", err)
}
Expand All @@ -171,7 +205,7 @@ func (s *GrpcServerMock) SendOTLPConnectionTestResult(ctx context.Context, resul
return nil, fmt.Errorf("could not validate token")
}

s.lastOtlpConnectionResponse = result
s.lastOtlpConnectionResponse = Message[*proto.OTLPConnectionTestResponse]{Data: result, Context: ctx}
return &proto.Empty{}, nil
}

Expand All @@ -180,7 +214,7 @@ func (s *GrpcServerMock) SendDataStoreConnectionTestResult(ctx context.Context,
return nil, fmt.Errorf("could not validate token")
}

s.lastDataStoreConnectionResponse = result
s.lastDataStoreConnectionResponse = Message[*proto.DataStoreConnectionTestResponse]{Data: result, Context: ctx}
return &proto.Empty{}, nil
}

Expand All @@ -189,14 +223,19 @@ func (s *GrpcServerMock) SendPolledSpans(ctx context.Context, result *proto.Poll
return nil, fmt.Errorf("could not validate token")
}

s.lastPollingResponse = result
s.lastPollingResponse = Message[*proto.PollingResponse]{Data: result, Context: ctx}
return &proto.Empty{}, nil
}

func (s *GrpcServerMock) RegisterShutdownListener(_ *proto.AgentIdentification, stream proto.Orchestrator_RegisterShutdownListenerServer) error {
for {
shutdownRequest := <-s.terminationChannel
err := stream.Send(shutdownRequest)
err := telemetry.InjectContextIntoStream(shutdownRequest.Context, stream)
if err != nil {
log.Println(err.Error())
}

err = stream.Send(shutdownRequest.Data)
if err != nil {
log.Println("could not send polling request to agent: %w", err)
}
Expand All @@ -205,41 +244,42 @@ func (s *GrpcServerMock) RegisterShutdownListener(_ *proto.AgentIdentification,

// Test methods

func (s *GrpcServerMock) SendTriggerRequest(request *proto.TriggerRequest) {
s.triggerChannel <- request
func (s *GrpcServerMock) SendTriggerRequest(ctx context.Context, request *proto.TriggerRequest) {
s.triggerChannel <- Message[*proto.TriggerRequest]{Context: ctx, Data: request}
}

func (s *GrpcServerMock) SendPollingRequest(request *proto.PollingRequest) {
s.pollingChannel <- request
func (s *GrpcServerMock) SendPollingRequest(ctx context.Context, request *proto.PollingRequest) {
s.pollingChannel <- Message[*proto.PollingRequest]{Context: ctx, Data: request}
}

func (s *GrpcServerMock) SendDataStoreConnectionTestRequest(request *proto.DataStoreConnectionTestRequest) {
s.dataStoreTestChannel <- request
func (s *GrpcServerMock) SendDataStoreConnectionTestRequest(ctx context.Context, request *proto.DataStoreConnectionTestRequest) {
s.dataStoreTestChannel <- Message[*proto.DataStoreConnectionTestRequest]{Context: ctx, Data: request}
}

func (s *GrpcServerMock) SendOTLPConnectionTestRequest(request *proto.OTLPConnectionTestRequest) {
s.otlpConnectionTestChannel <- request
func (s *GrpcServerMock) SendOTLPConnectionTestRequest(ctx context.Context, request *proto.OTLPConnectionTestRequest) {
s.otlpConnectionTestChannel <- Message[*proto.OTLPConnectionTestRequest]{Context: ctx, Data: request}
}

func (s *GrpcServerMock) GetLastTriggerResponse() *proto.TriggerResponse {
func (s *GrpcServerMock) GetLastTriggerResponse() Message[*proto.TriggerResponse] {
return s.lastTriggerResponse
}

func (s *GrpcServerMock) GetLastPollingResponse() *proto.PollingResponse {
func (s *GrpcServerMock) GetLastPollingResponse() Message[*proto.PollingResponse] {
return s.lastPollingResponse
}

func (s *GrpcServerMock) GetLastOTLPConnectionResponse() *proto.OTLPConnectionTestResponse {
func (s *GrpcServerMock) GetLastOTLPConnectionResponse() Message[*proto.OTLPConnectionTestResponse] {
return s.lastOtlpConnectionResponse
}

func (s *GrpcServerMock) GetLastDataStoreConnectionResponse() *proto.DataStoreConnectionTestResponse {
func (s *GrpcServerMock) GetLastDataStoreConnectionResponse() Message[*proto.DataStoreConnectionTestResponse] {
return s.lastDataStoreConnectionResponse
}

func (s *GrpcServerMock) TerminateConnection(reason string) {
s.terminationChannel <- &proto.ShutdownRequest{
Reason: reason,
func (s *GrpcServerMock) TerminateConnection(ctx context.Context, reason string) {
s.terminationChannel <- Message[*proto.ShutdownRequest]{
Context: ctx,
Data: &proto.ShutdownRequest{Reason: reason},
}
}

Expand Down
9 changes: 7 additions & 2 deletions agent/client/workflow_listen_for_ds_connection_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/server/telemetry"
)

func (c *Client) startDataStoreConnectionTestListener(ctx context.Context) error {
Expand Down Expand Up @@ -36,8 +37,12 @@ func (c *Client) startDataStoreConnectionTestListener(ctx context.Context) error
continue
}

// TODO: Get ctx from request
err = c.dataStoreConnectionListener(context.Background(), &req)
ctx, err := telemetry.ExtractContextFromStream(stream)
if err != nil {
log.Println("could not extract context from stream %w", err)
}

err = c.dataStoreConnectionListener(ctx, &req)
if err != nil {
fmt.Println(err.Error())
}
Expand Down
7 changes: 4 additions & 3 deletions agent/client/workflow_listen_for_ds_connection_tests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ import (
)

func TestDataStoreConnectionTestWorkflow(t *testing.T) {
ctx := context.Background()
server := mocks.NewGrpcServer()
defer server.Stop()

client, err := client.Connect(context.Background(), server.Addr())
client, err := client.Connect(ctx, server.Addr())
require.NoError(t, err)

var receivedConnectionTestRequest *proto.DataStoreConnectionTestRequest
Expand All @@ -25,14 +26,14 @@ func TestDataStoreConnectionTestWorkflow(t *testing.T) {
return nil
})

err = client.Start(context.Background())
err = client.Start(ctx)
require.NoError(t, err)

connectionTestRequest := &proto.DataStoreConnectionTestRequest{
RequestID: "request-id",
}

server.SendDataStoreConnectionTestRequest(connectionTestRequest)
server.SendDataStoreConnectionTestRequest(ctx, connectionTestRequest)

// ensures there's enough time for networking between server and client
time.Sleep(1 * time.Second)
Expand Down
9 changes: 7 additions & 2 deletions agent/client/workflow_listen_for_otlp_connection_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/server/telemetry"
)

func (c *Client) startOTLPConnectionTestListener(ctx context.Context) error {
Expand Down Expand Up @@ -36,8 +37,12 @@ func (c *Client) startOTLPConnectionTestListener(ctx context.Context) error {
continue
}

// TODO: Get ctx from request
err = c.otlpConnectionTestListener(context.Background(), &req)
ctx, err := telemetry.ExtractContextFromStream(stream)
if err != nil {
log.Println("could not extract context from stream %w", err)
}

err = c.otlpConnectionTestListener(ctx, &req)
if err != nil {
fmt.Println(err.Error())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ import (
)

func TestOtlpConnectionTestWorkflow(t *testing.T) {
ctx := context.Background()
server := mocks.NewGrpcServer()
defer server.Stop()

client, err := client.Connect(context.Background(), server.Addr())
client, err := client.Connect(ctx, server.Addr())
require.NoError(t, err)

var receivedConnectionTestRequest *proto.OTLPConnectionTestRequest
Expand All @@ -25,14 +26,14 @@ func TestOtlpConnectionTestWorkflow(t *testing.T) {
return nil
})

err = client.Start(context.Background())
err = client.Start(ctx)
require.NoError(t, err)

connectionTestRequest := &proto.OTLPConnectionTestRequest{
RequestID: "request-id",
}

server.SendOTLPConnectionTestRequest(connectionTestRequest)
server.SendOTLPConnectionTestRequest(ctx, connectionTestRequest)

// ensures there's enough time for networking between server and client
time.Sleep(1 * time.Second)
Expand Down
9 changes: 7 additions & 2 deletions agent/client/workflow_listen_for_poll_requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/server/telemetry"
)

func (c *Client) startPollerListener(ctx context.Context) error {
Expand Down Expand Up @@ -36,8 +37,12 @@ func (c *Client) startPollerListener(ctx context.Context) error {
continue
}

// TODO: Get ctx from request
err = c.pollListener(context.Background(), &resp)
ctx, err := telemetry.ExtractContextFromStream(stream)
if err != nil {
log.Println("could not extract context from stream %w", err)
}

err = c.pollListener(ctx, &resp)
if err != nil {
fmt.Println(err.Error())
}
Expand Down
Loading

0 comments on commit 6cb3f45

Please sign in to comment.