diff --git a/dot/rpc/modules/api.go b/dot/rpc/modules/api.go index fd60988506..9df87a4afa 100644 --- a/dot/rpc/modules/api.go +++ b/dot/rpc/modules/api.go @@ -19,6 +19,7 @@ type StorageAPI interface { RegisterStorageChangeChannel(sub state.StorageSubscription) (byte, error) UnregisterStorageChangeChannel(id byte) GetStateRootFromBlock(bhash *common.Hash) (*common.Hash, error) + GetKeysWithPrefix(root *common.Hash, prefix []byte) ([][]byte, error) } // BlockAPI is the interface for the block state diff --git a/dot/rpc/modules/state.go b/dot/rpc/modules/state.go index 4a10e143e6..fcc5e9741a 100644 --- a/dot/rpc/modules/state.go +++ b/dot/rpc/modules/state.go @@ -18,7 +18,9 @@ package modules import ( "encoding/hex" + "fmt" "net/http" + "strings" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/runtime" @@ -41,8 +43,10 @@ type StateChildStorageRequest struct { // StateStorageKeyRequest holds json fields type StateStorageKeyRequest struct { - Key []byte `json:"key"` - Block *common.Hash `json:"block"` + Prefix string `json:"prefix"` + Qty uint32 `json:"qty"` + AfterKey string `json:"afterKey"` + Block *common.Hash `json:"block"` } // StateRuntimeMetadataQuery is a hash value @@ -117,7 +121,7 @@ type StateStorageResponse string type StatePairResponse []interface{} // StateStorageKeysResponse field for storage keys -type StateStorageKeysResponse [][]byte +type StateStorageKeysResponse []string // StateMetadataResponse holds the metadata //TODO: Determine actual type @@ -233,10 +237,30 @@ func (sm *StateModule) GetChildStorageSize(r *http.Request, req *StateChildStora return nil } -// GetKeys isn't implemented properly yet. -func (sm *StateModule) GetKeys(r *http.Request, req *StateStorageKeyRequest, res *StateStorageKeysResponse) error { - // TODO implement change storage trie so that block hash parameter works (See issue #834) - return nil +// GetKeysPaged Returns the keys with prefix with pagination support. +func (sm *StateModule) GetKeysPaged(r *http.Request, req *StateStorageKeyRequest, res *StateStorageKeysResponse) error { + if len(req.Prefix) == 0 { + req.Prefix = "0x" + } + hPrefix, err := common.HexToBytes(req.Prefix) + if err != nil { + return err + } + keys, err := sm.storageAPI.GetKeysWithPrefix(req.Block, hPrefix) + resCount := uint32(0) + for _, k := range keys { + fKey := fmt.Sprintf("0x%x", k) + if strings.Compare(fKey, req.AfterKey) == 1 { + // sm.storageAPI.Keys sorts keys in lexicographical order, so we know that keys where strings.Compare = 1 + // are after the requested after key. + if resCount >= req.Qty { + break + } + *res = append(*res, fKey) + resCount++ + } + } + return err } // GetMetadata calls runtime Metadata_metadata function diff --git a/dot/rpc/modules/state_test.go b/dot/rpc/modules/state_test.go index 9d64b790a2..08fb9a7599 100644 --- a/dot/rpc/modules/state_test.go +++ b/dot/rpc/modules/state_test.go @@ -57,7 +57,7 @@ func TestStateModule_GetRuntimeVersion(t *testing.T) { }, } - sm, hash := setupStateModule(t) + sm, hash, _ := setupStateModule(t) randomHash, err := common.HexToHash(RandomHash) require.NoError(t, err) @@ -98,7 +98,7 @@ func TestStateModule_GetRuntimeVersion(t *testing.T) { } func TestStateModule_GetPairs(t *testing.T) { - sm, hash := setupStateModule(t) + sm, hash, _ := setupStateModule(t) randomHash, err := common.HexToHash(RandomHash) require.NoError(t, err) @@ -166,7 +166,7 @@ func TestStateModule_GetPairs(t *testing.T) { } func TestStateModule_GetStorage(t *testing.T) { - sm, hash := setupStateModule(t) + sm, hash, _ := setupStateModule(t) randomHash, err := common.HexToHash(RandomHash) require.NoError(t, err) @@ -216,7 +216,7 @@ func TestStateModule_GetStorage(t *testing.T) { } func TestStateModule_GetStorageHash(t *testing.T) { - sm, hash := setupStateModule(t) + sm, hash, _ := setupStateModule(t) randomHash, err := common.HexToHash(RandomHash) require.NoError(t, err) @@ -268,7 +268,7 @@ func TestStateModule_GetStorageHash(t *testing.T) { } func TestStateModule_GetStorageSize(t *testing.T) { - sm, hash := setupStateModule(t) + sm, hash, _ := setupStateModule(t) randomHash, err := common.HexToHash(RandomHash) require.NoError(t, err) @@ -313,7 +313,7 @@ func TestStateModule_GetStorageSize(t *testing.T) { } func TestStateModule_GetMetadata(t *testing.T) { - sm, hash := setupStateModule(t) + sm, hash, _ := setupStateModule(t) randomHash, err := common.HexToHash(RandomHash) require.NoError(t, err) @@ -353,7 +353,68 @@ func TestStateModule_GetMetadata(t *testing.T) { } } -func setupStateModule(t *testing.T) (*StateModule, *common.Hash) { +func TestStateModule_GetKeysPaged(t *testing.T) { + sm, _, stateRootHash := setupStateModule(t) + + testCases := []struct { + name string + params StateStorageKeyRequest + expected []string + }{ + {name: "allKeysNilBlockHash", + params: StateStorageKeyRequest{ + Qty: 10, + Block: nil, + }, expected: []string{"0x3a6b657931", "0x3a6b657932"}}, + {name: "allKeysTestBlockHash", + params: StateStorageKeyRequest{ + Qty: 10, + Block: stateRootHash, + }, expected: []string{"0x3a6b657931", "0x3a6b657932"}}, + {name: "prefixMatchAll", + params: StateStorageKeyRequest{ + Prefix: "0x3a6b6579", + Qty: 10, + }, expected: []string{"0x3a6b657931", "0x3a6b657932"}}, + {name: "prefixMatchOne", + params: StateStorageKeyRequest{ + Prefix: "0x3a6b657931", + Qty: 10, + }, expected: []string{"0x3a6b657931"}}, + {name: "prefixMatchNone", + params: StateStorageKeyRequest{ + Prefix: "0x00", + Qty: 10, + }, expected: nil}, + {name: "qtyOne", + params: StateStorageKeyRequest{ + Qty: 1, + }, expected: []string{"0x3a6b657931"}}, + {name: "afterKey", + params: StateStorageKeyRequest{ + Qty: 10, + AfterKey: "0x3a6b657931", + }, expected: []string{"0x3a6b657932"}}, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + var res StateStorageKeysResponse + + err := sm.GetKeysPaged(nil, &test.params, &res) + require.NoError(t, err) + + if test.expected == nil { + require.Empty(t, res) + return + } + + require.Equal(t, StateStorageKeysResponse(test.expected), res) + }) + } +} + +func setupStateModule(t *testing.T) (*StateModule, *common.Hash, *common.Hash) { // setup service net := newNetworkService(t) chain := newTestStateService(t) @@ -361,10 +422,10 @@ func setupStateModule(t *testing.T) (*StateModule, *common.Hash) { ts, err := chain.Storage.TrieState(nil) require.NoError(t, err) - err = ts.Set([]byte(`:key1`), []byte(`value1`)) - require.NoError(t, err) err = ts.Set([]byte(`:key2`), []byte(`value2`)) require.NoError(t, err) + err = ts.Set([]byte(`:key1`), []byte(`value1`)) + require.NoError(t, err) sr1, err := ts.Root() require.NoError(t, err) @@ -383,5 +444,5 @@ func setupStateModule(t *testing.T) (*StateModule, *common.Hash) { hash, _ := chain.Block.GetBlockHash(big.NewInt(2)) core := newCoreService(t, chain) - return NewStateModule(net, chain.Storage, core), hash + return NewStateModule(net, chain.Storage, core), hash, &sr1 } diff --git a/dot/rpc/websocket_test.go b/dot/rpc/websocket_test.go index dc613f3a4a..abfd9f6f60 100644 --- a/dot/rpc/websocket_test.go +++ b/dot/rpc/websocket_test.go @@ -140,3 +140,6 @@ func (m *MockStorageAPI) UnregisterStorageChangeChannel(id byte) { func (m *MockStorageAPI) GetStateRootFromBlock(bhash *common.Hash) (*common.Hash, error) { return nil, nil } +func (m *MockStorageAPI) GetKeysWithPrefix(root *common.Hash, prefix []byte) ([][]byte, error) { + return nil, nil +} diff --git a/dot/state/storage.go b/dot/state/storage.go index 7abf31dc3b..c588b28d1f 100644 --- a/dot/state/storage.go +++ b/dot/state/storage.go @@ -321,6 +321,22 @@ func (s *StorageState) Entries(hash *common.Hash) (map[string][]byte, error) { return s.tries[*hash].Entries(), nil } +// GetKeysWithPrefix returns all that match the given prefix for the given hash (or best block state root if hash is nil) in lexicographic order +func (s *StorageState) GetKeysWithPrefix(hash *common.Hash, prefix []byte) ([][]byte, error) { + if hash == nil { + sr, err := s.blockState.BestBlockStateRoot() + if err != nil { + return nil, err + } + hash = &sr + } + t := s.tries[*hash] + if t == nil { + return nil, fmt.Errorf("unable to retrieve trie with hash %x", *hash) + } + return t.GetKeysWithPrefix(prefix), nil +} + // GetStorageChild return GetChild from the trie func (s *StorageState) GetStorageChild(hash *common.Hash, keyToChild []byte) (*trie.Trie, error) { if hash == nil { diff --git a/lib/trie/trie.go b/lib/trie/trie.go index 2791d68f53..e2b7a85d95 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -379,10 +379,14 @@ func (t *Trie) Load(data map[string]string) error { // GetKeysWithPrefix returns all keys in the trie that have the given prefix func (t *Trie) GetKeysWithPrefix(prefix []byte) [][]byte { - p := keyToNibbles(prefix) - if p[len(p)-1] == 0 { - p = p[:len(p)-1] + p := []byte{} + if len(prefix) != 0 { + p = keyToNibbles(prefix) + if p[len(p)-1] == 0 { + p = p[:len(p)-1] + } } + return t.getKeysWithPrefix(t.root, []byte{}, p, [][]byte{}) } @@ -397,6 +401,12 @@ func (t *Trie) getKeysWithPrefix(parent node, prefix, key []byte, keys [][]byte) return keys } + if len(key) <= len(p.key) { + // no prefixed keys to be found here, return + return keys + } + + key = key[len(p.key):] keys = t.getKeysWithPrefix(p.children[key[0]], append(append(prefix, p.key...), key[0]), key[1:], keys) case *leaf: keys = append(keys, nibblesToKeyLE(append(prefix, p.key...))) diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index dae45b9681..22fdc25662 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -24,7 +24,6 @@ import ( "math/rand" "os" "path/filepath" - "reflect" "strconv" "strings" "testing" @@ -553,6 +552,8 @@ func TestGetKeysWithPrefix(t *testing.T) { {key: []byte{0x07, 0x3a}, value: []byte("ramen"), op: PUT}, {key: []byte{0x07, 0x3b}, value: []byte("noodles"), op: PUT}, {key: []byte{0xf2}, value: []byte("pho"), op: PUT}, + {key: []byte(":key1"), value: []byte("value1"), op: PUT}, + {key: []byte(":key2"), value: []byte("value2"), op: PUT}, } for _, test := range tests { @@ -561,21 +562,19 @@ func TestGetKeysWithPrefix(t *testing.T) { expected := [][]byte{{0x01, 0x35}, {0x01, 0x35, 0x79}} keys := trie.GetKeysWithPrefix([]byte{0x01}) - if !reflect.DeepEqual(keys, expected) { - t.Fatalf("Fail: got %v expected %v", keys, expected) - } + require.Equal(t, expected, keys) expected = [][]byte{{0x01, 0x35}, {0x01, 0x35, 0x79}, {0x07, 0x3a}, {0x07, 0x3b}} keys = trie.GetKeysWithPrefix([]byte{0x0}) - if !reflect.DeepEqual(keys, expected) { - t.Fatalf("Fail: got %v expected %v", keys, expected) - } + require.Equal(t, expected, keys) expected = [][]byte{{0x07, 0x3a}, {0x07, 0x3b}} keys = trie.GetKeysWithPrefix([]byte{0x07, 0x30}) - if !reflect.DeepEqual(keys, expected) { - t.Fatalf("Fail: got %v expected %v", keys, expected) - } + require.Equal(t, expected, keys) + + expected = [][]byte{[]byte(":key1")} + keys = trie.GetKeysWithPrefix([]byte(":key1")) + require.Equal(t, expected, keys) } func TestNextKey(t *testing.T) {