Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customize the strategy for resolving the principal #15833

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
Expand Down Expand Up @@ -121,40 +119,24 @@ public final class OAuth2ClientHttpRequestInterceptor implements ClientHttpReque

private final OAuth2AuthorizedClientManager authorizedClientManager;

private final ClientRegistrationIdResolver clientRegistrationIdResolver;
private ClientRegistrationIdResolver clientRegistrationIdResolver = new RequestAttributeClientRegistrationIdResolver();

private PrincipalResolver principalResolver = new SecurityContextHolderPrincipalResolver();

// @formatter:off
private OAuth2AuthorizationFailureHandler authorizationFailureHandler =
(clientRegistrationId, principal, attributes) -> { };
// @formatter:on

private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();

/**
* Constructs a {@code OAuth2ClientHttpRequestInterceptor} using the provided
* parameters.
* @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which
* manages the authorized client(s)
*/
public OAuth2ClientHttpRequestInterceptor(OAuth2AuthorizedClientManager authorizedClientManager) {
this(authorizedClientManager, new RequestAttributeClientRegistrationIdResolver());
}

/**
* Constructs a {@code OAuth2ClientHttpRequestInterceptor} using the provided
* parameters.
* @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which
* manages the authorized client(s)
* @param clientRegistrationIdResolver the strategy for resolving a
* {@code clientRegistrationId} from the intercepted request
*/
public OAuth2ClientHttpRequestInterceptor(OAuth2AuthorizedClientManager authorizedClientManager,
ClientRegistrationIdResolver clientRegistrationIdResolver) {
Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null");
Assert.notNull(clientRegistrationIdResolver, "clientRegistrationIdResolver cannot be null");
this.authorizedClientManager = authorizedClientManager;
this.clientRegistrationIdResolver = clientRegistrationIdResolver;
}

/**
Expand Down Expand Up @@ -238,20 +220,31 @@ public static OAuth2AuthorizationFailureHandler authorizationFailureHandler(
}

/**
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
* @param securityContextHolderStrategy the {@link SecurityContextHolderStrategy} to
* use
* Sets the strategy for resolving a {@code clientRegistrationId} from an intercepted
* request.
* @param clientRegistrationIdResolver the strategy for resolving a
* {@code clientRegistrationId} from an intercepted request
*/
public void setClientRegistrationIdResolver(ClientRegistrationIdResolver clientRegistrationIdResolver) {
Assert.notNull(clientRegistrationIdResolver, "clientRegistrationIdResolver cannot be null");
this.clientRegistrationIdResolver = clientRegistrationIdResolver;
}

/**
* Sets the strategy for resolving a {@link Authentication principal} from an
* intercepted request.
* @param principalResolver the strategy for resolving a {@link Authentication
* principal}
*/
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
this.securityContextHolderStrategy = securityContextHolderStrategy;
public void setPrincipalResolver(PrincipalResolver principalResolver) {
Assert.notNull(principalResolver, "principalResolver cannot be null");
this.principalResolver = principalResolver;
}

