Skip to content

Commit

Permalink
feat(avm-simulator): msm blackbox (#7048)
Browse files Browse the repository at this point in the history
Please read [contributing guidelines](CONTRIBUTING.md) and remove this
line.
  • Loading branch information
IlyasRidhuan authored Jun 17, 2024
1 parent bf6b8b8 commit 0ce27e0
Show file tree
Hide file tree
Showing 12 changed files with 314 additions and 2 deletions.
2 changes: 2 additions & 0 deletions avm-transpiler/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pub enum AvmOpcode {
SHA256, // temp - may be removed, but alot of contracts rely on it
PEDERSEN, // temp - may be removed, but alot of contracts rely on it
ECADD,
MSM,
// Conversions
TORADIXLE,
}
Expand Down Expand Up @@ -165,6 +166,7 @@ impl AvmOpcode {
AvmOpcode::SHA256 => "SHA256 ",
AvmOpcode::PEDERSEN => "PEDERSEN",
AvmOpcode::ECADD => "ECADD",
AvmOpcode::MSM => "MSM",
// Conversions
AvmOpcode::TORADIXLE => "TORADIXLE",
}
Expand Down
24 changes: 24 additions & 0 deletions avm-transpiler/src/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,30 @@ fn handle_black_box_function(avm_instrs: &mut Vec<AvmInstruction>, operation: &B
],
..Default::default()
}),
// Temporary while we dont have efficient noir implementations
BlackBoxOp::MultiScalarMul { points, scalars, outputs } => {
// The length of the scalars vector is 2x the length of the points vector due to limb
// decomposition
let points_offset = points.pointer.0;
let num_points = points.size.0;
let scalars_offset = scalars.pointer.0;
// Output array is fixed to 3
assert_eq!(outputs.size, 3, "Output array size must be equal to 3");
let outputs_offset = outputs.pointer.0;
avm_instrs.push(AvmInstruction {
opcode: AvmOpcode::MSM,
indirect: Some(
ZEROTH_OPERAND_INDIRECT | FIRST_OPERAND_INDIRECT | SECOND_OPERAND_INDIRECT,
),
operands: vec![
AvmOperand::U32 { value: points_offset as u32 },
AvmOperand::U32 { value: scalars_offset as u32 },
AvmOperand::U32 { value: outputs_offset as u32 },
AvmOperand::U32 { value: num_points as u32 },
],
..Default::default()
});
}
_ => panic!("Transpiler doesn't know how to process {:?}", operation),
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ const std::unordered_map<OpCode, std::vector<OperandType>> OPCODE_WIRE_FORMAT =
OperandType::UINT32, // rhs.y
OperandType::UINT32, // rhs.is_infinite
OperandType::UINT32 } }, // dst_offset
{ OpCode::MSM,
{ OperandType::INDIRECT, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32 } },
// Gadget - Conversion
{ OpCode::TORADIXLE,
{ OperandType::INDIRECT, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32 } },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ enum class OpCode : uint8_t {
SHA256,
PEDERSEN,
ECADD,
MSM,
// Conversions
TORADIXLE,
// Future Gadgets -- pending changes in noir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ contract AvmTest {
global big_field_136_bits: Field = 0x991234567890abcdef1234567890abcdef;

// Libs
use dep::std::embedded_curve_ops::EmbeddedCurvePoint;
use dep::std::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar, multi_scalar_mul};
use dep::aztec::protocol_types::constants::CONTRACT_INSTANCE_LENGTH;
use dep::aztec::prelude::{Map, Deserialize};
use dep::aztec::state_vars::PublicMutable;
Expand Down Expand Up @@ -144,6 +144,15 @@ contract AvmTest {
added
}

#[aztec(public)]
fn variable_base_msm() -> [Field; 3] {
let g = EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false };
let scalar = EmbeddedCurveScalar { lo: 3, hi: 0 };
let scalar2 = EmbeddedCurveScalar { lo: 20, hi: 0 };
let triple_g = multi_scalar_mul([g, g], [scalar, scalar2]);
triple_g
}

