Skip to content

Commit

Permalink
convert lower_intt and runner test
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Nov 20, 2024
1 parent 829a27b commit beab499
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1294,9 +1294,11 @@ struct ConvertNTT : public OpConversionPattern<NTTOp> {
op.emitError("expected coefficient type to be mod_arith type");
return failure();
}
auto inputConvertedFromModArith =
b.create<mod_arith::ExtractOp>(inputType, adaptor.getInput());
auto nttResult = fastNTT<false>(
b, ring, op.getRoot().value(), inputType,
computeReverseBitOrder(b, inputType, adaptor.getInput()));
computeReverseBitOrder(b, inputType, inputConvertedFromModArith));

// Insert the ring encoding here to the input type
auto resultType = RankedTensorType::get(inputType.getShape(),
Expand Down Expand Up @@ -1338,8 +1340,6 @@ struct ConvertINTT : public OpConversionPattern<INTTOp> {
// Remove the encoded ring from the input tensor type
auto resultType =
RankedTensorType::get(inputType.getShape(), inputType.getElementType());
auto input = b.create<tensor::CastOp>(resultType, adaptor.getInput());

auto coeffType =
dyn_cast<ModArithType>(polyTy.getRing().getCoefficientType());
// FIXME: file an issue
Expand All @@ -1348,10 +1348,16 @@ struct ConvertINTT : public OpConversionPattern<INTTOp> {
op.emitError("expected coefficient type to be mod_arith type");
return failure();
}

auto input = b.create<tensor::CastOp>(resultType, adaptor.getInput());
auto nttResult =
fastNTT<true>(b, ring, op.getRoot().value(), resultType, input);

rewriter.replaceOp(op, computeReverseBitOrder(b, resultType, nttResult));
auto reversedBitOrder = computeReverseBitOrder(b, resultType, nttResult);
auto outputType = typeConverter->convertType(op.getOutput().getType());
auto converted =
b.create<mod_arith::EncapsulateOp>(outputType, reversedBitOrder);
rewriter.replaceOp(op, converted);

return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
// CHECK-DAG: #[[ADD_DIV_MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1 floordiv 2)>
// CHECK-DAG: #[[ROOT_MAP:.*]] = affine_map<(d0, d1) -> ((d0 * 2 + 1) * d1)>

// CHECK: func.func @lower_intt() -> [[OUTPUT_TYPE:.*]] {
// CHECK: func.func @lower_intt() -> [[MOD_ARITH_OUTPUT_TYPE:.*]] {
// CHECK: %[[COEFFS:.*]] = arith.constant dense<[1, 2, 3, 4]> : [[INPUT_TYPE:.*]]
// CHECK: %[[CAST:.*]] = tensor.cast %[[COEFFS]] : [[INPUT_TYPE]] to [[OUTPUT_TYPE]]
// CHECK: %[[CAST:.*]] = tensor.cast %[[COEFFS]] : [[INPUT_TYPE]] to [[OUTPUT_TYPE:.*]]
// CHECK-DAG: %[[INITIAL_VALUE:.*]] = arith.extui %[[CAST]] : [[OUTPUT_TYPE]] to [[INTER_TYPE:.*]]
// CHECK-DAG: %[[CMOD:.*]] = arith.constant 7681 : [[ELEM_TYPE:i64]]
// CHECK-DAG: %[[ROOTS:.*]] = arith.constant dense<[1, 1213, 4298, 5756]> : [[INTER_TYPE]]
Expand Down Expand Up @@ -76,8 +76,8 @@
// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[RES_TRUNC]][%[[REV_INDEX]]] : [[OUTPUT_TYPE]]
// CHECK: linalg.yield %[[EXTRACTED]] : i32
// CHECK: } -> [[OUTPUT_TYPE]]

// CHECK: return %[[ORDERED_OUTPUT]] : [[OUTPUT_TYPE]]
// CHECK: %[[CONVERTED_OUTPUT:.*]] = mod_arith.encapsulate %[[ORDERED_OUTPUT]] : [[OUTPUT_TYPE]] -> [[MOD_ARITH_OUTPUT_TYPE:.*]]
// CHECK: return %[[CONVERTED_OUTPUT]] : [[MOD_ARITH_OUTPUT_TYPE]]

func.func @lower_intt() -> !poly_ty {
%ntt_coeffs = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32, #ring>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }

#cycl = #polynomial.int_polynomial<1 + x**4>
#ring = #polynomial.ring<coefficientType=!mod_arith.int<7681:i32>, polynomialModulus=#cycl>
!coeff_ty = !mod_arith.int<7681:i32>
#ring = #polynomial.ring<coefficientType=!coeff_ty, polynomialModulus=#cycl>
#root = #polynomial.primitive_root<value=1925:i32, degree=8:i32>
!poly_ty = !polynomial.polynomial<ring=#ring>

Expand All @@ -18,9 +19,10 @@ func.func @test_poly_ntt() {
%ntt_coeffs = tensor.cast %coeffs : tensor<4xi32> to tensor<4xi32, #ring>
%0 = polynomial.intt %ntt_coeffs {root=#root} : tensor<4xi32, #ring> -> !poly_ty

%1 = polynomial.to_tensor %0 : !poly_ty -> tensor<4xi32>
%2 = bufferization.to_memref %1 : memref<4xi32>
%U = memref.cast %2 : memref<4xi32> to memref<*xi32>
%1 = polynomial.to_tensor %0 : !poly_ty -> tensor<4x!coeff_ty>
%2 = mod_arith.extract %1 : tensor<4x!coeff_ty> -> tensor<4xi32>
%3 = bufferization.to_memref %2 : memref<4xi32>
%U = memref.cast %3 : memref<4xi32> to memref<*xi32>
func.call @printMemrefI32(%U) : (memref<*xi32>) -> ()
return
}
Expand Down

0 comments on commit beab499

Please sign in to comment.