From 9172b7e4d41d1c3797d499e60e56c7e616206965 Mon Sep 17 00:00:00 2001 From: Haytham Abuelfutuh Date: Sat, 1 Jul 2023 13:20:18 -0700 Subject: [PATCH 1/2] Propagate request id on incoming and outgoing requests Signed-off-by: Haytham Abuelfutuh --- go.mod | 2 +- go.sum | 2 ++ pkg/server/service.go | 46 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 2a469a5c3..ec659c812 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/flyteorg/flyteidl v1.5.11 github.com/flyteorg/flyteplugins v1.0.67 github.com/flyteorg/flytepropeller v1.1.98 - github.com/flyteorg/flytestdlib v1.0.15 + github.com/flyteorg/flytestdlib v1.0.17 github.com/flyteorg/stow v0.3.6 github.com/ghodss/yaml v1.0.0 github.com/go-gormigrate/gormigrate/v2 v2.0.0 diff --git a/go.sum b/go.sum index 819fb160b..5a3a72f02 100644 --- a/go.sum +++ b/go.sum @@ -320,6 +320,8 @@ github.com/flyteorg/flytepropeller v1.1.98 h1:Zk2ENYB9VZRT5tFUIFjm+aCkr0TU2EuyJ5 github.com/flyteorg/flytepropeller v1.1.98/go.mod h1:R0CB6Uzp9F4YyvPmLRE7XyXxDebAPFD+LbHTf07mBzI= github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0= github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s= +github.com/flyteorg/flytestdlib v1.0.17 h1:O+xuCLy1/H/Va4vA1vv/hFG555rCfGNh10ld9yIYumU= +github.com/flyteorg/flytestdlib v1.0.17/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s= github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk= github.com/flyteorg/stow v0.3.6/go.mod h1:5dfBitPM004dwaZdoVylVjxFT4GWAgI0ghAndhNUzCo= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= diff --git a/pkg/server/service.go b/pkg/server/service.go index f3b27416f..051efbfd2 100644 --- a/pkg/server/service.go +++ b/pkg/server/service.go @@ -4,6 +4,8 @@ import ( "context" "crypto/tls" "fmt" + "google.golang.org/grpc/metadata" + "k8s.io/apimachinery/pkg/util/rand" "net" "net/http" "strings" @@ -77,7 +79,8 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c scope promutils.Scope, opts ...grpc.ServerOption) (*grpc.Server, error) { logger.Infof(ctx, "Registering default middleware with blanket auth validation") - pluginRegistry.RegisterDefault(plugins.PluginIDUnaryServiceMiddleware, grpcmiddleware.ChainUnaryServer(auth.BlanketAuthorization, auth.ExecutionUserIdentifierInterceptor)) + pluginRegistry.RegisterDefault(plugins.PluginIDUnaryServiceMiddleware, grpcmiddleware.ChainUnaryServer( + RequestIDInterceptor, auth.BlanketAuthorization, auth.ExecutionUserIdentifierInterceptor)) // Not yet implemented for streaming var chainedUnaryInterceptors grpc.UnaryServerInterceptor @@ -228,11 +231,50 @@ func newHTTPServer(ctx context.Context, cfg *config.ServerConfig, _ *authConfig. return nil, errors.Wrap(err, "error registering signal service") } - mux.Handle("/", gwmux) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + ctx := GetOrGenerateRequestIDForRequest(r) + gwmux.ServeHTTP(w, r.WithContext(ctx)) + }) return mux, nil } +// RequestIDInterceptor is a server interceptor that sets the request id on the context for any incoming calls. +func RequestIDInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { + return handler(GetOrGenerateRequestIDForGRPC(ctx), req) +} + +// GetOrGenerateRequestIDForGRPC returns a context with request id set from the context or from grpc metadata if it exists, +// otherwise it generates a new one. +func GetOrGenerateRequestIDForGRPC(ctx context.Context) context.Context { + if ctx.Value(contextutils.RequestIDKey) != nil { + return ctx + } else if md, exists := metadata.FromIncomingContext(ctx); exists && len(md.Get(contextutils.RequestIDKey.String())) > 0 { + return contextutils.WithRequestID(ctx, md.Get(contextutils.RequestIDKey.String())[0]) + } else { + return contextutils.WithRequestID(ctx, generateRequestID()) + } +} + +// GetOrGenerateRequestIDForRequest returns a context with request id set from the context or from metadata if it exists, +// otherwise it generates a new one. +func GetOrGenerateRequestIDForRequest(req *http.Request) context.Context { + ctx := req.Context() + if ctx.Value(contextutils.RequestIDKey) != nil { + return ctx + } else if md, exists := metadata.FromIncomingContext(ctx); exists && len(md.Get(contextutils.RequestIDKey.String())) > 0 { + return contextutils.WithRequestID(ctx, md.Get(contextutils.RequestIDKey.String())[0]) + } else if req.Header != nil && req.Header.Get(contextutils.RequestIDKey.String()) != "" { + return contextutils.WithRequestID(ctx, req.Header.Get(contextutils.RequestIDKey.String())) + } else { + return contextutils.WithRequestID(ctx, generateRequestID()) + } +} + +func generateRequestID() string { + return "a-" + rand.String(20) +} + func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, cfg *config.ServerConfig, authCfg *authConfig.Config, storageConfig *storage.Config, additionalHandlers map[string]func(http.ResponseWriter, *http.Request), scope promutils.Scope) error { From db7d6bdd6caa9fb002b300c76b0a6aa2788f16b2 Mon Sep 17 00:00:00 2001 From: Haytham Abuelfutuh Date: Thu, 13 Jul 2023 12:49:46 -0700 Subject: [PATCH 2/2] goimports Signed-off-by: Haytham Abuelfutuh --- pkg/server/service.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pkg/server/service.go b/pkg/server/service.go index 051efbfd2..4a7983087 100644 --- a/pkg/server/service.go +++ b/pkg/server/service.go @@ -4,13 +4,14 @@ import ( "context" "crypto/tls" "fmt" - "google.golang.org/grpc/metadata" - "k8s.io/apimachinery/pkg/util/rand" "net" "net/http" "strings" "time" + "google.golang.org/grpc/metadata" + "k8s.io/apimachinery/pkg/util/rand" + "github.com/flyteorg/flytestdlib/contextutils" "github.com/flyteorg/flytestdlib/promutils/labeled"