Skip to content

Commit

Permalink
feat: SMT Verification Module Update (#6849)
Browse files Browse the repository at this point in the history
This PR slightly updates several classes in `smt_verification` directory

## CircuitSchema

- new structure of CircuitSchema object for compatibility with
UltraArith and UltraTraceBlocks
- also moved it to `circuit_builder_base` due to `UltraCircuitBuilder`'s
`operator==` definition
- UltraCircuitBuilder now has `export_circuit()` method

## Solver

- new table creation function
- cvc5 patch fix
- BitVector output values

## Circuit

`circuit.cpp/circuit.hpp` is moving to `standard_circuit.*`. More
significant changes in the follow up pr

## Terms

- fixed BitVector shift operation
- Added tests for all sorts of operations with bitvectors
- Added test for table inclusion with FFTerm

## Util

Slightly changed `default_model` function. Now it doesn't need solver
parameter + fixed the memory related issues.
  • Loading branch information
Sarkoxed authored Jun 5, 2024
1 parent 3a97f08 commit 6c98529
Show file tree
Hide file tree
Showing 20 changed files with 424 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ template <typename FF_> class UltraArithmeticRelationImpl {
* at the next gate. Then we can treat (q_arith - 1) as a simulated q_6 selector and scale q_m to handle (q_arith -
* 3) at product.
*
* The The relation is
* The relation is
* defined as C(in(X)...) = q_arith * [ -1/2(q_arith - 3)(q_m * w_r * w_l) + (q_l * w_l) + (q_r * w_r) +
* (q_o * w_o) + (q_4 * w_4) + q_c + (q_arith - 1)w_4_shift ]
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ add_dependencies(cvc5-lib cvc5)
include_directories(${CVC5_INCLUDE})
set_target_properties(cvc5-lib PROPERTIES IMPORTED_LOCATION ${CVC5_LIB})

barretenberg_module(smt_verification common proof_system stdlib_primitives stdlib_sha256 circuit_checker transcript plonk cvc5-lib)
barretenberg_module(smt_verification common stdlib_primitives stdlib_sha256 circuit_checker transcript plonk cvc5-lib)
# We have no easy way to add a dependency to an external target, we list the built targets explicit. Could be cleaner.
add_dependencies(smt_verification cvc5)
add_dependencies(smt_verification_objects cvc5)
add_dependencies(smt_verification_tests cvc5)
add_dependencies(smt_verification_test_objects cvc5)
add_dependencies(smt_verification_test_objects cvc5)
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ CircuitSchema unpack_from_file(const std::string& filename)
return cir;
}

/**
* @brief Get the CircuitSchema object
* @details Initialize the CircuitSchema from the msgpack compatible buffer.
*
* @param buf
* @return CircuitSchema
*/
CircuitSchema unpack_from_buffer(const msgpack::sbuffer& buf)
{
CircuitSchema cir;
msgpack::unpack(buf.data(), buf.size()).get().convert(cir);
return cir;
}

/**
* @brief Translates the schema to python format
* @details Returns the contents of the .py file
Expand All @@ -42,6 +56,7 @@ CircuitSchema unpack_from_file(const std::string& filename)
* gates = [
* [[0x000...0, 0x000...1, 0x000...0, 0x000...0, 0x000...0], [0, 0, 0]], ...
* ]
* @todo UltraCircuitSchema output
*/
void print_schema_for_use_in_python(CircuitSchema& cir)
{
Expand All @@ -64,37 +79,23 @@ void print_schema_for_use_in_python(CircuitSchema& cir)
for (size_t i = 0; i < cir.selectors.size(); i++) {
info("[",
"[",
cir.selectors[i][0],
cir.selectors[0][i][0],
", ",
cir.selectors[i][1],
cir.selectors[0][i][1],
", ",
cir.selectors[i][2],
cir.selectors[0][i][2],
", ",
cir.selectors[i][3],
cir.selectors[0][i][3],
", ",
cir.selectors[i][4],
cir.selectors[0][i][4],
"], [",
cir.wires[i][0],
cir.wires[0][i][0],
", ",
cir.wires[i][1],
cir.wires[0][i][1],
", ",
cir.wires[i][2],
cir.wires[0][i][2],
"]],");
}
info("]");
}

/**
* @brief Get the CircuitSchema object
* @details Initialize the CircuitSchema from the msgpack compatible buffer.
*
* @param buf
* @return CircuitSchema
*/
CircuitSchema unpack_from_buffer(const msgpack::sbuffer& buf)
{
CircuitSchema cir;
msgpack::unpack(buf.data(), buf.size()).get().convert(cir);
return cir;
}
} // namespace smt_circuit_schema
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ struct CircuitSchema {
std::vector<uint32_t> public_inps;
std::unordered_map<uint32_t, std::string> vars_of_interest;
std::vector<bb::fr> variables;
std::vector<std::vector<bb::fr>> selectors;
std::vector<std::vector<uint32_t>> wires;
std::vector<std::vector<std::vector<bb::fr>>> selectors;
std::vector<std::vector<std::vector<uint32_t>>> wires;
std::vector<uint32_t> real_variable_index;
MSGPACK_FIELDS(modulus, public_inps, vars_of_interest, variables, selectors, wires, real_variable_index);
std::vector<std::vector<std::vector<bb::fr>>> lookup_tables;
MSGPACK_FIELDS(
modulus, public_inps, vars_of_interest, variables, selectors, wires, real_variable_index, lookup_tables);
};

