Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable direct cancellation for IHttpForwarder #1985

Merged
merged 4 commits into from
Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/ReverseProxy/Forwarder/ForwarderRequestConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

using System;
using System.Net.Http;
using System.Threading;
using Microsoft.AspNetCore.Http;

namespace Yarp.ReverseProxy.Forwarder;

/// <summary>
/// Config for <see cref="IHttpForwarder.SendAsync"/>
/// Config for <see cref="IHttpForwarder.SendAsync(HttpContext, string, HttpMessageInvoker, ForwarderRequestConfig, HttpTransformer, CancellationToken)"/>
/// </summary>
public sealed record ForwarderRequestConfig
{
Expand Down
30 changes: 21 additions & 9 deletions src/ReverseProxy/Forwarder/HttpForwarder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,21 @@ public HttpForwarder(ILogger<HttpForwarder> logger, IClock clock)
/// ASP .NET Core (Kestrel) will finally send response trailers (if any)
/// after we complete the steps above and relinquish control.
/// </remarks>
public async ValueTask<ForwarderError> SendAsync(
public ValueTask<ForwarderError> SendAsync(
HttpContext context,
string destinationPrefix,
HttpMessageInvoker httpClient,
ForwarderRequestConfig requestConfig,
HttpTransformer transformer)
=> SendAsync(context, destinationPrefix, httpClient, requestConfig, transformer, CancellationToken.None);

public async ValueTask<ForwarderError> SendAsync(
HttpContext context,
string destinationPrefix,
HttpMessageInvoker httpClient,
ForwarderRequestConfig requestConfig,
HttpTransformer transformer,
CancellationToken cancellationToken)
{
_ = context ?? throw new ArgumentNullException(nameof(context));
_ = destinationPrefix ?? throw new ArgumentNullException(nameof(destinationPrefix));
Expand All @@ -110,7 +119,7 @@ public async ValueTask<ForwarderError> SendAsync(

ForwarderTelemetry.Log.ForwarderStart(destinationPrefix);

var activityCancellationSource = ActivityCancellationTokenSource.Rent(requestConfig?.ActivityTimeout ?? DefaultTimeout, context.RequestAborted);
var activityCancellationSource = ActivityCancellationTokenSource.Rent(requestConfig?.ActivityTimeout ?? DefaultTimeout, context.RequestAborted, cancellationToken);
try
{
var isClientHttp2OrGreater = ProtocolHelper.IsHttp2OrGreater(context.Request.Protocol);
Expand Down Expand Up @@ -193,7 +202,7 @@ public async ValueTask<ForwarderError> SendAsync(
{
// :: Step 5: Copy response status line Client ◄-- Proxy ◄-- Destination
// :: Step 6: Copy response headers Client ◄-- Proxy ◄-- Destination
var copyBody = await CopyResponseStatusAndHeadersAsync(destinationResponse, context, transformer);
var copyBody = await CopyResponseStatusAndHeadersAsync(destinationResponse, context, transformer, activityCancellationSource.Token);

if (!copyBody)
{
Expand Down Expand Up @@ -260,7 +269,7 @@ public async ValueTask<ForwarderError> SendAsync(
}

// :: Step 8: Copy response trailer headers and finish response Client ◄-- Proxy ◄-- Destination
await CopyResponseTrailingHeadersAsync(destinationResponse, context, transformer);
await CopyResponseTrailingHeadersAsync(destinationResponse, context, transformer, activityCancellationSource.Token);

if (isStreamingRequest)
{
Expand Down Expand Up @@ -402,14 +411,17 @@ public async ValueTask<ForwarderError> SendAsync(
destinationRequest.Content = requestContent;

// :: Step 3: Copy request headers Client --► Proxy --► Destination
await transformer.TransformRequestAsync(context, destinationRequest, destinationPrefix);
await transformer.TransformRequestAsync(context, destinationRequest, destinationPrefix, activityToken.Token);

// The transformer generated a response, do not forward.
if (RequestUtilities.IsResponseSet(context.Response))
{
return (destinationRequest, requestContent, false);
}

// Transforms may have taken a while, especially if they buffered the body, they count as forward progress.
activityToken.ResetTimeout();

FixupUpgradeRequestHeaders(context, destinationRequest, outgoingUpgrade, outgoingConnect);

// Allow someone to custom build the request uri, otherwise provide a default for them.
Expand Down Expand Up @@ -653,7 +665,7 @@ async ValueTask<ForwarderError> ReportErrorAsync(ForwarderError error, int statu
}
}

private static ValueTask<bool> CopyResponseStatusAndHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer)
private static ValueTask<bool> CopyResponseStatusAndHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer, CancellationToken cancellationToken)
{
context.Response.StatusCode = (int)source.StatusCode;

Expand All @@ -667,7 +679,7 @@ private static ValueTask<bool> CopyResponseStatusAndHeadersAsync(HttpResponseMes
}

// Copies headers
return transformer.TransformResponseAsync(context, source);
return transformer.TransformResponseAsync(context, source, cancellationToken);
}

private async ValueTask<ForwarderError> HandleUpgradedResponse(HttpContext context, HttpResponseMessage destinationResponse,
Expand Down Expand Up @@ -891,10 +903,10 @@ private async ValueTask<ForwarderError> HandleResponseBodyErrorAsync(HttpContext
return error;
}

private static ValueTask CopyResponseTrailingHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer)
private static ValueTask CopyResponseTrailingHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer, CancellationToken cancellationToken)
{
// Copies trailers
return transformer.TransformResponseTrailersAsync(context, source);
return transformer.TransformResponseTrailersAsync(context, source, cancellationToken);
}


Expand Down
43 changes: 43 additions & 0 deletions src/ReverseProxy/Forwarder/HttpTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Net.Http;
using System.Net.Http.Headers;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
Expand Down Expand Up @@ -53,6 +54,23 @@ private static bool IsBodylessStatusCode(HttpStatusCode statusCode) =>
_ => false
};

/// <summary>
/// A callback that is invoked prior to sending the proxied request. All HttpRequestMessage fields are
/// initialized except RequestUri, which will be initialized after the callback if no value is provided.
/// See <see cref="RequestUtilities.MakeDestinationAddress(string, PathString, QueryString)"/> for constructing a custom request Uri.
/// The string parameter represents the destination URI prefix that should be used when constructing the RequestUri.
/// The headers are copied by the base implementation, excluding some protocol headers like HTTP/2 pseudo headers (":authority").
/// This method may be overridden to conditionally produce a response, such as for error conditions, and prevent the request from
/// being proxied. This is indicated by setting the `HttpResponse.StatusCode` to a value other than 200, or calling `HttpResponse.StartAsync()`,
/// or writing to the `HttpResponse.Body` or `BodyWriter`.
/// </summary>
/// <param name="httpContext">The incoming request.</param>
/// <param name="proxyRequest">The outgoing proxy request.</param>
/// <param name="destinationPrefix">The uri prefix for the selected destination server which can be used to create the RequestUri.</param>
/// <param name="cancellationToken">Indicates that the request is being canceled.</param>
public virtual ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix, CancellationToken cancellationToken)
=> TransformRequestAsync(httpContext, proxyRequest, destinationPrefix);

/// <summary>
/// A callback that is invoked prior to sending the proxied request. All HttpRequestMessage fields are
/// initialized except RequestUri, which will be initialized after the callback if no value is provided.
Expand Down Expand Up @@ -126,9 +144,24 @@ public virtual ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequ
/// </summary>
/// <param name="httpContext">The incoming request.</param>
/// <param name="proxyResponse">The response from the destination. This can be null if the destination did not respond.</param>
/// <param name="cancellationToken">Indicates that the request is being canceled.</param>
/// <returns>A bool indicating if the response should be proxied to the client or not. A derived implementation
/// that returns false may send an alternate response inline or return control to the caller for it to retry, respond,
/// etc.</returns>
public virtual ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse, CancellationToken cancellationToken)
=> TransformResponseAsync(httpContext, proxyResponse);

/// <summary>
/// A callback that is invoked when the proxied response is received. The status code and reason phrase will be copied
/// to the HttpContext.Response before the callback is invoked, but may still be modified there. The headers will be
/// copied to HttpContext.Response.Headers by the base implementation, excludes certain protocol headers like
/// `Transfer-Encoding: chunked`.
/// </summary>
/// <param name="httpContext">The incoming request.</param>
/// <param name="proxyResponse">The response from the destination. This can be null if the destination did not respond.</param>
/// <returns>A bool indicating if the response should be proxied to the client or not. A derived implementation
/// that returns false may send an alternate response inline or return control to the caller for it to retry, respond,
/// etc.</returns>
public virtual ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse)
{
if (proxyResponse is null)
Expand Down Expand Up @@ -171,6 +204,16 @@ public virtual ValueTask<bool> TransformResponseAsync(HttpContext httpContext, H
return new ValueTask<bool>(true);
}

/// <summary>
/// A callback that is invoked after the response body to modify trailers, if supported. The trailers will be
/// copied to the HttpContext.Response by the base implementation.
/// </summary>
/// <param name="httpContext">The incoming request.</param>
/// <param name="proxyResponse">The response from the destination.</param>
/// <param name="cancellationToken">Indicates that the request is being canceled.</param>
public virtual ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse, CancellationToken cancellationToken)
=> TransformResponseTrailersAsync(httpContext, proxyResponse);

/// <summary>
/// A callback that is invoked after the response body to modify trailers, if supported. The trailers will be
/// copied to the HttpContext.Response by the base implementation.
Expand Down
16 changes: 16 additions & 0 deletions src/ReverseProxy/Forwarder/IHttpForwarder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;

Expand All @@ -24,4 +25,19 @@ public interface IHttpForwarder
/// <returns>The result of forwarding the request and response.</returns>
ValueTask<ForwarderError> SendAsync(HttpContext context, string destinationPrefix, HttpMessageInvoker httpClient,
ForwarderRequestConfig requestConfig, HttpTransformer transformer);

/// <summary>
/// Forwards the incoming request to the destination server, and the response back to the client.
/// </summary>
/// <param name="context">The HttpContext to forward.</param>
/// <param name="destinationPrefix">The url prefix for where to forward the request to.</param>
/// <param name="httpClient">The HTTP client used to forward the request.</param>
/// <param name="requestConfig">Config for the outgoing request.</param>
/// <param name="transformer">Request and response transforms. Use <see cref="HttpTransformer.Default"/> if
/// custom transformations are not needed.</param>
/// <param name="cancellationToken">A cancellation token that can be used to abort the request.</param>
/// <returns>The result of forwarding the request and response.</returns>
ValueTask<ForwarderError> SendAsync(HttpContext context, string destinationPrefix, HttpMessageInvoker httpClient,
ForwarderRequestConfig requestConfig, HttpTransformer transformer, CancellationToken cancellationToken)
=> SendAsync(context, destinationPrefix, httpClient, requestConfig, transformer);
}
16 changes: 10 additions & 6 deletions src/ReverseProxy/Transforms/Builder/StructuredTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
Expand Down Expand Up @@ -63,11 +64,11 @@ internal StructuredTransformer(bool? copyRequestHeaders, bool? copyResponseHeade
/// </summary>
internal ResponseTrailersTransform[] ResponseTrailerTransforms { get; }

public override async ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix)
public override async ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix, CancellationToken cancellationToken)
{
if (ShouldCopyRequestHeaders.GetValueOrDefault(true))
{
await base.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix);
await base.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, cancellationToken);
}

