diff --git a/src/ReverseProxy/Forwarder/ForwarderRequestConfig.cs b/src/ReverseProxy/Forwarder/ForwarderRequestConfig.cs index 8c0f2453a..a2344dbdf 100644 --- a/src/ReverseProxy/Forwarder/ForwarderRequestConfig.cs +++ b/src/ReverseProxy/Forwarder/ForwarderRequestConfig.cs @@ -3,11 +3,13 @@ using System; using System.Net.Http; +using System.Threading; +using Microsoft.AspNetCore.Http; namespace Yarp.ReverseProxy.Forwarder; /// -/// Config for +/// Config for /// public sealed record ForwarderRequestConfig { diff --git a/src/ReverseProxy/Forwarder/HttpForwarder.cs b/src/ReverseProxy/Forwarder/HttpForwarder.cs index ca70a1b1e..b0d2a2388 100644 --- a/src/ReverseProxy/Forwarder/HttpForwarder.cs +++ b/src/ReverseProxy/Forwarder/HttpForwarder.cs @@ -83,12 +83,21 @@ public HttpForwarder(ILogger logger, IClock clock) /// ASP .NET Core (Kestrel) will finally send response trailers (if any) /// after we complete the steps above and relinquish control. /// - public async ValueTask SendAsync( + public ValueTask SendAsync( HttpContext context, string destinationPrefix, HttpMessageInvoker httpClient, ForwarderRequestConfig requestConfig, HttpTransformer transformer) + => SendAsync(context, destinationPrefix, httpClient, requestConfig, transformer, CancellationToken.None); + + public async ValueTask SendAsync( + HttpContext context, + string destinationPrefix, + HttpMessageInvoker httpClient, + ForwarderRequestConfig requestConfig, + HttpTransformer transformer, + CancellationToken cancellationToken) { _ = context ?? throw new ArgumentNullException(nameof(context)); _ = destinationPrefix ?? throw new ArgumentNullException(nameof(destinationPrefix)); @@ -110,7 +119,7 @@ public async ValueTask SendAsync( ForwarderTelemetry.Log.ForwarderStart(destinationPrefix); - var activityCancellationSource = ActivityCancellationTokenSource.Rent(requestConfig?.ActivityTimeout ?? DefaultTimeout, context.RequestAborted); + var activityCancellationSource = ActivityCancellationTokenSource.Rent(requestConfig?.ActivityTimeout ?? DefaultTimeout, context.RequestAborted, cancellationToken); try { var isClientHttp2OrGreater = ProtocolHelper.IsHttp2OrGreater(context.Request.Protocol); @@ -193,7 +202,7 @@ public async ValueTask SendAsync( { // :: Step 5: Copy response status line Client ◄-- Proxy ◄-- Destination // :: Step 6: Copy response headers Client ◄-- Proxy ◄-- Destination - var copyBody = await CopyResponseStatusAndHeadersAsync(destinationResponse, context, transformer); + var copyBody = await CopyResponseStatusAndHeadersAsync(destinationResponse, context, transformer, activityCancellationSource.Token); if (!copyBody) { @@ -260,7 +269,7 @@ public async ValueTask SendAsync( } // :: Step 8: Copy response trailer headers and finish response Client ◄-- Proxy ◄-- Destination - await CopyResponseTrailingHeadersAsync(destinationResponse, context, transformer); + await CopyResponseTrailingHeadersAsync(destinationResponse, context, transformer, activityCancellationSource.Token); if (isStreamingRequest) { @@ -402,7 +411,7 @@ public async ValueTask SendAsync( destinationRequest.Content = requestContent; // :: Step 3: Copy request headers Client --► Proxy --► Destination - await transformer.TransformRequestAsync(context, destinationRequest, destinationPrefix); + await transformer.TransformRequestAsync(context, destinationRequest, destinationPrefix, activityToken.Token); // The transformer generated a response, do not forward. if (RequestUtilities.IsResponseSet(context.Response)) @@ -410,6 +419,9 @@ public async ValueTask SendAsync( return (destinationRequest, requestContent, false); } + // Transforms may have taken a while, especially if they buffered the body, they count as forward progress. + activityToken.ResetTimeout(); + FixupUpgradeRequestHeaders(context, destinationRequest, outgoingUpgrade, outgoingConnect); // Allow someone to custom build the request uri, otherwise provide a default for them. @@ -653,7 +665,7 @@ async ValueTask ReportErrorAsync(ForwarderError error, int statu } } - private static ValueTask CopyResponseStatusAndHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer) + private static ValueTask CopyResponseStatusAndHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer, CancellationToken cancellationToken) { context.Response.StatusCode = (int)source.StatusCode; @@ -667,7 +679,7 @@ private static ValueTask CopyResponseStatusAndHeadersAsync(HttpResponseMes } // Copies headers - return transformer.TransformResponseAsync(context, source); + return transformer.TransformResponseAsync(context, source, cancellationToken); } private async ValueTask HandleUpgradedResponse(HttpContext context, HttpResponseMessage destinationResponse, @@ -891,10 +903,10 @@ private async ValueTask HandleResponseBodyErrorAsync(HttpContext return error; } - private static ValueTask CopyResponseTrailingHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer) + private static ValueTask CopyResponseTrailingHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer, CancellationToken cancellationToken) { // Copies trailers - return transformer.TransformResponseTrailersAsync(context, source); + return transformer.TransformResponseTrailersAsync(context, source, cancellationToken); } diff --git a/src/ReverseProxy/Forwarder/HttpTransformer.cs b/src/ReverseProxy/Forwarder/HttpTransformer.cs index a3a847385..d36363e0d 100644 --- a/src/ReverseProxy/Forwarder/HttpTransformer.cs +++ b/src/ReverseProxy/Forwarder/HttpTransformer.cs @@ -7,6 +7,7 @@ using System.Net.Http; using System.Net.Http.Headers; using System.Runtime.CompilerServices; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; @@ -53,6 +54,23 @@ private static bool IsBodylessStatusCode(HttpStatusCode statusCode) => _ => false }; + /// + /// A callback that is invoked prior to sending the proxied request. All HttpRequestMessage fields are + /// initialized except RequestUri, which will be initialized after the callback if no value is provided. + /// See for constructing a custom request Uri. + /// The string parameter represents the destination URI prefix that should be used when constructing the RequestUri. + /// The headers are copied by the base implementation, excluding some protocol headers like HTTP/2 pseudo headers (":authority"). + /// This method may be overridden to conditionally produce a response, such as for error conditions, and prevent the request from + /// being proxied. This is indicated by setting the `HttpResponse.StatusCode` to a value other than 200, or calling `HttpResponse.StartAsync()`, + /// or writing to the `HttpResponse.Body` or `BodyWriter`. + /// + /// The incoming request. + /// The outgoing proxy request. + /// The uri prefix for the selected destination server which can be used to create the RequestUri. + /// Indicates that the request is being canceled. + public virtual ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix, CancellationToken cancellationToken) + => TransformRequestAsync(httpContext, proxyRequest, destinationPrefix); + /// /// A callback that is invoked prior to sending the proxied request. All HttpRequestMessage fields are /// initialized except RequestUri, which will be initialized after the callback if no value is provided. @@ -126,9 +144,24 @@ public virtual ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequ /// /// The incoming request. /// The response from the destination. This can be null if the destination did not respond. + /// Indicates that the request is being canceled. /// A bool indicating if the response should be proxied to the client or not. A derived implementation /// that returns false may send an alternate response inline or return control to the caller for it to retry, respond, /// etc. + public virtual ValueTask TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse, CancellationToken cancellationToken) + => TransformResponseAsync(httpContext, proxyResponse); + + /// + /// A callback that is invoked when the proxied response is received. The status code and reason phrase will be copied + /// to the HttpContext.Response before the callback is invoked, but may still be modified there. The headers will be + /// copied to HttpContext.Response.Headers by the base implementation, excludes certain protocol headers like + /// `Transfer-Encoding: chunked`. + /// + /// The incoming request. + /// The response from the destination. This can be null if the destination did not respond. + /// A bool indicating if the response should be proxied to the client or not. A derived implementation + /// that returns false may send an alternate response inline or return control to the caller for it to retry, respond, + /// etc. public virtual ValueTask TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse) { if (proxyResponse is null) @@ -171,6 +204,16 @@ public virtual ValueTask TransformResponseAsync(HttpContext httpContext, H return new ValueTask(true); } + /// + /// A callback that is invoked after the response body to modify trailers, if supported. The trailers will be + /// copied to the HttpContext.Response by the base implementation. + /// + /// The incoming request. + /// The response from the destination. + /// Indicates that the request is being canceled. + public virtual ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse, CancellationToken cancellationToken) + => TransformResponseTrailersAsync(httpContext, proxyResponse); + /// /// A callback that is invoked after the response body to modify trailers, if supported. The trailers will be /// copied to the HttpContext.Response by the base implementation. diff --git a/src/ReverseProxy/Forwarder/IHttpForwarder.cs b/src/ReverseProxy/Forwarder/IHttpForwarder.cs index 378e62892..8372625df 100644 --- a/src/ReverseProxy/Forwarder/IHttpForwarder.cs +++ b/src/ReverseProxy/Forwarder/IHttpForwarder.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System.Net.Http; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; @@ -24,4 +25,19 @@ public interface IHttpForwarder /// The result of forwarding the request and response. ValueTask SendAsync(HttpContext context, string destinationPrefix, HttpMessageInvoker httpClient, ForwarderRequestConfig requestConfig, HttpTransformer transformer); + + /// + /// Forwards the incoming request to the destination server, and the response back to the client. + /// + /// The HttpContext to forward. + /// The url prefix for where to forward the request to. + /// The HTTP client used to forward the request. + /// Config for the outgoing request. + /// Request and response transforms. Use if + /// custom transformations are not needed. + /// A cancellation token that can be used to abort the request. + /// The result of forwarding the request and response. + ValueTask SendAsync(HttpContext context, string destinationPrefix, HttpMessageInvoker httpClient, + ForwarderRequestConfig requestConfig, HttpTransformer transformer, CancellationToken cancellationToken) + => SendAsync(context, destinationPrefix, httpClient, requestConfig, transformer); } diff --git a/src/ReverseProxy/Transforms/Builder/StructuredTransformer.cs b/src/ReverseProxy/Transforms/Builder/StructuredTransformer.cs index 145d9f477..7f2f2edf1 100644 --- a/src/ReverseProxy/Transforms/Builder/StructuredTransformer.cs +++ b/src/ReverseProxy/Transforms/Builder/StructuredTransformer.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Linq; using System.Net.Http; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; @@ -63,11 +64,11 @@ internal StructuredTransformer(bool? copyRequestHeaders, bool? copyResponseHeade /// internal ResponseTrailersTransform[] ResponseTrailerTransforms { get; } - public override async ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix) + public override async ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix, CancellationToken cancellationToken) { if (ShouldCopyRequestHeaders.GetValueOrDefault(true)) { - await base.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix); + await base.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, cancellationToken); } if (RequestTransforms.Length == 0) @@ -83,6 +84,7 @@ public override async ValueTask TransformRequestAsync(HttpContext httpContext, H Path = httpContext.Request.Path, Query = new QueryTransformContext(httpContext.Request), HeadersCopied = ShouldCopyRequestHeaders.GetValueOrDefault(true), + CancellationToken = cancellationToken, }; foreach (var requestTransform in RequestTransforms) @@ -101,11 +103,11 @@ public override async ValueTask TransformRequestAsync(HttpContext httpContext, H transformContext.DestinationPrefix, transformContext.Path, transformContext.Query.QueryString); } - public override async ValueTask TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse) + public override async ValueTask TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse, CancellationToken cancellationToken) { if (ShouldCopyResponseHeaders.GetValueOrDefault(true)) { - await base.TransformResponseAsync(httpContext, proxyResponse); + await base.TransformResponseAsync(httpContext, proxyResponse, cancellationToken); } if (ResponseTransforms.Length == 0) @@ -118,6 +120,7 @@ public override async ValueTask TransformResponseAsync(HttpContext httpCon HttpContext = httpContext, ProxyResponse = proxyResponse, HeadersCopied = ShouldCopyResponseHeaders.GetValueOrDefault(true), + CancellationToken = cancellationToken, }; foreach (var responseTransform in ResponseTransforms) @@ -128,11 +131,11 @@ public override async ValueTask TransformResponseAsync(HttpContext httpCon return !transformContext.SuppressResponseBody; } - public override async ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse) + public override async ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse, CancellationToken cancellationToken) { if (ShouldCopyResponseTrailers.GetValueOrDefault(true)) { - await base.TransformResponseTrailersAsync(httpContext, proxyResponse); + await base.TransformResponseTrailersAsync(httpContext, proxyResponse, cancellationToken); } if (ResponseTrailerTransforms.Length == 0) @@ -150,6 +153,7 @@ public override async ValueTask TransformResponseTrailersAsync(HttpContext httpC HttpContext = httpContext, ProxyResponse = proxyResponse, HeadersCopied = ShouldCopyResponseTrailers.GetValueOrDefault(true), + CancellationToken = cancellationToken, }; foreach (var responseTrailerTransform in ResponseTrailerTransforms) diff --git a/src/ReverseProxy/Transforms/RequestTransformContext.cs b/src/ReverseProxy/Transforms/RequestTransformContext.cs index b6cc84cfd..41ecd7b9b 100644 --- a/src/ReverseProxy/Transforms/RequestTransformContext.cs +++ b/src/ReverseProxy/Transforms/RequestTransformContext.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System.Net.Http; +using System.Threading; using Microsoft.AspNetCore.Http; namespace Yarp.ReverseProxy.Transforms; @@ -48,4 +49,9 @@ public class RequestTransformContext /// port and path base. The 'Path' and 'Query' properties will be appended to this after the transforms have run. /// public string DestinationPrefix { get; init; } = default!; + + /// + /// A indicating that the request is being aborted. + /// + public CancellationToken CancellationToken { get; set; } } diff --git a/src/ReverseProxy/Transforms/ResponseTrailersTransformContext.cs b/src/ReverseProxy/Transforms/ResponseTrailersTransformContext.cs index bb59470c0..d4372fdfa 100644 --- a/src/ReverseProxy/Transforms/ResponseTrailersTransformContext.cs +++ b/src/ReverseProxy/Transforms/ResponseTrailersTransformContext.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System.Net.Http; +using System.Threading; using Microsoft.AspNetCore.Http; namespace Yarp.ReverseProxy.Transforms; @@ -27,4 +28,9 @@ public class ResponseTrailersTransformContext /// should operate on. /// public bool HeadersCopied { get; set; } + + /// + /// A indicating that the request is being aborted. + /// + public CancellationToken CancellationToken { get; set; } } diff --git a/src/ReverseProxy/Transforms/ResponseTransformContext.cs b/src/ReverseProxy/Transforms/ResponseTransformContext.cs index da5933375..9eef19a5d 100644 --- a/src/ReverseProxy/Transforms/ResponseTransformContext.cs +++ b/src/ReverseProxy/Transforms/ResponseTransformContext.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System.Net.Http; +using System.Threading; using Microsoft.AspNetCore.Http; namespace Yarp.ReverseProxy.Transforms; @@ -33,4 +34,9 @@ public class ResponseTransformContext /// Defaults to false. /// public bool SuppressResponseBody { get; set; } + + /// + /// A indicating that the request is being aborted. + /// + public CancellationToken CancellationToken { get; set; } } diff --git a/src/ReverseProxy/Utilities/ActivityCancellationTokenSource.cs b/src/ReverseProxy/Utilities/ActivityCancellationTokenSource.cs index 7449dc55b..08da5bd8f 100644 --- a/src/ReverseProxy/Utilities/ActivityCancellationTokenSource.cs +++ b/src/ReverseProxy/Utilities/ActivityCancellationTokenSource.cs @@ -19,7 +19,8 @@ internal sealed class ActivityCancellationTokenSource : CancellationTokenSource }; private int _activityTimeoutMs; - private CancellationTokenRegistration _linkedRegistration; + private CancellationTokenRegistration _linkedRegistration1; + private CancellationTokenRegistration _linkedRegistration2; private ActivityCancellationTokenSource() { } @@ -28,7 +29,7 @@ public void ResetTimeout() CancelAfter(_activityTimeoutMs); } - public static ActivityCancellationTokenSource Rent(TimeSpan activityTimeout, CancellationToken linkedToken) + public static ActivityCancellationTokenSource Rent(TimeSpan activityTimeout, CancellationToken linkedToken1 = default, CancellationToken linkedToken2 = default) { if (_sharedSources.TryDequeue(out var cts)) { @@ -40,7 +41,8 @@ public static ActivityCancellationTokenSource Rent(TimeSpan activityTimeout, Can } cts._activityTimeoutMs = (int)activityTimeout.TotalMilliseconds; - cts._linkedRegistration = linkedToken.UnsafeRegister(_linkedTokenCancelDelegate, cts); + cts._linkedRegistration1 = linkedToken1.UnsafeRegister(_linkedTokenCancelDelegate, cts); + cts._linkedRegistration2 = linkedToken2.UnsafeRegister(_linkedTokenCancelDelegate, cts); cts.ResetTimeout(); return cts; @@ -48,8 +50,10 @@ public static ActivityCancellationTokenSource Rent(TimeSpan activityTimeout, Can public void Return() { - _linkedRegistration.Dispose(); - _linkedRegistration = default; + _linkedRegistration1.Dispose(); + _linkedRegistration1 = default; + _linkedRegistration2.Dispose(); + _linkedRegistration2 = default; if (TryReset()) { diff --git a/test/ReverseProxy.Tests/Transforms/Builder/TransformBuilderTests.cs b/test/ReverseProxy.Tests/Transforms/Builder/TransformBuilderTests.cs index 847bd9d2c..8bc06540a 100644 --- a/test/ReverseProxy.Tests/Transforms/Builder/TransformBuilderTests.cs +++ b/test/ReverseProxy.Tests/Transforms/Builder/TransformBuilderTests.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Linq; using System.Net.Http; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.HttpOverrides; @@ -307,7 +308,7 @@ public async Task UseOriginalHost(bool? useOriginalHost, bool? copyHeaders) httpContext.Request.Host = new HostString("StartHost"); var proxyRequest = new HttpRequestMessage(); var destinationPrefix = "http://destinationhost:9090/path"; - await results.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix); + await results.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, CancellationToken.None); if (useOriginalHost.GetValueOrDefault(false)) { @@ -373,7 +374,7 @@ public async Task UseCustomHost(bool? useOriginalHost, bool? copyHeaders) var proxyRequest = new HttpRequestMessage(); var destinationPrefix = "http://destinationhost:9090/path"; - await results.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix); + await results.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, CancellationToken.None); Assert.Equal("CustomHost", proxyRequest.Headers.Host); } diff --git a/test/ReverseProxy.Tests/Utilities/ActivityCancellationTokenSourceTests.cs b/test/ReverseProxy.Tests/Utilities/ActivityCancellationTokenSourceTests.cs index 6018b35a7..53a62e583 100644 --- a/test/ReverseProxy.Tests/Utilities/ActivityCancellationTokenSourceTests.cs +++ b/test/ReverseProxy.Tests/Utilities/ActivityCancellationTokenSourceTests.cs @@ -45,7 +45,7 @@ public void ActivityCancellationTokenSource_DoesNotPoolsCanceledSources() } [Fact] - public void ActivityCancellationTokenSource_RespectsLinkedToken() + public void ActivityCancellationTokenSource_RespectsLinkedToken1() { var linkedCts = new CancellationTokenSource(); @@ -56,14 +56,40 @@ public void ActivityCancellationTokenSource_RespectsLinkedToken() } [Fact] - public void ActivityCancellationTokenSource_ClearsRegistrations() + public void ActivityCancellationTokenSource_RespectsLinkedToken2() { var linkedCts = new CancellationTokenSource(); - var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), linkedCts.Token); + var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), default, linkedCts.Token); + linkedCts.Cancel(); + + Assert.True(cts.IsCancellationRequested); + } + + [Fact] + public void ActivityCancellationTokenSource_RespectsBothLinkedTokens() + { + var linkedCts1 = new CancellationTokenSource(); + var linkedCts2 = new CancellationTokenSource(); + + var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), linkedCts1.Token, linkedCts2.Token); + linkedCts1.Cancel(); + linkedCts2.Cancel(); + + Assert.True(cts.IsCancellationRequested); + } + + [Fact] + public void ActivityCancellationTokenSource_ClearsRegistrations() + { + var linkedCts1 = new CancellationTokenSource(); + var linkedCts2 = new CancellationTokenSource(); + + var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), linkedCts1.Token, linkedCts2.Token); cts.Return(); - linkedCts.Cancel(); + linkedCts1.Cancel(); + linkedCts2.Cancel(); Assert.False(cts.IsCancellationRequested); }