Skip to content

Commit

Permalink
Merge pull request #1089 from privacy-scaling-explorations/fix/1066
Browse files Browse the repository at this point in the history
refactor(contracts): remove maxValues from Maci.deployPoll() and instead calculate
  • Loading branch information
ctrlc03 authored Jan 24, 2024
2 parents bb6c728 + 5393868 commit 013baa4
Show file tree
Hide file tree
Showing 24 changed files with 116 additions and 152 deletions.
2 changes: 0 additions & 2 deletions cli/tests/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ export const deployArgs: DeployArgs = {

export const deployPollArgs: DeployPollArgs = {
pollDuration,
maxMessages,
maxVoteOptions,
intStateTreeDepth: INT_STATE_TREE_DEPTH,
messageTreeSubDepth: MSG_BATCH_DEPTH,
messageTreeDepth: MSG_TREE_DEPTH,
Expand Down
11 changes: 0 additions & 11 deletions cli/ts/commands/deployPoll.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import {
*/
export const deployPoll = async ({
pollDuration,
maxMessages,
maxVoteOptions,
intStateTreeDepth,
messageTreeSubDepth,
messageTreeDepth,
Expand Down Expand Up @@ -57,14 +55,6 @@ export const deployPoll = async ({
if (pollDuration <= 0) {
logError("Duration cannot be <= 0");
}
// require arg -> max messages
if (maxMessages <= 0) {
logError("Max messages cannot be <= 0");
}
// required arg -> max vote options
if (maxVoteOptions <= 0) {
logError("Max vote options cannot be <= 0");
}

// required arg -> int state tree depth
if (intStateTreeDepth <= 0) {
Expand Down Expand Up @@ -111,7 +101,6 @@ export const deployPoll = async ({
// deploy the poll contract via the maci contract
const tx = await maciContract.deployPoll(
pollDuration,
{ maxMessages, maxVoteOptions },
{
intStateTreeDepth,
messageTreeSubDepth,
Expand Down
21 changes: 15 additions & 6 deletions cli/ts/commands/proveOnChain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
parseArtifact,
getDefaultNetwork,
} from "maci-contracts";
import { STATE_TREE_ARITY } from "maci-core";
import { G1Point, G2Point, hashLeftRight } from "maci-crypto";
import { VerifyingKey } from "maci-domainobjs";

Expand Down Expand Up @@ -188,13 +189,13 @@ export const proveOnChain = async ({
});

// retrieve the values we need from the smart contracts
const treeDepths = await pollContract.treeDepths();
const numSignUpsAndMessages = await pollContract.numSignUpsAndMessages();
const numSignUps = Number(numSignUpsAndMessages[0]);
const numMessages = Number(numSignUpsAndMessages[1]);
const batchSizes = await pollContract.batchSizes();
const messageBatchSize = Number(batchSizes.messageBatchSize);
const tallyBatchSize = Number(batchSizes.tallyBatchSize);
const subsidyBatchSize = Number(batchSizes.subsidyBatchSize);
const messageBatchSize = STATE_TREE_ARITY ** Number(treeDepths.messageTreeSubDepth);
const tallyBatchSize = STATE_TREE_ARITY ** Number(treeDepths.intStateTreeDepth);
const subsidyBatchSize = STATE_TREE_ARITY ** Number(treeDepths.intStateTreeDepth);
let totalMessageBatches = numMessages <= messageBatchSize ? 1 : Math.floor(numMessages / messageBatchSize);

if (numMessages > messageBatchSize && numMessages % messageBatchSize > 0) {
Expand All @@ -212,7 +213,6 @@ export const proveOnChain = async ({
);
}

const treeDepths = await pollContract.treeDepths();
let numberBatchesProcessed = Number(await mpContract.numBatchesProcessed());
const messageRootOnChain = await messageAqContract.getMainRoot(Number(treeDepths.messageTreeDepth));

Expand Down Expand Up @@ -286,7 +286,13 @@ export const proveOnChain = async ({
}

const packedValsOnChain = BigInt(
await mpContract.genProcessMessagesPackedVals(currentMessageBatchIndex, numSignUps),
await mpContract.genProcessMessagesPackedVals(
currentMessageBatchIndex,
numSignUps,
numMessages,
treeDepths.messageTreeSubDepth,
treeDepths.voteOptionTreeDepth,
),
).toString();

if (circuitInputs.packedVals !== packedValsOnChain) {
Expand All @@ -300,8 +306,11 @@ export const proveOnChain = async ({
currentMessageBatchIndex,
messageRootOnChain.toString(),
numSignUps,
numMessages,
circuitInputs.currentSbCommitment as BigNumberish,
circuitInputs.newSbCommitment as BigNumberish,
treeDepths.messageTreeSubDepth,
treeDepths.voteOptionTreeDepth,
),
);

Expand Down
4 changes: 0 additions & 4 deletions cli/ts/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,6 @@ program
.description("deploy a new poll")
.option("-vk, --vkRegistryAddress <vkRegistryAddress>", "the vk registry contract address")
.requiredOption("-t, --duration <pollDuration>", "the poll duration", parseInt)
.requiredOption("-g, --max-messages <maxMessages>", "the max messages", parseInt)
.requiredOption("-mv, --max-vote-options <maxVoteOptions>", "the max vote options", parseInt)
.requiredOption("-i, --int-state-tree-depth <intStateTreeDepth>", "the int state tree depth", parseInt)
.requiredOption("-b, --msg-batch-depth <messageTreeSubDepth>", "the message tree sub depth", parseInt)
.requiredOption("-m, --msg-tree-depth <messageTreeDepth>", "the message tree depth", parseInt)
Expand All @@ -198,8 +196,6 @@ program
try {
await deployPoll({
pollDuration: cmdObj.duration,
maxMessages: cmdObj.maxMessages,
maxVoteOptions: cmdObj.maxVoteOptions,
intStateTreeDepth: cmdObj.intStateTreeDepth,
messageTreeSubDepth: cmdObj.msgBatchDepth,
messageTreeDepth: cmdObj.msgTreeDepth,
Expand Down
10 changes: 0 additions & 10 deletions cli/ts/utils/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -285,16 +285,6 @@ export interface DeployPollArgs {
*/
pollDuration: number;

/**
* The maximum number of messages that can be submitted
*/
maxMessages: number;

/**
* The maximum number of vote options
*/
maxVoteOptions: number;

/**
* The depth of the intermediate state tree
*/
Expand Down
23 changes: 5 additions & 18 deletions contracts/contracts/MACI.sol
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ contract MACI is IMACI, Params, Utilities, Ownable {

/// @notice Deploy a new Poll contract.
/// @param _duration How long should the Poll last for
/// @param _maxValues The maximum number of vote options, and messages
/// @param _treeDepths The depth of the Merkle trees
/// @param _coordinatorPubKey The coordinator's public key
/// @param _verifier The Verifier Contract
Expand All @@ -196,7 +195,6 @@ contract MACI is IMACI, Params, Utilities, Ownable {
/// @return pollAddr a new Poll contract address
function deployPoll(
uint256 _duration,
MaxValues memory _maxValues,
TreeDepths memory _treeDepths,
PubKey memory _coordinatorPubKey,
address _verifier,
Expand All @@ -216,25 +214,14 @@ contract MACI is IMACI, Params, Utilities, Ownable {
if (!stateAq.treeMerged()) revert PreviousPollNotCompleted(pollId);
}

// The message batch size and the tally batch size
BatchSizes memory batchSizes = BatchSizes(
uint24(TREE_ARITY) ** _treeDepths.messageTreeSubDepth,
uint24(TREE_ARITY) ** _treeDepths.intStateTreeDepth,
uint24(TREE_ARITY) ** _treeDepths.intStateTreeDepth
);
MaxValues memory maxValues = MaxValues({
maxMessages: uint256(TREE_ARITY) ** _treeDepths.messageTreeDepth,
maxVoteOptions: uint256(TREE_ARITY) ** _treeDepths.voteOptionTreeDepth
});

address _owner = owner();

address p = pollFactory.deploy(
_duration,
_maxValues,
_treeDepths,
batchSizes,
_coordinatorPubKey,
this,
topupCredit,
_owner
);
address p = pollFactory.deploy(_duration, maxValues, _treeDepths, _coordinatorPubKey, this, topupCredit, _owner);

address mp = messageProcessorFactory.deploy(_verifier, _vkRegistry, p, _owner);
address tally = tallyFactory.deploy(_verifier, _vkRegistry, p, mp, _owner);
Expand Down
90 changes: 66 additions & 24 deletions contracts/contracts/MessageProcessor.sol
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes
error CurrentMessageBatchIndexTooLarge();
error BatchEndIndexTooLarge();

// the number of children per node in the merkle trees
uint256 private constant TREE_ARITY = 5;

/// @inheritdoc IMessageProcessor
bool public processingComplete;

Expand Down Expand Up @@ -61,6 +64,7 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes
/// after all messages are processed
/// @param _proof The zk-SNARK proof
function processMessages(uint256 _newSbCommitment, uint256[8] memory _proof) external onlyOwner {
// ensure the voting period is over
_votingPeriodOver(poll);

// There must be unprocessed messages
Expand All @@ -74,11 +78,11 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes
}

// Retrieve stored vals
(, , uint8 messageTreeDepth, ) = poll.treeDepths();
(uint256 messageBatchSize, , ) = poll.batchSizes();
(, uint8 messageTreeSubDepth, uint8 messageTreeDepth, uint8 voteOptionTreeDepth) = poll.treeDepths();
// calculate the message batch size from the message tree subdepth
uint256 messageBatchSize = TREE_ARITY ** messageTreeSubDepth;

AccQueue messageAq;
(, messageAq, ) = poll.extContracts();
(, AccQueue messageAq, ) = poll.extContracts();

// Require that the message queue has been merged
uint256 messageRoot = messageAq.getMainRoot(messageTreeDepth);
Expand All @@ -105,8 +109,18 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes
}
}

bool isValid = verifyProcessProof(currentMessageBatchIndex, messageRoot, sbCommitment, _newSbCommitment, _proof);
if (!isValid) {
if (
!verifyProcessProof(
currentMessageBatchIndex,
messageRoot,
sbCommitment,
_newSbCommitment,
messageTreeSubDepth,
messageTreeDepth,
voteOptionTreeDepth,
_proof
)
) {
revert InvalidProcessMessageProof();
}

Expand All @@ -132,35 +146,45 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes
/// @param _messageRoot The message tree root
/// @param _currentSbCommitment The current sbCommitment (state and ballot)
/// @param _newSbCommitment The new sbCommitment after we update this message batch
/// @param _messageTreeSubDepth The message tree subdepth
/// @param _messageTreeDepth The message tree depth
/// @param _voteOptionTreeDepth The vote option tree depth
/// @param _proof The zk-SNARK proof
/// @return isValid Whether the proof is valid
function verifyProcessProof(
uint256 _currentMessageBatchIndex,
uint256 _messageRoot,
uint256 _currentSbCommitment,
uint256 _newSbCommitment,
uint8 _messageTreeSubDepth,
uint8 _messageTreeDepth,
uint8 _voteOptionTreeDepth,
uint256[8] memory _proof
) internal view returns (bool isValid) {
(, , uint8 messageTreeDepth, uint8 voteOptionTreeDepth) = poll.treeDepths();
(uint256 messageBatchSize, , ) = poll.batchSizes();
(uint256 numSignUps, ) = poll.numSignUpsAndMessages();
// get the tree depths
// get the message batch size from the message tree subdepth
// get the number of signups
(uint256 numSignUps, uint256 numMessages) = poll.numSignUpsAndMessages();
(IMACI maci, , ) = poll.extContracts();

// Calculate the public input hash (a SHA256 hash of several values)
uint256 publicInputHash = genProcessMessagesPublicInputHash(
_currentMessageBatchIndex,
_messageRoot,
numSignUps,
numMessages,
_currentSbCommitment,
_newSbCommitment
_newSbCommitment,
_messageTreeSubDepth,
_voteOptionTreeDepth
);

// Get the verifying key from the VkRegistry
VerifyingKey memory vk = vkRegistry.getProcessVk(
maci.stateTreeDepth(),
messageTreeDepth,
voteOptionTreeDepth,
messageBatchSize
_messageTreeDepth,
_voteOptionTreeDepth,
TREE_ARITY ** _messageTreeSubDepth
);

isValid = verifier.verify(_proof, vk, publicInputHash);
Expand All @@ -175,22 +199,35 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes
/// higher and proving time will be longer.
/// @param _currentMessageBatchIndex The batch index of current message batch
/// @param _numSignUps The number of users that signup
/// @param _numMessages The number of messages
/// @param _currentSbCommitment The current sbCommitment (state and ballot root)
/// @param _newSbCommitment The new sbCommitment after we update this message batch
/// @param _messageTreeSubDepth The message tree subdepth
/// @return inputHash Returns the SHA256 hash of the packed values
function genProcessMessagesPublicInputHash(
uint256 _currentMessageBatchIndex,
uint256 _messageRoot,
uint256 _numSignUps,
uint256 _numMessages,
uint256 _currentSbCommitment,
uint256 _newSbCommitment
uint256 _newSbCommitment,
uint8 _messageTreeSubDepth,
uint8 _voteOptionTreeDepth
) public view returns (uint256 inputHash) {
uint256 coordinatorPubKeyHash = poll.coordinatorPubKeyHash();

uint256 packedVals = genProcessMessagesPackedVals(_currentMessageBatchIndex, _numSignUps);
// pack the values
uint256 packedVals = genProcessMessagesPackedVals(
_currentMessageBatchIndex,
_numSignUps,
_numMessages,
_messageTreeSubDepth,
_voteOptionTreeDepth
);

(uint256 deployTime, uint256 duration) = poll.getDeployTimeAndDuration();

// generate the circuit only public input
uint256[] memory input = new uint256[](6);
input[0] = packedVals;
input[1] = coordinatorPubKeyHash;
Expand All @@ -208,19 +245,24 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes
/// the current batch.
/// @param _currentMessageBatchIndex batch index of current message batch
/// @param _numSignUps number of users that signup
/// @param _numMessages number of messages
/// @param _messageTreeSubDepth message tree subdepth
/// @param _voteOptionTreeDepth vote option tree depth
/// @return result The packed value
function genProcessMessagesPackedVals(
uint256 _currentMessageBatchIndex,
uint256 _numSignUps
) public view returns (uint256 result) {
(, uint256 maxVoteOptions) = poll.maxValues();
(, uint256 numMessages) = poll.numSignUpsAndMessages();
(uint24 mbs, , ) = poll.batchSizes();
uint256 messageBatchSize = uint256(mbs);

uint256 _numSignUps,
uint256 _numMessages,
uint8 _messageTreeSubDepth,
uint8 _voteOptionTreeDepth
) public pure returns (uint256 result) {
uint256 maxVoteOptions = TREE_ARITY ** _voteOptionTreeDepth;

// calculate the message batch size from the message tree subdepth
uint256 messageBatchSize = TREE_ARITY ** _messageTreeSubDepth;
uint256 batchEndIndex = _currentMessageBatchIndex + messageBatchSize;
if (batchEndIndex > numMessages) {
batchEndIndex = numMessages;
if (batchEndIndex > _numMessages) {
batchEndIndex = _numMessages;
}

if (maxVoteOptions >= 2 ** 50) revert MaxVoteOptionsTooLarge();
Expand Down
5 changes: 0 additions & 5 deletions contracts/contracts/Poll.sol
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ contract Poll is Params, Utilities, SnarkCommon, Ownable, EmptyBallotRoots, IPol
/// @notice Depths of the merkle trees
TreeDepths public treeDepths;

/// @notice Batch sizes for processing
BatchSizes public batchSizes;
/// @notice The contracts used by the Poll
ExtContracts public extContracts;

Expand All @@ -79,14 +77,12 @@ contract Poll is Params, Utilities, SnarkCommon, Ownable, EmptyBallotRoots, IPol
/// @param _duration The duration of the voting period, in seconds
/// @param _maxValues The maximum number of messages and vote options
/// @param _treeDepths The depths of the merkle trees
/// @param _batchSizes The batch sizes for processing
/// @param _coordinatorPubKey The coordinator's public key
/// @param _extContracts The external contracts
constructor(
uint256 _duration,
MaxValues memory _maxValues,
TreeDepths memory _treeDepths,
BatchSizes memory _batchSizes,
PubKey memory _coordinatorPubKey,
ExtContracts memory _extContracts
) payable {
Expand All @@ -96,7 +92,6 @@ contract Poll is Params, Utilities, SnarkCommon, Ownable, EmptyBallotRoots, IPol
coordinatorPubKeyHash = hashLeftRight(_coordinatorPubKey.x, _coordinatorPubKey.y);
duration = _duration;
maxValues = _maxValues;
batchSizes = _batchSizes;
treeDepths = _treeDepths;
// Record the current timestamp
deployTime = block.timestamp;
Expand Down
Loading

0 comments on commit 013baa4

Please sign in to comment.