Skip to content

Commit

Permalink
Support self calling contract on instantiation (#300)
Browse files Browse the repository at this point in the history
* Support self calling contract on instantiation

* Review feedback

* Review feedback
  • Loading branch information
alpe authored Nov 9, 2020
1 parent fbd7168 commit 4fb3a50
Show file tree
Hide file tree
Showing 15 changed files with 299 additions and 89 deletions.
4 changes: 2 additions & 2 deletions app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ func NewWasmApp(logger log.Logger, db dbm.DB, traceStore io.Writer, loadLatest b
distr.NewAppModule(appCodec, app.distrKeeper, app.accountKeeper, app.bankKeeper, app.stakingKeeper),
staking.NewAppModule(appCodec, app.stakingKeeper, app.accountKeeper, app.bankKeeper),
upgrade.NewAppModule(app.upgradeKeeper),
wasm.NewAppModule(app.wasmKeeper),
wasm.NewAppModule(&app.wasmKeeper),
evidence.NewAppModule(app.evidenceKeeper),
ibc.NewAppModule(app.ibcKeeper),
params.NewAppModule(app.paramsKeeper),
Expand Down Expand Up @@ -472,7 +472,7 @@ func NewWasmApp(logger log.Logger, db dbm.DB, traceStore io.Writer, loadLatest b
distr.NewAppModule(appCodec, app.distrKeeper, app.accountKeeper, app.bankKeeper, app.stakingKeeper),
slashing.NewAppModule(appCodec, app.slashingKeeper, app.accountKeeper, app.bankKeeper, app.stakingKeeper),
params.NewAppModule(app.paramsKeeper),
wasm.NewAppModule(app.wasmKeeper),
wasm.NewAppModule(&app.wasmKeeper),
evidence.NewAppModule(app.evidenceKeeper),
ibc.NewAppModule(app.ibcKeeper),
transferModule,
Expand Down
4 changes: 2 additions & 2 deletions x/wasm/genesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,14 @@ func TestInitGenesis(t *testing.T) {
})

// export into genstate
genState := ExportGenesis(data.ctx, data.keeper)
genState := ExportGenesis(data.ctx, &data.keeper)

// create new app to import genstate into
newData := setupTest(t)
q2 := newData.module.LegacyQuerierHandler(nil)

// initialize new app with genstate
InitGenesis(newData.ctx, newData.keeper, *genState)
InitGenesis(newData.ctx, &newData.keeper, *genState)

// run same checks again on newdata, to make sure it was reinitialized correctly
assertCodeList(t, q2, newData.ctx, 1)
Expand Down
14 changes: 7 additions & 7 deletions x/wasm/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

// NewHandler returns a handler for "bank" type messages.
func NewHandler(k Keeper) sdk.Handler {
func NewHandler(k *Keeper) sdk.Handler {
return func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
ctx = ctx.WithEventManager(sdk.NewEventManager())

Expand Down Expand Up @@ -47,7 +47,7 @@ func filteredMessageEvents(manager *sdk.EventManager) []abci.Event {
return res
}

func handleStoreCode(ctx sdk.Context, k Keeper, msg *MsgStoreCode) (*sdk.Result, error) {
func handleStoreCode(ctx sdk.Context, k *Keeper, msg *MsgStoreCode) (*sdk.Result, error) {
err := msg.ValidateBasic()
if err != nil {
return nil, err
Expand All @@ -73,7 +73,7 @@ func handleStoreCode(ctx sdk.Context, k Keeper, msg *MsgStoreCode) (*sdk.Result,
}, nil
}

func handleInstantiate(ctx sdk.Context, k Keeper, msg *MsgInstantiateContract) (*sdk.Result, error) {
func handleInstantiate(ctx sdk.Context, k *Keeper, msg *MsgInstantiateContract) (*sdk.Result, error) {
contractAddr, err := k.Instantiate(ctx, msg.CodeID, msg.Sender, msg.Admin, msg.InitMsg, msg.Label, msg.InitFunds)
if err != nil {
return nil, err
Expand All @@ -95,7 +95,7 @@ func handleInstantiate(ctx sdk.Context, k Keeper, msg *MsgInstantiateContract) (
}, nil
}

func handleExecute(ctx sdk.Context, k Keeper, msg *MsgExecuteContract) (*sdk.Result, error) {
func handleExecute(ctx sdk.Context, k *Keeper, msg *MsgExecuteContract) (*sdk.Result, error) {
res, err := k.Execute(ctx, msg.Contract, msg.Sender, msg.Msg, msg.SentFunds)
if err != nil {
return nil, err
Expand All @@ -115,7 +115,7 @@ func handleExecute(ctx sdk.Context, k Keeper, msg *MsgExecuteContract) (*sdk.Res
return res, nil
}

func handleMigration(ctx sdk.Context, k Keeper, msg *MsgMigrateContract) (*sdk.Result, error) {
func handleMigration(ctx sdk.Context, k *Keeper, msg *MsgMigrateContract) (*sdk.Result, error) {
res, err := k.Migrate(ctx, msg.Contract, msg.Sender, msg.CodeID, msg.MigrateMsg)
if err != nil {
return nil, err
Expand All @@ -133,7 +133,7 @@ func handleMigration(ctx sdk.Context, k Keeper, msg *MsgMigrateContract) (*sdk.R
return res, nil
}

func handleUpdateContractAdmin(ctx sdk.Context, k Keeper, msg *MsgUpdateAdmin) (*sdk.Result, error) {
func handleUpdateContractAdmin(ctx sdk.Context, k *Keeper, msg *MsgUpdateAdmin) (*sdk.Result, error) {
if err := k.UpdateContractAdmin(ctx, msg.Contract, msg.Sender, msg.NewAdmin); err != nil {
return nil, err
}
Expand All @@ -149,7 +149,7 @@ func handleUpdateContractAdmin(ctx sdk.Context, k Keeper, msg *MsgUpdateAdmin) (
}, nil
}

func handleClearContractAdmin(ctx sdk.Context, k Keeper, msg *MsgClearAdmin) (*sdk.Result, error) {
func handleClearContractAdmin(ctx sdk.Context, k *Keeper, msg *MsgClearAdmin) (*sdk.Result, error) {
if err := k.ClearContractAdmin(ctx, msg.Contract, msg.Sender); err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions x/wasm/internal/keeper/genesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
// InitGenesis sets supply information for genesis.
//
// CONTRACT: all types of accounts must have been already initialized/created
func InitGenesis(ctx sdk.Context, keeper Keeper, data types.GenesisState) error {
func InitGenesis(ctx sdk.Context, keeper *Keeper, data types.GenesisState) error {
var maxCodeID uint64
for i, code := range data.Codes {
err := keeper.importCode(ctx, code.CodeID, code.CodeInfo, code.CodeBytes)
Expand Down Expand Up @@ -52,7 +52,7 @@ func InitGenesis(ctx sdk.Context, keeper Keeper, data types.GenesisState) error
}

// ExportGenesis returns a GenesisState for a given context and keeper.
func ExportGenesis(ctx sdk.Context, keeper Keeper) *types.GenesisState {
func ExportGenesis(ctx sdk.Context, keeper *Keeper) *types.GenesisState {
var genState types.GenesisState

genState.Params = keeper.GetParams(ctx)
Expand Down
4 changes: 2 additions & 2 deletions x/wasm/internal/keeper/genesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ func TestImportContractWithCodeHistoryReset(t *testing.T) {
assert.Equal(t, expHistory, keeper.GetContractHistory(ctx, contractAddr).CodeHistoryEntries)
}

func setupKeeper(t *testing.T) (Keeper, sdk.Context, []sdk.StoreKey, func()) {
func setupKeeper(t *testing.T) (*Keeper, sdk.Context, []sdk.StoreKey, func()) {
t.Helper()
tempDir, err := ioutil.TempDir("", "wasm")
require.NoError(t, err)
Expand Down Expand Up @@ -504,5 +504,5 @@ func setupKeeper(t *testing.T) (Keeper, sdk.Context, []sdk.StoreKey, func()) {
srcKeeper := NewKeeper(encodingConfig.Marshaler, keyWasm, pk.Subspace(wasmTypes.DefaultParamspace), authkeeper.AccountKeeper{}, nil, stakingkeeper.Keeper{}, distributionkeeper.Keeper{}, nil, tempDir, wasmConfig, "", nil, nil)
srcKeeper.setParams(ctx, wasmTypes.DefaultParams())

return srcKeeper, ctx, []sdk.StoreKey{keyWasm, keyParams}, cleanup
return &srcKeeper, ctx, []sdk.StoreKey{keyWasm, keyParams}, cleanup
}
16 changes: 9 additions & 7 deletions x/wasm/internal/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type Keeper struct {
accountKeeper authkeeper.AccountKeeper
bankKeeper bankkeeper.Keeper

wasmer wasm.Wasmer
wasmer types.WasmerEngine
queryPlugins QueryPlugins
messenger MessageHandler
// queryGasLimit is the max wasm gas that can be spent on executing a query with a contract
Expand Down Expand Up @@ -86,7 +86,7 @@ func NewKeeper(
keeper := Keeper{
storeKey: storeKey,
cdc: cdc,
wasmer: *wasmer,
wasmer: wasmer,
accountKeeper: accountKeeper,
bankKeeper: bankKeeper,
messenger: NewMessageHandler(router, customEncoders),
Expand Down Expand Up @@ -254,16 +254,18 @@ func (k Keeper) instantiate(ctx sdk.Context, codeID uint64, creator, admin sdk.A
events := types.ParseEvents(res.Attributes, contractAddress)
ctx.EventManager().EmitEvents(events)

// persist instance first
createdAt := types.NewAbsoluteTxPosition(ctx)
instance := types.NewContractInfo(codeID, creator, admin, label, createdAt)
store.Set(types.GetContractAddressKey(contractAddress), k.cdc.MustMarshalBinaryBare(&instance))
k.appendToContractHistory(ctx, contractAddress, instance.InitialHistory(initMsg))

// then dispatch so that contract could be called back
err = k.dispatchMessages(ctx, contractAddress, res.Messages)
if err != nil {
return nil, err
}

// persist instance
createdAt := types.NewAbsoluteTxPosition(ctx)
instance := types.NewContractInfo(codeID, creator, admin, label, createdAt)
store.Set(types.GetContractAddressKey(contractAddress), k.cdc.MustMarshalBinaryBare(&instance))
k.appendToContractHistory(ctx, contractAddress, instance.InitialHistory(initMsg))
return contractAddress, nil
}

Expand Down
15 changes: 15 additions & 0 deletions x/wasm/internal/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,21 @@ func TestInstantiateWithNonExistingCodeID(t *testing.T) {
require.Nil(t, addr)
}

func TestInstantiateWithCallbackToContract(t *testing.T) {
ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil)
var (
executeCalled bool
err error
)
wasmerMock := selfCallingInstMockWasmer(&executeCalled)

keepers.WasmKeeper.wasmer = wasmerMock
example := StoreHackatomExampleContract(t, ctx, keepers)
_, err = keepers.WasmKeeper.Instantiate(ctx, example.CodeID, example.CreatorAddr, nil, nil, "test", nil)
require.NoError(t, err)
assert.True(t, executeCalled)
}

func TestExecute(t *testing.T) {
ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, nil)
accKeeper, keeper, bankKeeper := keepers.AccountKeeper, keepers.WasmKeeper, keepers.BankKeeper
Expand Down
12 changes: 6 additions & 6 deletions x/wasm/internal/keeper/legacy_querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const (
)

// NewLegacyQuerier creates a new querier
func NewLegacyQuerier(keeper Keeper) sdk.Querier {
func NewLegacyQuerier(keeper *Keeper) sdk.Querier {
return func(ctx sdk.Context, path []string, req abci.RequestQuery) ([]byte, error) {
var (
rsp interface{}
Expand All @@ -39,13 +39,13 @@ func NewLegacyQuerier(keeper Keeper) sdk.Querier {
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidAddress, err.Error())
}
rsp, err = queryContractInfo(ctx, addr, keeper)
rsp, err = queryContractInfo(ctx, addr, *keeper)
case QueryListContractByCode:
codeID, err := strconv.ParseUint(path[1], 10, 64)
if err != nil {
return nil, sdkerrors.Wrapf(types.ErrInvalid, "code id: %s", err.Error())
}
rsp, err = queryContractListByCode(ctx, codeID, keeper)
rsp, err = queryContractListByCode(ctx, codeID, *keeper)
case QueryGetContractState:
if len(path) < 3 {
return nil, sdkerrors.Wrap(sdkerrors.ErrUnknownRequest, "unknown data query endpoint")
Expand All @@ -58,13 +58,13 @@ func NewLegacyQuerier(keeper Keeper) sdk.Querier {
}
rsp, err = queryCode(ctx, codeID, keeper)
case QueryListCode:
rsp, err = queryCodeList(ctx, keeper)
rsp, err = queryCodeList(ctx, *keeper)
case QueryContractHistory:
contractAddr, err := sdk.AccAddressFromBech32(path[1])
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidAddress, err.Error())
}
rsp, err = queryContractHistory(ctx, contractAddr, keeper)
rsp, err = queryContractHistory(ctx, contractAddr, *keeper)
default:
return nil, sdkerrors.Wrap(sdkerrors.ErrUnknownRequest, "unknown data query endpoint")
}
Expand All @@ -82,7 +82,7 @@ func NewLegacyQuerier(keeper Keeper) sdk.Querier {
}
}

func queryContractState(ctx sdk.Context, bech, queryMethod string, data []byte, keeper Keeper) (json.RawMessage, error) {
func queryContractState(ctx sdk.Context, bech, queryMethod string, data []byte, keeper *Keeper) (json.RawMessage, error) {
contractAddr, err := sdk.AccAddressFromBech32(bech)
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidAddress, bech)
Expand Down
14 changes: 7 additions & 7 deletions x/wasm/internal/keeper/querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ import (
)

type grpcQuerier struct {
keeper Keeper
keeper *Keeper
}

// todo: this needs proper tests and doc
func NewQuerier(keeper Keeper) grpcQuerier {
func NewQuerier(keeper *Keeper) grpcQuerier {
return grpcQuerier{keeper: keeper}
}

func (q grpcQuerier) ContractInfo(c context.Context, req *types.QueryContractInfoRequest) (*types.QueryContractInfoResponse, error) {
if err := sdk.VerifyAddressFormat(req.Address); err != nil {
return nil, err
}
rsp, err := queryContractInfo(sdk.UnwrapSDKContext(c), req.Address, q.keeper)
rsp, err := queryContractInfo(sdk.UnwrapSDKContext(c), req.Address, *q.keeper)
switch {
case err != nil:
return nil, err
Expand All @@ -40,7 +40,7 @@ func (q grpcQuerier) ContractHistory(c context.Context, req *types.QueryContract
if err := sdk.VerifyAddressFormat(req.Address); err != nil {
return nil, err
}
rsp, err := queryContractHistory(sdk.UnwrapSDKContext(c), req.Address, q.keeper)
rsp, err := queryContractHistory(sdk.UnwrapSDKContext(c), req.Address, *q.keeper)
switch {
case err != nil:
return nil, err
Expand All @@ -56,7 +56,7 @@ func (q grpcQuerier) ContractsByCode(c context.Context, req *types.QueryContract
if req.CodeId == 0 {
return nil, sdkerrors.Wrap(types.ErrInvalid, "code id")
}
rsp, err := queryContractListByCode(sdk.UnwrapSDKContext(c), req.CodeId, q.keeper)
rsp, err := queryContractListByCode(sdk.UnwrapSDKContext(c), req.CodeId, *q.keeper)
switch {
case err != nil:
return nil, err
Expand Down Expand Up @@ -134,7 +134,7 @@ func (q grpcQuerier) Code(c context.Context, req *types.QueryCodeRequest) (*type
}

func (q grpcQuerier) Codes(c context.Context, _ *empty.Empty) (*types.QueryCodesResponse, error) {
rsp, err := queryCodeList(sdk.UnwrapSDKContext(c), q.keeper)
rsp, err := queryCodeList(sdk.UnwrapSDKContext(c), *q.keeper)
switch {
case err != nil:
return nil, err
Expand Down Expand Up @@ -182,7 +182,7 @@ func queryContractListByCode(ctx sdk.Context, codeID uint64, keeper Keeper) ([]t
return contracts, nil
}

func queryCode(ctx sdk.Context, codeID uint64, keeper Keeper) (*types.QueryCodeResponse, error) {
func queryCode(ctx sdk.Context, codeID uint64, keeper *Keeper) (*types.QueryCodeResponse, error) {
if codeID == 0 {
return nil, nil
}
Expand Down
4 changes: 2 additions & 2 deletions x/wasm/internal/keeper/recurse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type recurseResponse struct {
// number os wasm queries called from a contract
var totalWasmQueryCounter int

func initRecurseContract(t *testing.T) (contract sdk.AccAddress, creator sdk.AccAddress, ctx sdk.Context, keeper Keeper) {
func initRecurseContract(t *testing.T) (contract sdk.AccAddress, creator sdk.AccAddress, ctx sdk.Context, keeper *Keeper) {
// we do one basic setup before all test cases (which are read-only and don't change state)
var realWasmQuerier func(ctx sdk.Context, request *wasmTypes.WasmQuery) ([]byte, error)
countingQuerier := &QueryPlugins{
Expand All @@ -48,7 +48,7 @@ func initRecurseContract(t *testing.T) (contract sdk.AccAddress, creator sdk.Acc

ctx, keepers := CreateTestInput(t, false, SupportedFeatures, nil, countingQuerier)
keeper = keepers.WasmKeeper
realWasmQuerier = WasmQuerier(&keeper)
realWasmQuerier = WasmQuerier(keeper)

exampleContract := InstantiateHackatomExampleContract(t, ctx, keepers)
return exampleContract.Contract, exampleContract.CreatorAddr, ctx, keeper
Expand Down
2 changes: 1 addition & 1 deletion x/wasm/internal/keeper/staking_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func initializeStaking(t *testing.T) initInfo {
ctx: ctx,
accKeeper: accKeeper,
stakingKeeper: stakingKeeper,
wasmKeeper: keeper,
wasmKeeper: *keeper,
distKeeper: k.DistKeeper,
bankKeeper: bankKeeper,
}
Expand Down
Loading

0 comments on commit 4fb3a50

Please sign in to comment.