Skip to content

Commit

Permalink
Merge pull request #72944 from CyrusNajmabadi/avoidArray2
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrusNajmabadi authored Apr 9, 2024
2 parents 8438e56 + b54747d commit b67df4c
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,13 @@ private sealed class AssetProvider(SerializationValidator validator) : AbstractA
public override async ValueTask<T> GetAssetAsync<T>(AssetPath assetPath, Checksum checksum, CancellationToken cancellationToken)
=> await validator.GetValueAsync<T>(checksum).ConfigureAwait(false);

public override async ValueTask<ImmutableArray<(Checksum checksum, T asset)>> GetAssetsAsync<T>(AssetPath assetPath, HashSet<Checksum> checksums, CancellationToken cancellationToken)
public override async ValueTask GetAssetsAsync<T, TArg>(AssetPath assetPath, HashSet<Checksum> checksums, Action<Checksum, T, TArg> callback, TArg arg, CancellationToken cancellationToken)
{
using var _ = ArrayBuilder<(Checksum checksum, T asset)>.GetInstance(out var result);

foreach (var checksum in checksums)
{
var value = await GetAssetAsync<T>(assetPath, checksum, cancellationToken).ConfigureAwait(false);
result.Add((checksum, value));
callback(checksum, value, arg);
}

return result.ToImmutable();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ private static async Task TestAssetAsync(object data)
var stored = await provider.GetAssetAsync<object>(AssetPath.FullLookupForTesting, checksum, CancellationToken.None);
Assert.Equal(data, stored);

var stored2 = await provider.GetAssetsAsync<object>(AssetPath.FullLookupForTesting, new HashSet<Checksum> { checksum }, CancellationToken.None);
Assert.Equal(1, stored2.Length);
var stored2 = new List<(Checksum, object)>();
await provider.GetAssetsAsync<object, VoidResult>(AssetPath.FullLookupForTesting, new HashSet<Checksum> { checksum }, (checksum, asset, _) => stored2.Add((checksum, asset)), default, CancellationToken.None);
Assert.Equal(1, stored2.Count);

Assert.Equal(checksum, stored2[0].Item1);
Assert.Equal(data, stored2[0].Item2);
Expand All @@ -83,7 +84,7 @@ public async Task TestAssetSynchronization()
var assetSource = new SimpleAssetSource(workspace.Services.GetService<ISerializerService>(), map);

var service = new AssetProvider(sessionId, storage, assetSource, remoteWorkspace.Services.GetService<ISerializerService>());
await service.SynchronizeAssetsAsync<object>(AssetPath.FullLookupForTesting, new HashSet<Checksum>(map.Keys), results: null, CancellationToken.None);
await service.SynchronizeAssetsAsync<object, VoidResult>(AssetPath.FullLookupForTesting, new HashSet<Checksum>(map.Keys), callback: null, arg: default, CancellationToken.None);

foreach (var kv in map)
{
Expand Down
6 changes: 3 additions & 3 deletions src/Workspaces/CoreTestUtilities/Fakes/SimpleAssetSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ namespace Microsoft.CodeAnalysis.Remote.Testing;
/// </summary>
internal sealed class SimpleAssetSource(ISerializerService serializerService, IReadOnlyDictionary<Checksum, object> map) : IAssetSource
{
public ValueTask GetAssetsAsync<T>(
Checksum solutionChecksum, AssetPath assetPath, ReadOnlyMemory<Checksum> checksums, ISerializerService deserializerService, Action<int, T> callback, CancellationToken cancellationToken)
public ValueTask GetAssetsAsync<T, TArg>(
Checksum solutionChecksum, AssetPath assetPath, ReadOnlyMemory<Checksum> checksums, ISerializerService deserializerService, Action<int, T, TArg> callback, TArg arg, CancellationToken cancellationToken)
{
var index = 0;
foreach (var checksum in checksums.Span)
Expand All @@ -38,7 +38,7 @@ public ValueTask GetAssetsAsync<T>(
using var reader = ObjectReader.GetReader(stream, leaveOpen: true, cancellationToken);
var asset = deserializerService.Deserialize(data.GetWellKnownSynchronizationKind(), reader, cancellationToken);
Contract.ThrowIfNull(asset);
callback(index, (T)asset);
callback(index, (T)asset, arg);
index++;
}

Expand Down
16 changes: 12 additions & 4 deletions src/Workspaces/Remote/Core/AbstractAssetProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Threading;
Expand All @@ -22,7 +23,7 @@ internal abstract class AbstractAssetProvider
/// return data of type T whose checksum is the given checksum
/// </summary>
public abstract ValueTask<T> GetAssetAsync<T>(AssetPath assetPath, Checksum checksum, CancellationToken cancellationToken);
public abstract ValueTask<ImmutableArray<(Checksum checksum, T asset)>> GetAssetsAsync<T>(AssetPath assetPath, HashSet<Checksum> checksums, CancellationToken cancellationToken);
public abstract ValueTask GetAssetsAsync<T, TArg>(AssetPath assetPath, HashSet<Checksum> checksums, Action<Checksum, T, TArg> callback, TArg arg, CancellationToken cancellationToken);

public async Task<SolutionInfo> CreateSolutionInfoAsync(Checksum solutionChecksum, CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -107,10 +108,17 @@ public async Task<DocumentInfo> CreateDocumentInfoAsync(
public async Task<ImmutableArray<T>> GetAssetsAsync<T>(
AssetPath assetPath, ChecksumCollection checksums, CancellationToken cancellationToken) where T : class
{
using var _ = PooledHashSet<Checksum>.GetInstance(out var checksumSet);
using var _1 = PooledHashSet<Checksum>.GetInstance(out var checksumSet);
checksumSet.AddAll(checksums.Children);

var results = await this.GetAssetsAsync<T>(assetPath, checksumSet, cancellationToken).ConfigureAwait(false);
return results.SelectAsArray(static t => t.asset);
using var _2 = ArrayBuilder<T>.GetInstance(checksumSet.Count, out var builder);

await this.GetAssetsAsync<T, ArrayBuilder<T>>(
assetPath, checksumSet,
static (checksum, asset, builder) => builder.Add(asset),
builder,
cancellationToken).ConfigureAwait(false);

return builder.ToImmutableAndClear();
}
}
19 changes: 8 additions & 11 deletions src/Workspaces/Remote/Core/RemoteHostAssetSerialization.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.IO;
using System.IO.Pipelines;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Serialization;
using Nerdbank.Streams;
using Roslyn.Utilities;
Expand Down Expand Up @@ -67,8 +64,8 @@ static void WriteAsset(ObjectWriter writer, ISerializerService serializer, Solut
}
}

public static ValueTask ReadDataAsync<T>(
PipeReader pipeReader, Checksum solutionChecksum, int objectCount, ISerializerService serializerService, Action<int, T> callback, CancellationToken cancellationToken)
public static ValueTask ReadDataAsync<T, TArg>(
PipeReader pipeReader, Checksum solutionChecksum, int objectCount, ISerializerService serializerService, Action<int, T, TArg> callback, TArg arg, CancellationToken cancellationToken)
{
// Suppress ExecutionContext flow for asynchronous operations operate on the pipe. In addition to avoiding
// ExecutionContext allocations, this clears the LogicalCallContext and avoids the need to clone data set by
Expand All @@ -77,18 +74,18 @@ public static ValueTask ReadDataAsync<T>(
// ⚠ DO NOT AWAIT INSIDE THE USING. The Dispose method that restores ExecutionContext flow must run on the
// same thread where SuppressFlow was originally run.
using var _ = FlowControlHelper.TrySuppressFlow();
return ReadDataSuppressedFlowAsync(pipeReader, solutionChecksum, objectCount, serializerService, callback, cancellationToken);
return ReadDataSuppressedFlowAsync(pipeReader, solutionChecksum, objectCount, serializerService, callback, arg, cancellationToken);

static async ValueTask ReadDataSuppressedFlowAsync(
PipeReader pipeReader, Checksum solutionChecksum, int objectCount, ISerializerService serializerService, Action<int, T> callback, CancellationToken cancellationToken)
PipeReader pipeReader, Checksum solutionChecksum, int objectCount, ISerializerService serializerService, Action<int, T, TArg> callback, TArg arg, CancellationToken cancellationToken)
{
using var stream = await pipeReader.AsPrebufferedStreamAsync(cancellationToken).ConfigureAwait(false);
ReadData<T>(stream, solutionChecksum, objectCount, serializerService, callback, cancellationToken);
ReadData(stream, solutionChecksum, objectCount, serializerService, callback, arg, cancellationToken);
}
}

public static void ReadData<T>(
Stream stream, Checksum solutionChecksum, int objectCount, ISerializerService serializerService, Action<int, T> callback, CancellationToken cancellationToken)
public static void ReadData<T, TArg>(
Stream stream, Checksum solutionChecksum, int objectCount, ISerializerService serializerService, Action<int, T, TArg> callback, TArg arg, CancellationToken cancellationToken)
{
using var reader = ObjectReader.GetReader(stream, leaveOpen: true, cancellationToken);

Expand All @@ -104,7 +101,7 @@ public static void ReadData<T>(
// in service hub, cancellation means simply closed stream
var result = serializerService.Deserialize(kind, reader, cancellationToken);
Contract.ThrowIfNull(result);
callback(i, (T)result);
callback(i, (T)result, arg);
}
}
}
Expand Down
96 changes: 44 additions & 52 deletions src/Workspaces/Remote/ServiceHub/Host/AssetProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,21 @@ public override async ValueTask<T> GetAssetAsync<T>(
using var _1 = PooledHashSet<Checksum>.GetInstance(out var checksums);
checksums.Add(checksum);

using var _2 = PooledDictionary<Checksum, T>.GetInstance(out var results);
await this.SynchronizeAssetsAsync(assetPath, checksums, results, cancellationToken).ConfigureAwait(false);
using var _2 = ArrayBuilder<T>.GetInstance(1, out var builder);
await this.SynchronizeAssetsAsync<T, ArrayBuilder<T>>(
assetPath, checksums,
static (_, asset, builder) => builder.Add(asset),
builder, cancellationToken).ConfigureAwait(false);

return results[checksum];
Contract.ThrowIfTrue(builder.Count != 1);

return builder[0];
}

public override async ValueTask<ImmutableArray<(Checksum checksum, T asset)>> GetAssetsAsync<T>(
AssetPath assetPath, HashSet<Checksum> checksums, CancellationToken cancellationToken)
public override async ValueTask GetAssetsAsync<T, TArg>(
AssetPath assetPath, HashSet<Checksum> checksums, Action<Checksum, T, TArg> callback, TArg arg, CancellationToken cancellationToken)
{
using var _ = PooledDictionary<Checksum, T>.GetInstance(out var results);

await this.SynchronizeAssetsAsync(assetPath, checksums, results, cancellationToken).ConfigureAwait(false);

var result = new (Checksum checksum, T asset)[checksums.Count];
var index = 0;
foreach (var (checksum, assetObject) in results)
{
result[index] = (checksum, assetObject);
index++;
}

return ImmutableCollectionsMarshal.AsImmutableArray(result);
await this.SynchronizeAssetsAsync(assetPath, checksums, callback, arg, cancellationToken).ConfigureAwait(false);
}

public async ValueTask SynchronizeSolutionAssetsAsync(Checksum solutionChecksum, CancellationToken cancellationToken)
Expand Down Expand Up @@ -88,17 +81,15 @@ public async ValueTask SynchronizeSolutionAssetsAsync(Checksum solutionChecksum,

async ValueTask SynchronizeSolutionAssetsWorkerAsync()
{
using var _1 = PooledDictionary<Checksum, object>.GetInstance(out var checksumToObjects);

// first, get top level solution state for the given solution checksum
var compilationStateChecksums = await this.GetAssetAsync<SolutionCompilationStateChecksums>(
assetPath: AssetPath.SolutionOnly, solutionChecksum, cancellationToken).ConfigureAwait(false);

using var _2 = PooledHashSet<Checksum>.GetInstance(out var checksums);
using var _1 = PooledHashSet<Checksum>.GetInstance(out var checksums);

// second, get direct children of the solution compilation state.
compilationStateChecksums.AddAllTo(checksums);
await this.SynchronizeAssetsAsync<object>(assetPath: AssetPath.SolutionOnly, checksums, results: null, cancellationToken).ConfigureAwait(false);
await this.SynchronizeAssetsAsync<object, VoidResult>(assetPath: AssetPath.SolutionOnly, checksums, callback: null, arg: default, cancellationToken).ConfigureAwait(false);

// third, get direct children of the solution state.
var stateChecksums = await this.GetAssetAsync<SolutionStateChecksums>(
Expand All @@ -108,7 +99,14 @@ async ValueTask SynchronizeSolutionAssetsWorkerAsync()
// the project states and we want to get that all in one batch.
checksums.Clear();
stateChecksums.AddAllTo(checksums);
await this.SynchronizeAssetsAsync(assetPath: AssetPath.SolutionAndTopLevelProjectsOnly, checksums, checksumToObjects, cancellationToken).ConfigureAwait(false);

using var _2 = PooledDictionary<Checksum, object>.GetInstance(out var checksumToObjects);

await this.SynchronizeAssetsAsync<object, Dictionary<Checksum, object>>(
assetPath: AssetPath.SolutionAndTopLevelProjectsOnly,
checksums,
static (checksum, asset, checksumToObjects) => checksumToObjects.Add(checksum, asset),
arg: checksumToObjects, cancellationToken).ConfigureAwait(false);

// fourth, get all projects and documents in the solution
foreach (var (projectChecksum, _) in stateChecksums.Projects)
Expand Down Expand Up @@ -151,8 +149,8 @@ async ValueTask SynchronizeProjectAssetsWorkerAsync()
AddAll(checksums, projectChecksums.AnalyzerConfigDocuments.Checksums);

// First synchronize all the top-level info about this project.
await this.SynchronizeAssetsAsync<object>(
assetPath: AssetPath.ProjectAndDocuments(projectChecksums.ProjectId), checksums, results: null, cancellationToken).ConfigureAwait(false);
await this.SynchronizeAssetsAsync<object, VoidResult>(
assetPath: AssetPath.ProjectAndDocuments(projectChecksums.ProjectId), checksums, callback: null, arg: default, cancellationToken).ConfigureAwait(false);

checksums.Clear();

Expand All @@ -161,8 +159,8 @@ await this.SynchronizeAssetsAsync<object>(
await CollectChecksumChildrenAsync(checksums, projectChecksums.AdditionalDocuments).ConfigureAwait(false);
await CollectChecksumChildrenAsync(checksums, projectChecksums.AnalyzerConfigDocuments).ConfigureAwait(false);

await this.SynchronizeAssetsAsync<object>(
assetPath: AssetPath.ProjectAndDocuments(projectChecksums.ProjectId), checksums, results: null, cancellationToken).ConfigureAwait(false);
await this.SynchronizeAssetsAsync<object, VoidResult>(
assetPath: AssetPath.ProjectAndDocuments(projectChecksums.ProjectId), checksums, callback: null, arg: default, cancellationToken).ConfigureAwait(false);
}

async ValueTask CollectChecksumChildrenAsync(HashSet<Checksum> checksums, ChecksumsAndIds<DocumentId> collection)
Expand All @@ -185,8 +183,8 @@ static void AddAll(HashSet<Checksum> checksums, ChecksumCollection checksumColle
}
}

public async ValueTask SynchronizeAssetsAsync<T>(
AssetPath assetPath, HashSet<Checksum> checksums, Dictionary<Checksum, T>? results, CancellationToken cancellationToken)
public async ValueTask SynchronizeAssetsAsync<T, TArg>(
AssetPath assetPath, HashSet<Checksum> checksums, Action<Checksum, T, TArg>? callback, TArg? arg, CancellationToken cancellationToken)
{
Contract.ThrowIfTrue(checksums.Contains(Checksum.Null));
if (checksums.Count == 0)
Expand All @@ -211,7 +209,7 @@ public async ValueTask SynchronizeAssetsAsync<T>(
{
if (_assetCache.TryGetAsset<T>(checksum, out var existing))
{
AddResult(checksum, existing);
callback?.Invoke(checksum, existing, arg!);
}
else
{
Expand Down Expand Up @@ -239,37 +237,31 @@ public async ValueTask SynchronizeAssetsAsync<T>(
{
var missingChecksumsMemory = new ReadOnlyMemory<Checksum>(missingChecksums, 0, missingChecksumsCount);

var currentIndex = 0;
await RequestAssetsAsync(assetPath, missingChecksumsMemory, (int index, T missingAsset) =>
{
Contract.ThrowIfTrue(currentIndex != index);
var missingChecksum = missingChecksums[index];
AddResult(missingChecksum, missingAsset);
_assetCache.GetOrAdd(missingChecksum, missingAsset!);
currentIndex++;
}, cancellationToken).ConfigureAwait(false);
await RequestAssetsAsync(
assetPath, missingChecksumsMemory,
static (
int index,
T missingAsset,
(AssetProvider assetProvider, Checksum[] missingChecksums, Action<Checksum, T, TArg>? callback, TArg? arg) tuple) =>
{
var missingChecksum = tuple.missingChecksums[index];
Contract.ThrowIfTrue(currentIndex != missingChecksumsCount);
tuple.callback?.Invoke(missingChecksum, missingAsset, tuple.arg!);
tuple.assetProvider._assetCache.GetOrAdd(missingChecksum, missingAsset!);
},
(this, missingChecksums, callback, arg),
cancellationToken).ConfigureAwait(false);
}

if (usePool)
s_checksumPool.Free(missingChecksums);
}

return;

void AddResult(Checksum checksum, T result)
{
if (results != null)
results[checksum] = result;
}
}

private async ValueTask RequestAssetsAsync<T>(
AssetPath assetPath, ReadOnlyMemory<Checksum> checksums, Action<int, T> callback, CancellationToken cancellationToken)
private async ValueTask RequestAssetsAsync<T, TArg>(
AssetPath assetPath, ReadOnlyMemory<Checksum> checksums, Action<int, T, TArg> callback, TArg arg, CancellationToken cancellationToken)
{
#if NETCOREAPP
Contract.ThrowIfTrue(checksums.Span.Contains(Checksum.Null));
Expand All @@ -280,6 +272,6 @@ private async ValueTask RequestAssetsAsync<T>(
if (checksums.Length == 0)
return;

await _assetSource.GetAssetsAsync(_solutionChecksum, assetPath, checksums, _serializerService, callback, cancellationToken).ConfigureAwait(false);
await _assetSource.GetAssetsAsync(_solutionChecksum, assetPath, checksums, _serializerService, callback, arg, cancellationToken).ConfigureAwait(false);
}
}
6 changes: 4 additions & 2 deletions src/Workspaces/Remote/ServiceHub/Host/IAssetSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ namespace Microsoft.CodeAnalysis.Remote;
/// </summary>
internal interface IAssetSource
{
ValueTask GetAssetsAsync<T>(
/// <param name="callback">Will be called back once per checksum in <paramref name="checksums"/> in the exact order of that array.</param>
ValueTask GetAssetsAsync<T, TArg>(
Checksum solutionChecksum,
AssetPath assetPath,
ReadOnlyMemory<Checksum> checksums,
ISerializerService serializerService,
Action<int, T> callback,
Action<int, T, TArg> callback,
TArg arg,
CancellationToken cancellationToken);
}
Loading

0 comments on commit b67df4c

Please sign in to comment.