Skip to content

Commit

Permalink
Enable direct cancellation for IHttpForwarder, transforms #1542 (#1985)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tratcher authored Jan 13, 2023
1 parent c253f4a commit b5192aa
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 27 deletions.
4 changes: 3 additions & 1 deletion src/ReverseProxy/Forwarder/ForwarderRequestConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

using System;
using System.Net.Http;
using System.Threading;
using Microsoft.AspNetCore.Http;

namespace Yarp.ReverseProxy.Forwarder;

/// <summary>
/// Config for <see cref="IHttpForwarder.SendAsync"/>
/// Config for <see cref="IHttpForwarder.SendAsync(HttpContext, string, HttpMessageInvoker, ForwarderRequestConfig, HttpTransformer, CancellationToken)"/>
/// </summary>
public sealed record ForwarderRequestConfig
{
Expand Down
30 changes: 21 additions & 9 deletions src/ReverseProxy/Forwarder/HttpForwarder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,21 @@ public HttpForwarder(ILogger<HttpForwarder> logger, IClock clock)
/// ASP .NET Core (Kestrel) will finally send response trailers (if any)
/// after we complete the steps above and relinquish control.
/// </remarks>
public async ValueTask<ForwarderError> SendAsync(
public ValueTask<ForwarderError> SendAsync(
HttpContext context,
string destinationPrefix,
HttpMessageInvoker httpClient,
ForwarderRequestConfig requestConfig,
HttpTransformer transformer)
=> SendAsync(context, destinationPrefix, httpClient, requestConfig, transformer, CancellationToken.None);

public async ValueTask<ForwarderError> SendAsync(
HttpContext context,
string destinationPrefix,
HttpMessageInvoker httpClient,
ForwarderRequestConfig requestConfig,
HttpTransformer transformer,
CancellationToken cancellationToken)
{
_ = context ?? throw new ArgumentNullException(nameof(context));
_ = destinationPrefix ?? throw new ArgumentNullException(nameof(destinationPrefix));
Expand All @@ -110,7 +119,7 @@ public async ValueTask<ForwarderError> SendAsync(

ForwarderTelemetry.Log.ForwarderStart(destinationPrefix);

var activityCancellationSource = ActivityCancellationTokenSource.Rent(requestConfig?.ActivityTimeout ?? DefaultTimeout, context.RequestAborted);
var activityCancellationSource = ActivityCancellationTokenSource.Rent(requestConfig?.ActivityTimeout ?? DefaultTimeout, context.RequestAborted, cancellationToken);
try
{
var isClientHttp2OrGreater = ProtocolHelper.IsHttp2OrGreater(context.Request.Protocol);
Expand Down Expand Up @@ -193,7 +202,7 @@ public async ValueTask<ForwarderError> SendAsync(
{
// :: Step 5: Copy response status line Client ◄-- Proxy ◄-- Destination
// :: Step 6: Copy response headers Client ◄-- Proxy ◄-- Destination
var copyBody = await CopyResponseStatusAndHeadersAsync(destinationResponse, context, transformer);
var copyBody = await CopyResponseStatusAndHeadersAsync(destinationResponse, context, transformer, activityCancellationSource.Token);

if (!copyBody)
{
Expand Down Expand Up @@ -260,7 +269,7 @@ public async ValueTask<ForwarderError> SendAsync(
}

// :: Step 8: Copy response trailer headers and finish response Client ◄-- Proxy ◄-- Destination
await CopyResponseTrailingHeadersAsync(destinationResponse, context, transformer);
await CopyResponseTrailingHeadersAsync(destinationResponse, context, transformer, activityCancellationSource.Token);

if (isStreamingRequest)
{
Expand Down Expand Up @@ -402,14 +411,17 @@ public async ValueTask<ForwarderError> SendAsync(
destinationRequest.Content = requestContent;

// :: Step 3: Copy request headers Client --► Proxy --► Destination
await transformer.TransformRequestAsync(context, destinationRequest, destinationPrefix);
await transformer.TransformRequestAsync(context, destinationRequest, destinationPrefix, activityToken.Token);

// The transformer generated a response, do not forward.
if (RequestUtilities.IsResponseSet(context.Response))
{
return (destinationRequest, requestContent, false);
}

// Transforms may have taken a while, especially if they buffered the body, they count as forward progress.
activityToken.ResetTimeout();

FixupUpgradeRequestHeaders(context, destinationRequest, outgoingUpgrade, outgoingConnect);

// Allow someone to custom build the request uri, otherwise provide a default for them.
Expand Down Expand Up @@ -653,7 +665,7 @@ async ValueTask<ForwarderError> ReportErrorAsync(ForwarderError error, int statu
}
}

private static ValueTask<bool> CopyResponseStatusAndHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer)
private static ValueTask<bool> CopyResponseStatusAndHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer, CancellationToken cancellationToken)
{
context.Response.StatusCode = (int)source.StatusCode;

Expand All @@ -667,7 +679,7 @@ private static ValueTask<bool> CopyResponseStatusAndHeadersAsync(HttpResponseMes
}

// Copies headers
return transformer.TransformResponseAsync(context, source);
return transformer.TransformResponseAsync(context, source, cancellationToken);
}

private async ValueTask<ForwarderError> HandleUpgradedResponse(HttpContext context, HttpResponseMessage destinationResponse,
Expand Down Expand Up @@ -891,10 +903,10 @@ private async ValueTask<ForwarderError> HandleResponseBodyErrorAsync(HttpContext
return error;
}

private static ValueTask CopyResponseTrailingHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer)
private static ValueTask CopyResponseTrailingHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer, CancellationToken cancellationToken)
{
// Copies trailers
return transformer.TransformResponseTrailersAsync(context, source);
return transformer.TransformResponseTrailersAsync(context, source, cancellationToken);
}


Expand Down
43 changes: 43 additions & 0 deletions src/ReverseProxy/Forwarder/HttpTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Net.Http;
using System.Net.Http.Headers;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
Expand Down Expand Up @@ -53,6 +54,23 @@ private static bool IsBodylessStatusCode(HttpStatusCode statusCode) =>
_ => false
};

/// <summary>
/// A callback that is invoked prior to sending the proxied request. All HttpRequestMessage fields are
/// initialized except RequestUri, which will be initialized after the callback if no value is provided.
/// See <see cref="RequestUtilities.MakeDestinationAddress(string, PathString, QueryString)"/> for constructing a custom request Uri.
/// The string parameter represents the destination URI prefix that should be used when constructing the RequestUri.
/// The headers are copied by the base implementation, excluding some protocol headers like HTTP/2 pseudo headers (":authority").
/// This method may be overridden to conditionally produce a response, such as for error conditions, and prevent the request from
/// being proxied. This is indicated by setting the `HttpResponse.StatusCode` to a value other than 200, or calling `HttpResponse.StartAsync()`,
/// or writing to the `HttpResponse.Body` or `BodyWriter`.
/// </summary>
/// <param name="httpContext">The incoming request.</param>
/// <param name="proxyRequest">The outgoing proxy request.</param>
/// <param name="destinationPrefix">The uri prefix for the selected destination server which can be used to create the RequestUri.</param>
/// <param name="cancellationToken">Indicates that the request is being canceled.</param>
public virtual ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix, CancellationToken cancellationToken)
=> TransformRequestAsync(httpContext, proxyRequest, destinationPrefix);

/// <summary>
/// A callback that is invoked prior to sending the proxied request. All HttpRequestMessage fields are
/// initialized except RequestUri, which will be initialized after the callback if no value is provided.
Expand Down Expand Up @@ -126,9 +144,24 @@ public virtual ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequ
/// </summary>
/// <param name="httpContext">The incoming request.</param>
/// <param name="proxyResponse">The response from the destination. This can be null if the destination did not respond.</param>
/// <param name="cancellationToken">Indicates that the request is being canceled.</param>
/// <returns>A bool indicating if the response should be proxied to the client or not. A derived implementation
/// that returns false may send an alternate response inline or return control to the caller for it to retry, respond,
/// etc.</returns>
public virtual ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse, CancellationToken cancellationToken)
=> TransformResponseAsync(httpContext, proxyResponse);

/// <summary>
/// A callback that is invoked when the proxied response is received. The status code and reason phrase will be copied
/// to the HttpContext.Response before the callback is invoked, but may still be modified there. The headers will be
/// copied to HttpContext.Response.Headers by the base implementation, excludes certain protocol headers like
/// `Transfer-Encoding: chunked`.
/// </summary>
/// <param name="httpContext">The incoming request.</param>
/// <param name="proxyResponse">The response from the destination. This can be null if the destination did not respond.</param>
/// <returns>A bool indicating if the response should be proxied to the client or not. A derived implementation
/// that returns false may send an alternate response inline or return control to the caller for it to retry, respond,
/// etc.</returns>
public virtual ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse)
{
if (proxyResponse is null)
Expand Down Expand Up @@ -171,6 +204,16 @@ public virtual ValueTask<bool> TransformResponseAsync(HttpContext httpContext, H
return new ValueTask<bool>(true);
}

/// <summary>
/// A callback that is invoked after the response body to modify trailers, if supported. The trailers will be
/// copied to the HttpContext.Response by the base implementation.
/// </summary>
/// <param name="httpContext">The incoming request.</param>
/// <param name="proxyResponse">The response from the destination.</param>
/// <param name="cancellationToken">Indicates that the request is being canceled.</param>
public virtual ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse, CancellationToken cancellationToken)
=> TransformResponseTrailersAsync(httpContext, proxyResponse);

/// <summary>
/// A callback that is invoked after the response body to modify trailers, if supported. The trailers will be
/// copied to the HttpContext.Response by the base implementation.
Expand Down
16 changes: 16 additions & 0 deletions src/ReverseProxy/Forwarder/IHttpForwarder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;

Expand All @@ -24,4 +25,19 @@ public interface IHttpForwarder
/// <returns>The result of forwarding the request and response.</returns>
ValueTask<ForwarderError> SendAsync(HttpContext context, string destinationPrefix, HttpMessageInvoker httpClient,
ForwarderRequestConfig requestConfig, HttpTransformer transformer);

/// <summary>
/// Forwards the incoming request to the destination server, and the response back to the client.
/// </summary>
/// <param name="context">The HttpContext to forward.</param>
/// <param name="destinationPrefix">The url prefix for where to forward the request to.</param>
/// <param name="httpClient">The HTTP client used to forward the request.</param>
/// <param name="requestConfig">Config for the outgoing request.</param>
/// <param name="transformer">Request and response transforms. Use <see cref="HttpTransformer.Default"/> if
/// custom transformations are not needed.</param>
/// <param name="cancellationToken">A cancellation token that can be used to abort the request.</param>
/// <returns>The result of forwarding the request and response.</returns>
ValueTask<ForwarderError> SendAsync(HttpContext context, string destinationPrefix, HttpMessageInvoker httpClient,
ForwarderRequestConfig requestConfig, HttpTransformer transformer, CancellationToken cancellationToken)
=> SendAsync(context, destinationPrefix, httpClient, requestConfig, transformer);
}
16 changes: 10 additions & 6 deletions src/ReverseProxy/Transforms/Builder/StructuredTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
Expand Down Expand Up @@ -63,11 +64,11 @@ internal StructuredTransformer(bool? copyRequestHeaders, bool? copyResponseHeade
/// </summary>
internal ResponseTrailersTransform[] ResponseTrailerTransforms { get; }

public override async ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix)
public override async ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix, CancellationToken cancellationToken)
{
if (ShouldCopyRequestHeaders.GetValueOrDefault(true))
{
await base.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix);
await base.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, cancellationToken);
}

