Skip to content

Commit

Permalink
fix(trie): use Merkle value for database keys
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Aug 4, 2022
1 parent 4f311a5 commit 849f798
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 155 deletions.
73 changes: 37 additions & 36 deletions dot/state/pruner/pruner.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,33 +51,34 @@ type Config struct {

// Pruner is implemented by FullNode and ArchiveNode.
type Pruner interface {
StoreJournalRecord(deletedHashesSet, insertedHashesSet map[common.Hash]struct{},
StoreJournalRecord(deletedMerkleValues, insertedMerkleValues map[string]struct{},
blockHash common.Hash, blockNum int64) error
}

// ArchiveNode is a no-op since we don't prune nodes in archive mode.
type ArchiveNode struct{}

// StoreJournalRecord for archive node doesn't do anything.
func (*ArchiveNode) StoreJournalRecord(_, _ map[common.Hash]struct{},
func (*ArchiveNode) StoreJournalRecord(_, _ map[string]struct{},
_ common.Hash, _ int64) error {
return nil
}

type deathRecord struct {
blockHash common.Hash
deletedKeys map[common.Hash]int64 // Mapping from deleted key hash to block number.
blockHash common.Hash
deletedMerkleValueToBlockNumber map[string]int64
}

type deathRow []*deathRecord

// FullNode stores state trie diff and allows online state trie pruning
type FullNode struct {
logger log.LeveledLogger
deathList []deathRow
storageDB chaindb.Database
journalDB chaindb.Database
deathIndex map[common.Hash]int64 // Mapping from deleted key hash to block number.
logger log.LeveledLogger
deathList []deathRow
storageDB chaindb.Database
journalDB chaindb.Database
// deathIndex is the mapping from deleted node Merkle value to block number.
deathIndex map[string]int64
// pendingNumber is the block number to be pruned.
// Initial value is set to 1 and is incremented after every block pruning.
pendingNumber int64
Expand All @@ -88,31 +89,31 @@ type FullNode struct {
type journalRecord struct {
// blockHash of the block corresponding to journal record
blockHash common.Hash
// Hash of keys that are inserted into state trie of the block
insertedHashesSet map[common.Hash]struct{}
// Hash of keys that are deleted from state trie of the block
deletedHashesSet map[common.Hash]struct{}
// Merkle values of nodes inserted in the state trie of the block
insertedMerkleValues map[string]struct{}
// Merkle values of nodes deleted from the state trie of the block
deletedMerkleValues map[string]struct{}
}

type journalKey struct {
blockNum int64
blockHash common.Hash
}

func newJournalRecord(hash common.Hash, insertedHashesSet,
deletedHashesSet map[common.Hash]struct{}) *journalRecord {
func newJournalRecord(hash common.Hash, insertedMerkleValues,
deletedMerkleValues map[string]struct{}) *journalRecord {
return &journalRecord{
blockHash: hash,
insertedHashesSet: insertedHashesSet,
deletedHashesSet: deletedHashesSet,
blockHash: hash,
insertedMerkleValues: insertedMerkleValues,
deletedMerkleValues: deletedMerkleValues,
}
}

// NewFullNode creates a Pruner for full node.
func NewFullNode(db, storageDB chaindb.Database, retainBlocks int64, l log.LeveledLogger) (Pruner, error) {
p := &FullNode{
deathList: make([]deathRow, 0),
deathIndex: make(map[common.Hash]int64),
deathIndex: make(map[string]int64),
storageDB: storageDB,
journalDB: chaindb.NewTable(db, journalPrefix),
retainBlocks: retainBlocks,
Expand Down Expand Up @@ -140,9 +141,9 @@ func NewFullNode(db, storageDB chaindb.Database, retainBlocks int64, l log.Level
}

// StoreJournalRecord stores journal record into DB and add deathRow into deathList
func (p *FullNode) StoreJournalRecord(deletedHashesSet, insertedHashesSet map[common.Hash]struct{},
func (p *FullNode) StoreJournalRecord(deletedMerkleValues, insertedMerkleValues map[string]struct{},
blockHash common.Hash, blockNum int64) error {
jr := newJournalRecord(blockHash, insertedHashesSet, deletedHashesSet)
jr := newJournalRecord(blockHash, insertedMerkleValues, deletedMerkleValues)

key := &journalKey{blockNum, blockHash}
err := p.storeJournal(key, jr)
Expand All @@ -168,13 +169,13 @@ func (p *FullNode) addDeathRow(jr *journalRecord, blockNum int64) {
return
}

p.processInsertedKeys(jr.insertedHashesSet, jr.blockHash)
p.processInsertedKeys(jr.insertedMerkleValues, jr.blockHash)

// add deleted keys from journal to death index
deletedKeys := make(map[common.Hash]int64, len(jr.deletedHashesSet))
for k := range jr.deletedHashesSet {
// add deleted node Merkle values from journal to death index
deletedMerkleValueToBlockNumber := make(map[string]int64, len(jr.deletedMerkleValues))
for k := range jr.deletedMerkleValues {
p.deathIndex[k] = blockNum
deletedKeys[k] = blockNum
deletedMerkleValueToBlockNumber[k] = blockNum
}

blockIndex := blockNum - p.pendingNumber
Expand All @@ -183,25 +184,25 @@ func (p *FullNode) addDeathRow(jr *journalRecord, blockNum int64) {
}

record := &deathRecord{
blockHash: jr.blockHash,
deletedKeys: deletedKeys,
blockHash: jr.blockHash,
deletedMerkleValueToBlockNumber: deletedMerkleValueToBlockNumber,
}

// add deathRow to deathList
p.deathList[blockIndex] = append(p.deathList[blockIndex], record)
}

// Remove re-inserted keys
func (p *FullNode) processInsertedKeys(insertedHashesSet map[common.Hash]struct{}, blockHash common.Hash) {
for k := range insertedHashesSet {
func (p *FullNode) processInsertedKeys(insertedMerkleValues map[string]struct{}, blockHash common.Hash) {
for k := range insertedMerkleValues {
num, ok := p.deathIndex[k]
if !ok {
continue
}
records := p.deathList[num-p.pendingNumber]
for _, v := range records {
if v.blockHash == blockHash {
delete(v.deletedKeys, k)
delete(v.deletedMerkleValueToBlockNumber, k)
}
}
delete(p.deathIndex, k)
Expand Down Expand Up @@ -229,14 +230,14 @@ func (p *FullNode) start() {

sdbBatch := p.storageDB.NewBatch()
for _, record := range row {
err := p.deleteKeys(sdbBatch, record.deletedKeys)
err := p.deleteKeys(sdbBatch, record.deletedMerkleValueToBlockNumber)
if err != nil {
p.logger.Warnf("failed to prune keys for block number %d: %s", blockNum, err)
sdbBatch.Reset()
return
}

for k := range record.deletedKeys {
for k := range record.deletedMerkleValueToBlockNumber {
delete(p.deathIndex, k)
}
}
Expand Down Expand Up @@ -373,9 +374,9 @@ func (p *FullNode) getLastPrunedIndex() (int64, error) {
return blockNum, nil
}

func (*FullNode) deleteKeys(b chaindb.Batch, nodesHash map[common.Hash]int64) error {
for k := range nodesHash {
err := b.Del(k.ToBytes())
func (*FullNode) deleteKeys(b chaindb.Batch, deletedMerkleValueToBlockNumber map[string]int64) error {
for merkleValue := range deletedMerkleValueToBlockNumber {
err := b.Del([]byte(merkleValue))
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions dot/state/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ func (s *StorageState) StoreTrie(ts *rtstorage.TrieState, header *types.Header)
}

if header != nil {
insertedNodeHashes, err := ts.GetInsertedNodeHashes()
insertedMerkleValues, err := ts.GetInsertedMerkleValues()
if err != nil {
return fmt.Errorf("failed to get state trie inserted keys: block %s %w", header.Hash(), err)
}

deletedNodeHashes := ts.GetDeletedNodeHashes()
err = s.pruner.StoreJournalRecord(deletedNodeHashes, insertedNodeHashes, header.Hash(), int64(header.Number))
deletedMerkleValues := ts.GetDeletedMerkleValues()
err = s.pruner.StoreJournalRecord(deletedMerkleValues, insertedMerkleValues, header.Hash(), int64(header.Number))
if err != nil {
return err
}
Expand Down
10 changes: 5 additions & 5 deletions internal/trie/node/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

// EncodeAndHash returns the encoding of the node and
// the Merkle value of the node.
func (n *Node) EncodeAndHash() (encoding, hash []byte, err error) {
func (n *Node) EncodeAndHash() (encoding, merkleValue []byte, err error) {
if !n.Dirty && n.Encoding != nil && n.HashDigest != nil {
return n.Encoding, n.HashDigest, nil
}
Expand All @@ -37,8 +37,8 @@ func (n *Node) EncodeAndHash() (encoding, hash []byte, err error) {
if buffer.Len() < 32 {
n.HashDigest = make([]byte, len(bufferBytes))
copy(n.HashDigest, bufferBytes)
hash = n.HashDigest // no need to copy
return encoding, hash, nil
merkleValue = n.HashDigest // no need to copy
return encoding, merkleValue, nil
}

// Note: using the sync.Pool's buffer is useful here.
Expand All @@ -47,9 +47,9 @@ func (n *Node) EncodeAndHash() (encoding, hash []byte, err error) {
return nil, nil, err
}
n.HashDigest = hashArray[:]
hash = n.HashDigest // no need to copy
merkleValue = n.HashDigest // no need to copy

return encoding, hash, nil
return encoding, merkleValue, nil
}

// EncodeAndHashRoot returns the encoding of the root node and
Expand Down
14 changes: 7 additions & 7 deletions lib/runtime/storage/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,18 +271,18 @@ func (s *TrieState) LoadCodeHash() (common.Hash, error) {
return common.Blake2bHash(code)
}

// GetInsertedNodeHashes returns a set of hashes of all nodes
// that were inserted into state trie since the last block produced.
func (s *TrieState) GetInsertedNodeHashes() (hashesSet map[common.Hash]struct{}, err error) {
// GetInsertedMerkleValues returns the set of all node Merkle value inserted
// into the state trie since the last block produced.
func (s *TrieState) GetInsertedMerkleValues() (merkleValues map[string]struct{}, err error) {
s.lock.RLock()
defer s.lock.RUnlock()
return s.t.GetInsertedNodeHashes()
return s.t.GetInsertedMerkleValues()
}

// GetDeletedNodeHashes returns the hash of nodes that were deleted
// GetDeletedMerkleValues returns the set of all node Merkle values deleted
// from the state trie since the last block produced.
func (s *TrieState) GetDeletedNodeHashes() (hashesSet map[common.Hash]struct{}) {
func (s *TrieState) GetDeletedMerkleValues() (merkleValues map[string]struct{}) {
s.lock.RLock()
defer s.lock.RUnlock()
return s.t.GetDeletedNodeHashes()
return s.t.GetDeletedMerkleValues()
}
42 changes: 20 additions & 22 deletions lib/trie/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,37 +379,35 @@ func (t *Trie) writeDirtyNode(db chaindb.Batch, n *Node) (err error) {
return nil
}

// GetInsertedNodeHashes returns a set of hashes with all
// the hashes of all nodes that were inserted in the state trie
// since the last snapshot.
// We need to compute the hash values of each newly inserted node.
func (t *Trie) GetInsertedNodeHashes() (hashesSet map[common.Hash]struct{}, err error) {
hashesSet = make(map[common.Hash]struct{})
err = t.getInsertedNodeHashesAtNode(t.root, hashesSet)
// GetInsertedMerkleValues returns the set of node Merkle values
// for each node that was inserted in the state trie since the last snapshot.
func (t *Trie) GetInsertedMerkleValues() (merkleValues map[string]struct{}, err error) {
merkleValues = make(map[string]struct{})
err = t.getInsertedNodeHashesAtNode(t.root, merkleValues)
if err != nil {
return nil, err
}
return hashesSet, nil
return merkleValues, nil
}

func (t *Trie) getInsertedNodeHashesAtNode(n *Node, hashes map[common.Hash]struct{}) (err error) {
func (t *Trie) getInsertedNodeHashesAtNode(n *Node, merkleValues map[string]struct{}) (err error) {
if n == nil || !n.Dirty {
return nil
}

var hash []byte
var merkleValue []byte
if n == t.root {
_, hash, err = n.EncodeAndHashRoot()
_, merkleValue, err = n.EncodeAndHashRoot()
} else {
_, hash, err = n.EncodeAndHash()
_, merkleValue, err = n.EncodeAndHash()
}
if err != nil {
return fmt.Errorf(
"cannot encode and hash node with hash 0x%x: %w",
n.HashDigest, err)
}

hashes[common.BytesToHash(hash)] = struct{}{}
merkleValues[string(merkleValue)] = struct{}{}

if n.Kind() != node.Branch {
return nil
Expand All @@ -420,7 +418,7 @@ func (t *Trie) getInsertedNodeHashesAtNode(n *Node, hashes map[common.Hash]struc
continue
}

err := t.getInsertedNodeHashesAtNode(child, hashes)
err := t.getInsertedNodeHashesAtNode(child, merkleValues)
if err != nil {
// Note: do not wrap error since this is called recursively.
return err
Expand All @@ -430,13 +428,13 @@ func (t *Trie) getInsertedNodeHashesAtNode(n *Node, hashes map[common.Hash]struc
return nil
}

// GetDeletedNodeHashes returns a set of all the hashes of nodes that were
// deleted from the trie since the last snapshot was made.
// The returned set is a copy of the internal set to prevent data races.
func (t *Trie) GetDeletedNodeHashes() (hashesSet map[common.Hash]struct{}) {
hashesSet = make(map[common.Hash]struct{}, len(t.deletedKeys))
for k := range t.deletedKeys {
hashesSet[k] = struct{}{}
// GetDeletedMerkleValues returns a set of all the node Merkle values for each
// node that was deleted from the trie since the last snapshot was made.
// The returned set is a copy of the internal set to prevent data corruption.
func (t *Trie) GetDeletedMerkleValues() (merkleValues map[string]struct{}) {
merkleValues = make(map[string]struct{}, len(t.deletedMerkleValues))
for k := range t.deletedMerkleValues {
merkleValues[k] = struct{}{}
}
return hashesSet
return merkleValues
}
Loading

0 comments on commit 849f798

Please sign in to comment.