-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Propagate SecurityContext in ChannelInterceptor
Add `SecurityContextPropagationChannelInterceptor` that propagates the current security context through the Spring Messaging API. Namely, it adds the current security context into any message before it is sent and then populates the security context when that message is received, typically in a separate thread.
- Loading branch information
1 parent
817e9d6
commit 60a00bb
Showing
2 changed files
with
324 additions
and
0 deletions.
There are no files selected for viewing
164 changes: 164 additions & 0 deletions
164
...ingframework/security/messaging/context/SecurityContextPropagationChannelInterceptor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
/* | ||
* Copyright 2002-2023 the original author or authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* https://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.springframework.security.messaging.context; | ||
|
||
import java.util.Stack; | ||
|
||
import org.springframework.messaging.Message; | ||
import org.springframework.messaging.MessageChannel; | ||
import org.springframework.messaging.MessageHandler; | ||
import org.springframework.messaging.simp.SimpMessageHeaderAccessor; | ||
import org.springframework.messaging.support.ExecutorChannelInterceptor; | ||
import org.springframework.messaging.support.MessageBuilder; | ||
import org.springframework.security.authentication.AnonymousAuthenticationToken; | ||
import org.springframework.security.core.Authentication; | ||
import org.springframework.security.core.authority.AuthorityUtils; | ||
import org.springframework.security.core.context.SecurityContext; | ||
import org.springframework.security.core.context.SecurityContextHolder; | ||
import org.springframework.security.core.context.SecurityContextHolderStrategy; | ||
import org.springframework.util.Assert; | ||
|
||
/** | ||
* An {@link ExecutorChannelInterceptor} that takes an {@link Authentication} from the | ||
* current {@link SecurityContext} (if any) in the | ||
* {@link #preSend(Message, MessageChannel)} callback and stores it into an | ||
* {@link #authenticationHeaderName} message header. Then sets the context from this | ||
* header in the {@link #beforeHandle(Message, MessageChannel, MessageHandler)} and | ||
* {@link #postReceive(Message, MessageChannel)} both of which typically happen on a | ||
* different thread. | ||
* <p> | ||
* Note: cannot be used in combination with a {@link SecurityContextChannelInterceptor} on | ||
* the same channel since both these interceptors modify a security context on a handling | ||
* and receiving operations. | ||
* | ||
* @author Artem Bilan | ||
* @since 6.2 | ||
* @see SecurityContextChannelInterceptor | ||
*/ | ||
public final class SecurityContextPropagationChannelInterceptor implements ExecutorChannelInterceptor { | ||
|
||
private static final ThreadLocal<Stack<SecurityContext>> originalContext = new ThreadLocal<>(); | ||
|
||
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder | ||
.getContextHolderStrategy(); | ||
|
||
private SecurityContext empty = this.securityContextHolderStrategy.createEmptyContext(); | ||
|
||
private final String authenticationHeaderName; | ||
|
||
private Authentication anonymous = new AnonymousAuthenticationToken("key", "anonymous", | ||
AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); | ||
|
||
/** | ||
* Create a new instance using the header of the name | ||
* {@link SimpMessageHeaderAccessor#USER_HEADER}. | ||
*/ | ||
public SecurityContextPropagationChannelInterceptor() { | ||
this(SimpMessageHeaderAccessor.USER_HEADER); | ||
} | ||
|
||
/** | ||
* Create a new instance that uses the specified header to populate the | ||
* {@link Authentication}. | ||
* @param authenticationHeaderName the header name to populate the | ||
* {@link Authentication}. Cannot be null. | ||
*/ | ||
public SecurityContextPropagationChannelInterceptor(String authenticationHeaderName) { | ||
Assert.notNull(authenticationHeaderName, "authenticationHeaderName cannot be null"); | ||
this.authenticationHeaderName = authenticationHeaderName; | ||
} | ||
|
||
public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy strategy) { | ||
this.securityContextHolderStrategy = strategy; | ||
this.empty = this.securityContextHolderStrategy.createEmptyContext(); | ||
} | ||
|
||
/** | ||
* Configure an Authentication used for anonymous authentication. Default is: <pre> | ||
* new AnonymousAuthenticationToken("key", "anonymous", | ||
* AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); | ||
* </pre> | ||
* @param authentication the Authentication used for anonymous authentication. Cannot | ||
* be null. | ||
*/ | ||
public void setAnonymousAuthentication(Authentication authentication) { | ||
Assert.notNull(authentication, "authentication cannot be null"); | ||
this.anonymous = authentication; | ||
} | ||
|
||
@Override | ||
public Message<?> preSend(Message<?> message, MessageChannel channel) { | ||
Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication(); | ||
if (authentication == null) { | ||
authentication = this.anonymous; | ||
} | ||
return MessageBuilder.fromMessage(message).setHeader(this.authenticationHeaderName, authentication).build(); | ||
} | ||
|
||
@Override | ||
public Message<?> beforeHandle(Message<?> message, MessageChannel channel, MessageHandler handler) { | ||
return postReceive(message, channel); | ||
} | ||
|
||
@Override | ||
public Message<?> postReceive(Message<?> message, MessageChannel channel) { | ||
setup(message); | ||
return message; | ||
} | ||
|
||
@Override | ||
public void afterMessageHandled(Message<?> message, MessageChannel channel, MessageHandler handler, Exception ex) { | ||
cleanup(); | ||
} | ||
|
||
private void setup(Message<?> message) { | ||
Authentication authentication = message.getHeaders().get(this.authenticationHeaderName, Authentication.class); | ||
SecurityContext currentContext = this.securityContextHolderStrategy.getContext(); | ||
Stack<SecurityContext> contextStack = originalContext.get(); | ||
if (contextStack == null) { | ||
contextStack = new Stack<>(); | ||
originalContext.set(contextStack); | ||
} | ||
contextStack.push(currentContext); | ||
SecurityContext context = this.securityContextHolderStrategy.createEmptyContext(); | ||
context.setAuthentication(authentication); | ||
this.securityContextHolderStrategy.setContext(context); | ||
} | ||
|
||
private void cleanup() { | ||
Stack<SecurityContext> contextStack = originalContext.get(); | ||
if (contextStack == null || contextStack.isEmpty()) { | ||
this.securityContextHolderStrategy.clearContext(); | ||
originalContext.remove(); | ||
return; | ||
} | ||
SecurityContext context = contextStack.pop(); | ||
try { | ||
if (this.empty.equals(context)) { | ||
this.securityContextHolderStrategy.clearContext(); | ||
originalContext.remove(); | ||
} | ||
else { | ||
this.securityContextHolderStrategy.setContext(context); | ||
} | ||
} | ||
catch (Throwable ex) { | ||
this.securityContextHolderStrategy.clearContext(); | ||
} | ||
} | ||
|
||
} |
160 changes: 160 additions & 0 deletions
160
...amework/security/messaging/context/SecurityContextPropagationChannelInterceptorTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
/* | ||
* Copyright 2002-2023 the original author or authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* https://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.springframework.security.messaging.context; | ||
|
||
import org.junit.jupiter.api.AfterEach; | ||
import org.junit.jupiter.api.BeforeEach; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.ExtendWith; | ||
import org.mockito.Mock; | ||
import org.mockito.junit.jupiter.MockitoExtension; | ||
|
||
import org.springframework.messaging.Message; | ||
import org.springframework.messaging.MessageChannel; | ||
import org.springframework.messaging.MessageHandler; | ||
import org.springframework.messaging.simp.SimpMessageHeaderAccessor; | ||
import org.springframework.messaging.support.MessageBuilder; | ||
import org.springframework.security.authentication.AnonymousAuthenticationToken; | ||
import org.springframework.security.authentication.TestingAuthenticationToken; | ||
import org.springframework.security.core.Authentication; | ||
import org.springframework.security.core.context.SecurityContextHolder; | ||
import org.springframework.security.core.context.SecurityContextHolderStrategy; | ||
import org.springframework.security.core.context.SecurityContextImpl; | ||
|
||
import static org.assertj.core.api.Assertions.assertThat; | ||
import static org.mockito.Mockito.spy; | ||
import static org.mockito.Mockito.times; | ||
import static org.mockito.Mockito.verify; | ||
|
||
@ExtendWith(MockitoExtension.class) | ||
public class SecurityContextPropagationChannelInterceptorTests { | ||
|
||
@Mock | ||
MessageChannel channel; | ||
|
||
@Mock | ||
MessageHandler handler; | ||
|
||
MessageBuilder<String> messageBuilder; | ||
|
||
Authentication authentication; | ||
|
||
SecurityContextPropagationChannelInterceptor interceptor; | ||
|
||
@BeforeEach | ||
public void setup() { | ||
this.authentication = new TestingAuthenticationToken("user", "pass", "ROLE_USER"); | ||
this.messageBuilder = MessageBuilder.withPayload("payload"); | ||
this.interceptor = new SecurityContextPropagationChannelInterceptor(); | ||
} | ||
|
||
@AfterEach | ||
public void cleanup() { | ||
this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null); | ||
SecurityContextHolder.clearContext(); | ||
} | ||
|
||
@Test | ||
public void preSendDefaultHeader() { | ||
SecurityContextHolder.getContext().setAuthentication(this.authentication); | ||
Message<?> message = this.interceptor.preSend(this.messageBuilder.build(), this.channel); | ||
assertThat(message.getHeaders()).containsEntry(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); | ||
} | ||
|
||
@Test | ||
public void preSendCustomHeader() { | ||
SecurityContextHolder.getContext().setAuthentication(this.authentication); | ||
String headerName = "header"; | ||
this.interceptor = new SecurityContextPropagationChannelInterceptor(headerName); | ||
Message<?> message = this.interceptor.preSend(this.messageBuilder.build(), this.channel); | ||
assertThat(message.getHeaders()).containsEntry(headerName, this.authentication); | ||
} | ||
|
||
@Test | ||
public void preSendWhenCustomSecurityContextHolderStrategyThenUserSet() { | ||
SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy()); | ||
strategy.setContext(new SecurityContextImpl(this.authentication)); | ||
this.interceptor.setSecurityContextHolderStrategy(strategy); | ||
Message<?> message = this.interceptor.preSend(this.messageBuilder.build(), this.channel); | ||
this.interceptor.beforeHandle(message, this.channel, this.handler); | ||
verify(strategy, times(2)).getContext(); | ||
assertThat(strategy.getContext().getAuthentication()).isSameAs(this.authentication); | ||
} | ||
|
||
@Test | ||
public void preSendUserNoContext() { | ||
Message<?> message = this.interceptor.preSend(this.messageBuilder.build(), this.channel); | ||
assertThat(message.getHeaders()).containsKey(SimpMessageHeaderAccessor.USER_HEADER); | ||
assertThat(message.getHeaders().get(SimpMessageHeaderAccessor.USER_HEADER)) | ||
.isInstanceOf(AnonymousAuthenticationToken.class); | ||
} | ||
|
||
@Test | ||
public void beforeHandleUserSet() { | ||
this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); | ||
this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler); | ||
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); | ||
} | ||
|
||
@Test | ||
public void postReceiveUserSet() { | ||
this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); | ||
this.interceptor.postReceive(this.messageBuilder.build(), this.channel); | ||
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); | ||
} | ||
|
||
@Test | ||
public void authenticationIsPropagatedFromPreSendToPostReceive() { | ||
SecurityContextHolder.getContext().setAuthentication(this.authentication); | ||
Message<?> message = this.interceptor.preSend(this.messageBuilder.build(), this.channel); | ||
assertThat(message.getHeaders().get(SimpMessageHeaderAccessor.USER_HEADER)).isSameAs(this.authentication); | ||
this.interceptor.postReceive(message, this.channel); | ||
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); | ||
} | ||
|
||
@Test | ||
public void beforeHandleUserNotSet() { | ||
this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler); | ||
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); | ||
} | ||
|
||
@Test | ||
public void afterMessageHandledUserNotSet() { | ||
this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null); | ||
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); | ||
} | ||
|
||
@Test | ||
public void afterMessageHandled() { | ||
SecurityContextHolder.getContext().setAuthentication(this.authentication); | ||
this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); | ||
this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null); | ||
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); | ||
} | ||
|
||
@Test | ||
public void restoresOriginalContext() { | ||
TestingAuthenticationToken original = new TestingAuthenticationToken("original", "original", "ROLE_USER"); | ||
SecurityContextHolder.getContext().setAuthentication(original); | ||
this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication); | ||
this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler); | ||
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication); | ||
this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null); | ||
assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(original); | ||
} | ||
|
||
} |