diff --git a/pkg/app2/client.go b/pkg/app2/client.go index 4e2e5d754d..8b70f07fdc 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -146,11 +146,32 @@ func (c *Client) listen() error { } // TODO: handle field get gracefully - port := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:])) - if err := c.lm.addConn(port, stream); err != nil { + remotePort := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen*2+HSFramePortLen:])) + if err := c.lm.addConn(remotePort, stream); err != nil { c.logger.WithError(err).Error("failed to accept") continue } + + localPort := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:])) + + var localPK cipher.PubKey + copy(localPK[:], hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) + + respHSFrame := NewHSFrameDMSGAccept(c.pid, routing.Loop{ + Local: routing.Addr{ + PubKey: c.PK, + Port: remotePort, + }, + Remote: routing.Addr{ + PubKey: localPK, + Port: localPort, + }, + }) + + if _, err := stream.Write(respHSFrame); err != nil { + c.logger.WithError(err).Error("error responding with DmsgAccept") + continue + } } } diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go index 1a5c14b6ae..98f997411e 100644 --- a/pkg/app2/listener.go +++ b/pkg/app2/listener.go @@ -17,21 +17,33 @@ var ( ErrListenerClosed = errors.New("listener closed") ) +type acceptedConn struct { + remote routing.Addr + net.Conn +} + +func (c *acceptedConn) Addr() net.Addr { + return c.remote +} + type Listener struct { addr routing.Addr - conns chan net.Conn + conns chan *acceptedConn stopListening func(port routing.Port) error logger *logging.Logger lm *listenersManager + procID ProcID } -func NewListener(addr routing.Addr, lm *listenersManager, stopListening func(port routing.Port) error, l *logging.Logger) *Listener { +func NewListener(addr routing.Addr, lm *listenersManager, procID ProcID, + stopListening func(port routing.Port) error, l *logging.Logger) *Listener { return &Listener{ addr: addr, - conns: make(chan net.Conn, listenerBufSize), + conns: make(chan *acceptedConn, listenerBufSize), lm: lm, stopListening: stopListening, logger: l, + procID: procID, } } @@ -41,6 +53,15 @@ func (l *Listener) Accept() (net.Conn, error) { return nil, ErrListenerClosed } + hsFrame := NewHSFrameDMSGAccept(l.procID, routing.Loop{ + Local: l.addr, + Remote: conn.remote, + }) + + if _, err := conn.Write(hsFrame); err != nil { + return nil, err + } + return conn, nil } @@ -62,6 +83,6 @@ func (l *Listener) Addr() net.Addr { return l.addr } -func (l *Listener) addConn(conn net.Conn) { +func (l *Listener) addConn(conn *acceptedConn) { l.conns <- conn } diff --git a/pkg/app2/server.go b/pkg/app2/server.go index 1c24fd22c4..6b46889ed6 100644 --- a/pkg/app2/server.go +++ b/pkg/app2/server.go @@ -22,6 +22,7 @@ import ( ) type clientConn struct { + procID ProcID conn net.Conn session *yamux.Session lm *listenersManager @@ -122,6 +123,9 @@ func (s *Server) serveStream(stream net.Conn, conn *clientConn) error { return errors.Wrap(err, "error reading HS frame") } + // TODO: ensure thread-safety + conn.procID = hsFrame.ProcID() + var respHSFrame HSFrame switch hsFrame.FrameType() { case HSFrameTypeDMSGListen: @@ -223,42 +227,40 @@ func (c *clientConn) addListener(port routing.Port, l *dmsg.Listener) error { return ErrPortAlreadyBound } c.dmsgListeners[port] = l + go c.acceptDMSG(l) c.dmsgListenersMx.Unlock() return nil } -func (s *Server) handleDMSGListen(frame HSFrame) error { - var local routing.Addr - if err := frame.UnmarshalBody(&local); err != nil { - return errors.Wrap(err, "invalid JSON body") - } +func (c *clientConn) acceptDMSG(l *dmsg.Listener) error { + for { + stream, err := c.session.Open() + if err != nil { + return errors.Wrap(err, "error opening yamux stream") + } - // TODO: check `local` for validity + remoteAddr, ok := l.Addr().(dmsg.Addr) + if !ok { + // shouldn't happen, but still + return errors.Wrap(err, "wrong type for DMSG addr") + } - dmsgL, err := s.dmsgC.Listen(uint16(local.Port)) - if err != nil { - return fmt.Errorf("error listening on port %d: %v", local.Port, err) - } + hsFrame := NewHSFrameDSMGDial(c.procID, routing.Loop{ + Local: routing.Addr{ + PubKey: remoteAddr.PK, + Port: routing.Port(remoteAddr.Port), + }, + // TODO: get local addr + Remote: routing.Addr{ + PubKey: + }, + }) -} - -func (s *Server) handleDMSGListening(frame HSFrame) error { - var local routing.Addr - if err := frame.UnmarshalBody(&local); err != nil { - return errors.Wrap(err, "invalid JSON body") - } -} + conn, err := l.Accept() + if err != nil { + return errors.Wrap(err, "error accepting DMSG conn") + } -func (s *Server) handleDMSGDial(frame HSFrame) error { - var loop routing.Loop - if err := frame.UnmarshalBody(&loop); err != nil { - return errors.Wrap(err, "invalid JSON body") - } -} -func (s *Server) handleDMSGAccept(frame HSFrame) error { - var loop routing.Loop - if err := frame.UnmarshalBody(&loop); err != nil { - return errors.Wrap(err, "invalid JSON body") } }