Skip to content

Commit

Permalink
Merge branch '6.0.x' into 6.1.x
Browse files Browse the repository at this point in the history
Closes gh-14085
  • Loading branch information
jzheaux committed Nov 1, 2023
2 parents 0ee006c + fa15c97 commit 624dcaf
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -26,6 +26,7 @@
import jakarta.servlet.DispatcherType;
import jakarta.servlet.ServletContext;
import jakarta.servlet.ServletRegistration;
import jakarta.servlet.http.HttpServletRequest;

import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.context.ApplicationContext;
Expand Down Expand Up @@ -203,11 +204,30 @@ public C requestMatchers(HttpMethod method, String... patterns) {
if (!hasDispatcherServlet(registrations)) {
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
}
if (registrations.size() > 1) {
String errorMessage = computeErrorMessage(registrations.values());
throw new IllegalArgumentException(errorMessage);
ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations);
if (dispatcherServlet != null) {
if (registrations.size() == 1) {
return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
}
List<RequestMatcher> matchers = new ArrayList<>();
for (String pattern : patterns) {
AntPathRequestMatcher ant = new AntPathRequestMatcher(pattern, (method != null) ? method.name() : null);
MvcRequestMatcher mvc = createMvcMatchers(method, pattern).get(0);
matchers.add(new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext));
}
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
}
return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations);
if (dispatcherServlet != null) {
String mapping = dispatcherServlet.getMappings().iterator().next();
List<MvcRequestMatcher> matchers = createMvcMatchers(method, patterns);
for (MvcRequestMatcher matcher : matchers) {
matcher.setServletPath(mapping.substring(0, mapping.length() - 2));
}
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
}
String errorMessage = computeErrorMessage(registrations.values());
throw new IllegalArgumentException(errorMessage);
}

private Map<String, ? extends ServletRegistration> mappableServletRegistrations(ServletContext servletContext) {
Expand All @@ -225,22 +245,66 @@ private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration>
if (registrations == null) {
return false;
}
Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
null);
for (ServletRegistration registration : registrations.values()) {
try {
Class<?> clazz = Class.forName(registration.getClassName());
if (dispatcherServlet.isAssignableFrom(clazz)) {
return true;
}
}
catch (ClassNotFoundException ex) {
return false;
if (isDispatcherServlet(registration)) {
return true;
}
}
return false;
}

private ServletRegistration requireOneRootDispatcherServlet(
Map<String, ? extends ServletRegistration> registrations) {
ServletRegistration rootDispatcherServlet = null;
for (ServletRegistration registration : registrations.values()) {
if (!isDispatcherServlet(registration)) {
continue;
}
if (registration.getMappings().size() > 1) {
return null;
}
if (!"/".equals(registration.getMappings().iterator().next())) {
return null;
}
rootDispatcherServlet = registration;
}
return rootDispatcherServlet;
}

private ServletRegistration requireOnlyPathMappedDispatcherServlet(
Map<String, ? extends ServletRegistration> registrations) {
ServletRegistration pathDispatcherServlet = null;
for (ServletRegistration registration : registrations.values()) {
if (!isDispatcherServlet(registration)) {
return null;
}
if (registration.getMappings().size() > 1) {
return null;
}
String mapping = registration.getMappings().iterator().next();
if (!mapping.startsWith("/") || !mapping.endsWith("/*")) {
return null;
}
if (pathDispatcherServlet != null) {
return null;
}
pathDispatcherServlet = registration;
}
return pathDispatcherServlet;
}

private boolean isDispatcherServlet(ServletRegistration registration) {
Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
null);
try {
Class<?> clazz = Class.forName(registration.getClassName());
return dispatcherServlet.isAssignableFrom(clazz);
}
catch (ClassNotFoundException ex) {
return false;
}
}

private String computeErrorMessage(Collection<? extends ServletRegistration> registrations) {
String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. "
+ "If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); "
Expand Down Expand Up @@ -380,4 +444,55 @@ static List<RequestMatcher> regexMatchers(String... regexPatterns) {

}

static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher {

private final AntPathRequestMatcher ant;

private final MvcRequestMatcher mvc;

private final ServletContext servletContext;

DispatcherServletDelegatingRequestMatcher(AntPathRequestMatcher ant, MvcRequestMatcher mvc,
ServletContext servletContext) {
this.ant = ant;
this.mvc = mvc;
this.servletContext = servletContext;
}

@Override
public boolean matches(HttpServletRequest request) {
String name = request.getHttpServletMapping().getServletName();
ServletRegistration registration = this.servletContext.getServletRegistration(name);
Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
if (isDispatcherServlet(registration)) {
return this.mvc.matches(request);
}
return this.ant.matches(request);
}

@Override
public MatchResult matcher(HttpServletRequest request) {
String name = request.getHttpServletMapping().getServletName();
ServletRegistration registration = this.servletContext.getServletRegistration(name);
Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
if (isDispatcherServlet(registration)) {
return this.mvc.matcher(request);
}
return this.ant.matcher(request);
}

private boolean isDispatcherServlet(ServletRegistration registration) {
Class<?> dispatcherServlet = ClassUtils
.resolveClassName("org.springframework.web.servlet.DispatcherServlet", null);
try {
Class<?> clazz = Class.forName(registration.getClassName());
return dispatcherServlet.isAssignableFrom(clazz);
}
catch (ClassNotFoundException ex) {
return false;
}
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ public ServletRegistration.Dynamic addServlet(@NonNull String servletName, Class
return this.registrations;
}

@Override
public ServletRegistration getServletRegistration(String servletName) {
return this.registrations.get(servletName);
}

private static class MockServletRegistration implements ServletRegistration.Dynamic {

private final String name;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.config;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.MappingMatch;

import org.springframework.mock.web.MockHttpServletMapping;

public final class TestMockHttpServletMappings {

private TestMockHttpServletMappings() {

}

public static MockHttpServletMapping extension(HttpServletRequest request, String extension) {
String uri = request.getRequestURI();
String matchValue = uri.substring(0, uri.lastIndexOf(extension));
return new MockHttpServletMapping(matchValue, "*" + extension, "extension", MappingMatch.EXTENSION);
}

public static MockHttpServletMapping path(HttpServletRequest request, String path) {
String uri = request.getRequestURI();
String matchValue = uri.substring(path.length());
return new MockHttpServletMapping(matchValue, path + "/*", "path", MappingMatch.PATH);
}

public static MockHttpServletMapping defaultMapping() {
return new MockHttpServletMapping("", "/", "default", MappingMatch.DEFAULT);
}

}
Loading

0 comments on commit 624dcaf

Please sign in to comment.