From 3b564b20263a51a95aacb733a950330186b41114 Mon Sep 17 00:00:00 2001 From: Steve Riesenberg Date: Tue, 28 Sep 2021 16:17:36 -0500 Subject: [PATCH] Add parameters converter support to AbstractWebClientReactiveOAuth2AccessTokenResponseClient This adds support for configuring NimbusJwtClientAuthenticationParametersConverter to any AbstractWebClientReactiveOAuth2AccessTokenResponseClient as an additional parameters converter, which in turns adds reactive support for jwt client authentication. Closes gh-10146 --- ...activeOAuth2AccessTokenResponseClient.java | 73 ++++++++- ...orizationCodeTokenResponseClientTests.java | 145 +++++++++++++++++- ...ntCredentialsTokenResponseClientTests.java | 138 +++++++++++++++++ ...tiveJwtBearerTokenResponseClientTests.java | 49 ++++++ ...ctivePasswordTokenResponseClientTests.java | 145 ++++++++++++++++++ ...eRefreshTokenTokenResponseClientTests.java | 145 ++++++++++++++++++ 6 files changed, 690 insertions(+), 5 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java index ebbe9c8ec33..c502bc701ab 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java @@ -35,6 +35,8 @@ import org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; import org.springframework.web.reactive.function.BodyExtractor; import org.springframework.web.reactive.function.BodyInserters; @@ -70,6 +72,8 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient headersConverter = this::populateTokenRequestHeaders; + private Converter> parametersConverter = this::populateTokenRequestParameters; + private BodyExtractor, ReactiveHttpInputMessage> bodyExtractor = OAuth2BodyExtractors .oauth2AccessTokenResponse(); @@ -132,7 +136,19 @@ private static String encodeClientCredential(String clientCredential) { } /** - * Creates and returns the body for the token request. + * Populates default parameters for the token request. + * @param grantRequest the grant request + * @return the parameters populated for the token request. + */ + private MultiValueMap populateTokenRequestParameters(T grantRequest) { + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue()); + return parameters; + } + + /** + * Combine the results of {@code parametersConverter} and + * {@link #populateTokenRequestBody}. * *

* This method pre-populates the body with some standard properties, and then @@ -144,9 +160,8 @@ private static String encodeClientCredential(String clientCredential) { * @return the body for the token request. */ private BodyInserters.FormInserter createTokenRequestBody(T grantRequest) { - BodyInserters.FormInserter body = BodyInserters.fromFormData(OAuth2ParameterNames.GRANT_TYPE, - grantRequest.getGrantType().getValue()); - return populateTokenRequestBody(grantRequest, body); + MultiValueMap parameters = getParametersConverter().convert(grantRequest); + return populateTokenRequestBody(grantRequest, BodyInserters.fromFormData(parameters)); } /** @@ -296,6 +311,56 @@ public final void addHeadersConverter(Converter headersConverter }; } + /** + * Returns the {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap} + * used in the OAuth 2.0 Access Token Request body. + * @return the {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} to {@link MultiValueMap} + */ + final Converter> getParametersConverter() { + return this.parametersConverter; + } + + /** + * Sets the {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap} + * used in the OAuth 2.0 Access Token Request body. + * @param parametersConverter the {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} to {@link MultiValueMap} + * @since 5.6 + */ + public final void setParametersConverter(Converter> parametersConverter) { + Assert.notNull(parametersConverter, "parametersConverter cannot be null"); + this.parametersConverter = parametersConverter; + } + + /** + * Add (compose) the provided {@code parametersConverter} to the current + * {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap} + * used in the OAuth 2.0 Access Token Request body. + * @param parametersConverter the {@link Converter} to add (compose) to the current + * {@link Converter} used for converting the + * {@link AbstractOAuth2AuthorizationGrantRequest} to a {@link MultiValueMap} + * @since 5.6 + */ + public final void addParametersConverter(Converter> parametersConverter) { + Assert.notNull(parametersConverter, "parametersConverter cannot be null"); + Converter> currentParametersConverter = this.parametersConverter; + this.parametersConverter = (authorizationGrantRequest) -> { + MultiValueMap parameters = currentParametersConverter.convert(authorizationGrantRequest); + if (parameters == null) { + parameters = new LinkedMultiValueMap<>(); + } + MultiValueMap parametersToAdd = parametersConverter.convert(authorizationGrantRequest); + if (parametersToAdd != null) { + parameters.addAll(parametersToAdd); + } + return parameters; + }; + } + /** * Sets the {@link BodyExtractor} that will be used to decode the * {@link OAuth2AccessTokenResponse} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java index 34640d9a353..1502f517de8 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java @@ -16,11 +16,16 @@ package org.springframework.security.oauth2.client.endpoint; +import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.function.Function; +import javax.crypto.spec.SecretKeySpec; + +import com.nimbusds.jose.jwk.JWK; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; @@ -36,6 +41,7 @@ import org.springframework.http.ReactiveHttpInputMessage; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; @@ -44,6 +50,10 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import org.springframework.web.reactive.function.BodyExtractor; import org.springframework.web.reactive.function.client.WebClient; @@ -112,6 +122,75 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t assertThat(accessTokenResponse.getAdditionalParameters()).containsEntry("custom_parameter_2", "custom-value-2"); } + @Test + public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + // @formatter:off + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY) + .build(); + // @formatter:on + + // Configure Jwt client authentication converter + SecretKeySpec secretKey = new SecretKeySpec( + clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256"); + JWK jwk = TestJwks.jwk(secretKey).build(); + Function jwkResolver = (registration) -> jwk; + configureJwtClientAuthenticationConverter(jwkResolver); + + this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest(clientRegistration)).block(); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=authorization_code", + "client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer", + "client_assertion="); + } + + @Test + public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + // @formatter:off + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT) + .build(); + // @formatter:on + + // Configure Jwt client authentication converter + JWK jwk = TestJwks.DEFAULT_RSA_JWK; + Function jwkResolver = (registration) -> jwk; + configureJwtClientAuthenticationConverter(jwkResolver); + + this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest(clientRegistration)).block(); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=authorization_code", + "client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer", + "client_assertion="); + } + + private void configureJwtClientAuthenticationConverter(Function jwkResolver) { + NimbusJwtClientAuthenticationParametersConverter jwtClientAuthenticationConverter = new NimbusJwtClientAuthenticationParametersConverter<>( + jwkResolver); + this.tokenResponseClient.addParametersConverter(jwtClientAuthenticationConverter); + } + // @Test // public void // getTokenResponseWhenRedirectUriMalformedThenThrowIllegalArgumentException() throws @@ -261,7 +340,10 @@ public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenReturnAcce } private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest() { - ClientRegistration registration = this.clientRegistration.build(); + return authorizationCodeGrantRequest(this.clientRegistration.build()); + } + + private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest(ClientRegistration registration) { OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() .clientId(registration.getClientId()).state("state") .authorizationUri(registration.getProviderDetails().getAuthorizationUri()) @@ -414,6 +496,67 @@ public void convertWhenHeadersConverterSetThenCalled() throws Exception { .isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); } + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void convertWhenParametersConverterAddedThenCalled() throws Exception { + OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest(); + Converter> addedParametersConverter = mock( + Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(addedParametersConverter.convert(request)).willReturn(parameters); + this.tokenResponseClient.addParametersConverter(addedParametersConverter); + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" + + " \"token_type\":\"bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" + + "}"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.tokenResponseClient.getTokenResponse(request).block(); + verify(addedParametersConverter).convert(request); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=authorization_code", + "custom-parameter-name=custom-parameter-value"); + } + + @Test + public void convertWhenParametersConverterSetThenCalled() throws Exception { + OAuth2AuthorizationCodeGrantRequest request = authorizationCodeGrantRequest(); + Converter> parametersConverter = mock( + Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(request)).willReturn(parameters); + this.tokenResponseClient.setParametersConverter(parametersConverter); + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" + + " \"token_type\":\"bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" + + "}"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.tokenResponseClient.getTokenResponse(request).block(); + verify(parametersConverter).convert(request); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value"); + } + // gh-10260 @Test public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java index a2a3fc0f7fe..0b458ac7cea 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java @@ -20,7 +20,11 @@ import java.nio.charset.StandardCharsets; import java.util.Base64; import java.util.Collections; +import java.util.function.Function; +import javax.crypto.spec.SecretKeySpec; + +import com.nimbusds.jose.jwk.JWK; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; @@ -39,6 +43,10 @@ import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import org.springframework.web.reactive.function.BodyExtractor; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClientResponseException; @@ -152,6 +160,75 @@ public void getTokenResponseWhenPostThenSuccess() throws Exception { "grant_type=client_credentials&client_id=client-id&client_secret=client-secret&scope=read%3Auser"); } + @Test + public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception { + // @formatter:off + enqueueJson("{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}"); + // @formatter:on + + // @formatter:off + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY) + .build(); + // @formatter:on + + // Configure Jwt client authentication converter + SecretKeySpec secretKey = new SecretKeySpec( + clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256"); + JWK jwk = TestJwks.jwk(secretKey).build(); + Function jwkResolver = (registration) -> jwk; + configureJwtClientAuthenticationConverter(jwkResolver); + + OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + this.client.getTokenResponse(request).block(); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=client_credentials", + "client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer", + "client_assertion="); + } + + @Test + public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception { + // @formatter:off + enqueueJson("{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}"); + // @formatter:on + + // @formatter:off + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT) + .build(); + // @formatter:on + + // Configure Jwt client authentication converter + JWK jwk = TestJwks.DEFAULT_RSA_JWK; + Function jwkResolver = (registration) -> jwk; + configureJwtClientAuthenticationConverter(jwkResolver); + + OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + this.client.getTokenResponse(request).block(); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=client_credentials", + "client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer", + "client_assertion="); + } + + private void configureJwtClientAuthenticationConverter(Function jwkResolver) { + NimbusJwtClientAuthenticationParametersConverter jwtClientAuthenticationConverter = new NimbusJwtClientAuthenticationParametersConverter<>( + jwkResolver); + this.client.addParametersConverter(jwtClientAuthenticationConverter); + } + @Test public void getTokenResponseWhenNoScopeThenClientRegistrationScopesDefaulted() { ClientRegistration registration = this.clientRegistration.build(); @@ -285,6 +362,67 @@ public void convertWhenHeadersConverterSetThenCalled() throws Exception { .isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); } + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.client.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.client.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void convertWhenParametersConverterAddedThenCalled() throws Exception { + OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest( + this.clientRegistration.build()); + Converter> addedParametersConverter = mock( + Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(addedParametersConverter.convert(request)).willReturn(parameters); + this.client.addParametersConverter(addedParametersConverter); + // @formatter:off + enqueueJson("{\n" + + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" + + " \"token_type\":\"bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" + + "}"); + // @formatter:on + this.client.getTokenResponse(request).block(); + verify(addedParametersConverter).convert(request); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=client_credentials", + "custom-parameter-name=custom-parameter-value"); + } + + @Test + public void convertWhenParametersConverterSetThenCalled() throws Exception { + OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest( + this.clientRegistration.build()); + Converter> parametersConverter = mock( + Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(request)).willReturn(parameters); + this.client.setParametersConverter(parametersConverter); + // @formatter:off + enqueueJson("{\n" + + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" + + " \"token_type\":\"bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" + + "}"); + // @formatter:on + this.client.getTokenResponse(request).block(); + verify(parametersConverter).convert(request); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value"); + } + // gh-10260 @Test public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClientTests.java index 9470556046b..31999665396 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveJwtBearerTokenResponseClientTests.java @@ -40,6 +40,8 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.TestJwts; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import org.springframework.web.reactive.function.BodyExtractor; import org.springframework.web.reactive.function.client.WebClient; @@ -228,6 +230,53 @@ public void getTokenResponseWhenHeadersConverterAddedThenCalled() throws Excepti assertThat(actualRequest.getHeader("custom-header-name")).isEqualTo("custom-header-value"); } + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.client.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.client.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void convertWhenParametersConverterAddedThenCalled() throws Exception { + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + Converter> addedParametersConverter = mock( + Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(addedParametersConverter.convert(request)).willReturn(parameters); + this.client.addParametersConverter(addedParametersConverter); + enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE); + this.client.getTokenResponse(request).block(); + verify(addedParametersConverter).convert(request); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getBody().readUtf8()).contains( + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer", + "custom-parameter-name=custom-parameter-value"); + } + + @Test + public void convertWhenParametersConverterSetThenCalled() throws Exception { + ClientRegistration clientRegistration = this.clientRegistration.build(); + JwtBearerGrantRequest request = new JwtBearerGrantRequest(clientRegistration, this.jwtAssertion); + Converter> parametersConverter = mock(Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(request)).willReturn(parameters); + this.client.setParametersConverter(parametersConverter); + enqueueJson(DEFAULT_ACCESS_TOKEN_RESPONSE); + this.client.getTokenResponse(request).block(); + verify(parametersConverter).convert(request); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value"); + } + @Test public void getTokenResponseWhenBodyExtractorSetThenCalled() { BodyExtractor, ReactiveHttpInputMessage> bodyExtractor = mock( diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java index 5b0118ebf9f..b6d1cd186f6 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java @@ -16,9 +16,14 @@ package org.springframework.security.oauth2.client.endpoint; +import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Collections; +import java.util.function.Function; +import javax.crypto.spec.SecretKeySpec; + +import com.nimbusds.jose.jwk.JWK; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; @@ -39,6 +44,10 @@ import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import org.springframework.web.reactive.function.BodyExtractor; import static org.assertj.core.api.Assertions.assertThat; @@ -146,6 +155,79 @@ public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSen assertThat(formParameters).contains("client_secret=client-secret"); } + @Test + public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + // @formatter:off + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY) + .build(); + // @formatter:on + + // Configure Jwt client authentication converter + SecretKeySpec secretKey = new SecretKeySpec( + clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256"); + JWK jwk = TestJwks.jwk(secretKey).build(); + Function jwkResolver = (registration) -> jwk; + configureJwtClientAuthenticationConverter(jwkResolver); + + OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, + this.username, this.password); + this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block(); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=password", + "client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer", + "client_assertion="); + } + + @Test + public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + // @formatter:off + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT) + .build(); + // @formatter:on + + // Configure Jwt client authentication converter + JWK jwk = TestJwks.DEFAULT_RSA_JWK; + Function jwkResolver = (registration) -> jwk; + configureJwtClientAuthenticationConverter(jwkResolver); + + OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, + this.username, this.password); + this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block(); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=password", + "client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer", + "client_assertion="); + } + + private void configureJwtClientAuthenticationConverter(Function jwkResolver) { + NimbusJwtClientAuthenticationParametersConverter jwtClientAuthenticationConverter = new NimbusJwtClientAuthenticationParametersConverter<>( + jwkResolver); + this.tokenResponseClient.addParametersConverter(jwtClientAuthenticationConverter); + } + @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { // @formatter:off @@ -291,6 +373,69 @@ public void convertWhenHeadersConverterSetThenCalled() throws Exception { .isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); } + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void convertWhenParametersConverterAddedThenCalled() throws Exception { + OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(), + this.username, this.password); + Converter> addedParametersConverter = mock( + Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(addedParametersConverter.convert(request)).willReturn(parameters); + this.tokenResponseClient.addParametersConverter(addedParametersConverter); + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" + + " \"token_type\":\"bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" + + "}"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.tokenResponseClient.getTokenResponse(request).block(); + verify(addedParametersConverter).convert(request); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=password", + "custom-parameter-name=custom-parameter-value"); + } + + @Test + public void convertWhenParametersConverterSetThenCalled() throws Exception { + OAuth2PasswordGrantRequest request = new OAuth2PasswordGrantRequest(this.clientRegistrationBuilder.build(), + this.username, this.password); + Converter> parametersConverter = mock( + Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(request)).willReturn(parameters); + this.tokenResponseClient.setParametersConverter(parametersConverter); + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" + + " \"token_type\":\"bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" + + "}"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.tokenResponseClient.getTokenResponse(request).block(); + verify(parametersConverter).convert(request); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value"); + } + // gh-10260 @Test public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java index 01f36f0fb86..4d2c9452372 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java @@ -16,9 +16,14 @@ package org.springframework.security.oauth2.client.endpoint; +import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Collections; +import java.util.function.Function; +import javax.crypto.spec.SecretKeySpec; + +import com.nimbusds.jose.jwk.JWK; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; @@ -42,6 +47,10 @@ import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.security.oauth2.jose.TestJwks; +import org.springframework.security.oauth2.jose.TestKeys; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; import org.springframework.web.reactive.function.BodyExtractor; import static org.assertj.core.api.Assertions.assertThat; @@ -149,6 +158,79 @@ public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSen assertThat(formParameters).contains("client_secret=client-secret"); } + @Test + public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + // @formatter:off + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY) + .build(); + // @formatter:on + + // Configure Jwt client authentication converter + SecretKeySpec secretKey = new SecretKeySpec( + clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256"); + JWK jwk = TestJwks.jwk(secretKey).build(); + Function jwkResolver = (registration) -> jwk; + configureJwtClientAuthenticationConverter(jwkResolver); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block(); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=refresh_token", + "client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer", + "client_assertion="); + } + + @Test + public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + // @formatter:off + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT) + .build(); + // @formatter:on + + // Configure Jwt client authentication converter + JWK jwk = TestJwks.DEFAULT_RSA_JWK; + Function jwkResolver = (registration) -> jwk; + configureJwtClientAuthenticationConverter(jwkResolver); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration, + this.accessToken, this.refreshToken); + this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block(); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=refresh_token", + "client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer", + "client_assertion="); + } + + private void configureJwtClientAuthenticationConverter(Function jwkResolver) { + NimbusJwtClientAuthenticationParametersConverter jwtClientAuthenticationConverter = new NimbusJwtClientAuthenticationParametersConverter<>( + jwkResolver); + this.tokenResponseClient.addParametersConverter(jwtClientAuthenticationConverter); + } + @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { // @formatter:off @@ -294,6 +376,69 @@ public void convertWhenHeadersConverterSetThenCalled() throws Exception { .isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); } + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + } + + @Test + public void convertWhenParametersConverterAddedThenCalled() throws Exception { + OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); + Converter> addedParametersConverter = mock( + Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(addedParametersConverter.convert(request)).willReturn(parameters); + this.tokenResponseClient.addParametersConverter(addedParametersConverter); + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" + + " \"token_type\":\"bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" + + "}"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.tokenResponseClient.getTokenResponse(request).block(); + verify(addedParametersConverter).convert(request); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getBody().readUtf8()).contains("grant_type=refresh_token", + "custom-parameter-name=custom-parameter-value"); + } + + @Test + public void convertWhenParametersConverterSetThenCalled() throws Exception { + OAuth2RefreshTokenGrantRequest request = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); + Converter> parametersConverter = mock( + Converter.class); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add("custom-parameter-name", "custom-parameter-value"); + given(parametersConverter.convert(request)).willReturn(parameters); + this.tokenResponseClient.setParametersConverter(parametersConverter); + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" + + " \"token_type\":\"bearer\",\n" + + " \"expires_in\":3600,\n" + + " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" + + "}"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + this.tokenResponseClient.getTokenResponse(request).block(); + verify(parametersConverter).convert(request); + RecordedRequest actualRequest = this.server.takeRequest(); + assertThat(actualRequest.getBody().readUtf8()).contains("custom-parameter-name=custom-parameter-value"); + } + // gh-10260 @Test public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {