Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add struct base enumerator for interval trees. #73877

Merged
merged 19 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ namespace Microsoft.CodeAnalysis.Shared.Collections;
/// Generic function representing the type of interval testing operation that can be performed on an interval tree. For
/// example checking if an interval 'contains', 'intersects', or 'overlaps' with a requested span.
/// </summary>
internal delegate bool TestInterval<T, TIntrospector>(T value, int start, int length, in TIntrospector introspector)
where TIntrospector : struct, IIntervalIntrospector<T>;
internal interface IIntervalTester<T, TIntrospector>
CyrusNajmabadi marked this conversation as resolved.
Show resolved Hide resolved
where TIntrospector : struct, IIntervalIntrospector<T>
{
bool Test(T value, int start, int length, in TIntrospector introspector);
}

/// <summary>
/// Base interface all interval trees need to implement to get full functionality. Callers are not expected to use
Expand All @@ -26,22 +29,24 @@ internal interface IIntervalTree<T> : IEnumerable<T>
{
/// <summary>
/// Adds all intervals within the tree within the given start/length pair that match the given <paramref
/// name="testInterval"/> predicate. Results are added to the <paramref name="builder"/> array. The <paramref
/// name="intervalTester"/> predicate. Results are added to the <paramref name="builder"/> array. The <paramref
/// name="stopAfterFirst"/> indicates if the search should stop after the first interval is found. Results will be
/// returned in a sorted order based on the start point of the interval.
/// </summary>
/// <returns>The number of matching intervals found by the method.</returns>
int FillWithIntervalsThatMatch<TIntrospector>(
int start, int length, TestInterval<T, TIntrospector> testInterval,
ref TemporaryArray<T> builder, in TIntrospector introspector,
int FillWithIntervalsThatMatch<TIntrospector, TIntervalTester>(
int start, int length, ref TemporaryArray<T> builder,
in TIntrospector introspector, in TIntervalTester intervalTester,
bool stopAfterFirst)
where TIntrospector : struct, IIntervalIntrospector<T>;
where TIntrospector : struct, IIntervalIntrospector<T>
where TIntervalTester : struct, IIntervalTester<T, TIntrospector>;

/// <summary>
/// Practically equivalent to <see cref="FillWithIntervalsThatMatch{TIntrospector}"/> with a check that at least one
/// item was found. However, separated out as a separate method as implementations can often be more efficient just
/// Practically equivalent to <see cref="FillWithIntervalsThatMatch"/> with a check that at least one item was
/// found. However, separated out as a separate method as implementations can often be more efficient just
/// answering this question, versus the more complex "fill with intervals" question above.
/// </summary>
bool Any<TIntrospector>(int start, int length, TestInterval<T, TIntrospector> testInterval, in TIntrospector introspector)
where TIntrospector : struct, IIntervalIntrospector<T>;
bool Any<TIntrospector, TIntervalTester>(int start, int length, in TIntrospector introspector, in TIntervalTester intervalTester)
where TIntrospector : struct, IIntervalIntrospector<T>
where TIntervalTester : struct, IIntervalTester<T, TIntrospector>;
}
Original file line number Diff line number Diff line change
Expand Up @@ -261,28 +261,31 @@ private static int GetLeftChildIndex(int nodeIndex)
private static int GetRightChildIndex(int nodeIndex)
=> (2 * nodeIndex) + 2;

bool IIntervalTree<T>.Any<TIntrospector>(int start, int length, TestInterval<T, TIntrospector> testInterval, in TIntrospector introspector)
=> IntervalTreeHelpers<T, ImmutableIntervalTree<T>, /*TNode*/ int, FlatArrayIntervalTreeWitness>.Any(this, start, length, testInterval, in introspector);
bool IIntervalTree<T>.Any<TIntrospector, TIntervalTester>(int start, int length, in TIntrospector introspector, in TIntervalTester intervalTester)
=> IntervalTreeHelpers<T, ImmutableIntervalTree<T>, /*TNode*/ int, FlatArrayIntervalTreeWitness>.Any(this, start, length, introspector, intervalTester);

int IIntervalTree<T>.FillWithIntervalsThatMatch<TIntrospector>(
int start, int length, TestInterval<T, TIntrospector> testInterval,
ref TemporaryArray<T> builder, in TIntrospector introspector,
int IIntervalTree<T>.FillWithIntervalsThatMatch<TIntrospector, TIntervalTester>(
int start, int length, ref TemporaryArray<T> builder,
in TIntrospector introspector, in TIntervalTester intervalTester,
bool stopAfterFirst)
{
return IntervalTreeHelpers<T, ImmutableIntervalTree<T>, /*TNode*/ int, FlatArrayIntervalTreeWitness>.FillWithIntervalsThatMatch(
this, start, length, testInterval, ref builder, in introspector, stopAfterFirst);
this, start, length, ref builder, in introspector, in intervalTester, stopAfterFirst);
}

IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerator();

public IEnumerator<T> GetEnumerator()
IEnumerator<T> IEnumerable<T>.GetEnumerator()
=> GetEnumerator();

public IntervalTreeHelpers<T, ImmutableIntervalTree<T>, /*TNode*/ int, FlatArrayIntervalTreeWitness>.Enumerator GetEnumerator()
=> IntervalTreeHelpers<T, ImmutableIntervalTree<T>, /*TNode*/ int, FlatArrayIntervalTreeWitness>.GetEnumerator(this);

/// <summary>
/// Wrapper type to allow the IntervalTreeHelpers type to work with this type.
/// </summary>
private readonly struct FlatArrayIntervalTreeWitness : IIntervalTreeWitness<T, ImmutableIntervalTree<T>, int>
internal readonly struct FlatArrayIntervalTreeWitness : IIntervalTreeWitness<T, ImmutableIntervalTree<T>, int>
{
public T GetValue(ImmutableIntervalTree<T> tree, int node)
=> tree._array[node].Value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,55 +15,56 @@ namespace Microsoft.CodeAnalysis.Shared.Collections;
/// </summary>
internal readonly struct IntervalTreeAlgorithms<T, TIntervalTree>(TIntervalTree tree) where TIntervalTree : IIntervalTree<T>
{
public ImmutableArray<T> GetIntervalsThatMatch<TIntrospector>(
int start, int length, TestInterval<T, TIntrospector> testInterval, in TIntrospector introspector)
public ImmutableArray<T> GetIntervalsThatMatch<TIntrospector, TIntervalTester>(
int start, int length, in TIntrospector introspector, in TIntervalTester intervalTester)
where TIntrospector : struct, IIntervalIntrospector<T>
where TIntervalTester : struct, IIntervalTester<T, TIntrospector>
{
using var result = TemporaryArray<T>.Empty;
tree.FillWithIntervalsThatMatch(start, length, testInterval, ref result.AsRef(), in introspector, stopAfterFirst: false);
tree.FillWithIntervalsThatMatch(start, length, ref result.AsRef(), in introspector, in intervalTester, stopAfterFirst: false);
return result.ToImmutableAndClear();
}

public ImmutableArray<T> GetIntervalsThatOverlapWith<TIntrospector>(
int start, int length, in TIntrospector introspector)
where TIntrospector : struct, IIntervalIntrospector<T>
{
return GetIntervalsThatMatch(start, length, Tests<TIntrospector>.OverlapsWithTest, in introspector);
return GetIntervalsThatMatch(start, length, in introspector, default(Tests<TIntrospector>.OverlapsWithIntervalTester));
}

public ImmutableArray<T> GetIntervalsThatIntersectWith<TIntrospector>(
int start, int length, in TIntrospector introspector)
where TIntrospector : struct, IIntervalIntrospector<T>
{
return GetIntervalsThatMatch(start, length, Tests<TIntrospector>.IntersectsWithTest, in introspector);
return GetIntervalsThatMatch(start, length, in introspector, default(Tests<TIntrospector>.IntersectsWithIntervalTester));
}

public ImmutableArray<T> GetIntervalsThatContain<TIntrospector>(
int start, int length, in TIntrospector introspector)
where TIntrospector : struct, IIntervalIntrospector<T>
{
return GetIntervalsThatMatch(start, length, Tests<TIntrospector>.ContainsTest, in introspector);
return GetIntervalsThatMatch(start, length, in introspector, default(Tests<TIntrospector>.ContainsIntervalTester));
}

public void FillWithIntervalsThatOverlapWith<TIntrospector>(
int start, int length, ref TemporaryArray<T> builder, in TIntrospector introspector)
where TIntrospector : struct, IIntervalIntrospector<T>
{
tree.FillWithIntervalsThatMatch(start, length, Tests<TIntrospector>.OverlapsWithTest, ref builder, in introspector, stopAfterFirst: false);
tree.FillWithIntervalsThatMatch(start, length, ref builder, in introspector, default(Tests<TIntrospector>.OverlapsWithIntervalTester),stopAfterFirst: false);
}

public void FillWithIntervalsThatIntersectWith<TIntrospector>(
int start, int length, ref TemporaryArray<T> builder, in TIntrospector introspector)
where TIntrospector : struct, IIntervalIntrospector<T>
{
tree.FillWithIntervalsThatMatch(start, length, Tests<TIntrospector>.IntersectsWithTest, ref builder, in introspector, stopAfterFirst: false);
tree.FillWithIntervalsThatMatch(start, length, ref builder, in introspector, default(Tests<TIntrospector>.IntersectsWithIntervalTester), stopAfterFirst: false);
}

public void FillWithIntervalsThatContain<TIntrospector>(
int start, int length, ref TemporaryArray<T> builder, in TIntrospector introspector)
where TIntrospector : struct, IIntervalIntrospector<T>
{
tree.FillWithIntervalsThatMatch(start, length, Tests<TIntrospector>.ContainsTest, ref builder, in introspector, stopAfterFirst: false);
tree.FillWithIntervalsThatMatch(start, length, ref builder, in introspector, default(Tests<TIntrospector>.ContainsIntervalTester), stopAfterFirst: false);
}

public bool HasIntervalThatIntersectsWith<TIntrospector>(
Expand All @@ -77,21 +78,21 @@ public bool HasIntervalThatIntersectsWith<TIntrospector>(
int start, int length, in TIntrospector introspector)
where TIntrospector : struct, IIntervalIntrospector<T>
{
return tree.Any(start, length, Tests<TIntrospector>.IntersectsWithTest, in introspector);
return tree.Any(start, length, in introspector, default(Tests<TIntrospector>.IntersectsWithIntervalTester));
}

public bool HasIntervalThatOverlapsWith<TIntrospector>(
int start, int length, in TIntrospector introspector)
where TIntrospector : struct, IIntervalIntrospector<T>
{
return tree.Any(start, length, Tests<TIntrospector>.OverlapsWithTest, in introspector);
return tree.Any(start, length, in introspector, default(Tests<TIntrospector>.OverlapsWithIntervalTester));
}

public bool HasIntervalThatContains<TIntrospector>(
int start, int length, in TIntrospector introspector)
where TIntrospector : struct, IIntervalIntrospector<T>
{
return tree.Any(start, length, Tests<TIntrospector>.ContainsTest, in introspector);
return tree.Any(start, length, in introspector, default(Tests<TIntrospector>.ContainsIntervalTester));
}

public static bool Contains<TIntrospector>(T value, int start, int length, in TIntrospector introspector)
Expand Down Expand Up @@ -151,8 +152,22 @@ private static bool OverlapsWith<TIntrospector>(T value, int start, int length,
private static class Tests<TIntrospector>
where TIntrospector : struct, IIntervalIntrospector<T>
{
public static readonly TestInterval<T, TIntrospector> ContainsTest = Contains;
public static readonly TestInterval<T, TIntrospector> IntersectsWithTest = IntersectsWith;
public static readonly TestInterval<T, TIntrospector> OverlapsWithTest = OverlapsWith;
public readonly struct ContainsIntervalTester : IIntervalTester<T, TIntrospector>
{
public bool Test(T value, int start, int length, in TIntrospector introspector)
=> Contains(value, start, length, in introspector);
}

public readonly struct IntersectsWithIntervalTester : IIntervalTester<T, TIntrospector>
{
public bool Test(T value, int start, int length, in TIntrospector introspector)
=> IntersectsWith(value, start, length, in introspector);
}

public readonly struct OverlapsWithIntervalTester : IIntervalTester<T, TIntrospector>
{
public bool Test(T value, int start, int length, in TIntrospector introspector)
=> OverlapsWith(value, start, length, in introspector);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections;
using System.Collections.Generic;
using Microsoft.CodeAnalysis.Text;

namespace Microsoft.CodeAnalysis.Shared.Collections;

internal static partial class IntervalTreeHelpers<T, TIntervalTree, TNode, TIntervalTreeWitness>
where TIntervalTree : IIntervalTree<T>
where TIntervalTreeWitness : struct, IIntervalTreeWitness<T, TIntervalTree, TNode>
{
public struct Enumerator(TIntervalTree tree) : IEnumerator<T>
{
/// <summary>
/// An introspector that always throws. Used when we need to call an api that takes this, but we know will never
/// call into it due to other arguments we pass along.
/// </summary>
private readonly struct AlwaysThrowIntrospector : IIntervalIntrospector<T>
{
public TextSpan GetSpan(T value) => throw new System.NotImplementedException();
}

/// <summary>
/// Because we're passing the full span of all ints, we know that we'll never call into the introspector. Since
/// all intervals will always be in that span.
/// </summary>
private NodeEnumerator<AlwaysThrowIntrospector> _nodeEnumerator = new(tree, start: int.MinValue, end: int.MaxValue, default);

public readonly T Current => default(TIntervalTreeWitness).GetValue(tree, _nodeEnumerator.Current);

readonly object IEnumerator.Current => this.Current!;

public bool MoveNext() => _nodeEnumerator.MoveNext();
public readonly void Reset() => _nodeEnumerator.Reset();
public readonly void Dispose() => _nodeEnumerator.Dispose();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections;
using System.Collections.Generic;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.Shared.Collections;

internal static partial class IntervalTreeHelpers<T, TIntervalTree, TNode, TIntervalTreeWitness>
where TIntervalTree : IIntervalTree<T>
where TIntervalTreeWitness : struct, IIntervalTreeWitness<T, TIntervalTree, TNode>
{
/// <summary>
/// Struct based enumerator, so we can iterate an interval tree without allocating.
/// </summary>
private struct NodeEnumerator<TIntrospector> : IEnumerator<TNode>
where TIntrospector : struct, IIntervalIntrospector<T>
{
private readonly TIntervalTree _tree;
private readonly TIntrospector _introspector;
private readonly int _start;
private readonly int _end;

private readonly PooledObject<Stack<TNode>> _pooledStack;
private readonly Stack<TNode>? _stack;

private bool _started;
private TNode? _currentNode;
private bool _currentNodeHasValue;

public NodeEnumerator(TIntervalTree tree, int start, int end, in TIntrospector introspector)
{
_tree = tree;
_start = start;
_end = end;
_introspector = introspector;

_currentNodeHasValue = default(TIntervalTreeWitness).TryGetRoot(_tree, out _currentNode);

// Avoid any pooling work if we don't even have a root.
if (_currentNodeHasValue)
{
_pooledStack = s_nodeStackPool.GetPooledObject();
_stack = _pooledStack.Object;
}
}

readonly object IEnumerator.Current => this.Current!;

public readonly TNode Current => _currentNode!;

public bool MoveNext()
{
// Trivial empty case
if (_stack is null)
return false;

// The first time through, we just want to start processing with the root node. Every other time through,
// after we've yielded the current element, we want to walk down the right side of it.
if (_started)
_currentNodeHasValue = ShouldExamineRight(_tree, _start, _end, _currentNode!, _introspector, out _currentNode);

// After we're called once, we're in the started point.
_started = true;

if (!_currentNodeHasValue && _stack.Count <= 0)
return false;

// Traverse all the way down the left side of the tree, pushing nodes onto the stack as we go.
while (_currentNodeHasValue)
{
_stack.Push(_currentNode!);
_currentNodeHasValue = ShouldExamineLeft(_tree, _start, _currentNode!, _introspector, out _currentNode);
}

Contract.ThrowIfTrue(_currentNodeHasValue);
Contract.ThrowIfTrue(_stack.Count == 0);
_currentNode = _stack.Pop();
return true;
}

public readonly void Dispose()
=> _pooledStack.Dispose();
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is fine on an unintialized _pooledStack value. Dispose for it checks if it actually has an object from the pool, and no-ops if not.


public readonly void Reset()
=> throw new System.NotImplementedException();
}
}
Loading
Loading