From 321b7804d68ca2c5dd48f092df6a4984c7384977 Mon Sep 17 00:00:00 2001 From: Ananth Bhaskararaman Date: Mon, 9 Oct 2023 03:48:32 +0530 Subject: [PATCH] Don't really need to make header configurable. --- cmd/hallpass/main.go | 2 +- internal/middleware/tlsident.go | 6 +++--- internal/middleware/tlsident_test.go | 10 ++++------ pkg/asgard/heimdallr.go | 8 ++++---- pkg/asgard/heimdallr_test.go | 7 +++---- 5 files changed, 15 insertions(+), 18 deletions(-) diff --git a/cmd/hallpass/main.go b/cmd/hallpass/main.go index 56c214c..f710317 100644 --- a/cmd/hallpass/main.go +++ b/cmd/hallpass/main.go @@ -72,7 +72,7 @@ func main() { defer ssllog.Close() } - ti := middleware.TLSIdentifier(middleware.RequestContextHeaderName, cert.Namespace) + ti := middleware.TLSIdentifier(cert.Namespace) hdlr := sundry.RequestLogHandler(ti(reverseProxy)) addr := fmt.Sprintf("%s:%d", config.HallPass.Host, config.HallPass.Port) diff --git a/internal/middleware/tlsident.go b/internal/middleware/tlsident.go index dcfe515..2b40895 100644 --- a/internal/middleware/tlsident.go +++ b/internal/middleware/tlsident.go @@ -16,8 +16,8 @@ import ( // TLSIdentifier returns a HTTP Handler middleware function that identifies clients using // TLS client certificates. // It parses the client certficiate into a RequestContext which is -// JSON-serialised into the headerName header. -func TLSIdentifier(headerName string, namespace uuid.UUID) func(http.Handler) http.Handler { +// JSON-serialised into the request context header. +func TLSIdentifier(namespace uuid.UUID) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { @@ -61,7 +61,7 @@ func TLSIdentifier(headerName string, namespace uuid.UUID) func(http.Handler) ht http.Error(w, "unexpected error", http.StatusInternalServerError) return } - r.Header.Set(headerName, string(rctxHeader)) + r.Header.Set(RequestContextHeaderName, string(rctxHeader)) next.ServeHTTP(w, r) }) } diff --git a/internal/middleware/tlsident_test.go b/internal/middleware/tlsident_test.go index b762df7..b00a68f 100644 --- a/internal/middleware/tlsident_test.go +++ b/internal/middleware/tlsident_test.go @@ -25,8 +25,6 @@ import ( "github.com/google/uuid" ) -const testHeader = "rctx-test" - func TestCertAuthorizerNoTLS(t *testing.T) { defer func() { if r := recover(); r == nil { @@ -38,7 +36,7 @@ func TestCertAuthorizerNoTLS(t *testing.T) { defer backendServer.Close() backendUrl, _ := url.Parse(backendServer.URL) - ti := TLSIdentifier(testHeader, uuid.Nil)(httputil.NewSingleHostReverseProxy(backendUrl)) + ti := TLSIdentifier(uuid.Nil)(httputil.NewSingleHostReverseProxy(backendUrl)) rr := httptest.NewRecorder() request := httptest.NewRequest(http.MethodGet, "/", nil) ti.ServeHTTP(rr, request) @@ -85,9 +83,9 @@ func TestHofund(t *testing.T) { // backend server handler checks if request has expected header backendServer := httptest.NewServer( http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { - rctxVal := r.Header.Get(testHeader) + rctxVal := r.Header.Get(RequestContextHeaderName) if rctxVal == "" { - t.Errorf("expected %s header in request", testHeader) + t.Errorf("expected %s header in request", RequestContextHeaderName) } var rctx AuthorizedRequestContext if err := json.Unmarshal([]byte(rctxVal), &rctx); err != nil { @@ -111,7 +109,7 @@ func TestHofund(t *testing.T) { t.Errorf("error parsing backedn url %s", err) } - ti := TLSIdentifier(testHeader, ns)(httputil.NewSingleHostReverseProxy(backendUrl)) + ti := TLSIdentifier(ns)(httputil.NewSingleHostReverseProxy(backendUrl)) // TLS server accepts client requests requiring TLS client cert auth server := httptest.NewUnstartedServer(ti) diff --git a/pkg/asgard/heimdallr.go b/pkg/asgard/heimdallr.go index adcf7f0..d7b5e33 100644 --- a/pkg/asgard/heimdallr.go +++ b/pkg/asgard/heimdallr.go @@ -65,15 +65,15 @@ func MustFromContext(ctx context.Context) *Identity { } // Heimdallr returns a HTTP Handler middleware function that parses an AuthorizedRequestContext -// from headerName. If namespace does not match the parsed one, the request is forbidden. -// The AuthorizedRequestContext is stored in the request context. +// from the request context header. If namespace does not match the parsed one, the +// request is forbidden. The AuthorizedRequestContext is stored in the request context. // // If Heimdallr is used in an AWS Lambda Web Adapter powered API server, Bouncer Lambda Authorizer // must be configured as an authorizer for the API Gateway method. -func Heimdallr(headerName string, namespace uuid.UUID) func(http.Handler) http.Handler { +func Heimdallr(namespace uuid.UUID) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - hdr := r.Header.Get(headerName) + hdr := r.Header.Get(middleware.RequestContextHeaderName) if hdr == "" { http.Error(w, middleware.ServiceUnavailableMsg, http.StatusServiceUnavailable) return diff --git a/pkg/asgard/heimdallr_test.go b/pkg/asgard/heimdallr_test.go index 3d1d2a9..2553db2 100644 --- a/pkg/asgard/heimdallr_test.go +++ b/pkg/asgard/heimdallr_test.go @@ -13,11 +13,10 @@ import ( "net/http/httptest" "testing" + "github.com/RealImage/bifrost/internal/middleware" "github.com/google/uuid" ) -const testHeader = "X-Test-Header" - var testPubKey = &ecdsa.PublicKey{ Curve: elliptic.P256(), } @@ -72,11 +71,11 @@ func TestHeimdallr(t *testing.T) { t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) if tc.headerName == "" { - tc.headerName = testHeader + tc.headerName = middleware.RequestContextHeaderName } req.Header.Set(tc.headerName, tc.headerValue) w := httptest.NewRecorder() - h := Heimdallr(testHeader, tc.expectedNs) + h := Heimdallr(tc.expectedNs) h(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id := MustFromContext(r.Context()) if ns := id.Namespace; ns != tc.expectedNs {