diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java index af362b2d56e..a1318a7eb9b 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java @@ -20,16 +20,15 @@ import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.HashSet; -import java.util.Arrays; -import java.util.Optional; import java.util.function.Consumer; import javax.annotation.Nonnull; @@ -98,8 +97,6 @@ import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; -import static org.opensaml.saml.saml2.core.StatusCode.*; - /** * Implementation of {@link AuthenticationProvider} for SAML authentications when * receiving a {@code Response} object containing an {@code Assertion}. This @@ -174,7 +171,8 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv private Converter responseAuthenticationConverter = createDefaultResponseAuthenticationConverter(); - private static final Set includeChildStatusCodes = new HashSet<>(Arrays.asList(REQUESTER, RESPONDER, VERSION_MISMATCH)); + private static final Set includeChildStatusCodes = new HashSet<>( + Arrays.asList(StatusCode.REQUESTER, StatusCode.RESPONDER, StatusCode.VERSION_MISMATCH)); /** * Creates an {@link OpenSaml4AuthenticationProvider} @@ -379,11 +377,13 @@ public static Converter createDefau Response response = responseToken.getResponse(); Saml2AuthenticationToken token = responseToken.getToken(); Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success(); - String statusCode = getStatusCode(response); - if (!StatusCode.SUCCESS.equals(statusCode)) { - String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode, - response.getID()); - result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message)); + List statusCodes = getStatusCodes(response); + if (!isSuccess(statusCodes)) { + for (String statusCode : statusCodes) { + String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode, + response.getID()); + result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message)); + } } String inResponseTo = response.getInResponseTo(); @@ -412,24 +412,37 @@ public static Converter createDefau }; } - private static String getStatusCode(Response response) { + private static List getStatusCodes(Response response) { if (response.getStatus() == null) { - return StatusCode.SUCCESS; + return Arrays.asList(StatusCode.SUCCESS); } if (response.getStatus().getStatusCode() == null) { - return StatusCode.SUCCESS; + return Arrays.asList(StatusCode.SUCCESS); } StatusCode parentStatusCode = response.getStatus().getStatusCode(); String parentStatusCodeValue = parentStatusCode.getValue(); if (includeChildStatusCodes.contains(parentStatusCodeValue)) { - return Optional.ofNullable(parentStatusCode.getStatusCode()) - .map(StatusCode::getValue) - .map(childStatusCodeValue -> parentStatusCodeValue + childStatusCodeValue) - .orElse(parentStatusCodeValue); + StatusCode statusCode = parentStatusCode.getStatusCode(); + if (statusCode != null) { + String childStatusCodeValue = statusCode.getValue(); + if (childStatusCodeValue != null) { + return Arrays.asList(parentStatusCodeValue, childStatusCodeValue); + } + } + return Arrays.asList(parentStatusCodeValue); + } + + return Arrays.asList(parentStatusCodeValue); + } + + private static boolean isSuccess(List statusCodes) { + if (statusCodes.size() != 1) { + return false; } - return parentStatusCodeValue; + String statusCode = statusCodes.get(0); + return StatusCode.SUCCESS.equals(statusCode); } private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest, diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java index 53aeaba17f7..8432b5760a2 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -86,8 +86,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.atLeastOnce; @@ -736,7 +734,7 @@ public void authenticateWhenCustomResponseValidatorThenUses() { } @Test - public void setsOnlyParentStatusCodeOnResultDescription() { + public void authenticateWhenResponseStatusIsNotSuccessThenOnlyReturnParentStatusCodes() { ResponseToken mockResponseToken = mock(ResponseToken.class); Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class); given(mockResponseToken.getToken()).willReturn(mockSamlToken); @@ -744,7 +742,8 @@ public void setsOnlyParentStatusCodeOnResultDescription() { RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class); given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration); - RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(RelyingPartyRegistration.AssertingPartyDetails.class); + RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock( + RelyingPartyRegistration.AssertingPartyDetails.class); given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails); Status parentStatus = new StatusBuilder().buildObject(); @@ -763,16 +762,21 @@ public void setsOnlyParentStatusCodeOnResultDescription() { given(mockResponseToken.getResponse()).willReturn(mockResponse); - Converter validator = OpenSaml4AuthenticationProvider.createDefaultResponseValidator(); + Converter validator = OpenSaml4AuthenticationProvider + .createDefaultResponseValidator(); Saml2ResponseValidatorResult result = validator.convert(mockResponseToken); - String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue()); - assertTrue(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(expectedErrorMessage))); - assertFalse(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(childStatusCode.getValue()))); + String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", + parentStatusCode.getValue()); + assertThat( + result.getErrors().stream().anyMatch((error) -> error.getDescription().contains(expectedErrorMessage))); + assertThat(result.getErrors() + .stream() + .noneMatch((error) -> error.getDescription().contains(childStatusCode.getValue()))); } @Test - public void setsParentAndChildStatusCodeOnResultDescription() { + public void authenticateWhenResponseStatusIsNotSuccessThenReturnParentAndChildStatusCode() { ResponseToken mockResponseToken = mock(ResponseToken.class); Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class); given(mockResponseToken.getToken()).willReturn(mockSamlToken); @@ -780,7 +784,8 @@ public void setsParentAndChildStatusCodeOnResultDescription() { RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class); given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration); - RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(RelyingPartyRegistration.AssertingPartyDetails.class); + RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock( + RelyingPartyRegistration.AssertingPartyDetails.class); given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails); Status parentStatus = new StatusBuilder().buildObject(); @@ -799,11 +804,20 @@ public void setsParentAndChildStatusCodeOnResultDescription() { given(mockResponseToken.getResponse()).willReturn(mockResponse); - Converter validator = OpenSaml4AuthenticationProvider.createDefaultResponseValidator(); + Converter validator = OpenSaml4AuthenticationProvider + .createDefaultResponseValidator(); Saml2ResponseValidatorResult result = validator.convert(mockResponseToken); - String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue() + childStatusCode.getValue()); - assertTrue(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(expectedErrorMessage))); + String expectedParentErrorMessage = String.format("Invalid status [%s] for SAML response", + parentStatusCode.getValue()); + String expectedChildErrorMessage = String.format("Invalid status [%s] for SAML response", + childStatusCode.getValue()); + assertThat(result.getErrors() + .stream() + .anyMatch((error) -> error.getDescription().contains(expectedParentErrorMessage))); + assertThat(result.getErrors() + .stream() + .anyMatch((error) -> error.getDescription().contains(expectedChildErrorMessage))); } @Test