-
Notifications
You must be signed in to change notification settings - Fork 301
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(avm-simulator): msm blackbox (#7048)
Please read [contributing guidelines](CONTRIBUTING.md) and remove this line.
- Loading branch information
1 parent
bf6b8b8
commit 0ce27e0
Showing
12 changed files
with
314 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
130 changes: 130 additions & 0 deletions
130
yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import { Fq, Fr } from '@aztec/circuits.js'; | ||
import { Grumpkin } from '@aztec/circuits.js/barretenberg'; | ||
|
||
import { type AvmContext } from '../avm_context.js'; | ||
import { Field, type MemoryValue, Uint8, Uint32 } from '../avm_memory_types.js'; | ||
import { initContext } from '../fixtures/index.js'; | ||
import { MultiScalarMul } from './multi_scalar_mul.js'; | ||
|
||
describe('MultiScalarMul Opcode', () => { | ||
let context: AvmContext; | ||
|
||
beforeEach(async () => { | ||
context = initContext(); | ||
}); | ||
it('Should (de)serialize correctly', () => { | ||
const buf = Buffer.from([ | ||
MultiScalarMul.opcode, // opcode | ||
7, // indirect | ||
...Buffer.from('12345678', 'hex'), // pointsOffset | ||
...Buffer.from('23456789', 'hex'), // scalars Offset | ||
...Buffer.from('3456789a', 'hex'), // outputOffset | ||
...Buffer.from('456789ab', 'hex'), // pointsLengthOffset | ||
]); | ||
const inst = new MultiScalarMul( | ||
/*indirect=*/ 7, | ||
/*pointsOffset=*/ 0x12345678, | ||
/*scalarsOffset=*/ 0x23456789, | ||
/*outputOffset=*/ 0x3456789a, | ||
/*pointsLengthOffset=*/ 0x456789ab, | ||
); | ||
|
||
expect(MultiScalarMul.deserialize(buf)).toEqual(inst); | ||
expect(inst.serialize()).toEqual(buf); | ||
}); | ||
|
||
it('Should perform msm correctly - direct', async () => { | ||
const indirect = 0; | ||
const grumpkin = new Grumpkin(); | ||
// We need to ensure points are actually on curve, so we just use the generator | ||
// In future we could use a random point, for now we create an array of [G, 2G, 3G] | ||
const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1))); | ||
|
||
// Pick some big scalars to test the edge cases | ||
const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)]; | ||
const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory | ||
const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory | ||
// Transform the points and scalars into the format that we will write to memory | ||
// We just store the x and y coordinates here, and handle the infinities when we write to memory | ||
const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]); | ||
// Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...] | ||
const storedPoints: MemoryValue[] = points | ||
.map(p => p.toFieldsWithInf()) | ||
.flatMap(([x, y, inf]) => [new Field(x), new Field(y), new Uint8(inf.toNumber())]); | ||
const pointsOffset = 0; | ||
context.machineState.memory.setSlice(pointsOffset, storedPoints); | ||
// Store scalars | ||
const scalarsOffset = pointsOffset + pointsReadLength; | ||
context.machineState.memory.setSlice(scalarsOffset, storedScalars); | ||
// Store length of points to read | ||
const pointsLengthOffset = scalarsOffset + scalarsLength; | ||
context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength)); | ||
const outputOffset = pointsLengthOffset + 1; | ||
|
||
await new MultiScalarMul(indirect, pointsOffset, scalarsOffset, outputOffset, pointsLengthOffset).execute(context); | ||
|
||
const result = context.machineState.memory.getSlice(outputOffset, 3).map(r => r.toFr()); | ||
|
||
// We write it out explicitly here | ||
let expectedResult = grumpkin.mul(points[0], scalars[0]); | ||
expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1])); | ||
expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2])); | ||
|
||
expect(result).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); | ||
}); | ||
|
||
it('Should perform msm correctly - indirect', async () => { | ||
const indirect = 7; | ||
const grumpkin = new Grumpkin(); | ||
// We need to ensure points are actually on curve, so we just use the generator | ||
// In future we could use a random point, for now we create an array of [G, 2G, 3G] | ||
const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1))); | ||
|
||
// Pick some big scalars to test the edge cases | ||
const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)]; | ||
const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory | ||
const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory | ||
// Transform the points and scalars into the format that we will write to memory | ||
// We just store the x and y coordinates here, and handle the infinities when we write to memory | ||
const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]); | ||
// Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...] | ||
const storedPoints: MemoryValue[] = points | ||
.map(p => p.toFieldsWithInf()) | ||
.flatMap(([x, y, inf]) => [new Field(x), new Field(y), new Uint8(inf.toNumber())]); | ||
const pointsOffset = 0; | ||
context.machineState.memory.setSlice(pointsOffset, storedPoints); | ||
// Store scalars | ||
const scalarsOffset = pointsOffset + pointsReadLength; | ||
context.machineState.memory.setSlice(scalarsOffset, storedScalars); | ||
// Store length of points to read | ||
const pointsLengthOffset = scalarsOffset + scalarsLength; | ||
context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength)); | ||
const outputOffset = pointsLengthOffset + 1; | ||
|
||
// Set up the indirect pointers | ||
const pointsIndirectOffset = outputOffset + 3; /* 3 since the output is a triplet */ | ||
const scalarsIndirectOffset = pointsIndirectOffset + 1; | ||
const outputIndirectOffset = scalarsIndirectOffset + 1; | ||
|
||
context.machineState.memory.set(pointsIndirectOffset, new Uint32(pointsOffset)); | ||
context.machineState.memory.set(scalarsIndirectOffset, new Uint32(scalarsOffset)); | ||
context.machineState.memory.set(outputIndirectOffset, new Uint32(outputOffset)); | ||
|
||
await new MultiScalarMul( | ||
indirect, | ||
pointsIndirectOffset, | ||
scalarsIndirectOffset, | ||
outputIndirectOffset, | ||
pointsLengthOffset, | ||
).execute(context); | ||
|
||
const result = context.machineState.memory.getSlice(outputOffset, 3).map(r => r.toFr()); | ||
|
||
// We write it out explicitly here | ||
let expectedResult = grumpkin.mul(points[0], scalars[0]); | ||
expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1])); | ||
expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2])); | ||
|
||
expect(result).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); | ||
}); | ||
}); |
114 changes: 114 additions & 0 deletions
114
yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import { Fq, Point } from '@aztec/circuits.js'; | ||
import { Grumpkin } from '@aztec/circuits.js/barretenberg'; | ||
|
||
import { strict as assert } from 'assert'; | ||
|
||
import { type AvmContext } from '../avm_context.js'; | ||
import { Field, TypeTag } from '../avm_memory_types.js'; | ||
import { InstructionExecutionError } from '../errors.js'; | ||
import { Opcode, OperandType } from '../serialization/instruction_serialization.js'; | ||
import { Addressing } from './addressing_mode.js'; | ||
import { Instruction } from './instruction.js'; | ||
|
||
export class MultiScalarMul extends Instruction { | ||
static type: string = 'MultiScalarMul'; | ||
static readonly opcode: Opcode = Opcode.MSM; | ||
|
||
// Informs (de)serialization. See Instruction.deserialize. | ||
static readonly wireFormat: OperandType[] = [ | ||
OperandType.UINT8 /* opcode */, | ||
OperandType.UINT8 /* indirect */, | ||
OperandType.UINT32 /* points vector offset */, | ||
OperandType.UINT32 /* scalars vector offset */, | ||
OperandType.UINT32 /* output offset (fixed triplet) */, | ||
OperandType.UINT32 /* points length offset */, | ||
]; | ||
|
||
constructor( | ||
private indirect: number, | ||
private pointsOffset: number, | ||
private scalarsOffset: number, | ||
private outputOffset: number, | ||
private pointsLengthOffset: number, | ||
) { | ||
super(); | ||
} | ||
|
||
public async execute(context: AvmContext): Promise<void> { | ||
const memory = context.machineState.memory.track(this.type); | ||
// Resolve indirects | ||
const [pointsOffset, scalarsOffset, outputOffset] = Addressing.fromWire(this.indirect).resolve( | ||
[this.pointsOffset, this.scalarsOffset, this.outputOffset], | ||
memory, | ||
); | ||
|
||
// Length of the points vector should be U32 | ||
memory.checkTag(TypeTag.UINT32, this.pointsLengthOffset); | ||
// Get the size of the unrolled (x, y , inf) points vector | ||
const pointsReadLength = memory.get(this.pointsLengthOffset).toNumber(); | ||
assert(pointsReadLength % 3 === 0, 'Points vector offset should be a multiple of 3'); | ||
// Divide by 3 since each point is represented as a triplet to get the number of points | ||
const numPoints = pointsReadLength / 3; | ||
// The tag for each triplet will be (Field, Field, Uint8) | ||
for (let i = 0; i < numPoints; i++) { | ||
const offset = pointsOffset + i * 3; | ||
// Check (Field, Field) | ||
memory.checkTagsRange(TypeTag.FIELD, offset, 2); | ||
// Check Uint8 (inf flag) | ||
memory.checkTag(TypeTag.UINT8, offset + 2); | ||
} | ||
// Get the unrolled (x, y, inf) representing the points | ||
const pointsVector = memory.getSlice(pointsOffset, pointsReadLength); | ||
|
||
// The size of the scalars vector is twice the NUMBER of points because of the scalar limb decomposition | ||
const scalarReadLength = numPoints * 2; | ||
// Consume gas prior to performing work | ||
const memoryOperations = { | ||
reads: 1 + pointsReadLength + scalarReadLength /* points and scalars */, | ||
writes: 3 /* output triplet */, | ||
indirect: this.indirect, | ||
}; | ||
context.machineState.consumeGas(this.gasCost(memoryOperations)); | ||
// Get the unrolled scalar (lo & hi) representing the scalars | ||
const scalarsVector = memory.getSlice(scalarsOffset, scalarReadLength); | ||
memory.checkTagsRange(TypeTag.FIELD, scalarsOffset, scalarReadLength); | ||
|
||
// Now we need to reconstruct the points and scalars into something we can operate on. | ||
const grumpkinPoints: Point[] = []; | ||
for (let i = 0; i < numPoints; i++) { | ||
const p: Point = new Point(pointsVector[3 * i].toFr(), pointsVector[3 * i + 1].toFr()); | ||
// Include this later when we have a standard for representing infinity | ||
// const isInf = pointsVector[i + 2].toBoolean(); | ||
|
||
if (!p.isOnGrumpkin()) { | ||
throw new InstructionExecutionError(`Point ${p.toString()} is not on the curve.`); | ||
} | ||
grumpkinPoints.push(p); | ||
} | ||
// The scalars are read from memory as Fr elements, which are limbs of Fq elements | ||
// So we need to reconstruct them before performing the scalar multiplications | ||
const scalarFqVector: Fq[] = []; | ||
for (let i = 0; i < numPoints; i++) { | ||
const scalarLo = scalarsVector[2 * i].toFr(); | ||
const scalarHi = scalarsVector[2 * i + 1].toFr(); | ||
const fqScalar = Fq.fromHighLow(scalarHi, scalarLo); | ||
scalarFqVector.push(fqScalar); | ||
} | ||
// TODO: Is there an efficient MSM implementation in ts that we can replace this by? | ||
const grumpkin = new Grumpkin(); | ||
// Zip the points and scalars into pairs | ||
const [firstBaseScalarPair, ...rest]: Array<[Point, Fq]> = grumpkinPoints.map((p, idx) => [p, scalarFqVector[idx]]); | ||
// Fold the points and scalars into a single point | ||
// We have to ensure get the first point, since the identity element (point at infinity) isn't quite working in ts | ||
const outputPoint = rest.reduce( | ||
(acc, curr) => grumpkin.add(acc, grumpkin.mul(curr[0], curr[1])), | ||
grumpkin.mul(firstBaseScalarPair[0], firstBaseScalarPair[1]), | ||
); | ||
const output = outputPoint.toFieldsWithInf().map(f => new Field(f)); | ||
|
||
memory.setSlice(outputOffset, output); | ||
|
||
memory.assert(memoryOperations); | ||
context.machineState.incrementPc(); | ||
} | ||
} |
Oops, something went wrong.