Skip to content

Commit

Permalink
Flow cancellation to transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
Tratcher committed Jan 12, 2023
1 parent 86ea977 commit 69fb774
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 15 deletions.
14 changes: 7 additions & 7 deletions src/ReverseProxy/Forwarder/HttpForwarder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ public async ValueTask<ForwarderError> SendAsync(
{
// :: Step 5: Copy response status line Client ◄-- Proxy ◄-- Destination
// :: Step 6: Copy response headers Client ◄-- Proxy ◄-- Destination
var copyBody = await CopyResponseStatusAndHeadersAsync(destinationResponse, context, transformer);
var copyBody = await CopyResponseStatusAndHeadersAsync(destinationResponse, context, transformer, activityCancellationSource.Token);

if (!copyBody)
{
Expand Down Expand Up @@ -269,7 +269,7 @@ public async ValueTask<ForwarderError> SendAsync(
}

// :: Step 8: Copy response trailer headers and finish response Client ◄-- Proxy ◄-- Destination
await CopyResponseTrailingHeadersAsync(destinationResponse, context, transformer);
await CopyResponseTrailingHeadersAsync(destinationResponse, context, transformer, activityCancellationSource.Token);

if (isStreamingRequest)
{
Expand Down Expand Up @@ -411,7 +411,7 @@ public async ValueTask<ForwarderError> SendAsync(
destinationRequest.Content = requestContent;

// :: Step 3: Copy request headers Client --► Proxy --► Destination
await transformer.TransformRequestAsync(context, destinationRequest, destinationPrefix);
await transformer.TransformRequestAsync(context, destinationRequest, destinationPrefix, activityToken.Token);

// The transformer generated a response, do not forward.
if (RequestUtilities.IsResponseSet(context.Response))
Expand Down Expand Up @@ -662,7 +662,7 @@ async ValueTask<ForwarderError> ReportErrorAsync(ForwarderError error, int statu
}
}

private static ValueTask<bool> CopyResponseStatusAndHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer)
private static ValueTask<bool> CopyResponseStatusAndHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer, CancellationToken cancellationToken)
{
context.Response.StatusCode = (int)source.StatusCode;

Expand All @@ -676,7 +676,7 @@ private static ValueTask<bool> CopyResponseStatusAndHeadersAsync(HttpResponseMes
}

// Copies headers
return transformer.TransformResponseAsync(context, source);
return transformer.TransformResponseAsync(context, source, cancellationToken);
}

private async ValueTask<ForwarderError> HandleUpgradedResponse(HttpContext context, HttpResponseMessage destinationResponse,
Expand Down Expand Up @@ -900,10 +900,10 @@ private async ValueTask<ForwarderError> HandleResponseBodyErrorAsync(HttpContext
return error;
}

private static ValueTask CopyResponseTrailingHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer)
private static ValueTask CopyResponseTrailingHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer, CancellationToken cancellationToken)
{
// Copies trailers
return transformer.TransformResponseTrailersAsync(context, source);
return transformer.TransformResponseTrailersAsync(context, source, cancellationToken);
}


Expand Down
43 changes: 43 additions & 0 deletions src/ReverseProxy/Forwarder/HttpTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Net.Http;
using System.Net.Http.Headers;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
Expand Down Expand Up @@ -53,6 +54,23 @@ private static bool IsBodylessStatusCode(HttpStatusCode statusCode) =>
_ => false
};

/// <summary>
/// A callback that is invoked prior to sending the proxied request. All HttpRequestMessage fields are
/// initialized except RequestUri, which will be initialized after the callback if no value is provided.
/// See <see cref="RequestUtilities.MakeDestinationAddress(string, PathString, QueryString)"/> for constructing a custom request Uri.
/// The string parameter represents the destination URI prefix that should be used when constructing the RequestUri.
/// The headers are copied by the base implementation, excluding some protocol headers like HTTP/2 pseudo headers (":authority").
/// This method may be overridden to conditionally produce a response, such as for error conditions, and prevent the request from
/// being proxied. This is indicated by setting the `HttpResponse.StatusCode` to a value other than 200, or calling `HttpResponse.StartAsync()`,
/// or writing to the `HttpResponse.Body` or `BodyWriter`.
/// </summary>
/// <param name="httpContext">The incoming request.</param>
/// <param name="proxyRequest">The outgoing proxy request.</param>
/// <param name="destinationPrefix">The uri prefix for the selected destination server which can be used to create the RequestUri.</param>
/// <param name="cancellationToken">Indicates that the request is being canceled.</param>
public virtual ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix, CancellationToken cancellationToken)
=> TransformRequestAsync(httpContext, proxyRequest, destinationPrefix);

