Skip to content

Commit

Permalink
Merge pull request #1078 from akshita31/order_codeProviders
Browse files Browse the repository at this point in the history
Order code actions
  • Loading branch information
filipw authored Jan 10, 2018
2 parents d5e8d70 + 7ba0db9 commit e00f370
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,19 @@ public abstract class BaseCodeActionService<TRequest, TResponse> : IRequestHandl

private static readonly Func<TextSpan, List<Diagnostic>> s_createDiagnosticList = _ => new List<Diagnostic>();

protected Lazy<List<CodeFixProvider>> OrderedCodeFixProviders;
protected Lazy<List<CodeRefactoringProvider>> OrderedCodeRefactoringProviders;

protected BaseCodeActionService(OmniSharpWorkspace workspace, CodeActionHelper helper, IEnumerable<ICodeActionProvider> providers, ILogger logger)
{
this.Workspace = workspace;
this.Providers = providers;
this.Logger = logger;
this._helper = helper;

OrderedCodeFixProviders = new Lazy<List<CodeFixProvider>>(() => GetSortedCodeFixProviders());
OrderedCodeRefactoringProviders = new Lazy<List<CodeRefactoringProvider>>(() => GetSortedCodeRefactoringProviders());

// Sadly, the CodeAction.NestedCodeActions property is still internal.
var nestedCodeActionsProperty = typeof(CodeAction).GetProperty("NestedCodeActions", BindingFlags.NonPublic | BindingFlags.Instance);
if (nestedCodeActionsProperty == null)
Expand Down Expand Up @@ -69,9 +75,6 @@ protected async Task<IEnumerable<AvailableCodeAction>> GetAvailableCodeActions(I
await CollectCodeFixesActions(document, span, codeActions);
await CollectRefactoringActions(document, span, codeActions);

// TODO: Determine good way to order code actions.
codeActions.Reverse();

// Be sure to filter out any code actions that inherit from CodeActionWithOptions.
// This isn't a great solution and might need changing later, but every Roslyn code action
// derived from this type tries to display a dialog. For now, this is a reasonable solution.
Expand Down Expand Up @@ -128,28 +131,67 @@ private async Task CollectCodeFixesActions(Document document, TextSpan span, Lis

private async Task AppendFixesAsync(Document document, TextSpan span, IEnumerable<Diagnostic> diagnostics, List<CodeAction> codeActions)
{
foreach (var provider in this.Providers)
foreach (var codeFixProvider in OrderedCodeFixProviders.Value)
{
foreach (var codeFixProvider in provider.CodeFixProviders)
var fixableDiagnostics = diagnostics.Where(d => HasFix(codeFixProvider, d.Id)).ToImmutableArray();
if (fixableDiagnostics.Length > 0)
{
var fixableDiagnostics = diagnostics.Where(d => HasFix(codeFixProvider, d.Id)).ToImmutableArray();
if (fixableDiagnostics.Length > 0)
var context = new CodeFixContext(document, span, fixableDiagnostics, (a, _) => codeActions.Add(a), CancellationToken.None);

try
{
var context = new CodeFixContext(document, span, fixableDiagnostics, (a, _) => codeActions.Add(a), CancellationToken.None);

try
{
await codeFixProvider.RegisterCodeFixesAsync(context);
}
catch (Exception ex)
{
this.Logger.LogError(ex, $"Error registering code fixes for {codeFixProvider.GetType().FullName}");
}
await codeFixProvider.RegisterCodeFixesAsync(context);
}
catch (Exception ex)
{
this.Logger.LogError(ex, $"Error registering code fixes for {codeFixProvider.GetType().FullName}");
}
}
}
}

private List<CodeFixProvider> GetSortedCodeFixProviders()
{
List<ProviderNode<CodeFixProvider>> nodesList = new List<ProviderNode<CodeFixProvider>>();
List<CodeFixProvider> providerList = new List<CodeFixProvider>();
foreach (var provider in this.Providers)
{
foreach (var codeFixProvider in provider.CodeFixProviders)
{
providerList.Add(codeFixProvider);
nodesList.Add(ProviderNode<CodeFixProvider>.From(codeFixProvider));
}
}

var graph = Graph<CodeFixProvider>.GetGraph(nodesList);
if (graph.HasCycles())
{
return providerList;
}
return graph.TopologicalSort();
}

private List<CodeRefactoringProvider> GetSortedCodeRefactoringProviders()
{
List<ProviderNode<CodeRefactoringProvider>> nodesList = new List<ProviderNode<CodeRefactoringProvider>>();
List<CodeRefactoringProvider> providerList = new List<CodeRefactoringProvider>();
foreach (var provider in this.Providers)
{
foreach (var codeRefactoringProvider in provider.CodeRefactoringProviders)
{
providerList.Add(codeRefactoringProvider);
nodesList.Add(ProviderNode<CodeRefactoringProvider>.From(codeRefactoringProvider));
}
}

var graph = Graph<CodeRefactoringProvider>.GetGraph(nodesList);
if (graph.HasCycles())
{
return providerList;
}
return graph.TopologicalSort();
}

private bool HasFix(CodeFixProvider codeFixProvider, string diagnosticId)
{
var typeName = codeFixProvider.GetType().FullName;
Expand Down Expand Up @@ -179,25 +221,22 @@ private bool HasFix(CodeFixProvider codeFixProvider, string diagnosticId)

private async Task CollectRefactoringActions(Document document, TextSpan span, List<CodeAction> codeActions)
{
foreach (var provider in this.Providers)
foreach (var codeRefactoringProvider in OrderedCodeRefactoringProviders.Value)
{
foreach (var codeRefactoringProvider in provider.CodeRefactoringProviders)
if (_helper.IsDisallowed(codeRefactoringProvider))
{
if (_helper.IsDisallowed(codeRefactoringProvider))
{
continue;
}
continue;
}

var context = new CodeRefactoringContext(document, span, a => codeActions.Add(a), CancellationToken.None);
var context = new CodeRefactoringContext(document, span, a => codeActions.Add(a), CancellationToken.None);

try
{
await codeRefactoringProvider.ComputeRefactoringsAsync(context);
}
catch (Exception ex)
{
this.Logger.LogError(ex, $"Error computing refactorings for {codeRefactoringProvider.GetType().FullName}");
}
try
{
await codeRefactoringProvider.ComputeRefactoringsAsync(context);
}
catch (Exception ex)
{
this.Logger.LogError(ex, $"Error computing refactorings for {codeRefactoringProvider.GetType().FullName}");
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Adapted from ExtensionOrderer in Roslyn
using System.Collections.Generic;

namespace OmniSharp.Roslyn.CSharp.Services.Refactoring.V2
{
internal class Graph<T>
{
//Dictionary to map between nodes and the names
private Dictionary<string, ProviderNode<T>> Nodes { get; }
private List<ProviderNode<T>> AllNodes { get; }
private Graph(List<ProviderNode<T>> nodesList)
{
Nodes = new Dictionary<string, ProviderNode<T>>();
AllNodes = nodesList;
}
internal static Graph<T> GetGraph(List<ProviderNode<T>> nodesList)
{
var graph = new Graph<T>(nodesList);

foreach (ProviderNode<T> node in graph.AllNodes)
{
graph.Nodes[node.ProviderName] = node;
}

foreach (ProviderNode<T> node in graph.AllNodes)
{
foreach (var before in node.Before)
{
if (graph.Nodes.ContainsKey(before))
{
var beforeNode = graph.Nodes[before];
beforeNode.NodesBeforeMeSet.Add(node);
}
}

foreach (var after in node.After)
{
if (graph.Nodes.ContainsKey(after))
{
var afterNode = graph.Nodes[after];
node.NodesBeforeMeSet.Add(afterNode);
}
}
}

return graph;
}

public bool HasCycles()
{
foreach (var node in this.AllNodes)
{
if (node.CheckForCycles())
return true;
}
return false;
}

public List<T> TopologicalSort()
{
List<T> result = new List<T>();
var seenNodes = new HashSet<ProviderNode<T>>();

foreach (var node in AllNodes)
{
Visit(node, result, seenNodes);
}

return result;
}

private void Visit(ProviderNode<T> node, List<T> result, HashSet<ProviderNode<T>> seenNodes)
{
if (seenNodes.Add(node))
{
foreach (var before in node.NodesBeforeMeSet)
{
Visit(before, result, seenNodes);
}

result.Add(node.Provider);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Adapted from ExtensionOrderer in Roslyn
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CodeRefactorings;

namespace OmniSharp.Roslyn.CSharp.Services.Refactoring.V2
{
internal class ProviderNode<TProvider>
{
public string ProviderName { get; set; }
public List<string> Before { get; set; }
public List<string> After { get; set; }
public TProvider Provider { get; set; }
public HashSet<ProviderNode<TProvider>> NodesBeforeMeSet { get; set; }

public static ProviderNode<TProvider> From(TProvider provider)
{
string providerName = "";
if (provider is CodeFixProvider)
{
var exportAttribute = provider.GetType().GetCustomAttribute(typeof(ExportCodeFixProviderAttribute));
if (exportAttribute is ExportCodeFixProviderAttribute fixAttribute && fixAttribute.Name != null)
{
providerName = fixAttribute.Name;
}
}
else
{
var exportAttribute = provider.GetType().GetCustomAttribute(typeof(ExportCodeRefactoringProviderAttribute));
if (exportAttribute is ExportCodeRefactoringProviderAttribute refactoringAttribute && refactoringAttribute.Name != null)
{
providerName = refactoringAttribute.Name;
}
}

var orderAttributes = provider.GetType().GetCustomAttributes(typeof(ExtensionOrderAttribute), true).Select(attr => (ExtensionOrderAttribute)attr).ToList();
return new ProviderNode<TProvider>(provider, providerName, orderAttributes);
}

private ProviderNode(TProvider provider, string providerName, List<ExtensionOrderAttribute> orderAttributes)
{
Provider = provider;
ProviderName = providerName;
Before = new List<string>();
After = new List<string>();
NodesBeforeMeSet = new HashSet<ProviderNode<TProvider>>();
orderAttributes.ForEach(attr => AddAttribute(attr));
}

private void AddAttribute(ExtensionOrderAttribute attribute)
{
if (attribute.Before != null)
Before.Add(attribute.Before);
if (attribute.After != null)
After.Add(attribute.After);
}

internal bool CheckForCycles()
{
return CheckForCycles(new HashSet<ProviderNode<TProvider>>());
}

private bool CheckForCycles(HashSet<ProviderNode<TProvider>> seenNodes)
{
if (!seenNodes.Add(this))
{
//Cycle detected
return true;
}

foreach (var before in this.NodesBeforeMeSet)
{
if (before.CheckForCycles(seenNodes))
return true;
}

seenNodes.Remove(this);
return false;
}
}
}
31 changes: 30 additions & 1 deletion tests/OmniSharp.Roslyn.CSharp.Tests/CodeActionsV2Facts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,35 @@ public void Whatever()
Assert.Contains("Extract Method", refactorings);
}

[Fact]
public async Task Returns_ordered_code_actions()
{
const string code =
@"public class Class1
{
public void Whatever()
{
[|Console.Write(""should be using System;"");|]
}
}";

var refactorings = await FindRefactoringNamesAsync(code);
List<string> expected = new List<string>
{
"using System;",
"System.Console",
"Generate variable 'Console' -> Generate property 'Class1.Console'",
"Generate variable 'Console' -> Generate field 'Class1.Console'",
"Generate variable 'Console' -> Generate read-only field 'Class1.Console'",
"Generate variable 'Console' -> Generate local 'Console'",
"Generate type 'Console' -> Generate class 'Console' in new file",
"Generate type 'Console' -> Generate class 'Console'",
"Generate type 'Console' -> Generate nested class 'Console'",
"Extract Method"
};
Assert.Equal(expected, refactorings);
}

[Fact]
public async Task Can_extract_method()
{
Expand Down Expand Up @@ -190,7 +219,7 @@ private async Task<IEnumerable<OmniSharpCodeAction>> FindRefactoringsAsync(strin
{
var testFile = new TestFile(BufferPath, code);

using (var host = CreateOmniSharpHost(new [] { testFile }, configurationData))
using (var host = CreateOmniSharpHost(new[] { testFile }, configurationData))
{
var requestHandler = host.GetRequestHandler<GetCodeActionsService>(OmniSharpEndpoints.V2.GetCodeActions);

Expand Down

0 comments on commit e00f370

Please sign in to comment.