/************************************************************************
* Misc
************************************************************************/
Expand Down
14 changes: 13 additions & 1 deletion yarn-project/foundation/src/fields/point.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,20 @@ export class Point {
return poseidon2Hash(this.toFields());
}

/**
* Check if this is point at infinity.
* Check this is consistent with how bb is encoding the point at infinity
*/
public get inf() {
return this.x == Fr.ZERO;
}
public toFieldsWithInf() {
return [this.x, this.y, new Fr(this.inf)];
}

isOnGrumpkin() {
if (this.isZero()) {
// TODO: Check this against how bb handles curve check and infinity point check
if (this.inf) {
return true;
}

Expand Down
1 change: 1 addition & 0 deletions yarn-project/simulator/src/avm/avm_gas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ const BaseGasCosts: Record<Opcode, Gas> = {
[Opcode.SHA256]: DefaultBaseGasCost,
[Opcode.PEDERSEN]: DefaultBaseGasCost,
[Opcode.ECADD]: DefaultBaseGasCost,
[Opcode.MSM]: DefaultBaseGasCost,
// Conversions
[Opcode.TORADIXLE]: DefaultBaseGasCost,
};
Expand Down
14 changes: 14 additions & 0 deletions yarn-project/simulator/src/avm/avm_simulator.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@ describe('AVM simulator: transpiled Noir contracts', () => {
expect(results.output).toEqual([g3.x, g3.y, Fr.ZERO]);
});

it('variable msm operations', async () => {
const context = initContext();

const bytecode = getAvmTestContractBytecode('variable_base_msm');
const results = await new AvmSimulator(context).executeBytecode(bytecode);

expect(results.reverted).toBe(false);
const grumpkin = new Grumpkin();
const g3 = grumpkin.mul(grumpkin.generator(), new Fq(3));
const g20 = grumpkin.mul(grumpkin.generator(), new Fq(20));
const expectedResult = grumpkin.add(g3, g20);
expect(results.output).toEqual([expectedResult.x, expectedResult.y, Fr.ZERO]);
});

describe('U128 addition and overflows', () => {
it('U128 addition', async () => {
const calldata: Fr[] = [
Expand Down
130 changes: 130 additions & 0 deletions yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import { Fq, Fr } from '@aztec/circuits.js';
import { Grumpkin } from '@aztec/circuits.js/barretenberg';

import { type AvmContext } from '../avm_context.js';
import { Field, type MemoryValue, Uint8, Uint32 } from '../avm_memory_types.js';
import { initContext } from '../fixtures/index.js';
import { MultiScalarMul } from './multi_scalar_mul.js';

describe('MultiScalarMul Opcode', () => {
let context: AvmContext;

beforeEach(async () => {
context = initContext();
});
it('Should (de)serialize correctly', () => {
const buf = Buffer.from([
MultiScalarMul.opcode, // opcode
7, // indirect
...Buffer.from('12345678', 'hex'), // pointsOffset
...Buffer.from('23456789', 'hex'), // scalars Offset
...Buffer.from('3456789a', 'hex'), // outputOffset
...Buffer.from('456789ab', 'hex'), // pointsLengthOffset
]);
const inst = new MultiScalarMul(
/*indirect=*/ 7,
/*pointsOffset=*/ 0x12345678,
/*scalarsOffset=*/ 0x23456789,
/*outputOffset=*/ 0x3456789a,
/*pointsLengthOffset=*/ 0x456789ab,
);

expect(MultiScalarMul.deserialize(buf)).toEqual(inst);
expect(inst.serialize()).toEqual(buf);
});

it('Should perform msm correctly - direct', async () => {
const indirect = 0;
const grumpkin = new Grumpkin();
// We need to ensure points are actually on curve, so we just use the generator
// In future we could use a random point, for now we create an array of [G, 2G, 3G]
const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1)));

// Pick some big scalars to test the edge cases
const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)];
const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory
const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory
// Transform the points and scalars into the format that we will write to memory
// We just store the x and y coordinates here, and handle the infinities when we write to memory
const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]);
// Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...]
const storedPoints: MemoryValue[] = points
.map(p => p.toFieldsWithInf())
.flatMap(([x, y, inf]) => [new Field(x), new Field(y), new Uint8(inf.toNumber())]);
const pointsOffset = 0;
context.machineState.memory.setSlice(pointsOffset, storedPoints);
// Store scalars
const scalarsOffset = pointsOffset + pointsReadLength;
context.machineState.memory.setSlice(scalarsOffset, storedScalars);
// Store length of points to read
const pointsLengthOffset = scalarsOffset + scalarsLength;
context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength));
const outputOffset = pointsLengthOffset + 1;

await new MultiScalarMul(indirect, pointsOffset, scalarsOffset, outputOffset, pointsLengthOffset).execute(context);

const result = context.machineState.memory.getSlice(outputOffset, 3).map(r => r.toFr());

// We write it out explicitly here
let expectedResult = grumpkin.mul(points[0], scalars[0]);
expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1]));
expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2]));

