Skip to content

Commit

Permalink
Merge branch '5.8.x' into 6.0.x
Browse files Browse the repository at this point in the history
Closes gh-14164
  • Loading branch information
jzheaux committed Nov 17, 2023
2 parents 00da9c9 + 4ca5468 commit c6c6eb4
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletContext;
Expand All @@ -42,6 +44,7 @@
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.function.SingletonSupplier;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.servlet.handler.HandlerMappingIntrospector;

Expand Down Expand Up @@ -197,34 +200,51 @@ public C requestMatchers(HttpMethod method, String... patterns) {
if (servletContext == null) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
}
boolean isProgrammaticApiAvailable = isProgrammaticApiAvailable(servletContext);
List<RequestMatcher> matchers = new ArrayList<>();
for (String pattern : patterns) {
AntPathRequestMatcher ant = new AntPathRequestMatcher(pattern, (method != null) ? method.name() : null);
MvcRequestMatcher mvc = createMvcMatchers(method, pattern).get(0);
if (isProgrammaticApiAvailable) {
matchers.add(resolve(ant, mvc, servletContext));
}
else {
matchers.add(new DeferredRequestMatcher(() -> resolve(ant, mvc, servletContext), mvc, ant));
}
}
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
}

private static boolean isProgrammaticApiAvailable(ServletContext servletContext) {
try {
servletContext.getServletRegistrations();
return true;
}
catch (UnsupportedOperationException ex) {
return false;
}
}

private RequestMatcher resolve(AntPathRequestMatcher ant, MvcRequestMatcher mvc, ServletContext servletContext) {
Map<String, ? extends ServletRegistration> registrations = mappableServletRegistrations(servletContext);
if (registrations.isEmpty()) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
return ant;
}
if (!hasDispatcherServlet(registrations)) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
return ant;
}
ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations);
if (dispatcherServlet != null) {
if (registrations.size() == 1) {
return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
return mvc;
}
List<RequestMatcher> matchers = new ArrayList<>();
for (String pattern : patterns) {
AntPathRequestMatcher ant = new AntPathRequestMatcher(pattern, (method != null) ? method.name() : null);
MvcRequestMatcher mvc = createMvcMatchers(method, pattern).get(0);
matchers.add(new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext));
}
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
return new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext);
}
dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations);
if (dispatcherServlet != null) {
String mapping = dispatcherServlet.getMappings().iterator().next();
List<MvcRequestMatcher> matchers = createMvcMatchers(method, patterns);
for (MvcRequestMatcher matcher : matchers) {
matcher.setServletPath(mapping.substring(0, mapping.length() - 2));
}
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
mvc.setServletPath(mapping.substring(0, mapping.length() - 2));
return mvc;
}
String errorMessage = computeErrorMessage(registrations.values());
throw new IllegalArgumentException(errorMessage);
Expand Down Expand Up @@ -444,6 +464,38 @@ static List<RequestMatcher> regexMatchers(String... regexPatterns) {

}

static class DeferredRequestMatcher implements RequestMatcher {

final Supplier<RequestMatcher> requestMatcher;

final AtomicReference<String> description = new AtomicReference<>();

DeferredRequestMatcher(Supplier<RequestMatcher> resolver, RequestMatcher... candidates) {
this.requestMatcher = SingletonSupplier.of(() -> {
RequestMatcher matcher = resolver.get();
this.description.set(matcher.toString());
return matcher;
});
this.description.set("Deferred " + candidates);
}

@Override
public boolean matches(HttpServletRequest request) {
return this.requestMatcher.get().matches(request);
}

@Override
public MatchResult matcher(HttpServletRequest request) {
return this.requestMatcher.get().matcher(request);
}

@Override
public String toString() {
return this.description.get();
}

}

static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher {

private final AntPathRequestMatcher ant;
Expand Down Expand Up @@ -493,6 +545,11 @@ private boolean isDispatcherServlet(ServletRegistration registration) {
}
}

@Override
public String toString() {
return "DispatcherServletDelegating [" + "ant = " + this.ant + ", mvc = " + this.mvc + "]";
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.security.config.annotation.web;

import java.util.ArrayList;
import java.util.List;

import jakarta.servlet.DispatcherType;
Expand Down Expand Up @@ -164,6 +165,7 @@ public void requestMatchersWhenMvcPresentInClassPathAndMvcIntrospectorBeanNotAva

@Test
public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType() {
mockMvcIntrospector(true);
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("servletOne", Servlet.class).addMapping("/one");
Expand All @@ -182,6 +184,7 @@ public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType(

@Test
public void requestMatchersWhenAmbiguousServletsThenException() {
mockMvcIntrospector(true);
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
Expand All @@ -192,6 +195,7 @@ public void requestMatchersWhenAmbiguousServletsThenException() {

@Test
public void requestMatchersWhenMultipleDispatcherServletMappingsThenException() {
mockMvcIntrospector(true);
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/", "/mvc/*");
Expand All @@ -201,6 +205,7 @@ public void requestMatchersWhenMultipleDispatcherServletMappingsThenException()

@Test
public void requestMatchersWhenPathDispatcherServletAndOtherServletsThenException() {
mockMvcIntrospector(true);
MockServletContext servletContext = new MockServletContext();
given(this.context.getServletContext()).willReturn(servletContext);
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/mvc/*");
Expand Down Expand Up @@ -309,11 +314,29 @@ private void mockMvcIntrospector(boolean isPresent) {

private static class TestRequestMatcherRegistry extends AbstractRequestMatcherRegistry<List<RequestMatcher>> {

@Override
public List<RequestMatcher> requestMatchers(RequestMatcher... requestMatchers) {
return unwrap(super.requestMatchers(requestMatchers));
}

@Override
protected List<RequestMatcher> chainRequestMatchers(List<RequestMatcher> requestMatchers) {
return requestMatchers;
}

private static List<RequestMatcher> unwrap(List<RequestMatcher> wrappedMatchers) {
List<RequestMatcher> requestMatchers = new ArrayList<>();
for (RequestMatcher requestMatcher : wrappedMatchers) {
if (requestMatcher instanceof AbstractRequestMatcherRegistry.DeferredRequestMatcher) {
requestMatchers.add(((DeferredRequestMatcher) requestMatcher).requestMatcher.get());
}
else {
requestMatchers.add(requestMatcher);
}
}
return requestMatchers;
}

}

}

0 comments on commit c6c6eb4

Please sign in to comment.