CircuitSchema unpack_from_buffer(const msgpack::sbuffer& buf);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "solver.hpp"
#include <iostream>
#include "barretenberg/common/log.hpp"

namespace smt_solver {

Expand Down Expand Up @@ -47,6 +47,8 @@ std::unordered_map<std::string, std::string> Solver::model(std::unordered_map<st
str_val = val.getIntegerValue();
} else if (val.isFiniteFieldValue()) {
str_val = val.getFiniteFieldValue();
} else if (val.isBitVectorValue()) {
str_val = "0b" + val.getBitVectorValue();
} else {
throw std::invalid_argument("Expected Integer or FiniteField sorts. Got: " + val.getSort().toString());
}
Expand Down Expand Up @@ -82,6 +84,8 @@ std::unordered_map<std::string, std::string> Solver::model(std::vector<cvc5::Ter
str_val = val.getIntegerValue();
} else if (val.isFiniteFieldValue()) {
str_val = val.getFiniteFieldValue();
} else if (val.isBitVectorValue()) {
str_val = "0b" + val.getBitVectorValue();
} else {
throw std::invalid_argument("Expected Integer or FiniteField sorts. Got: " + val.getSort().toString());
}
Expand Down Expand Up @@ -181,11 +185,11 @@ std::string stringify_term(const cvc5::Term& term, bool parenthesis)
break;
case cvc5::Kind::BITVECTOR_SHL:
back = true;
op = " << " + term.getOp()[0].toString();
op = " << ";
break;
case cvc5::Kind::BITVECTOR_LSHR:
back = true;
op = " >> " + term.getOp()[0].toString();
op = " >> ";
break;
case cvc5::Kind::BITVECTOR_ROTATE_LEFT:
back = true;
Expand Down Expand Up @@ -237,8 +241,31 @@ std::string stringify_term(const cvc5::Term& term, bool parenthesis)
* */
void Solver::print_assertions() const
{
for (auto& t : this->solver.getAssertions()) {
for (const auto& t : this->solver.getAssertions()) {
info(stringify_term(t));
}
}

cvc5::Term Solver::create_lookup_table(std::vector<std::vector<cvc5::Term>>& table)
{
if (!lookup_enabled) {
this->solver.setLogic("ALL");
this->solver.setOption("finite-model-find", "true");
this->solver.setOption("sets-ext", "true");
lookup_enabled = true;
}

cvc5::Term tmp = table[0][0];
cvc5::Sort tuple_sort = this->term_manager.mkTupleSort({ tmp.getSort(), tmp.getSort(), tmp.getSort() });
cvc5::Sort relation = this->term_manager.mkSetSort(tuple_sort);
cvc5::Term resulting_table = this->term_manager.mkEmptySet(relation);

std::vector<cvc5::Term> children;
for (auto& table_entry : table) {
cvc5::Term entry = this->term_manager.mkTuple(table_entry);
children.push_back(entry);
}
children.push_back(resulting_table);
return this->term_manager.mkTerm(cvc5::Kind::SET_INSERT, children);
}
}; // namespace smt_solver
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct SolverConfiguration {
uint64_t timeout;
uint32_t debug;

bool ff_disjunctive_bit;
bool ff_elim_disjunctive_bit;
std::string ff_solver;
};

Expand Down Expand Up @@ -76,8 +76,8 @@ class Solver {
// Cause bit constraints are part of the split-gb optimization
// and without them it will probably perform less efficient
// TODO(alex): test this `probably` after finishing the pr sequence
if (config.ff_disjunctive_bit) {
solver.setOption("ff-disjunctive-bit", "true");
if (!config.ff_elim_disjunctive_bit) {
solver.setOption("ff-elim-disjunctive-bit", "false");
}
// split-gb is an updated version of gb ff-solver
// It basically SPLITS the polynomials in the system into subsets
Expand All @@ -96,6 +96,10 @@ class Solver {
Solver& operator=(const Solver& other) = delete;
Solver& operator=(Solver&& other) = delete;

bool lookup_enabled = false;

cvc5::Term create_lookup_table(std::vector<std::vector<cvc5::Term>>& table);

void assertFormula(const cvc5::Term& term) const { this->solver.assertFormula(term); }

cvc5::Term getValue(const cvc5::Term& term) const { return this->solver.getValue(term); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,145 @@ TEST(BVTerm, rotl)
ASSERT_EQ(bvals, xvals);
}

// MUL, LSH, RSH, AND and OR are not tested, since they are not bijective
// non bijective operators
TEST(BVTerm, mul)
{
StandardCircuitBuilder builder;
uint_ct a = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct b = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct c = a * b;

uint32_t modulus_base = 16;
uint32_t bitvector_size = 32;
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
modulus_base,
bitvector_size);

STerm x = BVVar("x", &s);
STerm y = BVVar("y", &s);
STerm z = x * y;

x == a.get_value();
y == b.get_value();

ASSERT_TRUE(s.check());

std::string xvals = s.getValue(z.term).getBitVectorValue();
STerm bval = STerm(c.get_value(), &s, TermType::BVTerm);
std::string bvals = s.getValue(bval.term).getBitVectorValue();
ASSERT_EQ(bvals, xvals);
}

TEST(BVTerm, and)
{
StandardCircuitBuilder builder;
uint_ct a = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct b = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct c = a & b;

uint32_t modulus_base = 16;
uint32_t bitvector_size = 32;
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
modulus_base,
bitvector_size);

STerm x = BVVar("x", &s);
STerm y = BVVar("y", &s);
STerm z = x & y;

x == a.get_value();
y == b.get_value();

ASSERT_TRUE(s.check());

std::string xvals = s.getValue(z.term).getBitVectorValue();
STerm bval = STerm(c.get_value(), &s, TermType::BVTerm);
std::string bvals = s.getValue(bval.term).getBitVectorValue();
ASSERT_EQ(bvals, xvals);
}

TEST(BVTerm, or)
{
StandardCircuitBuilder builder;
uint_ct a = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct b = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct c = a | b;

uint32_t modulus_base = 16;
uint32_t bitvector_size = 32;
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
modulus_base,
bitvector_size);

STerm x = BVVar("x", &s);
STerm y = BVVar("y", &s);
STerm z = x | y;

x == a.get_value();
y == b.get_value();

ASSERT_TRUE(s.check());

std::string xvals = s.getValue(z.term).getBitVectorValue();
STerm bval = STerm(c.get_value(), &s, TermType::BVTerm);
std::string bvals = s.getValue(bval.term).getBitVectorValue();
ASSERT_EQ(bvals, xvals);
}

TEST(BVTerm, shr)
{
StandardCircuitBuilder builder;
uint_ct a = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct b = a >> 5;

uint32_t modulus_base = 16;
uint32_t bitvector_size = 32;
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
modulus_base,
bitvector_size);

STerm x = BVVar("x", &s);
STerm y = x >> 5;

x == a.get_value();

ASSERT_TRUE(s.check());

std::string xvals = s.getValue(y.term).getBitVectorValue();
STerm bval = STerm(b.get_value(), &s, TermType::BVTerm);
std::string bvals = s.getValue(bval.term).getBitVectorValue();
ASSERT_EQ(bvals, xvals);
}

