diff --git a/CHANGELOG.md b/CHANGELOG.md index db41a1eda64a..35d58bf9a621 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,7 +49,7 @@ Every module contains its own CHANGELOG.md. Please refer to the module you are i ### Bug Fixes * (baseapp) [#21256](https://github.com/cosmos/cosmos-sdk/pull/21256) Halt height will not commit the block indicated, meaning that if halt-height is set to 10, only blocks until 9 (included) will be committed. This is to go back to the original behavior before a change was introduced in v0.50.0. - +* (baseapp) [#]() Fix data race in sdk mempool. ### API Breaking Changes diff --git a/baseapp/abci_utils.go b/baseapp/abci_utils.go index 6da80906fab5..fef240a4aa36 100644 --- a/baseapp/abci_utils.go +++ b/baseapp/abci_utils.go @@ -285,14 +285,16 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan return &abci.PrepareProposalResponse{Txs: h.txSelector.SelectedTxs(ctx)}, nil } - iterator := h.mempool.Select(ctx, req.Txs) selectedTxsSignersSeqs := make(map[string]uint64) - var selectedTxsNums int - for iterator != nil { - memTx := iterator.Tx() - signerData, err := h.signerExtAdapter.GetSigners(memTx) + var ( + err error + selectedTxsNums int + ) + h.mempool.SelectBy(ctx, req.Txs, func(memTx sdk.Tx) bool { + var signerData []mempool.SignerData + signerData, err = h.signerExtAdapter.GetSigners(memTx) if err != nil { - return nil, err + return false } // If the signers aren't in selectedTxsSignersSeqs then we haven't seen them before @@ -316,24 +318,24 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan txSignersSeqs[signer.Signer.String()] = signer.Sequence } if !shouldAdd { - iterator = iterator.Next() - continue + return true } // NOTE: Since transaction verification was already executed in CheckTx, // which calls mempool.Insert, in theory everything in the pool should be // valid. But some mempool implementations may insert invalid txs, so we // check again. - txBz, err := h.txVerifier.PrepareProposalVerifyTx(memTx) + var txBz []byte + txBz, err = h.txVerifier.PrepareProposalVerifyTx(memTx) if err != nil { - err := h.mempool.Remove(memTx) + err = h.mempool.Remove(memTx) if err != nil && !errors.Is(err, mempool.ErrTxNotFound) { - return nil, err + return false } } else { stop := h.txSelector.SelectTxForProposal(ctx, uint64(req.MaxTxBytes), maxBlockGas, memTx, txBz) if stop { - break + return false } txsLen := len(h.txSelector.SelectedTxs(ctx)) @@ -354,8 +356,8 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan selectedTxsNums = txsLen } - iterator = iterator.Next() - } + return true + }) return &abci.PrepareProposalResponse{Txs: h.txSelector.SelectedTxs(ctx)}, nil } diff --git a/types/mempool/mempool.go b/types/mempool/mempool.go index 7051c93e3146..f3a12ed5e88c 100644 --- a/types/mempool/mempool.go +++ b/types/mempool/mempool.go @@ -17,6 +17,9 @@ type Mempool interface { // closed by the caller. Select(context.Context, [][]byte) Iterator + // SelectBy use callback to iterate over the mempool, it's thread-safe to use. + SelectBy(context.Context, [][]byte, func(sdk.Tx) bool) + // CountTx returns the number of transactions currently in the mempool. CountTx() int diff --git a/types/mempool/noop.go b/types/mempool/noop.go index 73c12639d1d6..33c002080f82 100644 --- a/types/mempool/noop.go +++ b/types/mempool/noop.go @@ -16,7 +16,8 @@ var _ Mempool = (*NoOpMempool)(nil) // is FIFO-ordered by default. type NoOpMempool struct{} -func (NoOpMempool) Insert(context.Context, sdk.Tx) error { return nil } -func (NoOpMempool) Select(context.Context, [][]byte) Iterator { return nil } -func (NoOpMempool) CountTx() int { return 0 } -func (NoOpMempool) Remove(sdk.Tx) error { return nil } +func (NoOpMempool) Insert(context.Context, sdk.Tx) error { return nil } +func (NoOpMempool) Select(context.Context, [][]byte) Iterator { return nil } +func (NoOpMempool) SelectBy(context.Context, [][]byte, func(sdk.Tx) bool) {} +func (NoOpMempool) CountTx() int { return 0 } +func (NoOpMempool) Remove(sdk.Tx) error { return nil } diff --git a/types/mempool/priority_nonce.go b/types/mempool/priority_nonce.go index a927693410ef..216cd8148cae 100644 --- a/types/mempool/priority_nonce.go +++ b/types/mempool/priority_nonce.go @@ -368,6 +368,27 @@ func (mp *PriorityNonceMempool[C]) Select(_ context.Context, _ [][]byte) Iterato return iterator.iteratePriority() } +// SelectBy will hold the mutex during the iteration, callback returns if continue. +func (mp *PriorityNonceMempool[C]) SelectBy(_ context.Context, _ [][]byte, callback func(sdk.Tx) bool) { + mp.mtx.Lock() + defer mp.mtx.Unlock() + if mp.priorityIndex.Len() == 0 { + return + } + + mp.reorderPriorityTies() + + iterator := &PriorityNonceIterator[C]{ + mempool: mp, + senderCursors: make(map[string]*skiplist.Element), + } + + iter := iterator.iteratePriority() + for iter != nil && callback(iter.Tx()) { + iter = iter.Next() + } +} + type reorderKey[C comparable] struct { deleteKey txMeta[C] insertKey txMeta[C] diff --git a/types/mempool/priority_nonce_test.go b/types/mempool/priority_nonce_test.go index 0a2f40355fbd..a5cf1a29249e 100644 --- a/types/mempool/priority_nonce_test.go +++ b/types/mempool/priority_nonce_test.go @@ -1,9 +1,11 @@ package mempool_test import ( + "context" "fmt" "math" "math/rand" + "sync" "testing" "time" @@ -395,6 +397,89 @@ func (s *MempoolTestSuite) TestIterator() { } } +func (s *MempoolTestSuite) TestIteratorConcurrency() { + t := s.T() + ctx := sdk.NewContext(nil, false, log.NewNopLogger()) + accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 2) + sa := accounts[0].Address + sb := accounts[1].Address + + tests := []struct { + txs []txSpec + fail bool + }{ + { + txs: []txSpec{ + {p: 20, n: 1, a: sa}, + {p: 15, n: 1, a: sb}, + {p: 6, n: 2, a: sa}, + {p: 21, n: 4, a: sa}, + {p: 8, n: 2, a: sb}, + }, + }, + { + txs: []txSpec{ + {p: 20, n: 1, a: sa}, + {p: 15, n: 1, a: sb}, + {p: 6, n: 2, a: sa}, + {p: 21, n: 4, a: sa}, + {p: math.MinInt64, n: 2, a: sb}, + }, + }, + } + + for i, tt := range tests { + t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) { + pool := mempool.DefaultPriorityMempool() + + // create test txs and insert into mempool + for i, ts := range tt.txs { + tx := testTx{id: i, priority: int64(ts.p), nonce: uint64(ts.n), address: ts.a} + c := ctx.WithPriority(tx.priority) + err := pool.Insert(c, tx) + require.NoError(t, err) + } + + // iterate through txs + stdCtx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + id := len(tt.txs) + for { + select { + case <-stdCtx.Done(): + return + default: + id++ + tx := testTx{id: id, priority: int64(rand.Intn(100)), nonce: uint64(id), address: sa} + c := ctx.WithPriority(tx.priority) + err := pool.Insert(c, tx) + require.NoError(t, err) + } + } + }() + + var i int + pool.SelectBy(ctx, nil, func(memTx sdk.Tx) bool { + tx := memTx.(testTx) + if tx.id < len(tt.txs) { + require.Equal(t, tt.txs[tx.id].p, int(tx.priority)) + require.Equal(t, tt.txs[tx.id].n, int(tx.nonce)) + require.Equal(t, tt.txs[tx.id].a, tx.address) + i++ + } + return i < len(tt.txs) + }) + require.Equal(t, i, len(tt.txs)) + cancel() + wg.Wait() + }) + } +} + func (s *MempoolTestSuite) TestPriorityTies() { ctx := sdk.NewContext(nil, false, log.NewNopLogger()) accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 3) diff --git a/types/mempool/sender_nonce.go b/types/mempool/sender_nonce.go index fc4902f64792..371e701d280a 100644 --- a/types/mempool/sender_nonce.go +++ b/types/mempool/sender_nonce.go @@ -189,6 +189,40 @@ func (snm *SenderNonceMempool) Select(_ context.Context, _ [][]byte) Iterator { return iter.Next() } +// SelectBy will hold the mutex during the iteration, callback returns if continue. +func (snm *SenderNonceMempool) SelectBy(_ context.Context, _ [][]byte, callback func(sdk.Tx) bool) { + snm.mtx.Lock() + defer snm.mtx.Unlock() + var senders []string + + senderCursors := make(map[string]*skiplist.Element) + orderedSenders := skiplist.New(skiplist.String) + + // #nosec + for s := range snm.senders { + orderedSenders.Set(s, s) + } + + s := orderedSenders.Front() + for s != nil { + sender := s.Value.(string) + senders = append(senders, sender) + senderCursors[sender] = snm.senders[sender].Front() + s = s.Next() + } + + iterator := &senderNonceMempoolIterator{ + senders: senders, + rnd: snm.rnd, + senderCursors: senderCursors, + } + + iter := iterator.Next() + for iter != nil && callback(iter.Tx()) { + iter = iter.Next() + } +} + // CountTx returns the total count of txs in the mempool. func (snm *SenderNonceMempool) CountTx() int { snm.mtx.Lock()