From f8ff056eb67d4e67716034681e0e6321a2650042 Mon Sep 17 00:00:00 2001 From: Marcus Hert Da Coregio Date: Wed, 28 Feb 2024 10:06:45 -0300 Subject: [PATCH] Update Max Sessions on WebFlux Delete WebSessionStoreReactiveSessionRegistry.java and gives the responsibility to remove the sessions from the WebSessionStore to the handler Issue gh-6192 --- .../config/web/server/ServerHttpSecurity.java | 25 +++- .../server/SessionManagementSpecTests.java | 55 +------ .../server/ServerSessionManagementDslTests.kt | 13 +- .../concurrent-sessions-control.adoc | 57 ++++---- ...rolServerAuthenticationSuccessHandler.java | 20 +-- ...dServerMaximumSessionsExceededHandler.java | 25 ++-- ...ebSessionStoreReactiveSessionRegistry.java | 100 ------------- ...rverAuthenticationSuccessHandlerTests.java | 22 +-- ...erMaximumSessionsExceededHandlerTests.java | 28 +++- ...sionStoreReactiveSessionRegistryTests.java | 136 ------------------ 10 files changed, 116 insertions(+), 365 deletions(-) delete mode 100644 web/src/main/java/org/springframework/security/web/session/WebSessionStoreReactiveSessionRegistry.java delete mode 100644 web/src/test/java/org/springframework/security/web/session/WebSessionStoreReactiveSessionRegistryTests.java diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 1939a4f1574..27a7e6a9be7 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -214,6 +214,8 @@ import org.springframework.web.server.WebFilter; import org.springframework.web.server.WebFilterChain; import org.springframework.web.server.WebSession; +import org.springframework.web.server.adapter.WebHttpHandlerBuilder; +import org.springframework.web.server.session.DefaultWebSessionManager; import org.springframework.web.util.pattern.PathPatternParser; /** @@ -1964,7 +1966,7 @@ public class SessionManagementSpec { private SessionLimit sessionLimit = SessionLimit.UNLIMITED; - private ServerMaximumSessionsExceededHandler maximumSessionsExceededHandler = new InvalidateLeastUsedServerMaximumSessionsExceededHandler(); + private ServerMaximumSessionsExceededHandler maximumSessionsExceededHandler; /** * Configures how many sessions are allowed for a given user. @@ -1983,9 +1985,8 @@ void configure(ServerHttpSecurity http) { if (this.concurrentSessions != null) { ReactiveSessionRegistry reactiveSessionRegistry = getSessionRegistry(); ConcurrentSessionControlServerAuthenticationSuccessHandler concurrentSessionControlStrategy = new ConcurrentSessionControlServerAuthenticationSuccessHandler( - reactiveSessionRegistry); + reactiveSessionRegistry, getMaximumSessionsExceededHandler()); concurrentSessionControlStrategy.setSessionLimit(this.sessionLimit); - concurrentSessionControlStrategy.setMaximumSessionsExceededHandler(this.maximumSessionsExceededHandler); RegisterSessionServerAuthenticationSuccessHandler registerSessionAuthenticationStrategy = new RegisterSessionServerAuthenticationSuccessHandler( reactiveSessionRegistry); this.authenticationSuccessHandler = new DelegatingServerAuthenticationSuccessHandler( @@ -1997,6 +1998,24 @@ void configure(ServerHttpSecurity http) { } } + private ServerMaximumSessionsExceededHandler getMaximumSessionsExceededHandler() { + if (this.maximumSessionsExceededHandler != null) { + return this.maximumSessionsExceededHandler; + } + DefaultWebSessionManager webSessionManager = getBeanOrNull( + WebHttpHandlerBuilder.WEB_SESSION_MANAGER_BEAN_NAME, DefaultWebSessionManager.class); + if (webSessionManager != null) { + this.maximumSessionsExceededHandler = new InvalidateLeastUsedServerMaximumSessionsExceededHandler( + webSessionManager.getSessionStore()); + } + if (this.maximumSessionsExceededHandler == null) { + throw new IllegalStateException( + "Could not create a default ServerMaximumSessionsExceededHandler. Please provide " + + "a ServerMaximumSessionsExceededHandler via DSL"); + } + return this.maximumSessionsExceededHandler; + } + private void configureSuccessHandlerOnAuthenticationFilters() { if (ServerHttpSecurity.this.formLogin != null) { ServerHttpSecurity.this.formLogin.defaultSuccessHandlers.add(0, this.authenticationSuccessHandler); diff --git a/config/src/test/java/org/springframework/security/config/web/server/SessionManagementSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/SessionManagementSpecTests.java index 22a5eee49c6..1a380bb6e3f 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/SessionManagementSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/SessionManagementSpecTests.java @@ -34,9 +34,8 @@ import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContextExtension; import org.springframework.security.config.users.ReactiveAuthenticationTestConfiguration; -import org.springframework.security.core.session.ReactiveSessionInformation; +import org.springframework.security.core.session.InMemoryReactiveSessionRegistry; import org.springframework.security.core.session.ReactiveSessionRegistry; -import org.springframework.security.core.userdetails.PasswordEncodedUser; import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken; import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; @@ -52,10 +51,8 @@ import org.springframework.security.web.server.authentication.PreventLoginServerMaximumSessionsExceededHandler; import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler; import org.springframework.security.web.server.authentication.ServerAuthenticationConverter; -import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler; import org.springframework.security.web.server.authentication.SessionLimit; import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; -import org.springframework.security.web.session.WebSessionStoreReactiveSessionRegistry; import org.springframework.test.context.junit.jupiter.SpringExtension; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.util.LinkedMultiValueMap; @@ -322,45 +319,6 @@ void oauth2LoginWhenMaxSessionsThenPreventLogin() { // @formatter:on } - @Test - void loginWhenUnlimitedSessionsButSessionsInvalidatedManuallyThenInvalidates() { - ConcurrentSessionsMaxSessionPreventsLoginFalseConfig.sessionLimit = SessionLimit.UNLIMITED; - this.spring.register(ConcurrentSessionsMaxSessionPreventsLoginFalseConfig.class).autowire(); - MultiValueMap data = new LinkedMultiValueMap<>(); - data.add("username", "user"); - data.add("password", "password"); - - ResponseCookie firstLogin = loginReturningCookie(data); - ResponseCookie secondLogin = loginReturningCookie(data); - this.client.get().uri("/").cookie(firstLogin.getName(), firstLogin.getValue()).exchange().expectStatus().isOk(); - this.client.get() - .uri("/") - .cookie(secondLogin.getName(), secondLogin.getValue()) - .exchange() - .expectStatus() - .isOk(); - ReactiveSessionRegistry sessionRegistry = this.spring.getContext().getBean(ReactiveSessionRegistry.class); - sessionRegistry.getAllSessions(PasswordEncodedUser.user()) - .flatMap(ReactiveSessionInformation::invalidate) - .blockLast(); - this.client.get() - .uri("/") - .cookie(firstLogin.getName(), firstLogin.getValue()) - .exchange() - .expectStatus() - .isFound() - .expectHeader() - .location("/login"); - this.client.get() - .uri("/") - .cookie(secondLogin.getName(), secondLogin.getValue()) - .exchange() - .expectStatus() - .isFound() - .expectHeader() - .location("/login"); - } - @Test void oauth2LoginWhenMaxSessionDoesNotPreventLoginThenSecondLoginSucceedsAndFirstSessionIsInvalidated() { OAuth2LoginConcurrentSessionsConfig.maxSessions = 1; @@ -490,10 +448,9 @@ static class OAuth2LoginConcurrentSessionsConfig { ServerOAuth2AuthorizationRequestResolver resolver = mock(ServerOAuth2AuthorizationRequestResolver.class); - ServerAuthenticationSuccessHandler successHandler = mock(ServerAuthenticationSuccessHandler.class); - @Bean - SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { + SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http, + DefaultWebSessionManager webSessionManager) { // @formatter:off http .authorizeExchange((exchanges) -> exchanges @@ -509,7 +466,7 @@ SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) { .maximumSessions(SessionLimit.of(maxSessions)) .maximumSessionsExceededHandler(preventLogin ? new PreventLoginServerMaximumSessionsExceededHandler() - : new InvalidateLeastUsedServerMaximumSessionsExceededHandler()) + : new InvalidateLeastUsedServerMaximumSessionsExceededHandler(webSessionManager.getSessionStore())) ) ); // @formatter:on @@ -611,8 +568,8 @@ DefaultWebSessionManager webSessionManager() { } @Bean - ReactiveSessionRegistry reactiveSessionRegistry(DefaultWebSessionManager webSessionManager) { - return new WebSessionStoreReactiveSessionRegistry(webSessionManager.getSessionStore()); + ReactiveSessionRegistry reactiveSessionRegistry() { + return new InMemoryReactiveSessionRegistry(); } } diff --git a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerSessionManagementDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerSessionManagementDslTests.kt index eda22ac6bed..a8c23c43bc2 100644 --- a/config/src/test/kotlin/org/springframework/security/config/web/server/ServerSessionManagementDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/web/server/ServerSessionManagementDslTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,13 +29,13 @@ import org.springframework.security.config.annotation.web.reactive.EnableWebFlux import org.springframework.security.config.test.SpringTestContext import org.springframework.security.config.test.SpringTestContextExtension import org.springframework.security.config.users.ReactiveAuthenticationTestConfiguration +import org.springframework.security.core.session.InMemoryReactiveSessionRegistry import org.springframework.security.core.session.ReactiveSessionRegistry import org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers import org.springframework.security.web.server.SecurityWebFilterChain import org.springframework.security.web.server.authentication.InvalidateLeastUsedServerMaximumSessionsExceededHandler import org.springframework.security.web.server.authentication.PreventLoginServerMaximumSessionsExceededHandler import org.springframework.security.web.server.authentication.SessionLimit -import org.springframework.security.web.session.WebSessionStoreReactiveSessionRegistry import org.springframework.test.web.reactive.server.WebTestClient import org.springframework.util.LinkedMultiValueMap import org.springframework.util.MultiValueMap @@ -45,7 +45,6 @@ import org.springframework.web.reactive.config.EnableWebFlux import org.springframework.web.reactive.function.BodyInserters import org.springframework.web.server.adapter.WebHttpHandlerBuilder import org.springframework.web.server.session.DefaultWebSessionManager -import reactor.core.publisher.Mono /** * Tests for [ServerSessionManagementDsl] @@ -208,7 +207,7 @@ class ServerSessionManagementDslTests { } @Bean - open fun springSecurity(http: ServerHttpSecurity): SecurityWebFilterChain { + open fun springSecurity(http: ServerHttpSecurity, webSessionManager: DefaultWebSessionManager): SecurityWebFilterChain { return http { authorizeExchange { authorize(anyExchange, authenticated) @@ -217,7 +216,7 @@ class ServerSessionManagementDslTests { sessionManagement { sessionConcurrency { maximumSessions = SessionLimit.of(maxSessions) - maximumSessionsExceededHandler = InvalidateLeastUsedServerMaximumSessionsExceededHandler() + maximumSessionsExceededHandler = InvalidateLeastUsedServerMaximumSessionsExceededHandler(webSessionManager.sessionStore) } } } @@ -263,8 +262,8 @@ class ServerSessionManagementDslTests { } @Bean - open fun reactiveSessionRegistry(webSessionManager: DefaultWebSessionManager): ReactiveSessionRegistry { - return WebSessionStoreReactiveSessionRegistry(webSessionManager.sessionStore) + open fun reactiveSessionRegistry(): ReactiveSessionRegistry { + return InMemoryReactiveSessionRegistry() } } diff --git a/docs/modules/ROOT/pages/reactive/authentication/concurrent-sessions-control.adoc b/docs/modules/ROOT/pages/reactive/authentication/concurrent-sessions-control.adoc index 480c2693814..eb313ec52c2 100644 --- a/docs/modules/ROOT/pages/reactive/authentication/concurrent-sessions-control.adoc +++ b/docs/modules/ROOT/pages/reactive/authentication/concurrent-sessions-control.adoc @@ -34,13 +34,14 @@ SecurityWebFilterChain filterChain(ServerHttpSecurity http) { .sessionManagement((sessions) -> sessions .concurrentSessions((concurrency) -> concurrency .maximumSessions(SessionLimit.of(1)) + ) ); return http.build(); } @Bean -ReactiveSessionRegistry reactiveSessionRegistry(WebSessionManager webSessionManager) { - return new WebSessionStoreReactiveSessionRegistry(((DefaultWebSessionManager) webSessionManager).getSessionStore()); +ReactiveSessionRegistry reactiveSessionRegistry() { + return new InMemoryReactiveSessionRegistry(); } ---- @@ -60,8 +61,8 @@ open fun springSecurity(http: ServerHttpSecurity): SecurityWebFilterChain { } } @Bean -open fun reactiveSessionRegistry(webSessionManager: WebSessionManager): ReactiveSessionRegistry { - return WebSessionStoreReactiveSessionRegistry((webSessionManager as DefaultWebSessionManager).sessionStore) +open fun reactiveSessionRegistry(): ReactiveSessionRegistry { + return InMemoryReactiveSessionRegistry() } ---- ====== @@ -88,8 +89,8 @@ SecurityWebFilterChain filterChain(ServerHttpSecurity http) { } @Bean -ReactiveSessionRegistry reactiveSessionRegistry(WebSessionManager webSessionManager) { - return new WebSessionStoreReactiveSessionRegistry(((DefaultWebSessionManager) webSessionManager).getSessionStore()); +ReactiveSessionRegistry reactiveSessionRegistry() { + return new InMemoryReactiveSessionRegistry(); } ---- @@ -110,7 +111,7 @@ open fun springSecurity(http: ServerHttpSecurity): SecurityWebFilterChain { } @Bean open fun reactiveSessionRegistry(webSessionManager: WebSessionManager): ReactiveSessionRegistry { - return WebSessionStoreReactiveSessionRegistry((webSessionManager as DefaultWebSessionManager).sessionStore) + return InMemoryReactiveSessionRegistry() } ---- ====== @@ -148,8 +149,8 @@ private SessionLimit maxSessions() { } @Bean -ReactiveSessionRegistry reactiveSessionRegistry(WebSessionManager webSessionManager) { - return new WebSessionStoreReactiveSessionRegistry(((DefaultWebSessionManager) webSessionManager).getSessionStore()); +ReactiveSessionRegistry reactiveSessionRegistry() { + return new InMemoryReactiveSessionRegistry(); } ---- @@ -178,8 +179,8 @@ fun maxSessions(): SessionLimit { } @Bean -open fun reactiveSessionRegistry(webSessionManager: WebSessionManager): ReactiveSessionRegistry { - return WebSessionStoreReactiveSessionRegistry((webSessionManager as DefaultWebSessionManager).sessionStore) +open fun reactiveSessionRegistry(): ReactiveSessionRegistry { + return InMemoryReactiveSessionRegistry() } ---- ====== @@ -215,8 +216,8 @@ SecurityWebFilterChain filterChain(ServerHttpSecurity http) { } @Bean -ReactiveSessionRegistry reactiveSessionRegistry(WebSessionManager webSessionManager) { - return new WebSessionStoreReactiveSessionRegistry(((DefaultWebSessionManager) webSessionManager).getSessionStore()); +ReactiveSessionRegistry reactiveSessionRegistry() { + return new InMemoryReactiveSessionRegistry(); } ---- @@ -238,8 +239,8 @@ open fun springSecurity(http: ServerHttpSecurity): SecurityWebFilterChain { } @Bean -open fun reactiveSessionRegistry(webSessionManager: WebSessionManager): ReactiveSessionRegistry { - return WebSessionStoreReactiveSessionRegistry((webSessionManager as DefaultWebSessionManager).sessionStore) +open fun reactiveSessionRegistry(): ReactiveSessionRegistry { + return InMemoryReactiveSessionRegistry() } ---- ====== @@ -248,15 +249,8 @@ open fun reactiveSessionRegistry(webSessionManager: WebSessionManager): Reactive == Specifying a `ReactiveSessionRegistry` In order to keep track of the user's sessions, Spring Security uses a {security-api-url}org/springframework/security/core/session/ReactiveSessionRegistry.html[ReactiveSessionRegistry], and, every time a user logs in, their session information is saved. -Typically, in a Spring WebFlux application, you will use the {security-api-url}/org/springframework/security/web/session/WebSessionStoreReactiveSessionRegistry.html[WebSessionStoreReactiveSessionRegistry] which makes sure that the `WebSession` is invalidated whenever the `ReactiveSessionInformation` is invalidated. - -Spring Security ships with {security-api-url}/org/springframework/security/web/session/WebSessionStoreReactiveSessionRegistry.html[WebSessionStoreReactiveSessionRegistry] and {security-api-url}org/springframework/security/core/session/InMemoryReactiveSessionRegistry.html[InMemoryReactiveSessionRegistry] implementations of `ReactiveSessionRegistry`. -[NOTE] -==== -When creating the `WebSessionStoreReactiveSessionRegistry`, you need to provide the `WebSessionStore` that is being used by your application. -If you are using Spring WebFlux, you can use the `WebSessionManager` bean (which is usually an instance of `DefaultWebSessionManager`) to get the `WebSessionStore`. -==== +Spring Security ships with {security-api-url}org/springframework/security/core/session/InMemoryReactiveSessionRegistry.html[InMemoryReactiveSessionRegistry] implementation of `ReactiveSessionRegistry`. To specify a `ReactiveSessionRegistry` implementation you can either declare it as a bean: @@ -281,7 +275,7 @@ SecurityWebFilterChain filterChain(ServerHttpSecurity http) { @Bean ReactiveSessionRegistry reactiveSessionRegistry() { - return new InMemoryReactiveSessionRegistry(); + return new MyReactiveSessionRegistry(); } ---- @@ -303,7 +297,7 @@ open fun springSecurity(http: ServerHttpSecurity): SecurityWebFilterChain { @Bean open fun reactiveSessionRegistry(): ReactiveSessionRegistry { - return InMemoryReactiveSessionRegistry() + return MyReactiveSessionRegistry() } ---- ====== @@ -324,7 +318,7 @@ SecurityWebFilterChain filterChain(ServerHttpSecurity http) { .sessionManagement((sessions) -> sessions .concurrentSessions((concurrency) -> concurrency .maximumSessions(SessionLimit.of(1)) - .sessionRegistry(new InMemoryReactiveSessionRegistry()) + .sessionRegistry(new MyReactiveSessionRegistry()) ) ); return http.build(); @@ -342,7 +336,7 @@ open fun springSecurity(http: ServerHttpSecurity): SecurityWebFilterChain { sessionManagement { sessionConcurrency { maximumSessions = SessionLimit.of(1) - sessionRegistry = InMemoryReactiveSessionRegistry() + sessionRegistry = MyReactiveSessionRegistry() } } } @@ -355,7 +349,7 @@ open fun springSecurity(http: ServerHttpSecurity): SecurityWebFilterChain { At times, it is handy to be able to invalidate all or some of a user's sessions. For example, when a user changes their password, you may want to invalidate all of their sessions so that they are forced to log in again. -To do that, you can use the `ReactiveSessionRegistry` bean to retrieve all the user's sessions and then invalidate them: +To do that, you can use the `ReactiveSessionRegistry` bean to retrieve all the user's sessions, invalidate them, and them remove them from the `WebSessionStore`: .Using ReactiveSessionRegistry to invalidate sessions manually [tabs] @@ -367,13 +361,12 @@ Java:: public class SessionControl { private final ReactiveSessionRegistry reactiveSessionRegistry; - public SessionControl(ReactiveSessionRegistry reactiveSessionRegistry) { - this.reactiveSessionRegistry = reactiveSessionRegistry; - } + private final WebSessionStore webSessionStore; public Mono invalidateSessions(String username) { return this.reactiveSessionRegistry.getAllSessions(username) - .flatMap(ReactiveSessionInformation::invalidate) + .flatMap((session) -> session.invalidate().thenReturn(session)) + .flatMap((session) -> this.webSessionStore.removeSession(session.getSessionId())) .then(); } } diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/ConcurrentSessionControlServerAuthenticationSuccessHandler.java b/web/src/main/java/org/springframework/security/web/server/authentication/ConcurrentSessionControlServerAuthenticationSuccessHandler.java index 556bf042dfe..cc28044da7c 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/ConcurrentSessionControlServerAuthenticationSuccessHandler.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/ConcurrentSessionControlServerAuthenticationSuccessHandler.java @@ -45,13 +45,16 @@ public final class ConcurrentSessionControlServerAuthenticationSuccessHandler private final ReactiveSessionRegistry sessionRegistry; - private SessionLimit sessionLimit = SessionLimit.of(1); + private final ServerMaximumSessionsExceededHandler maximumSessionsExceededHandler; - private ServerMaximumSessionsExceededHandler maximumSessionsExceededHandler = new InvalidateLeastUsedServerMaximumSessionsExceededHandler(); + private SessionLimit sessionLimit = SessionLimit.of(1); - public ConcurrentSessionControlServerAuthenticationSuccessHandler(ReactiveSessionRegistry sessionRegistry) { + public ConcurrentSessionControlServerAuthenticationSuccessHandler(ReactiveSessionRegistry sessionRegistry, + ServerMaximumSessionsExceededHandler maximumSessionsExceededHandler) { Assert.notNull(sessionRegistry, "sessionRegistry cannot be null"); + Assert.notNull(maximumSessionsExceededHandler, "maximumSessionsExceededHandler cannot be null"); this.sessionRegistry = sessionRegistry; + this.maximumSessionsExceededHandler = maximumSessionsExceededHandler; } @Override @@ -97,15 +100,4 @@ public void setSessionLimit(SessionLimit sessionLimit) { this.sessionLimit = sessionLimit; } - /** - * Sets the {@link ServerMaximumSessionsExceededHandler} to use. The default is - * {@link InvalidateLeastUsedServerMaximumSessionsExceededHandler}. - * @param maximumSessionsExceededHandler the - * {@link ServerMaximumSessionsExceededHandler} to use - */ - public void setMaximumSessionsExceededHandler(ServerMaximumSessionsExceededHandler maximumSessionsExceededHandler) { - Assert.notNull(maximumSessionsExceededHandler, "maximumSessionsExceededHandler cannot be null"); - this.maximumSessionsExceededHandler = maximumSessionsExceededHandler; - } - } diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/InvalidateLeastUsedServerMaximumSessionsExceededHandler.java b/web/src/main/java/org/springframework/security/web/server/authentication/InvalidateLeastUsedServerMaximumSessionsExceededHandler.java index efc4a3b6742..28446df993e 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/InvalidateLeastUsedServerMaximumSessionsExceededHandler.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/InvalidateLeastUsedServerMaximumSessionsExceededHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,19 +20,18 @@ import java.util.Comparator; import java.util.List; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import org.springframework.core.log.LogMessage; import org.springframework.security.core.session.ReactiveSessionInformation; +import org.springframework.web.server.session.WebSessionStore; /** * Implementation of {@link ServerMaximumSessionsExceededHandler} that invalidates the - * least recently used session(s). It only invalidates the amount of sessions that exceed - * the maximum allowed. For example, if the maximum was exceeded by 1, only the least - * recently used session will be invalidated. + * least recently used {@link ReactiveSessionInformation} and removes the related sessions + * from the {@link WebSessionStore}. It only invalidates the amount of sessions that + * exceed the maximum allowed. For example, if the maximum was exceeded by 1, only the + * least recently used session will be invalidated. * * @author Marcus da Coregio * @since 6.3 @@ -40,7 +39,11 @@ public final class InvalidateLeastUsedServerMaximumSessionsExceededHandler implements ServerMaximumSessionsExceededHandler { - private final Log logger = LogFactory.getLog(getClass()); + private final WebSessionStore webSessionStore; + + public InvalidateLeastUsedServerMaximumSessionsExceededHandler(WebSessionStore webSessionStore) { + this.webSessionStore = webSessionStore; + } @Override public Mono handle(MaximumSessionsContext context) { @@ -51,10 +54,8 @@ public Mono handle(MaximumSessionsContext context) { maximumSessionsExceededBy); return Flux.fromIterable(leastRecentlyUsedSessionsToInvalidate) - .doOnComplete(() -> this.logger - .debug(LogMessage.format("Invalidated %d least recently used sessions for authentication %s", - leastRecentlyUsedSessionsToInvalidate.size(), context.getAuthentication().getName()))) - .flatMap(ReactiveSessionInformation::invalidate) + .flatMap((toInvalidate) -> toInvalidate.invalidate().thenReturn(toInvalidate)) + .flatMap((toInvalidate) -> this.webSessionStore.removeSession(toInvalidate.getSessionId())) .then(); } diff --git a/web/src/main/java/org/springframework/security/web/session/WebSessionStoreReactiveSessionRegistry.java b/web/src/main/java/org/springframework/security/web/session/WebSessionStoreReactiveSessionRegistry.java deleted file mode 100644 index 781c438a6e1..00000000000 --- a/web/src/main/java/org/springframework/security/web/session/WebSessionStoreReactiveSessionRegistry.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Copyright 2002-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.web.session; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import org.springframework.security.core.session.InMemoryReactiveSessionRegistry; -import org.springframework.security.core.session.ReactiveSessionInformation; -import org.springframework.security.core.session.ReactiveSessionRegistry; -import org.springframework.util.Assert; -import org.springframework.web.server.WebSession; -import org.springframework.web.server.session.WebSessionStore; - -/** - * A {@link ReactiveSessionRegistry} implementation that uses a {@link WebSessionStore} to - * invalidate a {@link WebSession} when the {@link ReactiveSessionInformation} is - * invalidated. - * - * @author Marcus da Coregio - * @since 6.3 - */ -public final class WebSessionStoreReactiveSessionRegistry implements ReactiveSessionRegistry { - - private final WebSessionStore webSessionStore; - - private ReactiveSessionRegistry sessionRegistry = new InMemoryReactiveSessionRegistry(); - - public WebSessionStoreReactiveSessionRegistry(WebSessionStore webSessionStore) { - Assert.notNull(webSessionStore, "webSessionStore cannot be null"); - this.webSessionStore = webSessionStore; - } - - @Override - public Flux getAllSessions(Object principal) { - return this.sessionRegistry.getAllSessions(principal).map(WebSessionInformation::new); - } - - @Override - public Mono saveSessionInformation(ReactiveSessionInformation information) { - return this.sessionRegistry.saveSessionInformation(new WebSessionInformation(information)); - } - - @Override - public Mono getSessionInformation(String sessionId) { - return this.sessionRegistry.getSessionInformation(sessionId).map(WebSessionInformation::new); - } - - @Override - public Mono removeSessionInformation(String sessionId) { - return this.sessionRegistry.removeSessionInformation(sessionId).map(WebSessionInformation::new); - } - - @Override - public Mono updateLastAccessTime(String sessionId) { - return this.sessionRegistry.updateLastAccessTime(sessionId).map(WebSessionInformation::new); - } - - /** - * Sets the {@link ReactiveSessionRegistry} to use. - * @param sessionRegistry the {@link ReactiveSessionRegistry} to use. Cannot be null. - */ - public void setSessionRegistry(ReactiveSessionRegistry sessionRegistry) { - Assert.notNull(sessionRegistry, "sessionRegistry cannot be null"); - this.sessionRegistry = sessionRegistry; - } - - final class WebSessionInformation extends ReactiveSessionInformation { - - WebSessionInformation(ReactiveSessionInformation sessionInformation) { - super(sessionInformation.getPrincipal(), sessionInformation.getSessionId(), - sessionInformation.getLastAccessTime()); - } - - @Override - public Mono invalidate() { - return WebSessionStoreReactiveSessionRegistry.this.webSessionStore.retrieveSession(getSessionId()) - .flatMap(WebSession::invalidate) - .then(Mono - .defer(() -> WebSessionStoreReactiveSessionRegistry.this.removeSessionInformation(getSessionId()))) - .then(Mono.defer(super::invalidate)); - } - - } - -} diff --git a/web/src/test/java/org/springframework/security/web/server/authentication/session/ConcurrentSessionControlServerAuthenticationSuccessHandlerTests.java b/web/src/test/java/org/springframework/security/web/server/authentication/session/ConcurrentSessionControlServerAuthenticationSuccessHandlerTests.java index 643b3015fc8..b23c93c23cf 100644 --- a/web/src/test/java/org/springframework/security/web/server/authentication/session/ConcurrentSessionControlServerAuthenticationSuccessHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/server/authentication/session/ConcurrentSessionControlServerAuthenticationSuccessHandlerTests.java @@ -74,34 +74,36 @@ void setup() { given(this.exchange.getRequest()).willReturn(MockServerHttpRequest.get("/").build()); given(this.exchange.getSession()).willReturn(Mono.just(new MockWebSession())); given(this.handler.handle(any())).willReturn(Mono.empty()); - this.strategy = new ConcurrentSessionControlServerAuthenticationSuccessHandler(this.sessionRegistry); - this.strategy.setMaximumSessionsExceededHandler(this.handler); + this.strategy = new ConcurrentSessionControlServerAuthenticationSuccessHandler(this.sessionRegistry, + this.handler); } @Test void constructorWhenNullRegistryThenException() { assertThatIllegalArgumentException() - .isThrownBy(() -> new ConcurrentSessionControlServerAuthenticationSuccessHandler(null)) + .isThrownBy(() -> new ConcurrentSessionControlServerAuthenticationSuccessHandler(null, this.handler)) .withMessage("sessionRegistry cannot be null"); } @Test - void setMaximumSessionsForAuthenticationWhenNullThenException() { - assertThatIllegalArgumentException().isThrownBy(() -> this.strategy.setSessionLimit(null)) - .withMessage("sessionLimit cannot be null"); + void constructorWhenNullHandlerThenException() { + assertThatIllegalArgumentException() + .isThrownBy( + () -> new ConcurrentSessionControlServerAuthenticationSuccessHandler(this.sessionRegistry, null)) + .withMessage("maximumSessionsExceededHandler cannot be null"); } @Test - void setMaximumSessionsExceededHandlerWhenNullThenException() { - assertThatIllegalArgumentException().isThrownBy(() -> this.strategy.setMaximumSessionsExceededHandler(null)) - .withMessage("maximumSessionsExceededHandler cannot be null"); + void setMaximumSessionsForAuthenticationWhenNullThenException() { + assertThatIllegalArgumentException().isThrownBy(() -> this.strategy.setSessionLimit(null)) + .withMessage("sessionLimit cannot be null"); } @Test void onAuthenticationWhenSessionLimitIsUnlimitedThenDoNothing() { ServerMaximumSessionsExceededHandler handler = mock(ServerMaximumSessionsExceededHandler.class); + this.strategy = new ConcurrentSessionControlServerAuthenticationSuccessHandler(this.sessionRegistry, handler); this.strategy.setSessionLimit(SessionLimit.UNLIMITED); - this.strategy.setMaximumSessionsExceededHandler(handler); this.strategy.onAuthenticationSuccess(null, TestAuthentication.authenticatedUser()).block(); verifyNoInteractions(handler, this.sessionRegistry); } diff --git a/web/src/test/java/org/springframework/security/web/server/authentication/session/InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests.java b/web/src/test/java/org/springframework/security/web/server/authentication/session/InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests.java index 3c16e6fdd92..eb1db969c31 100644 --- a/web/src/test/java/org/springframework/security/web/server/authentication/session/InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/server/authentication/session/InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests.java @@ -19,6 +19,7 @@ import java.time.Instant; import java.util.List; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; @@ -26,10 +27,13 @@ import org.springframework.security.core.session.ReactiveSessionInformation; import org.springframework.security.web.server.authentication.InvalidateLeastUsedServerMaximumSessionsExceededHandler; import org.springframework.security.web.server.authentication.MaximumSessionsContext; +import org.springframework.web.server.session.InMemoryWebSessionStore; +import org.springframework.web.server.session.WebSessionStore; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -40,7 +44,14 @@ */ class InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests { - InvalidateLeastUsedServerMaximumSessionsExceededHandler handler = new InvalidateLeastUsedServerMaximumSessionsExceededHandler(); + InvalidateLeastUsedServerMaximumSessionsExceededHandler handler; + + WebSessionStore webSessionStore = spy(new InMemoryWebSessionStore()); + + @BeforeEach + void setup() { + this.handler = new InvalidateLeastUsedServerMaximumSessionsExceededHandler(this.webSessionStore); + } @Test void handleWhenInvokedThenInvalidatesLeastRecentlyUsedSessions() { @@ -48,7 +59,9 @@ void handleWhenInvokedThenInvalidatesLeastRecentlyUsedSessions() { ReactiveSessionInformation session2 = mock(ReactiveSessionInformation.class); given(session1.getLastAccessTime()).willReturn(Instant.ofEpochMilli(1700827760010L)); given(session2.getLastAccessTime()).willReturn(Instant.ofEpochMilli(1700827760000L)); + given(session2.getSessionId()).willReturn("session2"); given(session2.invalidate()).willReturn(Mono.empty()); + MaximumSessionsContext context = new MaximumSessionsContext(mock(Authentication.class), List.of(session1, session2), 2, null); @@ -57,6 +70,10 @@ void handleWhenInvokedThenInvalidatesLeastRecentlyUsedSessions() { verify(session2).invalidate(); verify(session1).getLastAccessTime(); // used by comparator to sort the sessions verify(session2).getLastAccessTime(); // used by comparator to sort the sessions + verify(session2).getSessionId(); // used to invalidate session against the + // WebSessionStore + verify(this.webSessionStore).removeSession("session2"); + verifyNoMoreInteractions(this.webSessionStore); verifyNoMoreInteractions(session2); verifyNoMoreInteractions(session1); } @@ -71,17 +88,24 @@ void handleWhenMoreThanOneSessionToInvalidateThenInvalidatesAllOfThem() { given(session3.getLastAccessTime()).willReturn(Instant.ofEpochMilli(1700827760030L)); given(session1.invalidate()).willReturn(Mono.empty()); given(session2.invalidate()).willReturn(Mono.empty()); + given(session1.getSessionId()).willReturn("session1"); + given(session2.getSessionId()).willReturn("session2"); + MaximumSessionsContext context = new MaximumSessionsContext(mock(Authentication.class), List.of(session1, session2, session3), 2, null); - this.handler.handle(context).block(); // @formatter:off verify(session1).invalidate(); verify(session2).invalidate(); + verify(session1).getSessionId(); + verify(session2).getSessionId(); verify(session1, atLeastOnce()).getLastAccessTime(); // used by comparator to sort the sessions verify(session2, atLeastOnce()).getLastAccessTime(); // used by comparator to sort the sessions verify(session3, atLeastOnce()).getLastAccessTime(); // used by comparator to sort the sessions + verify(this.webSessionStore).removeSession("session1"); + verify(this.webSessionStore).removeSession("session2"); + verifyNoMoreInteractions(this.webSessionStore); verifyNoMoreInteractions(session1); verifyNoMoreInteractions(session2); verifyNoMoreInteractions(session3); diff --git a/web/src/test/java/org/springframework/security/web/session/WebSessionStoreReactiveSessionRegistryTests.java b/web/src/test/java/org/springframework/security/web/session/WebSessionStoreReactiveSessionRegistryTests.java deleted file mode 100644 index 83dc8d5a501..00000000000 --- a/web/src/test/java/org/springframework/security/web/session/WebSessionStoreReactiveSessionRegistryTests.java +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright 2002-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.security.web.session; - -import java.time.Instant; -import java.util.List; - -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import org.springframework.security.core.session.ReactiveSessionInformation; -import org.springframework.security.core.session.ReactiveSessionRegistry; -import org.springframework.web.server.WebSession; -import org.springframework.web.server.session.WebSessionStore; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.BDDMockito.given; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; - -/** - * Tests for {@link WebSessionStoreReactiveSessionRegistry} - * - * @author Marcus da Coregio - */ -class WebSessionStoreReactiveSessionRegistryTests { - - WebSessionStore webSessionStore = mock(); - - WebSessionStoreReactiveSessionRegistry registry = new WebSessionStoreReactiveSessionRegistry(this.webSessionStore); - - @Test - void constructorWhenWebSessionStoreNullThenException() { - assertThatIllegalArgumentException().isThrownBy(() -> new WebSessionStoreReactiveSessionRegistry(null)) - .withMessage("webSessionStore cannot be null"); - } - - @Test - void getSessionInformationWhenSavedThenReturnsWebSessionInformation() { - ReactiveSessionInformation session = createSession(); - this.registry.saveSessionInformation(session).block(); - ReactiveSessionInformation saved = this.registry.getSessionInformation(session.getSessionId()).block(); - assertThat(saved).isInstanceOf(WebSessionStoreReactiveSessionRegistry.WebSessionInformation.class); - assertThat(saved.getPrincipal()).isEqualTo(session.getPrincipal()); - assertThat(saved.getSessionId()).isEqualTo(session.getSessionId()); - assertThat(saved.getLastAccessTime()).isEqualTo(session.getLastAccessTime()); - } - - @Test - void invalidateWhenReturnedFromGetSessionInformationThenWebSessionInvalidatedAndRemovedFromRegistry() { - ReactiveSessionInformation session = createSession(); - WebSession webSession = mock(); - given(webSession.invalidate()).willReturn(Mono.empty()); - given(this.webSessionStore.retrieveSession(session.getSessionId())).willReturn(Mono.just(webSession)); - - this.registry.saveSessionInformation(session).block(); - ReactiveSessionInformation saved = this.registry.getSessionInformation(session.getSessionId()).block(); - saved.invalidate().block(); - verify(webSession).invalidate(); - assertThat(this.registry.getSessionInformation(saved.getSessionId()).block()).isNull(); - } - - @Test - void invalidateWhenReturnedFromRemoveSessionInformationThenWebSessionInvalidatedAndRemovedFromRegistry() { - ReactiveSessionInformation session = createSession(); - WebSession webSession = mock(); - given(webSession.invalidate()).willReturn(Mono.empty()); - given(this.webSessionStore.retrieveSession(session.getSessionId())).willReturn(Mono.just(webSession)); - - this.registry.saveSessionInformation(session).block(); - ReactiveSessionInformation saved = this.registry.removeSessionInformation(session.getSessionId()).block(); - saved.invalidate().block(); - verify(webSession).invalidate(); - assertThat(this.registry.getSessionInformation(saved.getSessionId()).block()).isNull(); - } - - @Test - void invalidateWhenReturnedFromGetAllSessionsThenWebSessionInvalidatedAndRemovedFromRegistry() { - ReactiveSessionInformation session = createSession(); - WebSession webSession = mock(); - given(webSession.invalidate()).willReturn(Mono.empty()); - given(this.webSessionStore.retrieveSession(session.getSessionId())).willReturn(Mono.just(webSession)); - - this.registry.saveSessionInformation(session).block(); - List saved = this.registry.getAllSessions(session.getPrincipal()) - .collectList() - .block(); - saved.forEach((info) -> info.invalidate().block()); - verify(webSession).invalidate(); - assertThat(this.registry.getAllSessions(session.getPrincipal()).collectList().block()).isEmpty(); - } - - @Test - void setSessionRegistryThenUses() { - ReactiveSessionRegistry sessionRegistry = mock(); - given(sessionRegistry.saveSessionInformation(any())).willReturn(Mono.empty()); - given(sessionRegistry.removeSessionInformation(any())).willReturn(Mono.empty()); - given(sessionRegistry.updateLastAccessTime(any())).willReturn(Mono.empty()); - given(sessionRegistry.getSessionInformation(any())).willReturn(Mono.empty()); - given(sessionRegistry.getAllSessions(any())).willReturn(Flux.empty()); - this.registry.setSessionRegistry(sessionRegistry); - ReactiveSessionInformation session = createSession(); - this.registry.saveSessionInformation(session).block(); - verify(sessionRegistry).saveSessionInformation(any()); - this.registry.removeSessionInformation(session.getSessionId()).block(); - verify(sessionRegistry).removeSessionInformation(any()); - this.registry.updateLastAccessTime(session.getSessionId()).block(); - verify(sessionRegistry).updateLastAccessTime(any()); - this.registry.getSessionInformation(session.getSessionId()).block(); - verify(sessionRegistry).getSessionInformation(any()); - this.registry.getAllSessions(session.getPrincipal()).blockFirst(); - verify(sessionRegistry).getAllSessions(any()); - } - - private static ReactiveSessionInformation createSession() { - return new ReactiveSessionInformation("principal", "sessionId", Instant.now()); - } - -}