From 4178a9756078ee80b9625d02cb1c84db09b8099c Mon Sep 17 00:00:00 2001 From: Wang Gerui Date: Sat, 6 Apr 2024 19:29:31 +0800 Subject: [PATCH] feat: extend eth_call api with state override; extend statedb and state_object with necessary code --- core/state/state_object.go | 306 +++++++++++++++++++++---------------- core/state/statedb.go | 9 ++ internal/ethapi/api.go | 63 +++++++- 3 files changed, 240 insertions(+), 138 deletions(-) diff --git a/core/state/state_object.go b/core/state/state_object.go index a5c7ce0bc964..99763db267e2 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -31,23 +31,23 @@ var emptyCodeHash = crypto.Keccak256(nil) type Code []byte -func (self Code) String() string { - return string(self) //strings.Join(Disassemble(self), " ") +func (c Code) String() string { + return string(c) //strings.Join(Disassemble(c), " ") } type Storage map[common.Hash]common.Hash -func (self Storage) String() (str string) { - for key, value := range self { +func (s Storage) String() (str string) { + for key, value := range s { str += fmt.Sprintf("%X : %X\n", key, value) } return } -func (self Storage) Copy() Storage { +func (s Storage) Copy() Storage { cpy := make(Storage) - for key, value := range self { + for key, value := range s { cpy[key] = value } @@ -79,6 +79,7 @@ type stateObject struct { cachedStorage Storage // Storage entry cache to avoid duplicate reads dirtyStorage Storage // Storage entries that need to be flushed to disk + fakeStorage Storage // Fake storage which constructed by caller for debugging purpose. // Cache flags. // When an object is marked suicided it will be delete from the trie @@ -124,199 +125,236 @@ func newObject(db *StateDB, address common.Address, data Account, onDirty func(a } // EncodeRLP implements rlp.Encoder. -func (c *stateObject) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, c.data) +func (s *stateObject) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, s.data) } // setError remembers the first non-nil error it is called with. -func (self *stateObject) setError(err error) { - if self.dbErr == nil { - self.dbErr = err +func (s *stateObject) setError(err error) { + if s.dbErr == nil { + s.dbErr = err } } -func (self *stateObject) markSuicided() { - self.suicided = true - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil +func (s *stateObject) markSuicided() { + s.suicided = true + if s.onDirty != nil { + s.onDirty(s.Address()) + s.onDirty = nil } } -func (c *stateObject) touch() { - c.db.journal = append(c.db.journal, touchChange{ - account: &c.address, - prev: c.touched, - prevDirty: c.onDirty == nil, +func (s *stateObject) touch() { + s.db.journal = append(s.db.journal, touchChange{ + account: &s.address, + prev: s.touched, + prevDirty: s.onDirty == nil, }) - if c.onDirty != nil { - c.onDirty(c.Address()) - c.onDirty = nil + if s.onDirty != nil { + s.onDirty(s.Address()) + s.onDirty = nil } - c.touched = true + s.touched = true } -func (c *stateObject) getTrie(db Database) Trie { - if c.trie == nil { +func (s *stateObject) getTrie(db Database) Trie { + if s.trie == nil { var err error - c.trie, err = db.OpenStorageTrie(c.addrHash, c.data.Root) + s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root) if err != nil { - c.trie, _ = db.OpenStorageTrie(c.addrHash, common.Hash{}) - c.setError(fmt.Errorf("can't create storage trie: %v", err)) + s.trie, _ = db.OpenStorageTrie(s.addrHash, common.Hash{}) + s.setError(fmt.Errorf("can't create storage trie: %v", err)) } } - return c.trie + return s.trie } -func (self *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash { +func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash { + // If the fake storage is set, only lookup the state here(in the debugging mode) + if s.fakeStorage != nil { + return s.fakeStorage[key] + } value := common.Hash{} // Load from DB in case it is missing. - enc, err := self.getTrie(db).TryGet(key[:]) + enc, err := s.getTrie(db).TryGet(key[:]) if err != nil { - self.setError(err) + s.setError(err) return common.Hash{} } if len(enc) > 0 { _, content, _, err := rlp.Split(enc) if err != nil { - self.setError(err) + s.setError(err) } value.SetBytes(content) } return value } -func (self *stateObject) GetState(db Database, key common.Hash) common.Hash { - value, exists := self.cachedStorage[key] +func (s *stateObject) GetState(db Database, key common.Hash) common.Hash { + // If the fake storage is set, only lookup the state here(in the debugging mode) + if s.fakeStorage != nil { + return s.fakeStorage[key] + } + value, exists := s.cachedStorage[key] if exists { return value } // Load from DB in case it is missing. - enc, err := self.getTrie(db).TryGet(key[:]) + enc, err := s.getTrie(db).TryGet(key[:]) if err != nil { - self.setError(err) + s.setError(err) return common.Hash{} } if len(enc) > 0 { _, content, _, err := rlp.Split(enc) if err != nil { - self.setError(err) + s.setError(err) } value.SetBytes(content) } if (value != common.Hash{}) { - self.cachedStorage[key] = value + s.cachedStorage[key] = value } return value } // SetState updates a value in account storage. -func (self *stateObject) SetState(db Database, key, value common.Hash) { - self.db.journal = append(self.db.journal, storageChange{ - account: &self.address, +func (s *stateObject) SetState(db Database, key, value common.Hash) { + // If the fake storage is set, put the temporary state update here. + if s.fakeStorage != nil { + s.fakeStorage[key] = value + return + } + // If the new value is the same as old, don't set + prev := s.GetState(db, key) + if prev == value { + return + } + // New value is different, update and journal the change + s.db.journal = append(s.db.journal, storageChange{ + account: &s.address, key: key, - prevalue: self.GetState(db, key), + prevalue: prev, }) - self.setState(key, value) + s.setState(key, value) +} + +// SetStorage replaces the entire state storage with the given one. +// +// After this function is called, all original state will be ignored and state +// lookup only happens in the fake state storage. +// +// Note this function should only be used for debugging purpose. +func (s *stateObject) SetStorage(storage map[common.Hash]common.Hash) { + // Allocate fake storage if it's nil. + if s.fakeStorage == nil { + s.fakeStorage = make(Storage) + } + for key, value := range storage { + s.fakeStorage[key] = value + } + // Don't bother journal since this function should only be used for + // debugging and the `fake` storage won't be committed to database. } -func (self *stateObject) setState(key, value common.Hash) { - self.cachedStorage[key] = value - self.dirtyStorage[key] = value +func (s *stateObject) setState(key, value common.Hash) { + s.cachedStorage[key] = value + s.dirtyStorage[key] = value - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil + if s.onDirty != nil { + s.onDirty(s.Address()) + s.onDirty = nil } } // updateTrie writes cached storage modifications into the object's storage trie. -func (self *stateObject) updateTrie(db Database) Trie { - tr := self.getTrie(db) - for key, value := range self.dirtyStorage { - delete(self.dirtyStorage, key) +func (s *stateObject) updateTrie(db Database) Trie { + tr := s.getTrie(db) + for key, value := range s.dirtyStorage { + delete(s.dirtyStorage, key) if (value == common.Hash{}) { - self.setError(tr.TryDelete(key[:])) + s.setError(tr.TryDelete(key[:])) continue } // Encoding []byte cannot fail, ok to ignore the error. v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00")) - self.setError(tr.TryUpdate(key[:], v)) + s.setError(tr.TryUpdate(key[:], v)) } return tr } // UpdateRoot sets the trie root to the current root hash of -func (self *stateObject) updateRoot(db Database) { - self.updateTrie(db) - self.data.Root = self.trie.Hash() +func (s *stateObject) updateRoot(db Database) { + s.updateTrie(db) + s.data.Root = s.trie.Hash() } // CommitTrie the storage trie of the object to dwb. // This updates the trie root. -func (self *stateObject) CommitTrie(db Database) error { - self.updateTrie(db) - if self.dbErr != nil { - return self.dbErr +func (s *stateObject) CommitTrie(db Database) error { + s.updateTrie(db) + if s.dbErr != nil { + return s.dbErr } - root, err := self.trie.Commit(nil) + root, err := s.trie.Commit(nil) if err == nil { - self.data.Root = root + s.data.Root = root } return err } // AddBalance removes amount from c's balance. // It is used to add funds to the destination account of a transfer. -func (c *stateObject) AddBalance(amount *big.Int) { +func (s *stateObject) AddBalance(amount *big.Int) { // EIP158: We must check emptiness for the objects such that the account // clearing (0,0,0 objects) can take effect. if amount.Sign() == 0 { - if c.empty() { - c.touch() + if s.empty() { + s.touch() } return } - c.SetBalance(new(big.Int).Add(c.Balance(), amount)) + s.SetBalance(new(big.Int).Add(s.Balance(), amount)) } // SubBalance removes amount from c's balance. // It is used to remove funds from the origin account of a transfer. -func (c *stateObject) SubBalance(amount *big.Int) { +func (s *stateObject) SubBalance(amount *big.Int) { if amount.Sign() == 0 { return } - c.SetBalance(new(big.Int).Sub(c.Balance(), amount)) + s.SetBalance(new(big.Int).Sub(s.Balance(), amount)) } -func (self *stateObject) SetBalance(amount *big.Int) { - self.db.journal = append(self.db.journal, balanceChange{ - account: &self.address, - prev: new(big.Int).Set(self.data.Balance), +func (s *stateObject) SetBalance(amount *big.Int) { + s.db.journal = append(s.db.journal, balanceChange{ + account: &s.address, + prev: new(big.Int).Set(s.data.Balance), }) - self.setBalance(amount) + s.setBalance(amount) } -func (self *stateObject) setBalance(amount *big.Int) { - self.data.Balance = amount - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil +func (s *stateObject) setBalance(amount *big.Int) { + s.data.Balance = amount + if s.onDirty != nil { + s.onDirty(s.Address()) + s.onDirty = nil } } -func (self *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *stateObject { - stateObject := newObject(db, self.address, self.data, onDirty) - if self.trie != nil { - stateObject.trie = db.db.CopyTrie(self.trie) +func (s *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *stateObject { + stateObject := newObject(db, s.address, s.data, onDirty) + if s.trie != nil { + stateObject.trie = db.db.CopyTrie(s.trie) } - stateObject.code = self.code - stateObject.dirtyStorage = self.dirtyStorage.Copy() - stateObject.cachedStorage = self.dirtyStorage.Copy() - stateObject.suicided = self.suicided - stateObject.dirtyCode = self.dirtyCode - stateObject.deleted = self.deleted + stateObject.code = s.code + stateObject.dirtyStorage = s.dirtyStorage.Copy() + stateObject.cachedStorage = s.dirtyStorage.Copy() + stateObject.suicided = s.suicided + stateObject.dirtyCode = s.dirtyCode + stateObject.deleted = s.deleted return stateObject } @@ -325,81 +363,81 @@ func (self *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address) // // Returns the address of the contract/account -func (c *stateObject) Address() common.Address { - return c.address +func (s *stateObject) Address() common.Address { + return s.address } // Code returns the contract code associated with this object, if any. -func (self *stateObject) Code(db Database) []byte { - if self.code != nil { - return self.code +func (s *stateObject) Code(db Database) []byte { + if s.code != nil { + return s.code } - if bytes.Equal(self.CodeHash(), emptyCodeHash) { + if bytes.Equal(s.CodeHash(), emptyCodeHash) { return nil } - code, err := db.ContractCode(self.addrHash, common.BytesToHash(self.CodeHash())) + code, err := db.ContractCode(s.addrHash, common.BytesToHash(s.CodeHash())) if err != nil { - self.setError(fmt.Errorf("can't load code hash %x: %v", self.CodeHash(), err)) + s.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err)) } - self.code = code + s.code = code return code } -func (self *stateObject) SetCode(codeHash common.Hash, code []byte) { - prevcode := self.Code(self.db.db) - self.db.journal = append(self.db.journal, codeChange{ - account: &self.address, - prevhash: self.CodeHash(), +func (s *stateObject) SetCode(codeHash common.Hash, code []byte) { + prevcode := s.Code(s.db.db) + s.db.journal = append(s.db.journal, codeChange{ + account: &s.address, + prevhash: s.CodeHash(), prevcode: prevcode, }) - self.setCode(codeHash, code) + s.setCode(codeHash, code) } -func (self *stateObject) setCode(codeHash common.Hash, code []byte) { - self.code = code - self.data.CodeHash = codeHash[:] - self.dirtyCode = true - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil +func (s *stateObject) setCode(codeHash common.Hash, code []byte) { + s.code = code + s.data.CodeHash = codeHash[:] + s.dirtyCode = true + if s.onDirty != nil { + s.onDirty(s.Address()) + s.onDirty = nil } } -func (self *stateObject) SetNonce(nonce uint64) { - self.db.journal = append(self.db.journal, nonceChange{ - account: &self.address, - prev: self.data.Nonce, +func (s *stateObject) SetNonce(nonce uint64) { + s.db.journal = append(s.db.journal, nonceChange{ + account: &s.address, + prev: s.data.Nonce, }) - self.setNonce(nonce) + s.setNonce(nonce) } -func (self *stateObject) setNonce(nonce uint64) { - self.data.Nonce = nonce - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil +func (s *stateObject) setNonce(nonce uint64) { + s.data.Nonce = nonce + if s.onDirty != nil { + s.onDirty(s.Address()) + s.onDirty = nil } } -func (self *stateObject) CodeHash() []byte { - return self.data.CodeHash +func (s *stateObject) CodeHash() []byte { + return s.data.CodeHash } -func (self *stateObject) Balance() *big.Int { - return self.data.Balance +func (s *stateObject) Balance() *big.Int { + return s.data.Balance } -func (self *stateObject) Nonce() uint64 { - return self.data.Nonce +func (s *stateObject) Nonce() uint64 { + return s.data.Nonce } -func (self *stateObject) Root() common.Hash { - return self.data.Root +func (s *stateObject) Root() common.Hash { + return s.data.Root } // Never called, but must be present to allow stateObject to be used // as a vm.Account interface that also satisfies the vm.ContractRef // interface. Interfaces are awesome. -func (self *stateObject) Value() *big.Int { +func (s *stateObject) Value() *big.Int { panic("Value on stateObject should never be called") } diff --git a/core/state/statedb.go b/core/state/statedb.go index 0afd7554dcee..e73273a71069 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -387,6 +387,15 @@ func (self *StateDB) SetState(addr common.Address, key, value common.Hash) { } } +// SetStorage replaces the entire storage for the specified account with given +// storage. This function should only be used for debugging. +func (s *StateDB) SetStorage(addr common.Address, storage map[common.Hash]common.Hash) { + stateObject := s.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.SetStorage(storage) + } +} + // Suicide marks the given account as suicided. // This clears the account balance. // diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index 886a495eb911..272e64ff7b55 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -653,6 +653,58 @@ func (s *PublicBlockChainAPI) GetStorageAt(ctx context.Context, address common.A return res[:], state.Error() } +// OverrideAccount indicates the overriding fields of account during the execution +// of a message call. +// Note, state and stateDiff can't be specified at the same time. If state is +// set, message execution will only use the data in the given state. Otherwise +// if statDiff is set, all diff will be applied first and then execute the call +// message. +type OverrideAccount struct { + Nonce *hexutil.Uint64 `json:"nonce"` + Code *hexutil.Bytes `json:"code"` + Balance **hexutil.Big `json:"balance"` + State *map[common.Hash]common.Hash `json:"state"` + StateDiff *map[common.Hash]common.Hash `json:"stateDiff"` +} + +// StateOverride is the collection of overridden accounts. +type StateOverride map[common.Address]OverrideAccount + +// Apply overrides the fields of specified accounts into the given state. +func (diff *StateOverride) Apply(state *state.StateDB) error { + if diff == nil { + return nil + } + for addr, account := range *diff { + // Override account nonce. + if account.Nonce != nil { + state.SetNonce(addr, uint64(*account.Nonce)) + } + // Override account(contract) code. + if account.Code != nil { + state.SetCode(addr, *account.Code) + } + // Override account balance. + if account.Balance != nil { + state.SetBalance(addr, (*big.Int)(*account.Balance)) + } + if account.State != nil && account.StateDiff != nil { + return fmt.Errorf("account %s has both 'state' and 'stateDiff'", addr.Hex()) + } + // Replace entire state if caller requires. + if account.State != nil { + state.SetStorage(addr, *account.State) + } + // Apply state diff into specified accounts. + if account.StateDiff != nil { + for key, value := range *account.StateDiff { + state.SetState(addr, key, value) + } + } + } + return nil +} + func (s *PublicBlockChainAPI) GetBlockSignersByHash(ctx context.Context, blockHash common.Hash) ([]common.Address, error) { block, err := s.b.GetBlock(ctx, blockHash) if err != nil || block == nil { @@ -1113,13 +1165,16 @@ type CallArgs struct { Data hexutil.Bytes `json:"data"` } -func DoCall(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.BlockNumberOrHash, vmCfg vm.Config, timeout time.Duration) ([]byte, uint64, bool, error, error) { +func DoCall(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *StateOverride, vmCfg vm.Config, timeout time.Duration) ([]byte, uint64, bool, error, error) { defer func(start time.Time) { log.Debug("Executing EVM call finished", "runtime", time.Since(start)) }(time.Now()) statedb, header, err := b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) if statedb == nil || err != nil { return nil, 0, false, err, nil } + if err := overrides.Apply(statedb); err != nil { + return nil, 0, false, err, nil + } // Set sender address or use a default if none specified addr := args.From if addr == (common.Address{}) { @@ -1229,12 +1284,12 @@ func (e *revertError) ErrorData() interface{} { // Call executes the given transaction on the state for the given block number. // It doesn't make and changes in the state/blockchain and is useful to execute and retrieve values. -func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNrOrHash *rpc.BlockNumberOrHash) (hexutil.Bytes, error) { +func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNrOrHash *rpc.BlockNumberOrHash, overrides *StateOverride) (hexutil.Bytes, error) { if blockNrOrHash == nil { latest := rpc.BlockNumberOrHashWithNumber(rpc.LatestBlockNumber) blockNrOrHash = &latest } - result, _, failed, err, vmErr := DoCall(ctx, s.b, args, *blockNrOrHash, vm.Config{}, 5*time.Second) + result, _, failed, err, vmErr := DoCall(ctx, s.b, args, *blockNrOrHash, overrides, vm.Config{}, 5*time.Second) if err != nil { return nil, err } @@ -1275,7 +1330,7 @@ func DoEstimateGas(ctx context.Context, b Backend, args CallArgs, blockNrOrHash executable := func(gas uint64) (bool, []byte, error, error) { args.Gas = hexutil.Uint64(gas) - res, _, failed, err, vmErr := DoCall(ctx, b, args, blockNrOrHash, vm.Config{}, 0) + res, _, failed, err, vmErr := DoCall(ctx, b, args, blockNrOrHash, nil, vm.Config{}, 0) if err != nil { if errors.Is(err, vm.ErrOutOfGas) || errors.Is(err, core.ErrIntrinsicGas) { return false, nil, nil, nil // Special case, raise gas limit