From 1ef8abdb1132753e3a8b3f0d687c0c436b96fff6 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 17 Oct 2024 16:34:15 +0200 Subject: [PATCH] Apply new logic --- client/internal/engine.go | 25 ++- client/internal/peer/conn.go | 156 +++---------- client/internal/peer/conn_monitor.go | 212 ------------------ client/internal/peer/guard/guard.go | 194 ++++++++++++++++ client/internal/peer/guard/ice_monitor.go | 146 ++++++++++++ client/internal/peer/guard/sr_watcher.go | 84 +++++++ client/internal/peer/guard/stdnet.go | 11 + .../peer/{ => guard}/stdnet_android.go | 0 client/internal/peer/ice/agent.go | 89 ++++++++ client/internal/peer/ice/config.go | 22 ++ .../peer/{env_config.go => ice/env.go} | 14 +- client/internal/peer/{ => ice}/stdnet.go | 2 +- client/internal/peer/ice/stdnet_android.go | 7 + client/internal/peer/worker_ice.go | 107 +-------- relay/client/client.go | 26 ++- relay/client/manager.go | 21 ++ signal/client/client.go | 1 + signal/client/grpc.go | 14 ++ 18 files changed, 677 insertions(+), 454 deletions(-) delete mode 100644 client/internal/peer/conn_monitor.go create mode 100644 client/internal/peer/guard/guard.go create mode 100644 client/internal/peer/guard/ice_monitor.go create mode 100644 client/internal/peer/guard/sr_watcher.go create mode 100644 client/internal/peer/guard/stdnet.go rename client/internal/peer/{ => guard}/stdnet_android.go (100%) create mode 100644 client/internal/peer/ice/agent.go create mode 100644 client/internal/peer/ice/config.go rename client/internal/peer/{env_config.go => ice/env.go} (99%) rename client/internal/peer/{ => ice}/stdnet.go (94%) create mode 100644 client/internal/peer/ice/stdnet_android.go diff --git a/client/internal/engine.go b/client/internal/engine.go index 459518de136..c2af6962f97 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/netbirdio/netbird/client/internal/peer/guard" "maps" "math/rand" "net" @@ -23,14 +24,14 @@ import ( "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" - - "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" @@ -166,6 +167,8 @@ type Engine struct { checks []*mgmProto.Checks relayManager *relayClient.Manager + + srWatcher *guard.SRWatcher } // Peer is an instance of the Connection Peer @@ -369,6 +372,18 @@ func (e *Engine) Start() error { return fmt.Errorf("initialize dns server: %w", err) } + iceCfg := icemaker.Config{ + StunTurn: &e.stunTurn, + InterfaceBlackList: e.config.IFaceBlackList, + DisableIPv6Discovery: e.config.DisableIPv6Discovery, + UDPMux: e.udpMux.UDPMuxDefault, + UDPMuxSrflx: e.udpMux, + NATExternalIPs: e.parseNATExternalIPMappings(), + } + // todo: review the cancel event handling + e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) + e.srWatcher.Start(e.ctx) + e.receiveSignalEvents() e.receiveManagementEvents() e.receiveProbeEvents() @@ -951,7 +966,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e LocalWgPort: e.config.WgPort, RosenpassPubKey: e.getRosenpassPubKey(), RosenpassAddr: e.getRosenpassAddr(), - ICEConfig: peer.ICEConfig{ + ICEConfig: icemaker.Config{ StunTurn: &e.stunTurn, InterfaceBlackList: e.config.IFaceBlackList, DisableIPv6Discovery: e.config.DisableIPv6Discovery, @@ -961,7 +976,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e }, } - peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager) + peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher) if err != nil { return nil, err } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 99acfde314e..98247d741a8 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -10,7 +10,6 @@ import ( "sync" "time" - "github.com/cenkalti/backoff/v4" "github.com/pion/ice/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -18,6 +17,8 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/wgproxy" + "github.com/netbirdio/netbird/client/internal/peer/guard" + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" @@ -32,8 +33,6 @@ const ( connPriorityRelay ConnPriority = 1 connPriorityICETurn ConnPriority = 1 connPriorityICEP2P ConnPriority = 2 - - reconnectMaxElapsedTime = 30 * time.Minute ) type WgConfig struct { @@ -63,7 +62,7 @@ type ConnConfig struct { RosenpassAddr string // ICEConfig ICE protocol configuration - ICEConfig ICEConfig + ICEConfig icemaker.Config } type WorkerCallbacks struct { @@ -109,13 +108,12 @@ type Conn struct { // for reconnection operations iCEDisconnected chan bool relayDisconnected chan bool - connMonitor *ConnMonitor - reconnectCh <-chan struct{} + guard *guard.Guard } // NewConn creates a new not opened Conn to the remote peer. // To establish a connection run Conn.Open -func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager) (*Conn, error) { +func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) { allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps) if err != nil { log.Errorf("failed to parse allowedIPS: %v", err) @@ -124,6 +122,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu ctx, ctxCancel := context.WithCancel(engineCtx) connLog := log.WithField("peer", config.Key) + iCEDisconnected := make(chan bool, 1) + relayDisconnected := make(chan bool, 1) var conn = &Conn{ log: connLog, @@ -137,18 +137,10 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu allowedNet: allowedNet.String(), statusRelay: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(), - iCEDisconnected: make(chan bool, 1), - relayDisconnected: make(chan bool, 1), + iCEDisconnected: iCEDisconnected, + relayDisconnected: relayDisconnected, } - conn.connMonitor, conn.reconnectCh = NewConnMonitor( - signaler, - iFaceDiscover, - config, - conn.relayDisconnected, - conn.iCEDisconnected, - ) - rFns := WorkerRelayCallbacks{ OnConnReady: conn.relayConnectionIsReady, OnDisconnected: conn.onWorkerRelayStateDisconnected, @@ -174,6 +166,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) } + conn.guard = guard.NewGuard(connLog, true, conn.isConnected, conn.handshaker, config.Timeout, srWatcher, relayDisconnected, iCEDisconnected) + go conn.handshaker.Listen() return conn, nil @@ -200,24 +194,18 @@ func (conn *Conn) Open() { conn.log.Warnf("error while updating the state err: %v", err) } - go conn.startHandshakeAndReconnect() + go conn.startHandshakeAndReconnect(conn.ctx) } -func (conn *Conn) startHandshakeAndReconnect() { - conn.waitInitialRandomSleepTime() +func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) { + conn.waitInitialRandomSleepTime(ctx) err := conn.handshaker.sendOffer() if err != nil { conn.log.Errorf("failed to send initial offer: %v", err) } - go conn.connMonitor.Start(conn.ctx) - - if conn.workerRelay.IsController() { - conn.reconnectLoopWithRetry() - } else { - conn.reconnectLoopForOnDisconnectedEvent() - } + conn.guard.Start(ctx) } // Close closes this peer Conn issuing a close event to the Conn closeCh @@ -316,104 +304,6 @@ func (conn *Conn) GetKey() string { return conn.config.Key } -func (conn *Conn) reconnectLoopWithRetry() { - // Give chance to the peer to establish the initial connection. - // With it, we can decrease to send necessary offer - select { - case <-conn.ctx.Done(): - return - case <-time.After(3 * time.Second): - } - - ticker := conn.prepareExponentTicker() - defer ticker.Stop() - time.Sleep(1 * time.Second) - - for { - select { - case t := <-ticker.C: - if t.IsZero() { - // in case if the ticker has been canceled by context then avoid the temporary loop - return - } - - if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { - if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected { - conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE) - } - } else { - if conn.statusICE.Get() == StatusDisconnected { - conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE) - } - } - - // checks if there is peer connection is established via relay or ice - if conn.isConnected() { - continue - } - - err := conn.handshaker.sendOffer() - if err != nil { - conn.log.Errorf("failed to do handshake: %v", err) - } - - case <-conn.reconnectCh: - ticker.Stop() - ticker = conn.prepareExponentTicker() - - case <-conn.ctx.Done(): - conn.log.Debugf("context is done, stop reconnect loop") - return - } - } -} - -func (conn *Conn) prepareExponentTicker() *backoff.Ticker { - bo := backoff.WithContext(&backoff.ExponentialBackOff{ - InitialInterval: 800 * time.Millisecond, - RandomizationFactor: 0.1, - Multiplier: 2, - MaxInterval: conn.config.Timeout, - MaxElapsedTime: reconnectMaxElapsedTime, - Stop: backoff.Stop, - Clock: backoff.SystemClock, - }, conn.ctx) - - ticker := backoff.NewTicker(bo) - <-ticker.C // consume the initial tick what is happening right after the ticker has been created - - return ticker -} - -// reconnectLoopForOnDisconnectedEvent is used when the peer is not a controller and it should reconnect to the peer -// when the connection is lost. It will try to establish a connection only once time if before the connection was established -// It track separately the ice and relay connection status. Just because a lover priority connection reestablished it does not -// mean that to switch to it. We always force to use the higher priority connection. -func (conn *Conn) reconnectLoopForOnDisconnectedEvent() { - for { - select { - case changed := <-conn.relayDisconnected: - if !changed { - continue - } - conn.log.Debugf("Relay state changed, try to send new offer") - case changed := <-conn.iCEDisconnected: - if !changed { - continue - } - conn.log.Debugf("ICE state changed, try to send new offer") - case <-conn.ctx.Done(): - conn.log.Debugf("context is done, stop reconnect loop") - return - } - - err := conn.handshaker.SendOffer() - if err != nil { - conn.log.Errorf("failed to do handshake: %v", err) - } - } -} - // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) { conn.mu.Lock() @@ -693,7 +583,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd } } -func (conn *Conn) waitInitialRandomSleepTime() { +func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) { minWait := 100 maxWait := 800 duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond @@ -702,7 +592,7 @@ func (conn *Conn) waitInitialRandomSleepTime() { defer timeout.Stop() select { - case <-conn.ctx.Done(): + case <-ctx.Done(): case <-timeout.C: } } @@ -831,6 +721,18 @@ func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { } } +func (conn *Conn) logTraceConnState() { + if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { + if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected { + conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE) + } + } else { + if conn.statusICE.Get() == StatusDisconnected { + conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE) + } + } +} + func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { return remoteRosenpassPubKey != nil } diff --git a/client/internal/peer/conn_monitor.go b/client/internal/peer/conn_monitor.go deleted file mode 100644 index 75722c99011..00000000000 --- a/client/internal/peer/conn_monitor.go +++ /dev/null @@ -1,212 +0,0 @@ -package peer - -import ( - "context" - "fmt" - "sync" - "time" - - "github.com/pion/ice/v3" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/stdnet" -) - -const ( - signalerMonitorPeriod = 5 * time.Second - candidatesMonitorPeriod = 5 * time.Minute - candidateGatheringTimeout = 5 * time.Second -) - -type ConnMonitor struct { - signaler *Signaler - iFaceDiscover stdnet.ExternalIFaceDiscover - config ConnConfig - relayDisconnected chan bool - iCEDisconnected chan bool - reconnectCh chan struct{} - currentCandidates []ice.Candidate - candidatesMu sync.Mutex -} - -func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) { - reconnectCh := make(chan struct{}, 1) - cm := &ConnMonitor{ - signaler: signaler, - iFaceDiscover: iFaceDiscover, - config: config, - relayDisconnected: relayDisconnected, - iCEDisconnected: iCEDisconnected, - reconnectCh: reconnectCh, - } - return cm, reconnectCh -} - -func (cm *ConnMonitor) Start(ctx context.Context) { - signalerReady := make(chan struct{}, 1) - go cm.monitorSignalerReady(ctx, signalerReady) - - localCandidatesChanged := make(chan struct{}, 1) - go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged) - - for { - select { - case changed := <-cm.relayDisconnected: - if !changed { - continue - } - log.Debugf("Relay state changed, triggering reconnect") - cm.triggerReconnect() - - case changed := <-cm.iCEDisconnected: - if !changed { - continue - } - log.Debugf("ICE state changed, triggering reconnect") - cm.triggerReconnect() - - case <-signalerReady: - log.Debugf("Signaler became ready, triggering reconnect") - cm.triggerReconnect() - - case <-localCandidatesChanged: - log.Debugf("Local candidates changed, triggering reconnect") - cm.triggerReconnect() - - case <-ctx.Done(): - return - } - } -} - -func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) { - if cm.signaler == nil { - return - } - - ticker := time.NewTicker(signalerMonitorPeriod) - defer ticker.Stop() - - lastReady := true - for { - select { - case <-ticker.C: - currentReady := cm.signaler.Ready() - if !lastReady && currentReady { - select { - case signalerReady <- struct{}{}: - default: - } - } - lastReady = currentReady - case <-ctx.Done(): - return - } - } -} - -func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) { - ufrag, pwd, err := generateICECredentials() - if err != nil { - log.Warnf("Failed to generate ICE credentials: %v", err) - return - } - - ticker := time.NewTicker(candidatesMonitorPeriod) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil { - log.Warnf("Failed to handle candidate tick: %v", err) - } - case <-ctx.Done(): - return - } - } -} - -func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error { - log.Debugf("Gathering ICE candidates") - - transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList) - if err != nil { - log.Errorf("failed to create pion's stdnet: %s", err) - } - - agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd) - if err != nil { - return fmt.Errorf("create ICE agent: %w", err) - } - defer func() { - if err := agent.Close(); err != nil { - log.Warnf("Failed to close ICE agent: %v", err) - } - }() - - gatherDone := make(chan struct{}) - err = agent.OnCandidate(func(c ice.Candidate) { - log.Tracef("Got candidate: %v", c) - if c == nil { - close(gatherDone) - } - }) - if err != nil { - return fmt.Errorf("set ICE candidate handler: %w", err) - } - - if err := agent.GatherCandidates(); err != nil { - return fmt.Errorf("gather ICE candidates: %w", err) - } - - ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout) - defer cancel() - - select { - case <-ctx.Done(): - return fmt.Errorf("wait for gathering: %w", ctx.Err()) - case <-gatherDone: - } - - candidates, err := agent.GetLocalCandidates() - if err != nil { - return fmt.Errorf("get local candidates: %w", err) - } - log.Tracef("Got candidates: %v", candidates) - - if changed := cm.updateCandidates(candidates); changed { - select { - case localCandidatesChanged <- struct{}{}: - default: - } - } - - return nil -} - -func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool { - cm.candidatesMu.Lock() - defer cm.candidatesMu.Unlock() - - if len(cm.currentCandidates) != len(newCandidates) { - cm.currentCandidates = newCandidates - return true - } - - for i, candidate := range cm.currentCandidates { - if candidate.Address() != newCandidates[i].Address() { - cm.currentCandidates = newCandidates - return true - } - } - - return false -} - -func (cm *ConnMonitor) triggerReconnect() { - select { - case cm.reconnectCh <- struct{}{}: - default: - } -} diff --git a/client/internal/peer/guard/guard.go b/client/internal/peer/guard/guard.go new file mode 100644 index 00000000000..f51c5802794 --- /dev/null +++ b/client/internal/peer/guard/guard.go @@ -0,0 +1,194 @@ +package guard + +import ( + "context" + "math/rand" + "time" + + "github.com/cenkalti/backoff/v4" + log "github.com/sirupsen/logrus" +) + +const ( + reconnectMaxElapsedTime = 30 * time.Minute +) + +type handshake interface { + SendOffer() error // todo review to call sendOffer or SendOffer +} +type isConnectedFunc func() bool + +type Guard struct { + log *log.Entry + isController bool + isConnectedFn isConnectedFunc + timeout time.Duration + handshaker handshake + srWatcher *SRWatcher + relayedConnDisconnected chan bool + iCEConnDisconnected chan bool + iceMonitor *ICEMonitor +} + +func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, handshaker handshake, timeout time.Duration, srWatcher *SRWatcher, relayedConnDisconnected, iCEDisconnected chan bool) *Guard { + return &Guard{ + log: log, + isController: isController, + isConnectedFn: isConnectedFn, + timeout: timeout, + handshaker: handshaker, + srWatcher: srWatcher, + relayedConnDisconnected: relayedConnDisconnected, + iCEConnDisconnected: iCEDisconnected, + } +} + +func (g *Guard) Start(ctx context.Context) { + if g.isController { + g.reconnectLoopWithRetry(ctx) + } else { + g.listenForDisconnectEvents(ctx) + } +} + +// reconnectLoopWithRetry periodically check (max 30 min) the connection status with peer and try to reconnect if necessary +// If the Relay is connected but the ICE P2P not then it will trigger ICE connection offer +func (g *Guard) reconnectLoopWithRetry(ctx context.Context) { + // Give chance to the peer to establish the initial connection. + // With it, we can decrease to send necessary offer + select { + case <-ctx.Done(): + return + case <-time.After(3 * time.Second): + } + + srReconnectedChan := g.srWatcher.NewListener() + defer g.srWatcher.RemoveListener(srReconnectedChan) + + ticker := g.prepareExponentTicker(ctx) + defer ticker.Stop() + time.Sleep(1 * time.Second) + + for { + select { + case t := <-ticker.C: + if t.IsZero() { + // in case if the ticker has been canceled by context then avoid the temporary loop + return + } + g.logTraceConnState() + + if g.isConnectedFn() { + continue + } + + if err := g.handshaker.SendOffer(); err != nil { + g.log.Errorf("failed to do handshake: %v", err) + } + + case changed := <-g.relayedConnDisconnected: + if !changed { + continue + } + g.log.Debugf("Relay connection changed, triggering reconnect") + ticker.Stop() + ticker = g.prepareExponentTicker(ctx) + + case changed := <-g.iCEConnDisconnected: + if !changed { + continue + } + g.log.Debugf("ICE connection changed, triggering reconnect") + ticker.Stop() + ticker = g.prepareExponentTicker(ctx) + + case <-srReconnectedChan: + ticker.Stop() + ticker = g.prepareExponentTicker(ctx) + + case <-ctx.Done(): + g.log.Debugf("context is done, stop reconnect loop") + return + } + } +} + +// reconnectLoopForOnDisconnectedEvent is used when the peer is not a controller and it should reconnect to the peer +// when the connection is lost. It will try to establish a connection only once time if before the connection was established +// It track separately the ice and relay connection status. Just because a lower priority connection reestablished it does not +// mean that to switch to it. We always force to use the higher priority connection. +func (g *Guard) listenForDisconnectEvents(ctx context.Context) { + srReconnectedChan := g.srWatcher.NewListener() + defer g.srWatcher.RemoveListener(srReconnectedChan) + + for { + select { + case changed := <-g.relayedConnDisconnected: + if !changed { + continue + } + g.log.Debugf("Relay connection changed, triggering reconnect") + case changed := <-g.iCEConnDisconnected: + if !changed { + continue + } + g.log.Debugf("ICE state changed, try to send new offer") + case <-srReconnectedChan: + case <-ctx.Done(): + g.log.Debugf("context is done, stop reconnect loop") + return + } + + err := g.handshaker.SendOffer() + if err != nil { + g.log.Errorf("failed to do handshake: %v", err) + } + } +} + +func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker { + bo := backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: 800 * time.Millisecond, + RandomizationFactor: 0.1, + Multiplier: 2, + MaxInterval: g.timeout, + MaxElapsedTime: reconnectMaxElapsedTime, + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) + + ticker := backoff.NewTicker(bo) + <-ticker.C // consume the initial tick what is happening right after the ticker has been created + + return ticker +} + +func (g *Guard) waitInitialRandomSleepTime(ctx context.Context) { + minWait := 100 + maxWait := 800 + duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond + + timeout := time.NewTimer(duration) + defer timeout.Stop() + + select { + case <-ctx.Done(): + case <-timeout.C: + } +} + +func (g *Guard) logTraceConnState() { + //todo: implement me + /* + if g.workerRelay.IsRelayConnectionSupportedWithPeer() { + if g.statusRelay.Get() == StatusDisconnected || g.statusICE.Get() == StatusDisconnected { + g.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", g.statusRelay, g.statusICE) + } + } else { + if g.statusICE.Get() == StatusDisconnected { + g.log.Tracef("connectivity guard timedout, ice state: %s", g.statusICE) + } + } + + */ +} diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go new file mode 100644 index 00000000000..50bdfe4e53b --- /dev/null +++ b/client/internal/peer/guard/ice_monitor.go @@ -0,0 +1,146 @@ +package guard + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/pion/ice/v3" + log "github.com/sirupsen/logrus" + + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +const ( + candidatesMonitorPeriod = 5 * time.Minute + candidateGatheringTimeout = 5 * time.Second +) + +type ICEMonitor struct { + ReconnectCh chan struct{} + + iFaceDiscover stdnet.ExternalIFaceDiscover + iceConfig icemaker.Config + + currentCandidates []ice.Candidate + candidatesMu sync.Mutex +} + +func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor { + cm := &ICEMonitor{ + ReconnectCh: make(chan struct{}, 1), + iFaceDiscover: iFaceDiscover, + iceConfig: config, + } + return cm +} + +func (cm *ICEMonitor) Start(ctx context.Context) { + go cm.monitorLocalCandidatesChanged(ctx) +} + +func (cm *ICEMonitor) monitorLocalCandidatesChanged(ctx context.Context) { + ufrag, pwd, err := icemaker.GenerateICECredentials() + if err != nil { + log.Warnf("Failed to generate ICE credentials: %v", err) + return + } + + ticker := time.NewTicker(candidatesMonitorPeriod) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + changed, err := cm.handleCandidateTick(ctx, ufrag, pwd) + if err != nil { + log.Warnf("Failed to handle candidate tick: %v", err) + continue + } + + if changed { + cm.triggerReconnect() + } + case <-ctx.Done(): + return + } + } +} + +func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) { + log.Debugf("Gathering ICE candidates") + + agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd) + if err != nil { + return false, fmt.Errorf("create ICE agent: %w", err) + } + defer func() { + if err := agent.Close(); err != nil { + log.Warnf("Failed to close ICE agent: %v", err) + } + }() + + gatherDone := make(chan struct{}) + err = agent.OnCandidate(func(c ice.Candidate) { + log.Tracef("Got candidate: %v", c) + if c == nil { + close(gatherDone) + } + }) + if err != nil { + return false, fmt.Errorf("set ICE candidate handler: %w", err) + } + + if err := agent.GatherCandidates(); err != nil { + return false, fmt.Errorf("gather ICE candidates: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout) + defer cancel() + + select { + case <-ctx.Done(): + return false, fmt.Errorf("wait for gathering: %w", ctx.Err()) + case <-gatherDone: + } + + candidates, err := agent.GetLocalCandidates() + if err != nil { + return false, fmt.Errorf("get local candidates: %w", err) + } + log.Tracef("Got candidates: %v", candidates) + + return cm.updateCandidates(candidates), nil +} + +func (cm *ICEMonitor) updateCandidates(newCandidates []ice.Candidate) bool { + cm.candidatesMu.Lock() + defer cm.candidatesMu.Unlock() + + if len(cm.currentCandidates) != len(newCandidates) { + cm.currentCandidates = newCandidates + return true + } + + for i, candidate := range cm.currentCandidates { + if candidate.Address() != newCandidates[i].Address() { + cm.currentCandidates = newCandidates + return true + } + } + + return false +} + +func (cm *ICEMonitor) triggerReconnect() { + select { + case cm.ReconnectCh <- struct{}{}: + default: + } +} + +func candidateTypesP2P() []ice.CandidateType { + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} +} diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go new file mode 100644 index 00000000000..0d5e7e2db7b --- /dev/null +++ b/client/internal/peer/guard/sr_watcher.go @@ -0,0 +1,84 @@ +package guard + +import ( + "context" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +type chNotifier interface { + SetOnReconnectedListener(func()) + Ready() bool +} + +type SRWatcher struct { + signalClient chNotifier + relayManager chNotifier + + listeners map[chan struct{}]struct{} + mu sync.Mutex + iFaceDiscover stdnet.ExternalIFaceDiscover + iceConfig ice.Config +} + +// NewSRWatcher todo: implement cancel function in thread safe way. The context cancle is dangerous because during an +// engine restart maybe we overwrite the new listeners in signal and relayManager +func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscover stdnet.ExternalIFaceDiscover, iceConfig ice.Config) *SRWatcher { + srw := &SRWatcher{ + signalClient: signalClient, + relayManager: relayManager, + iFaceDiscover: iFaceDiscover, + iceConfig: iceConfig, + } + return srw +} + +func (w *SRWatcher) Start(ctx context.Context) { + iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig) + go iceMonitor.Start(ctx) + // todo read iceMonitor.ReconnectCh + + w.signalClient.SetOnReconnectedListener(w.onReconnected) + w.relayManager.SetOnReconnectedListener(w.onReconnected) +} + +func (w *SRWatcher) onReconnected() { + if !w.signalClient.Ready() { + return + } + if w.relayManager.Ready() { + return + } + w.notify() +} + +func (w *SRWatcher) NewListener() chan struct{} { + w.mu.Lock() + defer w.mu.Unlock() + + listenerChan := make(chan struct{}, 1) + w.listeners[listenerChan] = struct{}{} + return listenerChan +} + +func (w *SRWatcher) RemoveListener(listenerChan chan struct{}) { + w.mu.Lock() + defer w.mu.Unlock() + delete(w.listeners, listenerChan) + close(listenerChan) +} + +func (w *SRWatcher) notify() { + log.Infof("------ Siganl or relay reconnected!") + w.mu.Lock() + defer w.mu.Unlock() + for listener := range w.listeners { + select { + case listener <- struct{}{}: + } + } +} diff --git a/client/internal/peer/guard/stdnet.go b/client/internal/peer/guard/stdnet.go new file mode 100644 index 00000000000..a91730f42d3 --- /dev/null +++ b/client/internal/peer/guard/stdnet.go @@ -0,0 +1,11 @@ +//go:build !android + +package guard + +import ( + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNet(ifaceBlacklist) +} diff --git a/client/internal/peer/stdnet_android.go b/client/internal/peer/guard/stdnet_android.go similarity index 100% rename from client/internal/peer/stdnet_android.go rename to client/internal/peer/guard/stdnet_android.go diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go new file mode 100644 index 00000000000..b2a9669367e --- /dev/null +++ b/client/internal/peer/ice/agent.go @@ -0,0 +1,89 @@ +package ice + +import ( + "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/pion/ice/v3" + "github.com/pion/randutil" + "github.com/pion/stun/v2" + log "github.com/sirupsen/logrus" + "runtime" + "time" +) + +const ( + lenUFrag = 16 + lenPwd = 32 + runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + iceKeepAliveDefault = 4 * time.Second + iceDisconnectedTimeoutDefault = 6 * time.Second + // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package + iceRelayAcceptanceMinWaitDefault = 2 * time.Second +) + +var ( + failedTimeout = 6 * time.Second +) + +func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { + iceKeepAlive := iceKeepAlive() + iceDisconnectedTimeout := iceDisconnectedTimeout() + iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() + + transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList) + if err != nil { + log.Errorf("failed to create pion's stdnet: %s", err) + } + + agentConfig := &ice.AgentConfig{ + MulticastDNSMode: ice.MulticastDNSModeDisabled, + NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, + Urls: config.StunTurn.Load().([]*stun.URI), + CandidateTypes: candidateTypes, + InterfaceFilter: stdnet.InterfaceFilter(config.InterfaceBlackList), + UDPMux: config.UDPMux, + UDPMuxSrflx: config.UDPMuxSrflx, + NAT1To1IPs: config.NATExternalIPs, + Net: transportNet, + FailedTimeout: &failedTimeout, + DisconnectedTimeout: &iceDisconnectedTimeout, + KeepaliveInterval: &iceKeepAlive, + RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, + LocalUfrag: ufrag, + LocalPwd: pwd, + } + + if config.DisableIPv6Discovery { + agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} + } + + return ice.NewAgent(agentConfig) +} + +func GenerateICECredentials() (string, string, error) { + ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha) + if err != nil { + return "", "", err + } + + pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha) + if err != nil { + return "", "", err + } + return ufrag, pwd, nil +} + +func CandidateTypes() []ice.CandidateType { + if hasICEForceRelayConn() { + return []ice.CandidateType{ice.CandidateTypeRelay} + } + // TODO: remove this once we have refactored userspace proxy into the bind package + if runtime.GOOS == "ios" { + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} + } + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay} +} + +func CandidateTypesP2P() []ice.CandidateType { + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} +} diff --git a/client/internal/peer/ice/config.go b/client/internal/peer/ice/config.go new file mode 100644 index 00000000000..8abc842f0d2 --- /dev/null +++ b/client/internal/peer/ice/config.go @@ -0,0 +1,22 @@ +package ice + +import ( + "sync/atomic" + + "github.com/pion/ice/v3" +) + +type Config struct { + // StunTurn is a list of STUN and TURN URLs + StunTurn *atomic.Value // []*stun.URI + + // InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering + // (e.g. if eth0 is in the list, host candidate of this interface won't be used) + InterfaceBlackList []string + DisableIPv6Discovery bool + + UDPMux ice.UDPMux + UDPMuxSrflx ice.UniversalUDPMux + + NATExternalIPs []string +} diff --git a/client/internal/peer/env_config.go b/client/internal/peer/ice/env.go similarity index 99% rename from client/internal/peer/env_config.go rename to client/internal/peer/ice/env.go index 87b626df763..5a3733d8c8b 100644 --- a/client/internal/peer/env_config.go +++ b/client/internal/peer/ice/env.go @@ -1,4 +1,4 @@ -package peer +package ice import ( "os" @@ -10,12 +10,17 @@ import ( ) const ( + envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN" envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC" envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC" envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC" - envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN" ) +func hasICEForceRelayConn() bool { + disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn) + return strings.ToLower(disconnectedTimeoutEnv) == "true" +} + func iceKeepAlive() time.Duration { keepAliveEnv := os.Getenv(envICEKeepAliveIntervalSec) if keepAliveEnv == "" { @@ -63,8 +68,3 @@ func iceRelayAcceptanceMinWait() time.Duration { return time.Duration(disconnectedTimeoutSec) * time.Second } - -func hasICEForceRelayConn() bool { - disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn) - return strings.ToLower(disconnectedTimeoutEnv) == "true" -} diff --git a/client/internal/peer/stdnet.go b/client/internal/peer/ice/stdnet.go similarity index 94% rename from client/internal/peer/stdnet.go rename to client/internal/peer/ice/stdnet.go index 96d211dbc77..3ce83727e6e 100644 --- a/client/internal/peer/stdnet.go +++ b/client/internal/peer/ice/stdnet.go @@ -1,6 +1,6 @@ //go:build !android -package peer +package ice import ( "github.com/netbirdio/netbird/client/internal/stdnet" diff --git a/client/internal/peer/ice/stdnet_android.go b/client/internal/peer/ice/stdnet_android.go new file mode 100644 index 00000000000..84c665e6f40 --- /dev/null +++ b/client/internal/peer/ice/stdnet_android.go @@ -0,0 +1,7 @@ +package ice + +import "github.com/netbirdio/netbird/client/internal/stdnet" + +func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist) +} diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index c86c1858fdc..55894218d73 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -5,52 +5,20 @@ import ( "fmt" "net" "net/netip" - "runtime" "sync" - "sync/atomic" "time" "github.com/pion/ice/v3" - "github.com/pion/randutil" "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/bind" + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/route" ) -const ( - iceKeepAliveDefault = 4 * time.Second - iceDisconnectedTimeoutDefault = 6 * time.Second - // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package - iceRelayAcceptanceMinWaitDefault = 2 * time.Second - - lenUFrag = 16 - lenPwd = 32 - runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" -) - -var ( - failedTimeout = 6 * time.Second -) - -type ICEConfig struct { - // StunTurn is a list of STUN and TURN URLs - StunTurn *atomic.Value // []*stun.URI - - // InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering - // (e.g. if eth0 is in the list, host candidate of this interface won't be used) - InterfaceBlackList []string - DisableIPv6Discovery bool - - UDPMux ice.UDPMux - UDPMuxSrflx ice.UniversalUDPMux - - NATExternalIPs []string -} - type ICEConnInfo struct { RemoteConn net.Conn RosenpassPubKey []byte @@ -103,7 +71,7 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signal conn: callBacks, } - localUfrag, localPwd, err := generateICECredentials() + localUfrag, localPwd, err := icemaker.GenerateICECredentials() if err != nil { return nil, err } @@ -125,10 +93,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { var preferredCandidateTypes []ice.CandidateType if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { w.selectedPriority = connPriorityICEP2P - preferredCandidateTypes = candidateTypesP2P() + preferredCandidateTypes = icemaker.CandidateTypesP2P() } else { w.selectedPriority = connPriorityICETurn - preferredCandidateTypes = candidateTypes() + preferredCandidateTypes = icemaker.CandidateTypes() } w.log.Debugf("recreate ICE agent") @@ -232,15 +200,10 @@ func (w *WorkerICE) Close() { } } -func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) { - transportNet, err := newStdNet(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList) - if err != nil { - w.log.Errorf("failed to create pion's stdnet: %s", err) - } - +func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []ice.CandidateType) (*ice.Agent, error) { w.sentExtraSrflx = false - agent, err := newAgent(w.config, transportNet, relaySupport, w.localUfrag, w.localPwd) + agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) if err != nil { return nil, fmt.Errorf("create agent: %w", err) } @@ -365,36 +328,6 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferA } } -func newAgent(config ConnConfig, transportNet *stdnet.Net, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { - iceKeepAlive := iceKeepAlive() - iceDisconnectedTimeout := iceDisconnectedTimeout() - iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() - - agentConfig := &ice.AgentConfig{ - MulticastDNSMode: ice.MulticastDNSModeDisabled, - NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, - Urls: config.ICEConfig.StunTurn.Load().([]*stun.URI), - CandidateTypes: candidateTypes, - InterfaceFilter: stdnet.InterfaceFilter(config.ICEConfig.InterfaceBlackList), - UDPMux: config.ICEConfig.UDPMux, - UDPMuxSrflx: config.ICEConfig.UDPMuxSrflx, - NAT1To1IPs: config.ICEConfig.NATExternalIPs, - Net: transportNet, - FailedTimeout: &failedTimeout, - DisconnectedTimeout: &iceDisconnectedTimeout, - KeepaliveInterval: &iceKeepAlive, - RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, - LocalUfrag: ufrag, - LocalPwd: pwd, - } - - if config.ICEConfig.DisableIPv6Discovery { - agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} - } - - return ice.NewAgent(agentConfig) -} - func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { relatedAdd := candidate.RelatedAddress() return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ @@ -435,21 +368,6 @@ func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool return false } -func candidateTypes() []ice.CandidateType { - if hasICEForceRelayConn() { - return []ice.CandidateType{ice.CandidateTypeRelay} - } - // TODO: remove this once we have refactored userspace proxy into the bind package - if runtime.GOOS == "ios" { - return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} - } - return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay} -} - -func candidateTypesP2P() []ice.CandidateType { - return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} -} - func isRelayCandidate(candidate ice.Candidate) bool { return candidate.Type() == ice.CandidateTypeRelay } @@ -460,16 +378,3 @@ func isRelayed(pair *ice.CandidatePair) bool { } return false } - -func generateICECredentials() (string, string, error) { - ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha) - if err != nil { - return "", "", err - } - - pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha) - if err != nil { - return "", "", err - } - return ufrag, pwd, nil -} diff --git a/relay/client/client.go b/relay/client/client.go index 90bc3ac418f..bcdf7a99504 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -142,6 +142,7 @@ type Client struct { muInstanceURL sync.Mutex onDisconnectListener func() + onConnectedListener func() listenerMutex sync.Mutex } @@ -190,6 +191,7 @@ func (c *Client) Connect() error { go c.readLoop(c.relayConn) c.log.Infof("relay connection established") + go c.notifyConnected() return nil } @@ -236,6 +238,12 @@ func (c *Client) SetOnDisconnectListener(fn func()) { c.onDisconnectListener = fn } +func (c *Client) SetOnConnectedListener(fn func()) { + c.listenerMutex.Lock() + defer c.listenerMutex.Unlock() + c.onConnectedListener = fn +} + // HasConns returns true if there are connections. func (c *Client) HasConns() bool { c.mu.Lock() @@ -243,6 +251,12 @@ func (c *Client) HasConns() bool { return len(c.conns) > 0 } +func (c *Client) Ready() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.serviceIsRunning +} + // Close closes the connection to the relay server and all connections to other peers. func (c *Client) Close() error { return c.close(true) @@ -361,9 +375,9 @@ func (c *Client) readLoop(relayConn net.Conn) { c.instanceURL = nil c.muInstanceURL.Unlock() - c.notifyDisconnected() c.wgReadLoop.Done() _ = c.close(false) + c.notifyDisconnected() } func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) { @@ -542,6 +556,16 @@ func (c *Client) notifyDisconnected() { go c.onDisconnectListener() } +func (c *Client) notifyConnected() { + c.listenerMutex.Lock() + defer c.listenerMutex.Unlock() + + if c.onConnectedListener == nil { + return + } + go c.onConnectedListener() +} + func (c *Client) writeCloseMsg() { msg := messages.MarshalCloseMsg() _, err := c.relayConn.Write(msg) diff --git a/relay/client/manager.go b/relay/client/manager.go index 4554c7c0f6e..3981415fcd4 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -65,6 +65,7 @@ type Manager struct { relayClientsMutex sync.RWMutex onDisconnectedListeners map[string]*list.List + onReconnectedListenerFn func() listenerLock sync.Mutex } @@ -101,6 +102,7 @@ func (m *Manager) Serve() error { m.relayClient = client m.reconnectGuard = NewGuard(m.ctx, m.relayClient) + m.relayClient.SetOnConnectedListener(m.onServerConnected) m.relayClient.SetOnDisconnectListener(func() { m.onServerDisconnected(client.connectionURL) }) @@ -138,6 +140,18 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { return netConn, err } +// Ready returns true if the home Relay client is connected to the relay server. +func (m *Manager) Ready() bool { + if m.relayClient == nil { + return false + } + return m.relayClient.Ready() +} + +func (m *Manager) SetOnReconnectedListener(f func()) { + m.onReconnectedListenerFn = f +} + // AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection // closed. func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { @@ -240,6 +254,13 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { return conn, nil } +func (m *Manager) onServerConnected() { + if m.onReconnectedListenerFn == nil { + return + } + go m.onReconnectedListenerFn() +} + func (m *Manager) onServerDisconnected(serverAddress string) { if serverAddress == m.relayClient.connectionURL { go m.reconnectGuard.OnDisconnected() diff --git a/signal/client/client.go b/signal/client/client.go index ced3fb7d0eb..eff1ccb8794 100644 --- a/signal/client/client.go +++ b/signal/client/client.go @@ -35,6 +35,7 @@ type Client interface { WaitStreamConnected() SendToStream(msg *proto.EncryptedMessage) error Send(msg *proto.Message) error + SetOnReconnectedListener(func()) } // UnMarshalCredential parses the credentials from the message and returns a Credential instance diff --git a/signal/client/grpc.go b/signal/client/grpc.go index 7a3b502ffc6..2ff84e46075 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -43,6 +43,8 @@ type GrpcClient struct { connStateCallback ConnStateNotifier connStateCallbackLock sync.RWMutex + + onReconnectedListenerFn func() } func (c *GrpcClient) StreamConnected() bool { @@ -181,12 +183,17 @@ func (c *GrpcClient) notifyStreamDisconnected() { func (c *GrpcClient) notifyStreamConnected() { c.mux.Lock() defer c.mux.Unlock() + c.status = StreamConnected if c.connectedCh != nil { // there are goroutines waiting on this channel -> release them close(c.connectedCh) c.connectedCh = nil } + + if c.onReconnectedListenerFn != nil { + c.onReconnectedListenerFn() + } } func (c *GrpcClient) getStreamStatusChan() <-chan struct{} { @@ -271,6 +278,13 @@ func (c *GrpcClient) WaitStreamConnected() { } } +func (c *GrpcClient) SetOnReconnectedListener(fn func()) { + c.mux.Lock() + defer c.mux.Unlock() + + c.onReconnectedListenerFn = fn +} + // SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server // The GrpcClient.Receive method must be called before sending messages to establish initial connection to the Signal Exchange // GrpcClient.connWg can be used to wait