diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index f69a10334050b9c..3594b9669e3c6d7 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1083,17 +1083,34 @@ struct DimOfMemRefReshape : public OpRewritePattern { return rewriter.notifyMatchFailure( dim, "Dim op is not defined by a reshape op."); + // dim of a memref reshape can be folded if dim.getIndex() dominates the + // reshape. Instead of using `DominanceInfo` (which is usually costly) we + // cheaply check that either of the following conditions hold: + // 1. dim.getIndex() is defined in the same block as reshape but before + // reshape. + // 2. dim.getIndex() is defined in a parent block of + // reshape. + + // Check condition 1 if (dim.getIndex().getParentBlock() == reshape->getBlock()) { if (auto *definingOp = dim.getIndex().getDefiningOp()) { - if (reshape->isBeforeInBlock(definingOp)) + if (reshape->isBeforeInBlock(definingOp)) { return rewriter.notifyMatchFailure( dim, "dim.getIndex is not defined before reshape in the same block."); - } // else dim.getIndex is a block argument to reshape->getBlock - } else if (!dim.getIndex().getParentRegion()->isProperAncestor( - reshape->getParentRegion())) + } + } // else dim.getIndex is a block argument to reshape->getBlock and + // dominates reshape + } // Check condition 2 + else if (dim->getBlock() != reshape->getBlock() && + !dim.getIndex().getParentRegion()->isProperAncestor( + reshape->getParentRegion())) { + // If dim and reshape are in the same block but dim.getIndex() isn't, we + // already know dim.getIndex() dominates reshape without calling + // `isProperAncestor` return rewriter.notifyMatchFailure( dim, "dim.getIndex does not dominate reshape."); + } // Place the load directly after the reshape to ensure that the shape memref // was not mutated.