/// <summary>
/// A callback that is invoked prior to sending the proxied request. All HttpRequestMessage fields are
/// initialized except RequestUri, which will be initialized after the callback if no value is provided.
Expand Down Expand Up @@ -126,9 +144,24 @@ public virtual ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequ
/// </summary>
/// <param name="httpContext">The incoming request.</param>
/// <param name="proxyResponse">The response from the destination. This can be null if the destination did not respond.</param>
/// <param name="cancellationToken">Indicates that the request is being canceled.</param>
/// <returns>A bool indicating if the response should be proxied to the client or not. A derived implementation
/// that returns false may send an alternate response inline or return control to the caller for it to retry, respond,
/// etc.</returns>
public virtual ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse, CancellationToken cancellationToken)
=> TransformResponseAsync(httpContext, proxyResponse);

/// <summary>
/// A callback that is invoked when the proxied response is received. The status code and reason phrase will be copied
/// to the HttpContext.Response before the callback is invoked, but may still be modified there. The headers will be
/// copied to HttpContext.Response.Headers by the base implementation, excludes certain protocol headers like
/// `Transfer-Encoding: chunked`.
/// </summary>
/// <param name="httpContext">The incoming request.</param>
/// <param name="proxyResponse">The response from the destination. This can be null if the destination did not respond.</param>
/// <returns>A bool indicating if the response should be proxied to the client or not. A derived implementation
/// that returns false may send an alternate response inline or return control to the caller for it to retry, respond,
/// etc.</returns>
public virtual ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse)
{
if (proxyResponse is null)
Expand Down Expand Up @@ -171,6 +204,16 @@ public virtual ValueTask<bool> TransformResponseAsync(HttpContext httpContext, H
return new ValueTask<bool>(true);
}

/// <summary>
/// A callback that is invoked after the response body to modify trailers, if supported. The trailers will be
/// copied to the HttpContext.Response by the base implementation.
/// </summary>
/// <param name="httpContext">The incoming request.</param>
/// <param name="proxyResponse">The response from the destination.</param>
/// <param name="cancellationToken">Indicates that the request is being canceled.</param>
public virtual ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse, CancellationToken cancellationToken)
=> TransformResponseTrailersAsync(httpContext, proxyResponse);

/// <summary>
/// A callback that is invoked after the response body to modify trailers, if supported. The trailers will be
/// copied to the HttpContext.Response by the base implementation.
Expand Down
16 changes: 10 additions & 6 deletions src/ReverseProxy/Transforms/Builder/StructuredTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
Expand Down Expand Up @@ -63,11 +64,11 @@ internal StructuredTransformer(bool? copyRequestHeaders, bool? copyResponseHeade
/// </summary>
internal ResponseTrailersTransform[] ResponseTrailerTransforms { get; }

public override async ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix)
public override async ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix, CancellationToken cancellationToken)
{
if (ShouldCopyRequestHeaders.GetValueOrDefault(true))
{
await base.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix);
await base.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, cancellationToken);
}

if (RequestTransforms.Length == 0)
Expand All @@ -83,6 +84,7 @@ public override async ValueTask TransformRequestAsync(HttpContext httpContext, H
Path = httpContext.Request.Path,
Query = new QueryTransformContext(httpContext.Request),
HeadersCopied = ShouldCopyRequestHeaders.GetValueOrDefault(true),
CancellationToken = cancellationToken,
};

foreach (var requestTransform in RequestTransforms)
Expand All @@ -101,11 +103,11 @@ public override async ValueTask TransformRequestAsync(HttpContext httpContext, H
transformContext.DestinationPrefix, transformContext.Path, transformContext.Query.QueryString);
}

