Skip to content

Commit

Permalink
Detect missing middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
Tratcher committed Nov 10, 2023
1 parent 8726e55 commit ed09592
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 2 deletions.
28 changes: 27 additions & 1 deletion src/ReverseProxy/Model/ProxyPipelineInitializerMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
using System.Diagnostics;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
#if NET8_0_OR_GREATER
using Microsoft.AspNetCore.Http.Timeouts;
#endif
using Microsoft.Extensions.Logging;
#if NET8_0_OR_GREATER
using Yarp.ReverseProxy.Configuration;
#endif
using Yarp.ReverseProxy.Utilities;

namespace Yarp.ReverseProxy.Model;
Expand Down Expand Up @@ -41,7 +47,17 @@ public Task Invoke(HttpContext context)
context.Response.StatusCode = StatusCodes.Status503ServiceUnavailable;
return Task.CompletedTask;
}

#if NET8_0_OR_GREATER
// There's no way to detect the presence of the timeout middleware before this, only the options.
if (endpoint.Metadata.GetMetadata<RequestTimeoutAttribute>() != null
&& context.Features.Get<IHttpRequestTimeoutFeature>() == null)
{
Log.TimeoutNotApplied(_logger, route.Config.RouteId);
// Out of an abundance of caution, refuse the request rather than allowing it to proceed without the configured timeout.
throw new InvalidOperationException($"The timeout was not applied for route '{route.Config.RouteId}', ensure `IApplicationBuilder.UseRequestTimeouts()`"
+ " is called between `IApplicationBuilder.UseRouting()` and `IApplicationBuilder.UseEndpoints()`.");
}
#endif
var destinationsState = cluster.DestinationsState;
context.Features.Set<IReverseProxyFeature>(new ReverseProxyFeature
{
Expand Down Expand Up @@ -80,9 +96,19 @@ private static class Log
EventIds.NoClusterFound,
"Route '{routeId}' has no cluster information.");

private static readonly Action<ILogger, string, Exception?> _timeoutNotApplied = LoggerMessage.Define<string>(
LogLevel.Error,
EventIds.TimeoutNotApplied,
"The timeout was not applied for route '{routeId}', ensure `IApplicationBuilder.UseRequestTimeouts()` is called between `IApplicationBuilder.UseRouting()` and `IApplicationBuilder.UseEndpoints()`.");

public static void NoClusterFound(ILogger logger, string routeId)
{
_noClusterFound(logger, routeId, null);
}

public static void TimeoutNotApplied(ILogger logger, string routeId)
{
_timeoutNotApplied(logger, routeId, null);
}
}
}
1 change: 1 addition & 0 deletions src/ReverseProxy/Utilities/EventIds.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,5 @@ internal static class EventIds
public static readonly EventId RetryingWebSocketDowngradeNoConnect = new EventId(61, "RetryingWebSocketDowngradeNoConnect");
public static readonly EventId RetryingWebSocketDowngradeNoHttp2 = new EventId(62, "RetryingWebSocketDowngradeNoHttp2");
public static readonly EventId InvalidSecWebSocketKeyHeader = new EventId(63, "InvalidSecWebSocketKeyHeader");
public static readonly EventId TimeoutNotApplied = new(64, nameof(TimeoutNotApplied));
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
#if NET8_0_OR_GREATER
using Microsoft.AspNetCore.Http.Timeouts;
#endif
using Microsoft.AspNetCore.Routing;
using Microsoft.AspNetCore.Routing.Patterns;
using Moq;
Expand Down Expand Up @@ -119,14 +122,85 @@ public async Task Invoke_NoHealthyEndpoints_CallsNext()

Assert.Equal(StatusCodes.Status418ImATeapot, httpContext.Response.StatusCode);
}
#if NET8_0_OR_GREATER
[Fact]
public async Task Invoke_MissingTimeoutMiddleware_RefuseRequest()
{
var httpClient = new HttpMessageInvoker(new Mock<HttpMessageHandler>().Object);
var cluster1 = new ClusterState(clusterId: "cluster1")
{
Model = new ClusterModel(new ClusterConfig(), httpClient)
};

var aspNetCoreEndpoints = new List<Endpoint>();
var routeConfig = new RouteModel(
config: new RouteConfig(),
cluster: cluster1,
transformer: HttpTransformer.Default);
var aspNetCoreEndpoint = CreateAspNetCoreEndpoint(routeConfig,
builder =>
{
builder.Metadata.Add(new RequestTimeoutAttribute(1));
});
aspNetCoreEndpoints.Add(aspNetCoreEndpoint);
var httpContext = new DefaultHttpContext();
httpContext.SetEndpoint(aspNetCoreEndpoint);

var sut = Create<ProxyPipelineInitializerMiddleware>();

await sut.Invoke(httpContext);

Assert.Equal(StatusCodes.Status503ServiceUnavailable, httpContext.Response.StatusCode);
}

[Fact]
public async Task Invoke_MissingTimeoutMiddleware_DefaultPolicyAllowed()
{
var httpClient = new HttpMessageInvoker(new Mock<HttpMessageHandler>().Object);
var cluster1 = new ClusterState(clusterId: "cluster1");
cluster1.Model = new ClusterModel(new ClusterConfig(), httpClient);
var destination1 = cluster1.Destinations.GetOrAdd(
"destination1",
id => new DestinationState(id) { Model = new DestinationModel(new DestinationConfig { Address = "https://localhost:123/a/b/" }) });
cluster1.DestinationsState = new ClusterDestinationsState(new[] { destination1 }, new[] { destination1 });

var aspNetCoreEndpoints = new List<Endpoint>();
var routeConfig = new RouteModel(
config: new RouteConfig(),
cluster1,
HttpTransformer.Default);
var aspNetCoreEndpoint = CreateAspNetCoreEndpoint(routeConfig,
builder =>
{
builder.Metadata.Add(new RequestTimeoutAttribute(TimeoutPolicyConstants.Default));
});
aspNetCoreEndpoints.Add(aspNetCoreEndpoint);
var httpContext = new DefaultHttpContext();
httpContext.SetEndpoint(aspNetCoreEndpoint);

var sut = Create<ProxyPipelineInitializerMiddleware>();

await sut.Invoke(httpContext);

var proxyFeature = httpContext.GetReverseProxyFeature();
Assert.NotNull(proxyFeature);
Assert.NotNull(proxyFeature.AvailableDestinations);
Assert.Single(proxyFeature.AvailableDestinations);
Assert.Same(destination1, proxyFeature.AvailableDestinations[0]);
Assert.Same(cluster1.Model, proxyFeature.Cluster);

Assert.Equal(StatusCodes.Status418ImATeapot, httpContext.Response.StatusCode);
}
#endif

private static Endpoint CreateAspNetCoreEndpoint(RouteModel routeConfig)
private static Endpoint CreateAspNetCoreEndpoint(RouteModel routeConfig, Action<RouteEndpointBuilder> configure = null)
{
var endpointBuilder = new RouteEndpointBuilder(
requestDelegate: httpContext => Task.CompletedTask,
routePattern: RoutePatternFactory.Parse("/"),
order: 0);
endpointBuilder.Metadata.Add(routeConfig);
configure?.Invoke(endpointBuilder);
return endpointBuilder.Build();
}
}

0 comments on commit ed09592

Please sign in to comment.