diff --git a/auth/authzserver/provider.go b/auth/authzserver/provider.go index a6124d4c7..1aa88ac1e 100644 --- a/auth/authzserver/provider.go +++ b/auth/authzserver/provider.go @@ -133,12 +133,17 @@ func (p Provider) ValidateAccessToken(ctx context.Context, expectedAudience, tok func verifyClaims(expectedAudience sets.String, claimsRaw map[string]interface{}) (interfaces.IdentityContext, error) { claims := jwtx.ParseMapStringInterfaceClaims(claimsRaw) - if len(claims.Audience) != 1 { - return nil, fmt.Errorf("expected exactly one granted audience. found [%v]", len(claims.Audience)) + + foundAudIndex := -1 + for audIndex, aud := range claims.Audience { + if expectedAudience.Has(aud) { + foundAudIndex = audIndex + break + } } - if !expectedAudience.Has(claims.Audience[0]) { - return nil, fmt.Errorf("invalid audience [%v]", claims.Audience[0]) + if foundAudIndex < 0 { + return nil, fmt.Errorf("invalid audience [%v]", claims) } userInfo := &service.UserInfoResponse{} @@ -170,7 +175,7 @@ func verifyClaims(expectedAudience sets.String, claimsRaw map[string]interface{} scopes.Insert(auth.ScopeAll) } - return auth.NewIdentityContext(claims.Audience[0], claims.Subject, clientID, claims.IssuedAt, scopes, userInfo), nil + return auth.NewIdentityContext(claims.Audience[foundAudIndex], claims.Subject, clientID, claims.IssuedAt, scopes, userInfo), nil } // NewProvider creates a new OAuth2 Provider that is able to do OAuth 2-legged and 3-legged flows. It'll lookup diff --git a/auth/authzserver/provider_test.go b/auth/authzserver/provider_test.go index cc47edd02..9ebfd32a8 100644 --- a/auth/authzserver/provider_test.go +++ b/auth/authzserver/provider_test.go @@ -235,5 +235,49 @@ func Test_verifyClaims(t *testing.T) { assert.Equal(t, sets.NewString("all", "offline"), identityCtx.Scopes()) assert.Equal(t, "my-client", identityCtx.AppID()) assert.Equal(t, "123", identityCtx.UserID()) + assert.Equal(t, "https://myserver", identityCtx.Audience()) }) + + t.Run("Multiple audience", func(t *testing.T) { + identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"), + map[string]interface{}{ + "aud": []string{"https://myserver"}, + "user_info": map[string]interface{}{ + "preferred_name": "John Doe", + }, + "sub": "123", + "client_id": "my-client", + "scp": []interface{}{"all", "offline"}, + }) + + assert.NoError(t, err) + assert.Equal(t, "https://myserver", identityCtx.Audience()) + }) + + t.Run("No matching audience", func(t *testing.T) { + _, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"), + map[string]interface{}{ + "aud": []string{"https://myserver3"}, + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid audience") + }) + + t.Run("Use first matching audience", func(t *testing.T) { + identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2", "https://myserver3"), + map[string]interface{}{ + "aud": []string{"https://myserver", "https://myserver2"}, + "user_info": map[string]interface{}{ + "preferred_name": "John Doe", + }, + "sub": "123", + "client_id": "my-client", + "scp": []interface{}{"all", "offline"}, + }) + + assert.NoError(t, err) + assert.Equal(t, "https://myserver", identityCtx.Audience()) + }) + } diff --git a/auth/interfaces/context.go b/auth/interfaces/context.go index a703b4d3a..35d0aed4b 100644 --- a/auth/interfaces/context.go +++ b/auth/interfaces/context.go @@ -59,6 +59,7 @@ type AuthenticationContext interface { // to the platform. type IdentityContext interface { UserID() string + Audience() string AppID() string UserInfo() *service.UserInfoResponse AuthenticatedAt() time.Time diff --git a/auth/interfaces/mocks/identity_context.go b/auth/interfaces/mocks/identity_context.go index 2b43d5327..40a810bf4 100644 --- a/auth/interfaces/mocks/identity_context.go +++ b/auth/interfaces/mocks/identity_context.go @@ -51,6 +51,38 @@ func (_m *IdentityContext) AppID() string { return r0 } +type IdentityContext_Audience struct { + *mock.Call +} + +func (_m IdentityContext_Audience) Return(_a0 string) *IdentityContext_Audience { + return &IdentityContext_Audience{Call: _m.Call.Return(_a0)} +} + +func (_m *IdentityContext) OnAudience() *IdentityContext_Audience { + c := _m.On("Audience") + return &IdentityContext_Audience{Call: c} +} + +func (_m *IdentityContext) OnAudienceMatch(matchers ...interface{}) *IdentityContext_Audience { + c := _m.On("Audience", matchers...) + return &IdentityContext_Audience{Call: c} +} + +// Audience provides a mock function with given fields: +func (_m *IdentityContext) Audience() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + type IdentityContext_AuthenticatedAt struct { *mock.Call }