Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimisation of incrementing SELECT at lowering on arm64. #91262

Merged
merged 2 commits into from
Sep 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/coreclr/jit/lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3858,13 +3858,13 @@ GenTree* Lowering::LowerSelect(GenTreeConditional* select)
}

#ifdef TARGET_ARM64
if (trueVal->OperIs(GT_NOT, GT_NEG) || falseVal->OperIs(GT_NOT, GT_NEG))
if (trueVal->OperIs(GT_NOT, GT_NEG, GT_ADD) || falseVal->OperIs(GT_NOT, GT_NEG, GT_ADD))
{
TryLowerCselToCinvOrCneg(select, cond);
TryLowerCselToCSOp(select, cond);
}
else if (trueVal->IsCnsIntOrI() && falseVal->IsCnsIntOrI())
{
TryLowerCselToCinc(select, cond);
TryLowerCnsIntCselToCinc(select, cond);
}
#endif

Expand Down
4 changes: 2 additions & 2 deletions src/coreclr/jit/lower.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ class Lowering final : public Phase
insCflags TruthifyingFlags(GenCondition cond);
void ContainCheckConditionalCompare(GenTreeCCMP* ccmp);
void ContainCheckNeg(GenTreeOp* neg);
void TryLowerCselToCinc(GenTreeOp* select, GenTree* cond);
void TryLowerCselToCinvOrCneg(GenTreeOp* select, GenTree* cond);
void TryLowerCnsIntCselToCinc(GenTreeOp* select, GenTree* cond);
void TryLowerCselToCSOp(GenTreeOp* select, GenTree* cond);
#endif
void ContainCheckSelect(GenTreeOp* select);
void ContainCheckBitCast(GenTree* node);
Expand Down
121 changes: 85 additions & 36 deletions src/coreclr/jit/lowerarmarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2551,41 +2551,55 @@ void Lowering::ContainCheckNeg(GenTreeOp* neg)
}

//----------------------------------------------------------------------------------------------
// TryLowerCselToCinvOrCneg: Try converting SELECT/SELECTCC to SELECT_INV/SELECT_INVCC. Conversion is possible only if
// one of the operands of the select node is inverted.
// TryLowerCselToCSOp: Try converting SELECT/SELECTCC to SELECT_?/SELECT_?CC. Conversion is possible only if
// one of the operands of the select node is one of GT_NEG, GT_NOT or GT_ADD.
//
// Arguments:
// select - The select node that is now SELECT or SELECTCC
// cond - The condition node that SELECT or SELECTCC uses
//
void Lowering::TryLowerCselToCinvOrCneg(GenTreeOp* select, GenTree* cond)
void Lowering::TryLowerCselToCSOp(GenTreeOp* select, GenTree* cond)
{
assert(select->OperIs(GT_SELECT, GT_SELECTCC));

bool shouldReverseCondition;
GenTree* invertedOrNegatedVal;
GenTree* nonInvertedOrNegatedVal;
GenTree* operatedVal;
GenTree* nonOperatedVal;
GenTree* nodeToRemove;
GenTree* trueVal = select->gtOp1;
GenTree* falseVal = select->gtOp2;

GenTree* trueVal = select->gtOp1;
GenTree* falseVal = select->gtOp2;
const bool isCneg = trueVal->OperIs(GT_NEG) || falseVal->OperIs(GT_NEG);

assert(trueVal->OperIs(GT_NOT, GT_NEG) || falseVal->OperIs(GT_NOT, GT_NEG));
// Determine the resulting operation type.
genTreeOps resultingOp;
if (trueVal->OperIs(GT_NEG) || falseVal->OperIs(GT_NEG))
{
resultingOp = GT_SELECT_NEG;
shouldReverseCondition = trueVal->OperIs(GT_NEG);
}
else if (trueVal->OperIs(GT_NOT) || falseVal->OperIs(GT_NOT))
{
resultingOp = GT_SELECT_INV;
shouldReverseCondition = trueVal->OperIs(GT_NOT);
}
else
{
assert(trueVal->OperIs(GT_ADD) || falseVal->OperIs(GT_ADD));
resultingOp = GT_SELECT_INC;
shouldReverseCondition = trueVal->OperIs(GT_ADD);
}

if ((isCneg && trueVal->OperIs(GT_NEG)) || (!isCneg && trueVal->OperIs(GT_NOT)))
// Values to which the operation are applied must come last.
if (shouldReverseCondition)
{
shouldReverseCondition = true;
invertedOrNegatedVal = trueVal->gtGetOp1();
nonInvertedOrNegatedVal = falseVal;
nodeToRemove = trueVal;
operatedVal = trueVal->gtGetOp1();
nonOperatedVal = falseVal;
nodeToRemove = trueVal;
}
else
{
shouldReverseCondition = false;
invertedOrNegatedVal = falseVal->gtGetOp1();
nonInvertedOrNegatedVal = trueVal;
nodeToRemove = falseVal;
operatedVal = falseVal->gtGetOp1();
nonOperatedVal = trueVal;
nodeToRemove = falseVal;
}

if (shouldReverseCondition && !cond->OperIsCompare() && select->OperIs(GT_SELECT))
Expand All @@ -2595,17 +2609,33 @@ void Lowering::TryLowerCselToCinvOrCneg(GenTreeOp* select, GenTree* cond)
return;
}

