Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Refactor LoopFuseSiblingOp and support parallel fusion (#94391)" #97523

Merged
merged 1 commit into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,7 @@ def ForallOp : SCF_Op<"forall", [
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
"getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
"replaceWithAdditionalYields", "promoteIfSingleIteration",
"yieldTiledValuesAndReplace"]>,
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"scf::InParallelOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
Expand Down
20 changes: 0 additions & 20 deletions mlir/include/mlir/Dialect/SCF/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,6 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
scf::ForOp root);

//===----------------------------------------------------------------------===//
// Fusion related helpers
//===----------------------------------------------------------------------===//

/// Check structural compatibility between two loops such as iteration space
/// and dominance.
bool checkFusionStructuralLegality(LoopLikeOpInterface target,
LoopLikeOpInterface source,
Diagnostic &diag);

/// Given two scf.forall loops, `target` and `source`, fuses `target` into
/// `source`. Assumes that the given loops are siblings and are independent of
/// each other.
Expand All @@ -212,16 +202,6 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
RewriterBase &rewriter);

/// Given two scf.parallel loops, `target` and `source`, fuses `target` into
/// `source`. Assumes that the given loops are siblings and are independent of
/// each other.
///
/// This function does not perform any legality checks and simply fuses the
/// loops. The caller is responsible for ensuring that the loops are legal to
/// fuse.
scf::ParallelOp fuseIndependentSiblingParallelLoops(scf::ParallelOp target,
scf::ParallelOp source,
RewriterBase &rewriter);
} // namespace mlir

#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
20 changes: 0 additions & 20 deletions mlir/include/mlir/Interfaces/LoopLikeInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,24 +90,4 @@ struct JamBlockGatherer {
/// Include the generated interface declarations.
#include "mlir/Interfaces/LoopLikeInterface.h.inc"

namespace mlir {
/// A function that rewrites `target`'s terminator as a teminator obtained by
/// fusing `source` into `target`.
using FuseTerminatorFn =
function_ref<void(RewriterBase &rewriter, LoopLikeOpInterface source,
LoopLikeOpInterface &target, IRMapping mapping)>;

/// Returns a fused `LoopLikeOpInterface` created by fusing `source` to
/// `target`. The `NewYieldValuesFn` callback is used to pass to the
/// `replaceWithAdditionalYields` interface method to replace the loop with a
/// new loop with (possibly) additional yields, while the `FuseTerminatorFn`
/// callback is repsonsible for updating the fused loop terminator.
LoopLikeOpInterface createFused(LoopLikeOpInterface target,
LoopLikeOpInterface source,
RewriterBase &rewriter,
NewYieldValuesFn newYieldValuesFn,
FuseTerminatorFn fuseTerminatorFn);

} // namespace mlir

#endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_
38 changes: 0 additions & 38 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,44 +618,6 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,

SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }

