Skip to content

Commit

Permalink
Polish Max Sessions on WebFlux
Browse files Browse the repository at this point in the history
This commit changes the PreventLoginServerMaximumSessionsExceededHandler to invalidate the WebSession in addition to throwing the error, this is needed otherwise the session would still be saved with the security context. It also changes the SessionRegistryWebSession to first perform the operation on the delegate and then invoke the needed method on the ReactiveSessionRegistry

Issue gh-6192
  • Loading branch information
marcusdacoregio committed Feb 27, 2024
1 parent c639d0a commit a5ce8ae
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2143,19 +2143,23 @@ public boolean isStarted() {
@Override
public Mono<Void> changeSessionId() {
String currentId = this.session.getId();
return SessionRegistryWebFilter.this.sessionRegistry.removeSessionInformation(currentId)
.flatMap((information) -> this.session.changeSessionId().thenReturn(information))
.flatMap((information) -> {
information = information.withSessionId(this.session.getId());
return SessionRegistryWebFilter.this.sessionRegistry.saveSessionInformation(information);
});
return this.session.changeSessionId()
.then(Mono.defer(
() -> SessionRegistryWebFilter.this.sessionRegistry.removeSessionInformation(currentId)
.flatMap((information) -> {
information = information.withSessionId(this.session.getId());
return SessionRegistryWebFilter.this.sessionRegistry
.saveSessionInformation(information);
})));
}

@Override
public Mono<Void> invalidate() {
String currentId = this.session.getId();
return SessionRegistryWebFilter.this.sessionRegistry.removeSessionInformation(currentId)
.flatMap((information) -> this.session.invalidate());
return this.session.invalidate()
.then(Mono.defer(() -> SessionRegistryWebFilter.this.sessionRegistry
.removeSessionInformation(currentId)))
.then();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
import org.springframework.web.server.session.DefaultWebSessionManager;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -95,14 +96,19 @@ void loginWhenMaxSessionPreventsLoginThenSecondLoginFails() {
ResponseCookie firstLoginSessionCookie = loginReturningCookie(data);

// second login should fail
this.client.mutateWith(csrf())
ResponseCookie secondLoginSessionCookie = this.client.mutateWith(csrf())
.post()
.uri("/login")
.contentType(MediaType.MULTIPART_FORM_DATA)
.body(BodyInserters.fromFormData(data))
.exchange()
.expectHeader()
.location("/login?error");
.location("/login?error")
.returnResult(Void.class)
.getResponseCookies()
.getFirst("SESSION");

assertThat(secondLoginSessionCookie).isNull();

// first login should still be valid
this.client.mutateWith(csrf())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ private Mono<Void> handleConcurrency(WebFilterExchange exchange, Authentication
}
}
}
return this.maximumSessionsExceededHandler
.handle(new MaximumSessionsContext(authentication, registeredSessions, maximumSessions));
return this.maximumSessionsExceededHandler.handle(new MaximumSessionsContext(authentication,
registeredSessions, maximumSessions, currentSession));
});
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,7 @@

import org.springframework.security.core.Authentication;
import org.springframework.security.core.session.ReactiveSessionInformation;
import org.springframework.web.server.WebSession;

public final class MaximumSessionsContext {

Expand All @@ -29,11 +30,14 @@ public final class MaximumSessionsContext {

private final int maximumSessionsAllowed;

private final WebSession currentSession;

public MaximumSessionsContext(Authentication authentication, List<ReactiveSessionInformation> sessions,
int maximumSessionsAllowed) {
int maximumSessionsAllowed, WebSession currentSession) {
this.authentication = authentication;
this.sessions = sessions;
this.maximumSessionsAllowed = maximumSessionsAllowed;
this.currentSession = currentSession;
}

public Authentication getAuthentication() {
Expand All @@ -48,4 +52,8 @@ public int getMaximumSessionsAllowed() {
return this.maximumSessionsAllowed;
}

public WebSession getCurrentSession() {
return this.currentSession;
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,9 +31,9 @@ public final class PreventLoginServerMaximumSessionsExceededHandler implements S

@Override
public Mono<Void> handle(MaximumSessionsContext context) {
return Mono
.error(new SessionAuthenticationException("Maximum sessions of " + context.getMaximumSessionsAllowed()
+ " for authentication '" + context.getAuthentication().getName() + "' exceeded"));
return context.getCurrentSession()
.invalidate()
.then(Mono.defer(() -> Mono.error(new SessionAuthenticationException("Maximum sessions exceeded"))));
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,7 +50,7 @@ void handleWhenInvokedThenInvalidatesLeastRecentlyUsedSessions() {
given(session2.getLastAccessTime()).willReturn(Instant.ofEpochMilli(1700827760000L));
given(session2.invalidate()).willReturn(Mono.empty());
MaximumSessionsContext context = new MaximumSessionsContext(mock(Authentication.class),
List.of(session1, session2), 2);
List.of(session1, session2), 2, null);

this.handler.handle(context).block();

Expand All @@ -72,7 +72,7 @@ void handleWhenMoreThanOneSessionToInvalidateThenInvalidatesAllOfThem() {
given(session1.invalidate()).willReturn(Mono.empty());
given(session2.invalidate()).willReturn(Mono.empty());
MaximumSessionsContext context = new MaximumSessionsContext(mock(Authentication.class),
List.of(session1, session2, session3), 2);
List.of(session1, session2, session3), 2, null);

this.handler.handle(context).block();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,13 +19,19 @@
import java.util.Collections;

import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

import org.springframework.security.authentication.TestAuthentication;
import org.springframework.security.web.authentication.session.SessionAuthenticationException;
import org.springframework.security.web.server.authentication.MaximumSessionsContext;
import org.springframework.security.web.server.authentication.PreventLoginServerMaximumSessionsExceededHandler;
import org.springframework.web.server.WebSession;

import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

/**
* Tests for {@link PreventLoginServerMaximumSessionsExceededHandler}.
Expand All @@ -35,13 +41,17 @@
class PreventLoginServerMaximumSessionsExceededHandlerTests {

@Test
void handleWhenInvokedThenThrowsSessionAuthenticationException() {
void handleWhenInvokedThenInvalidateWebSessionAndThrowsSessionAuthenticationException() {
PreventLoginServerMaximumSessionsExceededHandler handler = new PreventLoginServerMaximumSessionsExceededHandler();
WebSession webSession = mock();
given(webSession.invalidate()).willReturn(Mono.empty());
MaximumSessionsContext context = new MaximumSessionsContext(TestAuthentication.authenticatedUser(),
Collections.emptyList(), 1);
assertThatExceptionOfType(SessionAuthenticationException.class)
.isThrownBy(() -> handler.handle(context).block())
.withMessage("Maximum sessions of 1 for authentication 'user' exceeded");
Collections.emptyList(), 1, webSession);
StepVerifier.create(handler.handle(context)).expectErrorSatisfies((ex) -> {
assertThat(ex).isInstanceOf(SessionAuthenticationException.class);
assertThat(ex.getMessage()).isEqualTo("Maximum sessions exceeded");
}).verify();
verify(webSession).invalidate();
}

}

0 comments on commit a5ce8ae

Please sign in to comment.