Skip to content

Commit

Permalink
Add optimisation of incrementing SELECT at lowering on arm64. (#91262)
Browse files Browse the repository at this point in the history
This patch adds support for optimising SELECT or SELECT_CC nodes
which contain a child of type GT_ADD with a second operand constant
value of 1 (increment) to be optimised into a SELECT_INC or
SELECT_INCCC node.

Change-Id: Ia26da6f4c0e3a75143e133964cbb76243d0822de
  • Loading branch information
c272 authored Sep 17, 2023
1 parent ebe6f54 commit 1972683
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 41 deletions.
6 changes: 3 additions & 3 deletions src/coreclr/jit/lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3858,13 +3858,13 @@ GenTree* Lowering::LowerSelect(GenTreeConditional* select)
}

#ifdef TARGET_ARM64
if (trueVal->OperIs(GT_NOT, GT_NEG) || falseVal->OperIs(GT_NOT, GT_NEG))
if (trueVal->OperIs(GT_NOT, GT_NEG, GT_ADD) || falseVal->OperIs(GT_NOT, GT_NEG, GT_ADD))
{
TryLowerCselToCinvOrCneg(select, cond);
TryLowerCselToCSOp(select, cond);
}
else if (trueVal->IsCnsIntOrI() && falseVal->IsCnsIntOrI())
{
TryLowerCselToCinc(select, cond);
TryLowerCnsIntCselToCinc(select, cond);
}
#endif

Expand Down
4 changes: 2 additions & 2 deletions src/coreclr/jit/lower.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ class Lowering final : public Phase
insCflags TruthifyingFlags(GenCondition cond);
void ContainCheckConditionalCompare(GenTreeCCMP* ccmp);
void ContainCheckNeg(GenTreeOp* neg);
void TryLowerCselToCinc(GenTreeOp* select, GenTree* cond);
void TryLowerCselToCinvOrCneg(GenTreeOp* select, GenTree* cond);
void TryLowerCnsIntCselToCinc(GenTreeOp* select, GenTree* cond);
void TryLowerCselToCSOp(GenTreeOp* select, GenTree* cond);
#endif
void ContainCheckSelect(GenTreeOp* select);
void ContainCheckBitCast(GenTree* node);
Expand Down
121 changes: 85 additions & 36 deletions src/coreclr/jit/lowerarmarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2551,41 +2551,55 @@ void Lowering::ContainCheckNeg(GenTreeOp* neg)
}

//----------------------------------------------------------------------------------------------
// TryLowerCselToCinvOrCneg: Try converting SELECT/SELECTCC to SELECT_INV/SELECT_INVCC. Conversion is possible only if
// one of the operands of the select node is inverted.
// TryLowerCselToCSOp: Try converting SELECT/SELECTCC to SELECT_?/SELECT_?CC. Conversion is possible only if
// one of the operands of the select node is one of GT_NEG, GT_NOT or GT_ADD.
//
// Arguments:
// select - The select node that is now SELECT or SELECTCC
// cond - The condition node that SELECT or SELECTCC uses
//
void Lowering::TryLowerCselToCinvOrCneg(GenTreeOp* select, GenTree* cond)
void Lowering::TryLowerCselToCSOp(GenTreeOp* select, GenTree* cond)
{
assert(select->OperIs(GT_SELECT, GT_SELECTCC));

bool shouldReverseCondition;
GenTree* invertedOrNegatedVal;
GenTree* nonInvertedOrNegatedVal;
GenTree* operatedVal;
GenTree* nonOperatedVal;
GenTree* nodeToRemove;
GenTree* trueVal = select->gtOp1;
GenTree* falseVal = select->gtOp2;

GenTree* trueVal = select->gtOp1;
GenTree* falseVal = select->gtOp2;
const bool isCneg = trueVal->OperIs(GT_NEG) || falseVal->OperIs(GT_NEG);

assert(trueVal->OperIs(GT_NOT, GT_NEG) || falseVal->OperIs(GT_NOT, GT_NEG));
// Determine the resulting operation type.
genTreeOps resultingOp;
if (trueVal->OperIs(GT_NEG) || falseVal->OperIs(GT_NEG))
{
resultingOp = GT_SELECT_NEG;
shouldReverseCondition = trueVal->OperIs(GT_NEG);
}
else if (trueVal->OperIs(GT_NOT) || falseVal->OperIs(GT_NOT))
{
resultingOp = GT_SELECT_INV;
shouldReverseCondition = trueVal->OperIs(GT_NOT);
}
else
{
assert(trueVal->OperIs(GT_ADD) || falseVal->OperIs(GT_ADD));
resultingOp = GT_SELECT_INC;
shouldReverseCondition = trueVal->OperIs(GT_ADD);
}

if ((isCneg && trueVal->OperIs(GT_NEG)) || (!isCneg && trueVal->OperIs(GT_NOT)))
// Values to which the operation are applied must come last.
if (shouldReverseCondition)
{
shouldReverseCondition = true;
invertedOrNegatedVal = trueVal->gtGetOp1();
nonInvertedOrNegatedVal = falseVal;
nodeToRemove = trueVal;
operatedVal = trueVal->gtGetOp1();
nonOperatedVal = falseVal;
nodeToRemove = trueVal;
}
else
{
shouldReverseCondition = false;
invertedOrNegatedVal = falseVal->gtGetOp1();
nonInvertedOrNegatedVal = trueVal;
nodeToRemove = falseVal;
operatedVal = falseVal->gtGetOp1();
nonOperatedVal = trueVal;
nodeToRemove = falseVal;
}

if (shouldReverseCondition && !cond->OperIsCompare() && select->OperIs(GT_SELECT))
Expand All @@ -2595,17 +2609,33 @@ void Lowering::TryLowerCselToCinvOrCneg(GenTreeOp* select, GenTree* cond)
return;
}