if (!(IsInvariantInRange(invertedOrNegatedVal, select) && IsInvariantInRange(nonInvertedOrNegatedVal, select)))
// For Csinc candidates, the second argument of the GT_ADD must be +1 (increment).
if (resultingOp == GT_SELECT_INC &&
!(nodeToRemove->gtGetOp2()->IsCnsIntOrI() && nodeToRemove->gtGetOp2()->AsIntCon()->IconValue() == 1))
{
return;
}

// As the select node would handle the negation/inversion, the op is not required.
// If a value is contained in the negate/invert op, it cannot be contained anymore.
// Check that we are safe to move both values.
if (!(IsInvariantInRange(operatedVal, select) && IsInvariantInRange(nonOperatedVal, select)))
{
return;
}

// Passed all checks, move on to block modification.
// If this is a Cinc candidate, we must remove the dangling second argument node.
if (resultingOp == GT_SELECT_INC)
{
BlockRange().Remove(nodeToRemove->gtGetOp2());
nodeToRemove->AsOp()->gtOp2 = nullptr;
}

// As the select node would handle the operation, the op is not required.
// If a value is contained in the negate/invert/increment op, it cannot be contained anymore.
BlockRange().Remove(nodeToRemove);
invertedOrNegatedVal->ClearContained();
select->gtOp1 = nonInvertedOrNegatedVal;
select->gtOp2 = invertedOrNegatedVal;
operatedVal->ClearContained();
select->gtOp1 = nonOperatedVal;
select->gtOp2 = operatedVal;

if (select->OperIs(GT_SELECT))
{
Expand All @@ -2614,10 +2644,7 @@ void Lowering::TryLowerCselToCinvOrCneg(GenTreeOp* select, GenTree* cond)
GenTree* revCond = comp->gtReverseCond(cond);
assert(cond == revCond); // Ensure `gtReverseCond` did not create a new node.
}
select->SetOper(isCneg ? GT_SELECT_NEG : GT_SELECT_INV);
JITDUMP("Converted to: %s\n", isCneg ? "SELECT_NEG" : "SELECT_INV");
DISPTREERANGE(BlockRange(), select);
JITDUMP("\n");
select->SetOper(resultingOp);
}
else
{
Expand All @@ -2628,22 +2655,44 @@ void Lowering::TryLowerCselToCinvOrCneg(GenTreeOp* select, GenTree* cond)
// Reverse the condition so that op2 will be selected
selectcc->gtCondition = GenCondition::Reverse(selectCond);
}
selectcc->SetOper(isCneg ? GT_SELECT_NEGCC : GT_SELECT_INVCC);
JITDUMP("Converted to: %s\n", isCneg ? "SELECT_NEGCC" : "SELECT_INVCC");
DISPTREERANGE(BlockRange(), selectcc);
JITDUMP("\n");

// Convert the resulting operation into the equivalent CC form.
switch (resultingOp)
{
case GT_SELECT_NEG:
resultingOp = GT_SELECT_NEGCC;
break;
case GT_SELECT_INV:
resultingOp = GT_SELECT_INVCC;
break;
case GT_SELECT_INC:
resultingOp = GT_SELECT_INCCC;
break;
default:
assert(false);
}
selectcc->SetOper(resultingOp);
}

#ifdef DEBUG
JITDUMP("Converted to ");
if (comp->verbose)
comp->gtDispNodeName(select);
JITDUMP(":\n");
DISPTREERANGE(BlockRange(), select);
JITDUMP("\n");
#endif
}

//----------------------------------------------------------------------------------------------
// TryLowerCselToCinc: Try converting SELECT/SELECTCC to SELECT_INC/SELECT_INCCC. Conversion is possible only if both
// the trueVal and falseVal are integral constants and abs(trueVal - falseVal) = 1.
// TryLowerCnsIntCselToCinc: Try converting SELECT/SELECTCC to SELECT_INC/SELECT_INCCC.
// Conversion is possible only if both the trueVal and falseVal are integer constants and abs(trueVal - falseVal) = 1.
//
// Arguments:
// select - The select node that is now SELECT or SELECTCC
// cond - The condition node that SELECT or SELECTCC uses
//
void Lowering::TryLowerCselToCinc(GenTreeOp* select, GenTree* cond)
void Lowering::TryLowerCnsIntCselToCinc(GenTreeOp* select, GenTree* cond)
{
assert(select->OperIs(GT_SELECT, GT_SELECTCC));

Expand Down
132 changes: 132 additions & 0 deletions src/tests/JIT/opt/Compares/conditionalIncrements.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,30 @@ public static void cinc_byte(byte op1, int expected)
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 35)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_byte(byte op1, int expected)
{
//ARM64-FULL-LINE: cmp {{w[0-9]+}}, #44
//ARM64-FULL-LINE-NEXT: csinc {{w[0-9]+}}, {{w[0-9]+}}, {{w[0-9]+}}, {{gt|le}}
int result = op1 <= 44 ? op1 + 1 : 5;
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 35)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_byte_reverse(byte op1, int expected)
{
//ARM64-FULL-LINE: cmp {{w[0-9]+}}, #44
//ARM64-FULL-LINE-NEXT: csinc {{w[0-9]+}}, {{w[0-9]+}}, {{w[0-9]+}}, {{gt|le}}
int result = op1 <= 44 ? 1 + op1 : 5;
Assert.Equal(expected, result);
}

