Skip to content

Commit

Permalink
Merge branch '6.0.x' into 6.1.x
Browse files Browse the repository at this point in the history
Closes gh-13722
  • Loading branch information
jzheaux committed Aug 21, 2023
2 parents 037c7cb + 5fb6f57 commit 0df1884
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -194,18 +196,31 @@ public C requestMatchers(HttpMethod method, String... patterns) {
if (servletContext == null) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
}
Map<String, ? extends ServletRegistration> registrations = servletContext.getServletRegistrations();
if (registrations == null) {
Map<String, ? extends ServletRegistration> registrations = mappableServletRegistrations(servletContext);
if (registrations.isEmpty()) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
}
if (!hasDispatcherServlet(registrations)) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
}
Assert.isTrue(registrations.size() == 1,
"This method cannot decide whether these patterns are Spring MVC patterns or not. If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); otherwise, please use requestMatchers(AntPathRequestMatcher).");
if (registrations.size() > 1) {
String errorMessage = computeErrorMessage(registrations.values());
throw new IllegalArgumentException(errorMessage);
}
return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
}

private Map<String, ? extends ServletRegistration> mappableServletRegistrations(ServletContext servletContext) {
Map<String, ServletRegistration> mappable = new LinkedHashMap<>();
for (Map.Entry<String, ? extends ServletRegistration> entry : servletContext.getServletRegistrations()
.entrySet()) {
if (!entry.getValue().getMappings().isEmpty()) {
mappable.put(entry.getKey(), entry.getValue());
}
}
return mappable;
}

private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration> registrations) {
if (registrations == null) {
return false;
Expand All @@ -226,6 +241,19 @@ private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration>
return false;
}

private String computeErrorMessage(Collection<? extends ServletRegistration> registrations) {
String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. "
+ "If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); "
+ "otherwise, please use requestMatchers(AntPathRequestMatcher).\n\n"
+ "This is because there is more than one mappable servlet in your servlet context: %s.\n\n"
+ "For each MvcRequestMatcher, call MvcRequestMatcher#setServletPath to indicate the servlet path.";
Map<String, Collection<String>> mappings = new LinkedHashMap<>();
for (ServletRegistration registration : registrations) {
mappings.put(registration.getClassName(), registration.getMappings());
}
return String.format(template, mappings);
}

/**
* <p>
* If the {@link HandlerMappingIntrospector} is available in the classpath, maps to an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

package org.springframework.security.config;

import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;

Expand All @@ -35,7 +37,7 @@ public class MockServletContext extends org.springframework.mock.web.MockServlet

public static MockServletContext mvc() {
MockServletContext servletContext = new MockServletContext();
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
return servletContext;
}

Expand All @@ -59,6 +61,8 @@ private static class MockServletRegistration implements ServletRegistration.Dyna

private final Class<?> clazz;

private final Set<String> mappings = new LinkedHashSet<>();

MockServletRegistration(String name, Class<?> clazz) {
this.name = name;
this.clazz = clazz;
Expand Down Expand Up @@ -91,12 +95,13 @@ public void setAsyncSupported(boolean isAsyncSupported) {

@Override
public Set<String> addMapping(String... urlPatterns) {
return null;
this.mappings.addAll(Arrays.asList(urlPatterns));
return this.mappings;
}

@Override
public Collection<String> getMappings() {
return null;
return this.mappings;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,24 @@ public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType(
public void requestMatchersWhenAmbiguousServletsThenException() {
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class);
servletContext.addServlet("servletTwo", Servlet.class);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
servletContext.addServlet("servletTwo", Servlet.class).addMapping("/servlet/**");
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
}

@Test
public void requestMatchersWhenUnmappableServletsThenSkips() {
mockMvcIntrospector(true);
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
servletContext.addServlet("servletTwo", Servlet.class);
List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
assertThat(requestMatchers).hasSize(1);
assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class);
}

private void mockMvcIntrospector(boolean isPresent) {
ApplicationContext context = this.matcherRegistry.getApplicationContext();
given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent);
Expand Down

0 comments on commit 0df1884

Please sign in to comment.