From efd4617724e956d2566062c6fe882e1d45cba7c4 Mon Sep 17 00:00:00 2001 From: Chao Ma Date: Wed, 21 Sep 2022 22:25:29 +0800 Subject: [PATCH] fix(float.circom): rewrite float.circom use a similar but require less constraint for IntegerDivision --- circuits/circom/float.circom | 106 ++++++++++++++++++++++++++--------- 1 file changed, 79 insertions(+), 27 deletions(-) diff --git a/circuits/circom/float.circom b/circuits/circom/float.circom index bca660a1e1..f07b9802a4 100644 --- a/circuits/circom/float.circom +++ b/circuits/circom/float.circom @@ -3,52 +3,93 @@ include "../node_modules/circomlib/circuits/bitify.circom"; include "../node_modules/circomlib/circuits/comparators.circom"; include "../node_modules/circomlib/circuits/mux1.circom"; +template msb(n) { + // require in < 2**n + signal input in; + signal output out; + component n2b = Num2Bits(n); + n2b.in <== in; + n2b.out[n-1] ==> out; +} + +template shift(n) { + // shift divident and partial rem together + // divident will reduce by 1 bit each call + // require divident < 2**n + signal input divident; + signal input rem; + signal output divident1; + signal output rem1; + + component lmsb = msb(n); + lmsb.in <== divident; + rem1 <== rem * 2 + lmsb.out; + divident1 <== divident - lmsb.out * 2**(n-1); +} + template IntegerDivision(n) { - // require max(a, b) < 2**n signal input a; signal input b; signal output c; - assert (n < 253); - assert (a < 2**n); - assert (b < 2**n); - var r = a; - var d = b * 2**n; + component lta = LessThan(252); + lta.in[0] <== a; + lta.in[1] <== 2**n; + lta.out === 1; + component ltb = LessThan(252); + ltb.in[0] <== b; + ltb.in[1] <== 2**n; + ltb.out === 1; + + component isz = IsZero(); + isz.in <== b; + isz.out === 0; + + var divident = a; + var rem = 0; + component b2n = Bits2Num(n); + component shf[n]; component lt[n]; component mux[n]; - component mux1[n]; for (var i = n - 1; i >= 0; i--) { - lt[i] = LessThan(2*n); + shf[i] = shift(i+1); + lt[i] = LessEqThan(n); mux[i] = Mux1(); - mux1[i] = Mux1(); } for (var i = n-1; i >= 0; i--) { - lt[i].in[0] <== 2 * r; - lt[i].in[1] <== d; + shf[i].divident <== divident; + shf[i].rem <== rem; + divident = shf[i].divident1; + rem = shf[i].rem1; - mux[i].s <== lt[i].out; - mux[i].c[0] <== 1; - mux[i].c[1] <== 0; + lt[i].in[0] <== b; + lt[i].in[1] <== rem; - mux1[i].s <== lt[i].out; - mux1[i].c[0] <== 2 * r - d; - mux1[i].c[1] <== 2 * r; + mux[i].s <== lt[i].out; + mux[i].c[0] <== 0; + mux[i].c[1] <== 1; + mux[i].out ==> b2n.in[i]; - b2n.in[i] <== mux[i].out; - r = mux1[i].out; + rem = rem - b * lt[i].out; } - c <== b2n.out; + b2n.out ==> c; } template ToFloat(W) { // W is the number of digits in decimal part - assert (W <= 76); // 10^76 < 2^253 + // 10^75 < 2^252 + assert(W < 75); + + // in*10^W <= 10^75 signal input in; signal output out; - assert (in < (10**(76-W))); + component lt = LessEqThan(252); + lt.in[0] <== in; + lt.in[1] <== 10**(75-W); + lt.out === 1; out <== in * (10**W); } @@ -59,11 +100,19 @@ template DivisionFromFloat(W, n) { signal input a; signal input b; signal output c; + + assert(W < 75); + assert(n < 252); + + component lt = LessThan(252); + lt.in[0] <== a; + lt.in[1] <== 10 ** (75 - W); + lt.out === 1; + component div = IntegerDivision(n); div.a <== a * (10 ** W); div.b <== b; c <== div.c; - log(c); } template DivisionFromNormal(W, n) { @@ -90,11 +139,16 @@ template MultiplicationFromFloat(W, n) { signal input a; signal input b; signal output c; - component div = IntegerDivision(n+4*W); + + assert(W < 75); + assert(n < 252); + assert(10**W < 2**n); + + component div = IntegerDivision(n); + // TODO: check a*b is not overflow div.a <== a * b; div.b <== 10**W; c <== div.c; - log(c); } template MultiplicationFromNormal(W, n) { @@ -113,5 +167,3 @@ template MultiplicationFromNormal(W, n) { mul.b <== tfb.out; c <== mul.c; } - -