[Theory]
[InlineData(72, byte.MinValue)]
[InlineData(32, byte.MaxValue)]
Expand Down Expand Up @@ -58,6 +82,30 @@ public static void cinc_short(short op1, short expected)
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 35)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_short(short op1, short expected)
{
//ARM64-FULL-LINE: cmp {{w[0-9]+}}, #44
//ARM64-FULL-LINE-NEXT: csinc {{w[0-9]+}}, {{w[0-9]+}}, {{w[0-9]+}}, {{gt|le}}
short result = (short) (op1 <= 44 ? op1 + 1 : 5);
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 35)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_short_reverse(short op1, short expected)
{
//ARM64-FULL-LINE: cmp {{w[0-9]+}}, #44
//ARM64-FULL-LINE-NEXT: csinc {{w[0-9]+}}, {{w[0-9]+}}, {{w[0-9]+}}, {{gt|le}}
short result = (short) (op1 <= 44 ? 1 + op1 : 5);
Assert.Equal(expected, result);
}

[Theory]
[InlineData(76, short.MinValue)]
[InlineData(-35, short.MaxValue)]
Expand All @@ -82,6 +130,66 @@ public static void cinc_int(int op1, int expected)
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 35)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_int(int op1, int expected)
{
//ARM64-FULL-LINE: cmp {{w[0-9]+}}, #44
//ARM64-FULL-LINE-NEXT: csinc {{w[0-9]+}}, {{w[0-9]+}}, {{w[0-9]+}}, {{gt|le}}
int result = op1 <= 44 ? op1 + 1 : 5;
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 35)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_int_reverse(int op1, int expected)
{
//ARM64-FULL-LINE: cmp {{w[0-9]+}}, #44
//ARM64-FULL-LINE-NEXT: csinc {{w[0-9]+}}, {{w[0-9]+}}, {{w[0-9]+}}, {{gt|le}}
int result = op1 <= 44 ? 1 + op1 : 5;
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 69)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_int_postinc(int op1, int expected)
{
int result = 2 * op1;

//ARM64-FULL-LINE: cmp {{w[0-9]+}}, #44
//ARM64-FULL-LINE-NEXT: csinc {{w[0-9]+}}, {{w[0-9]+}}, {{w[0-9]+}}, {{ge|lt}}
if (op1 < 44) {
result++;
} else {
result = 5;
}
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 149)]
[InlineData(34, 5)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_int_postinc_false(int op1, int expected)
{
int result = 2 * op1;

//ARM64-FULL-LINE: cmp {{w[0-9]+}}, #44
//ARM64-FULL-LINE-NEXT: csinc {{w[0-9]+}}, {{w[0-9]+}}, {{w[0-9]+}}, {{ge|lt}}
if (op1 < 44) {
result = 5;
} else {
result++;
}
Assert.Equal(expected, result);
}

[Theory]
[InlineData(77, int.MinValue)]
[InlineData(37, int.MaxValue)]
Expand All @@ -107,6 +215,30 @@ public static void cinc_long(long op1, long expected)
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 35)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_long(long op1, long expected)
{
//ARM64-FULL-LINE: cmp {{x[0-9]+}}, #48
//ARM64-FULL-LINE-NEXT: csinc {{x[0-9]+}}, {{x[0-9]+}}, {{x[0-9]+}}, {{gt|le}}
long result = op1 <= 48 ? op1 + 1 : 5;
Assert.Equal(expected, result);
}

[Theory]
[InlineData(74, 5)]
[InlineData(34, 35)]
[MethodImpl(MethodImplOptions.NoInlining)]
public static void csinc_long_reverse(long op1, long expected)
{
//ARM64-FULL-LINE: cmp {{x[0-9]+}}, #48
//ARM64-FULL-LINE-NEXT: csinc {{x[0-9]+}}, {{x[0-9]+}}, {{x[0-9]+}}, {{gt|le}}
long result = op1 <= 48 ? 1 + op1 : 5;
Assert.Equal(expected, result);
}

[Theory]
[InlineData(79, long.MaxValue)]
[InlineData(39, long.MinValue)]
Expand Down