diff --git a/CHANGELOG.md b/CHANGELOG.md index d2256a14f5..608836d720 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,8 @@ See [RELEASE](./RELEASE.md) for workflow instructions. * [#5562](https://github.com/spacemeshos/go-spacemesh/pull/5562) Add streaming mode for fetcher. This should lessen GC pressure during sync +* [#5718](https://github.com/spacemeshos/go-spacemesh/pull/5718) Sync malfeasance proofs continuously. + ## Release v1.4.0 ### Upgrade information diff --git a/checkpoint/recovery.go b/checkpoint/recovery.go index fcb2f7f6b8..57ca8b5ecb 100644 --- a/checkpoint/recovery.go +++ b/checkpoint/recovery.go @@ -23,6 +23,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxsync" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/malsync" "github.com/spacemeshos/go-spacemesh/sql/poets" "github.com/spacemeshos/go-spacemesh/sql/recovery" ) @@ -116,9 +117,12 @@ func Recover( return nil, fmt.Errorf("open old local database: %w", err) } defer localDB.Close() - logger.With().Info("clearing atx sync metadata from local database") + logger.With().Info("clearing atx and malfeasance sync metadata from local database") if err := localDB.WithTx(ctx, func(tx *sql.Tx) error { - return atxsync.Clear(tx) + if err := atxsync.Clear(tx); err != nil { + return err + } + return malsync.Clear(tx) }); err != nil { return nil, fmt.Errorf("clear atxsync: %w", err) } diff --git a/config/mainnet.go b/config/mainnet.go index f5768e919e..42ef05c154 100644 --- a/config/mainnet.go +++ b/config/mainnet.go @@ -25,6 +25,7 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/syncer" "github.com/spacemeshos/go-spacemesh/syncer/atxsync" + "github.com/spacemeshos/go-spacemesh/syncer/malsync" timeConfig "github.com/spacemeshos/go-spacemesh/timesync/config" "github.com/spacemeshos/go-spacemesh/tortoise" ) @@ -191,6 +192,7 @@ func MainnetConfig() Config { OutOfSyncThresholdLayers: 36, // 3h DisableMeshAgreement: true, AtxSync: atxsync.DefaultConfig(), + MalSync: malsync.DefaultConfig(), }, Recovery: checkpoint.DefaultConfig(), Cache: datastore.DefaultConfig(), diff --git a/config/presets/fastnet.go b/config/presets/fastnet.go index 6a7ca92ca3..76f469e260 100644 --- a/config/presets/fastnet.go +++ b/config/presets/fastnet.go @@ -52,6 +52,7 @@ func fastnet() config.Config { conf.Sync.Interval = 5 * time.Second conf.Sync.GossipDuration = 10 * time.Second conf.Sync.AtxSync.EpochInfoInterval = 20 * time.Second + conf.Sync.MalSync.IDRequestInterval = 20 * time.Second conf.LayersPerEpoch = 4 conf.RegossipAtxInterval = 30 * time.Second conf.FETCH.RequestTimeout = 2 * time.Second diff --git a/fetch/mesh_data.go b/fetch/mesh_data.go index 1aed0be449..868cefc3df 100644 --- a/fetch/mesh_data.go +++ b/fetch/mesh_data.go @@ -226,27 +226,30 @@ func (f *Fetch) GetPoetProof(ctx context.Context, id types.Hash32) error { } } -func (f *Fetch) GetMaliciousIDs(ctx context.Context, peer p2p.Peer) ([]byte, error) { +func (f *Fetch) GetMaliciousIDs(ctx context.Context, peer p2p.Peer) ([]types.NodeID, error) { + var malIDs MaliciousIDs if f.cfg.Streaming { - var b []byte if err := f.meteredStreamRequest( ctx, malProtocol, peer, []byte{}, func(ctx context.Context, s io.ReadWriter) (int, error) { return server.ReadResponse(s, func(respLen uint32) (n int, err error) { - b = make([]byte, respLen) - if _, err := io.ReadFull(s, b); err != nil { - return 0, err - } - return int(respLen), nil + return codec.DecodeFrom(s, &malIDs) }) }, ); err != nil { return nil, err } - return b, nil } else { - return f.meteredRequest(ctx, malProtocol, peer, []byte{}) + data, err := f.meteredRequest(ctx, malProtocol, peer, []byte{}) + if err != nil { + return nil, err + } + if err := codec.Decode(data, &malIDs); err != nil { + return nil, err + } } + f.RegisterPeerHashes(peer, types.NodeIDsToHashes(malIDs.NodeIDs)) + return malIDs.NodeIDs, nil } // GetLayerData get layer data from peers. diff --git a/fetch/mesh_data_test.go b/fetch/mesh_data_test.go index ae47f1e12c..f1d78cac3a 100644 --- a/fetch/mesh_data_test.go +++ b/fetch/mesh_data_test.go @@ -85,15 +85,13 @@ func startTestLoop(t *testing.T, f *Fetch, eg *errgroup.Group, stop chan struct{ }) } -func generateMaliciousIDs(t *testing.T) []byte { +func generateMaliciousIDs(t *testing.T) []types.NodeID { t.Helper() - var malicious MaliciousIDs - for i := 0; i < numMalicious; i++ { - malicious.NodeIDs = append(malicious.NodeIDs, types.RandomNodeID()) + malIDs := make([]types.NodeID, numMalicious) + for i := range malIDs { + malIDs[i] = types.RandomNodeID() } - data, err := codec.Encode(&malicious) - require.NoError(t, err) - return data + return malIDs } func generateLayerContent(t *testing.T) []byte { @@ -511,7 +509,9 @@ func TestFetch_GetMaliciousIDs(t *testing.T) { t.Parallel() f := createFetch(t) expectedIds := generateMaliciousIDs(t) - f.mMalS.EXPECT().Request(gomock.Any(), p2p.Peer("p0"), []byte{}).Return(expectedIds, nil) + resp := codec.MustEncode(&MaliciousIDs{NodeIDs: expectedIds}) + f.mh.EXPECT().ID().Return("self").AnyTimes() + f.mMalS.EXPECT().Request(gomock.Any(), p2p.Peer("p0"), []byte{}).Return(resp, nil) ids, err := f.GetMaliciousIDs(context.Background(), "p0") require.NoError(t, err) require.Equal(t, expectedIds, ids) diff --git a/fetch/p2p_test.go b/fetch/p2p_test.go index eae260d940..073ad88e7e 100644 --- a/fetch/p2p_test.go +++ b/fetch/p2p_test.go @@ -323,12 +323,10 @@ func TestP2PMaliciousIDs(t *testing.T) { tpf.serverDB.Close() } - out, err := tpf.clientFetch.GetMaliciousIDs(context.Background(), tpf.serverID) + malIDs, err := tpf.clientFetch.GetMaliciousIDs(context.Background(), tpf.serverID) if errStr == "" { require.NoError(t, err) - var got MaliciousIDs - require.NoError(t, codec.Decode(out, &got)) - require.ElementsMatch(t, bad, got.NodeIDs) + require.ElementsMatch(t, bad, malIDs) } else { require.ErrorContains(t, err, errStr) } diff --git a/node/node.go b/node/node.go index 5f5c412dfa..160edd2201 100644 --- a/node/node.go +++ b/node/node.go @@ -81,6 +81,7 @@ import ( "github.com/spacemeshos/go-spacemesh/syncer" "github.com/spacemeshos/go-spacemesh/syncer/atxsync" "github.com/spacemeshos/go-spacemesh/syncer/blockssync" + "github.com/spacemeshos/go-spacemesh/syncer/malsync" "github.com/spacemeshos/go-spacemesh/system" "github.com/spacemeshos/go-spacemesh/timesync" timeCfg "github.com/spacemeshos/go-spacemesh/timesync/config" @@ -862,6 +863,9 @@ func (app *App) initServices(ctx context.Context) error { syncerConf.SyncCertDistance = app.Config.Tortoise.Hdist syncerConf.Standalone = app.Config.Standalone + if app.Config.P2P.MinPeers < app.Config.Sync.MalSync.MinSyncPeers { + app.Config.Sync.MalSync.MinSyncPeers = max(1, app.Config.P2P.MinPeers) + } app.syncLogger = app.addLogger(SyncLogger, lg) newSyncer := syncer.NewSyncer( app.cachedDB, @@ -876,6 +880,11 @@ func (app *App) initServices(ctx context.Context) error { atxsync.WithConfig(app.Config.Sync.AtxSync), atxsync.WithLogger(app.syncLogger.Zap()), ), + malsync.New(fetcher, app.db, app.localDB, + malsync.WithConfig(app.Config.Sync.MalSync), + malsync.WithLogger(app.syncLogger.Zap()), + malsync.WithPeerErrMetric(syncer.MalPeerError), + ), syncer.WithConfig(syncerConf), syncer.WithLogger(app.syncLogger), ) diff --git a/sql/malsync/malsync.go b/sql/malsync/malsync.go new file mode 100644 index 0000000000..9ae572b0ad --- /dev/null +++ b/sql/malsync/malsync.go @@ -0,0 +1,45 @@ +package malsync + +import ( + "fmt" + "time" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +func GetSyncState(db sql.Executor) (time.Time, error) { + var timestamp time.Time + rows, err := db.Exec("select timestamp from malfeasance_sync_state", + nil, func(stmt *sql.Statement) bool { + v := stmt.ColumnInt64(0) + if v > 0 { + timestamp = time.Unix(v, 0) + } + return true + }) + if err != nil { + return time.Time{}, fmt.Errorf("error getting malfeasance sync state: %w", err) + } else if rows != 1 { + return time.Time{}, fmt.Errorf("expected malfeasance_sync_state to have 1 row but got %d rows", rows) + } + return timestamp, nil +} + +func updateSyncState(db sql.Executor, ts int64) error { + _, err := db.Exec("update malfeasance_sync_state set timestamp = ?1", + func(stmt *sql.Statement) { + stmt.BindInt64(1, ts) + }, nil) + if err != nil { + return fmt.Errorf("error updating malfeasance sync state: %w", err) + } + return nil +} + +func UpdateSyncState(db sql.Executor, timestamp time.Time) error { + return updateSyncState(db, timestamp.Unix()) +} + +func Clear(db sql.Executor) error { + return updateSyncState(db, 0) +} diff --git a/sql/malsync/malsync_test.go b/sql/malsync/malsync_test.go new file mode 100644 index 0000000000..103ec95b24 --- /dev/null +++ b/sql/malsync/malsync_test.go @@ -0,0 +1,29 @@ +package malsync + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sql/localsql" +) + +func TestMalfeasanceSyncState(t *testing.T) { + db := localsql.InMemory() + timestamp, err := GetSyncState(db) + require.NoError(t, err) + require.Equal(t, time.Time{}, timestamp) + ts := time.Now() + for i := 0; i < 3; i++ { + require.NoError(t, UpdateSyncState(db, ts)) + timestamp, err = GetSyncState(db) + require.NoError(t, err) + require.Equal(t, ts.Truncate(time.Second), timestamp) + ts = ts.Add(3 * time.Minute) + } + require.NoError(t, Clear(db)) + timestamp, err = GetSyncState(db) + require.NoError(t, err) + require.Equal(t, time.Time{}, timestamp) +} diff --git a/sql/migrations/local/0007_malfeasance_sync.sql b/sql/migrations/local/0007_malfeasance_sync.sql new file mode 100644 index 0000000000..0385eac87f --- /dev/null +++ b/sql/migrations/local/0007_malfeasance_sync.sql @@ -0,0 +1,6 @@ +CREATE TABLE malfeasance_sync_state +( + timestamp INT NOT NULL +); + +INSERT INTO malfeasance_sync_state (timestamp) VALUES (0); diff --git a/syncer/data_fetch.go b/syncer/data_fetch.go index 9cb087553e..b52331e4fc 100644 --- a/syncer/data_fetch.go +++ b/syncer/data_fetch.go @@ -61,68 +61,6 @@ func (e *threadSafeErr) join(err error) { e.err = errors.Join(e.err, err) } -// PollMaliciousProofs polls all peers for malicious NodeIDs. -func (d *DataFetch) PollMaliciousProofs(ctx context.Context) error { - peers := d.fetcher.SelectBestShuffled(fetch.RedundantPeers) - logger := d.logger.WithContext(ctx) - - maliciousIDs := make(chan fetch.MaliciousIDs, len(peers)) - var eg errgroup.Group - fetchErr := threadSafeErr{} - for _, peer := range peers { - peer := peer - eg.Go(func() error { - data, err := d.fetcher.GetMaliciousIDs(ctx, peer) - if err != nil { - malPeerError.Inc() - logger.With().Debug("failed to get malicious IDs", log.Err(err), log.Stringer("peer", peer)) - fetchErr.join(err) - return nil - } - var malIDs fetch.MaliciousIDs - if err := codec.Decode(data, &malIDs); err != nil { - logger.With().Debug("failed to decode", log.Err(err)) - fetchErr.join(err) - return nil - } - logger.With().Debug("received malicious id from peer", log.Stringer("peer", peer)) - maliciousIDs <- malIDs - return nil - }) - } - _ = eg.Wait() - close(maliciousIDs) - - allIds := make(map[types.NodeID]struct{}) - success := false - for ids := range maliciousIDs { - success = true - for _, id := range ids.NodeIDs { - allIds[id] = struct{}{} - } - } - if !success { - return fetchErr.err - } - - var idsToFetch []types.NodeID - for nodeID := range allIds { - if exists, err := d.ids.IdentityExists(nodeID); err != nil { - logger.With().Error("failed to check identity", log.Err(err)) - continue - } else if !exists { - logger.With().Info("malicious identity does not exist", log.Stringer("identity", nodeID)) - continue - } - idsToFetch = append(idsToFetch, nodeID) - } - - if err := d.fetcher.GetMalfeasanceProofs(ctx, idsToFetch); err != nil { - return fmt.Errorf("getting malfeasance proofs: %w", err) - } - return nil -} - // PollLayerData polls all peers for data in the specified layer. func (d *DataFetch) PollLayerData(ctx context.Context, lid types.LayerID, peers ...p2p.Peer) error { if len(peers) == 0 { diff --git a/syncer/data_fetch_test.go b/syncer/data_fetch_test.go index e52af87c10..d2cf967297 100644 --- a/syncer/data_fetch_test.go +++ b/syncer/data_fetch_test.go @@ -45,17 +45,6 @@ const ( numMalicious = 11 ) -func generateMaliciousIDs(t *testing.T) ([]types.NodeID, []byte) { - t.Helper() - var malicious fetch.MaliciousIDs - for i := 0; i < numMalicious; i++ { - malicious.NodeIDs = append(malicious.NodeIDs, types.RandomNodeID()) - } - data, err := codec.Encode(&malicious) - require.NoError(t, err) - return malicious.NodeIDs, data -} - func generateLayerOpinions(t *testing.T, bid *types.BlockID) []byte { t.Helper() lo := &fetch.LayerOpinion{ @@ -98,99 +87,6 @@ func GenPeers(num int) []p2p.Peer { return peers } -func TestDataFetch_PollMaliciousIDs(t *testing.T) { - numPeers := 4 - peers := GenPeers(numPeers) - errUnknown := errors.New("unknown") - newTestDataFetchWithMocks := func(_ *testing.T, exists bool) *testDataFetch { - td := newTestDataFetch(t) - td.mFetcher.EXPECT().SelectBestShuffled(gomock.Any()).Return(peers) - for _, peer := range peers { - td.mFetcher.EXPECT().GetMaliciousIDs(gomock.Any(), peer).DoAndReturn( - func(_ context.Context, peer p2p.Peer) ([]byte, error) { - ids, data := generateMaliciousIDs(t) - for _, id := range ids { - td.mIDs.EXPECT().IdentityExists(id).Return(exists, nil) - } - return data, nil - }) - } - return td - } - t.Run("getting malfeasance proofs success", func(t *testing.T) { - t.Parallel() - td := newTestDataFetchWithMocks(t, true) - td.mFetcher.EXPECT().GetMalfeasanceProofs(gomock.Any(), gomock.Any()) - require.NoError(t, td.PollMaliciousProofs(context.Background())) - }) - t.Run("getting proofs failure", func(t *testing.T) { - t.Parallel() - td := newTestDataFetchWithMocks(t, true) - td.mFetcher.EXPECT().GetMalfeasanceProofs(gomock.Any(), gomock.Any()).Return(errUnknown) - require.ErrorIs(t, td.PollMaliciousProofs(context.Background()), errUnknown) - }) - t.Run("ids do not exist", func(t *testing.T) { - t.Parallel() - td := newTestDataFetchWithMocks(t, false) - td.mFetcher.EXPECT().GetMalfeasanceProofs(gomock.Any(), nil) - require.NoError(t, td.PollMaliciousProofs(context.Background())) - }) -} - -func TestDataFetch_PollMaliciousIDs_PeerErrors(t *testing.T) { - t.Run("malformed data in response", func(t *testing.T) { - t.Parallel() - peers := []p2p.Peer{"p0"} - td := newTestDataFetch(t) - td.mFetcher.EXPECT().SelectBestShuffled(gomock.Any()).Return(peers) - td.mFetcher.EXPECT().GetMaliciousIDs(gomock.Any(), p2p.Peer("p0")).Return([]byte("malformed"), nil) - err := td.PollMaliciousProofs(context.Background()) - require.ErrorContains(t, err, "decode") - }) - t.Run("peer fails", func(t *testing.T) { - t.Parallel() - peers := []p2p.Peer{"p0"} - expectedErr := errors.New("peer failure") - td := newTestDataFetch(t) - td.mFetcher.EXPECT().SelectBestShuffled(gomock.Any()).Return(peers) - td.mFetcher.EXPECT().GetMaliciousIDs(gomock.Any(), p2p.Peer("p0")).Return(nil, expectedErr) - err := td.PollMaliciousProofs(context.Background()) - require.ErrorIs(t, err, expectedErr) - }) - t.Run("one peer sends malformed data (succeed anyway)", func(t *testing.T) { - t.Parallel() - peers := []p2p.Peer{"p0", "p1"} - td := newTestDataFetch(t) - maliciousIds, data := generateMaliciousIDs(t) - for _, id := range maliciousIds { - td.mIDs.EXPECT().IdentityExists(id).Return(true, nil) - } - - td.mFetcher.EXPECT().SelectBestShuffled(gomock.Any()).Return(peers) - td.mFetcher.EXPECT().GetMaliciousIDs(gomock.Any(), p2p.Peer("p0")).Return(data, nil) - td.mFetcher.EXPECT().GetMaliciousIDs(gomock.Any(), p2p.Peer("p1")).Return([]byte("malformed"), nil) - td.mFetcher.EXPECT().GetMalfeasanceProofs(gomock.Any(), gomock.Any()) - err := td.PollMaliciousProofs(context.Background()) - require.NoError(t, err) - }) - t.Run("one peer fails (succeed anyway)", func(t *testing.T) { - t.Parallel() - peers := []p2p.Peer{"p0", "p1"} - expectedErr := errors.New("peer failure") - td := newTestDataFetch(t) - maliciousIds, data := generateMaliciousIDs(t) - for _, id := range maliciousIds { - td.mIDs.EXPECT().IdentityExists(id).Return(true, nil) - } - td.mFetcher.EXPECT().SelectBestShuffled(gomock.Any()).Return(peers) - td.mFetcher.EXPECT().GetMaliciousIDs(gomock.Any(), p2p.Peer("p0")).Return(data, nil) - td.mFetcher.EXPECT().GetMaliciousIDs(gomock.Any(), p2p.Peer("p1")).Return(nil, expectedErr) - td.mFetcher.EXPECT().GetMalfeasanceProofs(gomock.Any(), gomock.Any()) - err := td.PollMaliciousProofs(context.Background()) - require.NoError(t, err) - }) -} - func TestDataFetch_PollLayerData(t *testing.T) { numPeers := 4 peers := GenPeers(numPeers) diff --git a/syncer/interface.go b/syncer/interface.go index 5b5efb8321..2250bd528d 100644 --- a/syncer/interface.go +++ b/syncer/interface.go @@ -26,7 +26,6 @@ type meshProvider interface { type fetchLogic interface { fetcher - PollMaliciousProofs(ctx context.Context) error PollLayerData(context.Context, types.LayerID, ...p2p.Peer) error PollLayerOpinions( context.Context, @@ -40,15 +39,20 @@ type atxSyncer interface { Download(context.Context, types.EpochID, time.Time) error } +type malSyncer interface { + EnsureInSync(parent context.Context, epochStart, epochEnd time.Time) error + DownloadLoop(parent context.Context) error +} + // fetcher is the interface to the low-level fetching. type fetcher interface { - GetMaliciousIDs(context.Context, p2p.Peer) ([]byte, error) + GetMaliciousIDs(context.Context, p2p.Peer) ([]types.NodeID, error) GetLayerData(context.Context, p2p.Peer, types.LayerID) ([]byte, error) GetLayerOpinions(context.Context, p2p.Peer, types.LayerID) ([]byte, error) GetCert(context.Context, types.LayerID, types.BlockID, []p2p.Peer) (*types.Certificate, error) - GetMalfeasanceProofs(context.Context, []types.NodeID) error system.AtxFetcher + system.MalfeasanceProofFetcher GetBallots(context.Context, []types.BallotID) error GetBlocks(context.Context, []types.BlockID) error RegisterPeerHashes(peer p2p.Peer, hashes []types.Hash32) diff --git a/syncer/malsync/mocks/mocks.go b/syncer/malsync/mocks/mocks.go new file mode 100644 index 0000000000..549d0b28b5 --- /dev/null +++ b/syncer/malsync/mocks/mocks.go @@ -0,0 +1,216 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./syncer.go +// +// Generated by this command: +// +// mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./syncer.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + types "github.com/spacemeshos/go-spacemesh/common/types" + p2p "github.com/spacemeshos/go-spacemesh/p2p" + gomock "go.uber.org/mock/gomock" +) + +// Mockfetcher is a mock of fetcher interface. +type Mockfetcher struct { + ctrl *gomock.Controller + recorder *MockfetcherMockRecorder +} + +// MockfetcherMockRecorder is the mock recorder for Mockfetcher. +type MockfetcherMockRecorder struct { + mock *Mockfetcher +} + +// NewMockfetcher creates a new mock instance. +func NewMockfetcher(ctrl *gomock.Controller) *Mockfetcher { + mock := &Mockfetcher{ctrl: ctrl} + mock.recorder = &MockfetcherMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *Mockfetcher) EXPECT() *MockfetcherMockRecorder { + return m.recorder +} + +// GetMalfeasanceProofs mocks base method. +func (m *Mockfetcher) GetMalfeasanceProofs(arg0 context.Context, arg1 []types.NodeID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMalfeasanceProofs", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// GetMalfeasanceProofs indicates an expected call of GetMalfeasanceProofs. +func (mr *MockfetcherMockRecorder) GetMalfeasanceProofs(arg0, arg1 any) *MockfetcherGetMalfeasanceProofsCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMalfeasanceProofs", reflect.TypeOf((*Mockfetcher)(nil).GetMalfeasanceProofs), arg0, arg1) + return &MockfetcherGetMalfeasanceProofsCall{Call: call} +} + +// MockfetcherGetMalfeasanceProofsCall wrap *gomock.Call +type MockfetcherGetMalfeasanceProofsCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockfetcherGetMalfeasanceProofsCall) Return(arg0 error) *MockfetcherGetMalfeasanceProofsCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockfetcherGetMalfeasanceProofsCall) Do(f func(context.Context, []types.NodeID) error) *MockfetcherGetMalfeasanceProofsCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockfetcherGetMalfeasanceProofsCall) DoAndReturn(f func(context.Context, []types.NodeID) error) *MockfetcherGetMalfeasanceProofsCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// GetMaliciousIDs mocks base method. +func (m *Mockfetcher) GetMaliciousIDs(arg0 context.Context, arg1 p2p.Peer) ([]types.NodeID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMaliciousIDs", arg0, arg1) + ret0, _ := ret[0].([]types.NodeID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMaliciousIDs indicates an expected call of GetMaliciousIDs. +func (mr *MockfetcherMockRecorder) GetMaliciousIDs(arg0, arg1 any) *MockfetcherGetMaliciousIDsCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMaliciousIDs", reflect.TypeOf((*Mockfetcher)(nil).GetMaliciousIDs), arg0, arg1) + return &MockfetcherGetMaliciousIDsCall{Call: call} +} + +// MockfetcherGetMaliciousIDsCall wrap *gomock.Call +type MockfetcherGetMaliciousIDsCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockfetcherGetMaliciousIDsCall) Return(arg0 []types.NodeID, arg1 error) *MockfetcherGetMaliciousIDsCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockfetcherGetMaliciousIDsCall) Do(f func(context.Context, p2p.Peer) ([]types.NodeID, error)) *MockfetcherGetMaliciousIDsCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockfetcherGetMaliciousIDsCall) DoAndReturn(f func(context.Context, p2p.Peer) ([]types.NodeID, error)) *MockfetcherGetMaliciousIDsCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// SelectBestShuffled mocks base method. +func (m *Mockfetcher) SelectBestShuffled(arg0 int) []p2p.Peer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SelectBestShuffled", arg0) + ret0, _ := ret[0].([]p2p.Peer) + return ret0 +} + +// SelectBestShuffled indicates an expected call of SelectBestShuffled. +func (mr *MockfetcherMockRecorder) SelectBestShuffled(arg0 any) *MockfetcherSelectBestShuffledCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SelectBestShuffled", reflect.TypeOf((*Mockfetcher)(nil).SelectBestShuffled), arg0) + return &MockfetcherSelectBestShuffledCall{Call: call} +} + +// MockfetcherSelectBestShuffledCall wrap *gomock.Call +type MockfetcherSelectBestShuffledCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockfetcherSelectBestShuffledCall) Return(arg0 []p2p.Peer) *MockfetcherSelectBestShuffledCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockfetcherSelectBestShuffledCall) Do(f func(int) []p2p.Peer) *MockfetcherSelectBestShuffledCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockfetcherSelectBestShuffledCall) DoAndReturn(f func(int) []p2p.Peer) *MockfetcherSelectBestShuffledCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Mockcounter is a mock of counter interface. +type Mockcounter struct { + ctrl *gomock.Controller + recorder *MockcounterMockRecorder +} + +// MockcounterMockRecorder is the mock recorder for Mockcounter. +type MockcounterMockRecorder struct { + mock *Mockcounter +} + +// NewMockcounter creates a new mock instance. +func NewMockcounter(ctrl *gomock.Controller) *Mockcounter { + mock := &Mockcounter{ctrl: ctrl} + mock.recorder = &MockcounterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *Mockcounter) EXPECT() *MockcounterMockRecorder { + return m.recorder +} + +// Inc mocks base method. +func (m *Mockcounter) Inc() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Inc") +} + +// Inc indicates an expected call of Inc. +func (mr *MockcounterMockRecorder) Inc() *MockcounterIncCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Inc", reflect.TypeOf((*Mockcounter)(nil).Inc)) + return &MockcounterIncCall{Call: call} +} + +// MockcounterIncCall wrap *gomock.Call +type MockcounterIncCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockcounterIncCall) Return() *MockcounterIncCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockcounterIncCall) Do(f func()) *MockcounterIncCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockcounterIncCall) DoAndReturn(f func()) *MockcounterIncCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/syncer/malsync/syncer.go b/syncer/malsync/syncer.go new file mode 100644 index 0000000000..6e73b50026 --- /dev/null +++ b/syncer/malsync/syncer.go @@ -0,0 +1,461 @@ +package malsync + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/jonboulle/clockwork" + "go.uber.org/zap" + "golang.org/x/exp/maps" + "golang.org/x/sync/errgroup" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/fetch" + "github.com/spacemeshos/go-spacemesh/log" + "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/p2p/pubsub" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/malsync" + "github.com/spacemeshos/go-spacemesh/system" +) + +//go:generate mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./syncer.go + +type fetcher interface { + SelectBestShuffled(int) []p2p.Peer + GetMaliciousIDs(context.Context, p2p.Peer) ([]types.NodeID, error) + system.MalfeasanceProofFetcher +} + +type Opt func(*Syncer) + +func WithLogger(logger *zap.Logger) Opt { + return func(s *Syncer) { + s.logger = logger + } +} + +type counter interface { + Inc() +} + +type noCounter struct{} + +func (noCounter) Inc() {} + +func WithPeerErrMetric(counter counter) Opt { + return func(s *Syncer) { + s.peerErrMetric = counter + } +} + +func withClock(clock clockwork.Clock) Opt { + return func(s *Syncer) { + s.clock = clock + } +} + +func DefaultConfig() Config { + return Config{ + IDRequestInterval: 30 * time.Minute, + MalfeasanceIDPeers: 3, + MinSyncPeers: 3, + MaxEpochFraction: 0.25, + MaxBatchSize: 1000, + RequestsLimit: 20, + RetryInterval: time.Minute, + } +} + +type Config struct { + // IDRequestInterval specifies the interval for malfeasance proof id requests to the network. + IDRequestInterval time.Duration `mapstructure:"id-request-interval"` + + // MalfeasanceIDPeers is the number of peers to fetch node IDs for malfeasance proofs from. + MalfeasanceIDPeers int `mapstructure:"malfeasance-id-peers"` + + // Minimum number of peers to sync against for initial sync to be considered complete. + MinSyncPeers int `mapstructure:"min-sync-peers"` + + // MaxEpochFraction specifies maximum fraction of epoch to expire before + // synchronous malfeasance proof sync is needed upon startup. + MaxEpochFraction float64 `mapstructure:"max-epoch-fraction"` + + // MaxBatchSize is the maximum number of node IDs to sync in a single request. + MaxBatchSize int `mapstructure:"max-batch-size"` + + // RequestsLimit is the maximum number of requests for a single malfeasance proof. + // + // The purpose of it is to prevent peers from advertising invalid node ID and disappearing. + // Which will make node ask other peers for invalid malfeasance proofs. + // It will be reset to 0 once malfeasance proof is advertised again. + RequestsLimit int `mapstructure:"requests-limit"` + + // RetryInterval specifies retry interval for the initial sync. + RetryInterval time.Duration `mapstructure:"retry-interval"` +} + +func WithConfig(cfg Config) Opt { + return func(s *Syncer) { + s.cfg = cfg + } +} + +type syncPeerSet map[p2p.Peer]struct{} + +func (sps syncPeerSet) add(peer p2p.Peer) { + sps[peer] = struct{}{} +} + +func (sps syncPeerSet) clear() { + maps.Clear(sps) +} + +func (sps syncPeerSet) updateFrom(other syncPeerSet) { + maps.Copy(sps, other) +} + +// syncState stores malfeasance sync state. +type syncState struct { + limit int + initial bool + state map[types.NodeID]int + syncingPeers syncPeerSet + syncedPeers syncPeerSet +} + +func newSyncState(limit int, initial bool) *syncState { + return &syncState{ + limit: limit, + initial: initial, + state: make(map[types.NodeID]int), + syncedPeers: make(syncPeerSet), + syncingPeers: make(syncPeerSet), + } +} + +func (sst *syncState) done() { + if sst.initial { + sst.syncedPeers.updateFrom(sst.syncingPeers) + sst.syncingPeers.clear() + } + maps.Clear(sst.state) +} + +func (sst *syncState) numSyncedPeers() int { + return len(sst.syncedPeers) +} + +func (sst *syncState) update(update malUpdate) { + if sst.initial { + sst.syncingPeers.add(update.peer) + } + for _, id := range update.nodeIDs { + if _, found := sst.state[id]; !found { + sst.state[id] = 0 + } + } +} + +func (sst *syncState) has(nodeID types.NodeID) bool { + _, found := sst.state[nodeID] + return found +} + +func (sst *syncState) failed(nodeID types.NodeID) { + // possibly temporary failure, count failed attempt + n := sst.state[nodeID] + if n >= 0 { + sst.state[nodeID] = n + 1 + } +} + +func (sst *syncState) rejected(nodeID types.NodeID) { + // malfeasance proof didn't pass validation, no sense in requesting it anymore + n := sst.state[nodeID] + if n >= 0 { + sst.state[nodeID] = sst.limit + } +} + +func (sst *syncState) downloaded(nodeID types.NodeID) { + sst.state[nodeID] = -1 +} + +func (sst *syncState) missing(max int, has func(nodeID types.NodeID) (bool, error)) ([]types.NodeID, error) { + r := make([]types.NodeID, 0, max) + for nodeID, count := range sst.state { + if count < 0 { + continue // already downloaded + } + exists, err := has(nodeID) + if err != nil { + return nil, err + } + if exists { + sst.downloaded(nodeID) + continue + } + if count >= sst.limit { + // unsuccessfully requested too many times + delete(sst.state, nodeID) + continue + } + r = append(r, nodeID) + if len(r) == cap(r) { + break + } + } + return r, nil +} + +type Syncer struct { + logger *zap.Logger + cfg Config + fetcher fetcher + db sql.Executor + localdb *localsql.Database + clock clockwork.Clock + peerErrMetric counter +} + +func New(fetcher fetcher, db sql.Executor, localdb *localsql.Database, opts ...Opt) *Syncer { + s := &Syncer{ + logger: zap.NewNop(), + cfg: DefaultConfig(), + fetcher: fetcher, + db: db, + localdb: localdb, + clock: clockwork.NewRealClock(), + peerErrMetric: noCounter{}, + } + for _, opt := range opts { + opt(s) + } + return s +} + +func (s *Syncer) shouldSync(epochStart, epochEnd time.Time) (bool, error) { + timestamp, err := malsync.GetSyncState(s.localdb) + if err != nil { + return false, fmt.Errorf("error getting malfeasance sync state: %w", err) + } + if timestamp.Before(epochStart) { + return true, nil + } + cutoff := epochEnd.Sub(epochStart).Seconds() * s.cfg.MaxEpochFraction + return s.clock.Now().Sub(timestamp).Seconds() > cutoff, nil +} + +func (s *Syncer) download(parent context.Context, initial bool) error { + s.logger.Info("starting malfeasance proof sync", log.ZContext(parent)) + defer s.logger.Debug("malfeasance proof sync terminated", log.ZContext(parent)) + ctx, cancel := context.WithCancel(parent) + eg, ctx := errgroup.WithContext(ctx) + updates := make(chan malUpdate, s.cfg.MalfeasanceIDPeers) + eg.Go(func() error { + return s.downloadNodeIDs(ctx, initial, updates) + }) + eg.Go(func() error { + defer cancel() + return s.downloadMalfeasanceProofs(ctx, initial, updates) + }) + if err := eg.Wait(); err != nil { + return err + } + return parent.Err() +} + +func (s *Syncer) downloadNodeIDs(ctx context.Context, initial bool, updates chan<- malUpdate) error { + interval := s.cfg.IDRequestInterval + if initial { + interval = 0 + } + for { + if interval != 0 { + s.logger.Debug( + "pausing between malfeasant node ID requests", + zap.Duration("duration", interval)) + select { + case <-ctx.Done(): + return nil + // TODO(ivan4th) this has to be randomized in a followup + // when sync will be schedulled in advance, in order to smooth out request rate across the network + case <-s.clock.After(interval): + } + } + + peers := s.fetcher.SelectBestShuffled(s.cfg.MalfeasanceIDPeers) + if len(peers) == 0 { + s.logger.Debug( + "don't have enough peers for malfeasance sync", + zap.Int("nPeers", s.cfg.MalfeasanceIDPeers), + ) + if interval == 0 { + interval = s.cfg.RetryInterval + } + continue + } + + var eg errgroup.Group + for _, peer := range peers { + peer := peer + eg.Go(func() error { + malIDs, err := s.fetcher.GetMaliciousIDs(ctx, peer) + if err != nil { + if errors.Is(err, context.Canceled) { + return ctx.Err() + } + s.peerErrMetric.Inc() + s.logger.Warn("failed to download malfeasant node IDs", + log.ZContext(ctx), + zap.String("peer", peer.String()), + zap.Error(err), + ) + return nil + } + s.logger.Info("downloaded malfeasant node IDs", + log.ZContext(ctx), + zap.String("peer", peer.String()), + zap.Int("ids", len(malIDs)), + ) + select { + case <-ctx.Done(): + return ctx.Err() + case updates <- malUpdate{peer: peer, nodeIDs: malIDs}: + } + return nil + }) + } + + if err := eg.Wait(); err != nil { + return err + } + + if interval == 0 { + interval = s.cfg.RetryInterval + } + } +} + +func (s *Syncer) updateState() error { + if err := malsync.UpdateSyncState(s.localdb, s.clock.Now()); err != nil { + return fmt.Errorf("error updating malsync state: %w", err) + } + + return nil +} + +func (s *Syncer) downloadMalfeasanceProofs(ctx context.Context, initial bool, updates <-chan malUpdate) error { + var ( + update malUpdate + sst = newSyncState(s.cfg.RequestsLimit, initial) + nothingToDownload = true + gotUpdate = false + ) + for { + if nothingToDownload { + sst.done() + if initial && sst.numSyncedPeers() >= s.cfg.MinSyncPeers { + if err := s.updateState(); err != nil { + return nil + } + s.logger.Info("initial sync of malfeasance proofs completed", log.ZContext(ctx)) + return nil + } else if !initial && gotUpdate { + if err := s.updateState(); err != nil { + return nil + } + } + select { + case <-ctx.Done(): + return ctx.Err() + case update = <-updates: + s.logger.Debug("malfeasance sync update", + log.ZContext(ctx), zap.Int("count", len(update.nodeIDs))) + sst.update(update) + gotUpdate = true + } + } else { + select { + case <-ctx.Done(): + return ctx.Err() + case update = <-updates: + s.logger.Debug("malfeasance sync update", + log.ZContext(ctx), zap.Int("count", len(update.nodeIDs))) + sst.update(update) + gotUpdate = true + default: + // If we have some hashes to fetch already, don't wait for + // another update + } + } + batch, err := sst.missing(s.cfg.MaxBatchSize, func(nodeID types.NodeID) (bool, error) { + // TODO(ivan4th): check multiple node IDs at once in a single SQL query + isMalicious, err := identities.IsMalicious(s.db, nodeID) + if err != nil && errors.Is(err, sql.ErrNotFound) { + return false, nil + } + return isMalicious, err + }) + if err != nil { + return fmt.Errorf("error checking malfeasant node IDs: %w", err) + } + + nothingToDownload = len(batch) == 0 + + if len(batch) != 0 { + s.logger.Debug("retrieving malfeasant identities", + log.ZContext(ctx), + zap.Int("count", len(batch))) + err := s.fetcher.GetMalfeasanceProofs(ctx, batch) + if err != nil { + if errors.Is(err, context.Canceled) { + return ctx.Err() + } + s.logger.Debug("failed to download malfeasance proofs", + log.ZContext(ctx), + log.NiceZapError(err), + ) + } + batchError := &fetch.BatchError{} + if errors.As(err, &batchError) { + for hash, err := range batchError.Errors { + nodeID := types.NodeID(hash) + switch { + case !sst.has(nodeID): + continue + case errors.Is(err, fetch.ErrExceedMaxRetries): + sst.failed(nodeID) + case errors.Is(err, pubsub.ErrValidationReject): + sst.rejected(nodeID) + } + } + } + } else { + s.logger.Debug("no new malfeasant identities", log.ZContext(ctx)) + } + } +} + +func (s *Syncer) EnsureInSync(parent context.Context, epochStart, epochEnd time.Time) error { + if shouldSync, err := s.shouldSync(epochStart, epochEnd); err != nil { + return err + } else if !shouldSync { + return nil + } + return s.download(parent, true) +} + +func (s *Syncer) DownloadLoop(parent context.Context) error { + return s.download(parent, false) +} + +type malUpdate struct { + peer p2p.Peer + nodeIDs []types.NodeID +} diff --git a/syncer/malsync/syncer_test.go b/syncer/malsync/syncer_test.go new file mode 100644 index 0000000000..281f3a14de --- /dev/null +++ b/syncer/malsync/syncer_test.go @@ -0,0 +1,387 @@ +package malsync + +import ( + "context" + "errors" + "slices" + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/exp/maps" + "golang.org/x/sync/errgroup" + + "github.com/spacemeshos/go-spacemesh/codec" + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/fetch" + "github.com/spacemeshos/go-spacemesh/log/logtest" + "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/p2p/pubsub" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/syncer/malsync/mocks" +) + +type fakeCounter struct { + n int +} + +func (fc *fakeCounter) Inc() { fc.n++ } + +func genNodeIDs(n int) []types.NodeID { + ids := make([]types.NodeID, n) + for i := range ids { + ids[i] = types.RandomNodeID() + } + return ids +} + +func TestSyncState(t *testing.T) { + nodeIDs := genNodeIDs(5) + sst := newSyncState(3, true) + require.Zero(t, sst.numSyncedPeers()) + require.False(t, sst.has(nodeIDs[0])) + sst.update(malUpdate{ + peer: "a", + nodeIDs: slices.Clone(nodeIDs[:4]), + }) + for _, id := range nodeIDs[:4] { + require.True(t, sst.has(id)) + } + require.False(t, sst.has(nodeIDs[4])) + ids, err := sst.missing(10, func(nodeID types.NodeID) (bool, error) { return false, nil }) + require.NoError(t, err) + require.ElementsMatch(t, nodeIDs[:4], ids) + + testErr := errors.New("fail") + _, err = sst.missing(10, func(nodeID types.NodeID) (bool, error) { return false, testErr }) + require.ErrorIs(t, err, testErr) + + sst.downloaded(nodeIDs[0]) + sst.failed(nodeIDs[1]) + sst.rejected(nodeIDs[2]) + + ids, err = sst.missing(10, func(nodeID types.NodeID) (bool, error) { return false, nil }) + require.NoError(t, err) + require.ElementsMatch(t, []types.NodeID{nodeIDs[1], nodeIDs[3]}, ids) + + // make nodeIDs[1] fail too many times + sst.failed(nodeIDs[1]) + sst.failed(nodeIDs[1]) + + ids, err = sst.missing(10, func(nodeID types.NodeID) (bool, error) { return false, nil }) + require.NoError(t, err) + require.ElementsMatch(t, []types.NodeID{nodeIDs[3]}, ids) + + for i := 0; i < 2; i++ { + ids, err = sst.missing(10, func(nodeID types.NodeID) (bool, error) { + // nodeIDs[3] will be marked as downloaded + return nodeID == nodeIDs[3], nil + }) + require.NoError(t, err) + require.Empty(t, ids) + } + + require.Zero(t, sst.numSyncedPeers()) + sst.done() + require.Equal(t, 1, sst.numSyncedPeers()) + + sst.update(malUpdate{ + peer: "b", + }) + require.Equal(t, 1, sst.numSyncedPeers()) + sst.done() + require.Equal(t, 2, sst.numSyncedPeers()) +} + +func mproof(nodeID types.NodeID) *types.MalfeasanceProof { + var ballotProof types.BallotProof + for i := 0; i < 2; i++ { + ballotProof.Messages[i] = types.BallotProofMsg{ + InnerMsg: types.BallotMetadata{ + Layer: types.LayerID(9), + MsgHash: types.RandomHash(), + }, + Signature: types.RandomEdSignature(), + SmesherID: nodeID, + } + } + + return &types.MalfeasanceProof{ + Layer: types.LayerID(11), + Proof: types.Proof{ + Type: types.MultipleBallots, + Data: &ballotProof, + }, + } +} + +func nid(id string) types.NodeID { + var nodeID types.NodeID + copy(nodeID[:], id) + return nodeID +} + +func malData(ids ...string) []types.NodeID { + malIDs := make([]types.NodeID, len(ids)) + for n, id := range ids { + malIDs[n] = nid(id) + } + return malIDs +} + +type tester struct { + tb testing.TB + syncer *Syncer + localdb *localsql.Database + db *sql.Database + cfg Config + ctrl *gomock.Controller + fetcher *mocks.Mockfetcher + clock clockwork.FakeClock + received map[types.NodeID]bool + attempts map[types.NodeID]int + peers []p2p.Peer + peerErrCount *fakeCounter +} + +func newTester(tb testing.TB, cfg Config) *tester { + localdb := localsql.InMemory() + db := sql.InMemory() + ctrl := gomock.NewController(tb) + fetcher := mocks.NewMockfetcher(ctrl) + clock := clockwork.NewFakeClock() + peerErrCount := &fakeCounter{} + syncer := New(fetcher, db, localdb, + withClock(clock), + WithConfig(cfg), + WithLogger(logtest.New(tb).Zap()), + WithPeerErrMetric(peerErrCount), + ) + return &tester{ + tb: tb, + syncer: syncer, + localdb: localdb, + db: db, + cfg: cfg, + ctrl: ctrl, + fetcher: fetcher, + clock: clock, + received: make(map[types.NodeID]bool), + attempts: make(map[types.NodeID]int), + peers: []p2p.Peer{"a", "b", "c"}, + peerErrCount: peerErrCount, + } +} + +func (tester *tester) expectGetMaliciousIDs() { + // "2" comes just from a single peer + tester.fetcher.EXPECT(). + GetMaliciousIDs(gomock.Any(), tester.peers[0]). + Return(malData("4", "1", "3", "2"), nil) + for _, p := range tester.peers[1:] { + tester.fetcher.EXPECT(). + GetMaliciousIDs(gomock.Any(), p). + Return(malData("4", "1", "3"), nil) + } +} + +func (tester *tester) expectGetProofs(errMap map[types.NodeID]error) { + tester.fetcher.EXPECT(). + GetMalfeasanceProofs(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, ids []types.NodeID) error { + batchErr := &fetch.BatchError{ + Errors: make(map[types.Hash32]error), + } + for _, id := range ids { + tester.attempts[id]++ + require.NotContains(tester.tb, tester.received, id) + if err := errMap[id]; err != nil { + batchErr.Errors[types.Hash32(id)] = err + continue + } + tester.received[id] = true + proofData := codec.MustEncode(mproof(id)) + require.NoError(tester.tb, identities.SetMalicious( + tester.db, id, proofData, tester.syncer.clock.Now())) + } + if len(batchErr.Errors) != 0 { + return batchErr + } + return nil + }).AnyTimes() +} + +func (tester *tester) expectPeers(peers []p2p.Peer) { + tester.fetcher.EXPECT().SelectBestShuffled(tester.cfg.MalfeasanceIDPeers).Return(peers).AnyTimes() +} + +func TestSyncer(t *testing.T) { + t.Run("EnsureInSync", func(t *testing.T) { + tester := newTester(t, DefaultConfig()) + tester.expectPeers(tester.peers) + tester.expectGetMaliciousIDs() + tester.expectGetProofs(nil) + epochStart := tester.clock.Now().Truncate(time.Second) + epochEnd := epochStart.Add(10 * time.Minute) + require.NoError(t, + tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.ElementsMatch(t, []types.NodeID{ + nid("1"), nid("2"), nid("3"), nid("4"), + }, maps.Keys(tester.received)) + require.Equal(t, map[types.NodeID]int{ + nid("1"): 1, + nid("2"): 1, + nid("3"): 1, + nid("4"): 1, + }, tester.attempts) + tester.clock.Advance(1 * time.Minute) + // second call does nothing after recent sync + require.NoError(t, + tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.Zero(t, tester.peerErrCount.n) + }) + t.Run("EnsureInSync with no malfeasant identities", func(t *testing.T) { + tester := newTester(t, DefaultConfig()) + tester.expectPeers(tester.peers) + for _, p := range tester.peers { + tester.fetcher.EXPECT(). + GetMaliciousIDs(gomock.Any(), p). + Return(nil, nil) + } + epochStart := tester.clock.Now().Truncate(time.Second) + epochEnd := epochStart.Add(10 * time.Minute) + require.NoError(t, + tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.Zero(t, tester.peerErrCount.n) + }) + t.Run("interruptible", func(t *testing.T) { + tester := newTester(t, DefaultConfig()) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + tester.expectPeers([]p2p.Peer{"a"}) + tester.fetcher.EXPECT(). + GetMaliciousIDs(gomock.Any(), gomock.Any()). + Return(malData("1"), nil).AnyTimes() + tester.fetcher.EXPECT(). + GetMalfeasanceProofs(gomock.Any(), gomock.Any()). + Return(errors.New("no atxs")).AnyTimes() + require.ErrorIs(t, tester.syncer.DownloadLoop(ctx), context.Canceled) + }) + t.Run("retries on no peers", func(t *testing.T) { + tester := newTester(t, DefaultConfig()) + ctx, cancel := context.WithCancel(context.Background()) + ch := make(chan []p2p.Peer) + tester.fetcher.EXPECT().SelectBestShuffled(tester.cfg.MalfeasanceIDPeers). + DoAndReturn(func(int) []p2p.Peer { + return <-ch + }).AnyTimes() + var eg errgroup.Group + eg.Go(func() error { + require.ErrorIs(t, tester.syncer.DownloadLoop(ctx), context.Canceled) + return nil + }) + tester.clock.BlockUntil(1) + tester.clock.Advance(tester.cfg.IDRequestInterval) + ch <- nil + tester.clock.BlockUntil(1) + tester.clock.Advance(tester.cfg.IDRequestInterval) + + tester.expectGetMaliciousIDs() + tester.expectGetProofs(nil) + ch <- tester.peers + tester.clock.BlockUntil(1) + cancel() + eg.Wait() + }) + t.Run("gettings ids from MinSyncPeers peers is enough", func(t *testing.T) { + cfg := DefaultConfig() + cfg.MinSyncPeers = 2 + tester := newTester(t, cfg) + tester.expectPeers(tester.peers) + tester.fetcher.EXPECT(). + GetMaliciousIDs(gomock.Any(), tester.peers[0]). + Return(nil, errors.New("fail")) + for _, p := range tester.peers[1:] { + tester.fetcher.EXPECT(). + GetMaliciousIDs(gomock.Any(), p). + Return(malData("4", "1", "3", "2"), nil) + } + tester.expectGetProofs(nil) + epochStart := tester.clock.Now().Truncate(time.Second) + epochEnd := epochStart.Add(10 * time.Minute) + require.NoError(t, + tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.ElementsMatch(t, []types.NodeID{ + nid("1"), nid("2"), nid("3"), nid("4"), + }, maps.Keys(tester.received)) + require.Equal(t, map[types.NodeID]int{ + nid("1"): 1, + nid("2"): 1, + nid("3"): 1, + nid("4"): 1, + }, tester.attempts) + tester.clock.Advance(1 * time.Minute) + // second call does nothing after recent sync + require.NoError(t, + tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.Equal(t, 1, tester.peerErrCount.n) + }) + t.Run("skip hashes after max retries", func(t *testing.T) { + cfg := DefaultConfig() + cfg.RequestsLimit = 3 + tester := newTester(t, cfg) + tester.expectPeers(tester.peers) + tester.expectGetMaliciousIDs() + tester.expectGetProofs(map[types.NodeID]error{ + nid("2"): fetch.ErrExceedMaxRetries, + }) + epochStart := tester.clock.Now().Truncate(time.Second) + epochEnd := epochStart.Add(10 * time.Minute) + require.NoError(t, + tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.ElementsMatch(t, []types.NodeID{ + nid("1"), nid("3"), nid("4"), + }, maps.Keys(tester.received)) + require.Equal(t, map[types.NodeID]int{ + nid("1"): 1, + nid("2"): tester.cfg.RequestsLimit, + nid("3"): 1, + nid("4"): 1, + }, tester.attempts) + tester.clock.Advance(1 * time.Minute) + // second call does nothing after recent sync + require.NoError(t, + tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + }) + t.Run("skip hashes after validation reject", func(t *testing.T) { + tester := newTester(t, DefaultConfig()) + tester.expectPeers(tester.peers) + tester.expectGetMaliciousIDs() + tester.expectGetProofs(map[types.NodeID]error{ + // note that "2" comes just from a single peer + // (see expectGetMaliciousIDs) + nid("2"): pubsub.ErrValidationReject, + }) + epochStart := tester.clock.Now().Truncate(time.Second) + epochEnd := epochStart.Add(10 * time.Minute) + require.NoError(t, + tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + require.ElementsMatch(t, []types.NodeID{ + nid("1"), nid("3"), nid("4"), + }, maps.Keys(tester.received)) + require.Equal(t, map[types.NodeID]int{ + nid("1"): 1, + nid("2"): 1, + nid("3"): 1, + nid("4"): 1, + }, tester.attempts) + tester.clock.Advance(1 * time.Minute) + // second call does nothing after recent sync + require.NoError(t, + tester.syncer.EnsureInSync(context.Background(), epochStart, epochEnd)) + }) +} diff --git a/syncer/metrics.go b/syncer/metrics.go index 87e6d680a2..eaebc37da1 100644 --- a/syncer/metrics.go +++ b/syncer/metrics.go @@ -87,7 +87,7 @@ var ( layerPeerError = peerError.WithLabelValues("layer") opnsPeerError = peerError.WithLabelValues("opns") certPeerError = peerError.WithLabelValues("cert") - malPeerError = peerError.WithLabelValues("mal") + MalPeerError = peerError.WithLabelValues("mal") v2OpnPoll = metrics.NewCounter( "opn_poll", diff --git a/syncer/mocks/mocks.go b/syncer/mocks/mocks.go index aa5f1d50ba..61ed631811 100644 --- a/syncer/mocks/mocks.go +++ b/syncer/mocks/mocks.go @@ -477,10 +477,10 @@ func (c *MockfetchLogicGetMalfeasanceProofsCall) DoAndReturn(f func(context.Cont } // GetMaliciousIDs mocks base method. -func (m *MockfetchLogic) GetMaliciousIDs(arg0 context.Context, arg1 p2p.Peer) ([]byte, error) { +func (m *MockfetchLogic) GetMaliciousIDs(arg0 context.Context, arg1 p2p.Peer) ([]types.NodeID, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetMaliciousIDs", arg0, arg1) - ret0, _ := ret[0].([]byte) + ret0, _ := ret[0].([]types.NodeID) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -498,19 +498,19 @@ type MockfetchLogicGetMaliciousIDsCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockfetchLogicGetMaliciousIDsCall) Return(arg0 []byte, arg1 error) *MockfetchLogicGetMaliciousIDsCall { +func (c *MockfetchLogicGetMaliciousIDsCall) Return(arg0 []types.NodeID, arg1 error) *MockfetchLogicGetMaliciousIDsCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockfetchLogicGetMaliciousIDsCall) Do(f func(context.Context, p2p.Peer) ([]byte, error)) *MockfetchLogicGetMaliciousIDsCall { +func (c *MockfetchLogicGetMaliciousIDsCall) Do(f func(context.Context, p2p.Peer) ([]types.NodeID, error)) *MockfetchLogicGetMaliciousIDsCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockfetchLogicGetMaliciousIDsCall) DoAndReturn(f func(context.Context, p2p.Peer) ([]byte, error)) *MockfetchLogicGetMaliciousIDsCall { +func (c *MockfetchLogicGetMaliciousIDsCall) DoAndReturn(f func(context.Context, p2p.Peer) ([]types.NodeID, error)) *MockfetchLogicGetMaliciousIDsCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -676,44 +676,6 @@ func (c *MockfetchLogicPollLayerOpinionsCall) DoAndReturn(f func(context.Context return c } -// PollMaliciousProofs mocks base method. -func (m *MockfetchLogic) PollMaliciousProofs(ctx context.Context) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PollMaliciousProofs", ctx) - ret0, _ := ret[0].(error) - return ret0 -} - -// PollMaliciousProofs indicates an expected call of PollMaliciousProofs. -func (mr *MockfetchLogicMockRecorder) PollMaliciousProofs(ctx any) *MockfetchLogicPollMaliciousProofsCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PollMaliciousProofs", reflect.TypeOf((*MockfetchLogic)(nil).PollMaliciousProofs), ctx) - return &MockfetchLogicPollMaliciousProofsCall{Call: call} -} - -// MockfetchLogicPollMaliciousProofsCall wrap *gomock.Call -type MockfetchLogicPollMaliciousProofsCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockfetchLogicPollMaliciousProofsCall) Return(arg0 error) *MockfetchLogicPollMaliciousProofsCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockfetchLogicPollMaliciousProofsCall) Do(f func(context.Context) error) *MockfetchLogicPollMaliciousProofsCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockfetchLogicPollMaliciousProofsCall) DoAndReturn(f func(context.Context) error) *MockfetchLogicPollMaliciousProofsCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - // RegisterPeerHashes mocks base method. func (m *MockfetchLogic) RegisterPeerHashes(peer p2p.Peer, hashes []types.Hash32) { m.ctrl.T.Helper() @@ -849,6 +811,105 @@ func (c *MockatxSyncerDownloadCall) DoAndReturn(f func(context.Context, types.Ep return c } +// MockmalSyncer is a mock of malSyncer interface. +type MockmalSyncer struct { + ctrl *gomock.Controller + recorder *MockmalSyncerMockRecorder +} + +// MockmalSyncerMockRecorder is the mock recorder for MockmalSyncer. +type MockmalSyncerMockRecorder struct { + mock *MockmalSyncer +} + +// NewMockmalSyncer creates a new mock instance. +func NewMockmalSyncer(ctrl *gomock.Controller) *MockmalSyncer { + mock := &MockmalSyncer{ctrl: ctrl} + mock.recorder = &MockmalSyncerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockmalSyncer) EXPECT() *MockmalSyncerMockRecorder { + return m.recorder +} + +// DownloadLoop mocks base method. +func (m *MockmalSyncer) DownloadLoop(parent context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DownloadLoop", parent) + ret0, _ := ret[0].(error) + return ret0 +} + +// DownloadLoop indicates an expected call of DownloadLoop. +func (mr *MockmalSyncerMockRecorder) DownloadLoop(parent any) *MockmalSyncerDownloadLoopCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DownloadLoop", reflect.TypeOf((*MockmalSyncer)(nil).DownloadLoop), parent) + return &MockmalSyncerDownloadLoopCall{Call: call} +} + +// MockmalSyncerDownloadLoopCall wrap *gomock.Call +type MockmalSyncerDownloadLoopCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockmalSyncerDownloadLoopCall) Return(arg0 error) *MockmalSyncerDownloadLoopCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockmalSyncerDownloadLoopCall) Do(f func(context.Context) error) *MockmalSyncerDownloadLoopCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockmalSyncerDownloadLoopCall) DoAndReturn(f func(context.Context) error) *MockmalSyncerDownloadLoopCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// EnsureInSync mocks base method. +func (m *MockmalSyncer) EnsureInSync(parent context.Context, epochStart, epochEnd time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EnsureInSync", parent, epochStart, epochEnd) + ret0, _ := ret[0].(error) + return ret0 +} + +// EnsureInSync indicates an expected call of EnsureInSync. +func (mr *MockmalSyncerMockRecorder) EnsureInSync(parent, epochStart, epochEnd any) *MockmalSyncerEnsureInSyncCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EnsureInSync", reflect.TypeOf((*MockmalSyncer)(nil).EnsureInSync), parent, epochStart, epochEnd) + return &MockmalSyncerEnsureInSyncCall{Call: call} +} + +// MockmalSyncerEnsureInSyncCall wrap *gomock.Call +type MockmalSyncerEnsureInSyncCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockmalSyncerEnsureInSyncCall) Return(arg0 error) *MockmalSyncerEnsureInSyncCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockmalSyncerEnsureInSyncCall) Do(f func(context.Context, time.Time, time.Time) error) *MockmalSyncerEnsureInSyncCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockmalSyncerEnsureInSyncCall) DoAndReturn(f func(context.Context, time.Time, time.Time) error) *MockmalSyncerEnsureInSyncCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // Mockfetcher is a mock of fetcher interface. type Mockfetcher struct { ctrl *gomock.Controller @@ -1147,10 +1208,10 @@ func (c *MockfetcherGetMalfeasanceProofsCall) DoAndReturn(f func(context.Context } // GetMaliciousIDs mocks base method. -func (m *Mockfetcher) GetMaliciousIDs(arg0 context.Context, arg1 p2p.Peer) ([]byte, error) { +func (m *Mockfetcher) GetMaliciousIDs(arg0 context.Context, arg1 p2p.Peer) ([]types.NodeID, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetMaliciousIDs", arg0, arg1) - ret0, _ := ret[0].([]byte) + ret0, _ := ret[0].([]types.NodeID) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -1168,19 +1229,19 @@ type MockfetcherGetMaliciousIDsCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockfetcherGetMaliciousIDsCall) Return(arg0 []byte, arg1 error) *MockfetcherGetMaliciousIDsCall { +func (c *MockfetcherGetMaliciousIDsCall) Return(arg0 []types.NodeID, arg1 error) *MockfetcherGetMaliciousIDsCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockfetcherGetMaliciousIDsCall) Do(f func(context.Context, p2p.Peer) ([]byte, error)) *MockfetcherGetMaliciousIDsCall { +func (c *MockfetcherGetMaliciousIDsCall) Do(f func(context.Context, p2p.Peer) ([]types.NodeID, error)) *MockfetcherGetMaliciousIDsCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockfetcherGetMaliciousIDsCall) DoAndReturn(f func(context.Context, p2p.Peer) ([]byte, error)) *MockfetcherGetMaliciousIDsCall { +func (c *MockfetcherGetMaliciousIDsCall) DoAndReturn(f func(context.Context, p2p.Peer) ([]types.NodeID, error)) *MockfetcherGetMaliciousIDsCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/syncer/syncer.go b/syncer/syncer.go index 3a8f136786..daf67ea16c 100644 --- a/syncer/syncer.go +++ b/syncer/syncer.go @@ -18,6 +18,7 @@ import ( "github.com/spacemeshos/go-spacemesh/mesh" "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/syncer/atxsync" + "github.com/spacemeshos/go-spacemesh/syncer/malsync" "github.com/spacemeshos/go-spacemesh/system" ) @@ -36,6 +37,7 @@ type Config struct { DisableMeshAgreement bool `mapstructure:"disable-mesh-agreement"` OutOfSyncThresholdLayers uint32 `mapstructure:"out-of-sync-threshold"` AtxSync atxsync.Config `mapstructure:"atx-sync"` + MalSync malsync.Config `mapstructure:"malfeasance-sync"` } // DefaultConfig for the syncer. @@ -50,6 +52,7 @@ func DefaultConfig() Config { GossipDuration: 15 * time.Second, OutOfSyncThresholdLayers: 3, AtxSync: atxsync.DefaultConfig(), + MalSync: malsync.DefaultConfig(), } } @@ -122,6 +125,7 @@ type Syncer struct { cfg Config cdb *datastore.CachedDB atxsyncer atxSyncer + malsyncer malSyncer ticker layerTicker beacon system.BeaconGetter mesh *mesh.Mesh @@ -147,6 +151,12 @@ type Syncer struct { cancel context.CancelFunc } + // malSync runs malfeasant identity sync in the background + malSync struct { + started bool + eg errgroup.Group + } + // awaitATXSyncedCh is the list of subscribers' channels to notify when this node enters ATX synced state awaitATXSyncedCh chan struct{} @@ -165,6 +175,7 @@ func NewSyncer( patrol layerPatrol, ch certHandler, atxSyncer atxSyncer, + malSyncer malSyncer, opts ...Option, ) *Syncer { s := &Syncer{ @@ -172,6 +183,7 @@ func NewSyncer( cfg: DefaultConfig(), cdb: cdb, atxsyncer: atxSyncer, + malsyncer: malSyncer, ticker: ticker, beacon: beacon, mesh: mesh, @@ -456,7 +468,7 @@ func (s *Syncer) syncAtx(ctx context.Context) error { // FIXME https://github.com/spacemeshos/go-spacemesh/issues/3987 s.logger.With().Info("syncing malicious proofs", log.Context(ctx)) - if err := s.syncMalfeasance(ctx); err != nil { + if err := s.syncMalfeasance(ctx, current.GetEpoch()); err != nil { return err } s.logger.With().Info("malicious IDs synced", log.Context(ctx)) @@ -465,6 +477,7 @@ func (s *Syncer) syncAtx(ctx context.Context) error { publish := current.GetEpoch() if publish == 0 { + s.logger.With().Info("QQQQQ: nothing to sync") return nil // nothing to sync in epoch 0 } @@ -497,6 +510,22 @@ func (s *Syncer) syncAtx(ctx context.Context) error { return err }) } + s.logger.With().Info("QQQQQ: aaaa") + if !s.malSync.started { + s.malSync.started = true + s.malSync.eg.Go(func() error { + select { + case <-ctx.Done(): + return nil + case <-s.awaitATXSyncedCh: + err := s.malsyncer.DownloadLoop(ctx) + if err != nil && !errors.Is(err, context.Canceled) { + s.logger.WithContext(ctx).Error("malfeasance sync failed", log.Err(err)) + } + return nil + } + }) + } return nil } @@ -585,9 +614,11 @@ func (s *Syncer) setStateAfterSync(ctx context.Context, success bool) { } } -func (s *Syncer) syncMalfeasance(ctx context.Context) error { - if err := s.dataFetcher.PollMaliciousProofs(ctx); err != nil { - return fmt.Errorf("PollMaliciousProofs: %w", err) +func (s *Syncer) syncMalfeasance(ctx context.Context, epoch types.EpochID) error { + epochStart := s.ticker.LayerToTime(epoch.FirstLayer()) + epochEnd := s.ticker.LayerToTime(epoch.Add(1).FirstLayer()) + if err := s.malsyncer.EnsureInSync(ctx, epochStart, epochEnd); err != nil { + return fmt.Errorf("syncing malfeasance proof: %w", err) } return nil } diff --git a/syncer/syncer_test.go b/syncer/syncer_test.go index e24f3e458a..e753a8f951 100644 --- a/syncer/syncer_test.go +++ b/syncer/syncer_test.go @@ -65,6 +65,7 @@ func (mlt *mockLayerTicker) LayerToTime(layerID types.LayerID) time.Time { } type testSyncer struct { + t testing.TB syncer *Syncer cdb *datastore.CachedDB msh *mesh.Mesh @@ -72,6 +73,7 @@ type testSyncer struct { mDataFetcher *mocks.MockfetchLogic mAtxSyncer *mocks.MockatxSyncer + mMalSyncer *mocks.MockmalSyncer mBeacon *smocks.MockBeaconGetter mLyrPatrol *mocks.MocklayerPatrol mVm *mmocks.MockvmState @@ -81,15 +83,42 @@ type testSyncer struct { mForkFinder *mocks.MockforkFinder } +func (ts *testSyncer) expectMalEnsureInSync(current types.LayerID) { + ts.mMalSyncer.EXPECT().EnsureInSync( + gomock.Any(), + ts.mTicker.LayerToTime(current.GetEpoch().FirstLayer()), + ts.mTicker.LayerToTime(current.GetEpoch().Add(1).FirstLayer()), + ) +} + +func (ts *testSyncer) expectDownloadLoop() chan struct{} { + ch := make(chan struct{}) + ts.mMalSyncer.EXPECT().DownloadLoop(gomock.Any()). + DoAndReturn(func(context.Context) error { + close(ch) + return nil + }) + ts.t.Cleanup(func() { + select { + case <-ch: + case <-time.After(10 * time.Second): + require.FailNow(ts.t, "timed out waiting for malsync loop start") + } + }) + return ch +} + func newTestSyncer(t *testing.T, interval time.Duration) *testSyncer { lg := logtest.New(t) mt := newMockLayerTicker() ctrl := gomock.NewController(t) ts := &testSyncer{ + t: t, mTicker: mt, mDataFetcher: mocks.NewMockfetchLogic(ctrl), mAtxSyncer: mocks.NewMockatxSyncer(ctrl), + mMalSyncer: mocks.NewMockmalSyncer(ctrl), mBeacon: smocks.NewMockBeaconGetter(ctrl), mLyrPatrol: mocks.NewMocklayerPatrol(ctrl), mVm: mmocks.NewMockvmState(ctrl), @@ -124,6 +153,7 @@ func newTestSyncer(t *testing.T, interval time.Duration) *testSyncer { ts.mLyrPatrol, ts.mCertHdr, ts.mAtxSyncer, + ts.mMalSyncer, WithConfig(cfg), WithLogger(lg), withDataFetcher(ts.mDataFetcher), @@ -173,10 +203,11 @@ func TestSynchronize_OnlyOneSynchronize(t *testing.T) { ts.mTicker.advanceToLayer(current) ctx, cancel := context.WithCancel(context.Background()) defer cancel() + dlCh := ts.expectDownloadLoop() ts.syncer.Start() ts.mAtxSyncer.EXPECT().Download(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - ts.mDataFetcher.EXPECT().PollMaliciousProofs(gomock.Any()) + ts.expectMalEnsureInSync(current) gLayer := types.GetEffectiveGenesis() started := make(chan struct{}, 1) @@ -203,6 +234,7 @@ func TestSynchronize_OnlyOneSynchronize(t *testing.T) { // allow synchronize to finish close(done) require.NoError(t, eg.Wait()) + <-dlCh cancel() ts.syncer.Close() @@ -232,6 +264,7 @@ func advanceState(t testing.TB, ts *testSyncer, from, to types.LayerID) { func TestSynchronize_AllGood(t *testing.T) { ts := newSyncerWithoutPeriodicRuns(t) + ts.expectDownloadLoop() gLayer := types.GetEffectiveGenesis() current1 := gLayer.Add(10) ts.mTicker.advanceToLayer(current1) @@ -257,7 +290,7 @@ func TestSynchronize_AllGood(t *testing.T) { }) } - ts.mDataFetcher.EXPECT().PollMaliciousProofs(gomock.Any()) + ts.expectMalEnsureInSync(current1) for lid := gLayer.Add(1); lid.Before(current2); lid = lid.Add(1) { ts.mDataFetcher.EXPECT().PollLayerData(gomock.Any(), lid) } @@ -306,13 +339,14 @@ func TestSynchronize_AllGood(t *testing.T) { func TestSynchronize_FetchLayerDataFailed(t *testing.T) { ts := newSyncerWithoutPeriodicRuns(t) + ts.expectDownloadLoop() gLayer := types.GetEffectiveGenesis() current := gLayer.Add(2) ts.mTicker.advanceToLayer(current) lyr := current.Sub(1) // times 2 as we will also spinup background worker ts.mAtxSyncer.EXPECT().Download(gomock.Any(), gLayer.GetEpoch(), gomock.Any()).Times(2) - ts.mDataFetcher.EXPECT().PollMaliciousProofs(gomock.Any()) + ts.expectMalEnsureInSync(current) ts.mDataFetcher.EXPECT().PollLayerData(gomock.Any(), lyr).Return(errors.New("meh")) require.False(t, ts.syncer.synchronize(context.Background())) @@ -332,7 +366,7 @@ func TestSynchronize_FetchMalfeasanceFailed(t *testing.T) { ts.mTicker.advanceToLayer(current) lyr := current.Sub(1) ts.mAtxSyncer.EXPECT().Download(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - ts.mDataFetcher.EXPECT().PollMaliciousProofs(gomock.Any()).Return(errors.New("meh")) + ts.mMalSyncer.EXPECT().EnsureInSync(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("meh")) require.False(t, ts.syncer.synchronize(context.Background())) require.EqualValues(t, current.GetEpoch()-1, ts.syncer.lastAtxEpoch()) @@ -388,7 +422,7 @@ func startWithSyncedState(t *testing.T, ts *testSyncer) types.LayerID { gLayer := types.GetEffectiveGenesis() ts.mTicker.advanceToLayer(gLayer) - ts.mDataFetcher.EXPECT().PollMaliciousProofs(gomock.Any()) + ts.expectMalEnsureInSync(gLayer) ts.mAtxSyncer.EXPECT().Download(gomock.Any(), gLayer.GetEpoch(), gomock.Any()) require.True(t, ts.syncer.synchronize(context.Background())) ts.syncer.waitBackgroundSync() @@ -418,8 +452,10 @@ func TestSyncAtxs_Genesis(t *testing.T) { }) t.Run("first atx epoch", func(t *testing.T) { ts := newSyncerWithoutPeriodicRuns(t) + ts.expectDownloadLoop() epoch := types.EpochID(1) - ts.mTicker.advanceToLayer(epoch.FirstLayer() + 2) // to pass epoch end fraction threshold + current := epoch.FirstLayer() + 2 + ts.mTicker.advanceToLayer(current) // to pass epoch end fraction threshold require.False(t, ts.syncer.ListenToATXGossip()) wait := make(chan types.EpochID, 1) ts.mAtxSyncer.EXPECT(). @@ -432,7 +468,7 @@ func TestSyncAtxs_Genesis(t *testing.T) { } return nil }) - ts.mDataFetcher.EXPECT().PollMaliciousProofs(gomock.Any()) + ts.expectMalEnsureInSync(current) require.True(t, ts.syncer.synchronize(context.Background())) require.True(t, ts.syncer.ListenToATXGossip()) select { @@ -465,6 +501,7 @@ func TestSyncAtxs(t *testing.T) { tc := tc t.Run(tc.desc, func(t *testing.T) { ts := newSyncerWithoutPeriodicRuns(t) + ts.expectDownloadLoop() lyr := startWithSyncedState(t, ts) require.LessOrEqual(t, lyr, tc.current) ts.mTicker.advanceToLayer(tc.current) @@ -482,6 +519,7 @@ func TestSyncAtxs(t *testing.T) { func TestSynchronize_StaySyncedUponFailure(t *testing.T) { ts := newSyncerWithoutPeriodicRuns(t) + ts.expectDownloadLoop() lyr := startWithSyncedState(t, ts) current := lyr.Add(1) ts.mTicker.advanceToLayer(current) @@ -498,6 +536,7 @@ func TestSynchronize_StaySyncedUponFailure(t *testing.T) { func TestSynchronize_BecomeNotSyncedUponFailureIfNoGossip(t *testing.T) { ts := newSyncerWithoutPeriodicRuns(t) + ts.expectDownloadLoop() lyr := startWithSyncedState(t, ts) current := lyr.Add(outOfSyncThreshold) ts.mTicker.advanceToLayer(current) @@ -516,12 +555,13 @@ func TestSynchronize_BecomeNotSyncedUponFailureIfNoGossip(t *testing.T) { // test the case where the node originally starts from notSynced and eventually becomes synced. func TestFromNotSyncedToSynced(t *testing.T) { ts := newSyncerWithoutPeriodicRuns(t) + ts.expectDownloadLoop() ts.mAtxSyncer.EXPECT().Download(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - ts.mDataFetcher.EXPECT().PollMaliciousProofs(gomock.Any()) lyr := types.GetEffectiveGenesis().Add(1) current := lyr.Add(5) ts.mTicker.advanceToLayer(current) ts.mDataFetcher.EXPECT().PollLayerData(gomock.Any(), lyr).Return(errors.New("baa-ram-ewe")) + ts.expectMalEnsureInSync(current) require.False(t, ts.syncer.synchronize(context.Background())) require.False(t, ts.syncer.dataSynced()) @@ -550,11 +590,12 @@ func TestFromNotSyncedToSynced(t *testing.T) { // to notSynced. func TestFromGossipSyncToNotSynced(t *testing.T) { ts := newSyncerWithoutPeriodicRuns(t) + ts.expectDownloadLoop() ts.mAtxSyncer.EXPECT().Download(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() lyr := types.GetEffectiveGenesis().Add(1) current := lyr.Add(1) ts.mTicker.advanceToLayer(current) - ts.mDataFetcher.EXPECT().PollMaliciousProofs(gomock.Any()) + ts.expectMalEnsureInSync(current) ts.mDataFetcher.EXPECT().PollLayerData(gomock.Any(), lyr) require.True(t, ts.syncer.synchronize(context.Background())) @@ -581,6 +622,7 @@ func TestFromGossipSyncToNotSynced(t *testing.T) { func TestNetworkHasNoData(t *testing.T) { ts := newSyncerWithoutPeriodicRuns(t) + ts.expectDownloadLoop() lyr := startWithSyncedState(t, ts) require.True(t, ts.syncer.IsSynced(context.Background())) @@ -606,8 +648,8 @@ func TestNetworkHasNoData(t *testing.T) { // eventually become synced again. func TestFromSyncedToNotSynced(t *testing.T) { ts := newSyncerWithoutPeriodicRuns(t) + ts.expectDownloadLoop() ts.mAtxSyncer.EXPECT().Download(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - ts.mDataFetcher.EXPECT().PollMaliciousProofs(gomock.Any()).AnyTimes() require.True(t, ts.syncer.synchronize(context.Background())) require.True(t, ts.syncer.IsSynced(context.Background())) @@ -657,11 +699,12 @@ func waitOutGossipSync(t *testing.T, ts *testSyncer) { func TestSync_AlsoSyncProcessedLayer(t *testing.T) { ts := newSyncerWithoutPeriodicRuns(t) + ts.expectDownloadLoop() ts.mAtxSyncer.EXPECT().Download(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - ts.mDataFetcher.EXPECT().PollMaliciousProofs(gomock.Any()) lyr := types.GetEffectiveGenesis().Add(1) current := lyr.Add(1) ts.mTicker.advanceToLayer(current) + ts.expectMalEnsureInSync(current) // simulate hare advancing the mesh forward ts.mTortoise.EXPECT().TallyVotes(gomock.Any(), lyr) @@ -717,6 +760,7 @@ func TestSyncer_IsBeaconSynced(t *testing.T) { func TestSynchronize_RecoverFromCheckpoint(t *testing.T) { ts := newSyncerWithoutPeriodicRuns(t) + ts.expectDownloadLoop() current := types.GetEffectiveGenesis().Add(types.GetLayersPerEpoch() * 5) // recover from a checkpoint types.SetEffectiveGenesis(current.Uint32()) @@ -731,6 +775,7 @@ func TestSynchronize_RecoverFromCheckpoint(t *testing.T) { ts.mLyrPatrol, ts.mCertHdr, ts.mAtxSyncer, + ts.mMalSyncer, WithConfig(ts.syncer.cfg), WithLogger(ts.syncer.logger), withDataFetcher(ts.mDataFetcher), @@ -739,7 +784,7 @@ func TestSynchronize_RecoverFromCheckpoint(t *testing.T) { // should not sync any atxs before current epoch ts.mAtxSyncer.EXPECT().Download(gomock.Any(), current.GetEpoch(), gomock.Any()) - ts.mDataFetcher.EXPECT().PollMaliciousProofs(gomock.Any()) + ts.expectMalEnsureInSync(current) require.True(t, ts.syncer.synchronize(context.Background())) ts.syncer.waitBackgroundSync() require.Equal(t, current.GetEpoch(), ts.syncer.lastAtxEpoch()) diff --git a/system/fetcher.go b/system/fetcher.go index 24341bad27..dec2655835 100644 --- a/system/fetcher.go +++ b/system/fetcher.go @@ -69,6 +69,11 @@ type ActiveSetFetcher interface { GetActiveSet(context.Context, types.Hash32) error } +// MalfeasanceProofFetcher defines an interface for fetching malfeasance proofs. +type MalfeasanceProofFetcher interface { + GetMalfeasanceProofs(context.Context, []types.NodeID) error +} + // PeerTracker defines an interface to track peer hashes. type PeerTracker interface { RegisterPeerHashes(peer p2p.Peer, hashes []types.Hash32) diff --git a/system/mocks/fetcher.go b/system/mocks/fetcher.go index fc231df720..6b6ed6755d 100644 --- a/system/mocks/fetcher.go +++ b/system/mocks/fetcher.go @@ -857,6 +857,67 @@ func (c *MockActiveSetFetcherGetActiveSetCall) DoAndReturn(f func(context.Contex return c } +// MockMalfeasanceProofFetcher is a mock of MalfeasanceProofFetcher interface. +type MockMalfeasanceProofFetcher struct { + ctrl *gomock.Controller + recorder *MockMalfeasanceProofFetcherMockRecorder +} + +// MockMalfeasanceProofFetcherMockRecorder is the mock recorder for MockMalfeasanceProofFetcher. +type MockMalfeasanceProofFetcherMockRecorder struct { + mock *MockMalfeasanceProofFetcher +} + +// NewMockMalfeasanceProofFetcher creates a new mock instance. +func NewMockMalfeasanceProofFetcher(ctrl *gomock.Controller) *MockMalfeasanceProofFetcher { + mock := &MockMalfeasanceProofFetcher{ctrl: ctrl} + mock.recorder = &MockMalfeasanceProofFetcherMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMalfeasanceProofFetcher) EXPECT() *MockMalfeasanceProofFetcherMockRecorder { + return m.recorder +} + +// GetMalfeasanceProofs mocks base method. +func (m *MockMalfeasanceProofFetcher) GetMalfeasanceProofs(arg0 context.Context, arg1 []types.NodeID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMalfeasanceProofs", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// GetMalfeasanceProofs indicates an expected call of GetMalfeasanceProofs. +func (mr *MockMalfeasanceProofFetcherMockRecorder) GetMalfeasanceProofs(arg0, arg1 any) *MockMalfeasanceProofFetcherGetMalfeasanceProofsCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMalfeasanceProofs", reflect.TypeOf((*MockMalfeasanceProofFetcher)(nil).GetMalfeasanceProofs), arg0, arg1) + return &MockMalfeasanceProofFetcherGetMalfeasanceProofsCall{Call: call} +} + +// MockMalfeasanceProofFetcherGetMalfeasanceProofsCall wrap *gomock.Call +type MockMalfeasanceProofFetcherGetMalfeasanceProofsCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockMalfeasanceProofFetcherGetMalfeasanceProofsCall) Return(arg0 error) *MockMalfeasanceProofFetcherGetMalfeasanceProofsCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockMalfeasanceProofFetcherGetMalfeasanceProofsCall) Do(f func(context.Context, []types.NodeID) error) *MockMalfeasanceProofFetcherGetMalfeasanceProofsCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockMalfeasanceProofFetcherGetMalfeasanceProofsCall) DoAndReturn(f func(context.Context, []types.NodeID) error) *MockMalfeasanceProofFetcherGetMalfeasanceProofsCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // MockPeerTracker is a mock of PeerTracker interface. type MockPeerTracker struct { ctrl *gomock.Controller