Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More Mlir fixes for 17.x release #673

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,26 @@ static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
Value operandA = op.getOperand(0);
Value operandB = op.getOperand(1);
Type opType = operandA.getType();
Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);

Value logA = b.create<math::LogOp>(opType, operandA);
Value mult = b.create<arith::MulFOp>(opType, logA, operandB);
Value logA = b.create<math::LogOp>(opType, opASquared);
Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
Value expResult = b.create<math::ExpOp>(opType, mult);
rewriter.replaceOp(op, expResult);
Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
Value negCheck =
b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
Value oddPower =
b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);

Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
expResult);
rewriter.replaceOp(op, res);
return success();
}

Expand Down
64 changes: 33 additions & 31 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,17 @@ namespace {
namespace saturated_arith {
struct Wrapper {
static Wrapper stride(int64_t v) {
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0}
: Wrapper{false, v};
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
}
static Wrapper offset(int64_t v) {
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0}
: Wrapper{false, v};
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
}
static Wrapper size(int64_t v) {
return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
}
int64_t asOffset() {
return saturated ? ShapedType::kDynamic : v;
}
int64_t asOffset() { return saturated ? ShapedType::kDynamic : v; }
int64_t asSize() { return saturated ? ShapedType::kDynamic : v; }
int64_t asStride() {
return saturated ? ShapedType::kDynamic : v;
}
int64_t asStride() { return saturated ? ShapedType::kDynamic : v; }
bool operator==(Wrapper other) {
return (saturated && other.saturated) ||
(!saturated && !other.saturated && v == other.v);
Expand Down Expand Up @@ -732,8 +726,7 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
for (auto it : llvm::zip(sourceStrides, resultStrides)) {
auto ss = std::get<0>(it), st = std::get<1>(it);
if (ss != st)
if (ShapedType::isDynamic(ss) &&
!ShapedType::isDynamic(st))
if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
return false;
}

Expand Down Expand Up @@ -766,8 +759,7 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
// same. They are also compatible if either one is dynamic (see
// description of MemRefCastOp for details).
auto checkCompatible = [](int64_t a, int64_t b) {
return (ShapedType::isDynamic(a) ||
ShapedType::isDynamic(b) || a == b);
return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
};
if (!checkCompatible(aOffset, bOffset))
return false;
Expand Down Expand Up @@ -1890,8 +1882,7 @@ LogicalResult ReinterpretCastOp::verify() {
// Match offset in result memref type and in static_offsets attribute.
int64_t expectedOffset = getStaticOffsets().front();
if (!ShapedType::isDynamic(resultOffset) &&
!ShapedType::isDynamic(expectedOffset) &&
resultOffset != expectedOffset)
!ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
return emitError("expected result type with offset = ")
<< expectedOffset << " instead of " << resultOffset;

Expand Down Expand Up @@ -2945,18 +2936,6 @@ static MemRefType getCanonicalSubViewResultType(
nonRankReducedType.getMemorySpace());
}

/// Compute the canonical result type of a SubViewOp. Call `inferResultType`
/// to deduce the result type. Additionally, reduce the rank of the inferred
/// result type if `currentResultType` is lower rank than `sourceType`.
static MemRefType getCanonicalSubViewResultType(
MemRefType currentResultType, MemRefType sourceType,
ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
return getCanonicalSubViewResultType(currentResultType, sourceType,
sourceType, mixedOffsets, mixedSizes,
mixedStrides);
}

Value mlir::memref::createCanonicalRankReducingSubViewOp(
OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
auto memrefType = llvm::cast<MemRefType>(memref.getType());
Expand Down Expand Up @@ -3109,9 +3088,32 @@ struct SubViewReturnTypeCanonicalizer {
MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
mixedOffsets, mixedSizes,
mixedStrides);
// Infer a memref type without taking into account any rank reductions.
MemRefType nonReducedType = cast<MemRefType>(SubViewOp::inferResultType(
op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides));

// Directly return the non-rank reduced type if there are no dropped dims.
llvm::SmallBitVector droppedDims = op.getDroppedDims();
if (droppedDims.empty())
return nonReducedType;

// Take the strides and offset from the non-rank reduced type.
auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType);

// Drop dims from shape and strides.
SmallVector<int64_t> targetShape;
SmallVector<int64_t> targetStrides;
for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
if (droppedDims.test(i))
continue;
targetStrides.push_back(nonReducedStrides[i]);
targetShape.push_back(nonReducedType.getDimSize(i));
}

return MemRefType::get(targetShape, nonReducedType.getElementType(),
StridedLayoutAttr::get(nonReducedType.getContext(),
offset, targetStrides),
nonReducedType.getMemorySpace());
}
};

Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,12 @@ std::pair<size_t, size_t> AliasInitializer::visitImpl(

void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) {
auto it = std::next(aliases.begin(), aliasIndex);

// If already marked non-deferrable stop the recursion.
// All children should already be marked non-deferrable as well.
if (!it->second.canBeDeferred)
return;

it->second.canBeDeferred = false;

// Propagate the non-deferrable flag to any child aliases.
Expand Down
13 changes: 8 additions & 5 deletions mlir/lib/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
for (auto *dialect : ctx->getLoadedDialects()) {
#ifndef NDEBUG
dialect->handleUseOfUndefinedPromisedInterface(interfaceKind, interfaceName);
dialect->handleUseOfUndefinedPromisedInterface(interfaceKind,
interfaceName);
#endif
if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
interfaces.insert(interface);
Expand Down Expand Up @@ -243,8 +244,9 @@ void DialectRegistry::applyExtensions(Dialect *dialect) const {
extension.apply(ctx, requiredDialects);
};

for (const auto &extension : extensions)
applyExtension(*extension);
// Note: Additional extensions may be added while applying an extension.
for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
applyExtension(*extensions[i]);
}

void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
Expand All @@ -264,8 +266,9 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
extension.apply(ctx, requiredDialects);
};