expect(result).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]);
});

it('Should perform msm correctly - indirect', async () => {
const indirect = 7;
const grumpkin = new Grumpkin();
// We need to ensure points are actually on curve, so we just use the generator
// In future we could use a random point, for now we create an array of [G, 2G, 3G]
const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1)));

// Pick some big scalars to test the edge cases
const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)];
const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory
const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory
// Transform the points and scalars into the format that we will write to memory
// We just store the x and y coordinates here, and handle the infinities when we write to memory
const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]);
// Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...]
const storedPoints: MemoryValue[] = points
.map(p => p.toFieldsWithInf())
.flatMap(([x, y, inf]) => [new Field(x), new Field(y), new Uint8(inf.toNumber())]);
const pointsOffset = 0;
context.machineState.memory.setSlice(pointsOffset, storedPoints);
// Store scalars
const scalarsOffset = pointsOffset + pointsReadLength;
context.machineState.memory.setSlice(scalarsOffset, storedScalars);
// Store length of points to read
const pointsLengthOffset = scalarsOffset + scalarsLength;
context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength));
const outputOffset = pointsLengthOffset + 1;

// Set up the indirect pointers
const pointsIndirectOffset = outputOffset + 3; /* 3 since the output is a triplet */
const scalarsIndirectOffset = pointsIndirectOffset + 1;
const outputIndirectOffset = scalarsIndirectOffset + 1;

context.machineState.memory.set(pointsIndirectOffset, new Uint32(pointsOffset));
context.machineState.memory.set(scalarsIndirectOffset, new Uint32(scalarsOffset));
context.machineState.memory.set(outputIndirectOffset, new Uint32(outputOffset));

await new MultiScalarMul(
indirect,
pointsIndirectOffset,
scalarsIndirectOffset,
outputIndirectOffset,
pointsLengthOffset,
).execute(context);

const result = context.machineState.memory.getSlice(outputOffset, 3).map(r => r.toFr());

// We write it out explicitly here
let expectedResult = grumpkin.mul(points[0], scalars[0]);
expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1]));
expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2]));

expect(result).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]);
});
});
114 changes: 114 additions & 0 deletions yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import { Fq, Point } from '@aztec/circuits.js';
import { Grumpkin } from '@aztec/circuits.js/barretenberg';

import { strict as assert } from 'assert';

import { type AvmContext } from '../avm_context.js';
import { Field, TypeTag } from '../avm_memory_types.js';
import { InstructionExecutionError } from '../errors.js';
import { Opcode, OperandType } from '../serialization/instruction_serialization.js';
import { Addressing } from './addressing_mode.js';
import { Instruction } from './instruction.js';