public override async ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse)
public override async ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse, CancellationToken cancellationToken)
{
if (ShouldCopyResponseHeaders.GetValueOrDefault(true))
{
await base.TransformResponseAsync(httpContext, proxyResponse);
await base.TransformResponseAsync(httpContext, proxyResponse, cancellationToken);
}

if (ResponseTransforms.Length == 0)
Expand All @@ -118,6 +120,7 @@ public override async ValueTask<bool> TransformResponseAsync(HttpContext httpCon
HttpContext = httpContext,
ProxyResponse = proxyResponse,
HeadersCopied = ShouldCopyResponseHeaders.GetValueOrDefault(true),
CancellationToken = cancellationToken,
};

foreach (var responseTransform in ResponseTransforms)
Expand All @@ -128,11 +131,11 @@ public override async ValueTask<bool> TransformResponseAsync(HttpContext httpCon
return !transformContext.SuppressResponseBody;
}

public override async ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse)
public override async ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse, CancellationToken cancellationToken)
{
if (ShouldCopyResponseTrailers.GetValueOrDefault(true))
{
await base.TransformResponseTrailersAsync(httpContext, proxyResponse);
await base.TransformResponseTrailersAsync(httpContext, proxyResponse, cancellationToken);
}

if (ResponseTrailerTransforms.Length == 0)
Expand All @@ -150,6 +153,7 @@ public override async ValueTask TransformResponseTrailersAsync(HttpContext httpC
HttpContext = httpContext,
ProxyResponse = proxyResponse,
HeadersCopied = ShouldCopyResponseTrailers.GetValueOrDefault(true),
CancellationToken = cancellationToken,
};

foreach (var responseTrailerTransform in ResponseTrailerTransforms)
Expand Down
6 changes: 6 additions & 0 deletions src/ReverseProxy/Transforms/RequestTransformContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System.Net.Http;
using System.Threading;
using Microsoft.AspNetCore.Http;

namespace Yarp.ReverseProxy.Transforms;
Expand Down Expand Up @@ -48,4 +49,9 @@ public class RequestTransformContext
/// port and path base. The 'Path' and 'Query' properties will be appended to this after the transforms have run.
/// </summary>
public string DestinationPrefix { get; init; } = default!;

/// <summary>
/// A <see cref="CancellationToken"/> indicating that the request is being aborted.
/// </summary>
public CancellationToken CancellationToken { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System.Net.Http;
using System.Threading;
using Microsoft.AspNetCore.Http;

namespace Yarp.ReverseProxy.Transforms;
Expand All @@ -27,4 +28,9 @@ public class ResponseTrailersTransformContext
/// should operate on.
/// </summary>
public bool HeadersCopied { get; set; }

/// <summary>
/// A <see cref="CancellationToken"/> indicating that the request is being aborted.
/// </summary>
public CancellationToken CancellationToken { get; set; }
}
6 changes: 6 additions & 0 deletions src/ReverseProxy/Transforms/ResponseTransformContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System.Net.Http;
using System.Threading;
using Microsoft.AspNetCore.Http;

namespace Yarp.ReverseProxy.Transforms;
Expand Down Expand Up @@ -33,4 +34,9 @@ public class ResponseTransformContext
/// Defaults to false.
/// </summary>
public bool SuppressResponseBody { get; set; }

/// <summary>
/// A <see cref="CancellationToken"/> indicating that the request is being aborted.
/// </summary>
public CancellationToken CancellationToken { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.HttpOverrides;
Expand Down Expand Up @@ -307,7 +308,7 @@ public async Task UseOriginalHost(bool? useOriginalHost, bool? copyHeaders)
httpContext.Request.Host = new HostString("StartHost");
var proxyRequest = new HttpRequestMessage();
var destinationPrefix = "http://destinationhost:9090/path";
await results.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix);
await results.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, CancellationToken.None);

if (useOriginalHost.GetValueOrDefault(false))
{
Expand Down Expand Up @@ -373,7 +374,7 @@ public async Task UseCustomHost(bool? useOriginalHost, bool? copyHeaders)
var proxyRequest = new HttpRequestMessage();
var destinationPrefix = "http://destinationhost:9090/path";

await results.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix);
await results.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, CancellationToken.None);

Assert.Equal("CustomHost", proxyRequest.Headers.Host);
}
Expand Down

0 comments on commit 69fb774

Please sign in to comment.