Skip to content
This repository has been archived by the owner on Aug 8, 2024. It is now read-only.

Commit

Permalink
Improve ManualResetValueTaskSource prototype (dotnet#29468)
Browse files Browse the repository at this point in the history
In .NET Core 2.1 we added the public `IValueTaskSource` and `IValueTaskSource<T>` interfaces, with associated support in `ValueTask` and `ValueTask<T>`, and while we implemented the interfaces on several types internally, we didn't expose any public implementations.

We should consider exposing several in the future, including a manual-reset and an auto-reset IValueTaskSource implementation.  We already have a ManualResetValueTaskSource implementation in our tests.  This commit improves upon it in a few ways:
- Separates out the logic into a separate public struct.  The ManualResetValueTaskSource class wraps the struct, giving developers a choice to either use the class directly, or to embed the struct in their own implementation.
- Fixes context capture to behave more similarly to Task, handling both SynchronizationContext and TaskSchedulers
- Adds a prototype implementation of an IAsyncEnumerable, demonstrating how the compiler could utilize ManualResetValueTaskSourceLogic in its implementation.

This is all still prototype, used only in tests.

(cherry picked from commit df43abb)
  • Loading branch information
stephentoub authored and Martin Baulig committed Feb 4, 2019
1 parent e74a664 commit 826ac29
Show file tree
Hide file tree
Showing 8 changed files with 528 additions and 156 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,80 +2,108 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Threading.Tasks.Sources;

namespace System.Threading.Tasks.Tests
namespace System.Runtime.CompilerServices
{
internal static class ManualResetValueTaskSource
public interface IStrongBox<T>
{
public static ManualResetValueTaskSource<T> Completed<T>(T result, Exception error = null)
{
var vts = new ManualResetValueTaskSource<T>();
if (error != null)
{
vts.SetException(error);
}
else
{
vts.SetResult(result);
}
return vts;
}
ref T Value { get; }
}
}

public static ManualResetValueTaskSource<T> Delay<T>(int delayMs, T result, Exception error = null)
{
var vts = new ManualResetValueTaskSource<T>();
Task.Delay(delayMs).ContinueWith(_ =>
{
if (error != null)
{
vts.SetException(error);
}
else
{
vts.SetResult(result);
}
});
return vts;
}
namespace System.Threading.Tasks.Sources
{
public sealed class ManualResetValueTaskSource<T> : IStrongBox<ManualResetValueTaskSourceLogic<T>>, IValueTaskSource<T>, IValueTaskSource
{
private ManualResetValueTaskSourceLogic<T> _logic; // mutable struct; do not make this readonly

public ManualResetValueTaskSource() => _logic = new ManualResetValueTaskSourceLogic<T>(this);

public short Version => _logic.Version;

public void Reset() => _logic.Reset();

public void SetResult(T result) => _logic.SetResult(result);

public void SetException(Exception error) => _logic.SetException(error);

public T GetResult(short token) => _logic.GetResult(token);
void IValueTaskSource.GetResult(short token) => _logic.GetResult(token);

public ValueTaskSourceStatus GetStatus(short token) => _logic.GetStatus(token);

public void OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) => _logic.OnCompleted(continuation, state, token, flags);

ref ManualResetValueTaskSourceLogic<T> IStrongBox<ManualResetValueTaskSourceLogic<T>>.Value => ref _logic;
}

internal sealed class ManualResetValueTaskSource<T> : IValueTaskSource<T>, IValueTaskSource
public struct ManualResetValueTaskSourceLogic<TResult>
{
private static readonly Action<object> s_sentinel = new Action<object>(s => { });
private static readonly Action<object> s_sentinel = new Action<object>(s => throw new InvalidOperationException());

private readonly IStrongBox<ManualResetValueTaskSourceLogic<TResult>> _parent;
private Action<object> _continuation;
private object _continuationState;
private SynchronizationContext _capturedContext;
private object _capturedContext;
private ExecutionContext _executionContext;
private bool _completed;
private T _result;
private TResult _result;
private ExceptionDispatchInfo _error;
private short _version;

public ManualResetValueTaskSourceLogic(IStrongBox<ManualResetValueTaskSourceLogic<TResult>> parent)
{
_parent = parent ?? throw new ArgumentNullException(nameof(parent));
_continuation = null;
_continuationState = null;
_capturedContext = null;
_executionContext = null;
_completed = false;
_result = default;
_error = null;
_version = 0;
}

public ValueTaskSourceStatus GetStatus(short token) =>
!_completed ? ValueTaskSourceStatus.Pending :
_error == null ? ValueTaskSourceStatus.Succeeded :
_error.SourceException is OperationCanceledException ? ValueTaskSourceStatus.Canceled :
ValueTaskSourceStatus.Faulted;
public short Version => _version;

public T GetResult(short token)
private void ValidateToken(short token)
{
if (!_completed)
if (token != _version)
{
throw new Exception("Not completed");
throw new InvalidOperationException();
}
}

_error?.Throw();
return _result;
public ValueTaskSourceStatus GetStatus(short token)
{
ValidateToken(token);

return
!_completed ? ValueTaskSourceStatus.Pending :
_error == null ? ValueTaskSourceStatus.Succeeded :
_error.SourceException is OperationCanceledException ? ValueTaskSourceStatus.Canceled :
ValueTaskSourceStatus.Faulted;
}

void IValueTaskSource.GetResult(short token)
public TResult GetResult(short token)
{
GetResult(token);
ValidateToken(token);

if (!_completed)
{
throw new InvalidOperationException();
}

_error?.Throw();
return _result;
}

public void Reset()
{
_version++;

_completed = false;
_continuation = null;
_continuationState = null;
Expand All @@ -87,36 +115,64 @@ public void Reset()

public void OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags)
{
if (continuation == null)
{
throw new ArgumentNullException(nameof(continuation));
}
ValidateToken(token);

if ((flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) != 0)
{
_executionContext = ExecutionContext.Capture();
}

if ((flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) != 0)
{
_capturedContext = SynchronizationContext.Current;
SynchronizationContext sc = SynchronizationContext.Current;
if (sc != null && sc.GetType() != typeof(SynchronizationContext))
{
_capturedContext = sc;
}
else
{
TaskScheduler ts = TaskScheduler.Current;
if (ts != TaskScheduler.Default)
{
_capturedContext = ts;
}
}
}

_continuationState = state;
if (Interlocked.CompareExchange(ref _continuation, continuation, null) != null)
{
SynchronizationContext sc = _capturedContext;
if (sc != null)
{
sc.Post(s =>
{
var tuple = (Tuple<Action<object>, object>)s;
tuple.Item1(tuple.Item2);
}, Tuple.Create(continuation, state));
}
else
_executionContext = null;

object cc = _capturedContext;
_capturedContext = null;

switch (cc)
{
Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default);
case null:
Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default);
break;

case SynchronizationContext sc:
sc.Post(s =>
{
var tuple = (Tuple<Action<object>, object>)s;
tuple.Item1(tuple.Item2);
}, Tuple.Create(continuation, state));
break;

case TaskScheduler ts:
Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, ts);
break;
}
}
}

