Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

don't use a context to shut down the circuitv2 #1185

Merged
merged 1 commit into from
Sep 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (cfg *Config) makeSwarm(ctx context.Context) (*swarm.Swarm, error) {
return swrm, nil
}

func (cfg *Config) addTransports(ctx context.Context, h host.Host) (err error) {
func (cfg *Config) addTransports(h host.Host) (err error) {
swrm, ok := h.Network().(transport.TransportNetwork)
if !ok {
// Should probably skip this if no transports.
Expand Down Expand Up @@ -165,15 +165,13 @@ func (cfg *Config) addTransports(ctx context.Context, h host.Host) (err error) {
return err
}
for _, t := range tpts {
err = swrm.AddTransport(t)
if err != nil {
if err := swrm.AddTransport(t); err != nil {
return err
}
}

if cfg.Relay {
err := circuitv2.AddTransport(ctx, h, upgrader)
if err != nil {
if err := circuitv2.AddTransport(h, upgrader); err != nil {
h.Close()
return err
}
Expand Down Expand Up @@ -225,8 +223,7 @@ func (cfg *Config) NewNode(ctx context.Context) (host.Host, error) {
}
}

err = cfg.addTransports(ctx, h)
if err != nil {
if err := cfg.addTransports(h); err != nil {
h.Close()
return nil, err
}
Expand Down Expand Up @@ -314,8 +311,7 @@ func (cfg *Config) NewNode(ctx context.Context) (host.Host, error) {
return nil, err
}
dialerHost := blankhost.NewBlankHost(dialer)
err = autoNatCfg.addTransports(ctx, dialerHost)
if err != nil {
if err := autoNatCfg.addTransports(dialerHost); err != nil {
dialerHost.Close()
h.Close()
return nil, err
Expand Down
28 changes: 21 additions & 7 deletions p2p/protocol/circuitv2/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package client

import (
"context"
"io"
"sync"

"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/proto"

"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/transport"

logging "github.com/ipfs/go-log"
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
Expand All @@ -24,9 +26,10 @@ var log = logging.Logger("p2p-circuit")
// This allows us to use the v2 code as drop in replacement for v1 in a host without breaking
// existing code and interoperability with older nodes.
type Client struct {
ctx context.Context
host host.Host
upgrader *tptu.Upgrader
ctx context.Context
ctxCancel context.CancelFunc
host host.Host
upgrader *tptu.Upgrader

incoming chan accept

Expand All @@ -35,6 +38,9 @@ type Client struct {
hopCount map[peer.ID]int
}

var _ io.Closer = &Client{}
var _ transport.Transport = &Client{}

type accept struct {
conn *Conn
writeResponse func() error
Expand All @@ -48,19 +54,27 @@ type completion struct {

// New constructs a new p2p-circuit/v2 client, attached to the given host and using the given
// upgrader to perform connection upgrades.
func New(ctx context.Context, h host.Host, upgrader *tptu.Upgrader) (*Client, error) {
return &Client{
ctx: ctx,
func New(h host.Host, upgrader *tptu.Upgrader) (*Client, error) {
cl := &Client{
host: h,
upgrader: upgrader,
incoming: make(chan accept),
activeDials: make(map[peer.ID]*completion),
hopCount: make(map[peer.ID]int),
}, nil
}
cl.ctx, cl.ctxCancel = context.WithCancel(context.Background())
return cl, nil
}

// Start registers the circuit (client) protocol stream handlers
func (c *Client) Start() {
c.host.SetStreamHandler(proto.ProtoIDv1, c.handleStreamV1)
c.host.SetStreamHandler(proto.ProtoIDv2Stop, c.handleStreamV2)
}

func (c *Client) Close() error {
c.ctxCancel()
c.host.RemoveStreamHandler(proto.ProtoIDv1)
c.host.RemoveStreamHandler(proto.ProtoIDv2Stop)
return nil
}
6 changes: 3 additions & 3 deletions p2p/protocol/circuitv2/client/listen.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"errors"
"net"

ma "github.com/multiformats/go-multiaddr"
Expand Down Expand Up @@ -32,7 +33,7 @@ func (l *Listener) Accept() (manet.Conn, error) {
return evt.conn, nil

case <-l.ctx.Done():
return nil, l.ctx.Err()
return nil, errors.New("circuit v2 client closed")
}
}
}
Expand All @@ -49,6 +50,5 @@ func (l *Listener) Multiaddr() ma.Multiaddr {
}

func (l *Listener) Close() error {
// noop for now
return nil
return (*Client)(l).Close()
}
6 changes: 4 additions & 2 deletions p2p/protocol/circuitv2/client/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package client
import (
"context"
"fmt"
"io"

"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/peer"
Expand All @@ -17,13 +18,13 @@ var circuitAddr = ma.Cast(circuitProtocol.VCode)

// AddTransport constructs a new p2p-circuit/v2 client and adds it as a transport to the
// host network
func AddTransport(ctx context.Context, h host.Host, upgrader *tptu.Upgrader) error {
func AddTransport(h host.Host, upgrader *tptu.Upgrader) error {
n, ok := h.Network().(transport.TransportNetwork)
if !ok {
return fmt.Errorf("%v is not a transport network", h.Network())
}

c, err := New(ctx, h, upgrader)
c, err := New(h, upgrader)
if err != nil {
return fmt.Errorf("error constructing circuit client: %w", err)
}
Expand All @@ -45,6 +46,7 @@ func AddTransport(ctx context.Context, h host.Host, upgrader *tptu.Upgrader) err

// Transport interface
var _ transport.Transport = (*Client)(nil)
var _ io.Closer = (*Client)(nil)

func (c *Client) Dial(ctx context.Context, a ma.Multiaddr, p peer.ID) (transport.CapableConn, error) {
conn, err := c.dial(ctx, a, p)
Expand Down
4 changes: 2 additions & 2 deletions p2p/protocol/circuitv2/test/compat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestRelayCompatV2DialV1(t *testing.T) {

hosts, upgraders := getNetHosts(t, ctx, 3)
addTransportV1(t, ctx, hosts[0], upgraders[0])
addTransport(t, ctx, hosts[2], upgraders[2])
addTransport(t, hosts[2], upgraders[2])

rch := make(chan []byte, 1)
hosts[0].SetStreamHandler("test", func(s network.Stream) {
Expand Down Expand Up @@ -105,7 +105,7 @@ func TestRelayCompatV1DialV2(t *testing.T) {
defer cancel()

hosts, upgraders := getNetHosts(t, ctx, 3)
addTransport(t, ctx, hosts[0], upgraders[0])
addTransport(t, hosts[0], upgraders[0])
addTransportV1(t, ctx, hosts[2], upgraders[2])

rch := make(chan []byte, 1)
Expand Down
23 changes: 11 additions & 12 deletions p2p/protocol/circuitv2/test/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ import (

logging "github.com/ipfs/go-log"
bhost "github.com/libp2p/go-libp2p-blankhost"
metrics "github.com/libp2p/go-libp2p-core/metrics"
pstoremem "github.com/libp2p/go-libp2p-peerstore/pstoremem"
"github.com/libp2p/go-libp2p-core/metrics"
"github.com/libp2p/go-libp2p-peerstore/pstoremem"
swarm "github.com/libp2p/go-libp2p-swarm"
swarmt "github.com/libp2p/go-libp2p-swarm/testing"
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
tcp "github.com/libp2p/go-tcp-transport"
"github.com/libp2p/go-tcp-transport"
ma "github.com/multiformats/go-multiaddr"
)

Expand Down Expand Up @@ -85,9 +85,8 @@ func connect(t *testing.T, a, b host.Host) {
}
}

func addTransport(t *testing.T, ctx context.Context, h host.Host, upgrader *tptu.Upgrader) {
err := client.AddTransport(ctx, h, upgrader)
if err != nil {
func addTransport(t *testing.T, h host.Host, upgrader *tptu.Upgrader) {
if err := client.AddTransport(h, upgrader); err != nil {
t.Fatal(err)
}
}
Expand All @@ -97,8 +96,8 @@ func TestBasicRelay(t *testing.T) {
defer cancel()

hosts, upgraders := getNetHosts(t, ctx, 3)
addTransport(t, ctx, hosts[0], upgraders[0])
addTransport(t, ctx, hosts[2], upgraders[2])
addTransport(t, hosts[0], upgraders[0])
addTransport(t, hosts[2], upgraders[2])

rch := make(chan []byte, 1)
hosts[0].SetStreamHandler("test", func(s network.Stream) {
Expand Down Expand Up @@ -184,8 +183,8 @@ func TestRelayLimitTime(t *testing.T) {
defer cancel()

hosts, upgraders := getNetHosts(t, ctx, 3)
addTransport(t, ctx, hosts[0], upgraders[0])
addTransport(t, ctx, hosts[2], upgraders[2])
addTransport(t, hosts[0], upgraders[0])
addTransport(t, hosts[2], upgraders[2])

rch := make(chan error, 1)
hosts[0].SetStreamHandler("test", func(s network.Stream) {
Expand Down Expand Up @@ -258,8 +257,8 @@ func TestRelayLimitData(t *testing.T) {
defer cancel()

hosts, upgraders := getNetHosts(t, ctx, 3)
addTransport(t, ctx, hosts[0], upgraders[0])
addTransport(t, ctx, hosts[2], upgraders[2])
addTransport(t, hosts[0], upgraders[0])
addTransport(t, hosts[2], upgraders[2])

rch := make(chan int, 1)
hosts[0].SetStreamHandler("test", func(s network.Stream) {
Expand Down