diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a0429ad0dd..10aaa53a58 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -86,6 +86,8 @@ jobs: make test-fmt make test-tidy make test-generate + - name: Check for vulnerabilities + run: make vulncheck lint: runs-on: ubuntu-22.04 @@ -203,7 +205,7 @@ jobs: - name: test coverage run: make cover - name: Upload to codecov.io - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 5606ed5664..18623abce6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,6 @@ See [RELEASE](./RELEASE.md) for workflow instructions. ### Improvements -* [#6431](https://github.com/spacemeshos/go-spacemesh/pull/6431) Fix db-allow-schema-drift handling * [#6408](https://github.com/spacemeshos/go-spacemesh/pull/6408) Prevent empty DB connection pool by freeing connections upon errors during DB operations. This mostly fixes issues when a node is under heavy load from the API. @@ -15,6 +14,12 @@ See [RELEASE](./RELEASE.md) for workflow instructions. * [#6422](https://github.com/spacemeshos/go-spacemesh/pull/6422) Further improved performance of the proposal building process to avoid late proposals. +* [#6443](https://github.com/spacemeshos/go-spacemesh/pull/6443) Improve eviction of ineffectual transactions in the database + which will now show up as ineffectual when querying them from the API. + +* [#6431](https://github.com/spacemeshos/go-spacemesh/pull/6431) Fix db-allow-schema-drift handling + +* [#6451](https://github.com/spacemeshos/go-spacemesh/pull/6451) Fix a possible deadloop in the beacon protocol. ## v1.7.6 diff --git a/Makefile b/Makefile index 641b6074de..d26852c722 100644 --- a/Makefile +++ b/Makefile @@ -4,6 +4,7 @@ BRANCH ?= $(shell git rev-parse --abbrev-ref HEAD) GOLANGCI_LINT_VERSION := v1.61.0 GOTESTSUM_VERSION := v1.12.0 +GOVULNCHECK_VERSION := v1.1.3 GOSCALE_VERSION := v1.2.0 MOCKGEN_VERSION := v0.5.0 @@ -68,6 +69,7 @@ install: go install github.com/spacemeshos/go-scale/scalegen@$(GOSCALE_VERSION) go install go.uber.org/mock/mockgen@$(MOCKGEN_VERSION) go install gotest.tools/gotestsum@$(GOTESTSUM_VERSION) + go install golang.org/x/vuln/cmd/govulncheck@$(GOVULNCHECK_VERSION) .PHONY: install build: go-spacemesh get-profiler get-postrs-service @@ -146,6 +148,10 @@ cover: get-libs @$(ULIMIT) CGO_LDFLAGS="$(CGO_TEST_LDFLAGS)" go test -coverprofile=cover.out -p 1 -timeout 30m -coverpkg=./... $(UNIT_TESTS) .PHONY: cover +vulncheck: get-libs + govulncheck ./... +.PHONY: vulncheck + list-versions: @echo "Latest 5 tagged versions:\n" @git for-each-ref --sort=-creatordate --count=5 --format '%(creatordate:short): %(refname:short)' refs/tags diff --git a/activation/handler.go b/activation/handler.go index da99dd999d..3960e484a7 100644 --- a/activation/handler.go +++ b/activation/handler.go @@ -146,7 +146,7 @@ func NewHandler( fetcher: fetcher, beacon: beacon, tortoise: tortoise, - malPublisher: &MalfeasancePublisher{}, + malPublisher: &MalfeasancePublisher{}, // TODO(mafa): pass real publisher when available }, } diff --git a/activation/handler_test.go b/activation/handler_test.go index c080012133..7c30c86fd2 100644 --- a/activation/handler_test.go +++ b/activation/handler_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "sort" "testing" "testing/quick" @@ -128,7 +129,7 @@ type handlerMocks struct { mValidator *MocknipostValidator mbeacon *MockAtxReceiver mtortoise *mocks.MockTortoise - mMalPublish *MockmalfeasancePublisher + mMalPublish *MockatxMalfeasancePublisher } type testHandler struct { @@ -159,6 +160,7 @@ func (h *handlerMocks) expectAtxV1(atx *wire.ActivationTxV1, nodeId types.NodeID } h.mockFetch.EXPECT().RegisterPeerHashes(gomock.Any(), gomock.Any()) h.mockFetch.EXPECT().GetPoetProof(gomock.Any(), types.BytesToHash(atx.NIPost.PostMetadata.Challenge)) + deps := []types.ATXID{atx.PrevATXID, atx.PositioningATXID} if atx.PrevATXID == types.EmptyATXID { h.mValidator.EXPECT().InitialNIPostChallengeV1(gomock.Any(), gomock.Any(), h.goldenATXID) h.mValidator.EXPECT(). @@ -170,9 +172,17 @@ func (h *handlerMocks) expectAtxV1(atx *wire.ActivationTxV1, nodeId types.NodeID time.Sleep(settings.postVerificationDuration) return nil }) + deps = append(deps, *atx.CommitmentATXID) } else { h.mValidator.EXPECT().NIPostChallengeV1(gomock.Any(), gomock.Any(), nodeId) } + deps = slices.Compact(deps) + deps = slices.DeleteFunc(deps, func(dep types.ATXID) bool { + return dep == types.EmptyATXID || dep == h.goldenATXID + }) + if len(deps) > 0 { + h.mockFetch.EXPECT().GetAtxs(gomock.Any(), deps, gomock.Any()) + } h.mValidator.EXPECT().PositioningAtx(atx.PositioningATXID, gomock.Any(), h.goldenATXID, atx.PublishEpoch) h.mValidator.EXPECT(). NIPost(gomock.Any(), nodeId, h.goldenATXID, gomock.Any(), gomock.Any(), atx.NumUnits, gomock.Any()). @@ -194,7 +204,7 @@ func newTestHandlerMocks(tb testing.TB, golden types.ATXID) handlerMocks { mValidator: NewMocknipostValidator(ctrl), mbeacon: NewMockAtxReceiver(ctrl), mtortoise: mocks.NewMockTortoise(ctrl), - mMalPublish: NewMockmalfeasancePublisher(ctrl), + mMalPublish: NewMockatxMalfeasancePublisher(ctrl), } } @@ -205,6 +215,8 @@ func newTestHandler(tb testing.TB, goldenATXID types.ATXID, opts ...HandlerOptio edVerifier := signing.NewEdVerifier() mocks := newTestHandlerMocks(tb, goldenATXID) + // TODO(mafa): make mandatory parameter when real publisher is available + opts = append(opts, func(h *Handler) { h.v2.malPublisher = mocks.mMalPublish }) atxHdlr := NewHandler( "localID", cdb, @@ -341,7 +353,6 @@ func TestHandler_ProcessAtxStoresNewVRFNonce(t *testing.T) { atx2.VRFNonce = (*uint64)(&nonce2) atx2.Sign(sig) atxHdlr.expectAtxV1(atx2, sig.NodeID()) - atxHdlr.mockFetch.EXPECT().GetAtxs(gomock.Any(), gomock.Any(), gomock.Any()) require.NoError(t, atxHdlr.HandleGossipAtx(context.Background(), "", codec.MustEncode(atx2))) got, err = atxs.VRFNonce(atxHdlr.cdb, sig.NodeID(), atx2.PublishEpoch+1) @@ -391,7 +402,6 @@ func TestHandler_HandleGossipAtx(t *testing.T) { // second is now valid (deps are in) atxHdlr.expectAtxV1(second, sig.NodeID()) - atxHdlr.mockFetch.EXPECT().GetAtxs(gomock.Any(), []types.ATXID{second.PrevATXID}, gomock.Any()) require.NoError(t, atxHdlr.HandleGossipAtx(context.Background(), "", codec.MustEncode(second))) } @@ -695,7 +705,6 @@ func TestHandler_AtxWeight(t *testing.T) { buf = codec.MustEncode(atx2) atxHdlr.expectAtxV1(atx2, sig.NodeID(), func(o *atxHandleOpts) { o.poetLeaves = leaves }) - atxHdlr.mockFetch.EXPECT().GetAtxs(gomock.Any(), []types.ATXID{atx1.ID()}, gomock.Any()) require.NoError(t, atxHdlr.HandleSyncedAtx(context.Background(), atx2.ID().Hash32(), peer, buf)) stored2, err := atxHdlr.cdb.GetAtx(atx2.ID()) diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 3ec7abcca0..b7386a4dd3 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -69,7 +69,7 @@ type HandlerV2 struct { tortoise system.Tortoise logger *zap.Logger fetcher system.Fetcher - malPublisher malfeasancePublisher + malPublisher atxMalfeasancePublisher } func (h *HandlerV2) processATX( @@ -338,7 +338,7 @@ func (h *HandlerV2) validateCommitmentAtx(golden, commitmentAtxId types.ATXID, p if commitmentAtxId != golden { commitment, err := atxs.Get(h.cdb, commitmentAtxId) if err != nil { - return &ErrAtxNotFound{Id: commitmentAtxId, source: err} + return fmt.Errorf("ATX (%s) not found: %w", commitmentAtxId.ShortString(), err) } if publish <= commitment.PublishEpoch { return fmt.Errorf( @@ -359,7 +359,7 @@ func (h *HandlerV2) validatePositioningAtx(publish types.EpochID, golden, positi posAtx, err := atxs.Get(h.cdb, positioning) if err != nil { - return 0, &ErrAtxNotFound{Id: positioning, source: err} + return 0, fmt.Errorf("positioning ATX (%s) not found: %w", positioning.ShortString(), err) } if posAtx.PublishEpoch >= publish { return 0, fmt.Errorf("positioning atx epoch (%v) must be before %v", posAtx.PublishEpoch, publish) @@ -607,104 +607,124 @@ func (h *HandlerV2) syntacticallyValidateDeps( } // validate all niposts + if atx.Initial != nil { + commitment := atx.Initial.CommitmentATX + nipostIdx := 0 + challenge := atx.NIPosts[nipostIdx].Challenge + post := atx.NIPosts[nipostIdx].Posts[0] + if err := h.validatePost(ctx, atx.SmesherID, atx, commitment, challenge, post, nipostIdx); err != nil { + return nil, err + } + result.ids[atx.SmesherID] = idData{ + previous: types.EmptyATXID, + previousIndex: 0, + units: post.NumUnits, + } + result.ticks = nipostSizes.minTicks() + return &result, nil + } + var smesherCommitment *types.ATXID for idx, niposts := range atx.NIPosts { for _, post := range niposts.Posts { id := equivocationSet[post.MarriageIndex] - var commitment types.ATXID - var previous types.ATXID - if atx.Initial != nil { - commitment = atx.Initial.CommitmentATX - } else { - var err error - commitment, err = atxs.CommitmentATX(h.cdb, id) - if err != nil { - return nil, fmt.Errorf("commitment atx not found for ID %s: %w", id, err) - } - if id == atx.SmesherID { - smesherCommitment = &commitment - } - previous = previousAtxs[post.PrevATXIndex].ID() + commitment, err := atxs.CommitmentATX(h.cdb, id) + if err != nil { + return nil, fmt.Errorf("commitment atx not found for ID %s: %w", id, err) } - - err := h.nipostValidator.PostV2( - ctx, - id, - commitment, - wire.PostFromWireV1(&post.Post), - niposts.Challenge.Bytes(), - post.NumUnits, - PostSubset([]byte(h.local)), - ) - invalidIdx := &verifying.ErrInvalidIndex{} - switch { - case errors.As(err, invalidIdx): - if err := h.publishInvalidPostProof(ctx, atx, id, idx, uint32(invalidIdx.Index)); err != nil { - return nil, fmt.Errorf("publishing invalid post proof: %w", err) - } - return nil, fmt.Errorf("invalid post for ID %s: %w", id.ShortString(), err) - case err != nil: - return nil, fmt.Errorf("validating post for ID %s: %w", id.ShortString(), err) + if id == atx.SmesherID { + smesherCommitment = &commitment + } + if err := h.validatePost(ctx, id, atx, commitment, niposts.Challenge, post, idx); err != nil { + return nil, err } result.ids[id] = idData{ - previous: previous, + previous: previousAtxs[post.PrevATXIndex].ID(), previousIndex: int(post.PrevATXIndex), units: post.NumUnits, } } } - if atx.Initial == nil { - if smesherCommitment == nil { - return nil, errors.New("ATX signer not present in merged ATX") - } - err := h.nipostValidator.VRFNonceV2(atx.SmesherID, *smesherCommitment, atx.VRFNonce, atx.TotalNumUnits()) - if err != nil { - return nil, fmt.Errorf("validating VRF nonce: %w", err) - } + if smesherCommitment == nil { + return nil, errors.New("ATX signer not present in merged ATX") + } + err = h.nipostValidator.VRFNonceV2(atx.SmesherID, *smesherCommitment, atx.VRFNonce, atx.TotalNumUnits()) + if err != nil { + return nil, fmt.Errorf("validating VRF nonce: %w", err) } result.ticks = nipostSizes.minTicks() return &result, nil } -func (h *HandlerV2) publishInvalidPostProof( +func (h *HandlerV2) validatePost( ctx context.Context, - atx *wire.ActivationTxV2, nodeID types.NodeID, + atx *wire.ActivationTxV2, + commitment types.ATXID, + challenge types.Hash32, + post wire.SubPostV2, nipostIndex int, - invalidPostIndex uint32, ) error { - initialAtx := atx - if initialAtx.Initial == nil { - initialID, err := atxs.GetFirstIDByNodeID(h.cdb, nodeID) - if err != nil { - return fmt.Errorf("fetch initial ATX for ID %s: %w", nodeID.ShortString(), err) + err := h.nipostValidator.PostV2( + ctx, + nodeID, + commitment, + wire.PostFromWireV1(&post.Post), + challenge.Bytes(), + post.NumUnits, + PostSubset([]byte(h.local)), + ) + if err == nil { + return nil + } + errInvalid := &verifying.ErrInvalidIndex{} + if !errors.As(err, &errInvalid) { + return fmt.Errorf("validating post for ID %s: %w", nodeID.ShortString(), err) + } + + // check if post contains at least one valid label + validIdx := 0 + for { + err := h.nipostValidator.PostV2( + ctx, + nodeID, + commitment, + wire.PostFromWireV1(&post.Post), + challenge.Bytes(), + post.NumUnits, + PostIndex(validIdx), + ) + if err == nil { + break } - - // TODO(mafa): implement for v1 initial ATXs: https://github.com/spacemeshos/go-spacemesh/issues/6433 - initialAtx, err = h.fetchWireAtx(ctx, h.cdb, initialID) - if err != nil { - return fmt.Errorf("fetch initial ATX blob for ID %s: %w", nodeID.ShortString(), err) + if errors.Is(err, ErrPostIndexOutOfRange) { + return fmt.Errorf("invalid post for ID %s: %w", nodeID.ShortString(), err) } + validIdx++ } - // TODO(mafa): checkpoints need to include all initial ATXs in full to be able to create this malfeasance proof: - // - // see https://github.com/spacemeshos/go-spacemesh/issues/6436 - // // TODO(mafa): checkpoints need to include all marriage ATXs in full to be able to create malfeasance proofs // like this one (but also others) // // see https://github.com/spacemeshos/go-spacemesh/issues/6435 - proof, err := wire.NewInvalidPostProof(h.cdb, atx, initialAtx, nodeID, nipostIndex, invalidPostIndex) + proof, err := wire.NewInvalidPostProof( + h.cdb, + atx, + commitment, + nodeID, + nipostIndex, + uint32(errInvalid.Index), + uint32(validIdx), + ) if err != nil { return fmt.Errorf("creating invalid post proof: %w", err) } if err := h.malPublisher.Publish(ctx, nodeID, proof); err != nil { return fmt.Errorf("publishing malfeasance proof for invalid post: %w", err) } - return nil + return fmt.Errorf("invalid post for ID %s: %w", nodeID.ShortString(), errInvalid) } func (h *HandlerV2) checkMalicious(ctx context.Context, tx sql.Transaction, atx *activationTx) (bool, error) { @@ -724,14 +744,6 @@ func (h *HandlerV2) checkMalicious(ctx context.Context, tx sql.Transaction, atx return true, nil } - malicious, err = h.checkDoublePost(ctx, tx, atx) - if err != nil { - return malicious, fmt.Errorf("checking double post: %w", err) - } - if malicious { - return true, nil - } - malicious, err = h.checkDoubleMerge(ctx, tx, atx) if err != nil { return malicious, fmt.Errorf("checking double merge: %w", err) @@ -795,31 +807,6 @@ func (h *HandlerV2) checkDoubleMarry(ctx context.Context, tx sql.Transaction, at return false, nil } -func (h *HandlerV2) checkDoublePost(ctx context.Context, tx sql.Transaction, atx *activationTx) (bool, error) { - for id := range atx.ids { - atxIDs, err := atxs.FindDoublePublish(tx, id, atx.PublishEpoch) - switch { - case errors.Is(err, sql.ErrNotFound): - continue - case err != nil: - return false, fmt.Errorf("searching for double publish: %w", err) - } - otherAtxId := slices.IndexFunc(atxIDs, func(other types.ATXID) bool { return other != atx.ID() }) - otherAtx := atxIDs[otherAtxId] - h.logger.Debug( - "found ID that has already contributed its PoST in this epoch", - zap.Stringer("node_id", id), - zap.Stringer("atx_id", atx.ID()), - zap.Stringer("other_atx_id", otherAtx), - zap.Uint32("epoch", atx.PublishEpoch.Uint32()), - ) - // TODO(mafa): finish proof - var proof wire.Proof - return true, h.malPublisher.Publish(ctx, id, proof) - } - return false, nil -} - func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx sql.Transaction, atx *activationTx) (bool, error) { if atx.MarriageATX == nil { return false, nil @@ -841,8 +828,25 @@ func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx sql.Transaction, at zap.Stringer("smesher_id", atx.SmesherID), ) - // TODO(mafa): finish proof - var proof wire.Proof + // TODO(mafa): during syntactical validation we should check if a merged ATX is targeting a checkpointed epoch + // merged ATXs need to be checkpointed with their marriage ATXs + // if there is a collision (i.e. the new ATX references the same marriage ATX as a golden ATX) it should be + // considered syntactically invalid + // + // see https://github.com/spacemeshos/go-spacemesh/issues/6434 + otherAtx, err := h.fetchWireAtx(ctx, tx, other) + if err != nil { + return false, fmt.Errorf("fetching other ATX: %w", err) + } + + // TODO(mafa): checkpoints need to include all marriage ATXs in full to be able to create malfeasance proofs + // like this one (but also others) + // + // see https://github.com/spacemeshos/go-spacemesh/issues/6435 + proof, err := wire.NewDoubleMergeProof(tx, atx.ActivationTxV2, otherAtx) + if err != nil { + return true, fmt.Errorf("creating double merge proof: %w", err) + } return true, h.malPublisher.Publish(ctx, atx.SmesherID, proof) } @@ -862,22 +866,63 @@ func (h *HandlerV2) checkPrevAtx(ctx context.Context, tx sql.Transaction, atx *a log.ZShortStringer("expected", expectedPrevID), ) - atx1, atx2, err := atxs.PrevATXCollision(tx, data.previous, id) + collisions, err := atxs.PrevATXCollisions(tx, data.previous, id) switch { case errors.Is(err, sql.ErrNotFound): continue case err != nil: - return false, fmt.Errorf("checking for previous ATX collision: %w", err) + return true, fmt.Errorf("checking for previous ATX collision: %w", err) } + var wireAtxV1 *wire.ActivationTxV1 + for _, collision := range collisions { + if collision == atx.ID() { + continue + } + var blob sql.Blob + v, err := atxs.LoadBlob(ctx, tx, collision.Bytes(), &blob) + if err != nil { + return true, fmt.Errorf("get atx blob %s: %w", id.ShortString(), err) + } + switch v { + case types.AtxV1: + if wireAtxV1 == nil { + // we have at least one v2 ATX (the one we are validating right now) so we only need one + // v1 ATX to create the proof if no other v2 ATXs are found + wireAtxV1 = &wire.ActivationTxV1{} + codec.MustDecode(blob.Bytes, wireAtxV1) + } + case types.AtxV2: + wireAtx := &wire.ActivationTxV2{} + codec.MustDecode(blob.Bytes, wireAtx) + // prefer creating a proof with 2 ATXs of version 2 + h.logger.Debug("creating a malfeasance proof for invalid previous ATX", + log.ZShortStringer("smesherID", id), + log.ZShortStringer("atx1", wireAtx.ID()), + log.ZShortStringer("atx2", atx.ActivationTxV2.ID()), + ) + proof, err := wire.NewInvalidPrevAtxProofV2(tx, atx.ActivationTxV2, wireAtx, id) + if err != nil { + return true, fmt.Errorf("creating invalid previous ATX proof: %w", err) + } + return true, h.malPublisher.Publish(ctx, id, proof) + default: + h.logger.Fatal("Failed to create invalid previous ATX proof: unknown ATX version", + zap.Stringer("atx_id", collision), + ) + } + } + + // no ATXv2 found, create a proof with an ATXv1 h.logger.Debug("creating a malfeasance proof for invalid previous ATX", log.ZShortStringer("smesherID", id), - log.ZShortStringer("atx1", atx1), - log.ZShortStringer("atx2", atx2), + log.ZShortStringer("atx1", wireAtxV1.ID()), + log.ZShortStringer("atx2", atx.ActivationTxV2.ID()), ) - - // TODO(mafa): finish proof - var proof wire.Proof + proof, err := wire.NewInvalidPrevAtxProofV1(tx, atx.ActivationTxV2, wireAtxV1, id) + if err != nil { + return true, fmt.Errorf("creating invalid previous ATX proof: %w", err) + } return true, h.malPublisher.Publish(ctx, id, proof) } return false, nil diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index a03058ab3d..19e2981233 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -135,7 +135,7 @@ func (h *handlerMocks) expectStoreAtxV2(atx *wire.ActivationTxV2) { } func (h *handlerMocks) expectInitialAtxV2(atx *wire.ActivationTxV2) { - h.mclock.EXPECT().CurrentLayer().Return(postGenesisEpoch.FirstLayer()) + h.mclock.EXPECT().CurrentLayer().Return(atx.PublishEpoch.FirstLayer()) h.mValidator.EXPECT().VRFNonceV2( atx.SmesherID, atx.Initial.CommitmentATX, @@ -175,7 +175,7 @@ func (h *handlerMocks) expectMergedAtxV2( equivocationSet []types.NodeID, poetLeaves []uint64, ) { - h.mclock.EXPECT().CurrentLayer().Return(postGenesisEpoch.FirstLayer()) + h.mclock.EXPECT().CurrentLayer().Return(atx.PublishEpoch.FirstLayer()) h.expectFetchDeps(atx) h.mValidator.EXPECT().VRFNonceV2( atx.SmesherID, @@ -984,12 +984,29 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { merged.PreviousATXs = []types.ATXID{otherATXs[1].ID(), otherATXs[2].ID()} merged.Sign(signers[2]) + verifier := wire.NewMockMalfeasanceValidator(atxHandler.ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return atxHandler.edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + atxHandler.expectMergedAtxV2(merged, equivocationSet, []uint64{100}) - atxHandler.mMalPublish.EXPECT().Publish(gomock.Any(), merged.SmesherID, gomock.Any()) + atxHandler.mMalPublish.EXPECT().Publish( + gomock.Any(), + merged.SmesherID, + gomock.AssignableToTypeOf(&wire.ProofDoubleMerge{}), + ).DoAndReturn(func(ctx context.Context, id types.NodeID, proof wire.Proof) error { + malProof := proof.(*wire.ProofDoubleMerge) + nId, err := malProof.Valid(context.Background(), verifier) + require.NoError(t, err) + require.Equal(t, merged.SmesherID, nId) + return nil + }) err = atxHandler.processATX(context.Background(), "", merged, time.Now()) require.NoError(t, err) }) t.Run("publishing two merged ATXs (one checkpointed)", func(t *testing.T) { + t.Skip("syntactically validating double merge where one ATX is checkpointed isn't implemented yet") atxHandler := newV2TestHandler(t, golden) mATX, otherATXs := marryIDs(t, atxHandler, signers, golden) @@ -1019,12 +1036,12 @@ func TestHandlerV2_ProcessMergedATX(t *testing.T) { merged.MarriageATX = &mATXID merged.PreviousATXs = []types.ATXID{otherATXs[1].ID(), otherATXs[2].ID(), otherATXs[3].ID()} merged.Sign(signers[2]) - atxHandler.expectMergedAtxV2(merged, equivocationSet, []uint64{100}) - // TODO: this could be syntactically validated as all nodes in the network + + // This is syntactically invalid as all nodes in the network // should already have the checkpointed merged ATX. - atxHandler.mMalPublish.EXPECT().Publish(gomock.Any(), merged.SmesherID, gomock.Any()) + atxHandler.expectMergedAtxV2(merged, equivocationSet, []uint64{100}) err := atxHandler.processATX(context.Background(), "", merged, time.Now()) - require.NoError(t, err) + require.Error(t, err) }) } @@ -1536,8 +1553,42 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { wire.PostFromWireV1(&atx.NIPosts[0].Posts[0].Post), atx.NIPosts[0].Challenge.Bytes(), atx.TotalNumUnits(), - gomock.Any(), - ).Return(verifying.ErrInvalidIndex{Index: 7}) + gomock.Cond(func(opt validatorOption) bool { + opts := &validatorOptions{} + opt(opts) + return opts.postSubsetSeed != nil + }), + ).Return(&verifying.ErrInvalidIndex{Index: 7}) + + for invalidPostIdx := 0; invalidPostIdx < 10; invalidPostIdx++ { + atxHandler.mValidator.EXPECT().PostV2( + context.Background(), + atx.SmesherID, + atx.Initial.CommitmentATX, + wire.PostFromWireV1(&atx.NIPosts[0].Posts[0].Post), + atx.NIPosts[0].Challenge.Bytes(), + atx.TotalNumUnits(), + gomock.Cond(func(opt validatorOption) bool { + opts := &validatorOptions{} + opt(opts) + return opts.postIdx != nil && *opts.postIdx == invalidPostIdx + }), + ).Return(&verifying.ErrInvalidIndex{Index: invalidPostIdx}) + } + + atxHandler.mValidator.EXPECT().PostV2( + context.Background(), + atx.SmesherID, + atx.Initial.CommitmentATX, + wire.PostFromWireV1(&atx.NIPosts[0].Posts[0].Post), + atx.NIPosts[0].Challenge.Bytes(), + atx.TotalNumUnits(), + gomock.Cond(func(opt validatorOption) bool { + opts := &validatorOptions{} + opt(opts) + return opts.postIdx != nil && *opts.postIdx == 10 + }), + ).Return(nil) verifier := wire.NewMockMalfeasanceValidator(atxHandler.ctrl) verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). @@ -1545,6 +1596,16 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { return atxHandler.edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() + verifier.EXPECT().PostIndex( + context.Background(), + atx.SmesherID, + atx.Initial.CommitmentATX, + wire.PostFromWireV1(&atx.NIPosts[0].Posts[0].Post), + atx.NIPosts[0].Challenge.Bytes(), + atx.TotalNumUnits(), + 10, + ).Return(nil) + verifier.EXPECT().PostIndex( context.Background(), atx.SmesherID, @@ -1558,10 +1619,7 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { atxHandler.mMalPublish.EXPECT().Publish( gomock.Any(), sig.NodeID(), - gomock.Cond(func(data wire.Proof) bool { - _, ok := data.(*wire.ProofInvalidPost) - return ok - }), + gomock.AssignableToTypeOf(&wire.ProofInvalidPost{}), ).DoAndReturn(func(ctx context.Context, _ types.NodeID, proof wire.Proof) error { malProof := proof.(*wire.ProofInvalidPost) nId, err := malProof.Valid(ctx, verifier) @@ -1571,7 +1629,7 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { }) _, err := atxHandler.syntacticallyValidateDeps(context.Background(), atx) vErr := &verifying.ErrInvalidIndex{} - require.ErrorAs(t, err, vErr) + require.ErrorAs(t, err, &vErr) require.Equal(t, 7, vErr.Index) }) t.Run("invalid PoST index solo ATX - generates a malfeasance proof", func(t *testing.T) { @@ -1590,8 +1648,42 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { wire.PostFromWireV1(&atx.NIPosts[0].Posts[0].Post), atx.NIPosts[0].Challenge.Bytes(), atx.TotalNumUnits(), - gomock.Any(), - ).Return(verifying.ErrInvalidIndex{Index: 7}) + gomock.Cond(func(opt validatorOption) bool { + opts := &validatorOptions{} + opt(opts) + return opts.postSubsetSeed != nil + }), + ).Return(&verifying.ErrInvalidIndex{Index: 7}) + + for invalidPostIdx := 0; invalidPostIdx < 10; invalidPostIdx++ { + atxHandler.mValidator.EXPECT().PostV2( + context.Background(), + atx.SmesherID, + initialAtx.Initial.CommitmentATX, + wire.PostFromWireV1(&atx.NIPosts[0].Posts[0].Post), + atx.NIPosts[0].Challenge.Bytes(), + atx.TotalNumUnits(), + gomock.Cond(func(opt validatorOption) bool { + opts := &validatorOptions{} + opt(opts) + return opts.postIdx != nil && *opts.postIdx == invalidPostIdx + }), + ).Return(&verifying.ErrInvalidIndex{Index: invalidPostIdx}) + } + + atxHandler.mValidator.EXPECT().PostV2( + context.Background(), + atx.SmesherID, + initialAtx.Initial.CommitmentATX, + wire.PostFromWireV1(&atx.NIPosts[0].Posts[0].Post), + atx.NIPosts[0].Challenge.Bytes(), + atx.TotalNumUnits(), + gomock.Cond(func(opt validatorOption) bool { + opts := &validatorOptions{} + opt(opts) + return opts.postIdx != nil && *opts.postIdx == 10 + }), + ).Return(nil) verifier := wire.NewMockMalfeasanceValidator(atxHandler.ctrl) verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). @@ -1599,6 +1691,16 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { return atxHandler.edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() + verifier.EXPECT().PostIndex( + context.Background(), + atx.SmesherID, + initialAtx.Initial.CommitmentATX, + wire.PostFromWireV1(&atx.NIPosts[0].Posts[0].Post), + atx.NIPosts[0].Challenge.Bytes(), + atx.TotalNumUnits(), + 10, + ).Return(nil) + verifier.EXPECT().PostIndex( context.Background(), atx.SmesherID, @@ -1612,10 +1714,7 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { atxHandler.mMalPublish.EXPECT().Publish( gomock.Any(), sig.NodeID(), - gomock.Cond(func(data wire.Proof) bool { - _, ok := data.(*wire.ProofInvalidPost) - return ok - }), + gomock.AssignableToTypeOf(&wire.ProofInvalidPost{}), ).DoAndReturn(func(ctx context.Context, _ types.NodeID, proof wire.Proof) error { malProof := proof.(*wire.ProofInvalidPost) nId, err := malProof.Valid(ctx, verifier) @@ -1625,7 +1724,7 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { }) _, err := atxHandler.syntacticallyValidateDeps(context.Background(), atx) vErr := &verifying.ErrInvalidIndex{} - require.ErrorAs(t, err, vErr) + require.ErrorAs(t, err, &vErr) require.Equal(t, 7, vErr.Index) }) t.Run("invalid PoST index merged ATX - generates a malfeasance proof", func(t *testing.T) { @@ -1672,18 +1771,58 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { gomock.Any(), ) if equivocationSet[post.MarriageIndex] == sig.NodeID() { - call.Return(verifying.ErrInvalidIndex{Index: 7}) + call.Return(&verifying.ErrInvalidIndex{Index: 7}) } else { call.AnyTimes() } } + for invalidPostIdx := 0; invalidPostIdx < 10; invalidPostIdx++ { + atxHandler.mValidator.EXPECT().PostV2( + context.Background(), + sig.NodeID(), + gomock.Any(), + gomock.Any(), + merged.NIPosts[0].Challenge.Bytes(), + gomock.Any(), + gomock.Cond(func(opt validatorOption) bool { + opts := &validatorOptions{} + opt(opts) + return opts.postIdx != nil && *opts.postIdx == invalidPostIdx + }), + ).Return(&verifying.ErrInvalidIndex{Index: invalidPostIdx}) + } + + atxHandler.mValidator.EXPECT().PostV2( + context.Background(), + sig.NodeID(), + gomock.Any(), + gomock.Any(), + merged.NIPosts[0].Challenge.Bytes(), + gomock.Any(), + gomock.Cond(func(opt validatorOption) bool { + opts := &validatorOptions{} + opt(opts) + return opts.postIdx != nil && *opts.postIdx == 10 + }), + ).Return(nil) + verifier := wire.NewMockMalfeasanceValidator(atxHandler.ctrl) verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { return atxHandler.edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() + verifier.EXPECT().PostIndex( + context.Background(), + sig.NodeID(), + gomock.Any(), + gomock.Any(), + merged.NIPosts[0].Challenge.Bytes(), + gomock.Any(), + 10, + ).Return(nil) + verifier.EXPECT().PostIndex( context.Background(), sig.NodeID(), @@ -1697,10 +1836,7 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { atxHandler.mMalPublish.EXPECT().Publish( gomock.Any(), sig.NodeID(), - gomock.Cond(func(data wire.Proof) bool { - _, ok := data.(*wire.ProofInvalidPost) - return ok - }), + gomock.AssignableToTypeOf(&wire.ProofInvalidPost{}), ).DoAndReturn(func(ctx context.Context, _ types.NodeID, proof wire.Proof) error { malProof := proof.(*wire.ProofInvalidPost) nId, err := malProof.Valid(ctx, verifier) @@ -1710,7 +1846,7 @@ func TestHandlerV2_SyntacticallyValidateDeps(t *testing.T) { }) _, err = atxHandler.syntacticallyValidateDeps(context.Background(), merged) vErr := &verifying.ErrInvalidIndex{} - require.ErrorAs(t, err, vErr) + require.ErrorAs(t, err, &vErr) require.Equal(t, 7, vErr.Index) }) t.Run("invalid PoET membership proof", func(t *testing.T) { @@ -1832,10 +1968,7 @@ func Test_Marriages(t *testing.T) { atxHandler.mMalPublish.EXPECT().Publish( gomock.Any(), sig.NodeID(), - gomock.Cond(func(data wire.Proof) bool { - _, ok := data.(*wire.ProofDoubleMarry) - return ok - }), + gomock.AssignableToTypeOf(&wire.ProofDoubleMarry{}), ).DoAndReturn(func(ctx context.Context, _ types.NodeID, proof wire.Proof) error { malProof := proof.(*wire.ProofDoubleMarry) nId, err := malProof.Valid(ctx, verifier) @@ -1966,64 +2099,6 @@ func Test_MarryingMalicious(t *testing.T) { t.Run("other is malicious", tc(otherSig.NodeID())) } -func TestContextualValidation_DoublePost(t *testing.T) { - t.Parallel() - golden := types.RandomATXID() - sig, err := signing.NewEdSigner() - require.NoError(t, err) - - atxHandler := newV2TestHandler(t, golden) - - // marry - otherSig, err := signing.NewEdSigner() - require.NoError(t, err) - othersAtx := atxHandler.createAndProcessInitial(otherSig) - - mATX := newInitialATXv2(t, golden) - mATX.Marriages = []wire.MarriageCertificate{ - { - Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), - }, - { - ReferenceAtx: othersAtx.ID(), - Signature: otherSig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), - }, - } - mATX.Sign(sig) - - atxHandler.expectInitialAtxV2(mATX) - err = atxHandler.processATX(context.Background(), "", mATX, time.Now()) - require.NoError(t, err) - - // publish merged - merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) - post := wire.SubPostV2{ - MarriageIndex: 1, - NumUnits: othersAtx.TotalNumUnits(), - PrevATXIndex: 1, - } - merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) - - mATXID := mATX.ID() - merged.MarriageATX = &mATXID - - merged.PreviousATXs = []types.ATXID{mATX.ID(), othersAtx.ID()} - merged.Sign(sig) - - atxHandler.expectMergedAtxV2(merged, []types.NodeID{sig.NodeID(), otherSig.NodeID()}, []uint64{poetLeaves}) - err = atxHandler.processATX(context.Background(), "", merged, time.Now()) - require.NoError(t, err) - - // The otherSig tries to publish alone in the same epoch. - // This is malfeasance as it tries include his PoST twice. - doubled := newSoloATXv2(t, merged.PublishEpoch, othersAtx.ID(), othersAtx.ID()) - doubled.Sign(otherSig) - atxHandler.expectAtxV2(doubled) - atxHandler.mMalPublish.EXPECT().Publish(gomock.Any(), otherSig.NodeID(), gomock.Any()) - err = atxHandler.processATX(context.Background(), "", doubled, time.Now()) - require.NoError(t, err) -} - func Test_CalculatingUnits(t *testing.T) { t.Parallel() t.Run("units on 1 nipost must not overflow", func(t *testing.T) { @@ -2051,47 +2126,224 @@ func Test_CalculatingUnits(t *testing.T) { } func TestContextual_PreviousATX(t *testing.T) { - golden := types.RandomATXID() - atxHdlr := newV2TestHandler(t, golden) - var ( - signers []*signing.EdSigner - eqSet []types.NodeID - ) - for range 3 { + t.Run("invalid previous ATX, both v2", func(t *testing.T) { + golden := types.RandomATXID() + atxHdlr := newV2TestHandler(t, golden) + var ( + signers []*signing.EdSigner + eqSet []types.NodeID + ) + for range 3 { + sig, err := signing.NewEdSigner() + require.NoError(t, err) + signers = append(signers, sig) + eqSet = append(eqSet, sig.NodeID()) + } + + mATX, otherAtxs := marryIDs(t, atxHdlr, signers, golden) + + // signer 1 creates a solo ATX + soloAtx := newSoloATXv2(t, mATX.PublishEpoch+1, otherAtxs[0].ID(), mATX.ID()) + soloAtx.Sign(signers[1]) + atxHdlr.expectAtxV2(soloAtx) + err := atxHdlr.processATX(context.Background(), "", soloAtx, time.Now()) + require.NoError(t, err) + + // create a MergedATX for all IDs + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + post := wire.SubPostV2{ + MarriageIndex: 1, + PrevATXIndex: 1, + NumUnits: soloAtx.TotalNumUnits(), + } + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) + // Pass a wrong previous ATX for signer 1. It's already been used for soloATX + // (which should be used for the previous ATX for signer 1). + merged.PreviousATXs = append(merged.PreviousATXs, otherAtxs[0].ID()) + matxID := mATX.ID() + merged.MarriageATX = &matxID + merged.Sign(signers[0]) + + atxHdlr.expectMergedAtxV2(merged, eqSet, []uint64{100}) + + verifier := wire.NewMockMalfeasanceValidator(atxHdlr.ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return atxHdlr.edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + atxHdlr.mMalPublish.EXPECT().Publish( + gomock.Any(), + signers[1].NodeID(), + gomock.AssignableToTypeOf(&wire.ProofInvalidPrevAtxV2{}), + ).DoAndReturn(func(ctx context.Context, _ types.NodeID, proof wire.Proof) error { + malProof := proof.(*wire.ProofInvalidPrevAtxV2) + nId, err := malProof.Valid(ctx, verifier) + require.NoError(t, err) + require.Equal(t, signers[1].NodeID(), nId) + return nil + }) + + err = atxHdlr.processATX(context.Background(), "", merged, time.Now()) + require.NoError(t, err) + }) + + t.Run("invalid previous ATX, v1 and v2", func(t *testing.T) { + golden := types.RandomATXID() + atxHdlr := newTestHandler(t, golden) + + sig1, err := signing.NewEdSigner() + require.NoError(t, err) + + // signer 1 creates a solo ATX + prevATX := newInitialATXv1(t, golden) + prevATX.Sign(sig1) + atxHdlr.expectAtxV1(prevATX, prevATX.SmesherID) + _, err = atxHdlr.v1.processATX(context.Background(), "", prevATX, time.Now()) + require.NoError(t, err) + atxv1 := newChainedActivationTxV1(t, prevATX, prevATX.ID()) + atxv1.Sign(sig1) + atxHdlr.expectAtxV1(atxv1, atxv1.SmesherID) + _, err = atxHdlr.v1.processATX(context.Background(), "", atxv1, time.Now()) + require.NoError(t, err) + + soloAtx := newSoloATXv2(t, atxv1.PublishEpoch+1, atxv1.ID(), atxv1.ID()) + soloAtx.Sign(sig1) + atxHdlr.expectAtxV2(soloAtx) + err = atxHdlr.v2.processATX(context.Background(), "", soloAtx, time.Now()) + require.NoError(t, err) + + sig2, err := signing.NewEdSigner() + require.NoError(t, err) + mATX := newInitialATXv2(t, golden) + mATX.Marriages = []wire.MarriageCertificate{ + { + ReferenceAtx: types.EmptyATXID, + Signature: sig2.Sign(signing.MARRIAGE, sig2.NodeID().Bytes()), + }, + { + ReferenceAtx: soloAtx.ID(), + Signature: sig1.Sign(signing.MARRIAGE, sig2.NodeID().Bytes()), + }, + } + mATX.PublishEpoch = soloAtx.PublishEpoch + mATX.Sign(sig2) + atxHdlr.expectInitialAtxV2(mATX) + err = atxHdlr.v2.processATX(context.Background(), "", mATX, time.Now()) + require.NoError(t, err) + + // create a MergedATX for all IDs + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + post := wire.SubPostV2{ + MarriageIndex: 1, + PrevATXIndex: 1, + NumUnits: soloAtx.TotalNumUnits(), + } + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) + merged.PreviousATXs = append(merged.PreviousATXs, prevATX.ID()) + merged.MarriageATX = new(types.ATXID) + *merged.MarriageATX = mATX.ID() + merged.Sign(sig2) + + atxHdlr.expectMergedAtxV2(merged, []types.NodeID{sig1.NodeID(), sig2.NodeID()}, []uint64{100}) + + verifier := wire.NewMockMalfeasanceValidator(atxHdlr.ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return atxHdlr.edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + atxHdlr.mMalPublish.EXPECT().Publish( + gomock.Any(), + sig1.NodeID(), + gomock.AssignableToTypeOf(&wire.ProofInvalidPrevAtxV1{}), + ).DoAndReturn(func(ctx context.Context, _ types.NodeID, proof wire.Proof) error { + malProof := proof.(*wire.ProofInvalidPrevAtxV1) + nId, err := malProof.Valid(ctx, verifier) + require.NoError(t, err) + require.Equal(t, sig1.NodeID(), nId) + return nil + }) + + err = atxHdlr.v2.processATX(context.Background(), "", merged, time.Now()) + require.NoError(t, err) + }) + + t.Run("double publish", func(t *testing.T) { + t.Parallel() + golden := types.RandomATXID() sig, err := signing.NewEdSigner() require.NoError(t, err) - signers = append(signers, sig) - eqSet = append(eqSet, sig.NodeID()) - } - mATX, otherAtxs := marryIDs(t, atxHdlr, signers, golden) + atxHdlr := newV2TestHandler(t, golden) - // signer 1 creates a solo ATX - soloAtx := newSoloATXv2(t, mATX.PublishEpoch+1, otherAtxs[0].ID(), mATX.ID()) - soloAtx.Sign(signers[1]) - atxHdlr.expectAtxV2(soloAtx) - err := atxHdlr.processATX(context.Background(), "", soloAtx, time.Now()) - require.NoError(t, err) + // marry + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + othersAtx := atxHdlr.createAndProcessInitial(otherSig) - // create a MergedATX for all IDs - merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) - post := wire.SubPostV2{ - MarriageIndex: 1, - PrevATXIndex: 1, - NumUnits: soloAtx.TotalNumUnits(), - } - merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) - // Pass a wrong previous ATX for signer 1. It's already been used for soloATX - // (which should be used for the previous ATX for signer 1). - merged.PreviousATXs = append(merged.PreviousATXs, otherAtxs[0].ID()) - matxID := mATX.ID() - merged.MarriageATX = &matxID - merged.Sign(signers[0]) - - atxHdlr.expectMergedAtxV2(merged, eqSet, []uint64{100}) - atxHdlr.mMalPublish.EXPECT().Publish(gomock.Any(), signers[1].NodeID(), gomock.Any()) - err = atxHdlr.processATX(context.Background(), "", merged, time.Now()) - require.NoError(t, err) + mATX := newInitialATXv2(t, golden) + mATX.Marriages = []wire.MarriageCertificate{ + { + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + { + ReferenceAtx: othersAtx.ID(), + Signature: otherSig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }, + } + mATX.Sign(sig) + + atxHdlr.expectInitialAtxV2(mATX) + err = atxHdlr.processATX(context.Background(), "", mATX, time.Now()) + require.NoError(t, err) + + // publish merged + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + post := wire.SubPostV2{ + MarriageIndex: 1, + NumUnits: othersAtx.TotalNumUnits(), + PrevATXIndex: 1, + } + merged.NIPosts[0].Posts = append(merged.NIPosts[0].Posts, post) + + mATXID := mATX.ID() + merged.MarriageATX = &mATXID + + merged.PreviousATXs = []types.ATXID{mATX.ID(), othersAtx.ID()} + merged.Sign(sig) + + atxHdlr.expectMergedAtxV2(merged, []types.NodeID{sig.NodeID(), otherSig.NodeID()}, []uint64{poetLeaves}) + err = atxHdlr.processATX(context.Background(), "", merged, time.Now()) + require.NoError(t, err) + + // The otherSig tries to publish alone in the same epoch. + // This is malfeasance as it tries include his PoST twice. + doubled := newSoloATXv2(t, merged.PublishEpoch, othersAtx.ID(), othersAtx.ID()) + doubled.Sign(otherSig) + atxHdlr.expectAtxV2(doubled) + + verifier := wire.NewMockMalfeasanceValidator(atxHdlr.ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return atxHdlr.edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + atxHdlr.mMalPublish.EXPECT().Publish( + gomock.Any(), + otherSig.NodeID(), + gomock.AssignableToTypeOf(&wire.ProofInvalidPrevAtxV2{}), + ).DoAndReturn(func(ctx context.Context, _ types.NodeID, proof wire.Proof) error { + malProof := proof.(*wire.ProofInvalidPrevAtxV2) + nId, err := malProof.Valid(ctx, verifier) + require.NoError(t, err) + require.Equal(t, otherSig.NodeID(), nId) + return nil + }) + + err = atxHdlr.processATX(context.Background(), "", doubled, time.Now()) + require.NoError(t, err) + }) } func Test_CalculatingWeight(t *testing.T) { diff --git a/activation/interface.go b/activation/interface.go index c9c3359091..38c8cf1332 100644 --- a/activation/interface.go +++ b/activation/interface.go @@ -92,7 +92,7 @@ type syncer interface { RegisterForATXSynced() <-chan struct{} } -// malfeasancePublisher is an interface for publishing malfeasance proofs. +// atxMalfeasancePublisher is an interface for publishing malfeasance proofs. // This interface is used to publish proofs in V2. // // The provider of that interface ensures that only valid proofs are published (invalid ones return an error). @@ -100,7 +100,7 @@ type syncer interface { // // Additionally the publisher will only gossip proofs when the node is in sync, otherwise it will only store them // and mark the associated identity as malfeasant. -type malfeasancePublisher interface { +type atxMalfeasancePublisher interface { Publish(ctx context.Context, id types.NodeID, proof wire.Proof) error } diff --git a/activation/mocks.go b/activation/mocks.go index 985f6a05f3..0ab71a7524 100644 --- a/activation/mocks.go +++ b/activation/mocks.go @@ -1092,32 +1092,32 @@ func (c *MocksyncerRegisterForATXSyncedCall) DoAndReturn(f func() <-chan struct{ return c } -// MockmalfeasancePublisher is a mock of malfeasancePublisher interface. -type MockmalfeasancePublisher struct { +// MockatxMalfeasancePublisher is a mock of atxMalfeasancePublisher interface. +type MockatxMalfeasancePublisher struct { ctrl *gomock.Controller - recorder *MockmalfeasancePublisherMockRecorder + recorder *MockatxMalfeasancePublisherMockRecorder isgomock struct{} } -// MockmalfeasancePublisherMockRecorder is the mock recorder for MockmalfeasancePublisher. -type MockmalfeasancePublisherMockRecorder struct { - mock *MockmalfeasancePublisher +// MockatxMalfeasancePublisherMockRecorder is the mock recorder for MockatxMalfeasancePublisher. +type MockatxMalfeasancePublisherMockRecorder struct { + mock *MockatxMalfeasancePublisher } -// NewMockmalfeasancePublisher creates a new mock instance. -func NewMockmalfeasancePublisher(ctrl *gomock.Controller) *MockmalfeasancePublisher { - mock := &MockmalfeasancePublisher{ctrl: ctrl} - mock.recorder = &MockmalfeasancePublisherMockRecorder{mock} +// NewMockatxMalfeasancePublisher creates a new mock instance. +func NewMockatxMalfeasancePublisher(ctrl *gomock.Controller) *MockatxMalfeasancePublisher { + mock := &MockatxMalfeasancePublisher{ctrl: ctrl} + mock.recorder = &MockatxMalfeasancePublisherMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockmalfeasancePublisher) EXPECT() *MockmalfeasancePublisherMockRecorder { +func (m *MockatxMalfeasancePublisher) EXPECT() *MockatxMalfeasancePublisherMockRecorder { return m.recorder } // Publish mocks base method. -func (m *MockmalfeasancePublisher) Publish(ctx context.Context, id types.NodeID, proof wire.Proof) error { +func (m *MockatxMalfeasancePublisher) Publish(ctx context.Context, id types.NodeID, proof wire.Proof) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Publish", ctx, id, proof) ret0, _ := ret[0].(error) @@ -1125,31 +1125,31 @@ func (m *MockmalfeasancePublisher) Publish(ctx context.Context, id types.NodeID, } // Publish indicates an expected call of Publish. -func (mr *MockmalfeasancePublisherMockRecorder) Publish(ctx, id, proof any) *MockmalfeasancePublisherPublishCall { +func (mr *MockatxMalfeasancePublisherMockRecorder) Publish(ctx, id, proof any) *MockatxMalfeasancePublisherPublishCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Publish", reflect.TypeOf((*MockmalfeasancePublisher)(nil).Publish), ctx, id, proof) - return &MockmalfeasancePublisherPublishCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Publish", reflect.TypeOf((*MockatxMalfeasancePublisher)(nil).Publish), ctx, id, proof) + return &MockatxMalfeasancePublisherPublishCall{Call: call} } -// MockmalfeasancePublisherPublishCall wrap *gomock.Call -type MockmalfeasancePublisherPublishCall struct { +// MockatxMalfeasancePublisherPublishCall wrap *gomock.Call +type MockatxMalfeasancePublisherPublishCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockmalfeasancePublisherPublishCall) Return(arg0 error) *MockmalfeasancePublisherPublishCall { +func (c *MockatxMalfeasancePublisherPublishCall) Return(arg0 error) *MockatxMalfeasancePublisherPublishCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockmalfeasancePublisherPublishCall) Do(f func(context.Context, types.NodeID, wire.Proof) error) *MockmalfeasancePublisherPublishCall { +func (c *MockatxMalfeasancePublisherPublishCall) Do(f func(context.Context, types.NodeID, wire.Proof) error) *MockatxMalfeasancePublisherPublishCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockmalfeasancePublisherPublishCall) DoAndReturn(f func(context.Context, types.NodeID, wire.Proof) error) *MockmalfeasancePublisherPublishCall { +func (c *MockatxMalfeasancePublisherPublishCall) DoAndReturn(f func(context.Context, types.NodeID, wire.Proof) error) *MockatxMalfeasancePublisherPublishCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/activation/validation.go b/activation/validation.go index cb9a9a8885..c4a1c1829a 100644 --- a/activation/validation.go +++ b/activation/validation.go @@ -21,24 +21,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" ) -type ErrAtxNotFound struct { - Id types.ATXID - // the source (if any) that caused the error - source error -} - -func (e *ErrAtxNotFound) Error() string { - return fmt.Sprintf("ATX ID (%v) not found (%v)", e.Id.String(), e.source) -} - -func (e *ErrAtxNotFound) Unwrap() error { return e.source } - -func (e *ErrAtxNotFound) Is(target error) bool { - if err, ok := target.(*ErrAtxNotFound); ok { - return err.Id == e.Id - } - return false -} +var ErrPostIndexOutOfRange = errors.New("post index out of range") type validatorOptions struct { postIdx *int @@ -214,6 +197,9 @@ func (v *Validator) Post( verifyOpts := []verifying.OptionFunc{verifying.WithLabelScryptParams(v.scrypt)} if options.postIdx != nil { + if *options.postIdx >= int(v.cfg.K2) { + return ErrPostIndexOutOfRange + } verifyOpts = append(verifyOpts, verifying.SelectedIndex(*options.postIdx)) } if options.postSubsetSeed != nil { @@ -309,7 +295,7 @@ func (v *Validator) InitialNIPostChallengeV1( if commitmentATXId != goldenATXID { commitmentAtx, err := atxs.GetAtx(commitmentATXId) if err != nil { - return &ErrAtxNotFound{Id: commitmentATXId, source: err} + return fmt.Errorf("ATX (%s) not found: %w", commitmentATXId.ShortString(), err) } if challenge.PublishEpoch <= commitmentAtx.PublishEpoch { return fmt.Errorf( @@ -362,7 +348,7 @@ func (v *Validator) PositioningAtx( } posAtx, err := atxs.GetAtx(id) if err != nil { - return &ErrAtxNotFound{Id: id, source: err} + return fmt.Errorf("positioning atx (%s) not found: %w", id.ShortString(), err) } if posAtx.PublishEpoch >= pubepoch { return fmt.Errorf("positioning atx epoch (%v) must be before %v", posAtx.PublishEpoch, pubepoch) diff --git a/activation/validation_test.go b/activation/validation_test.go index 4ba72911f3..8c0d88bdb8 100644 --- a/activation/validation_test.go +++ b/activation/validation_test.go @@ -332,8 +332,7 @@ func Test_Validation_PositioningAtx(t *testing.T) { atxProvider.EXPECT().GetAtx(posAtxId).Return(nil, errors.New("db error")) err := v.PositioningAtx(posAtxId, atxProvider, goldenAtxId, types.LayerID(1012).GetEpoch()) - require.ErrorIs(t, err, &ErrAtxNotFound{Id: posAtxId}) - require.ErrorContains(t, err, "db error") + require.EqualError(t, err, fmt.Sprintf("positioning atx (%s) not found: db error", posAtxId.ShortString())) }) t.Run("positioning atx published in higher epoch than expected", func(t *testing.T) { diff --git a/activation/wire/malfeasance_double_marry.go b/activation/wire/malfeasance_double_marry.go index ac7a760c1f..fc2a98e545 100644 --- a/activation/wire/malfeasance_double_marry.go +++ b/activation/wire/malfeasance_double_marry.go @@ -15,12 +15,10 @@ import ( // ProofDoubleMarry is a proof that two distinct ATXs contain a marriage certificate signed by the same identity. // // We are proving the following: -// 1. The ATXs have different IDs. -// 2. Both ATXs have a valid signature. -// 3. Both ATXs contain a marriage certificate created by the same identity. -// 4. Both marriage certificates have valid signatures. -// -// HINT: this works if the identity that publishes the marriage ATX marries themselves. +// 1. The ATXs have different IDs. +// 2. Both ATXs have a valid signature. +// 3. Both ATXs contain a marriage certificate created by the same identity. +// 4. Both marriage certificates have valid signatures. type ProofDoubleMarry struct { // NodeID is the node ID that married twice. NodeID types.NodeID diff --git a/activation/wire/malfeasance_double_marry_test.go b/activation/wire/malfeasance_double_marry_test.go index f9f686503a..f52f8c8559 100644 --- a/activation/wire/malfeasance_double_marry_test.go +++ b/activation/wire/malfeasance_double_marry_test.go @@ -3,7 +3,6 @@ package wire import ( "context" "fmt" - "slices" "testing" "github.com/stretchr/testify/require" @@ -16,6 +15,8 @@ import ( ) func Test_DoubleMarryProof(t *testing.T) { + t.Parallel() + sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -25,7 +26,9 @@ func Test_DoubleMarryProof(t *testing.T) { edVerifier := signing.NewEdVerifier() t.Run("valid", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} otherAtx.SetID(types.RandomATXID()) otherAtx.SmesherID = otherSig.NodeID() @@ -59,37 +62,55 @@ func Test_DoubleMarryProof(t *testing.T) { require.Equal(t, otherSig.NodeID(), id) }) - t.Run("does not contain same certificate owner", func(t *testing.T) { + t.Run("identity is not included in both ATXs", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} + otherAtx.SetID(types.RandomATXID()) + otherAtx.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) + atx1 := newActivationTxV2( withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), + withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), ) atx1.Sign(sig) atx2 := newActivationTxV2( withMarriageCertificate(otherSig, types.EmptyATXID, otherSig.NodeID()), + withMarriageCertificate(sig, atx1.ID(), otherSig.NodeID()), ) atx2.Sign(otherSig) + marriages := make([]MarriageCertificate, len(atx1.Marriages)) + copy(marriages, atx1.Marriages) + atx1.Marriages = marriages[:1] proof, err := NewDoubleMarryProof(db, atx1, atx2, otherSig.NodeID()) - require.ErrorContains(t, err, fmt.Sprintf( + require.EqualError(t, err, fmt.Sprintf( "proof for atx1: does not contain a marriage certificate signed by %s", otherSig.NodeID().ShortString(), )) require.Nil(t, proof) + atx1.Marriages = marriages + marriages = make([]MarriageCertificate, len(atx2.Marriages)) + copy(marriages, atx2.Marriages) + atx2.Marriages = marriages[:1] proof, err = NewDoubleMarryProof(db, atx1, atx2, sig.NodeID()) - require.ErrorContains(t, err, fmt.Sprintf( + require.EqualError(t, err, fmt.Sprintf( "proof for atx2: does not contain a marriage certificate signed by %s", sig.NodeID().ShortString(), )) require.Nil(t, proof) + atx2.Marriages = marriages }) t.Run("same ATX ID", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + atx1 := newActivationTxV2() atx1.Sign(sig) - db := statesql.InMemoryTest(t) proof, err := NewDoubleMarryProof(db, atx1, atx1, sig.NodeID()) require.ErrorContains(t, err, "ATXs have the same ID") require.Nil(t, proof) @@ -108,128 +129,10 @@ func Test_DoubleMarryProof(t *testing.T) { require.Equal(t, types.EmptyNodeID, id) }) - t.Run("invalid marriage proof", func(t *testing.T) { + t.Run("invalid proof", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) - otherAtx := &types.ActivationTx{} - otherAtx.SetID(types.RandomATXID()) - otherAtx.SmesherID = otherSig.NodeID() - require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) - - atx1 := newActivationTxV2( - withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), - withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), - ) - atx1.Sign(sig) - - atx2 := newActivationTxV2( - withMarriageCertificate(otherSig, types.EmptyATXID, otherSig.NodeID()), - withMarriageCertificate(sig, atx1.ID(), otherSig.NodeID()), - ) - atx2.Sign(otherSig) - - // manually construct an invalid proof - proof1, err := createMarryProof(db, atx1, otherSig.NodeID()) - require.NoError(t, err) - proof2, err := createMarryProof(db, atx2, otherSig.NodeID()) - require.NoError(t, err) - - proof := &ProofDoubleMarry{ - NodeID: otherSig.NodeID(), - - ATX1: atx1.ID(), - SmesherID1: atx1.SmesherID, - Signature1: atx1.Signature, - Proof1: proof1, - - ATX2: atx2.ID(), - SmesherID2: atx2.SmesherID, - Signature2: atx2.Signature, - Proof2: proof2, - } - - ctrl := gomock.NewController(t) - verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { - return edVerifier.Verify(d, nodeID, m, sig) - }).AnyTimes() - proof.Proof1.MarriageCertificatesProof = slices.Clone(proof1.MarriageCertificatesProof) - proof.Proof1.MarriageCertificatesProof[0] = types.RandomHash() - id, err := proof.Valid(context.Background(), verifier) - require.ErrorContains(t, err, "proof 1 is invalid: invalid marriage proof") - require.Equal(t, types.EmptyNodeID, id) - - proof.Proof1.MarriageCertificatesProof[0] = proof1.MarriageCertificatesProof[0] - proof.Proof2.MarriageCertificatesProof = slices.Clone(proof2.MarriageCertificatesProof) - proof.Proof2.MarriageCertificatesProof[0] = types.RandomHash() - id, err = proof.Valid(context.Background(), verifier) - require.ErrorContains(t, err, "proof 2 is invalid: invalid marriage proof") - require.Equal(t, types.EmptyNodeID, id) - }) - - t.Run("invalid certificate proof", func(t *testing.T) { - db := statesql.InMemoryTest(t) - otherAtx := &types.ActivationTx{} - otherAtx.SetID(types.RandomATXID()) - otherAtx.SmesherID = otherSig.NodeID() - require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) - - atx1 := newActivationTxV2( - withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), - withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), - ) - atx1.Sign(sig) - - atx2 := newActivationTxV2( - withMarriageCertificate(otherSig, types.EmptyATXID, otherSig.NodeID()), - withMarriageCertificate(sig, atx1.ID(), otherSig.NodeID()), - ) - atx2.Sign(otherSig) - - // manually construct an invalid proof - proof1, err := createMarryProof(db, atx1, otherSig.NodeID()) - require.NoError(t, err) - proof2, err := createMarryProof(db, atx2, otherSig.NodeID()) - require.NoError(t, err) - - proof := &ProofDoubleMarry{ - NodeID: otherSig.NodeID(), - - ATX1: atx1.ID(), - SmesherID1: atx1.SmesherID, - Signature1: atx1.Signature, - Proof1: proof1, - - ATX2: atx2.ID(), - SmesherID2: atx2.SmesherID, - Signature2: atx2.Signature, - Proof2: proof2, - } - - ctrl := gomock.NewController(t) - verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { - return edVerifier.Verify(d, nodeID, m, sig) - }).AnyTimes() - - proof.Proof1.CertificateProof = slices.Clone(proof1.CertificateProof) - proof.Proof1.CertificateProof[0] = types.RandomHash() - id, err := proof.Valid(context.Background(), verifier) - require.ErrorContains(t, err, "proof 1 is invalid: invalid certificate proof") - require.Equal(t, types.EmptyNodeID, id) - - proof.Proof1.CertificateProof[0] = proof1.CertificateProof[0] - proof.Proof2.CertificateProof = slices.Clone(proof2.CertificateProof) - proof.Proof2.CertificateProof[0] = types.RandomHash() - id, err = proof.Valid(context.Background(), verifier) - require.ErrorContains(t, err, "proof 2 is invalid: invalid certificate proof") - require.Equal(t, types.EmptyNodeID, id) - }) - - t.Run("invalid atx signature", func(t *testing.T) { - db := statesql.InMemoryTest(t) otherAtx := &types.ActivationTx{} otherAtx.SetID(types.RandomATXID()) otherAtx.SmesherID = otherSig.NodeID() @@ -257,76 +160,46 @@ func Test_DoubleMarryProof(t *testing.T) { return edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() + // invalid signature for ATX1 proof.Signature1 = types.RandomEdSignature() id, err := proof.Valid(context.Background(), verifier) require.ErrorContains(t, err, "invalid signature for ATX1") require.Equal(t, types.EmptyNodeID, id) - proof.Signature1 = atx1.Signature + + // invalid signature for ATX2 proof.Signature2 = types.RandomEdSignature() id, err = proof.Valid(context.Background(), verifier) require.ErrorContains(t, err, "invalid signature for ATX2") require.Equal(t, types.EmptyNodeID, id) - }) - - t.Run("invalid certificate signature", func(t *testing.T) { - db := statesql.InMemoryTest(t) - otherAtx := &types.ActivationTx{} - otherAtx.SetID(types.RandomATXID()) - otherAtx.SmesherID = otherSig.NodeID() - require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) - - atx1 := newActivationTxV2( - withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), - withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), - ) - atx1.Sign(sig) + proof.Signature2 = atx2.Signature - atx2 := newActivationTxV2( - withMarriageCertificate(otherSig, types.EmptyATXID, sig.NodeID()), - withMarriageCertificate(sig, atx1.ID(), sig.NodeID()), - ) - atx2.Sign(otherSig) - - proof, err := NewDoubleMarryProof(db, atx1, atx2, otherSig.NodeID()) - require.NoError(t, err) - - ctrl := gomock.NewController(t) - verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { - return edVerifier.Verify(d, nodeID, m, sig) - }).AnyTimes() - - proof.Proof1.Certificate.Signature = types.RandomEdSignature() - id, err := proof.Valid(context.Background(), verifier) - require.ErrorContains(t, err, "proof 1 is invalid: invalid certificate signature") + // invalid smesher ID for ATX1 + proof.SmesherID1 = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid signature for ATX1") require.Equal(t, types.EmptyNodeID, id) + proof.SmesherID1 = atx1.SmesherID - proof.Proof1.Certificate.Signature = atx1.Marriages[1].Signature - proof.Proof2.Certificate.Signature = types.RandomEdSignature() + // invalid smesher ID for ATX2 + proof.SmesherID2 = types.RandomNodeID() id, err = proof.Valid(context.Background(), verifier) - require.ErrorContains(t, err, "proof 2 is invalid: invalid certificate signature") + require.ErrorContains(t, err, "invalid signature for ATX2") require.Equal(t, types.EmptyNodeID, id) - }) - - t.Run("unknown reference ATX", func(t *testing.T) { - db := statesql.InMemoryTest(t) + proof.SmesherID2 = atx2.SmesherID - atx1 := newActivationTxV2( - withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), - withMarriageCertificate(otherSig, types.RandomATXID(), sig.NodeID()), // unknown reference ATX - ) - atx1.Sign(sig) - - atx2 := newActivationTxV2( - withMarriageCertificate(otherSig, types.EmptyATXID, sig.NodeID()), - withMarriageCertificate(sig, atx1.ID(), sig.NodeID()), - ) - atx2.Sign(otherSig) + // invalid ATX ID for ATX1 + proof.ATX1 = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid signature for ATX1") + require.Equal(t, types.EmptyNodeID, id) + proof.ATX1 = atx1.ID() - proof, err := NewDoubleMarryProof(db, atx1, atx2, otherSig.NodeID()) - require.Error(t, err) - require.Nil(t, proof) + // invalid ATX ID for ATX2 + proof.ATX2 = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid signature for ATX2") + require.Equal(t, types.EmptyNodeID, id) + proof.ATX2 = atx2.ID() }) } diff --git a/activation/wire/malfeasance_double_merge.go b/activation/wire/malfeasance_double_merge.go new file mode 100644 index 0000000000..3b3f73194a --- /dev/null +++ b/activation/wire/malfeasance_double_merge.go @@ -0,0 +1,168 @@ +package wire + +import ( + "context" + "errors" + "fmt" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/atxs" +) + +//go:generate scalegen + +// ProofDoubleMerge is a proof that two distinct ATXs published in the same epoch +// contain the same marriage ATX. +// +// We are proving the following: +// 1. The ATXs have different IDs. +// 2. Both ATXs have a valid signature. +// 3. Both ATXs contain the same marriage ATX. +// 4. Both ATXs were published in the same epoch. +// 5. Signers of both ATXs are married - to prevent banning others by +// publishing an ATX with the same marriage ATX. +type ProofDoubleMerge struct { + // PublishEpoch and its proof that it is contained in the ATX. + PublishEpoch types.EpochID + + // MarriageATXID is the ID of the marriage ATX. + MarriageATX types.ATXID + // MarriageATXSmesherID is the ID of the smesher that published the marriage ATX. + MarriageATXSmesherID types.NodeID + + // ATXID1 is the ID of the ATX being proven. + ATXID1 types.ATXID + // SmesherID1 is the ID of the smesher that published the ATX. + SmesherID1 types.NodeID + // Signature1 is the signature of the ATXID by the smesher. + Signature1 types.EdSignature + // PublishEpochProof1 is the proof that the publish epoch is contained in the ATX. + PublishEpochProof1 PublishEpochProof `scale:"max=32"` + // MarriageATXProof1 is the proof that MarriageATX is contained in the ATX. + MarriageATXProof1 MarriageATXProof `scale:"max=32"` + // SmesherID1MarryProof is the proof that they married in MarriageATX. + SmesherID1MarryProof MarryProof + + // ATXID2 is the ID of the ATX being proven. + ATXID2 types.ATXID + // SmesherID is the ID of the smesher that published the ATX. + SmesherID2 types.NodeID + // Signature2 is the signature of the ATXID by the smesher. + Signature2 types.EdSignature + // PublishEpochProof2 is the proof that the publish epoch is contained in the ATX. + PublishEpochProof2 PublishEpochProof `scale:"max=32"` + // MarriageATXProof1 is the proof that MarriageATX is contained in the ATX. + MarriageATXProof2 MarriageATXProof `scale:"max=32"` + // SmesherID1MarryProof is the proof that they married in MarriageATX. + SmesherID2MarryProof MarryProof +} + +var _ Proof = &ProofDoubleMerge{} + +func NewDoubleMergeProof(db sql.Executor, atx1, atx2 *ActivationTxV2) (*ProofDoubleMerge, error) { + if atx1.ID() == atx2.ID() { + return nil, errors.New("ATXs have the same ID") + } + if atx1.SmesherID == atx2.SmesherID { + return nil, errors.New("ATXs have the same smesher ID") + } + if atx1.PublishEpoch != atx2.PublishEpoch { + return nil, fmt.Errorf("ATXs have different publish epoch (%v != %v)", atx1.PublishEpoch, atx2.PublishEpoch) + } + if atx1.MarriageATX == nil { + return nil, errors.New("ATX 1 have no marriage ATX") + } + if atx2.MarriageATX == nil { + return nil, errors.New("ATX 2 have no marriage ATX") + } + if *atx1.MarriageATX != *atx2.MarriageATX { + return nil, errors.New("ATXs have different marriage ATXs") + } + + var blob sql.Blob + v, err := atxs.LoadBlob(context.Background(), db, atx1.MarriageATX.Bytes(), &blob) + if err != nil { + return nil, fmt.Errorf("get marriage ATX: %w", err) + } + if v != types.AtxV2 { + return nil, errors.New("invalid ATX version for marriage ATX") + } + marriageATX, err := DecodeAtxV2(blob.Bytes) + if err != nil { + return nil, fmt.Errorf("decode marriage ATX: %w", err) + } + + marriageProof1, err := createMarryProof(db, marriageATX, atx1.SmesherID) + if err != nil { + return nil, fmt.Errorf("NodeID marriage proof: %w", err) + } + marriageProof2, err := createMarryProof(db, marriageATX, atx2.SmesherID) + if err != nil { + return nil, fmt.Errorf("SmesherID marriage proof: %w", err) + } + + proof := ProofDoubleMerge{ + PublishEpoch: atx1.PublishEpoch, + MarriageATX: marriageATX.ID(), + MarriageATXSmesherID: marriageATX.SmesherID, + + ATXID1: atx1.ID(), + SmesherID1: atx1.SmesherID, + Signature1: atx1.Signature, + PublishEpochProof1: atx1.PublishEpochProof(), + MarriageATXProof1: atx1.MarriageATXProof(), + SmesherID1MarryProof: marriageProof1, + + ATXID2: atx2.ID(), + SmesherID2: atx2.SmesherID, + Signature2: atx2.Signature, + PublishEpochProof2: atx2.PublishEpochProof(), + MarriageATXProof2: atx2.MarriageATXProof(), + SmesherID2MarryProof: marriageProof2, + } + + return &proof, nil +} + +func (p *ProofDoubleMerge) Valid(_ context.Context, edVerifier MalfeasanceValidator) (types.NodeID, error) { + // 1. The ATXs have different IDs. + if p.ATXID1 == p.ATXID2 { + return types.EmptyNodeID, errors.New("ATXs have the same ID") + } + + // 2. Both ATXs have a valid signature. + if !edVerifier.Signature(signing.ATX, p.SmesherID1, p.ATXID1.Bytes(), p.Signature1) { + return types.EmptyNodeID, errors.New("ATX 1 invalid signature") + } + if !edVerifier.Signature(signing.ATX, p.SmesherID2, p.ATXID2.Bytes(), p.Signature2) { + return types.EmptyNodeID, errors.New("ATX 2 invalid signature") + } + + // 3. and 4. publish epoch is contained in the ATXs + if !p.PublishEpochProof1.Valid(p.ATXID1, p.PublishEpoch) { + return types.EmptyNodeID, errors.New("ATX 1 invalid publish epoch proof") + } + if !p.PublishEpochProof2.Valid(p.ATXID2, p.PublishEpoch) { + return types.EmptyNodeID, errors.New("ATX 2 invalid publish epoch proof") + } + + // 5. signers are married + if !p.MarriageATXProof1.Valid(p.ATXID1, p.MarriageATX) { + return types.EmptyNodeID, errors.New("ATX 1 invalid marriage ATX proof") + } + err := p.SmesherID1MarryProof.Valid(edVerifier, p.MarriageATX, p.MarriageATXSmesherID, p.SmesherID1) + if err != nil { + return types.EmptyNodeID, errors.New("ATX 1 invalid marriage ATX proof") + } + if !p.MarriageATXProof2.Valid(p.ATXID2, p.MarriageATX) { + return types.EmptyNodeID, errors.New("ATX 2 invalid marriage ATX proof") + } + err = p.SmesherID2MarryProof.Valid(edVerifier, p.MarriageATX, p.MarriageATXSmesherID, p.SmesherID2) + if err != nil { + return types.EmptyNodeID, errors.New("ATX 2 invalid marriage ATX proof") + } + + return p.SmesherID1, nil +} diff --git a/activation/wire/malfeasance_double_merge_scale.go b/activation/wire/malfeasance_double_merge_scale.go new file mode 100644 index 0000000000..d4d38b16f3 --- /dev/null +++ b/activation/wire/malfeasance_double_merge_scale.go @@ -0,0 +1,232 @@ +// Code generated by github.com/spacemeshos/go-scale/scalegen. DO NOT EDIT. + +// nolint +package wire + +import ( + "github.com/spacemeshos/go-scale" + "github.com/spacemeshos/go-spacemesh/common/types" +) + +func (t *ProofDoubleMerge) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeCompact32(enc, uint32(t.PublishEpoch)) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.MarriageATX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.MarriageATXSmesherID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.ATXID1[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.SmesherID1[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.Signature1[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.PublishEpochProof1, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.MarriageATXProof1, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.SmesherID1MarryProof.EncodeScale(enc) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.ATXID2[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.SmesherID2[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.Signature2[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.PublishEpochProof2, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.MarriageATXProof2, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.SmesherID2MarryProof.EncodeScale(enc) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *ProofDoubleMerge) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + field, n, err := scale.DecodeCompact32(dec) + if err != nil { + return total, err + } + total += n + t.PublishEpoch = types.EpochID(field) + } + { + n, err := scale.DecodeByteArray(dec, t.MarriageATX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.MarriageATXSmesherID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.ATXID1[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.SmesherID1[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.Signature1[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.PublishEpochProof1 = field + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.MarriageATXProof1 = field + } + { + n, err := t.SmesherID1MarryProof.DecodeScale(dec) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.ATXID2[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.SmesherID2[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.Signature2[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.PublishEpochProof2 = field + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.MarriageATXProof2 = field + } + { + n, err := t.SmesherID2MarryProof.DecodeScale(dec) + if err != nil { + return total, err + } + total += n + } + return total, nil +} diff --git a/activation/wire/malfeasance_double_merge_test.go b/activation/wire/malfeasance_double_merge_test.go new file mode 100644 index 0000000000..83ba3a2b4a --- /dev/null +++ b/activation/wire/malfeasance_double_merge_test.go @@ -0,0 +1,325 @@ +package wire + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" +) + +func Test_DoubleMergeProof(t *testing.T) { + t.Parallel() + + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + + marrySig, err := signing.NewEdSigner() + require.NoError(t, err) + + edVerifier := signing.NewEdVerifier() + + setupMarriage := func(db sql.Executor) *ActivationTxV2 { + wInitialAtx1 := newActivationTxV2( + withInitial(types.RandomATXID(), PostV1{}), + ) + wInitialAtx1.Sign(sig) + initialAtx1 := &types.ActivationTx{ + CommitmentATX: &wInitialAtx1.Initial.CommitmentATX, + } + initialAtx1.SetID(wInitialAtx1.ID()) + initialAtx1.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, initialAtx1, wInitialAtx1.Blob())) + + wInitialAtx2 := newActivationTxV2( + withInitial(types.RandomATXID(), PostV1{}), + ) + wInitialAtx2.Sign(otherSig) + initialAtx2 := &types.ActivationTx{} + initialAtx2.SetID(wInitialAtx2.ID()) + initialAtx2.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, initialAtx2, wInitialAtx2.Blob())) + + wMarriageAtx := newActivationTxV2( + withMarriageCertificate(marrySig, types.EmptyATXID, marrySig.NodeID()), + withMarriageCertificate(sig, wInitialAtx1.ID(), marrySig.NodeID()), + withMarriageCertificate(otherSig, wInitialAtx2.ID(), marrySig.NodeID()), + ) + wMarriageAtx.Sign(marrySig) + + marriageAtx := &types.ActivationTx{} + marriageAtx.SetID(wMarriageAtx.ID()) + marriageAtx.SmesherID = marrySig.NodeID() + require.NoError(t, atxs.Add(db, marriageAtx, wMarriageAtx.Blob())) + return wMarriageAtx + } + + t.Run("valid", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + marriageAtx := setupMarriage(db) + + atx1 := newActivationTxV2( + withMarriageATX(marriageAtx.ID()), + withPublishEpoch(marriageAtx.PublishEpoch+1), + ) + atx1.Sign(sig) + + atx2 := newActivationTxV2( + withMarriageATX(marriageAtx.ID()), + withPublishEpoch(marriageAtx.PublishEpoch+1), + ) + atx2.Sign(otherSig) + + proof, err := NewDoubleMergeProof(db, atx1, atx2) + require.NoError(t, err) + id, err := proof.Valid(context.Background(), verifier) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), id) + }) + + t.Run("same ATX ID", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + marriageAtx := setupMarriage(db) + + atx1 := newActivationTxV2( + withMarriageATX(marriageAtx.ID()), + withPublishEpoch(marriageAtx.PublishEpoch+1), + ) + atx1.Sign(sig) + + proof, err := NewDoubleMergeProof(db, atx1, atx1) + require.EqualError(t, err, "ATXs have the same ID") + require.Nil(t, proof) + + proof = &ProofDoubleMerge{ + ATXID1: atx1.ID(), + ATXID2: atx1.ID(), + } + id, err := proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATXs have the same ID") + require.Equal(t, types.EmptyNodeID, id) + }) + + t.Run("ATXs must have different signers", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + atx1 := newActivationTxV2() + atx1.Sign(sig) + + atx2 := newActivationTxV2() + atx2.Sign(sig) + + proof, err := NewDoubleMergeProof(db, atx1, atx2) + require.ErrorContains(t, err, "ATXs have the same smesher") + require.Nil(t, proof) + }) + + t.Run("ATXs must be published in the same epoch", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + atx := newActivationTxV2( + withPublishEpoch(1), + ) + atx.Sign(sig) + + atx2 := newActivationTxV2( + withPublishEpoch(2), + ) + atx2.Sign(otherSig) + proof, err := NewDoubleMergeProof(db, atx, atx2) + require.ErrorContains(t, err, "ATXs have different publish epoch") + require.Nil(t, proof) + }) + + t.Run("ATXs must have valid marriage ATX", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + atx := newActivationTxV2( + withPublishEpoch(1), + ) + atx.Sign(sig) + + atx2 := newActivationTxV2( + withPublishEpoch(1), + ) + atx2.Sign(otherSig) + + // ATX 1 has no marriage + proof, err := NewDoubleMergeProof(db, atx, atx2) + require.ErrorContains(t, err, "ATX 1 have no marriage ATX") + require.Nil(t, proof) + + // ATX 2 has no marriage + atx.MarriageATX = new(types.ATXID) + *atx.MarriageATX = types.RandomATXID() + proof, err = NewDoubleMergeProof(db, atx, atx2) + require.ErrorContains(t, err, "ATX 2 have no marriage ATX") + require.Nil(t, proof) + + // ATX 1 and 2 must have the same marriage ATX + atx2.MarriageATX = new(types.ATXID) + *atx2.MarriageATX = types.RandomATXID() + proof, err = NewDoubleMergeProof(db, atx, atx2) + require.ErrorContains(t, err, "ATXs have different marriage ATXs") + require.Nil(t, proof) + + // Marriage ATX must be valid + atx2.MarriageATX = atx.MarriageATX + proof, err = NewDoubleMergeProof(db, atx, atx2) + require.ErrorIs(t, err, sql.ErrNotFound) + require.Nil(t, proof) + }) + + t.Run("invalid proof", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + marriageAtx := setupMarriage(db) + + atx1 := newActivationTxV2( + withMarriageATX(marriageAtx.ID()), + withPublishEpoch(marriageAtx.PublishEpoch+1), + ) + atx1.Sign(sig) + + atx2 := newActivationTxV2( + withMarriageATX(marriageAtx.ID()), + withPublishEpoch(marriageAtx.PublishEpoch+1), + ) + atx2.Sign(otherSig) + + proof, err := NewDoubleMergeProof(db, atx1, atx2) + require.NoError(t, err) + + // invalid marriage ATX ID + marriageAtxID := proof.MarriageATX + proof.MarriageATX = types.RandomATXID() + id, err := proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 1 invalid marriage ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.MarriageATX = marriageAtxID + + // invalid marriage ATX smesher ID + smesherID := proof.MarriageATXSmesherID + proof.MarriageATXSmesherID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 1 invalid marriage ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.MarriageATXSmesherID = smesherID + + // invalid ATX1 ID + id1 := proof.ATXID1 + proof.ATXID1 = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 1 invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.ATXID1 = id1 + + // invalid ATX2 ID + id2 := proof.ATXID2 + proof.ATXID2 = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 2 invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.ATXID2 = id2 + + // invalid ATX1 smesher ID + smesherID1 := proof.SmesherID1 + proof.SmesherID1 = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 1 invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.SmesherID1 = smesherID1 + + // invalid ATX2 smesher ID + smesherID2 := proof.SmesherID2 + proof.SmesherID2 = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 2 invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.SmesherID2 = smesherID2 + + // invalid ATX1 signature + signature1 := proof.Signature1 + proof.Signature1 = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 1 invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Signature1 = signature1 + + // invalid ATX2 signature + signature2 := proof.Signature2 + proof.Signature2 = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 2 invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Signature2 = signature2 + + // invalid publish epoch proof 1 + hash := proof.PublishEpochProof1[0] + proof.PublishEpochProof1[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 1 invalid publish epoch proof") + require.Equal(t, types.EmptyNodeID, id) + proof.PublishEpochProof1[0] = hash + + // invalid publish epoch proof 2 + hash = proof.PublishEpochProof2[0] + proof.PublishEpochProof2[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 2 invalid publish epoch proof") + require.Equal(t, types.EmptyNodeID, id) + proof.PublishEpochProof2[0] = hash + + // invalid marriage ATX proof 1 + hash = proof.MarriageATXProof1[0] + proof.MarriageATXProof1[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 1 invalid marriage ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.MarriageATXProof1[0] = hash + + // invalid marriage ATX proof 2 + hash = proof.MarriageATXProof2[0] + proof.MarriageATXProof2[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "ATX 2 invalid marriage ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.MarriageATXProof2[0] = hash + }) +} diff --git a/activation/wire/malfeasance_invalid_post.go b/activation/wire/malfeasance_invalid_post.go index 076b678bcb..ddb792c9f0 100644 --- a/activation/wire/malfeasance_invalid_post.go +++ b/activation/wire/malfeasance_invalid_post.go @@ -16,10 +16,10 @@ import ( // ProofInvalidPost is a proof that a merged ATX with an invalid Post was published by a smesher. // // We are proofing the following: -// 1. The ATX has a valid signature. -// 2. If NodeID is different from SmesherID, we prove that NodeID and SmesherID are married. -// 3. The commitment ATX of NodeID used for the invalid PoST based on their initial ATX. -// 4. The provided Post is invalid for the given NodeID. +// 1. The ATX has a valid signature. +// 2. If NodeID is different from SmesherID, we prove that NodeID and SmesherID are married. +// 3. The commitment ATX of NodeID used for the invalid PoST based on their initial ATX. +// 4. The provided Post is invalid for the given NodeID. type ProofInvalidPost struct { // ATXID is the ID of the ATX containing the invalid PoST. ATXID types.ATXID @@ -33,8 +33,6 @@ type ProofInvalidPost struct { // MarriageProof is the proof that NodeID and SmesherID are married. It is nil if NodeID == SmesherID. MarriageProof *MarriageProof - // CommitmentProof is the proof for the commitment ATX of the smesher. Generated from the initial ATX of NodeID. - CommitmentProof CommitmentProof // InvalidPostProof is the proof for the invalid PoST of the ATX. It contains the PoST and the merkle proofs to // verify the PoST. InvalidPostProof InvalidPostProof @@ -44,10 +42,12 @@ var _ Proof = &ProofInvalidPost{} func NewInvalidPostProof( db sql.Executor, - atx, initialATX *ActivationTxV2, + atx *ActivationTxV2, + commitmentATX types.ATXID, nodeID types.NodeID, nipostIndex int, invalidPostIndex uint32, + validPostIndex uint32, ) (*ProofInvalidPost, error) { if atx.SmesherID != nodeID && atx.MarriageATX == nil { return nil, errors.New("ATX is not a merged ATX, but NodeID is different from SmesherID") @@ -73,11 +73,14 @@ func NewInvalidPostProof( } } - commitmentProof, err := createCommitmentProof(initialATX, nodeID) - if err != nil { - return nil, fmt.Errorf("commitment proof: %w", err) - } - invalidPostProof, err := createInvalidPostProof(atx, nipostIndex, postIndex, invalidPostIndex) + invalidPostProof, err := createInvalidPostProof( + atx, + commitmentATX, + nipostIndex, + postIndex, + invalidPostIndex, + validPostIndex, + ) if err != nil { return nil, fmt.Errorf("invalid post proof: %w", err) } @@ -91,7 +94,6 @@ func NewInvalidPostProof( MarriageProof: marriageProof, - CommitmentProof: commitmentProof, InvalidPostProof: invalidPostProof, }, nil } @@ -113,16 +115,11 @@ func (p ProofInvalidPost) Valid(ctx context.Context, malValidator MalfeasanceVal marriageIndex = &p.MarriageProof.NodeIDMarryProof.CertificateIndex } - if err := p.CommitmentProof.Valid(malValidator, p.NodeID); err != nil { - return types.EmptyNodeID, fmt.Errorf("invalid commitment proof: %w", err) - } - if err := p.InvalidPostProof.Valid( ctx, malValidator, p.ATXID, p.NodeID, - p.CommitmentProof.CommitmentATX, marriageIndex, ); err != nil { return types.EmptyNodeID, fmt.Errorf("invalid invalid post proof: %w", err) @@ -131,65 +128,11 @@ func (p ProofInvalidPost) Valid(ctx context.Context, malValidator MalfeasanceVal return p.NodeID, nil } -// CommitmentProof is a proof for the commitment ATX of a smesher. It is generated from the initial ATX. -type CommitmentProof struct { - // InitialATXID is the ID of the initial ATX of the smesher. - InitialATXID types.ATXID - - // InitialPostRoot and its proof that it is contained in the InitialATX. - InitialPostRoot InitialPostRoot - InitialPostProof InitialPostRootProof `scale:"max=32"` - - // CommitmentATX and its proof that it is contained in the InitialPostRoot. - CommitmentATX types.ATXID - CommitmentATXProof CommitmentATXProof `scale:"max=32"` - - // Signature is the signature of the ATXID by the smesher. - Signature types.EdSignature -} - -func createCommitmentProof(initialAtx *ActivationTxV2, nodeID types.NodeID) (CommitmentProof, error) { - if initialAtx.SmesherID != nodeID { - return CommitmentProof{}, errors.New("node ID does not match smesher ID of initial ATX") - } - if initialAtx.Initial == nil { - return CommitmentProof{}, errors.New("initial ATX does not contain initial PoST") - } - - return CommitmentProof{ - InitialATXID: initialAtx.ID(), - - InitialPostRoot: initialAtx.Initial.Root(), - InitialPostProof: initialAtx.InitialPostRootProof(), - - CommitmentATX: initialAtx.Initial.CommitmentATX, - CommitmentATXProof: initialAtx.Initial.CommitmentATXProof(), - - Signature: initialAtx.Signature, - }, nil -} - -func (p CommitmentProof) Valid(malValidator MalfeasanceValidator, nodeID types.NodeID) error { - if !malValidator.Signature(signing.ATX, nodeID, p.InitialATXID.Bytes(), p.Signature) { - return errors.New("invalid signature") - } - - if types.Hash32(p.InitialPostRoot) == types.EmptyHash32 { - return errors.New("invalid empty initial PoST root") // initial PoST root is empty for non-initial ATXs - } - - if !p.InitialPostProof.Valid(p.InitialATXID, p.InitialPostRoot) { - return errors.New("invalid initial PoST proof") - } - if !p.CommitmentATXProof.Valid(p.InitialPostRoot, p.CommitmentATX) { - return errors.New("invalid commitment ATX proof") - } - - return nil -} - // InvalidPostProof is a proof for an invalid PoST in an ATX. It contains the PoST and the merkle proofs to verify the // PoST. +// +// It contains both a valid and an invalid PoST index. This is required to proof that the commitment ATX was used to +// initialize the data for the invalid PoST. If a PoST contains no valid indices, then the ATX is syntactically invalid. type InvalidPostProof struct { // NIPostsRoot and its proof that it is contained in the ATX. NIPostsRoot NIPostsRoot @@ -213,8 +156,8 @@ type InvalidPostProof struct { SubPostRootProof SubPostRootProof `scale:"max=32"` SubPostRootIndex uint16 - // MarriageIndexProof is the proof that the MarriageIndex (CertificateIndex from MarryProof) is contained in the - // SubPostRoot. + // MarriageIndexProof is the proof that the MarriageIndex (CertificateIndex from NodeIDMarryProof) is contained in + // the SubPostRoot. MarriageIndexProof MarriageIndexProof `scale:"max=32"` // Post is the invalid PoST and its proof that it is contained in the SubPostRoot. @@ -225,15 +168,23 @@ type InvalidPostProof struct { NumUnits uint32 NumUnitsProof NumUnitsProof `scale:"max=32"` + // CommitmentATX is the ATX that was used to initialize data for the invalid PoST. + CommitmentATX types.ATXID + // InvalidPostIndex is the index of the leaf that was identified to be invalid. InvalidPostIndex uint32 + + // ValidPostIndex is the index of a leaf that was identified to be valid. + ValidPostIndex uint32 } func createInvalidPostProof( atx *ActivationTxV2, + commitmentATX types.ATXID, nipostIndex, postIndex int, invalidPostIndex uint32, + validPostIndex uint32, ) (InvalidPostProof, error) { if nipostIndex < 0 || nipostIndex >= len(atx.NIPosts) { return InvalidPostProof{}, errors.New("invalid NIPoST index") @@ -268,7 +219,10 @@ func createInvalidPostProof( NumUnits: atx.NIPosts[nipostIndex].Posts[postIndex].NumUnits, NumUnitsProof: atx.NIPosts[nipostIndex].Posts[postIndex].NumUnitsProof(atx.PreviousATXs), + CommitmentATX: commitmentATX, + InvalidPostIndex: invalidPostIndex, + ValidPostIndex: validPostIndex, }, nil } @@ -279,7 +233,6 @@ func (p InvalidPostProof) Valid( malValidator MalfeasanceValidator, atxID types.ATXID, nodeID types.NodeID, - commitmentATX types.ATXID, marriageIndex *uint32, ) error { if !p.NIPostsRootProof.Valid(atxID, p.NIPostsRoot) { @@ -312,7 +265,19 @@ func (p InvalidPostProof) Valid( if err := malValidator.PostIndex( ctx, nodeID, - commitmentATX, + p.CommitmentATX, + PostFromWireV1(&p.Post), + p.Challenge.Bytes(), + p.NumUnits, + int(p.ValidPostIndex), + ); err != nil { + return errors.New("Commitment ATX is not valid") + } + + if err := malValidator.PostIndex( + ctx, + nodeID, + p.CommitmentATX, PostFromWireV1(&p.Post), p.Challenge.Bytes(), p.NumUnits, diff --git a/activation/wire/malfeasance_invalid_post_scale.go b/activation/wire/malfeasance_invalid_post_scale.go index 64f2f630cc..ffcef75f16 100644 --- a/activation/wire/malfeasance_invalid_post_scale.go +++ b/activation/wire/malfeasance_invalid_post_scale.go @@ -44,13 +44,6 @@ func (t *ProofInvalidPost) EncodeScale(enc *scale.Encoder) (total int, err error } total += n } - { - n, err := t.CommitmentProof.EncodeScale(enc) - if err != nil { - return total, err - } - total += n - } { n, err := t.InvalidPostProof.EncodeScale(enc) if err != nil { @@ -98,13 +91,6 @@ func (t *ProofInvalidPost) DecodeScale(dec *scale.Decoder) (total int, err error total += n t.MarriageProof = field } - { - n, err := t.CommitmentProof.DecodeScale(dec) - if err != nil { - return total, err - } - total += n - } { n, err := t.InvalidPostProof.DecodeScale(dec) if err != nil { @@ -115,100 +101,6 @@ func (t *ProofInvalidPost) DecodeScale(dec *scale.Decoder) (total int, err error return total, nil } -func (t *CommitmentProof) EncodeScale(enc *scale.Encoder) (total int, err error) { - { - n, err := scale.EncodeByteArray(enc, t.InitialATXID[:]) - if err != nil { - return total, err - } - total += n - } - { - n, err := scale.EncodeByteArray(enc, t.InitialPostRoot[:]) - if err != nil { - return total, err - } - total += n - } - { - n, err := scale.EncodeStructSliceWithLimit(enc, t.InitialPostProof, 32) - if err != nil { - return total, err - } - total += n - } - { - n, err := scale.EncodeByteArray(enc, t.CommitmentATX[:]) - if err != nil { - return total, err - } - total += n - } - { - n, err := scale.EncodeStructSliceWithLimit(enc, t.CommitmentATXProof, 32) - if err != nil { - return total, err - } - total += n - } - { - n, err := scale.EncodeByteArray(enc, t.Signature[:]) - if err != nil { - return total, err - } - total += n - } - return total, nil -} - -func (t *CommitmentProof) DecodeScale(dec *scale.Decoder) (total int, err error) { - { - n, err := scale.DecodeByteArray(dec, t.InitialATXID[:]) - if err != nil { - return total, err - } - total += n - } - { - n, err := scale.DecodeByteArray(dec, t.InitialPostRoot[:]) - if err != nil { - return total, err - } - total += n - } - { - field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) - if err != nil { - return total, err - } - total += n - t.InitialPostProof = field - } - { - n, err := scale.DecodeByteArray(dec, t.CommitmentATX[:]) - if err != nil { - return total, err - } - total += n - } - { - field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) - if err != nil { - return total, err - } - total += n - t.CommitmentATXProof = field - } - { - n, err := scale.DecodeByteArray(dec, t.Signature[:]) - if err != nil { - return total, err - } - total += n - } - return total, nil -} - func (t *InvalidPostProof) EncodeScale(enc *scale.Encoder) (total int, err error) { { n, err := scale.EncodeByteArray(enc, t.NIPostsRoot[:]) @@ -329,6 +221,13 @@ func (t *InvalidPostProof) EncodeScale(enc *scale.Encoder) (total int, err error } total += n } + { + n, err := scale.EncodeByteArray(enc, t.CommitmentATX[:]) + if err != nil { + return total, err + } + total += n + } { n, err := scale.EncodeCompact32(enc, uint32(t.InvalidPostIndex)) if err != nil { @@ -336,6 +235,13 @@ func (t *InvalidPostProof) EncodeScale(enc *scale.Encoder) (total int, err error } total += n } + { + n, err := scale.EncodeCompact32(enc, uint32(t.ValidPostIndex)) + if err != nil { + return total, err + } + total += n + } return total, nil } @@ -470,6 +376,13 @@ func (t *InvalidPostProof) DecodeScale(dec *scale.Decoder) (total int, err error total += n t.NumUnitsProof = field } + { + n, err := scale.DecodeByteArray(dec, t.CommitmentATX[:]) + if err != nil { + return total, err + } + total += n + } { field, n, err := scale.DecodeCompact32(dec) if err != nil { @@ -478,5 +391,13 @@ func (t *InvalidPostProof) DecodeScale(dec *scale.Decoder) (total int, err error total += n t.InvalidPostIndex = uint32(field) } + { + field, n, err := scale.DecodeCompact32(dec) + if err != nil { + return total, err + } + total += n + t.ValidPostIndex = uint32(field) + } return total, nil } diff --git a/activation/wire/malfeasance_invalid_post_test.go b/activation/wire/malfeasance_invalid_post_test.go index 9144420297..c0bea896a8 100644 --- a/activation/wire/malfeasance_invalid_post_test.go +++ b/activation/wire/malfeasance_invalid_post_test.go @@ -20,6 +20,8 @@ import ( ) func Test_InvalidPostProof(t *testing.T) { + t.Parallel() + // sig is the identity that creates the invalid PoST sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -39,20 +41,8 @@ func Test_InvalidPostProof(t *testing.T) { nipostChallenge types.Hash32, post PostV1, numUnits uint32, - ) (*ActivationTxV2, *ActivationTxV2) { - wInitialAtx := newActivationTxV2( - withInitial(types.RandomATXID(), PostV1{}), - ) - wInitialAtx.Sign(sig) - initialAtx := &types.ActivationTx{ - CommitmentATX: &wInitialAtx.Initial.CommitmentATX, - } - initialAtx.SetID(wInitialAtx.ID()) - initialAtx.SmesherID = sig.NodeID() - require.NoError(t, atxs.Add(db, initialAtx, wInitialAtx.Blob())) - + ) *ActivationTxV2 { atx := newActivationTxV2( - withPreviousATXs(wInitialAtx.ID()), withNIPost( withNIPostChallenge(nipostChallenge), withNIPostSubPost(SubPostV2{ @@ -62,7 +52,7 @@ func Test_InvalidPostProof(t *testing.T) { ), ) atx.Sign(sig) - return atx, wInitialAtx + return atx } newMergedATXv2 := func( @@ -70,7 +60,7 @@ func Test_InvalidPostProof(t *testing.T) { nipostChallenge types.Hash32, post PostV1, numUnits uint32, - ) (*ActivationTxV2, *ActivationTxV2) { + ) *ActivationTxV2 { wInitialAtx := newActivationTxV2( withInitial(types.RandomATXID(), PostV1{}), ) @@ -130,10 +120,11 @@ func Test_InvalidPostProof(t *testing.T) { ), ) atx.Sign(pubSig) - return atx, wInitialAtx + return atx } t.Run("valid", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -143,10 +134,12 @@ func Test_InvalidPostProof(t *testing.T) { Indices: types.RandomBytes(11), Pow: rand.Uint64(), } - atx, initialAtx := newSoloATXv2(db, nipostChallenge, post, numUnits) + atx := newSoloATXv2(db, nipostChallenge, post, numUnits) + commitmentATX := types.RandomATXID() const invalidPostIndex = 7 - proof, err := NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 0, invalidPostIndex) + const validPostIndex = 15 + proof, err := NewInvalidPostProof(db, atx, commitmentATX, sig.NodeID(), 0, invalidPostIndex, validPostIndex) require.NoError(t, err) ctrl := gomock.NewController(t) @@ -159,7 +152,17 @@ func Test_InvalidPostProof(t *testing.T) { verifier.EXPECT().PostIndex( context.Background(), sig.NodeID(), - initialAtx.Initial.CommitmentATX, + commitmentATX, + PostFromWireV1(&post), + nipostChallenge.Bytes(), + numUnits, + validPostIndex, + ).Return(nil) + + verifier.EXPECT().PostIndex( + context.Background(), + sig.NodeID(), + commitmentATX, PostFromWireV1(&post), nipostChallenge.Bytes(), numUnits, @@ -172,6 +175,7 @@ func Test_InvalidPostProof(t *testing.T) { }) t.Run("valid merged atx", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -181,10 +185,12 @@ func Test_InvalidPostProof(t *testing.T) { Indices: types.RandomBytes(11), Pow: rand.Uint64(), } - atx, initialAtx := newMergedATXv2(db, nipostChallenge, post, numUnits) + atx := newMergedATXv2(db, nipostChallenge, post, numUnits) + commitmentATX := types.RandomATXID() const invalidPostIndex = 7 - proof, err := NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 0, invalidPostIndex) + const validPostIndex = 15 + proof, err := NewInvalidPostProof(db, atx, commitmentATX, sig.NodeID(), 0, invalidPostIndex, validPostIndex) require.NoError(t, err) ctrl := gomock.NewController(t) @@ -197,19 +203,30 @@ func Test_InvalidPostProof(t *testing.T) { verifier.EXPECT().PostIndex( context.Background(), sig.NodeID(), - initialAtx.Initial.CommitmentATX, + commitmentATX, PostFromWireV1(&post), nipostChallenge.Bytes(), numUnits, invalidPostIndex, ).Return(errors.New("invalid post")) + verifier.EXPECT().PostIndex( + context.Background(), + sig.NodeID(), + commitmentATX, + PostFromWireV1(&post), + nipostChallenge.Bytes(), + numUnits, + validPostIndex, + ).Return(nil) + id, err := proof.Valid(context.Background(), verifier) require.NoError(t, err) require.Equal(t, sig.NodeID(), id) }) t.Run("post is valid", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -219,10 +236,12 @@ func Test_InvalidPostProof(t *testing.T) { Indices: types.RandomBytes(11), Pow: rand.Uint64(), } - atx, initialAtx := newSoloATXv2(db, nipostChallenge, post, numUnits) + atx := newSoloATXv2(db, nipostChallenge, post, numUnits) + commitmentAtx := types.RandomATXID() const invalidPostIndex = 7 - proof, err := NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 0, invalidPostIndex) + const validPostIndex = 15 + proof, err := NewInvalidPostProof(db, atx, commitmentAtx, sig.NodeID(), 0, invalidPostIndex, validPostIndex) require.NoError(t, err) ctrl := gomock.NewController(t) @@ -235,7 +254,17 @@ func Test_InvalidPostProof(t *testing.T) { verifier.EXPECT().PostIndex( context.Background(), sig.NodeID(), - initialAtx.Initial.CommitmentATX, + commitmentAtx, + PostFromWireV1(&post), + nipostChallenge.Bytes(), + numUnits, + validPostIndex, + ).Return(nil) + + verifier.EXPECT().PostIndex( + context.Background(), + sig.NodeID(), + commitmentAtx, PostFromWireV1(&post), nipostChallenge.Bytes(), numUnits, @@ -247,7 +276,7 @@ func Test_InvalidPostProof(t *testing.T) { require.Equal(t, types.EmptyNodeID, id) }) - t.Run("differing node ID without marriage ATX", func(t *testing.T) { + t.Run("commitment ATX is not valid", func(t *testing.T) { db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -257,16 +286,13 @@ func Test_InvalidPostProof(t *testing.T) { Indices: types.RandomBytes(11), Pow: rand.Uint64(), } - atx, initialAtx := newSoloATXv2(db, nipostChallenge, post, numUnits) + atx := newSoloATXv2(db, nipostChallenge, post, numUnits) + commitmentAtx := types.RandomATXID() const invalidPostIndex = 7 - proof, err := NewInvalidPostProof(db, atx, initialAtx, types.RandomNodeID(), 0, invalidPostIndex) - require.EqualError(t, err, "ATX is not a merged ATX, but NodeID is different from SmesherID") - require.Nil(t, proof) - - proof, err = NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 0, invalidPostIndex) + const validPostIndex = 15 + proof, err := NewInvalidPostProof(db, atx, commitmentAtx, sig.NodeID(), 0, invalidPostIndex, validPostIndex) require.NoError(t, err) - require.NotNil(t, proof) ctrl := gomock.NewController(t) verifier := NewMockMalfeasanceValidator(ctrl) @@ -275,14 +301,23 @@ func Test_InvalidPostProof(t *testing.T) { return edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() - proof.NodeID = types.RandomNodeID() // invalid node ID + verifier.EXPECT().PostIndex( + context.Background(), + sig.NodeID(), + commitmentAtx, + PostFromWireV1(&post), + nipostChallenge.Bytes(), + numUnits, + validPostIndex, + ).Return(errors.New("invalid post")) id, err := proof.Valid(context.Background(), verifier) - require.EqualError(t, err, "missing marriage proof") + require.EqualError(t, err, "invalid invalid post proof: Commitment ATX is not valid") require.Equal(t, types.EmptyNodeID, id) }) - t.Run("node ID not in marriage ATX", func(t *testing.T) { + t.Run("differing node ID without marriage ATX", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -292,52 +327,19 @@ func Test_InvalidPostProof(t *testing.T) { Indices: types.RandomBytes(11), Pow: rand.Uint64(), } - atx, initialAtx := newMergedATXv2(db, nipostChallenge, post, numUnits) + atx := newSoloATXv2(db, nipostChallenge, post, numUnits) + commitmentAtx := types.RandomATXID() const invalidPostIndex = 7 + const validPostIndex = 15 nodeID := types.RandomNodeID() - proof, err := NewInvalidPostProof(db, atx, initialAtx, nodeID, 0, invalidPostIndex) - require.ErrorContains(t, err, - fmt.Sprintf("does not contain a marriage certificate signed by %s", nodeID.ShortString()), - ) + proof, err := NewInvalidPostProof(db, atx, commitmentAtx, nodeID, 0, invalidPostIndex, validPostIndex) + require.EqualError(t, err, "ATX is not a merged ATX, but NodeID is different from SmesherID") require.Nil(t, proof) }) - t.Run("invalid marriage proof", func(t *testing.T) { - db := statesql.InMemoryTest(t) - - nipostChallenge := types.RandomHash() - const numUnits = uint32(11) - post := PostV1{ - Nonce: rand.Uint32(), - Indices: types.RandomBytes(11), - Pow: rand.Uint64(), - } - atx, _ := newMergedATXv2(db, nipostChallenge, post, numUnits) - - ctrl := gomock.NewController(t) - verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { - return edVerifier.Verify(d, nodeID, m, sig) - }).AnyTimes() - - // manually construct an invalid proof - proof, err := createMarriageProof(db, atx, sig.NodeID()) - require.NoError(t, err) - - marriageATX := proof.MarriageATX - proof.MarriageATX = types.RandomATXID() // invalid ATX - err = proof.Valid(verifier, atx.ID(), sig.NodeID(), pubSig.NodeID()) - require.ErrorContains(t, err, "invalid marriage ATX proof") - - proof.MarriageATX = marriageATX - proof.MarriageATXProof[0] = types.RandomHash() // invalid proof - err = proof.Valid(verifier, atx.ID(), sig.NodeID(), pubSig.NodeID()) - require.ErrorContains(t, err, "invalid marriage ATX proof") - }) - - t.Run("node ID did not include post in merged ATX", func(t *testing.T) { + t.Run("nipost index is invalid", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -347,18 +349,18 @@ func Test_InvalidPostProof(t *testing.T) { Indices: types.RandomBytes(11), Pow: rand.Uint64(), } - atx, initialAtx := newMergedATXv2(db, nipostChallenge, post, numUnits) - atx.NIPosts[0].Posts = slices.DeleteFunc(atx.NIPosts[0].Posts, func(subPost SubPostV2) bool { - return cmp.Equal(subPost.Post, post) - }) + atx := newSoloATXv2(db, nipostChallenge, post, numUnits) + commitmentAtx := types.RandomATXID() const invalidPostIndex = 7 - proof, err := NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 0, invalidPostIndex) - require.EqualError(t, err, fmt.Sprintf("no PoST from %s in ATX", sig)) + const validPostIndex = 15 + proof, err := NewInvalidPostProof(db, atx, commitmentAtx, sig.NodeID(), 1, invalidPostIndex, validPostIndex) + require.EqualError(t, err, "invalid NIPoST index") require.Nil(t, proof) }) - t.Run("initial ATX is invalid", func(t *testing.T) { + t.Run("node ID not in marriage ATX", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -368,23 +370,21 @@ func Test_InvalidPostProof(t *testing.T) { Indices: types.RandomBytes(11), Pow: rand.Uint64(), } - atx, initialAtx := newMergedATXv2(db, nipostChallenge, post, numUnits) - initialAtx.SmesherID = types.RandomNodeID() // initial ATX published by different identity + atx := newMergedATXv2(db, nipostChallenge, post, numUnits) + commitmentAtx := types.RandomATXID() const invalidPostIndex = 7 - proof, err := NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 0, invalidPostIndex) - require.ErrorContains(t, err, "node ID does not match smesher ID of initial ATX") - require.Nil(t, proof) - - atx, initialAtx = newMergedATXv2(db, nipostChallenge, post, numUnits) - initialAtx.Initial = nil // not an initial ATX - - proof, err = NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 0, invalidPostIndex) - require.ErrorContains(t, err, "initial ATX does not contain initial PoST") + const validPostIndex = 15 + nodeID := types.RandomNodeID() + proof, err := NewInvalidPostProof(db, atx, commitmentAtx, nodeID, 0, invalidPostIndex, validPostIndex) + require.ErrorContains(t, err, + fmt.Sprintf("does not contain a marriage certificate signed by %s", nodeID.ShortString()), + ) require.Nil(t, proof) }) - t.Run("invalid nipost index", func(t *testing.T) { + t.Run("node ID did not include post in merged ATX", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -394,15 +394,21 @@ func Test_InvalidPostProof(t *testing.T) { Indices: types.RandomBytes(11), Pow: rand.Uint64(), } - atx, initialAtx := newSoloATXv2(db, nipostChallenge, post, numUnits) + atx := newMergedATXv2(db, nipostChallenge, post, numUnits) + commitmentAtx := types.RandomATXID() + atx.NIPosts[0].Posts = slices.DeleteFunc(atx.NIPosts[0].Posts, func(subPost SubPostV2) bool { + return cmp.Equal(subPost.Post, post) + }) const invalidPostIndex = 7 - proof, err := NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 1, invalidPostIndex) // 1 is invalid - require.EqualError(t, err, "invalid NIPoST index") + const validPostIndex = 15 + proof, err := NewInvalidPostProof(db, atx, commitmentAtx, sig.NodeID(), 0, invalidPostIndex, validPostIndex) + require.EqualError(t, err, fmt.Sprintf("no PoST from %s in ATX", sig)) require.Nil(t, proof) }) - t.Run("invalid ATX signature", func(t *testing.T) { + t.Run("invalid solo proof", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -412,10 +418,12 @@ func Test_InvalidPostProof(t *testing.T) { Indices: types.RandomBytes(11), Pow: rand.Uint64(), } - atx, initialAtx := newSoloATXv2(db, nipostChallenge, post, numUnits) + atx := newSoloATXv2(db, nipostChallenge, post, numUnits) + commitmentAtx := types.RandomATXID() const invalidPostIndex = 7 - proof, err := NewInvalidPostProof(db, atx, initialAtx, sig.NodeID(), 0, invalidPostIndex) + const validPostIndex = 15 + proof, err := NewInvalidPostProof(db, atx, commitmentAtx, sig.NodeID(), 0, invalidPostIndex, validPostIndex) require.NoError(t, err) ctrl := gomock.NewController(t) @@ -425,179 +433,166 @@ func Test_InvalidPostProof(t *testing.T) { return edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() - proof.Signature = types.RandomEdSignature() // invalid signature - + // invalid ATXID + proof.ATXID = types.RandomATXID() id, err := proof.Valid(context.Background(), verifier) require.EqualError(t, err, "invalid signature") require.Equal(t, types.EmptyNodeID, id) - }) + proof.ATXID = atx.ID() - t.Run("commitment proof is invalid", func(t *testing.T) { - db := statesql.InMemoryTest(t) + // invalid smesher ID + proof.SmesherID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.SmesherID = atx.SmesherID - nipostChallenge := types.RandomHash() - const numUnits = uint32(11) - post := PostV1{ - Nonce: rand.Uint32(), - Indices: types.RandomBytes(11), - Pow: rand.Uint64(), - } - _, initialAtx := newMergedATXv2(db, nipostChallenge, post, numUnits) + // invalid signature + proof.Signature = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Signature = atx.Signature - ctrl := gomock.NewController(t) - verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { - return edVerifier.Verify(d, nodeID, m, sig) - }).AnyTimes() + // invalid node ID + proof.NodeID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "missing marriage proof") + require.Equal(t, types.EmptyNodeID, id) + proof.NodeID = sig.NodeID() - // manually construct an invalid proof - proof, err := createCommitmentProof(initialAtx, sig.NodeID()) - require.NoError(t, err) + // invalid niposts root + nipostsRoot := proof.InvalidPostProof.NIPostsRoot + proof.InvalidPostProof.NIPostsRoot = NIPostsRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid NIPosts root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.NIPostsRoot = nipostsRoot - signature := proof.Signature - proof.Signature = types.RandomEdSignature() // invalid signature - err = proof.Valid(verifier, sig.NodeID()) - require.ErrorContains(t, err, "invalid signature") - proof.Signature = signature - - proof.InitialATXID = types.RandomATXID() // invalid ATX - err = proof.Valid(verifier, sig.NodeID()) - require.ErrorContains(t, err, "invalid signature") - proof.InitialATXID = initialAtx.ID() - - proofHash := proof.InitialPostProof[0] - proof.InitialPostProof[0] = types.RandomHash() // invalid proof - err = proof.Valid(verifier, sig.NodeID()) - require.ErrorContains(t, err, "invalid initial PoST proof") - proof.InitialPostProof[0] = proofHash - - initialPostRoot := proof.InitialPostRoot - proof.InitialPostRoot = InitialPostRoot(types.EmptyHash32) // invalid initial post root - err = proof.Valid(verifier, sig.NodeID()) - require.ErrorContains(t, err, "invalid empty initial PoST root") - proof.InitialPostRoot = initialPostRoot - - commitmentATX := proof.CommitmentATX - proof.CommitmentATX = types.RandomATXID() // invalid ATX - err = proof.Valid(verifier, sig.NodeID()) - require.ErrorContains(t, err, "invalid commitment ATX proof") - proof.CommitmentATX = commitmentATX - - proofHash = proof.CommitmentATXProof[0] - proof.CommitmentATXProof[0] = types.RandomHash() // invalid proof - err = proof.Valid(verifier, sig.NodeID()) - require.ErrorContains(t, err, "invalid commitment ATX proof") - proof.CommitmentATXProof[0] = proofHash - }) + // invalid niposts root proof + hash := proof.InvalidPostProof.NIPostsRootProof[0] + proof.InvalidPostProof.NIPostsRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid NIPosts root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.NIPostsRootProof[0] = hash - t.Run("solo invalid post proof is not valid", func(t *testing.T) { - db := statesql.InMemoryTest(t) + // invalid nipost root + nipostRoot := proof.InvalidPostProof.NIPostRoot + proof.InvalidPostProof.NIPostRoot = NIPostRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.NIPostRoot = nipostRoot - nipostChallenge := types.RandomHash() - const numUnits = uint32(11) - post := PostV1{ + // invalid nipost root proof + hash = proof.InvalidPostProof.NIPostRootProof[0] + proof.InvalidPostProof.NIPostRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.NIPostRootProof[0] = hash + + // invalid nipost index + proof.InvalidPostProof.NIPostIndex = 1 + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.NIPostIndex = 0 + + // invalid challenge + challenge := proof.InvalidPostProof.Challenge + proof.InvalidPostProof.Challenge = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid challenge proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.Challenge = challenge + + // invalid challenge proof + hash = proof.InvalidPostProof.ChallengeProof[0] + proof.InvalidPostProof.ChallengeProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid challenge proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.ChallengeProof[0] = hash + + // invalid subposts root + subPostsRoot := proof.InvalidPostProof.SubPostsRoot + proof.InvalidPostProof.SubPostsRoot = SubPostsRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid sub PoSTs root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.SubPostsRoot = subPostsRoot + + // invalid subposts root proof + hash = proof.InvalidPostProof.SubPostsRootProof[0] + proof.InvalidPostProof.SubPostsRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid sub PoSTs root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.SubPostsRootProof[0] = hash + + // invalid subpost root + subPostRoot := proof.InvalidPostProof.SubPostRoot + proof.InvalidPostProof.SubPostRoot = SubPostRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.SubPostRoot = subPostRoot + + // invalid subpost root proof + hash = proof.InvalidPostProof.SubPostRootProof[0] + proof.InvalidPostProof.SubPostRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.SubPostRootProof[0] = hash + + // invalid subpost root index + proof.InvalidPostProof.SubPostRootIndex++ + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.SubPostRootIndex-- + + // invalid post + post = proof.InvalidPostProof.Post + proof.InvalidPostProof.Post = PostV1{ Nonce: rand.Uint32(), Indices: types.RandomBytes(11), Pow: rand.Uint64(), } - atx, initialAtx := newSoloATXv2(db, nipostChallenge, post, numUnits) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid post proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.Post = post - ctrl := gomock.NewController(t) - verifier := NewMockMalfeasanceValidator(ctrl) - verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { - return edVerifier.Verify(d, nodeID, m, sig) - }).AnyTimes() + // invalid post proof + hash = proof.InvalidPostProof.PostProof[0] + proof.InvalidPostProof.PostProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid post proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.PostProof[0] = hash - // manually construct an invalid proof - const invalidPostIndex = 7 - proof, err := createInvalidPostProof(atx, 0, 0, invalidPostIndex) - require.NoError(t, err) - require.NotNil(t, proof) - - nipostsRoot := proof.NIPostsRoot - proof.NIPostsRoot = NIPostsRoot(types.RandomHash()) // invalid root - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) - require.EqualError(t, err, "invalid NIPosts root proof") - proof.NIPostsRoot = nipostsRoot - - proofHash := proof.NIPostsRootProof[0] - proof.NIPostsRootProof[0] = types.RandomHash() // invalid proof - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) - require.EqualError(t, err, "invalid NIPosts root proof") - proof.NIPostsRootProof[0] = proofHash - - proof.NIPostIndex = 1 // invalid index - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) - require.EqualError(t, err, "invalid NIPoST root proof") - proof.NIPostIndex = 0 - - nipostRoot := proof.NIPostRoot - proof.NIPostRoot = NIPostRoot(types.RandomHash()) // invalid root - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) - require.EqualError(t, err, "invalid NIPoST root proof") - proof.NIPostRoot = nipostRoot - - proofHash = proof.NIPostRootProof[0] - proof.NIPostRootProof[0] = types.RandomHash() // invalid proof - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) - require.EqualError(t, err, "invalid NIPoST root proof") - proof.NIPostRootProof[0] = proofHash - - challenge := proof.Challenge - proof.Challenge = types.RandomHash() // invalid challenge - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) - require.EqualError(t, err, "invalid challenge proof") - proof.Challenge = challenge - - proofHash = proof.ChallengeProof[0] - proof.ChallengeProof[0] = types.RandomHash() // invalid proof - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) - require.EqualError(t, err, "invalid challenge proof") - proof.ChallengeProof[0] = proofHash - - subPostsRoot := proof.SubPostsRoot - proof.SubPostsRoot = SubPostsRoot(types.RandomHash()) // invalid root - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) - require.EqualError(t, err, "invalid sub PoSTs root proof") - proof.SubPostsRoot = subPostsRoot - - proofHash = proof.SubPostsRootProof[0] - proof.SubPostsRootProof[0] = types.RandomHash() // invalid proof - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) - require.EqualError(t, err, "invalid sub PoSTs root proof") - proof.SubPostsRootProof[0] = proofHash - - proof.SubPostRootIndex = 1 // invalid index - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) - require.EqualError(t, err, "invalid sub PoST root proof") - proof.SubPostRootIndex = 0 - - subPost := proof.SubPostRoot - proof.SubPostRoot = SubPostRoot(types.RandomHash()) // invalid root - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) - require.EqualError(t, err, "invalid sub PoST root proof") - proof.SubPostRoot = subPost - - proofHash = proof.SubPostRootProof[0] - proof.SubPostRootProof[0] = types.RandomHash() // invalid proof - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) - require.EqualError(t, err, "invalid sub PoST root proof") - proof.SubPostRootProof[0] = proofHash - - proof.Post = PostV1{} // invalid post - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) - require.EqualError(t, err, "invalid PoST proof") - proof.Post = post - - proof.NumUnits++ // invalid number of units - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), initialAtx.Initial.CommitmentATX, nil) - require.EqualError(t, err, "invalid num units proof") - proof.NumUnits-- + // invalid numunits + proof.InvalidPostProof.NumUnits++ + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid post proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.NumUnits-- + + // invalid numunits proof + hash = proof.InvalidPostProof.NumUnitsProof[0] + proof.InvalidPostProof.NumUnitsProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid post proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.NumUnitsProof[0] = hash }) - t.Run("merged invalid post proof is not valid", func(t *testing.T) { + t.Run("invalid merged proof", func(t *testing.T) { + t.Parallel() db := statesql.InMemoryTest(t) nipostChallenge := types.RandomHash() @@ -607,7 +602,13 @@ func Test_InvalidPostProof(t *testing.T) { Indices: types.RandomBytes(11), Pow: rand.Uint64(), } - atx, initialAtx := newMergedATXv2(db, nipostChallenge, post, numUnits) + atx := newMergedATXv2(db, nipostChallenge, post, numUnits) + commitmentAtx := types.RandomATXID() + + const invalidPostIndex = 7 + const validPostIndex = 15 + proof, err := NewInvalidPostProof(db, atx, commitmentAtx, sig.NodeID(), 0, invalidPostIndex, validPostIndex) + require.NoError(t, err) ctrl := gomock.NewController(t) verifier := NewMockMalfeasanceValidator(ctrl) @@ -616,17 +617,48 @@ func Test_InvalidPostProof(t *testing.T) { return edVerifier.Verify(d, nodeID, m, sig) }).AnyTimes() - // manually construct an invalid proof - marriageIndex := uint32(1) - commitmentAtx := initialAtx.Initial.CommitmentATX - const invalidPostIndex = 7 - proof, err := createInvalidPostProof(atx, 0, 1, invalidPostIndex) - require.NoError(t, err) - require.NotNil(t, proof) + // invalid ATXID + proof.ATXID = types.RandomATXID() + id, err := proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.ATXID = atx.ID() + + // invalid smesher ID + proof.SmesherID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.SmesherID = atx.SmesherID - invalidMarriageIndex := marriageIndex + 1 + // invalid signature + proof.Signature = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.EqualError(t, err, "invalid signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Signature = atx.Signature - err = proof.Valid(context.Background(), verifier, atx.ID(), sig.NodeID(), commitmentAtx, &invalidMarriageIndex) - require.EqualError(t, err, "invalid marriage index proof") + // invalid node ID + proof.NodeID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid marriage proof for NodeID") + require.Equal(t, types.EmptyNodeID, id) + proof.NodeID = sig.NodeID() + + // invalid marriage index proof + hash := proof.InvalidPostProof.MarriageIndexProof[0] + proof.InvalidPostProof.MarriageIndexProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid marriage index proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.MarriageIndexProof[0] = hash + + // invalid numunits proof + hash = proof.InvalidPostProof.NumUnitsProof[0] + proof.InvalidPostProof.NumUnitsProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid post proof") + require.Equal(t, types.EmptyNodeID, id) + proof.InvalidPostProof.NumUnitsProof[0] = hash }) } diff --git a/activation/wire/malfeasance_invalid_prev_atx.go b/activation/wire/malfeasance_invalid_prev_atx.go new file mode 100644 index 0000000000..abb5acb5f6 --- /dev/null +++ b/activation/wire/malfeasance_invalid_prev_atx.go @@ -0,0 +1,343 @@ +package wire + +import ( + "context" + "errors" + "fmt" + "slices" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" +) + +//go:generate scalegen + +// ProofInvalidPrevAtxV2 is a proof that two distinct ATXs reference the same previous ATX for one of the included +// identities. +// +// We are proving the following: +// 1. The ATXs have different IDs. +// 2. Both ATXs have a valid signature. +// 3. Both ATXs reference the same previous ATX for the same identity. +// 4. If the signer of one of the two ATXs is not the identity that referenced the same previous ATX, then the identity +// that did is married to the signer via a valid marriage certificate in the referenced marriage ATX. +type ProofInvalidPrevAtxV2 struct { + // NodeID is the node ID that referenced the same previous ATX twice. + NodeID types.NodeID + + // PrevATX is the ATX that was referenced twice. + PrevATX types.ATXID + + Proofs [2]InvalidPrevAtxProof +} + +var _ Proof = &ProofInvalidPrevAtxV2{} + +func NewInvalidPrevAtxProofV2( + db sql.Executor, + atx1, atx2 *ActivationTxV2, + nodeID types.NodeID, +) (*ProofInvalidPrevAtxV2, error) { + if atx1.ID() == atx2.ID() { + return nil, errors.New("ATXs have the same ID") + } + + if atx1.SmesherID != nodeID && atx1.MarriageATX == nil { + return nil, errors.New("ATX1 is not a merged ATX, but NodeID is different from SmesherID") + } + + if atx2.SmesherID != nodeID && atx2.MarriageATX == nil { + return nil, errors.New("ATX2 is not a merged ATX, but NodeID is different from SmesherID") + } + + var marriageProof1 *MarriageProof + nipostIndex1 := 0 + postIndex1 := 0 + if atx1.SmesherID != nodeID { + proof, err := createMarriageProof(db, atx1, nodeID) + if err != nil { + return nil, fmt.Errorf("marriage proof: %w", err) + } + marriageProof1 = &proof + for i, nipost := range atx1.NIPosts { + postIndex1 = slices.IndexFunc(nipost.Posts, func(post SubPostV2) bool { + return post.MarriageIndex == proof.NodeIDMarryProof.CertificateIndex + }) + if postIndex1 != -1 { + nipostIndex1 = i + break + } + } + if postIndex1 == -1 { + return nil, fmt.Errorf("no PoST from %s in ATX", nodeID.ShortString()) + } + } + + var marriageProof2 *MarriageProof + nipostIndex2 := 0 + postIndex2 := 0 + if atx2.SmesherID != nodeID { + proof, err := createMarriageProof(db, atx2, nodeID) + if err != nil { + return nil, fmt.Errorf("marriage proof: %w", err) + } + marriageProof2 = &proof + for i, nipost := range atx2.NIPosts { + postIndex2 = slices.IndexFunc(nipost.Posts, func(post SubPostV2) bool { + return post.MarriageIndex == proof.NodeIDMarryProof.CertificateIndex + }) + if postIndex2 != -1 { + nipostIndex2 = i + break + } + } + if postIndex2 == -1 { + return nil, fmt.Errorf("no PoST from %s in ATX", nodeID.ShortString()) + } + } + + prevATX1 := atx1.PreviousATXs[atx1.NIPosts[nipostIndex1].Posts[postIndex1].PrevATXIndex] + prevATX2 := atx2.PreviousATXs[atx2.NIPosts[nipostIndex2].Posts[postIndex2].PrevATXIndex] + if prevATX1 != prevATX2 { + return nil, errors.New("ATXs reference different previous ATXs") + } + + proof1, err := createInvalidPrevAtxProof(atx1, prevATX1, nipostIndex1, postIndex1, marriageProof1) + if err != nil { + return nil, fmt.Errorf("proof for atx1: %w", err) + } + + proof2, err := createInvalidPrevAtxProof(atx2, prevATX2, nipostIndex2, postIndex2, marriageProof2) + if err != nil { + return nil, fmt.Errorf("proof for atx2: %w", err) + } + + proof := &ProofInvalidPrevAtxV2{ + NodeID: nodeID, + PrevATX: prevATX1, + Proofs: [2]InvalidPrevAtxProof{proof1, proof2}, + } + return proof, nil +} + +func createInvalidPrevAtxProof( + atx *ActivationTxV2, + prevATX types.ATXID, + nipostIndex, + postIndex int, + marriageProof *MarriageProof, +) (InvalidPrevAtxProof, error) { + proof := InvalidPrevAtxProof{ + ATXID: atx.ID(), + + NIPostsRoot: atx.NIPosts.Root(atx.PreviousATXs), + NIPostsRootProof: atx.NIPostsRootProof(), + + NIPostRoot: atx.NIPosts[nipostIndex].Root(atx.PreviousATXs), + NIPostRootProof: atx.NIPosts.Proof(int(nipostIndex), atx.PreviousATXs), + NIPostIndex: uint16(nipostIndex), + + SubPostsRoot: atx.NIPosts[nipostIndex].Posts.Root(atx.PreviousATXs), + SubPostsRootProof: atx.NIPosts[nipostIndex].PostsRootProof(atx.PreviousATXs), + + SubPostRoot: atx.NIPosts[nipostIndex].Posts[postIndex].Root(atx.PreviousATXs), + SubPostRootProof: atx.NIPosts[nipostIndex].Posts.Proof(postIndex, atx.PreviousATXs), + SubPostRootIndex: uint16(postIndex), + + MarriageIndexProof: atx.NIPosts[nipostIndex].Posts[postIndex].MarriageIndexProof(atx.PreviousATXs), + MarriageProof: marriageProof, + + PrevATXProof: atx.NIPosts[nipostIndex].Posts[postIndex].PrevATXProof(prevATX), + + SmesherID: atx.SmesherID, + Signature: atx.Signature, + } + + return proof, nil +} + +func (p ProofInvalidPrevAtxV2) Valid(_ context.Context, malValidator MalfeasanceValidator) (types.NodeID, error) { + if p.Proofs[0].ATXID == p.Proofs[1].ATXID { + return types.EmptyNodeID, errors.New("proofs have the same ATX ID") + } + if err := p.Proofs[0].Valid(p.PrevATX, p.NodeID, malValidator); err != nil { + return types.EmptyNodeID, fmt.Errorf("proof 1 is invalid: %w", err) + } + if err := p.Proofs[1].Valid(p.PrevATX, p.NodeID, malValidator); err != nil { + return types.EmptyNodeID, fmt.Errorf("proof 2 is invalid: %w", err) + } + return p.NodeID, nil +} + +// ProofInvalidPrevAtxV1 is a proof that two ATXs published by an identity reference the same previous ATX for an +// identity. +// +// We are proving the following: +// 1. Both ATXs have a valid signature. +// 2. Both ATXs reference the same previous ATX for the same identity. +// 3. If the signer of the ATXv2 is not the identity that referenced the same previous ATX, then the included marriage +// proof is valid. +// 4. The ATXv1 has been signed by the identity that referenced the same previous ATX. +type ProofInvalidPrevAtxV1 struct { + // NodeID is the node ID that referenced the same previous ATX twice. + NodeID types.NodeID + + // PrevATX is the ATX that was referenced twice. + PrevATX types.ATXID + + Proof InvalidPrevAtxProof + ATXv1 ActivationTxV1 +} + +var _ Proof = &ProofInvalidPrevAtxV1{} + +func NewInvalidPrevAtxProofV1( + db sql.Executor, + atx1 *ActivationTxV2, + atx2 *ActivationTxV1, + nodeID types.NodeID, +) (*ProofInvalidPrevAtxV1, error) { + if atx1.SmesherID != nodeID && atx1.MarriageATX == nil { + return nil, errors.New("ATX1 is not a merged ATX, but NodeID is different from SmesherID") + } + + if atx2.SmesherID != nodeID { + return nil, errors.New("ATX2 is not signed by NodeID") + } + + var marriageProof *MarriageProof + nipostIndex := 0 + postIndex := 0 + if atx1.SmesherID != nodeID { + proof, err := createMarriageProof(db, atx1, nodeID) + if err != nil { + return nil, fmt.Errorf("marriage proof: %w", err) + } + marriageProof = &proof + for i, nipost := range atx1.NIPosts { + postIndex = slices.IndexFunc(nipost.Posts, func(post SubPostV2) bool { + return post.MarriageIndex == proof.NodeIDMarryProof.CertificateIndex + }) + if postIndex != -1 { + nipostIndex = i + break + } + } + if postIndex == -1 { + return nil, fmt.Errorf("no PoST from %s in ATX", nodeID.ShortString()) + } + } + prevATX1 := atx1.PreviousATXs[atx1.NIPosts[nipostIndex].Posts[postIndex].PrevATXIndex] + prevATX2 := atx2.PrevATXID + if prevATX1 != prevATX2 { + return nil, errors.New("ATXs reference different previous ATXs") + } + + proof, err := createInvalidPrevAtxProof(atx1, prevATX1, nipostIndex, postIndex, marriageProof) + if err != nil { + return nil, fmt.Errorf("proof for atx1: %w", err) + } + + return &ProofInvalidPrevAtxV1{ + NodeID: nodeID, + PrevATX: prevATX1, + Proof: proof, + ATXv1: *atx2, + }, nil +} + +func (p ProofInvalidPrevAtxV1) Valid(_ context.Context, malValidator MalfeasanceValidator) (types.NodeID, error) { + if err := p.Proof.Valid(p.PrevATX, p.NodeID, malValidator); err != nil { + return types.EmptyNodeID, fmt.Errorf("proof is invalid: %w", err) + } + if !malValidator.Signature(signing.ATX, p.ATXv1.SmesherID, p.ATXv1.SignedBytes(), p.ATXv1.Signature) { + return types.EmptyNodeID, errors.New("invalid ATX signature") + } + if p.NodeID != p.ATXv1.SmesherID { + return types.EmptyNodeID, errors.New("ATXv1 has not been signed by the same identity") + } + if p.ATXv1.PrevATXID != p.PrevATX { + return types.EmptyNodeID, errors.New("ATXv1 references a different previous ATX") + } + return p.NodeID, nil +} + +type InvalidPrevAtxProof struct { + // ATXID is the ID of the ATX being proven. + ATXID types.ATXID + // SmesherID is the ID of the smesher that published the ATX. + SmesherID types.NodeID + // Signature is the signature of the ATXID by the smesher. + Signature types.EdSignature + + // NIPostsRoot and its proof that it is contained in the ATX. + NIPostsRoot NIPostsRoot + NIPostsRootProof NIPostsRootProof `scale:"max=32"` + + // NIPostRoot and its proof that it is contained at the given index in the NIPostsRoot. + NIPostRoot NIPostRoot + NIPostRootProof NIPostRootProof `scale:"max=32"` + NIPostIndex uint16 + + // SubPostsRoot and its proof that it is contained in the NIPostRoot. + SubPostsRoot SubPostsRoot + SubPostsRootProof SubPostsRootProof `scale:"max=32"` + + // SubPostRoot and its proof that is contained at the given index in the SubPostsRoot. + SubPostRoot SubPostRoot + SubPostRootProof SubPostRootProof `scale:"max=32"` + SubPostRootIndex uint16 + + // MarriageProof is the proof that NodeID and SmesherID are married. It is nil if NodeID == SmesherID. + MarriageProof *MarriageProof + + // MarriageIndexProof is the proof that the MarriageIndex (CertificateIndex from NodeIDMarryProof) is contained in + // the SubPostRoot. + MarriageIndexProof MarriageIndexProof `scale:"max=32"` + + // PrevATXProof is the proof that the previous ATX is contained in the SubPostRoot. + PrevATXProof PrevATXProof `scale:"max=32"` +} + +func (p InvalidPrevAtxProof) Valid(prevATX types.ATXID, nodeID types.NodeID, malValidator MalfeasanceValidator) error { + if !malValidator.Signature(signing.ATX, p.SmesherID, p.ATXID.Bytes(), p.Signature) { + return errors.New("invalid ATX signature") + } + + if nodeID != p.SmesherID && p.MarriageProof == nil { + return errors.New("missing marriage proof") + } + + if !p.NIPostsRootProof.Valid(p.ATXID, p.NIPostsRoot) { + return errors.New("invalid NIPosts root proof") + } + if !p.NIPostRootProof.Valid(p.NIPostsRoot, int(p.NIPostIndex), p.NIPostRoot) { + return errors.New("invalid NIPoST root proof") + } + if !p.SubPostsRootProof.Valid(p.NIPostRoot, p.SubPostsRoot) { + return errors.New("invalid sub PoSTs root proof") + } + if !p.SubPostRootProof.Valid(p.SubPostsRoot, int(p.SubPostRootIndex), p.SubPostRoot) { + return errors.New("invalid sub PoST root proof") + } + + var marriageIndex *uint32 + if p.MarriageProof != nil { + if err := p.MarriageProof.Valid(malValidator, p.ATXID, nodeID, p.SmesherID); err != nil { + return fmt.Errorf("invalid marriage proof: %w", err) + } + marriageIndex = &p.MarriageProof.NodeIDMarryProof.CertificateIndex + } + if marriageIndex != nil { + if !p.MarriageIndexProof.Valid(p.SubPostRoot, *marriageIndex) { + return errors.New("invalid marriage index proof") + } + } + + if !p.PrevATXProof.Valid(p.SubPostRoot, prevATX) { + return errors.New("invalid previous ATX proof") + } + + return nil +} diff --git a/activation/wire/malfeasance_invalid_prev_atx_scale.go b/activation/wire/malfeasance_invalid_prev_atx_scale.go new file mode 100644 index 0000000000..15b06acc28 --- /dev/null +++ b/activation/wire/malfeasance_invalid_prev_atx_scale.go @@ -0,0 +1,364 @@ +// Code generated by github.com/spacemeshos/go-scale/scalegen. DO NOT EDIT. + +// nolint +package wire + +import ( + "github.com/spacemeshos/go-scale" + "github.com/spacemeshos/go-spacemesh/common/types" +) + +func (t *ProofInvalidPrevAtxV2) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeByteArray(enc, t.NodeID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.PrevATX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructArray(enc, t.Proofs[:]) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *ProofInvalidPrevAtxV2) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := scale.DecodeByteArray(dec, t.NodeID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.PrevATX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeStructArray(dec, t.Proofs[:]) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *ProofInvalidPrevAtxV1) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeByteArray(enc, t.NodeID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.PrevATX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.Proof.EncodeScale(enc) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.ATXv1.EncodeScale(enc) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *ProofInvalidPrevAtxV1) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := scale.DecodeByteArray(dec, t.NodeID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.PrevATX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.Proof.DecodeScale(dec) + if err != nil { + return total, err + } + total += n + } + { + n, err := t.ATXv1.DecodeScale(dec) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *InvalidPrevAtxProof) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeByteArray(enc, t.ATXID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.SmesherID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.Signature[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.NIPostsRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.NIPostsRootProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.NIPostRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.NIPostRootProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeCompact16(enc, uint16(t.NIPostIndex)) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.SubPostsRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.SubPostsRootProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.SubPostRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.SubPostRootProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeCompact16(enc, uint16(t.SubPostRootIndex)) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeOption(enc, t.MarriageProof) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.MarriageIndexProof, 32) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.PrevATXProof, 32) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *InvalidPrevAtxProof) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := scale.DecodeByteArray(dec, t.ATXID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.SmesherID[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.Signature[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.NIPostsRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.NIPostsRootProof = field + } + { + n, err := scale.DecodeByteArray(dec, t.NIPostRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.NIPostRootProof = field + } + { + field, n, err := scale.DecodeCompact16(dec) + if err != nil { + return total, err + } + total += n + t.NIPostIndex = uint16(field) + } + { + n, err := scale.DecodeByteArray(dec, t.SubPostsRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.SubPostsRootProof = field + } + { + n, err := scale.DecodeByteArray(dec, t.SubPostRoot[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.SubPostRootProof = field + } + { + field, n, err := scale.DecodeCompact16(dec) + if err != nil { + return total, err + } + total += n + t.SubPostRootIndex = uint16(field) + } + { + field, n, err := scale.DecodeOption[MarriageProof](dec) + if err != nil { + return total, err + } + total += n + t.MarriageProof = field + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.MarriageIndexProof = field + } + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 32) + if err != nil { + return total, err + } + total += n + t.PrevATXProof = field + } + return total, nil +} diff --git a/activation/wire/malfeasance_invalid_prev_atx_test.go b/activation/wire/malfeasance_invalid_prev_atx_test.go new file mode 100644 index 0000000000..1c829e74eb --- /dev/null +++ b/activation/wire/malfeasance_invalid_prev_atx_test.go @@ -0,0 +1,974 @@ +package wire + +import ( + "context" + "fmt" + "slices" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" +) + +func Test_InvalidPrevAtxProofV2(t *testing.T) { + t.Parallel() + + // sig is the identity that creates the ATXs referencing the same prevATX + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + // pubSig is the identity that publishes a merged ATX with the same prevATX + pubSig, err := signing.NewEdSigner() + require.NoError(t, err) + + // marrySig is the identity that publishes the marriage ATX + marrySig, err := signing.NewEdSigner() + require.NoError(t, err) + + edVerifier := signing.NewEdVerifier() + + newMergedATXv2 := func( + db sql.Executor, + prevATX types.ATXID, + ) *ActivationTxV2 { + wInitialAtx := newActivationTxV2( + withInitial(types.RandomATXID(), PostV1{}), + ) + wInitialAtx.Sign(sig) + initialAtx := &types.ActivationTx{ + CommitmentATX: &wInitialAtx.Initial.CommitmentATX, + } + initialAtx.SetID(wInitialAtx.ID()) + initialAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, initialAtx, wInitialAtx.Blob())) + + wPubInitialAtx := newActivationTxV2( + withInitial(types.RandomATXID(), PostV1{}), + ) + wPubInitialAtx.Sign(pubSig) + pubInitialAtx := &types.ActivationTx{} + pubInitialAtx.SetID(wPubInitialAtx.ID()) + pubInitialAtx.SmesherID = pubSig.NodeID() + require.NoError(t, atxs.Add(db, pubInitialAtx, wPubInitialAtx.Blob())) + + marryInitialAtx := types.RandomATXID() + + wMarriageAtx := newActivationTxV2( + withMarriageCertificate(marrySig, types.EmptyATXID, marrySig.NodeID()), + withMarriageCertificate(sig, wInitialAtx.ID(), marrySig.NodeID()), + withMarriageCertificate(pubSig, wPubInitialAtx.ID(), marrySig.NodeID()), + ) + wMarriageAtx.Sign(marrySig) + + marriageAtx := &types.ActivationTx{} + marriageAtx.SetID(wMarriageAtx.ID()) + marriageAtx.SmesherID = marrySig.NodeID() + require.NoError(t, atxs.Add(db, marriageAtx, wMarriageAtx.Blob())) + + atx := newActivationTxV2( + withPreviousATXs(marryInitialAtx, wPubInitialAtx.ID(), prevATX), + withMarriageATX(wMarriageAtx.ID()), + withNIPost( + withNIPostMembershipProof(MerkleProofV2{}), + withNIPostSubPost(SubPostV2{ + MarriageIndex: 0, + PrevATXIndex: 0, + }), + withNIPostSubPost(SubPostV2{ + MarriageIndex: 1, + PrevATXIndex: 2, + }), + withNIPostSubPost(SubPostV2{ + MarriageIndex: 2, + PrevATXIndex: 1, + }), + ), + ) + atx.Sign(pubSig) + return atx + } + + t.Run("valid", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATXID := types.RandomATXID() + atx1 := newActivationTxV2( + withPreviousATXs(prevATXID), + withPublishEpoch(5), + ) + atx1.Sign(sig) + atx2 := newActivationTxV2( + withPreviousATXs(prevATXID), + withPublishEpoch(7), + ) + atx2.Sign(sig) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx2, sig.NodeID()) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // verify the proof + id, err := proof.Valid(context.Background(), verifier) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), id) + }) + + t.Run("valid merged & solo atx", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATXID := types.RandomATXID() + prevAtx := &types.ActivationTx{} + prevAtx.SetID(prevATXID) + prevAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, prevAtx, types.AtxBlob{})) + atx1 := newActivationTxV2( + withPreviousATXs(prevATXID), + withPublishEpoch(5), + ) + atx1.Sign(sig) + atx2 := newMergedATXv2(db, prevATXID) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx2, sig.NodeID()) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // verify the proof + id, err := proof.Valid(context.Background(), verifier) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), id) + }) + + // valid merged & merged is covered by either double marry or double merge proofs + + t.Run("same ATX ID", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + atx1 := newActivationTxV2( + withPreviousATXs(types.RandomATXID()), + ) + atx1.Sign(sig) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx1, sig.NodeID()) + require.ErrorContains(t, err, "ATXs have the same ID") + require.Nil(t, proof) + }) + + t.Run("smesher ID mismatch", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATX := types.RandomATXID() + atx1 := newActivationTxV2( + withPreviousATXs(prevATX), + ) + atx1.Sign(sig) + atx2 := newActivationTxV2( + withPreviousATXs(prevATX), + ) + atx2.Sign(pubSig) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx2, sig.NodeID()) + require.EqualError(t, err, "ATX2 is not a merged ATX, but NodeID is different from SmesherID") + require.Nil(t, proof) + + proof, err = NewInvalidPrevAtxProofV2(db, atx1, atx2, pubSig.NodeID()) + require.EqualError(t, err, "ATX1 is not a merged ATX, but NodeID is different from SmesherID") + require.Nil(t, proof) + }) + + t.Run("id not married to smesher", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + invalidSig, err := signing.NewEdSigner() + require.NoError(t, err) + + prevATXID := types.RandomATXID() + prevAtx := &types.ActivationTx{} + prevAtx.SetID(prevATXID) + prevAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, prevAtx, types.AtxBlob{})) + atx1 := newActivationTxV2( + withPreviousATXs(prevATXID), + withPublishEpoch(5), + ) + atx1.Sign(invalidSig) + atx2 := newMergedATXv2(db, prevATXID) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx2, invalidSig.NodeID()) + require.ErrorContains(t, err, + fmt.Sprintf("does not contain a marriage certificate signed by %s", invalidSig.NodeID().ShortString()), + ) + require.Nil(t, proof) + + proof, err = NewInvalidPrevAtxProofV2(db, atx2, atx1, invalidSig.NodeID()) + require.ErrorContains(t, err, + fmt.Sprintf("does not contain a marriage certificate signed by %s", invalidSig.NodeID().ShortString()), + ) + require.Nil(t, proof) + }) + + t.Run("merged ATX does not contain post from identity", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATXID := types.RandomATXID() + prevAtx := &types.ActivationTx{} + prevAtx.SetID(prevATXID) + prevAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, prevAtx, types.AtxBlob{})) + atx1 := newActivationTxV2( + withPreviousATXs(prevATXID), + withPublishEpoch(5), + ) + atx1.Sign(sig) + atx2 := newMergedATXv2(db, prevATXID) + + // remove the post from sig in the merged ATX + atx2.NIPosts[0].Posts = slices.DeleteFunc(atx2.NIPosts[0].Posts, func(subPost SubPostV2) bool { + return subPost.MarriageIndex == 1 + }) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx2, sig.NodeID()) + require.ErrorContains(t, err, + fmt.Sprintf("no PoST from %s in ATX", sig.NodeID().ShortString()), + ) + require.Nil(t, proof) + + proof, err = NewInvalidPrevAtxProofV2(db, atx2, atx1, sig.NodeID()) + require.ErrorContains(t, err, + fmt.Sprintf("no PoST from %s in ATX", sig.NodeID().ShortString()), + ) + require.Nil(t, proof) + }) + + t.Run("prev ATX differs between ATXs", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + atx1 := newActivationTxV2( + withPreviousATXs(types.RandomATXID()), + withPublishEpoch(5), + ) + atx1.Sign(sig) + atx2 := newActivationTxV2( + withPreviousATXs(types.RandomATXID()), + withPublishEpoch(7), + ) + atx2.Sign(sig) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx2, sig.NodeID()) + require.ErrorContains(t, err, "ATXs reference different previous ATXs") + require.Nil(t, proof) + }) + + t.Run("invalid solo proof", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATXID := types.RandomATXID() + atx1 := newActivationTxV2( + withPreviousATXs(prevATXID), + withPublishEpoch(5), + ) + atx1.Sign(sig) + atx2 := newActivationTxV2( + withPreviousATXs(prevATXID), + withPublishEpoch(7), + ) + atx2.Sign(sig) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx2, sig.NodeID()) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // same ATX ID + proof.Proofs[0].ATXID = atx2.ID() + id, err := proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proofs have the same ATX ID") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].ATXID = atx1.ID() + + // invalid prev ATX + proof.PrevATX = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid previous ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.PrevATX = prevATXID + + // invalid node ID + proof.NodeID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "missing marriage proof") + require.Equal(t, types.EmptyNodeID, id) + proof.NodeID = sig.NodeID() + + // invalid ATX ID + proof.Proofs[0].ATXID = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].ATXID = atx1.ID() + + proof.Proofs[1].ATXID = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].ATXID = atx2.ID() + + // invalid SmesherID + proof.Proofs[0].SmesherID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].SmesherID = sig.NodeID() + + proof.Proofs[1].SmesherID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].SmesherID = sig.NodeID() + + // invalid signature + proof.Proofs[0].Signature = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].Signature = atx1.Signature + + proof.Proofs[1].Signature = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].Signature = atx2.Signature + + // invalid NIPosts root + nipostsRoot := proof.Proofs[0].NIPostsRoot + proof.Proofs[0].NIPostsRoot = NIPostsRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid NIPosts root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].NIPostsRoot = nipostsRoot + + nipostsRoot = proof.Proofs[1].NIPostsRoot + proof.Proofs[1].NIPostsRoot = NIPostsRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid NIPosts root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].NIPostsRoot = nipostsRoot + + // invalid NIPosts root proof + hash := proof.Proofs[0].NIPostsRootProof[0] + proof.Proofs[0].NIPostsRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid NIPosts root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].NIPostsRootProof[0] = hash + + hash = proof.Proofs[1].NIPostsRootProof[0] + proof.Proofs[1].NIPostsRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid NIPosts root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].NIPostsRootProof[0] = hash + + // invalid NIPost root + nipostRoot := proof.Proofs[0].NIPostRoot + proof.Proofs[0].NIPostRoot = NIPostRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].NIPostRoot = nipostRoot + + nipostRoot = proof.Proofs[1].NIPostRoot + proof.Proofs[1].NIPostRoot = NIPostRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].NIPostRoot = nipostRoot + + // invalid NIPost root proof + hash = proof.Proofs[0].NIPostRootProof[0] + proof.Proofs[0].NIPostRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].NIPostRootProof[0] = hash + + hash = proof.Proofs[1].NIPostRootProof[0] + proof.Proofs[1].NIPostRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].NIPostRootProof[0] = hash + + // invalid NIPost index + proof.Proofs[0].NIPostIndex++ + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].NIPostIndex-- + + proof.Proofs[1].NIPostIndex++ + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid NIPoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].NIPostIndex-- + + // invalid sub posts root + subPostsRoot := proof.Proofs[0].SubPostsRoot + proof.Proofs[0].SubPostsRoot = SubPostsRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid sub PoSTs root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].SubPostsRoot = subPostsRoot + + subPostsRoot = proof.Proofs[1].SubPostsRoot + proof.Proofs[1].SubPostsRoot = SubPostsRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid sub PoSTs root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].SubPostsRoot = subPostsRoot + + // invalid sub posts root proof + hash = proof.Proofs[0].SubPostsRootProof[0] + proof.Proofs[0].SubPostsRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid sub PoSTs root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].SubPostsRootProof[0] = hash + + hash = proof.Proofs[1].SubPostsRootProof[0] + proof.Proofs[1].SubPostsRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid sub PoSTs root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].SubPostsRootProof[0] = hash + + // invalid sub post root + subPostRoot := proof.Proofs[0].SubPostRoot + proof.Proofs[0].SubPostRoot = SubPostRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].SubPostRoot = subPostRoot + + subPostRoot = proof.Proofs[1].SubPostRoot + proof.Proofs[1].SubPostRoot = SubPostRoot(types.RandomHash()) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].SubPostRoot = subPostRoot + + // invalid sub post root proof + hash = proof.Proofs[0].SubPostRootProof[0] + proof.Proofs[0].SubPostRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].SubPostRootProof[0] = hash + + hash = proof.Proofs[1].SubPostRootProof[0] + proof.Proofs[1].SubPostRootProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].SubPostRootProof[0] = hash + + // invalid sub post index + proof.Proofs[0].SubPostRootIndex++ + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].SubPostRootIndex-- + + proof.Proofs[1].SubPostRootIndex++ + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid sub PoST root proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].SubPostRootIndex-- + + // invalid prev atx proof + hash = proof.Proofs[0].PrevATXProof[0] + proof.Proofs[0].PrevATXProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid previous ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].PrevATXProof[0] = hash + + hash = proof.Proofs[1].PrevATXProof[0] + proof.Proofs[1].PrevATXProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid previous ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].PrevATXProof[0] = hash + }) + + t.Run("invalid merged proof", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATXID := types.RandomATXID() + prevAtx := &types.ActivationTx{} + prevAtx.SetID(prevATXID) + prevAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, prevAtx, types.AtxBlob{})) + atx1 := newActivationTxV2( + withPreviousATXs(prevATXID), + withPublishEpoch(5), + ) + atx1.Sign(sig) + atx2 := newMergedATXv2(db, prevATXID) + + proof, err := NewInvalidPrevAtxProofV2(db, atx1, atx2, sig.NodeID()) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // invalid node ID + proof.NodeID = types.RandomNodeID() + id, err := proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "missing marriage proof") + require.Equal(t, types.EmptyNodeID, id) + proof.NodeID = sig.NodeID() + + // invalid ATX ID + proof.Proofs[0].ATXID = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].ATXID = atx1.ID() + + proof.Proofs[1].ATXID = types.RandomATXID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].ATXID = atx2.ID() + + // invalid SmesherID + proof.Proofs[0].SmesherID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].SmesherID = sig.NodeID() + + proof.Proofs[1].SmesherID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].SmesherID = pubSig.NodeID() + + // invalid signature + proof.Proofs[0].Signature = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].Signature = atx1.Signature + + proof.Proofs[1].Signature = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].Signature = atx2.Signature + + // missing marriage proof + marriageProof := proof.Proofs[1].MarriageProof + proof.Proofs[1].MarriageProof = nil + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "missing marriage proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].MarriageProof = marriageProof + + // invalid marriage index proof + hash := proof.Proofs[1].MarriageIndexProof[0] + proof.Proofs[1].MarriageIndexProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid marriage index proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].MarriageIndexProof[0] = hash + + // invalid prev atx proof + hash = proof.Proofs[0].PrevATXProof[0] + proof.Proofs[0].PrevATXProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 1 is invalid: invalid previous ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[0].PrevATXProof[0] = hash + + hash = proof.Proofs[1].PrevATXProof[0] + proof.Proofs[1].PrevATXProof[0] = types.RandomHash() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "proof 2 is invalid: invalid previous ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.Proofs[1].PrevATXProof[0] = hash + }) +} + +func Test_InvalidPrevAtxProofV1(t *testing.T) { + t.Parallel() + + // sig is the identity that creates the ATXs referencing the same prevATX + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + // pubSig is the identity that publishes a merged ATX with the same prevATX + pubSig, err := signing.NewEdSigner() + require.NoError(t, err) + + // marrySig is the identity that publishes the marriage ATX + marrySig, err := signing.NewEdSigner() + require.NoError(t, err) + + edVerifier := signing.NewEdVerifier() + + newMergedATXv2 := func( + db sql.Executor, + prevATX types.ATXID, + ) *ActivationTxV2 { + wInitialAtx := newActivationTxV2( + withInitial(types.RandomATXID(), PostV1{}), + ) + wInitialAtx.Sign(sig) + initialAtx := &types.ActivationTx{ + CommitmentATX: &wInitialAtx.Initial.CommitmentATX, + } + initialAtx.SetID(wInitialAtx.ID()) + initialAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, initialAtx, wInitialAtx.Blob())) + + wPubInitialAtx := newActivationTxV2( + withInitial(types.RandomATXID(), PostV1{}), + ) + wPubInitialAtx.Sign(pubSig) + pubInitialAtx := &types.ActivationTx{} + pubInitialAtx.SetID(wPubInitialAtx.ID()) + pubInitialAtx.SmesherID = pubSig.NodeID() + require.NoError(t, atxs.Add(db, pubInitialAtx, wPubInitialAtx.Blob())) + + marryInitialAtx := types.RandomATXID() + + wMarriageAtx := newActivationTxV2( + withMarriageCertificate(marrySig, types.EmptyATXID, marrySig.NodeID()), + withMarriageCertificate(sig, wInitialAtx.ID(), marrySig.NodeID()), + withMarriageCertificate(pubSig, wPubInitialAtx.ID(), marrySig.NodeID()), + ) + wMarriageAtx.Sign(marrySig) + + marriageAtx := &types.ActivationTx{} + marriageAtx.SetID(wMarriageAtx.ID()) + marriageAtx.SmesherID = marrySig.NodeID() + require.NoError(t, atxs.Add(db, marriageAtx, wMarriageAtx.Blob())) + + atx := newActivationTxV2( + withPreviousATXs(marryInitialAtx, wPubInitialAtx.ID(), prevATX), + withMarriageATX(wMarriageAtx.ID()), + withNIPost( + withNIPostMembershipProof(MerkleProofV2{}), + withNIPostSubPost(SubPostV2{ + MarriageIndex: 0, + PrevATXIndex: 0, + }), + withNIPostSubPost(SubPostV2{ + MarriageIndex: 1, + PrevATXIndex: 2, + }), + withNIPostSubPost(SubPostV2{ + MarriageIndex: 2, + PrevATXIndex: 1, + }), + ), + ) + atx.Sign(pubSig) + return atx + } + + t.Run("valid", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATX := types.RandomATXID() + atxv1 := &ActivationTxV1{ + InnerActivationTxV1: InnerActivationTxV1{ + NIPostChallengeV1: NIPostChallengeV1{ + PublishEpoch: 5, + PrevATXID: prevATX, + PositioningATXID: types.RandomATXID(), + }, + }, + } + atxv1.Sign(sig) + + atxv2 := newActivationTxV2( + withPreviousATXs(prevATX), + withPublishEpoch(7), + ) + atxv2.Sign(sig) + + proof, err := NewInvalidPrevAtxProofV1(db, atxv2, atxv1, sig.NodeID()) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // verify the proof + id, err := proof.Valid(context.Background(), verifier) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), id) + }) + + t.Run("valid merged", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATX := types.RandomATXID() + atxv1 := &ActivationTxV1{ + InnerActivationTxV1: InnerActivationTxV1{ + NIPostChallengeV1: NIPostChallengeV1{ + PublishEpoch: 5, + PrevATXID: prevATX, + PositioningATXID: types.RandomATXID(), + }, + }, + } + atxv1.Sign(sig) + + atxv2 := newMergedATXv2(db, prevATX) + + proof, err := NewInvalidPrevAtxProofV1(db, atxv2, atxv1, sig.NodeID()) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // verify the proof + id, err := proof.Valid(context.Background(), verifier) + require.NoError(t, err) + require.Equal(t, sig.NodeID(), id) + }) + + t.Run("smesher ID mismatch", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATX := types.RandomATXID() + atxv1 := &ActivationTxV1{ + InnerActivationTxV1: InnerActivationTxV1{ + NIPostChallengeV1: NIPostChallengeV1{ + PublishEpoch: 5, + PrevATXID: prevATX, + PositioningATXID: types.RandomATXID(), + }, + }, + } + atxv1.Sign(sig) + + atxv2 := newActivationTxV2( + withPreviousATXs(prevATX), + withPublishEpoch(7), + ) + atxv2.Sign(pubSig) + + proof, err := NewInvalidPrevAtxProofV1(db, atxv2, atxv1, pubSig.NodeID()) + require.EqualError(t, err, "ATX2 is not signed by NodeID") + require.Nil(t, proof) + + proof, err = NewInvalidPrevAtxProofV1(db, atxv2, atxv1, sig.NodeID()) + require.EqualError(t, err, "ATX1 is not a merged ATX, but NodeID is different from SmesherID") + require.Nil(t, proof) + }) + + t.Run("id not married to smesher", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + invalidSig, err := signing.NewEdSigner() + require.NoError(t, err) + + prevATX := types.RandomATXID() + atxv1 := &ActivationTxV1{ + InnerActivationTxV1: InnerActivationTxV1{ + NIPostChallengeV1: NIPostChallengeV1{ + PublishEpoch: 5, + PrevATXID: prevATX, + PositioningATXID: types.RandomATXID(), + }, + }, + } + atxv1.Sign(invalidSig) + + atxv2 := newMergedATXv2(db, prevATX) + + proof, err := NewInvalidPrevAtxProofV1(db, atxv2, atxv1, invalidSig.NodeID()) + require.ErrorContains(t, err, + fmt.Sprintf("does not contain a marriage certificate signed by %s", invalidSig.NodeID().ShortString()), + ) + require.Nil(t, proof) + }) + + t.Run("merged ATX does not contain post from identity", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATX := types.RandomATXID() + atxv1 := &ActivationTxV1{ + InnerActivationTxV1: InnerActivationTxV1{ + NIPostChallengeV1: NIPostChallengeV1{ + PublishEpoch: 5, + PrevATXID: prevATX, + PositioningATXID: types.RandomATXID(), + }, + }, + } + atxv1.Sign(sig) + + atxv2 := newMergedATXv2(db, prevATX) + + // remove the post from sig in the merged ATX + atxv2.NIPosts[0].Posts = slices.DeleteFunc(atxv2.NIPosts[0].Posts, func(subPost SubPostV2) bool { + return subPost.MarriageIndex == 1 + }) + + proof, err := NewInvalidPrevAtxProofV1(db, atxv2, atxv1, sig.NodeID()) + require.ErrorContains(t, err, + fmt.Sprintf("no PoST from %s in ATX", sig.NodeID().ShortString()), + ) + require.Nil(t, proof) + }) + + t.Run("prev ATX differs between ATXs", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATX := types.RandomATXID() + atxv1 := &ActivationTxV1{ + InnerActivationTxV1: InnerActivationTxV1{ + NIPostChallengeV1: NIPostChallengeV1{ + PublishEpoch: 5, + PrevATXID: prevATX, + PositioningATXID: types.RandomATXID(), + }, + }, + } + atxv1.Sign(sig) + + atxv2 := newActivationTxV2( + withPreviousATXs(types.RandomATXID()), + withPublishEpoch(7), + ) + atxv2.Sign(sig) + + proof, err := NewInvalidPrevAtxProofV1(db, atxv2, atxv1, sig.NodeID()) + require.ErrorContains(t, err, "ATXs reference different previous ATXs") + require.Nil(t, proof) + }) + + t.Run("invalid proof", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + prevATX := types.RandomATXID() + atxv1 := &ActivationTxV1{ + InnerActivationTxV1: InnerActivationTxV1{ + NIPostChallengeV1: NIPostChallengeV1{ + PublishEpoch: 5, + PrevATXID: prevATX, + PositioningATXID: types.RandomATXID(), + }, + }, + } + atxv1.Sign(sig) + + atxv2 := newActivationTxV2( + withPreviousATXs(prevATX), + withPublishEpoch(7), + ) + atxv2.Sign(sig) + + proof, err := NewInvalidPrevAtxProofV1(db, atxv2, atxv1, sig.NodeID()) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // invalid PrevATX + proof.PrevATX = types.RandomATXID() + id, err := proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid previous ATX proof") + require.Equal(t, types.EmptyNodeID, id) + proof.PrevATX = prevATX + + // invalid SmesherID for atxv1 + proof.ATXv1.SmesherID = types.RandomNodeID() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.ATXv1.SmesherID = sig.NodeID() + + // invalid signature for atxv1 + proof.ATXv1.Signature = types.RandomEdSignature() + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "invalid ATX signature") + require.Equal(t, types.EmptyNodeID, id) + proof.ATXv1.Signature = atxv1.Signature + + // signer of atxv1 does not match + proof.ATXv1.Sign(pubSig) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "ATXv1 has not been signed by the same identity") + require.Equal(t, types.EmptyNodeID, id) + proof.ATXv1.Sign(sig) + + // prevATX of atxv1 does not match + proof.ATXv1.PrevATXID = types.RandomATXID() + proof.ATXv1.Sign(sig) + id, err = proof.Valid(context.Background(), verifier) + require.ErrorContains(t, err, "ATXv1 references a different previous ATX") + require.Equal(t, types.EmptyNodeID, id) + proof.ATXv1.PrevATXID = prevATX + proof.ATXv1.Sign(sig) + }) +} diff --git a/activation/wire/malfeasance_shared_test.go b/activation/wire/malfeasance_shared_test.go new file mode 100644 index 0000000000..46fbccea11 --- /dev/null +++ b/activation/wire/malfeasance_shared_test.go @@ -0,0 +1,362 @@ +package wire + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" +) + +func Test_MarryProof(t *testing.T) { + t.Parallel() + + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + + edVerifier := signing.NewEdVerifier() + + t.Run("valid", func(t *testing.T) { + t.Parallel() + + db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} + otherAtx.SetID(types.RandomATXID()) + otherAtx.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) + + atx1 := newActivationTxV2( + withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), + withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), + ) + atx1.Sign(sig) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + // valid for otherSig + proof, err := createMarryProof(db, atx1, otherSig.NodeID()) + require.NoError(t, err) + require.NotEmpty(t, proof) + + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), otherSig.NodeID()) + require.NoError(t, err) + + // valid for sig + proof, err = createMarryProof(db, atx1, sig.NodeID()) + require.NoError(t, err) + require.NotEmpty(t, proof) + + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), sig.NodeID()) + require.NoError(t, err) + }) + + t.Run("identity not included in certificates", func(t *testing.T) { + t.Parallel() + + db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} + otherAtx.SetID(types.RandomATXID()) + otherAtx.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) + + atx1 := newActivationTxV2( + withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), + withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), + ) + atx1.Sign(sig) + + nodeID := types.RandomNodeID() + proof, err := createMarryProof(db, atx1, nodeID) + require.EqualError(t, err, + fmt.Sprintf("does not contain a marriage certificate signed by %s", nodeID.ShortString()), + ) + require.Empty(t, proof) + }) + + t.Run("invalid proof", func(t *testing.T) { + t.Parallel() + + db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} + otherAtx.SetID(types.RandomATXID()) + otherAtx.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) + + atx1 := newActivationTxV2( + withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), + withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), + ) + atx1.Sign(sig) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + proof, err := createMarryProof(db, atx1, otherSig.NodeID()) + require.NoError(t, err) + require.NotEmpty(t, proof) + + // not valid for random NodeID + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), types.RandomNodeID()) + require.EqualError(t, err, "invalid certificate signature") + + // not valid for another ATX + err = proof.Valid(verifier, types.RandomATXID(), sig.NodeID(), otherSig.NodeID()) + require.EqualError(t, err, "invalid marriage proof") + + // not valid if certificate signature is invalid + certSig := proof.Certificate.Signature + proof.Certificate.Signature = types.RandomEdSignature() + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), otherSig.NodeID()) + require.EqualError(t, err, "invalid certificate signature") + proof.Certificate.Signature = certSig + + // not valid if marriage root is invalid + marriageRoot := proof.MarriageCertificatesRoot + proof.MarriageCertificatesRoot = MarriageCertificatesRoot(types.RandomHash()) + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), otherSig.NodeID()) + require.EqualError(t, err, "invalid marriage proof") + proof.MarriageCertificatesRoot = marriageRoot + + // not valid if marriage root proof is invalid + hash := proof.MarriageCertificatesProof[0] + proof.MarriageCertificatesProof[0] = types.RandomHash() + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), otherSig.NodeID()) + require.EqualError(t, err, "invalid marriage proof") + proof.MarriageCertificatesProof[0] = hash + + // not valid if certificate proof is invalid + index := proof.CertificateIndex + proof.CertificateIndex = 100 + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), otherSig.NodeID()) + require.EqualError(t, err, "invalid certificate proof") + proof.CertificateIndex = index + + certProof := proof.CertificateProof + proof.CertificateProof = MarriageCertificateProof{types.RandomHash()} + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), otherSig.NodeID()) + require.EqualError(t, err, "invalid certificate proof") + proof.CertificateProof = certProof + + hash = proof.CertificateProof[0] + proof.CertificateProof[0] = types.RandomHash() + err = proof.Valid(verifier, atx1.ID(), sig.NodeID(), otherSig.NodeID()) + require.EqualError(t, err, "invalid certificate proof") + proof.CertificateProof[0] = hash + }) +} + +func Test_MarriageProof(t *testing.T) { + t.Parallel() + + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + otherSig, err := signing.NewEdSigner() + require.NoError(t, err) + + edVerifier := signing.NewEdVerifier() + + t.Run("valid", func(t *testing.T) { + t.Parallel() + + db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} + otherAtx.SetID(types.RandomATXID()) + otherAtx.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) + + wMarriageAtx := newActivationTxV2( + withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), + withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), + ) + wMarriageAtx.Sign(sig) + marriageAtx := &types.ActivationTx{} + marriageAtx.SetID(wMarriageAtx.ID()) + marriageAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, marriageAtx, wMarriageAtx.Blob())) + + atx := newActivationTxV2( + withMarriageATX(wMarriageAtx.ID()), + ) + atx.Sign(sig) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + proof, err := createMarriageProof(db, atx, otherSig.NodeID()) + require.NoError(t, err) + require.NotEmpty(t, proof) + + err = proof.Valid(verifier, atx.ID(), otherSig.NodeID(), sig.NodeID()) + require.NoError(t, err) + }) + + t.Run("node ID is the same as smesher ID", func(t *testing.T) { + t.Parallel() + + db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} + otherAtx.SetID(types.RandomATXID()) + otherAtx.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) + + wMarriageAtx := newActivationTxV2( + withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), + withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), + ) + wMarriageAtx.Sign(sig) + marriageAtx := &types.ActivationTx{} + marriageAtx.SetID(wMarriageAtx.ID()) + marriageAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, marriageAtx, wMarriageAtx.Blob())) + + atx := newActivationTxV2( + withMarriageATX(wMarriageAtx.ID()), + ) + atx.Sign(sig) + + proof, err := createMarriageProof(db, atx, sig.NodeID()) + require.EqualError(t, err, "node ID is the same as smesher ID") + require.Empty(t, proof) + }) + + t.Run("marriage ATX is not available", func(t *testing.T) { + t.Parallel() + + db := statesql.InMemoryTest(t) + + atx := newActivationTxV2( + withMarriageATX(types.RandomATXID()), + ) + atx.Sign(sig) + + proof, err := createMarriageProof(db, atx, otherSig.NodeID()) + require.ErrorIs(t, err, sql.ErrNotFound) + require.Empty(t, proof) + }) + + t.Run("node ID isn't married in marriage ATX", func(t *testing.T) { + t.Parallel() + + db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} + otherAtx.SetID(types.RandomATXID()) + otherAtx.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) + + wMarriageAtx := newActivationTxV2( + withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), + withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), + ) + wMarriageAtx.Sign(sig) + marriageAtx := &types.ActivationTx{} + marriageAtx.SetID(wMarriageAtx.ID()) + marriageAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, marriageAtx, wMarriageAtx.Blob())) + + atx := newActivationTxV2( + withMarriageATX(wMarriageAtx.ID()), + ) + atx.Sign(sig) + + invalidSig, err := signing.NewEdSigner() + require.NoError(t, err) + + proof, err := createMarriageProof(db, atx, invalidSig.NodeID()) + require.ErrorContains(t, err, + fmt.Sprintf("does not contain a marriage certificate signed by %s", invalidSig.NodeID().ShortString()), + ) + require.Empty(t, proof) + + atx.Sign(invalidSig) + proof, err = createMarriageProof(db, atx, otherSig.NodeID()) + require.ErrorContains(t, err, + fmt.Sprintf("does not contain a marriage certificate signed by %s", invalidSig.NodeID().ShortString()), + ) + require.Empty(t, proof) + }) + + t.Run("invalid proof", func(t *testing.T) { + t.Parallel() + + db := statesql.InMemoryTest(t) + otherAtx := &types.ActivationTx{} + otherAtx.SetID(types.RandomATXID()) + otherAtx.SmesherID = otherSig.NodeID() + require.NoError(t, atxs.Add(db, otherAtx, types.AtxBlob{})) + + wMarriageAtx := newActivationTxV2( + withMarriageCertificate(sig, types.EmptyATXID, sig.NodeID()), + withMarriageCertificate(otherSig, otherAtx.ID(), sig.NodeID()), + ) + wMarriageAtx.Sign(sig) + marriageAtx := &types.ActivationTx{} + marriageAtx.SetID(wMarriageAtx.ID()) + marriageAtx.SmesherID = sig.NodeID() + require.NoError(t, atxs.Add(db, marriageAtx, wMarriageAtx.Blob())) + + atx := newActivationTxV2( + withMarriageATX(wMarriageAtx.ID()), + ) + atx.Sign(sig) + + ctrl := gomock.NewController(t) + verifier := NewMockMalfeasanceValidator(ctrl) + verifier.EXPECT().Signature(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool { + return edVerifier.Verify(d, nodeID, m, sig) + }).AnyTimes() + + proof, err := createMarriageProof(db, atx, otherSig.NodeID()) + require.NoError(t, err) + require.NotEmpty(t, proof) + + // not valid for random ATX + err = proof.Valid(verifier, types.RandomATXID(), otherSig.NodeID(), sig.NodeID()) + require.EqualError(t, err, "invalid marriage ATX proof") + + // not valid for another smesher + err = proof.Valid(verifier, atx.ID(), otherSig.NodeID(), types.RandomNodeID()) + require.ErrorContains(t, err, "invalid certificate signature") + + // not valid for another nodeID + err = proof.Valid(verifier, atx.ID(), types.RandomNodeID(), sig.NodeID()) + require.ErrorContains(t, err, "invalid certificate signature") + + // not valid for incorrect marriage ATX + marriageATX := proof.MarriageATX + proof.MarriageATX = types.RandomATXID() + err = proof.Valid(verifier, atx.ID(), otherSig.NodeID(), sig.NodeID()) + require.EqualError(t, err, "invalid marriage ATX proof") + proof.MarriageATX = marriageATX + + // not valid for incorrect marriage ATX smesher ID + marriageATXSmesherID := proof.MarriageATXSmesherID + proof.MarriageATXSmesherID = types.RandomNodeID() + err = proof.Valid(verifier, atx.ID(), otherSig.NodeID(), sig.NodeID()) + require.ErrorContains(t, err, "invalid certificate signature") + proof.MarriageATXSmesherID = marriageATXSmesherID + }) +} diff --git a/activation/wire/wire_v1.go b/activation/wire/wire_v1.go index 200ae83fa1..14ea40aa60 100644 --- a/activation/wire/wire_v1.go +++ b/activation/wire/wire_v1.go @@ -1,10 +1,8 @@ package wire import ( - "encoding/binary" "encoding/hex" - "github.com/spacemeshos/merkle-tree" "go.uber.org/zap/zapcore" "github.com/spacemeshos/go-spacemesh/codec" @@ -44,24 +42,14 @@ type PostV1 struct { Pow uint64 } -func (p *PostV1) merkleTree(tree *merkle.Tree) { - var nonce types.Hash32 - binary.LittleEndian.PutUint32(nonce[:], p.Nonce) - tree.AddLeaf(nonce.Bytes()) - - hasher := hash.GetHasher() - defer hash.PutHasher(hasher) - tree.AddLeaf(hasher.Sum(p.Indices)) - - var pow types.Hash32 - binary.LittleEndian.PutUint64(pow[:], p.Pow) - tree.AddLeaf(pow.Bytes()) -} - type PostRoot types.Hash32 -func (p *PostV1) Root() PostRoot { - return PostRoot(createRoot(p.merkleTree)) +func (p *PostV1) Root() (result PostRoot) { + h := hash.GetHasher() + defer hash.PutHasher(h) + codec.MustEncodeTo(h, p) + h.Sum(result[:0]) + return result } type MerkleProofV1 struct { diff --git a/activation/wire/wire_v1_test.go b/activation/wire/wire_v1_test.go index 94563790d7..adca8f7f55 100644 --- a/activation/wire/wire_v1_test.go +++ b/activation/wire/wire_v1_test.go @@ -6,13 +6,13 @@ import ( fuzz "github.com/google/gofuzz" "github.com/stretchr/testify/require" + "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" ) func Benchmark_ATXv1ID(b *testing.B) { f := fuzz.New() b.ResetTimer() - for i := 0; i < b.N; i++ { b.StopTimer() atx := &ActivationTxV1{} @@ -34,3 +34,16 @@ func Test_NoATXv1IDCollisions(t *testing.T) { atxIDs = append(atxIDs, id) } } + +func Fuzz_ATXv1IDConsistency(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + fuzzer := fuzz.NewFromGoFuzz(data) + atx := &ActivationTxV1{} + fuzzer.Fuzz(atx) + id := atx.ID() + encoded := codec.MustEncode(atx) + decoded := &ActivationTxV1{} + codec.MustDecode(encoded, decoded) + require.Equal(t, id, atx.ID(), "ID should be consistent") + }) +} diff --git a/activation/wire/wire_v2.go b/activation/wire/wire_v2.go index 08a1390f5d..696c4b332d 100644 --- a/activation/wire/wire_v2.go +++ b/activation/wire/wire_v2.go @@ -129,7 +129,7 @@ func (atx *ActivationTxV2) merkleTree(tree *merkle.Tree) { tree.AddLeaf(types.EmptyHash32.Bytes()) } - tree.AddLeaf(atx.PreviousATXs.Root().Bytes()) + tree.AddLeaf(types.Hash32(atx.PreviousATXs.Root()).Bytes()) tree.AddLeaf(types.Hash32(atx.NIPosts.Root(atx.PreviousATXs)).Bytes()) var vrfNonce types.Hash32 @@ -158,10 +158,18 @@ func (atx *ActivationTxV2) ID() types.ATXID { return atx.id } -func (atx *ActivationTxV2) PublishEpochProof() []types.Hash32 { +func (atx *ActivationTxV2) PublishEpochProof() PublishEpochProof { return atx.merkleProof(PublishEpochIndex) } +type PublishEpochProof []types.Hash32 + +func (p PublishEpochProof) Valid(atxID types.ATXID, publishEpoch types.EpochID) bool { + var publishEpochBytes types.Hash32 + binary.LittleEndian.PutUint32(publishEpochBytes[:], publishEpoch.Uint32()) + return validateProof(types.Hash32(atxID), publishEpochBytes, p, uint64(PublishEpochIndex)) +} + func (atx *ActivationTxV2) PositioningATXProof() []types.Hash32 { return atx.merkleProof(PositioningATXIndex) } @@ -180,10 +188,16 @@ func (p InitialPostRootProof) Valid(atxID types.ATXID, initialPostRoot InitialPo return validateProof(types.Hash32(atxID), types.Hash32(initialPostRoot), p, uint64(InitialPostRootIndex)) } -func (atx *ActivationTxV2) PreviousATXsRootProof() []types.Hash32 { +func (atx *ActivationTxV2) PreviousATXsRootProof() PrevATXsRootProof { return atx.merkleProof(PreviousATXsRootIndex) } +type PrevATXsRootProof []types.Hash32 + +func (p PrevATXsRootProof) Valid(atxID types.ATXID, prevATXsRoot PrevATXsRoot) bool { + return validateProof(types.Hash32(atxID), types.Hash32(prevATXsRoot), p, uint64(PreviousATXsRootIndex)) +} + func (atx *ActivationTxV2) NIPostsRootProof() NIPostsRootProof { return atx.merkleProof(NIPostsRootIndex) } @@ -272,8 +286,23 @@ func (prevATXs PrevATXs) merkleTree(tree *merkle.Tree) { } } -func (prevATXs PrevATXs) Root() types.Hash32 { - return createRoot(prevATXs.merkleTree) +type PrevATXsRoot types.Hash32 + +func (prevATXs PrevATXs) Root() PrevATXsRoot { + return PrevATXsRoot(createRoot(prevATXs.merkleTree)) +} + +func (prevATXs PrevATXs) Proof(index int) PrevATXsProof { + if index < 0 || index >= len(prevATXs) { + panic("index out of range") + } + return createProof(uint64(index), prevATXs.merkleTree) +} + +type PrevATXsProof []types.Hash32 + +func (p PrevATXsProof) Valid(prevATXsRoot PrevATXsRoot, index int, prevATX types.ATXID) bool { + return validateProof(types.Hash32(prevATXsRoot), types.Hash32(prevATX), p, uint64(index)) } type NIPosts []NIPostV2 @@ -385,14 +414,12 @@ type MerkleProofV2 struct { Nodes []types.Hash32 `scale:"max=32"` } -func (mp MerkleProofV2) Root() types.Hash32 { - hasher := hash.GetHasher() - defer hash.PutHasher(hasher) - hasher.Write([]byte{0x01}) - for _, node := range mp.Nodes { - hasher.Write(node.Bytes()) - } - return types.Hash32(hasher.Sum(nil)) +func (mp *MerkleProofV2) Root() (result types.Hash32) { + h := hash.GetHasher() + defer hash.PutHasher(h) + codec.MustEncodeTo(h, mp) + h.Sum(result[:0]) + return result } type SubPostsV2 []SubPostV2 @@ -459,21 +486,12 @@ func (post *SubPostV2) MarshalLogObject(encoder zapcore.ObjectEncoder) error { return nil } -func (sp *SubPostV2) merkleTree(tree *merkle.Tree, prevATXs []types.ATXID) { +func (sp *SubPostV2) merkleTree(tree *merkle.Tree, prevATX types.ATXID) { var marriageIndex types.Hash32 binary.LittleEndian.PutUint32(marriageIndex[:], sp.MarriageIndex) tree.AddLeaf(marriageIndex.Bytes()) - switch { - case len(prevATXs) == 0: // special case for initial ATX: prevATXs is empty - tree.AddLeaf(types.EmptyATXID.Bytes()) - case int(sp.PrevATXIndex) < len(prevATXs): - tree.AddLeaf(prevATXs[sp.PrevATXIndex].Bytes()) - default: - // prevATXIndex is out of range, don't fail ATXID generation - // will be detected by syntactical validation - tree.AddLeaf(types.EmptyATXID.Bytes()) - } + tree.AddLeaf(prevATX.Bytes()) var leafIndex types.Hash32 binary.LittleEndian.PutUint64(leafIndex[:], sp.MembershipLeafIndex) @@ -488,7 +506,17 @@ func (sp *SubPostV2) merkleTree(tree *merkle.Tree, prevATXs []types.ATXID) { func (sp *SubPostV2) merkleProof(leafIndex SubPostTreeIndex, prevATXs []types.ATXID) []types.Hash32 { return createProof(uint64(leafIndex), func(tree *merkle.Tree) { - sp.merkleTree(tree, prevATXs) + var prevATX types.ATXID + switch { + case len(prevATXs) == 0: // special case for initial ATX: prevATXs is empty + prevATX = types.EmptyATXID + case int(sp.PrevATXIndex) < len(prevATXs): + prevATX = prevATXs[sp.PrevATXIndex] + default: + // not the full set of prevATXs is provided, proof cannot be generated + panic("prevATXIndex out of range or prevATXs incomplete") + } + sp.merkleTree(tree, prevATX) }) } @@ -496,7 +524,18 @@ type SubPostRoot types.Hash32 func (sp *SubPostV2) Root(prevATXs []types.ATXID) SubPostRoot { return SubPostRoot(createRoot(func(tree *merkle.Tree) { - sp.merkleTree(tree, prevATXs) + var prevATX types.ATXID + switch { + case len(prevATXs) == 0: // special case for initial ATX: prevATXs is empty + prevATX = types.EmptyATXID + case int(sp.PrevATXIndex) < len(prevATXs): + prevATX = prevATXs[sp.PrevATXIndex] + default: + // prevATXIndex is out of range, don't fail ATXID generation + // will be detected by syntactical validation + prevATX = types.EmptyATXID + } + sp.merkleTree(tree, prevATX) })) } @@ -516,6 +555,18 @@ func (sp *SubPostV2) PrevATXIndexProof(prevATXs []types.ATXID) []types.Hash32 { return sp.merkleProof(PrevATXIndex, prevATXs) } +func (sp *SubPostV2) PrevATXProof(prevATX types.ATXID) PrevATXProof { + return createProof(uint64(SubPostTreeIndex(PrevATXIndex)), func(tree *merkle.Tree) { + sp.merkleTree(tree, prevATX) + }) +} + +type PrevATXProof []types.Hash32 + +func (p PrevATXProof) Valid(subPostRoot SubPostRoot, prevATX types.ATXID) bool { + return validateProof(types.Hash32(subPostRoot), types.Hash32(prevATX), p, uint64(PrevATXIndex)) +} + func (sp *SubPostV2) MembershipLeafIndexProof(prevATXs []types.ATXID) []types.Hash32 { return sp.merkleProof(MembershipLeafIndex, prevATXs) } @@ -595,34 +646,21 @@ func (mc *MarriageCertificate) MarshalLogObject(encoder zapcore.ObjectEncoder) e return nil } -func (mc *MarriageCertificate) merkleTree(tree *merkle.Tree) { - tree.AddLeaf(mc.ReferenceAtx.Bytes()) - tree.AddLeaf(mc.Signature.Bytes()) -} - -func (mc *MarriageCertificate) merkleProof(leafIndex MarriageCertificateIndex) []types.Hash32 { - return createProof(uint64(leafIndex), mc.merkleTree) -} - -func (mc *MarriageCertificate) Root() types.Hash32 { - return createRoot(mc.merkleTree) -} - -func (mc *MarriageCertificate) ReferenceATXProof() []types.Hash32 { - return mc.merkleProof(ReferenceATXIndex) -} - -func (mc *MarriageCertificate) SignatureProof() []types.Hash32 { - return mc.merkleProof(SignatureIndex) +func (mc *MarriageCertificate) Root() (result types.Hash32) { + h := hash.GetHasher() + defer hash.PutHasher(h) + codec.MustEncodeTo(h, mc) + h.Sum(result[:0]) + return result } func atxTreeHash(buf, lChild, rChild []byte) []byte { - hasher := hash.GetHasher() - defer hash.PutHasher(hasher) - hasher.Write([]byte{0x01}) - hasher.Write(lChild) - hasher.Write(rChild) - return hasher.Sum(buf) + h := hash.GetHasher() + defer hash.PutHasher(h) + h.Write([]byte{0x01}) + h.Write(lChild) + h.Write(rChild) + return h.Sum(buf) } func createRoot(addLeaves func(tree *merkle.Tree)) types.Hash32 { diff --git a/activation/wire/wire_v2_test.go b/activation/wire/wire_v2_test.go index 188e19a130..f71a6fba96 100644 --- a/activation/wire/wire_v2_test.go +++ b/activation/wire/wire_v2_test.go @@ -31,6 +31,12 @@ func withMarriageATX(id types.ATXID) testAtxV2Opt { } } +func withPublishEpoch(epoch types.EpochID) testAtxV2Opt { + return func(atx *ActivationTxV2) { + atx.PublishEpoch = epoch + } +} + func withInitial(commitAtx types.ATXID, post PostV1) testAtxV2Opt { return func(atx *ActivationTxV2) { atx.Initial = &InitialAtxPartsV2{ @@ -114,7 +120,6 @@ func newActivationTxV2(opts ...testAtxV2Opt) *ActivationTxV2 { func Benchmark_ATXv2ID(b *testing.B) { f := fuzz.New() b.ResetTimer() - for i := 0; i < b.N; i++ { b.StopTimer() atx := &ActivationTxV2{} @@ -125,49 +130,59 @@ func Benchmark_ATXv2ID(b *testing.B) { } func Benchmark_ATXv2ID_WorstScenario(b *testing.B) { - b.ResetTimer() - - for i := 0; i < b.N; i++ { - b.StopTimer() - atx := &ActivationTxV2{ - PublishEpoch: 0, - PositioningATX: types.RandomATXID(), - PreviousATXs: make([]types.ATXID, 256), - NIPosts: []NIPostV2{ - { - Membership: MerkleProofV2{ - Nodes: make([]types.Hash32, 32), - }, - Challenge: types.RandomHash(), - Posts: make([]SubPostV2, 256), + atx := &ActivationTxV2{ + PublishEpoch: 0, + PositioningATX: types.RandomATXID(), + PreviousATXs: make([]types.ATXID, 256), + NIPosts: []NIPostV2{ + { + Membership: MerkleProofV2{ + Nodes: make([]types.Hash32, 32), }, - { - Membership: MerkleProofV2{ - Nodes: make([]types.Hash32, 32), - }, - Challenge: types.RandomHash(), - Posts: make([]SubPostV2, 256), // actually the sum of all posts in `NiPosts` should be 256 + Challenge: types.RandomHash(), + Posts: make([]SubPostV2, 256), + }, + { + Membership: MerkleProofV2{ + Nodes: make([]types.Hash32, 32), }, + Challenge: types.RandomHash(), + Posts: make([]SubPostV2, 256), // actually the sum of all posts in `NiPosts` should be 256 }, - } - for i := range atx.NIPosts[0].Posts { - atx.NIPosts[0].Posts[i].Post = PostV1{ - Nonce: 0, - Indices: make([]byte, 800), - Pow: 0, - } - } - for i := range atx.NIPosts[1].Posts { - atx.NIPosts[1].Posts[i].Post = PostV1{ + { + Membership: MerkleProofV2{ + Nodes: make([]types.Hash32, 32), + }, + Challenge: types.RandomHash(), + Posts: make([]SubPostV2, 256), // actually the sum of all posts in `NiPosts` should be 256 + }, + { + Membership: MerkleProofV2{ + Nodes: make([]types.Hash32, 32), + }, + Challenge: types.RandomHash(), + Posts: make([]SubPostV2, 256), // actually the sum of all posts in `NiPosts` should be 256 + }, + }, + } + for j := range atx.NIPosts { + for i := range atx.NIPosts[j].Posts { + atx.NIPosts[j].Posts[i].Post = PostV1{ Nonce: 0, Indices: make([]byte, 800), Pow: 0, } } - atx.MarriageATX = new(types.ATXID) - b.StartTimer() - atx.ID() } + atx.MarriageATX = new(types.ATXID) + + var id types.ATXID + b.ResetTimer() + for i := 0; i < b.N; i++ { + atx.id = types.EmptyATXID + id = atx.ID() + } + require.Equal(b, id, atx.ID()) } func Test_NoATXv2IDCollisions(t *testing.T) { @@ -183,6 +198,26 @@ func Test_NoATXv2IDCollisions(t *testing.T) { } } +func Fuzz_ATXv2IDConsistency(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + fuzzer := fuzz.NewFromGoFuzz(data). + // Ensure that `NIPosts` is at most 4 elements long + Funcs(func(niposts *NIPosts, c fuzz.Continue) { + *niposts = make([]NIPostV2, c.Intn(5)) + for i := range *niposts { + c.Fuzz(&(*niposts)[i]) + } + }) + atx := &ActivationTxV2{} + fuzzer.Fuzz(atx) + id := atx.ID() + encoded := codec.MustEncode(atx) + decoded := &ActivationTxV2{} + codec.MustDecode(encoded, decoded) + require.Equal(t, id, atx.ID(), "ID should be consistent") + }) +} + func Test_ATXv2_SupportUpTo4Niposts(t *testing.T) { f := fuzz.New() atx := &ActivationTxV2{} diff --git a/api/grpcserver/grpcserver_test.go b/api/grpcserver/grpcserver_test.go index 53a16f2662..4b9e15c0bd 100644 --- a/api/grpcserver/grpcserver_test.go +++ b/api/grpcserver/grpcserver_test.go @@ -315,6 +315,10 @@ func (t *ConStateAPIMock) GetStateRoot() (types.Hash32, error) { return stateRoot, nil } +func (t *ConStateAPIMock) HasEvicted(id types.TransactionID) (bool, error) { + panic("not implemented") +} + func (t *ConStateAPIMock) GetMeshTransaction(id types.TransactionID) (*types.MeshTransaction, error) { tx, ok := t.returnTx[id] if ok { diff --git a/api/grpcserver/interface.go b/api/grpcserver/interface.go index 7b513b9c7e..bfa88836b9 100644 --- a/api/grpcserver/interface.go +++ b/api/grpcserver/interface.go @@ -40,6 +40,7 @@ type conservativeState interface { GetMeshTransactions([]types.TransactionID) ([]*types.MeshTransaction, map[types.TransactionID]struct{}) GetTransactionsByAddress(types.LayerID, types.LayerID, types.Address) ([]*types.MeshTransaction, error) Validation(raw types.RawTx) system.ValidationRequest + HasEvicted(tid types.TransactionID) (bool, error) } // syncer is the API to get sync status. diff --git a/api/grpcserver/mocks.go b/api/grpcserver/mocks.go index 920c5e8bfc..3867be40c6 100644 --- a/api/grpcserver/mocks.go +++ b/api/grpcserver/mocks.go @@ -691,6 +691,45 @@ func (c *MockconservativeStateGetTransactionsByAddressCall) DoAndReturn(f func(t return c } +// HasEvicted mocks base method. +func (m *MockconservativeState) HasEvicted(tid types.TransactionID) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasEvicted", tid) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// HasEvicted indicates an expected call of HasEvicted. +func (mr *MockconservativeStateMockRecorder) HasEvicted(tid any) *MockconservativeStateHasEvictedCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasEvicted", reflect.TypeOf((*MockconservativeState)(nil).HasEvicted), tid) + return &MockconservativeStateHasEvictedCall{Call: call} +} + +// MockconservativeStateHasEvictedCall wrap *gomock.Call +type MockconservativeStateHasEvictedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockconservativeStateHasEvictedCall) Return(arg0 bool, arg1 error) *MockconservativeStateHasEvictedCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockconservativeStateHasEvictedCall) Do(f func(types.TransactionID) (bool, error)) *MockconservativeStateHasEvictedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockconservativeStateHasEvictedCall) DoAndReturn(f func(types.TransactionID) (bool, error)) *MockconservativeStateHasEvictedCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // Validation mocks base method. func (m *MockconservativeState) Validation(raw types.RawTx) system.ValidationRequest { m.ctrl.T.Helper() diff --git a/api/grpcserver/transaction_service.go b/api/grpcserver/transaction_service.go index 16084e81f1..c56f96af8b 100644 --- a/api/grpcserver/transaction_service.go +++ b/api/grpcserver/transaction_service.go @@ -143,7 +143,15 @@ func (s *TransactionService) getTransactionAndStatus( case types.APPLIED: state = pb.TransactionState_TRANSACTION_STATE_PROCESSED default: - state = pb.TransactionState_TRANSACTION_STATE_UNSPECIFIED + evicted, err := s.conState.HasEvicted(txID) + if err != nil { + return nil, state + } + if evicted { + state = pb.TransactionState_TRANSACTION_STATE_INEFFECTUAL + } else { + state = pb.TransactionState_TRANSACTION_STATE_UNSPECIFIED + } } return &tx.Transaction, state } diff --git a/api/grpcserver/v2alpha1/node_test.go b/api/grpcserver/v2alpha1/node_test.go index 4ff5a85c27..7e8c14f436 100644 --- a/api/grpcserver/v2alpha1/node_test.go +++ b/api/grpcserver/v2alpha1/node_test.go @@ -23,7 +23,8 @@ func TestNodeService_Status(t *testing.T) { timesync.WithLayerDuration(layerDuration), timesync.WithTickInterval(1*time.Second), timesync.WithGenesisTime(time.Now()), - timesync.WithLogger(zaptest.NewLogger(t))) + timesync.WithLogger(zaptest.NewLogger(t)), + ) require.NoError(t, err) defer clock.Close() diff --git a/api/grpcserver/v2alpha1/transaction.go b/api/grpcserver/v2alpha1/transaction.go index 7ceed2c543..9804006286 100644 --- a/api/grpcserver/v2alpha1/transaction.go +++ b/api/grpcserver/v2alpha1/transaction.go @@ -6,9 +6,11 @@ import ( "errors" "fmt" + "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" spacemeshv2alpha1 "github.com/spacemeshos/api/release/go/spacemesh/v2alpha1" "github.com/spacemeshos/go-scale" + "go.uber.org/zap" "google.golang.org/genproto/googleapis/rpc/code" rpcstatus "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc" @@ -207,6 +209,9 @@ func (s *TransactionService) SubmitTransaction( } raw := types.NewRawTx(request.Transaction) + ctxzap.Info(ctx, "successfully submitted transaction", + zap.Stringer("tx_id", raw.ID), + ) return &spacemeshv2alpha1.SubmitTransactionResponse{ Status: &rpcstatus.Status{Code: int32(code.Code_OK)}, TxId: raw.ID[:], diff --git a/beacon/beacon.go b/beacon/beacon.go index 6a71b2064c..692a92b8dd 100644 --- a/beacon/beacon.go +++ b/beacon/beacon.go @@ -729,16 +729,16 @@ func (pd *ProtocolDriver) listenEpochs(ctx context.Context) { pd.logger.Debug("time sync detected, realigning Beacon") continue } + epoch := current.GetEpoch() + layer = epoch.Add(1).FirstLayer() if !current.FirstInEpoch() { continue } - epoch := current.GetEpoch() - layer = epoch.Add(1).FirstLayer() pd.setProposalTimeForNextEpoch() pd.logger.Info("processing epoch", zap.Uint32("epoch", epoch.Uint32())) pd.eg.Go(func() error { - _ = pd.onNewEpoch(ctx, epoch) + pd.onNewEpoch(ctx, epoch) return nil }) } diff --git a/dev-docs/fptree-agg-from-root.png b/dev-docs/fptree-agg-from-root.png new file mode 100644 index 0000000000..024d0b8418 Binary files /dev/null and b/dev-docs/fptree-agg-from-root.png differ diff --git a/dev-docs/fptree-agg-lca.png b/dev-docs/fptree-agg-lca.png new file mode 100644 index 0000000000..e88fed677e Binary files /dev/null and b/dev-docs/fptree-agg-lca.png differ diff --git a/dev-docs/fptree-agg-limit-wraparound.png b/dev-docs/fptree-agg-limit-wraparound.png new file mode 100644 index 0000000000..73fa84762a Binary files /dev/null and b/dev-docs/fptree-agg-limit-wraparound.png differ diff --git a/dev-docs/fptree-agg-limit.png b/dev-docs/fptree-agg-limit.png new file mode 100644 index 0000000000..9700d69fe0 Binary files /dev/null and b/dev-docs/fptree-agg-limit.png differ diff --git a/dev-docs/fptree-agg-wraparound.png b/dev-docs/fptree-agg-wraparound.png new file mode 100644 index 0000000000..cc536d3f06 Binary files /dev/null and b/dev-docs/fptree-agg-wraparound.png differ diff --git a/dev-docs/fptree-with-values.png b/dev-docs/fptree-with-values.png new file mode 100644 index 0000000000..40ee6cb72d Binary files /dev/null and b/dev-docs/fptree-with-values.png differ diff --git a/dev-docs/fptree.excalidraw.gz b/dev-docs/fptree.excalidraw.gz new file mode 100644 index 0000000000..1d018b5d87 Binary files /dev/null and b/dev-docs/fptree.excalidraw.gz differ diff --git a/dev-docs/fptree.png b/dev-docs/fptree.png new file mode 100644 index 0000000000..93ef265ddf Binary files /dev/null and b/dev-docs/fptree.png differ diff --git a/dev-docs/sync2-set-reconciliation.md b/dev-docs/sync2-set-reconciliation.md index 0ffdbe2cfe..b66248116f 100644 --- a/dev-docs/sync2-set-reconciliation.md +++ b/dev-docs/sync2-set-reconciliation.md @@ -27,6 +27,14 @@ - [Redundant ItemChunk messages](#redundant-itemchunk-messages) - [Range checksums](#range-checksums) - [Bloom filters for recent sync](#bloom-filters-for-recent-sync) +- [FPTree Data Structure](#fptree-data-structure) + - [Tree structure](#tree-structure) + - [Aggregation](#aggregation) + - [Aggregation of normal ranges](#aggregation-of-normal-ranges) + - [Aggregation of wraparound ranges](#aggregation-of-wraparound-ranges) + - [Splitting ranges and limited aggregation](#splitting-ranges-and-limited-aggregation) + - [Tree node representation](#tree-node-representation) + - [Accessing the database](#accessing-the-database) - [Multi-peer Reconciliation](#multi-peer-reconciliation) - [Deciding on the sync strategy](#deciding-on-the-sync-strategy) - [Split sync](#split-sync) @@ -775,11 +783,291 @@ just want to bring them closer to each other. That being said, a sufficient size of the Bloom filter needs to be chosen to minimize the number of missed elements. +# FPTree Data Structure + +FPTree (fingerprint tree) is data structure intended to facilitate +synchronization of objects stored in an SQLite database, with +hash-based IDs. It stores fingerprints (IDs XORed together) and item +counts for ID ranges. + +## Tree structure + +FPTree has the following properties: + +1. FPTree is an in-memory structure that provides efficient item count + and fingerprints for ID (item/key) ranges, trying to do its best to + avoid doing database queries. The queries may be entirely avoided + if ranges are aligned on the node boundaries. +1. FPTree is a binary trie (prefix tree), following the bits in the + IDs starting from the highest one. The intent is to convert it to a + proper radix tree instead, but that's not implemented yet. +1. FPTree relies on IDs being hashes and thus being uniformly + distributed to ensure balancedness of the tree, instead of using a + balancing mechanism such as red-black tree. +1. FPTree provides a range split mechanism (needed for pairwise sync) + which tries to ensure that the ranges are aligned on node + boundaries up to certain subdivision depth. +1. Full FPTree copy operation is `O(1)` in terms of time and + memory. The copies are safe for concurrent use. +1. FPTree can also store the actual IDs without the use of an + underlying table. +1. FPTrees can be "stacked" together. The FPTree-based `OrderedSet` + implementation uses 2 FPTrees, one database-bound and another one + fully in-memory. The in-memory FPTree is used to store fresh items + received via the [Recent sync](#recent-sync) mechanism. +1. FPTrees performs queries on ranges `[x,y)`, supporting normal `x < + y` ranges, as well as wraparound `x > y` ranges and full set range + `[x,x)` (see [Range representation](#range-representation)). +1. Each FPTree node has corresponding bit prefix by which it can be + reached. + +The tree structure is shown on the diagram below. The leaf nodes +correspond to the rows in database table with IDs having the bit prefix +corresponding to the leaf node. + +![FPTree structure](fptree.png) + +As it is mentioned above, FPTree itself can also store the actual IDs, +without using an underlying database table.\ + +![FPTree with values](fptree-with-values.png) + +## Aggregation + +Aggregation means calculation of fingerprint and item count for a +range. The aggregation is done using different methods depending on +whether the `[x,y)` range is normal (`xy`) or +indicates the whole set (`x=y`). Aggregation may also be bounded by +the maximum number of items to include. The easiest case is full set +aggregation, in which we just take the fingerprint and count values +from the root node of the FPTree. + +### Aggregation of normal ranges + +In case of a normal range `[x,y)` with `xy`, `aggregateLeft` and +`aggregateRight` are used, too. Somewhat unintuitively, in this case +`aggregateLeft` is used on the right side of the tree, b/c that's +where the beginning ("left side") of the wrapped-around `[x,y)` range +lies, whereas `aggregateRight` is applied to the left side of the tree +corresponding to the end ("right side") of the range. + +The subtree on which `aggregateLeft` is done is rooted at the node +reachable by following the longest prefix of `x` consisting entirely +of `1`s. Conversely, the subtree on which `aggregateRight` is done is +rooted at the node reachable by following the longest prefix of `y` +consisting entirely of `0`s. + +The figure below shows aggregation of the `[x,y)` range with +`x=0xD1..` and `y=0x29`. + +![Aggregation of a wrapped-around range](fptree-agg-wraparound.png) + +## Splitting ranges and limited aggregation + +During recursive set reconciliation, range split operation often needs +to be performed. This involves partitioning the range roughly in half +with respect to the number of items in each new subrange, and +calculating item count and fingerprint for each part resulting from +the split. FPTree will try to perform such an operation on node +boundary, but if the range is to small or not aligned to the node +boundary, the following is done: + +1. The number of items in the range obtained (`N`). +2. The items in the range are aggregated with the cap on maximum + aggregated count equal to `N/2`, and the non-inclusive upper bound + of the aggregated subrange is noted (`m`). The aggregated items + can be said to lie in range `[x,m)` +3. The second half of the range is aggregated starting with `m`. This + part of the range is `[m,y)`. + +In both cases, the operation is based upon imposing the limit on +number of items aggregated. In the easy, node-aligned case, the +aggregation continues after exhausting the limit on the total item +count, but using separate places for accumulation of remaining nodes' +fingerprints and counts. The initial accumulated fingerprint and count +are returned for the first resulting subrange, and the second +accumulated fingerprint and count are returned for the second subrange +resulting from the partition. In case if node-aligned "easy split" +cannot be done, aggregation stops after exhausting the limit. + +When limited aggregation is done, instead of including full right +subtrees during `aggregateLeft`, including full left subtrees during +`aggregateRight`, and including the whole tree during `[x,x)` (full +set) range aggregation, when subtree count exceeds the remaining limit +after processing all the nodes visited so far, the corresponding +subtrees are descended into to find the cutoff point. + +Below limited aggregation is shown for a normal `x= ? AND + "rowid" <= ? ORDER BY "id" LIMIT ? +``` + +Select number of recently received items items for recent sync +(which is not done using FPTree): +```sql +SELECT count("id") FROM "atxs" WHERE "epoch" = ? AND + "rowid" <= ? AND "received" >= ? +``` + +Select recently received IDs: +```sql +SELECT "id" FROM "atxs" WHERE "epoch" = ? AND "id" >= ? AND + "rowid" <= ? AND "received" >= ? ORDER BY "id" LIMIT ? +``` + # Multi-peer Reconciliation -The multi-peer reconciliation approach is loosely based on -[SREP: Out-Of-Band Sync of Transaction Pools for Large-Scale Blockchains](https://people.bu.edu/staro/2023-ICBC-Novak.pdf) -paper by Novak Boškov, Sevval Simsek, Ari Trachtenberg, and David Starobinski. +The multi-peer reconciliation approach is loosely based on [SREP: +Out-Of-Band Sync of Transaction Pools for Large-Scale +Blockchains](https://people.bu.edu/staro/2023-ICBC-Novak.pdf) paper by +Novak Boškov, Sevval Simsek, Ari Trachtenberg, and David Starobinski. ![Multi-peer set reconciliation](multipeer.png) diff --git a/fetch/handler.go b/fetch/handler.go index f04488019a..c2b92d430e 100644 --- a/fetch/handler.go +++ b/fetch/handler.go @@ -291,7 +291,8 @@ func (h *handler) doHandleHashReq(ctx context.Context, data []byte, hint datasto h.logger.Debug("remote peer requested nonexistent hash", log.ZContext(ctx), zap.Stringer("hash", r.Hash), - zap.String("hint", string(r.Hint))) + zap.String("hint", string(r.Hint)), + ) hashMissing.WithLabelValues(string(r.Hint)).Add(1) continue } else if len(blob.Bytes) == 0 { @@ -302,7 +303,8 @@ func (h *handler) doHandleHashReq(ctx context.Context, data []byte, hint datasto h.logger.Debug("responded to hash request", log.ZContext(ctx), zap.Stringer("hash", r.Hash), - zap.Int("dataSize", len(blob.Bytes))) + zap.Int("dataSize", len(blob.Bytes)), + ) } // add response to batch m := ResponseMessage{ diff --git a/genvm/vm.go b/genvm/vm.go index 432a623ac6..5e8671513d 100644 --- a/genvm/vm.go +++ b/genvm/vm.go @@ -231,7 +231,7 @@ func (v *VM) Apply( for _, reward := range rewardsResult { if err := rewards.Add(tx, &reward); err != nil { - return nil, nil, fmt.Errorf("%w: %w", core.ErrInternal, err) + return nil, nil, fmt.Errorf("add reward %w: %w", core.ErrInternal, err) } } @@ -247,17 +247,17 @@ func (v *VM) Apply( return true }) if err != nil { - return nil, nil, fmt.Errorf("%w: %w", core.ErrInternal, err) + return nil, nil, fmt.Errorf("iterate changed %w: %w", core.ErrInternal, err) } writesPerBlock.Observe(float64(total)) var hashSum types.Hash32 hasher.Sum(hashSum[:0]) if err := layers.UpdateStateHash(tx, layer, hashSum); err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("update state hash: %w", err) } if err := tx.Commit(); err != nil { - return nil, nil, fmt.Errorf("%w: %w", core.ErrInternal, err) + return nil, nil, fmt.Errorf("commit %w: %w", core.ErrInternal, err) } ss.IterateChanged(func(account *core.Account) bool { if err := events.ReportAccountUpdate(account.Address); err != nil { diff --git a/go.mod b/go.mod index 5f83e371f4..13fff3a813 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/spacemeshos/go-spacemesh go 1.23.2 require ( - cloud.google.com/go/storage v1.46.0 + cloud.google.com/go/storage v1.47.0 github.com/ALTree/bigfloat v0.2.0 github.com/chaos-mesh/chaos-mesh/api v0.0.0-20241021021428-64a7a81821a0 github.com/cosmos/btcutil v1.0.5 @@ -40,12 +40,12 @@ require ( github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/seehuhn/mt19937 v1.0.0 github.com/slok/go-http-metrics v0.13.0 - github.com/spacemeshos/api/release/go v1.55.0 + github.com/spacemeshos/api/release/go v1.56.0 github.com/spacemeshos/economics v0.1.4 github.com/spacemeshos/fixed v0.1.2 github.com/spacemeshos/go-scale v1.2.1 - github.com/spacemeshos/merkle-tree v0.2.5 - github.com/spacemeshos/poet v0.10.9 + github.com/spacemeshos/merkle-tree v0.2.6 + github.com/spacemeshos/poet v0.10.10 github.com/spacemeshos/post v0.12.10 github.com/spf13/afero v1.11.0 github.com/spf13/cobra v1.8.1 @@ -56,12 +56,12 @@ require ( github.com/zeebo/blake3 v0.2.4 go.uber.org/mock v0.5.0 go.uber.org/zap v1.27.0 - golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c - golang.org/x/sync v0.8.0 - golang.org/x/time v0.7.0 - google.golang.org/genproto/googleapis/rpc v0.0.0-20241104194629-dd2ea8efbc28 - google.golang.org/grpc v1.67.1 - google.golang.org/protobuf v1.35.1 + golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f + golang.org/x/sync v0.9.0 + golang.org/x/time v0.8.0 + google.golang.org/genproto/googleapis/rpc v0.0.0-20241113202542-65e8d215514f + google.golang.org/grpc v1.68.0 + google.golang.org/protobuf v1.35.2 k8s.io/api v0.31.2 k8s.io/apimachinery v0.31.2 k8s.io/client-go v0.31.2 @@ -71,7 +71,7 @@ require ( require ( cel.dev/expr v0.16.1 // indirect cloud.google.com/go v0.116.0 // indirect - cloud.google.com/go/auth v0.10.0 // indirect + cloud.google.com/go/auth v0.10.2 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.5 // indirect cloud.google.com/go/compute/metadata v0.5.2 // indirect cloud.google.com/go/iam v1.2.1 // indirect @@ -233,14 +233,14 @@ require ( go.uber.org/dig v1.18.0 // indirect go.uber.org/fx v1.23.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/crypto v0.28.0 // indirect - golang.org/x/mod v0.21.0 // indirect - golang.org/x/net v0.30.0 // indirect + golang.org/x/crypto v0.29.0 // indirect + golang.org/x/mod v0.22.0 // indirect + golang.org/x/net v0.31.0 // indirect golang.org/x/oauth2 v0.23.0 // indirect - golang.org/x/sys v0.26.0 // indirect - golang.org/x/term v0.25.0 // indirect - golang.org/x/text v0.19.0 // indirect - golang.org/x/tools v0.26.0 // indirect + golang.org/x/sys v0.27.0 // indirect + golang.org/x/term v0.26.0 // indirect + golang.org/x/text v0.20.0 // indirect + golang.org/x/tools v0.27.0 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect gonum.org/v1/gonum v0.15.0 // indirect google.golang.org/api v0.203.0 // indirect diff --git a/go.sum b/go.sum index 5eda5547ad..8f1925e66f 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,8 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo= cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= -cloud.google.com/go/auth v0.10.0 h1:tWlkvFAh+wwTOzXIjrwM64karR1iTBZ/GRr0S/DULYo= -cloud.google.com/go/auth v0.10.0/go.mod h1:xxA5AqpDrvS+Gkmo9RqrGGRh6WSNKKOXhY3zNOr38tI= +cloud.google.com/go/auth v0.10.2 h1:oKF7rgBfSHdp/kuhXtqU/tNDr0mZqhYbEh+6SiqzkKo= +cloud.google.com/go/auth v0.10.2/go.mod h1:xxA5AqpDrvS+Gkmo9RqrGGRh6WSNKKOXhY3zNOr38tI= cloud.google.com/go/auth/oauth2adapt v0.2.5 h1:2p29+dePqsCHPP1bqDJcKj4qxRyYCcbzKpFyKGt3MTk= cloud.google.com/go/auth/oauth2adapt v0.2.5/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8= cloud.google.com/go/compute/metadata v0.5.2 h1:UxK4uu/Tn+I3p2dYWTfiX4wva7aYlKixAHn3fyqngqo= @@ -20,8 +20,8 @@ cloud.google.com/go/longrunning v0.6.1 h1:lOLTFxYpr8hcRtcwWir5ITh1PAKUD/sG2lKrTS cloud.google.com/go/longrunning v0.6.1/go.mod h1:nHISoOZpBcmlwbJmiVk5oDRz0qG/ZxPynEGs1iZ79s0= cloud.google.com/go/monitoring v1.21.1 h1:zWtbIoBMnU5LP9A/fz8LmWMGHpk4skdfeiaa66QdFGc= cloud.google.com/go/monitoring v1.21.1/go.mod h1:Rj++LKrlht9uBi8+Eb530dIrzG/cU/lB8mt+lbeFK1c= -cloud.google.com/go/storage v1.46.0 h1:OTXISBpFd8KaA2ClT3K3oRk8UGOcTHtrZ1bW88xKiic= -cloud.google.com/go/storage v1.46.0/go.mod h1:lM+gMAW91EfXIeMTBmixRsKL/XCxysytoAgduVikjMk= +cloud.google.com/go/storage v1.47.0 h1:ajqgt30fnOMmLfWfu1PWcb+V9Dxz6n+9WKjdNg5R4HM= +cloud.google.com/go/storage v1.47.0/go.mod h1:Ks0vP374w0PW6jOUameJbapbQKXqkjGd/OJRp2fb9IQ= cloud.google.com/go/trace v1.11.1 h1:UNqdP+HYYtnm6lb91aNA5JQ0X14GnxkABGlfz2PzPew= cloud.google.com/go/trace v1.11.1/go.mod h1:IQKNQuBzH72EGaXEodKlNJrWykGZxet2zgjtS60OtjA= dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= @@ -629,18 +629,18 @@ github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:Udh github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= -github.com/spacemeshos/api/release/go v1.55.0 h1:IQ8PmQ1d7CwUiM1r3NH8uZ+JkEyNjSltiAuqEY6dn6o= -github.com/spacemeshos/api/release/go v1.55.0/go.mod h1:qM6GTS2QtUvxPNIJf+2ObH63bGXYrJnapgOd6l6pbpQ= +github.com/spacemeshos/api/release/go v1.56.0 h1:llBVijoO4I3mhHk0OtGJdTT/11I7ajo0CZp3x8h1EjA= +github.com/spacemeshos/api/release/go v1.56.0/go.mod h1:6o17nhNyXpbVeijAQqkZfL8Pe/IkMGAWMLSLZni0DOU= github.com/spacemeshos/economics v0.1.4 h1:twlawrcQhYNqPgyDv08+24EL/OgUKz3d7q+PvJIAND0= github.com/spacemeshos/economics v0.1.4/go.mod h1:6HKWKiKdxjVQcGa2z/wA0LR4M/DzKib856bP16yqNmQ= github.com/spacemeshos/fixed v0.1.2 h1:pENQ8pXFAqin3f15ZLoOVVeSgcmcFJ0IFdFm4+9u4SM= github.com/spacemeshos/fixed v0.1.2/go.mod h1:OekUZD7FA9Ji8H/WEf5VuGYxPB+mWfXjbUI7I3qcT48= github.com/spacemeshos/go-scale v1.2.1 h1:+IJ6KmFl9tF1Om8B1NvEwilGpBG1ebr4Se8A0Fe4puE= github.com/spacemeshos/go-scale v1.2.1/go.mod h1:fpO6tCoKdUmvF6o9zkUtq2erSOH5t4ik02Zwdm31qOs= -github.com/spacemeshos/merkle-tree v0.2.5 h1:4iWiW4SvDEBGYRUvFUjArHeTHjvOa52JQ/iLW6wBzUs= -github.com/spacemeshos/merkle-tree v0.2.5/go.mod h1:lxMuC/C2qhN6wdH6iSXW0HM8FS6fnKnyLWjCAKsCtr8= -github.com/spacemeshos/poet v0.10.9 h1:lJizp95P/yoh/cVulFFfIcVZZTmMXqtNyrHCUZvUGAk= -github.com/spacemeshos/poet v0.10.9/go.mod h1:irrgk9xbwNnv0Tq3YcMpC8eia8O1uFhDP5nULY3HjT4= +github.com/spacemeshos/merkle-tree v0.2.6 h1:PJ4LBx0vBbYVIHwApyjLy/yqUGEK35ggGTo05oiPhwg= +github.com/spacemeshos/merkle-tree v0.2.6/go.mod h1:lxMuC/C2qhN6wdH6iSXW0HM8FS6fnKnyLWjCAKsCtr8= +github.com/spacemeshos/poet v0.10.10 h1:LgCQUjKvvuaoibH7nnNYehSF0yNR67eufwwqPBWb9Ts= +github.com/spacemeshos/poet v0.10.10/go.mod h1:6p+jNwqZOIWXvqdHEENfWZvwElIrw0HoxHIM1m4uDrk= github.com/spacemeshos/post v0.12.10 h1:S4THKvy/uGdNzoZkTI5qqIo2H8/W4xktKtYzxKsYNVU= github.com/spacemeshos/post v0.12.10/go.mod h1:oMoQ2oU5EXU1GsxK/kvhnc1/pRh8VYeRo/p8mMSsdHc= github.com/spacemeshos/sha256-simd v0.1.0 h1:G7Mfu5RYdQiuE+wu4ZyJ7I0TI74uqLhFnKblEnSpjYI= @@ -717,6 +717,8 @@ go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+n go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8= go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw= go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8= +go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.29.0 h1:WDdP9acbMYjbKIyJUhTvtzj601sVJOqgWdUxSdR/Ysc= +go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.29.0/go.mod h1:BLbf7zbNIONBLPwvFnwNHGj4zge8uTCM/UPIVW1Mq2I= go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc= go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8= go.opentelemetry.io/otel/sdk v1.29.0 h1:vkqKjk7gwhS8VaWb0POZKmIEDimRCMsopNYnriHyryo= @@ -761,11 +763,11 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= -golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= +golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY= -golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8= +golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo= +golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f/go.mod h1:D5SMRVC3C2/4+F/DB1wZsLRnSNimn2Sp/NPsCrsv8ak= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= @@ -779,8 +781,8 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= -golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= +golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -806,8 +808,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= -golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= +golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= +golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -825,8 +827,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= +golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180810173357-98c5dad5d1a0/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -860,8 +862,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= +golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -869,8 +871,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= -golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= +golang.org/x/term v0.26.0 h1:WEQa6V3Gja/BhNxg540hBip/kkaYtRg3cxg4oXSw4AU= +golang.org/x/term v0.26.0/go.mod h1:Si5m1o57C5nBNQo5z1iq+XDijt21BDBDp2bK0QI8e3E= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -880,12 +882,12 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= -golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= -golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= +golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -906,8 +908,8 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ= -golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= +golang.org/x/tools v0.27.0 h1:qEKojBykQkQ4EynWy4S8Weg69NumxKdn40Fce3uc/8o= +golang.org/x/tools v0.27.0/go.mod h1:sUi0ZgbwW9ZPAq26Ekut+weQPR5eIM6GQLQ1Yjm1H0Q= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -937,8 +939,8 @@ google.golang.org/genproto v0.0.0-20241015192408-796eee8c2d53 h1:Df6WuGvthPzc+Ji google.golang.org/genproto v0.0.0-20241015192408-796eee8c2d53/go.mod h1:fheguH3Am2dGp1LfXkrvwqC/KlFq8F0nLq3LryOMrrE= google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 h1:M0KvPgPmDZHPlbRbaNU1APr28TvwvvdUPlSv7PUvy8g= google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28/go.mod h1:dguCy7UOdZhTvLzDyt15+rOrawrpM4q7DD9dQ1P11P4= -google.golang.org/genproto/googleapis/rpc v0.0.0-20241104194629-dd2ea8efbc28 h1:XVhgTWWV3kGQlwJHR3upFWZeTsei6Oks1apkZSeonIE= -google.golang.org/genproto/googleapis/rpc v0.0.0-20241104194629-dd2ea8efbc28/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241113202542-65e8d215514f h1:C1QccEa9kUwvMgEUORqQD9S17QesQijxjZ84sO82mfo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241113202542-65e8d215514f/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI= google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= @@ -948,8 +950,8 @@ google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQ google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E= -google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= +google.golang.org/grpc v1.68.0 h1:aHQeeJbo8zAkAa3pRzrVjZlbz6uSfeOXlJNQM0RAbz0= +google.golang.org/grpc v1.68.0/go.mod h1:fmSPC5AsjSBCK54MyHRx48kpOti1/jRfOlwEWywNjWA= google.golang.org/grpc/stats/opentelemetry v0.0.0-20240907200651-3ffb98b2c93a h1:UIpYSuWdWHSzjwcAFRLjKcPXFZVVLXGEM23W+NWqipw= google.golang.org/grpc/stats/opentelemetry v0.0.0-20240907200651-3ffb98b2c93a/go.mod h1:9i1T9n4ZinTUZGgzENMi8MDDgbGC5mqTS75JAv6xN3A= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= @@ -961,8 +963,8 @@ google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= -google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +google.golang.org/protobuf v1.35.2 h1:8Ar7bF+apOIoThw1EdZl0p1oWvMqTHmpA2fRTyZO8io= +google.golang.org/protobuf v1.35.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/malfeasance/wire/malfeasance.go b/malfeasance/wire/malfeasance.go index 32509721a8..f4bb39f972 100644 --- a/malfeasance/wire/malfeasance.go +++ b/malfeasance/wire/malfeasance.go @@ -252,6 +252,12 @@ type InvalidPostIndexProof struct { InvalidIdx uint32 } +func (p *InvalidPostIndexProof) MarshalLogObject(encoder zapcore.ObjectEncoder) error { + encoder.AddObject("atx", &p.Atx) + encoder.AddUint32("invalid_index", p.InvalidIdx) + return nil +} + type BallotProofMsg struct { InnerMsg types.BallotMetadata diff --git a/mesh/ballotwriter/ballotwriter_test.go b/mesh/ballotwriter/ballotwriter_test.go index 1948bdee29..47841bc4f7 100644 --- a/mesh/ballotwriter/ballotwriter_test.go +++ b/mesh/ballotwriter/ballotwriter_test.go @@ -26,7 +26,7 @@ import ( var testLayer = types.LayerID(5) func TestMain(m *testing.M) { - types.SetLayersPerEpoch(10) + types.SetLayersPerEpoch(1) res := m.Run() os.Exit(res) } diff --git a/node/node.go b/node/node.go index d23ec8dd71..002a127bb8 100644 --- a/node/node.go +++ b/node/node.go @@ -735,7 +735,24 @@ func (app *App) initServices(ctx context.Context) error { return nil }) - fetcherWrapped := &layerFetcher{} + proposalsStore := store.New( + store.WithEvictedLayer(app.clock.CurrentLayer()), + store.WithLogger(app.addLogger(ProposalStoreLogger, lg).Zap()), + store.WithCapacity(app.Config.Tortoise.Zdist+1), + ) + + flog := app.addLogger(Fetcher, lg) + fetcher, err := fetch.NewFetch(app.cachedDB, proposalsStore, app.host, + fetch.WithContext(ctx), + fetch.WithConfig(app.Config.FETCH), + fetch.WithLogger(flog.Zap()), + ) + if err != nil { + return fmt.Errorf("create fetcher: %w", err) + } + app.eg.Go(func() error { + return blockssync.Sync(ctx, flog.Zap(), msh.MissingBlocks(), fetcher) + }) atxHandler := activation.NewHandler( app.host.ID(), @@ -744,7 +761,7 @@ func (app *App) initServices(ctx context.Context) error { app.edVerifier, app.clock, app.host, - fetcherWrapped, + fetcher, goldenATXID, validator, beaconProtocol, @@ -768,8 +785,9 @@ func (app *App) initServices(ctx context.Context) error { ) } - blockHandler := blocks.NewHandler(fetcherWrapped, app.db, trtl, msh, - blocks.WithLogger(app.addLogger(BlockHandlerLogger, lg).Zap())) + blockHandler := blocks.NewHandler(fetcher, app.db, trtl, msh, + blocks.WithLogger(app.addLogger(BlockHandlerLogger, lg).Zap()), + ) app.txHandler = txs.NewTxHandler( app.conState, @@ -819,26 +837,6 @@ func (app *App) initServices(ctx context.Context) error { app.certifier.Register(sig) } - proposalsStore := store.New( - store.WithEvictedLayer(app.clock.CurrentLayer()), - store.WithLogger(app.addLogger(ProposalStoreLogger, lg).Zap()), - store.WithCapacity(app.Config.Tortoise.Zdist+1), - ) - - flog := app.addLogger(Fetcher, lg) - fetcher, err := fetch.NewFetch(app.cachedDB, proposalsStore, app.host, - fetch.WithContext(ctx), - fetch.WithConfig(app.Config.FETCH), - fetch.WithLogger(flog.Zap()), - ) - if err != nil { - return fmt.Errorf("create fetcher: %w", err) - } - fetcherWrapped.Fetcher = fetcher - app.eg.Go(func() error { - return blockssync.Sync(ctx, flog.Zap(), msh.MissingBlocks(), fetcher) - }) - patrol := layerpatrol.New() syncerConf := app.Config.Sync syncerConf.HareDelayLayers = app.Config.Tortoise.Zdist @@ -852,7 +850,6 @@ func (app *App) initServices(ctx context.Context) error { newSyncer := syncer.NewSyncer( app.cachedDB, app.clock, - beaconProtocol, msh, trtl, fetcher, @@ -955,7 +952,7 @@ func (app *App) initServices(ctx context.Context) error { propHare, app.edVerifier, app.host, - fetcherWrapped, + fetcher, beaconProtocol, msh, trtl, @@ -978,7 +975,7 @@ func (app *App) initServices(ctx context.Context) error { proposalsStore, executor, msh, - fetcherWrapped, + fetcher, app.certifier, patrol, blocks.WithConfig(blocks.Config{ @@ -2284,10 +2281,6 @@ func (app *App) Host() *p2p.Host { return app.host } -type layerFetcher struct { - system.Fetcher -} - func decodeLoggerLevel(cfg *config.Config, name string) (zap.AtomicLevel, error) { lvl := zap.NewAtomicLevel() loggers := map[string]string{} diff --git a/p2p/pubsub/pubsub_test.go b/p2p/pubsub/pubsub_test.go index 3ca8f9a9d9..ba1cb4029f 100644 --- a/p2p/pubsub/pubsub_test.go +++ b/p2p/pubsub/pubsub_test.go @@ -3,6 +3,7 @@ package pubsub import ( "context" "fmt" + "sync/atomic" "testing" "time" @@ -14,14 +15,16 @@ import ( func TestGossip(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - n := 10 + defer cancel() + + n := 5 mesh, err := mocknet.FullMeshLinked(n) require.NoError(t, err) + topic := "test" - pubsubs := []PubSub{} + pubsubs := make([]PubSub, 0, n) count := n * n - received := make(chan []byte, count) + var received atomic.Int32 logger := zaptest.NewLogger(t) for i, h := range mesh.Hosts() { @@ -30,11 +33,11 @@ func TestGossip(t *testing.T) { require.NoError(t, err) pubsubs = append(pubsubs, ps) ps.Register(topic, func(ctx context.Context, pid peer.ID, msg []byte) error { - received <- msg + received.Add(1) return nil }) } - // connect after initializng gossip sub protocol for every peer. otherwise stream initialize + // connect after initializing gossip sub protocol for every peer. otherwise stream initialize // maybe fail if other side wasn't able to initialize gossipsub on time. require.NoError(t, mesh.ConnectAllButSelf()) require.Eventually(t, func() bool { @@ -44,9 +47,9 @@ func TestGossip(t *testing.T) { } } return true - }, 5*time.Second, 10*time.Millisecond) + }, 10*time.Second, 100*time.Millisecond) for i, ps := range pubsubs { require.NoError(t, ps.Publish(ctx, topic, []byte(mesh.Hosts()[i].ID()))) } - require.Eventually(t, func() bool { return len(received) == count }, 5*time.Second, 10*time.Millisecond) + require.Eventually(t, func() bool { return received.Load() == int32(count) }, 5*time.Second, 10*time.Millisecond) } diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index 7b4c549614..4b022e1803 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -245,8 +245,7 @@ func GetLastIDByNodeID(db sql.Executor, nodeID types.NodeID) (id types.ATXID, er } // PrevIDByNodeID returns the previous ATX ID for a given node ID and public epoch. -// It returns the newest ATX ID containing PoST of the given node ID -// that was published before the given public epoch. +// It returns the newest ATX ID containing PoST of the given node ID that was published in or before the given epoch. func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID) (id types.ATXID, err error) { enc := func(stmt *sql.Statement) { stmt.BindBytes(1, nodeID.Bytes()) @@ -259,7 +258,7 @@ func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID if rows, err := db.Exec(` SELECT atxid FROM posts - WHERE pubkey = ?1 AND publish_epoch < ?2 + WHERE pubkey = ?1 AND publish_epoch <= ?2 ORDER BY publish_epoch DESC LIMIT 1;`, enc, dec); err != nil { return types.EmptyATXID, fmt.Errorf("exec nodeID %v, epoch %d: %w", nodeID, pubEpoch, err) @@ -863,7 +862,10 @@ func IterateAtxIdsWithMalfeasance( return err } -func PrevATXCollision(db sql.Executor, prev types.ATXID, id types.NodeID) (types.ATXID, types.ATXID, error) { +// PrevATXCollisions returns all ATXs with the same prevATX as the given ATX ID from the same node ID. +// It is used to detect double-publishing and double poet registrations. +// The ATXs returned are ordered by received time so that the first one is the one that was seen first by the node. +func PrevATXCollisions(db sql.Executor, prev types.ATXID, id types.NodeID) ([]types.ATXID, error) { var atxs []types.ATXID enc := func(stmt *sql.Statement) { stmt.BindBytes(1, prev[:]) @@ -873,16 +875,22 @@ func PrevATXCollision(db sql.Executor, prev types.ATXID, id types.NodeID) (types var id types.ATXID stmt.ColumnBytes(0, id[:]) atxs = append(atxs, id) - return len(atxs) < 2 + return true } - _, err := db.Exec("SELECT atxid FROM posts WHERE prev_atxid = ?1 AND pubkey = ?2;", enc, dec) + query := `SELECT atxid FROM posts + WHERE prev_atxid = ?1 AND pubkey = ?2 + ORDER BY ( + SELECT received FROM atxs + WHERE id = atxid + );` + _, err := db.Exec(query, enc, dec) if err != nil { - return types.EmptyATXID, types.EmptyATXID, fmt.Errorf("error getting ATXs with same prevATX: %w", err) + return nil, fmt.Errorf("error getting ATXs with same prevATX: %w", err) } - if len(atxs) != 2 { - return types.EmptyATXID, types.EmptyATXID, sql.ErrNotFound + if len(atxs) < 2 { + return nil, sql.ErrNotFound } - return atxs[0], atxs[1], nil + return atxs, nil } func Units(db sql.Executor, atxID types.ATXID, nodeID types.NodeID) (uint32, error) { diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index 124b914d6f..c82e4378ba 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -1070,15 +1070,15 @@ func Test_PrevATXCollision(t *testing.T) { require.NoError(t, atxs.SetPost(db, atx2.ID(), prevATXID, 0, atx2.SmesherID, 10, atx2.PublishEpoch)) } - collision1, collision2, err := atxs.PrevATXCollision(db, prevATXID, sig.NodeID()) + collisions, err := atxs.PrevATXCollisions(db, prevATXID, sig.NodeID()) require.NoError(t, err) - require.ElementsMatch(t, []types.ATXID{atx1.ID(), atx2.ID()}, []types.ATXID{collision1, collision2}) + require.ElementsMatch(t, []types.ATXID{atx1.ID(), atx2.ID()}, collisions) - _, _, err = atxs.PrevATXCollision(db, types.RandomATXID(), sig.NodeID()) + _, err = atxs.PrevATXCollisions(db, types.RandomATXID(), sig.NodeID()) require.ErrorIs(t, err, sql.ErrNotFound) for _, id := range append(otherIds, types.RandomNodeID()) { - _, _, err := atxs.PrevATXCollision(db, prevATXID, id) + _, err := atxs.PrevATXCollisions(db, prevATXID, id) require.ErrorIs(t, err, sql.ErrNotFound) } } @@ -1392,13 +1392,17 @@ func TestPrevIDByNodeID(t *testing.T) { require.NoError(t, atxs.Add(db, atx2, types.AtxBlob{})) require.NoError(t, atxs.SetPost(db, atx2.ID(), types.EmptyATXID, 0, sig.NodeID(), 4, atx2.PublishEpoch)) - _, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 1) + _, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 0) require.ErrorIs(t, err, sql.ErrNotFound) - prevID, err := atxs.PrevIDByNodeID(db, sig.NodeID(), 2) + prevID, err := atxs.PrevIDByNodeID(db, sig.NodeID(), 1) require.NoError(t, err) require.Equal(t, atx1.ID(), prevID) + prevID, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 2) + require.NoError(t, err) + require.Equal(t, atx2.ID(), prevID) + prevID, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 3) require.NoError(t, err) require.Equal(t, atx2.ID(), prevID) diff --git a/sql/statesql/schema/migrations/0026_pruned_txs.sql b/sql/statesql/schema/migrations/0026_pruned_txs.sql new file mode 100644 index 0000000000..cb8ae8cc0f --- /dev/null +++ b/sql/statesql/schema/migrations/0026_pruned_txs.sql @@ -0,0 +1,5 @@ +CREATE TABLE evicted_mempool ( + id CHAR(32) NOT NULL, + time INT NOT NULL, + PRIMARY KEY (id) +); diff --git a/sql/statesql/schema/schema.sql b/sql/statesql/schema/schema.sql index 913b1f287f..000d87b80a 100755 --- a/sql/statesql/schema/schema.sql +++ b/sql/statesql/schema/schema.sql @@ -1,4 +1,4 @@ -PRAGMA user_version = 25; +PRAGMA user_version = 26; CREATE TABLE accounts ( address CHAR(24), @@ -69,6 +69,11 @@ CREATE TABLE certificates valid bool NOT NULL, PRIMARY KEY (layer, block) ); +CREATE TABLE evicted_mempool ( + id CHAR(32) NOT NULL, + time INT NOT NULL, + PRIMARY KEY (id) +); CREATE TABLE identities ( pubkey CHAR(32) PRIMARY KEY, diff --git a/sql/transactions/transactions.go b/sql/transactions/transactions.go index 215c3369ef..cfab44aadf 100644 --- a/sql/transactions/transactions.go +++ b/sql/transactions/transactions.go @@ -233,6 +233,17 @@ func Has(db sql.Executor, id types.TransactionID) (bool, error) { return rows > 0, nil } +func HasEvicted(db sql.Executor, id types.TransactionID) (bool, error) { + rows, err := db.Exec("select 1 from evicted_mempool where id = ?1", + func(stmt *sql.Statement) { + stmt.BindBytes(1, id.Bytes()) + }, nil) + if err != nil { + return false, fmt.Errorf("has evicted %s: %w", id, err) + } + return rows > 0, nil +} + // GetByAddress finds all transactions for an address. func GetByAddress(db sql.Executor, from, to types.LayerID, address types.Address) ([]*types.MeshTransaction, error) { var txs []*types.MeshTransaction @@ -295,6 +306,58 @@ func GetAcctPendingFromNonce(db sql.Executor, address types.Address, from uint64 }, "get acct pending from nonce") } +// GetAcctPendingToNonce get all pending transactions with nonce before `to` for the given address. +func GetAcctPendingToNonce(db sql.Executor, address types.Address, to uint64) ([]types.TransactionID, error) { + ids := make([]types.TransactionID, 0) + _, err := db.Exec(`select id from transactions + where principal = ?1 and nonce < ?2 and result is null + order by nonce asc, timestamp asc;`, + func(stmt *sql.Statement) { + stmt.BindBytes(1, address.Bytes()) + stmt.BindBytes(2, util.Uint64ToBytesBigEndian(to)) + }, func(stmt *sql.Statement) bool { + id := types.TransactionID{} + stmt.ColumnBytes(0, id[:]) + ids = append(ids, id) + return true + }) + if err != nil { + return nil, fmt.Errorf("get acct pending to nonce %s: %w", address, err) + } + return ids, nil +} + +func SetEvicted(db sql.Executor, id types.TransactionID) error { + if _, err := db.Exec("insert into evicted_mempool (id, time) values (?1, ?2) on conflict do nothing;", + func(stmt *sql.Statement) { + stmt.BindBytes(1, id.Bytes()) + stmt.BindInt64(2, time.Now().UnixNano()) + }, nil); err != nil { + return fmt.Errorf("set evicted %s: %w", id, err) + } + return nil +} + +func Delete(db sql.Executor, id types.TransactionID) error { + if _, err := db.Exec("delete from transactions where id = ?1;", + func(stmt *sql.Statement) { + stmt.BindBytes(1, id.Bytes()) + }, nil); err != nil { + return fmt.Errorf("delete %s: %w", id, err) + } + return nil +} + +func PruneEvicted(db sql.Executor, before time.Time) error { + if _, err := db.Exec("delete from evicted_mempool where time < ?1;", + func(stmt *sql.Statement) { + stmt.BindInt64(1, before.UnixNano()) + }, nil); err != nil { + return fmt.Errorf("prune evicted %w", err) + } + return nil +} + // query MUST ensure that this order of fields tx, header, layer, block, timestamp, id. func queryPending( db sql.Executor, diff --git a/sql/transactions/transactions_test.go b/sql/transactions/transactions_test.go index 0bdac033b3..8818973a5c 100644 --- a/sql/transactions/transactions_test.go +++ b/sql/transactions/transactions_test.go @@ -562,3 +562,73 @@ func TestTransactionInBlock(t *testing.T) { _, _, err = transactions.TransactionInBlock(db, tid, lids[2]) require.ErrorIs(t, err, sql.ErrNotFound) } + +func TestTransactionEvictMempool(t *testing.T) { + principals := []types.Address{ + {1}, + {2}, + {3}, + } + txs := []types.Transaction{ + { + RawTx: types.RawTx{ID: types.TransactionID{1}}, + TxHeader: &types.TxHeader{Principal: principals[0], Nonce: 0}, + }, + { + RawTx: types.RawTx{ID: types.TransactionID{2}}, + TxHeader: &types.TxHeader{Principal: principals[0], Nonce: 1}, + }, + { + RawTx: types.RawTx{ID: types.TransactionID{3}}, + TxHeader: &types.TxHeader{Principal: principals[1], Nonce: 0}, + }, + } + db := statesql.InMemoryTest(t) + for _, tx := range txs { + require.NoError(t, transactions.Add(db, &tx, time.Time{})) + } + err := transactions.SetEvicted(db, types.TransactionID{1}) + require.NoError(t, err) + + err = transactions.Delete(db, types.TransactionID{1}) + require.NoError(t, err) + + pending, err := transactions.GetAcctPendingFromNonce(db, principals[0], 1) + require.NoError(t, err) + require.Len(t, pending, 1) + require.Equal(t, pending[0].ID, txs[1].ID) + + pending, err = transactions.GetAcctPendingFromNonce(db, principals[1], 0) + require.NoError(t, err) + require.Len(t, pending, 1) + require.Equal(t, pending[0].ID, txs[2].ID) + + has, err := transactions.Has(db, txs[0].ID) + require.False(t, has) + require.NoError(t, err) + + has, err = transactions.HasEvicted(db, txs[0].ID) + require.True(t, has) + require.NoError(t, err) +} + +func TestPruneEvicted(t *testing.T) { + txId := types.TransactionID{1} + db := statesql.InMemoryTest(t) + db.Exec(`insert into evicted_mempool (id, time) values (?1,?2);`, + func(stmt *sql.Statement) { + stmt.BindBytes(1, txId.Bytes()) + stmt.BindInt64(2, time.Now().Add(-13*time.Hour).UnixNano()) + }, nil) + + has, err := transactions.HasEvicted(db, txId) + require.True(t, has) + require.NoError(t, err) + + err = transactions.PruneEvicted(db, time.Now().Add(-12*time.Hour)) + require.NoError(t, err) + + has, err = transactions.HasEvicted(db, txId) + require.False(t, has) + require.NoError(t, err) +} diff --git a/sync2/dbset/dbset.go b/sync2/dbset/dbset.go new file mode 100644 index 0000000000..25c2101909 --- /dev/null +++ b/sync2/dbset/dbset.go @@ -0,0 +1,274 @@ +package dbset + +import ( + "fmt" + "maps" + "sync" + "time" + + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/fptree" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +// DBSet is an implementation of rangesync.OrderedSet that uses an SQL database +// as its backing store. It uses an FPTree to perform efficient range queries. +type DBSet struct { + loadMtx sync.Mutex + db sql.Executor + ft *fptree.FPTree + st *sqlstore.SyncedTable + snapshot *sqlstore.SyncedTableSnapshot + dbStore *fptree.DBBackedStore + keyLen int + maxDepth int + received map[string]struct{} +} + +var _ rangesync.OrderedSet = &DBSet{} + +// NewDBSet creates a new DBSet. +func NewDBSet( + db sql.Executor, + st *sqlstore.SyncedTable, + keyLen, maxDepth int, +) *DBSet { + return &DBSet{ + db: db, + st: st, + keyLen: keyLen, + maxDepth: maxDepth, + } +} + +func (d *DBSet) handleIDfromDB(stmt *sql.Statement) bool { + id := make(rangesync.KeyBytes, d.keyLen) + stmt.ColumnBytes(0, id[:]) + d.ft.AddStoredKey(id) + return true +} + +// EnsureLoaded ensures that the DBSet is loaded and ready to be used. +func (d *DBSet) EnsureLoaded() error { + d.loadMtx.Lock() + defer d.loadMtx.Unlock() + if d.ft != nil { + return nil + } + var err error + d.snapshot, err = d.st.Snapshot(d.db) + if err != nil { + return fmt.Errorf("error taking snapshot: %w", err) + } + count, err := d.snapshot.LoadCount(d.db) + if err != nil { + return fmt.Errorf("error loading count: %w", err) + } + d.dbStore = fptree.NewDBBackedStore(d.db, d.snapshot, count, d.keyLen) + d.ft = fptree.NewFPTree(count, d.dbStore, d.keyLen, d.maxDepth) + return d.snapshot.Load(d.db, d.handleIDfromDB) +} + +// Received returns a sequence of all items that have been received. +// Implements rangesync.OrderedSet. +func (d *DBSet) Received() rangesync.SeqResult { + return rangesync.SeqResult{ + Seq: func(yield func(k rangesync.KeyBytes) bool) { + for k := range d.received { + if !yield(rangesync.KeyBytes(k)) { + return + } + } + }, + Error: rangesync.NoSeqError, + } +} + +// Add adds an item to the DBSet. +// Implements rangesync.OrderedSet. +func (d *DBSet) Add(k rangesync.KeyBytes) error { + if has, err := d.Has(k); err != nil { + return fmt.Errorf("checking if item exists: %w", err) + } else if has { + return nil + } + d.ft.RegisterKey(k) + return nil +} + +// Receive handles a newly received item, arranging for it to be returned as part of the +// sequence returned by Received. +// Implements rangesync.OrderedSet. +func (d *DBSet) Receive(k rangesync.KeyBytes) error { + if d.received == nil { + d.received = make(map[string]struct{}) + } + d.received[string(k)] = struct{}{} + return nil +} + +func (d *DBSet) firstItem() (rangesync.KeyBytes, error) { + if err := d.EnsureLoaded(); err != nil { + return nil, err + } + return d.ft.All().First() +} + +// GetRangeInfo returns information about the range of items in the DBSet. +// Implements rangesync.OrderedSet. +func (d *DBSet) GetRangeInfo(x, y rangesync.KeyBytes) (rangesync.RangeInfo, error) { + if err := d.EnsureLoaded(); err != nil { + return rangesync.RangeInfo{}, err + } + if d.ft.Count() == 0 { + return rangesync.RangeInfo{ + Items: rangesync.EmptySeqResult(), + }, nil + } + if x == nil || y == nil { + if x != nil || y != nil { + panic("BUG: GetRangeInfo called with one of x/y nil but not both") + } + var err error + x, err = d.firstItem() + if err != nil { + return rangesync.RangeInfo{}, fmt.Errorf("getting first item: %w", err) + } + y = x + } + fpr, err := d.ft.FingerprintInterval(x, y, -1) + if err != nil { + return rangesync.RangeInfo{}, err + } + return rangesync.RangeInfo{ + Fingerprint: fpr.FP, + Count: int(fpr.Count), + Items: fpr.Items, + }, nil +} + +// SplitRange splits the range of items in the DBSet into two parts, +// returning information about eachn part and the middle item. +// Implements rangesync.OrderedSet. +func (d *DBSet) SplitRange(x, y rangesync.KeyBytes, count int) (rangesync.SplitInfo, error) { + if count <= 0 { + panic("BUG: bad split count") + } + + if err := d.EnsureLoaded(); err != nil { + return rangesync.SplitInfo{}, err + } + + sr, err := d.ft.Split(x, y, count) + if err != nil { + return rangesync.SplitInfo{}, err + } + + return rangesync.SplitInfo{ + Parts: [2]rangesync.RangeInfo{ + { + Fingerprint: sr.Part0.FP, + Count: int(sr.Part0.Count), + Items: sr.Part0.Items, + }, + { + Fingerprint: sr.Part1.FP, + Count: int(sr.Part1.Count), + Items: sr.Part1.Items, + }, + }, + Middle: sr.Middle, + }, nil +} + +// Items returns a sequence of all items in the DBSet. +// Implements rangesync.OrderedSet. +func (d *DBSet) Items() rangesync.SeqResult { + if err := d.EnsureLoaded(); err != nil { + return rangesync.ErrorSeqResult(err) + } + return d.ft.All() +} + +// Empty returns true if the DBSet is empty. +// Implements rangesync.OrderedSet. +func (d *DBSet) Empty() (bool, error) { + if err := d.EnsureLoaded(); err != nil { + return false, err + } + return d.ft.Count() == 0, nil +} + +// Advance advances the DBSet to the latest state of the underlying database table. +func (d *DBSet) Advance() error { + if err := d.EnsureLoaded(); err != nil { + return fmt.Errorf("loading DBSet: %w", err) + } + d.loadMtx.Lock() + defer d.loadMtx.Unlock() + oldSnapshot := d.snapshot + var err error + d.snapshot, err = d.st.Snapshot(d.db) + if err != nil { + return fmt.Errorf("error taking snapshot: %w", err) + } + d.dbStore.SetSnapshot(d.snapshot) + return d.snapshot.LoadSinceSnapshot(d.db, oldSnapshot, d.handleIDfromDB) +} + +// Copy creates a copy of the DBSet. +// Implements rangesync.OrderedSet. +func (d *DBSet) Copy(syncScope bool) rangesync.OrderedSet { + d.loadMtx.Lock() + defer d.loadMtx.Unlock() + if d.ft == nil { + // FIXME + panic("BUG: can't copy the DBItemStore before it's loaded") + } + ft := d.ft.Clone().(*fptree.FPTree) + return &DBSet{ + db: d.db, + ft: ft, + st: d.st, + keyLen: d.keyLen, + maxDepth: d.maxDepth, + dbStore: d.dbStore, + received: maps.Clone(d.received), + } +} + +// Has returns true if the DBSet contains the given item. +func (d *DBSet) Has(k rangesync.KeyBytes) (bool, error) { + if err := d.EnsureLoaded(); err != nil { + return false, err + } + + // checkKey may have false positives, but not false negatives, and it's much + // faster than querying the database + if !d.ft.CheckKey(k) { + return false, nil + } + + first, err := d.dbStore.From(k, 1).First() + if err != nil { + return false, err + } + return first.Compare(k) == 0, nil +} + +// Recent returns a sequence of items that have been added to the DBSet since the given time. +func (d *DBSet) Recent(since time.Time) (rangesync.SeqResult, int) { + return d.dbStore.Since(make(rangesync.KeyBytes, d.keyLen), since.UnixNano()) +} + +// Release releases resources associated with the DBSet. +func (d *DBSet) Release() error { + d.loadMtx.Lock() + defer d.loadMtx.Unlock() + if d.ft != nil { + d.ft.Release() + d.ft = nil + } + return nil +} diff --git a/sync2/dbset/dbset_test.go b/sync2/dbset/dbset_test.go new file mode 100644 index 0000000000..23625d6444 --- /dev/null +++ b/sync2/dbset/dbset_test.go @@ -0,0 +1,359 @@ +package dbset_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sync2/dbset" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +const ( + testKeyLen = 32 + testDepth = 24 +) + +func requireEmpty(t *testing.T, sr rangesync.SeqResult) { + for range sr.Seq { + require.Fail(t, "expected an empty sequence") + } + require.NoError(t, sr.Error()) +} + +func firstKey(t *testing.T, sr rangesync.SeqResult) rangesync.KeyBytes { + k, err := sr.First() + require.NoError(t, err) + return k +} + +func TestDBSet_Empty(t *testing.T) { + db := sqlstore.PopulateDB(t, testKeyLen, nil) + st := &sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := dbset.NewDBSet(db, st, testKeyLen, testDepth) + defer s.Release() + empty, err := s.Empty() + require.NoError(t, err) + require.True(t, empty) + requireEmpty(t, s.Items()) + requireEmpty(t, s.Received()) + + info, err := s.GetRangeInfo(nil, nil) + require.NoError(t, err) + require.Equal(t, 0, info.Count) + require.Equal(t, "000000000000000000000000", info.Fingerprint.String()) + requireEmpty(t, info.Items) + + info, err = s.GetRangeInfo( + rangesync.MustParseHexKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("0000000000000000000000000000000000000000000000000000000000000000")) + require.NoError(t, err) + require.Equal(t, 0, info.Count) + require.Equal(t, "000000000000000000000000", info.Fingerprint.String()) + requireEmpty(t, info.Items) + + info, err = s.GetRangeInfo( + rangesync.MustParseHexKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("9999000000000000000000000000000000000000000000000000000000000000")) + require.NoError(t, err) + require.Equal(t, 0, info.Count) + require.Equal(t, "000000000000000000000000", info.Fingerprint.String()) + requireEmpty(t, info.Items) +} + +func TestDBSet(t *testing.T) { + ids := []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("123456789abcdef0000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("5555555555555555555555555555555555555555555555555555555555555555"), + rangesync.MustParseHexKeyBytes("8888888888888888888888888888888888888888888888888888888888888888"), + rangesync.MustParseHexKeyBytes("abcdef1234567890000000000000000000000000000000000000000000000000"), + } + db := sqlstore.PopulateDB(t, testKeyLen, ids) + st := &sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := dbset.NewDBSet(db, st, testKeyLen, testDepth) + defer s.Release() + require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", + firstKey(t, s.Items()).String()) + has, err := s.Has( + rangesync.MustParseHexKeyBytes("9876000000000000000000000000000000000000000000000000000000000000")) + require.NoError(t, err) + require.False(t, has) + + for _, tc := range []struct { + xIdx, yIdx int + limit int + fp string + count int + startIdx, endIdx int + }{ + { + xIdx: 1, + yIdx: 1, + limit: -1, + fp: "642464b773377bbddddddddd", + count: 5, + startIdx: 1, + endIdx: 1, + }, + { + xIdx: -1, + yIdx: -1, + limit: -1, + fp: "642464b773377bbddddddddd", + count: 5, + startIdx: 0, + endIdx: 0, + }, + { + xIdx: 0, + yIdx: 3, + limit: -1, + fp: "4761032dcfe98ba555555555", + count: 3, + startIdx: 0, + endIdx: 3, + }, + { + xIdx: 2, + yIdx: 0, + limit: -1, + fp: "761032cfe98ba54ddddddddd", + count: 3, + startIdx: 2, + endIdx: 0, + }, + { + xIdx: 3, + yIdx: 2, + limit: 3, + fp: "2345679abcdef01888888888", + count: 3, + startIdx: 3, + endIdx: 1, + }, + } { + name := fmt.Sprintf("%d-%d_%d", tc.xIdx, tc.yIdx, tc.limit) + t.Run(name, func(t *testing.T) { + var x, y rangesync.KeyBytes + if tc.xIdx >= 0 { + x = ids[tc.xIdx] + y = ids[tc.yIdx] + } + t.Logf("x %v y %v limit %d", x, y, tc.limit) + var info rangesync.RangeInfo + if tc.limit < 0 { + info, err = s.GetRangeInfo(x, y) + require.NoError(t, err) + } else { + sr, err := s.SplitRange(x, y, tc.limit) + require.NoError(t, err) + info = sr.Parts[0] + } + require.Equal(t, tc.count, info.Count) + require.Equal(t, tc.fp, info.Fingerprint.String()) + require.Equal(t, ids[tc.startIdx], firstKey(t, info.Items)) + has, err := s.Has(ids[tc.startIdx]) + require.NoError(t, err) + require.True(t, has) + has, err = s.Has(ids[tc.endIdx]) + require.NoError(t, err) + require.True(t, has) + }) + } +} + +func TestDBSet_Receive(t *testing.T) { + ids := []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("123456789abcdef0000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("5555555555555555555555555555555555555555555555555555555555555555"), + rangesync.MustParseHexKeyBytes("8888888888888888888888888888888888888888888888888888888888888888"), + } + db := sqlstore.PopulateDB(t, testKeyLen, ids) + st := &sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := dbset.NewDBSet(db, st, testKeyLen, testDepth) + defer s.Release() + require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", + firstKey(t, s.Items()).String()) + + newID := rangesync.MustParseHexKeyBytes("abcdef1234567890000000000000000000000000000000000000000000000000") + require.NoError(t, s.Receive(newID)) + + recvd := s.Received() + items, err := recvd.FirstN(1) + require.NoError(t, err) + require.NoError(t, err) + require.Equal(t, []rangesync.KeyBytes{newID}, items) + + info, err := s.GetRangeInfo(ids[2], ids[0]) + require.NoError(t, err) + require.Equal(t, 2, info.Count) + require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.String()) +} + +func TestDBSet_Copy(t *testing.T) { + ids := []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("123456789abcdef0000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("5555555555555555555555555555555555555555555555555555555555555555"), + rangesync.MustParseHexKeyBytes("8888888888888888888888888888888888888888888888888888888888888888"), + } + db := sqlstore.PopulateDB(t, testKeyLen, ids) + st := &sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := dbset.NewDBSet(db, st, testKeyLen, testDepth) + defer s.Release() + require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", + firstKey(t, s.Items()).String()) + + copy := s.Copy(false) + + info, err := copy.GetRangeInfo(ids[2], ids[0]) + require.NoError(t, err) + require.Equal(t, 2, info.Count) + require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[2], firstKey(t, info.Items)) + + newID := rangesync.MustParseHexKeyBytes("abcdef1234567890000000000000000000000000000000000000000000000000") + require.NoError(t, copy.Receive(newID)) + + info, err = s.GetRangeInfo(ids[2], ids[0]) + require.NoError(t, err) + require.Equal(t, 2, info.Count) + require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[2], firstKey(t, info.Items)) + + items, err := s.Received().FirstN(100) + require.NoError(t, err) + require.Empty(t, items) + + info, err = s.GetRangeInfo(ids[2], ids[0]) + require.NoError(t, err) + require.Equal(t, 2, info.Count) + require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[2], firstKey(t, info.Items)) + + items, err = copy.(*dbset.DBSet).Received().FirstN(100) + require.NoError(t, err) + require.Equal(t, []rangesync.KeyBytes{newID}, items) +} + +func TestDBItemStore_Advance(t *testing.T) { + ids := []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("123456789abcdef0000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("5555555555555555555555555555555555555555555555555555555555555555"), + rangesync.MustParseHexKeyBytes("8888888888888888888888888888888888888888888888888888888888888888"), + } + db := sqlstore.PopulateDB(t, testKeyLen, ids) + st := &sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := dbset.NewDBSet(db, st, testKeyLen, testDepth) + defer s.Release() + require.NoError(t, s.EnsureLoaded()) + + copy := s.Copy(false) + + info, err := s.GetRangeInfo(ids[0], ids[0]) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + info, err = copy.GetRangeInfo(ids[0], ids[0]) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + sqlstore.InsertDBItems(t, db, []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("abcdef1234567890000000000000000000000000000000000000000000000000"), + }) + + info, err = s.GetRangeInfo(ids[0], ids[0]) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + info, err = copy.GetRangeInfo(ids[0], ids[0]) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + require.NoError(t, s.Advance()) + + info, err = s.GetRangeInfo(ids[0], ids[0]) + require.NoError(t, err) + require.Equal(t, 5, info.Count) + require.Equal(t, "642464b773377bbddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + info, err = copy.GetRangeInfo(ids[0], ids[0]) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + info, err = s.Copy(false).GetRangeInfo(ids[0], ids[0]) + require.NoError(t, err) + require.Equal(t, 5, info.Count) + require.Equal(t, "642464b773377bbddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) +} + +func TestDBSet_Added(t *testing.T) { + ids := []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("123456789abcdef0000000000000000000000000000000000000000000000000"), + rangesync.MustParseHexKeyBytes("5555555555555555555555555555555555555555555555555555555555555555"), + rangesync.MustParseHexKeyBytes("8888888888888888888888888888888888888888888888888888888888888888"), + rangesync.MustParseHexKeyBytes("abcdef1234567890000000000000000000000000000000000000000000000000"), + } + db := sqlstore.PopulateDB(t, testKeyLen, ids) + st := &sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := dbset.NewDBSet(db, st, testKeyLen, testDepth) + defer s.Release() + requireEmpty(t, s.Received()) + + add := []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("3333333333333333333333333333333333333333333333333333333333333333"), + rangesync.MustParseHexKeyBytes("4444444444444444444444444444444444444444444444444444444444444444"), + } + for _, item := range add { + require.NoError(t, s.Receive(item)) + } + + require.NoError(t, s.EnsureLoaded()) + + added, err := s.Received().FirstN(3) + require.NoError(t, err) + require.ElementsMatch(t, []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("3333333333333333333333333333333333333333333333333333333333333333"), + rangesync.MustParseHexKeyBytes("4444444444444444444444444444444444444444444444444444444444444444"), + }, added) + + added1, err := s.Copy(false).(*dbset.DBSet).Received().FirstN(3) + require.NoError(t, err) + require.ElementsMatch(t, added, added1) +} diff --git a/sync2/dbset/p2p_test.go b/sync2/dbset/p2p_test.go new file mode 100644 index 0000000000..9321f0fcc4 --- /dev/null +++ b/sync2/dbset/p2p_test.go @@ -0,0 +1,567 @@ +package dbset_test + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "slices" + "testing" + "time" + + "github.com/jonboulle/clockwork" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zaptest" + "golang.org/x/sync/errgroup" + + "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/p2p/server" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/dbset" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +var startDate = time.Date(2024, 8, 29, 18, 0, 0, 0, time.UTC) + +type fooRow struct { + id rangesync.KeyBytes + ts int64 +} + +func insertRow(t testing.TB, db sql.Executor, row fooRow) { + _, err := db.Exec( + "insert into foo (id, received) values (?, ?)", + func(stmt *sql.Statement) { + stmt.BindBytes(1, row.id) + stmt.BindInt64(2, row.ts) + }, nil) + require.NoError(t, err) +} + +func populateFoo(t testing.TB, rows []fooRow) (db sql.Database, dir string) { + // Use file-based database for more accurate benchmarks + dir = t.TempDir() + db, err := sql.Open("file:"+filepath.Join(dir, "temp.db"), + sql.WithNoCheckSchemaDrift()) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, db.Close()) + }) + require.NoError(t, db.WithTx(context.Background(), func(tx sql.Transaction) error { + _, err := tx.Exec( + "create table foo(id char(32) not null primary key, received int)", + nil, nil) + require.NoError(t, err) + for _, row := range rows { + insertRow(t, tx, row) + } + return nil + })) + return db, dir +} + +type syncTracer struct { + receivedItems int + sentItems int +} + +var _ rangesync.Tracer = &syncTracer{} + +func (tr *syncTracer) OnDumbSync() {} + +func (tr *syncTracer) OnRecent(receivedItems, sentItems int) { + tr.receivedItems += receivedItems + tr.sentItems += sentItems +} + +func addReceived(t testing.TB, db sql.Executor, to, from *dbset.DBSet) { + sr := from.Received() + for k := range sr.Seq { + has, err := to.Has(k) + require.NoError(t, err) + if !has { + insertRow(t, db, fooRow{id: k, ts: time.Now().UnixNano()}) + } + } + require.NoError(t, sr.Error()) + require.NoError(t, to.Advance()) +} + +type startStopTimer interface { + StartTimer() + StopTimer() +} + +func startTimer(tb testing.TB) { + if st, ok := tb.(startStopTimer); ok { + st.StartTimer() + } +} + +func stopTimer(tb testing.TB) { + if st, ok := tb.(startStopTimer); ok { + st.StopTimer() + } +} + +func dbFromRows(t testing.TB, rows []fooRow) sql.Transaction { + db, _ := populateFoo(t, rows) + tx, err := db.Tx(context.Background()) + require.NoError(t, err) + t.Cleanup(func() { tx.Release() }) + return tx +} + +func verifyP2P( + t testing.TB, + rowsA, rowsB []fooRow, + combined []rangesync.KeyBytes, + clockAt time.Time, + receivedRecent, sentRecent bool, + maxDepth int, + cfg rangesync.RangeSetReconcilerConfig, +) { + stopTimer(t) + dbA := dbFromRows(t, rowsA) + dbB := dbFromRows(t, rowsB) + runSync(t, dbA, dbB, combined, clockAt, receivedRecent, sentRecent, true, maxDepth, cfg) +} + +func runSync( + t testing.TB, + dbA, dbB sql.Executor, + combined []rangesync.KeyBytes, + clockAt time.Time, + receivedRecent, sentRecent, verify bool, + maxDepth int, + cfg rangesync.RangeSetReconcilerConfig, +) { + log := zaptest.NewLogger(t) + mesh, err := mocknet.FullMeshConnected(2) + require.NoError(t, err) + proto := "itest" + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second) + st := &sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + TimestampColumn: "received", + } + + t.Logf("using maxDepth %d", maxDepth) + + setA := dbset.NewDBSet(dbA, st, testKeyLen, maxDepth) + loadStart := time.Now() + require.NoError(t, setA.EnsureLoaded()) + t.Logf("loaded setA in %v", time.Since(loadStart)) + + setB := dbset.NewDBSet(dbB, st, testKeyLen, maxDepth) + loadStart = time.Now() + require.NoError(t, setB.EnsureLoaded()) + t.Logf("loaded setB in %v", time.Since(loadStart)) + + empty, err := setB.Empty() + require.NoError(t, err) + var x rangesync.KeyBytes + if !empty { + k, err := setB.Items().First() + require.NoError(t, err) + x := k.Clone() + x.Trim(maxDepth) + } + + var tr syncTracer + + srvPeerID := mesh.Hosts()[0].ID() + clock := clockwork.NewFakeClockAt(clockAt) + // Use the following to enable verbose logging which may slow down the tests + // syncLogger := log + syncLogger := zap.NewNop() + cfg.MaxReconcDiff = 1 // always reconcile + pssA := rangesync.NewPairwiseSetSyncerInternal(syncLogger.Named("sideA"), nil, "test", cfg, &tr, clock) + d := rangesync.NewDispatcher(log) + syncSetA := setA.Copy(false).(*dbset.DBSet) + pssA.Register(d, syncSetA) + srv := server.New(mesh.Hosts()[0], proto, + d.Dispatch, + server.WithTimeout(time.Minute), + server.WithLog(log)) + + var eg errgroup.Group + + client := server.New(mesh.Hosts()[1], proto, + func(_ context.Context, _ p2p.Peer, _ []byte, _ io.ReadWriter) error { + return errors.New("client should not receive requests") + }, + server.WithTimeout(time.Minute), + server.WithLog(log)) + + defer func() { + cancel() + eg.Wait() + }() + eg.Go(func() error { + return srv.Run(ctx) + }) + + // Wait for the server to activate + require.Eventually(t, func() bool { + for _, h := range mesh.Hosts() { + if len(h.Mux().Protocols()) == 0 { + return false + } + } + return true + }, time.Second, 10*time.Millisecond) + + startTimer(t) + pssB := rangesync.NewPairwiseSetSyncerInternal(syncLogger.Named("sideB"), client, "test", cfg, &tr, clock) + + tStart := time.Now() + syncSetB := setB.Copy(false).(*dbset.DBSet) + require.NoError(t, pssB.Sync(ctx, srvPeerID, syncSetB, x, x)) + stopTimer(t) + t.Logf("synced in %v, sent %d, recv %d", time.Since(tStart), pssB.Sent(), pssB.Received()) + + if verify { + // Check that the sets are equal after we add the received items + addReceived(t, dbA, setA, syncSetA) + addReceived(t, dbB, setB, syncSetB) + + require.Equal(t, receivedRecent, tr.receivedItems > 0) + require.Equal(t, sentRecent, tr.sentItems > 0) + + if len(combined) == 0 { + return + } + + actItemsA, err := setA.Items().Collect() + require.NoError(t, err) + + actItemsB, err := setB.Items().Collect() + require.NoError(t, err) + + assert.Equal(t, combined, actItemsA) + assert.Equal(t, actItemsA, actItemsB) + } +} + +func fooR(id string, seconds int) fooRow { + return fooRow{ + rangesync.MustParseHexKeyBytes(id), + startDate.Add(time.Duration(seconds) * time.Second).UnixNano(), + } +} + +func genRandomRows(nShared, nUniqueA, nUniqueB int) (rowsA, rowsB []fooRow, combined []rangesync.KeyBytes) { + combined = make([]rangesync.KeyBytes, 0, nShared+nUniqueA+nUniqueB) + rowsA = make([]fooRow, nShared+nUniqueA) + for i := range rowsA { + k := rangesync.RandomKeyBytes(testKeyLen) + rowsA[i] = fooRow{ + id: k, + ts: startDate.Add(time.Duration(i) * time.Second).UnixNano(), + } + combined = append(combined, k) + } + rowsB = make([]fooRow, nShared+nUniqueB) + for i := range rowsB { + if i < nShared { + rowsB[i] = fooRow{ + id: slices.Clone(rowsA[i].id), + ts: rowsA[i].ts, + } + } else { + k := rangesync.RandomKeyBytes(testKeyLen) + rowsB[i] = fooRow{ + id: k, + ts: startDate.Add(time.Duration(i) * time.Second).UnixNano(), + } + combined = append(combined, k) + } + } + slices.SortFunc(combined, func(a, b rangesync.KeyBytes) int { + return a.Compare(b) + }) + return rowsA, rowsB, combined +} + +func TestP2P(t *testing.T) { + // In this test, we synchronize two sets of items, A and B, and verify that they + // are equal. The sets are represented by two SQLite databases, each containing a + // table `foo` with columns `id` and `received`. The `id` column is a 32-byte id, + // and the `received` column is a timestamp in nanoseconds which is used to test + // recent sync mechanism. + const maxDepth = 24 + hexID := rangesync.MustParseHexKeyBytes + t.Run("predefined items", func(t *testing.T) { + verifyP2P( + t, []fooRow{ + fooR("1111111111111111111111111111111111111111111111111111111111111111", 10), + fooR("123456789abcdef0000000000000000000000000000000000000000000000000", 20), + fooR("8888888888888888888888888888888888888888888888888888888888888888", 30), + fooR("abcdef1234567890000000000000000000000000000000000000000000000000", 40), + }, + []fooRow{ + fooR("1111111111111111111111111111111111111111111111111111111111111111", 11), + fooR("123456789abcdef0000000000000000000000000000000000000000000000000", 12), + fooR("5555555555555555555555555555555555555555555555555555555555555555", 13), + fooR("8888888888888888888888888888888888888888888888888888888888888888", 14), + }, + []rangesync.KeyBytes{ + hexID("1111111111111111111111111111111111111111111111111111111111111111"), + hexID("123456789abcdef0000000000000000000000000000000000000000000000000"), + hexID("5555555555555555555555555555555555555555555555555555555555555555"), + hexID("8888888888888888888888888888888888888888888888888888888888888888"), + hexID("abcdef1234567890000000000000000000000000000000000000000000000000"), + }, + startDate, + false, + false, + maxDepth, + rangesync.DefaultConfig(), + ) + }) + t.Run("predefined items 2", func(t *testing.T) { + verifyP2P( + t, []fooRow{ + fooR("0e69888877324da35693decc7ded1b2bac16d394ced869af494568d66473a6f0", 10), + fooR("3a78db9e386493402561d9c6f69a6b434a62388f61d06d960598ebf29a3a2187", 20), + fooR("66c9aa8f3be7da713db66e56cc165a46764f88d3113244dd5964bb0a10ccacc3", 30), + fooR("72e1adaaf140d809a5da325a197341a453b00807ef8d8995fd3c8079b917c9d7", 40), + fooR("782c24553b0a8cf1d95f632054b7215be192facfb177cfd1312901dd4c9e0bfd", 50), + fooR("9e11fdb099f1118144738f9b68ca601e74b97280fd7bbc97cfc377f432e9b7b5", 60), + }, + []fooRow{ + fooR("0e69888877324da35693decc7ded1b2bac16d394ced869af494568d66473a6f0", 11), + fooR("3a78db9e386493402561d9c6f69a6b434a62388f61d06d960598ebf29a3a2187", 12), + fooR("66c9aa8f3be7da713db66e56cc165a46764f88d3113244dd5964bb0a10ccacc3", 13), + fooR("90b25f2d1ee9c9e2d20df5f2226d14ee4223ea27ba565a49aa66a9c44a51c241", 14), + fooR("9e11fdb099f1118144738f9b68ca601e74b97280fd7bbc97cfc377f432e9b7b5", 15), + fooR("c1690e47798295cca02392cbfc0a86cb5204878c04a29b3ae7701b6b51681128", 16), + }, + []rangesync.KeyBytes{ + hexID("0e69888877324da35693decc7ded1b2bac16d394ced869af494568d66473a6f0"), + hexID("3a78db9e386493402561d9c6f69a6b434a62388f61d06d960598ebf29a3a2187"), + hexID("66c9aa8f3be7da713db66e56cc165a46764f88d3113244dd5964bb0a10ccacc3"), + hexID("72e1adaaf140d809a5da325a197341a453b00807ef8d8995fd3c8079b917c9d7"), + hexID("782c24553b0a8cf1d95f632054b7215be192facfb177cfd1312901dd4c9e0bfd"), + hexID("90b25f2d1ee9c9e2d20df5f2226d14ee4223ea27ba565a49aa66a9c44a51c241"), + hexID("9e11fdb099f1118144738f9b68ca601e74b97280fd7bbc97cfc377f432e9b7b5"), + hexID("c1690e47798295cca02392cbfc0a86cb5204878c04a29b3ae7701b6b51681128"), + }, + startDate, + false, + false, + maxDepth, + rangesync.DefaultConfig(), + ) + }) + t.Run("predefined items 3", func(t *testing.T) { + verifyP2P( + t, []fooRow{ + fooR("08addda193ce5c8dfa56d58efaaaa51ccb534738027c4c73631f76811702e54f", 5), + fooR("112d34ac1724faa17502e9f1654808daa43d8e99c384c42faeccc6c713993079", 3), + fooR("8599b0264623ede5d198fd2caa537720e011ce17bd9f34c140de269f472a1126", 4), + fooR("9e8dc977998b3cbc30071202cb8ebb0c8bfa2c400fd28067f6d43c7e92acd077", 2), + fooR("a67249d334bd0c68e92b4c6d8716cdc218130c0e765838e52890133d07d35d48", 0), + fooR("e7f3c0ecf1410711cf16d8188dc0075f10d17c95208bdbf7a5c910a0ecb68085", 1), + }, + []fooRow{ + fooR("112d34ac1724faa17502e9f1654808daa43d8e99c384c42faeccc6c713993079", 3), + fooR("8599b0264623ede5d198fd2caa537720e011ce17bd9f34c140de269f472a1126", 4), + fooR("9e8dc977998b3cbc30071202cb8ebb0c8bfa2c400fd28067f6d43c7e92acd077", 2), + fooR("a67249d334bd0c68e92b4c6d8716cdc218130c0e765838e52890133d07d35d48", 0), + fooR("dc5938b62a49a31e947d48d85cf358a77dbbed0f3ad5d06e2df63da3cbe7c80a", 5), + fooR("e7f3c0ecf1410711cf16d8188dc0075f10d17c95208bdbf7a5c910a0ecb68085", 1), + }, + []rangesync.KeyBytes{ + hexID("08addda193ce5c8dfa56d58efaaaa51ccb534738027c4c73631f76811702e54f"), + hexID("112d34ac1724faa17502e9f1654808daa43d8e99c384c42faeccc6c713993079"), + hexID("8599b0264623ede5d198fd2caa537720e011ce17bd9f34c140de269f472a1126"), + hexID("9e8dc977998b3cbc30071202cb8ebb0c8bfa2c400fd28067f6d43c7e92acd077"), + hexID("a67249d334bd0c68e92b4c6d8716cdc218130c0e765838e52890133d07d35d48"), + hexID("dc5938b62a49a31e947d48d85cf358a77dbbed0f3ad5d06e2df63da3cbe7c80a"), + hexID("e7f3c0ecf1410711cf16d8188dc0075f10d17c95208bdbf7a5c910a0ecb68085"), + }, + startDate, + false, + false, + maxDepth, + rangesync.DefaultConfig(), + ) + }) + t.Run("predefined items with recent", func(t *testing.T) { + cfg := rangesync.DefaultConfig() + cfg.RecentTimeSpan = 48 * time.Second + verifyP2P( + t, []fooRow{ + fooR("80e95b39faa731eb50eae7585a8b1cae98f503481f950fdb690e60ff86c21236", 10), + fooR("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7", 20), + fooR("d862b2413af5c252028e8f9871be8297e807661d64decd8249ac2682db168b90", 30), + fooR("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567", 40), + }, + []fooRow{ + fooR("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7", 11), + fooR("bc6218a88d1648b8145fbf93ae74af8975f193af88788e7add3608e0bc50f701", 12), + fooR("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567", 13), + fooR("fbf03324234f79a3fe0587cf5505d7e4c826cb2be38d72eafa60296ed77b3f8f", 14), + }, + []rangesync.KeyBytes{ + hexID("80e95b39faa731eb50eae7585a8b1cae98f503481f950fdb690e60ff86c21236"), + hexID("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7"), + hexID("bc6218a88d1648b8145fbf93ae74af8975f193af88788e7add3608e0bc50f701"), + hexID("d862b2413af5c252028e8f9871be8297e807661d64decd8249ac2682db168b90"), + hexID("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567"), + hexID("fbf03324234f79a3fe0587cf5505d7e4c826cb2be38d72eafa60296ed77b3f8f"), + }, + startDate.Add(time.Minute), + true, + true, + maxDepth, + cfg, + ) + }) + t.Run("empty to non-empty", func(t *testing.T) { + verifyP2P( + t, nil, + []fooRow{ + fooR("1111111111111111111111111111111111111111111111111111111111111111", 11), + fooR("123456789abcdef0000000000000000000000000000000000000000000000000", 12), + fooR("5555555555555555555555555555555555555555555555555555555555555555", 13), + fooR("8888888888888888888888888888888888888888888888888888888888888888", 14), + }, + []rangesync.KeyBytes{ + hexID("1111111111111111111111111111111111111111111111111111111111111111"), + hexID("123456789abcdef0000000000000000000000000000000000000000000000000"), + hexID("5555555555555555555555555555555555555555555555555555555555555555"), + hexID("8888888888888888888888888888888888888888888888888888888888888888"), + }, + startDate, + false, + false, + maxDepth, + rangesync.DefaultConfig(), + ) + }) + t.Run("empty to non-empty with recent", func(t *testing.T) { + cfg := rangesync.DefaultConfig() + cfg.RecentTimeSpan = 48 * time.Second + verifyP2P( + t, nil, + []fooRow{ + fooR("1111111111111111111111111111111111111111111111111111111111111111", 11), + fooR("123456789abcdef0000000000000000000000000000000000000000000000000", 12), + fooR("5555555555555555555555555555555555555555555555555555555555555555", 13), + fooR("8888888888888888888888888888888888888888888888888888888888888888", 14), + }, + []rangesync.KeyBytes{ + hexID("1111111111111111111111111111111111111111111111111111111111111111"), + hexID("123456789abcdef0000000000000000000000000000000000000000000000000"), + hexID("5555555555555555555555555555555555555555555555555555555555555555"), + hexID("8888888888888888888888888888888888888888888888888888888888888888"), + }, + startDate.Add(time.Minute), + true, + true, + maxDepth, + cfg, + ) + }) + t.Run("non-empty to empty with recent", func(t *testing.T) { + cfg := rangesync.DefaultConfig() + cfg.RecentTimeSpan = 48 * time.Second + verifyP2P( + t, + []fooRow{ + fooR("1111111111111111111111111111111111111111111111111111111111111111", 11), + fooR("123456789abcdef0000000000000000000000000000000000000000000000000", 12), + fooR("5555555555555555555555555555555555555555555555555555555555555555", 13), + fooR("8888888888888888888888888888888888888888888888888888888888888888", 14), + }, + nil, + []rangesync.KeyBytes{ + hexID("1111111111111111111111111111111111111111111111111111111111111111"), + hexID("123456789abcdef0000000000000000000000000000000000000000000000000"), + hexID("5555555555555555555555555555555555555555555555555555555555555555"), + hexID("8888888888888888888888888888888888888888888888888888888888888888"), + }, + startDate.Add(time.Minute), + // no actual recent exchange happens due to the initial EmptySet message + false, + false, + maxDepth, + cfg, + ) + }) + t.Run("empty to empty", func(t *testing.T) { + verifyP2P(t, nil, nil, nil, startDate, false, false, maxDepth, rangesync.DefaultConfig()) + }) + t.Run("random test", func(t *testing.T) { + rowsA, rowsB, combined := genRandomRows(80000, 400, 800) + verifyP2P(t, rowsA, rowsB, combined, startDate, false, false, maxDepth, rangesync.DefaultConfig()) + }) +} + +func setupDBRandom( + t testing.TB, + nShared, nUniqueA, nUniqueB int, +) (dirA, dirB string) { + rowsA, rowsB, _ := genRandomRows(nShared, nUniqueA, nUniqueB) + dbA, dirA := populateFoo(t, rowsA) + dbA.Close() + dbB, dirB := populateFoo(t, rowsB) + dbB.Close() + return dirA, dirB +} + +func copyDB(t testing.TB, srcDir, dstDir string) sql.Transaction { + require.NoError(t, os.CopyFS(dstDir, os.DirFS(srcDir))) + db, err := sql.Open("file:"+filepath.Join(dstDir, "temp.db"), + sql.WithNoCheckSchemaDrift()) + require.NoError(t, err) + tx, err := db.Tx(context.Background()) + require.NoError(t, err) + t.Cleanup(func() { tx.Release() }) + return tx +} + +func verifyP2PRandom(t testing.TB, maxDepth int, dirA, dirB string) { + dbA := copyDB(t, dirA, t.TempDir()) + dbB := copyDB(t, dirB, t.TempDir()) + runSync(t, dbA, dbB, nil, startDate, false, false, false, maxDepth, rangesync.DefaultConfig()) +} + +func BenchmarkSyncSmallSet(b *testing.B) { + dirA, dirB := setupDBRandom(b, 800, 40, 80) + for i := 0; i < b.N; i++ { + verifyP2PRandom(b, 24, dirA, dirB) + } +} + +func BenchmarkSyncBigDiff(b *testing.B) { + dirA, dirB := setupDBRandom(b, 8_000_000, 100, 80_000) + for maxDepth := 16; maxDepth <= 24; maxDepth++ { + b.Run(fmt.Sprintf("maxDepth=%d", maxDepth), func(b *testing.B) { + for i := 0; i < b.N; i++ { + verifyP2PRandom(b, maxDepth, dirA, dirB) + } + }) + } +} + +func BenchmarkSyncSmallDiff(b *testing.B) { + dirA, dirB := setupDBRandom(b, 8_000_000, 10, 1000) + for maxDepth := 16; maxDepth <= 24; maxDepth++ { + b.Run(fmt.Sprintf("maxDepth=%d", maxDepth), func(b *testing.B) { + for i := 0; i < b.N; i++ { + verifyP2PRandom(b, maxDepth, dirA, dirB) + } + }) + } +} diff --git a/sync2/fptree/dbbackedstore.go b/sync2/fptree/dbbackedstore.go new file mode 100644 index 0000000000..d0a3e6e72f --- /dev/null +++ b/sync2/fptree/dbbackedstore.go @@ -0,0 +1,75 @@ +package fptree + +import ( + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +// DBBackedStore is an implementation of IDStore that keeps track of the rows in a +// database table, using an FPTree to store items that have arrived from a sync peer. +type DBBackedStore struct { + *sqlstore.SQLIDStore + *FPTree +} + +var _ sqlstore.IDStore = &DBBackedStore{} + +// NewDBBackedStore creates a new DB-backed store. +// sizeHint is the expected number of items added to the store via RegisterHash _after_ +// the store is created. +func NewDBBackedStore( + db sql.Executor, + sts *sqlstore.SyncedTableSnapshot, + sizeHint int, + keyLen int, +) *DBBackedStore { + return &DBBackedStore{ + SQLIDStore: sqlstore.NewSQLIDStore(db, sts, keyLen), + FPTree: NewFPTreeWithValues(sizeHint, keyLen), + } +} + +// Clone creates a copy of the store. +// Implements IDStore.Clone. +func (s *DBBackedStore) Clone() sqlstore.IDStore { + return &DBBackedStore{ + SQLIDStore: s.SQLIDStore.Clone().(*sqlstore.SQLIDStore), + FPTree: s.FPTree.Clone().(*FPTree), + } +} + +// RegisterKey adds a hash to the store, using the FPTree so that the underlying database +// table is unchanged. +// Implements IDStore. +func (s *DBBackedStore) RegisterKey(k rangesync.KeyBytes) error { + return s.FPTree.RegisterKey(k) +} + +// All returns all the items currently in the store. +// Implements IDStore. +func (s *DBBackedStore) All() rangesync.SeqResult { + return rangesync.CombineSeqs(nil, s.SQLIDStore.All(), s.FPTree.All()) +} + +// From returns all the items in the store that are greater than or equal to the given key. +// Implements IDStore. +func (s *DBBackedStore) From(from rangesync.KeyBytes, sizeHint int) rangesync.SeqResult { + return rangesync.CombineSeqs( + from, + // There may be fewer than sizeHint to be loaded from the database as some + // may be in FPTree, but for most cases that will do. + s.SQLIDStore.From(from, sizeHint), + s.FPTree.From(from, sizeHint)) +} + +// SetSnapshot sets the table snapshot to be used by the store. +func (s *DBBackedStore) SetSnapshot(sts *sqlstore.SyncedTableSnapshot) { + s.SQLIDStore.SetSnapshot(sts) + s.FPTree.Clear() +} + +// Release releases resources used by the store. +func (s *DBBackedStore) Release() { + s.FPTree.Release() +} diff --git a/sync2/fptree/dbbackedstore_test.go b/sync2/fptree/dbbackedstore_test.go new file mode 100644 index 0000000000..5c76c40e69 --- /dev/null +++ b/sync2/fptree/dbbackedstore_test.go @@ -0,0 +1,96 @@ +package fptree + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +func TestDBBackedStore(t *testing.T) { + const keyLen = 12 + db := sql.InMemoryTest(t) + _, err := db.Exec( + fmt.Sprintf("create table foo(id char(%d) not null primary key, received int)", keyLen), + nil, nil) + require.NoError(t, err) + for _, row := range []struct { + id rangesync.KeyBytes + ts int64 + }{ + { + id: rangesync.KeyBytes{0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}, + ts: 100, + }, + { + id: rangesync.KeyBytes{0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0}, + ts: 200, + }, + { + id: rangesync.KeyBytes{0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0}, + ts: 300, + }, + { + id: rangesync.KeyBytes{0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0}, + ts: 400, + }, + } { + _, err := db.Exec( + "insert into foo (id, received) values (?, ?)", + func(stmt *sql.Statement) { + stmt.BindBytes(1, row.id) + stmt.BindInt64(2, row.ts) + }, nil) + require.NoError(t, err) + } + st := &sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + TimestampColumn: "received", + } + sts, err := st.Snapshot(db) + require.NoError(t, err) + + store := NewDBBackedStore(db, sts, 0, keyLen) + actualIDs, err := store.From(rangesync.KeyBytes{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 5).FirstN(5) + require.NoError(t, err) + require.Equal(t, []rangesync.KeyBytes{ + {0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}, // wrapped around + }, actualIDs) + + actualIDs1, err := store.All().FirstN(5) + require.NoError(t, err) + require.Equal(t, actualIDs, actualIDs1) + + sr, count := store.Since(rangesync.KeyBytes{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 300) + require.Equal(t, 2, count) + actualIDs, err = sr.FirstN(3) + require.NoError(t, err) + require.Equal(t, []rangesync.KeyBytes{ + {0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0}, // wrapped around + }, actualIDs) + + require.NoError(t, store.RegisterKey(rangesync.KeyBytes{0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0})) + require.NoError(t, store.RegisterKey(rangesync.KeyBytes{0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0})) + sr = store.From(rangesync.KeyBytes{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1) + actualIDs, err = sr.FirstN(6) + require.NoError(t, err) + require.Equal(t, []rangesync.KeyBytes{ + {0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0}, + }, actualIDs) +} diff --git a/sync2/fptree/export_test.go b/sync2/fptree/export_test.go new file mode 100644 index 0000000000..cd43cf160c --- /dev/null +++ b/sync2/fptree/export_test.go @@ -0,0 +1,20 @@ +package fptree + +import ( + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +var ErrEasySplitFailed = errEasySplitFailed + +func (ft *FPTree) EasySplit(x, y rangesync.KeyBytes, limit int) (sr SplitResult, err error) { + return ft.easySplit(x, y, limit) +} + +func (ft *FPTree) PoolNodeCount() int { + return ft.np.nodeCount() +} + +func (ft *FPTree) IDStore() sqlstore.IDStore { + return ft.idStore +} diff --git a/sync2/fptree/fptree.go b/sync2/fptree/fptree.go new file mode 100644 index 0000000000..a712d6d52b --- /dev/null +++ b/sync2/fptree/fptree.go @@ -0,0 +1,1295 @@ +package fptree + +import ( + "errors" + "fmt" + "io" + "runtime" + "strconv" + "strings" + + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +var errEasySplitFailed = errors.New("easy split failed") + +const ( + FingerprintSize = rangesync.FingerprintSize + // sizeHintCoef is used to calculate the number of pool entries to preallocate for + // an FPTree based on the expected number of items which this tree may contain. + sizeHintCoef = 2.1 +) + +// FPResult represents the result of a range fingerprint query against FPTree, as returned +// by FingerprintInterval. +type FPResult struct { + // Range fingerprint + FP rangesync.Fingerprint + // Number of items in the range + Count uint32 + // Interval type: -1 for normal, 0 for the whole set, 1 for wrapped around ("inverse") + IType int + // Items in the range + Items rangesync.SeqResult + // The item following the range + Next rangesync.KeyBytes +} + +// SplitResult represents the result of a split operation. +type SplitResult struct { + // The two parts of the inteval + Part0, Part1 FPResult + // Moddle point value + Middle rangesync.KeyBytes +} + +// aggContext is the context used for aggregation operations. +type aggContext struct { + // nodePool used by the tree + np *nodePool + // Bounds of the interval being aggregated + x, y rangesync.KeyBytes + // The current fingerprint of the items aggregated so far, since the beginning or + // after the split ("easy split") + fp rangesync.Fingerprint + // The fingerprint of the items aggregated in the first part of the split + fp0 rangesync.Fingerprint + // Number of items aggregated so far, since the beginning or after the split + // ("easy split") + count uint32 + // Number of items aggregated in the first part of the split + count0 uint32 + // Interval type: -1 for normal, 0 for the whole set, 1 for wrapped around ("inverse") + itype int + // Maximum remaining number of items to aggregate. + limit int + // The number of items aggregated so far. + total uint32 + // The resulting item sequence. + items rangesync.SeqResult + // The item immediately following the aggregated items. + next rangesync.KeyBytes + // The prefix corresponding to the last aggregated node. + lastPrefix *prefix + // The prefix corresponding to the last aggregated node in the first part of the split. + lastPrefix0 *prefix + // Whether the aggregation is being done for an "easy split" (split operation + // without querying the underlying IDStore). + easySplit bool +} + +// prefixAtOrAfterX verifies that the any key with the prefix p is at or after x. +// It can be used for the whole interval in case of a normal interval. +// With inverse intervals, it should only be used when processing the [x, max) part of the +// interval. +func (ac *aggContext) prefixAtOrAfterX(p prefix) bool { + b := make(rangesync.KeyBytes, len(ac.x)) + p.minID(b) + return b.Compare(ac.x) >= 0 +} + +// prefixBelowY verifies that the any key with the prefix p is below y. +// It can be used for the whole interval in case of a normal interval. +// With inverse intervals, it should only be used when processing the [0, y) part of the +// interval. +func (ac *aggContext) prefixBelowY(p prefix) bool { + b := make(rangesync.KeyBytes, len(ac.y)) + // If p.idAfter(b) is true, this means there's wraparound and + // b is zero whereas all the possible keys beginning with prefix p + // are non-zero. In this case, there can be no key y such that + // all the keys beginning with prefix p are below y. + return !p.idAfter(b) && b.Compare(ac.y) <= 0 +} + +// fingerprintAtOrAfterX verifies that the specified fingerprint, which should be derived +// from a single key, is at or after x bound of the interval. +func (ac *aggContext) fingreprintAtOrAfterX(fp rangesync.Fingerprint) bool { + k := make(rangesync.KeyBytes, len(ac.x)) + copy(k, fp[:]) + return k.Compare(ac.x) >= 0 +} + +// fingerprintBelowY verifies that the specified fingerprint, which should be derived from a +// single key, is below y bound of the interval. +func (ac *aggContext) fingreprintBelowY(fp rangesync.Fingerprint) bool { + k := make(rangesync.KeyBytes, len(ac.x)) + copy(k, fp[:]) + k[:FingerprintSize].Inc() // 1 after max key derived from the fingerprint + return k.Compare(ac.y) <= 0 +} + +// nodeAtOrAfterX verifies that the node with the given index is at or after x bound of the +// interval. +func (ac *aggContext) nodeAtOrAfterX(idx nodeIndex, p prefix) bool { + count, fp, _ := ac.np.info(idx) + if count == 1 { + v := ac.np.value(idx) + if v != nil { + return v.Compare(ac.x) >= 0 + } + return ac.fingreprintAtOrAfterX(fp) + } + return ac.prefixAtOrAfterX(p) +} + +// nodeBelowY verifies that the node with the given index is below y bound of the interval. +func (ac *aggContext) nodeBelowY(idx nodeIndex, p prefix) bool { + count, fp, _ := ac.np.info(idx) + if count == 1 { + v := ac.np.value(idx) + if v != nil { + return v.Compare(ac.y) < 0 + } + return ac.fingreprintBelowY(fp) + } + return ac.prefixBelowY(p) +} + +// pruneX returns true if the specified node can be pruned during left-aggregation because +// all of its keys are below the x bound of the interval. +func (ac *aggContext) pruneX(idx nodeIndex, p prefix) bool { + b := make(rangesync.KeyBytes, len(ac.x)) + if !p.idAfter(b) && b.Compare(ac.x) <= 0 { + // idAfter derived from the prefix is at or below y => prune + return true + } + count, fp, _ := ac.np.info(idx) + if count > 1 { + // node has count > 1, so we can't use its fingerprint or value to + // determine if it's at or after X + return false + } + k := ac.np.value(idx) + if k != nil { + return k.Compare(ac.x) < 0 + } + + k = make(rangesync.KeyBytes, len(ac.x)) + copy(k, fp[:]) + k[:FingerprintSize].Inc() // 1 after max key derived from the fingerprint + return k.Compare(ac.x) <= 0 +} + +// pruneY returns true if the specified node can be pruned during right-aggregation +// because all of its keys are at or after the y bound of the interval. +func (ac *aggContext) pruneY(idx nodeIndex, p prefix) bool { + b := make(rangesync.KeyBytes, len(ac.y)) + p.minID(b) + if b.Compare(ac.y) >= 0 { + // min ID derived from the prefix is at or after y => prune + return true + } + + count, fp, _ := ac.np.info(idx) + if count > 1 { + // node has count > 1, so we can't use its fingerprint or value to + // determine if it's below y + return false + } + k := ac.np.value(idx) + if k == nil { + k = make(rangesync.KeyBytes, len(ac.y)) + copy(k, fp[:]) + } + return k.Compare(ac.y) >= 0 +} + +// switchToSecondPart switches aggregation to the second part of the "easy split". +func (ac *aggContext) switchToSecondPart() { + ac.limit = -1 + ac.fp0 = ac.fp + ac.count0 = ac.count + ac.lastPrefix0 = ac.lastPrefix + clear(ac.fp[:]) + ac.count = 0 + ac.lastPrefix = nil +} + +// maybeIncludeNode returns tries to include the full contents of the specified node in +// the aggregation and returns if it succeeded, based on the remaining limit and the numer +// of items in the node. +// It also handles "easy split" happening at the node. +func (ac *aggContext) maybeIncludeNode(idx nodeIndex, p prefix) bool { + count, fp, leaf := ac.np.info(idx) + switch { + case ac.limit < 0: + case uint32(ac.limit) >= count: + ac.limit -= int(count) + case !ac.easySplit || !leaf: + return false + case ac.count == 0: + // We're doing a split and this node is over the limit, but the first part + // is still empty so we include this node in the first part and + // then switch to the second part + ac.limit = 0 + default: + // We're doing a split and this node is over the limit, so store count and + // fingerprint for the first part and include the current node in the + // second part + ac.limit = -1 + ac.fp0 = ac.fp + ac.count0 = ac.count + ac.lastPrefix0 = ac.lastPrefix + copy(ac.fp[:], fp[:]) + ac.count = count + ac.lastPrefix = &p + return true + } + ac.fp.Update(fp[:]) + ac.count += count + ac.lastPrefix = &p + if ac.easySplit && ac.limit == 0 { + // We're doing a split and this node is exactly at the limit, or it was + // above the limit but first part was still empty, so store count and + // fingerprint for the first part which includes the current node and zero + // out cound and figerprint for the second part + ac.switchToSecondPart() + } + return true +} + +// FPTree is a binary tree data structure designed to perform range fingerprint queries +// efficiently. +// FPTree can work on its own, with fingerprint query complexity being O(log n). +// It can also be backed by an IDStore with a depth limit the binary tree, in which +// case the query efficiency degrades with the number of items growing. +// O(log n) query efficiency can be retained in this case for queries which +// have the number of non-zero bits, starting from the high bit, below maxDepth. +// FPTree does not do any special balancing and relies on the IDs added on it being +// uniformly distributed, which is the case for the IDs based on cryptographic hashes. +type FPTree struct { + trace + idStore sqlstore.IDStore + np *nodePool + root nodeIndex + keyLen int + maxDepth int +} + +var _ sqlstore.IDStore = &FPTree{} + +// NewFPTreeWithValues creates an FPTree which also stores the items themselves and does +// not make use of a backing IDStore. +// sizeHint specifies the approximage expected number of items. +// keyLen specifies the number of bytes in keys used. +func NewFPTreeWithValues(sizeHint, keyLen int) *FPTree { + return NewFPTree(sizeHint, nil, keyLen, 0) +} + +// NewFPTree creates an FPTree of limited depth backed by an IDStore. +// sizeHint specifies the approximage expected number of items. +// keyLen specifies the number of bytes in keys used. +func NewFPTree(sizeHint int, idStore sqlstore.IDStore, keyLen, maxDepth int) *FPTree { + var np nodePool + if sizeHint > 0 { + size := int(float64(sizeHint) * sizeHintCoef) + if maxDepth > 0 { + size = min(size, 1<<(maxDepth+1)) + } + np.init(size) + } + if idStore == nil && maxDepth != 0 { + panic("BUG: newFPTree: no idStore, but maxDepth specified") + } + ft := &FPTree{ + np: &np, + idStore: idStore, + root: noIndex, + keyLen: keyLen, + maxDepth: maxDepth, + } + runtime.SetFinalizer(ft, (*FPTree).Release) + return ft +} + +// traverse traverses the subtree rooted in idx in order and calls the given function for +// each item. +func (ft *FPTree) traverse(idx nodeIndex, yield func(rangesync.KeyBytes) bool) (res bool) { + ft.enter("traverse: idx %d", idx) + defer func() { + ft.leave(res) + }() + if idx == noIndex { + ft.log("no index") + return true + } + l := ft.np.left(idx) + r := ft.np.right(idx) + if l == noIndex && r == noIndex { + v := ft.np.value(idx) + if v != nil { + ft.log("yield value %s", shortened(v)) + } + if v != nil && !yield(v) { + return false + } + return true + } + return ft.traverse(l, yield) && ft.traverse(r, yield) +} + +// travereFrom traverses the subtree rooted in idx in order and calls the given function for +// each item starting from the given key. +func (ft *FPTree) traverseFrom( + idx nodeIndex, + p prefix, + from rangesync.KeyBytes, + yield func(rangesync.KeyBytes) bool, +) (res bool) { + ft.enter("traverseFrom: idx %d p %s from %s", idx, p, from) + defer func() { + ft.leave(res) + }() + if idx == noIndex { + return true + } + if p == emptyPrefix || ft.np.leaf(idx) { + v := ft.np.value(idx) + if v != nil && v.Compare(from) >= 0 { + ft.log("yield value %s", shortened(v)) + if !yield(v) { + return false + } + } + return true + } + if !p.highBit() { + return ft.traverseFrom(ft.np.left(idx), p.shift(), from, yield) && + ft.traverse(ft.np.right(idx), yield) + } else { + return ft.traverseFrom(ft.np.right(idx), p.shift(), from, yield) + } +} + +// All returns all the items currently in the tree (including those in the IDStore). +// Implements sqlstore.All. +func (ft *FPTree) All() rangesync.SeqResult { + ft.np.lockRead() + defer ft.np.unlockRead() + switch { + case ft.root == noIndex: + return rangesync.EmptySeqResult() + case ft.storeValues(): + return rangesync.SeqResult{ + Seq: func(yield func(rangesync.KeyBytes) bool) { + for { + if !ft.traverse(ft.root, yield) { + break + } + } + }, + Error: rangesync.NoSeqError, + } + } + return ft.idStore.All() +} + +// From returns all the items in the tree that are greater than or equal to the given key. +// Implements sqlstore.IDStore. +func (ft *FPTree) From(from rangesync.KeyBytes, sizeHint int) rangesync.SeqResult { + ft.np.lockRead() + defer ft.np.unlockRead() + switch { + case ft.root == noIndex: + return rangesync.EmptySeqResult() + case ft.storeValues(): + return rangesync.SeqResult{ + Seq: func(yield func(rangesync.KeyBytes) bool) { + p := prefixFromKeyBytes(from) + if !ft.traverseFrom(ft.root, p, from, yield) { + return + } + for { + if !ft.traverse(ft.root, yield) { + break + } + } + }, + Error: rangesync.NoSeqError, + } + } + return ft.idStore.From(from, sizeHint) +} + +// Release releases resources used by the tree. +// Implements sqlstore.IDStore. +func (ft *FPTree) Release() { + ft.np.lockWrite() + defer ft.np.unlockWrite() + ft.np.release(ft.root) + ft.root = noIndex + if ft.idStore != nil { + ft.idStore.Release() + } +} + +// Clear removes all items from the tree. +// It should only be used with trees that were created using NewFPtreeWithValues. +func (ft *FPTree) Clear() { + if !ft.storeValues() { + // if we have an idStore, it can't be cleared and thus the tree can't be + // cleared either + panic("BUG: can only clear fpTree with values") + } + ft.Release() +} + +// Clone makes a copy of the tree. +// The copy operation is thread-safe and has complexity of O(1). +func (ft *FPTree) Clone() sqlstore.IDStore { + ft.np.lockWrite() + defer ft.np.unlockWrite() + if ft.root != noIndex { + ft.np.ref(ft.root) + } + var idStore sqlstore.IDStore + if !ft.storeValues() { + idStore = ft.idStore.Clone() + } + return &FPTree{ + np: ft.np, + idStore: idStore, + root: ft.root, + keyLen: ft.keyLen, + maxDepth: ft.maxDepth, + } +} + +// pushLeafDown pushes a leaf node down the tree when the node's path matches that of the +// new to be added, splitting it if necessary. +func (ft *FPTree) pushLeafDown( + idx nodeIndex, + replace bool, + singleFP, prevFP rangesync.Fingerprint, + depth int, + curCount uint32, + value, prevValue rangesync.KeyBytes, +) (newIdx nodeIndex) { + if idx == noIndex { + panic("BUG: pushLeafDown on a nonexistent node") + } + // Once we stumble upon a node with refCount > 1, we no longer can replace nodes + // as they're also referenced by another tree. + if replace && ft.np.refCount(idx) > 1 { + ft.np.releaseOne(idx) + replace = false + } + replace = replace && ft.np.refCount(idx) == 1 + replaceIdx := noIndex + if replace { + replaceIdx = idx + } + fpCombined := rangesync.CombineFingerprints(singleFP, prevFP) + if ft.maxDepth != 0 && depth == ft.maxDepth { + newIdx = ft.np.add(fpCombined, curCount+1, noIndex, noIndex, nil, replaceIdx) + return newIdx + } + if curCount != 1 { + panic("BUG: pushDown of non-1-leaf below maxDepth") + } + dirA := singleFP.BitFromLeft(depth) + dirB := prevFP.BitFromLeft(depth) + if dirA == dirB { + // TODO: in the proper radix tree, these 1-child nodes should never be + // created, accumulating the prefix instead + childIdx := ft.pushLeafDown(idx, replace, singleFP, prevFP, depth+1, 1, value, prevValue) + if dirA { + newIdx = ft.np.add(fpCombined, 2, noIndex, childIdx, nil, noIndex) + } else { + newIdx = ft.np.add(fpCombined, 2, childIdx, noIndex, nil, noIndex) + } + } else { + idxA := ft.np.add(singleFP, 1, noIndex, noIndex, value, noIndex) + idxB := ft.np.add(prevFP, curCount, noIndex, noIndex, prevValue, replaceIdx) + if dirA { + newIdx = ft.np.add(fpCombined, 2, idxB, idxA, nil, noIndex) + } else { + newIdx = ft.np.add(fpCombined, 2, idxA, idxB, nil, noIndex) + } + } + return newIdx +} + +// addValue adds a value to the subtree rooted in idx. +func (ft *FPTree) addValue( + idx nodeIndex, + replace bool, + fp rangesync.Fingerprint, + depth int, + value rangesync.KeyBytes, +) (newIdx nodeIndex) { + if idx == noIndex { + newIdx = ft.np.add(fp, 1, noIndex, noIndex, value, noIndex) + return newIdx + } + // Once we stumble upon a node with refCount > 1, we no longer can replace nodes + // as they're also referenced by another tree. + if replace && ft.np.refCount(idx) > 1 { + ft.np.releaseOne(idx) + replace = false + } + count, nodeFP, leaf := ft.np.info(idx) + left := ft.np.left(idx) + right := ft.np.right(idx) + nodeValue := ft.np.value(idx) + if leaf { + if count != 1 && (ft.maxDepth == 0 || depth != ft.maxDepth) { + panic("BUG: unexpected leaf node") + } + // we're at a leaf node, need to push down the old fingerprint, or, + // if we've reached the max depth, just update the current node + return ft.pushLeafDown(idx, replace, fp, nodeFP, depth, count, value, nodeValue) + } + replaceIdx := noIndex + if replace { + replaceIdx = idx + } + fpCombined := rangesync.CombineFingerprints(fp, nodeFP) + if fp.BitFromLeft(depth) { + newRight := ft.addValue(right, replace, fp, depth+1, value) + newIdx := ft.np.add(fpCombined, count+1, left, newRight, nil, replaceIdx) + if !replace && left != noIndex { + // the original node is not being replaced, so the reused left + // node has acquired another reference + ft.np.ref(left) + } + return newIdx + } else { + newLeft := ft.addValue(left, replace, fp, depth+1, value) + newIdx := ft.np.add(fpCombined, count+1, newLeft, right, nil, replaceIdx) + if !replace && right != noIndex { + // the original node is not being replaced, so the reused right + // node has acquired another reference + ft.np.ref(right) + } + return newIdx + } +} + +// AddStoredKey adds a key to the tree, assuming that either the tree doesn't have an +// IDStore ar the IDStore already contains the key. +func (ft *FPTree) AddStoredKey(k rangesync.KeyBytes) { + var fp rangesync.Fingerprint + fp.Update(k) + ft.log("addStoredHash: h %s fp %s", k, fp) + var v rangesync.KeyBytes + if ft.storeValues() { + v = k + } + ft.np.lockWrite() + defer ft.np.unlockWrite() + ft.root = ft.addValue(ft.root, true, fp, 0, v) +} + +// RegisterKey registers a key in the tree. +// If the tree has an IDStore, the key is also registered with the IDStore. +func (ft *FPTree) RegisterKey(k rangesync.KeyBytes) error { + ft.log("addHash: k %s", k) + if !ft.storeValues() { + if err := ft.idStore.RegisterKey(k); err != nil { + return err + } + } + ft.AddStoredKey(k) + return nil +} + +// storeValues returns true if the tree stores the values (has no IDStore). +func (ft *FPTree) storeValues() bool { + return ft.idStore == nil +} + +// CheckKey returns true if the tree contains or may contain the given key. +// If this function returns false, the tree definitely doesn't contain the key. +// If this function returns true and the tree stores the values, the key is definitely +// contained in the tree. +// If this function returns true and the tree doesn't store the values, the key may be +// contained in the tree. +func (ft *FPTree) CheckKey(k rangesync.KeyBytes) bool { + // We're unlikely to be able to find a node with the full prefix, but if we can + // find a leaf node with matching partial prefix, that's good enough except + // that we also need to check the node's fingerprint. + idx, _, _ := ft.followPrefix(ft.root, prefixFromKeyBytes(k), emptyPrefix) + if idx == noIndex { + return false + } + count, fp, _ := ft.np.info(idx) + if count != 1 { + return true + } + var kFP rangesync.Fingerprint + kFP.Update(k) + return fp == kFP +} + +// followPrefix follows the bit prefix p from the node idx. +func (ft *FPTree) followPrefix(from nodeIndex, p, followed prefix) (idx nodeIndex, rp prefix, found bool) { + ft.enter("followPrefix: from %d p %s highBit %v", from, p, p.highBit()) + defer func() { ft.leave(idx, rp, found) }() + + for from != noIndex { + switch { + case p.len() == 0: + return from, followed, true + case ft.np.leaf(from): + return from, followed, false + case p.highBit(): + from = ft.np.right(from) + p = p.shift() + followed = followed.right() + default: + from = ft.np.left(from) + p = p.shift() + followed = followed.left() + } + } + + return noIndex, followed, false +} + +// aggregateEdge aggregates an edge of the interval, which can be bounded by x, y, both x +// and y or none of x and y, have a common prefix and optionally bounded by a limit of N of +// aggregated items. +// It returns a boolean indicating whether the limit or the right edge (y) was reached and +// an error, if any. +func (ft *FPTree) aggregateEdge( + x, y rangesync.KeyBytes, + idx nodeIndex, + p prefix, + ac *aggContext, +) (cont bool, err error) { + ft.enter("aggregateEdge: x %s y %s p %s limit %d count %d", x, y, p, ac.limit, ac.count) + defer func() { + ft.leave(ac.limit, ac.count, cont, err) + }() + if ft.storeValues() { + panic("BUG: aggregateEdge should not be used for tree with values") + } + if ac.easySplit { + // easySplit means we should not be querying the database, + // so we'll have to retry using slower strategy + return false, errEasySplitFailed + } + if ac.limit == 0 && ac.next != nil { + ft.log("aggregateEdge: limit is 0 and end already set") + return false, nil + } + var startFrom rangesync.KeyBytes + if x == nil { + startFrom = make(rangesync.KeyBytes, ft.keyLen) + p.minID(startFrom) + } else { + startFrom = x + } + ft.log("aggregateEdge: startFrom %s", startFrom) + sizeHint := int(ft.np.count(idx)) + switch { + case ac.limit == 0: + sizeHint = 1 + case ac.limit > 0: + sizeHint = min(ac.limit, sizeHint) + } + sr := ft.From(startFrom, sizeHint) + if ac.limit == 0 { + next, err := sr.First() + if err != nil { + return false, err + } + ac.next = next.Clone() + if x != nil { + ft.log("aggregateEdge: limit 0: x is not nil, setting start to %s", ac.next.String()) + ac.items = sr + } + ft.log("aggregateEdge: limit is 0 at %s", ac.next.String()) + return false, nil + } + if x != nil { + ac.items = sr + ft.log("aggregateEdge: x is not nil, setting start to %v", sr) + } + + n := ft.np.count(ft.root) + for id := range sr.Seq { + if ac.limit == 0 && !ac.easySplit { + ac.next = id.Clone() + ft.log("aggregateEdge: limit exhausted") + return false, nil + } + if n == 0 { + break + } + ft.log("aggregateEdge: ID %s", id) + if y != nil && id.Compare(y) >= 0 { + ac.next = id.Clone() + ft.log("aggregateEdge: ID is over Y: %s", id) + return false, nil + } + if !p.match(id) { + ft.log("aggregateEdge: ID doesn't match the prefix: %s", id) + ac.lastPrefix = &p + return true, nil + } + if ac.limit == 0 { + ft.log("aggregateEdge: switching to second part of easySplit") + ac.switchToSecondPart() + } + ac.fp.Update(id) + ac.count++ + if ac.limit > 0 { + ac.limit-- + } + n-- + } + if err := sr.Error(); err != nil { + return false, err + } + + return true, nil +} + +// aggregateUpToLimit aggregates the subtree rooted in idx up to the limit of N of nodes. +func (ft *FPTree) aggregateUpToLimit(idx nodeIndex, p prefix, ac *aggContext) (cont bool, err error) { + ft.enter("aggregateUpToLimit: idx %d p %s limit %d cur_fp %s cur_count0 %d cur_count %d", idx, p, ac.limit, + ac.fp, ac.count0, ac.count) + defer func() { + ft.leave(ac.fp, ac.count0, ac.count, err) + }() + switch { + case idx == noIndex: + ft.log("stop: no node") + return true, nil + case ac.limit == 0: + return false, nil + case ac.maybeIncludeNode(idx, p): + // node is fully included + ft.log("included fully, lastPrefix = %s", ac.lastPrefix) + return true, nil + case ft.np.leaf(idx): + // reached the limit on this node, do not need to continue after + // done with it + cont, err := ft.aggregateEdge(nil, nil, idx, p, ac) + if err != nil { + return false, err + } + if cont { + panic("BUG: expected limit not reached") + } + return false, nil + default: + pLeft := p.left() + left := ft.np.left(idx) + if left != noIndex { + if ac.maybeIncludeNode(left, pLeft) { + // left node is fully included, after which + // we need to stop somewhere in the right subtree + ft.log("include left in full") + } else { + // we must stop somewhere in the left subtree, + // and the right subtree is irrelevant unless + // easySplit is being done and we must restart + // after the limit is exhausted + ft.log("descend to the left") + if cont, err := ft.aggregateUpToLimit(left, pLeft, ac); !cont || err != nil { + return cont, err + } + if !ac.easySplit { + return false, nil + } + } + } + ft.log("descend to the right") + return ft.aggregateUpToLimit(ft.np.right(idx), p.right(), ac) + } +} + +// aggregateLeft aggregates the subtree that covers the left subtree of the LCA in case of +// normal intervals, and the subtree that covers [x, MAX] part for the inverse (wrapped +// around) intervals. +func (ft *FPTree) aggregateLeft( + idx nodeIndex, + k rangesync.KeyBytes, + p prefix, + ac *aggContext, +) (cont bool, err error) { + ft.enter("aggregateLeft: idx %d k %s p %s limit %d", idx, shortened(k), p, ac.limit) + defer func() { + ft.leave(ac.fp, ac.count0, ac.count, err) + }() + switch { + case idx == noIndex: + // for ac.limit == 0, it's important that we still visit the node + // so that we can get the item immediately following the included items + ft.log("stop: no node") + return true, nil + case ac.limit == 0: + return false, nil + case ac.nodeAtOrAfterX(idx, p) && ac.maybeIncludeNode(idx, p): + ft.log("including node in full: %s limit %d", p, ac.limit) + return true, nil + case (ft.maxDepth != 0 && p.len() == ft.maxDepth) || ft.np.leaf(idx): + if ac.pruneX(idx, p) { + ft.log("node %d p %s pruned", idx, p) + // we've not reached X yet so we should not stop, thus true + return true, nil + } + return ft.aggregateEdge(ac.x, nil, idx, p, ac) + case !k.BitFromLeft(p.len()): + left := ft.np.left(idx) + right := ft.np.right(idx) + ft.log("incl right node %d + go left to node %d", right, left) + cont, err := ft.aggregateLeft(left, k, p.left(), ac) + if !cont || err != nil { + return false, err + } + if right != noIndex { + return ft.aggregateUpToLimit(right, p.right(), ac) + } + return true, nil + default: + right := ft.np.right(idx) + ft.log("go right to node %d", right) + return ft.aggregateLeft(right, k, p.right(), ac) + } +} + +// aggregateRight aggregates the subtree that covers the right subtree of the LCA in case +// of normal intervals, and the subtree that covers [0, y) part for the inverse (wrapped +// around) intervals. +func (ft *FPTree) aggregateRight( + idx nodeIndex, + k rangesync.KeyBytes, + p prefix, + ac *aggContext, +) (cont bool, err error) { + ft.enter("aggregateRight: idx %d k %s p %s limit %d", idx, shortened(k), p, ac.limit) + defer func() { + ft.leave(ac.fp, ac.count0, ac.count, err) + }() + switch { + case idx == noIndex: + ft.log("stop: no node") + return true, nil + case ac.limit == 0: + return false, nil + case ac.nodeBelowY(idx, p) && ac.maybeIncludeNode(idx, p): + ft.log("including node in full: %s limit %d", p, ac.limit) + return ac.limit != 0, nil + case (ft.maxDepth != 0 && p.len() == ft.maxDepth) || ft.np.leaf(idx): + if ac.pruneY(idx, p) { + ft.log("node %d p %s pruned", idx, p) + return false, nil + } + return ft.aggregateEdge(nil, ac.y, idx, p, ac) + case !k.BitFromLeft(p.len()): + left := ft.np.left(idx) + ft.log("go left to node %d", left) + return ft.aggregateRight(left, k, p.left(), ac) + default: + left := ft.np.left(idx) + right := ft.np.right(idx) + ft.log("incl left node %d + go right to node %d", left, right) + if left != noIndex { + cont, err := ft.aggregateUpToLimit(left, p.left(), ac) + if !cont || err != nil { + return false, err + } + } + return ft.aggregateRight(ft.np.right(idx), k, p.right(), ac) + } +} + +// aggregateXX aggregtes intervals of form [x, x) which denotes the whole set. +func (ft *FPTree) aggregateXX(ac *aggContext) (err error) { + // [x, x) interval which denotes the whole set unless + // the limit is specified, in which case we need to start aggregating + // with x and wrap around if necessary + ft.enter("aggregateXX: x %s limit %d", ac.x, ac.limit) + defer func() { + ft.leave(ac, err) + }() + if ft.root == noIndex { + ft.log("empty set (no root)") + } else if ac.maybeIncludeNode(ft.root, emptyPrefix) { + ft.log("whole set") + } else { + // We need to aggregate up to ac.limit number of items starting + // from x and wrapping around if necessary + return ft.aggregateInverse(ac) + } + return nil +} + +// aggregateSimple aggregates simple (normal) intervals of form [x, y) where x < y. +func (ft *FPTree) aggregateSimple(ac *aggContext) (err error) { + // "proper" interval: [x, lca); (lca, y) + ft.enter("aggregateSimple: x %s y %s limit %d", ac.x, ac.y, ac.limit) + defer func() { + ft.leave(ac, err) + }() + p := commonPrefix(ac.x, ac.y) + lcaIdx, lcaPrefix, fullPrefixFound := ft.followPrefix(ft.root, p, emptyPrefix) + ft.log("commonPrefix %s lcaPrefix %s lca %d found %v", p, lcaPrefix, lcaIdx, fullPrefixFound) + switch { + case fullPrefixFound && !ft.np.leaf(lcaIdx): + if lcaPrefix != p { + panic("BUG: bad followedPrefix") + } + if _, err := ft.aggregateLeft(ft.np.left(lcaIdx), ac.x, p.left(), ac); err != nil { + return err + } + if ac.limit != 0 { + if _, err := ft.aggregateRight(ft.np.right(lcaIdx), ac.y, p.right(), ac); err != nil { + return err + } + } + case lcaIdx == noIndex || !ft.np.leaf(lcaIdx): + ft.log("commonPrefix %s NOT found b/c no items have it", p) + case ac.nodeAtOrAfterX(lcaIdx, lcaPrefix) && ac.nodeBelowY(lcaIdx, lcaPrefix) && + ac.maybeIncludeNode(lcaIdx, lcaPrefix): + ft.log("commonPrefix %s -- lca node %d included in full", p, lcaIdx) + case ft.np.leaf(lcaIdx) && ft.np.value(lcaIdx) != nil: + // leaf 1-node with value that could not be included should be skipped + return nil + default: + ft.log("commonPrefix %s -- lca %d", p, lcaIdx) + _, err := ft.aggregateEdge(ac.x, ac.y, lcaIdx, lcaPrefix, ac) + return err + } + return nil +} + +// aggregateInverse aggregates inverse intervals of form [x, y) where x > y. +func (ft *FPTree) aggregateInverse(ac *aggContext) (err error) { + // inverse interval: [min, y); [x, max] + + // First, we handle [x, max] part + // For this, we process the subtree rooted in the LCA of 0x000000... (all 0s) and x + ft.enter("aggregateInverse: x %s y %s limit %d", ac.x, ac.y, ac.limit) + defer func() { + ft.leave(ac, err) + }() + pf0 := preFirst0(ac.x) + idx0, followedPrefix, found := ft.followPrefix(ft.root, pf0, emptyPrefix) + ft.log("pf0 %s idx0 %d found %v followedPrefix %s", pf0, idx0, found, followedPrefix) + switch { + case found && !ft.np.leaf(idx0): + if followedPrefix != pf0 { + panic("BUG: bad followedPrefix") + } + cont, err := ft.aggregateLeft(idx0, ac.x, pf0, ac) + if err != nil { + return err + } + if !cont { + return nil + } + case idx0 == noIndex || !ft.np.leaf(idx0): + // nothing to do + case ac.nodeAtOrAfterX(idx0, followedPrefix) && ac.maybeIncludeNode(idx0, followedPrefix): + // node is fully included + case ac.pruneX(idx0, followedPrefix): + // the node is below X + ft.log("node %d p %s pruned", idx0, followedPrefix) + default: + _, err := ft.aggregateEdge(ac.x, nil, idx0, followedPrefix, ac) + if err != nil { + return err + } + } + + if ac.limit == 0 && !ac.easySplit { + return nil + } + + // Then we handle [min, y) part. + // For this, we process the subtree rooted in the LCA of y and 0xffffff... (all 1s) + pf1 := preFirst1(ac.y) + idx1, followedPrefix, found := ft.followPrefix(ft.root, pf1, emptyPrefix) + ft.log("pf1 %s idx1 %d found %v", pf1, idx1, found) + switch { + case found && !ft.np.leaf(idx1): + if followedPrefix != pf1 { + panic("BUG: bad followedPrefix") + } + if _, err := ft.aggregateRight(idx1, ac.y, pf1, ac); err != nil { + return err + } + case idx1 == noIndex || !ft.np.leaf(idx1): + // nothing to do + case ac.nodeBelowY(idx1, followedPrefix) && ac.maybeIncludeNode(idx1, followedPrefix): + // node is fully included + case ac.pruneY(idx1, followedPrefix): + // the node is at or after Y + ft.log("node %d p %s pruned", idx1, followedPrefix) + return nil + default: + _, err := ft.aggregateEdge(nil, ac.y, idx1, followedPrefix, ac) + if err != nil { + return err + } + } + + return nil +} + +// aggregateInterval aggregates an interval, updating the aggContext accordingly. +func (ft *FPTree) aggregateInterval(ac *aggContext) (err error) { + ft.enter("aggregateInterval: x %s y %s limit %d", ac.x, ac.y, ac.limit) + defer func() { + ft.leave(ac, err) + }() + ac.itype = ac.x.Compare(ac.y) + if ft.root == noIndex { + return nil + } + ac.total = ft.np.count(ft.root) + switch ac.itype { + case 0: + return ft.aggregateXX(ac) + case -1: + return ft.aggregateSimple(ac) + default: + return ft.aggregateInverse(ac) + } +} + +// startFromPrefix returns a SeqResult which begins with the first item that has the +// specified prefix. +func (ft *FPTree) startFromPrefix(ac *aggContext, p prefix) rangesync.SeqResult { + k := make(rangesync.KeyBytes, ft.keyLen) + p.idAfter(k) + ft.log("startFromPrefix: p: %s idAfter: %s", p, k) + return ft.From(k, 1) +} + +// nextFromPrefix return the first item that has the prefix p. +func (ft *FPTree) nextFromPrefix(ac *aggContext, p prefix) (rangesync.KeyBytes, error) { + id, err := ft.startFromPrefix(ac, p).First() + if err != nil { + return nil, err + } + if id == nil { + return nil, nil + } + return id.Clone(), nil +} + +// FingerprintInteval performs a range fingerprint query with specified bounds and limit. +func (ft *FPTree) FingerprintInterval(x, y rangesync.KeyBytes, limit int) (fpr FPResult, err error) { + ft.np.lockRead() + defer ft.np.unlockRead() + return ft.fingerprintInterval(x, y, limit) +} + +func (ft *FPTree) fingerprintInterval(x, y rangesync.KeyBytes, limit int) (fpr FPResult, err error) { + ft.enter("fingerprintInterval: x %s y %s limit %d", x, y, limit) + defer func() { + ft.leave(fpr.FP, fpr.Count, fpr.IType, fpr.Items, fpr.Next, err) + }() + ac := aggContext{np: ft.np, x: x, y: y, limit: limit} + if err := ft.aggregateInterval(&ac); err != nil { + return FPResult{}, err + } + fpr = FPResult{ + FP: ac.fp, + Count: ac.count, + IType: ac.itype, + Items: rangesync.EmptySeqResult(), + } + + if ac.total == 0 { + return fpr, nil + } + + if ac.items.Seq != nil { + ft.log("fingerprintInterval: items %v", ac.items) + fpr.Items = ac.items + } else { + fpr.Items = ft.From(x, 1) + ft.log("fingerprintInterval: start from x: %v", fpr.Items) + } + + if ac.next != nil { + ft.log("fingerprintInterval: next %s", ac.next) + fpr.Next = ac.next + } else if (fpr.IType == 0 && limit < 0) || fpr.Count == 0 { + next, err := fpr.Items.First() + if err != nil { + return FPResult{}, err + } + if next != nil { + fpr.Next = next.Clone() + } + ft.log("fingerprintInterval: next at start %s", fpr.Next) + } else if ac.lastPrefix != nil { + fpr.Next, err = ft.nextFromPrefix(&ac, *ac.lastPrefix) + ft.log("fingerprintInterval: next at lastPrefix %s -> %s", *ac.lastPrefix, fpr.Next) + } else { + next, err := ft.From(y, 1).First() + if err != nil { + return FPResult{}, err + } + fpr.Next = next.Clone() + ft.log("fingerprintInterval: next at y: %s", fpr.Next) + } + + return fpr, nil +} + +// easySplit splits an interval in two parts trying to do it in such way that the first +// part has close to limit items while not making any idStore queries so that the database +// is not accessed. If the split can't be done, which includes the situation where one of +// the sides has 0 items, easySplit returns errEasySplitFailed error. +// easySplit never fails for a tree with values. +func (ft *FPTree) easySplit(x, y rangesync.KeyBytes, limit int) (sr SplitResult, err error) { + ft.enter("easySplit: x %s y %s limit %d", x, y, limit) + defer func() { + ft.leave(sr.Part0.FP, sr.Part0.Count, sr.Part0.IType, sr.Part0.Items, sr.Part0.Next, + sr.Part1.FP, sr.Part1.Count, sr.Part1.IType, sr.Part1.Items, sr.Part1.Next, err) + }() + if limit < 0 { + panic("BUG: easySplit with limit < 0") + } + ac := aggContext{np: ft.np, x: x, y: y, limit: limit, easySplit: true} + if err := ft.aggregateInterval(&ac); err != nil { + return SplitResult{}, err + } + + if ac.total == 0 { + return SplitResult{}, nil + } + + if ac.count0 == 0 || ac.count == 0 { + // need to get some items on both sides for the easy split to succeed + ft.log("easySplit failed: one side missing: count0 %d count %d", ac.count0, ac.count) + return SplitResult{}, errEasySplitFailed + } + + // It should not be possible to have ac.lastPrefix0 == nil or ac.lastPrefix == nil + // if both ac.count0 and ac.count are non-zero, b/c of how + // aggContext.maybeIncludeNode works + if ac.lastPrefix0 == nil || ac.lastPrefix == nil { + panic("BUG: easySplit lastPrefix or lastPrefix0 not set") + } + + // ac.start / ac.end are only set in aggregateEdge which fails with + // errEasySplitFailed if easySplit is enabled, so we can ignore them here + middle := make(rangesync.KeyBytes, ft.keyLen) + ac.lastPrefix0.idAfter(middle) + ft.log("easySplit: lastPrefix0 %s middle %s", ac.lastPrefix0, middle) + items := ft.From(x, 1) + part0 := FPResult{ + FP: ac.fp0, + Count: ac.count0, + IType: ac.itype, + Items: items, + // Next is only used during splitting itself, and thus not included + } + items = ft.startFromPrefix(&ac, *ac.lastPrefix0) + part1 := FPResult{ + FP: ac.fp, + Count: ac.count, + IType: ac.itype, + Items: items, + // Next is only used during splitting itself, and thus not included + } + return SplitResult{ + Part0: part0, + Part1: part1, + Middle: middle, + }, nil +} + +// Split splits an interval in two parts. +func (ft *FPTree) Split(x, y rangesync.KeyBytes, limit int) (sr SplitResult, err error) { + ft.np.lockRead() + defer ft.np.unlockRead() + sr, err = ft.easySplit(x, y, limit) + if err == nil { + return sr, nil + } + if err != errEasySplitFailed { + return SplitResult{}, err + } + + fpr0, err := ft.fingerprintInterval(x, y, limit) + if err != nil { + return SplitResult{}, err + } + + if fpr0.Count == 0 { + return SplitResult{}, errors.New("can't split empty range") + } + + fpr1, err := ft.fingerprintInterval(fpr0.Next, y, -1) + if err != nil { + return SplitResult{}, err + } + + if fpr1.Count == 0 { + return SplitResult{}, errors.New("split produced empty 2nd range") + } + + return SplitResult{ + Part0: fpr0, + Part1: fpr1, + Middle: fpr0.Next, + }, nil +} + +// dumpNode prints the node structure to the writer. +func (ft *FPTree) dumpNode(w io.Writer, idx nodeIndex, indent, dir string) { + if idx == noIndex { + return + } + + count, fp, leaf := ft.np.info(idx) + countStr := strconv.Itoa(int(count)) + if leaf { + countStr = "LEAF:" + countStr + } + var valStr string + if v := ft.np.value(idx); v != nil { + valStr = fmt.Sprintf(" ", shortened(v)) + } + fmt.Fprintf(w, "%s%sidx=%d %s %s [%d]%s\n", indent, dir, idx, fp, countStr, ft.np.refCount(idx), valStr) + if !leaf { + indent += " " + ft.dumpNode(w, ft.np.left(idx), indent, "l: ") + ft.dumpNode(w, ft.np.right(idx), indent, "r: ") + } +} + +// Dump prints the tree structure to the writer. +func (ft *FPTree) Dump(w io.Writer) { + ft.np.lockRead() + defer ft.np.unlockRead() + if ft.root == noIndex { + fmt.Fprintln(w, "empty tree") + } else { + ft.dumpNode(w, ft.root, "", "") + } +} + +// DumpToString returns the tree structure as a string. +func (ft *FPTree) DumpToString() string { + var sb strings.Builder + ft.Dump(&sb) + return sb.String() +} + +// Count returns the number of items in the tree. +func (ft *FPTree) Count() int { + ft.np.lockRead() + defer ft.np.unlockRead() + if ft.root == noIndex { + return 0 + } + return int(ft.np.count(ft.root)) +} + +// EnableTrace enables or disables tracing for the tree. +func (ft *FPTree) EnableTrace(enable bool) { + ft.traceEnabled = enable +} diff --git a/sync2/fptree/fptree_test.go b/sync2/fptree/fptree_test.go new file mode 100644 index 0000000000..6136715d49 --- /dev/null +++ b/sync2/fptree/fptree_test.go @@ -0,0 +1,1247 @@ +package fptree_test + +import ( + "context" + "fmt" + "math/rand/v2" + "slices" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sync2/fptree" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/sqlstore" +) + +const ( + testKeyLen = 32 + testDepth = 24 +) + +func requireEmpty(t *testing.T, sr rangesync.SeqResult) { + for range sr.Seq { + require.Fail(t, "expected an empty sequence") + } + require.NoError(t, sr.Error()) +} + +func firstKey(t *testing.T, sr rangesync.SeqResult) rangesync.KeyBytes { + k, err := sr.First() + require.NoError(t, err) + return k +} + +func testFPTree(t *testing.T, makeFPTrees mkFPTreesFunc) { + type rangeTestCase struct { + xIdx, yIdx int + x, y string + limit int + fp string + count uint32 + itype int + startIdx, endIdx int + } + for _, tc := range []struct { + name string + ids []string + ranges []rangeTestCase + x, y string + }{ + { + name: "empty", + ids: nil, + ranges: []rangeTestCase{ + { + x: "123456789abcdef0000000000000000000000000000000000000000000000000", + y: "123456789abcdef0000000000000000000000000000000000000000000000000", + limit: -1, + fp: "000000000000000000000000", + count: 0, + itype: 0, + startIdx: -1, + endIdx: -1, + }, + { + x: "123456789abcdef0000000000000000000000000000000000000000000000000", + y: "123456789abcdef0000000000000000000000000000000000000000000000000", + limit: 1, + fp: "000000000000000000000000", + count: 0, + itype: 0, + startIdx: -1, + endIdx: -1, + }, + { + x: "123456789abcdef0000000000000000000000000000000000000000000000000", + y: "223456789abcdef0000000000000000000000000000000000000000000000000", + limit: 1, + fp: "000000000000000000000000", + count: 0, + itype: -1, + startIdx: -1, + endIdx: -1, + }, + { + x: "223456789abcdef0000000000000000000000000000000000000000000000000", + y: "123456789abcdef0000000000000000000000000000000000000000000000000", + limit: 1, + fp: "000000000000000000000000", + count: 0, + itype: 1, + startIdx: -1, + endIdx: -1, + }, + }, + }, + { + name: "ids1", + ids: []string{ + "0000000000000000000000000000000000000000000000000000000000000000", + "123456789abcdef0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "abcdef1234567890000000000000000000000000000000000000000000000000", + }, + ranges: []rangeTestCase{ + { + xIdx: 0, + yIdx: 0, + limit: -1, + fp: "642464b773377bbddddddddd", + count: 5, + itype: 0, + startIdx: 0, + endIdx: 0, + }, + { + xIdx: 0, + yIdx: 0, + limit: 0, + fp: "000000000000000000000000", + count: 0, + itype: 0, + startIdx: 0, + endIdx: 0, + }, + { + xIdx: 0, + yIdx: 0, + limit: 3, + fp: "4761032dcfe98ba555555555", + count: 3, + itype: 0, + startIdx: 0, + endIdx: 3, + }, + { + xIdx: 4, + yIdx: 4, + limit: -1, + fp: "642464b773377bbddddddddd", + count: 5, + itype: 0, + startIdx: 4, + endIdx: 4, + }, + { + xIdx: 4, + yIdx: 4, + limit: 1, + fp: "abcdef123456789000000000", + count: 1, + itype: 0, + startIdx: 4, + endIdx: 0, + }, + { + xIdx: 0, + yIdx: 1, + limit: -1, + fp: "000000000000000000000000", + count: 1, + itype: -1, + startIdx: 0, + endIdx: 1, + }, + { + xIdx: 0, + yIdx: 3, + limit: -1, + fp: "4761032dcfe98ba555555555", + count: 3, + itype: -1, + startIdx: 0, + endIdx: 3, + }, + { + xIdx: 0, + yIdx: 4, + limit: 3, + fp: "4761032dcfe98ba555555555", + count: 3, + itype: -1, + startIdx: 0, + endIdx: 3, + }, + { + xIdx: 0, + yIdx: 4, + limit: 0, + fp: "000000000000000000000000", + count: 0, + itype: -1, + startIdx: 0, + endIdx: 0, + }, + { + xIdx: 1, + yIdx: 4, + limit: -1, + fp: "cfe98ba54761032ddddddddd", + count: 3, + itype: -1, + startIdx: 1, + endIdx: 4, + }, + { + xIdx: 1, + yIdx: 0, + limit: -1, + fp: "642464b773377bbddddddddd", + count: 4, + itype: 1, + startIdx: 1, + endIdx: 0, + }, + { + xIdx: 2, + yIdx: 0, + limit: -1, + fp: "761032cfe98ba54ddddddddd", + count: 3, + itype: 1, + startIdx: 2, + endIdx: 0, + }, + { + xIdx: 2, + yIdx: 0, + limit: 0, + fp: "000000000000000000000000", + count: 0, + itype: 1, + startIdx: 2, + endIdx: 2, + }, + { + xIdx: 3, + yIdx: 1, + limit: -1, + fp: "2345679abcdef01888888888", + count: 3, + itype: 1, + startIdx: 3, + endIdx: 1, + }, + { + xIdx: 3, + yIdx: 2, + limit: -1, + fp: "317131e226622ee888888888", + count: 4, + itype: 1, + startIdx: 3, + endIdx: 2, + }, + { + xIdx: 3, + yIdx: 2, + limit: 3, + fp: "2345679abcdef01888888888", + count: 3, + itype: 1, + startIdx: 3, + endIdx: 1, + }, + { + x: "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0", + y: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + limit: -1, + fp: "000000000000000000000000", + count: 0, + itype: -1, + startIdx: 0, + endIdx: 0, + }, + }, + }, + { + name: "ids2", + ids: []string{ + "6e476ca729c3840d0118785496e488124ee7dade1aef0c87c6edc78f72e4904f", + "829977b444c8408dcddc1210536f3b3bdc7fd97777426264b9ac8f70b97a7fd1", + "a280bcb8123393e0d4a15e5c9850aab5dddffa03d5efa92e59bc96202e8992bc", + "e93163f908630280c2a8bffd9930aa684be7a3085432035f5c641b0786590d1d", + }, + ranges: []rangeTestCase{ + { + xIdx: 0, + yIdx: 0, + limit: -1, + fp: "a76fc452775b55e0dacd8be5", + count: 4, + itype: 0, + startIdx: 0, + endIdx: 0, + }, + { + xIdx: 0, + yIdx: 0, + limit: 3, + fp: "4e5ea7ab7f38576018653418", + count: 3, + itype: 0, + startIdx: 0, + endIdx: 3, + }, + { + xIdx: 0, + yIdx: 3, + limit: -1, + fp: "4e5ea7ab7f38576018653418", + count: 3, + itype: -1, + startIdx: 0, + endIdx: 3, + }, + { + xIdx: 3, + yIdx: 1, + limit: -1, + fp: "87760f5e21a0868dc3b0c7a9", + count: 2, + itype: 1, + startIdx: 3, + endIdx: 1, + }, + { + xIdx: 3, + yIdx: 2, + limit: -1, + fp: "05ef78ea6568c6000e6cd5b9", + count: 3, + itype: 1, + startIdx: 3, + endIdx: 2, + }, + }, + }, + { + name: "ids3", + ids: []string{ + "0451cd036aff0367b07590032da827b516b63a4c1b36ea9a253dcf9a7e084980", + "0e75d10a8e98a4307dd9d0427dc1d2ebf9e45b602d159ef62c5da95197159844", + "18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + "33940245f4aace670c84f471ff4e862d1d82ce0ada9b98a753038b4f9e60e330", + "366d9e7adb3932e52e0a92a0afc75a2875995e7de8e0c4159e22eb97526a3547", + "66883aa35d2c8d293f07c5c5c40c63416317423418fe5c7fd17b5fb68b3e976e", + "80fce3e9654459cff3441e1a96413f0872e0b6f093879609696042fcfe1c8115", + "8b2025fbe0bbebea4baee48bac9a63a4013a2ec898d7b0a518eccdb99bdb368e", + "8e3e609653adfddcdcb6ddda7461db3a2fc822c3f96874a002f715b80865e575", + "9b25e39d6cc3beac3ecc12140f46a699880ac8303555c694fd40ba8e61bb8b47", + "a3c8628a1b28d1ba6f3d8beb4a29315c02789c5b53a095fa7865c9b3041502d6", + "a98fdcab5e351a1bfd25ddcf9973e9c56a4b688d78743a8a03fa3b1d53da4949", + "ac9c015dd51defacfc14bd4c9c8eedb89aad884bef493553a189a2915c828e95", + "ba745196493a8368ef091860f2692978b381f67566d3413e85167672d672c8ac", + "c26353d8bc9a1eea8e79fd693c1a1e58dacded75ceda84ed6c356bcf02b6d0f1", + "c3f126a37c2e33b6258c87fd043026dacf0b8dd4df7a9afd7cdc293b075e1878", + "cefd0cc8b32929df07b6ebb5b6e433f28d5460f143814f3f651330ea15e5d6e7", + "d9390718256e71edfe671334edbfcbed8b4de3221db55805ebf606c73fe969f1", + "db7ee147da05a5cbec3f59b020cbdba88e40ab6b212ae93c98d5a210d83a4a7b", + "deab906f979a647eff85f3a54e5edd665f2536e0005812aee2e5e411ae71855e", + "e0b6ab7f483527771faadbee8b4ed99ae96167d054ae5c513faf00c78aa36bdd", + "e4ed6f5dcf179a4f10521d58d65d423098af5f6f18c42f3125a5917d338b7477", + "e53de3ec53ba88029a2a0459a3ab82cdb3726c8aeccabf38a04e048b9add92ef", + "f2aff99498615c44d94266060e948c11bb275ec37d0d3c651bb3ba0039a11a64", + "f7f81332b63b79718f0321660a5cd8f6970474ff873afcdebb0d3436a2ad12ac", + "fb42c36089a4883bc7ceaae9a57924d78557edb63ede3d5a2cf2d1f08db799d0", + "fe494ce48f5826c00f6bc6af74258ec6e47b92365850deed95b5bfcaeccc6be8", + }, + ranges: []rangeTestCase{ + { + x: "582485793d71c3e8429b9b2c8df360c2ea7bf90080d5bf375fe4618b00f59c0b", + y: "7eff517d2f11ed32f935be3001499ac779160a4891a496f88da0ceb33e3496cc", + limit: -1, + fp: "66883aa35d2c8d293f07c5c5", + count: 1, + itype: -1, + startIdx: 10, + endIdx: 11, + }, + }, + }, + { + name: "ids4", + ids: []string{ + "06a1f93f0dd88b60473d73127196631134382d59b7cd9b3e6bd6b4f25dd1c782", + "488da52a035df8674aa658d30ff58de82c9dc2ae9c474e004d585c52979eacbb", + "b5527010e990254702f77ffc8a6d6b499040bc3dc61b169a56fbc690e970c046", + "e10fc3141c5e3a00861a4dddb495a33736f845bff62fd295985b7dfa6bcbfc91", + }, + ranges: []rangeTestCase{ + { + xIdx: 2, + yIdx: 0, + limit: 1, + fp: "b5527010e990254702f77ffc", + count: 1, + itype: 1, + startIdx: 2, + endIdx: 3, + }, + }, + }, + { + name: "ids6", + ids: []string{ + "2727d39a2150ef91ef09fa0b60950a189d73e53fd73c1fc7a74e0a393582e51e", + "96a3a7cfdc9ec9101fd4a8bdf831c54053c2cd0b06a6914772edb68a0153fdec", + "b80318c43da5e4b56aa3b7f408a8f86c98418e5b364ef67a37db6017097c2ebc", + "b899092149e332f9686e02e2878e63b7ac85694eeadfe02c94f4f15627f41bcc", + }, + ranges: []rangeTestCase{ + { + xIdx: 3, + yIdx: 3, + limit: 2, + fp: "9fbedabb68b3dd688767f8e9", + count: 2, + itype: 0, + startIdx: 3, + endIdx: 1, + }, + }, + }, + { + name: "ids7", + ids: []string{ + "3595ec355452c94143c6bdae281b162e5b0997e6392dd1a345146861b8fb4586", + "68d02e8f0c69b0b16dc73dda147a231a09b32d709b9b4028f13ee7ffa2e820c8", + "7079bb2d00f961b4dc42911e2009411ceb7b8c950492a627111b60773a31c2ce", + "ad69fbf959a0b0ba1042a2b13d1b2c9a17f8507c642e55dd93277fe8dab378a6", + }, + ranges: []rangeTestCase{ + { + x: "4844a20cd5a83c101cc522fa37539412d0aac4c76a48b940e1845c3f2fe79c85", + y: "cb93566c2037bc8353162e9988974e4585c14f656bf6aed8fa51d00e1ae594de", + limit: -1, + fp: "b5c06e5b553061bfa1c70e75", + count: 3, + itype: -1, + startIdx: 1, + endIdx: 0, + }, + }, + }, + { + name: "ids8", + ids: []string{ + "0e69888877324da35693decc7ded1b2bac16d394ced869af494568d66473a6f0", + "3a78db9e386493402561d9c6f69a6b434a62388f61d06d960598ebf29a3a2187", + "66c9aa8f3be7da713db66e56cc165a46764f88d3113244dd5964bb0a10ccacc3", + "90b25f2d1ee9c9e2d20df5f2226d14ee4223ea27ba565a49aa66a9c44a51c241", + "9e11fdb099f1118144738f9b68ca601e74b97280fd7bbc97cfc377f432e9b7b5", + "c1690e47798295cca02392cbfc0a86cb5204878c04a29b3ae7701b6b51681128", + }, + ranges: []rangeTestCase{ + { + x: "9e11fdb099f1118144738f9b68ca601e74b97280fd7bbc97cfc377f432e9b7b5", + y: "0e69880000000000000000000000000000000000000000000000000000000000", + limit: -1, + fp: "5f78f3f7e073844de4501d50", + count: 2, + itype: 1, + startIdx: 4, + endIdx: 0, + }, + }, + }, + { + name: "ids9", + ids: []string{ + "03744b955a21408f78eb4c1e51f897ed90c22cf561b8ecef25a4b6ec68f3e895", + "691a62eb05d21ee9407fd48d252b5e80a525fd017e941fba383ceabe2ce0c0ee", + "73e10ac8b36bc20195c5d1b162d05402eaef6622accf648399cb60874ac22165", + "845c0a945137ed6b52fbb96a57909869cf34f41100a3a60e5d385d28c42621e1", + "bc1ffc4d9fddbd9f3cd17c0fe53c6b86a2e36256f37e1e73c11e4c9effa911bf", + }, + ranges: []rangeTestCase{ + { + x: "1a4f33388cab82533de99d9370fe367f654c76cd7e71a28334d993a31aa3e87a", + y: "6c5fe0023abc90d0a9327083ebc73c442cec8854f99e378551b502448f2ce000", + limit: -1, + fp: "691a62eb05d21ee9407fd48d", + count: 1, + itype: -1, + startIdx: 1, + endIdx: 2, + }, + }, + }, + { + name: "ids10", + ids: []string{ + "0aea5e19b9f53af915110ba1e05494666e8a1f4bb597d6ca0193c34b525f3480", + "219d9f504af986492356061a68cd2355fd423768c70e511cd7802cd4fdbde1c5", + "277a6bbc173628948456cbeb90309ae70ab837296f504640b53a891a3ddefb65", + "2ff6f89a1f0655255a74ff0dc4eda3a67ff69bc9667261763536917db15d9fe2", + "46b9e5fb278225f28885717512a4b2e5fbbc79b61bde8417cc2e5caf0ad86b17", + "a732516bf7198a3c3cb4edc1c3b1ec11a2545844c45464df44e31135ad84fee0", + "ea238facb9e3b3b6b9ca66bd9472b505e982ed937b22eb127269723124bb9ce8", + "ff90f791d2678d09d12f1a672de85c5127ef1f8a47ae5e8f3b61de06fd803db7", + }, + ranges: []rangeTestCase{ + { + x: "64015400af6cc54ce62fe1b478b38abfef5ab609182d6df0fd46f16c880263b2", + y: "0fcc4ed4c932e1f6ba53418a0116d20ab119c1152644abe5ee1ab30599cd3780", + limit: -1, + fp: "b86b774f25688e7a41409aba", + count: 4, + itype: 1, + startIdx: 5, + endIdx: 1, + }, + }, + }, + { + name: "ids11", + ids: []string{ + "05ce2ac65bf22e2d196814d881125ce5e4f93078ab357e151c7bfccd9ef24f1d", + "81f9f4becc8f91f1c37075ec810828b13d4e8d98b8207c467537043a1bb5d72c", + "a15ecd17ec6674a14faf67649e0058366bf852bd51a0c41c15542861eaf55bac", + "baeaf7d94cc800d38215396e46ba9e1293107a7e5c5d1cd5771f341e570b9f95", + "bd666290c1e339e8cc9d4d1aaf3ce68169dfffbfbe112e22818c72eb373160fd", + "d598253954cbf6719829dd4dca89106622cfb87666991214fece997855478a1c", + "d9e7a5bfa187a248e894e5e72874b3bf40b0863f707c72ae70e2042ba497d3ec", + "e58ededd4c54788c451ede2a3b92e62e1148fcd4184262dab28056f03b639ef5", + }, + ranges: []rangeTestCase{ + { + xIdx: 7, + yIdx: 2, + limit: -1, + fp: "61b900a5db29c7509f06bf1e", + count: 3, + itype: 1, + startIdx: 7, + endIdx: 2, + }, + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + trees := makeFPTrees(t) + ft := trees[0] + var hs []rangesync.KeyBytes + for _, hex := range tc.ids { + h := rangesync.MustParseHexKeyBytes(hex) + hs = append(hs, h) + require.NoError(t, ft.RegisterKey(h)) + fptree.AnalyzeTreeNodeRefs(t, trees...) + } + + var sb strings.Builder + ft.Dump(&sb) + t.Logf("tree:\n%s", sb.String()) + + fptree.CheckTree(t, ft) + for _, k := range hs { + require.True(t, ft.CheckKey(k), "checkKey(%s)", k.ShortString()) + } + require.False(t, ft.CheckKey(rangesync.RandomKeyBytes(testKeyLen)), "checkKey(random)") + + for _, rtc := range tc.ranges { + var x, y rangesync.KeyBytes + var name string + if rtc.x != "" { + x = rangesync.MustParseHexKeyBytes(rtc.x) + y = rangesync.MustParseHexKeyBytes(rtc.y) + name = fmt.Sprintf("%s-%s_%d", rtc.x, rtc.y, rtc.limit) + } else { + x = hs[rtc.xIdx] + y = hs[rtc.yIdx] + name = fmt.Sprintf("%d-%d_%d", rtc.xIdx, rtc.yIdx, rtc.limit) + } + t.Run(name, func(t *testing.T) { + fpr, err := ft.FingerprintInterval(x, y, rtc.limit) + require.NoError(t, err) + assert.Equal(t, rtc.fp, fpr.FP.String(), "fp") + assert.Equal(t, rtc.count, fpr.Count, "count") + assert.Equal(t, rtc.itype, fpr.IType, "itype") + + if rtc.startIdx == -1 { + requireEmpty(t, fpr.Items) + } else { + require.NotNil(t, fpr.Items, "items") + expK := rangesync.KeyBytes(hs[rtc.startIdx]) + assert.Equal(t, expK, firstKey(t, fpr.Items), "items") + } + + if rtc.endIdx == -1 { + require.Nil(t, fpr.Next, "next") + } else { + require.NotNil(t, fpr.Next, "next") + expK := rangesync.KeyBytes(hs[rtc.endIdx]) + assert.Equal(t, expK, fpr.Next, "next") + } + }) + } + + ft.Release() + require.Zero(t, ft.PoolNodeCount()) + }) + } +} + +type mkFPTreesFunc func(t *testing.T) []*fptree.FPTree + +func makeFPTreeWithValues(t *testing.T) []*fptree.FPTree { + ft := fptree.NewFPTreeWithValues(0, testKeyLen) + return []*fptree.FPTree{ft} +} + +func makeInMemoryFPTree(t *testing.T) []*fptree.FPTree { + store := fptree.NewFPTreeWithValues(0, testKeyLen) + ft := fptree.NewFPTree(0, store, testKeyLen, testDepth) + return []*fptree.FPTree{ft, store} +} + +func makeDBBackedFPTree(t *testing.T) []*fptree.FPTree { + db := sqlstore.CreateDB(t, testKeyLen) + st := &sqlstore.SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + tx, err := db.Tx(context.Background()) + require.NoError(t, err) + t.Cleanup(func() { tx.Release() }) + sts, err := st.Snapshot(tx) + require.NoError(t, err) + store := fptree.NewDBBackedStore(tx, sts, 0, testKeyLen) + ft := fptree.NewFPTree(0, store, testKeyLen, testDepth) + return []*fptree.FPTree{ft, store.FPTree} +} + +func TestFPTree(t *testing.T) { + t.Run("values in fpTree", func(t *testing.T) { + testFPTree(t, makeFPTreeWithValues) + }) + t.Run("in-memory fptree-based id store", func(t *testing.T) { + testFPTree(t, makeInMemoryFPTree) + }) + t.Run("db-backed store", func(t *testing.T) { + testFPTree(t, makeDBBackedFPTree) + }) +} + +func TestFPTreeAsStore(t *testing.T) { + s := fptree.NewFPTreeWithValues(0, testKeyLen) + + sr := s.All() + for range sr.Seq { + require.Fail(t, "sequence not empty") + } + require.NoError(t, sr.Error()) + + sr = s.From(rangesync.MustParseHexKeyBytes( + "0000000000000000000000000000000000000000000000000000000000000000"), + 1) + for range sr.Seq { + require.Fail(t, "sequence not empty") + } + require.NoError(t, sr.Error()) + + for _, h := range []string{ + "0000000000000000000000000000000000000000000000000000000000000000", + "1234561111111111111111111111111111111111111111111111111111111111", + "123456789abcdef0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "8888889999999999999999999999999999999999999999999999999999999999", + "abcdef1234567890000000000000000000000000000000000000000000000000", + } { + s.RegisterKey(rangesync.MustParseHexKeyBytes(h)) + } + + sr = s.All() + for range 3 { // make sure seq is reusable + var r []string + n := 15 + for k := range sr.Seq { + r = append(r, k.String()) + n-- + if n == 0 { + break + } + } + require.NoError(t, sr.Error()) + require.Equal(t, []string{ + "0000000000000000000000000000000000000000000000000000000000000000", + "1234561111111111111111111111111111111111111111111111111111111111", + "123456789abcdef0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "8888889999999999999999999999999999999999999999999999999999999999", + "abcdef1234567890000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + "1234561111111111111111111111111111111111111111111111111111111111", + "123456789abcdef0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "8888889999999999999999999999999999999999999999999999999999999999", + "abcdef1234567890000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + }, r) + } + + sr = s.From(rangesync.MustParseHexKeyBytes( + "5555555555555555555555555555555555555555555555555555555555555555"), + 1) + for range 3 { // make sure seq is reusable + var r []string + n := 15 + for k := range sr.Seq { + r = append(r, k.String()) + n-- + if n == 0 { + break + } + } + require.NoError(t, sr.Error()) + require.Equal(t, []string{ + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "8888889999999999999999999999999999999999999999999999999999999999", + "abcdef1234567890000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + "1234561111111111111111111111111111111111111111111111111111111111", + "123456789abcdef0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "8888889999999999999999999999999999999999999999999999999999999999", + "abcdef1234567890000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + "1234561111111111111111111111111111111111111111111111111111111111", + "123456789abcdef0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + }, r) + } +} + +type noIDStore struct{} + +var _ sqlstore.IDStore = noIDStore{} + +func (noIDStore) Clone() sqlstore.IDStore { return &noIDStore{} } +func (noIDStore) RegisterKey(h rangesync.KeyBytes) error { return nil } +func (noIDStore) All() rangesync.SeqResult { panic("no ID store") } +func (noIDStore) Release() {} + +func (noIDStore) From(from rangesync.KeyBytes, sizeHint int) rangesync.SeqResult { + return rangesync.EmptySeqResult() +} + +// TestFPTreeNoIDStoreCalls tests that an fpTree can avoid using an idStore if X has only +// 0 bits below max-depth and Y has only 1 bits below max-depth. It also checks that an fpTree +// can avoid using an idStore in "relaxed count" mode for splitting ranges. +func TestFPTreeNoIDStoreCalls(t *testing.T) { + ft := fptree.NewFPTree(0, &noIDStore{}, testKeyLen, testDepth) + hashes := []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("1111111111111111111111111111111111111111111111111111111111111111"), + rangesync.MustParseHexKeyBytes("2222222222222222222222222222222222222222222222222222222222222222"), + rangesync.MustParseHexKeyBytes("4444444444444444444444444444444444444444444444444444444444444444"), + rangesync.MustParseHexKeyBytes("8888888888888888888888888888888888888888888888888888888888888888"), + } + for _, h := range hashes { + ft.RegisterKey(h) + } + + for _, tc := range []struct { + x, y rangesync.KeyBytes + limit int + fp string + count uint32 + }{ + { + x: hashes[0], + y: hashes[0], + limit: -1, + fp: "ffffffffffffffffffffffff", + count: 4, + }, + { + x: rangesync.MustParseHexKeyBytes( + "1111110000000000000000000000000000000000000000000000000000000000"), + y: rangesync.MustParseHexKeyBytes( + "1111120000000000000000000000000000000000000000000000000000000000"), + limit: -1, + fp: "111111111111111111111111", + count: 1, + }, + { + x: rangesync.MustParseHexKeyBytes( + "0000000000000000000000000000000000000000000000000000000000000000"), + y: rangesync.MustParseHexKeyBytes( + "9000000000000000000000000000000000000000000000000000000000000000"), + limit: -1, + fp: "ffffffffffffffffffffffff", + count: 4, + }, + } { + fpr, err := ft.FingerprintInterval(tc.x, tc.y, tc.limit) + require.NoError(t, err) + require.Equal(t, tc.fp, fpr.FP.String(), "fp") + require.Equal(t, tc.count, fpr.Count, "count") + } +} + +func TestFPTreeClone(t *testing.T) { + store := fptree.NewFPTreeWithValues(10, testKeyLen) + ft1 := fptree.NewFPTree(10, store, testKeyLen, testDepth) + hashes := []rangesync.KeyBytes{ + rangesync.MustParseHexKeyBytes("1111111111111111111111111111111111111111111111111111111111111111"), + rangesync.MustParseHexKeyBytes("3333333333333333333333333333333333333333333333333333333333333333"), + rangesync.MustParseHexKeyBytes("4444444444444444444444444444444444444444444444444444444444444444"), + } + + ft1.RegisterKey(hashes[0]) + fptree.AnalyzeTreeNodeRefs(t, ft1, store) + + ft1.RegisterKey(hashes[1]) + + fpr, err := ft1.FingerprintInterval(hashes[0], hashes[0], -1) + require.NoError(t, err) + require.Equal(t, "222222222222222222222222", fpr.FP.String(), "fp") + require.Equal(t, uint32(2), fpr.Count, "count") + require.Equal(t, 0, fpr.IType, "itype") + + fptree.AnalyzeTreeNodeRefs(t, ft1, store) + + ft2 := ft1.Clone().(*fptree.FPTree) + + fpr, err = ft1.FingerprintInterval(hashes[0], hashes[0], -1) + require.NoError(t, err) + require.Equal(t, "222222222222222222222222", fpr.FP.String(), "fp") + require.Equal(t, uint32(2), fpr.Count, "count") + require.Equal(t, 0, fpr.IType, "itype") + + fptree.AnalyzeTreeNodeRefs(t, ft1, ft2, store, ft2.IDStore().(*fptree.FPTree)) + + t.Logf("add hash to copy") + ft2.RegisterKey(hashes[2]) + + fpr, err = ft2.FingerprintInterval(hashes[0], hashes[0], -1) + require.NoError(t, err) + require.Equal(t, "666666666666666666666666", fpr.FP.String(), "fp") + require.Equal(t, uint32(3), fpr.Count, "count") + require.Equal(t, 0, fpr.IType, "itype") + + // original tree unchanged + fpr, err = ft1.FingerprintInterval(hashes[0], hashes[0], -1) + require.NoError(t, err) + require.Equal(t, "222222222222222222222222", fpr.FP.String(), "fp") + require.Equal(t, uint32(2), fpr.Count, "count") + require.Equal(t, 0, fpr.IType, "itype") + + fptree.AnalyzeTreeNodeRefs(t, ft1, ft2, store, ft2.IDStore().(*fptree.FPTree)) + + ft1.Release() + ft2.Release() + fptree.AnalyzeTreeNodeRefs(t, ft1, ft2, store, ft2.IDStore().(*fptree.FPTree)) + + require.Zero(t, ft1.PoolNodeCount()) + require.Zero(t, ft2.PoolNodeCount()) +} + +func TestRandomClone(t *testing.T) { + trees := []*fptree.FPTree{ + fptree.NewFPTree(1000, fptree.NewFPTreeWithValues(1000, testKeyLen), testKeyLen, testDepth), + } + for range 100 { + n := len(trees) + for range rand.IntN(20) { + trees = append(trees, trees[rand.IntN(n)].Clone().(*fptree.FPTree)) + } + for range rand.IntN(100) { + trees[rand.IntN(len(trees))].RegisterKey(rangesync.RandomKeyBytes(testKeyLen)) + } + + trees = slices.DeleteFunc(trees, func(ft *fptree.FPTree) bool { + if n == 1 { + return false + } + n-- + if rand.IntN(3) == 0 { + ft.Release() + return true + } + return false + }) + allTrees := slices.Clone(trees) + for _, ft := range trees { + allTrees = append(allTrees, ft.IDStore().(*fptree.FPTree)) + } + fptree.AnalyzeTreeNodeRefs(t, allTrees...) + for _, ft := range trees { + fptree.CheckTree(t, ft) + fptree.CheckTree(t, ft.IDStore().(*fptree.FPTree)) + } + if t.Failed() { + break + } + } + for _, ft := range trees { + ft.Release() + } + for _, ft := range trees { + require.Zero(t, ft.PoolNodeCount()) + } +} + +type hashList []rangesync.KeyBytes + +func (l hashList) findGTE(h rangesync.KeyBytes) int { + p, _ := slices.BinarySearchFunc(l, h, func(a, b rangesync.KeyBytes) int { + return a.Compare(b) + }) + return p +} + +func (l hashList) keyAt(p int) rangesync.KeyBytes { + if p == len(l) { + p = 0 + } + return rangesync.KeyBytes(l[p]) +} + +type fpResultWithBounds struct { + fp rangesync.Fingerprint + //nolint:unused + count uint32 + itype int + start rangesync.KeyBytes + //nolint:unused + next rangesync.KeyBytes +} + +func toFPResultWithBounds(t *testing.T, fpr fptree.FPResult) fpResultWithBounds { + return fpResultWithBounds{ + fp: fpr.FP, + count: fpr.Count, + itype: fpr.IType, + next: fpr.Next, + start: firstKey(t, fpr.Items), + } +} + +func dumbFP(hs hashList, x, y rangesync.KeyBytes, limit int) fpResultWithBounds { + var fpr fpResultWithBounds + l := len(hs) + if l == 0 { + return fpr + } + fpr.itype = x.Compare(y) + switch fpr.itype { + case -1: + p := hs.findGTE(x) + pY := hs.findGTE(y) + fpr.start = hs.keyAt(p) + for { + if p >= pY || limit == 0 { + fpr.next = hs.keyAt(p) + break + } + fpr.fp.Update(hs.keyAt(p)) + limit-- + fpr.count++ + p++ + } + case 1: + p := hs.findGTE(x) + fpr.start = hs.keyAt(p) + for { + if p >= len(hs) || limit == 0 { + fpr.next = hs.keyAt(p) + break + } + fpr.fp.Update(hs.keyAt(p)) + limit-- + fpr.count++ + p++ + } + if limit == 0 { + return fpr + } + pY := hs.findGTE(y) + p = 0 + for { + if p == pY || limit == 0 { + fpr.next = hs.keyAt(p) + break + } + fpr.fp.Update(hs.keyAt(p)) + limit-- + fpr.count++ + p++ + } + default: + pX := hs.findGTE(x) + p := pX + fpr.start = hs.keyAt(p) + fpr.next = fpr.start + for { + if limit == 0 { + fpr.next = hs.keyAt(p) + break + } + fpr.fp.Update(hs.keyAt(p)) + limit-- + fpr.count++ + p = (p + 1) % l + if p == pX { + break + } + } + } + return fpr +} + +func verifyInterval(t *testing.T, hs hashList, ft *fptree.FPTree, x, y rangesync.KeyBytes, limit int) fptree.FPResult { + expFPR := dumbFP(hs, x, y, limit) + fpr, err := ft.FingerprintInterval(x, y, limit) + require.NoError(t, err) + require.Equal(t, expFPR, toFPResultWithBounds(t, fpr), + "x=%s y=%s limit=%d", x.String(), y.String(), limit) + + require.Equal(t, expFPR, toFPResultWithBounds(t, fpr), + "x=%s y=%s limit=%d", x.String(), y.String(), limit) + + return fpr +} + +func verifySubIntervals( + t *testing.T, + hs hashList, + ft *fptree.FPTree, + x, y rangesync.KeyBytes, + limit, d int, +) fptree.FPResult { + fpr := verifyInterval(t, hs, ft, x, y, limit) + if fpr.Count > 1 { + c := int((fpr.Count + 1) / 2) + if limit >= 0 { + require.Less(t, c, limit) + } + part := verifyInterval(t, hs, ft, x, y, c) + m := make(rangesync.KeyBytes, len(x)) + copy(m, part.Next) + verifySubIntervals(t, hs, ft, x, m, -1, d+1) + verifySubIntervals(t, hs, ft, m, y, -1, d+1) + } + return fpr +} + +func testFPTreeManyItems(t *testing.T, trees []*fptree.FPTree, randomXY bool, numItems, maxDepth, repeat int) { + ft := trees[0] + hs := make(hashList, numItems) + var fp rangesync.Fingerprint + for i := range hs { + h := rangesync.RandomKeyBytes(testKeyLen) + hs[i] = h + ft.RegisterKey(h) + fp.Update(h) + } + fptree.AnalyzeTreeNodeRefs(t, trees...) + slices.SortFunc(hs, func(a, b rangesync.KeyBytes) int { + return a.Compare(b) + }) + + fptree.CheckTree(t, ft) + for _, k := range hs { + require.True(t, ft.CheckKey(k), "checkKey(%s)", k.ShortString()) + } + + fpr, err := ft.FingerprintInterval(hs[0], hs[0], -1) + require.NoError(t, err) + require.Equal(t, fp, fpr.FP, "fp") + require.Equal(t, uint32(numItems), fpr.Count, "count") + require.Equal(t, 0, fpr.IType, "itype") + for i := 0; i < repeat; i++ { + var x, y rangesync.KeyBytes + if randomXY { + x = rangesync.RandomKeyBytes(testKeyLen) + y = rangesync.RandomKeyBytes(testKeyLen) + } else { + x = hs[rand.IntN(numItems)] + y = hs[rand.IntN(numItems)] + } + verifySubIntervals(t, hs, ft, x, y, -1, 0) + } +} + +func repeatTestFPTreeManyItems( + t *testing.T, + makeFPTrees mkFPTreesFunc, +) { + const ( + repeatOuter = 2 + repeatInner = 3 + numItems = 1 << 9 + maxDepth = 12 + ) + for _, tc := range []struct { + name string + randomXY bool + }{ + { + name: "bounds from the set", + randomXY: false, + }, + { + name: "random bounds", + randomXY: true, + }, + } { + for i := 0; i < repeatOuter; i++ { + testFPTreeManyItems(t, makeFPTrees(t), tc.randomXY, numItems, maxDepth, repeatInner) + } + } +} + +func TestFPTreeManyItems(t *testing.T) { + t.Run("values in fpTree", func(t *testing.T) { + repeatTestFPTreeManyItems(t, makeFPTreeWithValues) + }) + t.Run("in-memory fptree-based id store", func(t *testing.T) { + repeatTestFPTreeManyItems(t, makeInMemoryFPTree) + }) + t.Run("db-backed store", func(t *testing.T) { + repeatTestFPTreeManyItems(t, makeDBBackedFPTree) + }) +} + +func verifyEasySplit( + t *testing.T, + ft *fptree.FPTree, + x, y rangesync.KeyBytes, + depth, + maxDepth int, +) ( + succeeded, failed int, +) { + fpr, err := ft.FingerprintInterval(x, y, -1) + require.NoError(t, err) + if fpr.Count <= 1 { + return 0, 0 + } + a := firstKey(t, fpr.Items) + require.NoError(t, err) + b := fpr.Next + require.NotNil(t, b) + + m := fpr.Count / 2 + sr, err := ft.EasySplit(x, y, int(m)) + if err != nil { + require.ErrorIs(t, err, fptree.ErrEasySplitFailed) + failed++ + sr, err = ft.Split(x, y, int(m)) + require.NoError(t, err) + } + require.NoError(t, err) + require.NotZero(t, sr.Part0.Count) + require.NotZero(t, sr.Part1.Count) + require.Equal(t, fpr.Count, sr.Part0.Count+sr.Part1.Count) + require.Equal(t, fpr.IType, sr.Part0.IType) + require.Equal(t, fpr.IType, sr.Part1.IType) + fp := sr.Part0.FP + fp.Update(sr.Part1.FP[:]) + require.Equal(t, fpr.FP, fp) + require.Equal(t, a, firstKey(t, sr.Part0.Items)) + precMiddle := firstKey(t, sr.Part1.Items) + + fpr11, err := ft.FingerprintInterval(x, precMiddle, -1) + require.NoError(t, err) + require.Equal(t, sr.Part0.Count, fpr11.Count) + require.Equal(t, sr.Part0.FP, fpr11.FP) + require.Equal(t, a, firstKey(t, fpr11.Items)) + + fpr12, err := ft.FingerprintInterval(precMiddle, y, -1) + require.NoError(t, err) + require.Equal(t, sr.Part1.Count, fpr12.Count) + require.Equal(t, sr.Part1.FP, fpr12.FP) + require.Equal(t, precMiddle, firstKey(t, fpr12.Items)) + + fpr11, err = ft.FingerprintInterval(x, sr.Middle, -1) + require.NoError(t, err) + require.Equal(t, sr.Part0.Count, fpr11.Count) + require.Equal(t, sr.Part0.FP, fpr11.FP) + require.Equal(t, a, firstKey(t, fpr11.Items)) + + fpr12, err = ft.FingerprintInterval(sr.Middle, y, -1) + require.NoError(t, err) + require.Equal(t, sr.Part1.Count, fpr12.Count) + require.Equal(t, sr.Part1.FP, fpr12.FP) + require.Equal(t, precMiddle, firstKey(t, fpr12.Items)) + + if maxDepth > 0 && depth >= maxDepth { + return 1, 0 + } + s1, f1 := verifyEasySplit(t, ft, x, sr.Middle, depth+1, maxDepth) + s2, f2 := verifyEasySplit(t, ft, sr.Middle, y, depth+1, maxDepth) + return succeeded + s1 + s2 + 1, failed + f1 + f2 +} + +func TestEasySplit(t *testing.T) { + maxDepth := 17 + count := 10000 + for range 5 { + store := fptree.NewFPTreeWithValues(10000, testKeyLen) + ft := fptree.NewFPTree(10000, store, testKeyLen, maxDepth) + for range count { + h := rangesync.RandomKeyBytes(testKeyLen) + ft.RegisterKey(h) + } + x := firstKey(t, ft.All()).Clone() + x.Trim(maxDepth) + + succeeded, failed := verifyEasySplit(t, ft, x, x, 0, maxDepth-2) + successRate := float64(succeeded) * 100 / float64(succeeded+failed) + t.Logf("succeeded %d, failed %d, success rate %.2f%%", + succeeded, failed, successRate) + require.GreaterOrEqual(t, successRate, 95.0) + } +} + +func TestEasySplitFPTreeWithValues(t *testing.T) { + count := 10000 + + for range 5 { + ft := fptree.NewFPTreeWithValues(10000, testKeyLen) + for range count { + h := rangesync.RandomKeyBytes(testKeyLen) + ft.RegisterKey(h) + } + + x := firstKey(t, ft.All()).Clone() + _, failed := verifyEasySplit(t, ft, x, x, 0, -1) + require.Zero(t, failed) + } +} diff --git a/sync2/fptree/nodepool.go b/sync2/fptree/nodepool.go new file mode 100644 index 0000000000..b04ac77cb5 --- /dev/null +++ b/sync2/fptree/nodepool.go @@ -0,0 +1,213 @@ +package fptree + +import ( + "slices" + "sync" + + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" +) + +// nodeIndex represents an index of a node in the node pool. +type nodeIndex uint32 + +const ( + // noIndex represents an invalid node index. + noIndex = ^nodeIndex(0) + // leafFlag is a flag that indicates that a node is a leaf node. + leafFlag = uint32(1 << 31) +) + +// node represents an fpTree node. +type node struct { + // Fingerprint + fp rangesync.Fingerprint + // Item count + c uint32 + // Left child, noIndex if not present. + l nodeIndex + // Right child, noIndex if not present. + r nodeIndex +} + +// nodePool represents a pool of tree nodes. +// The pool is shared between the orignal tree and its clones. +type nodePool struct { + mtx sync.RWMutex + rcPool rcPool[node, uint32] + leafMap map[uint32]rangesync.KeyBytes +} + +// init pre-allocates the node pool with n nodes. +func (np *nodePool) init(n int) { + np.rcPool.init(n) +} + +// lockWrite locks the node pool for writing. +// There can only be one writer at a time. +// This blocks until all other reader and writer locks are released. +func (np *nodePool) lockWrite() { np.mtx.Lock() } + +// unlockWrite unlocks the node pool for writing. +func (np *nodePool) unlockWrite() { np.mtx.Unlock() } + +// lockRead locks the node pool for reading. +// There can be multiple reader locks held at a time. +// This blocks until the writer lock is released, if it's held. +func (np *nodePool) lockRead() { np.mtx.RLock() } + +// unlockRead unlocks the node pool for reading. +func (np *nodePool) unlockRead() { np.mtx.RUnlock() } + +// add adds a new node to the pool. +func (np *nodePool) add( + fp rangesync.Fingerprint, + c uint32, + left, right nodeIndex, + v rangesync.KeyBytes, + replaceIdx nodeIndex, +) nodeIndex { + if c == 1 || left == noIndex && right == noIndex { + c |= leafFlag + } + newNode := node{fp: fp, c: c, l: noIndex, r: noIndex} + if left != noIndex { + newNode.l = left + } + if right != noIndex { + newNode.r = right + } + var idx uint32 + if replaceIdx != noIndex { + np.rcPool.replace(uint32(replaceIdx), newNode) + idx = uint32(replaceIdx) + } else { + idx = np.rcPool.add(newNode) + } + if v != nil { + if c != 1|leafFlag { + panic("BUG: non-leaf node with a value") + } + if np.leafMap == nil { + np.leafMap = make(map[uint32]rangesync.KeyBytes) + } + np.leafMap[idx] = slices.Clone(v) + } else if replaceIdx != noIndex { + delete(np.leafMap, idx) + } + return nodeIndex(idx) +} + +// value returns the value of the node at the given index. +func (np *nodePool) value(idx nodeIndex) rangesync.KeyBytes { + if idx == noIndex { + return nil + } + return np.leafMap[uint32(idx)] +} + +// left returns the left child of the node at the given index. +func (np *nodePool) left(idx nodeIndex) nodeIndex { + if idx == noIndex { + return noIndex + } + node := np.rcPool.item(uint32(idx)) + if node.c&leafFlag != 0 || node.l == noIndex { + return noIndex + } + return node.l +} + +// right returns the right child of the node at the given index. +func (np *nodePool) right(idx nodeIndex) nodeIndex { + if idx == noIndex { + return noIndex + } + node := np.rcPool.item(uint32(idx)) + if node.c&leafFlag != 0 || node.r == noIndex { + return noIndex + } + return node.r +} + +// leaf returns true if this is a leaf node. +func (np *nodePool) leaf(idx nodeIndex) bool { + if idx == noIndex { + panic("BUG: bad node index") + } + node := np.rcPool.item(uint32(idx)) + return node.c&leafFlag != 0 +} + +// count returns number of set items to which the node at the given index corresponds. +func (np *nodePool) count(idx nodeIndex) uint32 { + if idx == noIndex { + return 0 + } + node := np.rcPool.item(uint32(idx)) + if node.c == 1 { + panic("BUG: single-count node w/o the leaf flag") + } + return node.c &^ leafFlag +} + +// info returns the count, fingerprint, and leaf flag of the node at the given index. +func (np *nodePool) info(idx nodeIndex) (count uint32, fp rangesync.Fingerprint, leaf bool) { + if idx == noIndex { + panic("BUG: bad node index") + } + node := np.rcPool.item(uint32(idx)) + if node.c == 1 { + panic("BUG: single-count node w/o the leaf flag") + } + return node.c &^ leafFlag, node.fp, node.c&leafFlag != 0 +} + +// releaseOne releases the node at the given index, returning it to the pool. +func (np *nodePool) releaseOne(idx nodeIndex) bool { + if idx == noIndex { + return false + } + if np.rcPool.release(uint32(idx)) { + delete(np.leafMap, uint32(idx)) + return true + } + return false +} + +// release releases the node at the given index, returning it to the pool, and recursively +// releases its children. +func (np *nodePool) release(idx nodeIndex) bool { + if idx == noIndex { + return false + } + node := np.rcPool.item(uint32(idx)) + if !np.rcPool.release(uint32(idx)) { + return false + } + if node.c&leafFlag == 0 { + if node.l != noIndex { + np.release(node.l) + } + if node.r != noIndex { + np.release(node.r) + } + } else { + delete(np.leafMap, uint32(idx)) + } + return true +} + +// ref adds a reference to the given node. +func (np *nodePool) ref(idx nodeIndex) { + np.rcPool.ref(uint32(idx)) +} + +// refCount returns the reference count for the node at the given index. +func (np *nodePool) refCount(idx nodeIndex) uint32 { + return np.rcPool.refCount(uint32(idx)) +} + +// nodeCount returns the number of nodes in the pool. +func (np *nodePool) nodeCount() int { + return np.rcPool.count() +} diff --git a/sync2/fptree/nodepool_test.go b/sync2/fptree/nodepool_test.go new file mode 100644 index 0000000000..32dcb2bbd8 --- /dev/null +++ b/sync2/fptree/nodepool_test.go @@ -0,0 +1,80 @@ +package fptree + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" +) + +func TestNodePool(t *testing.T) { + var np nodePool + require.Zero(t, np.nodeCount()) + idx1 := np.add(rangesync.MustParseHexFingerprint("000000000000000000000001"), 1, noIndex, noIndex, + rangesync.KeyBytes("foo"), noIndex) + idx2 := np.add(rangesync.MustParseHexFingerprint("000000000000000000000002"), 1, noIndex, noIndex, + rangesync.KeyBytes("bar"), noIndex) + idx3 := np.add(rangesync.MustParseHexFingerprint("000000000000000000000003"), 2, idx1, idx2, nil, noIndex) + + require.Equal(t, nodeIndex(0), idx1) + require.Equal(t, rangesync.KeyBytes("foo"), np.value(idx1)) + require.Equal(t, noIndex, np.left(idx1)) + require.Equal(t, noIndex, np.right(idx1)) + require.True(t, np.leaf(idx1)) + require.Equal(t, uint32(1), np.count(idx1)) + count, fp, leaf := np.info(idx1) + require.Equal(t, uint32(1), count) + require.Equal(t, rangesync.MustParseHexFingerprint("000000000000000000000001"), fp) + require.True(t, leaf) + require.Equal(t, uint32(1), np.refCount(idx1)) + + require.Equal(t, nodeIndex(1), idx2) + require.Equal(t, rangesync.KeyBytes("bar"), np.value(idx2)) + require.Equal(t, noIndex, np.left(idx2)) + require.Equal(t, noIndex, np.right(idx2)) + require.True(t, np.leaf(idx2)) + require.Equal(t, uint32(1), np.count(idx2)) + count, fp, leaf = np.info(idx2) + require.Equal(t, uint32(1), count) + require.Equal(t, rangesync.MustParseHexFingerprint("000000000000000000000002"), fp) + require.True(t, leaf) + require.Equal(t, uint32(1), np.refCount(idx2)) + + require.Equal(t, nodeIndex(2), idx3) + require.Nil(t, nil, idx3) + require.Equal(t, idx1, np.left(idx3)) + require.Equal(t, idx2, np.right(idx3)) + require.False(t, np.leaf(idx3)) + require.Equal(t, uint32(2), np.count(idx3)) + count, fp, leaf = np.info(idx3) + require.Equal(t, uint32(2), count) + require.Equal(t, rangesync.MustParseHexFingerprint("000000000000000000000003"), fp) + require.False(t, leaf) + require.Equal(t, uint32(1), np.refCount(idx3)) + + require.Equal(t, 3, np.nodeCount()) + + np.ref(idx2) + require.Equal(t, uint32(2), np.refCount(idx2)) + + np.release(idx3) + require.Equal(t, 1, np.nodeCount()) + require.Equal(t, uint32(1), np.refCount(idx2)) + count, fp, leaf = np.info(idx2) + require.Equal(t, uint32(1), count) + require.Equal(t, rangesync.MustParseHexFingerprint("000000000000000000000002"), fp) + require.True(t, leaf) + + require.Equal(t, idx2, np.add( + rangesync.MustParseHexFingerprint("000000000000000000000004"), 1, noIndex, noIndex, + rangesync.KeyBytes("bar2"), idx2)) + count, fp, leaf = np.info(idx2) + require.Equal(t, uint32(1), count) + require.Equal(t, rangesync.MustParseHexFingerprint("000000000000000000000004"), fp) + require.True(t, leaf) + require.Equal(t, rangesync.KeyBytes("bar2"), np.value(idx2)) + + np.release(idx2) + require.Zero(t, np.nodeCount()) +} diff --git a/sync2/fptree/prefix.go b/sync2/fptree/prefix.go new file mode 100644 index 0000000000..6fd9ffea76 --- /dev/null +++ b/sync2/fptree/prefix.go @@ -0,0 +1,198 @@ +package fptree + +import ( + "fmt" + "math/bits" + "strings" + + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" +) + +const ( + // prefixBytes is the number of bytes in a prefix. + prefixBytes = rangesync.FingerprintSize + // maxPrefixLen is the maximum length of a prefix in bits. + maxPrefixLen = prefixBytes * 8 +) + +// prefix is a prefix of a key, represented as a bit string. +type prefix struct { + // the bytes of the prefix, starting from the highest byte. + b [prefixBytes]byte + // length of the prefix in bits. + l uint16 +} + +// emptyPrefix is the empty prefix (length 0). +var emptyPrefix = prefix{} + +// prefixFromKeyBytes returns a prefix made from a key by using the maximum possible +// number of its bytes. +func prefixFromKeyBytes(k rangesync.KeyBytes) (p prefix) { + p.l = uint16(copy(p.b[:], k) * 8) + return p +} + +// len returns the length of the prefix. +func (p prefix) len() int { + return int(p.l) +} + +// left returns the prefix with one more 0 bit. +func (p prefix) left() prefix { + if p.l == maxPrefixLen { + panic("BUG: max prefix len reached") + } + p.b[p.l/8] &^= 1 << (7 - p.l%8) + p.l++ + return p +} + +// right returns the prefix with one more 1 bit. +func (p prefix) right() prefix { + if p.l == maxPrefixLen { + panic("BUG: max prefix len reached") + } + p.b[p.l/8] |= 1 << (7 - p.l%8) + p.l++ + return p +} + +// String implements fmt.Stringer. +func (p prefix) String() string { + if p.len() == 0 { + return "<0>" + } + var sb strings.Builder + for _, b := range p.b[:(p.l+7)/8] { + sb.WriteString(fmt.Sprintf("%08b", b)) + } + return fmt.Sprintf("<%d:%s>", p.l, sb.String()[:p.l]) +} + +// highBit returns the highest bit of the prefix as bool (false=0, true=1). +// If the prefix is empty, it returns false. +func (p prefix) highBit() bool { + return p.l != 0 && p.b[0]&0x80 != 0 +} + +// minID sets the key to the smallest key with the prefix. +func (p prefix) minID(k rangesync.KeyBytes) { + nb := (p.l + 7) / 8 + if len(k) < int(nb) { + panic("BUG: id slice too small") + } + copy(k[:nb], p.b[:nb]) + clear(k[nb:]) +} + +// idAfter sets the key to the key immediately after the largest key with the prefix. +// idAfter returns true if the resulting id is zero, meaning wraparound. +func (p prefix) idAfter(k rangesync.KeyBytes) bool { + nb := (p.l + 7) / 8 + if len(k) < int(nb) { + panic("BUG: id slice too small") + } + // Copy prefix bits to the key, set all the bits after the prefix to 1, then + // increment the key. + copy(k[:nb], p.b[:nb]) + if p.l%8 != 0 { + k[nb-1] |= (1<<(8-p.l%8) - 1) + } + for i := int(nb); i < len(k); i++ { + k[i] = 0xff + } + return k.Inc() +} + +// shift removes the highest bit from the prefix. +func (p prefix) shift() prefix { + switch l := p.len(); l { + case 0: + panic("BUG: can't shift zero prefix") + case 1: + return emptyPrefix + default: + var c byte + for nb := int((p.l+7)/8) - 1; nb >= 0; nb-- { + c, p.b[nb] = (p.b[nb]&0x80)>>7, (p.b[nb]<<1)|c + } + p.l-- + return p + } +} + +// match returns true if the prefix matches the key, that is, +// all the prefix bits are equal to the corresponding bits of the key. +func (p prefix) match(b rangesync.KeyBytes) bool { + if int(p.l) > len(b)*8 { + panic("BUG: id slice too small") + } + if p.l == 0 { + return true + } + bi := p.l / 8 + for i, v := range p.b[:bi] { + if b[i] != v { + return false + } + } + s := p.l % 8 + return s == 0 || p.b[bi]>>(8-s) == b[bi]>>(8-s) +} + +// preFirst0 returns the longest prefix of the key that consists entirely of binary 1s. +func preFirst0(k rangesync.KeyBytes) prefix { + var p prefix + nb := min(prefixBytes, len(k)) + for n, b := range k[:nb] { + if b != 0xff { + nOnes := bits.LeadingZeros8(^b) + if nOnes != 0 { + p.b[n] = 0xff << (8 - nOnes) + p.l += uint16(nOnes) + } + break + } + p.b[n] = 0xff + p.l += 8 + } + return p +} + +// preFirst1 returns the longest prefix of the key that consists entirely of binary 0s. +func preFirst1(k rangesync.KeyBytes) prefix { + var p prefix + nb := min(prefixBytes, len(k)) + for _, b := range k[:nb] { + if b != 0 { + p.l += uint16(bits.LeadingZeros8(b)) + break + } + p.l += 8 + } + return p +} + +// commonPrefix returns common prefix between two keys. +func commonPrefix(a, b rangesync.KeyBytes) prefix { + var p prefix + nb := min(prefixBytes, len(a), len(b)) + for n, v1 := range a[:nb] { + v2 := b[n] + p.b[n] = v1 + if v1 != v2 { + nEqBits := bits.LeadingZeros8(v1 ^ v2) + if nEqBits != 0 { + // Clear unused bits in the last used prefix byte + p.b[n] &^= 1<<(8-nEqBits) - 1 + p.l += uint16(nEqBits) + } else { + p.b[n] = 0 + } + break + } + p.l += 8 + } + return p +} diff --git a/sync2/fptree/prefix_test.go b/sync2/fptree/prefix_test.go new file mode 100644 index 0000000000..3394948191 --- /dev/null +++ b/sync2/fptree/prefix_test.go @@ -0,0 +1,328 @@ +package fptree + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" +) + +func verifyPrefix(t *testing.T, p prefix) { + for i := (p.len() + 7) / 8; i < prefixBytes; i++ { + require.Zero(t, p.b[i], "p.bs[%d]", i) + } +} + +func TestPrefix(t *testing.T) { + for _, tc := range []struct { + p prefix + s string + left prefix + right prefix + shift prefix + minID string + idAfter string + }{ + { + p: emptyPrefix, + s: "<0>", + left: prefix{b: [prefixBytes]byte{0}, l: 1}, + right: prefix{b: [prefixBytes]byte{0x80}, l: 1}, + minID: "0000000000000000000000000000000000000000000000000000000000000000", + idAfter: "0000000000000000000000000000000000000000000000000000000000000000", + }, + { + p: prefix{b: [prefixBytes]byte{0}, l: 1}, + s: "<1:0>", + left: prefix{b: [prefixBytes]byte{0}, l: 2}, + right: prefix{b: [prefixBytes]byte{0x40}, l: 2}, + shift: emptyPrefix, + minID: "0000000000000000000000000000000000000000000000000000000000000000", + idAfter: "8000000000000000000000000000000000000000000000000000000000000000", + }, + { + p: prefix{b: [prefixBytes]byte{0x80}, l: 1}, + s: "<1:1>", + left: prefix{b: [prefixBytes]byte{0x80}, l: 2}, + right: prefix{b: [prefixBytes]byte{0xc0}, l: 2}, + shift: emptyPrefix, + minID: "8000000000000000000000000000000000000000000000000000000000000000", + idAfter: "0000000000000000000000000000000000000000000000000000000000000000", + }, + { + p: prefix{b: [prefixBytes]byte{0}, l: 2}, + s: "<2:00>", + left: prefix{b: [prefixBytes]byte{0}, l: 3}, + right: prefix{b: [prefixBytes]byte{0x20}, l: 3}, + shift: prefix{b: [prefixBytes]byte{0}, l: 1}, + minID: "0000000000000000000000000000000000000000000000000000000000000000", + idAfter: "4000000000000000000000000000000000000000000000000000000000000000", + }, + { + p: prefix{b: [prefixBytes]byte{0x40}, l: 2}, + s: "<2:01>", + left: prefix{b: [prefixBytes]byte{0x40}, l: 3}, + right: prefix{b: [prefixBytes]byte{0x60}, l: 3}, + shift: prefix{b: [prefixBytes]byte{0x80}, l: 1}, + minID: "4000000000000000000000000000000000000000000000000000000000000000", + idAfter: "8000000000000000000000000000000000000000000000000000000000000000", + }, + { + p: prefix{b: [prefixBytes]byte{0x80}, l: 2}, + s: "<2:10>", + left: prefix{b: [prefixBytes]byte{0x80}, l: 3}, + right: prefix{b: [prefixBytes]byte{0xa0}, l: 3}, + shift: prefix{b: [prefixBytes]byte{0}, l: 1}, + minID: "8000000000000000000000000000000000000000000000000000000000000000", + idAfter: "c000000000000000000000000000000000000000000000000000000000000000", + }, + { + p: prefix{b: [prefixBytes]byte{0xc0}, l: 2}, + s: "<2:11>", + left: prefix{b: [prefixBytes]byte{0xc0}, l: 3}, + right: prefix{b: [prefixBytes]byte{0xe0}, l: 3}, + shift: prefix{b: [prefixBytes]byte{0x80}, l: 1}, + minID: "c000000000000000000000000000000000000000000000000000000000000000", + idAfter: "0000000000000000000000000000000000000000000000000000000000000000", + }, + { + p: prefix{b: [prefixBytes]byte{0xff, 0xff, 0xff}, l: 24}, + s: "<24:111111111111111111111111>", + left: prefix{b: [prefixBytes]byte{0xff, 0xff, 0xff, 0}, l: 25}, + right: prefix{b: [prefixBytes]byte{0xff, 0xff, 0xff, 0x80}, l: 25}, + shift: prefix{b: [prefixBytes]byte{0xff, 0xff, 0xfe}, l: 23}, + minID: "ffffff0000000000000000000000000000000000000000000000000000000000", + idAfter: "0000000000000000000000000000000000000000000000000000000000000000", + }, + { + p: prefix{b: [prefixBytes]byte{0xff, 0xff, 0xff, 0}, l: 25}, + s: "<25:1111111111111111111111110>", + left: prefix{b: [prefixBytes]byte{0xff, 0xff, 0xff, 0}, l: 26}, + right: prefix{b: [prefixBytes]byte{0xff, 0xff, 0xff, 0x40}, l: 26}, + shift: prefix{b: [prefixBytes]byte{0xff, 0xff, 0xfe}, l: 24}, + minID: "ffffff0000000000000000000000000000000000000000000000000000000000", + idAfter: "ffffff8000000000000000000000000000000000000000000000000000000000", + }, + } { + t.Run(fmt.Sprint(tc.p), func(t *testing.T) { + require.Equal(t, tc.s, tc.p.String()) + require.Equal(t, tc.left, tc.p.left()) + verifyPrefix(t, tc.p.left()) + require.Equal(t, tc.right, tc.p.right()) + verifyPrefix(t, tc.p.right()) + if tc.p != emptyPrefix { + require.Equal(t, tc.shift, tc.p.shift()) + verifyPrefix(t, tc.p.shift()) + } + + minID := make(rangesync.KeyBytes, 32) + tc.p.minID(minID) + require.Equal(t, tc.minID, minID.String()) + + idAfter := make(rangesync.KeyBytes, 32) + tc.p.idAfter(idAfter) + require.Equal(t, tc.idAfter, idAfter.String()) + }) + } +} + +func TestCommonPrefix(t *testing.T) { + for _, tc := range []struct { + a, b, p string + }{ + { + a: "0000000000000000000000000000000000000000000000000000000000000000", + b: "8000000000000000000000000000000000000000000000000000000000000000", + p: "<0>", + }, + { + a: "A000000000000000000000000000000000000000000000000000000000000000", + b: "8000000000000000000000000000000000000000000000000000000000000000", + p: "<2:10>", + }, + { + a: "A000000000000000000000000000000000000000000000000000000000000000", + b: "A800000000000000000000000000000000000000000000000000000000000000", + p: "<4:1010>", + }, + { + a: "ABCDEF1234567890000000000000000000000000000000000000000000000000", + b: "ABCDEF1234567800000000000000000000000000000000000000000000000000", + p: "<56:10101011110011011110111100010010001101000101011001111000>", + }, + { + a: "ABCDEF1234567890123456789ABCDEF000000000000000000000000000000000", + b: "ABCDEF1234567890123456789ABCDEF000000000000000000000000000000000", + p: "<96:1010101111001101111011110001001000110100010101100111100010010000" + + "00010010001101000101011001111000>", + }, + } { + a := rangesync.MustParseHexKeyBytes(tc.a) + b := rangesync.MustParseHexKeyBytes(tc.b) + require.Equal(t, tc.p, commonPrefix(a, b).String()) + verifyPrefix(t, commonPrefix(a, b)) + } +} + +func TestPreFirst0(t *testing.T) { + for _, tc := range []struct { + k, exp string + }{ + { + k: "00000000", + exp: "<0>", + }, + { + k: "10000000", + exp: "<0>", + }, + { + k: "40000000", + exp: "<0>", + }, + { + k: "00040000", + exp: "<0>", + }, + { + k: "80000000", + exp: "<1:1>", + }, + { + k: "c0000000", + exp: "<2:11>", + }, + { + k: "cc000000", + exp: "<2:11>", + }, + { + k: "ffc00000", + exp: "<10:1111111111>", + }, + { + k: "ffffffff", + exp: "<32:11111111111111111111111111111111>", + }, + } { + k := rangesync.MustParseHexKeyBytes(tc.k) + require.Equal(t, tc.exp, preFirst0(k).String(), "k=%s", tc.k) + verifyPrefix(t, preFirst0(k)) + } +} + +func TestPreFirst1(t *testing.T) { + for _, tc := range []struct { + k, exp string + }{ + { + k: "ffffffff", + exp: "<0>", + }, + { + k: "80000000", + exp: "<0>", + }, + { + k: "c0000000", + exp: "<0>", + }, + { + k: "ffffffc0", + exp: "<0>", + }, + { + k: "70000000", + exp: "<1:0>", + }, + { + k: "30000000", + exp: "<2:00>", + }, + { + k: "00300000", + exp: "<10:0000000000>", + }, + { + k: "00000000", + exp: "<32:00000000000000000000000000000000>", + }, + } { + k := rangesync.MustParseHexKeyBytes(tc.k) + require.Equal(t, tc.exp, preFirst1(k).String(), "k=%s", tc.k) + verifyPrefix(t, preFirst1(k)) + } +} + +func TestMatch(t *testing.T) { + for _, tc := range []struct { + k string + p prefix + match bool + }{ + { + k: "12345678", + p: emptyPrefix, + match: true, + }, + { + k: "12345678", + p: prefix{l: 1}, + match: true, + }, + { + k: "12345678", + p: prefix{l: 3}, + match: true, + }, + { + k: "12345678", + p: prefix{l: 4}, + match: false, + }, + { + k: "12345678", + p: prefix{b: [prefixBytes]byte{0x80}, l: 1}, + match: false, + }, + { + k: "12345678", + p: prefix{b: [prefixBytes]byte{0x80}, l: 2}, + match: false, + }, + { + k: "12345678", + p: prefix{b: [prefixBytes]byte{0x10}, l: 4}, + match: true, + }, + { + k: "12345678", + p: prefix{b: [prefixBytes]byte{0x12, 0x34, 0x50}, l: 20}, + match: true, + }, + { + k: "12345678", + p: prefix{b: [prefixBytes]byte{0x12, 0x34, 0x50}, l: 24}, + match: false, + }, + { + k: "12345678", + p: prefix{b: [prefixBytes]byte{0x12, 0x34, 0x56, 0x78}, l: 32}, + match: true, + }, + } { + k := rangesync.MustParseHexKeyBytes(tc.k) + require.Equal(t, tc.match, tc.p.match(k), "k=%s p=%s", tc.k, tc.p) + } +} + +func TestPrefixFromKeyBytes(t *testing.T) { + p := prefixFromKeyBytes(rangesync.MustParseHexKeyBytes( + "123456789abcdef0123456789abcdef111111111111111111111111111111111")) + require.Equal(t, + "<96:000100100011010001010110011110001001101010111100"+ + "110111101111000000010010001101000101011001111000>", + p.String()) +} diff --git a/sync2/fptree/refcountpool.go b/sync2/fptree/refcountpool.go new file mode 100644 index 0000000000..cc0c33b1d1 --- /dev/null +++ b/sync2/fptree/refcountpool.go @@ -0,0 +1,120 @@ +package fptree + +import ( + "strconv" + "sync/atomic" +) + +// freeBit is a bit that indicates that an entry is free. +const freeBit = 1 << 31 + +// freeListMask is a mask that extracts the free list index from a refCount. +const freeListMask = freeBit - 1 + +// poolEntry is an entry in the rcPool. +type poolEntry[T any, I ~uint32] struct { + refCount uint32 + content T +} + +// rcPool is a reference-counted pool of items. +// The zero value is a valid, empty rcPool. +// Unlike sync.Pool, rcPool does not shrink, but uint32 indices can be used +// to reference items instead of larger 64-bit pointers, and the items +// can be shared between. +type rcPool[T any, I ~uint32] struct { + entries []poolEntry[T, I] + // freeList is 1-based so that rcPool doesn't need a constructor + freeList uint32 + allocCount atomic.Int64 +} + +// init pre-allocates the rcPool with n items. +func (rc *rcPool[T, I]) init(n int) { + rc.entries = make([]poolEntry[T, I], 0, n) + rc.freeList = 0 + rc.allocCount.Store(0) +} + +// count returns the number of items in the rcPool. +func (rc *rcPool[T, I]) count() int { + return int(rc.allocCount.Load()) +} + +// item returns the item at the given index. +func (rc *rcPool[T, I]) item(idx I) T { + return rc.entry(idx).content +} + +// entry returns the pool entry at the given index. +func (rc *rcPool[T, I]) entry(idx I) *poolEntry[T, I] { + entry := &rc.entries[idx] + if entry.refCount&freeBit != 0 { + panic("BUG: referencing a free nodePool entry " + strconv.Itoa(int(idx))) + } + return entry +} + +// replace replaces the item at the given index. +func (rc *rcPool[T, I]) replace(idx I, item T) { + entry := &rc.entries[idx] + if entry.refCount&freeBit != 0 { + panic("BUG: replace of a free rcPool[T, I] entry") + } + if entry.refCount != 1 { + panic("BUG: bad rcPool[T, I] entry refcount for replace") + } + entry.content = item +} + +// add adds an item to the rcPool and returns its index. +func (rc *rcPool[T, I]) add(item T) I { + var idx I + if rc.freeList != 0 { + idx = I(rc.freeList - 1) + rc.freeList = rc.entries[idx].refCount & freeListMask + if rc.freeList > uint32(len(rc.entries)) { + panic("BUG: bad freeList linkage") + } + rc.entries[idx].refCount = 1 + } else { + idx = I(len(rc.entries)) + rc.entries = append(rc.entries, poolEntry[T, I]{refCount: 1}) + } + rc.entries[idx].content = item + rc.allocCount.Add(1) + return idx +} + +// release releases the item at the given index. +func (rc *rcPool[T, I]) release(idx I) bool { + entry := &rc.entries[idx] + if entry.refCount&freeBit != 0 { + panic("BUG: release of a free rcPool[T, I] entry") + } + if entry.refCount <= 0 { + panic("BUG: bad rcPool[T, I] entry refcount") + } + entry.refCount-- + if entry.refCount == 0 { + if rc.freeList > uint32(len(rc.entries)) { + panic("BUG: bad freeList") + } + entry.refCount = rc.freeList | freeBit + rc.freeList = uint32(idx + 1) + rc.allocCount.Add(-1) + return true + } + + return false +} + +// ref adds a reference to the item at the given index. +func (rc *rcPool[T, I]) ref(idx I) { + rc.entries[idx].refCount++ +} + +// refCount returns the reference count for the item at the given index. +func (rc *rcPool[T, I]) refCount(idx I) uint32 { + return rc.entries[idx].refCount +} diff --git a/sync2/fptree/refcountpool_test.go b/sync2/fptree/refcountpool_test.go new file mode 100644 index 0000000000..6bde075141 --- /dev/null +++ b/sync2/fptree/refcountpool_test.go @@ -0,0 +1,65 @@ +package fptree + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRCPool(t *testing.T) { + type foo struct { + //nolint:unused + x int + } + type fooIndex uint32 + + var pool rcPool[foo, fooIndex] + idx1 := pool.add(foo{x: 1}) + foo1 := pool.item(idx1) + require.Equal(t, 1, pool.count()) + idx2 := pool.add(foo{x: 2}) + foo2 := pool.item(idx2) + require.Equal(t, 2, pool.count()) + require.Equal(t, foo{x: 1}, foo1) + require.Equal(t, foo{x: 2}, foo2) + idx3 := pool.add(foo{x: 3}) + idx4 := pool.add(foo{x: 4}) + require.Equal(t, fooIndex(3), idx4) + pool.ref(idx4) + require.Equal(t, 4, pool.count()) + + require.False(t, pool.release(idx4)) + // not yet released due to an extra ref + require.Equal(t, fooIndex(4), pool.add(foo{x: 5})) + require.Equal(t, 5, pool.count()) + + require.True(t, pool.release(idx4)) + // idx4 was freed + require.Equal(t, idx4, pool.add(foo{x: 6})) + require.Equal(t, 5, pool.count()) + + // free item used just once + require.Equal(t, fooIndex(5), pool.add(foo{x: 7})) + require.Equal(t, 6, pool.count()) + + // form a free list containing several items + require.True(t, pool.release(idx3)) + require.True(t, pool.release(idx2)) + require.True(t, pool.release(idx1)) + require.Equal(t, 3, pool.count()) + + // the free list is LIFO + require.Equal(t, idx1, pool.add(foo{x: 8})) + require.Equal(t, idx2, pool.add(foo{x: 9})) + require.Equal(t, idx3, pool.add(foo{x: 10})) + require.Equal(t, 6, pool.count()) + + // the free list is exhausted + idx5 := pool.add(foo{x: 11}) + require.Equal(t, fooIndex(6), idx5) + require.Equal(t, 7, pool.count()) + + // replace the item + pool.replace(idx5, foo{x: 12}) + require.Equal(t, foo{x: 12}, pool.item(idx5)) +} diff --git a/sync2/fptree/testtree.go b/sync2/fptree/testtree.go new file mode 100644 index 0000000000..bfce519657 --- /dev/null +++ b/sync2/fptree/testtree.go @@ -0,0 +1,97 @@ +package fptree + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" +) + +// checkNode checks that the tree node at the given index is correct and also recursively +// checks its children. +func checkNode(t *testing.T, ft *FPTree, idx nodeIndex, depth int) { + left := ft.np.left(idx) + right := ft.np.right(idx) + if left == noIndex && right == noIndex { + if ft.np.count(idx) != 1 { + assert.Equal(t, depth, ft.maxDepth) + } else if ft.maxDepth == 0 && ft.idStore == nil { + assert.NotNil(t, ft.np.value(idx), "leaf node must have a value if there's no idStore") + } + } else { + if ft.maxDepth != 0 { + assert.Less(t, depth, ft.maxDepth) + } + var expFP rangesync.Fingerprint + var expCount uint32 + if left != noIndex { + checkNode(t, ft, left, depth+1) + count, fp, _ := ft.np.info(left) + expFP.Update(fp[:]) + expCount += count + } + if right != noIndex { + checkNode(t, ft, right, depth+1) + count, fp, _ := ft.np.info(right) + expFP.Update(fp[:]) + expCount += count + } + count, fp, _ := ft.np.info(idx) + assert.Equal(t, expFP, fp, "node fp at depth %d", depth) + assert.Equal(t, expCount, count, "node count at depth %d", depth) + } +} + +// CheckTree checks that the tree has correct structure. +func CheckTree(t *testing.T, ft *FPTree) { + if ft.root != noIndex { + checkNode(t, ft, ft.root, 0) + } +} + +// analyzeTreeNodeRefs checks that the reference counts in the node pool are correct. +func analyzeTreeNodeRefs(t *testing.T, np *nodePool, trees ...*FPTree) { + m := make(map[nodeIndex]map[nodeIndex]bool) + var rec func(*FPTree, nodeIndex, nodeIndex) + rec = func(ft *FPTree, idx, from nodeIndex) { + if idx == noIndex { + return + } + if _, ok := m[idx]; !ok { + m[idx] = make(map[nodeIndex]bool) + } + m[idx][from] = true + rec(ft, np.left(idx), idx) + rec(ft, np.right(idx), idx) + } + for n, ft := range trees { + treeRef := nodeIndex(-n - 1) + rec(ft, ft.root, treeRef) + } + for n, entry := range np.rcPool.entries { + if entry.refCount&freeBit != 0 { + continue + } + numTreeRefs := len(m[nodeIndex(n)]) + if numTreeRefs == 0 { + assert.Fail(t, "analyzeUnref: NOT REACHABLE", "idx: %d", n) + } else { + assert.Equal(t, numTreeRefs, int(entry.refCount), "analyzeRef: refCount for %d", n) + } + } +} + +// AnalyzeTreeNodeRefs checks that the reference counts are correct for the given trees in +// their respective node pools. +func AnalyzeTreeNodeRefs(t *testing.T, trees ...*FPTree) { + t.Helper() + // group trees by node pool they use + nodePools := make(map[*nodePool][]*FPTree) + for _, ft := range trees { + nodePools[ft.np] = append(nodePools[ft.np], ft) + } + for np, trees := range nodePools { + analyzeTreeNodeRefs(t, np, trees...) + } +} diff --git a/sync2/fptree/trace.go b/sync2/fptree/trace.go new file mode 100644 index 0000000000..e166293be9 --- /dev/null +++ b/sync2/fptree/trace.go @@ -0,0 +1,100 @@ +package fptree + +import ( + "fmt" + "os" + "strings" + + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" +) + +type shortened rangesync.KeyBytes + +// trace represents a logging facility for tracing FPTree operations, using indentation to +// show their nested structure. +type trace struct { + traceEnabled bool + traceStack []string +} + +func (t *trace) out(msg string) { + fmt.Fprintf(os.Stderr, "TRACE: %s%s\n", strings.Repeat(" ", len(t.traceStack)), msg) +} + +// enter marks the entry to a function, printing the log message with the given format +// string and arguments. +func (t *trace) enter(format string, args ...any) { + if !t.traceEnabled { + return + } + msg := fmt.Sprintf(format, preprocessTraceArgs(args)...) + t.out("ENTER: " + msg) + t.traceStack = append(t.traceStack, msg) +} + +// leave marks the exit from a function, printing the results of the function call +// together with the same log message contents which was used in the corresponding enter +// call. +func (t *trace) leave(results ...any) { + if !t.traceEnabled { + return + } + if len(t.traceStack) == 0 { + panic("BUG: trace stack underflow") + } + msg := t.traceStack[len(t.traceStack)-1] + results = preprocessTraceArgs(results) + if len(results) != 0 { + var r []string + for _, res := range results { + r = append(r, fmt.Sprint(res)) + } + msg += " => " + strings.Join(r, ", ") + } + t.traceStack = t.traceStack[:len(t.traceStack)-1] + t.out("LEAVE: " + msg) +} + +// log prints a log message with the given format string and arguments. +func (t *trace) log(format string, args ...any) { + if t.traceEnabled { + msg := fmt.Sprintf(format, preprocessTraceArgs(args)...) + t.out(msg) + } +} + +// seqFormatter is a lazy formatter for SeqResult. +type seqFormatter struct { + sr rangesync.SeqResult +} + +// String implements fmt.Stringer. +func (f seqFormatter) String() string { + for k := range f.sr.Seq { + return k.String() + } + if err := f.sr.Error(); err != nil { + return fmt.Sprintf("", err) + } + return "" +} + +// formatSeqResult returns a fmt.Stringer for the SeqResult that +// formats the sequence result lazily. +func formatSeqResult(sr rangesync.SeqResult) fmt.Stringer { + return seqFormatter{sr: sr} +} + +func preprocessTraceArgs(args []any) []any { + for n, arg := range args { + switch arg := arg.(type) { + case error: + return []any{fmt.Sprintf("", arg)} + case rangesync.SeqResult: + args[n] = formatSeqResult(arg) + case shortened: + args[n] = rangesync.KeyBytes(arg).ShortString() + } + } + return args +} diff --git a/sync2/rangesync/export_test.go b/sync2/rangesync/export_test.go index 4bd445d5ea..4f539e580c 100644 --- a/sync2/rangesync/export_test.go +++ b/sync2/rangesync/export_test.go @@ -1,12 +1,10 @@ package rangesync var ( - StartWireConduit = startWireConduit - StringToFP = stringToFP - CHash = chash - NaiveFPFunc = naiveFPFunc - NewRangeSetReconcilerInternal = newRangeSetReconciler - NewPairwiseSetSyncerInternal = newPairwiseSetSyncer + StartWireConduit = startWireConduit + StringToFP = stringToFP + CHash = chash + NaiveFPFunc = naiveFPFunc ) type ( diff --git a/sync2/rangesync/p2p.go b/sync2/rangesync/p2p.go index 23746260b8..4234e1bf0a 100644 --- a/sync2/rangesync/p2p.go +++ b/sync2/rangesync/p2p.go @@ -23,7 +23,7 @@ type PairwiseSetSyncer struct { clock clockwork.Clock } -func newPairwiseSetSyncer( +func NewPairwiseSetSyncerInternal( logger *zap.Logger, r Requester, name string, @@ -47,7 +47,7 @@ func NewPairwiseSetSyncer( name string, cfg RangeSetReconcilerConfig, ) *PairwiseSetSyncer { - return newPairwiseSetSyncer(logger, r, name, cfg, nullTracer{}, clockwork.NewRealClock()) + return NewPairwiseSetSyncerInternal(logger, r, name, cfg, nullTracer{}, clockwork.NewRealClock()) } func (pss *PairwiseSetSyncer) updateCounts(c *wireConduit) { @@ -56,7 +56,7 @@ func (pss *PairwiseSetSyncer) updateCounts(c *wireConduit) { } func (pss *PairwiseSetSyncer) createReconciler(os OrderedSet) *RangeSetReconciler { - return newRangeSetReconciler(pss.logger, pss.cfg, os, pss.tracer, pss.clock) + return NewRangeSetReconcilerInternal(pss.logger, pss.cfg, os, pss.tracer, pss.clock) } func (pss *PairwiseSetSyncer) Probe( diff --git a/sync2/rangesync/rangesync.go b/sync2/rangesync/rangesync.go index e07d0a4aa4..d03b758951 100644 --- a/sync2/rangesync/rangesync.go +++ b/sync2/rangesync/rangesync.go @@ -12,7 +12,7 @@ import ( ) const ( - DefaultMaxSendRange = 16 + DefaultMaxSendRange = 1 DefaultItemChunkSize = 1024 DefaultSampleSize = 200 maxSampleSize = 1000 @@ -85,7 +85,10 @@ type RangeSetReconciler struct { logger *zap.Logger } -func newRangeSetReconciler( +// NewRangeSetReconcilerInternal creates a new RangeSetReconciler. +// It is only directly called by the tests. +// It accepts extra tracer and clock parameters. +func NewRangeSetReconcilerInternal( logger *zap.Logger, cfg RangeSetReconcilerConfig, os OrderedSet, @@ -107,7 +110,7 @@ func newRangeSetReconciler( // NewRangeSetReconciler creates a new RangeSetReconciler. func NewRangeSetReconciler(logger *zap.Logger, cfg RangeSetReconcilerConfig, os OrderedSet) *RangeSetReconciler { - return newRangeSetReconciler(logger, cfg, os, nullTracer{}, clockwork.NewRealClock()) + return NewRangeSetReconcilerInternal(logger, cfg, os, nullTracer{}, clockwork.NewRealClock()) } func (rsr *RangeSetReconciler) defaultRange() (x, y KeyBytes, err error) { diff --git a/syncer/syncer.go b/syncer/syncer.go index ff52123950..5cd0002d81 100644 --- a/syncer/syncer.go +++ b/syncer/syncer.go @@ -127,7 +127,6 @@ type Syncer struct { atxsyncer atxSyncer malsyncer malSyncer ticker layerTicker - beacon system.BeaconGetter mesh *mesh.Mesh tortoise system.Tortoise certHandler certHandler @@ -168,7 +167,6 @@ type Syncer struct { func NewSyncer( cdb *datastore.CachedDB, ticker layerTicker, - beacon system.BeaconGetter, mesh *mesh.Mesh, tortoise system.Tortoise, fetcher fetcher, @@ -185,7 +183,6 @@ func NewSyncer( atxsyncer: atxSyncer, malsyncer: malSyncer, ticker: ticker, - beacon: beacon, mesh: mesh, tortoise: tortoise, certHandler: ch, @@ -241,11 +238,6 @@ func (s *Syncer) IsSynced(ctx context.Context) bool { return s.getSyncState() == synced } -func (s *Syncer) IsBeaconSynced(epoch types.EpochID) bool { - _, err := s.beacon.GetBeacon(epoch) - return err == nil -} - // Start starts the main sync loop that tries to sync data for every SyncInterval. func (s *Syncer) Start() { s.syncOnce.Do(func() { diff --git a/syncer/syncer_test.go b/syncer/syncer_test.go index cb74d04344..7a3f4cad50 100644 --- a/syncer/syncer_test.go +++ b/syncer/syncer_test.go @@ -75,7 +75,6 @@ type testSyncer struct { mDataFetcher *mocks.MockfetchLogic mAtxSyncer *mocks.MockatxSyncer mMalSyncer *mocks.MockmalSyncer - mBeacon *smocks.MockBeaconGetter mLyrPatrol *mocks.MocklayerPatrol mVm *mmocks.MockvmState mConState *mmocks.MockconservativeState @@ -120,7 +119,6 @@ func newTestSyncer(tb testing.TB, interval time.Duration) *testSyncer { mDataFetcher: mocks.NewMockfetchLogic(ctrl), mAtxSyncer: mocks.NewMockatxSyncer(ctrl), mMalSyncer: mocks.NewMockmalSyncer(ctrl), - mBeacon: smocks.NewMockBeaconGetter(ctrl), mLyrPatrol: mocks.NewMocklayerPatrol(ctrl), mVm: mmocks.NewMockvmState(ctrl), mConState: mmocks.NewMockconservativeState(ctrl), @@ -148,7 +146,6 @@ func newTestSyncer(tb testing.TB, interval time.Duration) *testSyncer { ts.syncer = NewSyncer( ts.cdb, ts.mTicker, - ts.mBeacon, ts.msh, ts.mTortoise, nil, @@ -755,15 +752,6 @@ func TestSyncer_setATXSyncedTwice_NoError(t *testing.T) { require.NotPanics(t, func() { ts.syncer.setATXSynced() }) } -func TestSyncer_IsBeaconSynced(t *testing.T) { - ts := newSyncerWithoutPeriodicRuns(t) - epoch := types.EpochID(11) - ts.mBeacon.EXPECT().GetBeacon(epoch).Return(types.EmptyBeacon, errors.New("unknown")) - require.False(t, ts.syncer.IsBeaconSynced(epoch)) - ts.mBeacon.EXPECT().GetBeacon(epoch).Return(types.RandomBeacon(), nil) - require.True(t, ts.syncer.IsBeaconSynced(epoch)) -} - func TestSynchronize_RecoverFromCheckpoint(t *testing.T) { ts := newSyncerWithoutPeriodicRuns(t) ts.expectDownloadLoop() @@ -774,7 +762,6 @@ func TestSynchronize_RecoverFromCheckpoint(t *testing.T) { ts.syncer = NewSyncer( ts.cdb, ts.mTicker, - ts.mBeacon, ts.msh, ts.mTortoise, nil, diff --git a/system/mocks/sync.go b/system/mocks/sync.go index e713826ebc..3680d09a65 100644 --- a/system/mocks/sync.go +++ b/system/mocks/sync.go @@ -13,7 +13,6 @@ import ( context "context" reflect "reflect" - types "github.com/spacemeshos/go-spacemesh/common/types" gomock "go.uber.org/mock/gomock" ) @@ -41,44 +40,6 @@ func (m *MockSyncStateProvider) EXPECT() *MockSyncStateProviderMockRecorder { return m.recorder } -// IsBeaconSynced mocks base method. -func (m *MockSyncStateProvider) IsBeaconSynced(arg0 types.EpochID) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsBeaconSynced", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsBeaconSynced indicates an expected call of IsBeaconSynced. -func (mr *MockSyncStateProviderMockRecorder) IsBeaconSynced(arg0 any) *MockSyncStateProviderIsBeaconSyncedCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsBeaconSynced", reflect.TypeOf((*MockSyncStateProvider)(nil).IsBeaconSynced), arg0) - return &MockSyncStateProviderIsBeaconSyncedCall{Call: call} -} - -// MockSyncStateProviderIsBeaconSyncedCall wrap *gomock.Call -type MockSyncStateProviderIsBeaconSyncedCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockSyncStateProviderIsBeaconSyncedCall) Return(arg0 bool) *MockSyncStateProviderIsBeaconSyncedCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockSyncStateProviderIsBeaconSyncedCall) Do(f func(types.EpochID) bool) *MockSyncStateProviderIsBeaconSyncedCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockSyncStateProviderIsBeaconSyncedCall) DoAndReturn(f func(types.EpochID) bool) *MockSyncStateProviderIsBeaconSyncedCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - // IsSynced mocks base method. func (m *MockSyncStateProvider) IsSynced(arg0 context.Context) bool { m.ctrl.T.Helper() diff --git a/system/sync.go b/system/sync.go index 3ae3e58716..497c0b69e4 100644 --- a/system/sync.go +++ b/system/sync.go @@ -2,8 +2,6 @@ package system import ( "context" - - "github.com/spacemeshos/go-spacemesh/common/types" ) //go:generate mockgen -typed -package=mocks -destination=./mocks/sync.go -source=./sync.go @@ -11,5 +9,4 @@ import ( // SyncStateProvider defines the interface that provides the node's sync state. type SyncStateProvider interface { IsSynced(context.Context) bool - IsBeaconSynced(types.EpochID) bool } diff --git a/systest/Makefile b/systest/Makefile index 3add184f17..1df62cc1c9 100644 --- a/systest/Makefile +++ b/systest/Makefile @@ -6,11 +6,11 @@ test_name ?= TestSmeshing org ?= spacemeshos image_name ?= $(org)/systest:$(version_info) certifier_image ?= $(org)/certifier-service:v0.7.13 -poet_image ?= $(org)/poet:v0.10.3 +poet_image ?= $(org)/poet:v0.10.10 post_service_image ?= $(org)/post-service:v0.7.13 post_init_image ?= $(org)/postcli:v0.12.5 smesher_image ?= $(org)/go-spacemesh-dev:$(version_info) -old_smesher_image ?= $(org)/go-spacemesh-dev:e46c154 # Update this when new version is released +old_smesher_image ?= $(org)/go-spacemesh-dev:7b9337a # Update this when new version is released bs_image ?= $(org)/go-spacemesh-dev-bs:$(version_info) test_id ?= systest-$(version_info) diff --git a/systest/tests/distributed_post_verification_test.go b/systest/tests/distributed_post_verification_test.go index db10973ed8..b9437d03c9 100644 --- a/systest/tests/distributed_post_verification_test.go +++ b/systest/tests/distributed_post_verification_test.go @@ -25,10 +25,12 @@ import ( "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/datastore" + "github.com/spacemeshos/go-spacemesh/fetch" mwire "github.com/spacemeshos/go-spacemesh/malfeasance/wire" "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/p2p/handshake" "github.com/spacemeshos/go-spacemesh/p2p/pubsub" + "github.com/spacemeshos/go-spacemesh/proposals/store" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" @@ -95,10 +97,50 @@ func TestPostMalfeasanceProof(t *testing.T) { require.NoError(t, err) logger.Info("p2p host created", zap.Stringer("id", host.ID())) host.Register(pubsub.AtxProtocol, func(context.Context, peer.ID, []byte) error { return nil }) - require.NoError(t, host.Start()) t.Cleanup(func() { assert.NoError(t, host.Stop()) }) + db := statesql.InMemoryTest(t) + cdb := datastore.NewCachedDB(db, zap.NewNop()) + t.Cleanup(func() { assert.NoError(t, cdb.Close()) }) + + clock, err := timesync.NewClock( + timesync.WithLayerDuration(cfg.LayerDuration), + timesync.WithTickInterval(1*time.Second), + timesync.WithGenesisTime(cl.Genesis()), + timesync.WithLogger(logger.Named("clock")), + ) + require.NoError(t, err) + t.Cleanup(clock.Close) + + proposalsStore := store.New( + store.WithEvictedLayer(clock.CurrentLayer()), + store.WithLogger(logger.Named("proposals-store")), + store.WithCapacity(cfg.Tortoise.Zdist+1), + ) + + fetcher, err := fetch.NewFetch(cdb, proposalsStore, host, + fetch.WithContext(ctx), + fetch.WithConfig(cfg.FETCH), + fetch.WithLogger(logger.Named("fetcher")), + ) + require.NoError(t, err) + + fetcher.SetValidators( + fetch.ValidatorFunc(func(context.Context, types.Hash32, peer.ID, []byte) error { return nil }), + fetch.ValidatorFunc(func(context.Context, types.Hash32, peer.ID, []byte) error { return nil }), + fetch.ValidatorFunc(func(context.Context, types.Hash32, peer.ID, []byte) error { return nil }), + fetch.ValidatorFunc(func(context.Context, types.Hash32, peer.ID, []byte) error { return nil }), + fetch.ValidatorFunc(func(context.Context, types.Hash32, peer.ID, []byte) error { return nil }), + fetch.ValidatorFunc(func(context.Context, types.Hash32, peer.ID, []byte) error { return nil }), + fetch.ValidatorFunc(func(context.Context, types.Hash32, peer.ID, []byte) error { return nil }), + fetch.ValidatorFunc(func(context.Context, types.Hash32, peer.ID, []byte) error { return nil }), + fetch.ValidatorFunc(func(context.Context, types.Hash32, peer.ID, []byte) error { return nil }), + ) + + require.NoError(t, fetcher.Start()) + t.Cleanup(fetcher.Stop) + ctrl := gomock.NewController(t) syncer := activation.NewMocksyncer(ctrl) syncer.EXPECT().RegisterForATXSynced().DoAndReturn(func() <-chan struct{} { @@ -108,9 +150,6 @@ func TestPostMalfeasanceProof(t *testing.T) { }).AnyTimes() // 1. Initialize - db := statesql.InMemoryTest(t) - cdb := datastore.NewCachedDB(db, zap.NewNop()) - t.Cleanup(func() { assert.NoError(t, cdb.Close()) }) postSetupMgr, err := activation.NewPostSetupManager( cfg.POST, logger.Named("post"), @@ -135,15 +174,6 @@ func TestPostMalfeasanceProof(t *testing.T) { t.Cleanup(func() { assert.NoError(t, postSupervisor.Stop(false)) }) // 2. create ATX with invalid POST labels - clock, err := timesync.NewClock( - timesync.WithLayerDuration(cfg.LayerDuration), - timesync.WithTickInterval(1*time.Second), - timesync.WithGenesisTime(cl.Genesis()), - timesync.WithLogger(logger.Named("clock")), - ) - require.NoError(t, err) - t.Cleanup(clock.Close) - grpcPostService := grpcserver.NewPostService( logger.Named("grpc-post-service"), grpcserver.PostServiceQueryInterval(500*time.Millisecond), @@ -167,9 +197,8 @@ func TestPostMalfeasanceProof(t *testing.T) { require.NoError(t, err) poetService, err := activation.NewPoetService( poetDb, - types.PoetServer{ - Address: cluster.MakePoetGlobalEndpoint(ctx.Namespace, 0), - }, cfg.POET, + types.PoetServer{Address: cluster.MakePoetGlobalEndpoint(ctx.Namespace, 0)}, + cfg.POET, logger, 1, activation.WithCertifier(certifier), @@ -248,6 +277,8 @@ func TestPostMalfeasanceProof(t *testing.T) { Pow: challenge.InitialPost.Pow, }, } + err = nipost.AddChallenge(localDb, signer.NodeID(), nipostChallenge) + require.NoError(t, err) nipost, err := nipostBuilder.BuildNIPost(ctx, signer, challenge.Hash(), nipostChallenge) require.NoError(t, err) @@ -297,7 +328,7 @@ func TestPostMalfeasanceProof(t *testing.T) { t.Cleanup(func() { assert.NoError(t, eg.Wait()) }) eg.Go(func() error { for { - logger.Sugar().Infow("publishing ATX", "atx", atx) + logger.Info("publishing ATX", zap.Object("atx", &atx)) buf := codec.MustEncode(&atx) err = host.Publish(ctx, pubsub.AtxProtocol, buf) require.NoError(t, err) @@ -326,7 +357,7 @@ func TestPostMalfeasanceProof(t *testing.T) { require.NoError(t, codec.Decode(malf.Proof.Proof, &proof)) require.Equal(t, mwire.InvalidPostIndex, proof.Proof.Type) invalidPostProof := proof.Proof.Data.(*mwire.InvalidPostIndexProof) - logger.Sugar().Infow("malfeasance post proof", "proof", invalidPostProof) + logger.Info("malfeasance post proof", zap.Object("proof", invalidPostProof)) invalidAtx := invalidPostProof.Atx require.Equal(t, atx.PublishEpoch, invalidAtx.PublishEpoch) require.Equal(t, atx.SmesherID, invalidAtx.SmesherID) diff --git a/systest/tests/smeshing_test.go b/systest/tests/smeshing_test.go index 16267bbd55..4531ead859 100644 --- a/systest/tests/smeshing_test.go +++ b/systest/tests/smeshing_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "os" "sort" "testing" "time" @@ -28,6 +29,13 @@ import ( "github.com/spacemeshos/go-spacemesh/systest/testcontext" ) +func TestMain(m *testing.M) { + // systest runs with `fastnet` preset. this init need to generate addresses with same hrp network prefix as fastnet. + types.SetNetworkHRP("stest") + res := m.Run() + os.Exit(res) +} + // TestSmeshing tests the network is healthy, smeshers are creating proposals, transactions are processed, and vesting // is working. func TestSmeshing(t *testing.T) { diff --git a/txs/cache.go b/txs/cache.go index ee49cedfca..d5aa8516f4 100644 --- a/txs/cache.go +++ b/txs/cache.go @@ -396,9 +396,32 @@ func (ac *accountCache) resetAfterApply( ac.txsByNonce = list.New() ac.startNonce = nextNonce ac.startBalance = newBalance + + err := ac.evictPendingNonce(db) + if err != nil { + return fmt.Errorf("evict pending: %w", err) + } return ac.addPendingFromNonce(logger, db, ac.startNonce, applied) } +func (ac *accountCache) evictPendingNonce(db sql.StateDatabase) error { + return db.WithTxImmediate(context.Background(), func(tx sql.Transaction) error { + txIds, err := transactions.GetAcctPendingToNonce(tx, ac.addr, ac.startNonce) + if err != nil { + return fmt.Errorf("get pending to nonce: %w", err) + } + for _, tid := range txIds { + if err := transactions.SetEvicted(tx, tid); err != nil { + return fmt.Errorf("set evicted for %s: %w", tid, err) + } + if err := transactions.Delete(tx, tid); err != nil { + return fmt.Errorf("delete tx %s: %w", tid, err) + } + } + return nil + }) +} + func (ac *accountCache) shouldEvict() bool { return ac.txsByNonce.Len() == 0 && !ac.moreInDB } @@ -514,7 +537,7 @@ func (c *Cache) BuildFromTXs(rst []*types.MeshTransaction, blockSeed []byte) err acctsAdded++ } } - c.logger.Sugar().Debug("added pending tx for %d accounts", acctsAdded) + c.logger.Sugar().Debugf("added pending tx for %d accounts", acctsAdded) return nil } @@ -776,6 +799,11 @@ func (c *Cache) ApplyLayer( } acctResetDuration.Observe(float64(time.Since(t2))) } + + err := transactions.PruneEvicted(db, time.Now().Add(-12*time.Hour)) + if err != nil { + logger.Warn("failed to prune evicted", zap.Error(err)) + } return nil } diff --git a/txs/cache_test.go b/txs/cache_test.go index 6d83fdf0e3..e4c66c00eb 100644 --- a/txs/cache_test.go +++ b/txs/cache_test.go @@ -334,10 +334,10 @@ func TestCache_Account_HappyFlow(t *testing.T) { checkProjection(t, tc.Cache, ta.principal, newNextNonce, newBalance+income) // mempool is unchanged checkMempool(t, tc.Cache, expectedMempool) + + // pruning has removed old and ineffective txs for _, mtx := range append(oldNonces, sameNonces...) { - got, err := transactions.Get(tc.db, mtx.ID) - require.NoError(t, err) - require.Equal(t, types.MEMPOOL, got.State) + checkTXNotInDB(t, tc.db, mtx.ID) } // revert to one layer before lid @@ -357,8 +357,6 @@ func TestCache_Account_HappyFlow(t *testing.T) { } checkProjection(t, tc.Cache, ta.principal, newNextNonce, newBalance) checkTXStateFromDB(t, tc.db, mtxs, types.MEMPOOL) - checkTXStateFromDB(t, tc.db, oldNonces, types.MEMPOOL) - checkTXStateFromDB(t, tc.db, sameNonces, types.MEMPOOL) } func TestCache_Account_TXInMultipleLayers(t *testing.T) { diff --git a/txs/conservative_state.go b/txs/conservative_state.go index edfa66afa8..7be6f38f07 100644 --- a/txs/conservative_state.go +++ b/txs/conservative_state.go @@ -283,3 +283,7 @@ func ShuffleWithNonceOrder( }))) return result } + +func (cs *ConservativeState) HasEvicted(tid types.TransactionID) (bool, error) { + return transactions.HasEvicted(cs.db, tid) +}