TEST(BVTerm, shl)
{
StandardCircuitBuilder builder;
uint_ct a = witness_ct(&builder, static_cast<uint32_t>(fr::random_element()));
uint_ct b = a << 5;

uint32_t modulus_base = 16;
uint32_t bitvector_size = 32;
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001",
default_solver_config,
modulus_base,
bitvector_size);

STerm x = BVVar("x", &s);
STerm y = x << 5;

x == a.get_value();

ASSERT_TRUE(s.check());

std::string xvals = s.getValue(y.term).getBitVectorValue();
STerm bval = STerm(b.get_value(), &s, TermType::BVTerm);
std::string bvals = s.getValue(bval.term).getBitVectorValue();
ASSERT_EQ(bvals, xvals);
}

// This test aims to check for the absence of unintended
// behavior. If an unsupported operator is called, an info message appears in stderr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,31 @@ TEST(FFTerm, division)
ASSERT_EQ(bvals, yvals);
}

TEST(FFTerm, set_inclusion)
{
Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001");

std::vector<std::vector<cvc5::Term>> table = { { FFConst("1", &s), FFConst("2", &s), FFConst("3", &s) },
{ FFConst("4", &s), FFConst("5", &s), FFConst("6", &s) } };
cvc5::Term symbolic_table = s.create_lookup_table(table);

STerm x = FFVar("x", &s);
STerm y = FFVar("y", &s);
STerm z = FFVar("z", &s);
std::vector<STerm> tmp_vec = { x, y, z };
STerm::in_table(tmp_vec, symbolic_table);
x != 4;

ASSERT_TRUE(s.check());

std::string xval = s.getValue(x).getFiniteFieldValue();
ASSERT_EQ(xval, "1");
std::string yval = s.getValue(y).getFiniteFieldValue();
ASSERT_EQ(yval, "2");
std::string zval = s.getValue(z).getFiniteFieldValue();
ASSERT_EQ(zval, "3");
}

// This test aims to check for the absence of unintended
// behavior. If an unsupported operator is called, an info message appears in stderr
// and the value is supposed to remain unchanged.
Expand Down
Loading

0 comments on commit 6c98529

Please sign in to comment.