Skip to content

Commit

Permalink
core/state: better randomized testing (postcheck) on journalling (eth…
Browse files Browse the repository at this point in the history
…ereum#29627)

This PR fixes some flaws with the existing tests.

The randomized testing (TestSnapshotRandom) executes a series of steps which modify the state and create journal-events. Later on, we compare the forward-going-states against the backwards-unrolling-journal-states, and check that they are identical.

The "identical" check is performed using various accessors. It turned out that we failed to check some things: 
- the accesslist contents
- the transient storage contents
- the 'newContract' flag
- the dirty storage map

This change adds these new checks
  • Loading branch information
holiman authored and stwiname committed Sep 9, 2024
1 parent 9e24770 commit 78e9b38
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 17 deletions.
35 changes: 35 additions & 0 deletions core/state/access_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
package state

import (
"fmt"
"maps"
"slices"
"strings"

"github.com/ethereum/go-ethereum/common"
)
Expand Down Expand Up @@ -130,3 +133,35 @@ func (al *accessList) DeleteSlot(address common.Address, slot common.Hash) {
func (al *accessList) DeleteAddress(address common.Address) {
delete(al.addresses, address)
}

// Equal returns true if the two access lists are identical
func (al *accessList) Equal(other *accessList) bool {
if !maps.Equal(al.addresses, other.addresses) {
return false
}
return slices.EqualFunc(al.slots, other.slots,
func(m map[common.Hash]struct{}, m2 map[common.Hash]struct{}) bool {
return maps.Equal(m, m2)
})
}

// PrettyPrint prints the contents of the access list in a human-readable form
func (al *accessList) PrettyPrint() string {
out := new(strings.Builder)
var sortedAddrs []common.Address
for addr := range al.addresses {
sortedAddrs = append(sortedAddrs, addr)
}
slices.SortFunc(sortedAddrs, common.Address.Cmp)
for _, addr := range sortedAddrs {
idx := al.addresses[addr]
fmt.Fprintf(out, "%#x : (idx %d)\n", addr, idx)
if idx >= 0 {
slotmap := al.slots[idx]
for h := range slotmap {
fmt.Fprintf(out, " %#x\n", h)
}
}
}
return out.String()
}
24 changes: 12 additions & 12 deletions core/state/state_object.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,22 +459,22 @@ func (s *stateObject) setBalance(amount *uint256.Int) {

func (s *stateObject) deepCopy(db *StateDB) *stateObject {
obj := &stateObject{
db: db,
address: s.address,
addrHash: s.addrHash,
origin: s.origin,
data: s.data,
db: db,
address: s.address,
addrHash: s.addrHash,
origin: s.origin,
data: s.data,
code: s.code,
originStorage: s.originStorage.Copy(),
pendingStorage: s.pendingStorage.Copy(),
dirtyStorage: s.dirtyStorage.Copy(),
dirtyCode: s.dirtyCode,
selfDestructed: s.selfDestructed,
newContract: s.newContract,
}
if s.trie != nil {
obj.trie = db.db.CopyTrie(s.trie)
}
obj.code = s.code
obj.originStorage = s.originStorage.Copy()
obj.pendingStorage = s.pendingStorage.Copy()
obj.dirtyStorage = s.dirtyStorage.Copy()
obj.dirtyCode = s.dirtyCode
obj.selfDestructed = s.selfDestructed
obj.newContract = s.newContract
return obj
}

Expand Down
68 changes: 66 additions & 2 deletions core/state/statedb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ import (
"encoding/binary"
"errors"
"fmt"
"maps"
"math"
"math/rand"
"reflect"
"slices"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -557,10 +559,14 @@ func forEachStorage(s *StateDB, addr common.Address, cb func(key, value common.H
if err != nil {
return err
}
it := trie.NewIterator(trieIt)
var (
it = trie.NewIterator(trieIt)
visited = make(map[common.Hash]bool)
)

for it.Next() {
key := common.BytesToHash(s.trie.GetKey(it.Key))
visited[key] = true
if value, dirty := so.dirtyStorage[key]; dirty {
if !cb(key, value) {
return nil
Expand Down Expand Up @@ -600,6 +606,10 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
// Check newContract-flag
if obj := state.getStateObject(addr); obj != nil {
checkeq("IsNewContract", obj.newContract, checkstate.getStateObject(addr).newContract)
}
// Check storage.
if obj := state.getStateObject(addr); obj != nil {
forEachStorage(state, addr, func(key, value common.Hash) bool {
Expand All @@ -608,12 +618,49 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
forEachStorage(checkstate, addr, func(key, value common.Hash) bool {
return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
})
other := checkstate.getStateObject(addr)
// Check dirty storage which is not in trie
if !maps.Equal(obj.dirtyStorage, other.dirtyStorage) {
print := func(dirty map[common.Hash]common.Hash) string {
var keys []common.Hash
out := new(strings.Builder)
for key := range dirty {
keys = append(keys, key)
}
slices.SortFunc(keys, common.Hash.Cmp)
for i, key := range keys {
fmt.Fprintf(out, " %d. %v %v\n", i, key, dirty[key])
}
return out.String()
}
return fmt.Errorf("dirty storage err, have\n%v\nwant\n%v",
print(obj.dirtyStorage),
print(other.dirtyStorage))
}
}
// Check transient storage.
{
have := state.transientStorage
want := checkstate.transientStorage
eq := maps.EqualFunc(have, want,
func(a Storage, b Storage) bool {
return maps.Equal(a, b)
})
if !eq {
return fmt.Errorf("transient storage differs ,have\n%v\nwant\n%v",
have.PrettyPrint(),
want.PrettyPrint())
}
}
if err != nil {
return err
}
}

if !checkstate.accessList.Equal(state.accessList) { // Check access lists
return fmt.Errorf("AccessLists are wrong, have \n%v\nwant\n%v",
checkstate.accessList.PrettyPrint(),
state.accessList.PrettyPrint())
}
if state.GetRefund() != checkstate.GetRefund() {
return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
state.GetRefund(), checkstate.GetRefund())
Expand All @@ -622,6 +669,23 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{}))
}
if !maps.Equal(state.journal.dirties, checkstate.journal.dirties) {
getKeys := func(dirty map[common.Address]int) string {
var keys []common.Address
out := new(strings.Builder)
for key := range dirty {
keys = append(keys, key)
}
slices.SortFunc(keys, common.Address.Cmp)
for i, key := range keys {
fmt.Fprintf(out, " %d. %v\n", i, key)
}
return out.String()
}
have := getKeys(state.journal.dirties)
want := getKeys(checkstate.journal.dirties)
return fmt.Errorf("dirty-journal set mismatch.\nhave:\n%v\nwant:\n%v\n", have, want)
}
return nil
}

Expand Down
43 changes: 40 additions & 3 deletions core/state/transient_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
package state

import (
"fmt"
"slices"
"strings"

"github.com/ethereum/go-ethereum/common"
)

Expand All @@ -30,10 +34,19 @@ func newTransientStorage() transientStorage {

// Set sets the transient-storage `value` for `key` at the given `addr`.
func (t transientStorage) Set(addr common.Address, key, value common.Hash) {
if _, ok := t[addr]; !ok {
t[addr] = make(Storage)
if value == (common.Hash{}) { // this is a 'delete'
if _, ok := t[addr]; ok {
delete(t[addr], key)
if len(t[addr]) == 0 {
delete(t, addr)
}
}
} else {
if _, ok := t[addr]; !ok {
t[addr] = make(Storage)
}
t[addr][key] = value
}
t[addr][key] = value
}

// Get gets the transient storage for `key` at the given `addr`.
Expand All @@ -53,3 +66,27 @@ func (t transientStorage) Copy() transientStorage {
}
return storage
}

// PrettyPrint prints the contents of the access list in a human-readable form
func (t transientStorage) PrettyPrint() string {
out := new(strings.Builder)
var sortedAddrs []common.Address
for addr := range t {
sortedAddrs = append(sortedAddrs, addr)
slices.SortFunc(sortedAddrs, common.Address.Cmp)
}

for _, addr := range sortedAddrs {
fmt.Fprintf(out, "%#x:", addr)
var sortedKeys []common.Hash
storage := t[addr]
for key := range storage {
sortedKeys = append(sortedKeys, key)
}
slices.SortFunc(sortedKeys, common.Hash.Cmp)
for _, key := range sortedKeys {
fmt.Fprintf(out, " %X : %X\n", key, storage[key])
}
}
return out.String()
}

0 comments on commit 78e9b38

Please sign in to comment.