@Override
public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution)
throws IOException {
Authentication principal = this.securityContextHolderStrategy.getContext().getAuthentication();
Authentication principal = this.principalResolver.resolve(request);
if (principal == null) {
principal = ANONYMOUS_AUTHENTICATION;
}
Expand Down Expand Up @@ -378,4 +371,24 @@ public interface ClientRegistrationIdResolver {

}

/**
* A strategy for resolving a {@link Authentication principal} from an intercepted
* request.
*/
@FunctionalInterface
public interface PrincipalResolver {

/**
* Resolve the {@link Authentication principal} from the current request, which is
* used to obtain an {@link OAuth2AuthorizedClient}.
* @param request the intercepted request, containing HTTP method, URI, headers,
* and request attributes
* @return the {@link Authentication principal} to be used for resolving an
* {@link OAuth2AuthorizedClient}.
*/
@Nullable
Authentication resolve(HttpRequest request);

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client.web.client;

import java.util.Collections;
import java.util.Map;
import java.util.function.Consumer;

import org.springframework.http.HttpRequest;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.util.Assert;

/**
* A strategy for resolving a {@link Authentication principal} from an intercepted request
* using {@link ClientHttpRequest#getAttributes() attributes}.
*
* @author Steve Riesenberg
* @since 6.4
*/
public class RequestAttributePrincipalResolver implements OAuth2ClientHttpRequestInterceptor.PrincipalResolver {

private static final String PRINCIPAL_ATTR_NAME = RequestAttributePrincipalResolver.class.getName()
.concat(".principal");

@Override
public Authentication resolve(HttpRequest request) {
return (Authentication) request.getAttributes().get(PRINCIPAL_ATTR_NAME);
}

/**
* Modifies the {@link ClientHttpRequest#getAttributes() attributes} to include the
* {@link Authentication principal} to be used to look up the
* {@link OAuth2AuthorizedClient}.
* @param principal the {@link Authentication principal} to be used to look up the
* {@link OAuth2AuthorizedClient}
* @return the {@link Consumer} to populate the attributes
*/
public static Consumer<Map<String, Object>> principal(Authentication principal) {
Assert.notNull(principal, "principal cannot be null");
return (attributes) -> attributes.put(PRINCIPAL_ATTR_NAME, principal);
}

/**
* Modifies the {@link ClientHttpRequest#getAttributes() attributes} to include the
* {@link Authentication principal} to be used to look up the
* {@link OAuth2AuthorizedClient}.
* @param principalName the {@code principalName} to be used to look up the
* {@link OAuth2AuthorizedClient}
* @return the {@link Consumer} to populate the attributes
*/
public static Consumer<Map<String, Object>> principal(String principalName) {
Assert.hasText(principalName, "principalName cannot be empty");
Authentication principal = createAuthentication(principalName);
return (attributes) -> attributes.put(PRINCIPAL_ATTR_NAME, principal);
}

private static Authentication createAuthentication(String principalName) {
return new AbstractAuthenticationToken(Collections.emptySet()) {
@Override
public Object getPrincipal() {
return principalName;
}

@Override
public Object getCredentials() {
return null;
}
};
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.client.web.client;

import org.springframework.http.HttpRequest;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;

/**
* A strategy for resolving a {@link Authentication principal} from an intercepted request
* using the {@link SecurityContextHolder}.
*
* @author Steve Riesenberg
* @since 6.4
*/
public class SecurityContextHolderPrincipalResolver implements OAuth2ClientHttpRequestInterceptor.PrincipalResolver {

private final SecurityContextHolderStrategy securityContextHolderStrategy;

/**
* Constructs a {@code SecurityContextHolderPrincipalResolver}.
*/
public SecurityContextHolderPrincipalResolver() {
this(SecurityContextHolder.getContextHolderStrategy());
}

/**
* Constructs a {@code SecurityContextHolderPrincipalResolver} using the provided
* parameters.
* @param securityContextHolderStrategy the {@link SecurityContextHolderStrategy} to
* use for resolving the {@link Authentication principal}
*/
public SecurityContextHolderPrincipalResolver(SecurityContextHolderStrategy securityContextHolderStrategy) {
this.securityContextHolderStrategy = securityContextHolderStrategy;
}

@Override
public Authentication resolve(HttpRequest request) {
return this.securityContextHolderStrategy.getContext().getAuthentication();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.security.oauth2.client.ClientAuthorizationException;
import org.springframework.security.oauth2.client.OAuth2AuthorizationFailureHandler;
Expand Down Expand Up @@ -110,15 +109,15 @@ public class OAuth2ClientHttpRequestInterceptorTests {
@Mock
private OAuth2AuthorizedClientRepository authorizedClientRepository;

@Mock
private SecurityContextHolderStrategy securityContextHolderStrategy;

@Mock
private OAuth2AuthorizedClientService authorizedClientService;

@Mock
private OAuth2ClientHttpRequestInterceptor.ClientRegistrationIdResolver clientRegistrationIdResolver;

@Mock
private OAuth2ClientHttpRequestInterceptor.PrincipalResolver principalResolver;

@Captor
private ArgumentCaptor<OAuth2AuthorizeRequest> authorizeRequestCaptor;

Expand Down Expand Up @@ -167,13 +166,6 @@ public void constructorWhenAuthorizedClientManagerIsNullThenThrowsIllegalArgumen
.withMessage("authorizedClientManager cannot be null");
}

@Test
public void constructorWhenClientRegistrationIdResolverIsNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> new OAuth2ClientHttpRequestInterceptor(this.authorizedClientManager, null))
.withMessage("clientRegistrationIdResolver cannot be null");
}

@Test
public void setAuthorizationFailureHandlerWhenNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException()
Expand All @@ -198,10 +190,16 @@ public void authorizationFailureHandlerWhenAuthorizedClientServiceIsNullThenThro
}

@Test
public void setSecurityContextHolderStrategyWhenNullThenThrowsIllegalArgumentException() {
public void setClientRegistrationIdResolverWhenNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.requestInterceptor.setSecurityContextHolderStrategy(null))
.withMessage("securityContextHolderStrategy cannot be null");
.isThrownBy(() -> this.requestInterceptor.setClientRegistrationIdResolver(null))
.withMessage("clientRegistrationIdResolver cannot be null");
}

@Test
public void setPrincipalResolverWhenNullThenThrowsIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.requestInterceptor.setPrincipalResolver(null))
.withMessage("principalResolver cannot be null");
}

