Skip to content

Commit

Permalink
Update Max Sessions on WebFlux
Browse files Browse the repository at this point in the history
Delete WebSessionStoreReactiveSessionRegistry.java and gives the responsibility to remove the sessions from the WebSessionStore to the handler

Issue gh-6192
  • Loading branch information
marcusdacoregio committed Feb 28, 2024
1 parent f3bcf7e commit f8ff056
Show file tree
Hide file tree
Showing 10 changed files with 116 additions and 365 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.WebSession;
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
import org.springframework.web.server.session.DefaultWebSessionManager;
import org.springframework.web.util.pattern.PathPatternParser;

/**
Expand Down Expand Up @@ -1964,7 +1966,7 @@ public class SessionManagementSpec {

private SessionLimit sessionLimit = SessionLimit.UNLIMITED;

private ServerMaximumSessionsExceededHandler maximumSessionsExceededHandler = new InvalidateLeastUsedServerMaximumSessionsExceededHandler();
private ServerMaximumSessionsExceededHandler maximumSessionsExceededHandler;

/**
* Configures how many sessions are allowed for a given user.
Expand All @@ -1983,9 +1985,8 @@ void configure(ServerHttpSecurity http) {
if (this.concurrentSessions != null) {
ReactiveSessionRegistry reactiveSessionRegistry = getSessionRegistry();
ConcurrentSessionControlServerAuthenticationSuccessHandler concurrentSessionControlStrategy = new ConcurrentSessionControlServerAuthenticationSuccessHandler(
reactiveSessionRegistry);
reactiveSessionRegistry, getMaximumSessionsExceededHandler());
concurrentSessionControlStrategy.setSessionLimit(this.sessionLimit);
concurrentSessionControlStrategy.setMaximumSessionsExceededHandler(this.maximumSessionsExceededHandler);
RegisterSessionServerAuthenticationSuccessHandler registerSessionAuthenticationStrategy = new RegisterSessionServerAuthenticationSuccessHandler(
reactiveSessionRegistry);
this.authenticationSuccessHandler = new DelegatingServerAuthenticationSuccessHandler(
Expand All @@ -1997,6 +1998,24 @@ void configure(ServerHttpSecurity http) {
}
}

private ServerMaximumSessionsExceededHandler getMaximumSessionsExceededHandler() {
if (this.maximumSessionsExceededHandler != null) {
return this.maximumSessionsExceededHandler;
}
DefaultWebSessionManager webSessionManager = getBeanOrNull(
WebHttpHandlerBuilder.WEB_SESSION_MANAGER_BEAN_NAME, DefaultWebSessionManager.class);
if (webSessionManager != null) {
this.maximumSessionsExceededHandler = new InvalidateLeastUsedServerMaximumSessionsExceededHandler(
webSessionManager.getSessionStore());
}
if (this.maximumSessionsExceededHandler == null) {
throw new IllegalStateException(
"Could not create a default ServerMaximumSessionsExceededHandler. Please provide "
+ "a ServerMaximumSessionsExceededHandler via DSL");
}
return this.maximumSessionsExceededHandler;
}

private void configureSuccessHandlerOnAuthenticationFilters() {
if (ServerHttpSecurity.this.formLogin != null) {
ServerHttpSecurity.this.formLogin.defaultSuccessHandlers.add(0, this.authenticationSuccessHandler);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@
import org.springframework.security.config.test.SpringTestContext;
import org.springframework.security.config.test.SpringTestContextExtension;
import org.springframework.security.config.users.ReactiveAuthenticationTestConfiguration;
import org.springframework.security.core.session.ReactiveSessionInformation;
import org.springframework.security.core.session.InMemoryReactiveSessionRegistry;
import org.springframework.security.core.session.ReactiveSessionRegistry;
import org.springframework.security.core.userdetails.PasswordEncodedUser;
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
Expand All @@ -52,10 +51,8 @@
import org.springframework.security.web.server.authentication.PreventLoginServerMaximumSessionsExceededHandler;
import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler;
import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler;
import org.springframework.security.web.server.authentication.SessionLimit;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.security.web.session.WebSessionStoreReactiveSessionRegistry;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.util.LinkedMultiValueMap;
Expand Down Expand Up @@ -322,45 +319,6 @@ void oauth2LoginWhenMaxSessionsThenPreventLogin() {
// @formatter:on
}

@Test
void loginWhenUnlimitedSessionsButSessionsInvalidatedManuallyThenInvalidates() {
ConcurrentSessionsMaxSessionPreventsLoginFalseConfig.sessionLimit = SessionLimit.UNLIMITED;
this.spring.register(ConcurrentSessionsMaxSessionPreventsLoginFalseConfig.class).autowire();
MultiValueMap<String, String> data = new LinkedMultiValueMap<>();
data.add("username", "user");
data.add("password", "password");

ResponseCookie firstLogin = loginReturningCookie(data);
ResponseCookie secondLogin = loginReturningCookie(data);
this.client.get().uri("/").cookie(firstLogin.getName(), firstLogin.getValue()).exchange().expectStatus().isOk();
this.client.get()
.uri("/")
.cookie(secondLogin.getName(), secondLogin.getValue())
.exchange()
.expectStatus()
.isOk();
ReactiveSessionRegistry sessionRegistry = this.spring.getContext().getBean(ReactiveSessionRegistry.class);
sessionRegistry.getAllSessions(PasswordEncodedUser.user())
.flatMap(ReactiveSessionInformation::invalidate)
.blockLast();
this.client.get()
.uri("/")
.cookie(firstLogin.getName(), firstLogin.getValue())
.exchange()
.expectStatus()
.isFound()
.expectHeader()
.location("/login");
this.client.get()
.uri("/")
.cookie(secondLogin.getName(), secondLogin.getValue())
.exchange()
.expectStatus()
.isFound()
.expectHeader()
.location("/login");
}

@Test
void oauth2LoginWhenMaxSessionDoesNotPreventLoginThenSecondLoginSucceedsAndFirstSessionIsInvalidated() {
OAuth2LoginConcurrentSessionsConfig.maxSessions = 1;
Expand Down Expand Up @@ -490,10 +448,9 @@ static class OAuth2LoginConcurrentSessionsConfig {

ServerOAuth2AuthorizationRequestResolver resolver = mock(ServerOAuth2AuthorizationRequestResolver.class);

ServerAuthenticationSuccessHandler successHandler = mock(ServerAuthenticationSuccessHandler.class);

@Bean
SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http,
DefaultWebSessionManager webSessionManager) {
// @formatter:off
http
.authorizeExchange((exchanges) -> exchanges
Expand All @@ -509,7 +466,7 @@ SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
.maximumSessions(SessionLimit.of(maxSessions))
.maximumSessionsExceededHandler(preventLogin
? new PreventLoginServerMaximumSessionsExceededHandler()
: new InvalidateLeastUsedServerMaximumSessionsExceededHandler())
: new InvalidateLeastUsedServerMaximumSessionsExceededHandler(webSessionManager.getSessionStore()))
)
);
// @formatter:on
Expand Down Expand Up @@ -611,8 +568,8 @@ DefaultWebSessionManager webSessionManager() {
}

@Bean
ReactiveSessionRegistry reactiveSessionRegistry(DefaultWebSessionManager webSessionManager) {
return new WebSessionStoreReactiveSessionRegistry(webSessionManager.getSessionStore());
ReactiveSessionRegistry reactiveSessionRegistry() {
return new InMemoryReactiveSessionRegistry();
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,13 +29,13 @@ import org.springframework.security.config.annotation.web.reactive.EnableWebFlux
import org.springframework.security.config.test.SpringTestContext
import org.springframework.security.config.test.SpringTestContextExtension
import org.springframework.security.config.users.ReactiveAuthenticationTestConfiguration
import org.springframework.security.core.session.InMemoryReactiveSessionRegistry
import org.springframework.security.core.session.ReactiveSessionRegistry
import org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers
import org.springframework.security.web.server.SecurityWebFilterChain
import org.springframework.security.web.server.authentication.InvalidateLeastUsedServerMaximumSessionsExceededHandler
import org.springframework.security.web.server.authentication.PreventLoginServerMaximumSessionsExceededHandler
import org.springframework.security.web.server.authentication.SessionLimit
import org.springframework.security.web.session.WebSessionStoreReactiveSessionRegistry
import org.springframework.test.web.reactive.server.WebTestClient
import org.springframework.util.LinkedMultiValueMap
import org.springframework.util.MultiValueMap
Expand All @@ -45,7 +45,6 @@ import org.springframework.web.reactive.config.EnableWebFlux
import org.springframework.web.reactive.function.BodyInserters
import org.springframework.web.server.adapter.WebHttpHandlerBuilder
import org.springframework.web.server.session.DefaultWebSessionManager
import reactor.core.publisher.Mono

/**
* Tests for [ServerSessionManagementDsl]
Expand Down Expand Up @@ -208,7 +207,7 @@ class ServerSessionManagementDslTests {
}

@Bean
open fun springSecurity(http: ServerHttpSecurity): SecurityWebFilterChain {
open fun springSecurity(http: ServerHttpSecurity, webSessionManager: DefaultWebSessionManager): SecurityWebFilterChain {
return http {
authorizeExchange {
authorize(anyExchange, authenticated)
Expand All @@ -217,7 +216,7 @@ class ServerSessionManagementDslTests {
sessionManagement {
sessionConcurrency {
maximumSessions = SessionLimit.of(maxSessions)
maximumSessionsExceededHandler = InvalidateLeastUsedServerMaximumSessionsExceededHandler()
maximumSessionsExceededHandler = InvalidateLeastUsedServerMaximumSessionsExceededHandler(webSessionManager.sessionStore)
}
}
}
Expand Down Expand Up @@ -263,8 +262,8 @@ class ServerSessionManagementDslTests {
}

@Bean
open fun reactiveSessionRegistry(webSessionManager: DefaultWebSessionManager): ReactiveSessionRegistry {
return WebSessionStoreReactiveSessionRegistry(webSessionManager.sessionStore)
open fun reactiveSessionRegistry(): ReactiveSessionRegistry {
return InMemoryReactiveSessionRegistry()
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ SecurityWebFilterChain filterChain(ServerHttpSecurity http) {
.sessionManagement((sessions) -> sessions
.concurrentSessions((concurrency) -> concurrency
.maximumSessions(SessionLimit.of(1))
)
);
return http.build();
}
@Bean
ReactiveSessionRegistry reactiveSessionRegistry(WebSessionManager webSessionManager) {
return new WebSessionStoreReactiveSessionRegistry(((DefaultWebSessionManager) webSessionManager).getSessionStore());
ReactiveSessionRegistry reactiveSessionRegistry() {
return new InMemoryReactiveSessionRegistry();
}
----
Expand All @@ -60,8 +61,8 @@ open fun springSecurity(http: ServerHttpSecurity): SecurityWebFilterChain {
}
}
@Bean
open fun reactiveSessionRegistry(webSessionManager: WebSessionManager): ReactiveSessionRegistry {
return WebSessionStoreReactiveSessionRegistry((webSessionManager as DefaultWebSessionManager).sessionStore)
open fun reactiveSessionRegistry(): ReactiveSessionRegistry {
return InMemoryReactiveSessionRegistry()
}
----
======
Expand All @@ -88,8 +89,8 @@ SecurityWebFilterChain filterChain(ServerHttpSecurity http) {
}
@Bean
ReactiveSessionRegistry reactiveSessionRegistry(WebSessionManager webSessionManager) {
return new WebSessionStoreReactiveSessionRegistry(((DefaultWebSessionManager) webSessionManager).getSessionStore());
ReactiveSessionRegistry reactiveSessionRegistry() {
return new InMemoryReactiveSessionRegistry();
}
----
Expand All @@ -110,7 +111,7 @@ open fun springSecurity(http: ServerHttpSecurity): SecurityWebFilterChain {
}
@Bean
open fun reactiveSessionRegistry(webSessionManager: WebSessionManager): ReactiveSessionRegistry {
return WebSessionStoreReactiveSessionRegistry((webSessionManager as DefaultWebSessionManager).sessionStore)
return InMemoryReactiveSessionRegistry()
}
----
======
Expand Down Expand Up @@ -148,8 +149,8 @@ private SessionLimit maxSessions() {
}
@Bean
ReactiveSessionRegistry reactiveSessionRegistry(WebSessionManager webSessionManager) {
return new WebSessionStoreReactiveSessionRegistry(((DefaultWebSessionManager) webSessionManager).getSessionStore());
ReactiveSessionRegistry reactiveSessionRegistry() {
return new InMemoryReactiveSessionRegistry();
}
----
Expand Down Expand Up @@ -178,8 +179,8 @@ fun maxSessions(): SessionLimit {
}
@Bean
open fun reactiveSessionRegistry(webSessionManager: WebSessionManager): ReactiveSessionRegistry {
return WebSessionStoreReactiveSessionRegistry((webSessionManager as DefaultWebSessionManager).sessionStore)
open fun reactiveSessionRegistry(): ReactiveSessionRegistry {
return InMemoryReactiveSessionRegistry()
}
----
======
Expand Down Expand Up @@ -215,8 +216,8 @@ SecurityWebFilterChain filterChain(ServerHttpSecurity http) {
}
@Bean
ReactiveSessionRegistry reactiveSessionRegistry(WebSessionManager webSessionManager) {
return new WebSessionStoreReactiveSessionRegistry(((DefaultWebSessionManager) webSessionManager).getSessionStore());
ReactiveSessionRegistry reactiveSessionRegistry() {
return new InMemoryReactiveSessionRegistry();
}
----
Expand All @@ -238,8 +239,8 @@ open fun springSecurity(http: ServerHttpSecurity): SecurityWebFilterChain {
}
@Bean
open fun reactiveSessionRegistry(webSessionManager: WebSessionManager): ReactiveSessionRegistry {
return WebSessionStoreReactiveSessionRegistry((webSessionManager as DefaultWebSessionManager).sessionStore)
open fun reactiveSessionRegistry(): ReactiveSessionRegistry {
return InMemoryReactiveSessionRegistry()
}
----
======
Expand All @@ -248,15 +249,8 @@ open fun reactiveSessionRegistry(webSessionManager: WebSessionManager): Reactive
== Specifying a `ReactiveSessionRegistry`

In order to keep track of the user's sessions, Spring Security uses a {security-api-url}org/springframework/security/core/session/ReactiveSessionRegistry.html[ReactiveSessionRegistry], and, every time a user logs in, their session information is saved.
Typically, in a Spring WebFlux application, you will use the {security-api-url}/org/springframework/security/web/session/WebSessionStoreReactiveSessionRegistry.html[WebSessionStoreReactiveSessionRegistry] which makes sure that the `WebSession` is invalidated whenever the `ReactiveSessionInformation` is invalidated.

Spring Security ships with {security-api-url}/org/springframework/security/web/session/WebSessionStoreReactiveSessionRegistry.html[WebSessionStoreReactiveSessionRegistry] and {security-api-url}org/springframework/security/core/session/InMemoryReactiveSessionRegistry.html[InMemoryReactiveSessionRegistry] implementations of `ReactiveSessionRegistry`.

[NOTE]
====
When creating the `WebSessionStoreReactiveSessionRegistry`, you need to provide the `WebSessionStore` that is being used by your application.
If you are using Spring WebFlux, you can use the `WebSessionManager` bean (which is usually an instance of `DefaultWebSessionManager`) to get the `WebSessionStore`.
====
Spring Security ships with {security-api-url}org/springframework/security/core/session/InMemoryReactiveSessionRegistry.html[InMemoryReactiveSessionRegistry] implementation of `ReactiveSessionRegistry`.

To specify a `ReactiveSessionRegistry` implementation you can either declare it as a bean:

Expand All @@ -281,7 +275,7 @@ SecurityWebFilterChain filterChain(ServerHttpSecurity http) {
@Bean
ReactiveSessionRegistry reactiveSessionRegistry() {
return new InMemoryReactiveSessionRegistry();
return new MyReactiveSessionRegistry();
}
----
Expand All @@ -303,7 +297,7 @@ open fun springSecurity(http: ServerHttpSecurity): SecurityWebFilterChain {
@Bean
open fun reactiveSessionRegistry(): ReactiveSessionRegistry {
return InMemoryReactiveSessionRegistry()
return MyReactiveSessionRegistry()
}
----
======
Expand All @@ -324,7 +318,7 @@ SecurityWebFilterChain filterChain(ServerHttpSecurity http) {
.sessionManagement((sessions) -> sessions
.concurrentSessions((concurrency) -> concurrency
.maximumSessions(SessionLimit.of(1))
.sessionRegistry(new InMemoryReactiveSessionRegistry())
.sessionRegistry(new MyReactiveSessionRegistry())
)
);
return http.build();
Expand All @@ -342,7 +336,7 @@ open fun springSecurity(http: ServerHttpSecurity): SecurityWebFilterChain {
sessionManagement {
sessionConcurrency {
maximumSessions = SessionLimit.of(1)
sessionRegistry = InMemoryReactiveSessionRegistry()
sessionRegistry = MyReactiveSessionRegistry()
}
}
}
Expand All @@ -355,7 +349,7 @@ open fun springSecurity(http: ServerHttpSecurity): SecurityWebFilterChain {

At times, it is handy to be able to invalidate all or some of a user's sessions.
For example, when a user changes their password, you may want to invalidate all of their sessions so that they are forced to log in again.
To do that, you can use the `ReactiveSessionRegistry` bean to retrieve all the user's sessions and then invalidate them:
To do that, you can use the `ReactiveSessionRegistry` bean to retrieve all the user's sessions, invalidate them, and them remove them from the `WebSessionStore`:

.Using ReactiveSessionRegistry to invalidate sessions manually
[tabs]
Expand All @@ -367,13 +361,12 @@ Java::
public class SessionControl {
private final ReactiveSessionRegistry reactiveSessionRegistry;
public SessionControl(ReactiveSessionRegistry reactiveSessionRegistry) {
this.reactiveSessionRegistry = reactiveSessionRegistry;
}
private final WebSessionStore webSessionStore;
public Mono<Void> invalidateSessions(String username) {
return this.reactiveSessionRegistry.getAllSessions(username)
.flatMap(ReactiveSessionInformation::invalidate)
.flatMap((session) -> session.invalidate().thenReturn(session))
.flatMap((session) -> this.webSessionStore.removeSession(session.getSessionId()))
.then();
}
}
Expand Down
Loading

0 comments on commit f8ff056

Please sign in to comment.