From ffd12ee3b9c71fc337175f48010db8c28e85b68e Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Wed, 11 Oct 2023 14:01:36 -0600 Subject: [PATCH] Refine requestMatcher Validation Rules Closes gh-14078 --- .../web/AbstractRequestMatcherRegistry.java | 145 ++++++++++++++++-- .../security/config/MockServletContext.java | 5 + .../config/TestMockHttpServletMappings.java | 46 ++++++ .../AbstractRequestMatcherRegistryTests.java | 113 +++++++++++++- 4 files changed, 292 insertions(+), 17 deletions(-) create mode 100644 config/src/test/java/org/springframework/security/config/TestMockHttpServletMappings.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java index e4a94b9cb6e..f2618aaa1e2 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import javax.servlet.DispatcherType; import javax.servlet.ServletContext; import javax.servlet.ServletRegistration; +import javax.servlet.http.HttpServletRequest; import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.context.ApplicationContext; @@ -321,11 +322,30 @@ public C requestMatchers(HttpMethod method, String... patterns) { if (!hasDispatcherServlet(registrations)) { return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns)); } - if (registrations.size() > 1) { - String errorMessage = computeErrorMessage(registrations.values()); - throw new IllegalArgumentException(errorMessage); + ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations); + if (dispatcherServlet != null) { + if (registrations.size() == 1) { + return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0])); + } + List matchers = new ArrayList<>(); + for (String pattern : patterns) { + AntPathRequestMatcher ant = new AntPathRequestMatcher(pattern, (method != null) ? method.name() : null); + MvcRequestMatcher mvc = createMvcMatchers(method, pattern).get(0); + matchers.add(new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext)); + } + return requestMatchers(matchers.toArray(new RequestMatcher[0])); } - return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0])); + dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations); + if (dispatcherServlet != null) { + String mapping = dispatcherServlet.getMappings().iterator().next(); + List matchers = createMvcMatchers(method, patterns); + for (MvcRequestMatcher matcher : matchers) { + matcher.setServletPath(mapping.substring(0, mapping.length() - 2)); + } + return requestMatchers(matchers.toArray(new RequestMatcher[0])); + } + String errorMessage = computeErrorMessage(registrations.values()); + throw new IllegalArgumentException(errorMessage); } private Map mappableServletRegistrations(ServletContext servletContext) { @@ -343,22 +363,66 @@ private boolean hasDispatcherServlet(Map if (registrations == null) { return false; } - Class dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet", - null); for (ServletRegistration registration : registrations.values()) { - try { - Class clazz = Class.forName(registration.getClassName()); - if (dispatcherServlet.isAssignableFrom(clazz)) { - return true; - } - } - catch (ClassNotFoundException ex) { - return false; + if (isDispatcherServlet(registration)) { + return true; } } return false; } + private ServletRegistration requireOneRootDispatcherServlet( + Map registrations) { + ServletRegistration rootDispatcherServlet = null; + for (ServletRegistration registration : registrations.values()) { + if (!isDispatcherServlet(registration)) { + continue; + } + if (registration.getMappings().size() > 1) { + return null; + } + if (!"/".equals(registration.getMappings().iterator().next())) { + return null; + } + rootDispatcherServlet = registration; + } + return rootDispatcherServlet; + } + + private ServletRegistration requireOnlyPathMappedDispatcherServlet( + Map registrations) { + ServletRegistration pathDispatcherServlet = null; + for (ServletRegistration registration : registrations.values()) { + if (!isDispatcherServlet(registration)) { + return null; + } + if (registration.getMappings().size() > 1) { + return null; + } + String mapping = registration.getMappings().iterator().next(); + if (!mapping.startsWith("/") || !mapping.endsWith("/*")) { + return null; + } + if (pathDispatcherServlet != null) { + return null; + } + pathDispatcherServlet = registration; + } + return pathDispatcherServlet; + } + + private boolean isDispatcherServlet(ServletRegistration registration) { + Class dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet", + null); + try { + Class clazz = Class.forName(registration.getClassName()); + return dispatcherServlet.isAssignableFrom(clazz); + } + catch (ClassNotFoundException ex) { + return false; + } + } + private String computeErrorMessage(Collection registrations) { String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. " + "If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); " @@ -498,4 +562,55 @@ static List regexMatchers(String... regexPatterns) { } + static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher { + + private final AntPathRequestMatcher ant; + + private final MvcRequestMatcher mvc; + + private final ServletContext servletContext; + + DispatcherServletDelegatingRequestMatcher(AntPathRequestMatcher ant, MvcRequestMatcher mvc, + ServletContext servletContext) { + this.ant = ant; + this.mvc = mvc; + this.servletContext = servletContext; + } + + @Override + public boolean matches(HttpServletRequest request) { + String name = request.getHttpServletMapping().getServletName(); + ServletRegistration registration = this.servletContext.getServletRegistration(name); + Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context"); + if (isDispatcherServlet(registration)) { + return this.mvc.matches(request); + } + return this.ant.matches(request); + } + + @Override + public MatchResult matcher(HttpServletRequest request) { + String name = request.getHttpServletMapping().getServletName(); + ServletRegistration registration = this.servletContext.getServletRegistration(name); + Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context"); + if (isDispatcherServlet(registration)) { + return this.mvc.matcher(request); + } + return this.ant.matcher(request); + } + + private boolean isDispatcherServlet(ServletRegistration registration) { + Class dispatcherServlet = ClassUtils + .resolveClassName("org.springframework.web.servlet.DispatcherServlet", null); + try { + Class clazz = Class.forName(registration.getClassName()); + return dispatcherServlet.isAssignableFrom(clazz); + } + catch (ClassNotFoundException ex) { + return false; + } + } + + } + } diff --git a/config/src/test/java/org/springframework/security/config/MockServletContext.java b/config/src/test/java/org/springframework/security/config/MockServletContext.java index df3ca415c7d..b677a1aaaee 100644 --- a/config/src/test/java/org/springframework/security/config/MockServletContext.java +++ b/config/src/test/java/org/springframework/security/config/MockServletContext.java @@ -55,6 +55,11 @@ public ServletRegistration.Dynamic addServlet(@NonNull String servletName, Class return this.registrations; } + @Override + public ServletRegistration getServletRegistration(String servletName) { + return this.registrations.get(servletName); + } + private static class MockServletRegistration implements ServletRegistration.Dynamic { private final String name; diff --git a/config/src/test/java/org/springframework/security/config/TestMockHttpServletMappings.java b/config/src/test/java/org/springframework/security/config/TestMockHttpServletMappings.java new file mode 100644 index 00000000000..48b71fde4fc --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/TestMockHttpServletMappings.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.MappingMatch; + +import org.springframework.mock.web.MockHttpServletMapping; + +public final class TestMockHttpServletMappings { + + private TestMockHttpServletMappings() { + + } + + public static MockHttpServletMapping extension(HttpServletRequest request, String extension) { + String uri = request.getRequestURI(); + String matchValue = uri.substring(0, uri.lastIndexOf(extension)); + return new MockHttpServletMapping(matchValue, "*" + extension, "extension", MappingMatch.EXTENSION); + } + + public static MockHttpServletMapping path(HttpServletRequest request, String path) { + String uri = request.getRequestURI(); + String matchValue = uri.substring(path.length()); + return new MockHttpServletMapping(matchValue, path + "/*", "path", MappingMatch.PATH); + } + + public static MockHttpServletMapping defaultMapping() { + return new MockHttpServletMapping("", "/", "default", MappingMatch.DEFAULT); + } + +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java index 391f131494e..fb06997c4f0 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import javax.servlet.DispatcherType; import javax.servlet.Servlet; +import javax.servlet.http.HttpServletMapping; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -29,8 +30,11 @@ import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.context.ApplicationContext; import org.springframework.http.HttpMethod; +import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.config.MockServletContext; +import org.springframework.security.config.TestMockHttpServletMappings; import org.springframework.security.config.annotation.ObjectPostProcessor; +import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry.DispatcherServletDelegatingRequestMatcher; import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.DispatcherTypeRequestMatcher; @@ -43,6 +47,9 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; /** * Tests for {@link AbstractRequestMatcherRegistry}. @@ -197,6 +204,8 @@ public void requestMatchersWhenMvcPresentInClassPathAndMvcIntrospectorBeanNotAva public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType() { MockServletContext servletContext = new MockServletContext(); given(this.context.getServletContext()).willReturn(servletContext); + servletContext.addServlet("servletOne", Servlet.class).addMapping("/one"); + servletContext.addServlet("servletTwo", Servlet.class).addMapping("/two"); List requestMatchers = this.matcherRegistry.requestMatchers("/**"); assertThat(requestMatchers).isNotEmpty(); assertThat(requestMatchers).hasSize(1); @@ -214,7 +223,26 @@ public void requestMatchersWhenAmbiguousServletsThenException() { MockServletContext servletContext = new MockServletContext(); given(this.context.getServletContext()).willReturn(servletContext); servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/"); - servletContext.addServlet("servletTwo", Servlet.class).addMapping("/servlet/**"); + servletContext.addServlet("servletTwo", DispatcherServlet.class).addMapping("/servlet/*"); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.matcherRegistry.requestMatchers("/**")); + } + + @Test + public void requestMatchersWhenMultipleDispatcherServletMappingsThenException() { + MockServletContext servletContext = new MockServletContext(); + given(this.context.getServletContext()).willReturn(servletContext); + servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/", "/mvc/*"); + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.matcherRegistry.requestMatchers("/**")); + } + + @Test + public void requestMatchersWhenPathDispatcherServletAndOtherServletsThenException() { + MockServletContext servletContext = new MockServletContext(); + given(this.context.getServletContext()).willReturn(servletContext); + servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/mvc/*"); + servletContext.addServlet("default", Servlet.class).addMapping("/"); assertThatExceptionOfType(IllegalArgumentException.class) .isThrownBy(() -> this.matcherRegistry.requestMatchers("/**")); } @@ -231,6 +259,87 @@ public void requestMatchersWhenUnmappableServletsThenSkips() { assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class); } + @Test + public void requestMatchersWhenOnlyDispatcherServletThenAllows() { + mockMvcIntrospector(true); + MockServletContext servletContext = new MockServletContext(); + given(this.context.getServletContext()).willReturn(servletContext); + servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/mvc/*"); + List requestMatchers = this.matcherRegistry.requestMatchers("/**"); + assertThat(requestMatchers).hasSize(1); + assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class); + } + + @Test + public void requestMatchersWhenImplicitServletsThenAllows() { + mockMvcIntrospector(true); + MockServletContext servletContext = new MockServletContext(); + given(this.context.getServletContext()).willReturn(servletContext); + servletContext.addServlet("defaultServlet", Servlet.class); + servletContext.addServlet("jspServlet", Servlet.class).addMapping("*.jsp", "*.jspx"); + servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/"); + List requestMatchers = this.matcherRegistry.requestMatchers("/**"); + assertThat(requestMatchers).hasSize(1); + assertThat(requestMatchers.get(0)).isInstanceOf(DispatcherServletDelegatingRequestMatcher.class); + } + + @Test + public void requestMatchersWhenPathBasedNonDispatcherServletThenAllows() { + mockMvcIntrospector(true); + MockServletContext servletContext = new MockServletContext(); + given(this.context.getServletContext()).willReturn(servletContext); + servletContext.addServlet("path", Servlet.class).addMapping("/services/*"); + servletContext.addServlet("default", DispatcherServlet.class).addMapping("/"); + List requestMatchers = this.matcherRegistry.requestMatchers("/services/*"); + assertThat(requestMatchers).hasSize(1); + assertThat(requestMatchers.get(0)).isInstanceOf(DispatcherServletDelegatingRequestMatcher.class); + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/services/endpoint") { + @Override + public HttpServletMapping getHttpServletMapping() { + return TestMockHttpServletMappings.defaultMapping(); + } + }; + assertThat(requestMatchers.get(0).matcher(request).isMatch()).isTrue(); + request = new MockHttpServletRequest("GET", "/services/endpoint") { + @Override + public HttpServletMapping getHttpServletMapping() { + return TestMockHttpServletMappings.path(this, "/services"); + } + }; + request.setServletPath("/services"); + request.setPathInfo("/endpoint"); + assertThat(requestMatchers.get(0).matcher(request).isMatch()).isTrue(); + } + + @Test + public void matchesWhenDispatcherServletThenMvc() { + MockServletContext servletContext = new MockServletContext(); + servletContext.addServlet("default", DispatcherServlet.class).addMapping("/"); + servletContext.addServlet("path", Servlet.class).addMapping("/services/*"); + MvcRequestMatcher mvc = mock(MvcRequestMatcher.class); + AntPathRequestMatcher ant = mock(AntPathRequestMatcher.class); + DispatcherServletDelegatingRequestMatcher requestMatcher = new DispatcherServletDelegatingRequestMatcher(ant, + mvc, servletContext); + MockHttpServletRequest request = new MockHttpServletRequest("GET", "/services/endpoint") { + @Override + public HttpServletMapping getHttpServletMapping() { + return TestMockHttpServletMappings.defaultMapping(); + } + }; + assertThat(requestMatcher.matches(request)).isFalse(); + verify(mvc).matches(request); + verifyNoInteractions(ant); + request = new MockHttpServletRequest("GET", "/services/endpoint") { + @Override + public HttpServletMapping getHttpServletMapping() { + return TestMockHttpServletMappings.path(this, "/services"); + } + }; + assertThat(requestMatcher.matches(request)).isFalse(); + verify(ant).matches(request); + verifyNoMoreInteractions(mvc); + } + private void mockMvcIntrospector(boolean isPresent) { ApplicationContext context = this.matcherRegistry.getApplicationContext(); given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent);