Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit amount of tokens used to calculate LCS distance #68151

Merged
merged 3 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

using System;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.CodeAnalysis.Differencing;
using Microsoft.CodeAnalysis.Test.Utilities;
using Roslyn.Test.Utilities;
Expand Down Expand Up @@ -93,16 +94,6 @@ public void ComputeDistance1()
Assert.Equal(0.67, Math.Round(distance, 2));
}

[Fact]
public void ComputeDistance2()
{
var distance = SyntaxComparer.ComputeDistance(
ImmutableArray.Create(MakeLiteral(0), MakeLiteral(1), MakeLiteral(2)),
ImmutableArray.Create(MakeLiteral(1), MakeLiteral(3)));

Assert.Equal(0.67, Math.Round(distance, 2));
}

[Fact]
public void ComputeDistance3()
{
Expand All @@ -113,16 +104,6 @@ public void ComputeDistance3()
Assert.Equal(0.33, Math.Round(distance, 2));
}

[Fact]
public void ComputeDistance4()
{
var distance = SyntaxComparer.ComputeDistance(
ImmutableArray.Create(SyntaxFactory.Token(SyntaxKind.PublicKeyword), SyntaxFactory.Token(SyntaxKind.StaticKeyword), SyntaxFactory.Token(SyntaxKind.AsyncKeyword)),
ImmutableArray.Create(SyntaxFactory.Token(SyntaxKind.StaticKeyword), SyntaxFactory.Token(SyntaxKind.PublicKeyword), SyntaxFactory.Token(SyntaxKind.AsyncKeyword)));

Assert.Equal(0.33, Math.Round(distance, 2));
}

[Fact]
public void ComputeDistance_Token()
{
Expand All @@ -141,18 +122,6 @@ public void ComputeDistance_Node()
public void ComputeDistance_Null()
{
var distance = SyntaxComparer.ComputeDistance(
default,
ImmutableArray.Create(SyntaxFactory.Token(SyntaxKind.StaticKeyword)));

Assert.Equal(1, Math.Round(distance, 2));

distance = SyntaxComparer.ComputeDistance(
default,
ImmutableArray.Create(MakeLiteral(0)));

Assert.Equal(1, Math.Round(distance, 2));

distance = SyntaxComparer.ComputeDistance(
null,
Array.Empty<SyntaxNode>());

Expand All @@ -176,5 +145,20 @@ public void ComputeDistance_Null()

Assert.Equal(0, Math.Round(distance, 2));
}

[Fact]
public void ComputeDistance_LongSequences()
{
var t1 = SyntaxFactory.Token(SyntaxKind.PublicKeyword);
var t2 = SyntaxFactory.Token(SyntaxKind.PrivateKeyword);
var t3 = SyntaxFactory.Token(SyntaxKind.ProtectedKeyword);

var distance = SyntaxComparer.ComputeDistance(
Enumerable.Range(0, 10000).Select(i => i < 2000 ? t1 : t2),
Enumerable.Range(0, 10000).Select(i => i < 2000 ? t1 : t3));

// long sequences are indistinguishable if they have common prefix shorter then threshold:
Assert.Equal(0, distance);
}
}
}
30 changes: 9 additions & 21 deletions src/Features/CSharp/Portable/EditAndContinue/SyntaxComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Differencing;
Expand Down Expand Up @@ -1593,25 +1594,21 @@ public static double ComputeDistance(SyntaxNode? oldNode, SyntaxNode? newNode)
/// Distance is a number within [0, 1], the smaller the more similar the tokens are.
/// </remarks>
public static double ComputeDistance(SyntaxToken oldToken, SyntaxToken newToken)
=> LongestCommonSubstring.ComputeDistance(oldToken.Text, newToken.Text);
=> LongestCommonSubstring.ComputePrefixDistance(
oldToken.Text, Math.Min(oldToken.Text.Length, LongestCommonSubsequence.MaxSequenceLengthForDistanceCalculation),
newToken.Text, Math.Min(newToken.Text.Length, LongestCommonSubsequence.MaxSequenceLengthForDistanceCalculation));

/// <summary>
/// Calculates the distance between two sequences of syntax tokens, disregarding trivia.
/// </summary>
/// <remarks>
/// Distance is a number within [0, 1], the smaller the more similar the sequences are.
/// </remarks>
public static double ComputeDistance(IEnumerable<SyntaxToken>? oldTokens, IEnumerable<SyntaxToken>? newTokens)
=> LcsTokens.Instance.ComputeDistance(oldTokens.AsImmutableOrEmpty(), newTokens.AsImmutableOrEmpty());
private static ImmutableArray<T> CreateArrayForDistanceCalculation<T>(IEnumerable<T>? enumerable)
=> enumerable is null ? ImmutableArray<T>.Empty : enumerable.Take(LongestCommonSubsequence.MaxSequenceLengthForDistanceCalculation).ToImmutableArray();

/// <summary>
/// Calculates the distance between two sequences of syntax tokens, disregarding trivia.
/// </summary>
/// <remarks>
/// Distance is a number within [0, 1], the smaller the more similar the sequences are.
/// </remarks>
public static double ComputeDistance(ImmutableArray<SyntaxToken> oldTokens, ImmutableArray<SyntaxToken> newTokens)
=> LcsTokens.Instance.ComputeDistance(oldTokens.NullToEmpty(), newTokens.NullToEmpty());
public static double ComputeDistance(IEnumerable<SyntaxToken>? oldTokens, IEnumerable<SyntaxToken>? newTokens)
=> LcsTokens.Instance.ComputeDistance(CreateArrayForDistanceCalculation(oldTokens), CreateArrayForDistanceCalculation(newTokens));

