Skip to content

Commit

Permalink
fix(float.circom): rewrite float.circom
Browse files Browse the repository at this point in the history
use a similar but require less constraint for IntegerDivision
  • Loading branch information
chaosma committed Sep 21, 2022
1 parent c07d5d9 commit efd4617
Showing 1 changed file with 79 additions and 27 deletions.
106 changes: 79 additions & 27 deletions circuits/circom/float.circom
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,93 @@ include "../node_modules/circomlib/circuits/bitify.circom";
include "../node_modules/circomlib/circuits/comparators.circom";
include "../node_modules/circomlib/circuits/mux1.circom";

template msb(n) {
// require in < 2**n
signal input in;
signal output out;
component n2b = Num2Bits(n);
n2b.in <== in;
n2b.out[n-1] ==> out;
}

template shift(n) {
// shift divident and partial rem together
// divident will reduce by 1 bit each call
// require divident < 2**n
signal input divident;
signal input rem;
signal output divident1;
signal output rem1;

component lmsb = msb(n);
lmsb.in <== divident;
rem1 <== rem * 2 + lmsb.out;
divident1 <== divident - lmsb.out * 2**(n-1);
}

template IntegerDivision(n) {
// require max(a, b) < 2**n
signal input a;
signal input b;
signal output c;
assert (n < 253);
assert (a < 2**n);
assert (b < 2**n);

var r = a;
var d = b * 2**n;
component lta = LessThan(252);
lta.in[0] <== a;
lta.in[1] <== 2**n;
lta.out === 1;
component ltb = LessThan(252);
ltb.in[0] <== b;
ltb.in[1] <== 2**n;
ltb.out === 1;

component isz = IsZero();
isz.in <== b;
isz.out === 0;

var divident = a;
var rem = 0;

component b2n = Bits2Num(n);
component shf[n];
component lt[n];
component mux[n];
component mux1[n];

for (var i = n - 1; i >= 0; i--) {
lt[i] = LessThan(2*n);
shf[i] = shift(i+1);
lt[i] = LessEqThan(n);
mux[i] = Mux1();
mux1[i] = Mux1();
}

for (var i = n-1; i >= 0; i--) {
lt[i].in[0] <== 2 * r;
lt[i].in[1] <== d;
shf[i].divident <== divident;
shf[i].rem <== rem;
divident = shf[i].divident1;
rem = shf[i].rem1;

mux[i].s <== lt[i].out;
mux[i].c[0] <== 1;
mux[i].c[1] <== 0;
lt[i].in[0] <== b;
lt[i].in[1] <== rem;

mux1[i].s <== lt[i].out;
mux1[i].c[0] <== 2 * r - d;
mux1[i].c[1] <== 2 * r;
mux[i].s <== lt[i].out;
mux[i].c[0] <== 0;
mux[i].c[1] <== 1;
mux[i].out ==> b2n.in[i];

b2n.in[i] <== mux[i].out;
r = mux1[i].out;
rem = rem - b * lt[i].out;
}
c <== b2n.out;
b2n.out ==> c;
}

template ToFloat(W) {
// W is the number of digits in decimal part
assert (W <= 76); // 10^76 < 2^253
// 10^75 < 2^252
assert(W < 75);

// in*10^W <= 10^75
signal input in;
signal output out;
assert (in < (10**(76-W)));
component lt = LessEqThan(252);
lt.in[0] <== in;
lt.in[1] <== 10**(75-W);
lt.out === 1;
out <== in * (10**W);
}

Expand All @@ -59,11 +100,19 @@ template DivisionFromFloat(W, n) {
signal input a;
signal input b;
signal output c;

assert(W < 75);
assert(n < 252);

component lt = LessThan(252);
lt.in[0] <== a;
lt.in[1] <== 10 ** (75 - W);
lt.out === 1;

component div = IntegerDivision(n);
div.a <== a * (10 ** W);
div.b <== b;
c <== div.c;
log(c);
}

template DivisionFromNormal(W, n) {
Expand All @@ -90,11 +139,16 @@ template MultiplicationFromFloat(W, n) {
signal input a;
signal input b;
signal output c;
component div = IntegerDivision(n+4*W);

assert(W < 75);
assert(n < 252);
assert(10**W < 2**n);

component div = IntegerDivision(n);
// TODO: check a*b is not overflow
div.a <== a * b;
div.b <== 10**W;
c <== div.c;
log(c);
}

template MultiplicationFromNormal(W, n) {
Expand All @@ -113,5 +167,3 @@ template MultiplicationFromNormal(W, n) {
mul.b <== tfb.out;
c <== mul.c;
}


0 comments on commit efd4617

Please sign in to comment.