Skip to content

Commit

Permalink
Improve AddImport conflict detection performance (#73780)
Browse files Browse the repository at this point in the history
* Improve add import conflict detection performance

We've recently obtained some traces from customers around ui delays during completion commit in C# files. The first profile I took a look at indicated that the delay was due to committing a snippet, and in particular during the code used to prevent adding conflicting imports.

This code has been improved in 3 ways, all relating to accessing the semantic model (where the *vast* majority of the time was spent):

1) Reduce the number of nodes for which semantic information need be obtained. The code already collects a couple dictionaries keyed on type names which could cause conflicts. This simply changes the code to filter syntax nodes based on this.

2) Change the kind of semantic model used. About half the time used during the binding process was due to the model dealing with work around nullability. We don't need that infomation, so we're good to use the much faster nullabledisabled semantic model.

3) Process all the binding of nodes to their semantic info in parallel. This does a simple collection of nodes of interest, then uses the ProducerConsumer mechanism to obtain semantic information in parallel.

As a minor tweak, I aslo did away with the TreeWalker derivation, instead just doing a simple walk over the results of GetDescendantNodes. Should be more performant, and I find it a bit easier to understand when the needs are this simple.
  • Loading branch information
ToddGrun authored May 30, 2024
1 parent 6f2f8aa commit b55a71a
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 164 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8508,6 +8508,8 @@ void M(nint i, nuint i2)
}
""",
"""
using System;
class Class
{
void M(nint i, nuint i2)
Expand All @@ -8517,7 +8519,7 @@ void M(nint i, nuint i2)
private (nint, nuint) NewMethod(nint i, nuint i2)
{
throw new System.NotImplementedException();
throw new NotImplementedException();
}
}
""");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1347,14 +1347,16 @@ End Class")
[|[Me]|]([string])
End Sub
End Module",
"Module Program
"Imports System

Module Program
Sub Main(args As String())
Dim [string] As String = ""hello""
[Me]([string])
End Sub

Private Sub [Me]([string] As String)
Throw New System.NotImplementedException()
Throw New NotImplementedException()
End Sub
End Module")
End Function
Expand Down
187 changes: 110 additions & 77 deletions src/Workspaces/CSharp/Portable/Editing/CSharpImportAdder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
using System.Collections.Immutable;
using System.Composition;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Host.Mef;
using Microsoft.CodeAnalysis.PooledObjects;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Shared.Utilities;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.CSharp.Editing;
Expand All @@ -39,20 +42,19 @@ public CSharpImportAdder()
return null;
}

protected override void AddPotentiallyConflictingImports(
protected override Task AddPotentiallyConflictingImportsAsync(
SemanticModel model,
SyntaxNode container,
ImmutableArray<INamespaceSymbol> namespaceSymbols,
HashSet<INamespaceSymbol> conflicts,
CancellationToken cancellationToken)
{
var rewriter = new ConflictWalker(model, namespaceSymbols, conflicts, cancellationToken);
rewriter.Visit(container);
var conflictFinder = new ConflictFinder(model, namespaceSymbols);
return conflictFinder.AddPotentiallyConflictingImportsAsync(container, conflicts, cancellationToken);
}

private static INamespaceSymbol? GetExplicitNamespaceSymbol(ExpressionSyntax fullName, ExpressionSyntax namespacePart, SemanticModel model)
{

// name must refer to something that is not a namespace, but be qualified with a namespace.
var symbol = model.GetSymbolInfo(fullName).Symbol;
if (symbol != null && symbol.Kind != SymbolKind.Namespace && model.GetSymbolInfo(namespacePart).Symbol is INamespaceSymbol)
Expand All @@ -76,10 +78,9 @@ protected override void AddPotentiallyConflictingImports(
/// no users being hit, then that's far less important than if we have a reasonable coding pattern that would be
/// impacted by adding an import to a normal namespace.
/// </summary>
private class ConflictWalker : CSharpSyntaxWalker
private class ConflictFinder
{
private readonly SemanticModel _model;
private readonly CancellationToken _cancellationToken;

/// <summary>
/// A mapping containing the simple names and arity of all imported types, mapped to the import that they're
Expand All @@ -99,25 +100,11 @@ private class ConflictWalker : CSharpSyntaxWalker
/// </remarks>
private readonly MultiDictionary<string, INamespaceSymbol> _importedExtensionMethods = [];

private readonly HashSet<INamespaceSymbol> _conflictNamespaces;

/// <summary>
/// Track if we're in an anonymous method or not. If so, because of how the language binds lambdas and
/// overloads, we'll assume any method access we see inside (instance or otherwise) could end up conflicting
/// with an extension method we might pull in.
/// </summary>
private bool _inAnonymousMethod;

public ConflictWalker(
public ConflictFinder(
SemanticModel model,
ImmutableArray<INamespaceSymbol> namespaceSymbols,
HashSet<INamespaceSymbol> conflictNamespaces,
CancellationToken cancellationToken)
: base(SyntaxWalkerDepth.StructuredTrivia)
ImmutableArray<INamespaceSymbol> namespaceSymbols)
{
_model = model;
_cancellationToken = cancellationToken;
_conflictNamespaces = conflictNamespaces;

AddImportedMembers(namespaceSymbols);
}
Expand All @@ -142,91 +129,137 @@ private void AddImportedMembers(ImmutableArray<INamespaceSymbol> namespaceSymbol
}
}

public override void VisitSimpleLambdaExpression(SimpleLambdaExpressionSyntax node)
public async Task AddPotentiallyConflictingImportsAsync(SyntaxNode container, HashSet<INamespaceSymbol> conflicts, CancellationToken cancellationToken)
{
// lambdas are interesting. Say you have:
//
// Goo(x => x.M());
//
// void Goo(Action<C> act) { }
// void Goo(Action<int> act) { }
//
// class C { public void M() { } }
//
// This is legal code where the lambda body is calling the instance method. However, if we introduce a
// using that brings in an extension method 'M' on 'int', then the above will become ambiguous. This is
// because lambda binding will try each interpretation separately and eliminate the ones that fail.
// Adding the import will make the int form succeed, causing ambiguity.
//
// To deal with that, we keep track of if we're in a lambda, and we conservatively assume that a method
// access (even to a non-extension method) could conflict with an extension method brought in.

var previousInAnonymousMethod = _inAnonymousMethod;
_inAnonymousMethod = true;
base.VisitSimpleLambdaExpression(node);
_inAnonymousMethod = previousInAnonymousMethod;
}
using var _ = ArrayBuilder<SyntaxNode>.GetInstance(out var nodes);

public override void VisitParenthesizedLambdaExpression(ParenthesizedLambdaExpressionSyntax node)
{
var previousInAnonymousMethod = _inAnonymousMethod;
_inAnonymousMethod = true;
base.VisitParenthesizedLambdaExpression(node);
_inAnonymousMethod = previousInAnonymousMethod;
CollectInfoFromContainer(container, nodes, out var containsAnonymousMethods);

await ProducerConsumer<INamespaceSymbol>.RunParallelAsync(
source: nodes,
produceItems: static (node, onItemsFound, args, cancellationToken) =>
{
var (self, containsAnonymousMethods, _) = args;
if (node is SimpleNameSyntax nameSyntaxNode)
self.ProduceConflicts(nameSyntaxNode, onItemsFound, cancellationToken);
else if (node is MemberAccessExpressionSyntax memberAccessExpressionNode)
self.ProduceConflicts(memberAccessExpressionNode, containsAnonymousMethods, onItemsFound, cancellationToken);
else
throw ExceptionUtilities.Unreachable();

return Task.CompletedTask;
},
consumeItems: static async (items, args, cancellationToken) =>
{
var (_, _, conflicts) = args;
await foreach (var conflict in items)
conflicts.Add(conflict);
},
args: (self: this, containsAnonymousMethods, conflicts),
cancellationToken).ConfigureAwait(false);
}

public override void VisitAnonymousMethodExpression(AnonymousMethodExpressionSyntax node)
private void CollectInfoFromContainer(SyntaxNode container, ArrayBuilder<SyntaxNode> nodes, out bool containsAnonymousMethods)
{
var previousInAnonymousMethod = _inAnonymousMethod;
_inAnonymousMethod = true;
base.VisitAnonymousMethodExpression(node);
_inAnonymousMethod = previousInAnonymousMethod;
containsAnonymousMethods = false;

foreach (var node in container.DescendantNodesAndSelf())
{
switch (node.Kind())
{
case SyntaxKind.IdentifierName:
case SyntaxKind.GenericName:
if (IsPotentialConflictWithImportedType((SimpleNameSyntax)node))
nodes.Add(node);
break;
case SyntaxKind.SimpleMemberAccessExpression:
case SyntaxKind.PointerMemberAccessExpression:
if (IsPotentialConflictWithImportedExtensionMethod((MemberAccessExpressionSyntax)node))
nodes.Add(node);
break;
case SyntaxKind.SimpleLambdaExpression:
case SyntaxKind.ParenthesizedLambdaExpression:
case SyntaxKind.AnonymousMethodExpression:
// Track if we've seen an anonymous method or not. If so, because of how the language binds lambdas and
// overloads, we'll assume any method access we see inside (instance or otherwise) could end up conflicting
// with an extension method we might pull in.
containsAnonymousMethods = true;
break;
}
}
}

private void CheckName(NameSyntax node)
private bool IsPotentialConflictWithImportedType(SimpleNameSyntax node)
{
// Check to see if we have an standalone identifier (or identifier on the left of a dot). If so, if that
// identifier binds to a type, then we don't want to bring in any imports that would bring in the same
// name and could then potentially conflict here.

if (node.IsRightSideOfDotOrArrowOrColonColon())
return;
return false;

// Check to see if we have a var. If so, then nothing assigned to a var
// would bring any imports that could cause a potential conflict.
if (node.IsVar)
return;
return false;

var symbol = _model.GetSymbolInfo(node, _cancellationToken).GetAnySymbol();
if (symbol?.Kind == SymbolKind.NamedType)
_conflictNamespaces.AddRange(_importedTypes[(symbol.Name, node.Arity)]);
}
// Drastically reduce the number of nodes that need to be inspected by filtering
// out nodes whose identifier isn't a potential conflict.
if (!_importedTypes.ContainsKey((node.Identifier.Text, node.Arity)))
return false;

public override void VisitIdentifierName(IdentifierNameSyntax node)
{
base.VisitIdentifierName(node);
CheckName(node);
return true;
}

public override void VisitGenericName(GenericNameSyntax node)
private bool IsPotentialConflictWithImportedExtensionMethod(MemberAccessExpressionSyntax node)
=> _importedExtensionMethods.ContainsKey(node.Name.Identifier.Text);

private void ProduceConflicts(SimpleNameSyntax node, Action<INamespaceSymbol> addConflict, CancellationToken cancellationToken)
{
base.VisitGenericName(node);
CheckName(node);
var symbol = _model.GetSymbolInfo(node, cancellationToken).GetAnySymbol();
if (symbol?.Kind == SymbolKind.NamedType)
{
foreach (var conflictingSymbol in _importedTypes[(symbol.Name, node.Arity)])
addConflict(conflictingSymbol);
}
}

public override void VisitMemberAccessExpression(MemberAccessExpressionSyntax node)
private void ProduceConflicts(MemberAccessExpressionSyntax node, bool containsAnonymousMethods, Action<INamespaceSymbol> addConflict, CancellationToken cancellationToken)
{
base.VisitMemberAccessExpression(node);

// Check to see if we have a reference to an extension method. If so, then pulling in an import could
// bring in an extension that conflicts with that.

var symbol = _model.GetSymbolInfo(node.Name, _cancellationToken).GetAnySymbol();
var symbol = _model.GetSymbolInfo(node.Name, cancellationToken).GetAnySymbol();
if (symbol is IMethodSymbol method)
{
// see explanation in VisitSimpleLambdaExpression for the _inAnonymousMethod check
if (method.IsReducedExtension() || _inAnonymousMethod)
_conflictNamespaces.AddRange(_importedExtensionMethods[method.Name]);
var isConflicting = method.IsReducedExtension();

if (!isConflicting && containsAnonymousMethods)
{
// lambdas are interesting. Say you have:
//
// Goo(x => x.M());
//
// void Goo(Action<C> act) { }
// void Goo(Action<int> act) { }
//
// class C { public void M() { } }
//
// This is legal code where the lambda body is calling the instance method. However, if we introduce a
// using that brings in an extension method 'M' on 'int', then the above will become ambiguous. This is
// because lambda binding will try each interpretation separately and eliminate the ones that fail.
// Adding the import will make the int form succeed, causing ambiguity.
//
// To deal with that, we keep track of if we're in a lambda, and we conservatively assume that a method
// access (even to a non-extension method) could conflict with an extension method brought in.
isConflicting = node.HasAncestor<AnonymousFunctionExpressionSyntax>();
}

if (isConflicting)
{
foreach (var conflictingSymbol in _importedExtensionMethods[method.Name])
addConflict(conflictingSymbol);
}
}
}
}
Expand Down
24 changes: 14 additions & 10 deletions src/Workspaces/Core/Portable/Editing/ImportAdderService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ public async Task<Document> AddImportsAsync(
// Create a simple interval tree for simplification spans.
var spansTree = new TextSpanIntervalTree(spans);

Func<SyntaxNode, bool> overlapsWithSpan = n => spansTree.HasIntervalThatOverlapsWith(n.FullSpan.Start, n.FullSpan.Length);

// Only dive deeper into nodes that actually overlap with the span we care about. And also only include
// those child nodes that themselves overlap with the span. i.e. if we have:
//
Expand All @@ -55,7 +53,7 @@ public async Task<Document> AddImportsAsync(
//
// We'll dive under the parent because it overlaps with the span. But we only want to include (and dive
// into) B and C not A and D.
var nodes = root.DescendantNodesAndSelf(overlapsWithSpan).Where(overlapsWithSpan);
var nodes = root.DescendantNodesAndSelf(OverlapsWithSpan).Where(OverlapsWithSpan);

if (strategy == Strategy.AddImportsFromSymbolAnnotations)
return await AddImportDirectivesFromSymbolAnnotationsAsync(document, nodes, addImportsService, generator, options, cancellationToken).ConfigureAwait(false);
Expand All @@ -64,19 +62,21 @@ public async Task<Document> AddImportsAsync(
return await AddImportDirectivesFromSyntaxesAsync(document, nodes, addImportsService, generator, options, cancellationToken).ConfigureAwait(false);

throw ExceptionUtilities.UnexpectedValue(strategy);

bool OverlapsWithSpan(SyntaxNode n) => spansTree.HasIntervalThatOverlapsWith(n.FullSpan.Start, n.FullSpan.Length);
}

protected abstract INamespaceSymbol? GetExplicitNamespaceSymbol(SyntaxNode node, SemanticModel model);

private ISet<INamespaceSymbol> GetSafeToAddImports(
private async Task<ISet<INamespaceSymbol>> GetSafeToAddImportsAsync(
ImmutableArray<INamespaceSymbol> namespaceSymbols,
SyntaxNode container,
SemanticModel model,
CancellationToken cancellationToken)
{
using var _ = PooledHashSet<INamespaceSymbol>.GetInstance(out var conflicts);
AddPotentiallyConflictingImports(
model, container, namespaceSymbols, conflicts, cancellationToken);
await AddPotentiallyConflictingImportsAsync(
model, container, namespaceSymbols, conflicts, cancellationToken).ConfigureAwait(false);
return namespaceSymbols.Except(conflicts).ToSet();
}

Expand All @@ -86,7 +86,7 @@ private ISet<INamespaceSymbol> GetSafeToAddImports(
/// <paramref name="container"/> is the node that the import will be added to. This will either be the
/// compilation-unit node, or one of the namespace-blocks in the file.
/// </summary>
protected abstract void AddPotentiallyConflictingImports(
protected abstract Task AddPotentiallyConflictingImportsAsync(
SemanticModel model,
SyntaxNode container,
ImmutableArray<INamespaceSymbol> namespaceSymbols,
Expand Down Expand Up @@ -174,7 +174,7 @@ private async Task<Document> AddImportDirectivesFromSymbolAnnotationsAsync(
using var _ = PooledDictionary<INamespaceSymbol, SyntaxNode>.GetInstance(out var importToSyntax);

var root = await document.GetRequiredSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var model = await document.GetRequiredSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var model = await document.GetRequiredNullableDisabledSemanticModelAsync(cancellationToken).ConfigureAwait(false);

SyntaxNode? first = null, last = null;
var annotatedNodes = syntaxNodes.Where(x => x.HasAnnotations(SymbolAnnotation.Kind));
Expand All @@ -183,7 +183,7 @@ private async Task<Document> AddImportDirectivesFromSymbolAnnotationsAsync(
{
cancellationToken.ThrowIfCancellationRequested();

if (annotatedNode.GetAnnotations(DoNotAddImportsAnnotation.Kind).Any())
if (annotatedNode.HasAnnotations(DoNotAddImportsAnnotation.Kind))
continue;

var annotations = annotatedNode.GetAnnotations(SymbolAnnotation.Kind);
Expand Down Expand Up @@ -228,7 +228,11 @@ private async Task<Document> AddImportDirectivesFromSymbolAnnotationsAsync(
var importContainer = addImportsService.GetImportContainer(root, context, importToSyntax.First().Value, options);

// Now remove any imports we think can cause conflicts in that container.
var safeImportsToAdd = GetSafeToAddImports([.. importToSyntax.Keys], importContainer, model, cancellationToken);
var safeImportsToAdd = await GetSafeToAddImportsAsync(
[.. importToSyntax.Keys],
importContainer,
model,
cancellationToken).ConfigureAwait(false);

var importsToAdd = importToSyntax.Where(kvp => safeImportsToAdd.Contains(kvp.Key)).Select(kvp => kvp.Value).ToImmutableArray();
if (importsToAdd.Length == 0)
Expand Down
Loading

0 comments on commit b55a71a

Please sign in to comment.