Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: do not accept invalid maci keys #1408

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions contracts/contracts/MACI.sol
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { Params } from "./utilities/Params.sol";
import { TopupCredit } from "./TopupCredit.sol";
import { Utilities } from "./utilities/Utilities.sol";
import { DomainObjs } from "./utilities/DomainObjs.sol";
import { CurveBabyJubJub } from "./crypto/BabyJubJub.sol";
import { Ownable } from "@openzeppelin/contracts/access/Ownable.sol";

/// @title MACI - Minimum Anti-Collusion Infrastructure Version 1
Expand Down Expand Up @@ -101,7 +102,7 @@ contract MACI is IMACI, DomainObjs, Params, Utilities, Ownable(msg.sender) {
error CallerMustBePoll(address _caller);
error PoseidonHashLibrariesNotLinked();
error TooManySignups();
error MaciPubKeyLargerThanSnarkFieldSize();
error InvalidPubKey();
error PreviousPollNotCompleted(uint256 pollId);
error PollDoesNotExist(uint256 pollId);
error SignupTemporaryBlocked();
Expand Down Expand Up @@ -168,8 +169,9 @@ contract MACI is IMACI, DomainObjs, Params, Utilities, Ownable(msg.sender) {
// ensure we do not have more signups than what the circuits support
if (numSignUps >= uint256(TREE_ARITY) ** uint256(stateTreeDepth)) revert TooManySignups();

if (_pubKey.x >= SNARK_SCALAR_FIELD || _pubKey.y >= SNARK_SCALAR_FIELD) {
revert MaciPubKeyLargerThanSnarkFieldSize();
// ensure that the public key is on the baby jubjub curve
if (!CurveBabyJubJub.isOnCurve(_pubKey.x, _pubKey.y)) {
kittybest marked this conversation as resolved.
Show resolved Hide resolved
revert InvalidPubKey();
}

// Increment the number of signups
Expand Down Expand Up @@ -219,6 +221,11 @@ contract MACI is IMACI, DomainObjs, Params, Utilities, Ownable(msg.sender) {
nextPollId++;
}

// check coordinator key is a valid point on the curve
if (!CurveBabyJubJub.isOnCurve(_coordinatorPubKey.x, _coordinatorPubKey.y)) {
revert InvalidPubKey();
}

if (pollId > 0) {
if (!stateAq.treeMerged()) revert PreviousPollNotCompleted(pollId);
}
Expand Down
13 changes: 7 additions & 6 deletions contracts/contracts/Poll.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { SafeERC20 } from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.s
import { EmptyBallotRoots } from "./trees/EmptyBallotRoots.sol";
import { IPoll } from "./interfaces/IPoll.sol";
import { Utilities } from "./utilities/Utilities.sol";
import { CurveBabyJubJub } from "./crypto/BabyJubJub.sol";

/// @title Poll
/// @notice A Poll contract allows voters to submit encrypted messages
Expand Down Expand Up @@ -68,7 +69,7 @@ contract Poll is Params, Utilities, SnarkCommon, Ownable(msg.sender), EmptyBallo
error VotingPeriodNotOver();
error PollAlreadyInit();
error TooManyMessages();
error MaciPubKeyLargerThanSnarkFieldSize();
error InvalidPubKey();
error StateAqAlreadyMerged();
error StateAqSubtreesNeedMerge();
error InvalidBatchLength();
Expand All @@ -95,8 +96,8 @@ contract Poll is Params, Utilities, SnarkCommon, Ownable(msg.sender), EmptyBallo
ExtContracts memory _extContracts
) payable {
// check that the coordinator public key is valid
if (_coordinatorPubKey.x >= SNARK_SCALAR_FIELD || _coordinatorPubKey.y >= SNARK_SCALAR_FIELD) {
revert MaciPubKeyLargerThanSnarkFieldSize();
if (!CurveBabyJubJub.isOnCurve(_coordinatorPubKey.x, _coordinatorPubKey.y)) {
revert InvalidPubKey();
}

// store the pub key as object then calculate the hash
Expand Down Expand Up @@ -183,9 +184,9 @@ contract Poll is Params, Utilities, SnarkCommon, Ownable(msg.sender), EmptyBallo
// we check that we do not exceed the max number of messages
if (numMessages >= maxValues.maxMessages) revert TooManyMessages();

// validate that the public key is valid
if (_encPubKey.x >= SNARK_SCALAR_FIELD || _encPubKey.y >= SNARK_SCALAR_FIELD) {
revert MaciPubKeyLargerThanSnarkFieldSize();
// check if the public key is on the curve
if (!CurveBabyJubJub.isOnCurve(_encPubKey.x, _encPubKey.y)) {
revert InvalidPubKey();
}

// cannot realistically overflow
Expand Down
138 changes: 138 additions & 0 deletions contracts/contracts/crypto/BabyJubJub.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// @note This code was taken from
// https://github.com/yondonfu/sol-baby-jubjub/blob/master/contracts/CurveBabyJubJub.sol
// Thanks to yondonfu for the code
// Implementation cited on baby-jubjub's paper
// https://eips.ethereum.org/EIPS/eip-2494#implementation

// SPDX-License-Identifier: MIT
pragma solidity ^0.8.20;

library CurveBabyJubJub {
// Curve parameters
// E: 168700x^2 + y^2 = 1 + 168696x^2y^2
// A = 168700
uint256 public constant A = 0x292FC;
// D = 168696
uint256 public constant D = 0x292F8;
// Prime Q = 21888242871839275222246405745257275088548364400416034343698204186575808495617
uint256 public constant Q = 0x30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001;

/**
* @dev Add 2 points on baby jubjub curve
* Formula for adding 2 points on a twisted Edwards curve:
* x3 = (x1y2 + y1x2) / (1 + dx1x2y1y2)
* y3 = (y1y2 - ax1x2) / (1 - dx1x2y1y2)
*/
function pointAdd(uint256 _x1, uint256 _y1, uint256 _x2, uint256 _y2) internal view returns (uint256 x3, uint256 y3) {

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

if (_x1 == 0 && _y1 == 0) {
return (_x2, _y2);
}

if (_x2 == 0 && _y1 == 0) {
return (_x1, _y1);
}

uint256 x1x2 = mulmod(_x1, _x2, Q);
uint256 y1y2 = mulmod(_y1, _y2, Q);
uint256 dx1x2y1y2 = mulmod(D, mulmod(x1x2, y1y2, Q), Q);
uint256 x3Num = addmod(mulmod(_x1, _y2, Q), mulmod(_y1, _x2, Q), Q);
uint256 y3Num = submod(y1y2, mulmod(A, x1x2, Q), Q);

x3 = mulmod(x3Num, inverse(addmod(1, dx1x2y1y2, Q)), Q);
y3 = mulmod(y3Num, inverse(submod(1, dx1x2y1y2, Q)), Q);
}

Check warning

Code scanning / Slither

Dead-code Warning

CurveBabyJubJub.pointAdd(uint256,uint256,uint256,uint256) is never used and should be removed

/**
* @dev Double a point on baby jubjub curve
* Doubling can be performed with the same formula as addition
*/
function pointDouble(uint256 _x1, uint256 _y1) internal view returns (uint256 x2, uint256 y2) {

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Parameter CurveBabyJubJub.pointDouble(uint256,uint256)._y1 is not in mixedCase

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Parameter CurveBabyJubJub.pointDouble(uint256,uint256)._x1 is not in mixedCase
return pointAdd(_x1, _y1, _x1, _y1);
}

Check warning

Code scanning / Slither

Dead-code Warning

CurveBabyJubJub.pointDouble(uint256,uint256) is never used and should be removed

/**
* @dev Multiply a point on baby jubjub curve by a scalar
* Use the double and add algorithm
*/
function pointMul(uint256 _x1, uint256 _y1, uint256 _d) internal view returns (uint256 x2, uint256 y2) {

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint256 remaining = _d;

uint256 px = _x1;
uint256 py = _y1;
uint256 ax = 0;
uint256 ay = 0;

while (remaining != 0) {
if ((remaining & 1) != 0) {
// Binary digit is 1 so add
(ax, ay) = pointAdd(ax, ay, px, py);
}

(px, py) = pointDouble(px, py);

remaining = remaining / 2;
}

x2 = ax;
y2 = ay;
}

Check warning

Code scanning / Slither

Dead-code Warning

CurveBabyJubJub.pointMul(uint256,uint256,uint256) is never used and should be removed

/**
* @dev Check if a given point is on the curve
* (168700x^2 + y^2) - (1 + 168696x^2y^2) == 0
*/
function isOnCurve(uint256 _x, uint256 _y) internal pure returns (bool) {

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Parameter CurveBabyJubJub.isOnCurve(uint256,uint256)._x is not in mixedCase

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Parameter CurveBabyJubJub.isOnCurve(uint256,uint256)._y is not in mixedCase
uint256 xSq = mulmod(_x, _x, Q);
uint256 ySq = mulmod(_y, _y, Q);
uint256 lhs = addmod(mulmod(A, xSq, Q), ySq, Q);
uint256 rhs = addmod(1, mulmod(mulmod(D, xSq, Q), ySq, Q), Q);
return submod(lhs, rhs, Q) == 0;
}

/**
* @dev Perform modular subtraction
*/
function submod(uint256 _a, uint256 _b, uint256 _mod) internal pure returns (uint256) {

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint256 aNN = _a;

if (_a <= _b) {
aNN += _mod;
}

return addmod(aNN - _b, 0, _mod);
}

/**
* @dev Compute modular inverse of a number
*/
function inverse(uint256 _a) internal view returns (uint256) {

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Parameter CurveBabyJubJub.inverse(uint256)._a is not in mixedCase
// We can use Euler's theorem instead of the extended Euclidean algorithm
// Since m = Q and Q is prime we have: a^-1 = a^(m - 2) (mod m)
return expmod(_a, Q - 2, Q);
}

Check warning

Code scanning / Slither

Dead-code Warning

CurveBabyJubJub.inverse(uint256) is never used and should be removed

/**
* @dev Helper function to call the bigModExp precompile
*/
function expmod(uint256 _b, uint256 _e, uint256 _m) internal view returns (uint256 o) {

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

assembly {

Check warning on line 118 in contracts/contracts/crypto/BabyJubJub.sol

View workflow job for this annotation

GitHub Actions / check (lint:sol)

Avoid to use inline assembly. It is acceptable only in rare cases
let memPtr := mload(0x40)
mstore(memPtr, 0x20) // Length of base _b
mstore(add(memPtr, 0x20), 0x20) // Length of exponent _e
mstore(add(memPtr, 0x40), 0x20) // Length of modulus _m
mstore(add(memPtr, 0x60), _b) // Base _b
mstore(add(memPtr, 0x80), _e) // Exponent _e
mstore(add(memPtr, 0xa0), _m) // Modulus _m

// The bigModExp precompile is at 0x05
let success := staticcall(gas(), 0x05, memPtr, 0xc0, memPtr, 0x20)
switch success
case 0 {
revert(0x0, 0x0)
}
default {
o := mload(memPtr)
}
}
}

Check warning

Code scanning / Slither

Assembly usage Warning

Check warning

Code scanning / Slither

Dead-code Warning

CurveBabyJubJub.expmod(uint256,uint256,uint256) is never used and should be removed
}

Check warning

Code scanning / Slither

Too many digits Warning

46 changes: 44 additions & 2 deletions contracts/tests/MACI.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ describe("MACI", () => {
}
});

it("should fail when given an invalid pubkey", async () => {
it("should fail when given an invalid pubkey (x >= p)", async () => {
await expect(
maciContract.signUp(
{
Expand All @@ -132,7 +132,49 @@ describe("MACI", () => {
AbiCoder.defaultAbiCoder().encode(["uint256"], [0]),
signUpTxOpts,
),
).to.be.revertedWithCustomError(maciContract, "MaciPubKeyLargerThanSnarkFieldSize");
).to.be.revertedWithCustomError(maciContract, "InvalidPubKey");
});

it("should fail when given an invalid pubkey (y >= p)", async () => {
await expect(
maciContract.signUp(
{
x: "1",
y: "21888242871839275222246405745257275088548364400416034343698204186575808495617",
},
AbiCoder.defaultAbiCoder().encode(["uint256"], [1]),
AbiCoder.defaultAbiCoder().encode(["uint256"], [0]),
signUpTxOpts,
),
).to.be.revertedWithCustomError(maciContract, "InvalidPubKey");
});

it("should fail when given an invalid pubkey (x >= p and y >= p)", async () => {
await expect(
maciContract.signUp(
{
x: "21888242871839275222246405745257275088548364400416034343698204186575808495617",
y: "21888242871839275222246405745257275088548364400416034343698204186575808495617",
},
AbiCoder.defaultAbiCoder().encode(["uint256"], [1]),
AbiCoder.defaultAbiCoder().encode(["uint256"], [0]),
signUpTxOpts,
),
).to.be.revertedWithCustomError(maciContract, "InvalidPubKey");
});

it("should fail when given an invalid public key (not on the curve)", async () => {
await expect(
maciContract.signUp(
{
x: "1",
y: "1",
},
AbiCoder.defaultAbiCoder().encode(["uint256"], [1]),
AbiCoder.defaultAbiCoder().encode(["uint256"], [0]),
signUpTxOpts,
),
).to.be.revertedWithCustomError(maciContract, "InvalidPubKey");
});

it("should not allow to sign up more than the supported amount of users (5 ** stateTreeDepth)", async () => {
Expand Down
45 changes: 40 additions & 5 deletions contracts/tests/Poll.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,29 @@ describe("Poll", () => {
const numMessages = await pollContract.numMessages();
expect(numMessages.toString()).to.eq("1");
});

it("should fail when passing an invalid coordinator public key", async () => {
const r = await deployTestContracts(initialVoiceCreditBalance, STATE_TREE_DEPTH, signer, true);
const testMaciContract = r.maciContract;

// deploy on chain poll
await expect(
testMaciContract.deployPoll(
duration,
treeDepths,
{
x: "100",
y: "1",
},
r.mockVerifierContract as Verifier,
r.vkRegistryContract,
EMode.QV,
{
gasLimit: 10000000,
},
),
).to.be.revertedWithCustomError(testMaciContract, "InvalidPubKey");
});
});

describe("topup", () => {
Expand Down Expand Up @@ -202,6 +225,20 @@ describe("Poll", () => {
maciState.polls.get(pollId)?.publishMessage(message, keypair.pubKey);
});

it("should throw when the encPubKey is not a point on the baby jubjub curve", async () => {
const command = new PCommand(1n, keypair.pubKey, 0n, 9n, 1n, pollId, 0n);

const signature = command.sign(keypair.privKey);
const sharedKey = Keypair.genEcdhSharedKey(keypair.privKey, coordinator.pubKey);
const message = command.encrypt(signature, sharedKey);
await expect(
pollContract.publishMessage(message.asContractParam(), {
x: "1",
y: "1",
}),
).to.be.revertedWithCustomError(pollContract, "InvalidPubKey");
});

it("should emit an event when publishing a message", async () => {
const command = new PCommand(1n, keypair.pubKey, 0n, 9n, 1n, pollId, 0n);

Expand Down Expand Up @@ -252,7 +289,7 @@ describe("Poll", () => {

it("should not allow to publish a message after the voting period ends", async () => {
const dd = await pollContract.getDeployTimeAndDuration();
await timeTravel(signer.provider as unknown as EthereumProvider, Number(dd[0]) + 1);
await timeTravel(signer.provider as unknown as EthereumProvider, Number(dd[0]) + 10);

const command = new PCommand(1n, keypair.pubKey, 0n, 9n, 1n, pollId, 0n);

Expand All @@ -261,7 +298,7 @@ describe("Poll", () => {
const message = command.encrypt(signature, sharedKey);

await expect(
pollContract.publishMessage(message.asContractParam(), keypair.pubKey.asContractParam(), { gasLimit: 300000 }),
pollContract.publishMessage(message.asContractParam(), keypair.pubKey.asContractParam()),
).to.be.revertedWithCustomError(pollContract, "VotingPeriodOver");
});

Expand All @@ -271,9 +308,7 @@ describe("Poll", () => {
const sharedKey = Keypair.genEcdhSharedKey(keypair.privKey, coordinator.pubKey);
const message = command.encrypt(signature, sharedKey);
await expect(
pollContract.publishMessageBatch([message.asContractParam()], [keypair.pubKey.asContractParam()], {
gasLimit: 300000,
}),
pollContract.publishMessageBatch([message.asContractParam()], [keypair.pubKey.asContractParam()]),
).to.be.revertedWithCustomError(pollContract, "VotingPeriodOver");
});
});
Expand Down
2 changes: 2 additions & 0 deletions crypto/ts/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ export { G1Point, G2Point, genRandomBabyJubValue } from "./babyjub";

export { sha256Hash, hashLeftRight, hashN, hash2, hash3, hash4, hash5, hash13, hashOne } from "./hashing";

export { inCurve } from "@zk-kit/baby-jubjub";

export { poseidonDecrypt, poseidonDecryptWithoutCheck, poseidonEncrypt } from "@zk-kit/poseidon-cipher";

export { verifySignature, signMessage as sign } from "@zk-kit/eddsa-poseidon";
Expand Down
9 changes: 8 additions & 1 deletion domainobjs/ts/__tests__/keypair.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ import { genKeypair, genPrivKey } from "maci-crypto";

import { Keypair, PrivKey } from "..";

describe("keypair", () => {
describe("keypair", function test() {
this.timeout(900000);
describe("constructor", () => {
it("should generate a random keypair if not provided a private key", () => {
const k1 = new Keypair();
Expand All @@ -14,6 +15,12 @@ describe("keypair", () => {
expect(k1.privKey.rawPrivKey).not.to.eq(k2.privKey.rawPrivKey);
});

it("should always generate a valid keypair", () => {
for (let i = 0; i < 100; i += 1) {
expect(() => new Keypair()).to.not.throw();
}
});

it("should generate the correct public key given a private key", () => {
const rawKeyPair = genKeypair();
const k = new Keypair(new PrivKey(rawKeyPair.privKey));
Expand Down
Loading
Loading