Skip to content

Commit

Permalink
add a ModArith constant op
Browse files Browse the repository at this point in the history
This custom-parsed op and attribute make it relatively easy to
systematically change an `arith.constant` op to `mod_arith`:

```
    arith.constant 17 : i32

-->

mod_arith.constant 17 : !mod_arith.mod_arith<234 : i32>
```

- comment on getValue nightmare
- mod_arith.mod_arith -> mod_arith.int
  • Loading branch information
j2kun committed Nov 19, 2024
1 parent 6d91164 commit 10721ef
Show file tree
Hide file tree
Showing 17 changed files with 198 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class ModArithToArithTypeConverter : public TypeConverter {
}
};

// A herlper function to generate the attribute or type
// needed to represent the result of modarith op as an integer
// A helper function to generate the attribute or type
// needed to represent the result of mod_arith op as an integer
// before applying a remainder operation
template <typename Op>
TypedAttr modulusAttr(Op op, bool mul = false) {
Expand Down
33 changes: 33 additions & 0 deletions lib/Dialect/ModArith/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ cc_library(
"ModArithDialect.cpp",
],
hdrs = [
"ModArithAttributes.h",
"ModArithDialect.h",
"ModArithOps.h",
"ModArithTypes.h",
],
deps = [
":attributes_inc_gen",
":dialect_inc_gen",
":ops_inc_gen",
":types_inc_gen",
Expand All @@ -32,6 +34,7 @@ cc_library(
td_library(
name = "td_files",
srcs = [
"ModArithAttributes.td",
"ModArithDialect.td",
"ModArithOps.td",
"ModArithTypes.td",
Expand Down Expand Up @@ -70,6 +73,36 @@ gentbl_cc_library(
],
)

gentbl_cc_library(
name = "attributes_inc_gen",
tbl_outs = [
(
[
"-gen-attrdef-decls",
"-attrdefs-dialect=mod_arith",
],
"ModArithAttributes.h.inc",
),
(
[
"-gen-attrdef-defs",
"-attrdefs-dialect=mod_arith",
],
"ModArithAttributes.cpp.inc",
),
(
["-gen-attrdef-doc"],
"ModArithAttributes.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ModArithAttributes.td",
deps = [
":dialect_inc_gen",
":td_files",
],
)

gentbl_cc_library(
name = "types_inc_gen",
tbl_outs = [
Expand Down
10 changes: 10 additions & 0 deletions lib/Dialect/ModArith/IR/ModArithAttributes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#ifndef LIB_DIALECT_MODARITH_IR_MODARITHATTRIBUTES_H_
#define LIB_DIALECT_MODARITH_IR_MODARITHATTRIBUTES_H_

#include "lib/Dialect/ModArith/IR/ModArithDialect.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"

#define GET_ATTRDEF_CLASSES
#include "lib/Dialect/ModArith/IR/ModArithAttributes.h.inc"

#endif // LIB_DIALECT_MODARITH_IR_MODARITHATTRIBUTES_H_
43 changes: 43 additions & 0 deletions lib/Dialect/ModArith/IR/ModArithAttributes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef LIB_DIALECT_MODARITH_IR_MODARITHATTRS_TD_
#define LIB_DIALECT_MODARITH_IR_MODARITHATTRS_TD_

include "lib/Dialect/ModArith/IR/ModArithDialect.td"
include "lib/Dialect/ModArith/IR/ModArithTypes.td"
include "mlir/IR/BuiltinAttributes.td"


class ModArith_Attr<string name, string attrMnemonic, list<Trait> traits = []>
: AttrDef<ModArith_Dialect, name, traits> {
let mnemonic = attrMnemonic;
}


def ModArith_ModArithAttr : ModArith_Attr<
"ModArith", "int", [TypedAttrInterface]> {
let summary = "a typed mod_arith attribute";
let description = [{
Example:

```mlir
#attr = 123:i32
#attr_verbose = #mod_arith.int<123:i32>
```
}];
let parameters = (ins "::mlir::heir::mod_arith::ModArithType":$type, "mlir::IntegerAttr":$value);
let assemblyFormat = "$value `:` $type";
let builders = [
AttrBuilderWithInferredContext<(ins "ModArithType":$type,
"const int64_t":$value), [{
return $_get(
type.getContext(),
type,
IntegerAttr::get(type.getModulus().getType(), value));
}]>
];
let extraClassDeclaration = [{
using ValueType = ::mlir::Attribute;
}];
}


#endif // LIB_DIALECT_MODARITH_IR_MODARITHATTRS_TD_
58 changes: 57 additions & 1 deletion lib/Dialect/ModArith/IR/ModArithDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

// NOLINTBEGIN(misc-include-cleaner): Required to define ModArithDialect,
// ModArithTypes, ModArithOps
// ModArithTypes, ModArithOps, ModArithAttributes
#include "lib/Dialect/ModArith/IR/ModArithAttributes.h"
#include "lib/Dialect/ModArith/IR/ModArithOps.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
Expand All @@ -19,6 +20,9 @@
// Generated definitions
#include "lib/Dialect/ModArith/IR/ModArithDialect.cpp.inc"

#define GET_ATTRDEF_CLASSES
#include "lib/Dialect/ModArith/IR/ModArithAttributes.cpp.inc"

#define GET_TYPEDEF_CLASSES
#include "lib/Dialect/ModArith/IR/ModArithTypes.cpp.inc"

Expand All @@ -34,6 +38,10 @@ void ModArithDialect::initialize() {
#define GET_TYPEDEF_LIST
#include "lib/Dialect/ModArith/IR/ModArithTypes.cpp.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "lib/Dialect/ModArith/IR/ModArithAttributes.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "lib/Dialect/ModArith/IR/ModArithOps.cpp.inc"
Expand Down Expand Up @@ -128,6 +136,54 @@ LogicalResult BarrettReduceOp::verify() {
return success();
}

ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
APInt parsedValue(64, 0);
Type parsedType;

if (failed(parser.parseInteger(parsedValue))) {
parser.emitError(parser.getCurrentLocation(),
"found invalid integer value");
return failure();
}

if (parser.parseColon() || parser.parseType(parsedType)) return failure();

auto modArithType = dyn_cast<ModArithType>(parsedType);
if (!modArithType) return failure();

auto outputBitWidth =
modArithType.getModulus().getType().getIntOrFloatBitWidth();
if (parsedValue.getActiveBits() > outputBitWidth)
return parser.emitError(parser.getCurrentLocation(),
"constant value is too large for the modulus");

auto intValue = IntegerAttr::get(modArithType.getModulus().getType(),
parsedValue.trunc(outputBitWidth));
result.addAttribute(
"value", ModArithAttr::get(parser.getContext(), modArithType, intValue));
result.addTypes(modArithType);
return success();
}

void ConstantOp::print(OpAsmPrinter &p) {
p << " ";
// getValue chain:
// op's ModArithAttribute value
// -> ModArithAttribute's IntegerAttr value
// -> IntegerAttr's APInt value
getValue().getValue().getValue().print(p.getStream(), true);
p << " : ";
p.printType(getOutput().getType());
}

LogicalResult ConstantOp::inferReturnTypes(
mlir::MLIRContext *context, std::optional<mlir::Location> loc,
ConstantOpAdaptor adaptor, llvm::SmallVectorImpl<mlir::Type> &returnTypes) {
adaptor.getValue().dump();
returnTypes.push_back(adaptor.getValue().getType());
return success();
}

} // namespace mod_arith
} // namespace heir
} // namespace mlir
1 change: 1 addition & 0 deletions lib/Dialect/ModArith/IR/ModArithDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def ModArith_Dialect : Dialect {

let cppNamespace = "::mlir::heir::mod_arith";
let useDefaultTypePrinterParser = 1;
let useDefaultAttributePrinterParser = 1;

let dependentDialects = [
"arith::ArithDialect",
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/ModArith/IR/ModArithOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define LIB_DIALECT_MODARITH_IR_MODARITHOPS_H_

// NOLINTBEGIN(misc-include-cleaner): Required to define ModArithOps
#include "lib/Dialect/ModArith/IR/ModArithAttributes.h"
#include "lib/Dialect/ModArith/IR/ModArithDialect.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
Expand Down
39 changes: 31 additions & 8 deletions lib/Dialect/ModArith/IR/ModArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define LIB_DIALECT_MODARITH_IR_MODARITHOPS_TD_

include "lib/Dialect/ModArith/IR/ModArithDialect.td"
include "lib/Dialect/ModArith/IR/ModArithAttributes.td"
include "lib/Dialect/ModArith/IR/ModArithTypes.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonTypeConstraints.td"
Expand All @@ -28,8 +29,8 @@ def ModArith_EncapsulateOp : ModArith_Op<"encapsulate", [Pure, ElementwiseMappab

Examples:
```
mod_arith.encapsulate %c0 : i32 -> mod_arith.mod_arith<65537 : i32>
mod_arith.encapsulate %c1 : i64 -> mod_arith.mod_arith<65537>
mod_arith.encapsulate %c0 : i32 -> mod_arith.int<65537 : i32>
mod_arith.encapsulate %c1 : i64 -> mod_arith.int<65537>
```
}];

Expand All @@ -52,10 +53,10 @@ def ModArith_ExtractOp : ModArith_Op<"extract", [Pure, ElementwiseMappable]> {

Examples:
```
%m0 = mod_arith.encapsulate %c0 : i32 -> mod_arith.mod_arith<65537 : i32>
%m1 = mod_arith.encapsulate %c1 : i64 -> mod_arith.mod_arith<65537>
%c2 = mod_arith.extract %m0 : mod_arith.mod_arith<65537 : i32> -> i32
%c3 = mod_arith.extract %m1 : mod_arith.mod_arith<65537> -> i64
%m0 = mod_arith.encapsulate %c0 : i32 -> mod_arith.int<65537 : i32>
%m1 = mod_arith.encapsulate %c1 : i64 -> mod_arith.int<65537>
%c2 = mod_arith.extract %m0 : mod_arith.int<65537 : i32> -> i32
%c3 = mod_arith.extract %m1 : mod_arith.int<65537> -> i64
```
}];

Expand All @@ -67,6 +68,28 @@ def ModArith_ExtractOp : ModArith_Op<"extract", [Pure, ElementwiseMappable]> {
let assemblyFormat = "operands attr-dict `:` type($input) `->` type($output)";
}

def ModArith_ConstantOp : Op<ModArith_Dialect, "constant",
[Pure, InferTypeOpAdaptor]> {
let summary = "Define a constant value via an attribute.";
let description = [{
Example:

```mlir
%0 = mod_arith.constant 123 : !mod_arith.int<65537:i32>
```
}];
let arguments = (ins ModArith_ModArithAttr:$value);
let results = (outs ModArith_ModArithType:$output);
let hasCustomAssemblyFormat = 1;

let builders = [
OpBuilder<(ins "::mlir::heir::mod_arith::ModArithType":$ty, "int64_t":$value), [{
return build($_builder, $_state, ty, mod_arith::ModArithAttr::get(ty, value));
}]>
];
}


def ModArith_ReduceOp : ModArith_Op<"reduce", [Pure, ElementwiseMappable, SameOperandsAndResultType]> {
let summary = "reduce the mod arith type to its canonical representative";

Expand All @@ -77,9 +100,9 @@ def ModArith_ReduceOp : ModArith_Op<"reduce", [Pure, ElementwiseMappable, SameOp
Examples:
```
%c0 = arith.constant 65538 : i32
%m0 = mod_arith.encapsulate %c0 : i32 -> mod_arith.mod_arith<65537 : i32>
%m0 = mod_arith.encapsulate %c0 : i32 -> mod_arith.int<65537 : i32>
// mod_arith.extract %m0 produces 65538
%m1 = mod_arith.reduce %m0 : mod_arith.mod_arith<65537: i32>
%m1 = mod_arith.reduce %m0 : mod_arith.int<65537: i32>
// mod_arith.extract %m1 produces 1
```
}];
Expand Down
14 changes: 7 additions & 7 deletions lib/Dialect/ModArith/IR/ModArithTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ class ModArith_Type<string name, string typeMnemonic>
let mnemonic = typeMnemonic;
}

def ModArith_ModArith : ModArith_Type<"ModArith", "mod_arith"> {
def ModArith_ModArithType : ModArith_Type<"ModArith", "int"> {
let summary = "Integer type with modular arithmetic";
let description = [{
`mod_arith.mod_arith<p>` represents an element of the ring of integers modulo $p$.
`mod_arith.int<p>` represents an element of the ring of integers modulo $p$.
The `modulus` attribute is the ring modulus, and `mod_arith` operations lower to
`arith` operations that produce results in the range `[0, modulus)`, often called
the _canonical representative_.

`modulus` is specified with an integer type suffix, for example,
`mod_arith.mod_arith<65537 : i32>`. This corresponds to the storage type for the
`mod_arith.int<65537 : i32>`. This corresponds to the storage type for the
modulus, and is `i64` by default.

It is required that the underlying integer type should be larger than
Expand All @@ -40,9 +40,9 @@ def ModArith_ModArith : ModArith_Type<"ModArith", "mod_arith"> {

Examples:
```
!Zp1 = !mod_arith.mod_arith<7> // implicitly being i64
!Zp2 = !mod_arith.mod_arith<65537 : i32>
!Zp3 = !mod_arith.mod_arith<536903681 : i64>
!Zp1 = !mod_arith.int<7> // implicitly being i64
!Zp2 = !mod_arith.int<65537 : i32>
!Zp3 = !mod_arith.int<536903681 : i64>
```
}];
let parameters = (ins
Expand All @@ -51,6 +51,6 @@ def ModArith_ModArith : ModArith_Type<"ModArith", "mod_arith"> {
let assemblyFormat = "`<` $modulus `>`";
}

def ModArithLike: TypeOrContainer<ModArith_ModArith, "mod_arith-like">;
def ModArithLike: TypeOrContainer<ModArith_ModArithType, "mod_arith-like">;

#endif // LIB_TYPES_MODARITH_IR_MODARITHTYPES_TD_
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: heir-opt -mod-arith-to-arith --split-input-file %s | FileCheck %s --enable-var-scope

!Zp = !mod_arith.mod_arith<65537 : i32>
!Zp = !mod_arith.int<65537 : i32>
!Zpv = tensor<4x!Zp>

// CHECK-LABEL: @test_lower_encapsulate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }

!Zp = !mod_arith.mod_arith<7681 : i26>
!Zp = !mod_arith.int<7681 : i26>
!Zpv = tensor<4x!Zp>

func.func @test_lower_add() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }

!Zp = !mod_arith.mod_arith<7681 : i26>
!Zp = !mod_arith.int<7681 : i26>
!Zpv = tensor<4x!Zp>

func.func @test_lower_mac() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }

!Zp = !mod_arith.mod_arith<7681 : i26>
!Zp = !mod_arith.int<7681 : i26>
!Zpv = tensor<4x!Zp>

func.func @test_lower_mul() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }

!Zp1 = !mod_arith.mod_arith<7681 : i26>
!Zp1 = !mod_arith.int<7681 : i26>
!Zp1v = tensor<6x!Zp1>
// 33554431 = 2 ** 25 - 1
!Zp2 = !mod_arith.mod_arith<33554431 : i26>
!Zp2 = !mod_arith.int<33554431 : i26>
!Zp2v = tensor<6x!Zp2>

func.func @test_lower_reduce() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }

!Zp = !mod_arith.mod_arith<7681 : i26>
!Zp = !mod_arith.int<7681 : i26>
!Zpv = tensor<4x!Zp>

func.func @test_lower_sub() {
Expand Down
Loading

0 comments on commit 10721ef

Please sign in to comment.