From f9d5dda65439f578f6e0968bc2ccf1b2f38f6946 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Fri, 19 Jul 2024 18:21:00 -0600 Subject: [PATCH] Polish Tests - Use test objects - Ensure assertThat is checked Issue gh-11725 --- .../OpenSaml4AuthenticationProviderTests.java | 57 +++++-------------- 1 file changed, 15 insertions(+), 42 deletions(-) diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java index abe4a4549a0..0193a531470 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java @@ -51,7 +51,6 @@ import org.opensaml.saml.saml2.core.EncryptedAssertion; import org.opensaml.saml.saml2.core.EncryptedAttribute; import org.opensaml.saml.saml2.core.EncryptedID; -import org.opensaml.saml.saml2.core.Issuer; import org.opensaml.saml.saml2.core.NameID; import org.opensaml.saml.saml2.core.OneTimeUse; import org.opensaml.saml.saml2.core.ProxyRestriction; @@ -737,16 +736,7 @@ public void authenticateWhenCustomResponseValidatorThenUses() { @Test public void authenticateWhenResponseStatusIsNotSuccessThenOnlyReturnParentStatusCodes() { - ResponseToken mockResponseToken = mock(ResponseToken.class); - Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class); - given(mockResponseToken.getToken()).willReturn(mockSamlToken); - - RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class); - given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration); - - RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock( - RelyingPartyRegistration.AssertingPartyDetails.class); - given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails); + Saml2AuthenticationToken token = TestSaml2AuthenticationTokens.token(); Status parentStatus = new StatusBuilder().buildObject(); StatusCode parentStatusCode = new StatusCodeBuilder().buildObject(); @@ -756,40 +746,27 @@ public void authenticateWhenResponseStatusIsNotSuccessThenOnlyReturnParentStatus parentStatusCode.setStatusCode(childStatusCode); parentStatus.setStatusCode(parentStatusCode); - Response mockResponse = mock(Response.class); - given(mockResponse.getStatus()).willReturn(parentStatus); - Issuer mockIssuer = mock(Issuer.class); - given(mockIssuer.getValue()).willReturn("mockedIssuer"); - given(mockResponse.getIssuer()).willReturn(mockIssuer); - - given(mockResponseToken.getResponse()).willReturn(mockResponse); + Response response = TestOpenSamlObjects.response(); + response.setStatus(parentStatus); + response.setIssuer(TestOpenSamlObjects.issuer("mockedIssuer")); Converter validator = OpenSaml4AuthenticationProvider .createDefaultResponseValidator(); - Saml2ResponseValidatorResult result = validator.convert(mockResponseToken); + Saml2ResponseValidatorResult result = validator.convert(new ResponseToken(response, token)); String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue()); assertThat( - result.getErrors().stream().anyMatch((error) -> error.getDescription().contains(expectedErrorMessage))); + result.getErrors().stream().anyMatch((error) -> error.getDescription().contains(expectedErrorMessage))) + .isTrue(); assertThat(result.getErrors() .stream() - .noneMatch((error) -> error.getDescription().contains(childStatusCode.getValue()))); + .noneMatch((error) -> error.getDescription().contains(childStatusCode.getValue()))).isTrue(); } @Test public void authenticateWhenResponseStatusIsNotSuccessThenReturnParentAndChildStatusCode() { - ResponseToken mockResponseToken = mock(ResponseToken.class); - Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class); - given(mockResponseToken.getToken()).willReturn(mockSamlToken); - - RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class); - given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration); - - RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock( - RelyingPartyRegistration.AssertingPartyDetails.class); - given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails); - + Saml2AuthenticationToken token = TestSaml2AuthenticationTokens.token(); Status parentStatus = new StatusBuilder().buildObject(); StatusCode parentStatusCode = new StatusCodeBuilder().buildObject(); parentStatusCode.setValue(StatusCode.REQUESTER); @@ -798,17 +775,13 @@ public void authenticateWhenResponseStatusIsNotSuccessThenReturnParentAndChildSt parentStatusCode.setStatusCode(childStatusCode); parentStatus.setStatusCode(parentStatusCode); - Response mockResponse = mock(Response.class); - given(mockResponse.getStatus()).willReturn(parentStatus); - Issuer mockIssuer = mock(Issuer.class); - given(mockIssuer.getValue()).willReturn("mockedIssuer"); - given(mockResponse.getIssuer()).willReturn(mockIssuer); - - given(mockResponseToken.getResponse()).willReturn(mockResponse); + Response response = TestOpenSamlObjects.response(); + response.setStatus(parentStatus); + response.setIssuer(TestOpenSamlObjects.issuer("mockedIssuer")); Converter validator = OpenSaml4AuthenticationProvider .createDefaultResponseValidator(); - Saml2ResponseValidatorResult result = validator.convert(mockResponseToken); + Saml2ResponseValidatorResult result = validator.convert(new ResponseToken(response, token)); String expectedParentErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue()); @@ -816,10 +789,10 @@ public void authenticateWhenResponseStatusIsNotSuccessThenReturnParentAndChildSt childStatusCode.getValue()); assertThat(result.getErrors() .stream() - .anyMatch((error) -> error.getDescription().contains(expectedParentErrorMessage))); + .anyMatch((error) -> error.getDescription().contains(expectedParentErrorMessage))).isTrue(); assertThat(result.getErrors() .stream() - .anyMatch((error) -> error.getDescription().contains(expectedChildErrorMessage))); + .anyMatch((error) -> error.getDescription().contains(expectedChildErrorMessage))).isTrue(); } @Test