for (const auto &extension : extensions)
applyExtension(*extension);
// Note: Additional extensions may be added while applying an extension.
for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
applyExtension(*extensions[i]);
}

bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
Expand Down
17 changes: 14 additions & 3 deletions mlir/test/Dialect/Math/expand-math.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,21 @@ func.func @roundf_func(%a: f32) -> f32 {
// CHECK-LABEL: func @powf_func
// CHECK-SAME: ([[ARG0:%.+]]: f64, [[ARG1:%.+]]: f64)
func.func @powf_func(%a: f64, %b: f64) ->f64 {
// CHECK-DAG: [[LOG:%.+]] = math.log [[ARG0]]
// CHECK-DAG: [[MULT:%.+]] = arith.mulf [[LOG]], [[ARG1]]
// CHECK-DAG = [[CST0:%.+]] = arith.constant 0.000000e+00
// CHECK-DAG: [[TWO:%.+]] = arith.constant 2.000000e+00
// CHECK-DAG: [[NEGONE:%.+]] = arith.constant -1.000000e+00
// CHECK-DAG: [[SQR:%.+]] = arith.mulf [[ARG0]], [[ARG0]]
// CHECK-DAG: [[HALF:%.+]] = arith.divf [[ARG1]], [[TWO]]
// CHECK-DAG: [[LOG:%.+]] = math.log [[SQR]]
// CHECK-DAG: [[MULT:%.+]] = arith.mulf [[HALF]], [[LOG]]
// CHECK-DAG: [[EXPR:%.+]] = math.exp [[MULT]]
// CHECK: return [[EXPR]]
// CHECK-DAG: [[NEGEXPR:%.+]] = arith.mulf [[EXPR]], [[NEGONE]]
// CHECK-DAG: [[REMF:%.+]] = arith.remf [[ARG1]], [[TWO]]
// CHECK-DAG: [[CMPNEG:%.+]] = arith.cmpf olt, [[ARG0]]
// CHECK-DAG: [[CMPZERO:%.+]] = arith.cmpf one, [[REMF]]
// CHECK-DAG: [[AND:%.+]] = arith.andi [[CMPZERO]], [[CMPNEG]]
// CHECK-DAG: [[SEL:%.+]] = arith.select [[AND]], [[NEGEXPR]], [[EXPR]]
// CHECK: return [[SEL]]
%ret = math.powf %a, %b : f64
return %ret : f64
}
Expand Down
17 changes: 16 additions & 1 deletion mlir/test/Dialect/MemRef/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,7 @@ func.func @fold_multiple_memory_space_cast(%arg : memref<?xf32>) -> memref<?xf32

// -----

// CHECK-lABEL: func @ub_negative_alloc_size
// CHECK-LABEL: func private @ub_negative_alloc_size
func.func private @ub_negative_alloc_size() -> memref<?x?x?xi1> {
%idx1 = index.constant 1
%c-2 = arith.constant -2 : index
Expand All @@ -940,3 +940,18 @@ func.func private @ub_negative_alloc_size() -> memref<?x?x?xi1> {
%alloc = memref.alloc(%c15, %c-2, %idx1) : memref<?x?x?xi1>
return %alloc : memref<?x?x?xi1>
}

// -----

// CHECK-LABEL: func @subview_rank_reduction(
// CHECK-SAME: %[[arg0:.*]]: memref<1x384x384xf32>, %[[arg1:.*]]: index
func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index)
-> memref<?x?xf32, strided<[384, 1], offset: ?>> {
%c1 = arith.constant 1 : index
// CHECK: %[[subview:.*]] = memref.subview %[[arg0]][0, %[[arg1]], %[[arg1]]] [1, 1, %[[arg1]]] [1, 1, 1] : memref<1x384x384xf32> to memref<1x?xf32, strided<[384, 1], offset: ?>>
// CHECK: %[[cast:.*]] = memref.cast %[[subview]] : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref<?x?xf32, strided<[384, 1], offset: ?>>
%0 = memref.subview %arg0[0, %idx, %idx] [1, %c1, %idx] [1, 1, 1]
: memref<1x384x384xf32> to memref<?x?xf32, strided<[384, 1], offset: ?>>
// CHECK: return %[[cast]]
return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>>
}
12 changes: 12 additions & 0 deletions mlir/test/IR/recursive-type.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// RUN: mlir-opt %s -test-recursive-types | FileCheck %s

// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>

// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
Expand All @@ -12,6 +14,16 @@ func.func @roundtrip() {
// into inifinite recursion.
// CHECK: !testrec
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<type_to_alias, test_rec<type_to_alias>>

// CHECK: () -> ![[$NAME]]
// CHECK: () -> ![[$NAME]]
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>

// CHECK: () -> ![[$NAME2]]
// CHECK: () -> ![[$NAME2]]
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
return
}

Expand Down
4 changes: 4 additions & 0 deletions mlir/test/lib/Dialect/Test/TestDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::FinalAlias;
}
}
if (auto recAliasType = dyn_cast<TestRecursiveAliasType>(type)) {
os << recAliasType.getName();
return AliasResult::FinalAlias;
}
return AliasResult::NoAlias;
}