if (RequestTransforms.Length == 0)
Expand All @@ -83,6 +84,7 @@ public override async ValueTask TransformRequestAsync(HttpContext httpContext, H
Path = httpContext.Request.Path,
Query = new QueryTransformContext(httpContext.Request),
HeadersCopied = ShouldCopyRequestHeaders.GetValueOrDefault(true),
CancellationToken = cancellationToken,
};

foreach (var requestTransform in RequestTransforms)
Expand All @@ -101,11 +103,11 @@ public override async ValueTask TransformRequestAsync(HttpContext httpContext, H
transformContext.DestinationPrefix, transformContext.Path, transformContext.Query.QueryString);
}

public override async ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse)
public override async ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse, CancellationToken cancellationToken)
{
if (ShouldCopyResponseHeaders.GetValueOrDefault(true))
{
await base.TransformResponseAsync(httpContext, proxyResponse);
await base.TransformResponseAsync(httpContext, proxyResponse, cancellationToken);
}

if (ResponseTransforms.Length == 0)
Expand All @@ -118,6 +120,7 @@ public override async ValueTask<bool> TransformResponseAsync(HttpContext httpCon
HttpContext = httpContext,
ProxyResponse = proxyResponse,
HeadersCopied = ShouldCopyResponseHeaders.GetValueOrDefault(true),
CancellationToken = cancellationToken,
};

foreach (var responseTransform in ResponseTransforms)
Expand All @@ -128,11 +131,11 @@ public override async ValueTask<bool> TransformResponseAsync(HttpContext httpCon
return !transformContext.SuppressResponseBody;
}

public override async ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse)
public override async ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse, CancellationToken cancellationToken)
{
if (ShouldCopyResponseTrailers.GetValueOrDefault(true))
{
await base.TransformResponseTrailersAsync(httpContext, proxyResponse);
await base.TransformResponseTrailersAsync(httpContext, proxyResponse, cancellationToken);
}

if (ResponseTrailerTransforms.Length == 0)
Expand All @@ -150,6 +153,7 @@ public override async ValueTask TransformResponseTrailersAsync(HttpContext httpC
HttpContext = httpContext,
ProxyResponse = proxyResponse,
HeadersCopied = ShouldCopyResponseTrailers.GetValueOrDefault(true),
CancellationToken = cancellationToken,
};

foreach (var responseTrailerTransform in ResponseTrailerTransforms)
Expand Down
6 changes: 6 additions & 0 deletions src/ReverseProxy/Transforms/RequestTransformContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System.Net.Http;
using System.Threading;
using Microsoft.AspNetCore.Http;

namespace Yarp.ReverseProxy.Transforms;
Expand Down Expand Up @@ -48,4 +49,9 @@ public class RequestTransformContext
/// port and path base. The 'Path' and 'Query' properties will be appended to this after the transforms have run.
/// </summary>
public string DestinationPrefix { get; init; } = default!;

/// <summary>
/// A <see cref="CancellationToken"/> indicating that the request is being aborted.
/// </summary>
public CancellationToken CancellationToken { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System.Net.Http;
using System.Threading;
using Microsoft.AspNetCore.Http;

namespace Yarp.ReverseProxy.Transforms;
Expand All @@ -27,4 +28,9 @@ public class ResponseTrailersTransformContext
/// should operate on.
/// </summary>
public bool HeadersCopied { get; set; }

/// <summary>
/// A <see cref="CancellationToken"/> indicating that the request is being aborted.
/// </summary>
public CancellationToken CancellationToken { get; set; }
}
6 changes: 6 additions & 0 deletions src/ReverseProxy/Transforms/ResponseTransformContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System.Net.Http;
using System.Threading;
using Microsoft.AspNetCore.Http;

namespace Yarp.ReverseProxy.Transforms;
Expand Down Expand Up @@ -33,4 +34,9 @@ public class ResponseTransformContext
/// Defaults to false.
/// </summary>
public bool SuppressResponseBody { get; set; }

/// <summary>
/// A <see cref="CancellationToken"/> indicating that the request is being aborted.
/// </summary>
public CancellationToken CancellationToken { get; set; }
}
14 changes: 9 additions & 5 deletions src/ReverseProxy/Utilities/ActivityCancellationTokenSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ internal sealed class ActivityCancellationTokenSource : CancellationTokenSource
};

