diff --git a/src/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/AbstractLanguageServer.cs b/src/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/AbstractLanguageServer.cs index 90492b7d7fa00..a375878201a7e 100644 --- a/src/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/AbstractLanguageServer.cs +++ b/src/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/AbstractLanguageServer.cs @@ -7,6 +7,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Reflection; using System.Threading; @@ -172,10 +173,11 @@ protected IRequestExecutionQueue GetRequestExecutionQueue() return _queue.Value; } - public virtual string GetLanguageForRequest(string methodName, object? serializedRequest) + public virtual bool TryGetLanguageForRequest(string methodName, object? serializedRequest, [NotNullWhen(true)] out string? language) { Logger.LogInformation($"Using default language handler for {methodName}"); - return LanguageServerConstants.DefaultLanguageName; + language = LanguageServerConstants.DefaultLanguageName; + return true; } protected abstract DelegatingEntryPoint CreateDelegatingEntryPoint(string method); diff --git a/src/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/IQueueItem.cs b/src/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/IQueueItem.cs index f49d36b7b13ed..36e80aee8e73a 100644 --- a/src/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/IQueueItem.cs +++ b/src/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/IQueueItem.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Threading; using System.Threading.Tasks; @@ -37,6 +38,12 @@ internal interface IQueueItem /// Task<(TRequestContext, TRequest)?> CreateRequestContextAsync(IMethodHandler handler, RequestHandlerMetadata requestHandlerMetadata, AbstractLanguageServer languageServer, CancellationToken cancellationToken); + /// + /// Handles when the queue needs to manually fail a request before the + /// handler is invoked without shutting down the entire queue. + /// + void FailRequest(string message); + /// /// Provides access to LSP services. /// diff --git a/src/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/QueueItem.cs b/src/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/QueueItem.cs index 8b04b82bc394c..dc8b05b5a346b 100644 --- a/src/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/QueueItem.cs +++ b/src/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/QueueItem.cs @@ -6,11 +6,7 @@ #nullable enable using System; -using System.Collections.Frozen; using System.Diagnostics.CodeAnalysis; -using System.Diagnostics.Contracts; -using System.Linq; -using System.Reflection; using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.Threading; @@ -30,6 +26,13 @@ internal class QueueItem : IQueueItem private readonly ILspLogger _logger; private readonly AbstractRequestScope? _requestTelemetryScope; + /// + /// True if this queue item has actually started handling the request + /// by delegating to the handler. False while the item is still being + /// processed by the queue. + /// + private bool _requestHandlingStarted = false; + /// /// A task completion source representing the result of this queue item's work. /// This is the task that the client is waiting on. @@ -158,6 +161,7 @@ private bool TryDeserializeRequest( /// public async Task StartRequestAsync(TRequest request, TRequestContext? context, IMethodHandler handler, string language, CancellationToken cancellationToken) { + _requestHandlingStarted = true; _logger.LogStartContext($"{MethodName}"); try @@ -240,4 +244,20 @@ public async Task StartRequestAsync(TRequest request, TRequ // so it can decide how to handle the result / exception. await _completionSource.Task.ConfigureAwait(false); } + + public void FailRequest(string message) + { + // This is not valid to call after StartRequestAsync starts as they both access the same state. + // StartRequestAsync handles any failures internally once it runs. + if (_requestHandlingStarted) + { + throw new InvalidOperationException("Cannot manually fail queue item after it has started"); + } + var exception = new Exception(message); + _requestTelemetryScope?.RecordException(exception); + _logger.LogException(exception); + + _completionSource.TrySetException(exception); + _requestTelemetryScope?.Dispose(); + } } diff --git a/src/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/RequestExecutionQueue.cs b/src/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/RequestExecutionQueue.cs index edbbc2d4297d4..0f987aa808ff2 100644 --- a/src/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/RequestExecutionQueue.cs +++ b/src/LanguageServer/Microsoft.CommonLanguageServerProtocol.Framework/RequestExecutionQueue.cs @@ -10,6 +10,7 @@ using System.Collections.Frozen; using System.Collections.Generic; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Reflection; using System.Threading; @@ -245,10 +246,27 @@ private async Task ProcessQueueAsync() // notifications have been completed by the time we attempt to determine the language, so we have the up to date map of URI to language. // Since didOpen notifications are marked as mutating, the queue will not advance to the next request until the server has finished processing // the didOpen, ensuring that this line will only run once all prior didOpens have completed. - var language = _languageServer.GetLanguageForRequest(work.MethodName, work.SerializedRequest); + var didGetLanguage = _languageServer.TryGetLanguageForRequest(work.MethodName, work.SerializedRequest, out var language); // Now that we know the actual language, we can deserialize the request and start creating the request context. - var (metadata, handler, methodInfo) = GetHandlerForRequest(work, language); + var (metadata, handler, methodInfo) = GetHandlerForRequest(work, language ?? LanguageServerConstants.DefaultLanguageName); + + // We had an issue determining the language. Generally this is very rare and only occurs + // if there is a mis-behaving client that sends us requests for files where we haven't saved the languageId. + // We should only crash if this was a mutating method, otherwise we should just fail the single request. + if (!didGetLanguage) + { + var message = $"Failed to get language for {work.MethodName}"; + if (handler.MutatesSolutionState) + { + throw new InvalidOperationException(message); + } + else + { + work.FailRequest(message); + return; + } + } // We now have the actual handler and language, so we can process the work item using the concrete types defined by the metadata. await InvokeProcessCoreAsync(work, metadata, handler, methodInfo, concurrentlyExecutingTasks, currentWorkCts, cancellationToken).ConfigureAwait(false); diff --git a/src/LanguageServer/Protocol/ILanguageInfoProvider.cs b/src/LanguageServer/Protocol/ILanguageInfoProvider.cs index 9673195278cdd..315e6ed6e22fb 100644 --- a/src/LanguageServer/Protocol/ILanguageInfoProvider.cs +++ b/src/LanguageServer/Protocol/ILanguageInfoProvider.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Diagnostics.CodeAnalysis; using Microsoft.CodeAnalysis.Features.Workspaces; namespace Microsoft.CodeAnalysis.LanguageServer @@ -19,7 +20,6 @@ internal interface ILanguageInfoProvider : ILspService /// It is totally possible to not find language based on the file path (e.g. a newly created file that hasn't been saved to disk). /// In that case, we use the language Id that the LSP client gave us. /// - /// Thrown when the language information cannot be determined. - LanguageInformation GetLanguageInformation(Uri documentUri, string? lspLanguageId); + bool TryGetLanguageInformation(Uri uri, string? lspLanguageId, [NotNullWhen(true)] out LanguageInformation? languageInformation); } } diff --git a/src/LanguageServer/Protocol/LanguageInfoProvider.cs b/src/LanguageServer/Protocol/LanguageInfoProvider.cs index 0758584cad42f..56e9c506aa809 100644 --- a/src/LanguageServer/Protocol/LanguageInfoProvider.cs +++ b/src/LanguageServer/Protocol/LanguageInfoProvider.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.IO; using Microsoft.CodeAnalysis.Features.Workspaces; @@ -43,7 +44,7 @@ internal class LanguageInfoProvider : ILanguageInfoProvider { ".mts", s_typeScriptLanguageInformation }, }; - public LanguageInformation GetLanguageInformation(Uri uri, string? lspLanguageId) + public bool TryGetLanguageInformation(Uri uri, string? lspLanguageId, [NotNullWhen(true)] out LanguageInformation? languageInformation) { // First try to get language information from the URI path. // We can do this for File uris and absolute uris. We use local path to get the value without any query parameters. @@ -51,14 +52,14 @@ public LanguageInformation GetLanguageInformation(Uri uri, string? lspLanguageId { var localPath = uri.LocalPath; var extension = Path.GetExtension(localPath); - if (s_extensionToLanguageInformation.TryGetValue(extension, out var languageInformation)) + if (s_extensionToLanguageInformation.TryGetValue(extension, out languageInformation)) { - return languageInformation; + return true; } } // If the URI file path mapping failed, use the languageId from the LSP client (if any). - return lspLanguageId switch + languageInformation = lspLanguageId switch { "csharp" => s_csharpLanguageInformation, "fsharp" => s_fsharpLanguageInformation, @@ -67,8 +68,10 @@ public LanguageInformation GetLanguageInformation(Uri uri, string? lspLanguageId "xaml" => s_xamlLanguageInformation, "typescript" => s_typeScriptLanguageInformation, "javascript" => s_typeScriptLanguageInformation, - _ => throw new InvalidOperationException($"Unable to determine language for '{uri}' with LSP language id '{lspLanguageId}'") + _ => null, }; + + return languageInformation != null; } } } diff --git a/src/LanguageServer/Protocol/RoslynLanguageServer.cs b/src/LanguageServer/Protocol/RoslynLanguageServer.cs index 37c845e4e7578..a68788ad03c87 100644 --- a/src/LanguageServer/Protocol/RoslynLanguageServer.cs +++ b/src/LanguageServer/Protocol/RoslynLanguageServer.cs @@ -6,6 +6,7 @@ using System.Collections.Frozen; using System.Collections.Generic; using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; using System.Text.Json; using System.Threading; using System.Threading.Tasks; @@ -164,12 +165,13 @@ public Task OnInitializedAsync(ClientCapabilities clientCapabilities, RequestCon return Task.CompletedTask; } - public override string GetLanguageForRequest(string methodName, object? serializedParameters) + public override bool TryGetLanguageForRequest(string methodName, object? serializedParameters, [NotNullWhen(true)] out string? language) { if (serializedParameters == null) { Logger.LogInformation("No request parameters given, using default language handler"); - return LanguageServerConstants.DefaultLanguageName; + language = LanguageServerConstants.DefaultLanguageName; + return true; } // We implement the STJ language server so this must be a JsonElement. @@ -179,7 +181,8 @@ public override string GetLanguageForRequest(string methodName, object? serializ // as we do not want languages to be able to override them. if (ShouldUseDefaultLanguage(methodName)) { - return LanguageServerConstants.DefaultLanguageName; + language = LanguageServerConstants.DefaultLanguageName; + return true; } var lspWorkspaceManager = GetLspServices().GetRequiredService(); @@ -188,34 +191,41 @@ public override string GetLanguageForRequest(string methodName, object? serializ // { "textDocument": { "uri": "" ... } ... } // // We can easily identify the URI for the request by looking for this structure + Uri? uri = null; if (parameters.TryGetProperty("textDocument", out var textDocumentToken) || parameters.TryGetProperty("_vs_textDocument", out textDocumentToken)) { var uriToken = textDocumentToken.GetProperty("uri"); - var uri = JsonSerializer.Deserialize(uriToken, ProtocolConversions.LspJsonSerializerOptions); + uri = JsonSerializer.Deserialize(uriToken, ProtocolConversions.LspJsonSerializerOptions); Contract.ThrowIfNull(uri, "Failed to deserialize uri property"); - var language = lspWorkspaceManager.GetLanguageForUri(uri); - Logger.LogInformation($"Using {language} from request text document"); - return language; } - - // All the LSP resolve params have the following known json structure - // { "data": { "TextDocument": { "uri": "" ... } ... } ... } - // - // We can deserialize the data object using our unified DocumentResolveData. - //var dataToken = parameters["data"]; - if (parameters.TryGetProperty("data", out var dataToken)) + else if (parameters.TryGetProperty("data", out var dataToken)) { + // All the LSP resolve params have the following known json structure + // { "data": { "TextDocument": { "uri": "" ... } ... } ... } + // + // We can deserialize the data object using our unified DocumentResolveData. + //var dataToken = parameters["data"]; var data = JsonSerializer.Deserialize(dataToken, ProtocolConversions.LspJsonSerializerOptions); Contract.ThrowIfNull(data, "Failed to document resolve data object"); - var language = lspWorkspaceManager.GetLanguageForUri(data.TextDocument.Uri); - Logger.LogInformation($"Using {language} from data text document"); - return language; + uri = data.TextDocument.Uri; + } + + if (uri == null) + { + // This request is not for a textDocument and is not a resolve request. + Logger.LogInformation("Request did not contain a textDocument, using default language handler"); + language = LanguageServerConstants.DefaultLanguageName; + return true; + } + + if (!lspWorkspaceManager.TryGetLanguageForUri(uri, out language)) + { + Logger.LogError($"Failed to get language for {uri} with language {language}"); + return false; } - // This request is not for a textDocument and is not a resolve request. - Logger.LogInformation("Request did not contain a textDocument, using default language handler"); - return LanguageServerConstants.DefaultLanguageName; + return true; static bool ShouldUseDefaultLanguage(string methodName) { diff --git a/src/LanguageServer/Protocol/Workspaces/LspMiscellaneousFilesWorkspace.cs b/src/LanguageServer/Protocol/Workspaces/LspMiscellaneousFilesWorkspace.cs index d6d9c2809bb8d..289bda83594e4 100644 --- a/src/LanguageServer/Protocol/Workspaces/LspMiscellaneousFilesWorkspace.cs +++ b/src/LanguageServer/Protocol/Workspaces/LspMiscellaneousFilesWorkspace.cs @@ -50,8 +50,7 @@ internal sealed class LspMiscellaneousFilesWorkspace(ILspServices lspServices, I } var languageInfoProvider = lspServices.GetRequiredService(); - var languageInformation = languageInfoProvider.GetLanguageInformation(uri, languageId); - if (languageInformation == null) + if (!languageInfoProvider.TryGetLanguageInformation(uri, languageId, out var languageInformation)) { // Only log here since throwing here could take down the LSP server. logger.LogError($"Could not find language information for {uri} with absolute path {documentFilePath}"); diff --git a/src/LanguageServer/Protocol/Workspaces/LspWorkspaceManager.cs b/src/LanguageServer/Protocol/Workspaces/LspWorkspaceManager.cs index 41864db53d6dc..1121ca2c28c49 100644 --- a/src/LanguageServer/Protocol/Workspaces/LspWorkspaceManager.cs +++ b/src/LanguageServer/Protocol/Workspaces/LspWorkspaceManager.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Collections.Immutable; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -532,7 +533,7 @@ private static async ValueTask AreChecksumsEqualAsync(TextDocument documen /// /// Returns a Roslyn language name for the given URI. /// - internal string GetLanguageForUri(Uri uri) + internal bool TryGetLanguageForUri(Uri uri, [NotNullWhen(true)] out string? language) { string? languageId = null; if (_trackedDocuments.TryGetValue(uri, out var trackedDocument)) @@ -540,7 +541,14 @@ internal string GetLanguageForUri(Uri uri) languageId = trackedDocument.LanguageId; } - return _languageInfoProvider.GetLanguageInformation(uri, languageId).LanguageName; + if (_languageInfoProvider.TryGetLanguageInformation(uri, languageId, out var languageInfo)) + { + language = languageInfo.LanguageName; + return true; + } + + language = null; + return false; } /// diff --git a/src/LanguageServer/ProtocolUnitTests/HandlerTests.cs b/src/LanguageServer/ProtocolUnitTests/HandlerTests.cs index d79b96023a6c2..402b78d7ab003 100644 --- a/src/LanguageServer/ProtocolUnitTests/HandlerTests.cs +++ b/src/LanguageServer/ProtocolUnitTests/HandlerTests.cs @@ -17,6 +17,7 @@ using Roslyn.Test.Utilities; using Xunit; using Xunit.Abstractions; +using static Microsoft.CodeAnalysis.LanguageServer.UnitTests.LocaleTests; namespace Microsoft.CodeAnalysis.LanguageServer.UnitTests { @@ -295,6 +296,24 @@ await Assert.ThrowsAnyAsync(async () Assert.False(didReport); } + [Theory, CombinatorialData] + public async Task TestMutatingHandlerCrashesIfUnableToDetermineLanguage(bool mutatingLspWorkspace) + { + await using var testLspServer = await CreateTestLspServerAsync(string.Empty, mutatingLspWorkspace, new InitializationOptions { ServerKind = WellKnownLspServerKinds.CSharpVisualBasicLspServer }); + + // Run a mutating request against a file which we have no saved languageId for + // and where the language cannot be determined from the URI. + // This should crash the server. + var looseFileUri = ProtocolConversions.CreateAbsoluteUri(@"untitled:untitledFile"); + var request = new TestRequestTypeOne(new TextDocumentIdentifier + { + Uri = looseFileUri + }); + + await Assert.ThrowsAnyAsync(async () => await testLspServer.ExecuteRequestAsync(TestDocumentHandler.MethodName, request, CancellationToken.None)).ConfigureAwait(false); + await testLspServer.AssertServerShuttingDownAsync(); + } + internal record TestRequestTypeOne([property: JsonPropertyName("textDocument"), JsonRequired] TextDocumentIdentifier TextDocumentIdentifier); internal record TestRequestTypeTwo([property: JsonPropertyName("textDocument"), JsonRequired] TextDocumentIdentifier TextDocumentIdentifier); diff --git a/src/LanguageServer/ProtocolUnitTests/UriTests.cs b/src/LanguageServer/ProtocolUnitTests/UriTests.cs index 8ad1df6cca445..8fe9ef3348d4b 100644 --- a/src/LanguageServer/ProtocolUnitTests/UriTests.cs +++ b/src/LanguageServer/ProtocolUnitTests/UriTests.cs @@ -295,6 +295,34 @@ await Assert.ThrowsAnyAsync(async () new CustomResolveParams(new LSP.TextDocumentIdentifier { Uri = lowerCaseUri }), CancellationToken.None)); } + [Theory, CombinatorialData] + public async Task TestDoesNotCrashIfUnableToDetermineLanguageInfo(bool mutatingLspWorkspace) + { + // Create a server that supports LSP misc files and verify no misc files present. + await using var testLspServer = await CreateTestLspServerAsync(string.Empty, mutatingLspWorkspace, new InitializationOptions { ServerKind = WellKnownLspServerKinds.CSharpVisualBasicLspServer }); + + // Open an empty loose file that hasn't been saved with a name. + var looseFileUri = ProtocolConversions.CreateAbsoluteUri(@"untitled:untitledFile"); + await testLspServer.OpenDocumentAsync(looseFileUri, "hello", languageId: "csharp").ConfigureAwait(false); + + // Verify file is added to the misc file workspace. + var (workspace, _, document) = await testLspServer.GetManager().GetLspDocumentInfoAsync(new LSP.TextDocumentIdentifier { Uri = looseFileUri }, CancellationToken.None); + Assert.True(workspace is LspMiscellaneousFilesWorkspace); + AssertEx.NotNull(document); + Assert.Equal(looseFileUri, document.GetURI()); + Assert.Equal(looseFileUri.OriginalString, document.FilePath); + + // Close the document (deleting the saved language information) + await testLspServer.CloseDocumentAsync(looseFileUri); + + // Assert that the request throws but the server does not crash. + await Assert.ThrowsAnyAsync(async () + => await testLspServer.ExecuteRequestAsync(CustomResolveHandler.MethodName, + new CustomResolveParams(new LSP.TextDocumentIdentifier { Uri = looseFileUri }), CancellationToken.None)); + Assert.False(testLspServer.GetServerAccessor().HasShutdownStarted()); + Assert.False(testLspServer.GetQueueAccessor()!.Value.IsComplete()); + } + private record class ResolvedDocumentInfo(string WorkspaceKind, string ProjectLanguage); private record class CustomResolveParams([property: JsonPropertyName("textDocument")] LSP.TextDocumentIdentifier TextDocument);