diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java index 46a67cc4d39..8732914c2aa 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java +++ b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java @@ -56,13 +56,16 @@ static String getTokenValue(String actualToken, String token) { System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize); byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf); - return Utf8.decode(csrfBytes); + return (csrfBytes != null) ? Utf8.decode(csrfBytes) : null; } - private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) { + static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) { + if (csrfBytes.length < randomBytes.length) { + return null; + } int len = Math.min(randomBytes.length, csrfBytes.length); byte[] xoredCsrf = new byte[len]; - System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length); + System.arraycopy(csrfBytes, 0, xoredCsrf, 0, len); for (int i = 0; i < len; i++) { xoredCsrf[i] ^= randomBytes[i]; } diff --git a/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java index 884c3d2fc20..7166b8916b8 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java @@ -32,6 +32,7 @@ import org.springframework.security.web.csrf.MissingCsrfTokenException; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatNoException; import static org.mockito.Mockito.mock; /** @@ -141,6 +142,14 @@ public void preSendWhenUnsubscribeThenIgnores() { this.interceptor.preSend(message(), this.channel); } + @Test + public void preSendWhenCsrfBytesIsLongerThanRandomBytesThenArrayIndexOutOfBoundsExceptionWillNotBeThrown() { + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE); + DefaultCsrfToken token = new DefaultCsrfToken("header", "param", "tokenl"); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), token); + assertThatNoException().isThrownBy(() -> this.interceptor.preSend(message(), this.channel)); + } + private Message message() { return MessageBuilder.withPayload("message").copyHeaders(this.messageHeaders.toMap()).build(); } diff --git a/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java b/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java index 8d966331ae2..bce518bab7a 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java +++ b/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java @@ -119,7 +119,7 @@ private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) { } int len = Math.min(randomBytes.length, csrfBytes.length); byte[] xoredCsrf = new byte[len]; - System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length); + System.arraycopy(csrfBytes, 0, xoredCsrf, 0, len); for (int i = 0; i < len; i++) { xoredCsrf[i] ^= randomBytes[i]; } diff --git a/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java b/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java index 6f508624119..7d432b791e5 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java @@ -30,6 +30,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.assertj.core.api.Assertions.assertThatNoException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.willAnswer; import static org.mockito.Mockito.mock; @@ -216,6 +217,13 @@ public void resolveCsrfTokenIsInvalidThenReturnsNull() { assertThat(tokenValue).isNull(); } + @Test + public void resolveCsrfTokenValueWhenCsrfBytesIsLongerThanRandomBytesThenArrayIndexOutOfBoundsExceptionWillNotBeThrown() { + this.request.setParameter(this.token.getParameterName(), XOR_CSRF_TOKEN_VALUE); + CsrfToken csrfToken = new DefaultCsrfToken("headerName", "paramName", "ABCDE"); + assertThatNoException().isThrownBy(() -> { this.handler.resolveCsrfTokenValue(this.request, csrfToken); }); + } + private static Answer fillByteArray() { return (invocation) -> { byte[] bytes = invocation.getArgument(0);