Skip to content

Commit

Permalink
Don't really need to make header configurable.
Browse files Browse the repository at this point in the history
  • Loading branch information
ananthb committed Oct 8, 2023
1 parent ef2cede commit 321b780
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 18 deletions.
2 changes: 1 addition & 1 deletion cmd/hallpass/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func main() {
defer ssllog.Close()
}

ti := middleware.TLSIdentifier(middleware.RequestContextHeaderName, cert.Namespace)
ti := middleware.TLSIdentifier(cert.Namespace)
hdlr := sundry.RequestLogHandler(ti(reverseProxy))

addr := fmt.Sprintf("%s:%d", config.HallPass.Host, config.HallPass.Port)
Expand Down
6 changes: 3 additions & 3 deletions internal/middleware/tlsident.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ import (
// TLSIdentifier returns a HTTP Handler middleware function that identifies clients using
// TLS client certificates.
// It parses the client certficiate into a RequestContext which is
// JSON-serialised into the headerName header.
func TLSIdentifier(headerName string, namespace uuid.UUID) func(http.Handler) http.Handler {
// JSON-serialised into the request context header.
func TLSIdentifier(namespace uuid.UUID) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
Expand Down Expand Up @@ -61,7 +61,7 @@ func TLSIdentifier(headerName string, namespace uuid.UUID) func(http.Handler) ht
http.Error(w, "unexpected error", http.StatusInternalServerError)
return
}
r.Header.Set(headerName, string(rctxHeader))
r.Header.Set(RequestContextHeaderName, string(rctxHeader))
next.ServeHTTP(w, r)
})
}
Expand Down
10 changes: 4 additions & 6 deletions internal/middleware/tlsident_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ import (
"github.com/google/uuid"
)

const testHeader = "rctx-test"

func TestCertAuthorizerNoTLS(t *testing.T) {
defer func() {
if r := recover(); r == nil {
Expand All @@ -38,7 +36,7 @@ func TestCertAuthorizerNoTLS(t *testing.T) {
defer backendServer.Close()
backendUrl, _ := url.Parse(backendServer.URL)

ti := TLSIdentifier(testHeader, uuid.Nil)(httputil.NewSingleHostReverseProxy(backendUrl))
ti := TLSIdentifier(uuid.Nil)(httputil.NewSingleHostReverseProxy(backendUrl))
rr := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/", nil)
ti.ServeHTTP(rr, request)
Expand Down Expand Up @@ -85,9 +83,9 @@ func TestHofund(t *testing.T) {
// backend server handler checks if request has expected header
backendServer := httptest.NewServer(
http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
rctxVal := r.Header.Get(testHeader)
rctxVal := r.Header.Get(RequestContextHeaderName)
if rctxVal == "" {
t.Errorf("expected %s header in request", testHeader)
t.Errorf("expected %s header in request", RequestContextHeaderName)
}
var rctx AuthorizedRequestContext
if err := json.Unmarshal([]byte(rctxVal), &rctx); err != nil {
Expand All @@ -111,7 +109,7 @@ func TestHofund(t *testing.T) {
t.Errorf("error parsing backedn url %s", err)
}

ti := TLSIdentifier(testHeader, ns)(httputil.NewSingleHostReverseProxy(backendUrl))
ti := TLSIdentifier(ns)(httputil.NewSingleHostReverseProxy(backendUrl))

// TLS server accepts client requests requiring TLS client cert auth
server := httptest.NewUnstartedServer(ti)
Expand Down
8 changes: 4 additions & 4 deletions pkg/asgard/heimdallr.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ func MustFromContext(ctx context.Context) *Identity {
}

// Heimdallr returns a HTTP Handler middleware function that parses an AuthorizedRequestContext
// from headerName. If namespace does not match the parsed one, the request is forbidden.
// The AuthorizedRequestContext is stored in the request context.
// from the request context header. If namespace does not match the parsed one, the
// request is forbidden. The AuthorizedRequestContext is stored in the request context.
//
// If Heimdallr is used in an AWS Lambda Web Adapter powered API server, Bouncer Lambda Authorizer
// must be configured as an authorizer for the API Gateway method.
func Heimdallr(headerName string, namespace uuid.UUID) func(http.Handler) http.Handler {
func Heimdallr(namespace uuid.UUID) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hdr := r.Header.Get(headerName)
hdr := r.Header.Get(middleware.RequestContextHeaderName)
if hdr == "" {
http.Error(w, middleware.ServiceUnavailableMsg, http.StatusServiceUnavailable)
return
Expand Down
7 changes: 3 additions & 4 deletions pkg/asgard/heimdallr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@ import (
"net/http/httptest"
"testing"

"github.com/RealImage/bifrost/internal/middleware"
"github.com/google/uuid"
)

const testHeader = "X-Test-Header"

var testPubKey = &ecdsa.PublicKey{
Curve: elliptic.P256(),
}
Expand Down Expand Up @@ -72,11 +71,11 @@ func TestHeimdallr(t *testing.T) {
t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.headerName == "" {
tc.headerName = testHeader
tc.headerName = middleware.RequestContextHeaderName
}
req.Header.Set(tc.headerName, tc.headerValue)
w := httptest.NewRecorder()
h := Heimdallr(testHeader, tc.expectedNs)
h := Heimdallr(tc.expectedNs)
h(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := MustFromContext(r.Context())
if ns := id.Namespace; ns != tc.expectedNs {
Expand Down

0 comments on commit 321b780

Please sign in to comment.