diff --git a/proxy/credentials_issuer_headers.go b/proxy/credentials_issuer_headers.go index 43b1cbd8ea..c3e7f060d4 100644 --- a/proxy/credentials_issuer_headers.go +++ b/proxy/credentials_issuer_headers.go @@ -19,7 +19,7 @@ type CredentialsHeaders struct { func NewCredentialsIssuerHeaders() *CredentialsHeaders { return &CredentialsHeaders{ - rulesCache: template.New("rules"), + rulesCache: template.New("rules").Option("missingkey=zero"), } } @@ -38,6 +38,8 @@ func (a *CredentialsHeaders) Issue(r *http.Request, session *AuthenticationSessi return errors.WithStack(err) } + convertedSession := convertSession(session) + for hdr, templateString := range cfg { var tmpl *template.Template var err error @@ -52,7 +54,7 @@ func (a *CredentialsHeaders) Issue(r *http.Request, session *AuthenticationSessi } headerValue := bytes.Buffer{} - err = tmpl.Execute(&headerValue, session) + err = tmpl.Execute(&headerValue, convertedSession) if err != nil { return errors.Wrapf(err, `error executing header template "%s" in rule "%s"`, templateString, rl.ID) } @@ -61,3 +63,21 @@ func (a *CredentialsHeaders) Issue(r *http.Request, session *AuthenticationSessi return nil } + +type authSession struct { + Subject string + Extra map[string]string +} + +func convertSession(in *AuthenticationSession) *authSession { + out := authSession{ + Subject: in.Subject, + Extra: map[string]string{}, + } + + for k, v := range in.Extra { + out.Extra[k] = fmt.Sprintf("%s", v) + } + + return &out +} diff --git a/proxy/credentials_issuer_headers_test.go b/proxy/credentials_issuer_headers_test.go index 2bc603b24f..37cdb5ffc2 100644 --- a/proxy/credentials_issuer_headers_test.go +++ b/proxy/credentials_issuer_headers_test.go @@ -11,6 +11,7 @@ import ( ) func TestCredentialsIssuerHeaders(t *testing.T) { + var testMap = map[string]struct { Session *AuthenticationSession Rule *rule.Rule @@ -18,13 +19,41 @@ func TestCredentialsIssuerHeaders(t *testing.T) { Request *http.Request Match http.Header }{ - "Subject": { + "Simple Subject": { Session: &AuthenticationSession{Subject: "foo"}, Rule: &rule.Rule{ID: "test-rule"}, Config: json.RawMessage([]byte(`{"X-User": "{{ .Subject }}"}`)), Request: &http.Request{Header: http.Header{}}, Match: http.Header{"X-User": []string{"foo"}}, }, + "Complex Subject": { + Session: &AuthenticationSession{Subject: "foo"}, + Rule: &rule.Rule{ID: "test-rule2"}, + Config: json.RawMessage([]byte(`{"X-User": "realm:resources:users:{{ .Subject }}"}`)), + Request: &http.Request{Header: http.Header{}}, + Match: http.Header{"X-User": []string{"realm:resources:users:foo"}}, + }, + "Subject & Extras": { + Session: &AuthenticationSession{Subject: "foo", Extra: map[string]interface{}{"iss": "issuer", "aud": "audience"}}, + Rule: &rule.Rule{ID: "test-rule3"}, + Config: json.RawMessage([]byte(`{"X-User": "{{ .Subject }}", "X-Issuer": "{{ .Extra.iss }}", "X-Audience": "{{ .Extra.aud }}"}`)), + Request: &http.Request{Header: http.Header{}}, + Match: http.Header{"X-User": []string{"foo"}, "X-Issuer": []string{"issuer"}, "X-Audience": []string{"audience"}}, + }, + "All In One Header": { + Session: &AuthenticationSession{Subject: "foo", Extra: map[string]interface{}{"iss": "issuer", "aud": "audience"}}, + Rule: &rule.Rule{ID: "test-rule4"}, + Config: json.RawMessage([]byte(`{"X-Kitchen-Sink": "{{ .Subject }} {{ .Extra.iss }} {{ .Extra.aud }}"}`)), + Request: &http.Request{Header: http.Header{}}, + Match: http.Header{"X-Kitchen-Sink": []string{"foo issuer audience"}}, + }, + "Scrub Incoming Headers": { + Session: &AuthenticationSession{Subject: "anonymous"}, + Rule: &rule.Rule{ID: "test-rule5"}, + Config: json.RawMessage([]byte(`{"X-User": "{{ .Subject }}", "X-Issuer": "{{ .Extra.iss }}", "X-Audience": "{{ .Extra.aud }}"}`)), + Request: &http.Request{Header: http.Header{"X-User": []string{"admin"}, "X-Issuer": []string{"issuer"}, "X-Audience": []string{"audience"}}}, + Match: http.Header{"X-User": []string{"anonymous"}, "X-Issuer": []string{""}, "X-Audience": []string{""}}, + }, } for testName, specs := range testMap { @@ -41,7 +70,15 @@ func TestCredentialsIssuerHeaders(t *testing.T) { require.NoError(t, issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule)) // Output request headers must match test specs - assert.Equal(t, specs.Request.Header, specs.Match) + assert.Equal(t, specs.Match, specs.Request.Header) }) } + + /* + t.Run("Test Template Caching", func(t *testing.T) { + for testName, specs := range testMap { + + } + }) + */ }