Skip to content

Commit

Permalink
refactor: middleware refactor to change tx.Handler interface (#10527)
Browse files Browse the repository at this point in the history
<!--
The default pull request template is for types feat, fix, or refactor.
For other templates, add one of the following parameters to the url:
- template=docs.md
- template=other.md
-->

## Description

Closes: #10484 

This PR makes the following big changes:

### 1. Change the tx.Handler interface

```diff
-	CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error)
+	CheckTx(ctx context.Context, req tx.Request, checkReq tx.RequestCheckTx) (tx.Response, tx.ResponseCheckTx, error)
// same for Deliver and Simulate
```
 where:

```go
type Response struct {
	GasWanted uint64
	GasUsed   uint64
	// MsgResponses is an array containing each Msg service handler's response
	// type, packed in an Any. This will get proto-serialized into the `Data` field
	// in the ABCI Check/DeliverTx responses.
	MsgResponses []*codectypes.Any
	Log          string
	Events       []abci.Event
}
```

### 2. Change what gets passed into the ABCI Check/DeliverTx `Data` field

Before, we were passing the concatenation of MsgResponse bytes into the `Data`. Now we are passing the proto-serialiazation of this struct:

```proto
message TxMsgData {
  repeated google.protobuf.Any msg_responses = 2;
}
```

<!-- Add a description of the changes that this PR introduces and the files that
are the most critical to review. -->

---

### Author Checklist

*All items are required. Please add a note to the item if the item is not applicable and
please add links to any relevant follow up issues.*

I have...

- [ ] included the correct [type prefix](https://github.com/commitizen/conventional-commit-types/blob/v3.0.0/index.json) in the PR title
- [ ] added `!` to the type prefix if API or client breaking change
- [ ] targeted the correct branch (see [PR Targeting](https://github.com/cosmos/cosmos-sdk/blob/master/CONTRIBUTING.md#pr-targeting))
- [ ] provided a link to the relevant issue or specification
- [ ] followed the guidelines for [building modules](https://github.com/cosmos/cosmos-sdk/blob/master/docs/building-modules)
- [ ] included the necessary unit and integration [tests](https://github.com/cosmos/cosmos-sdk/blob/master/CONTRIBUTING.md#testing)
- [ ] added a changelog entry to `CHANGELOG.md`
- [ ] included comments for [documenting Go code](https://blog.golang.org/godoc)
- [ ] updated the relevant documentation or specification
- [ ] reviewed "Files changed" and left comments if necessary
- [ ] confirmed all CI checks have passed

### Reviewers Checklist

*All items are required. Please add a note if the item is not applicable and please add
your handle next to the items reviewed if you only reviewed selected items.*

I have...

- [ ] confirmed the correct [type prefix](https://github.com/commitizen/conventional-commit-types/blob/v3.0.0/index.json) in the PR title
- [ ] confirmed `!` in the type prefix if API or client breaking change
- [ ] confirmed all author checklist items have been addressed 
- [ ] reviewed state machine logic
- [ ] reviewed API design and naming
- [ ] reviewed documentation is accurate
- [ ] reviewed tests and test coverage
- [ ] manually tested (if applicable)
  • Loading branch information
atheeshp authored Dec 2, 2021
1 parent 25f3af2 commit 5d86db3
Show file tree
Hide file tree
Showing 37 changed files with 886 additions and 572 deletions.
70 changes: 58 additions & 12 deletions baseapp/abci.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/cosmos/cosmos-sdk/telemetry"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/cosmos/cosmos-sdk/types/tx"
)

// InitChain implements the ABCI interface. It runs the initialization logic
Expand Down Expand Up @@ -249,18 +250,23 @@ func (app *BaseApp) CheckTx(req abci.RequestCheckTx) abci.ResponseCheckTx {
panic(fmt.Sprintf("unknown RequestCheckTx type: %s", req.Type))
}

tx, err := app.txDecoder(req.Tx)
reqTx, err := app.txDecoder(req.Tx)
if err != nil {
return sdkerrors.ResponseCheckTx(err, 0, 0, app.trace)
}

ctx := app.getContextForTx(mode, req.Tx)
res, err := app.txHandler.CheckTx(ctx, tx, req)
res, checkRes, err := app.txHandler.CheckTx(ctx, tx.Request{Tx: reqTx, TxBytes: req.Tx}, tx.RequestCheckTx{Type: req.Type})
if err != nil {
return sdkerrors.ResponseCheckTx(err, uint64(res.GasUsed), uint64(res.GasWanted), app.trace)
}

return res
abciRes, err := convertTxResponseToCheckTx(res, checkRes)
if err != nil {
return sdkerrors.ResponseCheckTx(err, uint64(res.GasUsed), uint64(res.GasWanted), app.trace)
}

return abciRes
}

// DeliverTx implements the ABCI interface and executes a tx in DeliverTx mode.
Expand All @@ -271,28 +277,34 @@ func (app *BaseApp) CheckTx(req abci.RequestCheckTx) abci.ResponseCheckTx {
func (app *BaseApp) DeliverTx(req abci.RequestDeliverTx) abci.ResponseDeliverTx {
defer telemetry.MeasureSince(time.Now(), "abci", "deliver_tx")

var res abci.ResponseDeliverTx
var abciRes abci.ResponseDeliverTx
defer func() {
for _, streamingListener := range app.abciListeners {
if err := streamingListener.ListenDeliverTx(app.deliverState.ctx, req, res); err != nil {
if err := streamingListener.ListenDeliverTx(app.deliverState.ctx, req, abciRes); err != nil {
app.logger.Error("DeliverTx listening hook failed", "err", err)
}
}
}()
tx, err := app.txDecoder(req.Tx)
reqTx, err := app.txDecoder(req.Tx)
if err != nil {
res = sdkerrors.ResponseDeliverTx(err, 0, 0, app.trace)
return res
abciRes = sdkerrors.ResponseDeliverTx(err, 0, 0, app.trace)
return abciRes
}

ctx := app.getContextForTx(runTxModeDeliver, req.Tx)
res, err = app.txHandler.DeliverTx(ctx, tx, req)
res, err := app.txHandler.DeliverTx(ctx, tx.Request{Tx: reqTx, TxBytes: req.Tx})
if err != nil {
res = sdkerrors.ResponseDeliverTx(err, uint64(res.GasUsed), uint64(res.GasWanted), app.trace)
return res
abciRes = sdkerrors.ResponseDeliverTx(err, uint64(res.GasUsed), uint64(res.GasWanted), app.trace)
return abciRes
}

return res
abciRes, err = convertTxResponseToDeliverTx(res)
if err != nil {
return sdkerrors.ResponseDeliverTx(err, uint64(res.GasUsed), uint64(res.GasWanted), app.trace)
}

return abciRes

}

// Commit implements the ABCI interface. It will commit all state that exists in
Expand Down Expand Up @@ -894,3 +906,37 @@ func splitPath(requestPath string) (path []string) {

return path
}

// makeABCIData generates the Data field to be sent to ABCI Check/DeliverTx.
func makeABCIData(txRes tx.Response) ([]byte, error) {
return proto.Marshal(&sdk.TxMsgData{MsgResponses: txRes.MsgResponses})
}

// convertTxResponseToCheckTx converts a tx.Response into a abci.ResponseCheckTx.
func convertTxResponseToCheckTx(txRes tx.Response, checkRes tx.ResponseCheckTx) (abci.ResponseCheckTx, error) {
data, err := makeABCIData(txRes)
if err != nil {
return abci.ResponseCheckTx{}, nil
}

return abci.ResponseCheckTx{
Data: data,
Log: txRes.Log,
Events: txRes.Events,
Priority: checkRes.Priority,
}, nil
}

// convertTxResponseToDeliverTx converts a tx.Response into a abci.ResponseDeliverTx.
func convertTxResponseToDeliverTx(txRes tx.Response) (abci.ResponseDeliverTx, error) {
data, err := makeABCIData(txRes)
if err != nil {
return abci.ResponseDeliverTx{}, nil
}

return abci.ResponseDeliverTx{
Data: data,
Log: txRes.Log,
Events: txRes.Events,
}, nil
}
101 changes: 77 additions & 24 deletions baseapp/baseapp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/cosmos/cosmos-sdk/baseapp"
"github.com/cosmos/cosmos-sdk/codec"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
"github.com/cosmos/cosmos-sdk/snapshots"
snapshottypes "github.com/cosmos/cosmos-sdk/snapshots/types"
"github.com/cosmos/cosmos-sdk/store/rootmulti"
Expand Down Expand Up @@ -148,7 +149,14 @@ func setupBaseAppWithSnapshots(t *testing.T, blocks uint, blockTxs int, options
legacyRouter.AddRoute(sdk.NewRoute(routeMsgKeyValue, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
kv := msg.(*msgKeyValue)
bapp.CMS().GetCommitKVStore(capKey2).Set(kv.Key, kv.Value)
return &sdk.Result{}, nil
any, err := codectypes.NewAnyWithValue(msg)
if err != nil {
return nil, err
}

return &sdk.Result{
MsgResponses: []*codectypes.Any{any},
}, nil
}))
txHandler := testTxHandler(
middleware.TxHandlerOptions{
Expand Down Expand Up @@ -716,7 +724,7 @@ func (tx *txTest) setFailOnAnte(fail bool) {

func (tx *txTest) setFailOnHandler(fail bool) {
for i, msg := range tx.Msgs {
tx.Msgs[i] = msgCounter{msg.(msgCounter).Counter, fail}
tx.Msgs[i] = &msgCounter{msg.(*msgCounter).Counter, fail}
}
}

Expand Down Expand Up @@ -744,16 +752,16 @@ type msgCounter struct {
}

// dummy implementation of proto.Message
func (msg msgCounter) Reset() {}
func (msg msgCounter) String() string { return "TODO" }
func (msg msgCounter) ProtoMessage() {}
func (msg *msgCounter) Reset() {}
func (msg *msgCounter) String() string { return "TODO" }
func (msg *msgCounter) ProtoMessage() {}

// Implements Msg
func (msg msgCounter) Route() string { return routeMsgCounter }
func (msg msgCounter) Type() string { return "counter1" }
func (msg msgCounter) GetSignBytes() []byte { return nil }
func (msg msgCounter) GetSigners() []sdk.AccAddress { return nil }
func (msg msgCounter) ValidateBasic() error {
func (msg *msgCounter) Route() string { return routeMsgCounter }
func (msg *msgCounter) Type() string { return "counter1" }
func (msg *msgCounter) GetSignBytes() []byte { return nil }
func (msg *msgCounter) GetSigners() []sdk.AccAddress { return nil }
func (msg *msgCounter) ValidateBasic() error {
if msg.Counter >= 0 {
return nil
}
Expand All @@ -763,7 +771,7 @@ func (msg msgCounter) ValidateBasic() error {
func newTxCounter(counter int64, msgCounters ...int64) txTest {
msgs := make([]sdk.Msg, 0, len(msgCounters))
for _, c := range msgCounters {
msgs = append(msgs, msgCounter{c, false})
msgs = append(msgs, &msgCounter{c, false})
}

return txTest{msgs, counter, false, math.MaxUint64}
Expand Down Expand Up @@ -903,6 +911,14 @@ func handlerMsgCounter(t *testing.T, capKey storetypes.StoreKey, deliverKey []by
}

res.Events = ctx.EventManager().Events().ToABCIEvents()

any, err := codectypes.NewAnyWithValue(msg)
if err != nil {
return nil, err
}

res.MsgResponses = []*codectypes.Any{any}

return res, nil
}
}
Expand Down Expand Up @@ -1153,7 +1169,15 @@ func TestSimulateTx(t *testing.T) {
legacyRouter := middleware.NewLegacyRouter()
r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
ctx.GasMeter().ConsumeGas(gasConsumed, "test")
return &sdk.Result{}, nil
// Return dummy MsgResponse for msgCounter.
any, err := codectypes.NewAnyWithValue(&testdata.Dog{})
if err != nil {
return nil, err
}

return &sdk.Result{
MsgResponses: []*codectypes.Any{any},
}, nil
})
legacyRouter.AddRoute(r)
txHandler := testTxHandler(
Expand Down Expand Up @@ -1221,7 +1245,14 @@ func TestRunInvalidTransaction(t *testing.T) {
txHandlerOpt := func(bapp *baseapp.BaseApp) {
legacyRouter := middleware.NewLegacyRouter()
r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
return &sdk.Result{}, nil
any, err := codectypes.NewAnyWithValue(msg)
if err != nil {
return nil, err
}

return &sdk.Result{
MsgResponses: []*codectypes.Any{any},
}, nil
})
legacyRouter.AddRoute(r)
txHandler := testTxHandler(
Expand Down Expand Up @@ -1269,7 +1300,7 @@ func TestRunInvalidTransaction(t *testing.T) {

for _, testCase := range testCases {
tx := testCase.tx
_, result, err := app.SimDeliver(aminoTxEncoder(), tx)
_, _, err := app.SimDeliver(aminoTxEncoder(), tx)

if testCase.fail {
require.Error(t, err)
Expand All @@ -1278,14 +1309,14 @@ func TestRunInvalidTransaction(t *testing.T) {
require.EqualValues(t, sdkerrors.ErrInvalidSequence.Codespace(), space, err)
require.EqualValues(t, sdkerrors.ErrInvalidSequence.ABCICode(), code, err)
} else {
require.NotNil(t, result)
require.NoError(t, err)
}
}
}

// transaction with no known route
{
unknownRouteTx := txTest{[]sdk.Msg{msgNoRoute{}}, 0, false, math.MaxUint64}
unknownRouteTx := txTest{[]sdk.Msg{&msgNoRoute{}}, 0, false, math.MaxUint64}
_, result, err := app.SimDeliver(aminoTxEncoder(), unknownRouteTx)
require.Error(t, err)
require.Nil(t, result)
Expand All @@ -1294,7 +1325,7 @@ func TestRunInvalidTransaction(t *testing.T) {
require.EqualValues(t, sdkerrors.ErrUnknownRequest.Codespace(), space, err)
require.EqualValues(t, sdkerrors.ErrUnknownRequest.ABCICode(), code, err)

unknownRouteTx = txTest{[]sdk.Msg{msgCounter{}, msgNoRoute{}}, 0, false, math.MaxUint64}
unknownRouteTx = txTest{[]sdk.Msg{&msgCounter{}, &msgNoRoute{}}, 0, false, math.MaxUint64}
_, result, err = app.SimDeliver(aminoTxEncoder(), unknownRouteTx)
require.Error(t, err)
require.Nil(t, result)
Expand All @@ -1307,7 +1338,7 @@ func TestRunInvalidTransaction(t *testing.T) {
// Transaction with an unregistered message
{
tx := newTxCounter(0, 0)
tx.Msgs = append(tx.Msgs, msgNoDecode{})
tx.Msgs = append(tx.Msgs, &msgNoDecode{})

// new codec so we can encode the tx, but we shouldn't be able to decode
newCdc := codec.NewLegacyAmino()
Expand Down Expand Up @@ -1336,9 +1367,16 @@ func TestTxGasLimits(t *testing.T) {
txHandlerOpt := func(bapp *baseapp.BaseApp) {
legacyRouter := middleware.NewLegacyRouter()
r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
count := msg.(msgCounter).Counter
count := msg.(*msgCounter).Counter
ctx.GasMeter().ConsumeGas(uint64(count), "counter-handler")
return &sdk.Result{}, nil
any, err := codectypes.NewAnyWithValue(msg)
if err != nil {
return nil, err
}

return &sdk.Result{
MsgResponses: []*codectypes.Any{any},
}, nil
})
legacyRouter.AddRoute(r)
txHandler := testTxHandler(
Expand Down Expand Up @@ -1415,9 +1453,17 @@ func TestMaxBlockGasLimits(t *testing.T) {
txHandlerOpt := func(bapp *baseapp.BaseApp) {
legacyRouter := middleware.NewLegacyRouter()
r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
count := msg.(msgCounter).Counter
count := msg.(*msgCounter).Counter
ctx.GasMeter().ConsumeGas(uint64(count), "counter-handler")
return &sdk.Result{}, nil

any, err := codectypes.NewAnyWithValue(msg)
if err != nil {
return nil, err
}

return &sdk.Result{
MsgResponses: []*codectypes.Any{any},
}, nil
})
legacyRouter.AddRoute(r)
txHandler := testTxHandler(
Expand Down Expand Up @@ -1595,7 +1641,7 @@ func TestGasConsumptionBadTx(t *testing.T) {
txHandlerOpt := func(bapp *baseapp.BaseApp) {
legacyRouter := middleware.NewLegacyRouter()
r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
count := msg.(msgCounter).Counter
count := msg.(*msgCounter).Counter
ctx.GasMeter().ConsumeGas(uint64(count), "counter-handler")
return &sdk.Result{}, nil
})
Expand Down Expand Up @@ -1651,7 +1697,14 @@ func TestQuery(t *testing.T) {
r := sdk.NewRoute(routeMsgCounter, func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
store := ctx.KVStore(capKey1)
store.Set(key, value)
return &sdk.Result{}, nil

any, err := codectypes.NewAnyWithValue(msg)
if err != nil {
return nil, err
}
return &sdk.Result{
MsgResponses: []*codectypes.Any{any},
}, nil
})
legacyRouter.AddRoute(r)
txHandler := testTxHandler(
Expand Down
25 changes: 12 additions & 13 deletions baseapp/custom_txhandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/types/tx"
abci "github.com/tendermint/tendermint/abci/types"
"github.com/tendermint/tendermint/crypto/tmhash"
)

Expand All @@ -31,33 +30,33 @@ func CustomTxHandlerMiddleware(handler handlerFun) tx.Middleware {
}

// CheckTx implements tx.Handler.CheckTx method.
func (txh customTxHandler) CheckTx(ctx context.Context, tx sdk.Tx, req abci.RequestCheckTx) (abci.ResponseCheckTx, error) {
sdkCtx, err := txh.runHandler(ctx, tx, req.Tx, false)
func (txh customTxHandler) CheckTx(ctx context.Context, req tx.Request, checkReq tx.RequestCheckTx) (tx.Response, tx.ResponseCheckTx, error) {
sdkCtx, err := txh.runHandler(ctx, req.Tx, req.TxBytes, false)
if err != nil {
return abci.ResponseCheckTx{}, err
return tx.Response{}, tx.ResponseCheckTx{}, err
}

return txh.next.CheckTx(sdk.WrapSDKContext(sdkCtx), tx, req)
return txh.next.CheckTx(sdk.WrapSDKContext(sdkCtx), req, checkReq)
}

// DeliverTx implements tx.Handler.DeliverTx method.
func (txh customTxHandler) DeliverTx(ctx context.Context, tx sdk.Tx, req abci.RequestDeliverTx) (abci.ResponseDeliverTx, error) {
sdkCtx, err := txh.runHandler(ctx, tx, req.Tx, false)
func (txh customTxHandler) DeliverTx(ctx context.Context, req tx.Request) (tx.Response, error) {
sdkCtx, err := txh.runHandler(ctx, req.Tx, req.TxBytes, false)
if err != nil {
return abci.ResponseDeliverTx{}, err
return tx.Response{}, err
}

return txh.next.DeliverTx(sdk.WrapSDKContext(sdkCtx), tx, req)
return txh.next.DeliverTx(sdk.WrapSDKContext(sdkCtx), req)
}

// SimulateTx implements tx.Handler.SimulateTx method.
func (txh customTxHandler) SimulateTx(ctx context.Context, sdkTx sdk.Tx, req tx.RequestSimulateTx) (tx.ResponseSimulateTx, error) {
sdkCtx, err := txh.runHandler(ctx, sdkTx, req.TxBytes, true)
func (txh customTxHandler) SimulateTx(ctx context.Context, req tx.Request) (tx.Response, error) {
sdkCtx, err := txh.runHandler(ctx, req.Tx, req.TxBytes, true)
if err != nil {
return tx.ResponseSimulateTx{}, err
return tx.Response{}, err
}

return txh.next.SimulateTx(sdk.WrapSDKContext(sdkCtx), sdkTx, req)
return txh.next.SimulateTx(sdk.WrapSDKContext(sdkCtx), req)
}

func (txh customTxHandler) runHandler(ctx context.Context, tx sdk.Tx, txBytes []byte, isSimulate bool) (sdk.Context, error) {
Expand Down
Loading

0 comments on commit 5d86db3

Please sign in to comment.