diff --git a/handler/oauth2/revocation.go b/handler/oauth2/revocation.go index 0a3e3fc4..f83c891e 100644 --- a/handler/oauth2/revocation.go +++ b/handler/oauth2/revocation.go @@ -17,21 +17,47 @@ type TokenRevocationHandler struct { TokenRevocationStorage TokenRevocationStorage RefreshTokenStrategy RefreshTokenStrategy AccessTokenStrategy AccessTokenStrategy + + // RevokeRefreshTokenOnRequestOnly is used to indicate if the refresh token should be revoked only if + // token passed to the request is a refresh token. The default behavior revokes both the access and refresh + // tokens if the token passed to the request is either. + // + // [RFC7009 - Section 2.1] Depending on the authorization server's revocation policy, the + // revocation of a particular token may cause the revocation of related + // tokens and the underlying authorization grant. If the particular + // token is a refresh token and the authorization server supports the + // revocation of access tokens, then the authorization server SHOULD + // also invalidate all access tokens based on the same authorization + // grant (see Implementation Note). If the token passed to the request + // is an access token, the server MAY revoke the respective refresh + // token as well. + RevokeRefreshTokenOnRequestOnly bool } // RevokeToken implements https://tools.ietf.org/html/rfc7009#section-2.1 // The token type hint indicates which token type check should be performed first. func (r *TokenRevocationHandler) RevokeToken(ctx context.Context, token string, tokenType fosite.TokenType, client fosite.Client) error { + actualTokenType := tokenType discoveryFuncs := []func() (request fosite.Requester, err error){ func() (request fosite.Requester, err error) { // Refresh token signature := r.RefreshTokenStrategy.RefreshTokenSignature(ctx, token) - return r.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil) + ar, err := r.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil) + if err == nil { + actualTokenType = fosite.RefreshToken + } + + return ar, err }, func() (request fosite.Requester, err error) { // Access token signature := r.AccessTokenStrategy.AccessTokenSignature(ctx, token) - return r.TokenRevocationStorage.GetAccessTokenSession(ctx, signature, nil) + ar, err := r.TokenRevocationStorage.GetAccessTokenSession(ctx, signature, nil) + if err == nil { + actualTokenType = fosite.AccessToken + } + + return ar, err }, } @@ -55,7 +81,10 @@ func (r *TokenRevocationHandler) RevokeToken(ctx context.Context, token string, } requestID := ar.GetID() - err1 = r.TokenRevocationStorage.RevokeRefreshToken(ctx, requestID) + if !r.RevokeRefreshTokenOnRequestOnly || actualTokenType == fosite.RefreshToken { + err1 = r.TokenRevocationStorage.RevokeRefreshToken(ctx, requestID) + } + err2 = r.TokenRevocationStorage.RevokeAccessToken(ctx, requestID) return storeErrorsToRevocationError(err1, err2) diff --git a/handler/oauth2/revocation_test.go b/handler/oauth2/revocation_test.go index b657bed3..1eca13d3 100644 --- a/handler/oauth2/revocation_test.go +++ b/handler/oauth2/revocation_test.go @@ -246,3 +246,233 @@ func TestRevokeToken(t *testing.T) { }) } } + +func TestRevokeTokenWithRefreshTokenOnRequestOnly(t *testing.T) { + ctrl := gomock.NewController(t) + store := internal.NewMockTokenRevocationStorage(ctrl) + atStrat := internal.NewMockAccessTokenStrategy(ctrl) + rtStrat := internal.NewMockRefreshTokenStrategy(ctrl) + ar := internal.NewMockAccessRequester(ctrl) + defer ctrl.Finish() + + h := TokenRevocationHandler{ + TokenRevocationStorage: store, + RefreshTokenStrategy: rtStrat, + AccessTokenStrategy: atStrat, + RevokeRefreshTokenOnRequestOnly: true, + } + + var token string + var tokenType fosite.TokenType + + for k, c := range []struct { + description string + mock func() + expectErr error + client fosite.Client + }{ + { + description: "should fail - token was issued to another client", + expectErr: fosite.ErrUnauthorizedClient, + client: &fosite.DefaultClient{ID: "bar"}, + mock: func() { + token = "foo" + tokenType = fosite.RefreshToken + rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token) + store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(ar, nil) + ar.EXPECT().GetClient().Return(&fosite.DefaultClient{ID: "foo"}) + }, + }, + { + description: "should pass - refresh token discovery first; refresh token found", + expectErr: nil, + client: &fosite.DefaultClient{ID: "bar"}, + mock: func() { + token = "foo" + tokenType = fosite.RefreshToken + rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token) + store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(ar, nil) + ar.EXPECT().GetID().Return("refresh token discovery first; refresh token found") + ar.EXPECT().GetClient().Return(&fosite.DefaultClient{ID: "bar"}) + store.EXPECT().RevokeRefreshToken(gomock.Any(), gomock.Any()) + store.EXPECT().RevokeAccessToken(gomock.Any(), gomock.Any()) + }, + }, + { + description: "should pass - access token discovery first; access token found", + expectErr: nil, + client: &fosite.DefaultClient{ID: "bar"}, + mock: func() { + token = "foo" + tokenType = fosite.AccessToken + atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token) + store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(ar, nil) + ar.EXPECT().GetID().Return("access token discovery first; access token found") + ar.EXPECT().GetClient().Return(&fosite.DefaultClient{ID: "bar"}) + store.EXPECT().RevokeRefreshToken(gomock.Any(), gomock.Any()) + store.EXPECT().RevokeAccessToken(gomock.Any(), gomock.Any()) + }, + }, + { + description: "should pass - refresh token discovery first; refresh token not found", + expectErr: nil, + client: &fosite.DefaultClient{ID: "bar"}, + mock: func() { + token = "foo" + tokenType = fosite.AccessToken + atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token) + store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound) + + rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token) + store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(ar, nil) + ar.EXPECT().GetID().Return("refresh token discovery first; refresh token not found") + ar.EXPECT().GetClient().Return(&fosite.DefaultClient{ID: "bar"}) + store.EXPECT().RevokeAccessToken(gomock.Any(), gomock.Any()) + }, + }, + { + description: "should pass - access token discovery first; access token not found", + expectErr: nil, + client: &fosite.DefaultClient{ID: "bar"}, + mock: func() { + token = "foo" + tokenType = fosite.RefreshToken + rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token) + store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound) + + atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token) + store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(ar, nil) + ar.EXPECT().GetID().Return("access token discovery first; access token not found") + ar.EXPECT().GetClient().Return(&fosite.DefaultClient{ID: "bar"}) + store.EXPECT().RevokeAccessToken(gomock.Any(), gomock.Any()) + }, + }, + { + description: "should pass - refresh token discovery first; both tokens not found", + expectErr: nil, + client: &fosite.DefaultClient{ID: "bar"}, + mock: func() { + token = "foo" + tokenType = fosite.RefreshToken + rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token) + store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound) + + atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token) + store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound) + }, + }, + { + description: "should pass - access token discovery first; both tokens not found", + expectErr: nil, + client: &fosite.DefaultClient{ID: "bar"}, + mock: func() { + token = "foo" + tokenType = fosite.AccessToken + atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token) + store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound) + + rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token) + store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound) + }, + }, + { + + description: "should pass - refresh token discovery first; refresh token is inactive", + expectErr: nil, + client: &fosite.DefaultClient{ID: "bar"}, + mock: func() { + token = "foo" + tokenType = fosite.RefreshToken + rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token) + store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrInactiveToken) + + atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token) + store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound) + }, + }, + { + description: "should pass - access token discovery first; refresh token is inactive", + expectErr: nil, + client: &fosite.DefaultClient{ID: "bar"}, + mock: func() { + token = "foo" + tokenType = fosite.AccessToken + atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token) + store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound) + + rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token) + store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrInactiveToken) + }, + }, + { + description: "should fail - store error for access token get", + expectErr: fosite.ErrTemporarilyUnavailable, + client: &fosite.DefaultClient{ID: "bar"}, + mock: func() { + token = "foo" + tokenType = fosite.AccessToken + atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token) + store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("random error")) + + rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token) + store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound) + }, + }, + { + description: "should fail - store error for refresh token get", + expectErr: fosite.ErrTemporarilyUnavailable, + client: &fosite.DefaultClient{ID: "bar"}, + mock: func() { + token = "foo" + tokenType = fosite.RefreshToken + atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token) + store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fosite.ErrNotFound) + + rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token) + store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("random error")) + }, + }, + { + description: "should fail - store error for access token revoke", + expectErr: fosite.ErrTemporarilyUnavailable, + client: &fosite.DefaultClient{ID: "bar"}, + mock: func() { + token = "foo" + tokenType = fosite.AccessToken + atStrat.EXPECT().AccessTokenSignature(gomock.Any(), token) + store.EXPECT().GetAccessTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(ar, nil) + + ar.EXPECT().GetID().Return("access token revoke") + ar.EXPECT().GetClient().Return(&fosite.DefaultClient{ID: "bar"}) + store.EXPECT().RevokeAccessToken(gomock.Any(), gomock.Any()).Return(fmt.Errorf("random error")) + }, + }, + { + description: "should fail - store error for refresh token revoke", + expectErr: fosite.ErrTemporarilyUnavailable, + client: &fosite.DefaultClient{ID: "bar"}, + mock: func() { + token = "foo" + tokenType = fosite.RefreshToken + rtStrat.EXPECT().RefreshTokenSignature(gomock.Any(), token) + store.EXPECT().GetRefreshTokenSession(gomock.Any(), gomock.Any(), gomock.Any()).Return(ar, nil) + + ar.EXPECT().GetID() + ar.EXPECT().GetClient().Return(&fosite.DefaultClient{ID: "bar"}) + store.EXPECT().RevokeRefreshToken(gomock.Any(), gomock.Any()).Return(fmt.Errorf("random error")) + store.EXPECT().RevokeAccessToken(gomock.Any(), gomock.Any()).Return(fosite.ErrNotFound) + }, + }, + } { + t.Run(fmt.Sprintf("case=%d/description=%s", k, c.description), func(tt *testing.T) { + c.mock() + err := h.RevokeToken(context.Background(), token, tokenType, c.client) + + if c.expectErr != nil { + require.EqualError(tt, err, c.expectErr.Error()) + } else { + require.NoError(tt, err) + } + }) + } +}