diff --git a/src/ReverseProxy/Forwarder/HttpForwarder.cs b/src/ReverseProxy/Forwarder/HttpForwarder.cs index d290c334b..01046f1fc 100644 --- a/src/ReverseProxy/Forwarder/HttpForwarder.cs +++ b/src/ReverseProxy/Forwarder/HttpForwarder.cs @@ -416,7 +416,7 @@ public async ValueTask SendAsync( // :: Step 2: Setup copy of request body (background) Client --► Proxy --► Destination // Note that we must do this before step (3) because step (3) may also add headers to the HttpContent that we set up here. - var requestContent = SetupRequestBodyCopy(context.Request, isStreamingRequest, activityToken); + var requestContent = SetupRequestBodyCopy(context, isStreamingRequest, activityToken); destinationRequest.Content = requestContent; // :: Step 3: Copy request headers Client --► Proxy --► Destination @@ -496,12 +496,13 @@ private void FixupUpgradeRequestHeaders(HttpContext context, HttpRequestMessage // else not an upgrade, or H2->H2, no changes needed } - private StreamCopyHttpContent? SetupRequestBodyCopy(HttpRequest request, bool isStreamingRequest, ActivityCancellationTokenSource activityToken) + private StreamCopyHttpContent? SetupRequestBodyCopy(HttpContext context, bool isStreamingRequest, ActivityCancellationTokenSource activityToken) { // If we generate an HttpContent without a Content-Length then for HTTP/1.1 HttpClient will add a Transfer-Encoding: chunked header // even if it's a GET request. Some servers reject requests containing a Transfer-Encoding header if they're not expecting a body. // Try to be as specific as possible about the client's intent to send a body. The one thing we don't want to do is to start // reading the body early because that has side-effects like 100-continue. + var request = context.Request; var hasBody = true; var contentLength = request.Headers.ContentLength; var method = request.Method; @@ -512,10 +513,11 @@ private void FixupUpgradeRequestHeaders(HttpContext context, HttpRequestMessage // 5.0 servers provide a definitive answer for us. hasBody = canHaveBodyFeature.CanHaveBody; - // TODO: Kestrel bug, this shouldn't be true for ExtendedConnect. -#if NET7_0_OR_GREATER +#if NET7_0 + // TODO: Kestrel 7.0 bug only, hasBody shouldn't be true for ExtendedConnect. + // https://github.com/dotnet/aspnetcore/issues/46002 Fixed in 8.0 var connectFeature = request.HttpContext.Features.Get(); - if (connectFeature?.Protocol != null) + if (connectFeature?.IsExtendedConnect == true) { hasBody = false; } @@ -560,31 +562,13 @@ private void FixupUpgradeRequestHeaders(HttpContext context, HttpRequestMessage if (hasBody) { - if (isStreamingRequest) - { - DisableMinRequestBodyDataRateAndMaxRequestBodySize(request.HttpContext); - } - - // Note on `autoFlushHttpClientOutgoingStream: isStreamingRequest`: - // The.NET Core HttpClient stack keeps its own buffers on top of the underlying outgoing connection socket. - // We flush those buffers down to the socket on every write when this is set, - // but it does NOT result in calls to flush on the underlying socket. - // This is necessary because we proxy http2 transparently, - // and we are deliberately unaware of packet structure used e.g. in gRPC duplex channels. - // Because the sockets aren't flushed, the perf impact of this choice is expected to be small. - // Future: It may be wise to set this to true for *all* http2 incoming requests, - // but for now, out of an abundance of caution, we only do it for requests that look like gRPC. - return new StreamCopyHttpContent( - request: request, - autoFlushHttpClientOutgoingStream: isStreamingRequest, - clock: _clock, - activityToken); + return new StreamCopyHttpContent(context, isStreamingRequest, _clock, _logger, activityToken); } return null; } - private ForwarderError HandleRequestBodyFailure(HttpContext context, StreamCopyResult requestBodyCopyResult, Exception requestBodyException, Exception additionalException) + private ForwarderError HandleRequestBodyFailure(HttpContext context, StreamCopyResult requestBodyCopyResult, Exception requestBodyException, Exception additionalException, bool timedOut) { ForwarderError requestBodyError; int statusCode; @@ -593,19 +577,12 @@ private ForwarderError HandleRequestBodyFailure(HttpContext context, StreamCopyR // Failed while trying to copy the request body from the client. It's ambiguous if the request or response failed first. case StreamCopyResult.InputError: requestBodyError = ForwarderError.RequestBodyClient; - statusCode = StatusCodes.Status400BadRequest; + statusCode = timedOut ? StatusCodes.Status408RequestTimeout : StatusCodes.Status400BadRequest; break; // Failed while trying to copy the request body to the destination. It's ambiguous if the request or response failed first. case StreamCopyResult.OutputError: requestBodyError = ForwarderError.RequestBodyDestination; - statusCode = StatusCodes.Status502BadGateway; - break; - // Canceled while trying to copy the request body, either due to a client disconnect or a timeout. This probably caused the response to fail as a secondary error. - case StreamCopyResult.Canceled: - requestBodyError = ForwarderError.RequestBodyCanceled; - // Timeouts (504s) are handled at the SendAsync call site. - // The request body should only be canceled by the RequestAborted token. - statusCode = StatusCodes.Status502BadGateway; + statusCode = timedOut ? StatusCodes.Status504GatewayTimeout : StatusCodes.Status502BadGateway; break; default: throw new NotImplementedException(requestBodyCopyResult.ToString()); @@ -630,33 +607,46 @@ private ForwarderError HandleRequestBodyFailure(HttpContext context, StreamCopyR private async ValueTask HandleRequestFailureAsync(HttpContext context, StreamCopyHttpContent? requestContent, Exception requestException, HttpTransformer transformer, ActivityCancellationTokenSource requestCancellationSource, bool failedDuringRequestCreation) { - if (requestException is OperationCanceledException) + var triedRequestBody = requestContent?.ConsumptionTask.IsCompleted == true; + + if (requestCancellationSource.CancelledByLinkedToken) { - if (requestCancellationSource.CancelledByLinkedToken) + var requestBodyCanceled = false; + if (triedRequestBody) { - // Either the client went away (HttpContext.RequestAborted) or the CancellationToken provided to SendAsync was signaled. - return await ReportErrorAsync(ForwarderError.RequestCanceled, StatusCodes.Status502BadGateway); - } - else - { - Debug.Assert(requestCancellationSource.IsCancellationRequested || requestException.ToString().Contains("ConnectTimeout"), requestException.ToString()); - return await ReportErrorAsync(ForwarderError.RequestTimedOut, StatusCodes.Status504GatewayTimeout); + var (requestBodyCopyResult, requestBodyException) = requestContent!.ConsumptionTask.Result; + requestBodyCanceled = requestBodyCopyResult == StreamCopyResult.Canceled; + if (requestBodyCanceled) + { + requestException = new AggregateException(requestException, requestBodyException!); + } } + // Either the client went away (HttpContext.RequestAborted) or the CancellationToken provided to SendAsync was signaled. + return await ReportErrorAsync(requestBodyCanceled ? ForwarderError.RequestBodyCanceled : ForwarderError.RequestCanceled, + context.RequestAborted.IsCancellationRequested ? StatusCodes.Status400BadRequest : StatusCodes.Status502BadGateway); } // Check for request body errors, these may have triggered the response error. - if (requestContent?.ConsumptionTask.IsCompleted == true) + if (triedRequestBody) { - var (requestBodyCopyResult, requestBodyException) = requestContent.ConsumptionTask.Result; + var (requestBodyCopyResult, requestBodyException) = requestContent!.ConsumptionTask.Result; - if (requestBodyCopyResult != StreamCopyResult.Success) + if (requestBodyCopyResult is StreamCopyResult.InputError or StreamCopyResult.OutputError) { - var error = HandleRequestBodyFailure(context, requestBodyCopyResult, requestBodyException!, requestException); + var error = HandleRequestBodyFailure(context, requestBodyCopyResult, requestBodyException!, requestException, + timedOut: requestCancellationSource.IsCancellationRequested); await transformer.TransformResponseAsync(context, proxyResponse: null, requestCancellationSource.Token); return error; } } + if (requestException is OperationCanceledException) + { + Debug.Assert(requestCancellationSource.IsCancellationRequested || requestException.ToString().Contains("ConnectTimeout"), requestException.ToString()); + + return await ReportErrorAsync(ForwarderError.RequestTimedOut, StatusCodes.Status504GatewayTimeout); + } + // We couldn't communicate with the destination. return await ReportErrorAsync(failedDuringRequestCreation ? ForwarderError.RequestCreation : ForwarderError.Request, StatusCodes.Status502BadGateway); @@ -870,7 +860,7 @@ private ForwarderError FixupUpgradeResponseHeaders(HttpContext context, HttpResp return (StreamCopyResult.Success, null); } - private async ValueTask HandleResponseBodyErrorAsync(HttpContext context, StreamCopyHttpContent? requestContent, StreamCopyResult responseBodyCopyResult, Exception responseBodyException, CancellationTokenSource requestCancellationSource) + private async ValueTask HandleResponseBodyErrorAsync(HttpContext context, StreamCopyHttpContent? requestContent, StreamCopyResult responseBodyCopyResult, Exception responseBodyException, ActivityCancellationTokenSource requestCancellationSource) { if (requestContent is not null && requestContent.Started) { @@ -884,9 +874,10 @@ private async ValueTask HandleResponseBodyErrorAsync(HttpContext var (requestBodyCopyResult, requestBodyError) = await requestContent.ConsumptionTask; // Check for request body errors, these may have triggered the response error. - if (alreadyFinished && requestBodyCopyResult != StreamCopyResult.Success) + if (alreadyFinished && requestBodyCopyResult is StreamCopyResult.InputError or StreamCopyResult.OutputError) { - return HandleRequestBodyFailure(context, requestBodyCopyResult, requestBodyError!, responseBodyException); + return HandleRequestBodyFailure(context, requestBodyCopyResult, requestBodyError!, responseBodyException, + timedOut: requestCancellationSource.IsCancellationRequested && !requestCancellationSource.CancelledByLinkedToken); } } @@ -920,41 +911,6 @@ private static ValueTask CopyResponseTrailingHeadersAsync(HttpResponseMessage so return transformer.TransformResponseTrailersAsync(context, source, cancellationToken); } - - /// - /// Disable some ASP .NET Core server limits so that we can handle long-running gRPC requests unconstrained. - /// Note that the gRPC server implementation on ASP .NET Core does the same for client-streaming and duplex methods. - /// Since in Gateway we have no way to determine if the current request requires client-streaming or duplex comm, - /// we do this for *all* incoming requests that look like they might be gRPC. - /// - /// - /// Inspired on - /// . - /// - private void DisableMinRequestBodyDataRateAndMaxRequestBodySize(HttpContext httpContext) - { - var minRequestBodyDataRateFeature = httpContext.Features.Get(); - if (minRequestBodyDataRateFeature is not null) - { - minRequestBodyDataRateFeature.MinDataRate = null; - } - - var maxRequestBodySizeFeature = httpContext.Features.Get(); - if (maxRequestBodySizeFeature is not null) - { - if (!maxRequestBodySizeFeature.IsReadOnly) - { - maxRequestBodySizeFeature.MaxRequestBodySize = null; - } - else - { - // IsReadOnly could be true if middleware has already started reading the request body - // In that case we can't disable the max request body size for the request stream - _logger.LogWarning("Unable to disable max request body size."); - } - } - } - private void ReportProxyError(HttpContext context, ForwarderError error, Exception ex) { context.Features.Set(new ForwarderErrorFeature(error, ex)); diff --git a/src/ReverseProxy/Forwarder/StreamCopier.cs b/src/ReverseProxy/Forwarder/StreamCopier.cs index f34a0ff8e..0bea0b021 100644 --- a/src/ReverseProxy/Forwarder/StreamCopier.cs +++ b/src/ReverseProxy/Forwarder/StreamCopier.cs @@ -124,9 +124,13 @@ internal static class StreamCopier telemetry?.AfterWrite(); } - var result = ex is OperationCanceledException ? StreamCopyResult.Canceled : - (read == 0 ? StreamCopyResult.InputError : StreamCopyResult.OutputError); + if (activityToken.CancelledByLinkedToken) + { + return (StreamCopyResult.Canceled, ex); + } + // If the activity timeout triggered while reading or writing, blame the sender or receiver. + var result = read == 0 ? StreamCopyResult.InputError : StreamCopyResult.OutputError; return (result, ex); } finally diff --git a/src/ReverseProxy/Forwarder/StreamCopyHttpContent.cs b/src/ReverseProxy/Forwarder/StreamCopyHttpContent.cs index 7fd1a541d..12f433f94 100644 --- a/src/ReverseProxy/Forwarder/StreamCopyHttpContent.cs +++ b/src/ReverseProxy/Forwarder/StreamCopyHttpContent.cs @@ -9,6 +9,9 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; +using Microsoft.Extensions.Logging; using Yarp.ReverseProxy.Utilities; namespace Yarp.ReverseProxy.Forwarder; @@ -39,21 +42,22 @@ namespace Yarp.ReverseProxy.Forwarder; /// internal sealed class StreamCopyHttpContent : HttpContent { - private readonly HttpRequest _request; + private readonly HttpContext _context; // HttpClient's machinery keeps an internal buffer that doesn't get flushed to the socket on every write. // Some protocols (e.g. gRPC) may rely on specific bytes being sent, and HttpClient's buffering would prevent it. - private readonly bool _autoFlushHttpClientOutgoingStream; + private bool _isStreamingRequest; private readonly IClock _clock; + private readonly ILogger _logger; private readonly ActivityCancellationTokenSource _activityToken; private readonly TaskCompletionSource<(StreamCopyResult, Exception?)> _tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); private int _started; - public StreamCopyHttpContent(HttpRequest request, bool autoFlushHttpClientOutgoingStream, IClock clock, ActivityCancellationTokenSource activityToken) + public StreamCopyHttpContent(HttpContext context, bool isStreamingRequest, IClock clock, ILogger logger, ActivityCancellationTokenSource activityToken) { - _request = request ?? throw new ArgumentNullException(nameof(request)); - _autoFlushHttpClientOutgoingStream = autoFlushHttpClientOutgoingStream; + _context = context ?? throw new ArgumentNullException(nameof(context)); + _isStreamingRequest = isStreamingRequest; _clock = clock ?? throw new ArgumentNullException(nameof(clock)); - + _logger = logger; _activityToken = activityToken; } @@ -137,11 +141,22 @@ protected override async Task SerializeToStreamAsync(Stream stream, TransportCon // _cancellation will be the same as cancellationToken for HTTP/1.1, so we can avoid the overhead of linking them CancellationTokenSource? linkedCts = null; - if (_activityToken.Token != cancellationToken) + if (_activityToken.Token == cancellationToken) + { + // We're talking to the destination via HTTP/1.1, so this can't be a streaming gRPC request. + _isStreamingRequest = false; + // TODO: Log if _isStreamingRequest is true? Something went wrong with protocol selection. + } + else { Debug.Assert(cancellationToken.CanBeCanceled); linkedCts = CancellationTokenSource.CreateLinkedTokenSource(_activityToken.Token, cancellationToken); cancellationToken = linkedCts.Token; + + if (_isStreamingRequest) + { + DisableMinRequestBodyDataRateAndMaxRequestBodySize(_context); + } } try @@ -163,8 +178,20 @@ protected override async Task SerializeToStreamAsync(Stream stream, TransportCon return; } - // Check that the content-length matches the request body size. This can be removed in .NET 7 now that SocketsHttpHandler enforces this: https://github.com/dotnet/runtime/issues/62258. - var (result, error) = await StreamCopier.CopyAsync(isRequest: true, _request.Body, stream, Headers.ContentLength ?? StreamCopier.UnknownLength, _clock, _activityToken, _autoFlushHttpClientOutgoingStream, cancellationToken); + // Check that the content-length matches the request body size. This can be removed in .NET 7 now that SocketsHttpHandler + // enforces this: https://github.com/dotnet/runtime/issues/62258. + // + // Note on `_isStreamingRequest`: + // The.NET Core HttpClient stack keeps its own buffers on top of the underlying outgoing connection socket. + // We flush those buffers down to the socket on every write when this is set, + // but it does NOT result in calls to flush on the underlying socket. + // This is necessary because we proxy http2 transparently, + // and we are deliberately unaware of packet structure used e.g. in gRPC duplex channels. + // Because the sockets aren't flushed, the perf impact of this choice is expected to be small. + // Future: It may be wise to set this to true for *all* http2 incoming requests, + // but for now, out of an abundance of caution, we only do it for requests that look like gRPC. + var (result, error) = await StreamCopier.CopyAsync(isRequest: true, _context.Request.Body, stream, + Headers.ContentLength ?? StreamCopier.UnknownLength, _clock, _activityToken, _isStreamingRequest, cancellationToken); _tcs.TrySetResult((result, error)); // Check for errors that weren't the result of the destination failing. @@ -199,4 +226,38 @@ protected override bool TryComputeLength(out long length) length = -1; return false; } + + /// + /// Disable some ASP .NET Core server limits so that we can handle long-running gRPC requests unconstrained. + /// Note that the gRPC server implementation on ASP .NET Core does the same for client-streaming and duplex methods. + /// Since in Gateway we have no way to determine if the current request requires client-streaming or duplex comm, + /// we do this for *all* incoming requests that look like they might be gRPC. + /// + /// + /// Inspired on + /// . + /// + private void DisableMinRequestBodyDataRateAndMaxRequestBodySize(HttpContext httpContext) + { + var minRequestBodyDataRateFeature = httpContext.Features.Get(); + if (minRequestBodyDataRateFeature is not null) + { + minRequestBodyDataRateFeature.MinDataRate = null; + } + + var maxRequestBodySizeFeature = httpContext.Features.Get(); + if (maxRequestBodySizeFeature is not null) + { + if (!maxRequestBodySizeFeature.IsReadOnly) + { + maxRequestBodySizeFeature.MaxRequestBodySize = null; + } + else + { + // IsReadOnly could be true if middleware has already started reading the request body + // In that case we can't disable the max request body size for the request stream + _logger.LogWarning("Unable to disable max request body size."); + } + } + } } diff --git a/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs b/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs index ed4e0861d..2fd5c5e43 100644 --- a/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs +++ b/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs @@ -699,8 +699,8 @@ public async Task UpgradableRequest_CancelsIfIdle() Assert.Equal(StatusCodes.Status101SwitchingProtocols, httpContext.Response.StatusCode); // When both are idle it's a race which gets reported as canceled first. - Assert.True(ForwarderError.UpgradeRequestCanceled == result - || ForwarderError.UpgradeResponseCanceled == result); + Assert.True(ForwarderError.UpgradeRequestClient == result + || ForwarderError.UpgradeResponseDestination == result); events.AssertContainProxyStages(upgrade: true); } @@ -1479,7 +1479,7 @@ public async Task RequestConnectTimedOut_Returns504() } [Fact] - public async Task RequestCanceled_Returns502() + public async Task RequestCanceled_Returns400() { var events = TestEventListener.Collect(); @@ -1503,7 +1503,7 @@ public async Task RequestCanceled_Returns502() var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client); Assert.Equal(ForwarderError.RequestCanceled, proxyError); - Assert.Equal(StatusCodes.Status502BadGateway, httpContext.Response.StatusCode); + Assert.Equal(StatusCodes.Status400BadRequest, httpContext.Response.StatusCode); Assert.Equal(0, proxyResponseStream.Length); var errorFeature = httpContext.Features.Get(); Assert.Equal(ForwarderError.RequestCanceled, errorFeature.Error); @@ -1614,7 +1614,7 @@ public async Task RequestWithBody_KeptAliveByActivity() } [Fact] - public async Task RequestWithBodyCanceled_Returns502() + public async Task RequestWithBodyCanceled_Returns400() { var events = TestEventListener.Collect(); @@ -1640,7 +1640,7 @@ public async Task RequestWithBodyCanceled_Returns502() var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client); Assert.Equal(ForwarderError.RequestCanceled, proxyError); - Assert.Equal(StatusCodes.Status502BadGateway, httpContext.Response.StatusCode); + Assert.Equal(StatusCodes.Status400BadRequest, httpContext.Response.StatusCode); Assert.Equal(0, proxyResponseStream.Length); var errorFeature = httpContext.Features.Get(); Assert.Equal(ForwarderError.RequestCanceled, errorFeature.Error); @@ -1762,7 +1762,7 @@ public async Task RequestBodyCanceledBeforeResponseError_Returns502() var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client); Assert.Equal(ForwarderError.RequestBodyCanceled, proxyError); - Assert.Equal(StatusCodes.Status502BadGateway, httpContext.Response.StatusCode); + Assert.Equal(StatusCodes.Status400BadRequest, httpContext.Response.StatusCode); Assert.Equal(0, proxyResponseStream.Length); var errorFeature = httpContext.Features.Get(); Assert.Equal(ForwarderError.RequestBodyCanceled, errorFeature.Error); @@ -1993,6 +1993,7 @@ public async Task ResponseBodyCancelledAfterStart_Aborted() httpContext.Features.Set(responseBody); var destinationPrefix = "https://localhost:123/"; + var cts = new CancellationTokenSource(); var sut = CreateProxy(); var client = MockHttpHandler.CreateClient( (HttpRequestMessage request, CancellationToken cancellationToken) => @@ -2002,14 +2003,17 @@ public async Task ResponseBodyCancelledAfterStart_Aborted() Content = new StreamContent(new CallbackReadStream((_, _) => { responseBody.HasStarted = true; - throw new TaskCanceledException(); + cts.Cancel(); + cts.Token.ThrowIfCancellationRequested(); + throw new NotImplementedException(); })) }; message.Headers.AcceptRanges.Add("bytes"); return Task.FromResult(message); }); - var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client); + var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client, ForwarderRequestConfig.Empty, + HttpTransformer.Empty, cts.Token); Assert.Equal(ForwarderError.ResponseBodyCanceled, proxyError); Assert.Equal(StatusCodes.Status200OK, httpContext.Response.StatusCode); @@ -2017,7 +2021,7 @@ public async Task ResponseBodyCancelledAfterStart_Aborted() Assert.Equal("bytes", httpContext.Response.Headers[HeaderNames.AcceptRanges]); var errorFeature = httpContext.Features.Get(); Assert.Equal(ForwarderError.ResponseBodyCanceled, errorFeature.Error); - Assert.IsType(errorFeature.Exception); + Assert.IsType(errorFeature.Exception); AssertProxyStartFailedStop(events, destinationPrefix, httpContext.Response.StatusCode, errorFeature.Error); events.AssertContainProxyStages(hasRequestContent: false); @@ -2732,9 +2736,13 @@ public async Task ForwarderCancellations_CancellationsAreVisibleInTransforms(Can ? ForwarderError.RequestTimedOut : ForwarderError.RequestCanceled; - var expectedStatusCode = cancellationScenario == CancellationScenario.ActivityTimeout - ? StatusCodes.Status504GatewayTimeout - : StatusCodes.Status502BadGateway; + var expectedStatusCode = cancellationScenario switch + { + CancellationScenario.ActivityTimeout => StatusCodes.Status504GatewayTimeout, + CancellationScenario.RequestAborted => StatusCodes.Status400BadRequest, + CancellationScenario.ManualCancellationToken => StatusCodes.Status502BadGateway, + _ => throw new NotImplementedException(cancellationScenario.ToString()), + }; Assert.Equal(expectedError, proxyError); Assert.Equal(expectedStatusCode, httpContext.Response.StatusCode); diff --git a/test/ReverseProxy.Tests/Forwarder/StreamCopierTests.cs b/test/ReverseProxy.Tests/Forwarder/StreamCopierTests.cs index af2eb16cd..4c0b4d039 100644 --- a/test/ReverseProxy.Tests/Forwarder/StreamCopierTests.cs +++ b/test/ReverseProxy.Tests/Forwarder/StreamCopierTests.cs @@ -102,8 +102,9 @@ public async Task Cancelled_Reported(bool isRequest) var source = new MemoryStream(new byte[10]); var destination = new MemoryStream(); - using var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); - cts.Cancel(); + var requestCts = new CancellationTokenSource(); + using var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), requestCts.Token); + requestCts.Cancel(); var (result, error) = await StreamCopier.CopyAsync(isRequest, source, destination, StreamCopier.UnknownLength, new ManualClock(), cts, cts.Token); Assert.Equal(StreamCopyResult.Canceled, result); Assert.IsAssignableFrom(error); diff --git a/test/ReverseProxy.Tests/Forwarder/StreamCopyHttpContentTests.cs b/test/ReverseProxy.Tests/Forwarder/StreamCopyHttpContentTests.cs index 0c6e88bc0..3c5e8852d 100644 --- a/test/ReverseProxy.Tests/Forwarder/StreamCopyHttpContentTests.cs +++ b/test/ReverseProxy.Tests/Forwarder/StreamCopyHttpContentTests.cs @@ -13,19 +13,22 @@ using Yarp.Tests.Common; using Yarp.ReverseProxy.Utilities; using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using System.Xml.Linq; +using Microsoft.Extensions.Logging.Abstractions; namespace Yarp.ReverseProxy.Forwarder.Tests; public class StreamCopyHttpContentTests { - private static StreamCopyHttpContent CreateContent(HttpRequest request = null, bool autoFlushHttpClientOutgoingStream = false, IClock clock = null, ActivityCancellationTokenSource contentCancellation = null) + private static StreamCopyHttpContent CreateContent(HttpContext context = null, bool isStreamingRequest = false, IClock clock = null, ActivityCancellationTokenSource contentCancellation = null) { - request ??= new DefaultHttpContext().Request; + context ??= new DefaultHttpContext(); clock ??= new Clock(); contentCancellation ??= ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); - return new StreamCopyHttpContent(request, autoFlushHttpClientOutgoingStream, clock, contentCancellation); + return new StreamCopyHttpContent(context, isStreamingRequest, clock, NullLogger.Instance, contentCancellation); } [Fact] @@ -34,11 +37,11 @@ public async Task CopyToAsync_InvokesStreamCopier() const int SourceSize = (128 * 1024) - 3; var sourceBytes = Enumerable.Range(0, SourceSize).Select(i => (byte)(i % 256)).ToArray(); - var request = new DefaultHttpContext().Request; - request.Body = new MemoryStream(sourceBytes); + var context = new DefaultHttpContext(); + context.Request.Body = new MemoryStream(sourceBytes); var destination = new MemoryStream(); - var sut = CreateContent(request); + var sut = CreateContent(context); Assert.False(sut.ConsumptionTask.IsCompleted); Assert.False(sut.Started); @@ -68,12 +71,12 @@ public async Task CopyToAsync_AutoFlushing(bool autoFlush) expectedFlushes++; var sourceBytes = Enumerable.Range(0, SourceSize).Select(i => (byte)(i % 256)).ToArray(); - var request = new DefaultHttpContext().Request; - request.Body = new MemoryStream(sourceBytes); + var context = new DefaultHttpContext(); + context.Request.Body = new MemoryStream(sourceBytes); var destination = new MemoryStream(); var flushCountingDestination = new FlushCountingStream(destination); - var sut = CreateContent(request, autoFlushHttpClientOutgoingStream: autoFlush); + var sut = CreateContent(context, autoFlush); Assert.False(sut.ConsumptionTask.IsCompleted); Assert.False(sut.Started); @@ -91,11 +94,11 @@ public async Task CopyToAsync_AsyncSequencing() var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var source = new Mock(); source.Setup(s => s.ReadAsync(It.IsAny>(), It.IsAny())).Returns(() => new ValueTask(tcs.Task)); - var request = new DefaultHttpContext().Request; - request.Body = source.Object; + var context = new DefaultHttpContext(); + context.Request.Body = source.Object; var destination = new MemoryStream(); - var sut = CreateContent(request); + var sut = CreateContent(context); Assert.False(sut.ConsumptionTask.IsCompleted); Assert.False(sut.Started); @@ -151,12 +154,12 @@ public async Task SerializeToStreamAsync_RespectsContentCancellation() return 0; }); - var request = new DefaultHttpContext().Request; - request.Body = source; + var context = new DefaultHttpContext(); + context.Request.Body = source; using var contentCts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); - var sut = CreateContent(request, contentCancellation: contentCts); + var sut = CreateContent(context, contentCancellation: contentCts); var copyToTask = sut.CopyToWithCancellationAsync(new MemoryStream()); contentCts.Cancel(); @@ -183,10 +186,10 @@ public async Task SerializeToStreamAsync_CanBeCanceledExternally() return 0; }); - var request = new DefaultHttpContext().Request; - request.Body = source; + var context = new DefaultHttpContext(); + context.Request.Body = source; - var sut = CreateContent(request); + var sut = CreateContent(context); using var cts = new CancellationTokenSource(); var copyToTask = sut.CopyToAsync(new MemoryStream(), cts.Token);