Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unroll SequenceEqual(ref byte, ref byte, nuint) in JIT #83945

Merged
merged 15 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/coreclr/jit/importercalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3782,6 +3782,7 @@ GenTree* Compiler::impIntrinsic(GenTree* newobjThis,
break;
}

case NI_System_SpanHelpers_SequenceEqual:
case NI_System_Buffer_Memmove:
{
// We'll try to unroll this in lower for constant input.
Expand Down Expand Up @@ -8139,6 +8140,13 @@ NamedIntrinsic Compiler::lookupNamedIntrinsic(CORINFO_METHOD_HANDLE method)
result = NI_System_Span_get_Length;
}
}
else if (strcmp(className, "SpanHelpers") == 0)
{
if (strcmp(methodName, "SequenceEqual") == 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is completely unrelated to the fix or changes being done, sorry... but has anyone ever tried reversing the methodName and className tests for a performance hack in the JIT itself?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IDisposable lookupNamedIntrinsics never show up in our JIT traces so we don't bother. This code is only executed for methods with [Intrinsic] attribute so for 99% of methods it doesn't kick in.

We could use here a Trie/binary search if it was a real problem

{
result = NI_System_SpanHelpers_SequenceEqual;
}
}
else if (strcmp(className, "String") == 0)
{
if (strcmp(methodName, "Equals") == 0)
Expand Down
202 changes: 194 additions & 8 deletions src/coreclr/jit/lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,185 @@ GenTree* Lowering::LowerCallMemmove(GenTreeCall* call)
return nullptr;
}

//------------------------------------------------------------------------
// LowerCallMemcmp: Replace SpanHelpers.SequenceEqual)(left, right, CNS_SIZE)
// with a series of merged comparisons (via GT_IND nodes)
//
// Arguments:
// tree - GenTreeCall node to unroll as memcmp
//
// Return Value:
// nullptr if no changes were made
//
GenTree* Lowering::LowerCallMemcmp(GenTreeCall* call)
{
JITDUMP("Considering Memcmp [%06d] for unrolling.. ", comp->dspTreeID(call))
assert(comp->lookupNamedIntrinsic(call->gtCallMethHnd) == NI_System_SpanHelpers_SequenceEqual);
assert(call->gtArgs.CountUserArgs() == 3);
assert(TARGET_POINTER_SIZE == 8);

if (!comp->opts.OptimizationEnabled())
{
JITDUMP("Optimizations aren't allowed - bail out.\n")
return nullptr;
}

if (comp->info.compHasNextCallRetAddr)
{
JITDUMP("compHasNextCallRetAddr=true so we won't be able to remove the call - bail out.\n")
return nullptr;
}

GenTree* lengthArg = call->gtArgs.GetUserArgByIndex(2)->GetNode();
if (lengthArg->IsIntegralConst())
{
ssize_t cnsSize = lengthArg->AsIntCon()->IconValue();
JITDUMP("Size=%ld.. ", (LONG)cnsSize);
// TODO-CQ: drop the whole thing in case of 0
if (cnsSize > 0)
{
GenTree* lArg = call->gtArgs.GetUserArgByIndex(0)->GetNode();
GenTree* rArg = call->gtArgs.GetUserArgByIndex(1)->GetNode();
// TODO: Add SIMD path for [16..128] via GT_HWINTRINSIC nodes
if (cnsSize <= 16)
{
unsigned loadWidth = 1 << BitOperations::Log2((unsigned)cnsSize);
var_types loadType;
if (loadWidth == 1)
{
loadType = TYP_UBYTE;
}
else if (loadWidth == 2)
{
loadType = TYP_USHORT;
}
else if (loadWidth == 4)
{
loadType = TYP_INT;
}
else if ((loadWidth == 8) || (loadWidth == 16))
{
loadWidth = 8;
loadType = TYP_LONG;
}
else
{
unreached();
}
var_types actualLoadType = genActualType(loadType);

GenTree* result = nullptr;

// loadWidth == cnsSize means a single load is enough for both args
if ((loadWidth == (unsigned)cnsSize) && (loadWidth <= 8))
{
// We're going to emit something like the following:
//
// bool result = *(int*)leftArg == *(int*)rightArg
//
// ^ in the given example we unroll for length=4
//
GenTree* lIndir = comp->gtNewIndir(loadType, lArg);
GenTree* rIndir = comp->gtNewIndir(loadType, rArg);
result = comp->gtNewOperNode(GT_EQ, TYP_INT, lIndir, rIndir);

BlockRange().InsertAfter(lArg, lIndir);
BlockRange().InsertAfter(rArg, rIndir);
BlockRange().InsertBefore(call, result);
}
else
{
// First, make both args multi-use:
LIR::Use lArgUse;
LIR::Use rArgUse;
bool lFoundUse = BlockRange().TryGetUse(lArg, &lArgUse);
bool rFoundUse = BlockRange().TryGetUse(rArg, &rArgUse);
assert(lFoundUse && rFoundUse);
Comment on lines +1957 to +1961
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit wasteful to go looking for the uses of this given that we know the arg they come from. E.g. you could do

Suggested change
LIR::Use lArgUse;
LIR::Use rArgUse;
bool lFoundUse = BlockRange().TryGetUse(lArg, &lArgUse);
bool rFoundUse = BlockRange().TryGetUse(rArg, &rArgUse);
assert(lFoundUse && rFoundUse);
CallArg* lArg = call->gtArgs.GetUserArgByIndex(0);
GenTree*& lArgNode = lArg->GetLateNode() == nullptr ? lArg->EarlyNodeRef() : lArg->LateNodeRef();
...
LIR::Use lArgUse(BlockRange(), &lArgNode, call);

I don't have a super strong opinion on it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank, will check in a follow up once SPMI is collected - want to see if it's worth the effort to improve this expansion. jit-diff utils found around 30 methods only

GenTree* lArgClone = comp->gtNewLclLNode(lArgUse.ReplaceWithLclVar(comp), lArg->TypeGet());
GenTree* rArgClone = comp->gtNewLclLNode(rArgUse.ReplaceWithLclVar(comp), rArg->TypeGet());
EgorBo marked this conversation as resolved.
Show resolved Hide resolved
BlockRange().InsertBefore(call, lArgClone, rArgClone);

// We're going to emit something like the following:
//
// bool result = ((*(int*)leftArg ^ *(int*)rightArg) |
// (*(int*)(leftArg + 1) ^ *((int*)(rightArg + 1)))) == 0;
//
// ^ in the given example we unroll for length=5
//
// In IR:
//
// * EQ int
// +--* OR int
// | +--* XOR int
// | | +--* IND int
// | | | \--* LCL_VAR byref V1
// | | \--* IND int
// | | \--* LCL_VAR byref V2
// | \--* XOR int
// | +--* IND int
// | | \--* ADD byref
// | | +--* LCL_VAR byref V1
// | | \--* CNS_INT int 1
// | \--* IND int
// | \--* ADD byref
// | +--* LCL_VAR byref V2
// | \--* CNS_INT int 1
// \--* CNS_INT int 0
//
GenTree* l1Indir = comp->gtNewIndir(loadType, lArgUse.Def());
GenTree* r1Indir = comp->gtNewIndir(loadType, rArgUse.Def());
GenTree* lXor = comp->gtNewOperNode(GT_XOR, actualLoadType, l1Indir, r1Indir);
GenTree* l2Offs = comp->gtNewIconNode(cnsSize - loadWidth);
EgorBo marked this conversation as resolved.
Show resolved Hide resolved
GenTree* l2AddOffs = comp->gtNewOperNode(GT_ADD, lArg->TypeGet(), lArgClone, l2Offs);
jakobbotsch marked this conversation as resolved.
Show resolved Hide resolved
GenTree* l2Indir = comp->gtNewIndir(loadType, l2AddOffs);
GenTree* r2Offs = comp->gtCloneExpr(l2Offs); // offset is the same
GenTree* r2AddOffs = comp->gtNewOperNode(GT_ADD, rArg->TypeGet(), rArgClone, r2Offs);
jakobbotsch marked this conversation as resolved.
Show resolved Hide resolved
GenTree* r2Indir = comp->gtNewIndir(loadType, r2AddOffs);
GenTree* rXor = comp->gtNewOperNode(GT_XOR, actualLoadType, l2Indir, r2Indir);
GenTree* resultOr = comp->gtNewOperNode(GT_OR, actualLoadType, lXor, rXor);
GenTree* zeroCns = comp->gtNewIconNode(0, actualLoadType);
result = comp->gtNewOperNode(GT_EQ, TYP_INT, resultOr, zeroCns);

BlockRange().InsertAfter(rArgClone, l1Indir, r1Indir, l2Offs, l2AddOffs);
BlockRange().InsertAfter(l2AddOffs, l2Indir, r2Offs, r2AddOffs, r2Indir);
BlockRange().InsertAfter(r2Indir, lXor, rXor, resultOr, zeroCns);
BlockRange().InsertAfter(zeroCns, result);
}

JITDUMP("\nUnrolled to:\n");
DISPTREE(result);

LIR::Use use;
if (BlockRange().TryGetUse(call, &use))
{
use.ReplaceWith(result);
}
BlockRange().Remove(lengthArg);
BlockRange().Remove(call);

// Remove all non-user args (e.g. r2r cell)
for (CallArg& arg : call->gtArgs.Args())
{
if (!arg.IsUserArg())
{
arg.GetNode()->SetUnusedValue();
}
}
return lArg;
}
}
else
{
JITDUMP("Size is either 0 or too big to unroll.\n")
}
}
else
{
JITDUMP("size is not a constant.\n")
}
return nullptr;
}

// do lowering steps for a call
// this includes:
// - adding the placement nodes (either stack or register variety) for arguments
Expand All @@ -1883,19 +2062,26 @@ GenTree* Lowering::LowerCall(GenTree* node)
// All runtime lookups are expected to be expanded in fgExpandRuntimeLookups
assert(!call->IsExpRuntimeLookup());

#if defined(TARGET_AMD64) || defined(TARGET_ARM64)
if (call->gtCallMoreFlags & GTF_CALL_M_SPECIAL_INTRINSIC)
{
#if defined(TARGET_AMD64) || defined(TARGET_ARM64)
if (comp->lookupNamedIntrinsic(call->gtCallMethHnd) == NI_System_Buffer_Memmove)
GenTree* newNode = nullptr;
NamedIntrinsic ni = comp->lookupNamedIntrinsic(call->gtCallMethHnd);
if (ni == NI_System_Buffer_Memmove)
{
GenTree* newNode = LowerCallMemmove(call);
if (newNode != nullptr)
{
return newNode->gtNext;
}
newNode = LowerCallMemmove(call);
}
else if (ni == NI_System_SpanHelpers_SequenceEqual)
{
newNode = LowerCallMemcmp(call);
}

if (newNode != nullptr)
{
return newNode->gtNext;
}
#endif
}
#endif

call->ClearOtherRegs();
LowerArgsForCall(call);
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/lower.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class Lowering final : public Phase
// ------------------------------
GenTree* LowerCall(GenTree* call);
GenTree* LowerCallMemmove(GenTreeCall* call);
GenTree* LowerCallMemcmp(GenTreeCall* call);
void LowerCFGCall(GenTreeCall* call);
void MoveCFGCallArg(GenTreeCall* call, GenTree* node);
#ifndef TARGET_64BIT
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/namedintrinsiclist.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ enum NamedIntrinsic : unsigned short
NI_System_String_StartsWith,
NI_System_Span_get_Item,
NI_System_Span_get_Length,
NI_System_SpanHelpers_SequenceEqual,
NI_System_ReadOnlySpan_get_Item,
NI_System_ReadOnlySpan_get_Length,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1429,12 +1429,11 @@ public static unsafe bool SequenceEqual<T>(this Span<T> span, ReadOnlySpan<T> ot

if (RuntimeHelpers.IsBitwiseEquatable<T>())
{
nuint size = (nuint)sizeof(T);
return length == other.Length &&
SpanHelpers.SequenceEqual(
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(other)),
((uint)length) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
((uint)length) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
}

return length == other.Length && SpanHelpers.SequenceEqual(ref MemoryMarshal.GetReference(span), ref MemoryMarshal.GetReference(other), length);
Expand Down Expand Up @@ -2164,12 +2163,11 @@ public static unsafe bool SequenceEqual<T>(this ReadOnlySpan<T> span, ReadOnlySp
int length = span.Length;
if (RuntimeHelpers.IsBitwiseEquatable<T>())
{
nuint size = (nuint)sizeof(T);
return length == other.Length &&
SpanHelpers.SequenceEqual(
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(other)),
((uint)length) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this API in such a case so we choose not to take the overhead of checking.
((uint)length) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this API in such a case so we choose not to take the overhead of checking.
}

return length == other.Length && SpanHelpers.SequenceEqual(ref MemoryMarshal.GetReference(span), ref MemoryMarshal.GetReference(other), length);
Expand Down Expand Up @@ -2207,11 +2205,10 @@ public static unsafe bool SequenceEqual<T>(this ReadOnlySpan<T> span, ReadOnlySp
// If no comparer was supplied and the type is bitwise equatable, take the fast path doing a bitwise comparison.
if (RuntimeHelpers.IsBitwiseEquatable<T>())
{
nuint size = (nuint)sizeof(T);
return SpanHelpers.SequenceEqual(
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(other)),
((uint)span.Length) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this API in such a case so we choose not to take the overhead of checking.
((uint)span.Length) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this API in such a case so we choose not to take the overhead of checking.
}

// Otherwise, compare each element using EqualityComparer<T>.Default.Equals in a way that will enable it to devirtualize.
Expand Down Expand Up @@ -2277,12 +2274,11 @@ public static unsafe bool StartsWith<T>(this Span<T> span, ReadOnlySpan<T> value
int valueLength = value.Length;
if (RuntimeHelpers.IsBitwiseEquatable<T>())
{
nuint size = (nuint)sizeof(T);
return valueLength <= span.Length &&
SpanHelpers.SequenceEqual(
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(value)),
((uint)valueLength) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
((uint)valueLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
}

return valueLength <= span.Length && SpanHelpers.SequenceEqual(ref MemoryMarshal.GetReference(span), ref MemoryMarshal.GetReference(value), valueLength);
Expand All @@ -2298,12 +2294,11 @@ public static unsafe bool StartsWith<T>(this ReadOnlySpan<T> span, ReadOnlySpan<
int valueLength = value.Length;
if (RuntimeHelpers.IsBitwiseEquatable<T>())
{
nuint size = (nuint)sizeof(T);
return valueLength <= span.Length &&
SpanHelpers.SequenceEqual(
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(value)),
((uint)valueLength) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
((uint)valueLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
}

return valueLength <= span.Length && SpanHelpers.SequenceEqual(ref MemoryMarshal.GetReference(span), ref MemoryMarshal.GetReference(value), valueLength);
Expand All @@ -2319,12 +2314,11 @@ public static unsafe bool EndsWith<T>(this Span<T> span, ReadOnlySpan<T> value)
int valueLength = value.Length;
if (RuntimeHelpers.IsBitwiseEquatable<T>())
{
nuint size = (nuint)sizeof(T);
return valueLength <= spanLength &&
SpanHelpers.SequenceEqual(
ref Unsafe.As<T, byte>(ref Unsafe.Add(ref MemoryMarshal.GetReference(span), (nint)(uint)(spanLength - valueLength) /* force zero-extension */)),
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(value)),
((uint)valueLength) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
((uint)valueLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
}

return valueLength <= spanLength &&
Expand All @@ -2344,12 +2338,11 @@ public static unsafe bool EndsWith<T>(this ReadOnlySpan<T> span, ReadOnlySpan<T>
int valueLength = value.Length;
if (RuntimeHelpers.IsBitwiseEquatable<T>())
{
nuint size = (nuint)sizeof(T);
return valueLength <= spanLength &&
SpanHelpers.SequenceEqual(
ref Unsafe.As<T, byte>(ref Unsafe.Add(ref MemoryMarshal.GetReference(span), (nint)(uint)(spanLength - valueLength) /* force zero-extension */)),
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(value)),
((uint)valueLength) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
((uint)valueLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
}

return valueLength <= spanLength &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,7 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)

// Optimized byte-based SequenceEquals. The "length" parameter for this one is declared a nuint rather than int as we also use it for types other than byte
// where the length can exceed 2Gb once scaled by sizeof(T).
[Intrinsic] // Unrolled for constant length
public static unsafe bool SequenceEqual(ref byte first, ref byte second, nuint length)
{
bool result;
Expand Down
Loading