Skip to content

Commit

Permalink
convert lower_intt
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Nov 20, 2024
1 parent 829a27b commit 3a71467
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1294,9 +1294,11 @@ struct ConvertNTT : public OpConversionPattern<NTTOp> {
op.emitError("expected coefficient type to be mod_arith type");
return failure();
}
auto inputConvertedFromModArith =
b.create<mod_arith::ExtractOp>(inputType, adaptor.getInput());
auto nttResult = fastNTT<false>(
b, ring, op.getRoot().value(), inputType,
computeReverseBitOrder(b, inputType, adaptor.getInput()));
computeReverseBitOrder(b, inputType, inputConvertedFromModArith));

// Insert the ring encoding here to the input type
auto resultType = RankedTensorType::get(inputType.getShape(),
Expand Down Expand Up @@ -1338,8 +1340,6 @@ struct ConvertINTT : public OpConversionPattern<INTTOp> {
// Remove the encoded ring from the input tensor type
auto resultType =
RankedTensorType::get(inputType.getShape(), inputType.getElementType());
auto input = b.create<tensor::CastOp>(resultType, adaptor.getInput());

auto coeffType =
dyn_cast<ModArithType>(polyTy.getRing().getCoefficientType());
// FIXME: file an issue
Expand All @@ -1348,10 +1348,16 @@ struct ConvertINTT : public OpConversionPattern<INTTOp> {
op.emitError("expected coefficient type to be mod_arith type");
return failure();
}

auto input = b.create<tensor::CastOp>(resultType, adaptor.getInput());
auto nttResult =
fastNTT<true>(b, ring, op.getRoot().value(), resultType, input);

rewriter.replaceOp(op, computeReverseBitOrder(b, resultType, nttResult));
auto reversedBitOrder = computeReverseBitOrder(b, resultType, nttResult);
auto outputType = typeConverter->convertType(op.getOutput().getType());
auto converted =
b.create<mod_arith::EncapsulateOp>(outputType, reversedBitOrder);
rewriter.replaceOp(op, converted);

return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
// CHECK-DAG: #[[ADD_DIV_MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1 floordiv 2)>
// CHECK-DAG: #[[ROOT_MAP:.*]] = affine_map<(d0, d1) -> ((d0 * 2 + 1) * d1)>

// CHECK: func.func @lower_intt() -> [[OUTPUT_TYPE:.*]] {
// CHECK: func.func @lower_intt() -> [[MOD_ARITH_OUTPUT_TYPE:.*]] {
// CHECK: %[[COEFFS:.*]] = arith.constant dense<[1, 2, 3, 4]> : [[INPUT_TYPE:.*]]
// CHECK: %[[CAST:.*]] = tensor.cast %[[COEFFS]] : [[INPUT_TYPE]] to [[OUTPUT_TYPE]]
// CHECK: %[[CAST:.*]] = tensor.cast %[[COEFFS]] : [[INPUT_TYPE]] to [[OUTPUT_TYPE:.*]]
// CHECK-DAG: %[[INITIAL_VALUE:.*]] = arith.extui %[[CAST]] : [[OUTPUT_TYPE]] to [[INTER_TYPE:.*]]
// CHECK-DAG: %[[CMOD:.*]] = arith.constant 7681 : [[ELEM_TYPE:i64]]
// CHECK-DAG: %[[ROOTS:.*]] = arith.constant dense<[1, 1213, 4298, 5756]> : [[INTER_TYPE]]
Expand Down Expand Up @@ -76,8 +76,8 @@
// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[RES_TRUNC]][%[[REV_INDEX]]] : [[OUTPUT_TYPE]]
// CHECK: linalg.yield %[[EXTRACTED]] : i32
// CHECK: } -> [[OUTPUT_TYPE]]

// CHECK: return %[[ORDERED_OUTPUT]] : [[OUTPUT_TYPE]]
// CHECK: %[[CONVERTED_OUTPUT:.*]] = mod_arith.encapsulate %[[ORDERED_OUTPUT]] : [[OUTPUT_TYPE]] -> [[MOD_ARITH_OUTPUT_TYPE:.*]]
// CHECK: return %[[CONVERTED_OUTPUT]] : [[MOD_ARITH_OUTPUT_TYPE]]

func.func @lower_intt() -> !poly_ty {
%ntt_coeffs = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32, #ring>
Expand Down

0 comments on commit 3a71467

Please sign in to comment.