@Test
Expand Down Expand Up @@ -605,8 +603,7 @@ public void interceptWhenUnauthorizedAndAuthorizationFailureHandlerSetWithAuthor

@Test
public void interceptWhenCustomClientRegistrationIdResolverSetThenUsed() {
this.requestInterceptor = new OAuth2ClientHttpRequestInterceptor(this.authorizedClientManager,
this.clientRegistrationIdResolver);
this.requestInterceptor.setClientRegistrationIdResolver(this.clientRegistrationIdResolver);
this.requestInterceptor.setAuthorizationFailureHandler(this.authorizationFailureHandler);
given(this.authorizedClientManager.authorize(any(OAuth2AuthorizeRequest.class)))
.willReturn(this.authorizedClient);
Expand All @@ -625,31 +622,29 @@ public void interceptWhenCustomClientRegistrationIdResolverSetThenUsed() {
this.server.verify();
verify(this.authorizedClientManager).authorize(this.authorizeRequestCaptor.capture());
verify(this.clientRegistrationIdResolver).resolve(any(HttpRequest.class));
verifyNoMoreInteractions(this.clientRegistrationIdResolver, this.authorizedClientManager);
verifyNoMoreInteractions(this.authorizedClientManager, this.clientRegistrationIdResolver);
verifyNoInteractions(this.authorizationFailureHandler);
OAuth2AuthorizeRequest authorizeRequest = this.authorizeRequestCaptor.getValue();
assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(clientRegistrationId);
assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal);
}

@Test
public void interceptWhenCustomSecurityContextHolderStrategySetThenUsed() {
this.requestInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
public void interceptWhenCustomPrincipalResolverSetThenUsed() {
this.requestInterceptor.setPrincipalResolver(this.principalResolver);
given(this.authorizedClientManager.authorize(any(OAuth2AuthorizeRequest.class)))
.willReturn(this.authorizedClient);

bindToRestClient(withRequestInterceptor());
this.server.expect(requestTo(REQUEST_URI))
.andExpect(hasAuthorizationHeader(this.authorizedClient.getAccessToken()))
.andRespond(withApplicationJson());
SecurityContext securityContext = new SecurityContextImpl();
securityContext.setAuthentication(this.principal);
given(this.securityContextHolderStrategy.getContext()).willReturn(securityContext);
given(this.principalResolver.resolve(any(HttpRequest.class))).willReturn(this.principal);
performRequest(withClientRegistrationId());
this.server.verify();
verify(this.authorizedClientManager).authorize(this.authorizeRequestCaptor.capture());
verify(this.securityContextHolderStrategy).getContext();
verifyNoMoreInteractions(this.authorizedClientManager, this.securityContextHolderStrategy);
verify(this.principalResolver).resolve(any(HttpRequest.class));
verifyNoMoreInteractions(this.authorizedClientManager, this.principalResolver);
OAuth2AuthorizeRequest authorizeRequest = this.authorizeRequestCaptor.getValue();
assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.clientRegistration.getRegistrationId());
assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal);
Expand Down