Skip to content

Commit

Permalink
Fix nunit tests adapter losing async locals (#16157)
Browse files Browse the repository at this point in the history
* Fix NUnit test context not being properly set

* Add failing tests

* Capture ExecutionContext to keep async locals

* Remove explicit EstablishExecutionEnvironment call, as it was a bad idea

* Make ExecutionContext usage disabled by default, and only enabled for NUnit
  • Loading branch information
maxkatz6 authored Jul 17, 2024
1 parent 945b371 commit 8ea60fe
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ internal class AvaloniaTestMethodCommand : TestCommand
.GetField("BeforeTest", BindingFlags.Instance | BindingFlags.NonPublic)!;
private static FieldInfo s_afterTest = typeof(BeforeAndAfterTestCommand)
.GetField("AfterTest", BindingFlags.Instance | BindingFlags.NonPublic)!;

private AvaloniaTestMethodCommand(
HeadlessUnitTestSession session,
TestCommand innerCommand,
Expand All @@ -47,7 +47,7 @@ public static TestCommand ProcessCommand(HeadlessUnitTestSession session, TestCo
{
return ProcessCommand(session, command, new List<Action>(), new List<Action>());
}

private static TestCommand ProcessCommand(HeadlessUnitTestSession session, TestCommand command, List<Action> before, List<Action> after)
{
if (command is BeforeAndAfterTestCommand beforeAndAfterTestCommand)
Expand Down Expand Up @@ -79,7 +79,7 @@ private static TestCommand ProcessCommand(HeadlessUnitTestSession session, TestC

public override TestResult Execute(TestExecutionContext context)
{
return _session.Dispatch(() => ExecuteTestMethod(context), default).GetAwaiter().GetResult();
return _session.DispatchCore(() => ExecuteTestMethod(context), true, default).GetAwaiter().GetResult();
}

// Unfortunately, NUnit has issues with custom synchronization contexts, which means we need to add some hacks to make it work.
Expand Down
1 change: 1 addition & 0 deletions src/Headless/Avalonia.Headless/Avalonia.Headless.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
</ItemGroup>

<ItemGroup Label="InternalsVisibleTo">
<InternalsVisibleTo Include="Avalonia.Headless.NUnit, PublicKey=$(AvaloniaPublicKey)" />
<InternalsVisibleTo Include="Avalonia.Headless.Vnc, PublicKey=$(AvaloniaPublicKey)" />
<InternalsVisibleTo Include="Avalonia.UnitTests, PublicKey=$(AvaloniaPublicKey)" />
<InternalsVisibleTo Include="Avalonia.Base.UnitTests, PublicKey=$(AvaloniaPublicKey)" />
Expand Down
42 changes: 28 additions & 14 deletions src/Headless/Avalonia.Headless/HeadlessUnitTestSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public sealed class HeadlessUnitTestSession : IDisposable

private readonly AppBuilder _appBuilder;
private readonly CancellationTokenSource _cancellationTokenSource;
private readonly BlockingCollection<Action> _queue;
private readonly BlockingCollection<(Action, ExecutionContext?)> _queue;
private readonly Task _dispatchTask;

internal const DynamicallyAccessedMemberTypes DynamicallyAccessed =
Expand All @@ -32,50 +32,58 @@ public sealed class HeadlessUnitTestSession : IDisposable
DynamicallyAccessedMemberTypes.PublicParameterlessConstructor;

private HeadlessUnitTestSession(AppBuilder appBuilder, CancellationTokenSource cancellationTokenSource,
BlockingCollection<Action> queue, Task dispatchTask)
BlockingCollection<(Action, ExecutionContext?)> queue, Task dispatchTask)
{
_appBuilder = appBuilder;
_cancellationTokenSource = cancellationTokenSource;
_queue = queue;
_dispatchTask = dispatchTask;
}

/// <inheritdoc cref="Dispatch{TResult}(Func{Task{TResult}}, CancellationToken)"/>
/// <inheritdoc cref="DispatchCore{TResult}"/>
public Task Dispatch(Action action, CancellationToken cancellationToken)
{
return Dispatch(() =>
return DispatchCore(() =>
{
action();
return Task.FromResult(0);
}, cancellationToken);
}, false ,cancellationToken);
}

/// <inheritdoc cref="Dispatch{TResult}(Func{Task{TResult}}, CancellationToken)"/>
/// <inheritdoc cref="DispatchCore{TResult}"/>
public Task<TResult> Dispatch<TResult>(Func<TResult> action, CancellationToken cancellationToken)
{
return Dispatch(() => Task.FromResult(action()), cancellationToken);
return DispatchCore(() => Task.FromResult(action()), false, cancellationToken);
}

/// <inheritdoc cref="DispatchCore{TResult}"/>
public Task<TResult> Dispatch<TResult>(Func<Task<TResult>> action, CancellationToken cancellationToken)
{
return DispatchCore(action, false, cancellationToken);
}

/// <summary>
/// Dispatch method queues an async operation on the dispatcher thread, creates a new application instance,
/// setting app avalonia services, and runs <paramref name="action"/> parameter.
/// </summary>
/// <param name="action">Action to execute on the dispatcher thread with avalonia services.</param>
/// <param name="captureExecutionContext">Whether dispatch should capture ExecutionContext.</param>
/// <param name="cancellationToken">Cancellation token to cancel execution.</param>
/// <exception cref="ObjectDisposedException">
/// If global session was already cancelled and thread killed, it's not possible to dispatch any actions again
/// </exception>
public Task<TResult> Dispatch<TResult>(Func<Task<TResult>> action, CancellationToken cancellationToken)
internal Task<TResult> DispatchCore<TResult>(Func<Task<TResult>> action, bool captureExecutionContext, CancellationToken cancellationToken)
{
if (_cancellationTokenSource.IsCancellationRequested)
{
throw new ObjectDisposedException("Session was already disposed.");
}

var token = _cancellationTokenSource.Token;
var executionContext = captureExecutionContext ? ExecutionContext.Capture() : null;

var tcs = new TaskCompletionSource<TResult>();
_queue.Add(() =>
_queue.Add((() =>
{
var cts = new CancellationTokenSource();
using var globalCts = token.Register(s => ((CancellationTokenSource)s!).Cancel(), cts, true);
Expand All @@ -84,7 +92,6 @@ public Task<TResult> Dispatch<TResult>(Func<Task<TResult>> action, CancellationT
try
{
using var application = EnsureApplication();
var task = action();
if (task.Status != TaskStatus.RanToCompletion)
{
Expand All @@ -110,7 +117,7 @@ public Task<TResult> Dispatch<TResult>(Func<Task<TResult>> action, CancellationT
{
tcs.TrySetException(ex);
}
});
}, executionContext));
return tcs.Task;
}

Expand Down Expand Up @@ -157,7 +164,7 @@ public static HeadlessUnitTestSession StartNew(
{
var tcs = new TaskCompletionSource<HeadlessUnitTestSession>();
var cancellationTokenSource = new CancellationTokenSource();
var queue = new BlockingCollection<Action>();
var queue = new BlockingCollection<(Action, ExecutionContext?)>();

Task? task = null;
task = Task.Run(() =>
Expand Down Expand Up @@ -185,8 +192,15 @@ public static HeadlessUnitTestSession StartNew(
{
try
{
var action = queue.Take(cancellationTokenSource.Token);
action();
var (action, executionContext) = queue.Take(cancellationTokenSource.Token);
if (executionContext is not null)
{
ExecutionContext.Run(executionContext, a => ((Action)a!).Invoke(), action);
}
else
{
action();
}
}
catch (OperationCanceledException)
{
Expand Down
35 changes: 29 additions & 6 deletions tests/Avalonia.Headless.UnitTests/ThreadingTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Avalonia.Threading;
Expand All @@ -14,6 +16,7 @@ public class ThreadingTests
#endif
public void Should_Be_On_Dispatcher_Thread()
{
ValidateTestContext();
Dispatcher.UIThread.VerifyAccess();
}

Expand All @@ -34,20 +37,40 @@ public void Should_Fail_Test_On_Delayed_Post_When_FlushDispatcher()
#endif
public async Task DispatcherTimer_Works_On_The_Same_Thread(int interval)
{
Assert.NotNull(SynchronizationContext.Current);
ValidateTestContext();
var currentThread = Thread.CurrentThread;

await Task.Delay(100);

var currentThread = Thread.CurrentThread;
ValidateTestContext();
Assert.True(currentThread == Thread.CurrentThread);

var tcs = new TaskCompletionSource();
var hasCompleted = false;

DispatcherTimer.RunOnce(() =>
{
hasCompleted = currentThread == Thread.CurrentThread;
tcs.SetResult();
try
{
ValidateTestContext();
Assert.True(currentThread == Thread.CurrentThread);
tcs.SetResult();
}
catch (Exception ex)
{
tcs.SetException(ex);
}
}, TimeSpan.FromTicks(interval));

await tcs.Task;
Assert.True(hasCompleted);
}

private void ValidateTestContext([CallerMemberName] string runningMethodName = null)
{
#if NUNIT
var testName = TestContext.CurrentContext.Test.Name;
// Test.Name also includes parameters.
Assert.AreEqual(testName.Split('(').First(), runningMethodName);
#endif
}
}

0 comments on commit 8ea60fe

Please sign in to comment.