diff --git a/CHANGELOG.md b/CHANGELOG.md index 8bec380212..60e9b2c45d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ # Table Of Contents +- [v0.28.0](#v0280) - [v0.27.0](#v0270) - [v0.26.4](#v0264) - [v0.26.3](#v0263) @@ -8,6 +9,17 @@ - [v0.25.1](#v0251) - [v0.25.0](#v0250) +# [v0.28.0]() + +## 🔦 Highlights + +### Smart Dialing +* When connecting to a peer we now do [happy eyeballs](https://www.rfc-editor.org/rfc/rfc8305) like dial prioritisation to prefer QUIC addresses over TCP addresses. We dial the QUIC address first and wait 250ms to dial the TCP address of the peer. +* In our experiments we've seen little impact on latencies up to 80th percentile. 90th and 95th percentile latencies are impacted. For details see discussion on the [PR](https://github.com/libp2p/go-libp2p/pull/2260#issuecomment-1528848170). +* For details of the address ranking logic see godoc for `swarm.DefaultDialRanker`. +* To disable smart dialing and keep the old behaviour use the +`libp2p.NoDelayNetworkDialRanker` option. + # [v0.27.0](https://github.com/libp2p/go-libp2p/releases/tag/v0.27.0) ### Breaking Changes diff --git a/config/config.go b/config/config.go index a3bd86f27b..7a6d7fa908 100644 --- a/config/config.go +++ b/config/config.go @@ -123,6 +123,8 @@ type Config struct { DisableMetrics bool PrometheusRegisterer prometheus.Registerer + + NoDelayNetworkDialRanker bool } func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) { @@ -173,6 +175,9 @@ func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swa if cfg.MultiaddrResolver != nil { opts = append(opts, swarm.WithMultiaddrResolver(cfg.MultiaddrResolver)) } + if cfg.NoDelayNetworkDialRanker { + opts = append(opts, swarm.WithNoDialDelay()) + } if enableMetrics { opts = append(opts, swarm.WithMetricsTracer(swarm.NewMetricsTracer(swarm.WithRegisterer(cfg.PrometheusRegisterer)))) diff --git a/core/network/network.go b/core/network/network.go index 47908b8e31..4cedb75d37 100644 --- a/core/network/network.go +++ b/core/network/network.go @@ -187,6 +187,16 @@ type Dialer interface { StopNotify(Notifiee) } +// AddrDelay provides an address along with the delay after which the address +// should be dialed +type AddrDelay struct { + Addr ma.Multiaddr + Delay time.Duration +} + +// DialRanker provides a schedule of dialing the provided addresses +type DialRanker func([]ma.Multiaddr) []AddrDelay + // DedupAddrs deduplicates addresses in place, leave only unique addresses. // It doesn't allocate. func DedupAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { diff --git a/dashboards/swarm/swarm.json b/dashboards/swarm/swarm.json index 84735c34cc..3a1d875059 100644 --- a/dashboards/swarm/swarm.json +++ b/dashboards/swarm/swarm.json @@ -1224,7 +1224,8 @@ "mode": "absolute", "steps": [ { - "color": "green" + "color": "green", + "value": null }, { "color": "red", @@ -1450,7 +1451,8 @@ "mode": "absolute", "steps": [ { - "color": "green" + "color": "green", + "value": null }, { "color": "red", @@ -2327,7 +2329,8 @@ "mode": "absolute", "steps": [ { - "color": "green" + "color": "green", + "value": null }, { "color": "red", @@ -2553,6 +2556,476 @@ ], "title": "libp2p key types", "type": "piechart" + }, + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 75 + }, + "id": 40, + "panels": [], + "title": "Dial Prioritisation", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + } + }, + "mappings": [], + "unit": "none" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "<=300ms" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "purple", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "<=500ms" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "red", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "<=750ms" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "light-red", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "<=50ms" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "dark-blue", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "<=10ms" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "blue", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 76 + }, + "id": 38, + "options": { + "displayLabels": [ + "percent" + ], + "legend": { + "displayMode": "table", + "placement": "right", + "showLegend": true, + "values": [ + "percent" + ] + }, + "pieType": "donut", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "exemplar": false, + "expr": "sum(increase(libp2p_swarm_dial_ranking_delay_seconds_bucket{instance=~\"$instance\",le=\"0.001\"}[$__range]))", + "format": "time_series", + "instant": false, + "legendFormat": "No delay", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "exemplar": false, + "expr": "sum(increase(libp2p_swarm_dial_ranking_delay_seconds_bucket{instance=~\"$instance\",le=\"0.01\"}[$__range])) - ignoring(le) sum(increase(libp2p_swarm_dial_ranking_delay_seconds_bucket{instance=~\"$instance\",le=\"0.001\"}[$__range]))", + "hide": false, + "instant": false, + "legendFormat": "<=10ms", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "sum(increase(libp2p_swarm_dial_ranking_delay_seconds_bucket{instance=~\"$instance\",le=\"0.05\"}[$__range])) - ignoring(le) sum(increase(libp2p_swarm_dial_ranking_delay_seconds_bucket{instance=~\"$instance\",le=\"0.01\"}[$__range]))", + "hide": false, + "legendFormat": "<=50ms", + "range": true, + "refId": "F" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "sum(increase(libp2p_swarm_dial_ranking_delay_seconds_bucket{instance=~\"$instance\",le=\"0.3\"}[$__range])) - ignoring(le) sum(increase(libp2p_swarm_dial_ranking_delay_seconds_bucket{instance=~\"$instance\",le=\"0.05\"}[$__range]))", + "hide": false, + "legendFormat": "<=300ms", + "range": true, + "refId": "D" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "sum(increase(libp2p_swarm_dial_ranking_delay_seconds_bucket{instance=~\"$instance\",le=\"0.5\"}[$__range])) - ignoring(le) sum(increase(libp2p_swarm_dial_ranking_delay_seconds_bucket{instance=~\"$instance\",le=\"0.3\"}[$__range]))", + "hide": false, + "legendFormat": "<=500ms", + "range": true, + "refId": "E" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "sum(increase(libp2p_swarm_dial_ranking_delay_seconds_bucket{instance=~\"$instance\",le=\"0.75\"}[$__range])) - ignoring(le) sum(increase(libp2p_swarm_dial_ranking_delay_seconds_bucket{instance=~\"$instance\",le=\"0.5\"}[$__range]))", + "hide": false, + "legendFormat": "<=750ms", + "range": true, + "refId": "G" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "sum(increase(libp2p_swarm_dial_ranking_delay_seconds_bucket{instance=~\"$instance\",le=\"inf\"}[$__range])) - ignoring(le) sum(increase(libp2p_swarm_dial_ranking_delay_seconds_bucket{instance=~\"$instance\",le=\"0.75\"}[$__range]))", + "hide": false, + "legendFormat": ">750ms", + "range": true, + "refId": "H" + } + ], + "title": "Dial Ranking Delay", + "transformations": [], + "type": "piechart" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + } + }, + "mappings": [] + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": ">=6" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "light-red", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "5" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "dark-red", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "2" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "purple", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "1" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "green", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "3" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "orange", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "4" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "dark-orange", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 76 + }, + "id": 42, + "options": { + "displayLabels": [ + "percent", + "name" + ], + "legend": { + "displayMode": "table", + "placement": "right", + "showLegend": true, + "values": [ + "percent" + ] + }, + "pieType": "donut", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "tooltip": { + "mode": "single", + "sort": "none" + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "increase(libp2p_swarm_dials_per_peer_total{instance=~\"$instance\", outcome=\"success\", num_dials=\"0\"}[$__range])", + "legendFormat": "0", + "range": true, + "refId": "A" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "increase(libp2p_swarm_dials_per_peer_total{instance=~\"$instance\", outcome=\"success\", num_dials=\"1\"}[$__range])", + "hide": false, + "legendFormat": "1", + "range": true, + "refId": "B" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "increase(libp2p_swarm_dials_per_peer_total{instance=~\"$instance\", outcome=\"success\", num_dials=\"2\"}[$__range])", + "hide": false, + "legendFormat": "2", + "range": true, + "refId": "C" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "increase(libp2p_swarm_dials_per_peer_total{instance=~\"$instance\", outcome=\"success\", num_dials=\"3\"}[$__range])", + "hide": false, + "legendFormat": "3", + "range": true, + "refId": "D" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "increase(libp2p_swarm_dials_per_peer_total{instance=~\"$instance\", outcome=\"success\", num_dials=\"4\"}[$__range])", + "hide": false, + "legendFormat": "4", + "range": true, + "refId": "E" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "increase(libp2p_swarm_dials_per_peer_total{instance=~\"$instance\", outcome=\"success\", num_dials=\"5\"}[$__range])", + "hide": false, + "legendFormat": "5", + "range": true, + "refId": "F" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "editorMode": "code", + "expr": "increase(libp2p_swarm_dials_per_peer_total{instance=~\"$instance\", outcome=\"success\", num_dials=\">=6\"}[$__range])", + "hide": false, + "legendFormat": ">=6", + "range": true, + "refId": "G" + } + ], + "title": "Dials per connection", + "type": "piechart" } ], "schemaVersion": 37, @@ -2577,6 +3050,7 @@ "refId": "StandardVariableQuery" }, "refresh": 1, + "regex": "", "skipUrlSync": false, "sort": 0, "type": "query" @@ -2584,13 +3058,13 @@ ] }, "time": { - "from": "now-15m", + "from": "now-1h", "to": "now" }, "timepicker": {}, "timezone": "", "title": "libp2p Swarm", "uid": "a15PyhO4z", - "version": 12, + "version": 6, "weekStart": "" -} +} \ No newline at end of file diff --git a/options.go b/options.go index 1809ec44ba..41057a7180 100644 --- a/options.go +++ b/options.go @@ -574,3 +574,12 @@ func PrometheusRegisterer(reg prometheus.Registerer) Option { return nil } } + +// NoDelayNetworkDialRanker configures libp2p to disable dial prioritisation and dial +// all addresses of the peer without any delay +func NoDelayNetworkDialRanker() Option { + return func(cfg *Config) error { + cfg.NoDelayNetworkDialRanker = true + return nil + } +} diff --git a/p2p/net/swarm/clock.go b/p2p/net/swarm/clock.go new file mode 100644 index 0000000000..6b63ac9c87 --- /dev/null +++ b/p2p/net/swarm/clock.go @@ -0,0 +1,49 @@ +package swarm + +import "time" + +// InstantTimer is a timer that triggers at some instant rather than some duration +type InstantTimer interface { + Reset(d time.Time) bool + Stop() bool + Ch() <-chan time.Time +} + +// Clock is a clock that can create timers that trigger at some +// instant rather than some duration +type Clock interface { + Now() time.Time + Since(t time.Time) time.Duration + InstantTimer(when time.Time) InstantTimer +} + +type RealTimer struct{ t *time.Timer } + +var _ InstantTimer = (*RealTimer)(nil) + +func (t RealTimer) Ch() <-chan time.Time { + return t.t.C +} + +func (t RealTimer) Reset(d time.Time) bool { + return t.t.Reset(time.Until(d)) +} + +func (t RealTimer) Stop() bool { + return t.t.Stop() +} + +type RealClock struct{} + +var _ Clock = RealClock{} + +func (RealClock) Now() time.Time { + return time.Now() +} +func (RealClock) Since(t time.Time) time.Duration { + return time.Since(t) +} +func (RealClock) InstantTimer(when time.Time) InstantTimer { + t := time.NewTimer(time.Until(when)) + return &RealTimer{t} +} diff --git a/p2p/net/swarm/dial_ranker.go b/p2p/net/swarm/dial_ranker.go new file mode 100644 index 0000000000..d848bf8640 --- /dev/null +++ b/p2p/net/swarm/dial_ranker.go @@ -0,0 +1,219 @@ +package swarm + +import ( + "net/netip" + "sort" + "strconv" + "time" + + "github.com/libp2p/go-libp2p/core/network" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +// The 250ms value is from happy eyeballs RFC 8305. This is a rough estimate of 1 RTT +const ( + // duration by which TCP dials are delayed relative to QUIC dial + PublicTCPDelay = 250 * time.Millisecond + PrivateTCPDelay = 30 * time.Millisecond + + // duration by which QUIC dials are delayed relative to first QUIC dial + PublicQUICDelay = 250 * time.Millisecond + PrivateQUICDelay = 30 * time.Millisecond + + // RelayDelay is the duration by which relay dials are delayed relative to direct addresses + RelayDelay = 250 * time.Millisecond +) + +// noDelayRanker ranks addresses with no delay. This is useful for simultaneous connect requests. +func noDelayRanker(addrs []ma.Multiaddr) []network.AddrDelay { + return getAddrDelay(addrs, 0, 0, 0) +} + +// DefaultDialRanker is the default ranking logic. +// +// We rank private, public ip4, public ip6, relay addresses separately. +// We do not prefer IPv6 over IPv4 as recommended by Happy Eyeballs RFC 8305. Currently there is no +// mechanism to detect an IPv6 blackhole, so we dial both IPv4 and IPv6 addresses in parallel. +// If direct addresses are present we delay all relay addresses by 500 millisecond + +// In each group we apply the following logic: +// +// First we filter the addresses we don't want to dial. We are filtering these addresses because we +// have an address that we prefer more than that address and which has the same reachability +// +// If a QUIC-v1 address is present we don't dial QUIC or webtransport address on the same (ip,port) +// combination. If a QUICDraft29 or webtransport address is reachable, QUIC-v1 will definitely be +// reachable. QUICDraft29 is deprecated in favour of QUIC-v1 and QUIC-v1 is more performant than +// webtransport +// +// If a TCP address is present we don't dial ws or wss address on the same (ip, port) combination. +// If a ws address is reachable, TCP will definitely be reachable and it'll be more performant +// +// Then we rank the addresses: +// +// If two QUIC addresses are present, we dial the QUIC address with the lowest port first. This is more +// likely to be the listen port. After this we dial the rest of the QUIC addresses delayed by QUICDelay. +// +// If a QUIC or webtransport address is present, TCP address dials are delayed by TCPDelay relative to +// the last QUIC dial. +// +// TCPDelay for public ip4 and public ip6 is PublicTCPDelay +// TCPDelay for private addresses is PrivateTCPDelay +// QUICDelay for public addresses is PublicQUICDelay +// QUICDelay for private addresses is PrivateQUICDelay +func DefaultDialRanker(addrs []ma.Multiaddr) []network.AddrDelay { + relay, addrs := filterAddrs(addrs, isRelayAddr) + pvt, addrs := filterAddrs(addrs, manet.IsPrivateAddr) + ip4, addrs := filterAddrs(addrs, func(a ma.Multiaddr) bool { return isProtocolAddr(a, ma.P_IP4) }) + ip6, addrs := filterAddrs(addrs, func(a ma.Multiaddr) bool { return isProtocolAddr(a, ma.P_IP6) }) + + var relayOffset time.Duration = 0 + if len(ip4) > 0 || len(ip6) > 0 { + // if there is a public direct address available delay relay dials + relayOffset = RelayDelay + } + + res := make([]network.AddrDelay, 0, len(addrs)) + for i := 0; i < len(addrs); i++ { + res = append(res, network.AddrDelay{Addr: addrs[i], Delay: 0}) + } + res = append(res, getAddrDelay(pvt, PrivateTCPDelay, PrivateQUICDelay, 0)...) + res = append(res, getAddrDelay(ip4, PublicTCPDelay, PublicQUICDelay, 0)...) + res = append(res, getAddrDelay(ip6, PublicTCPDelay, PublicQUICDelay, 0)...) + res = append(res, getAddrDelay(relay, PublicTCPDelay, PublicQUICDelay, relayOffset)...) + return res +} + +// getAddrDelay ranks a group of addresses(private, ip4, ip6) according to the ranking logic +// explained in defaultDialRanker. +// offset is used to delay all addresses by a fixed duration. This is useful for delaying all relay +// addresses relative to direct addresses +func getAddrDelay(addrs []ma.Multiaddr, tcpDelay time.Duration, quicDelay time.Duration, + offset time.Duration) []network.AddrDelay { + + // First make a map of QUICV1 and TCP AddrPorts. + quicV1Addr := make(map[netip.AddrPort]struct{}) + tcpAddr := make(map[netip.AddrPort]struct{}) + for _, a := range addrs { + switch { + case isProtocolAddr(a, ma.P_WEBTRANSPORT): + case isProtocolAddr(a, ma.P_QUIC_V1): + quicV1Addr[addrPort(a, ma.P_UDP)] = struct{}{} + case isProtocolAddr(a, ma.P_WS) || isProtocolAddr(a, ma.P_WSS): + case isProtocolAddr(a, ma.P_TCP): + tcpAddr[addrPort(a, ma.P_TCP)] = struct{}{} + } + } + + // Filter addresses we are sure we don't want to dial + selectedAddrs := addrs + i := 0 + for _, a := range addrs { + switch { + // If a QUICDraft29 or webtransport address is reachable, QUIC-v1 will also be reachable. So we + // drop the QUICDraft29 or webtransport address + // We prefer QUIC-v1 over the older QUIC-draft29 address. + // We prefer QUIC-v1 over webtransport as it is more performant. + case isProtocolAddr(a, ma.P_WEBTRANSPORT) || isProtocolAddr(a, ma.P_QUIC): + if _, ok := quicV1Addr[addrPort(a, ma.P_UDP)]; ok { + continue + } + // If a ws address is reachable, TCP will also be reachable and it'll be more performant + case isProtocolAddr(a, ma.P_WS) || isProtocolAddr(a, ma.P_WSS): + if _, ok := tcpAddr[addrPort(a, ma.P_TCP)]; ok { + continue + } + } + selectedAddrs[i] = a + i++ + } + selectedAddrs = selectedAddrs[:i] + sort.Slice(selectedAddrs, func(i, j int) bool { return score(selectedAddrs[i]) < score(selectedAddrs[j]) }) + + res := make([]network.AddrDelay, 0, len(addrs)) + quicCount := 0 + for _, a := range selectedAddrs { + delay := offset + switch { + case isProtocolAddr(a, ma.P_QUIC) || isProtocolAddr(a, ma.P_QUIC_V1): + // For QUIC addresses we dial a single address first and then wait for QUICDelay + // After QUICDelay we dial rest of the QUIC addresses + if quicCount > 0 { + delay += quicDelay + } + quicCount++ + case isProtocolAddr(a, ma.P_TCP): + if quicCount >= 2 { + delay += 2 * quicDelay + } else if quicCount == 1 { + delay += tcpDelay + } + } + res = append(res, network.AddrDelay{Addr: a, Delay: delay}) + } + return res +} + +// score scores a multiaddress for dialing delay. lower is better +func score(a ma.Multiaddr) int { + // the lower 16 bits of the result are the relavant port + // the higher bits rank the protocol + // low ports are ranked higher because they're more likely to + // be listen addresses + if _, err := a.ValueForProtocol(ma.P_WEBTRANSPORT); err == nil { + p, _ := a.ValueForProtocol(ma.P_UDP) + pi, _ := strconv.Atoi(p) // cannot error + return pi + (1 << 18) + } + if _, err := a.ValueForProtocol(ma.P_QUIC); err == nil { + p, _ := a.ValueForProtocol(ma.P_UDP) + pi, _ := strconv.Atoi(p) // cannot error + return pi + (1 << 17) + } + if _, err := a.ValueForProtocol(ma.P_QUIC_V1); err == nil { + p, _ := a.ValueForProtocol(ma.P_UDP) + pi, _ := strconv.Atoi(p) // cannot error + return pi + } + + if p, err := a.ValueForProtocol(ma.P_TCP); err == nil { + pi, _ := strconv.Atoi(p) // cannot error + return pi + (1 << 19) + } + return (1 << 30) +} + +// addrPort returns the ip and port for a. p should be either ma.P_TCP or ma.P_UDP. +// a must be an (ip, TCP) or (ip, udp) address. +func addrPort(a ma.Multiaddr, p int) netip.AddrPort { + ip, _ := manet.ToIP(a) + port, _ := a.ValueForProtocol(p) + pi, _ := strconv.Atoi(port) + addr, _ := netip.AddrFromSlice(ip) + return netip.AddrPortFrom(addr, uint16(pi)) +} + +func isProtocolAddr(a ma.Multiaddr, p int) bool { + found := false + ma.ForEach(a, func(c ma.Component) bool { + if c.Protocol().Code == p { + found = true + return false + } + return true + }) + return found +} + +// filterAddrs filters an address slice in place +func filterAddrs(addrs []ma.Multiaddr, f func(a ma.Multiaddr) bool) (filtered, rest []ma.Multiaddr) { + j := 0 + for i := 0; i < len(addrs); i++ { + if f(addrs[i]) { + addrs[i], addrs[j] = addrs[j], addrs[i] + j++ + } + } + return addrs[:j], addrs[j:] +} diff --git a/p2p/net/swarm/dial_ranker_test.go b/p2p/net/swarm/dial_ranker_test.go new file mode 100644 index 0000000000..28cb6e8c1a --- /dev/null +++ b/p2p/net/swarm/dial_ranker_test.go @@ -0,0 +1,243 @@ +package swarm + +import ( + "fmt" + "sort" + "testing" + + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/test" + ma "github.com/multiformats/go-multiaddr" +) + +func sortAddrDelays(addrDelays []network.AddrDelay) { + sort.Slice(addrDelays, func(i, j int) bool { + if addrDelays[i].Delay == addrDelays[j].Delay { + return addrDelays[i].Addr.String() < addrDelays[j].Addr.String() + } + return addrDelays[i].Delay < addrDelays[j].Delay + }) +} + +func TestNoDelayRanker(t *testing.T) { + q1 := ma.StringCast("/ip4/1.2.3.4/udp/1/quic") + q1v1 := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1") + wt1 := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/webtransport/") + q2 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic") + q2v1 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic-v1") + q3 := ma.StringCast("/ip4/1.2.3.4/udp/3/quic") + q3v1 := ma.StringCast("/ip4/1.2.3.4/udp/3/quic-v1") + q4 := ma.StringCast("/ip4/1.2.3.4/udp/4/quic") + + testCase := []struct { + name string + addrs []ma.Multiaddr + output []network.AddrDelay + }{ + { + name: "quic+webtransport filtered when quicv1", + addrs: []ma.Multiaddr{q1, q2, q3, q4, q1v1, q2v1, q3v1, wt1}, + output: []network.AddrDelay{ + {Addr: q1v1, Delay: 0}, + {Addr: q2v1, Delay: 0}, + {Addr: q3v1, Delay: 0}, + {Addr: q4, Delay: 0}, + }, + }, + } + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + res := noDelayRanker(tc.addrs) + if len(res) != len(tc.output) { + log.Errorf("expected %s got %s", tc.output, res) + t.Errorf("expected elems: %d got: %d", len(tc.output), len(res)) + } + sortAddrDelays(res) + sortAddrDelays(tc.output) + for i := 0; i < len(tc.output); i++ { + if !tc.output[i].Addr.Equal(res[i].Addr) || tc.output[i].Delay != res[i].Delay { + t.Fatalf("expected %+v got %+v", tc.output, res) + } + } + }) + } +} + +func TestDelayRankerQUICDelay(t *testing.T) { + q1 := ma.StringCast("/ip4/1.2.3.4/udp/1/quic") + q1v1 := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1") + wt1 := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/webtransport/") + q2 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic") + q2v1 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic-v1") + q3 := ma.StringCast("/ip4/1.2.3.4/udp/3/quic") + q3v1 := ma.StringCast("/ip4/1.2.3.4/udp/3/quic-v1") + q4 := ma.StringCast("/ip4/1.2.3.4/udp/4/quic") + + q1v16 := ma.StringCast("/ip6/1::2/udp/1/quic-v1") + q2v16 := ma.StringCast("/ip6/1::2/udp/2/quic-v1") + q3v16 := ma.StringCast("/ip6/1::2/udp/3/quic-v1") + + testCase := []struct { + name string + addrs []ma.Multiaddr + output []network.AddrDelay + }{ + { + name: "single quic dialed first", + addrs: []ma.Multiaddr{q1, q2, q3, q4}, + output: []network.AddrDelay{ + {Addr: q1, Delay: 0}, + {Addr: q2, Delay: PublicQUICDelay}, + {Addr: q3, Delay: PublicQUICDelay}, + {Addr: q4, Delay: PublicQUICDelay}, + }, + }, + { + name: "quicv1 dialed before quic", + addrs: []ma.Multiaddr{q1, q2v1, q3, q4}, + output: []network.AddrDelay{ + {Addr: q2v1, Delay: 0}, + {Addr: q1, Delay: PublicQUICDelay}, + {Addr: q3, Delay: PublicQUICDelay}, + {Addr: q4, Delay: PublicQUICDelay}, + }, + }, + { + name: "quic+webtransport filtered when quicv1", + addrs: []ma.Multiaddr{q1, q2, q3, q4, q1v1, q2v1, q3v1, wt1}, + output: []network.AddrDelay{ + {Addr: q1v1, Delay: 0}, + {Addr: q2v1, Delay: PublicQUICDelay}, + {Addr: q3v1, Delay: PublicQUICDelay}, + {Addr: q4, Delay: PublicQUICDelay}, + }, + }, + { + name: "ipv6", + addrs: []ma.Multiaddr{q1v16, q2v16, q3v16, q1}, + output: []network.AddrDelay{ + {Addr: q1, Delay: 0}, + {Addr: q1v16, Delay: 0}, + {Addr: q2v16, Delay: PublicQUICDelay}, + {Addr: q3v16, Delay: PublicQUICDelay}, + }, + }, + } + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + res := DefaultDialRanker(tc.addrs) + if len(res) != len(tc.output) { + log.Errorf("expected %s got %s", tc.output, res) + t.Errorf("expected elems: %d got: %d", len(tc.output), len(res)) + } + sortAddrDelays(res) + sortAddrDelays(tc.output) + for i := 0; i < len(tc.output); i++ { + if !tc.output[i].Addr.Equal(res[i].Addr) || tc.output[i].Delay != res[i].Delay { + t.Fatalf("expected %+v got %+v", tc.output, res) + } + } + }) + } +} + +func TestDelayRankerTCPDelay(t *testing.T) { + + q1 := ma.StringCast("/ip4/1.2.3.4/udp/1/quic") + q2v1 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic-v1") + + t1 := ma.StringCast("/ip4/1.2.3.5/tcp/1/") + t2 := ma.StringCast("/ip4/1.2.3.4/tcp/2") + + testCase := []struct { + name string + addrs []ma.Multiaddr + output []network.AddrDelay + }{ + { + name: "2 quic with tcp", + addrs: []ma.Multiaddr{q1, q2v1, t1, t2}, + output: []network.AddrDelay{ + {Addr: q2v1, Delay: 0}, + {Addr: q1, Delay: PublicQUICDelay}, + {Addr: t1, Delay: PublicQUICDelay + PublicTCPDelay}, + {Addr: t2, Delay: PublicQUICDelay + PublicTCPDelay}, + }, + }, + { + name: "1 quic with tcp", + addrs: []ma.Multiaddr{q1, t1, t2}, + output: []network.AddrDelay{ + {Addr: q1, Delay: 0}, + {Addr: t1, Delay: PublicTCPDelay}, + {Addr: t2, Delay: PublicTCPDelay}, + }, + }, + { + name: "no quic", + addrs: []ma.Multiaddr{t1, t2}, + output: []network.AddrDelay{ + {Addr: t1, Delay: 0}, + {Addr: t2, Delay: 0}, + }, + }, + } + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + res := DefaultDialRanker(tc.addrs) + if len(res) != len(tc.output) { + log.Errorf("expected %s got %s", tc.output, res) + t.Errorf("expected elems: %d got: %d", len(tc.output), len(res)) + } + sortAddrDelays(res) + sortAddrDelays(tc.output) + for i := 0; i < len(tc.output); i++ { + if !tc.output[i].Addr.Equal(res[i].Addr) || tc.output[i].Delay != res[i].Delay { + t.Fatalf("expected %+v got %+v", tc.output, res) + } + } + }) + } +} + +func TestDelayRankerRelay(t *testing.T) { + q1 := ma.StringCast("/ip4/1.2.3.4/udp/1/quic") + q2 := ma.StringCast("/ip4/1.2.3.4/udp/2/quic") + + pid := test.RandPeerIDFatal(t) + r1 := ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/1/p2p-circuit/p2p/%s", pid)) + r2 := ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/udp/1/quic/p2p-circuit/p2p/%s", pid)) + + testCase := []struct { + name string + addrs []ma.Multiaddr + output []network.AddrDelay + }{ + { + name: "relay address delayed", + addrs: []ma.Multiaddr{q1, q2, r1, r2}, + output: []network.AddrDelay{ + {Addr: q1, Delay: 0}, + {Addr: q2, Delay: PublicQUICDelay}, + {Addr: r2, Delay: RelayDelay}, + {Addr: r1, Delay: PublicTCPDelay + RelayDelay}, + }, + }, + } + for _, tc := range testCase { + t.Run(tc.name, func(t *testing.T) { + res := DefaultDialRanker(tc.addrs) + if len(res) != len(tc.output) { + log.Errorf("expected %s got %s", tc.output, res) + t.Errorf("expected elems: %d got: %d", len(tc.output), len(res)) + } + sortAddrDelays(res) + sortAddrDelays(tc.output) + for i := 0; i < len(tc.output); i++ { + if !tc.output[i].Addr.Equal(res[i].Addr) || tc.output[i].Delay != res[i].Delay { + t.Fatalf("expected %+v got %+v", tc.output, res) + } + } + }) + } +} diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index a319d00e5c..8d574eba7f 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -2,13 +2,14 @@ package swarm import ( "context" + "math" "sync" + "time" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" ma "github.com/multiformats/go-multiaddr" - manet "github.com/multiformats/go-multiaddr/net" ) // ///////////////////////////////////////////////////////////////////////////////// @@ -16,77 +17,159 @@ import ( // TODO explain how all this works // //////////////////////////////////////////////////////////////////////////////// +// dialRequest is structure used to request dials to the peer associated with a +// worker loop type dialRequest struct { - ctx context.Context + // ctx is the context that may be used for the request + // if another concurrent request is made, any of the concurrent request's ctx may be used for + // dials to the peer's addresses + // ctx for simultaneous connect requests have higher priority than normal requests + ctx context.Context + // resch is the channel used to send the response for this query resch chan dialResponse } +// dialResponse is the response sent to dialRequests on the request's resch channel type dialResponse struct { + // conn is the connection to the peer on success conn *Conn - err error + // err is the error in dialing the peer + // nil on connection success + err error } +// pendRequest is used to track progress on a dialRequest. type pendRequest struct { - req dialRequest // the original request - err *DialError // dial error accumulator - addrs map[string]struct{} // pending address to dial. The key is a multiaddr + // req is the original dialRequest + req dialRequest + // err comprises errors of all failed dials + err *DialError + // addrs are the addresses on which we are waiting for pending dials + // At the time of creation addrs is initialised to all the addresses of the peer. On a failed dial, + // the addr is removed from the map and err is updated. On a successful dial, the dialRequest is + // completed and response is sent with the connection + addrs map[string]struct{} } +// addrDial tracks dials to a particular multiaddress. type addrDial struct { - addr ma.Multiaddr - ctx context.Context - conn *Conn - err error + // addr is the address dialed + addr ma.Multiaddr + // ctx is the context used for dialing the address + ctx context.Context + // conn is the established connection on success + conn *Conn + // err is the err on dialing the address + err error + // requests is the list of pendRequests interested in this dial + // the value in the slice is the request number assigned to this request by the dialWorker requests []int + // dialed indicates whether we have triggered the dial to the address + dialed bool + // createdAt is the time this struct was created + createdAt time.Time + // dialRankingDelay is the delay in dialing this address introduced by the ranking logic + dialRankingDelay time.Duration } +// dialWorker synchronises concurrent dials to a peer. It ensures that we make at most one dial to a +// peer's address type dialWorker struct { - s *Swarm - peer peer.ID - reqch <-chan dialRequest - reqno int - requests map[int]*pendRequest - pending map[string]*addrDial // pending addresses to dial. The key is a multiaddr - resch chan dialResult + s *Swarm + peer peer.ID + // reqch is used to send dial requests to the worker. close reqch to end the worker loop + reqch <-chan dialRequest + // reqno is the request number used to track different dialRequests for a peer. + // Each incoming request is assigned a reqno. This reqno is used in pendingRequests and in + // addrDial objects in trackedDials to track this request + reqno int + // pendingRequests maps reqno to the pendRequest object for a dialRequest + pendingRequests map[int]*pendRequest + // trackedDials tracks dials to the peers addresses. An entry here is used to ensure that + // we dial an address at most once + trackedDials map[string]*addrDial + // resch is used to receive response for dials to the peers addresses. + resch chan dialResult connected bool // true when a connection has been successfully established - nextDial []ma.Multiaddr - - // ready when we have more addresses to dial (nextDial is not empty) - triggerDial <-chan struct{} - // for testing wg sync.WaitGroup + cl Clock } -func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest) *dialWorker { +func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest, cl Clock) *dialWorker { + if cl == nil { + cl = RealClock{} + } return &dialWorker{ - s: s, - peer: p, - reqch: reqch, - requests: make(map[int]*pendRequest), - pending: make(map[string]*addrDial), - resch: make(chan dialResult), + s: s, + peer: p, + reqch: reqch, + pendingRequests: make(map[int]*pendRequest), + trackedDials: make(map[string]*addrDial), + resch: make(chan dialResult), + cl: cl, } } +// loop implements the core dial worker loop. Requests are received on w.reqch. +// The loop exits when w.reqch is closed. func (w *dialWorker) loop() { w.wg.Add(1) defer w.wg.Done() defer w.s.limiter.clearAllPeerDials(w.peer) - // used to signal readiness to dial and completion of the dial - ready := make(chan struct{}) - close(ready) + // dq is used to pace dials to different addresses of the peer + dq := newDialQueue() + // dialsInFlight is the number of dials in flight. + dialsInFlight := 0 + + startTime := w.cl.Now() + // dialTimer is the dialTimer used to trigger dials + dialTimer := w.cl.InstantTimer(startTime.Add(math.MaxInt64)) + timerRunning := true + // scheduleNextDial updates timer for triggering the next dial + scheduleNextDial := func() { + if timerRunning && !dialTimer.Stop() { + <-dialTimer.Ch() + } + timerRunning = false + if dq.len() > 0 { + if dialsInFlight == 0 && !w.connected { + // if there are no dials in flight, trigger the next dials immediately + dialTimer.Reset(startTime) + } else { + dialTimer.Reset(startTime.Add(dq.top().Delay)) + } + timerRunning = true + } + } + // totalDials is used to track number of dials made by this worker for metrics + totalDials := 0 loop: for { + // The loop has three parts + // 1. Input requests are received on w.reqch. If a suitable connection is not available we create + // a pendRequest object to track the dialRequest and add the addresses to dq. + // 2. Addresses from the dialQueue are dialed at appropriate time intervals depending on delay logic. + // We are notified of the completion of these dials on w.resch. + // 3. Responses for dials are received on w.resch. On receiving a response, we updated the pendRequests + // interested in dials on this address. + select { case req, ok := <-w.reqch: if !ok { + if w.s.metricsTracer != nil { + w.s.metricsTracer.DialCompleted(w.connected, totalDials) + } return } + // We have received a new request. If we do not have a suitable connection, + // track this dialRequest with a pendRequest. + // Enqueue the peer's addresses relevant to this request in dq and + // track dials to the addresses relevant to this request. c, err := w.s.bestAcceptableConnToPeer(req.ctx, w.peer) if c != nil || err != nil { @@ -100,29 +183,34 @@ loop: continue loop } - // at this point, len(addrs) > 0 or else it would be error from addrsForDial - // ranke them to process in order - addrs = w.rankAddrs(addrs) + // get the delays to dial these addrs from the swarms dialRanker + simConnect, _, _ := network.GetSimultaneousConnect(req.ctx) + addrRanking := w.rankAddrs(addrs, simConnect) + addrDelay := make(map[string]time.Duration, len(addrRanking)) // create the pending request object pr := &pendRequest{ req: req, err: &DialError{Peer: w.peer}, - addrs: make(map[string]struct{}), + addrs: make(map[string]struct{}, len(addrRanking)), } - for _, a := range addrs { - pr.addrs[string(a.Bytes())] = struct{}{} + for _, adelay := range addrRanking { + pr.addrs[string(adelay.Addr.Bytes())] = struct{}{} + addrDelay[string(adelay.Addr.Bytes())] = adelay.Delay } - // check if any of the addrs has been successfully dialed and accumulate - // errors from complete dials while collecting new addrs to dial/join + // Check if dials to any of the addrs have completed already + // If they have errored, record the error in pr. If they have succeeded, + // respond with the connection. + // If they are pending, add them to tojoin. + // If we haven't seen any of the addresses before, add them to todial. var todial []ma.Multiaddr var tojoin []*addrDial - for _, a := range addrs { - ad, ok := w.pending[string(a.Bytes())] + for _, adelay := range addrRanking { + ad, ok := w.trackedDials[string(adelay.Addr.Bytes())] if !ok { - todial = append(todial, a) + todial = append(todial, adelay.Addr) continue } @@ -134,8 +222,8 @@ loop: if ad.err != nil { // dial to this addr errored, accumulate the error - pr.err.recordErr(a, ad.err) - delete(pr.addrs, string(a.Bytes())) + pr.err.recordErr(ad.addr, ad.err) + delete(pr.addrs, string(ad.addr.Bytes())) continue } @@ -149,58 +237,89 @@ loop: continue loop } - // the request has some pending or new dials, track it and schedule new dials + // The request has some pending or new dials. We assign this request a request number. + // This value of w.reqno is used to track this request in all the structures w.reqno++ - w.requests[w.reqno] = pr + w.pendingRequests[w.reqno] = pr for _, ad := range tojoin { - if simConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); simConnect { - if simConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !simConnect { - ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason) + if !ad.dialed { + // we haven't dialed this address. update the ad.ctx to have simultaneous connect values + // set correctly + if simConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); simConnect { + if simConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !simConnect { + ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason) + // update the element in dq to use the simultaneous connect delay. + dq.Add(network.AddrDelay{ + Addr: ad.addr, + Delay: addrDelay[string(ad.addr.Bytes())], + }) + } } } + // add the request to the addrDial ad.requests = append(ad.requests, w.reqno) } if len(todial) > 0 { + now := time.Now() + // these are new addresses, track them and add them to dq for _, a := range todial { - w.pending[string(a.Bytes())] = &addrDial{addr: a, ctx: req.ctx, requests: []int{w.reqno}} + w.trackedDials[string(a.Bytes())] = &addrDial{ + addr: a, + ctx: req.ctx, + requests: []int{w.reqno}, + createdAt: now, + } + dq.Add(network.AddrDelay{Addr: a, Delay: addrDelay[string(a.Bytes())]}) } - - w.nextDial = append(w.nextDial, todial...) - w.nextDial = w.rankAddrs(w.nextDial) - - // trigger a new dial now to account for the new addrs we added - w.triggerDial = ready } - - case <-w.triggerDial: - for _, addr := range w.nextDial { + // setup dialTimer for updates to dq + scheduleNextDial() + + case <-dialTimer.Ch(): + // It's time to dial the next batch of addresses. + // We don't check the delay of the addresses received from the queue here + // because if the timer triggered before the delay, it means that all + // the inflight dials have errored and we should dial the next batch of + // addresses + now := time.Now() + for _, adelay := range dq.NextBatch() { // spawn the dial - ad, ok := w.pending[string(addr.Bytes())] + ad, ok := w.trackedDials[string(adelay.Addr.Bytes())] if !ok { - log.Warn("unexpectedly missing pending addrDial for addr") - // Assume nothing to dial here + log.Errorf("SWARM BUG: no entry for address %s in trackedDials", adelay.Addr) continue } - err := w.s.dialNextAddr(ad.ctx, w.peer, addr, w.resch) + ad.dialed = true + ad.dialRankingDelay = now.Sub(ad.createdAt) + err := w.s.dialNextAddr(ad.ctx, w.peer, ad.addr, w.resch) if err != nil { + // the actual dial happens in a different go routine. An err here + // only happens in case of backoff. handle that. w.dispatchError(ad, err) + } else { + // the dial was successful. update inflight dials + dialsInFlight++ + totalDials++ } } - - w.nextDial = nil - w.triggerDial = nil + timerRunning = false + // schedule more dials + scheduleNextDial() case res := <-w.resch: - if res.Conn != nil { - w.connected = true - } + // A dial to an address has completed. + // Update all requests waiting on this address. On success, complete the request. + // On error, record the error - ad, ok := w.pending[string(res.Addr.Bytes())] + dialsInFlight-- + ad, ok := w.trackedDials[string(res.Addr.Bytes())] if !ok { - log.Warn("unexpectedly missing pending addrDial res") - // Assume nothing to do here + log.Errorf("SWARM BUG: no entry for address %s in trackedDials", res.Addr) + if res.Conn != nil { + res.Conn.Close() + } continue } @@ -214,21 +333,27 @@ loop: continue loop } - // dispatch to still pending requests + // request succeeded, respond to all pending requests for _, reqno := range ad.requests { - pr, ok := w.requests[reqno] + pr, ok := w.pendingRequests[reqno] if !ok { - // it has already dispatched a connection + // some other dial for this request succeeded before this one continue } - pr.req.resch <- dialResponse{conn: conn} - delete(w.requests, reqno) + delete(w.pendingRequests, reqno) } ad.conn = conn ad.requests = nil + if !w.connected { + w.connected = true + if w.s.metricsTracer != nil { + w.s.metricsTracer.DialRankingDelay(ad.dialRankingDelay) + } + } + continue loop } @@ -238,8 +363,11 @@ loop: // for consistency with the old dialer behavior. w.s.backf.AddBackoff(w.peer, res.Addr) } - w.dispatchError(ad, res.Err) + // Only schedule next dial on error. + // If we scheduleNextDial on success, we will end up making one dial more than + // required because the final successful dial will spawn one more dial + scheduleNextDial() } } } @@ -248,9 +376,9 @@ loop: func (w *dialWorker) dispatchError(ad *addrDial, err error) { ad.err = err for _, reqno := range ad.requests { - pr, ok := w.requests[reqno] + pr, ok := w.pendingRequests[reqno] if !ok { - // has already been dispatched + // some other dial for this request succeeded before this one continue } @@ -268,7 +396,7 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) { } else { pr.req.resch <- dialResponse{err: pr.err} } - delete(w.requests, reqno) + delete(w.pendingRequests, reqno) } } @@ -278,46 +406,82 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) { // this is necessary to support active listen scenarios, where a new dial comes in while // another dial is in progress, and needs to do a direct connection without inhibitions from // dial backoff. - // it is also necessary to preserve consisent behaviour with the old dialer -- TestDialBackoff - // regresses without this. if err == ErrDialBackoff { - delete(w.pending, string(ad.addr.Bytes())) + delete(w.trackedDials, string(ad.addr.Bytes())) } } -// ranks addresses in descending order of preference for dialing, with the following rules: -// NonRelay > Relay -// NonWS > WS -// Private > Public -// UDP > TCP -func (w *dialWorker) rankAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { - addrTier := func(a ma.Multiaddr) (tier int) { - if isRelayAddr(a) { - tier |= 0b1000 - } - if isExpensiveAddr(a) { - tier |= 0b0100 - } - if !manet.IsPrivateAddr(a) { - tier |= 0b0010 - } - if isFdConsumingAddr(a) { - tier |= 0b0001 +// rankAddrs ranks addresses for dialing. if it's a simConnect request we +// dial all addresses immediately without any delay +func (w *dialWorker) rankAddrs(addrs []ma.Multiaddr, isSimConnect bool) []network.AddrDelay { + if isSimConnect { + return noDelayRanker(addrs) + } + return w.s.dialRanker(addrs) +} + +// dialQueue is a priority queue used to schedule dials +type dialQueue struct { + // q contains dials ordered by delay + q []network.AddrDelay +} + +// newDialQueue returns a new dialQueue +func newDialQueue() *dialQueue { + return &dialQueue{q: make([]network.AddrDelay, 0, 16)} +} + +// Add adds adelay to the queue. If another element exists in the queue with +// the same address, it replaces that element. +func (dq *dialQueue) Add(adelay network.AddrDelay) { + for i := 0; i < dq.len(); i++ { + if dq.q[i].Addr.Equal(adelay.Addr) { + if dq.q[i].Delay == adelay.Delay { + // existing element is the same. nothing to do + return + } + // remove the element + copy(dq.q[i:], dq.q[i+1:]) + dq.q = dq.q[:len(dq.q)-1] + break } + } - return tier + for i := 0; i < dq.len(); i++ { + if dq.q[i].Delay > adelay.Delay { + dq.q = append(dq.q, network.AddrDelay{}) // extend the slice + copy(dq.q[i+1:], dq.q[i:]) + dq.q[i] = adelay + return + } } + dq.q = append(dq.q, adelay) +} - tiers := make([][]ma.Multiaddr, 16) - for _, a := range addrs { - tier := addrTier(a) - tiers[tier] = append(tiers[tier], a) +// NextBatch returns all the elements in the queue with the highest priority +func (dq *dialQueue) NextBatch() []network.AddrDelay { + if dq.len() == 0 { + return nil } - result := make([]ma.Multiaddr, 0, len(addrs)) - for _, tier := range tiers { - result = append(result, tier...) + // i is the index of the second highest priority element + var i int + for i = 0; i < dq.len(); i++ { + if dq.q[i].Delay != dq.q[0].Delay { + break + } } + res := dq.q[:i] + dq.q = dq.q[i:] + return res +} + +// top returns the top element of the queue +func (dq *dialQueue) top() network.AddrDelay { + return dq.q[0] +} - return result +// len returns the number of elements in the queue +func (dq *dialQueue) len() int { + return len(dq.q) } diff --git a/p2p/net/swarm/dial_worker_test.go b/p2p/net/swarm/dial_worker_test.go index ebdaedb245..903a9e500c 100644 --- a/p2p/net/swarm/dial_worker_test.go +++ b/p2p/net/swarm/dial_worker_test.go @@ -5,15 +5,22 @@ import ( "crypto/rand" "errors" "fmt" + "math" + mrand "math/rand" + "reflect" + "sort" "sync" "testing" + "testing/quick" "time" "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/core/sec/insecure" + "github.com/libp2p/go-libp2p/core/test" "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem" @@ -28,6 +35,18 @@ import ( "github.com/stretchr/testify/require" ) +type mockClock struct { + *test.MockClock +} + +func (m *mockClock) InstantTimer(when time.Time) InstantTimer { + return m.MockClock.InstantTimer(when) +} + +func newMockClock() *mockClock { + return &mockClock{test.NewMockClock()} +} + func newPeer(t *testing.T) (crypto.PrivKey, peer.ID) { priv, _, err := crypto.GenerateEd25519Key(rand.Reader) require.NoError(t, err) @@ -37,6 +56,19 @@ func newPeer(t *testing.T) (crypto.PrivKey, peer.ID) { } func makeSwarm(t *testing.T) *Swarm { + s := makeSwarmWithNoListenAddrs(t, WithDialTimeout(1*time.Second)) + if err := s.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")); err != nil { + t.Fatal(err) + } + + if err := s.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic")); err != nil { + t.Fatal(err) + } + + return s +} + +func makeSwarmWithNoListenAddrs(t *testing.T, opts ...Option) *Swarm { priv, id := newPeer(t) ps, err := pstoremem.NewPeerstore() @@ -45,11 +77,10 @@ func makeSwarm(t *testing.T) *Swarm { ps.AddPrivKey(id, priv) t.Cleanup(func() { ps.Close() }) - s, err := NewSwarm(id, ps, eventbus.NewBus(), WithDialTimeout(time.Second)) + s, err := NewSwarm(id, ps, eventbus.NewBus(), opts...) require.NoError(t, err) upgrader := makeUpgrader(t, s) - var tcpOpts []tcp.Option tcpOpts = append(tcpOpts, tcp.DisableReuseport()) tcpTransport, err := tcp.NewTCPTransport(upgrader, nil, tcpOpts...) @@ -57,10 +88,6 @@ func makeSwarm(t *testing.T) *Swarm { if err := s.AddTransport(tcpTransport); err != nil { t.Fatal(err) } - if err := s.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0")); err != nil { - t.Fatal(err) - } - reuse, err := quicreuse.NewConnManager([32]byte{}) if err != nil { t.Fatal(err) @@ -72,10 +99,6 @@ func makeSwarm(t *testing.T) *Swarm { if err := s.AddTransport(quicTransport); err != nil { t.Fatal(err) } - if err := s.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic")); err != nil { - t.Fatal(err) - } - return s } @@ -89,6 +112,33 @@ func makeUpgrader(t *testing.T, n *Swarm) transport.Upgrader { return u } +// makeTCPListener listens on tcp address a. On accepting a connection it notifies recvCh. Sending a message to +// channel ch will close an accepted connection +func makeTCPListener(t *testing.T, a ma.Multiaddr, recvCh chan struct{}) (list manet.Listener, ch chan struct{}) { + t.Helper() + list, err := manet.Listen(a) + if err != nil { + t.Fatal(err) + } + ch = make(chan struct{}) + go func() { + for { + c, err := list.Accept() + if err != nil { + break + } + recvCh <- struct{}{} + <-ch + err = c.Close() + if err != nil { + t.Error(err) + } + + } + }() + return list, ch +} + func TestDialWorkerLoopBasic(t *testing.T) { s1 := makeSwarm(t) s2 := makeSwarm(t) @@ -100,7 +150,7 @@ func TestDialWorkerLoopBasic(t *testing.T) { reqch := make(chan dialRequest) resch := make(chan dialResponse) - worker := newDialWorker(s1, s2.LocalPeer(), reqch) + worker := newDialWorker(s1, s2.LocalPeer(), reqch, nil) go worker.loop() var conn *Conn @@ -145,7 +195,7 @@ func TestDialWorkerLoopConcurrent(t *testing.T) { s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.PermanentAddrTTL) reqch := make(chan dialRequest) - worker := newDialWorker(s1, s2.LocalPeer(), reqch) + worker := newDialWorker(s1, s2.LocalPeer(), reqch, nil) go worker.loop() const dials = 100 @@ -188,7 +238,7 @@ func TestDialWorkerLoopFailure(t *testing.T) { reqch := make(chan dialRequest) resch := make(chan dialResponse) - worker := newDialWorker(s1, p2, reqch) + worker := newDialWorker(s1, p2, reqch, nil) go worker.loop() reqch <- dialRequest{ctx: context.Background(), resch: resch} @@ -212,7 +262,7 @@ func TestDialWorkerLoopConcurrentFailure(t *testing.T) { s1.Peerstore().AddAddrs(p2, []ma.Multiaddr{ma.StringCast("/ip4/11.0.0.1/tcp/1234"), ma.StringCast("/ip4/11.0.0.1/udp/1234/quic")}, peerstore.PermanentAddrTTL) reqch := make(chan dialRequest) - worker := newDialWorker(s1, p2, reqch) + worker := newDialWorker(s1, p2, reqch, nil) go worker.loop() const dials = 100 @@ -260,7 +310,7 @@ func TestDialWorkerLoopConcurrentMix(t *testing.T) { s1.Peerstore().AddAddrs(s2.LocalPeer(), []ma.Multiaddr{ma.StringCast("/ip4/11.0.0.1/tcp/1234"), ma.StringCast("/ip4/11.0.0.1/udp/1234/quic")}, peerstore.PermanentAddrTTL) reqch := make(chan dialRequest) - worker := newDialWorker(s1, s2.LocalPeer(), reqch) + worker := newDialWorker(s1, s2.LocalPeer(), reqch, nil) go worker.loop() const dials = 100 @@ -306,7 +356,7 @@ func TestDialWorkerLoopConcurrentFailureStress(t *testing.T) { s1.Peerstore().AddAddrs(p2, addrs, peerstore.PermanentAddrTTL) reqch := make(chan dialRequest) - worker := newDialWorker(s1, p2, reqch) + worker := newDialWorker(s1, p2, reqch, nil) go worker.loop() const dials = 100 @@ -344,6 +394,629 @@ func TestDialWorkerLoopConcurrentFailureStress(t *testing.T) { worker.wg.Wait() } +func TestDialQueueNextBatch(t *testing.T) { + addrs := make([]ma.Multiaddr, 0) + for i := 0; i < 10; i++ { + addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/1.2.3.4/tcp/%d", i))) + } + testcase := []struct { + name string + input []network.AddrDelay + output [][]ma.Multiaddr + }{ + { + name: "next batch", + input: []network.AddrDelay{ + {Addr: addrs[0], Delay: 3}, + {Addr: addrs[1], Delay: 2}, + {Addr: addrs[2], Delay: 1}, + {Addr: addrs[3], Delay: 1}, + }, + output: [][]ma.Multiaddr{ + {addrs[2], addrs[3]}, + {addrs[1]}, + {addrs[0]}, + }, + }, + { + name: "priority queue property 2", + input: []network.AddrDelay{ + {Addr: addrs[0], Delay: 5}, + {Addr: addrs[1], Delay: 3}, + {Addr: addrs[2], Delay: 2}, + {Addr: addrs[3], Delay: 1}, + {Addr: addrs[4], Delay: 1}, + }, + + output: [][]ma.Multiaddr{ + {addrs[3], addrs[4]}, + {addrs[2]}, + {addrs[1]}, + {addrs[0]}, + }, + }, + { + name: "updates", + input: []network.AddrDelay{ + {Addr: addrs[0], Delay: 3}, // decreasing order + {Addr: addrs[1], Delay: 3}, + {Addr: addrs[2], Delay: 2}, + {Addr: addrs[3], Delay: 2}, + {Addr: addrs[4], Delay: 1}, + {Addr: addrs[0], Delay: 1}, // increasing order + {Addr: addrs[1], Delay: 1}, + {Addr: addrs[2], Delay: 2}, + {Addr: addrs[3], Delay: 2}, + {Addr: addrs[4], Delay: 3}, + }, + output: [][]ma.Multiaddr{ + {addrs[0], addrs[1]}, + {addrs[2], addrs[3]}, + {addrs[4]}, + {}, + }, + }, + { + name: "null input", + input: []network.AddrDelay{}, + output: [][]ma.Multiaddr{ + {}, + {}, + }, + }, + } + for _, tc := range testcase { + t.Run(tc.name, func(t *testing.T) { + q := newDialQueue() + for i := 0; i < len(tc.input); i++ { + q.Add(tc.input[i]) + } + for _, batch := range tc.output { + b := q.NextBatch() + if len(batch) != len(b) { + t.Errorf("expected %d elements got %d", len(batch), len(b)) + } + sort.Slice(b, func(i, j int) bool { return b[i].Addr.String() < b[j].Addr.String() }) + sort.Slice(batch, func(i, j int) bool { return batch[i].String() < batch[j].String() }) + for i := 0; i < len(b); i++ { + if !b[i].Addr.Equal(batch[i]) { + log.Errorf("expected %s got %s", batch[i], b[i].Addr) + } + } + } + if q.len() != 0 { + t.Errorf("expected queue to be empty at end. got: %d", q.len()) + } + }) + } +} + +// timedDial is a dial to a single address of the peer +type timedDial struct { + // addr is the address to dial + addr ma.Multiaddr + // delay is the delay after which this address should be dialed + delay time.Duration + // success indicates whether the dial should succeed + success bool + // failAfter is how long this dial should take to fail after it is dialed + failAfter time.Duration +} + +// schedulingTestCase is used to test dialWorker loop scheduler logic +// a ranker is made according to `input` which provides the addresses to +// dial worker loop with the specified delays +// checkDialWorkerLoopScheduling then verifies that the different dial calls are +// made at the right moments +type schedulingTestCase struct { + name string + input []timedDial + maxDuration time.Duration +} + +// schedulingTestCase generates a random test case +func (s schedulingTestCase) Generate(rand *mrand.Rand, size int) reflect.Value { + if size > 20 { + size = 20 + } + input := make([]timedDial, size) + delays := make(map[time.Duration]struct{}) + for i := 0; i < size; i++ { + input[i] = timedDial{ + addr: ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", i+10550)), + delay: time.Duration(mrand.Intn(100)) * 10 * time.Millisecond, // max 1 second + success: false, + failAfter: time.Duration(mrand.Intn(100)) * 10 * time.Millisecond, // max 1 second + } + delays[input[i].delay] = struct{}{} + } + successIdx := rand.Intn(size) + for { + // set a unique delay for success. This is required to test the property that + // no extra dials are made after success + d := time.Duration(rand.Intn(100)) * 10 * time.Millisecond + if _, ok := delays[d]; !ok { + input[successIdx].delay = d + input[successIdx].success = true + break + } + } + return reflect.ValueOf(schedulingTestCase{ + name: "", + input: input, + maxDuration: 10 * time.Second, // not tested here + }) +} + +// dialState is used to track the dials for testing dialWorker ranking logic +type dialState struct { + // ch is the chan used to trigger dial failure. + ch chan struct{} + // addr is the address of the dial + addr ma.Multiaddr + // delay is the delay after which this address should be dialed + delay time.Duration + // success indicates whether the dial should succeed + success bool + // failAfter is how long this dial should take to fail after it is dialed + failAfter time.Duration + // failAt is the instant at which this dial should fail if success is false + failAt time.Time +} + +// checkDialWorkerLoopScheduling verifies whether s1 dials s2 according to the +// schedule specified by the test case tc +func checkDialWorkerLoopScheduling(t *testing.T, s1, s2 *Swarm, tc schedulingTestCase) error { + t.Helper() + // failDials is used to track dials which should fail in the future + // at appropriate moment a message is sent to dialState.ch to trigger + // failure + failDials := make(map[ma.Multiaddr]dialState) + // recvCh is used to receive dial notifications for dials that will fail + recvCh := make(chan struct{}, 100) + // allDials tracks all pending dials + allDials := make(map[ma.Multiaddr]dialState) + // addrs are the peer addresses the swarm will use for dialing + addrs := make([]ma.Multiaddr, 0) + // create pending dials + // we add success cases as a listen address on swarm + // failed cases are created using makeTCPListener + for _, inp := range tc.input { + var failCh chan struct{} + if inp.success { + // add the address as a listen address if this dial should succeed + err := s2.AddListenAddr(inp.addr) + if err != nil { + return fmt.Errorf("failed to listen on addr: %s: err: %w", inp.addr, err) + } + } else { + // make a listener which will fail on sending a message to ch + l, ch := makeTCPListener(t, inp.addr, recvCh) + failCh = ch + f := func() { + err := l.Close() + if err != nil { + t.Error(err) + } + } + defer f() + } + addrs = append(addrs, inp.addr) + // add to pending dials + allDials[inp.addr] = dialState{ + ch: failCh, + addr: inp.addr, + delay: inp.delay, + success: inp.success, + failAfter: inp.failAfter, + } + } + // setup the peers addresses + s1.Peerstore().AddAddrs(s2.LocalPeer(), addrs, peerstore.PermanentAddrTTL) + + // create worker + reqch := make(chan dialRequest) + resch := make(chan dialResponse) + cl := newMockClock() + st := cl.Now() + worker1 := newDialWorker(s1, s2.LocalPeer(), reqch, cl) + go worker1.loop() + defer worker1.wg.Wait() + defer close(reqch) + + // trigger the request + reqch <- dialRequest{ctx: context.Background(), resch: resch} + + connected := false + + // Advance the clock by 10 ms every iteration + // At every iteration: + // Check if any dial should fail. if it should, trigger the failure by sending a message on the + // listener failCh + // If there are no dials in flight check the most urgent dials have been triggered + // If there are dials in flight check that the relevant dials have been triggered + // Before next iteration ensure that no unexpected dials are received +loop: + for { + // fail any dials that should fail at this instant + for a, p := range failDials { + if p.failAt.Before(cl.Now()) || p.failAt == cl.Now() { + p.ch <- struct{}{} + delete(failDials, a) + } + } + // if there are no pending dials, next dial should have been triggered + trigger := len(failDials) == 0 + + // mi is the minimum delay of pending dials + // if trigger is true, all dials with miDelay should have been triggered + mi := time.Duration(math.MaxInt64) + for _, ds := range allDials { + if ds.delay < mi { + mi = ds.delay + } + } + for a, ds := range allDials { + if (trigger && mi == ds.delay) || + cl.Now().After(st.Add(ds.delay)) || + cl.Now() == st.Add(ds.delay) { + if ds.success { + // check for success and exit + select { + case r := <-resch: + if r.conn == nil { + return errors.New("expected connection to succeed") + } + // High timeout here is okay. We will exit whenever the other branch + // is triggered + case <-time.After(10 * time.Second): + return errors.New("expected to receive a response") + } + connected = true + break loop + } else { + // ensure that a failing dial attempt happened but didn't succeed + select { + case <-recvCh: + case <-resch: + return errors.New("didn't expect a response") + // High timeout here is okay. We will exit whenever the other branch + // is triggered + case <-time.After(10 * time.Second): + return errors.New("didn't receive a dial attempt notification") + } + failDials[a] = dialState{ + ch: ds.ch, + failAt: cl.Now().Add(ds.failAfter), + addr: a, + delay: ds.delay, + } + } + delete(allDials, a) + } + } + // check for unexpected dials + select { + case <-recvCh: + return errors.New("no dial should have succeeded at this instant") + default: + } + + // advance the clock + cl.AdvanceBy(10 * time.Millisecond) + // nothing more to do. exit + if len(failDials) == 0 && len(allDials) == 0 { + break + } + } + + if connected { + // ensure we don't receive any extra connections + select { + case <-recvCh: + return errors.New("didn't expect a dial attempt") + case <-time.After(100 * time.Millisecond): + } + } else { + // ensure that we do receive the final error response + select { + case r := <-resch: + require.Error(t, r.err) + case <-time.After(100 * time.Millisecond): + return errors.New("expected to receive response") + } + } + // check if this test didn't take too much time + if cl.Now().Sub(st) > tc.maxDuration { + return fmt.Errorf("expected test to finish early: expected %d, took: %d", tc.maxDuration, cl.Now().Sub(st)) + } + return nil +} + +// makeRanker takes a slice of timedDial objects and returns a DialRanker +// which will trigger dials to addresses at the specified delays in the timedDials +func makeRanker(tc []timedDial) network.DialRanker { + return func(addrs []ma.Multiaddr) []network.AddrDelay { + res := make([]network.AddrDelay, len(tc)) + for i := 0; i < len(tc); i++ { + res[i] = network.AddrDelay{Addr: tc[i].addr, Delay: tc[i].delay} + } + return res + } +} + +// TestCheckDialWorkerLoopScheduling will check the checker +func TestCheckDialWorkerLoopScheduling(t *testing.T) { + addrs := make([]ma.Multiaddr, 0) + for i := 0; i < 10; i++ { + for { + p := 20000 + i + addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", p))) + break + } + } + + tc := schedulingTestCase{ + input: []timedDial{ + { + addr: addrs[1], + delay: 0, + success: true, + }, + { + addr: addrs[0], + delay: 100 * time.Millisecond, + success: false, + failAfter: 50 * time.Millisecond, + }, + }, + maxDuration: 20 * time.Millisecond, + } + s1 := makeSwarmWithNoListenAddrs(t) + s2 := makeSwarmWithNoListenAddrs(t) + // valid ranking logic, so it shouldn't error + s1.dialRanker = makeRanker(tc.input) + err := checkDialWorkerLoopScheduling(t, s1, s2, tc) + require.NoError(t, err) + // close swarms to remove address binding + s1.Close() + s2.Close() + + s3 := makeSwarmWithNoListenAddrs(t) + defer s3.Close() + s4 := makeSwarmWithNoListenAddrs(t) + defer s4.Close() + // invalid ranking logic to trigger an error + s3.dialRanker = noDelayRanker + err = checkDialWorkerLoopScheduling(t, s3, s4, tc) + require.Error(t, err) +} + +func TestDialWorkerLoopRanking(t *testing.T) { + addrs := make([]ma.Multiaddr, 0) + for i := 0; i < 10; i++ { + for { + p := 20000 + i + addrs = append(addrs, ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", p))) + break + } + } + + testcases := []schedulingTestCase{ + { + name: "first success", + input: []timedDial{ + { + addr: addrs[1], + delay: 0, + success: true, + }, + { + addr: addrs[0], + delay: 100 * time.Millisecond, + success: false, + failAfter: 50 * time.Millisecond, + }, + }, + maxDuration: 20 * time.Millisecond, + }, + { + name: "delayed dials", + input: []timedDial{ + { + addr: addrs[0], + delay: 0, + success: false, + failAfter: 200 * time.Millisecond, + }, + { + addr: addrs[1], + delay: 100 * time.Millisecond, + success: false, + failAfter: 100 * time.Millisecond, + }, + { + addr: addrs[2], + delay: 300 * time.Millisecond, + success: false, + failAfter: 100 * time.Millisecond, + }, + { + addr: addrs[3], + delay: 2 * time.Second, + success: true, + }, + { + addr: addrs[4], + delay: 2*time.Second + 1*time.Millisecond, + success: false, // this call will never happened + failAfter: 100 * time.Millisecond, + }, + }, + maxDuration: 310 * time.Millisecond, + }, + { + name: "failed dials", + input: []timedDial{ + { + addr: addrs[0], + delay: 0, + success: false, + failAfter: 105 * time.Millisecond, + }, + { + addr: addrs[1], + delay: 100 * time.Millisecond, + success: false, + failAfter: 20 * time.Millisecond, + }, + }, + maxDuration: 200 * time.Millisecond, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + s1 := makeSwarmWithNoListenAddrs(t) + defer s1.Close() + s2 := makeSwarmWithNoListenAddrs(t) + defer s2.Close() + // setup the ranker to trigger dials according to the test case + s1.dialRanker = makeRanker(tc.input) + err := checkDialWorkerLoopScheduling(t, s1, s2, tc) + if err != nil { + t.Error(err) + } + }) + } +} + +func TestDialWorkerLoopSchedulingProperty(t *testing.T) { + f := func(tc schedulingTestCase) bool { + s1 := makeSwarmWithNoListenAddrs(t) + defer s1.Close() + // ignore limiter delays just check scheduling + s1.limiter.perPeerLimit = 10000 + s2 := makeSwarmWithNoListenAddrs(t) + defer s2.Close() + // setup the ranker to trigger dials according to the test case + s1.dialRanker = makeRanker(tc.input) + err := checkDialWorkerLoopScheduling(t, s1, s2, tc) + if err != nil { + log.Error(err) + } + return err == nil + } + + if err := quick.Check(f, &quick.Config{MaxCount: 50}); err != nil { + t.Error(err) + } +} + +func TestDialWorkerLoopQuicOverTCP(t *testing.T) { + tc := schedulingTestCase{ + input: []timedDial{ + { + addr: ma.StringCast("/ip4/127.0.0.1/udp/20000/quic"), + delay: 0, + success: true, + }, + { + addr: ma.StringCast("/ip4/127.0.0.1/tcp/20000"), + delay: 30 * time.Millisecond, + success: true, + }, + }, + maxDuration: 20 * time.Millisecond, + } + s1 := makeSwarmWithNoListenAddrs(t) + defer s1.Close() + + s2 := makeSwarmWithNoListenAddrs(t) + defer s2.Close() + + // we use the default ranker here + + err := checkDialWorkerLoopScheduling(t, s1, s2, tc) + require.NoError(t, err) +} + +func TestDialWorkerLoopHolePunching(t *testing.T) { + s1 := makeSwarmWithNoListenAddrs(t) + defer s1.Close() + + s2 := makeSwarmWithNoListenAddrs(t) + defer s2.Close() + + // t1 will accept and keep the other end waiting + t1 := ma.StringCast("/ip4/127.0.0.1/tcp/10000") + recvCh := make(chan struct{}) + list, ch := makeTCPListener(t, t1, recvCh) // ignore ch because we want to hang forever + defer list.Close() + defer func() { ch <- struct{}{} }() // close listener + + // t2 will succeed + t2 := ma.StringCast("/ip4/127.0.0.1/tcp/10001") + + err := s2.AddListenAddr(t2) + if err != nil { + t.Error(err) + } + + s1.dialRanker = func(addrs []ma.Multiaddr) (res []network.AddrDelay) { + res = make([]network.AddrDelay, len(addrs)) + for i := 0; i < len(addrs); i++ { + delay := 10 * time.Second + if addrs[i].Equal(t1) { + //fire t1 immediately + delay = 0 + } else if addrs[i].Equal(t2) { + // delay t2 by 100ms + // without holepunch this call will not happen + delay = 100 * time.Millisecond + } + res[i] = network.AddrDelay{Addr: addrs[i], Delay: delay} + } + return + } + s1.Peerstore().AddAddrs(s2.LocalPeer(), []ma.Multiaddr{t1, t2}, peerstore.PermanentAddrTTL) + + reqch := make(chan dialRequest) + resch := make(chan dialResponse, 2) + + cl := newMockClock() + worker := newDialWorker(s1, s2.LocalPeer(), reqch, cl) + go worker.loop() + defer worker.wg.Wait() + defer close(reqch) + + reqch <- dialRequest{ctx: context.Background(), resch: resch} + <-recvCh // received connection on t1 + + select { + case <-resch: + t.Errorf("didn't expect connection to succeed") + case <-time.After(100 * time.Millisecond): + } + + hpCtx := network.WithSimultaneousConnect(context.Background(), true, "testing") + // with holepunch request, t2 will be dialed immediately + reqch <- dialRequest{ctx: hpCtx, resch: resch} + select { + case r := <-resch: + require.NoError(t, r.err) + case <-time.After(5 * time.Second): + t.Errorf("expected conn to succeed") + } + + select { + case r := <-resch: + require.NoError(t, r.err) + case <-time.After(5 * time.Second): + t.Errorf("expected conn to succeed") + } +} + func TestDialWorkerLoopAddrDedup(t *testing.T) { s1 := makeSwarm(t) s2 := makeSwarm(t) @@ -384,7 +1057,7 @@ func TestDialWorkerLoopAddrDedup(t *testing.T) { reqch := make(chan dialRequest) resch := make(chan dialResponse, 2) - worker := newDialWorker(s1, s2.LocalPeer(), reqch) + worker := newDialWorker(s1, s2.LocalPeer(), reqch, nil) go worker.loop() defer worker.wg.Wait() defer close(reqch) diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index bd2e9a2659..eaae4bcfc8 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -100,6 +100,23 @@ func WithResourceManager(m network.ResourceManager) Option { } } +// WithNoDialDelay configures swarm to dial all addresses for a peer without +// any delay +func WithNoDialDelay() Option { + return func(s *Swarm) error { + s.dialRanker = noDelayRanker + return nil + } +} + +// WithDialRanker configures swarm to use d as the DialRanker +func WithDialRanker(d network.DialRanker) Option { + return func(s *Swarm) error { + s.dialRanker = d + return nil + } +} + // Swarm is a connection muxer, allowing connections to other peers to // be opened and closed, while still using the same Chan for all // communication. The Chan sends/receives Messages, which note the @@ -163,6 +180,8 @@ type Swarm struct { bwc metrics.Reporter metricsTracer MetricsTracer + + dialRanker network.DialRanker } // NewSwarm constructs a Swarm. @@ -181,6 +200,7 @@ func NewSwarm(local peer.ID, peers peerstore.Peerstore, eventBus event.Bus, opts dialTimeout: defaultDialTimeout, dialTimeoutLocal: defaultDialTimeoutLocal, maResolver: madns.DefaultResolver, + dialRanker: DefaultDialRanker, } s.conns.m = make(map[peer.ID][]*Conn) diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index 49c0fc7fd9..7aa2befe7a 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -15,7 +15,6 @@ import ( ma "github.com/multiformats/go-multiaddr" madns "github.com/multiformats/go-multiaddr-dns" manet "github.com/multiformats/go-multiaddr/net" - "github.com/quic-go/quic-go" ) // The maximum number of address resolution steps we'll perform for a single @@ -295,7 +294,7 @@ func (s *Swarm) dialPeer(ctx context.Context, p peer.ID) (*Conn, error) { // dialWorkerLoop synchronizes and executes concurrent dials to a single peer func (s *Swarm) dialWorkerLoop(p peer.ID, reqch <-chan dialRequest) { - w := newDialWorker(s, p, reqch) + w := newDialWorker(s, p, reqch, nil) w.loop() } @@ -440,24 +439,36 @@ func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Mul lisAddrs, _ := s.InterfaceListenAddresses() var ourAddrs []ma.Multiaddr for _, addr := range lisAddrs { - protos := addr.Protocols() // we're only sure about filtering out /ip4 and /ip6 addresses, so far - if protos[0].Code == ma.P_IP4 || protos[0].Code == ma.P_IP6 { - ourAddrs = append(ourAddrs, addr) + ma.ForEach(addr, func(c ma.Component) bool { + if c.Protocol().Code == ma.P_IP4 || c.Protocol().Code == ma.P_IP6 { + ourAddrs = append(ourAddrs, addr) + } + return false + }) + } + + // Make a map of udp ports we are listening on to filter peers web transport addresses + ourLocalHostUDPPorts := make(map[string]bool, 2) + for _, a := range ourAddrs { + if !manet.IsIPLoopback(a) { + continue + } + if p, err := a.ValueForProtocol(ma.P_UDP); err == nil { + ourLocalHostUDPPorts[p] = true } } - return maybeRemoveWebTransportAddrs( - maybeRemoveQUICDraft29( - ma.FilterAddrs(addrs, - func(addr ma.Multiaddr) bool { return !ma.Contains(ourAddrs, addr) }, - s.canDial, - // TODO: Consider allowing link-local addresses - func(addr ma.Multiaddr) bool { return !manet.IsIP6LinkLocal(addr) }, - func(addr ma.Multiaddr) bool { - return s.gater == nil || s.gater.InterceptAddrDial(p, addr) - }, - ))) + return ma.FilterAddrs(addrs, + func(addr ma.Multiaddr) bool { return !ma.Contains(ourAddrs, addr) }, + func(addr ma.Multiaddr) bool { return checkLocalHostUDPAddrs(addr, ourLocalHostUDPPorts) }, + s.canDial, + // TODO: Consider allowing link-local addresses + func(addr ma.Multiaddr) bool { return !manet.IsIP6LinkLocal(addr) }, + func(addr ma.Multiaddr) bool { + return s.gater == nil || s.gater.InterceptAddrDial(p, addr) + }, + ) } // limitedDial will start a dial to the given peer when @@ -543,110 +554,20 @@ func isFdConsumingAddr(addr ma.Multiaddr) bool { return err1 == nil || err2 == nil } -func isExpensiveAddr(addr ma.Multiaddr) bool { - _, wsErr := addr.ValueForProtocol(ma.P_WS) - _, wssErr := addr.ValueForProtocol(ma.P_WSS) - _, wtErr := addr.ValueForProtocol(ma.P_WEBTRANSPORT) - return wsErr == nil || wssErr == nil || wtErr == nil -} - func isRelayAddr(addr ma.Multiaddr) bool { _, err := addr.ValueForProtocol(ma.P_CIRCUIT) return err == nil } -func isWebTransport(addr ma.Multiaddr) bool { - _, err := addr.ValueForProtocol(ma.P_WEBTRANSPORT) - return err == nil -} - -func quicVersion(addr ma.Multiaddr) (quic.VersionNumber, bool) { - found := false - foundWebTransport := false - var version quic.VersionNumber - ma.ForEach(addr, func(c ma.Component) bool { - switch c.Protocol().Code { - case ma.P_QUIC: - version = quic.VersionDraft29 - found = true - return true - case ma.P_QUIC_V1: - version = quic.Version1 - found = true - return true - case ma.P_WEBTRANSPORT: - version = quic.Version1 - foundWebTransport = true - return false - default: - return true - } - }) - if foundWebTransport { - return 0, false - } - return version, found -} - -// If we have QUIC addresses, we don't want to dial WebTransport addresses. -// It's better to have a native QUIC connection. -// Note that this is a hack. The correct solution would be a proper -// Happy-Eyeballs-style dialing. -func maybeRemoveWebTransportAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { - var hasQuic, hasWebTransport bool - for _, addr := range addrs { - if _, isQuic := quicVersion(addr); isQuic { - hasQuic = true - } - if isWebTransport(addr) { - hasWebTransport = true - } - } - if !hasWebTransport || !hasQuic { - return addrs - } - var c int - for _, addr := range addrs { - if isWebTransport(addr) { - continue - } - addrs[c] = addr - c++ - } - return addrs[:c] -} - -// If we have QUIC V1 addresses, we don't want to dial QUIC draft29 addresses. -// This is a similar hack to the above. If we add one more hack like this, let's -// define a `Filterer` interface like the `Resolver` interface that transports -// can optionally implement if they want to filter the multiaddrs. -// -// This mutates the input -func maybeRemoveQUICDraft29(addrs []ma.Multiaddr) []ma.Multiaddr { - var hasQuicV1, hasQuicDraft29 bool - for _, addr := range addrs { - v, isQuic := quicVersion(addr) - if !isQuic { - continue - } - - if v == quic.Version1 { - hasQuicV1 = true - } - if v == quic.VersionDraft29 { - hasQuicDraft29 = true - } - } - if !hasQuicDraft29 || !hasQuicV1 { - return addrs +// checkLocalHostUDPAddrs returns false for addresses that have the same localhost port +// as the one we are listening on +// This is useful for filtering out peer's localhost webtransport addresses. +func checkLocalHostUDPAddrs(addr ma.Multiaddr, ourUDPPorts map[string]bool) bool { + if !manet.IsIPLoopback(addr) { + return true } - var c int - for _, addr := range addrs { - if v, isQuic := quicVersion(addr); isQuic && v == quic.VersionDraft29 { - continue - } - addrs[c] = addr - c++ + if p, err := addr.ValueForProtocol(ma.P_UDP); err == nil { + return !ourUDPPorts[p] } - return addrs[:c] + return true } diff --git a/p2p/net/swarm/swarm_dial_test.go b/p2p/net/swarm/swarm_dial_test.go index 215ee6df9f..ce60701875 100644 --- a/p2p/net/swarm/swarm_dial_test.go +++ b/p2p/net/swarm/swarm_dial_test.go @@ -14,8 +14,11 @@ import ( "github.com/libp2p/go-libp2p/core/test" "github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem" + quic "github.com/libp2p/go-libp2p/p2p/transport/quic" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" "github.com/libp2p/go-libp2p/p2p/transport/tcp" "github.com/libp2p/go-libp2p/p2p/transport/websocket" + webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" ma "github.com/multiformats/go-multiaddr" madns "github.com/multiformats/go-multiaddr-dns" @@ -239,26 +242,35 @@ func TestAddrResolutionRecursive(t *testing.T) { require.Contains(t, addrs2, addr1) } -func TestRemoveWebTransportAddrs(t *testing.T) { - tcpAddr := ma.StringCast("/ip4/9.5.6.4/tcp/1234") - quicAddr := ma.StringCast("/ip4/1.2.3.4/udp/443/quic") - webtransportAddr := ma.StringCast("/ip4/1.2.3.4/udp/443/quic-v1/webtransport") +func TestLocalHostWebTransportRemoved(t *testing.T) { + resolver, err := madns.NewResolver() + if err != nil { + t.Fatal(err) + } - require.Equal(t, []ma.Multiaddr{tcpAddr, quicAddr}, maybeRemoveWebTransportAddrs([]ma.Multiaddr{tcpAddr, quicAddr})) - require.Equal(t, []ma.Multiaddr{tcpAddr, webtransportAddr}, maybeRemoveWebTransportAddrs([]ma.Multiaddr{tcpAddr, webtransportAddr})) - require.Equal(t, []ma.Multiaddr{tcpAddr, quicAddr}, maybeRemoveWebTransportAddrs([]ma.Multiaddr{tcpAddr, webtransportAddr, quicAddr})) - require.Equal(t, []ma.Multiaddr{quicAddr}, maybeRemoveWebTransportAddrs([]ma.Multiaddr{quicAddr, webtransportAddr})) - require.Equal(t, []ma.Multiaddr{webtransportAddr}, maybeRemoveWebTransportAddrs([]ma.Multiaddr{webtransportAddr})) -} + s := newTestSwarmWithResolver(t, resolver) + p, err := test.RandPeerID() + if err != nil { + t.Error(err) + } + reuse, err := quicreuse.NewConnManager([32]byte{}) + require.NoError(t, err) + defer reuse.Close() + + quicTr, err := quic.NewTransport(s.Peerstore().PrivKey(s.LocalPeer()), reuse, nil, nil, nil) + require.NoError(t, err) + require.NoError(t, s.AddTransport(quicTr)) -func TestRemoveQuicDraft29(t *testing.T) { - tcpAddr := ma.StringCast("/ip4/9.5.6.4/tcp/1234") - quicDraft29Addr := ma.StringCast("/ip4/1.2.3.4/udp/443/quic") - quicV1Addr := ma.StringCast("/ip4/1.2.3.4/udp/443/quic-v1") + webtransportTr, err := webtransport.New(s.Peerstore().PrivKey(s.LocalPeer()), nil, reuse, nil, nil) + require.NoError(t, err) + s.AddTransport(webtransportTr) - require.Equal(t, []ma.Multiaddr{tcpAddr, quicV1Addr}, maybeRemoveQUICDraft29([]ma.Multiaddr{tcpAddr, quicV1Addr})) - require.Equal(t, []ma.Multiaddr{tcpAddr, quicDraft29Addr}, maybeRemoveQUICDraft29([]ma.Multiaddr{tcpAddr, quicDraft29Addr})) - require.Equal(t, []ma.Multiaddr{tcpAddr, quicV1Addr}, maybeRemoveQUICDraft29([]ma.Multiaddr{tcpAddr, quicDraft29Addr, quicV1Addr})) - require.Equal(t, []ma.Multiaddr{quicV1Addr}, maybeRemoveQUICDraft29([]ma.Multiaddr{quicV1Addr, quicDraft29Addr})) - require.Equal(t, []ma.Multiaddr{quicDraft29Addr}, maybeRemoveQUICDraft29([]ma.Multiaddr{quicDraft29Addr})) + err = s.AddListenAddr(ma.StringCast("/ip4/127.0.0.1/udp/10000/quic-v1/")) + require.NoError(t, err) + + res := s.filterKnownUndialables(p, []ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/udp/10000/quic-v1/webtransport")}) + if len(res) != 0 { + t.Errorf("failed to filter localhost webtransport address") + } + s.Close() } diff --git a/p2p/net/swarm/swarm_metrics.go b/p2p/net/swarm/swarm_metrics.go index 95e4b78b88..992cc07bc0 100644 --- a/p2p/net/swarm/swarm_metrics.go +++ b/p2p/net/swarm/swarm_metrics.go @@ -69,6 +69,22 @@ var ( }, []string{"transport", "security", "muxer", "early_muxer", "ip_version"}, ) + dialsPerPeer = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: metricNamespace, + Name: "dials_per_peer_total", + Help: "Number of addresses dialed per peer", + }, + []string{"outcome", "num_dials"}, + ) + dialRankingDelay = prometheus.NewHistogram( + prometheus.HistogramOpts{ + Namespace: metricNamespace, + Name: "dial_ranking_delay_seconds", + Help: "delay introduced by the dial ranking logic", + Buckets: []float64{0.001, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.75, 1, 2}, + }, + ) collectors = []prometheus.Collector{ connsOpened, keyTypes, @@ -76,6 +92,8 @@ var ( dialError, connDuration, connHandshakeLatency, + dialsPerPeer, + dialRankingDelay, } ) @@ -84,6 +102,8 @@ type MetricsTracer interface { ClosedConnection(network.Direction, time.Duration, network.ConnectionState, ma.Multiaddr) CompletedHandshake(time.Duration, network.ConnectionState, ma.Multiaddr) FailedDialing(ma.Multiaddr, error) + DialCompleted(success bool, totalDials int) + DialRankingDelay(d time.Duration) } type metricsTracer struct{} @@ -213,3 +233,27 @@ func (m *metricsTracer) FailedDialing(addr ma.Multiaddr, err error) { *tags = append(*tags, getIPVersion(addr)) dialError.WithLabelValues(*tags...).Inc() } + +func (m *metricsTracer) DialCompleted(success bool, totalDials int) { + tags := metricshelper.GetStringSlice() + defer metricshelper.PutStringSlice(tags) + if success { + *tags = append(*tags, "success") + } else { + *tags = append(*tags, "failed") + } + + numDialLabels := [...]string{"0", "1", "2", "3", "4", "5", ">=6"} + var numDials string + if totalDials < len(numDialLabels) { + numDials = numDialLabels[totalDials] + } else { + numDials = numDialLabels[len(numDialLabels)-1] + } + *tags = append(*tags, numDials) + dialsPerPeer.WithLabelValues(*tags...).Inc() +} + +func (m *metricsTracer) DialRankingDelay(d time.Duration) { + dialRankingDelay.Observe(d.Seconds()) +} diff --git a/p2p/net/swarm/swarm_metrics_test.go b/p2p/net/swarm/swarm_metrics_test.go index 6b00da1a8a..0e13048a99 100644 --- a/p2p/net/swarm/swarm_metrics_test.go +++ b/p2p/net/swarm/swarm_metrics_test.go @@ -88,7 +88,9 @@ func TestMetricsNoAllocNoCover(t *testing.T) { "CompletedHandshake": func() { mt.CompletedHandshake(time.Duration(mrand.Intn(100))*time.Second, randItem(connections), randItem(addrs)) }, - "FailedDialing": func() { mt.FailedDialing(randItem(addrs), randItem(errors)) }, + "FailedDialing": func() { mt.FailedDialing(randItem(addrs), randItem(errors)) }, + "DialCompleted": func() { mt.DialCompleted(mrand.Intn(2) == 1, mrand.Intn(10)) }, + "DialRankingDelay": func() { mt.DialRankingDelay(time.Duration(mrand.Intn(1e10))) }, } for method, f := range tests {