Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make garaga-rs an importable Rust lib and add get_groth16_calldata rust function #229

Merged
merged 13 commits into from
Oct 25, 2024
24 changes: 23 additions & 1 deletion hydra/garaga/starknet/groth16_contract_generator/calldata.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from garaga import garaga_rs
from garaga.definitions import G1G2Pair, G1Point
from garaga.starknet.groth16_contract_generator.parsing_utils import (
Groth16Proof,
Expand All @@ -6,10 +7,15 @@
from garaga.starknet.tests_and_calldata_generators.mpcheck import MPCheckCalldataBuilder
from garaga.starknet.tests_and_calldata_generators.msm import MSMCalldataBuilder

garaga_rs.get_groth16_calldata


def groth16_calldata_from_vk_and_proof(
vk: Groth16VerifyingKey, proof: Groth16Proof
vk: Groth16VerifyingKey, proof: Groth16Proof, use_rust: bool = True
) -> list[int]:
if use_rust:
return _groth16_calldata_from_vk_and_proof_rust(vk, proof)

assert (
vk.curve_id == proof.curve_id
), f"Curve ID mismatch: {vk.curve_id} != {proof.curve_id}"
Expand Down Expand Up @@ -68,6 +74,22 @@ def groth16_calldata_from_vk_and_proof(
return [len(calldata)] + calldata


def _groth16_calldata_from_vk_and_proof_rust(
vk: Groth16VerifyingKey, proof: Groth16Proof
) -> list[int]:
assert (
vk.curve_id == proof.curve_id
), f"Curve ID mismatch: {vk.curve_id} != {proof.curve_id}"

return garaga_rs.get_groth16_calldata(
proof.flatten(),
vk.flatten(),
proof.curve_id.value,
proof.image_id,
proof.journal,
)


if __name__ == "__main__":
VK_PATH = "hydra/garaga/starknet/groth16_contract_generator/examples/snarkjs_vk_bn254.json"
PROOF_PATH = "hydra/garaga/starknet/groth16_contract_generator/examples/snarkjs_proof_bn254.json"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,16 @@ def serialize_to_cairo(self) -> str:
"""
return code

def flatten(self) -> list[int]:
lst = []
lst.extend([self.alpha.x, self.alpha.y])
lst.extend([self.beta.x[0], self.beta.x[1], self.beta.y[0], self.beta.y[1]])
lst.extend([self.gamma.x[0], self.gamma.x[1], self.gamma.y[0], self.gamma.y[1]])
lst.extend([self.delta.x[0], self.delta.x[1], self.delta.y[0], self.delta.y[1]])
for point in self.ic:
lst.extend([point.x, point.y])
return lst


def reverse_byte_order_uint256(value: int | bytes) -> int:
if isinstance(value, int):
Expand Down Expand Up @@ -469,6 +479,16 @@ def serialize_to_calldata(self) -> list[int]:
cd.extend(io.bigint_split(pub, 2, 2**128))
return cd

def flatten(self, include_public_inputs: bool = True) -> list[int]:
lst = []
lst.extend([self.a.x, self.a.y])
lst.extend([self.b.x[0], self.b.x[1]])
lst.extend([self.b.y[0], self.b.y[1]])
lst.extend([self.c.x, self.c.y])
if include_public_inputs:
lst.extend(self.public_inputs)
return lst


class ExitCode:
def __init__(self, system, user):
Expand Down
41 changes: 41 additions & 0 deletions tests/hydra/starknet/test_groth16_vk_proof_parsing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest

from garaga.starknet.groth16_contract_generator.calldata import (
groth16_calldata_from_vk_and_proof,
)
from garaga.starknet.groth16_contract_generator.parsing_utils import (
Groth16Proof,
Groth16VerifyingKey,
Expand Down Expand Up @@ -45,3 +48,41 @@ def test_proof_parsing_with_public_input(proof_path: str, pub_inputs_path: str):
proof = Groth16Proof.from_json(proof_path, pub_inputs_path)

print(proof)


@pytest.mark.parametrize(
"proof_path, vk_path, pub_inputs_path",
[
(f"{PATH}/proof_bn254.json", f"{PATH}/vk_bn254.json", None),
(f"{PATH}/proof_bls.json", f"{PATH}/vk_bls.json", None),
(
f"{PATH}/gnark_proof_bn254.json",
f"{PATH}/gnark_vk_bn254.json",
f"{PATH}/gnark_public_bn254.json",
),
(
f"{PATH}/snarkjs_proof_bn254.json",
f"{PATH}/snarkjs_vk_bn254.json",
f"{PATH}/snarkjs_public_bn254.json",
),
(f"{PATH}/proof_risc0.json", f"{PATH}/vk_risc0.json", None),
],
)
def test_calldata_generation(
proof_path: str, vk_path: str, pub_inputs_path: str | None
):
import time

vk = Groth16VerifyingKey.from_json(vk_path)
proof = Groth16Proof.from_json(proof_path, pub_inputs_path)

start = time.time()
calldata = groth16_calldata_from_vk_and_proof(vk, proof, use_rust=False)
end = time.time()
print(f"Python time: {end - start}")

start = time.time()
calldata_rust = groth16_calldata_from_vk_and_proof(vk, proof, use_rust=True)
end = time.time()
print(f"Rust time: {end - start}")
assert calldata == calldata_rust
8 changes: 4 additions & 4 deletions tools/garaga_rs/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tools/garaga_rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "garaga_rs"
crate-type = ["cdylib"]
crate-type = ["cdylib", "rlib"]

[profile.release]
lto = true
Expand Down
20 changes: 13 additions & 7 deletions tools/garaga_rs/src/algebra/g1point.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::definitions::{CurveParamsProvider, FieldElement};
use lambdaworks_math::field::traits::IsPrimeField;
use num_bigint::{BigInt, BigUint, Sign};

#[derive(Debug, Clone)]
pub struct G1Point<F: IsPrimeField> {
pub x: FieldElement<F>,
Expand Down Expand Up @@ -48,15 +47,16 @@ impl<F: IsPrimeField + CurveParamsProvider<F>> G1Point<F> {

let lambda = if self.eq(other) {
let alpha = F::get_curve_params().a;

(FieldElement::<F>::from(3_u64) * self.x.square() + alpha)
/ (FieldElement::<F>::from(2_u64) * self.y.clone())
let numerator = FieldElement::<F>::from(3_u64) * &self.x.square() + alpha;
let denominator = FieldElement::<F>::from(2_u64) * &self.y;
numerator / denominator
} else {
(other.y.clone() - self.y.clone()) / (other.x.clone() - self.x.clone())
(&other.y - &self.y) / (&other.x - &self.x)
};

let x3 = lambda.square() - self.x.clone() - other.x.clone();
let y3 = lambda * (self.x.clone() - x3.clone()) - self.y.clone();
let x3 = &lambda.square() - &self.x - &other.x;

let y3 = &lambda * &(self.x.clone() - &x3) - &self.y;

G1Point::new_unchecked(x3, y3)
}
Expand Down Expand Up @@ -128,6 +128,12 @@ impl<F: IsPrimeField + CurveParamsProvider<F>> G1Point<F> {
self.y.representative().to_string()
);
}
pub fn generator() -> Self {
let curve_params = F::get_curve_params();
let generator_x = curve_params.g_x;
let generator_y = curve_params.g_y;
G1Point::new(generator_x, generator_y).unwrap()
}
}

impl<F: IsPrimeField> PartialEq for G1Point<F> {
Expand Down
33 changes: 16 additions & 17 deletions tools/garaga_rs/src/algebra/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,32 +229,31 @@ impl<F: IsPrimeField> std::ops::Add<&Polynomial<F>> for &Polynomial<F> {

fn add(self, a_polynomial: &Polynomial<F>) -> Self::Output {
let (pa, pb) = pad_with_zero_coefficients(self, a_polynomial);
let iter_coeff_pa = pa.coefficients.iter();
let iter_coeff_pb = pb.coefficients.iter();
let new_coefficients = iter_coeff_pa.zip(iter_coeff_pb).map(|(x, y)| x + y);
let new_coefficients_vec = new_coefficients.collect::<Vec<FieldElement<F>>>();
Polynomial::new(new_coefficients_vec)
let new_coefficients = pa
.coefficients
.iter()
.zip(pb.coefficients.iter())
.map(|(x, y)| x + y)
.collect();
Polynomial::new(new_coefficients)
}
}

impl<F: IsPrimeField> std::ops::Add for Polynomial<F> {
type Output = Polynomial<F>;

fn add(self, other: Polynomial<F>) -> Polynomial<F> {
let (ns, no) = (self.coefficients.len(), other.coefficients.len());
if ns >= no {
let mut coeffs = self.coefficients.clone();
for (i, coeff) in other.coefficients.iter().enumerate() {
coeffs[i] += coeff.clone();
}
Polynomial::new(coeffs)
let (mut longer, shorter) = if self.coefficients.len() >= other.coefficients.len() {
(self.coefficients, other.coefficients)
} else {
let mut coeffs = other.coefficients.clone();
for (i, coeff) in self.coefficients.iter().enumerate() {
coeffs[i] += coeff.clone();
}
Polynomial::new(coeffs)
(other.coefficients, self.coefficients)
};

for (i, coeff) in shorter.into_iter().enumerate() {
longer[i] += coeff;
}

Polynomial::new(longer)
}
}

Expand Down
1 change: 1 addition & 0 deletions tools/garaga_rs/src/calldata/full_proof_with_hints.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod groth16;
Loading
Loading