if (!(IsInvariantInRange(invertedOrNegatedVal, select) && IsInvariantInRange(nonInvertedOrNegatedVal, select)))
// For Csinc candidates, the second argument of the GT_ADD must be +1 (increment).
if (resultingOp == GT_SELECT_INC &&
!(nodeToRemove->gtGetOp2()->IsCnsIntOrI() && nodeToRemove->gtGetOp2()->AsIntCon()->IconValue() == 1))
{
return;
}

// As the select node would handle the negation/inversion, the op is not required.
// If a value is contained in the negate/invert op, it cannot be contained anymore.
// Check that we are safe to move both values.
if (!(IsInvariantInRange(operatedVal, select) && IsInvariantInRange(nonOperatedVal, select)))
{
return;
}

// Passed all checks, move on to block modification.
// If this is a Cinc candidate, we must remove the dangling second argument node.
if (resultingOp == GT_SELECT_INC)
{
BlockRange().Remove(nodeToRemove->gtGetOp2());
nodeToRemove->AsOp()->gtOp2 = nullptr;
}

// As the select node would handle the operation, the op is not required.
// If a value is contained in the negate/invert/increment op, it cannot be contained anymore.
BlockRange().Remove(nodeToRemove);
invertedOrNegatedVal->ClearContained();
select->gtOp1 = nonInvertedOrNegatedVal;
select->gtOp2 = invertedOrNegatedVal;
operatedVal->ClearContained();
select->gtOp1 = nonOperatedVal;
select->gtOp2 = operatedVal;

if (select->OperIs(GT_SELECT))
{
Expand All @@ -2614,10 +2644,7 @@ void Lowering::TryLowerCselToCinvOrCneg(GenTreeOp* select, GenTree* cond)
GenTree* revCond = comp->gtReverseCond(cond);
assert(cond == revCond); // Ensure `gtReverseCond` did not create a new node.
}
select->SetOper(isCneg ? GT_SELECT_NEG : GT_SELECT_INV);
JITDUMP("Converted to: %s\n", isCneg ? "SELECT_NEG" : "SELECT_INV");
DISPTREERANGE(BlockRange(), select);
JITDUMP("\n");
select->SetOper(resultingOp);
}
else
{
Expand All @@ -2628,22 +2655,44 @@ void Lowering::TryLowerCselToCinvOrCneg(GenTreeOp* select, GenTree* cond)
// Reverse the condition so that op2 will be selected
selectcc->gtCondition = GenCondition::Reverse(selectCond);
}
selectcc->SetOper(isCneg ? GT_SELECT_NEGCC : GT_SELECT_INVCC);
JITDUMP("Converted to: %s\n", isCneg ? "SELECT_NEGCC" : "SELECT_INVCC");
DISPTREERANGE(BlockRange(), selectcc);
JITDUMP("\n");

// Convert the resulting operation into the equivalent CC form.
switch (resultingOp)
{
case GT_SELECT_NEG:
resultingOp = GT_SELECT_NEGCC;
break;
case GT_SELECT_INV:
resultingOp = GT_SELECT_INVCC;
break;
case GT_SELECT_INC:
resultingOp = GT_SELECT_INCCC;
break;
default:
assert(false);
}
selectcc->SetOper(resultingOp);
}

#ifdef DEBUG
JITDUMP("Converted to ");
if (comp->verbose)
comp->gtDispNodeName(select);
JITDUMP(":\n");
DISPTREERANGE(BlockRange(), select);
JITDUMP("\n");
#endif
}

