Skip to content

Commit

Permalink
fix: do not accept invalid maci keys
Browse files Browse the repository at this point in the history
  • Loading branch information
ctrlc03 committed Apr 30, 2024
1 parent eb75e40 commit feb1eec
Show file tree
Hide file tree
Showing 11 changed files with 257 additions and 41 deletions.
13 changes: 10 additions & 3 deletions contracts/contracts/MACI.sol
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { Params } from "./utilities/Params.sol";
import { TopupCredit } from "./TopupCredit.sol";
import { Utilities } from "./utilities/Utilities.sol";
import { DomainObjs } from "./utilities/DomainObjs.sol";
import { CurveBabyJubJub } from "./crypto/BabyJubJub.sol";
import { Ownable } from "@openzeppelin/contracts/access/Ownable.sol";

/// @title MACI - Minimum Anti-Collusion Infrastructure Version 1
Expand Down Expand Up @@ -101,7 +102,7 @@ contract MACI is IMACI, DomainObjs, Params, Utilities, Ownable(msg.sender) {
error CallerMustBePoll(address _caller);
error PoseidonHashLibrariesNotLinked();
error TooManySignups();
error MaciPubKeyLargerThanSnarkFieldSize();
error InvalidPubKey();
error PreviousPollNotCompleted(uint256 pollId);
error PollDoesNotExist(uint256 pollId);
error SignupTemporaryBlocked();
Expand Down Expand Up @@ -168,8 +169,9 @@ contract MACI is IMACI, DomainObjs, Params, Utilities, Ownable(msg.sender) {
// ensure we do not have more signups than what the circuits support
if (numSignUps >= uint256(TREE_ARITY) ** uint256(stateTreeDepth)) revert TooManySignups();

if (_pubKey.x >= SNARK_SCALAR_FIELD || _pubKey.y >= SNARK_SCALAR_FIELD) {
revert MaciPubKeyLargerThanSnarkFieldSize();
// ensure that the public key is on the baby jubjub curve
if (!CurveBabyJubJub.isOnCurve(_pubKey.x, _pubKey.y)) {
revert InvalidPubKey();
}

// Increment the number of signups
Expand Down Expand Up @@ -219,6 +221,11 @@ contract MACI is IMACI, DomainObjs, Params, Utilities, Ownable(msg.sender) {
nextPollId++;
}

// check coordinator key is a valid point on the curve
if (!CurveBabyJubJub.isOnCurve(_coordinatorPubKey.x, _coordinatorPubKey.y)) {
revert InvalidPubKey();
}

if (pollId > 0) {
if (!stateAq.treeMerged()) revert PreviousPollNotCompleted(pollId);
}
Expand Down
13 changes: 7 additions & 6 deletions contracts/contracts/Poll.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { SafeERC20 } from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.s
import { EmptyBallotRoots } from "./trees/EmptyBallotRoots.sol";
import { IPoll } from "./interfaces/IPoll.sol";
import { Utilities } from "./utilities/Utilities.sol";
import { CurveBabyJubJub } from "./crypto/BabyJubJub.sol";

/// @title Poll
/// @notice A Poll contract allows voters to submit encrypted messages
Expand Down Expand Up @@ -68,7 +69,7 @@ contract Poll is Params, Utilities, SnarkCommon, Ownable(msg.sender), EmptyBallo
error VotingPeriodNotOver();
error PollAlreadyInit();
error TooManyMessages();
error MaciPubKeyLargerThanSnarkFieldSize();
error InvalidPubKey();
error StateAqAlreadyMerged();
error StateAqSubtreesNeedMerge();
error InvalidBatchLength();
Expand All @@ -95,8 +96,8 @@ contract Poll is Params, Utilities, SnarkCommon, Ownable(msg.sender), EmptyBallo
ExtContracts memory _extContracts
) payable {
// check that the coordinator public key is valid
if (_coordinatorPubKey.x >= SNARK_SCALAR_FIELD || _coordinatorPubKey.y >= SNARK_SCALAR_FIELD) {
revert MaciPubKeyLargerThanSnarkFieldSize();
if (!CurveBabyJubJub.isOnCurve(_coordinatorPubKey.x, _coordinatorPubKey.y)) {
revert InvalidPubKey();
}

// store the pub key as object then calculate the hash
Expand Down Expand Up @@ -183,9 +184,9 @@ contract Poll is Params, Utilities, SnarkCommon, Ownable(msg.sender), EmptyBallo
// we check that we do not exceed the max number of messages
if (numMessages >= maxValues.maxMessages) revert TooManyMessages();

// validate that the public key is valid
if (_encPubKey.x >= SNARK_SCALAR_FIELD || _encPubKey.y >= SNARK_SCALAR_FIELD) {
revert MaciPubKeyLargerThanSnarkFieldSize();
// check if the public key is on the curve
if (!CurveBabyJubJub.isOnCurve(_encPubKey.x, _encPubKey.y)) {
revert InvalidPubKey();
}

// cannot realistically overflow
Expand Down
138 changes: 138 additions & 0 deletions contracts/contracts/crypto/BabyJubJub.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// @note This code was taken from
// https://github.com/yondonfu/sol-baby-jubjub/blob/master/contracts/CurveBabyJubJub.sol
// Thanks to yondonfu for the code
// Implementation cited on baby-jubjub's paper
// https://eips.ethereum.org/EIPS/eip-2494#implementation

// SPDX-License-Identifier: MIT
pragma solidity ^0.8.20;

library CurveBabyJubJub {
// Curve parameters
// E: 168700x^2 + y^2 = 1 + 168696x^2y^2
// A = 168700
uint256 public constant A = 0x292FC;
// D = 168696
uint256 public constant D = 0x292F8;
// Prime Q = 21888242871839275222246405745257275088548364400416034343698204186575808495617
uint256 public constant Q = 0x30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001;

/**
* @dev Add 2 points on baby jubjub curve
* Formula for adding 2 points on a twisted Edwards curve:
* x3 = (x1y2 + y1x2) / (1 + dx1x2y1y2)
* y3 = (y1y2 - ax1x2) / (1 - dx1x2y1y2)
*/
function pointAdd(uint256 _x1, uint256 _y1, uint256 _x2, uint256 _y2) internal view returns (uint256 x3, uint256 y3) {

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

if (_x1 == 0 && _y1 == 0) {
return (_x2, _y2);
}

if (_x2 == 0 && _y1 == 0) {
return (_x1, _y1);
}

uint256 x1x2 = mulmod(_x1, _x2, Q);
uint256 y1y2 = mulmod(_y1, _y2, Q);
uint256 dx1x2y1y2 = mulmod(D, mulmod(x1x2, y1y2, Q), Q);
uint256 x3Num = addmod(mulmod(_x1, _y2, Q), mulmod(_y1, _x2, Q), Q);
uint256 y3Num = submod(y1y2, mulmod(A, x1x2, Q), Q);

x3 = mulmod(x3Num, inverse(addmod(1, dx1x2y1y2, Q)), Q);
y3 = mulmod(y3Num, inverse(submod(1, dx1x2y1y2, Q)), Q);
}

Check warning

Code scanning / Slither

Dead-code Warning

CurveBabyJubJub.pointAdd(uint256,uint256,uint256,uint256) is never used and should be removed

/**
* @dev Double a point on baby jubjub curve
* Doubling can be performed with the same formula as addition
*/
function pointDouble(uint256 _x1, uint256 _y1) internal view returns (uint256 x2, uint256 y2) {

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Parameter CurveBabyJubJub.pointDouble(uint256,uint256)._y1 is not in mixedCase

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Parameter CurveBabyJubJub.pointDouble(uint256,uint256)._x1 is not in mixedCase
return pointAdd(_x1, _y1, _x1, _y1);
}

Check warning

Code scanning / Slither

Dead-code Warning

CurveBabyJubJub.pointDouble(uint256,uint256) is never used and should be removed

/**
* @dev Multiply a point on baby jubjub curve by a scalar
* Use the double and add algorithm
*/
function pointMul(uint256 _x1, uint256 _y1, uint256 _d) internal view returns (uint256 x2, uint256 y2) {

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint256 remaining = _d;

uint256 px = _x1;
uint256 py = _y1;
uint256 ax = 0;
uint256 ay = 0;

while (remaining != 0) {
if ((remaining & 1) != 0) {
// Binary digit is 1 so add
(ax, ay) = pointAdd(ax, ay, px, py);
}

(px, py) = pointDouble(px, py);

remaining = remaining / 2;
}

x2 = ax;
y2 = ay;
}

Check warning

Code scanning / Slither

Dead-code Warning

CurveBabyJubJub.pointMul(uint256,uint256,uint256) is never used and should be removed

/**
* @dev Check if a given point is on the curve
* (168700x^2 + y^2) - (1 + 168696x^2y^2) == 0
*/
function isOnCurve(uint256 _x, uint256 _y) internal pure returns (bool) {

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Parameter CurveBabyJubJub.isOnCurve(uint256,uint256)._x is not in mixedCase

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Parameter CurveBabyJubJub.isOnCurve(uint256,uint256)._y is not in mixedCase
uint256 xSq = mulmod(_x, _x, Q);
uint256 ySq = mulmod(_y, _y, Q);
uint256 lhs = addmod(mulmod(A, xSq, Q), ySq, Q);
uint256 rhs = addmod(1, mulmod(mulmod(D, xSq, Q), ySq, Q), Q);
return submod(lhs, rhs, Q) == 0;
}

/**
* @dev Perform modular subtraction
*/
function submod(uint256 _a, uint256 _b, uint256 _mod) internal pure returns (uint256) {

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint256 aNN = _a;

if (_a <= _b) {
aNN += _mod;
}

return addmod(aNN - _b, 0, _mod);
}

/**
* @dev Compute modular inverse of a number
*/
function inverse(uint256 _a) internal view returns (uint256) {

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Parameter CurveBabyJubJub.inverse(uint256)._a is not in mixedCase
// We can use Euler's theorem instead of the extended Euclidean algorithm
// Since m = Q and Q is prime we have: a^-1 = a^(m - 2) (mod m)
return expmod(_a, Q - 2, Q);
}

Check warning

Code scanning / Slither

Dead-code Warning

CurveBabyJubJub.inverse(uint256) is never used and should be removed

/**
* @dev Helper function to call the bigModExp precompile
*/
function expmod(uint256 _b, uint256 _e, uint256 _m) internal view returns (uint256 o) {

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

assembly {

Check warning on line 118 in contracts/contracts/crypto/BabyJubJub.sol

View workflow job for this annotation

GitHub Actions / check (lint:sol)

Avoid to use inline assembly. It is acceptable only in rare cases
let memPtr := mload(0x40)
mstore(memPtr, 0x20) // Length of base _b
mstore(add(memPtr, 0x20), 0x20) // Length of exponent _e
mstore(add(memPtr, 0x40), 0x20) // Length of modulus _m
mstore(add(memPtr, 0x60), _b) // Base _b
mstore(add(memPtr, 0x80), _e) // Exponent _e
mstore(add(memPtr, 0xa0), _m) // Modulus _m

// The bigModExp precompile is at 0x05
let success := staticcall(gas(), 0x05, memPtr, 0xc0, memPtr, 0x20)
switch success
case 0 {
revert(0x0, 0x0)
}
default {
o := mload(memPtr)
}
}
}

Check warning

Code scanning / Slither

Assembly usage Warning

Check warning

Code scanning / Slither

Dead-code Warning

CurveBabyJubJub.expmod(uint256,uint256,uint256) is never used and should be removed
}

Check warning

Code scanning / Slither

Too many digits Warning

46 changes: 44 additions & 2 deletions contracts/tests/MACI.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ describe("MACI", () => {
}
});

it("should fail when given an invalid pubkey", async () => {
it("should fail when given an invalid pubkey (x >= p)", async () => {
await expect(
maciContract.signUp(
{
Expand All @@ -132,7 +132,49 @@ describe("MACI", () => {
AbiCoder.defaultAbiCoder().encode(["uint256"], [0]),
signUpTxOpts,
),
).to.be.revertedWithCustomError(maciContract, "MaciPubKeyLargerThanSnarkFieldSize");
).to.be.revertedWithCustomError(maciContract, "InvalidPubKey");
});

it("should fail when given an invalid pubkey (y >= p)", async () => {
await expect(
maciContract.signUp(
{
x: "1",
y: "21888242871839275222246405745257275088548364400416034343698204186575808495617",
},
AbiCoder.defaultAbiCoder().encode(["uint256"], [1]),
AbiCoder.defaultAbiCoder().encode(["uint256"], [0]),
signUpTxOpts,
),
).to.be.revertedWithCustomError(maciContract, "InvalidPubKey");
});

it("should fail when given an invalid pubkey (x >= p and y >= p)", async () => {
await expect(
maciContract.signUp(
{
x: "21888242871839275222246405745257275088548364400416034343698204186575808495617",
y: "21888242871839275222246405745257275088548364400416034343698204186575808495617",
},
AbiCoder.defaultAbiCoder().encode(["uint256"], [1]),
AbiCoder.defaultAbiCoder().encode(["uint256"], [0]),
signUpTxOpts,
),
).to.be.revertedWithCustomError(maciContract, "InvalidPubKey");
});

it("should fail when given an invalid public key (not on the curve)", async () => {
await expect(
maciContract.signUp(
{
x: "1",
y: "1",
},
AbiCoder.defaultAbiCoder().encode(["uint256"], [1]),
AbiCoder.defaultAbiCoder().encode(["uint256"], [0]),
signUpTxOpts,
),
).to.be.revertedWithCustomError(maciContract, "InvalidPubKey");
});

it("should not allow to sign up more than the supported amount of users (5 ** stateTreeDepth)", async () => {
Expand Down
31 changes: 31 additions & 0 deletions contracts/tests/Poll.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,26 @@ describe("Poll", () => {
const numMessages = await pollContract.numMessages();
expect(numMessages.toString()).to.eq("1");
});

it("should fail when passing an invalid coordinator public key", async () => {
const r = await deployTestContracts(initialVoiceCreditBalance, STATE_TREE_DEPTH, signer, true);
const testMaciContract = r.maciContract;

// deploy on chain poll
await expect(
testMaciContract.deployPoll(
duration,
treeDepths,
new PubKey([100n, 1n]).asContractParam(),
r.mockVerifierContract as Verifier,
r.vkRegistryContract,
EMode.QV,
{
gasLimit: 10000000,
},
),
).to.be.revertedWithCustomError(testMaciContract, "InvalidPubKey");
});
});

describe("topup", () => {
Expand Down Expand Up @@ -202,6 +222,17 @@ describe("Poll", () => {
maciState.polls.get(pollId)?.publishMessage(message, keypair.pubKey);
});

it("should throw when the encPubKey is not a point on the baby jubjub curve", async () => {
const command = new PCommand(1n, keypair.pubKey, 0n, 9n, 1n, pollId, 0n);

const signature = command.sign(keypair.privKey);
const sharedKey = Keypair.genEcdhSharedKey(keypair.privKey, coordinator.pubKey);
const message = command.encrypt(signature, sharedKey);
await expect(
pollContract.publishMessage(message.asContractParam(), new PubKey([1n, 1n]).asContractParam()),
).to.be.revertedWithCustomError(pollContract, "InvalidPubKey");
});

it("should emit an event when publishing a message", async () => {
const command = new PCommand(1n, keypair.pubKey, 0n, 9n, 1n, pollId, 0n);

Expand Down
2 changes: 2 additions & 0 deletions crypto/ts/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ export { G1Point, G2Point, genRandomBabyJubValue } from "./babyjub";

export { sha256Hash, hashLeftRight, hashN, hash2, hash3, hash4, hash5, hash13, hashOne } from "./hashing";

export { inCurve } from "@zk-kit/baby-jubjub";

export { poseidonDecrypt, poseidonDecryptWithoutCheck, poseidonEncrypt } from "@zk-kit/poseidon-cipher";

export { verifySignature, signMessage as sign } from "@zk-kit/eddsa-poseidon";
Expand Down
6 changes: 6 additions & 0 deletions domainobjs/ts/__tests__/keypair.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ describe("keypair", () => {
expect(k1.privKey.rawPrivKey).not.to.eq(k2.privKey.rawPrivKey);
});

it("should always generate a valid keypair", () => {
for (let i = 0; i < 100; i += 1) {
expect(() => new Keypair()).to.not.throw();
}
});

it("should generate the correct public key given a private key", () => {
const rawKeyPair = genKeypair();
const k = new Keypair(new PrivKey(rawKeyPair.privKey));
Expand Down
9 changes: 0 additions & 9 deletions domainobjs/ts/__tests__/publicKey.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,6 @@ describe("public key", () => {
expect(unpacked[0].toString()).to.eq(pk1.rawPubKey[0].toString());
expect(unpacked[1].toString()).to.eq(pk1.rawPubKey[1].toString());
});
it("should serialize into an invalid serialized key when the key is [bigint(0), bigint(0)]", () => {
const pk1 = new PubKey([BigInt(0), BigInt(0)]);
const s = pk1.serialize();
expect(s).to.eq("macipk.z");
});
});

describe("deserialize", () => {
Expand All @@ -74,10 +69,6 @@ describe("public key", () => {
const pk2 = PubKey.deserialize(s);
expect(pk1.rawPubKey.toString()).to.eq(pk2.rawPubKey.toString());
});
it("should deserialize into an invalid serialized key when the key is equal to macipk.z", () => {
const pk1 = PubKey.deserialize("macipk.z");
expect(pk1.rawPubKey.toString()).to.eq([BigInt(0), BigInt(0)].toString());
});
});

describe("isValidSerializedPubKey", () => {
Expand Down
10 changes: 6 additions & 4 deletions domainobjs/ts/commands/PCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ export class PCommand implements ICommand {
* Decrypts a Message to produce a Command.
* @dev You can force decrypt the message by setting `force` to true.
* This is useful in case you don't want an invalid message to throw an error.
* @param {Message} message - the message to decrypt
* @param {EcdhSharedKey} sharedKey - the shared key to use for decryption
* @param {boolean} force - whether to force decryption or not
* @param message - the message to decrypt
* @param sharedKey - the shared key to use for decryption
* @param force - whether to force decryption or not
*/
static decrypt = (message: Message, sharedKey: EcdhSharedKey, force = false): IDecryptMessage => {
const decrypted = force
Expand Down Expand Up @@ -207,7 +207,9 @@ export class PCommand implements ICommand {
const nonce = extract(p, 150);
const pollId = extract(p, 200);

const newPubKey = new PubKey([decrypted[1], decrypted[2]]);
// create new public key but allow it to be invalid (as when passing an mismatched
// encPubKey, a message will not decrypt resulting in potentially invalid public keys)
const newPubKey = new PubKey([decrypted[1], decrypted[2]], true);
const salt = decrypted[3];

const command = new PCommand(stateIndex, newPubKey, voteOptionIndex, newVoteWeight, nonce, pollId, salt);
Expand Down
Loading

0 comments on commit feb1eec

Please sign in to comment.