Expand Down
22 changes: 22 additions & 0 deletions mlir/test/lib/Dialect/Test/TestTypeDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -369,4 +369,26 @@ def TestTypeElseAnchorStruct : Test_Type<"TestTypeElseAnchorStruct"> {
let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`";
}

def TestI32 : Test_Type<"TestI32"> {
let mnemonic = "i32";
}

def TestRecursiveAlias
: Test_Type<"TestRecursiveAlias", [NativeTypeTrait<"IsMutable">]> {
let mnemonic = "test_rec_alias";
let storageClass = "TestRecursiveTypeStorage";
let storageNamespace = "test";
let genStorageClass = 0;

let parameters = (ins "llvm::StringRef":$name);

let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
Type getBody() const;

void setBody(Type type);
}];
}

#endif // TEST_TYPEDEFS
51 changes: 51 additions & 0 deletions mlir/test/lib/Dialect/Test/TestTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,54 @@ void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
SetVector<Type> stack;
printTestType(type, printer, stack);
}

Type TestRecursiveAliasType::getBody() const { return getImpl()->body; }

void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); }

StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; }

Type TestRecursiveAliasType::parse(AsmParser &parser) {
thread_local static SetVector<Type> stack;

StringRef name;
if (parser.parseLess() || parser.parseKeyword(&name))
return Type();
auto rec = TestRecursiveAliasType::get(parser.getContext(), name);

// If this type already has been parsed above in the stack, expect just the
// name.
if (stack.contains(rec)) {
if (failed(parser.parseGreater()))
return Type();
return rec;
}

// Otherwise, parse the body and update the type.
if (failed(parser.parseComma()))
return Type();
stack.insert(rec);
Type subtype;
if (parser.parseType(subtype))
return nullptr;
stack.pop_back();
if (!subtype || failed(parser.parseGreater()))
return Type();

rec.setBody(subtype);

return rec;
}

void TestRecursiveAliasType::print(AsmPrinter &printer) const {
thread_local static SetVector<Type> stack;

printer << "<" << getName();
if (!stack.contains(*this)) {
printer << ", ";
stack.insert(*this);
printer << getBody();
stack.pop_back();
}
printer << ">";
}
6 changes: 3 additions & 3 deletions mlir/test/lib/Dialect/Test/TestTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ struct FieldParser<std::optional<int>> {

#include "TestTypeInterfaces.h.inc"

#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.h.inc"

namespace test {

/// Storage for simple named recursive types, where the type is identified by
Expand Down Expand Up @@ -150,4 +147,7 @@ class TestRecursiveType

} // namespace test

#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.h.inc"

#endif // MLIR_TESTTYPES_H
Loading