diff --git a/pkg/app2/listeners_manager.go b/pkg/app2/listeners_manager.go index 251c8b2af0..a1c9925357 100644 --- a/pkg/app2/listeners_manager.go +++ b/pkg/app2/listeners_manager.go @@ -1,9 +1,14 @@ package app2 import ( + "encoding/binary" "net" "sync" + "sync/atomic" + "github.com/hashicorp/yamux" + + "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skycoin/src/util/logging" "github.com/pkg/errors" @@ -14,20 +19,35 @@ var ( ErrPortAlreadyBound = errors.New("port is already bound") ErrNoListenerOnPort = errors.New("no listener on port") ErrListenersManagerAlreadyServing = errors.New("listeners manager already serving") + ErrWrongPID = errors.New("wrong ProcID specified in the HS frame") ) type listenersManager struct { - listeners map[routing.Port]*Listener - mx sync.RWMutex - isServing int32 + pid ProcID + pk cipher.PubKey + listeners map[routing.Port]*Listener + mx sync.RWMutex + isListening int32 + logger *logging.Logger + doneCh chan struct{} + doneWg sync.WaitGroup } -func newListenersManager() *listenersManager { +func newListenersManager(l *logging.Logger, pid ProcID, pk cipher.PubKey) *listenersManager { return &listenersManager{ + pid: pid, + pk: pk, listeners: make(map[routing.Port]*Listener), + logger: l, + doneCh: make(chan struct{}), } } +func (lm *listenersManager) close() { + close(lm.doneCh) + lm.doneWg.Wait() +} + func (lm *listenersManager) portIsBound(port routing.Port) bool { lm.mx.RLock() _, ok := lm.listeners[port] @@ -41,7 +61,7 @@ func (lm *listenersManager) add(addr routing.Addr, stopListening func(port routi lm.mx.Unlock() return nil, ErrPortAlreadyBound } - l := NewListener(addr, lm, stopListening, logger) + l := NewListener(addr, lm, lm.pid, stopListening, logger) lm.listeners[addr.Port] = l lm.mx.Unlock() return l, nil @@ -58,13 +78,89 @@ func (lm *listenersManager) remove(port routing.Port) error { return nil } -func (lm *listenersManager) addConn(port routing.Port, conn net.Conn) error { +func (lm *listenersManager) addConn(localPort routing.Port, remote routing.Addr, conn net.Conn) error { lm.mx.RLock() - if _, ok := lm.listeners[port]; !ok { + if _, ok := lm.listeners[localPort]; !ok { lm.mx.RUnlock() return ErrNoListenerOnPort } - lm.listeners[port].addConn(conn) + lm.listeners[localPort].addConn(&acceptedConn{ + remote: remote, + Conn: conn, + }) lm.mx.RUnlock() return nil } + +func (lm *listenersManager) listen(session *yamux.Session) { + // this one should only start once + if !atomic.CompareAndSwapInt32(&lm.isListening, 0, 1) { + return + } + + lm.doneWg.Add(1) + + go func() { + defer lm.doneWg.Done() + + for { + select { + case <-lm.doneCh: + return + default: + stream, err := session.Accept() + if err != nil { + lm.logger.WithError(err).Error("error accepting stream") + return + } + + hsFrame, err := readHSFrame(stream) + if err != nil { + lm.logger.WithError(err).Error("error reading HS frame") + continue + } + + if hsFrame.ProcID() != lm.pid { + lm.logger.WithError(ErrWrongPID).Error("error listening for Dial") + } + + if hsFrame.FrameType() != HSFrameTypeDMSGDial { + lm.logger.WithError(ErrWrongHSFrameTypeReceived).Error("error listening for Dial") + continue + } + + // TODO: handle field get gracefully + remotePort := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen*2+HSFramePortLen:])) + localPort := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:])) + + var localPK cipher.PubKey + copy(localPK[:], hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) + + err = lm.addConn(remotePort, routing.Addr{ + PubKey: localPK, + Port: localPort, + }, stream) + if err != nil { + lm.logger.WithError(err).Error("failed to accept") + continue + } + + respHSFrame := NewHSFrameDMSGAccept(hsFrame.ProcID(), routing.Loop{ + Local: routing.Addr{ + PubKey: lm.pk, + Port: remotePort, + }, + Remote: routing.Addr{ + PubKey: localPK, + Port: localPort, + }, + }) + + if _, err := stream.Write(respHSFrame); err != nil { + lm.logger.WithError(err).Error("error responding with DmsgAccept") + continue + } + } + } + }() +}