diff --git a/pkg/app2/client.go b/pkg/app2/client.go index 3f9745e081..c78d1ab7b0 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -1,178 +1,70 @@ package app2 import ( - "encoding/binary" - "net" + "net/rpc" - "github.com/hashicorp/yamux" - - "github.com/pkg/errors" - - "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/pkg/routing" "github.com/skycoin/dmsg/cipher" -) - -var ( - ErrWrongHSFrameTypeReceived = errors.New("received wrong HS frame type") + "github.com/skycoin/skycoin/src/util/logging" ) // Client is used by skywire apps. type Client struct { - PK cipher.PubKey - pid ProcID - sockAddr string - conn net.Conn - session *yamux.Session - logger *logging.Logger - lm *listenersManager - isListening int32 + PK cipher.PubKey + pid ProcID + rpc ServerRPCClient + logger *logging.Logger } // NewClient creates a new Client. The Client needs to be provided with: // - localPK: The local public key of the parent skywire visor. // - pid: The procID assigned for the process that Client is being used by. // - sockAddr: The socket address to connect to Server. -func NewClient(localPK cipher.PubKey, pid ProcID, sockAddr string, l *logging.Logger) (*Client, error) { - conn, err := net.Dial("unix", sockAddr) - if err != nil { - return nil, errors.Wrap(err, "error connecting app server") - } - - session, err := yamux.Client(conn, nil) - if err != nil { - return nil, errors.Wrap(err, "error opening yamux session") - } - - lm := newListenersManager(l, pid, localPK) - +func NewClient(localPK cipher.PubKey, pid ProcID, rpc *rpc.Client, l *logging.Logger) *Client { return &Client{ - PK: localPK, - pid: pid, - sockAddr: sockAddr, - conn: conn, - session: session, - lm: lm, - }, nil + PK: localPK, + pid: pid, + rpc: newServerRPCClient(rpc), + logger: l, + } } -func (c *Client) Dial(addr routing.Addr) (net.Conn, error) { - stream, err := c.session.Open() +func (c *Client) Dial(remote routing.Addr) (*Conn, error) { + connID, err := c.rpc.Dial(remote) if err != nil { - return nil, errors.Wrap(err, "error opening stream") + return nil, err } - err = dialHS(stream, c.pid, routing.Loop{ - Local: routing.Addr{ + conn := &Conn{ + id: connID, + rpc: c.rpc, + // TODO: port? + local: routing.Addr{ PubKey: c.PK, }, - Remote: addr, - }) - if err != nil { - return nil, errors.Wrap(err, "error performing Dial HS") + remote: remote, } - return stream, nil + return conn, nil } -func (c *Client) Listen(port routing.Port) (net.Listener, error) { - if err := c.lm.reserveListener(port); err != nil { - return nil, errors.Wrap(err, "error reserving listener") - } - - stream, err := c.session.Open() - if err != nil { - return nil, errors.Wrap(err, "error opening stream") - } - +func (c *Client) Listen(port routing.Port) (*Listener, error) { local := routing.Addr{ PubKey: c.PK, Port: port, } - err = listenHS(stream, c.pid, local) - if err != nil { - return nil, errors.Wrap(err, "error performing Listen HS") - } - - c.lm.listen(c.session) - - l := newListener(local, c.lm, c.pid, c.stopListening, c.logger) - if err := c.lm.set(port, l); err != nil { - return nil, errors.Wrap(err, "error setting listener") - } - - return l, nil -} - -func (c *Client) listen() error { - for { - stream, err := c.session.Accept() - if err != nil { - return errors.Wrap(err, "error accepting stream") - } - - hsFrame, err := readHSFrame(stream) - if err != nil { - c.logger.WithError(err).Error("error reading HS frame") - continue - } - - if hsFrame.FrameType() != HSFrameTypeDMSGDial { - c.logger.WithError(ErrWrongHSFrameTypeReceived).Error("on listening for Dial") - continue - } - - // TODO: handle field get gracefully - remotePort := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen*2+HSFramePortLen:])) - if err := c.lm.addConn(remotePort, stream); err != nil { - c.logger.WithError(err).Error("failed to accept") - continue - } - - localPort := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:])) - - var localPK cipher.PubKey - copy(localPK[:], hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) - - respHSFrame := NewHSFrameDMSGAccept(c.pid, routing.Loop{ - Local: routing.Addr{ - PubKey: c.PK, - Port: remotePort, - }, - Remote: routing.Addr{ - PubKey: localPK, - Port: localPort, - }, - }) - - if _, err := stream.Write(respHSFrame); err != nil { - c.logger.WithError(err).Error("error responding with DmsgAccept") - continue - } - } -} - -func (c *Client) stopListening(port routing.Port) error { - stream, err := c.session.Open() + lisID, err := c.rpc.Listen(local) if err != nil { - return errors.Wrap(err, "error opening stream") - } - - addr := routing.Addr{ - PubKey: c.PK, - Port: port, - } - - hsFrame := NewHSFrameDMSGStopListening(c.pid, addr) - if _, err := stream.Write(hsFrame); err != nil { - return errors.Wrap(err, "error writing HS frame") + return nil, err } - if err := stream.Close(); err != nil { - return errors.Wrap(err, "error closing stream") + listener := &Listener{ + id: lisID, + rpc: c.rpc, + addr: local, } - return nil + return listener, nil } diff --git a/pkg/app2/client_conn.go b/pkg/app2/client_conn.go deleted file mode 100644 index 5617d30709..0000000000 --- a/pkg/app2/client_conn.go +++ /dev/null @@ -1,23 +0,0 @@ -package app2 - -import ( - "net" - - "github.com/skycoin/skywire/pkg/routing" -) - -// clientConn serves as a wrapper for `net.Conn` being returned to the -// app client side from `Accept` func -type clientConn struct { - remote routing.Addr - local routing.Addr - net.Conn -} - -func (c *clientConn) RemoteAddr() net.Addr { - return c.remote -} - -func (c *clientConn) LocalAddr() net.Addr { - return c.local -} diff --git a/pkg/app2/conn.go b/pkg/app2/conn.go new file mode 100644 index 0000000000..c5a7e2ed91 --- /dev/null +++ b/pkg/app2/conn.go @@ -0,0 +1,34 @@ +package app2 + +import ( + "net" + + "github.com/skycoin/skywire/pkg/routing" +) + +type Conn struct { + id uint16 + rpc ConnRPCClient + local routing.Addr + remote routing.Addr +} + +func (c *Conn) Read(b []byte) (int, error) { + return c.rpc.Read(c.id, b) +} + +func (c *Conn) Write(b []byte) (int, error) { + return c.rpc.Write(c.id, b) +} + +func (c *Conn) Close() error { + return c.rpc.CloseConn(c.id) +} + +func (c *Conn) LocalAddr() net.Addr { + return c.local +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.remote +} diff --git a/pkg/app2/conns_manager.go b/pkg/app2/conns_manager.go new file mode 100644 index 0000000000..d09b61c7cd --- /dev/null +++ b/pkg/app2/conns_manager.go @@ -0,0 +1,83 @@ +package app2 + +import ( + "fmt" + "net" + "sync" + + "github.com/pkg/errors" +) + +type connsManager struct { + conns map[uint16]net.Conn + mx sync.RWMutex + lstID uint16 +} + +func newConnsManager() *connsManager { + return &connsManager{ + conns: make(map[uint16]net.Conn), + } +} + +func (m *connsManager) nextID() (*uint16, error) { + m.mx.Lock() + + connID := m.lstID + 1 + for ; connID < m.lstID; connID++ { + if _, ok := m.conns[connID]; !ok { + break + } + } + + if connID == m.lstID { + m.mx.Unlock() + return nil, errors.New("no more available conns") + } + + m.conns[connID] = nil + m.lstID = connID + + m.mx.Unlock() + return &connID, nil +} + +func (m *connsManager) getAndRemove(connID uint16) (net.Conn, error) { + m.mx.Lock() + conn, ok := m.conns[connID] + if !ok { + m.mx.Unlock() + return nil, fmt.Errorf("no conn with id %d", connID) + } + + if conn == nil { + m.mx.Unlock() + return nil, fmt.Errorf("conn with id %d is not set", connID) + } + + delete(m.conns, connID) + + m.mx.Unlock() + return conn, nil +} + +func (m *connsManager) set(connID uint16, conn net.Conn) error { + m.mx.Lock() + + if c, ok := m.conns[connID]; ok && c != nil { + m.mx.Unlock() + return errors.New("conn already exists") + } + + m.conns[connID] = conn + + m.mx.Unlock() + return nil +} + +func (m *connsManager) get(connID uint16) (net.Conn, bool) { + m.mx.RLock() + conn, ok := m.conns[connID] + m.mx.RUnlock() + return conn, ok +} diff --git a/pkg/app2/hsframe.go b/pkg/app2/hsframe.go deleted file mode 100644 index d2ffc16544..0000000000 --- a/pkg/app2/hsframe.go +++ /dev/null @@ -1,206 +0,0 @@ -package app2 - -import ( - "encoding/binary" - "io" - "net" - - "github.com/pkg/errors" - "github.com/skycoin/skywire/pkg/routing" -) - -const ( - HSFrameHeaderLen = 3 - HSFrameProcIDLen = 2 - HSFrameTypeLen = 1 - HSFramePKLen = 33 - HSFramePortLen = 2 -) - -// HSFrameType identifies the type of a handshake frame. -type HSFrameType byte - -const ( - HSFrameTypeDMSGListen HSFrameType = 10 + iota - HSFrameTypeDMSGListening - HSFrameTypeDMSGDial - HSFrameTypeDMSGAccept - HSFrameTypeStopListening - HSFrameTypeError -) - -// HSFrame is the data unit for socket connection handshakes between Server and Client. -// It consists of header and body. -// -// Header is a big-endian encoded 3 bytes and is constructed as follows: -// | ProcID (2 bytes) | HSFrameType (1 byte) | -type HSFrame []byte - -func newHSFrame(procID ProcID, frameType HSFrameType, bodyLen int) HSFrame { - hsFrame := make(HSFrame, HSFrameHeaderLen+bodyLen) - - hsFrame.SetProcID(procID) - hsFrame.SetFrameType(frameType) - - return hsFrame -} - -func NewHSFrameDMSGListen(procID ProcID, local routing.Addr) HSFrame { - hsFrame := newHSFrame(procID, HSFrameTypeDMSGListen, HSFramePKLen+HSFramePortLen) - - copy(hsFrame[HSFrameHeaderLen:], local.PubKey[:]) - binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:], uint16(local.Port)) - - return hsFrame -} - -func NewHSFrameDMSGListening(procID ProcID, local routing.Addr) HSFrame { - hsFrame := newHSFrame(procID, HSFrameTypeDMSGListening, HSFramePKLen+HSFramePortLen) - - copy(hsFrame[HSFrameHeaderLen:], local.PubKey[:]) - binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:], uint16(local.Port)) - - return hsFrame -} - -func NewHSFrameDSMGDial(procID ProcID, loop routing.Loop) HSFrame { - hsFrame := newHSFrame(procID, HSFrameTypeDMSGDial, 2*HSFramePKLen+2*HSFramePortLen) - - copy(hsFrame[HSFrameHeaderLen:], loop.Local.PubKey[:]) - binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:], uint16(loop.Local.Port)) - - copy(hsFrame[HSFrameHeaderLen+HSFramePKLen+HSFramePortLen:], loop.Remote.PubKey[:]) - binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+2*HSFramePKLen+HSFramePortLen:], uint16(loop.Remote.Port)) - - return hsFrame -} - -func NewHSFrameDMSGAccept(procID ProcID, loop routing.Loop) HSFrame { - hsFrame := newHSFrame(procID, HSFrameTypeDMSGAccept, 2*HSFramePKLen+2*HSFramePortLen) - - copy(hsFrame[HSFrameHeaderLen:], loop.Local.PubKey[:]) - binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:], uint16(loop.Local.Port)) - - copy(hsFrame[HSFrameHeaderLen+HSFramePKLen+HSFramePortLen:], loop.Remote.PubKey[:]) - binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+2*HSFramePKLen+HSFramePortLen:], uint16(loop.Remote.Port)) - - return hsFrame -} - -func NewHSFrameDMSGStopListening(procID ProcID, local routing.Addr) HSFrame { - hsFrame := newHSFrame(procID, HSFrameTypeDMSGListen, HSFramePKLen+HSFramePortLen) - - copy(hsFrame[HSFrameHeaderLen:], local.PubKey[:]) - binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:], uint16(local.Port)) - - return hsFrame -} - -func NewHSFrameError(procID ProcID) HSFrame { - hsFrame := newHSFrame(procID, HSFrameTypeError, 0) - - return hsFrame -} - -// ProcID gets ProcID from the HSFrame. -func (f HSFrame) ProcID() ProcID { - return ProcID(binary.BigEndian.Uint16(f)) -} - -// SetProcID sets ProcID for the HSFrame. -func (f HSFrame) SetProcID(procID ProcID) { - binary.BigEndian.PutUint16(f, uint16(procID)) -} - -// FrameType gets FrameType from the HSFrame. -func (f HSFrame) FrameType() HSFrameType { - _ = f[HSFrameProcIDLen] // bounds check hint to compiler; see golang.org/issue/14808 - return HSFrameType(f[HSFrameProcIDLen]) -} - -// SetFrameType sets FrameType for the HSFrame. -func (f HSFrame) SetFrameType(frameType HSFrameType) { - _ = f[HSFrameProcIDLen] // bounds check hint to compiler; see golang.org/issue/14808 - f[HSFrameProcIDLen] = byte(frameType) -} - -func readHSFrame(r io.Reader) (HSFrame, error) { - hsFrame := make(HSFrame, HSFrameHeaderLen) - if _, err := io.ReadFull(r, hsFrame); err != nil { - return nil, errors.Wrap(err, "error reading HS frame header") - } - - hsFrame, err := readHSFrameBody(hsFrame, r) - if err != nil { - return nil, errors.Wrap(err, "error reading HS frame body") - } - - return hsFrame, nil -} - -func writeHSFrame(w io.Writer, hsFrame HSFrame) (int, error) { - n, err := w.Write(hsFrame) - if err != nil { - return 0, errors.Wrap(err, "error writing HS frame") - } - - return n, nil -} - -func readHSFrameBody(hsFrame HSFrame, r io.Reader) (HSFrame, error) { - switch hsFrame.FrameType() { - case HSFrameTypeDMSGListen, HSFrameTypeDMSGListening: - hsFrame = append(hsFrame, make([]byte, HSFramePKLen+HSFramePortLen)...) - case HSFrameTypeDMSGDial, HSFrameTypeDMSGAccept: - hsFrame = append(hsFrame, make([]byte, 2*HSFramePKLen+2*HSFramePortLen)...) - } - - _, err := io.ReadFull(r, hsFrame[HSFrameHeaderLen:]) - return hsFrame, err -} - -func dialHS(conn net.Conn, pid ProcID, loop routing.Loop) error { - hsFrame := NewHSFrameDSMGDial(pid, loop) - - if _, err := writeHSFrame(conn, hsFrame); err != nil { - return err - } - - hsFrame, err := readHSFrame(conn) - if err != nil { - return err - } - - if hsFrame.FrameType() != HSFrameTypeDMSGAccept { - return ErrWrongHSFrameTypeReceived - } - - if hsFrame.ProcID() != pid { - return ErrWrongPID - } - - return nil -} - -func listenHS(conn net.Conn, pid ProcID, local routing.Addr) error { - hsFrame := NewHSFrameDMSGListen(pid, local) - - if _, err := writeHSFrame(conn, hsFrame); err != nil { - return err - } - - hsFrame, err := readHSFrame(conn) - if err != nil { - return err - } - - if hsFrame.FrameType() != HSFrameTypeDMSGListening { - return ErrWrongHSFrameTypeReceived - } - - if hsFrame.ProcID() != pid { - return ErrWrongPID - } - - return nil -} diff --git a/pkg/app2/hsframe_test.go b/pkg/app2/hsframe_test.go deleted file mode 100644 index ab66e376f9..0000000000 --- a/pkg/app2/hsframe_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package app2 - -import ( - "encoding/binary" - "encoding/json" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestHSFrame(t *testing.T) { - t.Run("ok", func(t *testing.T) { - body := struct { - Test string `json:"test"` - }{ - Test: "some string", - } - - bodyBytes, err := json.Marshal(body) - require.NoError(t, err) - - procID := ProcID(1) - frameType := HSFrameTypeDMSGListen - bodyLen := len(bodyBytes) - - hsFrame, err := NewHSFrame(procID, frameType, body) - require.NoError(t, err) - - require.Equal(t, len(hsFrame), HSFrameHeaderLen+len(bodyBytes)) - - gotProcID := ProcID(binary.BigEndian.Uint16(hsFrame)) - require.Equal(t, gotProcID, procID) - - gotFrameType := HSFrameType(hsFrame[HSFrameProcIDLen]) - require.Equal(t, gotFrameType, frameType) - - gotBodyLen := int(binary.BigEndian.Uint16(hsFrame[HSFrameProcIDLen+HSFrameTypeLen:])) - require.Equal(t, gotBodyLen, bodyLen) - - require.Equal(t, bodyBytes, []byte(hsFrame[HSFrameProcIDLen+HSFrameTypeLen+HSFrameBodyLenLen:])) - - gotProcID = hsFrame.ProcID() - require.Equal(t, gotProcID, procID) - - gotFrameType = hsFrame.FrameType() - require.Equal(t, gotFrameType, frameType) - - gotBodyLen = hsFrame.BodyLen() - require.Equal(t, gotBodyLen, bodyLen) - }) - - t.Run("fail - too large body", func(t *testing.T) { - body := struct { - Test string `json:"test"` - }{ - Test: "some string", - } - - for len(body.Test) <= HSFrameMaxBodyLen { - body.Test += body.Test - } - - procID := ProcID(1) - frameType := HSFrameTypeDMSGListen - - _, err := NewHSFrame(procID, frameType, body) - require.Equal(t, err, ErrHSFrameBodyTooLarge) - }) -} diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go index 33c7c2e296..fb423fbfdc 100644 --- a/pkg/app2/listener.go +++ b/pkg/app2/listener.go @@ -1,78 +1,38 @@ package app2 import ( - "errors" "net" - "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/pkg/routing" ) -const ( - listenerBufSize = 1000 -) - -var ( - ErrListenerClosed = errors.New("listener closed") -) - -type listener struct { - addr routing.Addr - conns chan *clientConn - stopListening func(port routing.Port) error - logger *logging.Logger - lm *listenersManager - procID ProcID +type Listener struct { + id uint16 + rpc ListenerRPCClient + addr routing.Addr } -func newListener(addr routing.Addr, lm *listenersManager, procID ProcID, - stopListening func(port routing.Port) error, l *logging.Logger) *listener { - return &listener{ - addr: addr, - conns: make(chan *clientConn, listenerBufSize), - lm: lm, - stopListening: stopListening, - logger: l, - procID: procID, +func (l *Listener) Accept() (*Conn, error) { + connID, err := l.rpc.Accept(l.id) + if err != nil { + return nil, err } -} -func (l *listener) Accept() (net.Conn, error) { - conn, ok := <-l.conns - if !ok { - return nil, ErrListenerClosed - } - - hsFrame := NewHSFrameDMSGAccept(l.procID, routing.Loop{ - Local: l.addr, - Remote: conn.remote, - }) - - if _, err := conn.Write(hsFrame); err != nil { - return nil, err + conn := &Conn{ + id: connID, + rpc: l.rpc, + local: l.addr, + // TODO: probably pass with response + remote: routing.Addr{}, } return conn, nil } -func (l *listener) Close() error { - if err := l.stopListening(l.addr.Port); err != nil { - l.logger.WithError(err).Error("error sending DmsgStopListening") - } - - if err := l.lm.remove(l.addr.Port); err != nil { - return err - } - - close(l.conns) - - return nil +func (l *Listener) Close() error { + return l.rpc.CloseListener(l.id) } -func (l *listener) Addr() net.Addr { +func (l *Listener) Addr() net.Addr { return l.addr } - -func (l *listener) addConn(conn *clientConn) { - l.conns <- conn -} diff --git a/pkg/app2/listeners_manager.go b/pkg/app2/listeners_manager.go index 5619a439b9..f8c605e34c 100644 --- a/pkg/app2/listeners_manager.go +++ b/pkg/app2/listeners_manager.go @@ -1,172 +1,84 @@ package app2 import ( - "encoding/binary" - "net" + "fmt" "sync" - "sync/atomic" - "github.com/hashicorp/yamux" "github.com/pkg/errors" - "github.com/skycoin/dmsg/cipher" - "github.com/skycoin/skycoin/src/util/logging" - - "github.com/skycoin/skywire/pkg/routing" -) - -var ( - ErrPortAlreadyBound = errors.New("port is already bound") - ErrNoListenerOnPort = errors.New("no listener on port") - ErrWrongPID = errors.New("wrong ProcID specified in the HS frame") + "github.com/skycoin/dmsg" ) // listenersManager contains and manages all the instantiated listeners type listenersManager struct { - pid ProcID - pk cipher.PubKey - listeners map[routing.Port]*listener - mx sync.RWMutex - isListening int32 - logger *logging.Logger - doneCh chan struct{} - doneWg sync.WaitGroup + listeners map[uint16]*dmsg.Listener + mx sync.RWMutex + lstID uint16 } -func newListenersManager(l *logging.Logger, pid ProcID, pk cipher.PubKey) *listenersManager { +func newListenersManager() *listenersManager { return &listenersManager{ - pid: pid, - pk: pk, - listeners: make(map[routing.Port]*listener), - logger: l, - doneCh: make(chan struct{}), + listeners: make(map[uint16]*dmsg.Listener), } } -func (lm *listenersManager) close() { - close(lm.doneCh) - lm.doneWg.Wait() -} +func (m *listenersManager) nextID() (*uint16, error) { + m.mx.Lock() -func (lm *listenersManager) set(port routing.Port, l *listener) error { - lm.mx.Lock() - if v, ok := lm.listeners[port]; !ok || v != nil { - lm.mx.Unlock() - return ErrPortAlreadyBound + lisID := m.lstID + 1 + for ; lisID < m.lstID; lisID++ { + if _, ok := m.listeners[lisID]; !ok { + break + } } - lm.listeners[port] = l - lm.mx.Unlock() - return nil -} -func (lm *listenersManager) reserveListener(port routing.Port) error { - lm.mx.Lock() - if _, ok := lm.listeners[port]; ok { - lm.mx.Unlock() - return ErrPortAlreadyBound + if lisID == m.lstID { + m.mx.Unlock() + return nil, errors.New("no more available listeners") } - lm.listeners[port] = nil - lm.mx.Unlock() - return nil + + m.listeners[lisID] = nil + m.lstID = lisID + + m.mx.Unlock() + return &lisID, nil } -func (lm *listenersManager) remove(port routing.Port) error { - lm.mx.Lock() - if _, ok := lm.listeners[port]; !ok { - lm.mx.Unlock() - return ErrNoListenerOnPort +func (m *listenersManager) getAndRemove(lisID uint16) (*dmsg.Listener, error) { + m.mx.Lock() + lis, ok := m.listeners[lisID] + if !ok { + m.mx.Unlock() + return nil, fmt.Errorf("no listener with id %d", lisID) } - delete(lm.listeners, port) - lm.mx.Unlock() - return nil -} -// addConn passes connection to the corresponding listener -func (lm *listenersManager) addConn(localPort routing.Port, remote routing.Addr, conn net.Conn) error { - lm.mx.RLock() - if _, ok := lm.listeners[localPort]; !ok { - lm.mx.RUnlock() - return ErrNoListenerOnPort + if lis == nil { + m.mx.Unlock() + return nil, fmt.Errorf("listener with id %d is not set", lisID) } - lm.listeners[localPort].addConn(&clientConn{ - remote: remote, - Conn: conn, - }) - lm.mx.RUnlock() - return nil + + delete(m.listeners, lisID) + + m.mx.Unlock() + return lis, nil } -// listen accepts all new yamux streams from the server. We want to accept only -// `DmsgDial` frames here, thus all the other frames get rejected. `DmsgDial` frames -// are being distributed between the corresponding listeners with regards to their port -func (lm *listenersManager) listen(session *yamux.Session) { - // this one should only start once - if !atomic.CompareAndSwapInt32(&lm.isListening, 0, 1) { - return +func (m *listenersManager) set(lisID uint16, lis *dmsg.Listener) error { + m.mx.Lock() + + if l, ok := m.listeners[lisID]; ok && l != nil { + m.mx.Unlock() + return errors.New("listener already exists") } - lm.doneWg.Add(1) - - go func() { - defer lm.doneWg.Done() - - for { - select { - case <-lm.doneCh: - return - default: - stream, err := session.Accept() - if err != nil { - lm.logger.WithError(err).Error("error accepting stream") - return - } - - hsFrame, err := readHSFrame(stream) - if err != nil { - lm.logger.WithError(err).Error("error reading HS frame") - continue - } - - if hsFrame.ProcID() != lm.pid { - lm.logger.WithError(ErrWrongPID).Error("error listening for Dial") - } - - if hsFrame.FrameType() != HSFrameTypeDMSGDial { - lm.logger.WithError(ErrWrongHSFrameTypeReceived).Error("error listening for Dial") - continue - } - - // TODO: handle field get gracefully - remotePort := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen*2+HSFramePortLen:])) - localPort := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:])) - - var localPK cipher.PubKey - copy(localPK[:], hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) - - err = lm.addConn(remotePort, routing.Addr{ - PubKey: localPK, - Port: localPort, - }, stream) - if err != nil { - lm.logger.WithError(err).Error("failed to accept") - continue - } - - respHSFrame := NewHSFrameDMSGAccept(hsFrame.ProcID(), routing.Loop{ - Local: routing.Addr{ - PubKey: lm.pk, - Port: remotePort, - }, - Remote: routing.Addr{ - PubKey: localPK, - Port: localPort, - }, - }) - - if _, err := stream.Write(respHSFrame); err != nil { - lm.logger.WithError(err).Error("error responding with DmsgAccept") - continue - } - } - } - }() + m.listeners[lisID] = lis + + m.mx.Unlock() + return nil +} + +func (m *listenersManager) get(lisID uint16) (*dmsg.Listener, bool) { + m.mx.RLock() + lis, ok := m.listeners[lisID] + m.mx.RUnlock() + return lis, ok } diff --git a/pkg/app2/server.go b/pkg/app2/server.go deleted file mode 100644 index 3953266c09..0000000000 --- a/pkg/app2/server.go +++ /dev/null @@ -1,266 +0,0 @@ -package app2 - -import ( - "context" - "encoding/binary" - "fmt" - "io" - "net" - "sync" - - "github.com/skycoin/skycoin/src/util/logging" - - "github.com/hashicorp/yamux" - - "github.com/skycoin/dmsg" - - "github.com/skycoin/skywire/pkg/routing" - - "github.com/pkg/errors" - - "github.com/skycoin/dmsg/cipher" -) - -type serverConn struct { - procID ProcID - conn net.Conn - session *yamux.Session - lm *listenersManager - dmsgListeners map[routing.Port]*dmsg.Listener - dmsgListenersMx sync.RWMutex -} - -// Server is used by skywire visor. -type Server struct { - PK cipher.PubKey - dmsgC *dmsg.Client - apps map[string]*serverConn - appsMx sync.RWMutex - logger *logging.Logger -} - -func NewServer(localPK cipher.PubKey, dmsgC *dmsg.Client, l *logging.Logger) *Server { - return &Server{ - PK: localPK, - dmsgC: dmsgC, - apps: make(map[string]*serverConn), - logger: l, - } -} - -func (s *Server) Serve(sockAddr string) error { - l, err := net.Listen("unix", sockAddr) - if err != nil { - return errors.Wrap(err, "error listening unix socket") - } - - for { - conn, err := l.Accept() - if err != nil { - return errors.Wrap(err, "error accepting client connection") - } - - s.appsMx.Lock() - if _, ok := s.apps[conn.RemoteAddr().String()]; ok { - s.logger.WithError(ErrPortAlreadyBound).Error("error storing session") - } - - session, err := yamux.Server(conn, nil) - if err != nil { - return errors.Wrap(err, "error creating yamux session") - } - - s.apps[conn.RemoteAddr().String()] = &serverConn{ - session: session, - conn: conn, - lm: newListenersManager(), - dmsgListeners: make(map[routing.Port]*dmsg.Listener), - } - s.appsMx.Unlock() - - // TODO: handle error - go s.serveClient(session) - } -} - -func (s *Server) serveClient(conn *serverConn) error { - for { - stream, err := conn.session.Accept() - if err != nil { - return errors.Wrap(err, "error opening stream") - } - - go s.serveStream(stream, conn) -/////////////////////////////// - hsFrame, err := readHSFrame(conn) - if err != nil { - return errors.Wrap(err, "error reading HS frame") - } - - switch hsFrame.FrameType() { - case HSFrameTypeDMSGListen: - if s. - pk := make(cipher.PubKey, 33) - copy(pk, hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) - port := binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:]) - dmsgL, err := s.dmsgC.Listen(port) - if err != nil { - return fmt.Errorf("error listening on port %d: %v", port, err) - } - - respHSFrame := NewHSFrameDMSGListening(hsFrame.ProcID(), routing.Addr{ - PubKey: cipher.PubKey(pk), - Port: 0, - }) - } - } -} - -func (s *Server) serveStream(stream net.Conn, conn *serverConn) error { - for { - hsFrame, err := readHSFrame(stream) - if err != nil { - return errors.Wrap(err, "error reading HS frame") - } - - // TODO: ensure thread-safety - conn.procID = hsFrame.ProcID() - - var respHSFrame HSFrame - switch hsFrame.FrameType() { - case HSFrameTypeDMSGListen: - port := binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:]) - if err := conn.reserveListener(routing.Port(port)); err != nil { - respHSFrame = NewHSFrameError(hsFrame.ProcID()) - } else { - dmsgL, err := s.dmsgC.Listen(port) - if err != nil { - respHSFrame = NewHSFrameError(hsFrame.ProcID()) - } else { - if err := conn.addListener(routing.Port(port), dmsgL); err != nil { - respHSFrame = NewHSFrameError(hsFrame.ProcID()) - } else { - var pk cipher.PubKey - copy(pk[:], hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) - - respHSFrame = NewHSFrameDMSGListening(hsFrame.ProcID(), routing.Addr{ - PubKey: pk, - Port: routing.Port(port), - }) - } - } - } - case HSFrameTypeDMSGDial: - localPort := binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:]) - var localPK cipher.PubKey - copy(localPK[:], hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) - - var remotePK cipher.PubKey - copy(remotePK[:], hsFrame[HSFrameHeaderLen+HSFramePKLen+HSFramePortLen:HSFrameHeaderLen+HSFramePKLen+HSFramePortLen+HSFramePKLen]) - remotePort := binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen+HSFramePortLen+HSFramePKLen:]) - - // TODO: context - tp, err := s.dmsgC.Dial(context.Background(), localPK, localPort) - if err != nil { - respHSFrame = NewHSFrameError(hsFrame.ProcID()) - } else { - respHSFrame = NewHSFrameDMSGAccept(hsFrame.ProcID(), routing.Loop{ - Local: routing.Addr{ - PubKey: localPK, - Port: routing.Port(localPort), - }, - Remote: routing.Addr{ - PubKey: remotePK, - Port: routing.Port(remotePort), - }, - }) - - go func() { - if err := s.forwardOverDMSG(stream, tp); err != nil { - s.logger.WithError(err).Error("error forwarding over DMSG") - } - }() - } - } - - if _, err := stream.Write(respHSFrame); err != nil { - return errors.Wrap(err, "error writing response") - } - } -} - -func (s *Server) forwardOverDMSG(stream net.Conn, tp *dmsg.Transport) error { - toStreamErrCh := make(chan error) - defer close(toStreamErrCh) - go func() { - _, err := io.Copy(stream, tp) - toStreamErrCh <- err - }() - - _, err := io.Copy(stream, tp) - if err != nil { - return err - } - - if err := <-toStreamErrCh; err != nil { - return err - } - - return nil -} - -func (c *serverConn) reserveListener(port routing.Port) error { - c.dmsgListenersMx.Lock() - if _, ok := c.dmsgListeners[port]; ok { - c.dmsgListenersMx.Unlock() - return ErrPortAlreadyBound - } - c.dmsgListeners[port] = nil - c.dmsgListenersMx.Unlock() - return nil -} - -func (c *serverConn) addListener(port routing.Port, l *dmsg.Listener) error { - c.dmsgListenersMx.Lock() - if lis, ok := c.dmsgListeners[port]; ok && lis != nil { - c.dmsgListenersMx.Unlock() - return ErrPortAlreadyBound - } - c.dmsgListeners[port] = l - go c.acceptDMSG(l) - c.dmsgListenersMx.Unlock() - return nil -} - -func (c *serverConn) acceptDMSG(l *dmsg.Listener) error { - for { - stream, err := c.session.Open() - if err != nil { - return errors.Wrap(err, "error opening yamux stream") - } - - remoteAddr, ok := l.Addr().(dmsg.Addr) - if !ok { - // shouldn't happen, but still - return errors.Wrap(err, "wrong type for DMSG addr") - } - - hsFrame := NewHSFrameDSMGDial(c.procID, routing.Loop{ - Local: routing.Addr{ - PubKey: remoteAddr.PK, - Port: routing.Port(remoteAddr.Port), - }, - // TODO: get local addr - Remote: routing.Addr{ - PubKey: - }, - }) - - conn, err := l.Accept() - if err != nil { - return errors.Wrap(err, "error accepting DMSG conn") - } - - - } -} diff --git a/pkg/app2/server_rpc.go b/pkg/app2/server_rpc.go index ba73f4fe88..ace37eb6ae 100644 --- a/pkg/app2/server_rpc.go +++ b/pkg/app2/server_rpc.go @@ -2,10 +2,7 @@ package app2 import ( "context" - "errors" "fmt" - "net" - "sync" "github.com/skycoin/dmsg" @@ -13,139 +10,49 @@ import ( ) type ServerRPC struct { - dmsgC *dmsg.Client - conns map[uint16]net.Conn - connsMx sync.RWMutex - lstConnID uint16 - listeners map[uint16]*dmsg.Listener - listenersMx sync.RWMutex - lstLisID uint16 + dmsgC *dmsg.Client + lm *listenersManager + cm *connsManager } -func (r *ServerRPC) nextConnID() (*uint16, error) { - r.connsMx.Lock() - - connID := r.lstConnID + 1 - for ; connID < r.lstConnID; connID++ { - if _, ok := r.conns[connID]; !ok { - break - } - } - - if connID == r.lstConnID { - r.connsMx.Unlock() - return nil, errors.New("no more available conns") - } - - r.conns[connID] = nil - r.lstConnID = connID - - r.connsMx.Unlock() - return &connID, nil -} - -func (r *ServerRPC) nextLisID() (*uint16, error) { - r.listenersMx.Lock() - - lisID := r.lstLisID + 1 - for ; lisID < r.lstLisID; lisID++ { - if _, ok := r.listeners[lisID]; !ok { - break - } - } - - if lisID == r.lstLisID { - r.listenersMx.Unlock() - return nil, errors.New("no more available listeners") - } - - r.listeners[lisID] = nil - r.lstLisID = lisID - - r.listenersMx.Unlock() - return &lisID, nil -} - -func (r *ServerRPC) setConn(connID uint16, conn net.Conn) error { - r.connsMx.Lock() - - if c, ok := r.conns[connID]; ok && c != nil { - r.connsMx.Unlock() - return errors.New("conn already exists") - } - - r.conns[connID] = conn - - r.connsMx.Unlock() - return nil -} - -func (r *ServerRPC) setListener(lisID uint16, lis *dmsg.Listener) error { - r.listenersMx.Lock() - - if l, ok := r.listeners[lisID]; ok && l != nil { - r.listenersMx.Unlock() - return errors.New("listener already exists") +func newServerRPC(dmsgC *dmsg.Client) *ServerRPC { + return &ServerRPC{ + dmsgC: dmsgC, + lm: newListenersManager(), + cm: newConnsManager(), } - - r.listeners[lisID] = lis - - r.listenersMx.Unlock() - return nil -} - -func (r *ServerRPC) getConn(connID uint16) (net.Conn, bool) { - r.connsMx.RLock() - conn, ok := r.conns[connID] - r.connsMx.RUnlock() - return conn, ok -} - -func (r *ServerRPC) getListener(lisID uint16) (*dmsg.Listener, bool) { - r.listenersMx.RLock() - lis, ok := r.listeners[lisID] - r.listenersMx.RUnlock() - return lis, ok -} - -type DialReq struct { - Remote routing.Addr } -func (r *ServerRPC) Dial(req *DialReq, connID *uint16) error { - connID, err := r.nextConnID() +func (r *ServerRPC) Dial(remote *routing.Addr, connID *uint16) error { + connID, err := r.cm.nextID() if err != nil { return err } - tp, err := r.dmsgC.Dial(context.TODO(), req.Remote.PubKey, uint16(req.Remote.Port)) + tp, err := r.dmsgC.Dial(context.TODO(), remote.PubKey, uint16(remote.Port)) if err != nil { return err } - if err := r.setConn(*connID, tp); err != nil { + if err := r.cm.set(*connID, tp); err != nil { return err } return nil } -type ListenReq struct { - Local routing.Addr -} - -func (r *ServerRPC) Listen(req *ListenReq, lisID *uint16) error { - lisID, err := r.nextLisID() +func (r *ServerRPC) Listen(local *routing.Addr, lisID *uint16) error { + lisID, err := r.lm.nextID() if err != nil { return err } - dmsgL, err := r.dmsgC.Listen(uint16(req.Local.Port)) + dmsgL, err := r.dmsgC.Listen(uint16(local.Port)) if err != nil { return err } - if err := r.setListener(*lisID, dmsgL); err != nil { + if err := r.lm.set(*lisID, dmsgL); err != nil { // TODO: close listener return err } @@ -154,12 +61,12 @@ func (r *ServerRPC) Listen(req *ListenReq, lisID *uint16) error { } func (r *ServerRPC) Accept(lisID *uint16, connID *uint16) error { - lis, ok := r.getListener(*lisID) + lis, ok := r.lm.get(*lisID) if !ok { return fmt.Errorf("not listener with id %d", *lisID) } - connID, err := r.nextConnID() + connID, err := r.cm.nextID() if err != nil { return err } @@ -169,7 +76,7 @@ func (r *ServerRPC) Accept(lisID *uint16, connID *uint16) error { return err } - if err := r.setConn(*connID, tp); err != nil { + if err := r.cm.set(*connID, tp); err != nil { // TODO: close conn return err } @@ -183,9 +90,54 @@ type WriteReq struct { } func (r *ServerRPC) Write(req *WriteReq, n *int) error { - conn, ok := r.getConn(req.ConnID) + conn, ok := r.cm.get(req.ConnID) + if !ok { + return fmt.Errorf("no conn with id %d", req.ConnID) + } + + var err error + *n, err = conn.Write(req.B) + if err != nil { + return err + } + + return nil +} + +type ReadResp struct { + B []byte + N int +} + +func (r *ServerRPC) Read(connID *uint16, resp *ReadResp) error { + conn, ok := r.cm.get(*connID) if !ok { - return fmt.Errorf("not conn with id %d", req.ConnID) + return fmt.Errorf("no conn with id %d", *connID) + } + + var err error + resp.N, err = conn.Read(resp.B) + if err != nil { + return err + } + + return nil +} + +func (r *ServerRPC) CloseConn(connID *uint16, _ *struct{}) error { + conn, err := r.cm.getAndRemove(*connID) + if err != nil { + return err + } + + return conn.Close() +} + +func (r *ServerRPC) CloseListener(lisID *uint16, _ *struct{}) error { + lis, err := r.lm.getAndRemove(*lisID) + if err != nil { + return err } + return lis.Close() } diff --git a/pkg/app2/server_rpc_client.go b/pkg/app2/server_rpc_client.go new file mode 100644 index 0000000000..f9e8ed3970 --- /dev/null +++ b/pkg/app2/server_rpc_client.go @@ -0,0 +1,101 @@ +package app2 + +import ( + "net/rpc" + + "github.com/skycoin/skywire/pkg/routing" +) + +type ServerRPCClient interface { + Dial(remote routing.Addr) (uint16, error) + Listen(local routing.Addr) (uint16, error) + Accept(lisID uint16) (uint16, error) + Write(connID uint16, b []byte) (int, error) + Read(connID uint16, b []byte) (int, error) + CloseConn(id uint16) error + CloseListener(id uint16) error +} + +type ListenerRPCClient interface { + Accept(id uint16) (uint16, error) + CloseListener(id uint16) error + Write(connID uint16, b []byte) (int, error) + Read(connID uint16, b []byte) (int, error) + CloseConn(id uint16) error +} + +type ConnRPCClient interface { + Write(id uint16, b []byte) (int, error) + Read(id uint16, b []byte) (int, error) + CloseConn(id uint16) error +} + +type serverRPCCLient struct { + rpc *rpc.Client +} + +func newServerRPCClient(rpc *rpc.Client) ServerRPCClient { + return &serverRPCCLient{ + rpc: rpc, + } +} + +func (c *serverRPCCLient) Dial(remote routing.Addr) (uint16, error) { + var connID uint16 + if err := c.rpc.Call("Dial", &remote, &connID); err != nil { + return 0, err + } + + return connID, nil +} + +func (c *serverRPCCLient) Listen(local routing.Addr) (uint16, error) { + var lisID uint16 + if err := c.rpc.Call("Listen", &local, &lisID); err != nil { + return 0, err + } + + return lisID, nil +} + +func (c *serverRPCCLient) Accept(lisID uint16) (uint16, error) { + var connID uint16 + if err := c.rpc.Call("Accept", &lisID, &connID); err != nil { + return 0, err + } + + return connID, nil +} + +func (c *serverRPCCLient) Write(connID uint16, b []byte) (int, error) { + req := WriteReq{ + ConnID: connID, + B: b, + } + + var n int + if err := c.rpc.Call("Write", &req, &n); err != nil { + return n, err + } + + return n, nil +} + +func (c *serverRPCCLient) Read(connID uint16, b []byte) (int, error) { + var resp ReadResp + if err := c.rpc.Call("Read", &connID, &resp); err != nil { + return 0, err + } + + copy(b[:resp.N], resp.B[:resp.N]) + + return resp.N, nil +} + +func (c *serverRPCCLient) CloseConn(id uint16) error { + return c.rpc.Call("CloseConn", &id, nil) +} + +func (c *serverRPCCLient) CloseListener(id uint16) error { + return c.rpc.Call("CloseListener", &id, nil) +}