Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move to a callback-style approach to deserializing objects (part2) #72944

Merged
merged 13 commits into from
Apr 9, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,13 @@ private sealed class AssetProvider(SerializationValidator validator) : AbstractA
public override async ValueTask<T> GetAssetAsync<T>(AssetPath assetPath, Checksum checksum, CancellationToken cancellationToken)
=> await validator.GetValueAsync<T>(checksum).ConfigureAwait(false);

public override async ValueTask<ImmutableArray<(Checksum checksum, T asset)>> GetAssetsAsync<T>(AssetPath assetPath, HashSet<Checksum> checksums, CancellationToken cancellationToken)
public override async ValueTask GetAssetsAsync<T, TArg>(AssetPath assetPath, HashSet<Checksum> checksums, Action<Checksum, T, TArg> callback, TArg arg, CancellationToken cancellationToken)
{
using var _ = ArrayBuilder<(Checksum checksum, T asset)>.GetInstance(out var result);

foreach (var checksum in checksums)
{
var value = await GetAssetAsync<T>(assetPath, checksum, cancellationToken).ConfigureAwait(false);
result.Add((checksum, value));
callback(checksum, value, arg);
}

return result.ToImmutable();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ private static async Task TestAssetAsync(object data)
var stored = await provider.GetAssetAsync<object>(AssetPath.FullLookupForTesting, checksum, CancellationToken.None);
Assert.Equal(data, stored);

var stored2 = await provider.GetAssetsAsync<object>(AssetPath.FullLookupForTesting, new HashSet<Checksum> { checksum }, CancellationToken.None);
Assert.Equal(1, stored2.Length);
var stored2 = new List<(Checksum, object)>();
await provider.GetAssetsAsync<object, VoidResult>(AssetPath.FullLookupForTesting, new HashSet<Checksum> { checksum }, (checksum, asset, _) => stored2.Add((checksum, asset)), default, CancellationToken.None);
Assert.Equal(1, stored2.Count);

Assert.Equal(checksum, stored2[0].Item1);
Assert.Equal(data, stored2[0].Item2);
Expand All @@ -83,7 +84,7 @@ public async Task TestAssetSynchronization()
var assetSource = new SimpleAssetSource(workspace.Services.GetService<ISerializerService>(), map);

var service = new AssetProvider(sessionId, storage, assetSource, remoteWorkspace.Services.GetService<ISerializerService>());
await service.SynchronizeAssetsAsync<object>(AssetPath.FullLookupForTesting, new HashSet<Checksum>(map.Keys), results: null, CancellationToken.None);
await service.SynchronizeAssetsAsync<object, VoidResult>(AssetPath.FullLookupForTesting, new HashSet<Checksum>(map.Keys), callback: null, arg: default, CancellationToken.None);

foreach (var kv in map)
{
Expand Down
6 changes: 3 additions & 3 deletions src/Workspaces/CoreTestUtilities/Fakes/SimpleAssetSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ namespace Microsoft.CodeAnalysis.Remote.Testing;
/// </summary>
internal sealed class SimpleAssetSource(ISerializerService serializerService, IReadOnlyDictionary<Checksum, object> map) : IAssetSource
{
public ValueTask GetAssetsAsync<T>(
Checksum solutionChecksum, AssetPath assetPath, ReadOnlyMemory<Checksum> checksums, ISerializerService deserializerService, Action<int, T> callback, CancellationToken cancellationToken)
public ValueTask GetAssetsAsync<T, TArg>(
Checksum solutionChecksum, AssetPath assetPath, ReadOnlyMemory<Checksum> checksums, ISerializerService deserializerService, Action<int, T, TArg> callback, TArg arg, CancellationToken cancellationToken)
{
var index = 0;
foreach (var checksum in checksums.Span)
Expand All @@ -38,7 +38,7 @@ public ValueTask GetAssetsAsync<T>(
using var reader = ObjectReader.GetReader(stream, leaveOpen: true, cancellationToken);
var asset = deserializerService.Deserialize(data.GetWellKnownSynchronizationKind(), reader, cancellationToken);
Contract.ThrowIfNull(asset);
callback(index, (T)asset);
callback(index, (T)asset, arg);
index++;
}

Expand Down
15 changes: 12 additions & 3 deletions src/Workspaces/Remote/Core/AbstractAssetProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Runtime.InteropServices;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed

using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.Diagnostics;
Expand All @@ -22,7 +24,7 @@ internal abstract class AbstractAssetProvider
/// return data of type T whose checksum is the given checksum
/// </summary>
public abstract ValueTask<T> GetAssetAsync<T>(AssetPath assetPath, Checksum checksum, CancellationToken cancellationToken);
public abstract ValueTask<ImmutableArray<(Checksum checksum, T asset)>> GetAssetsAsync<T>(AssetPath assetPath, HashSet<Checksum> checksums, CancellationToken cancellationToken);
public abstract ValueTask GetAssetsAsync<T, TArg>(AssetPath assetPath, HashSet<Checksum> checksums, Action<Checksum, T, TArg> callback, TArg arg, CancellationToken cancellationToken);

public async Task<SolutionInfo> CreateSolutionInfoAsync(Checksum solutionChecksum, CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -110,7 +112,14 @@ public async Task<ImmutableArray<T>> GetAssetsAsync<T>(
using var _ = PooledHashSet<Checksum>.GetInstance(out var checksumSet);
checksumSet.AddAll(checksums.Children);

var results = await this.GetAssetsAsync<T>(assetPath, checksumSet, cancellationToken).ConfigureAwait(false);
return results.SelectAsArray(static t => t.asset);
var results = ImmutableArray.CreateBuilder<T>(checksumSet.Count);
CyrusNajmabadi marked this conversation as resolved.
Show resolved Hide resolved

await this.GetAssetsAsync<T, ImmutableArray<T>.Builder>(
assetPath, checksumSet,
static (checksum, asset, results) => results.Add(asset),
results,
cancellationToken).ConfigureAwait(false);

return results.MoveToImmutable();
}
}
19 changes: 8 additions & 11 deletions src/Workspaces/Remote/Core/RemoteHostAssetSerialization.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.IO;
using System.IO.Pipelines;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Serialization;
using Nerdbank.Streams;
using Roslyn.Utilities;
Expand Down Expand Up @@ -67,8 +64,8 @@ static void WriteAsset(ObjectWriter writer, ISerializerService serializer, Solut
}
}

public static ValueTask ReadDataAsync<T>(
PipeReader pipeReader, Checksum solutionChecksum, int objectCount, ISerializerService serializerService, Action<int, T> callback, CancellationToken cancellationToken)
public static ValueTask ReadDataAsync<T, TArg>(
PipeReader pipeReader, Checksum solutionChecksum, int objectCount, ISerializerService serializerService, Action<int, T, TArg> callback, TArg arg, CancellationToken cancellationToken)
{
// Suppress ExecutionContext flow for asynchronous operations operate on the pipe. In addition to avoiding
// ExecutionContext allocations, this clears the LogicalCallContext and avoids the need to clone data set by
Expand All @@ -77,18 +74,18 @@ public static ValueTask ReadDataAsync<T>(
// ⚠ DO NOT AWAIT INSIDE THE USING. The Dispose method that restores ExecutionContext flow must run on the
// same thread where SuppressFlow was originally run.
using var _ = FlowControlHelper.TrySuppressFlow();
return ReadDataSuppressedFlowAsync(pipeReader, solutionChecksum, objectCount, serializerService, callback, cancellationToken);
return ReadDataSuppressedFlowAsync(pipeReader, solutionChecksum, objectCount, serializerService, callback, arg, cancellationToken);

static async ValueTask ReadDataSuppressedFlowAsync(
PipeReader pipeReader, Checksum solutionChecksum, int objectCount, ISerializerService serializerService, Action<int, T> callback, CancellationToken cancellationToken)
PipeReader pipeReader, Checksum solutionChecksum, int objectCount, ISerializerService serializerService, Action<int, T, TArg> callback, TArg arg, CancellationToken cancellationToken)
{
using var stream = await pipeReader.AsPrebufferedStreamAsync(cancellationToken).ConfigureAwait(false);
ReadData<T>(stream, solutionChecksum, objectCount, serializerService, callback, cancellationToken);
ReadData(stream, solutionChecksum, objectCount, serializerService, callback, arg, cancellationToken);
}
}

public static void ReadData<T>(
Stream stream, Checksum solutionChecksum, int objectCount, ISerializerService serializerService, Action<int, T> callback, CancellationToken cancellationToken)
public static void ReadData<T, TArg>(
Stream stream, Checksum solutionChecksum, int objectCount, ISerializerService serializerService, Action<int, T, TArg> callback, TArg arg, CancellationToken cancellationToken)
{
using var reader = ObjectReader.GetReader(stream, leaveOpen: true, cancellationToken);

Expand All @@ -104,7 +101,7 @@ public static void ReadData<T>(
// in service hub, cancellation means simply closed stream
var result = serializerService.Deserialize(kind, reader, cancellationToken);
Contract.ThrowIfNull(result);
callback(i, (T)result);
callback(i, (T)result, arg);
}
}
}
Expand Down
99 changes: 47 additions & 52 deletions src/Workspaces/Remote/ServiceHub/Host/AssetProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,25 @@ public override async ValueTask<T> GetAssetAsync<T>(
using var _1 = PooledHashSet<Checksum>.GetInstance(out var checksums);
checksums.Add(checksum);

using var _2 = PooledDictionary<Checksum, T>.GetInstance(out var results);
await this.SynchronizeAssetsAsync(assetPath, checksums, results, cancellationToken).ConfigureAwait(false);
var called = false;
T? result = default;
await this.SynchronizeAssetsAsync<T, VoidResult>(assetPath, checksums, (_, asset, _) =>
{
Contract.ThrowIfTrue(called);
called = true;
CyrusNajmabadi marked this conversation as resolved.
Show resolved Hide resolved
result = asset;
}, default, cancellationToken).ConfigureAwait(false);

Contract.ThrowIfFalse(called);
Contract.ThrowIfNull((object?)result);

return results[checksum];
return result;
}

public override async ValueTask<ImmutableArray<(Checksum checksum, T asset)>> GetAssetsAsync<T>(
AssetPath assetPath, HashSet<Checksum> checksums, CancellationToken cancellationToken)
public override async ValueTask GetAssetsAsync<T, TArg>(
AssetPath assetPath, HashSet<Checksum> checksums, Action<Checksum, T, TArg> callback, TArg arg, CancellationToken cancellationToken)
{
using var _ = PooledDictionary<Checksum, T>.GetInstance(out var results);

await this.SynchronizeAssetsAsync(assetPath, checksums, results, cancellationToken).ConfigureAwait(false);

var result = new (Checksum checksum, T asset)[checksums.Count];
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no intermediary alloc!

var index = 0;
foreach (var (checksum, assetObject) in results)
{
result[index] = (checksum, assetObject);
index++;
}

return ImmutableCollectionsMarshal.AsImmutableArray(result);
await this.SynchronizeAssetsAsync(assetPath, checksums, callback, arg, cancellationToken).ConfigureAwait(false);
}

public async ValueTask SynchronizeSolutionAssetsAsync(Checksum solutionChecksum, CancellationToken cancellationToken)
Expand Down Expand Up @@ -88,17 +85,15 @@ public async ValueTask SynchronizeSolutionAssetsAsync(Checksum solutionChecksum,

async ValueTask SynchronizeSolutionAssetsWorkerAsync()
{
using var _1 = PooledDictionary<Checksum, object>.GetInstance(out var checksumToObjects);

// first, get top level solution state for the given solution checksum
var compilationStateChecksums = await this.GetAssetAsync<SolutionCompilationStateChecksums>(
assetPath: AssetPath.SolutionOnly, solutionChecksum, cancellationToken).ConfigureAwait(false);

using var _2 = PooledHashSet<Checksum>.GetInstance(out var checksums);
using var _1 = PooledHashSet<Checksum>.GetInstance(out var checksums);

// second, get direct children of the solution compilation state.
compilationStateChecksums.AddAllTo(checksums);
await this.SynchronizeAssetsAsync<object>(assetPath: AssetPath.SolutionOnly, checksums, results: null, cancellationToken).ConfigureAwait(false);
await this.SynchronizeAssetsAsync<object, VoidResult>(assetPath: AssetPath.SolutionOnly, checksums, callback: null, arg: default, cancellationToken).ConfigureAwait(false);

// third, get direct children of the solution state.
var stateChecksums = await this.GetAssetAsync<SolutionStateChecksums>(
Expand All @@ -108,7 +103,13 @@ async ValueTask SynchronizeSolutionAssetsWorkerAsync()
// the project states and we want to get that all in one batch.
checksums.Clear();
stateChecksums.AddAllTo(checksums);
await this.SynchronizeAssetsAsync(assetPath: AssetPath.SolutionAndTopLevelProjectsOnly, checksums, checksumToObjects, cancellationToken).ConfigureAwait(false);

using var _2 = PooledDictionary<Checksum, object>.GetInstance(out var checksumToObjects);

await this.SynchronizeAssetsAsync<object, Dictionary<Checksum, object>>(
assetPath: AssetPath.SolutionAndTopLevelProjectsOnly, checksums,
CyrusNajmabadi marked this conversation as resolved.
Show resolved Hide resolved
static (checksum, asset, checksumToObjects) => checksumToObjects.Add(checksum, asset),
arg: checksumToObjects, cancellationToken).ConfigureAwait(false);

// fourth, get all projects and documents in the solution
foreach (var (projectChecksum, _) in stateChecksums.Projects)
Expand Down Expand Up @@ -151,8 +152,8 @@ async ValueTask SynchronizeProjectAssetsWorkerAsync()
AddAll(checksums, projectChecksums.AnalyzerConfigDocuments.Checksums);

// First synchronize all the top-level info about this project.
await this.SynchronizeAssetsAsync<object>(
assetPath: AssetPath.ProjectAndDocuments(projectChecksums.ProjectId), checksums, results: null, cancellationToken).ConfigureAwait(false);
await this.SynchronizeAssetsAsync<object, VoidResult>(
assetPath: AssetPath.ProjectAndDocuments(projectChecksums.ProjectId), checksums, callback: null, arg: default, cancellationToken).ConfigureAwait(false);

checksums.Clear();

Expand All @@ -161,8 +162,8 @@ await this.SynchronizeAssetsAsync<object>(
await CollectChecksumChildrenAsync(checksums, projectChecksums.AdditionalDocuments).ConfigureAwait(false);
await CollectChecksumChildrenAsync(checksums, projectChecksums.AnalyzerConfigDocuments).ConfigureAwait(false);

await this.SynchronizeAssetsAsync<object>(
assetPath: AssetPath.ProjectAndDocuments(projectChecksums.ProjectId), checksums, results: null, cancellationToken).ConfigureAwait(false);
await this.SynchronizeAssetsAsync<object, VoidResult>(
assetPath: AssetPath.ProjectAndDocuments(projectChecksums.ProjectId), checksums, callback: null, arg: default, cancellationToken).ConfigureAwait(false);
}

async ValueTask CollectChecksumChildrenAsync(HashSet<Checksum> checksums, ChecksumsAndIds<DocumentId> collection)
Expand All @@ -185,8 +186,8 @@ static void AddAll(HashSet<Checksum> checksums, ChecksumCollection checksumColle
}
}

public async ValueTask SynchronizeAssetsAsync<T>(
AssetPath assetPath, HashSet<Checksum> checksums, Dictionary<Checksum, T>? results, CancellationToken cancellationToken)
public async ValueTask SynchronizeAssetsAsync<T, TArg>(
AssetPath assetPath, HashSet<Checksum> checksums, Action<Checksum, T, TArg>? callback, TArg? arg, CancellationToken cancellationToken)
{
Contract.ThrowIfTrue(checksums.Contains(Checksum.Null));
if (checksums.Count == 0)
Expand All @@ -211,7 +212,7 @@ public async ValueTask SynchronizeAssetsAsync<T>(
{
if (_assetCache.TryGetAsset<T>(checksum, out var existing))
{
AddResult(checksum, existing);
callback?.Invoke(checksum, existing, arg!);
}
else
{
Expand Down Expand Up @@ -239,37 +240,31 @@ public async ValueTask SynchronizeAssetsAsync<T>(
{
var missingChecksumsMemory = new ReadOnlyMemory<Checksum>(missingChecksums, 0, missingChecksumsCount);

var currentIndex = 0;
await RequestAssetsAsync(assetPath, missingChecksumsMemory, (int index, T missingAsset) =>
{
Contract.ThrowIfTrue(currentIndex != index);

var missingChecksum = missingChecksums[index];

AddResult(missingChecksum, missingAsset);
_assetCache.GetOrAdd(missingChecksum, missingAsset!);

currentIndex++;
}, cancellationToken).ConfigureAwait(false);
await RequestAssetsAsync(
assetPath, missingChecksumsMemory,
static (
int index,
T missingAsset,
(AssetProvider assetProvider, Checksum[] missingChecksums, Action<Checksum, T, TArg>? callback, TArg? arg) tuple) =>
{
var missingChecksum = tuple.missingChecksums[index];

Contract.ThrowIfTrue(currentIndex != missingChecksumsCount);
tuple.callback?.Invoke(missingChecksum, missingAsset, tuple.arg!);
tuple.assetProvider._assetCache.GetOrAdd(missingChecksum, missingAsset!);
},
(this, missingChecksums, callback, arg),
cancellationToken).ConfigureAwait(false);
}

if (usePool)
s_checksumPool.Free(missingChecksums);
}

return;

void AddResult(Checksum checksum, T result)
{
if (results != null)
results[checksum] = result;
}
}

private async ValueTask RequestAssetsAsync<T>(
AssetPath assetPath, ReadOnlyMemory<Checksum> checksums, Action<int, T> callback, CancellationToken cancellationToken)
private async ValueTask RequestAssetsAsync<T, TArg>(
AssetPath assetPath, ReadOnlyMemory<Checksum> checksums, Action<int, T, TArg> callback, TArg arg, CancellationToken cancellationToken)
{
#if NETCOREAPP
Contract.ThrowIfTrue(checksums.Span.Contains(Checksum.Null));
Expand All @@ -280,6 +275,6 @@ private async ValueTask RequestAssetsAsync<T>(
if (checksums.Length == 0)
return;

await _assetSource.GetAssetsAsync(_solutionChecksum, assetPath, checksums, _serializerService, callback, cancellationToken).ConfigureAwait(false);
await _assetSource.GetAssetsAsync(_solutionChecksum, assetPath, checksums, _serializerService, callback, arg, cancellationToken).ConfigureAwait(false);
}
}
6 changes: 4 additions & 2 deletions src/Workspaces/Remote/ServiceHub/Host/IAssetSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ namespace Microsoft.CodeAnalysis.Remote;
/// </summary>
internal interface IAssetSource
{
ValueTask GetAssetsAsync<T>(
/// <param name="callback">Will be called back once per checksum in <paramref name="checksums"/> in the exact order of that array.</param>
ValueTask GetAssetsAsync<T, TArg>(
Checksum solutionChecksum,
AssetPath assetPath,
ReadOnlyMemory<Checksum> checksums,
ISerializerService serializerService,
Action<int, T> callback,
Action<int, T, TArg> callback,
TArg arg,
CancellationToken cancellationToken);
}
Loading
Loading