Skip to content

Commit

Permalink
Add yamux conn multiplexing
Browse files Browse the repository at this point in the history
  • Loading branch information
Darkren committed Sep 6, 2019
1 parent 804b770 commit 5029589
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 30 deletions.
103 changes: 73 additions & 30 deletions pkg/app2/client.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package app2

import (
"encoding/binary"
"net"

"github.com/hashicorp/yamux"

"github.com/pkg/errors"

"github.com/skycoin/skycoin/src/util/logging"
"github.com/skycoin/skywire/pkg/routing"

"github.com/skycoin/dmsg/cipher"
Expand All @@ -14,47 +18,46 @@ var (
ErrWrongHSFrameTypeReceived = errors.New("received wrong HS frame type")
)

type Listener struct {
conn net.Conn
}

func (l *Listener) Accept() (net.Conn, error) {
hsFrame, err := readHSFrame(l.conn)
if err != nil {
return nil, errors.Wrap(err, "error reading HS frame")
}

if hsFrame.FrameType() != HSFrameTypeDMSGAccept {
return nil, ErrWrongHSFrameTypeReceived
}

return l.conn, nil
}

// Client is used by skywire apps.
type Client struct {
PK cipher.PubKey
pid ProcID
sockAddr string
conn net.Conn
session *yamux.Session
logger *logging.Logger
lm *listenersManager
}

// NewClient creates a new Client. The Client needs to be provided with:
// - localPK: The local public key of the parent skywire visor.
// - pid: The procID assigned for the process that Client is being used by.
// - sockAddr: The socket address to connect to Server.
func NewClient(localPK cipher.PubKey, pid ProcID, sockAddr string) (*Client, error) {
conn, err := net.Dial("unix", sockAddr)
if err != nil {
return nil, errors.Wrap(err, "error connecting app server")
}

session, err := yamux.Client(conn, nil)
if err != nil {
return nil, errors.Wrap(err, "error opening yamux session")
}

return &Client{
PK: localPK,
pid: pid,
sockAddr: sockAddr,
conn: conn,
session: session,
lm: newListenersManager(),
}, nil
}

func (c *Client) Dial(addr routing.Addr) (net.Conn, error) {
conn, err := net.Dial("unix", c.sockAddr)
stream, err := c.session.Open()
if err != nil {
return nil, errors.Wrap(err, "error connecting app server")
return nil, errors.Wrap(err, "error opening stream")
}

hsFrame := NewHSFrameDSMGDial(c.pid, routing.Loop{
Expand All @@ -63,11 +66,12 @@ func (c *Client) Dial(addr routing.Addr) (net.Conn, error) {
},
Remote: addr,
})
if _, err := conn.Write(hsFrame); err != nil {

if _, err := stream.Write(hsFrame); err != nil {
return nil, errors.Wrap(err, "error writing HS frame")
}

hsFrame, err = readHSFrame(conn)
hsFrame, err = readHSFrame(stream)
if err != nil {
return nil, errors.Wrap(err, "error reading HS frame")
}
Expand All @@ -76,21 +80,30 @@ func (c *Client) Dial(addr routing.Addr) (net.Conn, error) {
return nil, ErrWrongHSFrameTypeReceived
}

return conn, nil
return stream, nil
}

func (c *Client) Listen(addr routing.Addr) (*Listener, error) {
conn, err := net.Dial("unix", c.sockAddr)
func (c *Client) Listen(port routing.Port) (*Listener, error) {
if c.lm.portIsBound(port) {
return nil, ErrPortAlreadyBound
}

stream, err := c.session.Open()
if err != nil {
return nil, errors.Wrap(err, "error connecting app server")
return nil, errors.Wrap(err, "error opening stream")
}

addr := routing.Addr{
PubKey: c.PK,
Port: port,
}

hsFrame := NewHSFrameDMSGListen(c.pid, addr)
if _, err := conn.Write(hsFrame); err != nil {
if _, err := stream.Write(hsFrame); err != nil {
return nil, errors.Wrap(err, "error writing HS frame")
}

hsFrame, err = readHSFrame(conn)
hsFrame, err = readHSFrame(stream)
if err != nil {
return nil, errors.Wrap(err, "error reading HS frame")
}
Expand All @@ -99,7 +112,37 @@ func (c *Client) Listen(addr routing.Addr) (*Listener, error) {
return nil, ErrWrongHSFrameTypeReceived
}

return &Listener{
conn: conn,
}, nil
l := NewListener(addr, c.lm)
if err := c.lm.add(port, l); err != nil {
return nil, err
}

return l, nil
}

func (c *Client) listen() error {
for {
stream, err := c.session.Accept()
if err != nil {
return errors.Wrap(err, "error accepting stream")
}

hsFrame, err := readHSFrame(stream)
if err != nil {
c.logger.WithError(err).Error("error reading HS frame")
continue
}

if hsFrame.FrameType() != HSFrameTypeDMSGDial {
c.logger.WithError(ErrWrongHSFrameTypeReceived).Error("on listening for Dial")
continue
}

// TODO: handle field get gracefully
port := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:]))
if err := c.lm.addConn(port, stream); err != nil {
c.logger.WithError(err).Error("failed to accept")
continue
}
}
}
58 changes: 58 additions & 0 deletions pkg/app2/listener.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package app2

import (
"errors"
"net"

"github.com/skycoin/skywire/pkg/routing"
)

const (
listenerBufSize = 1000
)

var (
ErrListenerClosed = errors.New("listener closed")
)

type Listener struct {
addr routing.Addr
conns chan net.Conn
lm *listenersManager
}

func NewListener(addr routing.Addr, lm *listenersManager) *Listener {
return &Listener{
addr: addr,
conns: make(chan net.Conn, listenerBufSize),
lm: lm,
}
}

func (l *Listener) Accept() (net.Conn, error) {
conn, ok := <-l.conns
if !ok {
return nil, ErrListenerClosed
}

return conn, nil
}

func (l *Listener) Close() error {
if err := l.lm.remove(l.addr.Port); err != nil {
return err
}

// TODO: send ListenEnd frame
close(l.conns)

return nil
}

func (l *Listener) Addr() net.Addr {
return l.addr
}

func (l *Listener) addConn(conn net.Conn) {
l.conns <- conn
}
65 changes: 65 additions & 0 deletions pkg/app2/listeners_manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package app2

import (
"net"
"sync"

"github.com/pkg/errors"
"github.com/skycoin/skywire/pkg/routing"
)

var (
ErrPortAlreadyBound = errors.New("port is already bound")
ErrNoListenerOnPort = errors.New("no listener on port")
)

type listenersManager struct {
listeners map[routing.Port]*Listener
mx sync.RWMutex
}

func newListenersManager() *listenersManager {
return &listenersManager{
listeners: make(map[routing.Port]*Listener),
}
}

func (lm *listenersManager) portIsBound(port routing.Port) bool {
lm.mx.RLock()
_, ok := lm.listeners[port]
lm.mx.RUnlock()
return ok
}

func (lm *listenersManager) add(port routing.Port, l *Listener) error {
lm.mx.Lock()
if _, ok := lm.listeners[port]; ok {
lm.mx.Unlock()
return ErrPortAlreadyBound
}
lm.listeners[port] = l
lm.mx.Unlock()
return nil
}

func (lm *listenersManager) remove(port routing.Port) error {
lm.mx.Lock()
if _, ok := lm.listeners[port]; !ok {
lm.mx.Unlock()
return ErrNoListenerOnPort
}
delete(lm.listeners, port)
lm.mx.Unlock()
return nil
}

func (lm *listenersManager) addConn(port routing.Port, conn net.Conn) error {
lm.mx.RLock()
if _, ok := lm.listeners[port]; !ok {
lm.mx.RUnlock()
return ErrNoListenerOnPort
}
lm.listeners[port].addConn(conn)
lm.mx.RUnlock()
return nil
}

0 comments on commit 5029589

Please sign in to comment.