FailureOr<LoopLikeOpInterface> ForallOp::replaceWithAdditionalYields(
RewriterBase &rewriter, ValueRange newInitOperands,
bool replaceInitOperandUsesInLoop,
const NewYieldValuesFn &newYieldValuesFn) {
// Create a new loop before the existing one, with the extra operands.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(getOperation());
SmallVector<Value> inits(getOutputs());
llvm::append_range(inits, newInitOperands);
scf::ForallOp newLoop = rewriter.create<scf::ForallOp>(
getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
inits, getMapping(),
/*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {});

// Move the loop body to the new op.
rewriter.mergeBlocks(getBody(), newLoop.getBody(),
newLoop.getBody()->getArguments().take_front(
getBody()->getNumArguments()));

if (replaceInitOperandUsesInLoop) {
// Replace all uses of `newInitOperands` with the corresponding basic block
// arguments.
for (auto &&[newOperand, oldOperand] :
llvm::zip(newInitOperands, newLoop.getBody()->getArguments().take_back(
newInitOperands.size()))) {
rewriter.replaceUsesWithIf(newOperand, oldOperand, [&](OpOperand &use) {
Operation *user = use.getOwner();
return newLoop->isProperAncestor(user);
});
}
}

// Replace the old loop.
rewriter.replaceOp(getOperation(),
newLoop->getResults().take_front(getNumResults()));
return cast<LoopLikeOpInterface>(newLoop.getOperation());
}

/// Promotes the loop body of a forallOp to its containing block if it can be
/// determined that the loop has a single iteration.
LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
Expand Down
140 changes: 119 additions & 21 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,8 @@ loopScheduling(scf::ForOp forOp,
return 1;
};

std::optional<int64_t> ubConstant =
getConstantIntValue(forOp.getUpperBound());
std::optional<int64_t> lbConstant =
getConstantIntValue(forOp.getLowerBound());
std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound());
std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound());
DenseMap<Operation *, unsigned> opCycles;
std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
for (Operation &op : forOp.getBody()->getOperations()) {
Expand Down Expand Up @@ -449,6 +447,113 @@ void transform::TakeAssumedBranchOp::getEffects(
// LoopFuseSiblingOp
//===----------------------------------------------------------------------===//

/// Check if `target` and `source` are siblings, in the context that `target`
/// is being fused into `source`.
///
/// This is a simple check that just checks if both operations are in the same
/// block and some checks to ensure that the fused IR does not violate
/// dominance.
static DiagnosedSilenceableFailure isOpSibling(Operation *target,
Operation *source) {
// Check if both operations are same.
if (target == source)
return emitSilenceableFailure(source)
<< "target and source need to be different loops";

// Check if both operations are in the same block.
if (target->getBlock() != source->getBlock())
return emitSilenceableFailure(source)
<< "target and source are not in the same block";

// Check if fusion will violate dominance.
DominanceInfo domInfo(source);
if (target->isBeforeInBlock(source)) {
// Since `target` is before `source`, all users of results of `target`
// need to be dominated by `source`.
for (Operation *user : target->getUsers()) {
if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
return emitSilenceableFailure(target)
<< "user of results of target should be properly dominated by "
"source";
}
}
} else {
// Since `target` is after `source`, all values used by `target` need
// to dominate `source`.

// Check if operands of `target` are dominated by `source`.
for (Value operand : target->getOperands()) {
Operation *operandOp = operand.getDefiningOp();
// Operands without defining operations are block arguments. When `target`
// and `source` occur in the same block, these operands dominate `source`.
if (!operandOp)
continue;

// Operand's defining operation should properly dominate `source`.
if (!domInfo.properlyDominates(operandOp, source,
/*enclosingOpOk=*/false))
return emitSilenceableFailure(target)
<< "operands of target should be properly dominated by source";
}

// Check if values used by `target` are dominated by `source`.
bool failed = false;
OpOperand *failedValue = nullptr;
visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
Operation *operandOp = operand->get().getDefiningOp();
if (operandOp && !domInfo.properlyDominates(operandOp, source,
/*enclosingOpOk=*/false)) {
// `operand` is not an argument of an enclosing block and the defining
// op of `operand` is outside `target` but does not dominate `source`.
failed = true;
failedValue = operand;
}
});

if (failed)
return emitSilenceableFailure(failedValue->getOwner())
<< "values used inside regions of target should be properly "
"dominated by source";
}

return DiagnosedSilenceableFailure::success();
}

/// Check if `target` scf.forall can be fused into `source` scf.forall.
///
/// This simply checks if both loops have the same bounds, steps and mapping.
/// No attempt is made at checking that the side effects of `target` and
/// `source` are independent of each other.
static bool isForallWithIdenticalConfiguration(Operation *target,
Operation *source) {
auto targetOp = dyn_cast<scf::ForallOp>(target);
auto sourceOp = dyn_cast<scf::ForallOp>(source);
if (!targetOp || !sourceOp)
return false;

return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
targetOp.getMixedStep() == sourceOp.getMixedStep() &&
targetOp.getMapping() == sourceOp.getMapping();
}

/// Check if `target` scf.for can be fused into `source` scf.for.
///
/// This simply checks if both loops have the same bounds and steps. No attempt
/// is made at checking that the side effects of `target` and `source` are
/// independent of each other.
static bool isForWithIdenticalConfiguration(Operation *target,
Operation *source) {
auto targetOp = dyn_cast<scf::ForOp>(target);
auto sourceOp = dyn_cast<scf::ForOp>(source);
if (!targetOp || !sourceOp)
return false;

return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
targetOp.getStep() == sourceOp.getStep();
}

DiagnosedSilenceableFailure
transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
Expand All @@ -464,32 +569,25 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
<< "source handle (got " << llvm::range_size(sourceOps) << ")";
}

auto target = dyn_cast<LoopLikeOpInterface>(*targetOps.begin());
auto source = dyn_cast<LoopLikeOpInterface>(*sourceOps.begin());
if (!target || !source)
return emitSilenceableFailure(target->getLoc())
<< "target or source is not a loop op";
Operation *target = *targetOps.begin();
Operation *source = *sourceOps.begin();

