diff --git a/pkg/app2/client.go b/pkg/app2/client.go index 0ba012dcc1..d17cb923db 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -1,10 +1,14 @@ package app2 import ( + "encoding/binary" "net" + "github.com/hashicorp/yamux" + "github.com/pkg/errors" + "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/pkg/routing" "github.com/skycoin/dmsg/cipher" @@ -14,29 +18,15 @@ var ( ErrWrongHSFrameTypeReceived = errors.New("received wrong HS frame type") ) -type Listener struct { - conn net.Conn -} - -func (l *Listener) Accept() (net.Conn, error) { - hsFrame, err := readHSFrame(l.conn) - if err != nil { - return nil, errors.Wrap(err, "error reading HS frame") - } - - if hsFrame.FrameType() != HSFrameTypeDMSGAccept { - return nil, ErrWrongHSFrameTypeReceived - } - - return l.conn, nil -} - // Client is used by skywire apps. type Client struct { PK cipher.PubKey pid ProcID sockAddr string conn net.Conn + session *yamux.Session + logger *logging.Logger + lm *listenersManager } // NewClient creates a new Client. The Client needs to be provided with: @@ -44,17 +34,30 @@ type Client struct { // - pid: The procID assigned for the process that Client is being used by. // - sockAddr: The socket address to connect to Server. func NewClient(localPK cipher.PubKey, pid ProcID, sockAddr string) (*Client, error) { + conn, err := net.Dial("unix", sockAddr) + if err != nil { + return nil, errors.Wrap(err, "error connecting app server") + } + + session, err := yamux.Client(conn, nil) + if err != nil { + return nil, errors.Wrap(err, "error opening yamux session") + } + return &Client{ PK: localPK, pid: pid, sockAddr: sockAddr, + conn: conn, + session: session, + lm: newListenersManager(), }, nil } func (c *Client) Dial(addr routing.Addr) (net.Conn, error) { - conn, err := net.Dial("unix", c.sockAddr) + stream, err := c.session.Open() if err != nil { - return nil, errors.Wrap(err, "error connecting app server") + return nil, errors.Wrap(err, "error opening stream") } hsFrame := NewHSFrameDSMGDial(c.pid, routing.Loop{ @@ -63,11 +66,12 @@ func (c *Client) Dial(addr routing.Addr) (net.Conn, error) { }, Remote: addr, }) - if _, err := conn.Write(hsFrame); err != nil { + + if _, err := stream.Write(hsFrame); err != nil { return nil, errors.Wrap(err, "error writing HS frame") } - hsFrame, err = readHSFrame(conn) + hsFrame, err = readHSFrame(stream) if err != nil { return nil, errors.Wrap(err, "error reading HS frame") } @@ -76,21 +80,30 @@ func (c *Client) Dial(addr routing.Addr) (net.Conn, error) { return nil, ErrWrongHSFrameTypeReceived } - return conn, nil + return stream, nil } -func (c *Client) Listen(addr routing.Addr) (*Listener, error) { - conn, err := net.Dial("unix", c.sockAddr) +func (c *Client) Listen(port routing.Port) (*Listener, error) { + if c.lm.portIsBound(port) { + return nil, ErrPortAlreadyBound + } + + stream, err := c.session.Open() if err != nil { - return nil, errors.Wrap(err, "error connecting app server") + return nil, errors.Wrap(err, "error opening stream") + } + + addr := routing.Addr{ + PubKey: c.PK, + Port: port, } hsFrame := NewHSFrameDMSGListen(c.pid, addr) - if _, err := conn.Write(hsFrame); err != nil { + if _, err := stream.Write(hsFrame); err != nil { return nil, errors.Wrap(err, "error writing HS frame") } - hsFrame, err = readHSFrame(conn) + hsFrame, err = readHSFrame(stream) if err != nil { return nil, errors.Wrap(err, "error reading HS frame") } @@ -99,7 +112,37 @@ func (c *Client) Listen(addr routing.Addr) (*Listener, error) { return nil, ErrWrongHSFrameTypeReceived } - return &Listener{ - conn: conn, - }, nil + l := NewListener(addr, c.lm) + if err := c.lm.add(port, l); err != nil { + return nil, err + } + + return l, nil +} + +func (c *Client) listen() error { + for { + stream, err := c.session.Accept() + if err != nil { + return errors.Wrap(err, "error accepting stream") + } + + hsFrame, err := readHSFrame(stream) + if err != nil { + c.logger.WithError(err).Error("error reading HS frame") + continue + } + + if hsFrame.FrameType() != HSFrameTypeDMSGDial { + c.logger.WithError(ErrWrongHSFrameTypeReceived).Error("on listening for Dial") + continue + } + + // TODO: handle field get gracefully + port := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:])) + if err := c.lm.addConn(port, stream); err != nil { + c.logger.WithError(err).Error("failed to accept") + continue + } + } } diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go new file mode 100644 index 0000000000..069fcdb0a9 --- /dev/null +++ b/pkg/app2/listener.go @@ -0,0 +1,58 @@ +package app2 + +import ( + "errors" + "net" + + "github.com/skycoin/skywire/pkg/routing" +) + +const ( + listenerBufSize = 1000 +) + +var ( + ErrListenerClosed = errors.New("listener closed") +) + +type Listener struct { + addr routing.Addr + conns chan net.Conn + lm *listenersManager +} + +func NewListener(addr routing.Addr, lm *listenersManager) *Listener { + return &Listener{ + addr: addr, + conns: make(chan net.Conn, listenerBufSize), + lm: lm, + } +} + +func (l *Listener) Accept() (net.Conn, error) { + conn, ok := <-l.conns + if !ok { + return nil, ErrListenerClosed + } + + return conn, nil +} + +func (l *Listener) Close() error { + if err := l.lm.remove(l.addr.Port); err != nil { + return err + } + + // TODO: send ListenEnd frame + close(l.conns) + + return nil +} + +func (l *Listener) Addr() net.Addr { + return l.addr +} + +func (l *Listener) addConn(conn net.Conn) { + l.conns <- conn +} diff --git a/pkg/app2/listeners_manager.go b/pkg/app2/listeners_manager.go new file mode 100644 index 0000000000..b5bc8e3b83 --- /dev/null +++ b/pkg/app2/listeners_manager.go @@ -0,0 +1,65 @@ +package app2 + +import ( + "net" + "sync" + + "github.com/pkg/errors" + "github.com/skycoin/skywire/pkg/routing" +) + +var ( + ErrPortAlreadyBound = errors.New("port is already bound") + ErrNoListenerOnPort = errors.New("no listener on port") +) + +type listenersManager struct { + listeners map[routing.Port]*Listener + mx sync.RWMutex +} + +func newListenersManager() *listenersManager { + return &listenersManager{ + listeners: make(map[routing.Port]*Listener), + } +} + +func (lm *listenersManager) portIsBound(port routing.Port) bool { + lm.mx.RLock() + _, ok := lm.listeners[port] + lm.mx.RUnlock() + return ok +} + +func (lm *listenersManager) add(port routing.Port, l *Listener) error { + lm.mx.Lock() + if _, ok := lm.listeners[port]; ok { + lm.mx.Unlock() + return ErrPortAlreadyBound + } + lm.listeners[port] = l + lm.mx.Unlock() + return nil +} + +func (lm *listenersManager) remove(port routing.Port) error { + lm.mx.Lock() + if _, ok := lm.listeners[port]; !ok { + lm.mx.Unlock() + return ErrNoListenerOnPort + } + delete(lm.listeners, port) + lm.mx.Unlock() + return nil +} + +func (lm *listenersManager) addConn(port routing.Port, conn net.Conn) error { + lm.mx.RLock() + if _, ok := lm.listeners[port]; !ok { + lm.mx.RUnlock() + return ErrNoListenerOnPort + } + lm.listeners[port].addConn(conn) + lm.mx.RUnlock() + return nil +}