public void SetResult(T result)
public void SetResult(TResult result)
{
_result = result;
SignalCompletion();
Expand All @@ -130,12 +186,20 @@ public void SetException(Exception error)

private void SignalCompletion()
{
if (_completed)
{
throw new InvalidOperationException();
}
_completed = true;

if (Interlocked.CompareExchange(ref _continuation, s_sentinel, null) != null)
{
if (_executionContext != null)
{
ExecutionContext.Run(_executionContext, s => ((ManualResetValueTaskSource<T>)s).InvokeContinuation(), this);
ExecutionContext.Run(
_executionContext,
s => ((IStrongBox<ManualResetValueTaskSourceLogic<TResult>>)s).Value.InvokeContinuation(),
_parent ?? throw new InvalidOperationException());
}
else
{
Expand All @@ -146,18 +210,26 @@ private void SignalCompletion()

private void InvokeContinuation()
{
SynchronizationContext sc = _capturedContext;
if (sc != null)
{
sc.Post(s =>
{
var thisRef = (ManualResetValueTaskSource<T>)s;
thisRef._continuation(thisRef._continuationState);
}, this);
}
else
object cc = _capturedContext;
_capturedContext = null;

switch (cc)
{
_continuation(_continuationState);
case null:
_continuation(_continuationState);
break;

case SynchronizationContext sc:
sc.Post(s =>
{
ref ManualResetValueTaskSourceLogic<TResult> logicRef = ref ((IStrongBox<ManualResetValueTaskSourceLogic<TResult>>)s).Value;
logicRef._continuation(logicRef._continuationState);
}, _parent ?? throw new InvalidOperationException());
break;

case TaskScheduler ts:
Task.Factory.StartNew(_continuation, _continuationState, CancellationToken.None, TaskCreationOptions.DenyChildAttach, ts);
break;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Runtime.ExceptionServices;
using System.Threading.Tasks.Sources;

namespace System.Threading.Tasks.Tests
{
internal static class ManualResetValueTaskSourceFactory
{
public static ManualResetValueTaskSource<T> Completed<T>(T result, Exception error = null)
{
var vts = new ManualResetValueTaskSource<T>();
if (error != null)
{
vts.SetException(error);
}
else
{
vts.SetResult(result);
}
return vts;
}

public static ManualResetValueTaskSource<T> Delay<T>(int delayMs, T result, Exception error = null)
{
var vts = new ManualResetValueTaskSource<T>();
Task.Delay(delayMs).ContinueWith(_ =>
{
if (error != null)
{
vts.SetException(error);
}
else
{
vts.SetResult(result);
}
});
return vts;
}
}
}
Loading

0 comments on commit 826ac29

Please sign in to comment.