From a543aa53a178764bd711cd9a1c84d0d9605a5893 Mon Sep 17 00:00:00 2001 From: Alex Peters Date: Fri, 25 Mar 2022 10:56:21 +0100 Subject: [PATCH] Use callback pattern for contract state iterator --- x/wasm/keeper/genesis.go | 15 ++++----------- x/wasm/keeper/keeper.go | 13 +++++++++++-- x/wasm/keeper/legacy_querier.go | 13 ++++--------- x/wasm/types/exported_keepers.go | 2 +- 4 files changed, 20 insertions(+), 23 deletions(-) diff --git a/x/wasm/keeper/genesis.go b/x/wasm/keeper/genesis.go index 3c95e1f8cd..7fa5280b9e 100644 --- a/x/wasm/keeper/genesis.go +++ b/x/wasm/keeper/genesis.go @@ -102,15 +102,11 @@ func ExportGenesis(ctx sdk.Context, keeper *Keeper) *types.GenesisState { }) keeper.IterateContractInfo(ctx, func(addr sdk.AccAddress, contract types.ContractInfo) bool { - contractStateIterator := keeper.GetContractState(ctx, addr) var state []types.Model - for ; contractStateIterator.Valid(); contractStateIterator.Next() { - m := types.Model{ - Key: contractStateIterator.Key(), - Value: contractStateIterator.Value(), - } - state = append(state, m) - } + keeper.IterateContractState(ctx, addr, func(key, value []byte) bool { + state = append(state, types.Model{Key: key, Value: value}) + return false + }) // redact contract info contract.Created = nil genState.Contracts = append(genState.Contracts, types.Contract{ @@ -118,9 +114,6 @@ func ExportGenesis(ctx sdk.Context, keeper *Keeper) *types.GenesisState { ContractInfo: contract, ContractState: state, }) - - contractStateIterator.Close() - return false }) diff --git a/x/wasm/keeper/keeper.go b/x/wasm/keeper/keeper.go index 2f897855df..ba47392396 100644 --- a/x/wasm/keeper/keeper.go +++ b/x/wasm/keeper/keeper.go @@ -688,10 +688,19 @@ func (k Keeper) IterateContractInfo(ctx sdk.Context, cb func(sdk.AccAddress, typ } } -func (k Keeper) GetContractState(ctx sdk.Context, contractAddress sdk.AccAddress) sdk.Iterator { +// IterateContractState iterates through all elements of the key value store for the given contract address and passes +// them to the provided callback function. The callback method can return true to abort early. +func (k Keeper) IterateContractState(ctx sdk.Context, contractAddress sdk.AccAddress, cb func(key, value []byte) bool) { prefixStoreKey := types.GetContractStorePrefix(contractAddress) prefixStore := prefix.NewStore(ctx.KVStore(k.storeKey), prefixStoreKey) - return prefixStore.Iterator(nil, nil) + iter := prefixStore.Iterator(nil, nil) + defer iter.Close() + + for ; iter.Valid(); iter.Next() { + if cb(iter.Key(), iter.Value()) { + break + } + } } func (k Keeper) importContractState(ctx sdk.Context, contractAddress sdk.AccAddress, models []types.Model) error { diff --git a/x/wasm/keeper/legacy_querier.go b/x/wasm/keeper/legacy_querier.go index 78e23f0690..f12e8b0ebc 100644 --- a/x/wasm/keeper/legacy_querier.go +++ b/x/wasm/keeper/legacy_querier.go @@ -93,15 +93,10 @@ func queryContractState(ctx sdk.Context, bech, queryMethod string, data []byte, case QueryMethodContractStateAll: resultData := make([]types.Model, 0) // this returns a serialized json object (which internally encoded binary fields properly) - iter := keeper.GetContractState(ctx, contractAddr) - defer iter.Close() - - for ; iter.Valid(); iter.Next() { - resultData = append(resultData, types.Model{ - Key: iter.Key(), - Value: iter.Value(), - }) - } + keeper.IterateContractState(ctx, contractAddr, func(key, value []byte) bool { + resultData = append(resultData, types.Model{Key: key, Value: value}) + return false + }) bz, err := json.Marshal(resultData) if err != nil { return nil, sdkerrors.Wrap(sdkerrors.ErrJSONMarshal, err.Error()) diff --git a/x/wasm/types/exported_keepers.go b/x/wasm/types/exported_keepers.go index 525d219404..515b946fa8 100644 --- a/x/wasm/types/exported_keepers.go +++ b/x/wasm/types/exported_keepers.go @@ -15,7 +15,7 @@ type ViewKeeper interface { GetContractInfo(ctx sdk.Context, contractAddress sdk.AccAddress) *ContractInfo IterateContractInfo(ctx sdk.Context, cb func(sdk.AccAddress, ContractInfo) bool) IterateContractsByCode(ctx sdk.Context, codeID uint64, cb func(address sdk.AccAddress) bool) - GetContractState(ctx sdk.Context, contractAddress sdk.AccAddress) sdk.Iterator + IterateContractState(ctx sdk.Context, contractAddress sdk.AccAddress, cb func(key, value []byte) bool) GetCodeInfo(ctx sdk.Context, codeID uint64) *CodeInfo IterateCodeInfos(ctx sdk.Context, cb func(uint64, CodeInfo) bool) GetByteCode(ctx sdk.Context, codeID uint64) ([]byte, error)