Skip to content

Commit

Permalink
Ensure that NI_Vector_Dot is always handled as TYP_SIMD (#88447)
Browse files Browse the repository at this point in the history
* Ensure that NI_Vector_Dot is always handled as TYP_SIMD

* Apply formatting patch

* Ensure simdType is passed in to gtNewSimdDotProdNode

* Apply suggestions from code review

Co-authored-by: Kunal Pathak <[email protected]>

---------

Co-authored-by: Kunal Pathak <[email protected]>
  • Loading branch information
tannergooding and kunalspathak authored Jul 10, 2023
1 parent acafb20 commit ab7c452
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 98 deletions.
4 changes: 1 addition & 3 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22052,9 +22052,7 @@ GenTree* Compiler::gtNewSimdDotProdNode(
assert(op2->TypeIs(simdType));

var_types simdBaseType = JitType2PreciseVarType(simdBaseJitType);

// We support the return type being a SIMD for floating-point as a special optimization
assert(varTypeIsArithmetic(type) || (varTypeIsSIMD(type) && varTypeIsFloating(simdBaseType)));
assert(varTypeIsSIMD(type));

NamedIntrinsic intrinsic = NI_Illegal;

Expand Down
3 changes: 2 additions & 1 deletion src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,8 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
op2 = impSIMDPopStack();
op1 = impSIMDPopStack();

retNode = gtNewSimdDotProdNode(retType, op1, op2, simdBaseJitType, simdSize);
retNode = gtNewSimdDotProdNode(simdType, op1, op2, simdBaseJitType, simdSize);
retNode = gtNewSimdGetElementNode(retType, retNode, gtNewIconNode(0), simdBaseJitType, simdSize);
}
break;
}
Expand Down
3 changes: 2 additions & 1 deletion src/coreclr/jit/hwintrinsicxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1778,7 +1778,8 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
op2 = impSIMDPopStack();
op1 = impSIMDPopStack();

retNode = gtNewSimdDotProdNode(retType, op1, op2, simdBaseJitType, simdSize);
retNode = gtNewSimdDotProdNode(simdType, op1, op2, simdBaseJitType, simdSize);
retNode = gtNewSimdGetElementNode(retType, retNode, gtNewIconNode(0), simdBaseJitType, simdSize);
break;
}

Expand Down
38 changes: 9 additions & 29 deletions src/coreclr/jit/lowerarmarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1628,9 +1628,7 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
assert(varTypeIsSIMD(simdType));
assert(varTypeIsArithmetic(simdBaseType));
assert(simdSize != 0);

// We support the return type being a SIMD for floating-point as a special optimization
assert(varTypeIsArithmetic(node) || (varTypeIsSIMD(node) && varTypeIsFloating(simdBaseType)));
assert(varTypeIsSIMD(node));

GenTree* op1 = node->Op(1);
GenTree* op2 = node->Op(2);
Expand All @@ -1647,7 +1645,7 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)

// For 12 byte SIMD, we need to clear the upper 4 bytes:
// idx = CNS_INT int 0x03
// tmp1 = * CNS_DLB float 0.0
// tmp1 = * CNS_DBL float 0.0
// /--* op1 simd16
// +--* idx int
// +--* tmp1 simd16
Expand Down Expand Up @@ -1887,34 +1885,16 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
}
}

if (varTypeIsSIMD(node->gtType))
{
// We're producing a vector result, so just return the result directly

LIR::Use use;

if (BlockRange().TryGetUse(node, &use))
{
use.ReplaceWith(tmp2);
}
// We're producing a vector result, so just return the result directly
LIR::Use use;

BlockRange().Remove(node);
return tmp2->gtNext;
}
else
if (BlockRange().TryGetUse(node, &use))
{
// We will be constructing the following parts:
// ...
// /--* tmp2 simd16
// node = * HWINTRINSIC simd16 T ToScalar

// This is roughly the following managed code:
// ...
// return tmp2.ToScalar();

node->ResetHWIntrinsicId((simdSize == 8) ? NI_Vector64_ToScalar : NI_Vector128_ToScalar, tmp2);
return LowerNode(node);
use.ReplaceWith(tmp2);
}

BlockRange().Remove(node);
return tmp2->gtNext;
}
#endif // FEATURE_HW_INTRINSICS

Expand Down
55 changes: 18 additions & 37 deletions src/coreclr/jit/lowerxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4428,9 +4428,7 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
assert(varTypeIsSIMD(simdType));
assert(varTypeIsArithmetic(simdBaseType));
assert(simdSize != 0);

// We support the return type being a SIMD for floating-point as a special optimization
assert(varTypeIsArithmetic(node) || (varTypeIsSIMD(node) && varTypeIsFloating(simdBaseType)));
assert(varTypeIsSIMD(node));