// Check if loops can be fused
Diagnostic diag(target.getLoc(), DiagnosticSeverity::Error);
if (!mlir::checkFusionStructuralLegality(target, source, diag))
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
// Check if the target and source are siblings.
DiagnosedSilenceableFailure diag = isOpSibling(target, source);
if (!diag.succeeded())
return diag;

Operation *fusedLoop;
// TODO: Support fusion for loop-like ops besides scf.for, scf.forall
// and scf.parallel.
if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
/// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
if (isForWithIdenticalConfiguration(target, source)) {
fusedLoop = fuseIndependentSiblingForLoops(
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
} else if (isa<scf::ForallOp>(target) && isa<scf::ForallOp>(source)) {
} else if (isForallWithIdenticalConfiguration(target, source)) {
fusedLoop = fuseIndependentSiblingForallLoops(
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
} else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
fusedLoop = fuseIndependentSiblingParallelLoops(
cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
} else
return emitSilenceableFailure(target->getLoc())
<< "unsupported loop type for fusion";
<< "operations cannot be fused";

assert(fusedLoop && "failed to fuse operations");

Expand Down
80 changes: 74 additions & 6 deletions mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpDefinition.h"
Expand All @@ -38,6 +37,24 @@ static bool hasNestedParallelOp(ParallelOp ploop) {
return walkResult.wasInterrupted();
}

/// Verify equal iteration spaces.
static bool equalIterationSpaces(ParallelOp firstPloop,
ParallelOp secondPloop) {
if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
return false;

auto matchOperands = [&](const OperandRange &lhs,
const OperandRange &rhs) -> bool {
// TODO: Extend this to support aliases and equal constants.
return std::equal(lhs.begin(), lhs.end(), rhs.begin());
};
return matchOperands(firstPloop.getLowerBound(),
secondPloop.getLowerBound()) &&
matchOperands(firstPloop.getUpperBound(),
secondPloop.getUpperBound()) &&
matchOperands(firstPloop.getStep(), secondPloop.getStep());
}

/// Checks if the parallel loops have mixed access to the same buffers. Returns
/// `true` if the first parallel loop writes to the same indices that the second
/// loop reads.
Expand Down Expand Up @@ -136,10 +153,9 @@ verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
const IRMapping &firstToSecondPloopIndices,
llvm::function_ref<bool(Value, Value)> mayAlias) {
Diagnostic diag(firstPloop.getLoc(), DiagnosticSeverity::Remark);
return !hasNestedParallelOp(firstPloop) &&
!hasNestedParallelOp(secondPloop) &&
checkFusionStructuralLegality(firstPloop, secondPloop, diag) &&
equalIterationSpaces(firstPloop, secondPloop) &&
succeeded(verifyDependencies(firstPloop, secondPloop,
firstToSecondPloopIndices, mayAlias));
}
Expand All @@ -158,9 +174,61 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
mayAlias))
return;

IRRewriter rewriter(builder);
secondPloop = mlir::fuseIndependentSiblingParallelLoops(
firstPloop, secondPloop, rewriter);
DominanceInfo dom;
// We are fusing first loop into second, make sure there are no users of the
// first loop results between loops.
for (Operation *user : firstPloop->getUsers())
if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
return;

ValueRange inits1 = firstPloop.getInitVals();
ValueRange inits2 = secondPloop.getInitVals();

SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
newInitVars.append(inits2.begin(), inits2.end());

IRRewriter b(builder);
b.setInsertionPoint(secondPloop);
auto newSecondPloop = b.create<ParallelOp>(
secondPloop.getLoc(), secondPloop.getLowerBound(),
secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);

Block *newBlock = newSecondPloop.getBody();
auto term1 = cast<ReduceOp>(block1->getTerminator());
auto term2 = cast<ReduceOp>(block2->getTerminator());

b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
newBlock->getArguments());
b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
newBlock->getArguments());

ValueRange results = newSecondPloop.getResults();
if (!results.empty()) {
b.setInsertionPointToEnd(newBlock);

ValueRange reduceArgs1 = term1.getOperands();
ValueRange reduceArgs2 = term2.getOperands();
SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());

auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);

for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
term1.getReductions(), term2.getReductions()))) {
Block &oldRedBlock = reg.front();
Block &newRedBlock = newReduceOp.getReductions()[i].front();
b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
newRedBlock.getArguments());
}

firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
}
term1->erase();
term2->erase();
firstPloop.erase();
secondPloop.erase();
secondPloop = newSecondPloop;
}

void mlir::scf::naivelyFuseParallelOps(
Expand Down
Loading
Loading