diff --git a/bitswap/client/bitswap_with_sessions_test.go b/bitswap/client/bitswap_with_sessions_test.go index e28379113..2191a5a90 100644 --- a/bitswap/client/bitswap_with_sessions_test.go +++ b/bitswap/client/bitswap_with_sessions_test.go @@ -8,6 +8,7 @@ import ( "github.com/ipfs/boxo/bitswap" "github.com/ipfs/boxo/bitswap/client/internal/session" + "github.com/ipfs/boxo/bitswap/client/traceability" testinstance "github.com/ipfs/boxo/bitswap/testinstance" tn "github.com/ipfs/boxo/bitswap/testnet" "github.com/ipfs/boxo/internal/test" @@ -17,6 +18,7 @@ import ( blocksutil "github.com/ipfs/go-ipfs-blocksutil" delay "github.com/ipfs/go-ipfs-delay" tu "github.com/libp2p/go-libp2p-testing/etc" + "github.com/libp2p/go-libp2p/core/peer" ) func getVirtualNetwork() tn.Network { @@ -71,9 +73,18 @@ func TestBasicSessions(t *testing.T) { if !blkout.Cid().Equals(block.Cid()) { t.Fatal("got wrong block") } + + traceBlock, ok := blkout.(traceability.Block) + if !ok { + t.Fatal("did not get tracable block") + } + + if traceBlock.From != b.Peer { + t.Fatal("should have received block from peer B, did not") + } } -func assertBlockLists(got, exp []blocks.Block) error { +func assertBlockListsFrom(from peer.ID, got, exp []blocks.Block) error { if len(got) != len(exp) { return fmt.Errorf("got wrong number of blocks, %d != %d", len(got), len(exp)) } @@ -81,6 +92,13 @@ func assertBlockLists(got, exp []blocks.Block) error { h := cid.NewSet() for _, b := range got { h.Add(b.Cid()) + traceableBlock, ok := b.(traceability.Block) + if !ok { + return fmt.Errorf("not a traceable block: %s", b.Cid()) + } + if traceableBlock.From != from { + return fmt.Errorf("incorrect peer sent block, expect %s, got %s", from, traceableBlock.From) + } } for _, b := range exp { if !h.Has(b.Cid()) { @@ -133,7 +151,7 @@ func TestSessionBetweenPeers(t *testing.T) { for b := range ch { got = append(got, b) } - if err := assertBlockLists(got, blks[i*10:(i+1)*10]); err != nil { + if err := assertBlockListsFrom(inst[0].Peer, got, blks[i*10:(i+1)*10]); err != nil { t.Fatal(err) } } @@ -192,7 +210,7 @@ func TestSessionSplitFetch(t *testing.T) { for b := range ch { got = append(got, b) } - if err := assertBlockLists(got, blks[i*10:(i+1)*10]); err != nil { + if err := assertBlockListsFrom(inst[i].Peer, got, blks[i*10:(i+1)*10]); err != nil { t.Fatal(err) } } @@ -238,7 +256,7 @@ func TestFetchNotConnected(t *testing.T) { for b := range ch { got = append(got, b) } - if err := assertBlockLists(got, blks); err != nil { + if err := assertBlockListsFrom(other.Peer, got, blks); err != nil { t.Fatal(err) } } @@ -289,7 +307,7 @@ func TestFetchAfterDisconnect(t *testing.T) { got = append(got, b) } - if err := assertBlockLists(got, blks[:5]); err != nil { + if err := assertBlockListsFrom(peerA.Peer, got, blks[:5]); err != nil { t.Fatal(err) } @@ -318,7 +336,7 @@ func TestFetchAfterDisconnect(t *testing.T) { } } - if err := assertBlockLists(got, blks); err != nil { + if err := assertBlockListsFrom(peerA.Peer, got, blks); err != nil { t.Fatal(err) } }