if (RequestTransforms.Length == 0)
Expand All @@ -83,6 +84,7 @@ public override async ValueTask TransformRequestAsync(HttpContext httpContext, H
Path = httpContext.Request.Path,
Query = new QueryTransformContext(httpContext.Request),
HeadersCopied = ShouldCopyRequestHeaders.GetValueOrDefault(true),
CancellationToken = cancellationToken,
};

foreach (var requestTransform in RequestTransforms)
Expand All @@ -101,11 +103,11 @@ public override async ValueTask TransformRequestAsync(HttpContext httpContext, H
transformContext.DestinationPrefix, transformContext.Path, transformContext.Query.QueryString);
}

public override async ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse)
public override async ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse, CancellationToken cancellationToken)
{
if (ShouldCopyResponseHeaders.GetValueOrDefault(true))
{
await base.TransformResponseAsync(httpContext, proxyResponse);
await base.TransformResponseAsync(httpContext, proxyResponse, cancellationToken);
}

if (ResponseTransforms.Length == 0)
Expand All @@ -118,6 +120,7 @@ public override async ValueTask<bool> TransformResponseAsync(HttpContext httpCon
HttpContext = httpContext,
ProxyResponse = proxyResponse,
HeadersCopied = ShouldCopyResponseHeaders.GetValueOrDefault(true),
CancellationToken = cancellationToken,
};

foreach (var responseTransform in ResponseTransforms)
Expand All @@ -128,11 +131,11 @@ public override async ValueTask<bool> TransformResponseAsync(HttpContext httpCon
return !transformContext.SuppressResponseBody;
}

public override async ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse)
public override async ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse, CancellationToken cancellationToken)
{
if (ShouldCopyResponseTrailers.GetValueOrDefault(true))
{
await base.TransformResponseTrailersAsync(httpContext, proxyResponse);
await base.TransformResponseTrailersAsync(httpContext, proxyResponse, cancellationToken);
}

if (ResponseTrailerTransforms.Length == 0)
Expand All @@ -150,6 +153,7 @@ public override async ValueTask TransformResponseTrailersAsync(HttpContext httpC
HttpContext = httpContext,
ProxyResponse = proxyResponse,
HeadersCopied = ShouldCopyResponseTrailers.GetValueOrDefault(true),
CancellationToken = cancellationToken,
};

foreach (var responseTrailerTransform in ResponseTrailerTransforms)
Expand Down
6 changes: 6 additions & 0 deletions src/ReverseProxy/Transforms/RequestTransformContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System.Net.Http;
using System.Threading;
using Microsoft.AspNetCore.Http;

namespace Yarp.ReverseProxy.Transforms;
Expand Down Expand Up @@ -48,4 +49,9 @@ public class RequestTransformContext
/// port and path base. The 'Path' and 'Query' properties will be appended to this after the transforms have run.
/// </summary>
public string DestinationPrefix { get; init; } = default!;

/// <summary>
/// A <see cref="CancellationToken"/> indicating that the request is being aborted.
/// </summary>
public CancellationToken CancellationToken { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System.Net.Http;
using System.Threading;
using Microsoft.AspNetCore.Http;

namespace Yarp.ReverseProxy.Transforms;
Expand All @@ -27,4 +28,9 @@ public class ResponseTrailersTransformContext
/// should operate on.
/// </summary>
public bool HeadersCopied { get; set; }

/// <summary>
/// A <see cref="CancellationToken"/> indicating that the request is being aborted.
/// </summary>
public CancellationToken CancellationToken { get; set; }
}
6 changes: 6 additions & 0 deletions src/ReverseProxy/Transforms/ResponseTransformContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System.Net.Http;
using System.Threading;
using Microsoft.AspNetCore.Http;

namespace Yarp.ReverseProxy.Transforms;
Expand Down Expand Up @@ -33,4 +34,9 @@ public class ResponseTransformContext
/// Defaults to false.
/// </summary>
public bool SuppressResponseBody { get; set; }

/// <summary>
/// A <see cref="CancellationToken"/> indicating that the request is being aborted.
/// </summary>
public CancellationToken CancellationToken { get; set; }
}
14 changes: 9 additions & 5 deletions src/ReverseProxy/Utilities/ActivityCancellationTokenSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ internal sealed class ActivityCancellationTokenSource : CancellationTokenSource
};

