From fe3c5bf9b42a5317e4aa8d9327f00db92b9b89e6 Mon Sep 17 00:00:00 2001 From: kratosmy Date: Sat, 27 Apr 2024 15:13:22 +0800 Subject: [PATCH] Fix possible ArrayIndexOutOfBoundsException in XorCsrfTokenRequestAttributeHandler and XorCsrfTokenUtils --- .../messaging/web/csrf/XorCsrfTokenUtils.java | 9 ++++++--- .../web/csrf/XorCsrfChannelInterceptorTests.java | 8 ++++++++ .../csrf/XorCsrfTokenRequestAttributeHandler.java | 2 +- .../XorCsrfTokenRequestAttributeHandlerTests.java | 12 ++++++++++++ 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java index 46a67cc4d39..8732914c2aa 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java +++ b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java @@ -56,13 +56,16 @@ static String getTokenValue(String actualToken, String token) { System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize); byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf); - return Utf8.decode(csrfBytes); + return (csrfBytes != null) ? Utf8.decode(csrfBytes) : null; } - private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) { + static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) { + if (csrfBytes.length < randomBytes.length) { + return null; + } int len = Math.min(randomBytes.length, csrfBytes.length); byte[] xoredCsrf = new byte[len]; - System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length); + System.arraycopy(csrfBytes, 0, xoredCsrf, 0, len); for (int i = 0; i < len; i++) { xoredCsrf[i] ^= randomBytes[i]; } diff --git a/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java index 884c3d2fc20..22f90d56603 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java @@ -32,6 +32,7 @@ import org.springframework.security.web.csrf.MissingCsrfTokenException; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatNoException; import static org.mockito.Mockito.mock; /** @@ -141,6 +142,13 @@ public void preSendWhenUnsubscribeThenIgnores() { this.interceptor.preSend(message(), this.channel); } + @Test + public void preSendWhenCsrfBytesIsLongerThanRandomBytesThenArrayIndexOutOfBoundsExceptionWillNotBeThrown() { + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token); + assertThatNoException().isThrownBy(() -> this.interceptor.preSend(message(), this.channel)); + } + private Message message() { return MessageBuilder.withPayload("message").copyHeaders(this.messageHeaders.toMap()).build(); } diff --git a/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java b/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java index 8d966331ae2..bce518bab7a 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java +++ b/web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java @@ -119,7 +119,7 @@ private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) { } int len = Math.min(randomBytes.length, csrfBytes.length); byte[] xoredCsrf = new byte[len]; - System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length); + System.arraycopy(csrfBytes, 0, xoredCsrf, 0, len); for (int i = 0; i < len; i++) { xoredCsrf[i] ^= randomBytes[i]; } diff --git a/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java b/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java index 6f508624119..70bb069cf54 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java @@ -30,6 +30,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.assertj.core.api.Assertions.assertThatNoException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.willAnswer; import static org.mockito.Mockito.mock; @@ -216,6 +217,17 @@ public void resolveCsrfTokenIsInvalidThenReturnsNull() { assertThat(tokenValue).isNull(); } + @Test + public void resolveCsrfTokenValueWhenCsrfBytesIsLongerThanRandomBytesThenArrayIndexOutOfBoundsExceptionWillNotBeThrown() { + this.request.setParameter(this.token.getParameterName(), XOR_CSRF_TOKEN_VALUE); + CsrfToken csrfToken = new DefaultCsrfToken("headerName", "paramName", "ABCDE"); + // @formatter:off + assertThatNoException().isThrownBy(() -> { + this.handler.resolveCsrfTokenValue(this.request, csrfToken); + }); + // @formatter:on + } + private static Answer fillByteArray() { return (invocation) -> { byte[] bytes = invocation.getArgument(0);