Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR for llvm/llvm-project#64553 #653

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,8 @@ static mlir::ParseResult parseAllocatableOp(FN wrapResultType,
parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
return mlir::failure();
}
result.addAttribute(
"operand_segment_sizes",
builder.getDenseI32ArrayAttr({typeparamsSize, shapeSize}));
result.addAttribute("operandSegmentSizes", builder.getDenseI32ArrayAttr(
{typeparamsSize, shapeSize}));
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.addTypeToList(restype, result.types))
return mlir::failure();
Expand All @@ -149,7 +148,7 @@ static void printAllocatableOp(mlir::OpAsmPrinter &p, OP &op) {
p << ", ";
p.printOperand(sh);
}
p.printOptionalAttrDict(op->getAttrs(), {"in_type", "operand_segment_sizes"});
p.printOptionalAttrDict(op->getAttrs(), {"in_type", "operandSegmentSizes"});
}

//===----------------------------------------------------------------------===//
Expand Down
22 changes: 11 additions & 11 deletions flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func.func @_QPsb1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<!
// CHECK: %[[ONE_2:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: omp.parallel {
// CHECK: %[[ONE_3:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[I_VAR:.*]] = llvm.alloca %[[ONE_3]] x i32 {adapt.valuebyref, in_type = i32, operand_segment_sizes = array<i32: 0, 0>, pinned} : (i64) -> !llvm.ptr<i32>
// CHECK: %[[I_VAR:.*]] = llvm.alloca %[[ONE_3]] x i32 {adapt.valuebyref, in_type = i32, operandSegmentSizes = array<i32: 0, 0>, pinned} : (i64) -> !llvm.ptr<i32>
// CHECK: %[[N:.*]] = llvm.load %[[N_REF]] : !llvm.ptr<i32>
// CHECK: omp.wsloop nowait
// CHECK-SAME: for (%[[I:.*]]) : i32 = (%[[ONE_2]]) to (%[[N]]) inclusive step (%[[ONE_2]]) {
Expand Down Expand Up @@ -200,7 +200,7 @@ func.func @_QPsimd1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref
// CHECK: %[[ONE_2:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: omp.parallel {
// CHECK: %[[ONE_3:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[I_VAR:.*]] = llvm.alloca %[[ONE_3]] x i32 {adapt.valuebyref, in_type = i32, operand_segment_sizes = array<i32: 0, 0>, pinned} : (i64) -> !llvm.ptr<i32>
// CHECK: %[[I_VAR:.*]] = llvm.alloca %[[ONE_3]] x i32 {adapt.valuebyref, in_type = i32, operandSegmentSizes = array<i32: 0, 0>, pinned} : (i64) -> !llvm.ptr<i32>
// CHECK: %[[N:.*]] = llvm.load %[[N_REF]] : !llvm.ptr<i32>
// CHECK: omp.simdloop
// CHECK-SAME: (%[[I:.*]]) : i32 = (%[[ONE_2]]) to (%[[N]]) step (%[[ONE_2]]) {
Expand Down Expand Up @@ -231,13 +231,13 @@ func.func @_QPomp_target_data() {

// CHECK-LABEL: llvm.func @_QPomp_target_data() {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VAL_1:.*]] = llvm.alloca %[[VAL_0]] x !llvm.array<1024 x i32> {bindc_name = "a", in_type = !fir.array<1024xi32>, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_dataEa"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %[[VAL_1:.*]] = llvm.alloca %[[VAL_0]] x !llvm.array<1024 x i32> {bindc_name = "a", in_type = !fir.array<1024xi32>, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_dataEa"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VAL_3:.*]] = llvm.alloca %[[VAL_2]] x !llvm.array<1024 x i32> {bindc_name = "b", in_type = !fir.array<1024xi32>, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_dataEb"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %[[VAL_3:.*]] = llvm.alloca %[[VAL_2]] x !llvm.array<1024 x i32> {bindc_name = "b", in_type = !fir.array<1024xi32>, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_dataEb"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VAL_5:.*]] = llvm.alloca %[[VAL_4]] x !llvm.array<1024 x i32> {bindc_name = "c", in_type = !fir.array<1024xi32>, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_dataEc"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %[[VAL_5:.*]] = llvm.alloca %[[VAL_4]] x !llvm.array<1024 x i32> {bindc_name = "c", in_type = !fir.array<1024xi32>, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_dataEc"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VAL_7:.*]] = llvm.alloca %[[VAL_6]] x !llvm.array<1024 x i32> {bindc_name = "d", in_type = !fir.array<1024xi32>, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_dataEd"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %[[VAL_7:.*]] = llvm.alloca %[[VAL_6]] x !llvm.array<1024 x i32> {bindc_name = "d", in_type = !fir.array<1024xi32>, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_dataEd"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: omp.target_enter_data map((to -> %[[VAL_1]] : !llvm.ptr<array<1024 x i32>>), (to -> %[[VAL_3]] : !llvm.ptr<array<1024 x i32>>), (always, alloc -> %[[VAL_5]] : !llvm.ptr<array<1024 x i32>>))
// CHECK: omp.target_exit_data map((from -> %[[VAL_1]] : !llvm.ptr<array<1024 x i32>>), (from -> %[[VAL_3]] : !llvm.ptr<array<1024 x i32>>), (release -> %[[VAL_5]] : !llvm.ptr<array<1024 x i32>>), (always, delete -> %[[VAL_7]] : !llvm.ptr<array<1024 x i32>>))
// CHECK: llvm.return
Expand Down Expand Up @@ -278,9 +278,9 @@ func.func @_QPopenmp_target_data_region() {

// CHECK-LABEL: llvm.func @_QPopenmp_target_data_region() {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VAL_1:.*]] = llvm.alloca %[[VAL_0]] x !llvm.array<1024 x i32> {bindc_name = "a", in_type = !fir.array<1024xi32>, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_data_regionEa"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %[[VAL_1:.*]] = llvm.alloca %[[VAL_0]] x !llvm.array<1024 x i32> {bindc_name = "a", in_type = !fir.array<1024xi32>, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_data_regionEa"} : (i64) -> !llvm.ptr<array<1024 x i32>>
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VAL_3:.*]] = llvm.alloca %[[VAL_2]] x i32 {bindc_name = "i", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_data_regionEi"} : (i64) -> !llvm.ptr<i32>
// CHECK: %[[VAL_3:.*]] = llvm.alloca %[[VAL_2]] x i32 {bindc_name = "i", in_type = i32, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_data_regionEi"} : (i64) -> !llvm.ptr<i32>
// CHECK: omp.target_data map((tofrom -> %[[VAL_1]] : !llvm.ptr<array<1024 x i32>>)) {
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_5:.*]] = llvm.sext %[[VAL_4]] : i32 to i64
Expand Down Expand Up @@ -338,7 +338,7 @@ func.func @_QPomp_target() {

// CHECK-LABEL: llvm.func @_QPomp_target() {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VAL_1:.*]] = llvm.alloca %[[VAL_0]] x !llvm.array<512 x i32> {bindc_name = "a", in_type = !fir.array<512xi32>, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_targetEa"} : (i64) -> !llvm.ptr<array<512 x i32>>
// CHECK: %[[VAL_1:.*]] = llvm.alloca %[[VAL_0]] x !llvm.array<512 x i32> {bindc_name = "a", in_type = !fir.array<512xi32>, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFomp_targetEa"} : (i64) -> !llvm.ptr<array<512 x i32>>
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(64 : i32) : i32
// CHECK: omp.target thread_limit(%[[VAL_2]] : i32) map((tofrom -> %[[VAL_1]] : !llvm.ptr<array<512 x i32>>)) {
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(10 : i32) : i32
Expand Down Expand Up @@ -544,7 +544,7 @@ func.func @_QPsb() {
// CHECK: llvm.func @_QPsb() {
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[LI_REF:.*]] = llvm.alloca %6 x i32 {bindc_name = "li", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFsbEli"} : (i64) -> !llvm.ptr<i32>
// CHECK: %[[LI_REF:.*]] = llvm.alloca %6 x i32 {bindc_name = "li", in_type = i32, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFsbEli"} : (i64) -> !llvm.ptr<i32>
// CHECK: omp.sections {
// CHECK: omp.section {
// CHECK: llvm.br ^[[BB_ENTRY:.*]]({{.*}})
Expand Down Expand Up @@ -582,7 +582,7 @@ func.func @_QPsb() {
// CHECK: }
// CHECK-LABEL: @_QPsimple_reduction
// CHECK-SAME: %[[ARRAY_REF:.*]]: !llvm.ptr<array<100 x i32>>
// CHECK: %[[RED_ACCUMULATOR:.*]] = llvm.alloca %2 x i32 {bindc_name = "x", in_type = !fir.logical<4>, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFsimple_reductionEx"} : (i64) -> !llvm.ptr<i32>
// CHECK: %[[RED_ACCUMULATOR:.*]] = llvm.alloca %2 x i32 {bindc_name = "x", in_type = !fir.logical<4>, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFsimple_reductionEx"} : (i64) -> !llvm.ptr<i32>
// CHECK: omp.parallel {
// CHECK: omp.wsloop reduction(@[[EQV_REDUCTION]] -> %[[RED_ACCUMULATOR]] : !llvm.ptr<i32>) for
// CHECK: %[[ARRAY_ELEM_REF:.*]] = llvm.getelementptr %[[ARRAY_REF]][0, %{{.*}}] : (!llvm.ptr<array<100 x i32>>, i64) -> !llvm.ptr<i32>
Expand Down
8 changes: 4 additions & 4 deletions flang/test/Fir/convert-to-llvm.fir
Original file line number Diff line number Diff line change
Expand Up @@ -1748,7 +1748,7 @@ func.func @no_reassoc(%arg0: !fir.ref<i32>) {
// CHECK-LABEL: llvm.func @no_reassoc(
// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr<i32>) {
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ALLOC:.*]] = llvm.alloca %[[C1]] x i32 {in_type = i32, operand_segment_sizes = array<i32: 0, 0>} : (i64) -> !llvm.ptr<i32>
// CHECK: %[[ALLOC:.*]] = llvm.alloca %[[C1]] x i32 {in_type = i32, operandSegmentSizes = array<i32: 0, 0>} : (i64) -> !llvm.ptr<i32>
// CHECK: %[[LOAD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr<i32>
// CHECK: llvm.store %[[LOAD]], %[[ALLOC]] : !llvm.ptr<i32>
// CHECK: llvm.return
Expand Down Expand Up @@ -1868,7 +1868,7 @@ func.func private @_QPxb(!fir.box<!fir.array<?x?xf64>>)
// CHECK: %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ARR_SIZE_TMP1:.*]] = llvm.mul %[[C1_0]], %[[N1]] : i64
// CHECK: %[[ARR_SIZE:.*]] = llvm.mul %[[ARR_SIZE_TMP1]], %[[N2]] : i64
// CHECK: %[[ARR:.*]] = llvm.alloca %[[ARR_SIZE]] x f64 {bindc_name = "arr", in_type = !fir.array<?x?xf64>, operand_segment_sizes = array<i32: 0, 2>, uniq_name = "_QFsbEarr"} : (i64) -> !llvm.ptr<f64>
// CHECK: %[[ARR:.*]] = llvm.alloca %[[ARR_SIZE]] x f64 {bindc_name = "arr", in_type = !fir.array<?x?xf64>, operandSegmentSizes = array<i32: 0, 2>, uniq_name = "_QFsbEarr"} : (i64) -> !llvm.ptr<f64>
// CHECK: %[[TYPE_CODE:.*]] = llvm.mlir.constant(28 : i32) : i32
// CHECK: %[[NULL:.*]] = llvm.mlir.null : !llvm.ptr<f64>
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[NULL]][1]
Expand Down Expand Up @@ -1945,9 +1945,9 @@ func.func private @_QPtest_dt_callee(%arg0: !fir.box<!fir.array<?xi32>>)
// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : i64) : i64
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
// CHECK: %[[ALLOCA_SIZE_V:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[V:.*]] = llvm.alloca %[[ALLOCA_SIZE_V]] x i32 {bindc_name = "v", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFtest_dt_sliceEv"} : (i64) -> !llvm.ptr<i32>
// CHECK: %[[V:.*]] = llvm.alloca %[[ALLOCA_SIZE_V]] x i32 {bindc_name = "v", in_type = i32, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFtest_dt_sliceEv"} : (i64) -> !llvm.ptr<i32>
// CHECK: %[[ALLOCA_SIZE_X:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[X:.*]] = llvm.alloca %[[ALLOCA_SIZE_X]] x !llvm.array<20 x struct<"_QFtest_dt_sliceTt", (i32, i32)>> {bindc_name = "x", in_type = !fir.array<20x!fir.type<_QFtest_dt_sliceTt{i:i32,j:i32}>>, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFtest_dt_sliceEx"} : (i64) -> !llvm.ptr<array<20 x struct<"_QFtest_dt_sliceTt", (i32, i32)>>>
// CHECK: %[[X:.*]] = llvm.alloca %[[ALLOCA_SIZE_X]] x !llvm.array<20 x struct<"_QFtest_dt_sliceTt", (i32, i32)>> {bindc_name = "x", in_type = !fir.array<20x!fir.type<_QFtest_dt_sliceTt{i:i32,j:i32}>>, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFtest_dt_sliceEx"} : (i64) -> !llvm.ptr<array<20 x struct<"_QFtest_dt_sliceTt", (i32, i32)>>>
// CHECK: %[[TYPE_CODE:.*]] = llvm.mlir.constant(9 : i32) : i32
// CHECK: %[[NULL:.*]] = llvm.mlir.null : !llvm.ptr<i32>
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[NULL]][1]
Expand Down
2 changes: 1 addition & 1 deletion mlir/docs/PatternRewriter.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ Example output is shown below:
```
//===-------------------------------------------===//
Processing operation : 'cf.cond_br'(0x60f000001120) {
"cf.cond_br"(%arg0)[^bb2, ^bb2] {operand_segment_sizes = array<i32: 1, 0, 0>} : (i1) -> ()
"cf.cond_br"(%arg0)[^bb2, ^bb2] {operandSegmentSizes = array<i32: 1, 0, 0>} : (i1) -> ()

* Pattern SimplifyConstCondBranchPred : 'cf.cond_br -> ()' {
} -> failure : pattern failed to match
Expand Down
24 changes: 24 additions & 0 deletions mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,18 @@ def IRDL_OperandsOp : IRDL_Op<"operands", [HasParent<"OperationOp">]> {

The `mul` operation will expect two operands of type `cmath.complex`, that
have the same type, and return a result of the same type.

The operands can also be marked as variadic or optional:
```mlir
irdl.operands(%0, single %1, optional %2, variadic %3)
```

Here, %0 and %1 are required single operands, %2 is an optional operand,
and %3 is a variadic operand.

When more than one operand is marked as optional or variadic, the operation
will expect a 'operandSegmentSizes' attribute that defines the number of
operands in each segment.
}];

let arguments = (ins Variadic<IRDL_AttributeType>:$args);
Expand Down Expand Up @@ -254,6 +266,18 @@ def IRDL_ResultsOp : IRDL_Op<"results", [HasParent<"OperationOp">]> {

The operation will expect one operand of the `cmath.complex` type, and two
results that have the underlying type of the `cmath.complex`.

The results can also be marked as variadic or optional:
```mlir
irdl.results(%0, single %1, optional %2, variadic %3)
```

Here, %0 and %1 are required single results, %2 is an optional result,
and %3 is a variadic result.

When more than one result is marked as optional or variadic, the operation
will expect a 'resultSegmentSizes' attribute that defines the number of
results in each segment.
}];

let arguments = (ins Variadic<IRDL_AttributeType>:$args);
Expand Down
17 changes: 0 additions & 17 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -874,23 +874,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return cast<DestinationStyleOpInterface>(*this->getOperation())
.hasTensorSemantics();
}

//========================================================================//
// Helper functions to mutate the `operand_segment_sizes` attribute.
// These are useful when cloning and changing operand types.
//========================================================================//
void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); }
void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); }

private:
void setOperandSegmentAt(unsigned idx, unsigned val) {
auto attr = ::llvm::cast<DenseIntElementsAttr>(
(*this)->getAttr("operand_segment_sizes"));
unsigned i = 0;
auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32),
[&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
getOperation()->setAttr("operand_segment_sizes", newAttr);
}
}];

let verify = [{ return detail::verifyStructuredOpInterface($_op); }];
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/IR/OpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -2178,7 +2178,7 @@ def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">;
// to have the same array size.
def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">;

// Uses an attribute named `operand_segment_sizes` to specify how many actual
// Uses an attribute named `operandSegmentSizes` to specify how many actual
// operand each ODS-declared operand (variadic or not) corresponds to.
// This trait is used for ops that have multiple variadic operands but do
// not know statically their size relationship. The attribute must be a 1D
Expand All @@ -2188,7 +2188,7 @@ def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">;
def AttrSizedOperandSegments :
NativeOpTrait<"AttrSizedOperandSegments">, StructuralOpTrait;
// Similar to AttrSizedOperandSegments, but used for results. The attribute
// should be named as `result_segment_sizes`.
// should be named as `resultSegmentSizes`.
def AttrSizedResultSegments :
NativeOpTrait<"AttrSizedResultSegments">, StructuralOpTrait;

Expand Down
8 changes: 3 additions & 5 deletions mlir/include/mlir/IR/OpDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -1331,17 +1331,15 @@ struct HasParent {
/// relationship is not always known statically. For such cases, we need
/// a per-op-instance specification to divide the operands into logical groups
/// or segments. This can be modeled by attributes. The attribute will be named
/// as `operand_segment_sizes`.
/// as `operandSegmentSizes`.
///
/// This trait verifies the attribute for specifying operand segments has
/// the correct type (1D vector) and values (non-negative), etc.
template <typename ConcreteType>
class AttrSizedOperandSegments
: public TraitBase<ConcreteType, AttrSizedOperandSegments> {
public:
static StringRef getOperandSegmentSizeAttr() {
return "operand_segment_sizes";
}
static StringRef getOperandSegmentSizeAttr() { return "operandSegmentSizes"; }

static LogicalResult verifyTrait(Operation *op) {
return ::mlir::OpTrait::impl::verifyOperandSizeAttr(
Expand All @@ -1354,7 +1352,7 @@ template <typename ConcreteType>
class AttrSizedResultSegments
: public TraitBase<ConcreteType, AttrSizedResultSegments> {
public:
static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; }
static StringRef getResultSegmentSizeAttr() { return "resultSegmentSizes"; }

static LogicalResult verifyTrait(Operation *op) {
return ::mlir::OpTrait::impl::verifyResultSizeAttr(
Expand Down
14 changes: 8 additions & 6 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -715,18 +715,20 @@ class AsmParser {
//===--------------------------------------------------------------------===//

/// This class represents a StringSwitch like class that is useful for parsing
/// expected keywords. On construction, it invokes `parseKeyword` and
/// processes each of the provided cases statements until a match is hit. The
/// provided `ResultT` must be assignable from `failure()`.
/// expected keywords. On construction, unless a non-empty keyword is
/// provided, it invokes `parseKeyword` and processes each of the provided
/// cases statements until a match is hit. The provided `ResultT` must be
/// assignable from `failure()`.
template <typename ResultT = ParseResult>
class KeywordSwitch {
public:
KeywordSwitch(AsmParser &parser)
KeywordSwitch(AsmParser &parser, StringRef *keyword = nullptr)
: parser(parser), loc(parser.getCurrentLocation()) {
if (failed(parser.parseKeywordOrCompletion(&keyword)))
if (keyword && !keyword->empty())
this->keyword = *keyword;
else if (failed(parser.parseKeywordOrCompletion(&this->keyword)))
result = failure();
}

/// Case that uses the provided value when true.
KeywordSwitch &Case(StringLiteral str, ResultT value) {
return Case(str, [&](StringRef, SMLoc) { return std::move(value); });
Expand Down
Loading