Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable request transforms to reject requests #1923

Merged
merged 5 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion samples/ReverseProxy.Auth.Sample/Startup.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.Net.Http.Headers;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authentication.Cookies;
using Microsoft.AspNetCore.Builder;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Yarp.ReverseProxy.Transforms;

namespace Yarp.Sample
{
Expand All @@ -31,8 +34,37 @@ public void ConfigureServices(IServiceCollection services)
// Required to supply the authentication UI in Views/*
services.AddRazorPages();

services.AddSingleton<TokenService>();

services.AddReverseProxy()
.LoadFromConfig(_configuration.GetSection("ReverseProxy"));
.LoadFromConfig(_configuration.GetSection("ReverseProxy"))
.AddTransforms(transformBuilderContext => // Add transforms inline
{
// For each route+cluster pair decide if we want to add transforms, and if so, which?
// This logic is re-run each time a route is rebuilt.

// Only do this for routes that require auth.
if (string.Equals("myPolicy", transformBuilderContext.Route.AuthorizationPolicy))
{
transformBuilderContext.AddRequestTransform(async transformContext =>
{
// AuthN and AuthZ will have already been completed after request routing.
var ticket = await transformContext.HttpContext.AuthenticateAsync(CookieAuthenticationDefaults.AuthenticationScheme);
var tokenService = transformContext.HttpContext.RequestServices.GetRequiredService<TokenService>();
var token = await tokenService.GetAuthTokenAsync(ticket.Principal);

// Reject invalid requests
if (string.IsNullOrEmpty(token))
{
var response = transformContext.HttpContext.Response;
response.StatusCode = 401;
return;
}

transformContext.ProxyRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", token);
});
}
}); ;

services.AddAuthentication(CookieAuthenticationDefaults.AuthenticationScheme)
.AddCookie();
Expand Down
21 changes: 21 additions & 0 deletions samples/ReverseProxy.Auth.Sample/TokenService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.Security.Claims;
using System.Threading.Tasks;

namespace Yarp.Sample
{
internal class TokenService
{
internal Task<string> GetAuthTokenAsync(ClaimsPrincipal user)
{
// we only have tokens for bob
if (string.Equals("Bob", user.Identity.Name))
{
return Task.FromResult("valid");
}
return Task.FromResult<string>(null);
}
}
Tratcher marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<form action="Login" method="post">
<input hidden name="returnurl" type="text" value="@ViewData["ReturnUrl"]" /><br />
<input name="Name" type="text" value="My Name" /><br />
<input name="Name" type="text" value="Bob" /><br />
<input name="myClaimValue" type="text" value="green" /><br />
<input type="submit">
<div><b>Note:</b>The authorization policy will check for the value of "green", other values should pass authentication, but not authorize for specific routes</div>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ public void ValidateRoute(TransformRouteValidationContext context)
public void ValidateCluster(TransformClusterValidationContext context)
{
// Check all clusters for a custom property and validate the associated transform data.
string value = null;
if (context.Cluster.Metadata?.TryGetValue("CustomMetadata", out value) ?? false)
if (context.Cluster.Metadata?.TryGetValue("CustomMetadata", out var value) ?? false)
{
if (string.IsNullOrEmpty(value))
{
Expand Down
32 changes: 30 additions & 2 deletions src/ReverseProxy/Forwarder/HttpForwarder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ public async ValueTask<ForwarderError> SendAsync(
_ = requestConfig ?? throw new ArgumentNullException(nameof(requestConfig));
_ = transformer ?? throw new ArgumentNullException(nameof(transformer));

if (context.Response.StatusCode != StatusCodes.Status200OK || context.Response.HasStarted)
Tratcher marked this conversation as resolved.
Show resolved Hide resolved
{
throw new InvalidOperationException("The request cannot be forwarded, the response has already started");
}

// HttpClient overload for SendAsync changes response behavior to fully buffered which impacts performance
// See discussion in https://github.com/microsoft/reverse-proxy/issues/458
if (httpClient is HttpClient)
Expand All @@ -116,6 +121,15 @@ public async ValueTask<ForwarderError> SendAsync(
var (destinationRequest, requestContent) = await CreateRequestMessageAsync(
context, destinationPrefix, transformer, requestConfig, isStreamingRequest, activityCancellationSource);

// Transforms generated a response, do not proxy.
if (context.Response.StatusCode != StatusCodes.Status200OK || context.Response.HasStarted)
{
Log.NotProxying(_logger, context.Response.StatusCode);
return ForwarderError.None;
}

Log.Proxying(_logger, destinationRequest, isStreamingRequest);

// :: Step 4: Send the outgoing request using HttpClient
HttpResponseMessage destinationResponse;
try
Expand Down Expand Up @@ -282,6 +296,12 @@ public async ValueTask<ForwarderError> SendAsync(
// :: Step 3: Copy request headers Client --► Proxy --► Destination
await transformer.TransformRequestAsync(context, destinationRequest, destinationPrefix);

// The transformer generated a response, do not forward.
if (context.Response.StatusCode != StatusCodes.Status200OK || context.Response.HasStarted)
{
return (destinationRequest, requestContent);
}

if (isUpgradeRequest)
{
RestoreUpgradeHeaders(context, destinationRequest);
Expand All @@ -291,8 +311,6 @@ public async ValueTask<ForwarderError> SendAsync(
var request = context.Request;
destinationRequest.RequestUri ??= RequestUtilities.MakeDestinationAddress(destinationPrefix, request.Path, request.QueryString);

Log.Proxying(_logger, destinationRequest, isStreamingRequest);

if (requestConfig?.AllowResponseBuffering != true)
{
context.Features.Get<IHttpResponseBodyFeature>()?.DisableBuffering();
Expand Down Expand Up @@ -765,6 +783,11 @@ private static class Log
EventIds.ForwardingError,
"{error}: {message}");

private static readonly Action<ILogger, int, Exception?> _notProxying = LoggerMessage.Define<int>(
LogLevel.Information,
EventIds.NotForwarding,
"Not Proxying, a {statusCode} response was set by the transforms.");

public static void ResponseReceived(ILogger logger, HttpResponseMessage msg)
{
_responseReceived(logger, msg.Version, (int)msg.StatusCode, null);
Expand All @@ -782,6 +805,11 @@ public static void Proxying(ILogger logger, HttpRequestMessage msg, bool isStrea
}
}

public static void NotProxying(ILogger logger, int statusCode)
{
_notProxying(logger, statusCode, null);
}

public static void ErrorProxying(ILogger logger, ForwarderError error, Exception ex)
{
_proxyError(logger, error, GetMessage(error), ex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ public override async ValueTask TransformRequestAsync(HttpContext httpContext, H
foreach (var requestTransform in RequestTransforms)
{
await requestTransform.ApplyAsync(transformContext);

// The transform generated a response, do not apply further transforms and do not forward.
if (httpContext.Response.StatusCode != StatusCodes.Status200OK || httpContext.Response.HasStarted)
{
return;
}
}

// Allow a transform to directly set a custom RequestUri.
Expand Down
1 change: 1 addition & 0 deletions src/ReverseProxy/Utilities/EventIds.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ internal static class EventIds
public static readonly EventId ResponseReceived = new EventId(56, "ResponseReceived");
public static readonly EventId DelegationQueueReset = new EventId(57, "DelegationQueueReset");
public static readonly EventId Http10RequestVersionDetected = new EventId(58, "Http10RequestVersionDetected");
public static readonly EventId NotForwarding = new EventId(59, "NotForwarding");
}
131 changes: 127 additions & 4 deletions test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,125 @@ public async Task TransformRequestAsync_ReplaceBody()
events.AssertContainProxyStages();
}

[Fact]
public async Task TransformRequestAsync_SetsStatus_ShortCircuits()
{
var events = TestEventListener.Collect();

var httpContext = new DefaultHttpContext();
httpContext.Request.Method = "POST";
httpContext.Request.Protocol = "HTTP/2";

var destinationPrefix = "https://localhost/";

var transforms = new DelegateHttpTransforms()
{
CopyRequestHeaders = true,
OnRequest = (context, request, destination) =>
{
context.Response.StatusCode = 401;
return Task.CompletedTask;
}
};

var sut = CreateProxy();
var client = MockHttpHandler.CreateClient(
async (HttpRequestMessage request, CancellationToken cancellationToken) =>
{
throw new NotImplementedException();
});

var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client, ForwarderRequestConfig.Empty, transforms);

Assert.Equal(ForwarderError.None, proxyError);
Assert.Equal(StatusCodes.Status401Unauthorized, httpContext.Response.StatusCode);

AssertProxyStartStop(events, destinationPrefix, httpContext.Response.StatusCode);
events.AssertContainProxyStages(new ForwarderStage[0]);
}

[Fact]
public async Task TransformRequestAsync_StartsResponse_ShortCircuits()
{
var events = TestEventListener.Collect();

var httpContext = new DefaultHttpContext();
var responseBody = new TestResponseBody();
httpContext.Features.Set<IHttpResponseFeature>(responseBody);
httpContext.Features.Set<IHttpResponseBodyFeature>(responseBody);
httpContext.Request.Method = "POST";
httpContext.Request.Protocol = "HTTP/2";

var destinationPrefix = "https://localhost/";

var transforms = new DelegateHttpTransforms()
{
CopyRequestHeaders = true,
OnRequest = (context, request, destination) =>
{
return context.Response.StartAsync();
}
};

var sut = CreateProxy();
var client = MockHttpHandler.CreateClient(
(HttpRequestMessage request, CancellationToken cancellationToken) =>
{
throw new NotImplementedException();
});

var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client, ForwarderRequestConfig.Empty, transforms);

Assert.Equal(ForwarderError.None, proxyError);
Assert.Equal(StatusCodes.Status200OK, httpContext.Response.StatusCode);
Assert.True(httpContext.Response.HasStarted);

AssertProxyStartStop(events, destinationPrefix, httpContext.Response.StatusCode);
events.AssertContainProxyStages(new ForwarderStage[0]);
}

[Fact]
public async Task TransformRequestAsync_WritesToResponse_ShortCircuits()
{
var events = TestEventListener.Collect();

var httpContext = new DefaultHttpContext();
var resultStream = new MemoryStream();
var responseBody = new TestResponseBody(resultStream);
httpContext.Features.Set<IHttpResponseFeature>(responseBody);
httpContext.Features.Set<IHttpResponseBodyFeature>(responseBody);
httpContext.Request.Method = "POST";
httpContext.Request.Protocol = "HTTP/2";

var destinationPrefix = "https://localhost/";

var transforms = new DelegateHttpTransforms()
{
CopyRequestHeaders = true,
OnRequest = (context, request, destination) =>
{
return context.Response.Body.WriteAsync(Encoding.UTF8.GetBytes("Hello World")).AsTask();
}
};

var sut = CreateProxy();
var client = MockHttpHandler.CreateClient(
(HttpRequestMessage request, CancellationToken cancellationToken) =>
{
throw new NotImplementedException();
});

var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client, ForwarderRequestConfig.Empty, transforms);

Assert.Equal(ForwarderError.None, proxyError);
Assert.Equal(StatusCodes.Status200OK, httpContext.Response.StatusCode);
Assert.True(httpContext.Response.HasStarted);
Assert.Equal("Hello World", Encoding.UTF8.GetString(resultStream.ToArray()));

AssertProxyStartStop(events, destinationPrefix, httpContext.Response.StatusCode);
events.AssertContainProxyStages(new ForwarderStage[0]);
}

// Tests proxying an upgradeable request.
[Theory]
[InlineData("WebSocket")]
Expand Down Expand Up @@ -1887,11 +2006,10 @@ public async Task ResponseBodyCancelledAfterStart_Aborted()
var httpContext = new DefaultHttpContext();
httpContext.Request.Method = "GET";
httpContext.Request.Host = new HostString("example.com:3456");
var responseBody = new TestResponseBody() { HasStarted = true };
var responseBody = new TestResponseBody();
httpContext.Features.Set<IHttpResponseFeature>(responseBody);
httpContext.Features.Set<IHttpResponseBodyFeature>(responseBody);
httpContext.Features.Set<IHttpRequestLifetimeFeature>(responseBody);
httpContext.RequestAborted = new CancellationToken(canceled: true);

var destinationPrefix = "https://localhost:123/";
var sut = CreateProxy();
Expand All @@ -1900,7 +2018,11 @@ public async Task ResponseBodyCancelledAfterStart_Aborted()
{
var message = new HttpResponseMessage()
{
Content = new StreamContent(new MemoryStream(new byte[1]))
Content = new StreamContent(new CallbackReadStream((_, _) =>
{
responseBody.HasStarted = true;
throw new TaskCanceledException();
}))
};
message.Headers.AcceptRanges.Add("bytes");
return Task.FromResult(message);
Expand Down Expand Up @@ -2828,7 +2950,8 @@ public Task SendFileAsync(string path, long offset, long? count, CancellationTok

public Task StartAsync(CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
OnStart();
return Task.CompletedTask;
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
Expand Down