diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java index 8dc4cbae0bd..c09990e9eae 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java @@ -44,11 +44,13 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.AnyRequestMatcher; import org.springframework.security.web.util.matcher.DispatcherTypeRequestMatcher; +import org.springframework.security.web.util.matcher.OrRequestMatcher; import org.springframework.security.web.util.matcher.RegexRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.web.context.WebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; import org.springframework.web.servlet.handler.HandlerMappingIntrospector; /** @@ -335,10 +337,10 @@ public C requestMatchers(HttpMethod method, String... patterns) { private RequestMatcher resolve(AntPathRequestMatcher ant, MvcRequestMatcher mvc, ServletContext servletContext) { Map registrations = mappableServletRegistrations(servletContext); if (registrations.isEmpty()) { - return ant; + return new DispatcherServletDelegatingRequestMatcher(ant, mvc, new MockMvcRequestMatcher()); } if (!hasDispatcherServlet(registrations)) { - return ant; + return new DispatcherServletDelegatingRequestMatcher(ant, mvc, new MockMvcRequestMatcher()); } ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations); if (dispatcherServlet != null) { @@ -605,27 +607,70 @@ public String toString() { } + static class MockMvcRequestMatcher implements RequestMatcher { + + @Override + public boolean matches(HttpServletRequest request) { + return request.getAttribute("org.springframework.test.web.servlet.MockMvc.MVC_RESULT_ATTRIBUTE") != null; + } + + } + + static class DispatcherServletRequestMatcher implements RequestMatcher { + + private final ServletContext servletContext; + + DispatcherServletRequestMatcher(ServletContext servletContext) { + this.servletContext = servletContext; + } + + @Override + public boolean matches(HttpServletRequest request) { + String name = request.getHttpServletMapping().getServletName(); + ServletRegistration registration = this.servletContext.getServletRegistration(name); + Assert.notNull(name, "Failed to find servlet [" + name + "] in the servlet context"); + try { + Class clazz = Class.forName(registration.getClassName()); + return DispatcherServlet.class.isAssignableFrom(clazz); + } + catch (ClassNotFoundException ex) { + return false; + } + } + + } + static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher { private final AntPathRequestMatcher ant; private final MvcRequestMatcher mvc; - private final ServletContext servletContext; + private final RequestMatcher dispatcherServlet; DispatcherServletDelegatingRequestMatcher(AntPathRequestMatcher ant, MvcRequestMatcher mvc, ServletContext servletContext) { + this(ant, mvc, new OrRequestMatcher(new MockMvcRequestMatcher(), + new DispatcherServletRequestMatcher(servletContext))); + } + + DispatcherServletDelegatingRequestMatcher(AntPathRequestMatcher ant, MvcRequestMatcher mvc, + RequestMatcher dispatcherServlet) { this.ant = ant; this.mvc = mvc; - this.servletContext = servletContext; + this.dispatcherServlet = dispatcherServlet; + } + + RequestMatcher requestMatcher(HttpServletRequest request) { + if (this.dispatcherServlet.matches(request)) { + return this.mvc; + } + return this.ant; } @Override public boolean matches(HttpServletRequest request) { - String name = request.getHttpServletMapping().getServletName(); - ServletRegistration registration = this.servletContext.getServletRegistration(name); - Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context"); - if (isDispatcherServlet(registration)) { + if (this.dispatcherServlet.matches(request)) { return this.mvc.matches(request); } return this.ant.matches(request); @@ -633,27 +678,12 @@ public boolean matches(HttpServletRequest request) { @Override public MatchResult matcher(HttpServletRequest request) { - String name = request.getHttpServletMapping().getServletName(); - ServletRegistration registration = this.servletContext.getServletRegistration(name); - Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context"); - if (isDispatcherServlet(registration)) { + if (this.dispatcherServlet.matches(request)) { return this.mvc.matcher(request); } return this.ant.matcher(request); } - private boolean isDispatcherServlet(ServletRegistration registration) { - Class dispatcherServlet = ClassUtils - .resolveClassName("org.springframework.web.servlet.DispatcherServlet", null); - try { - Class clazz = Class.forName(registration.getClassName()); - return dispatcherServlet.isAssignableFrom(clazz); - } - catch (ClassNotFoundException ex) { - return false; - } - } - @Override public String toString() { return "DispatcherServletDelegating [" + "ant = " + this.ant + ", mvc = " + this.mvc + "]"; diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java index e5f96703ba7..efa51cbc20f 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java @@ -30,27 +30,35 @@ import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Configuration; import org.springframework.http.HttpMethod; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.config.MockServletContext; import org.springframework.security.config.TestMockHttpServletMappings; import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry.DispatcherServletDelegatingRequestMatcher; +import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.DispatcherTypeRequestMatcher; import org.springframework.security.web.util.matcher.RegexRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.InstanceOfAssertFactories.type; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; /** * Tests for {@link AbstractRequestMatcherRegistry}. @@ -206,18 +214,65 @@ public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType( mockMvcIntrospector(true); MockServletContext servletContext = new MockServletContext(); given(this.context.getServletContext()).willReturn(servletContext); + MockHttpServletRequest request = new MockHttpServletRequest(); + List requestMatchers = this.matcherRegistry.requestMatchers("/**"); + assertThat(requestMatchers).isNotEmpty(); + assertThat(requestMatchers).hasSize(1); + assertThat(requestMatchers.get(0)).asInstanceOf(type(DispatcherServletDelegatingRequestMatcher.class)) + .extracting((matcher) -> matcher.requestMatcher(request)) + .isInstanceOf(AntPathRequestMatcher.class); servletContext.addServlet("servletOne", Servlet.class).addMapping("/one"); servletContext.addServlet("servletTwo", Servlet.class).addMapping("/two"); - List requestMatchers = this.matcherRegistry.requestMatchers("/**"); + requestMatchers = this.matcherRegistry.requestMatchers("/**"); assertThat(requestMatchers).isNotEmpty(); assertThat(requestMatchers).hasSize(1); - assertThat(requestMatchers.get(0)).isExactlyInstanceOf(AntPathRequestMatcher.class); + assertThat(requestMatchers.get(0)).asInstanceOf(type(DispatcherServletDelegatingRequestMatcher.class)) + .extracting((matcher) -> matcher.requestMatcher(request)) + .isInstanceOf(AntPathRequestMatcher.class); servletContext.addServlet("servletOne", Servlet.class); servletContext.addServlet("servletTwo", Servlet.class); requestMatchers = this.matcherRegistry.requestMatchers("/**"); assertThat(requestMatchers).isNotEmpty(); assertThat(requestMatchers).hasSize(1); - assertThat(requestMatchers.get(0)).isExactlyInstanceOf(AntPathRequestMatcher.class); + assertThat(requestMatchers.get(0)).asInstanceOf(type(DispatcherServletDelegatingRequestMatcher.class)) + .extracting((matcher) -> matcher.requestMatcher(request)) + .isInstanceOf(AntPathRequestMatcher.class); + } + + // gh-14418 + @Test + public void requestMatchersWhenNoDispatcherServletMockMvcThenMvcRequestMatcherType() throws Exception { + MockServletContext servletContext = new MockServletContext(); + try (SpringTestContext spring = new SpringTestContext(this)) { + spring.register(MockMvcConfiguration.class) + .postProcessor((context) -> context.setServletContext(servletContext)) + .autowire(); + this.matcherRegistry.setApplicationContext(spring.getContext()); + MockMvc mvc = MockMvcBuilders.webAppContextSetup(spring.getContext()).build(); + MockHttpServletRequest request = mvc.perform(get("/")).andReturn().getRequest(); + List requestMatchers = this.matcherRegistry.requestMatchers("/**"); + assertThat(requestMatchers).isNotEmpty(); + assertThat(requestMatchers).hasSize(1); + assertThat(requestMatchers.get(0)).asInstanceOf(type(DispatcherServletDelegatingRequestMatcher.class)) + .extracting((matcher) -> matcher.requestMatcher(request)) + .isInstanceOf(MvcRequestMatcher.class); + servletContext.addServlet("servletOne", Servlet.class).addMapping("/one"); + servletContext.addServlet("servletTwo", Servlet.class).addMapping("/two"); + requestMatchers = this.matcherRegistry.requestMatchers("/**"); + assertThat(requestMatchers).isNotEmpty(); + assertThat(requestMatchers).hasSize(1); + assertThat(requestMatchers.get(0)).asInstanceOf(type(DispatcherServletDelegatingRequestMatcher.class)) + .extracting((matcher) -> matcher.requestMatcher(request)) + .isInstanceOf(MvcRequestMatcher.class); + servletContext.addServlet("servletOne", Servlet.class); + servletContext.addServlet("servletTwo", Servlet.class); + requestMatchers = this.matcherRegistry.requestMatchers("/**"); + assertThat(requestMatchers).isNotEmpty(); + assertThat(requestMatchers).hasSize(1); + assertThat(requestMatchers.get(0)).asInstanceOf(type(DispatcherServletDelegatingRequestMatcher.class)) + .extracting((matcher) -> matcher.requestMatcher(request)) + .isInstanceOf(MvcRequestMatcher.class); + } } @Test @@ -398,4 +453,11 @@ private List unwrap(List wrappedMatchers) { } + @Configuration + @EnableWebSecurity + @EnableWebMvc + static class MockMvcConfiguration { + + } + } diff --git a/etc/checkstyle/checkstyle.xml b/etc/checkstyle/checkstyle.xml index c827af5f4d8..e0a8b9d743f 100644 --- a/etc/checkstyle/checkstyle.xml +++ b/etc/checkstyle/checkstyle.xml @@ -20,6 +20,7 @@ +