Skip to content

Commit

Permalink
Update to return List of StatusCodes and add Saml2Error to result obj…
Browse files Browse the repository at this point in the history
…ect and other formatting
  • Loading branch information
YoungKi Hong authored and jzheaux committed Mar 22, 2024
1 parent 76331a5 commit 6e45e65
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.HashSet;
import java.util.Arrays;
import java.util.Optional;
import java.util.function.Consumer;

import javax.annotation.Nonnull;
Expand Down Expand Up @@ -98,8 +97,6 @@
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;

import static org.opensaml.saml.saml2.core.StatusCode.*;

/**
* Implementation of {@link AuthenticationProvider} for SAML authentications when
* receiving a {@code Response} object containing an {@code Assertion}. This
Expand Down Expand Up @@ -174,7 +171,8 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv

private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = createDefaultResponseAuthenticationConverter();

private static final Set<String> includeChildStatusCodes = new HashSet<>(Arrays.asList(REQUESTER, RESPONDER, VERSION_MISMATCH));
private static final Set<String> includeChildStatusCodes = new HashSet<>(
Arrays.asList(StatusCode.REQUESTER, StatusCode.RESPONDER, StatusCode.VERSION_MISMATCH));

/**
* Creates an {@link OpenSaml4AuthenticationProvider}
Expand Down Expand Up @@ -379,11 +377,13 @@ public static Converter<ResponseToken, Saml2ResponseValidatorResult> createDefau
Response response = responseToken.getResponse();
Saml2AuthenticationToken token = responseToken.getToken();
Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
String statusCode = getStatusCode(response);
if (!StatusCode.SUCCESS.equals(statusCode)) {
String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
response.getID());
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
List<String> statusCodes = getStatusCodes(response);
if (!isSuccess(statusCodes)) {
for (String statusCode : statusCodes) {
String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
response.getID());
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
}
}

String inResponseTo = response.getInResponseTo();
Expand Down Expand Up @@ -412,24 +412,37 @@ public static Converter<ResponseToken, Saml2ResponseValidatorResult> createDefau
};
}

private static String getStatusCode(Response response) {
private static List<String> getStatusCodes(Response response) {
if (response.getStatus() == null) {
return StatusCode.SUCCESS;
return Arrays.asList(StatusCode.SUCCESS);
}
if (response.getStatus().getStatusCode() == null) {
return StatusCode.SUCCESS;
return Arrays.asList(StatusCode.SUCCESS);
}

StatusCode parentStatusCode = response.getStatus().getStatusCode();
String parentStatusCodeValue = parentStatusCode.getValue();
if (includeChildStatusCodes.contains(parentStatusCodeValue)) {
return Optional.ofNullable(parentStatusCode.getStatusCode())
.map(StatusCode::getValue)
.map(childStatusCodeValue -> parentStatusCodeValue + childStatusCodeValue)
.orElse(parentStatusCodeValue);
StatusCode statusCode = parentStatusCode.getStatusCode();
if (statusCode != null) {
String childStatusCodeValue = statusCode.getValue();
if (childStatusCodeValue != null) {
return Arrays.asList(parentStatusCodeValue, childStatusCodeValue);
}
}
return Arrays.asList(parentStatusCodeValue);
}

return Arrays.asList(parentStatusCodeValue);
}

private static boolean isSuccess(List<String> statusCodes) {
if (statusCodes.size() != 1) {
return false;
}

return parentStatusCodeValue;
String statusCode = statusCodes.get(0);
return StatusCode.SUCCESS.equals(statusCode);
}

private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -86,8 +86,6 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
Expand Down Expand Up @@ -736,15 +734,16 @@ public void authenticateWhenCustomResponseValidatorThenUses() {
}

@Test
public void setsOnlyParentStatusCodeOnResultDescription() {
public void authenticateWhenResponseStatusIsNotSuccessThenOnlyReturnParentStatusCodes() {
ResponseToken mockResponseToken = mock(ResponseToken.class);
Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class);
given(mockResponseToken.getToken()).willReturn(mockSamlToken);

RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class);
given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration);

RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(RelyingPartyRegistration.AssertingPartyDetails.class);
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(
RelyingPartyRegistration.AssertingPartyDetails.class);
given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails);

Status parentStatus = new StatusBuilder().buildObject();
Expand All @@ -763,24 +762,30 @@ public void setsOnlyParentStatusCodeOnResultDescription() {

given(mockResponseToken.getResponse()).willReturn(mockResponse);

Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider.createDefaultResponseValidator();
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider
.createDefaultResponseValidator();
Saml2ResponseValidatorResult result = validator.convert(mockResponseToken);

String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue());
assertTrue(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(expectedErrorMessage)));
assertFalse(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(childStatusCode.getValue())));
String expectedErrorMessage = String.format("Invalid status [%s] for SAML response",
parentStatusCode.getValue());
assertThat(
result.getErrors().stream().anyMatch((error) -> error.getDescription().contains(expectedErrorMessage)));
assertThat(result.getErrors()
.stream()
.noneMatch((error) -> error.getDescription().contains(childStatusCode.getValue())));
}

@Test
public void setsParentAndChildStatusCodeOnResultDescription() {
public void authenticateWhenResponseStatusIsNotSuccessThenReturnParentAndChildStatusCode() {
ResponseToken mockResponseToken = mock(ResponseToken.class);
Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class);
given(mockResponseToken.getToken()).willReturn(mockSamlToken);

RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class);
given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration);

RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(RelyingPartyRegistration.AssertingPartyDetails.class);
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(
RelyingPartyRegistration.AssertingPartyDetails.class);
given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails);

Status parentStatus = new StatusBuilder().buildObject();
Expand All @@ -799,11 +804,20 @@ public void setsParentAndChildStatusCodeOnResultDescription() {

given(mockResponseToken.getResponse()).willReturn(mockResponse);

Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider.createDefaultResponseValidator();
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider
.createDefaultResponseValidator();
Saml2ResponseValidatorResult result = validator.convert(mockResponseToken);

String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue() + childStatusCode.getValue());
assertTrue(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(expectedErrorMessage)));
String expectedParentErrorMessage = String.format("Invalid status [%s] for SAML response",
parentStatusCode.getValue());
String expectedChildErrorMessage = String.format("Invalid status [%s] for SAML response",
childStatusCode.getValue());
assertThat(result.getErrors()
.stream()
.anyMatch((error) -> error.getDescription().contains(expectedParentErrorMessage)));
assertThat(result.getErrors()
.stream()
.anyMatch((error) -> error.getDescription().contains(expectedChildErrorMessage)));
}

@Test
Expand Down

0 comments on commit 6e45e65

Please sign in to comment.