Skip to content

Commit

Permalink
Rust extf mul divmod (#167)
Browse files Browse the repository at this point in the history
Co-authored-by: feltroidprime <[email protected]>
  • Loading branch information
raugfer and feltroidprime authored Aug 22, 2024
1 parent 2986b7b commit 933784e
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 84 deletions.
6 changes: 1 addition & 5 deletions hydra/garaga/hints/ecip.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,8 @@ def zk_ecip_hint(
if ec_group_class == G1Point and use_rust:
pts = []
c_id = Bs[0].curve_id
if c_id == CurveID.BLS12_381:
nb = 48
else:
nb = 32
for pt in Bs:
pts.extend([pt.x.to_bytes(nb, "big"), pt.y.to_bytes(nb, "big")])
pts.extend([pt.x, pt.y])
field_type = get_field_type_from_ec_point(Bs[0])
field = get_base_field(c_id.value, field_type)

Expand Down
29 changes: 7 additions & 22 deletions hydra/garaga/hints/extf_mul.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import operator
from functools import reduce

from garaga import garaga_rs
from garaga.algebra import ModuloCircuitElement, Polynomial, PyFelt
from garaga.definitions import (
direct_to_tower,
Expand All @@ -19,29 +20,13 @@ def nondeterministic_extension_field_mul_divmod(
curve_id: int,
extension_degree: int,
) -> tuple[list[PyFelt], list[PyFelt]]:

Ps = [Polynomial(P) for P in Ps]
field = get_base_field(curve_id)

P_irr = get_irreducible_poly(curve_id, extension_degree)

z_poly = reduce(operator.mul, Ps) # Π(Pi)
z_polyq, z_polyr = divmod(z_poly, P_irr)

z_polyr_coeffs = z_polyr.get_coeffs()
z_polyq_coeffs = z_polyq.get_coeffs()
# assert len(z_polyq_coeffs) <= (
# extension_degree - 1
# ), f"len z_polyq_coeffs={len(z_polyq_coeffs)}, degree: {z_polyq.degree()}"
assert (
len(z_polyr_coeffs) <= extension_degree
), f"len z_polyr_coeffs={len(z_polyr_coeffs)}, degree: {z_polyr.degree()}"

# Extend polynomials with 0 coefficients to match the expected lengths.
# TODO : pass exact expected max degree when len(Ps)>2.
z_polyq_coeffs += [field(0)] * (extension_degree - 1 - len(z_polyq_coeffs))
z_polyr_coeffs += [field(0)] * (extension_degree - len(z_polyr_coeffs))

ps = [[c.value for c in P] for P in Ps]
q, r = garaga_rs.nondeterministic_extension_field_mul_divmod(
curve_id, extension_degree, ps
)
z_polyq_coeffs = [field(c) for c in q] if len(q) > 0 else [field.zero()]
z_polyr_coeffs = [field(c) for c in r] if len(r) > 0 else [field.zero()]
return (z_polyq_coeffs, z_polyr_coeffs)


Expand Down
71 changes: 21 additions & 50 deletions tools/garaga_rs/src/ecip/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,118 +3,89 @@ use lambdaworks_math::elliptic_curve::short_weierstrass::curves::bls12_381::fiel
use lambdaworks_math::elliptic_curve::short_weierstrass::curves::bn_254::field_extension::BN254PrimeField;
use lambdaworks_math::field::element::FieldElement;
use lambdaworks_math::field::traits::IsPrimeField;
use lambdaworks_math::traits::ByteConversion;

use crate::ecip::curve::{SECP256K1PrimeField, SECP256R1PrimeField, X25519PrimeField};
use crate::ecip::ff::FF;
use crate::ecip::g1point::G1Point;
use crate::ecip::rational_function::FunctionFelt;
use crate::ecip::rational_function::RationalFunction;
use crate::io::parse_field_elements_from_list;

use num_bigint::{BigInt, BigUint, ToBigInt};

use super::curve::CurveParamsProvider;

pub fn zk_ecip_hint(
list_bytes: Vec<Vec<u8>>,
list_values: Vec<BigUint>,
list_scalars: Vec<BigUint>,
curve_id: usize,
) -> Result<[Vec<String>; 5], String> {
match curve_id {
0 => {
let list_felts: Vec<FieldElement<BN254PrimeField>> = list_bytes
.into_iter()
.map(|x| {
FieldElement::<BN254PrimeField>::from_bytes_be(&x)
.map_err(|e| format!("Byte conversion error: {:?}", e))
})
.collect::<Result<Vec<FieldElement<BN254PrimeField>>, _>>()?;
let list_felts = parse_field_elements_from_list::<BN254PrimeField>(&list_values)?;

let points: Vec<G1Point<BN254PrimeField>> = list_felts
.chunks(2)
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
.collect();

let scalars: Vec<Vec<i8>> = extract_scalars::<BN254PrimeField>(list_scalars);
Ok(run_ecip::<BN254PrimeField>(points, scalars))
let dss: Vec<Vec<i8>> = construct_digits_vectors::<BN254PrimeField>(list_scalars);
Ok(run_ecip::<BN254PrimeField>(points, dss))
}
1 => {
let list_felts: Vec<FieldElement<BLS12381PrimeField>> = list_bytes
.into_iter()
.map(|x| {
FieldElement::<BLS12381PrimeField>::from_bytes_be(&x)
.map_err(|e| format!("Byte conversion error: {:?}", e))
})
.collect::<Result<Vec<FieldElement<BLS12381PrimeField>>, _>>()?;
let list_felts = parse_field_elements_from_list::<BLS12381PrimeField>(&list_values)?;

let points: Vec<G1Point<BLS12381PrimeField>> = list_felts
.chunks(2)
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
.collect();

let scalars: Vec<Vec<i8>> = extract_scalars::<BLS12381PrimeField>(list_scalars);
Ok(run_ecip::<BLS12381PrimeField>(points, scalars))
let dss: Vec<Vec<i8>> = construct_digits_vectors::<BLS12381PrimeField>(list_scalars);
Ok(run_ecip::<BLS12381PrimeField>(points, dss))
}
2 => {
let list_felts: Vec<FieldElement<SECP256K1PrimeField>> = list_bytes
.into_iter()
.map(|x| {
FieldElement::<SECP256K1PrimeField>::from_bytes_be(&x)
.map_err(|e| format!("Byte conversion error: {:?}", e))
})
.collect::<Result<Vec<FieldElement<SECP256K1PrimeField>>, _>>()?;
let list_felts = parse_field_elements_from_list::<SECP256K1PrimeField>(&list_values)?;

let points: Vec<G1Point<SECP256K1PrimeField>> = list_felts
.chunks(2)
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
.collect();

let scalars: Vec<Vec<i8>> = extract_scalars::<SECP256K1PrimeField>(list_scalars);
Ok(run_ecip::<SECP256K1PrimeField>(points, scalars))
let dss: Vec<Vec<i8>> = construct_digits_vectors::<SECP256K1PrimeField>(list_scalars);
Ok(run_ecip::<SECP256K1PrimeField>(points, dss))
}
3 => {
let list_felts: Vec<FieldElement<SECP256R1PrimeField>> = list_bytes
.into_iter()
.map(|x| {
FieldElement::<SECP256R1PrimeField>::from_bytes_be(&x)
.map_err(|e| format!("Byte conversion error: {:?}", e))
})
.collect::<Result<Vec<FieldElement<SECP256R1PrimeField>>, _>>()?;
let list_felts = parse_field_elements_from_list::<SECP256R1PrimeField>(&list_values)?;

let points: Vec<G1Point<SECP256R1PrimeField>> = list_felts
.chunks(2)
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
.collect();

let scalars: Vec<Vec<i8>> = extract_scalars::<SECP256R1PrimeField>(list_scalars);
Ok(run_ecip::<SECP256R1PrimeField>(points, scalars))
let dss: Vec<Vec<i8>> = construct_digits_vectors::<SECP256R1PrimeField>(list_scalars);
Ok(run_ecip::<SECP256R1PrimeField>(points, dss))
}
4 => {
let list_felts: Vec<FieldElement<X25519PrimeField>> = list_bytes
.into_iter()
.map(|x| {
FieldElement::<X25519PrimeField>::from_bytes_be(&x)
.map_err(|e| format!("Byte conversion error: {:?}", e))
})
.collect::<Result<Vec<FieldElement<X25519PrimeField>>, _>>()?;
let list_felts = parse_field_elements_from_list::<X25519PrimeField>(&list_values)?;

let points: Vec<G1Point<X25519PrimeField>> = list_felts
.chunks(2)
.map(|chunk| G1Point::new(chunk[0].clone(), chunk[1].clone()))
.collect();

let scalars: Vec<Vec<i8>> = extract_scalars::<X25519PrimeField>(list_scalars);
Ok(run_ecip::<X25519PrimeField>(points, scalars))
let dss: Vec<Vec<i8>> = construct_digits_vectors::<X25519PrimeField>(list_scalars);
Ok(run_ecip::<X25519PrimeField>(points, dss))
}
_ => Err(String::from("Invalid curve ID")),
}
}

fn extract_scalars<F: IsPrimeField + CurveParamsProvider<F>>(list: Vec<BigUint>) -> Vec<Vec<i8>> {
fn construct_digits_vectors<F: IsPrimeField + CurveParamsProvider<F>>(
list: Vec<BigUint>,
) -> Vec<Vec<i8>> {
let mut dss_ = Vec::new();

for i in 0..list.len() {
let scalar_biguint = list[i].clone();
for scalar_biguint in list {
let neg_3_digits = neg_3_base_le(scalar_biguint);
dss_.push(neg_3_digits);
}
Expand Down
28 changes: 28 additions & 0 deletions tools/garaga_rs/src/ecip/curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ use lambdaworks_math::field::fields::montgomery_backed_prime_fields::{
IsModulus, MontgomeryBackendPrimeField,
};

use crate::ecip::polynomial::Polynomial;
use lambdaworks_math::field::traits::IsPrimeField;
use lambdaworks_math::unsigned_integer::element::U256;
use num_bigint::BigUint;
use std::cmp::PartialEq;
use std::collections::HashMap;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CurveID {
Expand Down Expand Up @@ -72,6 +74,21 @@ pub struct CurveParams<F: IsPrimeField> {
pub g_y: FieldElement<F>,
pub n: FieldElement<F>, // Order of the curve
pub h: u32, // Cofactor
pub irreducible_polys: HashMap<usize, &'static [i8]>,
}

pub fn get_irreducible_poly<F: IsPrimeField + CurveParamsProvider<F>>(
ext_degree: usize,
) -> Polynomial<F> {
let coeffs = (F::get_curve_params().irreducible_polys)[&ext_degree];
fn lift<F: IsPrimeField>(c: i8) -> FieldElement<F> {
if c >= 0 {
FieldElement::from(c as u64)
} else {
-FieldElement::from(-c as u64)
}
}
return Polynomial::new(coeffs.iter().map(|x| lift::<F>(*x)).collect());
}

/// A trait that provides curve parameters for a specific field type.
Expand Down Expand Up @@ -99,6 +116,7 @@ impl CurveParamsProvider<SECP256K1PrimeField> for SECP256K1PrimeField {
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141",
),
h: 1,
irreducible_polys: HashMap::from([]), // Provide appropriate values here
}
}
}
Expand All @@ -122,6 +140,7 @@ impl CurveParamsProvider<SECP256R1PrimeField> for SECP256R1PrimeField {
"FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551",
),
h: 1,
irreducible_polys: HashMap::from([]), // Provide appropriate values here
}
}
}
Expand All @@ -143,6 +162,7 @@ impl CurveParamsProvider<X25519PrimeField> for X25519PrimeField {
"1000000000000000000000000000000014DEF9DEA2F79CD65812631A5CF5D3ED",
),
h: 8,
irreducible_polys: HashMap::from([]), // Provide appropriate values here
}
}
}
Expand All @@ -158,6 +178,10 @@ impl CurveParamsProvider<BN254PrimeField> for BN254PrimeField {
g_y: FieldElement::from_hex_unchecked("2"), // Replace with actual 'g_y'
n: FieldElement::from_hex_unchecked("1"), // Replace with actual 'n'
h: 1, // Replace with actual 'h'
irreducible_polys: HashMap::from([
(6, [82, 0, 0, -18, 0, 0, 1].as_slice()),
(12, [82, 0, 0, 0, 0, 0, -18, 0, 0, 0, 0, 0, 1].as_slice()),
]),
}
}
}
Expand All @@ -173,6 +197,10 @@ impl CurveParamsProvider<BLS12381PrimeField> for BLS12381PrimeField {
g_y: FieldElement::from_hex_unchecked("2"), // Replace with actual 'g_y'
n: FieldElement::from_hex_unchecked("1"), // Replace with actual 'n'
h: 1, // Replace with actual 'h'
irreducible_polys: HashMap::from([
(6, [2, 0, 0, -2, 0, 0, 1].as_slice()),
(12, [2, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 1].as_slice()),
]),
}
}
}
6 changes: 5 additions & 1 deletion tools/garaga_rs/src/ecip/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ impl<F: IsPrimeField> Polynomial<F> {
Polynomial::new(vec![FieldElement::<F>::zero()])
}