private int _activityTimeoutMs;
private CancellationTokenRegistration _linkedRegistration;
private CancellationTokenRegistration _linkedRegistration1;
private CancellationTokenRegistration _linkedRegistration2;

private ActivityCancellationTokenSource() { }

Expand All @@ -28,7 +29,7 @@ public void ResetTimeout()
CancelAfter(_activityTimeoutMs);
}

public static ActivityCancellationTokenSource Rent(TimeSpan activityTimeout, CancellationToken linkedToken)
public static ActivityCancellationTokenSource Rent(TimeSpan activityTimeout, CancellationToken linkedToken1 = default, CancellationToken linkedToken2 = default)
{
if (_sharedSources.TryDequeue(out var cts))
{
Expand All @@ -40,16 +41,19 @@ public static ActivityCancellationTokenSource Rent(TimeSpan activityTimeout, Can
}

cts._activityTimeoutMs = (int)activityTimeout.TotalMilliseconds;
cts._linkedRegistration = linkedToken.UnsafeRegister(_linkedTokenCancelDelegate, cts);
cts._linkedRegistration1 = linkedToken1.UnsafeRegister(_linkedTokenCancelDelegate, cts);
cts._linkedRegistration2 = linkedToken2.UnsafeRegister(_linkedTokenCancelDelegate, cts);
cts.ResetTimeout();

return cts;
}

public void Return()
{
_linkedRegistration.Dispose();
_linkedRegistration = default;
_linkedRegistration1.Dispose();
_linkedRegistration1 = default;
_linkedRegistration2.Dispose();
_linkedRegistration2 = default;

if (TryReset())
{
Expand Down
Loading

0 comments on commit b5192aa

Please sign in to comment.