GenTree* op1 = node->Op(1);
GenTree* op2 = node->Op(2);
Expand Down Expand Up @@ -4479,15 +4477,12 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
// tmp2 = * HWINTRINSIC simd16 T GetUpper
// /--* tmp1 simd16
// +--* tmp2 simd16
// tmp3 = * HWINTRINSIC simd16 T Add
// /--* tmp3 simd16
// node = * HWINTRINSIC simd16 T ToScalar
// node = * HWINTRINSIC simd16 T Add

// This is roughly the following managed code:
// var tmp1 = Avx.DotProduct(op1, op2, 0xFF);
// var tmp2 = tmp1.GetUpper();
// var tmp3 = Sse.Add(tmp1, tmp2);
// return tmp3.ToScalar();
// return Sse.Add(tmp1, tmp2);

idx = comp->gtNewIconNode(0xF1, TYP_INT);
BlockRange().InsertBefore(node, idx);
Expand Down Expand Up @@ -4515,13 +4510,17 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)

tmp2 = comp->gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, tmp3, tmp1, simdBaseJitType, 16);
BlockRange().InsertAfter(tmp1, tmp2);
LowerNode(tmp2);

node->SetSimdSize(16);
// We're producing a vector result, so just return the result directly
LIR::Use use;

node->ResetHWIntrinsicId(NI_Vector128_ToScalar, tmp2);
if (BlockRange().TryGetUse(node, &use))
{
use.ReplaceWith(tmp2);
}

return LowerNode(node);
BlockRange().Remove(node);
return LowerNode(tmp2);
}

case TYP_DOUBLE:
Expand Down Expand Up @@ -5043,34 +5042,16 @@ GenTree* Lowering::LowerHWIntrinsicDot(GenTreeHWIntrinsic* node)
tmp1 = tmp2;
}

if (varTypeIsSIMD(node->gtType))
{
// We're producing a vector result, so just return the result directly

LIR::Use use;

if (BlockRange().TryGetUse(node, &use))
{
use.ReplaceWith(tmp1);
}
// We're producing a vector result, so just return the result directly
LIR::Use use;

BlockRange().Remove(node);
return tmp1->gtNext;
}
else
if (BlockRange().TryGetUse(node, &use))
{
// We will be constructing the following parts:
// ...
// /--* tmp1 simd16
// node = * HWINTRINSIC simd16 T ToScalar

// This is roughly the following managed code:
// ...
// return tmp1.ToScalar();

node->ResetHWIntrinsicId(NI_Vector128_ToScalar, tmp1);
return LowerNode(node);
use.ReplaceWith(tmp1);
}

BlockRange().Remove(node);
return tmp1->gtNext;
}

