Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert polynomial coefficient type to use modarith #1096

Draft
wants to merge 48 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
05e1775
update mlir::polynomial -> mlir::heir::polynomial and start migrating
j2kun Nov 16, 2024
c4f1a69
migrate polynomial off coefficientType
j2kun Nov 18, 2024
c6f7cd0
migrate tests off coefficientType
j2kun Nov 18, 2024
3210e8f
copy some more tests from upstream
j2kun Nov 19, 2024
80cf66d
modarith -> mod_arith
j2kun Nov 19, 2024
07e6b11
try fixing polynomial patterns and op arg types
j2kun Nov 19, 2024
994afa7
fix some more tests
j2kun Nov 19, 2024
a8b06c6
fix lwe-to-polynomial conversion
j2kun Nov 19, 2024
c914b60
fix polynomial fromtensor builders and verifiers
j2kun Nov 19, 2024
4eb3d34
fix tests broken by syntax or missing mod_arith dep
j2kun Nov 19, 2024
3c977c6
fix more tests for mod_arith check lines
j2kun Nov 19, 2024
3f740c2
fix a test for from_tensor type change
j2kun Nov 19, 2024
ef43489
fix mod_arith conversions in ntt benchmark
j2kun Nov 19, 2024
28488f8
fix canonicalization test and use signed printer
j2kun Nov 19, 2024
e9bfece
fix two tests to align from_tensor types
j2kun Nov 19, 2024
3bb2c87
change test from a parser error to an impl-needs-update error
j2kun Nov 19, 2024
26839f8
fix LWEToPolynomial
j2kun Nov 19, 2024
3d5d4ea
fix bad regex breaking test
j2kun Nov 19, 2024
94e9756
start fixing polynomial-to-llvm
j2kun Nov 19, 2024
d462816
fix lower_add, lower constant and add to mod_arith
j2kun Nov 20, 2024
49a7a9b
mod_arith.mod_arith -> mod_arith.int
j2kun Nov 20, 2024
c4842e2
fix storage type on two tests
j2kun Nov 20, 2024
54e24e5
allow toTensor to use AnyType for elt ty
j2kun Nov 20, 2024
db1878e
add mod_arith lowering to poly-to-llvm pipeline
j2kun Nov 20, 2024
68e9cf4
use mod arith in lower_add_runner
j2kun Nov 20, 2024
115067e
augment lower_add for mod_arith
j2kun Nov 20, 2024
5c7f63e
convert lower_intt and runner test
j2kun Nov 20, 2024
fba5c84
more coefficientModulus syntax removal
j2kun Nov 20, 2024
b6788e2
convert lower_leading_term
j2kun Nov 21, 2024
8a1a103
lower monomial to mod_arith
j2kun Nov 21, 2024
3f85a83
fix test for lower_monomial_mul
j2kun Nov 21, 2024
5e1e717
lower to_tensor to mod_arith
j2kun Nov 21, 2024
e622522
reduce scope of one test to test type lowering only
j2kun Nov 21, 2024
b27e56f
migrate lower_mul_scalar
j2kun Nov 21, 2024
60bf141
convert lower_ntt
j2kun Nov 21, 2024
7ed7d91
migrate lower_ntt
j2kun Nov 21, 2024
ad550d3
start converting ConvertMul to mod_arith
j2kun Nov 21, 2024
793aceb
migrate lower_mul
j2kun Nov 23, 2024
69f47e3
lower ModArith.constant, convert lower_mul_runner
j2kun Nov 23, 2024
408890c
fix some, but not all, lower_mul runner tests
j2kun Nov 23, 2024
b174779
fix lower_ntt_perf_runner
j2kun Nov 23, 2024
7b20069
fix ntt_benchmark
j2kun Nov 23, 2024
239fabb
Fix more tests by removing sub canonicalization
j2kun Nov 23, 2024
e20d695
fix materialization of negative constants
j2kun Nov 23, 2024
9bef19e
fix remaining tests
j2kun Nov 23, 2024
b1ce8f4
polynomial-to-standard -> polynomial-to-mod-arith
j2kun Nov 23, 2024
fd37924
rename PolynomialToStandard -> PolynomialToModArith
j2kun Nov 23, 2024
3d1386e
add TODO
j2kun Nov 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion lib/Dialect/BGV/Conversions/BGVToLWE/BGVToLWE.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def BGVToLWE : Pass<"bgv-to-lwe"> {

let dependentDialects = [
"mlir::heir::bgv::BGVDialect",
"mlir::polynomial::PolynomialDialect",
"mlir::heir::lwe::LWEDialect",
"mlir::tensor::TensorDialect",
];
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ cc_library(
hdrs = ["FHEHelpers.h"],
deps = [
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
Expand Down
5 changes: 4 additions & 1 deletion lib/Dialect/CGGI/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ def SetDefaultParameters : Pass<"cggi-set-default-parameters"> {
The specific parameters are hard-coded in
`lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp`.
}];
let dependentDialects = ["mlir::heir::cggi::CGGIDialect"];
let dependentDialects = [
"mlir::heir::cggi::CGGIDialect",
"mlir::heir::mod_arith::ModArithDialect",
];
}

def BooleanVectorizer : Pass<"cggi-boolean-vectorize"> {
Expand Down
13 changes: 8 additions & 5 deletions lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "lib/Dialect/CGGI/IR/CGGIAttributes.h"
#include "lib/Dialect/CGGI/IR/CGGIOps.h"
#include "lib/Dialect/LWE/IR/LWEAttributes.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "lib/Dialect/Polynomial/IR/Polynomial.h"
#include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
Expand All @@ -29,11 +30,12 @@ struct SetDefaultParameters
MLIRContext &context = getContext();
unsigned defaultRlweDimension = 1;
APInt defaultCmod = APInt::getOneBitSet(64, 32);
std::vector<::mlir::polynomial::IntMonomial> monomials;
std::vector<::mlir::heir::polynomial::IntMonomial> monomials;
monomials.emplace_back(1, 1024);
monomials.emplace_back(1, 0);
::mlir::polynomial::IntPolynomial defaultPolyIdeal =
::mlir::polynomial::IntPolynomial::fromMonomials(monomials).value();
::mlir::heir::polynomial::IntPolynomial defaultPolyIdeal =
::mlir::heir::polynomial::IntPolynomial::fromMonomials(monomials)
.value();

// https://github.com/google/jaxite/blob/main/jaxite/jaxite_bool/bool_params.py
unsigned defaultBskNoiseVariance = 65536; // stdev = 2**8, var = 2**16
Expand All @@ -46,8 +48,9 @@ struct SetDefaultParameters

lwe::RLWEParamsAttr defaultRlweParams = lwe::RLWEParamsAttr::get(
&context, defaultRlweDimension,
::mlir::polynomial::RingAttr::get(
intType, IntegerAttr::get(intType, defaultCmod),
::mlir::heir::polynomial::RingAttr::get(
mod_arith::ModArithType::get(
&context, IntegerAttr::get(intType, defaultCmod)),
polynomial::IntPolynomialAttr::get(&context, defaultPolyIdeal)));
CGGIParamsAttr defaultParams =
CGGIParamsAttr::get(&context, defaultRlweParams,
Expand Down
20 changes: 16 additions & 4 deletions lib/Dialect/FHEHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "lib/Dialect/LWE/IR/LWEAttributes.h"
#include "lib/Dialect/LWE/IR/LWETypes.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
Expand Down Expand Up @@ -64,14 +65,25 @@ LogicalResult verifyModulusSwitchOrRescaleOp(Op* op) {
if (outRing != op->getToRing()) {
return op->emitOpError() << "output ring should match to_ring";
}
if (xRing.getCoefficientModulus().getValue().ule(
outRing.getCoefficientModulus().getValue())) {

auto xRingCoeffType =
dyn_cast<mod_arith::ModArithType>(xRing.getCoefficientType());
auto outRingCoeffType =
dyn_cast<mod_arith::ModArithType>(outRing.getCoefficientType());

if (!xRingCoeffType || !outRingCoeffType) {
return op->emitOpError()
<< "input and output rings should have mod_arith coefficient types";
}

if (xRingCoeffType.getModulus().getValue().ule(
outRingCoeffType.getModulus().getValue())) {
return op->emitOpError()
<< "output ring modulus should be less than the input ring modulus";
}
if (!xRing.getCoefficientModulus()
if (!xRingCoeffType.getModulus()
.getValue()
.urem(outRing.getCoefficientModulus().getValue())
.urem(outRingCoeffType.getModulus().getValue())
.isZero()) {
return op->emitOpError()
<< "output ring modulus should divide the input ring modulus";
Expand Down
93 changes: 63 additions & 30 deletions lib/Dialect/LWE/Conversions/LWEToPolynomial/LWEToPolynomial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "lib/Dialect/LWE/IR/LWEAttributes.h"
#include "lib/Dialect/LWE/IR/LWEOps.h"
#include "lib/Dialect/LWE/IR/LWETypes.h"
#include "lib/Dialect/ModArith/IR/ModArithOps.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "lib/Dialect/Polynomial/IR/Polynomial.h"
#include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h"
#include "lib/Dialect/Polynomial/IR/PolynomialOps.h"
Expand Down Expand Up @@ -47,26 +49,26 @@ class CiphertextTypeConverter : public TypeConverter {
addConversion([ctx](lwe::RLWECiphertextType type) -> Type {
auto rlweParams = type.getRlweParams();
auto ring = rlweParams.getRing();
auto polyTy = ::mlir::polynomial::PolynomialType::get(ctx, ring);
auto polyTy = ::mlir::heir::polynomial::PolynomialType::get(ctx, ring);

return RankedTensorType::get({rlweParams.getDimension()}, polyTy);
});
addConversion([ctx](lwe::RLWEPlaintextType type) -> Type {
auto ring = type.getRing();
auto polyTy = ::mlir::polynomial::PolynomialType::get(ctx, ring);
auto polyTy = ::mlir::heir::polynomial::PolynomialType::get(ctx, ring);
return polyTy;
});
addConversion([ctx](lwe::RLWESecretKeyType type) -> Type {
auto rlweParams = type.getRlweParams();
auto ring = rlweParams.getRing();
auto polyTy = ::mlir::polynomial::PolynomialType::get(ctx, ring);
auto polyTy = ::mlir::heir::polynomial::PolynomialType::get(ctx, ring);

return RankedTensorType::get({rlweParams.getDimension() - 1}, polyTy);
});
addConversion([ctx](lwe::RLWEPublicKeyType type) -> Type {
auto rlweParams = type.getRlweParams();
auto ring = rlweParams.getRing();
auto polyTy = ::mlir::polynomial::PolynomialType::get(ctx, ring);
auto polyTy = ::mlir::heir::polynomial::PolynomialType::get(ctx, ring);

return RankedTensorType::get({rlweParams.getDimension()}, polyTy);
});
Expand Down Expand Up @@ -211,12 +213,16 @@ struct ConvertRLWEEncrypt : public OpConversionPattern<RLWEEncryptOp> {
auto dimension =
inputT.getRing().getPolynomialModulus().getPolynomial().getDegree();

auto elementType = rewriter.getIntegerType(inputT.getRing()
.getCoefficientModulus()
.getType()
.getIntOrFloatBitWidth());
auto coefficientType = inputT.getRing().getCoefficientType();
auto modArithType = dyn_cast<mod_arith::ModArithType>(coefficientType);
if (!modArithType) {
op.emitError() << "Unsupported coefficient type: " << coefficientType;
return failure();
}

auto tensorParams = RankedTensorType::get({dimension}, elementType);
Type tensorEltTy = modArithType.getModulus().getType();
auto tensorParams = RankedTensorType::get({dimension}, tensorEltTy);
auto modArithTensorType = RankedTensorType::get({dimension}, modArithType);

// TODO (#881): Add pass options to change the seed (which is currently
// hardcoded to 0 with index).
Expand All @@ -240,8 +246,11 @@ struct ConvertRLWEEncrypt : public OpConversionPattern<RLWEEncryptOp> {
// Generate random u polynomial from uniform random ternary distribution
auto uTensor =
builder.create<random::SampleOp>(tensorParams, uniformDistribution);
auto u =
builder.create<polynomial::FromTensorOp>(uTensor, inputT.getRing());
// Convert the tensor of ints to a tensor of mod_arith, then a polynomial
auto modArithUTensor =
builder.create<mod_arith::EncapsulateOp>(modArithTensorType, uTensor);
auto u = builder.create<polynomial::FromTensorOp>(modArithUTensor,
inputT.getRing());

// Create a discrete Gaussian distribution
auto discreteGaussianDistributionType = random::DistributionType::get(
Expand All @@ -265,20 +274,24 @@ struct ConvertRLWEEncrypt : public OpConversionPattern<RLWEEncryptOp> {
// multiplication.
// TODO(#876): Migrate to using the plaintext modulus of the encoding info
// attributes.
auto constantT = builder.create<arith::ConstantOp>(
builder.getI32IntegerAttr(1 << cleartextBitwidth));
auto constantT = builder.create<mod_arith::ConstantOp>(
modArithType, 1 << cleartextBitwidth);

// generate random e0 polynomial from discrete gaussian distribution
auto e0Tensor = builder.create<random::SampleOp>(
tensorParams, discreteGaussianDistribution);
auto e0 =
builder.create<polynomial::FromTensorOp>(e0Tensor, inputT.getRing());
auto modArithE0Tensor = builder.create<mod_arith::EncapsulateOp>(
modArithTensorType, e0Tensor);
auto e0 = builder.create<polynomial::FromTensorOp>(modArithE0Tensor,
inputT.getRing());

// generate random e1 polynomial from discrete gaussian distribution
auto e1Tensor = builder.create<random::SampleOp>(
tensorParams, discreteGaussianDistribution);
auto e1 =
builder.create<polynomial::FromTensorOp>(e1Tensor, inputT.getRing());
auto modArithE1Tensor = builder.create<mod_arith::EncapsulateOp>(
modArithTensorType, e1Tensor);
auto e1 = builder.create<polynomial::FromTensorOp>(modArithE1Tensor,
inputT.getRing());

// TODO (#882): Other encryption schemes (e.g. CKKS) may multiply the
// noise or key differently. Add support for those cases.
Expand Down Expand Up @@ -312,8 +325,10 @@ struct ConvertRLWEEncrypt : public OpConversionPattern<RLWEEncryptOp> {
// Generate random e polynomial from discrete gaussian distribution
auto eTensor = builder.create<random::SampleOp>(
tensorParams, discreteGaussianDistribution);
auto e =
builder.create<polynomial::FromTensorOp>(eTensor, inputT.getRing());
auto modArithETensor =
builder.create<mod_arith::EncapsulateOp>(modArithTensorType, eTensor);
auto e = builder.create<polynomial::FromTensorOp>(modArithETensor,
inputT.getRing());

// TODO (#882): Other encryption schemes (e.g. CKKS) may multiply the
// noise or key differently. Add support for those cases.
Expand Down Expand Up @@ -343,7 +358,7 @@ struct ConvertRAdd : public OpConversionPattern<RAddOp> {
LogicalResult matchAndRewrite(
RAddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<::mlir::polynomial::AddOp>(
rewriter.replaceOpWithNewOp<::mlir::heir::polynomial::AddOp>(
op, adaptor.getOperands()[0], adaptor.getOperands()[1]);
return success();
}
Expand All @@ -358,7 +373,7 @@ struct ConvertRSub : public OpConversionPattern<RSubOp> {
LogicalResult matchAndRewrite(
RSubOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<::mlir::polynomial::SubOp>(
rewriter.replaceOpWithNewOp<::mlir::heir::polynomial::SubOp>(
op, adaptor.getOperands()[0], adaptor.getOperands()[1]);
return success();
}
Expand All @@ -377,10 +392,28 @@ struct ConvertRNegate : public OpConversionPattern<RNegateOp> {
auto arg = adaptor.getOperands()[0];
polynomial::PolynomialType polyType = cast<polynomial::PolynomialType>(
cast<RankedTensorType>(arg.getType()).getElementType());
auto neg = rewriter.create<arith::ConstantIntOp>(
loc, -1, polyType.getRing().getCoefficientType());
rewriter.replaceOp(op, rewriter.create<::mlir::polynomial::MulScalarOp>(
loc, arg.getType(), arg, neg));
FailureOr<Value> neg =
llvm::TypeSwitch<Type, FailureOr<Value>>(
polyType.getRing().getCoefficientType())
.Case<mod_arith::ModArithType>(
[&](mod_arith::ModArithType type) -> Value {
return rewriter.create<mod_arith::ConstantOp>(loc, type, -1);
})
.Case<IntegerType>([&](IntegerType type) -> Value {
return rewriter.create<arith::ConstantIntOp>(loc, -1, type);
})
.Default([&](Type type) -> FailureOr<Value> {
op.emitError() << "Unsupported coefficient type: " << type;
return failure();
});

if (failed(neg)) {
return failure();
}

rewriter.replaceOp(op,
rewriter.create<::mlir::heir::polynomial::MulScalarOp>(
loc, arg.getType(), arg, neg.value()));
return success();
}
};
Expand Down Expand Up @@ -425,11 +458,11 @@ struct ConvertRMul : public OpConversionPattern<RMulOp> {
auto y1 =
b.create<tensor::ExtractOp>(yT.getElementType(), y, ValueRange{i1});

auto z0 = b.create<::mlir::polynomial::MulOp>(x0, y0);
auto x0y1 = b.create<::mlir::polynomial::MulOp>(x0, y1);
auto x1y0 = b.create<::mlir::polynomial::MulOp>(x1, y0);
auto z1 = b.create<::mlir::polynomial::AddOp>(x0y1, x1y0);
auto z2 = b.create<::mlir::polynomial::MulOp>(x1, y1);
auto z0 = b.create<::mlir::heir::polynomial::MulOp>(x0, y0);
auto x0y1 = b.create<::mlir::heir::polynomial::MulOp>(x0, y1);
auto x1y0 = b.create<::mlir::heir::polynomial::MulOp>(x1, y0);
auto z1 = b.create<::mlir::heir::polynomial::AddOp>(x0y1, x1y0);
auto z2 = b.create<::mlir::heir::polynomial::MulOp>(x1, y1);

auto z = b.create<tensor::FromElementsOp>(ArrayRef<Value>({z0, z1, z2}));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def LWEToPolynomial : Pass<"lwe-to-polynomial"> {
}];

let dependentDialects = [
"mlir::polynomial::PolynomialDialect",
"mlir::heir::polynomial::PolynomialDialect",
"mlir::tensor::TensorDialect",
"mlir::heir::random::RandomDialect"
];
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/LWE/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ cc_library(
":dialect_inc_gen",
":ops_inc_gen",
":types_inc_gen",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@heir//lib/Dialect/Polynomial/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/LWE/IR/LWEAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def LWE_RLWEParams : AttrDef<LWE_Dialect, "RLWEParams"> {

let parameters = (ins
DefaultValuedParameter<"unsigned", "2">:$dimension,
"::mlir::polynomial::RingAttr":$ring
"::mlir::heir::polynomial::RingAttr":$ring
);

let assemblyFormat = "`<` struct(params) `>`";
Expand Down
38 changes: 24 additions & 14 deletions lib/Dialect/LWE/IR/LWEDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "lib/Dialect/LWE/IR/LWEAttributes.h"
#include "lib/Dialect/LWE/IR/LWEOps.h"
#include "lib/Dialect/LWE/IR/LWETypes.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h"
#include "lib/Dialect/Polynomial/IR/PolynomialTypes.h"
#include "llvm/include/llvm/ADT/STLFunctionalExtras.h" // from @llvm-project
Expand Down Expand Up @@ -129,20 +130,25 @@ LogicalResult requirePolynomialElementTypeFits(
Type elementType, llvm::StringRef encodingName, unsigned cleartextBitwidth,
unsigned cleartextStart,
llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
if (!mlir::isa<::mlir::polynomial::PolynomialType>(elementType)) {
if (!mlir::isa<::mlir::heir::polynomial::PolynomialType>(elementType)) {
return emitError()
<< "Tensors with encoding " << encodingName
<< " must have `polynomial.polynomial` element type, but found "
<< elementType << "\n";
}
::mlir::polynomial::PolynomialType polyType =
llvm::cast<::mlir::polynomial::PolynomialType>(elementType);
::mlir::heir::polynomial::PolynomialType polyType =
llvm::cast<::mlir::heir::polynomial::PolynomialType>(elementType);
// The coefficient modulus takes the place of the plaintext bitwidth for
// RLWE.
unsigned plaintextBitwidth = polyType.getRing()
.getCoefficientModulus()
.getType()
.getIntOrFloatBitWidth();
auto coeffType = dyn_cast<mod_arith::ModArithType>(
polyType.getRing().getCoefficientType());
if (!coeffType) {
return emitError()
<< "The polys in this tensor have a mod_arith coefficient type"
<< " but found " << polyType.getRing().getCoefficientType();
}
unsigned plaintextBitwidth =
coeffType.getModulus().getType().getIntOrFloatBitWidth();

if (plaintextBitwidth < cleartextBitwidth)
return emitError() << "The polys in this tensor have a coefficient "
Expand Down Expand Up @@ -245,7 +251,7 @@ LogicalResult ApplicationDataAttr::verify(

LogicalResult PlaintextSpaceAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
mlir::polynomial::RingAttr ring, Attribute encoding) {
mlir::heir::polynomial::RingAttr ring, Attribute encoding) {
if (mlir::isa<FullCRTPackingEncodingAttr>(encoding)) {
// For full CRT packing, the ring must be of the form x^n + 1 and the
// modulus must be 1 mod n.
Expand All @@ -264,12 +270,16 @@ LogicalResult PlaintextSpaceAttr::verify(
<< "but found " << polyMod << "\n";
}
// Check that the modulus is 1 mod n.
APInt modulus = ring.getCoefficientModulus().getValue();
unsigned n = poly.getDegree();
if (!modulus.urem(APInt(modulus.getBitWidth(), n)).isOne()) {
return emitError()
<< "modulus must be 1 mod n for full CRT packing, mod = "
<< ring.getCoefficientModulus() << " n = " << n << "\n";
auto modCoeffTy =
llvm::dyn_cast<mod_arith::ModArithType>(ring.getCoefficientType());
if (modCoeffTy) {
APInt modulus = modCoeffTy.getModulus().getValue();
unsigned n = poly.getDegree();
if (!modulus.urem(APInt(modulus.getBitWidth(), n)).isOne()) {
return emitError()
<< "modulus must be 1 mod n for full CRT packing, mod = "
<< modulus.getZExtValue() << " n = " << n << "\n";
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/LWE/IR/LWETraits.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class SameOperandsAndResultRings
: public OpTrait::TraitBase<ConcreteType, SameOperandsAndResultRings> {
public:
static LogicalResult verifyTrait(Operation *op) {
::mlir::polynomial::RingAttr rings = nullptr;
::mlir::heir::polynomial::RingAttr rings = nullptr;
for (auto rTy : op->getResultTypes()) {
auto ct = dyn_cast<lwe::RLWECiphertextType>(rTy);
if (!ct) continue;
Expand Down
Loading
Loading