Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not crash server if we fail to determine language for non-mutating request #75509

Merged
merged 4 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection;
using System.Threading;
Expand Down Expand Up @@ -172,10 +173,11 @@ protected IRequestExecutionQueue<TRequestContext> GetRequestExecutionQueue()
return _queue.Value;
}

public virtual string GetLanguageForRequest(string methodName, object? serializedRequest)
public virtual bool TryGetLanguageForRequest(string methodName, object? serializedRequest, [NotNullWhen(true)] out string? language)
{
Logger.LogInformation($"Using default language handler for {methodName}");
return LanguageServerConstants.DefaultLanguageName;
language = LanguageServerConstants.DefaultLanguageName;
return true;
}

protected abstract DelegatingEntryPoint CreateDelegatingEntryPoint(string method);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Threading;
using System.Threading.Tasks;

Expand Down Expand Up @@ -37,6 +38,12 @@ internal interface IQueueItem<TRequestContext>
/// </summary>
Task<(TRequestContext, TRequest)?> CreateRequestContextAsync<TRequest>(IMethodHandler handler, RequestHandlerMetadata requestHandlerMetadata, AbstractLanguageServer<TRequestContext> languageServer, CancellationToken cancellationToken);

/// <summary>
/// Handles when the queue needs to manually fail a request before the
/// handler is invoked without shutting down the entire queue.
/// </summary>
void FailRequest(string message);

/// <summary>
/// Provides access to LSP services.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
#nullable enable

using System;
using System.Collections.Frozen;
using System.Diagnostics.CodeAnalysis;
using System.Diagnostics.Contracts;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.VisualStudio.Threading;
Expand All @@ -30,6 +26,13 @@ internal class QueueItem<TRequestContext> : IQueueItem<TRequestContext>
private readonly ILspLogger _logger;
private readonly AbstractRequestScope? _requestTelemetryScope;

/// <summary>
/// True if this queue item has actually started handling the request
/// by delegating to the handler. False while the item is still being
/// processed by the queue.
/// </summary>
private bool _requestHandlingStarted = false;
dibarbet marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// A task completion source representing the result of this queue item's work.
/// This is the task that the client is waiting on.
Expand Down Expand Up @@ -158,6 +161,7 @@ private bool TryDeserializeRequest<TRequest>(
/// </summary>
public async Task StartRequestAsync<TRequest, TResponse>(TRequest request, TRequestContext? context, IMethodHandler handler, string language, CancellationToken cancellationToken)
{
_requestHandlingStarted = true;
_logger.LogStartContext($"{MethodName}");

try
Expand Down Expand Up @@ -240,4 +244,20 @@ public async Task StartRequestAsync<TRequest, TResponse>(TRequest request, TRequ
// so it can decide how to handle the result / exception.
await _completionSource.Task.ConfigureAwait(false);
}

public void FailRequest(string message)
{
// This is not valid to call after StartRequestAsync starts as they both access the same state.
// StartRequestAsync handles any failures internally once it runs.
if (_requestHandlingStarted)
{
throw new InvalidOperationException("Cannot manually fail queue item after it has started");
dibarbet marked this conversation as resolved.
Show resolved Hide resolved
}
var exception = new Exception(message);
_requestTelemetryScope?.RecordException(exception);
_logger.LogException(exception);

_completionSource.TrySetException(exception);
_requestTelemetryScope?.Dispose();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Collections.Frozen;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection;
using System.Threading;
Expand Down Expand Up @@ -245,10 +246,27 @@ private async Task ProcessQueueAsync()
// notifications have been completed by the time we attempt to determine the language, so we have the up to date map of URI to language.
// Since didOpen notifications are marked as mutating, the queue will not advance to the next request until the server has finished processing
// the didOpen, ensuring that this line will only run once all prior didOpens have completed.
var language = _languageServer.GetLanguageForRequest(work.MethodName, work.SerializedRequest);
var didGetLanguage = _languageServer.TryGetLanguageForRequest(work.MethodName, work.SerializedRequest, out var language);

// Now that we know the actual language, we can deserialize the request and start creating the request context.
var (metadata, handler, methodInfo) = GetHandlerForRequest(work, language);
var (metadata, handler, methodInfo) = GetHandlerForRequest(work, language ?? LanguageServerConstants.DefaultLanguageName);

// We had an issue determining the language. Generally this is very rare and only occurs
// if there is a mis-behaving client that sends us requests for files where we haven't saved the languageId.
// We should only crash if this was a mutating method, otherwise we should just fail the single request.
if (!didGetLanguage)
{
var message = $"Failed to get language for {work.MethodName}";
if (handler.MutatesSolutionState)
{
throw new InvalidOperationException(message);
}
else
{
work.FailRequest(message);
return;
}
}

// We now have the actual handler and language, so we can process the work item using the concrete types defined by the metadata.
await InvokeProcessCoreAsync(work, metadata, handler, methodInfo, concurrentlyExecutingTasks, currentWorkCts, cancellationToken).ConfigureAwait(false);
Expand Down
4 changes: 2 additions & 2 deletions src/LanguageServer/Protocol/ILanguageInfoProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis.Features.Workspaces;

namespace Microsoft.CodeAnalysis.LanguageServer
Expand All @@ -19,7 +20,6 @@ internal interface ILanguageInfoProvider : ILspService
/// It is totally possible to not find language based on the file path (e.g. a newly created file that hasn't been saved to disk).
/// In that case, we use the language Id that the LSP client gave us.
/// </remarks>
/// <exception cref="InvalidOperationException">Thrown when the language information cannot be determined.</exception>
LanguageInformation GetLanguageInformation(Uri documentUri, string? lspLanguageId);
bool TryGetLanguageInformation(Uri uri, string? lspLanguageId, [NotNullWhen(true)] out LanguageInformation? languageInformation);
}
}
13 changes: 8 additions & 5 deletions src/LanguageServer/Protocol/LanguageInfoProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using Microsoft.CodeAnalysis.Features.Workspaces;

Expand Down Expand Up @@ -43,22 +44,22 @@ internal class LanguageInfoProvider : ILanguageInfoProvider
{ ".mts", s_typeScriptLanguageInformation },
};

