diff --git a/src/OmniSharp.Abstractions/Models/ICanBeEmptyResponse.cs b/src/OmniSharp.Abstractions/Models/ICanBeEmptyResponse.cs new file mode 100644 index 0000000000..a4d46a70a3 --- /dev/null +++ b/src/OmniSharp.Abstractions/Models/ICanBeEmptyResponse.cs @@ -0,0 +1,7 @@ +namespace OmniSharp.Models +{ + public interface ICanBeEmptyResponse + { + bool IsEmpty { get; } + } +} diff --git a/src/OmniSharp.Abstractions/Models/v1/GotoDefinition/GotoDefinitionResponse.cs b/src/OmniSharp.Abstractions/Models/v1/GotoDefinition/GotoDefinitionResponse.cs index b9bfca5a0d..ba6695e7bb 100644 --- a/src/OmniSharp.Abstractions/Models/v1/GotoDefinition/GotoDefinitionResponse.cs +++ b/src/OmniSharp.Abstractions/Models/v1/GotoDefinition/GotoDefinitionResponse.cs @@ -3,7 +3,7 @@ namespace OmniSharp.Models.GotoDefinition { - public class GotoDefinitionResponse + public class GotoDefinitionResponse : ICanBeEmptyResponse { public string FileName { get; set; } [JsonConverter(typeof(ZeroBasedIndexConverter))] @@ -11,5 +11,6 @@ public class GotoDefinitionResponse [JsonConverter(typeof(ZeroBasedIndexConverter))] public int Column { get; set; } public MetadataSource MetadataSource { get; set; } + public bool IsEmpty => FileName == null || FileName == string.Empty; } } diff --git a/src/OmniSharp.Host/AssemblyInfo.cs b/src/OmniSharp.Host/AssemblyInfo.cs index aa4d184407..28fc7fdf1a 100644 --- a/src/OmniSharp.Host/AssemblyInfo.cs +++ b/src/OmniSharp.Host/AssemblyInfo.cs @@ -3,4 +3,5 @@ [assembly: InternalsVisibleTo("OmniSharp.Http.Tests")] [assembly: InternalsVisibleTo("OmniSharp.MSBuild.Tests")] [assembly: InternalsVisibleTo("OmniSharp.Roslyn.CSharp.Tests")] +[assembly: InternalsVisibleTo("OmniSharp.Stdio.Tests")] [assembly: InternalsVisibleTo("TestUtility")] diff --git a/src/OmniSharp.Host/Endpoint/EndpointHandler.cs b/src/OmniSharp.Host/Endpoint/EndpointHandler.cs index 5426649783..b83e1d9a55 100644 --- a/src/OmniSharp.Host/Endpoint/EndpointHandler.cs +++ b/src/OmniSharp.Host/Endpoint/EndpointHandler.cs @@ -48,7 +48,7 @@ public class EndpointHandler : EndpointHandler { private readonly CompositionHost _host; private readonly IPredicateHandler _languagePredicateHandler; - private readonly Lazy>>> _exports; + private readonly Lazy[]>>> _exports; private readonly OmniSharpWorkspace _workspace; private readonly bool _hasLanguageProperty; private readonly bool _hasFileNameProperty; @@ -71,10 +71,10 @@ public EndpointHandler(IPredicateHandler languagePredicateHandler, CompositionHo _canBeAggregated = typeof(IAggregateResponse).IsAssignableFrom(metadata.ResponseType); _updateBufferHandler = updateBufferHandler; - _exports = new Lazy>>>(() => LoadExportHandlers(handlers)); + _exports = new Lazy[]>>>(() => LoadExportHandlers(handlers)); } - private Task>> LoadExportHandlers(IEnumerable> handlers) + private Task[]>> LoadExportHandlers(IEnumerable> handlers) { var interfaceHandlers = handlers .Select(export => new RequestHandlerExportHandler(export.Metadata.Language, (IRequestHandler)export.Value)) @@ -84,9 +84,13 @@ private Task>> LoadExportH .Select(plugin => new PluginExportHandler(EndpointName, plugin)) .Cast>(); + // Group handlers by language and sort each group for consistency return Task.FromResult(interfaceHandlers - .Concat(plugins) - .ToDictionary(export => export.Language)); + .Concat(plugins) + .GroupBy(export => export.Language, StringComparer.OrdinalIgnoreCase) + .ToDictionary( + group => group.Key, + group => group.OrderBy(g => g).ToArray())); } public string EndpointName { get; } @@ -142,18 +146,75 @@ private Task HandleLanguageRequest(string language, TRequest request, Re { if (!string.IsNullOrEmpty(language)) { - return HandleSingleRequest(language, request, packet); + return HandleRequestForLanguage(language, request, packet); } return HandleAllRequest(request, packet); } - private async Task HandleSingleRequest(string language, TRequest request, RequestPacket packet) + private async Task AggregateResponsesFromLanguageHandlers(ExportHandler[] handlers, TRequest request) + { + IAggregateResponse aggregateResponse = null; + + var responses = new List>(); + foreach (var handler in handlers) + { + responses.Add(handler.Handle(request)); + } + + foreach (IAggregateResponse response in await Task.WhenAll(responses)) + { + if (aggregateResponse != null) + { + aggregateResponse = aggregateResponse.Merge(response); + } + else + { + aggregateResponse = response; + } + } + + return aggregateResponse; + } + + private async Task GetFirstNotEmptyResponseFromHandlers(ExportHandler[] handlers, TRequest request) + { + var responses = new List>(); + foreach (var handler in handlers) + { + responses.Add(handler.Handle(request)); + } + + foreach (object response in await Task.WhenAll(responses)) + { + var canBeEmptyResponse = response as ICanBeEmptyResponse; + if (canBeEmptyResponse != null) + { + if (!canBeEmptyResponse.IsEmpty) + { + return response; + } + } + else if (response != null) + { + return response; + } + } + + return null; + } + + private async Task HandleRequestForLanguage(string language, TRequest request, RequestPacket packet) { var exports = await _exports.Value; - if (exports.TryGetValue(language, out var handler)) + if (exports.TryGetValue(language, out var handlers)) { - return await handler.Handle(request); + if (_canBeAggregated) + { + return await AggregateResponsesFromLanguageHandlers(handlers, request); + } + + return await GetFirstNotEmptyResponseFromHandlers(handlers, request); } throw new NotSupportedException($"{language} does not support {EndpointName}"); @@ -169,11 +230,10 @@ private async Task HandleAllRequest(TRequest request, RequestPacket pack var exports = await _exports.Value; IAggregateResponse aggregateResponse = null; - - var responses = new List>(); - foreach (var handler in exports.Values) + var responses = new List>(); + foreach (var export in exports) { - responses.Add(handler.Handle(request)); + responses.Add(AggregateResponsesFromLanguageHandlers(export.Value, request)); } foreach (IAggregateResponse exportResponse in await Task.WhenAll(responses)) diff --git a/src/OmniSharp.Host/Endpoint/Exports/ExportHandler.cs b/src/OmniSharp.Host/Endpoint/Exports/ExportHandler.cs index 877254da8e..db068fca3c 100644 --- a/src/OmniSharp.Host/Endpoint/Exports/ExportHandler.cs +++ b/src/OmniSharp.Host/Endpoint/Exports/ExportHandler.cs @@ -1,8 +1,9 @@ +using System; using System.Threading.Tasks; namespace OmniSharp.Endpoint.Exports { - abstract class ExportHandler + abstract class ExportHandler : IComparable> { protected ExportHandler(string language) { @@ -10,6 +11,8 @@ protected ExportHandler(string language) } public string Language { get; } + + public abstract int CompareTo(ExportHandler other); public abstract Task Handle(TRequest request); } } diff --git a/src/OmniSharp.Host/Endpoint/Exports/PluginExportHandler.cs b/src/OmniSharp.Host/Endpoint/Exports/PluginExportHandler.cs index 290be277fc..22dfb90c48 100644 --- a/src/OmniSharp.Host/Endpoint/Exports/PluginExportHandler.cs +++ b/src/OmniSharp.Host/Endpoint/Exports/PluginExportHandler.cs @@ -14,6 +14,17 @@ public PluginExportHandler(string endpoint, Plugin plugin) : base(plugin.Config. _plugin = plugin; } + public override int CompareTo(ExportHandler other) + { + var otherPlugin = other as PluginExportHandler; + if (otherPlugin == null) + { + return 1; + } + + return _plugin.Key.CompareTo(otherPlugin._plugin.Key); + } + public override Task Handle(TRequest request) { return _plugin.Handle(_endpoint, request); diff --git a/src/OmniSharp.Host/Endpoint/Exports/RequestHandlerExportHandler.cs b/src/OmniSharp.Host/Endpoint/Exports/RequestHandlerExportHandler.cs index 58c5cdc147..93cc9797ac 100644 --- a/src/OmniSharp.Host/Endpoint/Exports/RequestHandlerExportHandler.cs +++ b/src/OmniSharp.Host/Endpoint/Exports/RequestHandlerExportHandler.cs @@ -13,6 +13,17 @@ public RequestHandlerExportHandler(string language, IRequestHandler other) + { + var otherHandler = other as RequestHandlerExportHandler; + if (otherHandler == null) + { + return 1; + } + + return _handler.GetType().ToString().CompareTo(otherHandler._handler.GetType().ToString()); + } + public override Task Handle(TRequest request) { return _handler.Handle(request); diff --git a/src/OmniSharp.Host/Mef/MefValueProvider.cs b/src/OmniSharp.Host/Mef/MefValueProvider.cs index 38ba78957c..be725be02c 100644 --- a/src/OmniSharp.Host/Mef/MefValueProvider.cs +++ b/src/OmniSharp.Host/Mef/MefValueProvider.cs @@ -7,10 +7,12 @@ namespace OmniSharp.Mef internal class MefValueProvider : ExportDescriptorProvider { private readonly T _item; + private readonly IDictionary _metadata; - public MefValueProvider(T item) + public MefValueProvider(T item, IDictionary metadata) { _item = item; + _metadata = metadata; } public override IEnumerable GetExportDescriptors(CompositionContract contract, DependencyAccessor descriptorAccessor) @@ -19,16 +21,16 @@ public override IEnumerable GetExportDescriptors(Compos { yield return new ExportDescriptorPromise(contract, string.Empty, true, () => Enumerable.Empty(), - deps => ExportDescriptor.Create((context, operation) => _item, new Dictionary())); + deps => ExportDescriptor.Create((context, operation) => _item, _metadata ?? new Dictionary())); } } } internal static class MefValueProvider { - public static MefValueProvider From(T value) + public static MefValueProvider From(T value, IDictionary metadata = null) { - return new MefValueProvider(value); + return new MefValueProvider(value, metadata); } } } diff --git a/tests/OmniSharp.Stdio.Tests/EndpointHandlerFacts.cs b/tests/OmniSharp.Stdio.Tests/EndpointHandlerFacts.cs new file mode 100644 index 0000000000..d06b9b5ae9 --- /dev/null +++ b/tests/OmniSharp.Stdio.Tests/EndpointHandlerFacts.cs @@ -0,0 +1,278 @@ +using System; +using System.Collections.Generic; +using System.Composition.Hosting.Core; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis; +using Microsoft.Extensions.Configuration; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using OmniSharp.Mef; +using OmniSharp.Models; +using OmniSharp.Models.FindSymbols; +using OmniSharp.Models.GotoDefinition; +using OmniSharp.Models.WorkspaceInformation; +using OmniSharp.Protocol; +using OmniSharp.Services; +using Xunit; + +namespace OmniSharp.Stdio.Tests +{ + public class EndpointHandlerFacts + { + private abstract class FakeFindSymbolsServiceBase : IRequestHandler + { + private readonly string _name; + + public FakeFindSymbolsServiceBase(string name) + { + _name = name; + } + + public Task Handle(FindSymbolsRequest request = null) + { + return Task.FromResult(new QuickFixResponse + { + QuickFixes = new[] + { + new QuickFix + { + FileName = $"{_name}.cs", + Line = 1, + Column = 1, + EndLine = 1, + EndColumn = 4, + Text = _name + } + } + }); + } + } + + private class AAAFakeFindSymbolsService : FakeFindSymbolsServiceBase + { + public AAAFakeFindSymbolsService() : base(nameof(AAAFakeFindSymbolsService)) { } + } + + private class BBBFakeFindSymbolsService : FakeFindSymbolsServiceBase + { + public BBBFakeFindSymbolsService() : base(nameof(BBBFakeFindSymbolsService)) { } + } + + private class CCCFakeFindSymbolsService : FakeFindSymbolsServiceBase + { + public CCCFakeFindSymbolsService() : base(nameof(CCCFakeFindSymbolsService)) { } + } + + private class FakeGotoDefinitionService : IRequestHandler + { + private readonly string _name; + private readonly bool _returnEmptyResponse; + + public FakeGotoDefinitionService(string name, bool returnEmptyResponse) + { + _name = name; + _returnEmptyResponse = returnEmptyResponse; + } + + public Task Handle(GotoDefinitionRequest request) + { + if (_returnEmptyResponse) + { + Task.FromResult(new GotoDefinitionResponse()); + } + + return Task.FromResult(new GotoDefinitionResponse + { + FileName = $"{_name}.cs", + Line = 1, + Column = 2 + }); + } + } + + private class FakeProjectSystem : IProjectSystem + { + public string Key => "FakeProjectSystem"; + + public string Language => LanguageNames.CSharp; + + public IEnumerable Extensions => new[] { string.Empty }; + + public bool EnabledByDefault => true; + + public bool Initialized => true; + + public Task GetProjectModelAsync(string filePath) + { + throw new NotImplementedException(); + } + + public Task GetWorkspaceModelAsync(WorkspaceInformationRequest request) + { + throw new NotImplementedException(); + } + + public void Initalize(IConfiguration configuration) + { + } + } + + private class TestRequestPacket : RequestPacket + { + public string Arguments { get; set; } + } + + [Fact] + public void HandleAggregatableResponsesForSingleLanguage() + { + var request = new TestRequestPacket() + { + Seq = 99, + Command = OmniSharpEndpoints.FindSymbols, + Arguments = JsonConvert.SerializeObject(new FindSymbolsRequest { Language = LanguageNames.CSharp }) + }; + + var writer = new TestTextWriter( + value => + { + var packet = JsonConvert.DeserializeObject(value); + Assert.Equal("started", packet.Event); + }, + value => + { + var packet = JsonConvert.DeserializeObject(value); + Assert.Equal(request.Seq, packet.Request_seq); + Assert.Equal(request.Command, packet.Command); + Assert.True(packet.Success); + Assert.True(packet.Running); + Assert.Null(packet.Message); + var quickFixResponse = ((JObject)packet.Body).ToObject(); + Assert.Equal(3, quickFixResponse.QuickFixes.Count()); + Assert.Equal("AAAFakeFindSymbolsService", quickFixResponse.QuickFixes.ElementAt(0).Text); + Assert.Equal("BBBFakeFindSymbolsService", quickFixResponse.QuickFixes.ElementAt(1).Text); + Assert.Equal("CCCFakeFindSymbolsService", quickFixResponse.QuickFixes.ElementAt(2).Text); + } + ); + + var exports = new ExportDescriptorProvider[] + { + MefValueProvider.From( + new BBBFakeFindSymbolsService(), + new Dictionary { ["EndpointName"] = OmniSharpEndpoints.FindSymbols, ["Language"] = LanguageNames.CSharp }), + MefValueProvider.From( + new CCCFakeFindSymbolsService(), + new Dictionary { ["EndpointName"] = OmniSharpEndpoints.FindSymbols, ["Language"] = LanguageNames.CSharp }), + MefValueProvider.From( + new AAAFakeFindSymbolsService(), + new Dictionary { ["EndpointName"] = OmniSharpEndpoints.FindSymbols, ["Language"] = LanguageNames.CSharp }), + }; + + using (StdioServerFacts.BuildTestServerAndStart(new StringReader(JsonConvert.SerializeObject(request) + "\r\n"), writer, additionalExports: exports)) + { + Assert.True(writer.Completion.WaitOne(TimeSpan.FromSeconds(60)), "Timeout"); + Assert.Null(writer.Exception); + } + } + + [Fact] + public void HandleAggregatableResponsesForMultipleLanguages() + { + var request = new TestRequestPacket() + { + Seq = 99, + Command = OmniSharpEndpoints.FindSymbols, + Arguments = JsonConvert.SerializeObject(new FindSymbolsRequest()) + }; + + var writer = new TestTextWriter( + value => + { + var packet = JsonConvert.DeserializeObject(value); + Assert.Equal("started", packet.Event); + }, + value => + { + var packet = JsonConvert.DeserializeObject(value); + Assert.Equal(request.Seq, packet.Request_seq); + Assert.Equal(request.Command, packet.Command); + Assert.True(packet.Success); + Assert.True(packet.Running); + Assert.Null(packet.Message); + var quickFixResponse = ((JObject)packet.Body).ToObject(); + Assert.Equal(3, quickFixResponse.QuickFixes.Count()); + Assert.Equal("AAAFakeFindSymbolsService", quickFixResponse.QuickFixes.ElementAt(0).Text); + Assert.Equal("BBBFakeFindSymbolsService", quickFixResponse.QuickFixes.ElementAt(1).Text); + Assert.Equal("CCCFakeFindSymbolsService", quickFixResponse.QuickFixes.ElementAt(2).Text); + } + ); + + var exports = new ExportDescriptorProvider[] + { + MefValueProvider.From( + new BBBFakeFindSymbolsService(), + new Dictionary { ["EndpointName"] = OmniSharpEndpoints.FindSymbols, ["Language"] = LanguageNames.CSharp }), + MefValueProvider.From( + new CCCFakeFindSymbolsService(), + new Dictionary { ["EndpointName"] = OmniSharpEndpoints.FindSymbols, ["Language"] = LanguageNames.VisualBasic }), + MefValueProvider.From( + new AAAFakeFindSymbolsService(), + new Dictionary { ["EndpointName"] = OmniSharpEndpoints.FindSymbols, ["Language"] = LanguageNames.CSharp }), + }; + + using (StdioServerFacts.BuildTestServerAndStart(new StringReader(JsonConvert.SerializeObject(request) + "\r\n"), writer, additionalExports: exports)) + { + Assert.True(writer.Completion.WaitOne(TimeSpan.FromSeconds(60)), "Timeout"); + Assert.Null(writer.Exception); + } + } + + [Fact] + public void HandleNonAggregatableResponses() + { + var request = new TestRequestPacket() + { + Seq = 99, + Command = OmniSharpEndpoints.GotoDefinition, + Arguments = JsonConvert.SerializeObject(new GotoDefinitionRequest { FileName = "foo.cs" }) + }; + + var writer = new TestTextWriter( + value => + { + var packet = JsonConvert.DeserializeObject(value); + Assert.Equal("started", packet.Event); + }, + value => + { + var packet = JsonConvert.DeserializeObject(value); + Assert.Equal(request.Seq, packet.Request_seq); + Assert.Equal(request.Command, packet.Command); + Assert.True(packet.Success); + Assert.True(packet.Running); + Assert.Null(packet.Message); + var gotoDefinitionResponse = ((JObject)packet.Body).ToObject(); + Assert.Equal("ZZZFake.cs", gotoDefinitionResponse.FileName); + } + ); + + var exports = new ExportDescriptorProvider[] + { + MefValueProvider.From(new FakeProjectSystem()), + MefValueProvider.From( + new FakeGotoDefinitionService("ZZZFake", false), + new Dictionary { ["EndpointName"] = OmniSharpEndpoints.GotoDefinition, ["Language"] = LanguageNames.CSharp }), + MefValueProvider.From( + new FakeGotoDefinitionService("AAAFake", true), + new Dictionary { ["EndpointName"] = OmniSharpEndpoints.GotoDefinition, ["Language"] = LanguageNames.CSharp }), + }; + + using (StdioServerFacts.BuildTestServerAndStart(new StringReader(JsonConvert.SerializeObject(request) + "\r\n"), writer, additionalExports: exports)) + { + Assert.True(writer.Completion.WaitOne(TimeSpan.FromSeconds(60)), "Timeout"); + Assert.Null(writer.Exception); + } + } + } +} diff --git a/tests/OmniSharp.Stdio.Tests/StdioServerFacts.cs b/tests/OmniSharp.Stdio.Tests/StdioServerFacts.cs index bfe751e74d..2b6a535a8a 100644 --- a/tests/OmniSharp.Stdio.Tests/StdioServerFacts.cs +++ b/tests/OmniSharp.Stdio.Tests/StdioServerFacts.cs @@ -1,4 +1,6 @@ using System; +using System.Collections.Generic; +using System.Composition.Hosting.Core; using System.IO; using System.Threading; using Microsoft.Extensions.DependencyInjection; @@ -13,7 +15,8 @@ namespace OmniSharp.Stdio.Tests { public class StdioServerFacts { - private Host BuildTestServerAndStart(TextReader reader, ISharedTextWriter writer, Action programDelegate = null) + internal static Host BuildTestServerAndStart(TextReader reader, ISharedTextWriter writer, Action programDelegate = null, + IEnumerable additionalExports = null) { var configuration = new Microsoft.Extensions.Configuration.ConfigurationBuilder().Build(); var environment = new OmniSharpEnvironment(); @@ -22,7 +25,7 @@ private Host BuildTestServerAndStart(TextReader reader, ISharedTextWriter writer var host = new Host(reader, writer, environment, serviceProvider, - new CompositionHostBuilder(serviceProvider), + new CompositionHostBuilder(serviceProvider, exportDescriptorProviders: additionalExports), serviceProvider.GetRequiredService(), cancelationTokenSource);