Skip to content

Commit

Permalink
refactor(contracts): remove maxValues and batchSizes from Maci.deploy…
Browse files Browse the repository at this point in the history
…Poll() and instead calculate

Simplify logic and save gas by computing maxValues within MACI vs passing as parameter to
deployPoll. Also remove batchSizes and compute within the smart
contracts, as well as remove large number of calls to the poll contract,
and instead pass parameters around functions.

fix #1066
  • Loading branch information
ctrlc03 committed Jan 24, 2024
1 parent bb6c728 commit 5393868
Show file tree
Hide file tree
Showing 24 changed files with 116 additions and 152 deletions.
2 changes: 0 additions & 2 deletions cli/tests/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ export const deployArgs: DeployArgs = {

export const deployPollArgs: DeployPollArgs = {
pollDuration,
maxMessages,
maxVoteOptions,
intStateTreeDepth: INT_STATE_TREE_DEPTH,
messageTreeSubDepth: MSG_BATCH_DEPTH,
messageTreeDepth: MSG_TREE_DEPTH,
Expand Down
11 changes: 0 additions & 11 deletions cli/ts/commands/deployPoll.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import {
*/
export const deployPoll = async ({
pollDuration,
maxMessages,
maxVoteOptions,
intStateTreeDepth,
messageTreeSubDepth,
messageTreeDepth,
Expand Down Expand Up @@ -57,14 +55,6 @@ export const deployPoll = async ({
if (pollDuration <= 0) {
logError("Duration cannot be <= 0");
}
// require arg -> max messages
if (maxMessages <= 0) {
logError("Max messages cannot be <= 0");
}
// required arg -> max vote options
if (maxVoteOptions <= 0) {
logError("Max vote options cannot be <= 0");
}

// required arg -> int state tree depth
if (intStateTreeDepth <= 0) {
Expand Down Expand Up @@ -111,7 +101,6 @@ export const deployPoll = async ({
// deploy the poll contract via the maci contract
const tx = await maciContract.deployPoll(
pollDuration,
{ maxMessages, maxVoteOptions },
{
intStateTreeDepth,
messageTreeSubDepth,
Expand Down
21 changes: 15 additions & 6 deletions cli/ts/commands/proveOnChain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
parseArtifact,
getDefaultNetwork,
} from "maci-contracts";
import { STATE_TREE_ARITY } from "maci-core";
import { G1Point, G2Point, hashLeftRight } from "maci-crypto";
import { VerifyingKey } from "maci-domainobjs";

Expand Down Expand Up @@ -188,13 +189,13 @@ export const proveOnChain = async ({
});

// retrieve the values we need from the smart contracts
const treeDepths = await pollContract.treeDepths();
const numSignUpsAndMessages = await pollContract.numSignUpsAndMessages();
const numSignUps = Number(numSignUpsAndMessages[0]);
const numMessages = Number(numSignUpsAndMessages[1]);
const batchSizes = await pollContract.batchSizes();
const messageBatchSize = Number(batchSizes.messageBatchSize);
const tallyBatchSize = Number(batchSizes.tallyBatchSize);
const subsidyBatchSize = Number(batchSizes.subsidyBatchSize);
const messageBatchSize = STATE_TREE_ARITY ** Number(treeDepths.messageTreeSubDepth);
const tallyBatchSize = STATE_TREE_ARITY ** Number(treeDepths.intStateTreeDepth);
const subsidyBatchSize = STATE_TREE_ARITY ** Number(treeDepths.intStateTreeDepth);
let totalMessageBatches = numMessages <= messageBatchSize ? 1 : Math.floor(numMessages / messageBatchSize);

if (numMessages > messageBatchSize && numMessages % messageBatchSize > 0) {
Expand All @@ -212,7 +213,6 @@ export const proveOnChain = async ({
);
}

const treeDepths = await pollContract.treeDepths();
let numberBatchesProcessed = Number(await mpContract.numBatchesProcessed());
const messageRootOnChain = await messageAqContract.getMainRoot(Number(treeDepths.messageTreeDepth));

Expand Down Expand Up @@ -286,7 +286,13 @@ export const proveOnChain = async ({
}

const packedValsOnChain = BigInt(
await mpContract.genProcessMessagesPackedVals(currentMessageBatchIndex, numSignUps),
await mpContract.genProcessMessagesPackedVals(
currentMessageBatchIndex,
numSignUps,
numMessages,
treeDepths.messageTreeSubDepth,
treeDepths.voteOptionTreeDepth,
),
).toString();

if (circuitInputs.packedVals !== packedValsOnChain) {
Expand All @@ -300,8 +306,11 @@ export const proveOnChain = async ({
currentMessageBatchIndex,
messageRootOnChain.toString(),
numSignUps,
numMessages,
circuitInputs.currentSbCommitment as BigNumberish,
circuitInputs.newSbCommitment as BigNumberish,
treeDepths.messageTreeSubDepth,
treeDepths.voteOptionTreeDepth,
),
);

Expand Down
4 changes: 0 additions & 4 deletions cli/ts/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,6 @@ program
.description("deploy a new poll")
.option("-vk, --vkRegistryAddress <vkRegistryAddress>", "the vk registry contract address")
.requiredOption("-t, --duration <pollDuration>", "the poll duration", parseInt)
.requiredOption("-g, --max-messages <maxMessages>", "the max messages", parseInt)
.requiredOption("-mv, --max-vote-options <maxVoteOptions>", "the max vote options", parseInt)
.requiredOption("-i, --int-state-tree-depth <intStateTreeDepth>", "the int state tree depth", parseInt)
.requiredOption("-b, --msg-batch-depth <messageTreeSubDepth>", "the message tree sub depth", parseInt)
.requiredOption("-m, --msg-tree-depth <messageTreeDepth>", "the message tree depth", parseInt)
Expand All @@ -198,8 +196,6 @@ program
try {
await deployPoll({
pollDuration: cmdObj.duration,
maxMessages: cmdObj.maxMessages,
maxVoteOptions: cmdObj.maxVoteOptions,
intStateTreeDepth: cmdObj.intStateTreeDepth,
messageTreeSubDepth: cmdObj.msgBatchDepth,
messageTreeDepth: cmdObj.msgTreeDepth,
Expand Down
10 changes: 0 additions & 10 deletions cli/ts/utils/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -285,16 +285,6 @@ export interface DeployPollArgs {
*/
pollDuration: number;

/**
* The maximum number of messages that can be submitted
*/
maxMessages: number;

/**
* The maximum number of vote options
*/
maxVoteOptions: number;

/**
* The depth of the intermediate state tree
*/
Expand Down
23 changes: 5 additions & 18 deletions contracts/contracts/MACI.sol
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ contract MACI is IMACI, Params, Utilities, Ownable {

/// @notice Deploy a new Poll contract.
/// @param _duration How long should the Poll last for
/// @param _maxValues The maximum number of vote options, and messages
/// @param _treeDepths The depth of the Merkle trees
/// @param _coordinatorPubKey The coordinator's public key
/// @param _verifier The Verifier Contract
Expand All @@ -196,7 +195,6 @@ contract MACI is IMACI, Params, Utilities, Ownable {
/// @return pollAddr a new Poll contract address
function deployPoll(
uint256 _duration,
MaxValues memory _maxValues,
TreeDepths memory _treeDepths,
PubKey memory _coordinatorPubKey,
address _verifier,
Expand All @@ -216,25 +214,14 @@ contract MACI is IMACI, Params, Utilities, Ownable {
if (!stateAq.treeMerged()) revert PreviousPollNotCompleted(pollId);
}

// The message batch size and the tally batch size
BatchSizes memory batchSizes = BatchSizes(
uint24(TREE_ARITY) ** _treeDepths.messageTreeSubDepth,
uint24(TREE_ARITY) ** _treeDepths.intStateTreeDepth,
uint24(TREE_ARITY) ** _treeDepths.intStateTreeDepth
);
MaxValues memory maxValues = MaxValues({
maxMessages: uint256(TREE_ARITY) ** _treeDepths.messageTreeDepth,
maxVoteOptions: uint256(TREE_ARITY) ** _treeDepths.voteOptionTreeDepth
});

address _owner = owner();

address p = pollFactory.deploy(
_duration,
_maxValues,
_treeDepths,
batchSizes,
_coordinatorPubKey,
this,
topupCredit,
_owner
);
address p = pollFactory.deploy(_duration, maxValues, _treeDepths, _coordinatorPubKey, this, topupCredit, _owner);

address mp = messageProcessorFactory.deploy(_verifier, _vkRegistry, p, _owner);
address tally = tallyFactory.deploy(_verifier, _vkRegistry, p, mp, _owner);
Expand Down
90 changes: 66 additions & 24 deletions contracts/contracts/MessageProcessor.sol
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes
error CurrentMessageBatchIndexTooLarge();
error BatchEndIndexTooLarge();

// the number of children per node in the merkle trees
uint256 private constant TREE_ARITY = 5;

/// @inheritdoc IMessageProcessor
bool public processingComplete;

Expand Down Expand Up @@ -61,6 +64,7 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes
/// after all messages are processed
/// @param _proof The zk-SNARK proof
function processMessages(uint256 _newSbCommitment, uint256[8] memory _proof) external onlyOwner {
// ensure the voting period is over
_votingPeriodOver(poll);

// There must be unprocessed messages
Expand All @@ -74,11 +78,11 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes
}

// Retrieve stored vals
(, , uint8 messageTreeDepth, ) = poll.treeDepths();
(uint256 messageBatchSize, , ) = poll.batchSizes();
(, uint8 messageTreeSubDepth, uint8 messageTreeDepth, uint8 voteOptionTreeDepth) = poll.treeDepths();
// calculate the message batch size from the message tree subdepth
uint256 messageBatchSize = TREE_ARITY ** messageTreeSubDepth;

AccQueue messageAq;
(, messageAq, ) = poll.extContracts();
(, AccQueue messageAq, ) = poll.extContracts();

// Require that the message queue has been merged
uint256 messageRoot = messageAq.getMainRoot(messageTreeDepth);
Expand All @@ -105,8 +109,18 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes
}
}

bool isValid = verifyProcessProof(currentMessageBatchIndex, messageRoot, sbCommitment, _newSbCommitment, _proof);
if (!isValid) {
if (
!verifyProcessProof(
currentMessageBatchIndex,
messageRoot,
sbCommitment,
_newSbCommitment,
messageTreeSubDepth,
messageTreeDepth,
voteOptionTreeDepth,
_proof
)
) {
revert InvalidProcessMessageProof();
}

Expand All @@ -132,35 +146,45 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes
/// @param _messageRoot The message tree root
/// @param _currentSbCommitment The current sbCommitment (state and ballot)
/// @param _newSbCommitment The new sbCommitment after we update this message batch
/// @param _messageTreeSubDepth The message tree subdepth
/// @param _messageTreeDepth The message tree depth
/// @param _voteOptionTreeDepth The vote option tree depth
/// @param _proof The zk-SNARK proof
/// @return isValid Whether the proof is valid
function verifyProcessProof(
uint256 _currentMessageBatchIndex,
uint256 _messageRoot,
uint256 _currentSbCommitment,
uint256 _newSbCommitment,
uint8 _messageTreeSubDepth,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint8 _messageTreeDepth,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint8 _voteOptionTreeDepth,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint256[8] memory _proof
) internal view returns (bool isValid) {
(, , uint8 messageTreeDepth, uint8 voteOptionTreeDepth) = poll.treeDepths();
(uint256 messageBatchSize, , ) = poll.batchSizes();
(uint256 numSignUps, ) = poll.numSignUpsAndMessages();
// get the tree depths
// get the message batch size from the message tree subdepth
// get the number of signups
(uint256 numSignUps, uint256 numMessages) = poll.numSignUpsAndMessages();
(IMACI maci, , ) = poll.extContracts();

// Calculate the public input hash (a SHA256 hash of several values)
uint256 publicInputHash = genProcessMessagesPublicInputHash(
_currentMessageBatchIndex,
_messageRoot,
numSignUps,
numMessages,
_currentSbCommitment,
_newSbCommitment
_newSbCommitment,
_messageTreeSubDepth,
_voteOptionTreeDepth
);

// Get the verifying key from the VkRegistry
VerifyingKey memory vk = vkRegistry.getProcessVk(
maci.stateTreeDepth(),
messageTreeDepth,
voteOptionTreeDepth,
messageBatchSize
_messageTreeDepth,
_voteOptionTreeDepth,
TREE_ARITY ** _messageTreeSubDepth
);

isValid = verifier.verify(_proof, vk, publicInputHash);
Expand All @@ -175,22 +199,35 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes
/// higher and proving time will be longer.
/// @param _currentMessageBatchIndex The batch index of current message batch
/// @param _numSignUps The number of users that signup
/// @param _numMessages The number of messages
/// @param _currentSbCommitment The current sbCommitment (state and ballot root)
/// @param _newSbCommitment The new sbCommitment after we update this message batch
/// @param _messageTreeSubDepth The message tree subdepth
/// @return inputHash Returns the SHA256 hash of the packed values
function genProcessMessagesPublicInputHash(
uint256 _currentMessageBatchIndex,
uint256 _messageRoot,
uint256 _numSignUps,
uint256 _numMessages,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint256 _currentSbCommitment,
uint256 _newSbCommitment
uint256 _newSbCommitment,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint8 _messageTreeSubDepth,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint8 _voteOptionTreeDepth

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

) public view returns (uint256 inputHash) {
uint256 coordinatorPubKeyHash = poll.coordinatorPubKeyHash();

uint256 packedVals = genProcessMessagesPackedVals(_currentMessageBatchIndex, _numSignUps);
// pack the values
uint256 packedVals = genProcessMessagesPackedVals(
_currentMessageBatchIndex,
_numSignUps,
_numMessages,
_messageTreeSubDepth,
_voteOptionTreeDepth
);

(uint256 deployTime, uint256 duration) = poll.getDeployTimeAndDuration();

// generate the circuit only public input
uint256[] memory input = new uint256[](6);
input[0] = packedVals;
input[1] = coordinatorPubKeyHash;
Expand All @@ -208,19 +245,24 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes
/// the current batch.
/// @param _currentMessageBatchIndex batch index of current message batch
/// @param _numSignUps number of users that signup
/// @param _numMessages number of messages
/// @param _messageTreeSubDepth message tree subdepth
/// @param _voteOptionTreeDepth vote option tree depth
/// @return result The packed value
function genProcessMessagesPackedVals(
uint256 _currentMessageBatchIndex,
uint256 _numSignUps
) public view returns (uint256 result) {
(, uint256 maxVoteOptions) = poll.maxValues();
(, uint256 numMessages) = poll.numSignUpsAndMessages();
(uint24 mbs, , ) = poll.batchSizes();
uint256 messageBatchSize = uint256(mbs);

uint256 _numSignUps,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint256 _numMessages,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint8 _messageTreeSubDepth,

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

uint8 _voteOptionTreeDepth

Check warning

Code scanning / Slither

Conformance to Solidity naming conventions Warning

) public pure returns (uint256 result) {
uint256 maxVoteOptions = TREE_ARITY ** _voteOptionTreeDepth;

// calculate the message batch size from the message tree subdepth
uint256 messageBatchSize = TREE_ARITY ** _messageTreeSubDepth;
uint256 batchEndIndex = _currentMessageBatchIndex + messageBatchSize;
if (batchEndIndex > numMessages) {
batchEndIndex = numMessages;
if (batchEndIndex > _numMessages) {
batchEndIndex = _numMessages;
}

if (maxVoteOptions >= 2 ** 50) revert MaxVoteOptionsTooLarge();
Expand Down
5 changes: 0 additions & 5 deletions contracts/contracts/Poll.sol
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ contract Poll is Params, Utilities, SnarkCommon, Ownable, EmptyBallotRoots, IPol
/// @notice Depths of the merkle trees
TreeDepths public treeDepths;

/// @notice Batch sizes for processing
BatchSizes public batchSizes;
/// @notice The contracts used by the Poll
ExtContracts public extContracts;

Expand All @@ -79,14 +77,12 @@ contract Poll is Params, Utilities, SnarkCommon, Ownable, EmptyBallotRoots, IPol
/// @param _duration The duration of the voting period, in seconds
/// @param _maxValues The maximum number of messages and vote options
/// @param _treeDepths The depths of the merkle trees
/// @param _batchSizes The batch sizes for processing
/// @param _coordinatorPubKey The coordinator's public key
/// @param _extContracts The external contracts
constructor(
uint256 _duration,
MaxValues memory _maxValues,
TreeDepths memory _treeDepths,
BatchSizes memory _batchSizes,
PubKey memory _coordinatorPubKey,
ExtContracts memory _extContracts
) payable {
Expand All @@ -96,7 +92,6 @@ contract Poll is Params, Utilities, SnarkCommon, Ownable, EmptyBallotRoots, IPol
coordinatorPubKeyHash = hashLeftRight(_coordinatorPubKey.x, _coordinatorPubKey.y);
duration = _duration;
maxValues = _maxValues;
batchSizes = _batchSizes;
treeDepths = _treeDepths;
// Record the current timestamp
deployTime = block.timestamp;
Expand Down
Loading

0 comments on commit 5393868

Please sign in to comment.