Skip to content

Commit

Permalink
[MLIR] Add AtomicRMWRegionOp.
Browse files Browse the repository at this point in the history
  • Loading branch information
pifon2a committed Apr 20, 2020
1 parent 5247499 commit 871beba
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 1 deletion.
71 changes: 71 additions & 0 deletions mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,77 @@ def AtomicRMWOp : Std_Op<"atomic_rmw", [
}];
}

def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [
SingleBlockImplicitTerminator<"AtomicYieldOp">,
TypesMatchWith<"result type matches element type of memref",
"memref", "result",
"$_self.cast<MemRefType>().getElementType()">
]> {
let summary = "atomic read-modify-write operation with a region";
let description = [{
The `atomic_rmw` operation provides a way to perform a read-modify-write
sequence that is free from data races. The memref operand represents the
buffer that the read and write will be performed against, as accessed by
the specified indices. The arity of the indices is the rank of the memref.
The result represents the latest value that was stored. The region contains
the code for the modification itself. The entry block has a single argument
that represents the value stored in `memref[indices]` before the write is
performed.

Example:

```mlir
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%current_value : f32):
%c1 = constant 1.0 : f32
%inc = addf %c1, %current_value : f32
atomic_yield %inc : f32
}
```
}];

let arguments = (ins
MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
Variadic<Index>:$indices);

let results = (outs
AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);

let regions = (region AnyRegion:$body);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"Builder *builder, OperationState &result, "
"Value memref, ValueRange ivs">
];

let extraClassDeclaration = [{
OpBuilder getBodyBuilder() {
assert(!body().empty() && "Unexpected empty 'body' region.");
Block &block = body().front();
return OpBuilder(&block, block.end());
}
// The value stored in memref[ivs].
Value getCurrentValue() {
return body().front().getArgument(0);
}
}];
}

def AtomicYieldOp : Std_Op<"atomic_yield", [
HasParent<"GenericAtomicRMWOp">,
NoSideEffect,
Terminator
]> {
let summary = "yield operation for GenericAtomicRMWOp";
let description = [{
"atomic_yield" yields an SSA value from a GenericAtomicRMWOp region.
}];

let arguments = (ins AnyType:$result);
let assemblyFormat = "$result attr-dict `:` type($result)";
}

//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//
Expand Down
71 changes: 71 additions & 0 deletions mlir/lib/Dialect/StandardOps/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,77 @@ static LogicalResult verify(AtomicRMWOp op) {
return success();
}

//===----------------------------------------------------------------------===//
// GenericAtomicRMWOp
//===----------------------------------------------------------------------===//

void GenericAtomicRMWOp::build(Builder *builder, OperationState &result,
Value memref, ValueRange ivs) {
result.addOperands(memref);
result.addOperands(ivs);

if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
Type elementType = memrefType.getElementType();
result.addTypes(elementType);

Region *bodyRegion = result.addRegion();
bodyRegion->push_back(new Block());
bodyRegion->front().addArgument(elementType);
}
}

static LogicalResult verify(GenericAtomicRMWOp op) {
auto &block = op.body().front();
if (block.getNumArguments() != 1)
return op.emitOpError("expected single number of entry block arguments");

if (op.getResult().getType() != block.getArgument(0).getType())
return op.emitOpError(
"expected block argument of the same type result type");
return success();
}

static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType memref;
Type memrefType;
SmallVector<OpAsmParser::OperandType, 4> ivs;

Type indexType = parser.getBuilder().getIndexType();
if (parser.parseOperand(memref) ||
parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
parser.parseColonType(memrefType) ||
parser.resolveOperand(memref, memrefType, result.operands) ||
parser.resolveOperands(ivs, indexType, result.operands))
return failure();

Region *body = result.addRegion();
if (parser.parseRegion(*body, llvm::None, llvm::None))
return failure();
result.types.push_back(memrefType.cast<MemRefType>().getElementType());
return success();
}

static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
p << op.getOperationName() << ' ' << op.memref() << "[" << op.indices()
<< "] : " << op.memref().getType();
p.printRegion(op.body());
p.printOptionalAttrDict(op.getAttrs());
}

//===----------------------------------------------------------------------===//
// AtomicYieldOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(AtomicYieldOp op) {
Type parentType = op.getParentOp()->getResultTypes().front();
Type resultType = op.result().getType();
if (parentType != resultType)
return op.emitOpError() << "types mismatch between yield op: " << resultType
<< " and its parent: " << parentType;
return success();
}

//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//
Expand Down
16 changes: 15 additions & 1 deletion mlir/test/IR/core-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -751,9 +751,23 @@ func @tensor_load_store(%0 : memref<4x4xi32>) {
}

// CHECK-LABEL: func @atomic_rmw
// CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index)
func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
// CHECK: %{{.*}} = atomic_rmw "addf" %{{.*}}, %{{.*}}[%{{.*}}]
%x = atomic_rmw "addf" %val, %I[%i] : (f32, memref<10xf32>) -> f32
// CHECK: atomic_rmw "addf" [[VAL]], [[BUF]]{{\[}}[[I]]]
return
}

// CHECK-LABEL: func @generic_atomic_rmw
// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index)
func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) {
%x = generic_atomic_rmw %I[%i, %j] : memref<1x2xf32> {
// CHECK-NEXT: generic_atomic_rmw [[BUF]]{{\[}}[[I]], [[J]]] : memref
^bb0(%old_value : f32):
%c1 = constant 1.0 : f32
%out = addf %c1, %old_value : f32
atomic_yield %out : f32
}
return
}

Expand Down
48 changes: 48 additions & 0 deletions mlir/test/IR/invalid-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,54 @@ func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) {

// -----

func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) {
// expected-error@+1 {{expected single number of entry block arguments}}
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%arg0 : f32, %arg1 : f32):
%c1 = constant 1.0 : f32
atomic_yield %c1 : f32
}
return
}

// -----

func @generic_atomic_rmw_wrong_arg_type(%I: memref<10xf32>, %i : index) {
// expected-error@+1 {{expected block argument of the same type result type}}
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%old_value : i32):
%c1 = constant 1.0 : f32
atomic_yield %c1 : f32
}
return
}

// -----

func @generic_atomic_rmw_result_type_mismatch(%I: memref<10xf32>, %i : index) {
// expected-error@+1 {{failed to verify that result type matches element type of memref}}
%0 = "std.generic_atomic_rmw"(%I, %i) ( {
^bb0(%old_value: f32):
%c1 = constant 1.0 : f32
atomic_yield %c1 : f32
}) : (memref<10xf32>, index) -> i32
return
}

// -----

func @atomic_yield_type_mismatch(%I: memref<10xf32>, %i : index) {
// expected-error@+4 {{op types mismatch between yield op: 'i32' and its parent: 'f32'}}
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%old_value : f32):
%c1 = constant 1 : i32
atomic_yield %c1 : i32
}
return
}

// -----

// alignment is not power of 2.
func @assume_alignment(%0: memref<4x4xf16>) {
// expected-error@+1 {{alignment must be power of 2}}
Expand Down

0 comments on commit 871beba

Please sign in to comment.