//----------------------------------------------------------------------------------------------
// TryLowerCselToCinc: Try converting SELECT/SELECTCC to SELECT_INC/SELECT_INCCC. Conversion is possible only if both
// the trueVal and falseVal are integral constants and abs(trueVal - falseVal) = 1.
// TryLowerCnsIntCselToCinc: Try converting SELECT/SELECTCC to SELECT_INC/SELECT_INCCC.
// Conversion is possible only if both the trueVal and falseVal are integer constants and abs(trueVal - falseVal) = 1.
//
// Arguments:
// select - The select node that is now SELECT or SELECTCC
// cond - The condition node that SELECT or SELECTCC uses
//
void Lowering::TryLowerCselToCinc(GenTreeOp* select, GenTree* cond)
void Lowering::TryLowerCnsIntCselToCinc(GenTreeOp* select, GenTree* cond)
{
assert(select->OperIs(GT_SELECT, GT_SELECTCC));

Expand Down
132 changes: 132 additions & 0 deletions src/tests/JIT/opt/Compares/conditionalIncrements.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,30 @@ public static void cinc_byte(byte op1, int expected)
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 35)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_byte(byte op1, int expected)
{
//ARM64-FULL-LINE: cmp {{w[0-9]+}}, #44
//ARM64-FULL-LINE-NEXT: csinc {{w[0-9]+}}, {{w[0-9]+}}, {{w[0-9]+}}, {{gt|le}}
int result = op1 <= 44 ? op1 + 1 : 5;
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 35)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_byte_reverse(byte op1, int expected)
{
//ARM64-FULL-LINE: cmp {{w[0-9]+}}, #44
//ARM64-FULL-LINE-NEXT: csinc {{w[0-9]+}}, {{w[0-9]+}}, {{w[0-9]+}}, {{gt|le}}
int result = op1 <= 44 ? 1 + op1 : 5;
Assert.Equal(expected, result);
}

[Theory]
[InlineData(72, byte.MinValue)]
[InlineData(32, byte.MaxValue)]
Expand Down Expand Up @@ -58,6 +82,30 @@ public static void cinc_short(short op1, short expected)
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 35)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_short(short op1, short expected)
{
//ARM64-FULL-LINE: cmp {{w[0-9]+}}, #44
//ARM64-FULL-LINE-NEXT: csinc {{w[0-9]+}}, {{w[0-9]+}}, {{w[0-9]+}}, {{gt|le}}
short result = (short) (op1 <= 44 ? op1 + 1 : 5);
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 35)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_short_reverse(short op1, short expected)
{
//ARM64-FULL-LINE: cmp {{w[0-9]+}}, #44
//ARM64-FULL-LINE-NEXT: csinc {{w[0-9]+}}, {{w[0-9]+}}, {{w[0-9]+}}, {{gt|le}}
short result = (short) (op1 <= 44 ? 1 + op1 : 5);
Assert.Equal(expected, result);
}

[Theory]
[InlineData(76, short.MinValue)]
[InlineData(-35, short.MaxValue)]
Expand All @@ -82,6 +130,66 @@ public static void cinc_int(int op1, int expected)
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 35)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_int(int op1, int expected)
{
//ARM64-FULL-LINE: cmp {{w[0-9]+}}, #44
//ARM64-FULL-LINE-NEXT: csinc {{w[0-9]+}}, {{w[0-9]+}}, {{w[0-9]+}}, {{gt|le}}
int result = op1 <= 44 ? op1 + 1 : 5;
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 35)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_int_reverse(int op1, int expected)
{
//ARM64-FULL-LINE: cmp {{w[0-9]+}}, #44
//ARM64-FULL-LINE-NEXT: csinc {{w[0-9]+}}, {{w[0-9]+}}, {{w[0-9]+}}, {{gt|le}}
int result = op1 <= 44 ? 1 + op1 : 5;
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 69)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_int_postinc(int op1, int expected)
{
int result = 2 * op1;

//ARM64-FULL-LINE: cmp {{w[0-9]+}}, #44
//ARM64-FULL-LINE-NEXT: csinc {{w[0-9]+}}, {{w[0-9]+}}, {{w[0-9]+}}, {{ge|lt}}
if (op1 < 44) {
result++;
} else {
result = 5;
}
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 149)]
[InlineData(34, 5)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_int_postinc_false(int op1, int expected)
{
int result = 2 * op1;

//ARM64-FULL-LINE: cmp {{w[0-9]+}}, #44
//ARM64-FULL-LINE-NEXT: csinc {{w[0-9]+}}, {{w[0-9]+}}, {{w[0-9]+}}, {{ge|lt}}
if (op1 < 44) {
result = 5;
} else {
result++;
}
Assert.Equal(expected, result);
}

[Theory]
[InlineData(77, int.MinValue)]
[InlineData(37, int.MaxValue)]
Expand All @@ -107,6 +215,30 @@ public static void cinc_long(long op1, long expected)
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 35)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_long(long op1, long expected)
{
//ARM64-FULL-LINE: cmp {{x[0-9]+}}, #48
//ARM64-FULL-LINE-NEXT: csinc {{x[0-9]+}}, {{x[0-9]+}}, {{x[0-9]+}}, {{gt|le}}
long result = op1 <= 48 ? op1 + 1 : 5;
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 35)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_long_reverse(long op1, long expected)
{
//ARM64-FULL-LINE: cmp {{x[0-9]+}}, #48
//ARM64-FULL-LINE-NEXT: csinc {{x[0-9]+}}, {{x[0-9]+}}, {{x[0-9]+}}, {{gt|le}}
long result = op1 <= 48 ? 1 + op1 : 5;
Assert.Equal(expected, result);
}

[Theory]
[InlineData(79, long.MaxValue)]
[InlineData(39, long.MinValue)]
Expand Down

0 comments on commit 1972683

Please sign in to comment.