private int _activityTimeoutMs;
private CancellationTokenRegistration _linkedRegistration;
private CancellationTokenRegistration _linkedRegistration1;
private CancellationTokenRegistration _linkedRegistration2;

private ActivityCancellationTokenSource() { }

Expand All @@ -28,7 +29,7 @@ public void ResetTimeout()
CancelAfter(_activityTimeoutMs);
}

public static ActivityCancellationTokenSource Rent(TimeSpan activityTimeout, CancellationToken linkedToken)
public static ActivityCancellationTokenSource Rent(TimeSpan activityTimeout, CancellationToken linkedToken1 = default, CancellationToken linkedToken2 = default)
{
if (_sharedSources.TryDequeue(out var cts))
{
Expand All @@ -40,16 +41,19 @@ public static ActivityCancellationTokenSource Rent(TimeSpan activityTimeout, Can
}

cts._activityTimeoutMs = (int)activityTimeout.TotalMilliseconds;
cts._linkedRegistration = linkedToken.UnsafeRegister(_linkedTokenCancelDelegate, cts);
cts._linkedRegistration1 = linkedToken1.UnsafeRegister(_linkedTokenCancelDelegate, cts);
cts._linkedRegistration2 = linkedToken2.UnsafeRegister(_linkedTokenCancelDelegate, cts);
cts.ResetTimeout();

return cts;
}

public void Return()
{
_linkedRegistration.Dispose();
_linkedRegistration = default;
_linkedRegistration1.Dispose();
_linkedRegistration1 = default;
_linkedRegistration2.Dispose();
_linkedRegistration2 = default;

if (TryReset())
{
Expand Down
Loading