diff --git a/src/OmniSharp.Roslyn.CSharp/Services/Refactoring/V2/BaseCodeActionService.cs b/src/OmniSharp.Roslyn.CSharp/Services/Refactoring/V2/BaseCodeActionService.cs index 3dc6718078..1945ee3d42 100644 --- a/src/OmniSharp.Roslyn.CSharp/Services/Refactoring/V2/BaseCodeActionService.cs +++ b/src/OmniSharp.Roslyn.CSharp/Services/Refactoring/V2/BaseCodeActionService.cs @@ -30,6 +30,9 @@ public abstract class BaseCodeActionService : IRequestHandl private static readonly Func> s_createDiagnosticList = _ => new List(); + protected Lazy> OrderedCodeFixProviders; + protected Lazy> OrderedCodeRefactoringProviders; + protected BaseCodeActionService(OmniSharpWorkspace workspace, CodeActionHelper helper, IEnumerable providers, ILogger logger) { this.Workspace = workspace; @@ -37,6 +40,9 @@ protected BaseCodeActionService(OmniSharpWorkspace workspace, CodeActionHelper h this.Logger = logger; this._helper = helper; + OrderedCodeFixProviders = new Lazy>(() => GetSortedCodeFixProviders()); + OrderedCodeRefactoringProviders = new Lazy>(() => GetSortedCodeRefactoringProviders()); + // Sadly, the CodeAction.NestedCodeActions property is still internal. var nestedCodeActionsProperty = typeof(CodeAction).GetProperty("NestedCodeActions", BindingFlags.NonPublic | BindingFlags.Instance); if (nestedCodeActionsProperty == null) @@ -69,9 +75,6 @@ protected async Task> GetAvailableCodeActions(I await CollectCodeFixesActions(document, span, codeActions); await CollectRefactoringActions(document, span, codeActions); - // TODO: Determine good way to order code actions. - codeActions.Reverse(); - // Be sure to filter out any code actions that inherit from CodeActionWithOptions. // This isn't a great solution and might need changing later, but every Roslyn code action // derived from this type tries to display a dialog. For now, this is a reasonable solution. @@ -128,28 +131,67 @@ private async Task CollectCodeFixesActions(Document document, TextSpan span, Lis private async Task AppendFixesAsync(Document document, TextSpan span, IEnumerable diagnostics, List codeActions) { - foreach (var provider in this.Providers) + foreach (var codeFixProvider in OrderedCodeFixProviders.Value) { - foreach (var codeFixProvider in provider.CodeFixProviders) + var fixableDiagnostics = diagnostics.Where(d => HasFix(codeFixProvider, d.Id)).ToImmutableArray(); + if (fixableDiagnostics.Length > 0) { - var fixableDiagnostics = diagnostics.Where(d => HasFix(codeFixProvider, d.Id)).ToImmutableArray(); - if (fixableDiagnostics.Length > 0) + var context = new CodeFixContext(document, span, fixableDiagnostics, (a, _) => codeActions.Add(a), CancellationToken.None); + + try { - var context = new CodeFixContext(document, span, fixableDiagnostics, (a, _) => codeActions.Add(a), CancellationToken.None); - - try - { - await codeFixProvider.RegisterCodeFixesAsync(context); - } - catch (Exception ex) - { - this.Logger.LogError(ex, $"Error registering code fixes for {codeFixProvider.GetType().FullName}"); - } + await codeFixProvider.RegisterCodeFixesAsync(context); + } + catch (Exception ex) + { + this.Logger.LogError(ex, $"Error registering code fixes for {codeFixProvider.GetType().FullName}"); } } } } + private List GetSortedCodeFixProviders() + { + List> nodesList = new List>(); + List providerList = new List(); + foreach (var provider in this.Providers) + { + foreach (var codeFixProvider in provider.CodeFixProviders) + { + providerList.Add(codeFixProvider); + nodesList.Add(ProviderNode.From(codeFixProvider)); + } + } + + var graph = Graph.GetGraph(nodesList); + if (graph.HasCycles()) + { + return providerList; + } + return graph.TopologicalSort(); + } + + private List GetSortedCodeRefactoringProviders() + { + List> nodesList = new List>(); + List providerList = new List(); + foreach (var provider in this.Providers) + { + foreach (var codeRefactoringProvider in provider.CodeRefactoringProviders) + { + providerList.Add(codeRefactoringProvider); + nodesList.Add(ProviderNode.From(codeRefactoringProvider)); + } + } + + var graph = Graph.GetGraph(nodesList); + if (graph.HasCycles()) + { + return providerList; + } + return graph.TopologicalSort(); + } + private bool HasFix(CodeFixProvider codeFixProvider, string diagnosticId) { var typeName = codeFixProvider.GetType().FullName; @@ -179,25 +221,22 @@ private bool HasFix(CodeFixProvider codeFixProvider, string diagnosticId) private async Task CollectRefactoringActions(Document document, TextSpan span, List codeActions) { - foreach (var provider in this.Providers) + foreach (var codeRefactoringProvider in OrderedCodeRefactoringProviders.Value) { - foreach (var codeRefactoringProvider in provider.CodeRefactoringProviders) + if (_helper.IsDisallowed(codeRefactoringProvider)) { - if (_helper.IsDisallowed(codeRefactoringProvider)) - { - continue; - } + continue; + } - var context = new CodeRefactoringContext(document, span, a => codeActions.Add(a), CancellationToken.None); + var context = new CodeRefactoringContext(document, span, a => codeActions.Add(a), CancellationToken.None); - try - { - await codeRefactoringProvider.ComputeRefactoringsAsync(context); - } - catch (Exception ex) - { - this.Logger.LogError(ex, $"Error computing refactorings for {codeRefactoringProvider.GetType().FullName}"); - } + try + { + await codeRefactoringProvider.ComputeRefactoringsAsync(context); + } + catch (Exception ex) + { + this.Logger.LogError(ex, $"Error computing refactorings for {codeRefactoringProvider.GetType().FullName}"); } } } diff --git a/src/OmniSharp.Roslyn.CSharp/Services/Refactoring/V2/CodeActionsOrder.Graph.cs b/src/OmniSharp.Roslyn.CSharp/Services/Refactoring/V2/CodeActionsOrder.Graph.cs new file mode 100644 index 0000000000..79c8a007c5 --- /dev/null +++ b/src/OmniSharp.Roslyn.CSharp/Services/Refactoring/V2/CodeActionsOrder.Graph.cs @@ -0,0 +1,85 @@ +// Adapted from ExtensionOrderer in Roslyn +using System.Collections.Generic; + +namespace OmniSharp.Roslyn.CSharp.Services.Refactoring.V2 +{ + internal class Graph + { + //Dictionary to map between nodes and the names + private Dictionary> Nodes { get; } + private List> AllNodes { get; } + private Graph(List> nodesList) + { + Nodes = new Dictionary>(); + AllNodes = nodesList; + } + internal static Graph GetGraph(List> nodesList) + { + var graph = new Graph(nodesList); + + foreach (ProviderNode node in graph.AllNodes) + { + graph.Nodes[node.ProviderName] = node; + } + + foreach (ProviderNode node in graph.AllNodes) + { + foreach (var before in node.Before) + { + if (graph.Nodes.ContainsKey(before)) + { + var beforeNode = graph.Nodes[before]; + beforeNode.NodesBeforeMeSet.Add(node); + } + } + + foreach (var after in node.After) + { + if (graph.Nodes.ContainsKey(after)) + { + var afterNode = graph.Nodes[after]; + node.NodesBeforeMeSet.Add(afterNode); + } + } + } + + return graph; + } + + public bool HasCycles() + { + foreach (var node in this.AllNodes) + { + if (node.CheckForCycles()) + return true; + } + return false; + } + + public List TopologicalSort() + { + List result = new List(); + var seenNodes = new HashSet>(); + + foreach (var node in AllNodes) + { + Visit(node, result, seenNodes); + } + + return result; + } + + private void Visit(ProviderNode node, List result, HashSet> seenNodes) + { + if (seenNodes.Add(node)) + { + foreach (var before in node.NodesBeforeMeSet) + { + Visit(before, result, seenNodes); + } + + result.Add(node.Provider); + } + } + } +} diff --git a/src/OmniSharp.Roslyn.CSharp/Services/Refactoring/V2/CodeActionsOrder.ProviderNode.cs b/src/OmniSharp.Roslyn.CSharp/Services/Refactoring/V2/CodeActionsOrder.ProviderNode.cs new file mode 100644 index 0000000000..92be73ac8a --- /dev/null +++ b/src/OmniSharp.Roslyn.CSharp/Services/Refactoring/V2/CodeActionsOrder.ProviderNode.cs @@ -0,0 +1,85 @@ +// Adapted from ExtensionOrderer in Roslyn +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CodeFixes; +using Microsoft.CodeAnalysis.CodeRefactorings; + +namespace OmniSharp.Roslyn.CSharp.Services.Refactoring.V2 +{ + internal class ProviderNode + { + public string ProviderName { get; set; } + public List Before { get; set; } + public List After { get; set; } + public TProvider Provider { get; set; } + public HashSet> NodesBeforeMeSet { get; set; } + + public static ProviderNode From(TProvider provider) + { + string providerName = ""; + if (provider is CodeFixProvider) + { + var exportAttribute = provider.GetType().GetCustomAttribute(typeof(ExportCodeFixProviderAttribute)); + if (exportAttribute is ExportCodeFixProviderAttribute fixAttribute && fixAttribute.Name != null) + { + providerName = fixAttribute.Name; + } + } + else + { + var exportAttribute = provider.GetType().GetCustomAttribute(typeof(ExportCodeRefactoringProviderAttribute)); + if (exportAttribute is ExportCodeRefactoringProviderAttribute refactoringAttribute && refactoringAttribute.Name != null) + { + providerName = refactoringAttribute.Name; + } + } + + var orderAttributes = provider.GetType().GetCustomAttributes(typeof(ExtensionOrderAttribute), true).Select(attr => (ExtensionOrderAttribute)attr).ToList(); + return new ProviderNode(provider, providerName, orderAttributes); + } + + private ProviderNode(TProvider provider, string providerName, List orderAttributes) + { + Provider = provider; + ProviderName = providerName; + Before = new List(); + After = new List(); + NodesBeforeMeSet = new HashSet>(); + orderAttributes.ForEach(attr => AddAttribute(attr)); + } + + private void AddAttribute(ExtensionOrderAttribute attribute) + { + if (attribute.Before != null) + Before.Add(attribute.Before); + if (attribute.After != null) + After.Add(attribute.After); + } + + internal bool CheckForCycles() + { + return CheckForCycles(new HashSet>()); + } + + private bool CheckForCycles(HashSet> seenNodes) + { + if (!seenNodes.Add(this)) + { + //Cycle detected + return true; + } + + foreach (var before in this.NodesBeforeMeSet) + { + if (before.CheckForCycles(seenNodes)) + return true; + } + + seenNodes.Remove(this); + return false; + } + } +} diff --git a/tests/OmniSharp.Roslyn.CSharp.Tests/CodeActionsV2Facts.cs b/tests/OmniSharp.Roslyn.CSharp.Tests/CodeActionsV2Facts.cs index 63c09cea6c..e7073a43a8 100644 --- a/tests/OmniSharp.Roslyn.CSharp.Tests/CodeActionsV2Facts.cs +++ b/tests/OmniSharp.Roslyn.CSharp.Tests/CodeActionsV2Facts.cs @@ -99,6 +99,35 @@ public void Whatever() Assert.Contains("Extract Method", refactorings); } + [Fact] + public async Task Returns_ordered_code_actions() + { + const string code = + @"public class Class1 + { + public void Whatever() + { + [|Console.Write(""should be using System;"");|] + } + }"; + + var refactorings = await FindRefactoringNamesAsync(code); + List expected = new List + { + "using System;", + "System.Console", + "Generate variable 'Console' -> Generate property 'Class1.Console'", + "Generate variable 'Console' -> Generate field 'Class1.Console'", + "Generate variable 'Console' -> Generate read-only field 'Class1.Console'", + "Generate variable 'Console' -> Generate local 'Console'", + "Generate type 'Console' -> Generate class 'Console' in new file", + "Generate type 'Console' -> Generate class 'Console'", + "Generate type 'Console' -> Generate nested class 'Console'", + "Extract Method" + }; + Assert.Equal(expected, refactorings); + } + [Fact] public async Task Can_extract_method() { @@ -190,7 +219,7 @@ private async Task> FindRefactoringsAsync(strin { var testFile = new TestFile(BufferPath, code); - using (var host = CreateOmniSharpHost(new [] { testFile }, configurationData)) + using (var host = CreateOmniSharpHost(new[] { testFile }, configurationData)) { var requestHandler = host.GetRequestHandler(OmniSharpEndpoints.V2.GetCodeActions);