public LanguageInformation GetLanguageInformation(Uri uri, string? lspLanguageId)
public bool TryGetLanguageInformation(Uri uri, string? lspLanguageId, [NotNullWhen(true)] out LanguageInformation? languageInformation)
{
// First try to get language information from the URI path.
// We can do this for File uris and absolute uris. We use local path to get the value without any query parameters.
if (uri.IsFile || uri.IsAbsoluteUri)
{
var localPath = uri.LocalPath;
var extension = Path.GetExtension(localPath);
if (s_extensionToLanguageInformation.TryGetValue(extension, out var languageInformation))
if (s_extensionToLanguageInformation.TryGetValue(extension, out languageInformation))
{
return languageInformation;
return true;
}
}

// If the URI file path mapping failed, use the languageId from the LSP client (if any).
return lspLanguageId switch
languageInformation = lspLanguageId switch
{
"csharp" => s_csharpLanguageInformation,
"fsharp" => s_fsharpLanguageInformation,
Expand All @@ -67,8 +68,10 @@ public LanguageInformation GetLanguageInformation(Uri uri, string? lspLanguageId
"xaml" => s_xamlLanguageInformation,
"typescript" => s_typeScriptLanguageInformation,
"javascript" => s_typeScriptLanguageInformation,
_ => throw new InvalidOperationException($"Unable to determine language for '{uri}' with LSP language id '{lspLanguageId}'")
_ => null,
};

return languageInformation != null;
}
}
}
50 changes: 30 additions & 20 deletions src/LanguageServer/Protocol/RoslynLanguageServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Frozen;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -164,12 +165,13 @@ public Task OnInitializedAsync(ClientCapabilities clientCapabilities, RequestCon
return Task.CompletedTask;
}

