Skip to content

Commit

Permalink
Merge pull request #71737 from jasonmalinowski/fix-cancellation-handling
Browse files Browse the repository at this point in the history
Ensure we don't leak the request info if a request is cancelled
  • Loading branch information
jasonmalinowski authored Jan 31, 2024
2 parents 45e9054 + 9bb45d5 commit be8d58b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 5 deletions.
32 changes: 27 additions & 5 deletions src/Workspaces/Core/MSBuild/Rpc/RpcClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,36 @@ public async Task<T> InvokeAsync<T>(int targetObject, string methodName, List<ob

var requestJson = JsonConvert.SerializeObject(request, JsonSettings.SingleLineSerializerSettings);

using (await _sendingStreamSemaphore.DisposableWaitAsync(cancellationToken).ConfigureAwait(false))
try
{
await _sendingStream.WriteLineAsync(requestJson).ConfigureAwait(false);
#pragma warning disable CA2016 // https://github.com/dotnet/roslyn/issues/71580
await _sendingStream.FlushAsync().ConfigureAwait(false);
#pragma warning restore CA2016
// The only cancellation we support is cancelling before we are able to write the request to the stream; once it's been written
// the other side will execute it to completion. Thus cancellationToken is checked here, but nowhere else.
using (await _sendingStreamSemaphore.DisposableWaitAsync(cancellationToken).ConfigureAwait(false))
{
await _sendingStream.WriteLineAsync(requestJson).ConfigureAwait(false);
#if NET8_0_OR_GREATER
await _sendingStream.FlushAsync(CancellationToken.None).ConfigureAwait(false);
#else
await _sendingStream.FlushAsync().ConfigureAwait(false);
#endif
}
}
catch (OperationCanceledException)
{
// The request was cancelled, so we don't need to hold it around anymore.
_outstandingRequests.TryRemove(requestId, out _);
throw;
}

return await requestCompletionSource.Task.ConfigureAwait(false);
}

internal TestAccessor GetTestAccessor()
=> new(this);

internal readonly struct TestAccessor(RpcClient client)
{
public int GetOutstandingRequestCount()
=> client._outstandingRequests.Count;
}
}
15 changes: 15 additions & 0 deletions src/Workspaces/MSBuildTest/RpcTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.MSBuild.Rpc;
using Microsoft.VisualStudio.Telemetry;
using Nerdbank.Streams;
using Xunit;

Expand Down Expand Up @@ -190,6 +191,20 @@ public async Task ExceptionHandling()
Assert.Contains("Exception thrown by test method!", exception.Message);
}

[Fact]
public async Task CancelledTaskDoesNotLeakRequest()
{
await using var rpcPair = new RpcPair();

var tokenSource = new CancellationTokenSource();
tokenSource.Cancel();

rpcPair.Server.AddTarget(new ObjectWithHelloMethod());
await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await rpcPair.Client.InvokeAsync<string>(targetObject: 0, nameof(ObjectWithHelloMethod.Hello), [], tokenSource.Token));

Assert.Equal(0, rpcPair.Client.GetTestAccessor().GetOutstandingRequestCount());
}

#pragma warning disable CA1822 // Mark members as static

private sealed class ObjectWithHelloMethod { public string Hello(string name) { return "Hello " + name; } }
Expand Down

0 comments on commit be8d58b

Please sign in to comment.