/// <summary>
/// Calculates the distance between two sequences of syntax nodes, disregarding trivia.
Expand All @@ -1620,16 +1617,7 @@ public static double ComputeDistance(ImmutableArray<SyntaxToken> oldTokens, Immu
/// Distance is a number within [0, 1], the smaller the more similar the sequences are.
/// </remarks>
public static double ComputeDistance(IEnumerable<SyntaxNode>? oldNodes, IEnumerable<SyntaxNode>? newNodes)
=> LcsNodes.Instance.ComputeDistance(oldNodes.AsImmutableOrEmpty(), newNodes.AsImmutableOrEmpty());

/// <summary>
/// Calculates the distance between two sequences of syntax tokens, disregarding trivia.
/// </summary>
/// <remarks>
/// Distance is a number within [0, 1], the smaller the more similar the sequences are.
/// </remarks>
public static double ComputeDistance(ImmutableArray<SyntaxNode> oldNodes, ImmutableArray<SyntaxNode> newNodes)
=> LcsNodes.Instance.ComputeDistance(oldNodes.NullToEmpty(), newNodes.NullToEmpty());
=> LcsNodes.Instance.ComputeDistance(CreateArrayForDistanceCalculation(oldNodes), CreateArrayForDistanceCalculation(newNodes));

/// <summary>
/// Calculates the edits that transform one sequence of syntax nodes to another, disregarding trivia.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1389,17 +1389,13 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.EditAndContinue
''' Distance is a number within [0, 1], the smaller the more similar the tokens are.
''' </remarks>
Public Overloads Shared Function ComputeDistance(oldToken As SyntaxToken, newToken As SyntaxToken) As Double
Return LongestCommonSubstring.ComputeDistance(oldToken.ValueText, newToken.ValueText)
Return LongestCommonSubstring.ComputePrefixDistance(
oldToken.Text, Math.Min(oldToken.Text.Length, LongestCommonSubsequence.MaxSequenceLengthForDistanceCalculation),
newToken.Text, Math.Min(newToken.Text.Length, LongestCommonSubsequence.MaxSequenceLengthForDistanceCalculation))
End Function

''' <summary>
''' Calculates the distance between two sequences of syntax tokens, disregarding trivia.
''' </summary>
''' <remarks>
''' Distance is a number within [0, 1], the smaller the more similar the sequences are.
''' </remarks>
Public Overloads Shared Function ComputeDistance(oldTokens As IEnumerable(Of SyntaxToken), newTokens As IEnumerable(Of SyntaxToken)) As Double
Return ComputeDistance(oldTokens.AsImmutableOrNull(), newTokens.AsImmutableOrNull())
Private Shared Function CreateArrayForDistanceCalculation(Of T)(enumerable As IEnumerable(Of T)) As ImmutableArray(Of T)
Return If(enumerable Is Nothing, ImmutableArray(Of T).Empty, enumerable.Take(LongestCommonSubsequence.MaxSequenceLengthForDistanceCalculation).ToImmutableArray())
End Function

''' <summary>
Expand All @@ -1408,8 +1404,8 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.EditAndContinue
''' <remarks>
''' Distance is a number within [0, 1], the smaller the more similar the sequences are.
''' </remarks>
Public Overloads Shared Function ComputeDistance(oldTokens As ImmutableArray(Of SyntaxToken), newTokens As ImmutableArray(Of SyntaxToken)) As Double
Return LcsTokens.Instance.ComputeDistance(oldTokens.NullToEmpty(), newTokens.NullToEmpty())
Public Overloads Shared Function ComputeDistance(oldTokens As IEnumerable(Of SyntaxToken), newTokens As IEnumerable(Of SyntaxToken)) As Double
Return ComputeDistance(CreateArrayForDistanceCalculation(oldTokens), CreateArrayForDistanceCalculation(newTokens))
End Function

''' <summary>
Expand All @@ -1419,17 +1415,7 @@ Namespace Microsoft.CodeAnalysis.VisualBasic.EditAndContinue
''' Distance is a number within [0, 1], the smaller the more similar the sequences are.
''' </remarks>
Public Overloads Shared Function ComputeDistance(oldTokens As IEnumerable(Of SyntaxNode), newTokens As IEnumerable(Of SyntaxNode)) As Double
Return ComputeDistance(oldTokens.AsImmutableOrNull(), newTokens.AsImmutableOrNull())
End Function

''' <summary>
''' Calculates the distance between two sequences of syntax nodes, disregarding trivia.
''' </summary>
''' <remarks>
''' Distance is a number within [0, 1], the smaller the more similar the sequences are.
''' </remarks>
Public Overloads Shared Function ComputeDistance(oldTokens As ImmutableArray(Of SyntaxNode), newTokens As ImmutableArray(Of SyntaxNode)) As Double
Return LcsNodes.Instance.ComputeDistance(oldTokens.NullToEmpty(), newTokens.NullToEmpty())
Return ComputeDistance(CreateArrayForDistanceCalculation(oldTokens), CreateArrayForDistanceCalculation(newTokens))
End Function

''' <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@
using System.Collections.Generic;
using System.Diagnostics;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.Differencing
{
internal abstract class LongestCommonSubsequence
{
/// <summary>
/// Limit the number of tokens used to compute distance between sequences of tokens so that
/// we always use the pooled buffers. The combined length of the two sequences being compared
/// must be less than <see cref="VBuffer.PooledSegmentMaxDepthThreshold"/>.
/// </summary>
public const int MaxSequenceLengthForDistanceCalculation = VBuffer.PooledSegmentMaxDepthThreshold / 2;

// Define the pool in a non-generic base class to allow sharing among instantiations.
private static readonly ObjectPool<VBuffer> s_pool = new(() => new VBuffer());

Expand All @@ -40,17 +46,31 @@ protected sealed class VBuffer
/// For 150 it'd be 91KB, which would be allocated on LOH.
/// The buffers grow by factor of <see cref="GrowFactor"/>, so the next buffer will be allocated on LOH.
/// </summary>
private const int FirstBufferMaxDepth = 100;
public const int FirstSegmentMaxDepth = 100;

// 3 + Sum { d = 1..maxDepth : 2*d+1 } = (maxDepth + 1)^2 + 2
private const int FirstBufferLength = (FirstBufferMaxDepth + 1) * (FirstBufferMaxDepth + 1) + 2;

internal const int GrowFactor = 2;
private const int FirstSegmentLength = (FirstSegmentMaxDepth + 1) * (FirstSegmentMaxDepth + 1) + 2;

// Segment Segment Total buffer
// MaxDepth length length
// ---------------------------------------
// 100 10,204 10,204
// 150 12,600 22,804
// 225 28,275 51,079
// 338 63,845 114,924
// 507 143,143 258,067
// 761 322,580 580,647
// 1142 725,805 1,306,452
// 1713 1,631,347 2,937,799 <-- last pooled segment
// 2570 3,672,245 6,610,044
// 3855 8,258,695 14,868,739
internal const double GrowFactor = 1.5;

/// <summary>
/// Do not pool segments that are too large.
/// Do not expand pooled buffers to more than ~12 MB total size (sum of all linked segment sizes).
/// This threshold is achieved when <see cref="MaxDepth"/> is greater than <see cref="PooledSegmentMaxDepthThreshold"/> = sqrt(size_limit / sizeof(int)).
/// </summary>
internal const int MaxPooledBufferSize = 1024 * 1024;
internal const int PooledSegmentMaxDepthThreshold = 1800;

public VBuffer Previous { get; private set; }
public VBuffer Next { get; private set; }
Expand All @@ -62,22 +82,22 @@ protected sealed class VBuffer

public VBuffer()
{
_array = new int[FirstBufferLength];
MaxDepth = FirstBufferMaxDepth;
_array = new int[FirstSegmentLength];
MaxDepth = FirstSegmentMaxDepth;
}

public VBuffer(VBuffer previous)
{
Debug.Assert(previous != null);

var minDepth = previous.MaxDepth + 1;
var maxDepth = previous.MaxDepth * GrowFactor;
var maxDepth = (int)(previous.MaxDepth * GrowFactor);

Debug.Assert(minDepth > 0);
Debug.Assert(minDepth <= maxDepth);

Previous = previous;
_array = new int[GetNextBufferLength(minDepth - 1, maxDepth)];
_array = new int[GetNextSegmentLength(minDepth - 1, maxDepth)];
MinDepth = minDepth;
MaxDepth = maxDepth;

Expand All @@ -95,7 +115,7 @@ public VArray GetVArray(int depth)
}

public bool IsTooLargeToPool
=> _array.Length > MaxPooledBufferSize;
=> MaxDepth > PooledSegmentMaxDepthThreshold;

private static int GetVArrayLength(int depth)
=> 2 * Math.Max(depth, 1) + 1;
Expand All @@ -105,7 +125,7 @@ private static int GetVArrayStart(int depth)
=> (depth == 0) ? 0 : depth * depth + 2;

// Sum { d = previousChunkDepth..maxDepth : 2*d+1 } = (maxDepth + 1)^2 - precedingBufferMaxDepth^2
private static int GetNextBufferLength(int precedingBufferMaxDepth, int maxDepth)
private static int GetNextSegmentLength(int precedingBufferMaxDepth, int maxDepth)
=> (maxDepth + 1) * (maxDepth + 1) - precedingBufferMaxDepth * precedingBufferMaxDepth;

public void Unlink()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ private LongestCommonSubstring()
protected override bool ItemsEqual(string oldSequence, int oldIndex, string newSequence, int newIndex)
=> oldSequence[oldIndex] == newSequence[newIndex];

public static double ComputeDistance(string oldValue, string newValue)
=> s_instance.ComputeDistance(oldValue, oldValue.Length, newValue, newValue.Length);
public static double ComputePrefixDistance(string oldValue, int oldLength, string newValue, int newLength)
=> s_instance.ComputeDistance(oldValue, oldLength, newValue, newLength);

public static IEnumerable<SequenceEdit> GetEdits(string oldValue, string newValue)
=> s_instance.GetEdits(oldValue, oldValue.Length, newValue, newValue.Length);
Expand Down