public override string GetLanguageForRequest(string methodName, object? serializedParameters)
public override bool TryGetLanguageForRequest(string methodName, object? serializedParameters, [NotNullWhen(true)] out string? language)
{
if (serializedParameters == null)
{
Logger.LogInformation("No request parameters given, using default language handler");
return LanguageServerConstants.DefaultLanguageName;
language = LanguageServerConstants.DefaultLanguageName;
return true;
}

// We implement the STJ language server so this must be a JsonElement.
Expand All @@ -179,7 +181,8 @@ public override string GetLanguageForRequest(string methodName, object? serializ
// as we do not want languages to be able to override them.
if (ShouldUseDefaultLanguage(methodName))
{
return LanguageServerConstants.DefaultLanguageName;
language = LanguageServerConstants.DefaultLanguageName;
return true;
}

var lspWorkspaceManager = GetLspServices().GetRequiredService<LspWorkspaceManager>();
Expand All @@ -188,34 +191,41 @@ public override string GetLanguageForRequest(string methodName, object? serializ
// { "textDocument": { "uri": "<uri>" ... } ... }
//
// We can easily identify the URI for the request by looking for this structure
Uri? uri = null;
if (parameters.TryGetProperty("textDocument", out var textDocumentToken) ||
parameters.TryGetProperty("_vs_textDocument", out textDocumentToken))
{
var uriToken = textDocumentToken.GetProperty("uri");
var uri = JsonSerializer.Deserialize<Uri>(uriToken, ProtocolConversions.LspJsonSerializerOptions);
uri = JsonSerializer.Deserialize<Uri>(uriToken, ProtocolConversions.LspJsonSerializerOptions);
Contract.ThrowIfNull(uri, "Failed to deserialize uri property");
var language = lspWorkspaceManager.GetLanguageForUri(uri);
Logger.LogInformation($"Using {language} from request text document");
return language;
}

// All the LSP resolve params have the following known json structure
// { "data": { "TextDocument": { "uri": "<uri>" ... } ... } ... }
//
// We can deserialize the data object using our unified DocumentResolveData.
//var dataToken = parameters["data"];
if (parameters.TryGetProperty("data", out var dataToken))
else if (parameters.TryGetProperty("data", out var dataToken))
{
// All the LSP resolve params have the following known json structure
// { "data": { "TextDocument": { "uri": "<uri>" ... } ... } ... }
//
// We can deserialize the data object using our unified DocumentResolveData.
//var dataToken = parameters["data"];
var data = JsonSerializer.Deserialize<DocumentResolveData>(dataToken, ProtocolConversions.LspJsonSerializerOptions);
Contract.ThrowIfNull(data, "Failed to document resolve data object");
var language = lspWorkspaceManager.GetLanguageForUri(data.TextDocument.Uri);
Logger.LogInformation($"Using {language} from data text document");
return language;
uri = data.TextDocument.Uri;
}

if (uri == null)
{
// This request is not for a textDocument and is not a resolve request.
Logger.LogInformation("Request did not contain a textDocument, using default language handler");
language = LanguageServerConstants.DefaultLanguageName;
return true;
}

if (!lspWorkspaceManager.TryGetLanguageForUri(uri, out language))
{
Logger.LogError($"Failed to get language for {uri} with language {language}");
return false;
}

// This request is not for a textDocument and is not a resolve request.
Logger.LogInformation("Request did not contain a textDocument, using default language handler");
return LanguageServerConstants.DefaultLanguageName;
return true;

static bool ShouldUseDefaultLanguage(string methodName)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ internal sealed class LspMiscellaneousFilesWorkspace(ILspServices lspServices, I
}

var languageInfoProvider = lspServices.GetRequiredService<ILanguageInfoProvider>();
var languageInformation = languageInfoProvider.GetLanguageInformation(uri, languageId);
if (languageInformation == null)
if (!languageInfoProvider.TryGetLanguageInformation(uri, languageId, out var languageInformation))
{
// Only log here since throwing here could take down the LSP server.
logger.LogError($"Could not find language information for {uri} with absolute path {documentFilePath}");
Expand Down
12 changes: 10 additions & 2 deletions src/LanguageServer/Protocol/Workspaces/LspWorkspaceManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -532,15 +533,22 @@ private static async ValueTask<bool> AreChecksumsEqualAsync(TextDocument documen
/// <summary>
/// Returns a Roslyn language name for the given URI.
/// </summary>
internal string GetLanguageForUri(Uri uri)
internal bool TryGetLanguageForUri(Uri uri, [NotNullWhen(true)] out string? language)
{
string? languageId = null;
if (_trackedDocuments.TryGetValue(uri, out var trackedDocument))
{
languageId = trackedDocument.LanguageId;
}

return _languageInfoProvider.GetLanguageInformation(uri, languageId).LanguageName;
if (_languageInfoProvider.TryGetLanguageInformation(uri, languageId, out var languageInfo))
{
language = languageInfo.LanguageName;
return true;
}

language = null;
return false;
}

/// <summary>
Expand Down
19 changes: 19 additions & 0 deletions src/LanguageServer/ProtocolUnitTests/HandlerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using Roslyn.Test.Utilities;
using Xunit;
using Xunit.Abstractions;
using static Microsoft.CodeAnalysis.LanguageServer.UnitTests.LocaleTests;

namespace Microsoft.CodeAnalysis.LanguageServer.UnitTests
{
Expand Down Expand Up @@ -295,6 +296,24 @@ await Assert.ThrowsAnyAsync<Exception>(async ()
Assert.False(didReport);
}

[Theory, CombinatorialData]
public async Task TestMutatingHandlerCrashesIfUnableToDetermineLanguage(bool mutatingLspWorkspace)
{
await using var testLspServer = await CreateTestLspServerAsync(string.Empty, mutatingLspWorkspace, new InitializationOptions { ServerKind = WellKnownLspServerKinds.CSharpVisualBasicLspServer });

// Run a mutating request against a file which we have no saved languageId for
// and where the language cannot be determined from the URI.
// This should crash the server.
var looseFileUri = ProtocolConversions.CreateAbsoluteUri(@"untitled:untitledFile");
var request = new TestRequestTypeOne(new TextDocumentIdentifier
{
Uri = looseFileUri
});

await Assert.ThrowsAnyAsync<Exception>(async () => await testLspServer.ExecuteRequestAsync<TestRequestTypeOne, string>(TestDocumentHandler.MethodName, request, CancellationToken.None)).ConfigureAwait(false);
await testLspServer.AssertServerShuttingDownAsync();
}

internal record TestRequestTypeOne([property: JsonPropertyName("textDocument"), JsonRequired] TextDocumentIdentifier TextDocumentIdentifier);

internal record TestRequestTypeTwo([property: JsonPropertyName("textDocument"), JsonRequired] TextDocumentIdentifier TextDocumentIdentifier);
Expand Down
Loading
Loading