//----------------------------------------------------------------------------------------------
Expand Down
69 changes: 48 additions & 21 deletions src/coreclr/jit/morph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10670,39 +10670,45 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)
#endif // TARGET_ARM64
case NI_Vector128_Create:
{
// The `Dot` API returns a scalar. However, many common usages require it to
// be then immediately broadcast back to a vector so that it can be used in
// a subsequent operation. One of the most common is normalizing a vector
// The managed `Dot` API returns a scalar. However, many common usages require
// it to be then immediately broadcast back to a vector so that it can be used
// in a subsequent operation. One of the most common is normalizing a vector
// which is effectively `value / value.Length` where `Length` is
// `Sqrt(Dot(value, value))`
// `Sqrt(Dot(value, value))`. Because of this, and because of how a lot of
// hardware works, we treat `NI_Vector_Dot` as returning a SIMD type and then
// also wrap it in `ToScalar` where required.
//
// In order to ensure that developers can still utilize this efficiently, we
// will look for two common patterns:
// then look for four common patterns:
// * Create(Dot(..., ...))
// * Create(Sqrt(Dot(..., ...)))
// * Create(ToScalar(Dot(..., ...)))
// * Create(ToScalar(Sqrt(Dot(..., ...))))
//
// When these exist, we'll avoid converting to a scalar at all and just
// keep everything as a vector. However, we only do this for Vector64/Vector128
// and only for float/double.
// When these exist, we'll avoid converting to a scalar and hence, avoid broadcasting
// the value back into a vector. Instead we'll just keep everything as a vector.
//
// We don't do this for Vector256 since that is xarch only and doesn't trivially
// support operations which cross the upper and lower 128-bit lanes
// We only do this for Vector64/Vector128 today. We could expand this more in
// the future but it would require additional hand handling for Vector256
// (since a 256-bit result requires more work). We do some integer handling
// when the value is trivially replicated to all elements without extra work.

if (node->GetOperandCount() != 1)
{
break;
}

if (!varTypeIsFloating(node->GetSimdBaseType()))
{
break;
}

GenTree* op1 = node->Op(1);
GenTree* sqrt = nullptr;
GenTree* op1 = node->Op(1);
GenTree* sqrt = nullptr;
GenTree* toScalar = nullptr;

if (op1->OperIs(GT_INTRINSIC))
{
if (!varTypeIsFloating(node->GetSimdBaseType()))
{
break;
}

if (op1->AsIntrinsic()->gtIntrinsicName != NI_System_Math_Sqrt)
{
break;
Expand All @@ -10719,6 +10725,24 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)

GenTreeHWIntrinsic* hwop1 = op1->AsHWIntrinsic();

#if defined(TARGET_ARM64)
if ((hwop1->GetHWIntrinsicId() == NI_Vector64_ToScalar) ||
(hwop1->GetHWIntrinsicId() == NI_Vector128_ToScalar))
#else
if (hwop1->GetHWIntrinsicId() == NI_Vector128_ToScalar)
#endif
{
op1 = hwop1->Op(1);

if (!op1->OperIs(GT_HWINTRINSIC))
{
break;
}

toScalar = hwop1;
hwop1 = op1->AsHWIntrinsic();
}

#if defined(TARGET_ARM64)
if ((hwop1->GetHWIntrinsicId() != NI_Vector64_Dot) && (hwop1->GetHWIntrinsicId() != NI_Vector128_Dot))
#else
Expand All @@ -10728,13 +10752,16 @@ GenTree* Compiler::fgOptimizeHWIntrinsic(GenTreeHWIntrinsic* node)
break;
}

unsigned simdSize = node->GetSimdSize();
var_types simdType = getSIMDTypeForSize(simdSize);

hwop1->gtType = simdType;
if (toScalar != nullptr)
{
DEBUG_DESTROY_NODE(toScalar);
}

if (sqrt != nullptr)
{
unsigned simdSize = node->GetSimdSize();
var_types simdType = getSIMDTypeForSize(simdSize);

node = gtNewSimdSqrtNode(simdType, hwop1, node->GetSimdBaseJitType(), simdSize)->AsHWIntrinsic();
DEBUG_DESTROY_NODE(sqrt);
}
Expand Down
13 changes: 7 additions & 6 deletions src/coreclr/jit/simdashwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1065,8 +1065,7 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic intrinsic,
vecCon->gtSimdVal.f32[3] = +1.0f;

GenTree* conjugate = gtNewSimdBinOpNode(GT_MUL, retType, op1, vecCon, simdBaseJitType, simdSize);

op1 = gtNewSimdDotProdNode(retType, clonedOp1, clonedOp2, simdBaseJitType, simdSize);
op1 = gtNewSimdDotProdNode(retType, clonedOp1, clonedOp2, simdBaseJitType, simdSize);

return gtNewSimdBinOpNode(GT_DIV, retType, conjugate, op1, simdBaseJitType, simdSize);
}
Expand Down Expand Up @@ -1095,7 +1094,8 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic intrinsic,
op1 = impCloneExpr(op1, &clonedOp1, CHECK_SPILL_ALL,
nullptr DEBUGARG("Clone op1 for vector length squared"));

return gtNewSimdDotProdNode(retType, op1, clonedOp1, simdBaseJitType, simdSize);
op1 = gtNewSimdDotProdNode(simdType, op1, clonedOp1, simdBaseJitType, simdSize);
return gtNewSimdGetElementNode(retType, op1, gtNewIconNode(0), simdBaseJitType, simdSize);
}

case NI_VectorT128_Load:
Expand Down Expand Up @@ -1174,7 +1174,6 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic intrinsic,
nullptr DEBUGARG("Clone op1 for vector normalize (2)"));

op1 = gtNewSimdDotProdNode(retType, op1, clonedOp1, simdBaseJitType, simdSize);

op1 = gtNewSimdSqrtNode(retType, op1, simdBaseJitType, simdSize);

return gtNewSimdBinOpNode(GT_DIV, retType, clonedOp2, op1, simdBaseJitType, simdSize);
Expand Down Expand Up @@ -1462,7 +1461,8 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic intrinsic,
op1 = impCloneExpr(op1, &clonedOp1, CHECK_SPILL_ALL,
nullptr DEBUGARG("Clone diff for vector distance squared"));

return gtNewSimdDotProdNode(retType, op1, clonedOp1, simdBaseJitType, simdSize);
op1 = gtNewSimdDotProdNode(simdType, op1, clonedOp1, simdBaseJitType, simdSize);
return gtNewSimdGetElementNode(retType, op1, gtNewIconNode(0), simdBaseJitType, simdSize);
}

case NI_Quaternion_Divide:
Expand Down Expand Up @@ -1492,7 +1492,8 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic intrinsic,
case NI_VectorT256_Dot:
#endif // TARGET_XARCH
{
return gtNewSimdDotProdNode(retType, op1, op2, simdBaseJitType, simdSize);
op1 = gtNewSimdDotProdNode(simdType, op1, op2, simdBaseJitType, simdSize);
return gtNewSimdGetElementNode(retType, op1, gtNewIconNode(0), simdBaseJitType, simdSize);
}

case NI_VectorT128_Equals:
Expand Down

0 comments on commit ab7c452

Please sign in to comment.