Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Propagate request id on incoming and outgoing requests
Browse files Browse the repository at this point in the history
Signed-off-by: Haytham Abuelfutuh <[email protected]>
  • Loading branch information
EngHabu committed Jul 1, 2023
1 parent 78b3e14 commit 9172b7e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ require (
github.com/flyteorg/flyteidl v1.5.11
github.com/flyteorg/flyteplugins v1.0.67
github.com/flyteorg/flytepropeller v1.1.98
github.com/flyteorg/flytestdlib v1.0.15
github.com/flyteorg/flytestdlib v1.0.17
github.com/flyteorg/stow v0.3.6
github.com/ghodss/yaml v1.0.0
github.com/go-gormigrate/gormigrate/v2 v2.0.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ github.com/flyteorg/flytepropeller v1.1.98 h1:Zk2ENYB9VZRT5tFUIFjm+aCkr0TU2EuyJ5
github.com/flyteorg/flytepropeller v1.1.98/go.mod h1:R0CB6Uzp9F4YyvPmLRE7XyXxDebAPFD+LbHTf07mBzI=
github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0=
github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s=
github.com/flyteorg/flytestdlib v1.0.17 h1:O+xuCLy1/H/Va4vA1vv/hFG555rCfGNh10ld9yIYumU=
github.com/flyteorg/flytestdlib v1.0.17/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s=
github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk=
github.com/flyteorg/stow v0.3.6/go.mod h1:5dfBitPM004dwaZdoVylVjxFT4GWAgI0ghAndhNUzCo=
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
Expand Down
46 changes: 44 additions & 2 deletions pkg/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"crypto/tls"
"fmt"
"google.golang.org/grpc/metadata"
"k8s.io/apimachinery/pkg/util/rand"
"net"
"net/http"
"strings"
Expand Down Expand Up @@ -77,7 +79,8 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c
scope promutils.Scope, opts ...grpc.ServerOption) (*grpc.Server, error) {

logger.Infof(ctx, "Registering default middleware with blanket auth validation")
pluginRegistry.RegisterDefault(plugins.PluginIDUnaryServiceMiddleware, grpcmiddleware.ChainUnaryServer(auth.BlanketAuthorization, auth.ExecutionUserIdentifierInterceptor))
pluginRegistry.RegisterDefault(plugins.PluginIDUnaryServiceMiddleware, grpcmiddleware.ChainUnaryServer(
RequestIDInterceptor, auth.BlanketAuthorization, auth.ExecutionUserIdentifierInterceptor))

// Not yet implemented for streaming
var chainedUnaryInterceptors grpc.UnaryServerInterceptor
Expand Down Expand Up @@ -228,11 +231,50 @@ func newHTTPServer(ctx context.Context, cfg *config.ServerConfig, _ *authConfig.
return nil, errors.Wrap(err, "error registering signal service")
}

mux.Handle("/", gwmux)
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
ctx := GetOrGenerateRequestIDForRequest(r)
gwmux.ServeHTTP(w, r.WithContext(ctx))
})

return mux, nil
}

// RequestIDInterceptor is a server interceptor that sets the request id on the context for any incoming calls.
func RequestIDInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return handler(GetOrGenerateRequestIDForGRPC(ctx), req)
}

// GetOrGenerateRequestIDForGRPC returns a context with request id set from the context or from grpc metadata if it exists,
// otherwise it generates a new one.
func GetOrGenerateRequestIDForGRPC(ctx context.Context) context.Context {
if ctx.Value(contextutils.RequestIDKey) != nil {
return ctx
} else if md, exists := metadata.FromIncomingContext(ctx); exists && len(md.Get(contextutils.RequestIDKey.String())) > 0 {
return contextutils.WithRequestID(ctx, md.Get(contextutils.RequestIDKey.String())[0])
} else {
return contextutils.WithRequestID(ctx, generateRequestID())
}
}

// GetOrGenerateRequestIDForRequest returns a context with request id set from the context or from metadata if it exists,
// otherwise it generates a new one.
func GetOrGenerateRequestIDForRequest(req *http.Request) context.Context {
ctx := req.Context()
if ctx.Value(contextutils.RequestIDKey) != nil {
return ctx
} else if md, exists := metadata.FromIncomingContext(ctx); exists && len(md.Get(contextutils.RequestIDKey.String())) > 0 {
return contextutils.WithRequestID(ctx, md.Get(contextutils.RequestIDKey.String())[0])
} else if req.Header != nil && req.Header.Get(contextutils.RequestIDKey.String()) != "" {
return contextutils.WithRequestID(ctx, req.Header.Get(contextutils.RequestIDKey.String()))
} else {
return contextutils.WithRequestID(ctx, generateRequestID())
}
}

func generateRequestID() string {
return "a-" + rand.String(20)
}

func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig,
authCfg *authConfig.Config, storageConfig *storage.Config,
additionalHandlers map[string]func(http.ResponseWriter, *http.Request), scope promutils.Scope) error {
Expand Down

0 comments on commit 9172b7e

Please sign in to comment.