pub fn one() -> Self {
Polynomial::new(vec![FieldElement::<F>::one()])
}

pub fn mul_with_ref(&self, other: &Polynomial<F>) -> Polynomial<F> {
if self.degree() == -1 || other.degree() == -1 {
return Polynomial::zero();
Expand Down Expand Up @@ -142,7 +146,7 @@ impl<F: IsPrimeField> Polynomial<F> {
for (i, coeff) in self.coefficients.iter().enumerate().skip(1) {
let u_64 = i as u64;
let degree = &FieldElement::<F>::from(u_64);
new_coeffs[i - 1] = *(&coeff) * degree;
new_coeffs[i - 1] = coeff * degree;
}
Polynomial::new(new_coeffs)
}
Expand Down
30 changes: 30 additions & 0 deletions tools/garaga_rs/src/extf_mul.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use crate::ecip::{
curve::{get_irreducible_poly, CurveParamsProvider},
polynomial::{pad_with_zero_coefficients_to_length, Polynomial},
};
use lambdaworks_math::field::traits::IsPrimeField;

// Returns (Q(X), R(X)) such that Π(Pi)(X) = Q(X) * P_irr(X) + R(X), for a given curve and extension degree.
// R(X) is the result of the multiplication in the extension field.
// Q(X) is used for verification.
pub fn nondeterministic_extension_field_mul_divmod<F: IsPrimeField + CurveParamsProvider<F>>(
ext_degree: usize,
ps: Vec<Polynomial<F>>,
) -> (Polynomial<F>, Polynomial<F>) {
let mut z_poly = Polynomial::one();
for poly in ps {
z_poly = z_poly.mul_with_ref(&poly);
}

let p_irr = get_irreducible_poly(ext_degree);

let (z_polyq, mut z_polyr) = z_poly.divmod(&p_irr);
assert!(z_polyr.coefficients.len() <= ext_degree);

// Extend polynomial with 0 coefficients to match the expected length.
if z_polyr.coefficients.len() < ext_degree {
pad_with_zero_coefficients_to_length(&mut z_polyr, ext_degree);
}

(z_polyq, z_polyr)
}
24 changes: 24 additions & 0 deletions tools/garaga_rs/src/io.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use lambdaworks_math::field::element::FieldElement;
use lambdaworks_math::field::traits::IsPrimeField;
use lambdaworks_math::traits::ByteConversion;
use num_bigint::BigUint;

pub fn parse_field_elements_from_list<F: IsPrimeField>(
coeffs: &[BigUint],
) -> Result<Vec<FieldElement<F>>, String>
where
FieldElement<F>: ByteConversion,
{
let length = (F::field_bit_size() + 7) / 8;
coeffs
.iter()
.map(|x| {
let bytes = x.to_bytes_be();
let pad_length = length.saturating_sub(bytes.len());
let mut padded_bytes = vec![0u8; pad_length];
padded_bytes.extend(bytes);
FieldElement::from_bytes_be(&padded_bytes)
.map_err(|e| format!("Byte conversion error: {:?}", e))
})
.collect()
}
Loading

0 comments on commit 933784e

Please sign in to comment.