export class MultiScalarMul extends Instruction {
static type: string = 'MultiScalarMul';
static readonly opcode: Opcode = Opcode.MSM;

// Informs (de)serialization. See Instruction.deserialize.
static readonly wireFormat: OperandType[] = [
OperandType.UINT8 /* opcode */,
OperandType.UINT8 /* indirect */,
OperandType.UINT32 /* points vector offset */,
OperandType.UINT32 /* scalars vector offset */,
OperandType.UINT32 /* output offset (fixed triplet) */,
OperandType.UINT32 /* points length offset */,
];

constructor(
private indirect: number,
private pointsOffset: number,
private scalarsOffset: number,
private outputOffset: number,
private pointsLengthOffset: number,
) {
super();
}

public async execute(context: AvmContext): Promise<void> {
const memory = context.machineState.memory.track(this.type);
// Resolve indirects
const [pointsOffset, scalarsOffset, outputOffset] = Addressing.fromWire(this.indirect).resolve(
[this.pointsOffset, this.scalarsOffset, this.outputOffset],
memory,
);

// Length of the points vector should be U32
memory.checkTag(TypeTag.UINT32, this.pointsLengthOffset);
// Get the size of the unrolled (x, y , inf) points vector
const pointsReadLength = memory.get(this.pointsLengthOffset).toNumber();
assert(pointsReadLength % 3 === 0, 'Points vector offset should be a multiple of 3');
// Divide by 3 since each point is represented as a triplet to get the number of points
const numPoints = pointsReadLength / 3;
// The tag for each triplet will be (Field, Field, Uint8)
for (let i = 0; i < numPoints; i++) {
const offset = pointsOffset + i * 3;
// Check (Field, Field)
memory.checkTagsRange(TypeTag.FIELD, offset, 2);
// Check Uint8 (inf flag)
memory.checkTag(TypeTag.UINT8, offset + 2);
}
// Get the unrolled (x, y, inf) representing the points
const pointsVector = memory.getSlice(pointsOffset, pointsReadLength);

// The size of the scalars vector is twice the NUMBER of points because of the scalar limb decomposition
const scalarReadLength = numPoints * 2;
// Consume gas prior to performing work
const memoryOperations = {
reads: 1 + pointsReadLength + scalarReadLength /* points and scalars */,
writes: 3 /* output triplet */,
indirect: this.indirect,
};
context.machineState.consumeGas(this.gasCost(memoryOperations));
// Get the unrolled scalar (lo & hi) representing the scalars
const scalarsVector = memory.getSlice(scalarsOffset, scalarReadLength);
memory.checkTagsRange(TypeTag.FIELD, scalarsOffset, scalarReadLength);

// Now we need to reconstruct the points and scalars into something we can operate on.
const grumpkinPoints: Point[] = [];
for (let i = 0; i < numPoints; i++) {
const p: Point = new Point(pointsVector[3 * i].toFr(), pointsVector[3 * i + 1].toFr());
// Include this later when we have a standard for representing infinity
// const isInf = pointsVector[i + 2].toBoolean();

if (!p.isOnGrumpkin()) {
throw new InstructionExecutionError(`Point ${p.toString()} is not on the curve.`);
}
grumpkinPoints.push(p);
}
// The scalars are read from memory as Fr elements, which are limbs of Fq elements
// So we need to reconstruct them before performing the scalar multiplications
const scalarFqVector: Fq[] = [];
for (let i = 0; i < numPoints; i++) {
const scalarLo = scalarsVector[2 * i].toFr();
const scalarHi = scalarsVector[2 * i + 1].toFr();
const fqScalar = Fq.fromHighLow(scalarHi, scalarLo);
scalarFqVector.push(fqScalar);
}
// TODO: Is there an efficient MSM implementation in ts that we can replace this by?
const grumpkin = new Grumpkin();
// Zip the points and scalars into pairs
const [firstBaseScalarPair, ...rest]: Array<[Point, Fq]> = grumpkinPoints.map((p, idx) => [p, scalarFqVector[idx]]);
// Fold the points and scalars into a single point
// We have to ensure get the first point, since the identity element (point at infinity) isn't quite working in ts
const outputPoint = rest.reduce(
(acc, curr) => grumpkin.add(acc, grumpkin.mul(curr[0], curr[1])),
grumpkin.mul(firstBaseScalarPair[0], firstBaseScalarPair[1]),
);
const output = outputPoint.toFieldsWithInf().map(f => new Field(f));

memory.setSlice(outputOffset, output);

memory.assert(memoryOperations);
context.machineState.incrementPc();
}
}
Loading

0 comments on commit 0ce27e0

Please sign in to comment.