From 790571e48a4d61c7c40f3d54b96d2c6aebd723d2 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Fri, 30 Aug 2019 09:48:27 +0300 Subject: [PATCH 01/43] Add basic structures --- pkg/app2/client.go | 1 + pkg/app2/doc.go | 4 +++ pkg/app2/hsframe.go | 63 +++++++++++++++++++++++++++++++++++++++++++++ pkg/app2/procid.go | 6 +++++ pkg/app2/server.go | 1 + 5 files changed, 75 insertions(+) create mode 100644 pkg/app2/client.go create mode 100644 pkg/app2/doc.go create mode 100644 pkg/app2/hsframe.go create mode 100644 pkg/app2/procid.go create mode 100644 pkg/app2/server.go diff --git a/pkg/app2/client.go b/pkg/app2/client.go new file mode 100644 index 0000000000..658af578d0 --- /dev/null +++ b/pkg/app2/client.go @@ -0,0 +1 @@ +package app2 diff --git a/pkg/app2/doc.go b/pkg/app2/doc.go new file mode 100644 index 0000000000..0dca17f0d7 --- /dev/null +++ b/pkg/app2/doc.go @@ -0,0 +1,4 @@ +// Package app2 provides facilities to establish communication +// between a visor node and a skywire application. Intended to +// replace the original `app` module +package app2 diff --git a/pkg/app2/hsframe.go b/pkg/app2/hsframe.go new file mode 100644 index 0000000000..b70120fc51 --- /dev/null +++ b/pkg/app2/hsframe.go @@ -0,0 +1,63 @@ +package app2 + +import ( + "encoding/binary" + "encoding/json" + "errors" + "math" +) + +const ( + HSFrameHeaderLength = 5 + HSFrameProcIDLength = 2 + HSFrameTypeLength = 1 + HSFrameBodyLenLength = 2 + HSFrameMaxBodyLength = math.MaxUint16 +) + +var ( + // ErrHSFrameBodyTooLong is being returned when the body is too long to be + // fit in the HSFrame + ErrHSFrameBodyTooLong = errors.New("frame body is too long") +) + +// HSFrameType identifies the type of a handshake frame. +type HSFrameType byte + +const ( + HSFrameTypeDMSGListen HSFrameType = 10 + iota + HSFrameTypeDMSGListening + HSFrameTypeDMSGDial + HSFrameTypeDMSGAccept +) + +// HSFrame is the data unit for socket connection handshakes between Server and Client. +// It consists of header and body. +// +// Header is a big-endian encoded 5 bytes and is constructed as follows: +// | ProcID (2 bytes) | HSFrameType (1 byte) | BodyLen (2 bytes) | +// +// Body is a marshaled JSON structure +type HSFrame []byte + +// NewHSFrame constructs new HSFrame. +func NewHSFrame(procID ProcID, frameType HSFrameType, body interface{}) (HSFrame, error) { + bodyBytes, err := json.Marshal(body) + if err != nil { + return nil, err + } + + if len(bodyBytes) > HSFrameMaxBodyLength { + return nil, ErrHSFrameBodyTooLong + } + + hsFrame := make(HSFrame, HSFrameHeaderLength+len(bodyBytes)) + + binary.BigEndian.PutUint16(hsFrame, uint16(procID)) + hsFrame[HSFrameProcIDLength] = byte(frameType) + binary.BigEndian.PutUint16(hsFrame[HSFrameProcIDLength+HSFrameTypeLength:], uint16(len(bodyBytes))) + + copy(hsFrame[HSFrameProcIDLength+HSFrameTypeLength+HSFrameBodyLenLength:], bodyBytes) + + return hsFrame, nil +} diff --git a/pkg/app2/procid.go b/pkg/app2/procid.go new file mode 100644 index 0000000000..d49cdec5fd --- /dev/null +++ b/pkg/app2/procid.go @@ -0,0 +1,6 @@ +package app2 + +// ProcID identifies the current instance of an app (an app process). +// The visor node is responsible for starting apps, and the started process +// should be provided with a ProcID. +type ProcID uint16 diff --git a/pkg/app2/server.go b/pkg/app2/server.go new file mode 100644 index 0000000000..658af578d0 --- /dev/null +++ b/pkg/app2/server.go @@ -0,0 +1 @@ +package app2 From 52992839543102f72317aa521201cb5cf435de00 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Fri, 30 Aug 2019 10:27:16 +0300 Subject: [PATCH 02/43] Add methods for HSFrame --- pkg/app2/hsframe.go | 63 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 55 insertions(+), 8 deletions(-) diff --git a/pkg/app2/hsframe.go b/pkg/app2/hsframe.go index b70120fc51..05fe44c9bd 100644 --- a/pkg/app2/hsframe.go +++ b/pkg/app2/hsframe.go @@ -42,22 +42,69 @@ type HSFrame []byte // NewHSFrame constructs new HSFrame. func NewHSFrame(procID ProcID, frameType HSFrameType, body interface{}) (HSFrame, error) { - bodyBytes, err := json.Marshal(body) + bodyBytes, err := marshalHSFrameBody(body) if err != nil { return nil, err } - if len(bodyBytes) > HSFrameMaxBodyLength { - return nil, ErrHSFrameBodyTooLong - } - hsFrame := make(HSFrame, HSFrameHeaderLength+len(bodyBytes)) - binary.BigEndian.PutUint16(hsFrame, uint16(procID)) - hsFrame[HSFrameProcIDLength] = byte(frameType) - binary.BigEndian.PutUint16(hsFrame[HSFrameProcIDLength+HSFrameTypeLength:], uint16(len(bodyBytes))) + hsFrame.SetProcID(procID) + hsFrame.SetFrameType(frameType) + _ = hsFrame.SetBodyLen(len(bodyBytes)) copy(hsFrame[HSFrameProcIDLength+HSFrameTypeLength+HSFrameBodyLenLength:], bodyBytes) return hsFrame, nil } + +// ProcID gets ProcID from the HSFrame. +func (f HSFrame) ProcID() ProcID { + return ProcID(binary.BigEndian.Uint16(f)) +} + +// SetProcID sets ProcID for the HSFrame. +func (f HSFrame) SetProcID(procID ProcID) { + binary.BigEndian.PutUint16(f, uint16(procID)) +} + +// FrameType gets FrameType from the HSFrame. +func (f HSFrame) FrameType() HSFrameType { + _ = f[HSFrameProcIDLength] // bounds check hint to compiler; see golang.org/issue/14808 + return HSFrameType(f[HSFrameProcIDLength]) +} + +// SetFrameType sets FrameType for the HSFrame. +func (f HSFrame) SetFrameType(frameType HSFrameType) { + _ = f[HSFrameProcIDLength] // bounds check hint to compiler; see golang.org/issue/14808 + f[HSFrameProcIDLength] = byte(frameType) +} + +// BodyLen gets BodyLen from the HSFrame. +func (f HSFrame) BodyLen() int { + return int(binary.BigEndian.Uint16(f[HSFrameProcIDLength+HSFrameTypeLength:])) +} + +// SetBodyLen sets BodyLen for the HSFrame. +func (f HSFrame) SetBodyLen(bodyLen int) error { + if bodyLen > HSFrameMaxBodyLength { + return ErrHSFrameBodyTooLong + } + + binary.BigEndian.PutUint16(f[HSFrameProcIDLength+HSFrameTypeLength:], uint16(bodyLen)) + + return nil +} + +func marshalHSFrameBody(body interface{}) ([]byte, error) { + bodyBytes, err := json.Marshal(body) + if err != nil { + return nil, err + } + + if len(bodyBytes) > HSFrameMaxBodyLength { + return nil, ErrHSFrameBodyTooLong + } + + return bodyBytes, nil +} From 7deb42b280ed99a4399a1e719eb5e8c7362ae16c Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Sat, 31 Aug 2019 10:45:08 +0300 Subject: [PATCH 03/43] Fix PR queries --- pkg/app2/hsframe.go | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/pkg/app2/hsframe.go b/pkg/app2/hsframe.go index 05fe44c9bd..e79891082c 100644 --- a/pkg/app2/hsframe.go +++ b/pkg/app2/hsframe.go @@ -8,17 +8,17 @@ import ( ) const ( - HSFrameHeaderLength = 5 - HSFrameProcIDLength = 2 - HSFrameTypeLength = 1 - HSFrameBodyLenLength = 2 - HSFrameMaxBodyLength = math.MaxUint16 + HSFrameHeaderLen = 5 + HSFrameProcIDLen = 2 + HSFrameTypeLen = 1 + HSFrameBodyLenLen = 2 + HSFrameMaxBodyLen = math.MaxUint16 ) var ( - // ErrHSFrameBodyTooLong is being returned when the body is too long to be + // ErrHSFrameBodyTooLarge is being returned when the body is too long to be // fit in the HSFrame - ErrHSFrameBodyTooLong = errors.New("frame body is too long") + ErrHSFrameBodyTooLarge = errors.New("frame body is too long") ) // HSFrameType identifies the type of a handshake frame. @@ -47,13 +47,13 @@ func NewHSFrame(procID ProcID, frameType HSFrameType, body interface{}) (HSFrame return nil, err } - hsFrame := make(HSFrame, HSFrameHeaderLength+len(bodyBytes)) + hsFrame := make(HSFrame, HSFrameHeaderLen+len(bodyBytes)) hsFrame.SetProcID(procID) hsFrame.SetFrameType(frameType) _ = hsFrame.SetBodyLen(len(bodyBytes)) - copy(hsFrame[HSFrameProcIDLength+HSFrameTypeLength+HSFrameBodyLenLength:], bodyBytes) + copy(hsFrame[HSFrameProcIDLen+HSFrameTypeLen+HSFrameBodyLenLen:], bodyBytes) return hsFrame, nil } @@ -70,28 +70,28 @@ func (f HSFrame) SetProcID(procID ProcID) { // FrameType gets FrameType from the HSFrame. func (f HSFrame) FrameType() HSFrameType { - _ = f[HSFrameProcIDLength] // bounds check hint to compiler; see golang.org/issue/14808 - return HSFrameType(f[HSFrameProcIDLength]) + _ = f[HSFrameProcIDLen] // bounds check hint to compiler; see golang.org/issue/14808 + return HSFrameType(f[HSFrameProcIDLen]) } // SetFrameType sets FrameType for the HSFrame. func (f HSFrame) SetFrameType(frameType HSFrameType) { - _ = f[HSFrameProcIDLength] // bounds check hint to compiler; see golang.org/issue/14808 - f[HSFrameProcIDLength] = byte(frameType) + _ = f[HSFrameProcIDLen] // bounds check hint to compiler; see golang.org/issue/14808 + f[HSFrameProcIDLen] = byte(frameType) } // BodyLen gets BodyLen from the HSFrame. func (f HSFrame) BodyLen() int { - return int(binary.BigEndian.Uint16(f[HSFrameProcIDLength+HSFrameTypeLength:])) + return int(binary.BigEndian.Uint16(f[HSFrameProcIDLen+HSFrameTypeLen:])) } // SetBodyLen sets BodyLen for the HSFrame. func (f HSFrame) SetBodyLen(bodyLen int) error { - if bodyLen > HSFrameMaxBodyLength { - return ErrHSFrameBodyTooLong + if bodyLen > HSFrameMaxBodyLen { + return ErrHSFrameBodyTooLarge } - binary.BigEndian.PutUint16(f[HSFrameProcIDLength+HSFrameTypeLength:], uint16(bodyLen)) + binary.BigEndian.PutUint16(f[HSFrameProcIDLen+HSFrameTypeLen:], uint16(bodyLen)) return nil } @@ -102,8 +102,8 @@ func marshalHSFrameBody(body interface{}) ([]byte, error) { return nil, err } - if len(bodyBytes) > HSFrameMaxBodyLength { - return nil, ErrHSFrameBodyTooLong + if len(bodyBytes) > HSFrameMaxBodyLen { + return nil, ErrHSFrameBodyTooLarge } return bodyBytes, nil From 95a0d28436e235721395028de83fc42e46d9b855 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Sat, 31 Aug 2019 11:05:44 +0300 Subject: [PATCH 04/43] Add HSFrame tests --- pkg/app2/hsframe_test.go | 69 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 pkg/app2/hsframe_test.go diff --git a/pkg/app2/hsframe_test.go b/pkg/app2/hsframe_test.go new file mode 100644 index 0000000000..ab66e376f9 --- /dev/null +++ b/pkg/app2/hsframe_test.go @@ -0,0 +1,69 @@ +package app2 + +import ( + "encoding/binary" + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHSFrame(t *testing.T) { + t.Run("ok", func(t *testing.T) { + body := struct { + Test string `json:"test"` + }{ + Test: "some string", + } + + bodyBytes, err := json.Marshal(body) + require.NoError(t, err) + + procID := ProcID(1) + frameType := HSFrameTypeDMSGListen + bodyLen := len(bodyBytes) + + hsFrame, err := NewHSFrame(procID, frameType, body) + require.NoError(t, err) + + require.Equal(t, len(hsFrame), HSFrameHeaderLen+len(bodyBytes)) + + gotProcID := ProcID(binary.BigEndian.Uint16(hsFrame)) + require.Equal(t, gotProcID, procID) + + gotFrameType := HSFrameType(hsFrame[HSFrameProcIDLen]) + require.Equal(t, gotFrameType, frameType) + + gotBodyLen := int(binary.BigEndian.Uint16(hsFrame[HSFrameProcIDLen+HSFrameTypeLen:])) + require.Equal(t, gotBodyLen, bodyLen) + + require.Equal(t, bodyBytes, []byte(hsFrame[HSFrameProcIDLen+HSFrameTypeLen+HSFrameBodyLenLen:])) + + gotProcID = hsFrame.ProcID() + require.Equal(t, gotProcID, procID) + + gotFrameType = hsFrame.FrameType() + require.Equal(t, gotFrameType, frameType) + + gotBodyLen = hsFrame.BodyLen() + require.Equal(t, gotBodyLen, bodyLen) + }) + + t.Run("fail - too large body", func(t *testing.T) { + body := struct { + Test string `json:"test"` + }{ + Test: "some string", + } + + for len(body.Test) <= HSFrameMaxBodyLen { + body.Test += body.Test + } + + procID := ProcID(1) + frameType := HSFrameTypeDMSGListen + + _, err := NewHSFrame(procID, frameType, body) + require.Equal(t, err, ErrHSFrameBodyTooLarge) + }) +} From 804b770056a8ea81dbb6ec3b8ca27a6d231336c0 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Sun, 1 Sep 2019 19:46:27 +0300 Subject: [PATCH 05/43] Add more work --- pkg/app2/client.go | 104 +++++++++++++++++++++++++++++++++++++ pkg/app2/hsframe.go | 121 ++++++++++++++++++++++++++------------------ pkg/app2/server.go | 114 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 290 insertions(+), 49 deletions(-) diff --git a/pkg/app2/client.go b/pkg/app2/client.go index 658af578d0..0ba012dcc1 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -1 +1,105 @@ package app2 + +import ( + "net" + + "github.com/pkg/errors" + + "github.com/skycoin/skywire/pkg/routing" + + "github.com/skycoin/dmsg/cipher" +) + +var ( + ErrWrongHSFrameTypeReceived = errors.New("received wrong HS frame type") +) + +type Listener struct { + conn net.Conn +} + +func (l *Listener) Accept() (net.Conn, error) { + hsFrame, err := readHSFrame(l.conn) + if err != nil { + return nil, errors.Wrap(err, "error reading HS frame") + } + + if hsFrame.FrameType() != HSFrameTypeDMSGAccept { + return nil, ErrWrongHSFrameTypeReceived + } + + return l.conn, nil +} + +// Client is used by skywire apps. +type Client struct { + PK cipher.PubKey + pid ProcID + sockAddr string + conn net.Conn +} + +// NewClient creates a new Client. The Client needs to be provided with: +// - localPK: The local public key of the parent skywire visor. +// - pid: The procID assigned for the process that Client is being used by. +// - sockAddr: The socket address to connect to Server. +func NewClient(localPK cipher.PubKey, pid ProcID, sockAddr string) (*Client, error) { + return &Client{ + PK: localPK, + pid: pid, + sockAddr: sockAddr, + }, nil +} + +func (c *Client) Dial(addr routing.Addr) (net.Conn, error) { + conn, err := net.Dial("unix", c.sockAddr) + if err != nil { + return nil, errors.Wrap(err, "error connecting app server") + } + + hsFrame := NewHSFrameDSMGDial(c.pid, routing.Loop{ + Local: routing.Addr{ + PubKey: c.PK, + }, + Remote: addr, + }) + if _, err := conn.Write(hsFrame); err != nil { + return nil, errors.Wrap(err, "error writing HS frame") + } + + hsFrame, err = readHSFrame(conn) + if err != nil { + return nil, errors.Wrap(err, "error reading HS frame") + } + + if hsFrame.FrameType() != HSFrameTypeDMSGAccept { + return nil, ErrWrongHSFrameTypeReceived + } + + return conn, nil +} + +func (c *Client) Listen(addr routing.Addr) (*Listener, error) { + conn, err := net.Dial("unix", c.sockAddr) + if err != nil { + return nil, errors.Wrap(err, "error connecting app server") + } + + hsFrame := NewHSFrameDMSGListen(c.pid, addr) + if _, err := conn.Write(hsFrame); err != nil { + return nil, errors.Wrap(err, "error writing HS frame") + } + + hsFrame, err = readHSFrame(conn) + if err != nil { + return nil, errors.Wrap(err, "error reading HS frame") + } + + if hsFrame.FrameType() != HSFrameTypeDMSGListening { + return nil, ErrWrongHSFrameTypeReceived + } + + return &Listener{ + conn: conn, + }, nil +} diff --git a/pkg/app2/hsframe.go b/pkg/app2/hsframe.go index e79891082c..8b5f0277fb 100644 --- a/pkg/app2/hsframe.go +++ b/pkg/app2/hsframe.go @@ -2,23 +2,18 @@ package app2 import ( "encoding/binary" - "encoding/json" - "errors" - "math" -) + "io" -const ( - HSFrameHeaderLen = 5 - HSFrameProcIDLen = 2 - HSFrameTypeLen = 1 - HSFrameBodyLenLen = 2 - HSFrameMaxBodyLen = math.MaxUint16 + "github.com/pkg/errors" + "github.com/skycoin/skywire/pkg/routing" ) -var ( - // ErrHSFrameBodyTooLarge is being returned when the body is too long to be - // fit in the HSFrame - ErrHSFrameBodyTooLarge = errors.New("frame body is too long") +const ( + HSFrameHeaderLen = 3 + HSFrameProcIDLen = 2 + HSFrameTypeLen = 1 + HSFramePKLen = 33 + HSFramePortLen = 2 ) // HSFrameType identifies the type of a handshake frame. @@ -34,28 +29,59 @@ const ( // HSFrame is the data unit for socket connection handshakes between Server and Client. // It consists of header and body. // -// Header is a big-endian encoded 5 bytes and is constructed as follows: -// | ProcID (2 bytes) | HSFrameType (1 byte) | BodyLen (2 bytes) | -// -// Body is a marshaled JSON structure +// Header is a big-endian encoded 3 bytes and is constructed as follows: +// | ProcID (2 bytes) | HSFrameType (1 byte) | type HSFrame []byte -// NewHSFrame constructs new HSFrame. -func NewHSFrame(procID ProcID, frameType HSFrameType, body interface{}) (HSFrame, error) { - bodyBytes, err := marshalHSFrameBody(body) - if err != nil { - return nil, err - } - - hsFrame := make(HSFrame, HSFrameHeaderLen+len(bodyBytes)) +func newHSFrame(procID ProcID, frameType HSFrameType, bodyLen int) HSFrame { + hsFrame := make(HSFrame, HSFrameHeaderLen+bodyLen) hsFrame.SetProcID(procID) hsFrame.SetFrameType(frameType) - _ = hsFrame.SetBodyLen(len(bodyBytes)) - copy(hsFrame[HSFrameProcIDLen+HSFrameTypeLen+HSFrameBodyLenLen:], bodyBytes) + return hsFrame +} - return hsFrame, nil +func NewHSFrameDMSGListen(procID ProcID, local routing.Addr) HSFrame { + hsFrame := newHSFrame(procID, HSFrameTypeDMSGListen, HSFramePKLen+HSFramePortLen) + + copy(hsFrame[HSFrameHeaderLen:], local.PubKey[:]) + binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:], uint16(local.Port)) + + return hsFrame +} + +func NewHSFrameDMSGListening(procID ProcID, local routing.Addr) HSFrame { + hsFrame := newHSFrame(procID, HSFrameTypeDMSGListening, HSFramePKLen+HSFramePortLen) + + copy(hsFrame[HSFrameHeaderLen:], local.PubKey[:]) + binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:], uint16(local.Port)) + + return hsFrame +} + +func NewHSFrameDSMGDial(procID ProcID, loop routing.Loop) HSFrame { + hsFrame := newHSFrame(procID, HSFrameTypeDMSGDial, 2*HSFramePKLen+2*HSFramePortLen) + + copy(hsFrame[HSFrameHeaderLen:], loop.Local.PubKey[:]) + binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:], uint16(loop.Local.Port)) + + copy(hsFrame[HSFrameHeaderLen+HSFramePKLen+HSFramePortLen:], loop.Remote.PubKey[:]) + binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+2*HSFramePKLen+HSFramePortLen:], uint16(loop.Remote.Port)) + + return hsFrame +} + +func NewHSFrameDMSGAccept(procID ProcID, loop routing.Loop) HSFrame { + hsFrame := newHSFrame(procID, HSFrameTypeDMSGAccept, 2*HSFramePKLen+2*HSFramePortLen) + + copy(hsFrame[HSFrameHeaderLen:], loop.Local.PubKey[:]) + binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:], uint16(loop.Local.Port)) + + copy(hsFrame[HSFrameHeaderLen+HSFramePKLen+HSFramePortLen:], loop.Remote.PubKey[:]) + binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+2*HSFramePKLen+HSFramePortLen:], uint16(loop.Remote.Port)) + + return hsFrame } // ProcID gets ProcID from the HSFrame. @@ -80,31 +106,28 @@ func (f HSFrame) SetFrameType(frameType HSFrameType) { f[HSFrameProcIDLen] = byte(frameType) } -// BodyLen gets BodyLen from the HSFrame. -func (f HSFrame) BodyLen() int { - return int(binary.BigEndian.Uint16(f[HSFrameProcIDLen+HSFrameTypeLen:])) -} - -// SetBodyLen sets BodyLen for the HSFrame. -func (f HSFrame) SetBodyLen(bodyLen int) error { - if bodyLen > HSFrameMaxBodyLen { - return ErrHSFrameBodyTooLarge +func readHSFrame(r io.Reader) (HSFrame, error) { + hsFrame := make(HSFrame, HSFrameHeaderLen) + if _, err := io.ReadFull(r, hsFrame); err != nil { + return nil, errors.Wrap(err, "error reading HS frame header") } - binary.BigEndian.PutUint16(f[HSFrameProcIDLen+HSFrameTypeLen:], uint16(bodyLen)) - - return nil -} - -func marshalHSFrameBody(body interface{}) ([]byte, error) { - bodyBytes, err := json.Marshal(body) + hsFrame, err := readHSFrameBody(hsFrame, r) if err != nil { - return nil, err + return nil, errors.Wrap(err, "error reading HS frame body") } - if len(bodyBytes) > HSFrameMaxBodyLen { - return nil, ErrHSFrameBodyTooLarge + return hsFrame, nil +} + +func readHSFrameBody(hsFrame HSFrame, r io.Reader) (HSFrame, error) { + switch hsFrame.FrameType() { + case HSFrameTypeDMSGListen, HSFrameTypeDMSGListening: + hsFrame = append(hsFrame, make([]byte, HSFramePKLen+HSFramePortLen)...) + case HSFrameTypeDMSGDial, HSFrameTypeDMSGAccept: + hsFrame = append(hsFrame, make([]byte, 2*HSFramePKLen+2*HSFramePortLen)...) } - return bodyBytes, nil + _, err := io.ReadFull(r, hsFrame[HSFrameHeaderLen:]) + return hsFrame, err } diff --git a/pkg/app2/server.go b/pkg/app2/server.go index 658af578d0..dd0b7adb79 100644 --- a/pkg/app2/server.go +++ b/pkg/app2/server.go @@ -1 +1,115 @@ package app2 + +import ( + "encoding/binary" + "fmt" + "net" + "sync" + + "github.com/skycoin/dmsg" + + "github.com/skycoin/skywire/pkg/routing" + + "github.com/pkg/errors" + + "github.com/skycoin/dmsg/cipher" +) + +// Server is used by skywire visor. +type Server struct { + PK cipher.PubKey + dmsgC *dmsg.Client + apps map[string]net.Conn + appsMx sync.Mutex +} + +func NewServer(localPK cipher.PubKey, dmsgC *dmsg.Client) *Server { + return &Server{ + PK: localPK, + dmsgC: dmsgC, + } +} + +func (s *Server) Serve(sockAddr string) error { + l, err := net.Listen("unix", sockAddr) + if err != nil { + return errors.Wrap(err, "error listening unix socket") + } + + for { + conn, err := l.Accept() + if err != nil { + return err + } + + s.appsMx.Lock() + s.apps[conn.RemoteAddr().String()] = conn + s.appsMx.Unlock() + + // TODO: handle error + go s.serveConn(conn) + } +} + +func (s *Server) serveConn(conn net.Conn) error { + var hsFinished bool + + for { + hsFrame, err := readHSFrame(conn) + if err != nil { + return errors.Wrap(err, "error reading HS frame") + } + + switch hsFrame.FrameType() { + case HSFrameTypeDMSGListen: + pk := make(cipher.PubKey, 33) + copy(pk, hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) + port := binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:]) + dmsgL, err := s.dmsgC.Listen(port) + if err != nil { + return fmt.Errorf("error listening on port %d: %v", port, err) + } + + respHSFrame := NewHSFrameDMSGListening(hsFrame.ProcID(), routing.Addr{ + PubKey: cipher.PubKey(pk), + Port: 0, + }) + } + } +} + +func (s *Server) handleDMSGListen(frame HSFrame) error { + var local routing.Addr + if err := frame.UnmarshalBody(&local); err != nil { + return errors.Wrap(err, "invalid JSON body") + } + + // TODO: check `local` for validity + + dmsgL, err := s.dmsgC.Listen(uint16(local.Port)) + if err != nil { + return fmt.Errorf("error listening on port %d: %v", local.Port, err) + } + +} + +func (s *Server) handleDMSGListening(frame HSFrame) error { + var local routing.Addr + if err := frame.UnmarshalBody(&local); err != nil { + return errors.Wrap(err, "invalid JSON body") + } +} + +func (s *Server) handleDMSGDial(frame HSFrame) error { + var loop routing.Loop + if err := frame.UnmarshalBody(&loop); err != nil { + return errors.Wrap(err, "invalid JSON body") + } +} + +func (s *Server) handleDMSGAccept(frame HSFrame) error { + var loop routing.Loop + if err := frame.UnmarshalBody(&loop); err != nil { + return errors.Wrap(err, "invalid JSON body") + } +} From 5029589a010cb15de0ca6bc59959e2eb93d8af81 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Fri, 6 Sep 2019 14:14:52 +0300 Subject: [PATCH 06/43] Add yamux conn multiplexing --- pkg/app2/client.go | 103 ++++++++++++++++++++++++---------- pkg/app2/listener.go | 58 +++++++++++++++++++ pkg/app2/listeners_manager.go | 65 +++++++++++++++++++++ 3 files changed, 196 insertions(+), 30 deletions(-) create mode 100644 pkg/app2/listener.go create mode 100644 pkg/app2/listeners_manager.go diff --git a/pkg/app2/client.go b/pkg/app2/client.go index 0ba012dcc1..d17cb923db 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -1,10 +1,14 @@ package app2 import ( + "encoding/binary" "net" + "github.com/hashicorp/yamux" + "github.com/pkg/errors" + "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/pkg/routing" "github.com/skycoin/dmsg/cipher" @@ -14,29 +18,15 @@ var ( ErrWrongHSFrameTypeReceived = errors.New("received wrong HS frame type") ) -type Listener struct { - conn net.Conn -} - -func (l *Listener) Accept() (net.Conn, error) { - hsFrame, err := readHSFrame(l.conn) - if err != nil { - return nil, errors.Wrap(err, "error reading HS frame") - } - - if hsFrame.FrameType() != HSFrameTypeDMSGAccept { - return nil, ErrWrongHSFrameTypeReceived - } - - return l.conn, nil -} - // Client is used by skywire apps. type Client struct { PK cipher.PubKey pid ProcID sockAddr string conn net.Conn + session *yamux.Session + logger *logging.Logger + lm *listenersManager } // NewClient creates a new Client. The Client needs to be provided with: @@ -44,17 +34,30 @@ type Client struct { // - pid: The procID assigned for the process that Client is being used by. // - sockAddr: The socket address to connect to Server. func NewClient(localPK cipher.PubKey, pid ProcID, sockAddr string) (*Client, error) { + conn, err := net.Dial("unix", sockAddr) + if err != nil { + return nil, errors.Wrap(err, "error connecting app server") + } + + session, err := yamux.Client(conn, nil) + if err != nil { + return nil, errors.Wrap(err, "error opening yamux session") + } + return &Client{ PK: localPK, pid: pid, sockAddr: sockAddr, + conn: conn, + session: session, + lm: newListenersManager(), }, nil } func (c *Client) Dial(addr routing.Addr) (net.Conn, error) { - conn, err := net.Dial("unix", c.sockAddr) + stream, err := c.session.Open() if err != nil { - return nil, errors.Wrap(err, "error connecting app server") + return nil, errors.Wrap(err, "error opening stream") } hsFrame := NewHSFrameDSMGDial(c.pid, routing.Loop{ @@ -63,11 +66,12 @@ func (c *Client) Dial(addr routing.Addr) (net.Conn, error) { }, Remote: addr, }) - if _, err := conn.Write(hsFrame); err != nil { + + if _, err := stream.Write(hsFrame); err != nil { return nil, errors.Wrap(err, "error writing HS frame") } - hsFrame, err = readHSFrame(conn) + hsFrame, err = readHSFrame(stream) if err != nil { return nil, errors.Wrap(err, "error reading HS frame") } @@ -76,21 +80,30 @@ func (c *Client) Dial(addr routing.Addr) (net.Conn, error) { return nil, ErrWrongHSFrameTypeReceived } - return conn, nil + return stream, nil } -func (c *Client) Listen(addr routing.Addr) (*Listener, error) { - conn, err := net.Dial("unix", c.sockAddr) +func (c *Client) Listen(port routing.Port) (*Listener, error) { + if c.lm.portIsBound(port) { + return nil, ErrPortAlreadyBound + } + + stream, err := c.session.Open() if err != nil { - return nil, errors.Wrap(err, "error connecting app server") + return nil, errors.Wrap(err, "error opening stream") + } + + addr := routing.Addr{ + PubKey: c.PK, + Port: port, } hsFrame := NewHSFrameDMSGListen(c.pid, addr) - if _, err := conn.Write(hsFrame); err != nil { + if _, err := stream.Write(hsFrame); err != nil { return nil, errors.Wrap(err, "error writing HS frame") } - hsFrame, err = readHSFrame(conn) + hsFrame, err = readHSFrame(stream) if err != nil { return nil, errors.Wrap(err, "error reading HS frame") } @@ -99,7 +112,37 @@ func (c *Client) Listen(addr routing.Addr) (*Listener, error) { return nil, ErrWrongHSFrameTypeReceived } - return &Listener{ - conn: conn, - }, nil + l := NewListener(addr, c.lm) + if err := c.lm.add(port, l); err != nil { + return nil, err + } + + return l, nil +} + +func (c *Client) listen() error { + for { + stream, err := c.session.Accept() + if err != nil { + return errors.Wrap(err, "error accepting stream") + } + + hsFrame, err := readHSFrame(stream) + if err != nil { + c.logger.WithError(err).Error("error reading HS frame") + continue + } + + if hsFrame.FrameType() != HSFrameTypeDMSGDial { + c.logger.WithError(ErrWrongHSFrameTypeReceived).Error("on listening for Dial") + continue + } + + // TODO: handle field get gracefully + port := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:])) + if err := c.lm.addConn(port, stream); err != nil { + c.logger.WithError(err).Error("failed to accept") + continue + } + } } diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go new file mode 100644 index 0000000000..069fcdb0a9 --- /dev/null +++ b/pkg/app2/listener.go @@ -0,0 +1,58 @@ +package app2 + +import ( + "errors" + "net" + + "github.com/skycoin/skywire/pkg/routing" +) + +const ( + listenerBufSize = 1000 +) + +var ( + ErrListenerClosed = errors.New("listener closed") +) + +type Listener struct { + addr routing.Addr + conns chan net.Conn + lm *listenersManager +} + +func NewListener(addr routing.Addr, lm *listenersManager) *Listener { + return &Listener{ + addr: addr, + conns: make(chan net.Conn, listenerBufSize), + lm: lm, + } +} + +func (l *Listener) Accept() (net.Conn, error) { + conn, ok := <-l.conns + if !ok { + return nil, ErrListenerClosed + } + + return conn, nil +} + +func (l *Listener) Close() error { + if err := l.lm.remove(l.addr.Port); err != nil { + return err + } + + // TODO: send ListenEnd frame + close(l.conns) + + return nil +} + +func (l *Listener) Addr() net.Addr { + return l.addr +} + +func (l *Listener) addConn(conn net.Conn) { + l.conns <- conn +} diff --git a/pkg/app2/listeners_manager.go b/pkg/app2/listeners_manager.go new file mode 100644 index 0000000000..b5bc8e3b83 --- /dev/null +++ b/pkg/app2/listeners_manager.go @@ -0,0 +1,65 @@ +package app2 + +import ( + "net" + "sync" + + "github.com/pkg/errors" + "github.com/skycoin/skywire/pkg/routing" +) + +var ( + ErrPortAlreadyBound = errors.New("port is already bound") + ErrNoListenerOnPort = errors.New("no listener on port") +) + +type listenersManager struct { + listeners map[routing.Port]*Listener + mx sync.RWMutex +} + +func newListenersManager() *listenersManager { + return &listenersManager{ + listeners: make(map[routing.Port]*Listener), + } +} + +func (lm *listenersManager) portIsBound(port routing.Port) bool { + lm.mx.RLock() + _, ok := lm.listeners[port] + lm.mx.RUnlock() + return ok +} + +func (lm *listenersManager) add(port routing.Port, l *Listener) error { + lm.mx.Lock() + if _, ok := lm.listeners[port]; ok { + lm.mx.Unlock() + return ErrPortAlreadyBound + } + lm.listeners[port] = l + lm.mx.Unlock() + return nil +} + +func (lm *listenersManager) remove(port routing.Port) error { + lm.mx.Lock() + if _, ok := lm.listeners[port]; !ok { + lm.mx.Unlock() + return ErrNoListenerOnPort + } + delete(lm.listeners, port) + lm.mx.Unlock() + return nil +} + +func (lm *listenersManager) addConn(port routing.Port, conn net.Conn) error { + lm.mx.RLock() + if _, ok := lm.listeners[port]; !ok { + lm.mx.RUnlock() + return ErrNoListenerOnPort + } + lm.listeners[port].addConn(conn) + lm.mx.RUnlock() + return nil +} From 972fa9912a4ac164d44bedb2237d7eab891400a1 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Sun, 8 Sep 2019 17:26:13 +0300 Subject: [PATCH 07/43] Add some shape to the client --- pkg/app2/client.go | 56 ++++++++++++++++++------ pkg/app2/hsframe.go | 10 +++++ pkg/app2/listener.go | 25 +++++++---- pkg/app2/listeners_manager.go | 19 +++++--- pkg/app2/server.go | 82 ++++++++++++++++++++++++++++++----- 5 files changed, 154 insertions(+), 38 deletions(-) diff --git a/pkg/app2/client.go b/pkg/app2/client.go index d17cb923db..4e2e5d754d 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -3,6 +3,7 @@ package app2 import ( "encoding/binary" "net" + "sync/atomic" "github.com/hashicorp/yamux" @@ -20,20 +21,21 @@ var ( // Client is used by skywire apps. type Client struct { - PK cipher.PubKey - pid ProcID - sockAddr string - conn net.Conn - session *yamux.Session - logger *logging.Logger - lm *listenersManager + PK cipher.PubKey + pid ProcID + sockAddr string + conn net.Conn + session *yamux.Session + logger *logging.Logger + lm *listenersManager + isListening int32 } // NewClient creates a new Client. The Client needs to be provided with: // - localPK: The local public key of the parent skywire visor. // - pid: The procID assigned for the process that Client is being used by. // - sockAddr: The socket address to connect to Server. -func NewClient(localPK cipher.PubKey, pid ProcID, sockAddr string) (*Client, error) { +func NewClient(localPK cipher.PubKey, pid ProcID, sockAddr string, l *logging.Logger) (*Client, error) { conn, err := net.Dial("unix", sockAddr) if err != nil { return nil, errors.Wrap(err, "error connecting app server") @@ -44,13 +46,15 @@ func NewClient(localPK cipher.PubKey, pid ProcID, sockAddr string) (*Client, err return nil, errors.Wrap(err, "error opening yamux session") } + lm := newListenersManager() + return &Client{ PK: localPK, pid: pid, sockAddr: sockAddr, conn: conn, session: session, - lm: newListenersManager(), + lm: lm, }, nil } @@ -112,12 +116,15 @@ func (c *Client) Listen(port routing.Port) (*Listener, error) { return nil, ErrWrongHSFrameTypeReceived } - l := NewListener(addr, c.lm) - if err := c.lm.add(port, l); err != nil { - return nil, err + if atomic.CompareAndSwapInt32(&c.isListening, 0, 1) { + go func() { + if err := c.listen(); err != nil { + c.logger.WithError(err).Error("error listening") + } + }() } - return l, nil + return c.lm.add(addr, c.stopListening, c.logger) } func (c *Client) listen() error { @@ -146,3 +153,26 @@ func (c *Client) listen() error { } } } + +func (c *Client) stopListening(port routing.Port) error { + stream, err := c.session.Open() + if err != nil { + return errors.Wrap(err, "error opening stream") + } + + addr := routing.Addr{ + PubKey: c.PK, + Port: port, + } + + hsFrame := NewHSFrameDMSGStopListening(c.pid, addr) + if _, err := stream.Write(hsFrame); err != nil { + return errors.Wrap(err, "error writing HS frame") + } + + if err := stream.Close(); err != nil { + return errors.Wrap(err, "error closing stream") + } + + return nil +} diff --git a/pkg/app2/hsframe.go b/pkg/app2/hsframe.go index 8b5f0277fb..c174363ff2 100644 --- a/pkg/app2/hsframe.go +++ b/pkg/app2/hsframe.go @@ -24,6 +24,7 @@ const ( HSFrameTypeDMSGListening HSFrameTypeDMSGDial HSFrameTypeDMSGAccept + HSFrameTypeStopListening ) // HSFrame is the data unit for socket connection handshakes between Server and Client. @@ -84,6 +85,15 @@ func NewHSFrameDMSGAccept(procID ProcID, loop routing.Loop) HSFrame { return hsFrame } +func NewHSFrameDMSGStopListening(procID ProcID, local routing.Addr) HSFrame { + hsFrame := newHSFrame(procID, HSFrameTypeDMSGListen, HSFramePKLen+HSFramePortLen) + + copy(hsFrame[HSFrameHeaderLen:], local.PubKey[:]) + binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:], uint16(local.Port)) + + return hsFrame +} + // ProcID gets ProcID from the HSFrame. func (f HSFrame) ProcID() ProcID { return ProcID(binary.BigEndian.Uint16(f)) diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go index 069fcdb0a9..1a5c14b6ae 100644 --- a/pkg/app2/listener.go +++ b/pkg/app2/listener.go @@ -4,6 +4,8 @@ import ( "errors" "net" + "github.com/skycoin/skycoin/src/util/logging" + "github.com/skycoin/skywire/pkg/routing" ) @@ -16,16 +18,20 @@ var ( ) type Listener struct { - addr routing.Addr - conns chan net.Conn - lm *listenersManager + addr routing.Addr + conns chan net.Conn + stopListening func(port routing.Port) error + logger *logging.Logger + lm *listenersManager } -func NewListener(addr routing.Addr, lm *listenersManager) *Listener { +func NewListener(addr routing.Addr, lm *listenersManager, stopListening func(port routing.Port) error, l *logging.Logger) *Listener { return &Listener{ - addr: addr, - conns: make(chan net.Conn, listenerBufSize), - lm: lm, + addr: addr, + conns: make(chan net.Conn, listenerBufSize), + lm: lm, + stopListening: stopListening, + logger: l, } } @@ -39,11 +45,14 @@ func (l *Listener) Accept() (net.Conn, error) { } func (l *Listener) Close() error { + if err := l.stopListening(l.addr.Port); err != nil { + l.logger.WithError(err).Error("error sending DmsgStopListening") + } + if err := l.lm.remove(l.addr.Port); err != nil { return err } - // TODO: send ListenEnd frame close(l.conns) return nil diff --git a/pkg/app2/listeners_manager.go b/pkg/app2/listeners_manager.go index b5bc8e3b83..251c8b2af0 100644 --- a/pkg/app2/listeners_manager.go +++ b/pkg/app2/listeners_manager.go @@ -4,18 +4,22 @@ import ( "net" "sync" + "github.com/skycoin/skycoin/src/util/logging" + "github.com/pkg/errors" "github.com/skycoin/skywire/pkg/routing" ) var ( - ErrPortAlreadyBound = errors.New("port is already bound") - ErrNoListenerOnPort = errors.New("no listener on port") + ErrPortAlreadyBound = errors.New("port is already bound") + ErrNoListenerOnPort = errors.New("no listener on port") + ErrListenersManagerAlreadyServing = errors.New("listeners manager already serving") ) type listenersManager struct { listeners map[routing.Port]*Listener mx sync.RWMutex + isServing int32 } func newListenersManager() *listenersManager { @@ -31,15 +35,16 @@ func (lm *listenersManager) portIsBound(port routing.Port) bool { return ok } -func (lm *listenersManager) add(port routing.Port, l *Listener) error { +func (lm *listenersManager) add(addr routing.Addr, stopListening func(port routing.Port) error, logger *logging.Logger) (*Listener, error) { lm.mx.Lock() - if _, ok := lm.listeners[port]; ok { + if _, ok := lm.listeners[addr.Port]; ok { lm.mx.Unlock() - return ErrPortAlreadyBound + return nil, ErrPortAlreadyBound } - lm.listeners[port] = l + l := NewListener(addr, lm, stopListening, logger) + lm.listeners[addr.Port] = l lm.mx.Unlock() - return nil + return l, nil } func (lm *listenersManager) remove(port routing.Port) error { diff --git a/pkg/app2/server.go b/pkg/app2/server.go index dd0b7adb79..163f46d9a9 100644 --- a/pkg/app2/server.go +++ b/pkg/app2/server.go @@ -6,6 +6,10 @@ import ( "net" "sync" + "github.com/skycoin/skycoin/src/util/logging" + + "github.com/hashicorp/yamux" + "github.com/skycoin/dmsg" "github.com/skycoin/skywire/pkg/routing" @@ -15,18 +19,28 @@ import ( "github.com/skycoin/dmsg/cipher" ) +type clientConn struct { + conn net.Conn + session *yamux.Session + lm *listenersManager + dmsgL *dmsg.Listener +} + // Server is used by skywire visor. type Server struct { PK cipher.PubKey dmsgC *dmsg.Client - apps map[string]net.Conn + apps map[string]*clientConn appsMx sync.Mutex + logger *logging.Logger } -func NewServer(localPK cipher.PubKey, dmsgC *dmsg.Client) *Server { +func NewServer(localPK cipher.PubKey, dmsgC *dmsg.Client, l *logging.Logger) *Server { return &Server{ - PK: localPK, - dmsgC: dmsgC, + PK: localPK, + dmsgC: dmsgC, + apps: make(map[string]*clientConn), + logger: l, } } @@ -39,22 +53,40 @@ func (s *Server) Serve(sockAddr string) error { for { conn, err := l.Accept() if err != nil { - return err + return errors.Wrap(err, "error accepting client connection") } s.appsMx.Lock() - s.apps[conn.RemoteAddr().String()] = conn + if _, ok := s.apps[conn.RemoteAddr().String()]; ok { + s.logger.WithError(ErrPortAlreadyBound).Error("error storing session") + } + + session, err := yamux.Server(conn, nil) + if err != nil { + return errors.Wrap(err, "error creating yamux session") + } + + s.apps[conn.RemoteAddr().String()] = &clientConn{ + session: session, + conn: conn, + lm: newListenersManager(), + } s.appsMx.Unlock() // TODO: handle error - go s.serveConn(conn) + go s.serveClient(session) } } -func (s *Server) serveConn(conn net.Conn) error { - var hsFinished bool - +func (s *Server) serveClient(session *yamux.Session) error { for { + stream, err := session.Accept() + if err != nil { + return errors.Wrap(err, "error opening stream") + } + + go s.serveStream(stream) + hsFrame, err := readHSFrame(conn) if err != nil { return errors.Wrap(err, "error reading HS frame") @@ -78,6 +110,36 @@ func (s *Server) serveConn(conn net.Conn) error { } } +func (s *Server) serveStream(stream net.Conn) error { + for { + hsFrame, err := readHSFrame(stream) + if err != nil { + return errors.Wrap(err, "error reading HS frame") + } + + switch hsFrame.FrameType() { + case HSFrameTypeDMSGListen: + var pk cipher.PubKey + copy(pk[:], hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) + port := binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:]) + dmsgL, err := s.dmsgC.Listen(port) + if err != nil { + return fmt.Errorf("error listening on port %d: %v", port, err) + } + + respHSFrame := NewHSFrameDMSGListening(hsFrame.ProcID(), routing.Addr{ + PubKey: pk, + Port: routing.Port(port), + }) + + if _, err := stream.Write(respHSFrame); err != nil { + return errors.Wrap(err, "error writing response") + } + + } + } +} + func (s *Server) handleDMSGListen(frame HSFrame) error { var local routing.Addr if err := frame.UnmarshalBody(&local); err != nil { From 0a6506d93d47ce78454ff9301de994455c7d821a Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Mon, 9 Sep 2019 15:56:46 +0300 Subject: [PATCH 08/43] Add some work on server --- pkg/app2/hsframe.go | 7 +++ pkg/app2/server.go | 135 ++++++++++++++++++++++++++++++++++++-------- 2 files changed, 118 insertions(+), 24 deletions(-) diff --git a/pkg/app2/hsframe.go b/pkg/app2/hsframe.go index c174363ff2..975589edd6 100644 --- a/pkg/app2/hsframe.go +++ b/pkg/app2/hsframe.go @@ -25,6 +25,7 @@ const ( HSFrameTypeDMSGDial HSFrameTypeDMSGAccept HSFrameTypeStopListening + HSFrameTypeError ) // HSFrame is the data unit for socket connection handshakes between Server and Client. @@ -94,6 +95,12 @@ func NewHSFrameDMSGStopListening(procID ProcID, local routing.Addr) HSFrame { return hsFrame } +func NewHSFrameError(procID ProcID) HSFrame { + hsFrame := newHSFrame(procID, HSFrameTypeError, 0) + + return hsFrame +} + // ProcID gets ProcID from the HSFrame. func (f HSFrame) ProcID() ProcID { return ProcID(binary.BigEndian.Uint16(f)) diff --git a/pkg/app2/server.go b/pkg/app2/server.go index 163f46d9a9..1c24fd22c4 100644 --- a/pkg/app2/server.go +++ b/pkg/app2/server.go @@ -1,8 +1,10 @@ package app2 import ( + "context" "encoding/binary" "fmt" + "io" "net" "sync" @@ -20,10 +22,11 @@ import ( ) type clientConn struct { - conn net.Conn - session *yamux.Session - lm *listenersManager - dmsgL *dmsg.Listener + conn net.Conn + session *yamux.Session + lm *listenersManager + dmsgListeners map[routing.Port]*dmsg.Listener + dmsgListenersMx sync.RWMutex } // Server is used by skywire visor. @@ -31,7 +34,7 @@ type Server struct { PK cipher.PubKey dmsgC *dmsg.Client apps map[string]*clientConn - appsMx sync.Mutex + appsMx sync.RWMutex logger *logging.Logger } @@ -67,9 +70,10 @@ func (s *Server) Serve(sockAddr string) error { } s.apps[conn.RemoteAddr().String()] = &clientConn{ - session: session, - conn: conn, - lm: newListenersManager(), + session: session, + conn: conn, + lm: newListenersManager(), + dmsgListeners: make(map[routing.Port]*dmsg.Listener), } s.appsMx.Unlock() @@ -78,15 +82,15 @@ func (s *Server) Serve(sockAddr string) error { } } -func (s *Server) serveClient(session *yamux.Session) error { +func (s *Server) serveClient(conn *clientConn) error { for { - stream, err := session.Accept() + stream, err := conn.session.Accept() if err != nil { return errors.Wrap(err, "error opening stream") } - go s.serveStream(stream) - + go s.serveStream(stream, conn) +/////////////////////////////// hsFrame, err := readHSFrame(conn) if err != nil { return errors.Wrap(err, "error reading HS frame") @@ -94,6 +98,7 @@ func (s *Server) serveClient(session *yamux.Session) error { switch hsFrame.FrameType() { case HSFrameTypeDMSGListen: + if s. pk := make(cipher.PubKey, 33) copy(pk, hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) port := binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:]) @@ -110,36 +115,118 @@ func (s *Server) serveClient(session *yamux.Session) error { } } -func (s *Server) serveStream(stream net.Conn) error { +func (s *Server) serveStream(stream net.Conn, conn *clientConn) error { for { hsFrame, err := readHSFrame(stream) if err != nil { return errors.Wrap(err, "error reading HS frame") } + var respHSFrame HSFrame switch hsFrame.FrameType() { case HSFrameTypeDMSGListen: - var pk cipher.PubKey - copy(pk[:], hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) port := binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:]) - dmsgL, err := s.dmsgC.Listen(port) - if err != nil { - return fmt.Errorf("error listening on port %d: %v", port, err) + if err := conn.reserveListener(routing.Port(port)); err != nil { + respHSFrame = NewHSFrameError(hsFrame.ProcID()) + } else { + dmsgL, err := s.dmsgC.Listen(port) + if err != nil { + respHSFrame = NewHSFrameError(hsFrame.ProcID()) + } else { + if err := conn.addListener(routing.Port(port), dmsgL); err != nil { + respHSFrame = NewHSFrameError(hsFrame.ProcID()) + } else { + var pk cipher.PubKey + copy(pk[:], hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) + + respHSFrame = NewHSFrameDMSGListening(hsFrame.ProcID(), routing.Addr{ + PubKey: pk, + Port: routing.Port(port), + }) + } + } } + case HSFrameTypeDMSGDial: + localPort := binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:]) + var localPK cipher.PubKey + copy(localPK[:], hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) - respHSFrame := NewHSFrameDMSGListening(hsFrame.ProcID(), routing.Addr{ - PubKey: pk, - Port: routing.Port(port), - }) + var remotePK cipher.PubKey + copy(remotePK[:], hsFrame[HSFrameHeaderLen+HSFramePKLen+HSFramePortLen:HSFrameHeaderLen+HSFramePKLen+HSFramePortLen+HSFramePKLen]) + remotePort := binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen+HSFramePortLen+HSFramePKLen:]) - if _, err := stream.Write(respHSFrame); err != nil { - return errors.Wrap(err, "error writing response") + // TODO: context + tp, err := s.dmsgC.Dial(context.Background(), localPK, localPort) + if err != nil { + respHSFrame = NewHSFrameError(hsFrame.ProcID()) + } else { + respHSFrame = NewHSFrameDMSGAccept(hsFrame.ProcID(), routing.Loop{ + Local: routing.Addr{ + PubKey: localPK, + Port: routing.Port(localPort), + }, + Remote: routing.Addr{ + PubKey: remotePK, + Port: routing.Port(remotePort), + }, + }) + + go func() { + if err := s.forwardOverDMSG(stream, tp); err != nil { + s.logger.WithError(err).Error("error forwarding over DMSG") + } + }() } + } + if _, err := stream.Write(respHSFrame); err != nil { + return errors.Wrap(err, "error writing response") } } } +func (s *Server) forwardOverDMSG(stream net.Conn, tp *dmsg.Transport) error { + toStreamErrCh := make(chan error) + defer close(toStreamErrCh) + go func() { + _, err := io.Copy(stream, tp) + toStreamErrCh <- err + }() + + _, err := io.Copy(stream, tp) + if err != nil { + return err + } + + if err := <-toStreamErrCh; err != nil { + return err + } + + return nil +} + +func (c *clientConn) reserveListener(port routing.Port) error { + c.dmsgListenersMx.Lock() + if _, ok := c.dmsgListeners[port]; ok { + c.dmsgListenersMx.Unlock() + return ErrPortAlreadyBound + } + c.dmsgListeners[port] = nil + c.dmsgListenersMx.Unlock() + return nil +} + +func (c *clientConn) addListener(port routing.Port, l *dmsg.Listener) error { + c.dmsgListenersMx.Lock() + if lis, ok := c.dmsgListeners[port]; ok && lis != nil { + c.dmsgListenersMx.Unlock() + return ErrPortAlreadyBound + } + c.dmsgListeners[port] = l + c.dmsgListenersMx.Unlock() + return nil +} + func (s *Server) handleDMSGListen(frame HSFrame) error { var local routing.Addr if err := frame.UnmarshalBody(&local); err != nil { From 67ab801013e739fa24dc3bbbec3479fa8649cdd4 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Mon, 9 Sep 2019 22:07:37 +0300 Subject: [PATCH 09/43] More work on app2 --- pkg/app2/client.go | 25 +++++++++++++++++-- pkg/app2/listener.go | 29 +++++++++++++++++++--- pkg/app2/server.go | 58 +++++++++++++++++++++++--------------------- 3 files changed, 78 insertions(+), 34 deletions(-) diff --git a/pkg/app2/client.go b/pkg/app2/client.go index 4e2e5d754d..8b70f07fdc 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -146,11 +146,32 @@ func (c *Client) listen() error { } // TODO: handle field get gracefully - port := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:])) - if err := c.lm.addConn(port, stream); err != nil { + remotePort := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen*2+HSFramePortLen:])) + if err := c.lm.addConn(remotePort, stream); err != nil { c.logger.WithError(err).Error("failed to accept") continue } + + localPort := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:])) + + var localPK cipher.PubKey + copy(localPK[:], hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) + + respHSFrame := NewHSFrameDMSGAccept(c.pid, routing.Loop{ + Local: routing.Addr{ + PubKey: c.PK, + Port: remotePort, + }, + Remote: routing.Addr{ + PubKey: localPK, + Port: localPort, + }, + }) + + if _, err := stream.Write(respHSFrame); err != nil { + c.logger.WithError(err).Error("error responding with DmsgAccept") + continue + } } } diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go index 1a5c14b6ae..98f997411e 100644 --- a/pkg/app2/listener.go +++ b/pkg/app2/listener.go @@ -17,21 +17,33 @@ var ( ErrListenerClosed = errors.New("listener closed") ) +type acceptedConn struct { + remote routing.Addr + net.Conn +} + +func (c *acceptedConn) Addr() net.Addr { + return c.remote +} + type Listener struct { addr routing.Addr - conns chan net.Conn + conns chan *acceptedConn stopListening func(port routing.Port) error logger *logging.Logger lm *listenersManager + procID ProcID } -func NewListener(addr routing.Addr, lm *listenersManager, stopListening func(port routing.Port) error, l *logging.Logger) *Listener { +func NewListener(addr routing.Addr, lm *listenersManager, procID ProcID, + stopListening func(port routing.Port) error, l *logging.Logger) *Listener { return &Listener{ addr: addr, - conns: make(chan net.Conn, listenerBufSize), + conns: make(chan *acceptedConn, listenerBufSize), lm: lm, stopListening: stopListening, logger: l, + procID: procID, } } @@ -41,6 +53,15 @@ func (l *Listener) Accept() (net.Conn, error) { return nil, ErrListenerClosed } + hsFrame := NewHSFrameDMSGAccept(l.procID, routing.Loop{ + Local: l.addr, + Remote: conn.remote, + }) + + if _, err := conn.Write(hsFrame); err != nil { + return nil, err + } + return conn, nil } @@ -62,6 +83,6 @@ func (l *Listener) Addr() net.Addr { return l.addr } -func (l *Listener) addConn(conn net.Conn) { +func (l *Listener) addConn(conn *acceptedConn) { l.conns <- conn } diff --git a/pkg/app2/server.go b/pkg/app2/server.go index 1c24fd22c4..6b46889ed6 100644 --- a/pkg/app2/server.go +++ b/pkg/app2/server.go @@ -22,6 +22,7 @@ import ( ) type clientConn struct { + procID ProcID conn net.Conn session *yamux.Session lm *listenersManager @@ -122,6 +123,9 @@ func (s *Server) serveStream(stream net.Conn, conn *clientConn) error { return errors.Wrap(err, "error reading HS frame") } + // TODO: ensure thread-safety + conn.procID = hsFrame.ProcID() + var respHSFrame HSFrame switch hsFrame.FrameType() { case HSFrameTypeDMSGListen: @@ -223,42 +227,40 @@ func (c *clientConn) addListener(port routing.Port, l *dmsg.Listener) error { return ErrPortAlreadyBound } c.dmsgListeners[port] = l + go c.acceptDMSG(l) c.dmsgListenersMx.Unlock() return nil } -func (s *Server) handleDMSGListen(frame HSFrame) error { - var local routing.Addr - if err := frame.UnmarshalBody(&local); err != nil { - return errors.Wrap(err, "invalid JSON body") - } +func (c *clientConn) acceptDMSG(l *dmsg.Listener) error { + for { + stream, err := c.session.Open() + if err != nil { + return errors.Wrap(err, "error opening yamux stream") + } - // TODO: check `local` for validity + remoteAddr, ok := l.Addr().(dmsg.Addr) + if !ok { + // shouldn't happen, but still + return errors.Wrap(err, "wrong type for DMSG addr") + } - dmsgL, err := s.dmsgC.Listen(uint16(local.Port)) - if err != nil { - return fmt.Errorf("error listening on port %d: %v", local.Port, err) - } + hsFrame := NewHSFrameDSMGDial(c.procID, routing.Loop{ + Local: routing.Addr{ + PubKey: remoteAddr.PK, + Port: routing.Port(remoteAddr.Port), + }, + // TODO: get local addr + Remote: routing.Addr{ + PubKey: + }, + }) -} - -func (s *Server) handleDMSGListening(frame HSFrame) error { - var local routing.Addr - if err := frame.UnmarshalBody(&local); err != nil { - return errors.Wrap(err, "invalid JSON body") - } -} + conn, err := l.Accept() + if err != nil { + return errors.Wrap(err, "error accepting DMSG conn") + } -func (s *Server) handleDMSGDial(frame HSFrame) error { - var loop routing.Loop - if err := frame.UnmarshalBody(&loop); err != nil { - return errors.Wrap(err, "invalid JSON body") - } -} -func (s *Server) handleDMSGAccept(frame HSFrame) error { - var loop routing.Loop - if err := frame.UnmarshalBody(&loop); err != nil { - return errors.Wrap(err, "invalid JSON body") } } From a07d8a017cd6d6cc1593b482f1f6dd68e265c4e3 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Tue, 10 Sep 2019 18:48:52 +0300 Subject: [PATCH 10/43] Move `listen` to lm, add `Close` to lm --- pkg/app2/listeners_manager.go | 112 +++++++++++++++++++++++++++++++--- 1 file changed, 104 insertions(+), 8 deletions(-) diff --git a/pkg/app2/listeners_manager.go b/pkg/app2/listeners_manager.go index 251c8b2af0..a1c9925357 100644 --- a/pkg/app2/listeners_manager.go +++ b/pkg/app2/listeners_manager.go @@ -1,9 +1,14 @@ package app2 import ( + "encoding/binary" "net" "sync" + "sync/atomic" + "github.com/hashicorp/yamux" + + "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skycoin/src/util/logging" "github.com/pkg/errors" @@ -14,20 +19,35 @@ var ( ErrPortAlreadyBound = errors.New("port is already bound") ErrNoListenerOnPort = errors.New("no listener on port") ErrListenersManagerAlreadyServing = errors.New("listeners manager already serving") + ErrWrongPID = errors.New("wrong ProcID specified in the HS frame") ) type listenersManager struct { - listeners map[routing.Port]*Listener - mx sync.RWMutex - isServing int32 + pid ProcID + pk cipher.PubKey + listeners map[routing.Port]*Listener + mx sync.RWMutex + isListening int32 + logger *logging.Logger + doneCh chan struct{} + doneWg sync.WaitGroup } -func newListenersManager() *listenersManager { +func newListenersManager(l *logging.Logger, pid ProcID, pk cipher.PubKey) *listenersManager { return &listenersManager{ + pid: pid, + pk: pk, listeners: make(map[routing.Port]*Listener), + logger: l, + doneCh: make(chan struct{}), } } +func (lm *listenersManager) close() { + close(lm.doneCh) + lm.doneWg.Wait() +} + func (lm *listenersManager) portIsBound(port routing.Port) bool { lm.mx.RLock() _, ok := lm.listeners[port] @@ -41,7 +61,7 @@ func (lm *listenersManager) add(addr routing.Addr, stopListening func(port routi lm.mx.Unlock() return nil, ErrPortAlreadyBound } - l := NewListener(addr, lm, stopListening, logger) + l := NewListener(addr, lm, lm.pid, stopListening, logger) lm.listeners[addr.Port] = l lm.mx.Unlock() return l, nil @@ -58,13 +78,89 @@ func (lm *listenersManager) remove(port routing.Port) error { return nil } -func (lm *listenersManager) addConn(port routing.Port, conn net.Conn) error { +func (lm *listenersManager) addConn(localPort routing.Port, remote routing.Addr, conn net.Conn) error { lm.mx.RLock() - if _, ok := lm.listeners[port]; !ok { + if _, ok := lm.listeners[localPort]; !ok { lm.mx.RUnlock() return ErrNoListenerOnPort } - lm.listeners[port].addConn(conn) + lm.listeners[localPort].addConn(&acceptedConn{ + remote: remote, + Conn: conn, + }) lm.mx.RUnlock() return nil } + +func (lm *listenersManager) listen(session *yamux.Session) { + // this one should only start once + if !atomic.CompareAndSwapInt32(&lm.isListening, 0, 1) { + return + } + + lm.doneWg.Add(1) + + go func() { + defer lm.doneWg.Done() + + for { + select { + case <-lm.doneCh: + return + default: + stream, err := session.Accept() + if err != nil { + lm.logger.WithError(err).Error("error accepting stream") + return + } + + hsFrame, err := readHSFrame(stream) + if err != nil { + lm.logger.WithError(err).Error("error reading HS frame") + continue + } + + if hsFrame.ProcID() != lm.pid { + lm.logger.WithError(ErrWrongPID).Error("error listening for Dial") + } + + if hsFrame.FrameType() != HSFrameTypeDMSGDial { + lm.logger.WithError(ErrWrongHSFrameTypeReceived).Error("error listening for Dial") + continue + } + + // TODO: handle field get gracefully + remotePort := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen*2+HSFramePortLen:])) + localPort := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:])) + + var localPK cipher.PubKey + copy(localPK[:], hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) + + err = lm.addConn(remotePort, routing.Addr{ + PubKey: localPK, + Port: localPort, + }, stream) + if err != nil { + lm.logger.WithError(err).Error("failed to accept") + continue + } + + respHSFrame := NewHSFrameDMSGAccept(hsFrame.ProcID(), routing.Loop{ + Local: routing.Addr{ + PubKey: lm.pk, + Port: remotePort, + }, + Remote: routing.Addr{ + PubKey: localPK, + Port: localPort, + }, + }) + + if _, err := stream.Write(respHSFrame); err != nil { + lm.logger.WithError(err).Error("error responding with DmsgAccept") + continue + } + } + } + }() +} From 8dc151370f9930f79502625ce54bb7eb6e414043 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Tue, 10 Sep 2019 18:57:26 +0300 Subject: [PATCH 11/43] `acceptedConn` -> `clientConn`, `clientConn` -> `serverConn` --- pkg/app2/client_conn.go | 23 +++++++++++++++++++++++ pkg/app2/listener.go | 16 +++------------- pkg/app2/listeners_manager.go | 2 +- pkg/app2/server.go | 18 +++++++++--------- 4 files changed, 36 insertions(+), 23 deletions(-) create mode 100644 pkg/app2/client_conn.go diff --git a/pkg/app2/client_conn.go b/pkg/app2/client_conn.go new file mode 100644 index 0000000000..5617d30709 --- /dev/null +++ b/pkg/app2/client_conn.go @@ -0,0 +1,23 @@ +package app2 + +import ( + "net" + + "github.com/skycoin/skywire/pkg/routing" +) + +// clientConn serves as a wrapper for `net.Conn` being returned to the +// app client side from `Accept` func +type clientConn struct { + remote routing.Addr + local routing.Addr + net.Conn +} + +func (c *clientConn) RemoteAddr() net.Addr { + return c.remote +} + +func (c *clientConn) LocalAddr() net.Addr { + return c.local +} diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go index 98f997411e..e062138c18 100644 --- a/pkg/app2/listener.go +++ b/pkg/app2/listener.go @@ -5,7 +5,6 @@ import ( "net" "github.com/skycoin/skycoin/src/util/logging" - "github.com/skycoin/skywire/pkg/routing" ) @@ -17,18 +16,9 @@ var ( ErrListenerClosed = errors.New("listener closed") ) -type acceptedConn struct { - remote routing.Addr - net.Conn -} - -func (c *acceptedConn) Addr() net.Addr { - return c.remote -} - type Listener struct { addr routing.Addr - conns chan *acceptedConn + conns chan *clientConn stopListening func(port routing.Port) error logger *logging.Logger lm *listenersManager @@ -39,7 +29,7 @@ func NewListener(addr routing.Addr, lm *listenersManager, procID ProcID, stopListening func(port routing.Port) error, l *logging.Logger) *Listener { return &Listener{ addr: addr, - conns: make(chan *acceptedConn, listenerBufSize), + conns: make(chan *clientConn, listenerBufSize), lm: lm, stopListening: stopListening, logger: l, @@ -83,6 +73,6 @@ func (l *Listener) Addr() net.Addr { return l.addr } -func (l *Listener) addConn(conn *acceptedConn) { +func (l *Listener) addConn(conn *clientConn) { l.conns <- conn } diff --git a/pkg/app2/listeners_manager.go b/pkg/app2/listeners_manager.go index a1c9925357..1c1b7a7daa 100644 --- a/pkg/app2/listeners_manager.go +++ b/pkg/app2/listeners_manager.go @@ -84,7 +84,7 @@ func (lm *listenersManager) addConn(localPort routing.Port, remote routing.Addr, lm.mx.RUnlock() return ErrNoListenerOnPort } - lm.listeners[localPort].addConn(&acceptedConn{ + lm.listeners[localPort].addConn(&clientConn{ remote: remote, Conn: conn, }) diff --git a/pkg/app2/server.go b/pkg/app2/server.go index 6b46889ed6..3953266c09 100644 --- a/pkg/app2/server.go +++ b/pkg/app2/server.go @@ -21,7 +21,7 @@ import ( "github.com/skycoin/dmsg/cipher" ) -type clientConn struct { +type serverConn struct { procID ProcID conn net.Conn session *yamux.Session @@ -34,7 +34,7 @@ type clientConn struct { type Server struct { PK cipher.PubKey dmsgC *dmsg.Client - apps map[string]*clientConn + apps map[string]*serverConn appsMx sync.RWMutex logger *logging.Logger } @@ -43,7 +43,7 @@ func NewServer(localPK cipher.PubKey, dmsgC *dmsg.Client, l *logging.Logger) *Se return &Server{ PK: localPK, dmsgC: dmsgC, - apps: make(map[string]*clientConn), + apps: make(map[string]*serverConn), logger: l, } } @@ -70,7 +70,7 @@ func (s *Server) Serve(sockAddr string) error { return errors.Wrap(err, "error creating yamux session") } - s.apps[conn.RemoteAddr().String()] = &clientConn{ + s.apps[conn.RemoteAddr().String()] = &serverConn{ session: session, conn: conn, lm: newListenersManager(), @@ -83,7 +83,7 @@ func (s *Server) Serve(sockAddr string) error { } } -func (s *Server) serveClient(conn *clientConn) error { +func (s *Server) serveClient(conn *serverConn) error { for { stream, err := conn.session.Accept() if err != nil { @@ -116,7 +116,7 @@ func (s *Server) serveClient(conn *clientConn) error { } } -func (s *Server) serveStream(stream net.Conn, conn *clientConn) error { +func (s *Server) serveStream(stream net.Conn, conn *serverConn) error { for { hsFrame, err := readHSFrame(stream) if err != nil { @@ -209,7 +209,7 @@ func (s *Server) forwardOverDMSG(stream net.Conn, tp *dmsg.Transport) error { return nil } -func (c *clientConn) reserveListener(port routing.Port) error { +func (c *serverConn) reserveListener(port routing.Port) error { c.dmsgListenersMx.Lock() if _, ok := c.dmsgListeners[port]; ok { c.dmsgListenersMx.Unlock() @@ -220,7 +220,7 @@ func (c *clientConn) reserveListener(port routing.Port) error { return nil } -func (c *clientConn) addListener(port routing.Port, l *dmsg.Listener) error { +func (c *serverConn) addListener(port routing.Port, l *dmsg.Listener) error { c.dmsgListenersMx.Lock() if lis, ok := c.dmsgListeners[port]; ok && lis != nil { c.dmsgListenersMx.Unlock() @@ -232,7 +232,7 @@ func (c *clientConn) addListener(port routing.Port, l *dmsg.Listener) error { return nil } -func (c *clientConn) acceptDMSG(l *dmsg.Listener) error { +func (c *serverConn) acceptDMSG(l *dmsg.Listener) error { for { stream, err := c.session.Open() if err != nil { From a282c5a1f3480f77ff6a4e24299ec07e5767d6fc Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Tue, 10 Sep 2019 22:22:27 +0300 Subject: [PATCH 12/43] More work on app2 --- pkg/app2/client.go | 49 +++++++++--------------------- pkg/app2/hsframe.go | 56 +++++++++++++++++++++++++++++++++++ pkg/app2/listener.go | 16 +++++----- pkg/app2/listeners_manager.go | 44 +++++++++++++++------------ 4 files changed, 103 insertions(+), 62 deletions(-) diff --git a/pkg/app2/client.go b/pkg/app2/client.go index 8b70f07fdc..3f9745e081 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -3,7 +3,6 @@ package app2 import ( "encoding/binary" "net" - "sync/atomic" "github.com/hashicorp/yamux" @@ -46,7 +45,7 @@ func NewClient(localPK cipher.PubKey, pid ProcID, sockAddr string, l *logging.Lo return nil, errors.Wrap(err, "error opening yamux session") } - lm := newListenersManager() + lm := newListenersManager(l, pid, localPK) return &Client{ PK: localPK, @@ -64,32 +63,22 @@ func (c *Client) Dial(addr routing.Addr) (net.Conn, error) { return nil, errors.Wrap(err, "error opening stream") } - hsFrame := NewHSFrameDSMGDial(c.pid, routing.Loop{ + err = dialHS(stream, c.pid, routing.Loop{ Local: routing.Addr{ PubKey: c.PK, }, Remote: addr, }) - - if _, err := stream.Write(hsFrame); err != nil { - return nil, errors.Wrap(err, "error writing HS frame") - } - - hsFrame, err = readHSFrame(stream) if err != nil { - return nil, errors.Wrap(err, "error reading HS frame") - } - - if hsFrame.FrameType() != HSFrameTypeDMSGAccept { - return nil, ErrWrongHSFrameTypeReceived + return nil, errors.Wrap(err, "error performing Dial HS") } return stream, nil } -func (c *Client) Listen(port routing.Port) (*Listener, error) { - if c.lm.portIsBound(port) { - return nil, ErrPortAlreadyBound +func (c *Client) Listen(port routing.Port) (net.Listener, error) { + if err := c.lm.reserveListener(port); err != nil { + return nil, errors.Wrap(err, "error reserving listener") } stream, err := c.session.Open() @@ -97,34 +86,24 @@ func (c *Client) Listen(port routing.Port) (*Listener, error) { return nil, errors.Wrap(err, "error opening stream") } - addr := routing.Addr{ + local := routing.Addr{ PubKey: c.PK, Port: port, } - hsFrame := NewHSFrameDMSGListen(c.pid, addr) - if _, err := stream.Write(hsFrame); err != nil { - return nil, errors.Wrap(err, "error writing HS frame") - } - - hsFrame, err = readHSFrame(stream) + err = listenHS(stream, c.pid, local) if err != nil { - return nil, errors.Wrap(err, "error reading HS frame") + return nil, errors.Wrap(err, "error performing Listen HS") } - if hsFrame.FrameType() != HSFrameTypeDMSGListening { - return nil, ErrWrongHSFrameTypeReceived - } + c.lm.listen(c.session) - if atomic.CompareAndSwapInt32(&c.isListening, 0, 1) { - go func() { - if err := c.listen(); err != nil { - c.logger.WithError(err).Error("error listening") - } - }() + l := newListener(local, c.lm, c.pid, c.stopListening, c.logger) + if err := c.lm.set(port, l); err != nil { + return nil, errors.Wrap(err, "error setting listener") } - return c.lm.add(addr, c.stopListening, c.logger) + return l, nil } func (c *Client) listen() error { diff --git a/pkg/app2/hsframe.go b/pkg/app2/hsframe.go index 975589edd6..d2ffc16544 100644 --- a/pkg/app2/hsframe.go +++ b/pkg/app2/hsframe.go @@ -3,6 +3,7 @@ package app2 import ( "encoding/binary" "io" + "net" "github.com/pkg/errors" "github.com/skycoin/skywire/pkg/routing" @@ -137,6 +138,15 @@ func readHSFrame(r io.Reader) (HSFrame, error) { return hsFrame, nil } +func writeHSFrame(w io.Writer, hsFrame HSFrame) (int, error) { + n, err := w.Write(hsFrame) + if err != nil { + return 0, errors.Wrap(err, "error writing HS frame") + } + + return n, nil +} + func readHSFrameBody(hsFrame HSFrame, r io.Reader) (HSFrame, error) { switch hsFrame.FrameType() { case HSFrameTypeDMSGListen, HSFrameTypeDMSGListening: @@ -148,3 +158,49 @@ func readHSFrameBody(hsFrame HSFrame, r io.Reader) (HSFrame, error) { _, err := io.ReadFull(r, hsFrame[HSFrameHeaderLen:]) return hsFrame, err } + +func dialHS(conn net.Conn, pid ProcID, loop routing.Loop) error { + hsFrame := NewHSFrameDSMGDial(pid, loop) + + if _, err := writeHSFrame(conn, hsFrame); err != nil { + return err + } + + hsFrame, err := readHSFrame(conn) + if err != nil { + return err + } + + if hsFrame.FrameType() != HSFrameTypeDMSGAccept { + return ErrWrongHSFrameTypeReceived + } + + if hsFrame.ProcID() != pid { + return ErrWrongPID + } + + return nil +} + +func listenHS(conn net.Conn, pid ProcID, local routing.Addr) error { + hsFrame := NewHSFrameDMSGListen(pid, local) + + if _, err := writeHSFrame(conn, hsFrame); err != nil { + return err + } + + hsFrame, err := readHSFrame(conn) + if err != nil { + return err + } + + if hsFrame.FrameType() != HSFrameTypeDMSGListening { + return ErrWrongHSFrameTypeReceived + } + + if hsFrame.ProcID() != pid { + return ErrWrongPID + } + + return nil +} diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go index e062138c18..33c7c2e296 100644 --- a/pkg/app2/listener.go +++ b/pkg/app2/listener.go @@ -16,7 +16,7 @@ var ( ErrListenerClosed = errors.New("listener closed") ) -type Listener struct { +type listener struct { addr routing.Addr conns chan *clientConn stopListening func(port routing.Port) error @@ -25,9 +25,9 @@ type Listener struct { procID ProcID } -func NewListener(addr routing.Addr, lm *listenersManager, procID ProcID, - stopListening func(port routing.Port) error, l *logging.Logger) *Listener { - return &Listener{ +func newListener(addr routing.Addr, lm *listenersManager, procID ProcID, + stopListening func(port routing.Port) error, l *logging.Logger) *listener { + return &listener{ addr: addr, conns: make(chan *clientConn, listenerBufSize), lm: lm, @@ -37,7 +37,7 @@ func NewListener(addr routing.Addr, lm *listenersManager, procID ProcID, } } -func (l *Listener) Accept() (net.Conn, error) { +func (l *listener) Accept() (net.Conn, error) { conn, ok := <-l.conns if !ok { return nil, ErrListenerClosed @@ -55,7 +55,7 @@ func (l *Listener) Accept() (net.Conn, error) { return conn, nil } -func (l *Listener) Close() error { +func (l *listener) Close() error { if err := l.stopListening(l.addr.Port); err != nil { l.logger.WithError(err).Error("error sending DmsgStopListening") } @@ -69,10 +69,10 @@ func (l *Listener) Close() error { return nil } -func (l *Listener) Addr() net.Addr { +func (l *listener) Addr() net.Addr { return l.addr } -func (l *Listener) addConn(conn *clientConn) { +func (l *listener) addConn(conn *clientConn) { l.conns <- conn } diff --git a/pkg/app2/listeners_manager.go b/pkg/app2/listeners_manager.go index 1c1b7a7daa..5619a439b9 100644 --- a/pkg/app2/listeners_manager.go +++ b/pkg/app2/listeners_manager.go @@ -7,25 +7,24 @@ import ( "sync/atomic" "github.com/hashicorp/yamux" - + "github.com/pkg/errors" "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skycoin/src/util/logging" - "github.com/pkg/errors" "github.com/skycoin/skywire/pkg/routing" ) var ( - ErrPortAlreadyBound = errors.New("port is already bound") - ErrNoListenerOnPort = errors.New("no listener on port") - ErrListenersManagerAlreadyServing = errors.New("listeners manager already serving") - ErrWrongPID = errors.New("wrong ProcID specified in the HS frame") + ErrPortAlreadyBound = errors.New("port is already bound") + ErrNoListenerOnPort = errors.New("no listener on port") + ErrWrongPID = errors.New("wrong ProcID specified in the HS frame") ) +// listenersManager contains and manages all the instantiated listeners type listenersManager struct { pid ProcID pk cipher.PubKey - listeners map[routing.Port]*Listener + listeners map[routing.Port]*listener mx sync.RWMutex isListening int32 logger *logging.Logger @@ -37,7 +36,7 @@ func newListenersManager(l *logging.Logger, pid ProcID, pk cipher.PubKey) *liste return &listenersManager{ pid: pid, pk: pk, - listeners: make(map[routing.Port]*Listener), + listeners: make(map[routing.Port]*listener), logger: l, doneCh: make(chan struct{}), } @@ -48,23 +47,26 @@ func (lm *listenersManager) close() { lm.doneWg.Wait() } -func (lm *listenersManager) portIsBound(port routing.Port) bool { - lm.mx.RLock() - _, ok := lm.listeners[port] - lm.mx.RUnlock() - return ok +func (lm *listenersManager) set(port routing.Port, l *listener) error { + lm.mx.Lock() + if v, ok := lm.listeners[port]; !ok || v != nil { + lm.mx.Unlock() + return ErrPortAlreadyBound + } + lm.listeners[port] = l + lm.mx.Unlock() + return nil } -func (lm *listenersManager) add(addr routing.Addr, stopListening func(port routing.Port) error, logger *logging.Logger) (*Listener, error) { +func (lm *listenersManager) reserveListener(port routing.Port) error { lm.mx.Lock() - if _, ok := lm.listeners[addr.Port]; ok { + if _, ok := lm.listeners[port]; ok { lm.mx.Unlock() - return nil, ErrPortAlreadyBound + return ErrPortAlreadyBound } - l := NewListener(addr, lm, lm.pid, stopListening, logger) - lm.listeners[addr.Port] = l + lm.listeners[port] = nil lm.mx.Unlock() - return l, nil + return nil } func (lm *listenersManager) remove(port routing.Port) error { @@ -78,6 +80,7 @@ func (lm *listenersManager) remove(port routing.Port) error { return nil } +// addConn passes connection to the corresponding listener func (lm *listenersManager) addConn(localPort routing.Port, remote routing.Addr, conn net.Conn) error { lm.mx.RLock() if _, ok := lm.listeners[localPort]; !ok { @@ -92,6 +95,9 @@ func (lm *listenersManager) addConn(localPort routing.Port, remote routing.Addr, return nil } +// listen accepts all new yamux streams from the server. We want to accept only +// `DmsgDial` frames here, thus all the other frames get rejected. `DmsgDial` frames +// are being distributed between the corresponding listeners with regards to their port func (lm *listenersManager) listen(session *yamux.Session) { // this one should only start once if !atomic.CompareAndSwapInt32(&lm.isListening, 0, 1) { From e2afc9eae28def55cfe7d2d0eef915034d5c9b59 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Thu, 12 Sep 2019 14:11:22 +0300 Subject: [PATCH 13/43] Add RPC server for server --- pkg/app2/server_rpc.go | 191 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 pkg/app2/server_rpc.go diff --git a/pkg/app2/server_rpc.go b/pkg/app2/server_rpc.go new file mode 100644 index 0000000000..ba73f4fe88 --- /dev/null +++ b/pkg/app2/server_rpc.go @@ -0,0 +1,191 @@ +package app2 + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + + "github.com/skycoin/dmsg" + + "github.com/skycoin/skywire/pkg/routing" +) + +type ServerRPC struct { + dmsgC *dmsg.Client + conns map[uint16]net.Conn + connsMx sync.RWMutex + lstConnID uint16 + listeners map[uint16]*dmsg.Listener + listenersMx sync.RWMutex + lstLisID uint16 +} + +func (r *ServerRPC) nextConnID() (*uint16, error) { + r.connsMx.Lock() + + connID := r.lstConnID + 1 + for ; connID < r.lstConnID; connID++ { + if _, ok := r.conns[connID]; !ok { + break + } + } + + if connID == r.lstConnID { + r.connsMx.Unlock() + return nil, errors.New("no more available conns") + } + + r.conns[connID] = nil + r.lstConnID = connID + + r.connsMx.Unlock() + return &connID, nil +} + +func (r *ServerRPC) nextLisID() (*uint16, error) { + r.listenersMx.Lock() + + lisID := r.lstLisID + 1 + for ; lisID < r.lstLisID; lisID++ { + if _, ok := r.listeners[lisID]; !ok { + break + } + } + + if lisID == r.lstLisID { + r.listenersMx.Unlock() + return nil, errors.New("no more available listeners") + } + + r.listeners[lisID] = nil + r.lstLisID = lisID + + r.listenersMx.Unlock() + return &lisID, nil +} + +func (r *ServerRPC) setConn(connID uint16, conn net.Conn) error { + r.connsMx.Lock() + + if c, ok := r.conns[connID]; ok && c != nil { + r.connsMx.Unlock() + return errors.New("conn already exists") + } + + r.conns[connID] = conn + + r.connsMx.Unlock() + return nil +} + +func (r *ServerRPC) setListener(lisID uint16, lis *dmsg.Listener) error { + r.listenersMx.Lock() + + if l, ok := r.listeners[lisID]; ok && l != nil { + r.listenersMx.Unlock() + return errors.New("listener already exists") + } + + r.listeners[lisID] = lis + + r.listenersMx.Unlock() + return nil +} + +func (r *ServerRPC) getConn(connID uint16) (net.Conn, bool) { + r.connsMx.RLock() + conn, ok := r.conns[connID] + r.connsMx.RUnlock() + return conn, ok +} + +func (r *ServerRPC) getListener(lisID uint16) (*dmsg.Listener, bool) { + r.listenersMx.RLock() + lis, ok := r.listeners[lisID] + r.listenersMx.RUnlock() + return lis, ok +} + +type DialReq struct { + Remote routing.Addr +} + +func (r *ServerRPC) Dial(req *DialReq, connID *uint16) error { + connID, err := r.nextConnID() + if err != nil { + return err + } + + tp, err := r.dmsgC.Dial(context.TODO(), req.Remote.PubKey, uint16(req.Remote.Port)) + if err != nil { + return err + } + + if err := r.setConn(*connID, tp); err != nil { + return err + } + + return nil +} + +type ListenReq struct { + Local routing.Addr +} + +func (r *ServerRPC) Listen(req *ListenReq, lisID *uint16) error { + lisID, err := r.nextLisID() + if err != nil { + return err + } + + dmsgL, err := r.dmsgC.Listen(uint16(req.Local.Port)) + if err != nil { + return err + } + + if err := r.setListener(*lisID, dmsgL); err != nil { + // TODO: close listener + return err + } + + return nil +} + +func (r *ServerRPC) Accept(lisID *uint16, connID *uint16) error { + lis, ok := r.getListener(*lisID) + if !ok { + return fmt.Errorf("not listener with id %d", *lisID) + } + + connID, err := r.nextConnID() + if err != nil { + return err + } + + tp, err := lis.Accept() + if err != nil { + return err + } + + if err := r.setConn(*connID, tp); err != nil { + // TODO: close conn + return err + } + + return nil +} + +type WriteReq struct { + ConnID uint16 + B []byte +} + +func (r *ServerRPC) Write(req *WriteReq, n *int) error { + conn, ok := r.getConn(req.ConnID) + if !ok { + return fmt.Errorf("not conn with id %d", req.ConnID) + } + +} From f374c397618bb891150c38231d021158063df601 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Sun, 15 Sep 2019 22:12:45 +0300 Subject: [PATCH 14/43] Finish RPC communication --- pkg/app2/client.go | 168 ++++----------------- pkg/app2/client_conn.go | 23 --- pkg/app2/conn.go | 34 +++++ pkg/app2/conns_manager.go | 83 +++++++++++ pkg/app2/hsframe.go | 206 -------------------------- pkg/app2/hsframe_test.go | 69 --------- pkg/app2/listener.go | 74 +++------- pkg/app2/listeners_manager.go | 196 +++++++------------------ pkg/app2/server.go | 266 ---------------------------------- pkg/app2/server_rpc.go | 180 +++++++++-------------- pkg/app2/server_rpc_client.go | 101 +++++++++++++ 11 files changed, 385 insertions(+), 1015 deletions(-) delete mode 100644 pkg/app2/client_conn.go create mode 100644 pkg/app2/conn.go create mode 100644 pkg/app2/conns_manager.go delete mode 100644 pkg/app2/hsframe.go delete mode 100644 pkg/app2/hsframe_test.go delete mode 100644 pkg/app2/server.go create mode 100644 pkg/app2/server_rpc_client.go diff --git a/pkg/app2/client.go b/pkg/app2/client.go index 3f9745e081..c78d1ab7b0 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -1,178 +1,70 @@ package app2 import ( - "encoding/binary" - "net" + "net/rpc" - "github.com/hashicorp/yamux" - - "github.com/pkg/errors" - - "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/pkg/routing" "github.com/skycoin/dmsg/cipher" -) - -var ( - ErrWrongHSFrameTypeReceived = errors.New("received wrong HS frame type") + "github.com/skycoin/skycoin/src/util/logging" ) // Client is used by skywire apps. type Client struct { - PK cipher.PubKey - pid ProcID - sockAddr string - conn net.Conn - session *yamux.Session - logger *logging.Logger - lm *listenersManager - isListening int32 + PK cipher.PubKey + pid ProcID + rpc ServerRPCClient + logger *logging.Logger } // NewClient creates a new Client. The Client needs to be provided with: // - localPK: The local public key of the parent skywire visor. // - pid: The procID assigned for the process that Client is being used by. // - sockAddr: The socket address to connect to Server. -func NewClient(localPK cipher.PubKey, pid ProcID, sockAddr string, l *logging.Logger) (*Client, error) { - conn, err := net.Dial("unix", sockAddr) - if err != nil { - return nil, errors.Wrap(err, "error connecting app server") - } - - session, err := yamux.Client(conn, nil) - if err != nil { - return nil, errors.Wrap(err, "error opening yamux session") - } - - lm := newListenersManager(l, pid, localPK) - +func NewClient(localPK cipher.PubKey, pid ProcID, rpc *rpc.Client, l *logging.Logger) *Client { return &Client{ - PK: localPK, - pid: pid, - sockAddr: sockAddr, - conn: conn, - session: session, - lm: lm, - }, nil + PK: localPK, + pid: pid, + rpc: newServerRPCClient(rpc), + logger: l, + } } -func (c *Client) Dial(addr routing.Addr) (net.Conn, error) { - stream, err := c.session.Open() +func (c *Client) Dial(remote routing.Addr) (*Conn, error) { + connID, err := c.rpc.Dial(remote) if err != nil { - return nil, errors.Wrap(err, "error opening stream") + return nil, err } - err = dialHS(stream, c.pid, routing.Loop{ - Local: routing.Addr{ + conn := &Conn{ + id: connID, + rpc: c.rpc, + // TODO: port? + local: routing.Addr{ PubKey: c.PK, }, - Remote: addr, - }) - if err != nil { - return nil, errors.Wrap(err, "error performing Dial HS") + remote: remote, } - return stream, nil + return conn, nil } -func (c *Client) Listen(port routing.Port) (net.Listener, error) { - if err := c.lm.reserveListener(port); err != nil { - return nil, errors.Wrap(err, "error reserving listener") - } - - stream, err := c.session.Open() - if err != nil { - return nil, errors.Wrap(err, "error opening stream") - } - +func (c *Client) Listen(port routing.Port) (*Listener, error) { local := routing.Addr{ PubKey: c.PK, Port: port, } - err = listenHS(stream, c.pid, local) - if err != nil { - return nil, errors.Wrap(err, "error performing Listen HS") - } - - c.lm.listen(c.session) - - l := newListener(local, c.lm, c.pid, c.stopListening, c.logger) - if err := c.lm.set(port, l); err != nil { - return nil, errors.Wrap(err, "error setting listener") - } - - return l, nil -} - -func (c *Client) listen() error { - for { - stream, err := c.session.Accept() - if err != nil { - return errors.Wrap(err, "error accepting stream") - } - - hsFrame, err := readHSFrame(stream) - if err != nil { - c.logger.WithError(err).Error("error reading HS frame") - continue - } - - if hsFrame.FrameType() != HSFrameTypeDMSGDial { - c.logger.WithError(ErrWrongHSFrameTypeReceived).Error("on listening for Dial") - continue - } - - // TODO: handle field get gracefully - remotePort := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen*2+HSFramePortLen:])) - if err := c.lm.addConn(remotePort, stream); err != nil { - c.logger.WithError(err).Error("failed to accept") - continue - } - - localPort := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:])) - - var localPK cipher.PubKey - copy(localPK[:], hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) - - respHSFrame := NewHSFrameDMSGAccept(c.pid, routing.Loop{ - Local: routing.Addr{ - PubKey: c.PK, - Port: remotePort, - }, - Remote: routing.Addr{ - PubKey: localPK, - Port: localPort, - }, - }) - - if _, err := stream.Write(respHSFrame); err != nil { - c.logger.WithError(err).Error("error responding with DmsgAccept") - continue - } - } -} - -func (c *Client) stopListening(port routing.Port) error { - stream, err := c.session.Open() + lisID, err := c.rpc.Listen(local) if err != nil { - return errors.Wrap(err, "error opening stream") - } - - addr := routing.Addr{ - PubKey: c.PK, - Port: port, - } - - hsFrame := NewHSFrameDMSGStopListening(c.pid, addr) - if _, err := stream.Write(hsFrame); err != nil { - return errors.Wrap(err, "error writing HS frame") + return nil, err } - if err := stream.Close(); err != nil { - return errors.Wrap(err, "error closing stream") + listener := &Listener{ + id: lisID, + rpc: c.rpc, + addr: local, } - return nil + return listener, nil } diff --git a/pkg/app2/client_conn.go b/pkg/app2/client_conn.go deleted file mode 100644 index 5617d30709..0000000000 --- a/pkg/app2/client_conn.go +++ /dev/null @@ -1,23 +0,0 @@ -package app2 - -import ( - "net" - - "github.com/skycoin/skywire/pkg/routing" -) - -// clientConn serves as a wrapper for `net.Conn` being returned to the -// app client side from `Accept` func -type clientConn struct { - remote routing.Addr - local routing.Addr - net.Conn -} - -func (c *clientConn) RemoteAddr() net.Addr { - return c.remote -} - -func (c *clientConn) LocalAddr() net.Addr { - return c.local -} diff --git a/pkg/app2/conn.go b/pkg/app2/conn.go new file mode 100644 index 0000000000..c5a7e2ed91 --- /dev/null +++ b/pkg/app2/conn.go @@ -0,0 +1,34 @@ +package app2 + +import ( + "net" + + "github.com/skycoin/skywire/pkg/routing" +) + +type Conn struct { + id uint16 + rpc ConnRPCClient + local routing.Addr + remote routing.Addr +} + +func (c *Conn) Read(b []byte) (int, error) { + return c.rpc.Read(c.id, b) +} + +func (c *Conn) Write(b []byte) (int, error) { + return c.rpc.Write(c.id, b) +} + +func (c *Conn) Close() error { + return c.rpc.CloseConn(c.id) +} + +func (c *Conn) LocalAddr() net.Addr { + return c.local +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.remote +} diff --git a/pkg/app2/conns_manager.go b/pkg/app2/conns_manager.go new file mode 100644 index 0000000000..d09b61c7cd --- /dev/null +++ b/pkg/app2/conns_manager.go @@ -0,0 +1,83 @@ +package app2 + +import ( + "fmt" + "net" + "sync" + + "github.com/pkg/errors" +) + +type connsManager struct { + conns map[uint16]net.Conn + mx sync.RWMutex + lstID uint16 +} + +func newConnsManager() *connsManager { + return &connsManager{ + conns: make(map[uint16]net.Conn), + } +} + +func (m *connsManager) nextID() (*uint16, error) { + m.mx.Lock() + + connID := m.lstID + 1 + for ; connID < m.lstID; connID++ { + if _, ok := m.conns[connID]; !ok { + break + } + } + + if connID == m.lstID { + m.mx.Unlock() + return nil, errors.New("no more available conns") + } + + m.conns[connID] = nil + m.lstID = connID + + m.mx.Unlock() + return &connID, nil +} + +func (m *connsManager) getAndRemove(connID uint16) (net.Conn, error) { + m.mx.Lock() + conn, ok := m.conns[connID] + if !ok { + m.mx.Unlock() + return nil, fmt.Errorf("no conn with id %d", connID) + } + + if conn == nil { + m.mx.Unlock() + return nil, fmt.Errorf("conn with id %d is not set", connID) + } + + delete(m.conns, connID) + + m.mx.Unlock() + return conn, nil +} + +func (m *connsManager) set(connID uint16, conn net.Conn) error { + m.mx.Lock() + + if c, ok := m.conns[connID]; ok && c != nil { + m.mx.Unlock() + return errors.New("conn already exists") + } + + m.conns[connID] = conn + + m.mx.Unlock() + return nil +} + +func (m *connsManager) get(connID uint16) (net.Conn, bool) { + m.mx.RLock() + conn, ok := m.conns[connID] + m.mx.RUnlock() + return conn, ok +} diff --git a/pkg/app2/hsframe.go b/pkg/app2/hsframe.go deleted file mode 100644 index d2ffc16544..0000000000 --- a/pkg/app2/hsframe.go +++ /dev/null @@ -1,206 +0,0 @@ -package app2 - -import ( - "encoding/binary" - "io" - "net" - - "github.com/pkg/errors" - "github.com/skycoin/skywire/pkg/routing" -) - -const ( - HSFrameHeaderLen = 3 - HSFrameProcIDLen = 2 - HSFrameTypeLen = 1 - HSFramePKLen = 33 - HSFramePortLen = 2 -) - -// HSFrameType identifies the type of a handshake frame. -type HSFrameType byte - -const ( - HSFrameTypeDMSGListen HSFrameType = 10 + iota - HSFrameTypeDMSGListening - HSFrameTypeDMSGDial - HSFrameTypeDMSGAccept - HSFrameTypeStopListening - HSFrameTypeError -) - -// HSFrame is the data unit for socket connection handshakes between Server and Client. -// It consists of header and body. -// -// Header is a big-endian encoded 3 bytes and is constructed as follows: -// | ProcID (2 bytes) | HSFrameType (1 byte) | -type HSFrame []byte - -func newHSFrame(procID ProcID, frameType HSFrameType, bodyLen int) HSFrame { - hsFrame := make(HSFrame, HSFrameHeaderLen+bodyLen) - - hsFrame.SetProcID(procID) - hsFrame.SetFrameType(frameType) - - return hsFrame -} - -func NewHSFrameDMSGListen(procID ProcID, local routing.Addr) HSFrame { - hsFrame := newHSFrame(procID, HSFrameTypeDMSGListen, HSFramePKLen+HSFramePortLen) - - copy(hsFrame[HSFrameHeaderLen:], local.PubKey[:]) - binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:], uint16(local.Port)) - - return hsFrame -} - -func NewHSFrameDMSGListening(procID ProcID, local routing.Addr) HSFrame { - hsFrame := newHSFrame(procID, HSFrameTypeDMSGListening, HSFramePKLen+HSFramePortLen) - - copy(hsFrame[HSFrameHeaderLen:], local.PubKey[:]) - binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:], uint16(local.Port)) - - return hsFrame -} - -func NewHSFrameDSMGDial(procID ProcID, loop routing.Loop) HSFrame { - hsFrame := newHSFrame(procID, HSFrameTypeDMSGDial, 2*HSFramePKLen+2*HSFramePortLen) - - copy(hsFrame[HSFrameHeaderLen:], loop.Local.PubKey[:]) - binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:], uint16(loop.Local.Port)) - - copy(hsFrame[HSFrameHeaderLen+HSFramePKLen+HSFramePortLen:], loop.Remote.PubKey[:]) - binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+2*HSFramePKLen+HSFramePortLen:], uint16(loop.Remote.Port)) - - return hsFrame -} - -func NewHSFrameDMSGAccept(procID ProcID, loop routing.Loop) HSFrame { - hsFrame := newHSFrame(procID, HSFrameTypeDMSGAccept, 2*HSFramePKLen+2*HSFramePortLen) - - copy(hsFrame[HSFrameHeaderLen:], loop.Local.PubKey[:]) - binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:], uint16(loop.Local.Port)) - - copy(hsFrame[HSFrameHeaderLen+HSFramePKLen+HSFramePortLen:], loop.Remote.PubKey[:]) - binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+2*HSFramePKLen+HSFramePortLen:], uint16(loop.Remote.Port)) - - return hsFrame -} - -func NewHSFrameDMSGStopListening(procID ProcID, local routing.Addr) HSFrame { - hsFrame := newHSFrame(procID, HSFrameTypeDMSGListen, HSFramePKLen+HSFramePortLen) - - copy(hsFrame[HSFrameHeaderLen:], local.PubKey[:]) - binary.BigEndian.PutUint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:], uint16(local.Port)) - - return hsFrame -} - -func NewHSFrameError(procID ProcID) HSFrame { - hsFrame := newHSFrame(procID, HSFrameTypeError, 0) - - return hsFrame -} - -// ProcID gets ProcID from the HSFrame. -func (f HSFrame) ProcID() ProcID { - return ProcID(binary.BigEndian.Uint16(f)) -} - -// SetProcID sets ProcID for the HSFrame. -func (f HSFrame) SetProcID(procID ProcID) { - binary.BigEndian.PutUint16(f, uint16(procID)) -} - -// FrameType gets FrameType from the HSFrame. -func (f HSFrame) FrameType() HSFrameType { - _ = f[HSFrameProcIDLen] // bounds check hint to compiler; see golang.org/issue/14808 - return HSFrameType(f[HSFrameProcIDLen]) -} - -// SetFrameType sets FrameType for the HSFrame. -func (f HSFrame) SetFrameType(frameType HSFrameType) { - _ = f[HSFrameProcIDLen] // bounds check hint to compiler; see golang.org/issue/14808 - f[HSFrameProcIDLen] = byte(frameType) -} - -func readHSFrame(r io.Reader) (HSFrame, error) { - hsFrame := make(HSFrame, HSFrameHeaderLen) - if _, err := io.ReadFull(r, hsFrame); err != nil { - return nil, errors.Wrap(err, "error reading HS frame header") - } - - hsFrame, err := readHSFrameBody(hsFrame, r) - if err != nil { - return nil, errors.Wrap(err, "error reading HS frame body") - } - - return hsFrame, nil -} - -func writeHSFrame(w io.Writer, hsFrame HSFrame) (int, error) { - n, err := w.Write(hsFrame) - if err != nil { - return 0, errors.Wrap(err, "error writing HS frame") - } - - return n, nil -} - -func readHSFrameBody(hsFrame HSFrame, r io.Reader) (HSFrame, error) { - switch hsFrame.FrameType() { - case HSFrameTypeDMSGListen, HSFrameTypeDMSGListening: - hsFrame = append(hsFrame, make([]byte, HSFramePKLen+HSFramePortLen)...) - case HSFrameTypeDMSGDial, HSFrameTypeDMSGAccept: - hsFrame = append(hsFrame, make([]byte, 2*HSFramePKLen+2*HSFramePortLen)...) - } - - _, err := io.ReadFull(r, hsFrame[HSFrameHeaderLen:]) - return hsFrame, err -} - -func dialHS(conn net.Conn, pid ProcID, loop routing.Loop) error { - hsFrame := NewHSFrameDSMGDial(pid, loop) - - if _, err := writeHSFrame(conn, hsFrame); err != nil { - return err - } - - hsFrame, err := readHSFrame(conn) - if err != nil { - return err - } - - if hsFrame.FrameType() != HSFrameTypeDMSGAccept { - return ErrWrongHSFrameTypeReceived - } - - if hsFrame.ProcID() != pid { - return ErrWrongPID - } - - return nil -} - -func listenHS(conn net.Conn, pid ProcID, local routing.Addr) error { - hsFrame := NewHSFrameDMSGListen(pid, local) - - if _, err := writeHSFrame(conn, hsFrame); err != nil { - return err - } - - hsFrame, err := readHSFrame(conn) - if err != nil { - return err - } - - if hsFrame.FrameType() != HSFrameTypeDMSGListening { - return ErrWrongHSFrameTypeReceived - } - - if hsFrame.ProcID() != pid { - return ErrWrongPID - } - - return nil -} diff --git a/pkg/app2/hsframe_test.go b/pkg/app2/hsframe_test.go deleted file mode 100644 index ab66e376f9..0000000000 --- a/pkg/app2/hsframe_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package app2 - -import ( - "encoding/binary" - "encoding/json" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestHSFrame(t *testing.T) { - t.Run("ok", func(t *testing.T) { - body := struct { - Test string `json:"test"` - }{ - Test: "some string", - } - - bodyBytes, err := json.Marshal(body) - require.NoError(t, err) - - procID := ProcID(1) - frameType := HSFrameTypeDMSGListen - bodyLen := len(bodyBytes) - - hsFrame, err := NewHSFrame(procID, frameType, body) - require.NoError(t, err) - - require.Equal(t, len(hsFrame), HSFrameHeaderLen+len(bodyBytes)) - - gotProcID := ProcID(binary.BigEndian.Uint16(hsFrame)) - require.Equal(t, gotProcID, procID) - - gotFrameType := HSFrameType(hsFrame[HSFrameProcIDLen]) - require.Equal(t, gotFrameType, frameType) - - gotBodyLen := int(binary.BigEndian.Uint16(hsFrame[HSFrameProcIDLen+HSFrameTypeLen:])) - require.Equal(t, gotBodyLen, bodyLen) - - require.Equal(t, bodyBytes, []byte(hsFrame[HSFrameProcIDLen+HSFrameTypeLen+HSFrameBodyLenLen:])) - - gotProcID = hsFrame.ProcID() - require.Equal(t, gotProcID, procID) - - gotFrameType = hsFrame.FrameType() - require.Equal(t, gotFrameType, frameType) - - gotBodyLen = hsFrame.BodyLen() - require.Equal(t, gotBodyLen, bodyLen) - }) - - t.Run("fail - too large body", func(t *testing.T) { - body := struct { - Test string `json:"test"` - }{ - Test: "some string", - } - - for len(body.Test) <= HSFrameMaxBodyLen { - body.Test += body.Test - } - - procID := ProcID(1) - frameType := HSFrameTypeDMSGListen - - _, err := NewHSFrame(procID, frameType, body) - require.Equal(t, err, ErrHSFrameBodyTooLarge) - }) -} diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go index 33c7c2e296..fb423fbfdc 100644 --- a/pkg/app2/listener.go +++ b/pkg/app2/listener.go @@ -1,78 +1,38 @@ package app2 import ( - "errors" "net" - "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/pkg/routing" ) -const ( - listenerBufSize = 1000 -) - -var ( - ErrListenerClosed = errors.New("listener closed") -) - -type listener struct { - addr routing.Addr - conns chan *clientConn - stopListening func(port routing.Port) error - logger *logging.Logger - lm *listenersManager - procID ProcID +type Listener struct { + id uint16 + rpc ListenerRPCClient + addr routing.Addr } -func newListener(addr routing.Addr, lm *listenersManager, procID ProcID, - stopListening func(port routing.Port) error, l *logging.Logger) *listener { - return &listener{ - addr: addr, - conns: make(chan *clientConn, listenerBufSize), - lm: lm, - stopListening: stopListening, - logger: l, - procID: procID, +func (l *Listener) Accept() (*Conn, error) { + connID, err := l.rpc.Accept(l.id) + if err != nil { + return nil, err } -} -func (l *listener) Accept() (net.Conn, error) { - conn, ok := <-l.conns - if !ok { - return nil, ErrListenerClosed - } - - hsFrame := NewHSFrameDMSGAccept(l.procID, routing.Loop{ - Local: l.addr, - Remote: conn.remote, - }) - - if _, err := conn.Write(hsFrame); err != nil { - return nil, err + conn := &Conn{ + id: connID, + rpc: l.rpc, + local: l.addr, + // TODO: probably pass with response + remote: routing.Addr{}, } return conn, nil } -func (l *listener) Close() error { - if err := l.stopListening(l.addr.Port); err != nil { - l.logger.WithError(err).Error("error sending DmsgStopListening") - } - - if err := l.lm.remove(l.addr.Port); err != nil { - return err - } - - close(l.conns) - - return nil +func (l *Listener) Close() error { + return l.rpc.CloseListener(l.id) } -func (l *listener) Addr() net.Addr { +func (l *Listener) Addr() net.Addr { return l.addr } - -func (l *listener) addConn(conn *clientConn) { - l.conns <- conn -} diff --git a/pkg/app2/listeners_manager.go b/pkg/app2/listeners_manager.go index 5619a439b9..f8c605e34c 100644 --- a/pkg/app2/listeners_manager.go +++ b/pkg/app2/listeners_manager.go @@ -1,172 +1,84 @@ package app2 import ( - "encoding/binary" - "net" + "fmt" "sync" - "sync/atomic" - "github.com/hashicorp/yamux" "github.com/pkg/errors" - "github.com/skycoin/dmsg/cipher" - "github.com/skycoin/skycoin/src/util/logging" - - "github.com/skycoin/skywire/pkg/routing" -) - -var ( - ErrPortAlreadyBound = errors.New("port is already bound") - ErrNoListenerOnPort = errors.New("no listener on port") - ErrWrongPID = errors.New("wrong ProcID specified in the HS frame") + "github.com/skycoin/dmsg" ) // listenersManager contains and manages all the instantiated listeners type listenersManager struct { - pid ProcID - pk cipher.PubKey - listeners map[routing.Port]*listener - mx sync.RWMutex - isListening int32 - logger *logging.Logger - doneCh chan struct{} - doneWg sync.WaitGroup + listeners map[uint16]*dmsg.Listener + mx sync.RWMutex + lstID uint16 } -func newListenersManager(l *logging.Logger, pid ProcID, pk cipher.PubKey) *listenersManager { +func newListenersManager() *listenersManager { return &listenersManager{ - pid: pid, - pk: pk, - listeners: make(map[routing.Port]*listener), - logger: l, - doneCh: make(chan struct{}), + listeners: make(map[uint16]*dmsg.Listener), } } -func (lm *listenersManager) close() { - close(lm.doneCh) - lm.doneWg.Wait() -} +func (m *listenersManager) nextID() (*uint16, error) { + m.mx.Lock() -func (lm *listenersManager) set(port routing.Port, l *listener) error { - lm.mx.Lock() - if v, ok := lm.listeners[port]; !ok || v != nil { - lm.mx.Unlock() - return ErrPortAlreadyBound + lisID := m.lstID + 1 + for ; lisID < m.lstID; lisID++ { + if _, ok := m.listeners[lisID]; !ok { + break + } } - lm.listeners[port] = l - lm.mx.Unlock() - return nil -} -func (lm *listenersManager) reserveListener(port routing.Port) error { - lm.mx.Lock() - if _, ok := lm.listeners[port]; ok { - lm.mx.Unlock() - return ErrPortAlreadyBound + if lisID == m.lstID { + m.mx.Unlock() + return nil, errors.New("no more available listeners") } - lm.listeners[port] = nil - lm.mx.Unlock() - return nil + + m.listeners[lisID] = nil + m.lstID = lisID + + m.mx.Unlock() + return &lisID, nil } -func (lm *listenersManager) remove(port routing.Port) error { - lm.mx.Lock() - if _, ok := lm.listeners[port]; !ok { - lm.mx.Unlock() - return ErrNoListenerOnPort +func (m *listenersManager) getAndRemove(lisID uint16) (*dmsg.Listener, error) { + m.mx.Lock() + lis, ok := m.listeners[lisID] + if !ok { + m.mx.Unlock() + return nil, fmt.Errorf("no listener with id %d", lisID) } - delete(lm.listeners, port) - lm.mx.Unlock() - return nil -} -// addConn passes connection to the corresponding listener -func (lm *listenersManager) addConn(localPort routing.Port, remote routing.Addr, conn net.Conn) error { - lm.mx.RLock() - if _, ok := lm.listeners[localPort]; !ok { - lm.mx.RUnlock() - return ErrNoListenerOnPort + if lis == nil { + m.mx.Unlock() + return nil, fmt.Errorf("listener with id %d is not set", lisID) } - lm.listeners[localPort].addConn(&clientConn{ - remote: remote, - Conn: conn, - }) - lm.mx.RUnlock() - return nil + + delete(m.listeners, lisID) + + m.mx.Unlock() + return lis, nil } -// listen accepts all new yamux streams from the server. We want to accept only -// `DmsgDial` frames here, thus all the other frames get rejected. `DmsgDial` frames -// are being distributed between the corresponding listeners with regards to their port -func (lm *listenersManager) listen(session *yamux.Session) { - // this one should only start once - if !atomic.CompareAndSwapInt32(&lm.isListening, 0, 1) { - return +func (m *listenersManager) set(lisID uint16, lis *dmsg.Listener) error { + m.mx.Lock() + + if l, ok := m.listeners[lisID]; ok && l != nil { + m.mx.Unlock() + return errors.New("listener already exists") } - lm.doneWg.Add(1) - - go func() { - defer lm.doneWg.Done() - - for { - select { - case <-lm.doneCh: - return - default: - stream, err := session.Accept() - if err != nil { - lm.logger.WithError(err).Error("error accepting stream") - return - } - - hsFrame, err := readHSFrame(stream) - if err != nil { - lm.logger.WithError(err).Error("error reading HS frame") - continue - } - - if hsFrame.ProcID() != lm.pid { - lm.logger.WithError(ErrWrongPID).Error("error listening for Dial") - } - - if hsFrame.FrameType() != HSFrameTypeDMSGDial { - lm.logger.WithError(ErrWrongHSFrameTypeReceived).Error("error listening for Dial") - continue - } - - // TODO: handle field get gracefully - remotePort := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen*2+HSFramePortLen:])) - localPort := routing.Port(binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:])) - - var localPK cipher.PubKey - copy(localPK[:], hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) - - err = lm.addConn(remotePort, routing.Addr{ - PubKey: localPK, - Port: localPort, - }, stream) - if err != nil { - lm.logger.WithError(err).Error("failed to accept") - continue - } - - respHSFrame := NewHSFrameDMSGAccept(hsFrame.ProcID(), routing.Loop{ - Local: routing.Addr{ - PubKey: lm.pk, - Port: remotePort, - }, - Remote: routing.Addr{ - PubKey: localPK, - Port: localPort, - }, - }) - - if _, err := stream.Write(respHSFrame); err != nil { - lm.logger.WithError(err).Error("error responding with DmsgAccept") - continue - } - } - } - }() + m.listeners[lisID] = lis + + m.mx.Unlock() + return nil +} + +func (m *listenersManager) get(lisID uint16) (*dmsg.Listener, bool) { + m.mx.RLock() + lis, ok := m.listeners[lisID] + m.mx.RUnlock() + return lis, ok } diff --git a/pkg/app2/server.go b/pkg/app2/server.go deleted file mode 100644 index 3953266c09..0000000000 --- a/pkg/app2/server.go +++ /dev/null @@ -1,266 +0,0 @@ -package app2 - -import ( - "context" - "encoding/binary" - "fmt" - "io" - "net" - "sync" - - "github.com/skycoin/skycoin/src/util/logging" - - "github.com/hashicorp/yamux" - - "github.com/skycoin/dmsg" - - "github.com/skycoin/skywire/pkg/routing" - - "github.com/pkg/errors" - - "github.com/skycoin/dmsg/cipher" -) - -type serverConn struct { - procID ProcID - conn net.Conn - session *yamux.Session - lm *listenersManager - dmsgListeners map[routing.Port]*dmsg.Listener - dmsgListenersMx sync.RWMutex -} - -// Server is used by skywire visor. -type Server struct { - PK cipher.PubKey - dmsgC *dmsg.Client - apps map[string]*serverConn - appsMx sync.RWMutex - logger *logging.Logger -} - -func NewServer(localPK cipher.PubKey, dmsgC *dmsg.Client, l *logging.Logger) *Server { - return &Server{ - PK: localPK, - dmsgC: dmsgC, - apps: make(map[string]*serverConn), - logger: l, - } -} - -func (s *Server) Serve(sockAddr string) error { - l, err := net.Listen("unix", sockAddr) - if err != nil { - return errors.Wrap(err, "error listening unix socket") - } - - for { - conn, err := l.Accept() - if err != nil { - return errors.Wrap(err, "error accepting client connection") - } - - s.appsMx.Lock() - if _, ok := s.apps[conn.RemoteAddr().String()]; ok { - s.logger.WithError(ErrPortAlreadyBound).Error("error storing session") - } - - session, err := yamux.Server(conn, nil) - if err != nil { - return errors.Wrap(err, "error creating yamux session") - } - - s.apps[conn.RemoteAddr().String()] = &serverConn{ - session: session, - conn: conn, - lm: newListenersManager(), - dmsgListeners: make(map[routing.Port]*dmsg.Listener), - } - s.appsMx.Unlock() - - // TODO: handle error - go s.serveClient(session) - } -} - -func (s *Server) serveClient(conn *serverConn) error { - for { - stream, err := conn.session.Accept() - if err != nil { - return errors.Wrap(err, "error opening stream") - } - - go s.serveStream(stream, conn) -/////////////////////////////// - hsFrame, err := readHSFrame(conn) - if err != nil { - return errors.Wrap(err, "error reading HS frame") - } - - switch hsFrame.FrameType() { - case HSFrameTypeDMSGListen: - if s. - pk := make(cipher.PubKey, 33) - copy(pk, hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) - port := binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:]) - dmsgL, err := s.dmsgC.Listen(port) - if err != nil { - return fmt.Errorf("error listening on port %d: %v", port, err) - } - - respHSFrame := NewHSFrameDMSGListening(hsFrame.ProcID(), routing.Addr{ - PubKey: cipher.PubKey(pk), - Port: 0, - }) - } - } -} - -func (s *Server) serveStream(stream net.Conn, conn *serverConn) error { - for { - hsFrame, err := readHSFrame(stream) - if err != nil { - return errors.Wrap(err, "error reading HS frame") - } - - // TODO: ensure thread-safety - conn.procID = hsFrame.ProcID() - - var respHSFrame HSFrame - switch hsFrame.FrameType() { - case HSFrameTypeDMSGListen: - port := binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:]) - if err := conn.reserveListener(routing.Port(port)); err != nil { - respHSFrame = NewHSFrameError(hsFrame.ProcID()) - } else { - dmsgL, err := s.dmsgC.Listen(port) - if err != nil { - respHSFrame = NewHSFrameError(hsFrame.ProcID()) - } else { - if err := conn.addListener(routing.Port(port), dmsgL); err != nil { - respHSFrame = NewHSFrameError(hsFrame.ProcID()) - } else { - var pk cipher.PubKey - copy(pk[:], hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) - - respHSFrame = NewHSFrameDMSGListening(hsFrame.ProcID(), routing.Addr{ - PubKey: pk, - Port: routing.Port(port), - }) - } - } - } - case HSFrameTypeDMSGDial: - localPort := binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen:]) - var localPK cipher.PubKey - copy(localPK[:], hsFrame[HSFrameHeaderLen:HSFrameHeaderLen+HSFramePKLen]) - - var remotePK cipher.PubKey - copy(remotePK[:], hsFrame[HSFrameHeaderLen+HSFramePKLen+HSFramePortLen:HSFrameHeaderLen+HSFramePKLen+HSFramePortLen+HSFramePKLen]) - remotePort := binary.BigEndian.Uint16(hsFrame[HSFrameHeaderLen+HSFramePKLen+HSFramePortLen+HSFramePKLen:]) - - // TODO: context - tp, err := s.dmsgC.Dial(context.Background(), localPK, localPort) - if err != nil { - respHSFrame = NewHSFrameError(hsFrame.ProcID()) - } else { - respHSFrame = NewHSFrameDMSGAccept(hsFrame.ProcID(), routing.Loop{ - Local: routing.Addr{ - PubKey: localPK, - Port: routing.Port(localPort), - }, - Remote: routing.Addr{ - PubKey: remotePK, - Port: routing.Port(remotePort), - }, - }) - - go func() { - if err := s.forwardOverDMSG(stream, tp); err != nil { - s.logger.WithError(err).Error("error forwarding over DMSG") - } - }() - } - } - - if _, err := stream.Write(respHSFrame); err != nil { - return errors.Wrap(err, "error writing response") - } - } -} - -func (s *Server) forwardOverDMSG(stream net.Conn, tp *dmsg.Transport) error { - toStreamErrCh := make(chan error) - defer close(toStreamErrCh) - go func() { - _, err := io.Copy(stream, tp) - toStreamErrCh <- err - }() - - _, err := io.Copy(stream, tp) - if err != nil { - return err - } - - if err := <-toStreamErrCh; err != nil { - return err - } - - return nil -} - -func (c *serverConn) reserveListener(port routing.Port) error { - c.dmsgListenersMx.Lock() - if _, ok := c.dmsgListeners[port]; ok { - c.dmsgListenersMx.Unlock() - return ErrPortAlreadyBound - } - c.dmsgListeners[port] = nil - c.dmsgListenersMx.Unlock() - return nil -} - -func (c *serverConn) addListener(port routing.Port, l *dmsg.Listener) error { - c.dmsgListenersMx.Lock() - if lis, ok := c.dmsgListeners[port]; ok && lis != nil { - c.dmsgListenersMx.Unlock() - return ErrPortAlreadyBound - } - c.dmsgListeners[port] = l - go c.acceptDMSG(l) - c.dmsgListenersMx.Unlock() - return nil -} - -func (c *serverConn) acceptDMSG(l *dmsg.Listener) error { - for { - stream, err := c.session.Open() - if err != nil { - return errors.Wrap(err, "error opening yamux stream") - } - - remoteAddr, ok := l.Addr().(dmsg.Addr) - if !ok { - // shouldn't happen, but still - return errors.Wrap(err, "wrong type for DMSG addr") - } - - hsFrame := NewHSFrameDSMGDial(c.procID, routing.Loop{ - Local: routing.Addr{ - PubKey: remoteAddr.PK, - Port: routing.Port(remoteAddr.Port), - }, - // TODO: get local addr - Remote: routing.Addr{ - PubKey: - }, - }) - - conn, err := l.Accept() - if err != nil { - return errors.Wrap(err, "error accepting DMSG conn") - } - - - } -} diff --git a/pkg/app2/server_rpc.go b/pkg/app2/server_rpc.go index ba73f4fe88..ace37eb6ae 100644 --- a/pkg/app2/server_rpc.go +++ b/pkg/app2/server_rpc.go @@ -2,10 +2,7 @@ package app2 import ( "context" - "errors" "fmt" - "net" - "sync" "github.com/skycoin/dmsg" @@ -13,139 +10,49 @@ import ( ) type ServerRPC struct { - dmsgC *dmsg.Client - conns map[uint16]net.Conn - connsMx sync.RWMutex - lstConnID uint16 - listeners map[uint16]*dmsg.Listener - listenersMx sync.RWMutex - lstLisID uint16 + dmsgC *dmsg.Client + lm *listenersManager + cm *connsManager } -func (r *ServerRPC) nextConnID() (*uint16, error) { - r.connsMx.Lock() - - connID := r.lstConnID + 1 - for ; connID < r.lstConnID; connID++ { - if _, ok := r.conns[connID]; !ok { - break - } - } - - if connID == r.lstConnID { - r.connsMx.Unlock() - return nil, errors.New("no more available conns") - } - - r.conns[connID] = nil - r.lstConnID = connID - - r.connsMx.Unlock() - return &connID, nil -} - -func (r *ServerRPC) nextLisID() (*uint16, error) { - r.listenersMx.Lock() - - lisID := r.lstLisID + 1 - for ; lisID < r.lstLisID; lisID++ { - if _, ok := r.listeners[lisID]; !ok { - break - } - } - - if lisID == r.lstLisID { - r.listenersMx.Unlock() - return nil, errors.New("no more available listeners") - } - - r.listeners[lisID] = nil - r.lstLisID = lisID - - r.listenersMx.Unlock() - return &lisID, nil -} - -func (r *ServerRPC) setConn(connID uint16, conn net.Conn) error { - r.connsMx.Lock() - - if c, ok := r.conns[connID]; ok && c != nil { - r.connsMx.Unlock() - return errors.New("conn already exists") - } - - r.conns[connID] = conn - - r.connsMx.Unlock() - return nil -} - -func (r *ServerRPC) setListener(lisID uint16, lis *dmsg.Listener) error { - r.listenersMx.Lock() - - if l, ok := r.listeners[lisID]; ok && l != nil { - r.listenersMx.Unlock() - return errors.New("listener already exists") +func newServerRPC(dmsgC *dmsg.Client) *ServerRPC { + return &ServerRPC{ + dmsgC: dmsgC, + lm: newListenersManager(), + cm: newConnsManager(), } - - r.listeners[lisID] = lis - - r.listenersMx.Unlock() - return nil -} - -func (r *ServerRPC) getConn(connID uint16) (net.Conn, bool) { - r.connsMx.RLock() - conn, ok := r.conns[connID] - r.connsMx.RUnlock() - return conn, ok -} - -func (r *ServerRPC) getListener(lisID uint16) (*dmsg.Listener, bool) { - r.listenersMx.RLock() - lis, ok := r.listeners[lisID] - r.listenersMx.RUnlock() - return lis, ok -} - -type DialReq struct { - Remote routing.Addr } -func (r *ServerRPC) Dial(req *DialReq, connID *uint16) error { - connID, err := r.nextConnID() +func (r *ServerRPC) Dial(remote *routing.Addr, connID *uint16) error { + connID, err := r.cm.nextID() if err != nil { return err } - tp, err := r.dmsgC.Dial(context.TODO(), req.Remote.PubKey, uint16(req.Remote.Port)) + tp, err := r.dmsgC.Dial(context.TODO(), remote.PubKey, uint16(remote.Port)) if err != nil { return err } - if err := r.setConn(*connID, tp); err != nil { + if err := r.cm.set(*connID, tp); err != nil { return err } return nil } -type ListenReq struct { - Local routing.Addr -} - -func (r *ServerRPC) Listen(req *ListenReq, lisID *uint16) error { - lisID, err := r.nextLisID() +func (r *ServerRPC) Listen(local *routing.Addr, lisID *uint16) error { + lisID, err := r.lm.nextID() if err != nil { return err } - dmsgL, err := r.dmsgC.Listen(uint16(req.Local.Port)) + dmsgL, err := r.dmsgC.Listen(uint16(local.Port)) if err != nil { return err } - if err := r.setListener(*lisID, dmsgL); err != nil { + if err := r.lm.set(*lisID, dmsgL); err != nil { // TODO: close listener return err } @@ -154,12 +61,12 @@ func (r *ServerRPC) Listen(req *ListenReq, lisID *uint16) error { } func (r *ServerRPC) Accept(lisID *uint16, connID *uint16) error { - lis, ok := r.getListener(*lisID) + lis, ok := r.lm.get(*lisID) if !ok { return fmt.Errorf("not listener with id %d", *lisID) } - connID, err := r.nextConnID() + connID, err := r.cm.nextID() if err != nil { return err } @@ -169,7 +76,7 @@ func (r *ServerRPC) Accept(lisID *uint16, connID *uint16) error { return err } - if err := r.setConn(*connID, tp); err != nil { + if err := r.cm.set(*connID, tp); err != nil { // TODO: close conn return err } @@ -183,9 +90,54 @@ type WriteReq struct { } func (r *ServerRPC) Write(req *WriteReq, n *int) error { - conn, ok := r.getConn(req.ConnID) + conn, ok := r.cm.get(req.ConnID) + if !ok { + return fmt.Errorf("no conn with id %d", req.ConnID) + } + + var err error + *n, err = conn.Write(req.B) + if err != nil { + return err + } + + return nil +} + +type ReadResp struct { + B []byte + N int +} + +func (r *ServerRPC) Read(connID *uint16, resp *ReadResp) error { + conn, ok := r.cm.get(*connID) if !ok { - return fmt.Errorf("not conn with id %d", req.ConnID) + return fmt.Errorf("no conn with id %d", *connID) + } + + var err error + resp.N, err = conn.Read(resp.B) + if err != nil { + return err + } + + return nil +} + +func (r *ServerRPC) CloseConn(connID *uint16, _ *struct{}) error { + conn, err := r.cm.getAndRemove(*connID) + if err != nil { + return err + } + + return conn.Close() +} + +func (r *ServerRPC) CloseListener(lisID *uint16, _ *struct{}) error { + lis, err := r.lm.getAndRemove(*lisID) + if err != nil { + return err } + return lis.Close() } diff --git a/pkg/app2/server_rpc_client.go b/pkg/app2/server_rpc_client.go new file mode 100644 index 0000000000..f9e8ed3970 --- /dev/null +++ b/pkg/app2/server_rpc_client.go @@ -0,0 +1,101 @@ +package app2 + +import ( + "net/rpc" + + "github.com/skycoin/skywire/pkg/routing" +) + +type ServerRPCClient interface { + Dial(remote routing.Addr) (uint16, error) + Listen(local routing.Addr) (uint16, error) + Accept(lisID uint16) (uint16, error) + Write(connID uint16, b []byte) (int, error) + Read(connID uint16, b []byte) (int, error) + CloseConn(id uint16) error + CloseListener(id uint16) error +} + +type ListenerRPCClient interface { + Accept(id uint16) (uint16, error) + CloseListener(id uint16) error + Write(connID uint16, b []byte) (int, error) + Read(connID uint16, b []byte) (int, error) + CloseConn(id uint16) error +} + +type ConnRPCClient interface { + Write(id uint16, b []byte) (int, error) + Read(id uint16, b []byte) (int, error) + CloseConn(id uint16) error +} + +type serverRPCCLient struct { + rpc *rpc.Client +} + +func newServerRPCClient(rpc *rpc.Client) ServerRPCClient { + return &serverRPCCLient{ + rpc: rpc, + } +} + +func (c *serverRPCCLient) Dial(remote routing.Addr) (uint16, error) { + var connID uint16 + if err := c.rpc.Call("Dial", &remote, &connID); err != nil { + return 0, err + } + + return connID, nil +} + +func (c *serverRPCCLient) Listen(local routing.Addr) (uint16, error) { + var lisID uint16 + if err := c.rpc.Call("Listen", &local, &lisID); err != nil { + return 0, err + } + + return lisID, nil +} + +func (c *serverRPCCLient) Accept(lisID uint16) (uint16, error) { + var connID uint16 + if err := c.rpc.Call("Accept", &lisID, &connID); err != nil { + return 0, err + } + + return connID, nil +} + +func (c *serverRPCCLient) Write(connID uint16, b []byte) (int, error) { + req := WriteReq{ + ConnID: connID, + B: b, + } + + var n int + if err := c.rpc.Call("Write", &req, &n); err != nil { + return n, err + } + + return n, nil +} + +func (c *serverRPCCLient) Read(connID uint16, b []byte) (int, error) { + var resp ReadResp + if err := c.rpc.Call("Read", &connID, &resp); err != nil { + return 0, err + } + + copy(b[:resp.N], resp.B[:resp.N]) + + return resp.N, nil +} + +func (c *serverRPCCLient) CloseConn(id uint16) error { + return c.rpc.Call("CloseConn", &id, nil) +} + +func (c *serverRPCCLient) CloseListener(id uint16) error { + return c.rpc.Call("CloseListener", &id, nil) +} From 7bcafc95705761fe8a8d0fd77ab6fee3ba94cc83 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Mon, 16 Sep 2019 13:31:04 +0300 Subject: [PATCH 15/43] Add comments --- pkg/app2/client.go | 33 ++++++++++++++++++--------------- pkg/app2/conn.go | 1 + pkg/app2/conns_manager.go | 7 +++++++ pkg/app2/doc.go | 2 +- pkg/app2/listener.go | 2 ++ pkg/app2/listeners_manager.go | 8 +++++++- pkg/app2/server_rpc.go | 11 +++++++++++ pkg/app2/server_rpc_client.go | 14 ++++++++++++++ 8 files changed, 61 insertions(+), 17 deletions(-) diff --git a/pkg/app2/client.go b/pkg/app2/client.go index c78d1ab7b0..912134a8df 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -3,33 +3,35 @@ package app2 import ( "net/rpc" - "github.com/skycoin/skywire/pkg/routing" - "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skycoin/src/util/logging" + + "github.com/skycoin/skywire/pkg/routing" ) // Client is used by skywire apps. type Client struct { - PK cipher.PubKey - pid ProcID - rpc ServerRPCClient - logger *logging.Logger + pk cipher.PubKey + pid ProcID + rpc ServerRPCClient + log *logging.Logger } -// NewClient creates a new Client. The Client needs to be provided with: +// NewClient creates a new `Client`. The `Client` needs to be provided with: +// - log: Logger instance. // - localPK: The local public key of the parent skywire visor. // - pid: The procID assigned for the process that Client is being used by. -// - sockAddr: The socket address to connect to Server. -func NewClient(localPK cipher.PubKey, pid ProcID, rpc *rpc.Client, l *logging.Logger) *Client { +// - rpc: RPC client to communicate with the server. +func NewClient(log *logging.Logger, localPK cipher.PubKey, pid ProcID, rpc *rpc.Client) *Client { return &Client{ - PK: localPK, - pid: pid, - rpc: newServerRPCClient(rpc), - logger: l, + pk: localPK, + pid: pid, + rpc: newServerRPCClient(rpc), + log: log, } } +// Dial dials the remote node using `remote`. func (c *Client) Dial(remote routing.Addr) (*Conn, error) { connID, err := c.rpc.Dial(remote) if err != nil { @@ -41,7 +43,7 @@ func (c *Client) Dial(remote routing.Addr) (*Conn, error) { rpc: c.rpc, // TODO: port? local: routing.Addr{ - PubKey: c.PK, + PubKey: c.pk, }, remote: remote, } @@ -49,9 +51,10 @@ func (c *Client) Dial(remote routing.Addr) (*Conn, error) { return conn, nil } +// Listen listens on the specified `port` for the incoming connections. func (c *Client) Listen(port routing.Port) (*Listener, error) { local := routing.Addr{ - PubKey: c.PK, + PubKey: c.pk, Port: port, } diff --git a/pkg/app2/conn.go b/pkg/app2/conn.go index c5a7e2ed91..6913315a09 100644 --- a/pkg/app2/conn.go +++ b/pkg/app2/conn.go @@ -6,6 +6,7 @@ import ( "github.com/skycoin/skywire/pkg/routing" ) +// Conn is a connection from app client to the server. type Conn struct { id uint16 rpc ConnRPCClient diff --git a/pkg/app2/conns_manager.go b/pkg/app2/conns_manager.go index d09b61c7cd..101ad42425 100644 --- a/pkg/app2/conns_manager.go +++ b/pkg/app2/conns_manager.go @@ -8,18 +8,21 @@ import ( "github.com/pkg/errors" ) +// connsManager manages connections within the app server. type connsManager struct { conns map[uint16]net.Conn mx sync.RWMutex lstID uint16 } +// newConnsManager constructs new `connsManager`. func newConnsManager() *connsManager { return &connsManager{ conns: make(map[uint16]net.Conn), } } +// `nextID` reserves slot for the next connection and returns its id. func (m *connsManager) nextID() (*uint16, error) { m.mx.Lock() @@ -42,6 +45,8 @@ func (m *connsManager) nextID() (*uint16, error) { return &connID, nil } +// getAndRemove removes connection specified by `connID` from the manager instance and +// returns it. func (m *connsManager) getAndRemove(connID uint16) (net.Conn, error) { m.mx.Lock() conn, ok := m.conns[connID] @@ -61,6 +66,7 @@ func (m *connsManager) getAndRemove(connID uint16) (net.Conn, error) { return conn, nil } +// set sets `conn` associated with `connID`. func (m *connsManager) set(connID uint16, conn net.Conn) error { m.mx.Lock() @@ -75,6 +81,7 @@ func (m *connsManager) set(connID uint16, conn net.Conn) error { return nil } +// get gets the connection associated with the `connID`. func (m *connsManager) get(connID uint16) (net.Conn, bool) { m.mx.RLock() conn, ok := m.conns[connID] diff --git a/pkg/app2/doc.go b/pkg/app2/doc.go index 0dca17f0d7..ff4dfd1b32 100644 --- a/pkg/app2/doc.go +++ b/pkg/app2/doc.go @@ -1,4 +1,4 @@ // Package app2 provides facilities to establish communication // between a visor node and a skywire application. Intended to -// replace the original `app` module +// replace the original `app` module. package app2 diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go index fb423fbfdc..ee11225090 100644 --- a/pkg/app2/listener.go +++ b/pkg/app2/listener.go @@ -6,6 +6,7 @@ import ( "github.com/skycoin/skywire/pkg/routing" ) +// Listener is a listener for app server connections. type Listener struct { id uint16 rpc ListenerRPCClient @@ -29,6 +30,7 @@ func (l *Listener) Accept() (*Conn, error) { return conn, nil } +// TODO: should unblock all called `Accept`s with errors func (l *Listener) Close() error { return l.rpc.CloseListener(l.id) } diff --git a/pkg/app2/listeners_manager.go b/pkg/app2/listeners_manager.go index f8c605e34c..48233da113 100644 --- a/pkg/app2/listeners_manager.go +++ b/pkg/app2/listeners_manager.go @@ -8,19 +8,21 @@ import ( "github.com/skycoin/dmsg" ) -// listenersManager contains and manages all the instantiated listeners +// connsManager manages listeners within the app server. type listenersManager struct { listeners map[uint16]*dmsg.Listener mx sync.RWMutex lstID uint16 } +// newListenersManager constructs new `listenersManager`. func newListenersManager() *listenersManager { return &listenersManager{ listeners: make(map[uint16]*dmsg.Listener), } } +// `nextID` reserves slot for the next listener and returns its id. func (m *listenersManager) nextID() (*uint16, error) { m.mx.Lock() @@ -43,6 +45,8 @@ func (m *listenersManager) nextID() (*uint16, error) { return &lisID, nil } +// getAndRemove removes listener specified by `lisID` from the manager instance and +// returns it. func (m *listenersManager) getAndRemove(lisID uint16) (*dmsg.Listener, error) { m.mx.Lock() lis, ok := m.listeners[lisID] @@ -62,6 +66,7 @@ func (m *listenersManager) getAndRemove(lisID uint16) (*dmsg.Listener, error) { return lis, nil } +// set sets `lis` associated with `lisID`. func (m *listenersManager) set(lisID uint16, lis *dmsg.Listener) error { m.mx.Lock() @@ -76,6 +81,7 @@ func (m *listenersManager) set(lisID uint16, lis *dmsg.Listener) error { return nil } +// get gets the listener associated with the `lisID`. func (m *listenersManager) get(lisID uint16) (*dmsg.Listener, bool) { m.mx.RLock() lis, ok := m.listeners[lisID] diff --git a/pkg/app2/server_rpc.go b/pkg/app2/server_rpc.go index ace37eb6ae..c269510864 100644 --- a/pkg/app2/server_rpc.go +++ b/pkg/app2/server_rpc.go @@ -9,12 +9,14 @@ import ( "github.com/skycoin/skywire/pkg/routing" ) +// ServerRPC is a RPC interface for the app server. type ServerRPC struct { dmsgC *dmsg.Client lm *listenersManager cm *connsManager } +// newServerRPC constructs new server RPC interface. func newServerRPC(dmsgC *dmsg.Client) *ServerRPC { return &ServerRPC{ dmsgC: dmsgC, @@ -23,6 +25,7 @@ func newServerRPC(dmsgC *dmsg.Client) *ServerRPC { } } +// Dial dials to the remote. func (r *ServerRPC) Dial(remote *routing.Addr, connID *uint16) error { connID, err := r.cm.nextID() if err != nil { @@ -41,6 +44,7 @@ func (r *ServerRPC) Dial(remote *routing.Addr, connID *uint16) error { return nil } +// Listen starts listening. func (r *ServerRPC) Listen(local *routing.Addr, lisID *uint16) error { lisID, err := r.lm.nextID() if err != nil { @@ -60,6 +64,7 @@ func (r *ServerRPC) Listen(local *routing.Addr, lisID *uint16) error { return nil } +// Accept accepts connection from the listener specified by `lisID`. func (r *ServerRPC) Accept(lisID *uint16, connID *uint16) error { lis, ok := r.lm.get(*lisID) if !ok { @@ -84,11 +89,13 @@ func (r *ServerRPC) Accept(lisID *uint16, connID *uint16) error { return nil } +// WriteReq contains arguments for `Write`. type WriteReq struct { ConnID uint16 B []byte } +// Write writes to the connection. func (r *ServerRPC) Write(req *WriteReq, n *int) error { conn, ok := r.cm.get(req.ConnID) if !ok { @@ -104,11 +111,13 @@ func (r *ServerRPC) Write(req *WriteReq, n *int) error { return nil } +// ReadResp contains response parameters for `Read`. type ReadResp struct { B []byte N int } +// Read reads data from connection specified by `connID`. func (r *ServerRPC) Read(connID *uint16, resp *ReadResp) error { conn, ok := r.cm.get(*connID) if !ok { @@ -124,6 +133,7 @@ func (r *ServerRPC) Read(connID *uint16, resp *ReadResp) error { return nil } +// CloseConn closes connection specified by `connID`. func (r *ServerRPC) CloseConn(connID *uint16, _ *struct{}) error { conn, err := r.cm.getAndRemove(*connID) if err != nil { @@ -133,6 +143,7 @@ func (r *ServerRPC) CloseConn(connID *uint16, _ *struct{}) error { return conn.Close() } +// CloseListener closes listener specified by `lisID`. func (r *ServerRPC) CloseListener(lisID *uint16, _ *struct{}) error { lis, err := r.lm.getAndRemove(*lisID) if err != nil { diff --git a/pkg/app2/server_rpc_client.go b/pkg/app2/server_rpc_client.go index f9e8ed3970..e277d18e3a 100644 --- a/pkg/app2/server_rpc_client.go +++ b/pkg/app2/server_rpc_client.go @@ -6,6 +6,7 @@ import ( "github.com/skycoin/skywire/pkg/routing" ) +// ServerRPCClient describes RPC interface to communicate with the server. type ServerRPCClient interface { Dial(remote routing.Addr) (uint16, error) Listen(local routing.Addr) (uint16, error) @@ -16,6 +17,8 @@ type ServerRPCClient interface { CloseListener(id uint16) error } +// ListenerRPCClient describes RPC interface to communicate with the server. +// Contains funcs for `Listener` and `Conn`. type ListenerRPCClient interface { Accept(id uint16) (uint16, error) CloseListener(id uint16) error @@ -24,22 +27,27 @@ type ListenerRPCClient interface { CloseConn(id uint16) error } +// ConnRPCClient describes RPC interface to communicate with the server. +// Contains funcs for `Conn`. type ConnRPCClient interface { Write(id uint16, b []byte) (int, error) Read(id uint16, b []byte) (int, error) CloseConn(id uint16) error } +// serverRPCClient implements `ServerRPCClient`. type serverRPCCLient struct { rpc *rpc.Client } +// newServerRPCClient constructs new `serverRPCClient`. func newServerRPCClient(rpc *rpc.Client) ServerRPCClient { return &serverRPCCLient{ rpc: rpc, } } +// Dial sends `Dial` command to the server. func (c *serverRPCCLient) Dial(remote routing.Addr) (uint16, error) { var connID uint16 if err := c.rpc.Call("Dial", &remote, &connID); err != nil { @@ -49,6 +57,7 @@ func (c *serverRPCCLient) Dial(remote routing.Addr) (uint16, error) { return connID, nil } +// Listen sends `Listen` command to the server. func (c *serverRPCCLient) Listen(local routing.Addr) (uint16, error) { var lisID uint16 if err := c.rpc.Call("Listen", &local, &lisID); err != nil { @@ -58,6 +67,7 @@ func (c *serverRPCCLient) Listen(local routing.Addr) (uint16, error) { return lisID, nil } +// Accept sends `Accept` command to the server. func (c *serverRPCCLient) Accept(lisID uint16) (uint16, error) { var connID uint16 if err := c.rpc.Call("Accept", &lisID, &connID); err != nil { @@ -67,6 +77,7 @@ func (c *serverRPCCLient) Accept(lisID uint16) (uint16, error) { return connID, nil } +// Write sends `Write` command to the server. func (c *serverRPCCLient) Write(connID uint16, b []byte) (int, error) { req := WriteReq{ ConnID: connID, @@ -81,6 +92,7 @@ func (c *serverRPCCLient) Write(connID uint16, b []byte) (int, error) { return n, nil } +// Read sends `Read` command to the server. func (c *serverRPCCLient) Read(connID uint16, b []byte) (int, error) { var resp ReadResp if err := c.rpc.Call("Read", &connID, &resp); err != nil { @@ -92,10 +104,12 @@ func (c *serverRPCCLient) Read(connID uint16, b []byte) (int, error) { return resp.N, nil } +// CloseConn sends `CloseConn` command to the server. func (c *serverRPCCLient) CloseConn(id uint16) error { return c.rpc.Call("CloseConn", &id, nil) } +// CloseListener sends `CloseListener` command to the server. func (c *serverRPCCLient) CloseListener(id uint16) error { return c.rpc.Call("CloseListener", &id, nil) } From cf5e1059d2944a123dd81f51c65d53e5515e339c Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Mon, 16 Sep 2019 14:35:02 +0300 Subject: [PATCH 16/43] Impement `net` interfaces --- pkg/app2/client.go | 3 ++- pkg/app2/conn.go | 16 +++++++++++++++- pkg/app2/listener.go | 4 ++-- pkg/app2/server_rpc_client.go | 18 ------------------ 4 files changed, 19 insertions(+), 22 deletions(-) diff --git a/pkg/app2/client.go b/pkg/app2/client.go index 912134a8df..c24cca523e 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -1,6 +1,7 @@ package app2 import ( + "net" "net/rpc" "github.com/skycoin/dmsg/cipher" @@ -52,7 +53,7 @@ func (c *Client) Dial(remote routing.Addr) (*Conn, error) { } // Listen listens on the specified `port` for the incoming connections. -func (c *Client) Listen(port routing.Port) (*Listener, error) { +func (c *Client) Listen(port routing.Port) (net.Listener, error) { local := routing.Addr{ PubKey: c.pk, Port: port, diff --git a/pkg/app2/conn.go b/pkg/app2/conn.go index 6913315a09..f250ea6621 100644 --- a/pkg/app2/conn.go +++ b/pkg/app2/conn.go @@ -2,14 +2,16 @@ package app2 import ( "net" + "time" + "github.com/pkg/errors" "github.com/skycoin/skywire/pkg/routing" ) // Conn is a connection from app client to the server. type Conn struct { id uint16 - rpc ConnRPCClient + rpc ServerRPCClient local routing.Addr remote routing.Addr } @@ -33,3 +35,15 @@ func (c *Conn) LocalAddr() net.Addr { func (c *Conn) RemoteAddr() net.Addr { return c.remote } + +func (c *Conn) SetDeadline(t time.Time) error { + return errors.New("method not implemented") +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return errors.New("method not implemented") +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + return errors.New("method not implemented") +} diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go index ee11225090..b1854ba5e0 100644 --- a/pkg/app2/listener.go +++ b/pkg/app2/listener.go @@ -9,11 +9,11 @@ import ( // Listener is a listener for app server connections. type Listener struct { id uint16 - rpc ListenerRPCClient + rpc ServerRPCClient addr routing.Addr } -func (l *Listener) Accept() (*Conn, error) { +func (l *Listener) Accept() (net.Conn, error) { connID, err := l.rpc.Accept(l.id) if err != nil { return nil, err diff --git a/pkg/app2/server_rpc_client.go b/pkg/app2/server_rpc_client.go index e277d18e3a..bb7a5e842a 100644 --- a/pkg/app2/server_rpc_client.go +++ b/pkg/app2/server_rpc_client.go @@ -17,24 +17,6 @@ type ServerRPCClient interface { CloseListener(id uint16) error } -// ListenerRPCClient describes RPC interface to communicate with the server. -// Contains funcs for `Listener` and `Conn`. -type ListenerRPCClient interface { - Accept(id uint16) (uint16, error) - CloseListener(id uint16) error - Write(connID uint16, b []byte) (int, error) - Read(connID uint16, b []byte) (int, error) - CloseConn(id uint16) error -} - -// ConnRPCClient describes RPC interface to communicate with the server. -// Contains funcs for `Conn`. -type ConnRPCClient interface { - Write(id uint16, b []byte) (int, error) - Read(id uint16, b []byte) (int, error) - CloseConn(id uint16) error -} - // serverRPCClient implements `ServerRPCClient`. type serverRPCCLient struct { rpc *rpc.Client From f68b3e0f86bb6ace0e04ed2abe0c0f3e019a3f2c Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Mon, 16 Sep 2019 20:45:55 +0300 Subject: [PATCH 17/43] Add proper port handling --- go.mod | 3 +- go.sum | 2 + pkg/app2/client.go | 47 ++- pkg/app2/conn.go | 11 +- pkg/app2/errors.go | 7 + pkg/app2/listener.go | 20 +- pkg/app2/server_rpc.go | 38 ++- pkg/app2/server_rpc_client.go | 12 +- vendor/github.com/pkg/errors/.gitignore | 24 ++ vendor/github.com/pkg/errors/.travis.yml | 11 + vendor/github.com/pkg/errors/LICENSE | 23 ++ vendor/github.com/pkg/errors/README.md | 52 ++++ vendor/github.com/pkg/errors/appveyor.yml | 32 +++ vendor/github.com/pkg/errors/errors.go | 269 ++++++++++++++++++ vendor/github.com/pkg/errors/stack.go | 178 ++++++++++++ vendor/github.com/skycoin/dmsg/addr.go | 26 -- vendor/github.com/skycoin/dmsg/client.go | 28 +- vendor/github.com/skycoin/dmsg/client_conn.go | 172 ++++++----- vendor/github.com/skycoin/dmsg/listener.go | 121 ++++---- .../github.com/skycoin/dmsg/netutil/porter.go | 102 +++++++ .../github.com/skycoin/dmsg/port_manager.go | 76 +++-- vendor/github.com/skycoin/dmsg/server.go | 233 --------------- vendor/github.com/skycoin/dmsg/server_conn.go | 243 ++++++++++++++++ vendor/github.com/skycoin/dmsg/transport.go | 27 +- .../skycoin/dmsg/{frame.go => types.go} | 74 +++-- .../x/sys/windows/syscall_windows.go | 13 +- .../x/sys/windows/zsyscall_windows.go | 6 + vendor/modules.txt | 9 +- 28 files changed, 1317 insertions(+), 542 deletions(-) create mode 100644 pkg/app2/errors.go create mode 100644 vendor/github.com/pkg/errors/.gitignore create mode 100644 vendor/github.com/pkg/errors/.travis.yml create mode 100644 vendor/github.com/pkg/errors/LICENSE create mode 100644 vendor/github.com/pkg/errors/README.md create mode 100644 vendor/github.com/pkg/errors/appveyor.yml create mode 100644 vendor/github.com/pkg/errors/errors.go create mode 100644 vendor/github.com/pkg/errors/stack.go delete mode 100644 vendor/github.com/skycoin/dmsg/addr.go create mode 100644 vendor/github.com/skycoin/dmsg/netutil/porter.go create mode 100644 vendor/github.com/skycoin/dmsg/server_conn.go rename vendor/github.com/skycoin/dmsg/{frame.go => types.go} (71%) diff --git a/go.mod b/go.mod index d89040e5b9..3c85be4999 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d github.com/kr/pty v1.1.5 // indirect github.com/mitchellh/go-homedir v1.1.0 + github.com/pkg/errors v0.8.0 github.com/pkg/profile v1.3.0 github.com/prometheus/client_golang v1.0.0 github.com/prometheus/common v0.4.1 @@ -28,4 +29,4 @@ require ( ) // Uncomment for tests with alternate branches of 'dmsg' -//replace github.com/skycoin/dmsg => ../dmsg +replace github.com/skycoin/dmsg => ../dmsg diff --git a/go.sum b/go.sum index 1e5468a790..f60c1121a3 100644 --- a/go.sum +++ b/go.sum @@ -69,6 +69,7 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= +github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/profile v1.3.0 h1:OQIvuDgm00gWVWGTf4m4mCt6W1/0YqU7Ntg0mySWgaI= github.com/pkg/profile v1.3.0/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA= @@ -143,6 +144,7 @@ golang.org/x/sys v0.0.0-20190804053845-51ab0e2deafa h1:KIDDMLT1O0Nr7TSxp8xM5tJcd golang.org/x/sys v0.0.0-20190804053845-51ab0e2deafa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a h1:aYOabOQFp6Vj6W1F80affTUvO9UxmJRx8K0gsfABByQ= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190825160603-fb81701db80f h1:LCxigP8q3fPRGNVYndYsyHnF0zRrvcoVwZMfb8iQZe4= golang.org/x/sys v0.0.0-20190825160603-fb81701db80f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= diff --git a/pkg/app2/client.go b/pkg/app2/client.go index c24cca523e..32ea18169f 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -1,9 +1,12 @@ package app2 import ( + "context" "net" "net/rpc" + "github.com/skycoin/dmsg/netutil" + "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skycoin/src/util/logging" @@ -12,10 +15,11 @@ import ( // Client is used by skywire apps. type Client struct { - pk cipher.PubKey - pid ProcID - rpc ServerRPCClient - log *logging.Logger + pk cipher.PubKey + pid ProcID + rpc ServerRPCClient + log *logging.Logger + porter *netutil.Porter } // NewClient creates a new `Client`. The `Client` needs to be provided with: @@ -23,17 +27,24 @@ type Client struct { // - localPK: The local public key of the parent skywire visor. // - pid: The procID assigned for the process that Client is being used by. // - rpc: RPC client to communicate with the server. -func NewClient(log *logging.Logger, localPK cipher.PubKey, pid ProcID, rpc *rpc.Client) *Client { +func NewClient(log *logging.Logger, localPK cipher.PubKey, pid ProcID, rpc *rpc.Client, + porter *netutil.Porter) *Client { return &Client{ - pk: localPK, - pid: pid, - rpc: newServerRPCClient(rpc), - log: log, + pk: localPK, + pid: pid, + rpc: newServerRPCClient(rpc), + log: log, + porter: porter, } } // Dial dials the remote node using `remote`. func (c *Client) Dial(remote routing.Addr) (*Conn, error) { + localPort, free, err := c.porter.ReserveEphemeral(context.TODO(), nil) + if err != nil { + return nil, err + } + connID, err := c.rpc.Dial(remote) if err != nil { return nil, err @@ -42,11 +53,12 @@ func (c *Client) Dial(remote routing.Addr) (*Conn, error) { conn := &Conn{ id: connID, rpc: c.rpc, - // TODO: port? local: routing.Addr{ PubKey: c.pk, + Port: routing.Port(localPort), }, - remote: remote, + remote: remote, + freeLocalPort: free, } return conn, nil @@ -54,6 +66,12 @@ func (c *Client) Dial(remote routing.Addr) (*Conn, error) { // Listen listens on the specified `port` for the incoming connections. func (c *Client) Listen(port routing.Port) (net.Listener, error) { + ok, free := c.porter.Reserve(uint16(port), nil) + if !ok { + free() + return nil, ErrPortAlreadyBound + } + local := routing.Addr{ PubKey: c.pk, Port: port, @@ -65,9 +83,10 @@ func (c *Client) Listen(port routing.Port) (net.Listener, error) { } listener := &Listener{ - id: lisID, - rpc: c.rpc, - addr: local, + id: lisID, + rpc: c.rpc, + addr: local, + freePort: free, } return listener, nil diff --git a/pkg/app2/conn.go b/pkg/app2/conn.go index f250ea6621..9bf669d3fa 100644 --- a/pkg/app2/conn.go +++ b/pkg/app2/conn.go @@ -10,10 +10,11 @@ import ( // Conn is a connection from app client to the server. type Conn struct { - id uint16 - rpc ServerRPCClient - local routing.Addr - remote routing.Addr + id uint16 + rpc ServerRPCClient + local routing.Addr + remote routing.Addr + freeLocalPort func() } func (c *Conn) Read(b []byte) (int, error) { @@ -25,6 +26,8 @@ func (c *Conn) Write(b []byte) (int, error) { } func (c *Conn) Close() error { + defer c.freeLocalPort() + return c.rpc.CloseConn(c.id) } diff --git a/pkg/app2/errors.go b/pkg/app2/errors.go new file mode 100644 index 0000000000..f07b063015 --- /dev/null +++ b/pkg/app2/errors.go @@ -0,0 +1,7 @@ +package app2 + +import "github.com/pkg/errors" + +var ( + ErrPortAlreadyBound = errors.New("port is already bound") +) diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go index b1854ba5e0..cebc34935f 100644 --- a/pkg/app2/listener.go +++ b/pkg/app2/listener.go @@ -8,23 +8,23 @@ import ( // Listener is a listener for app server connections. type Listener struct { - id uint16 - rpc ServerRPCClient - addr routing.Addr + id uint16 + rpc ServerRPCClient + addr routing.Addr + freePort func() } func (l *Listener) Accept() (net.Conn, error) { - connID, err := l.rpc.Accept(l.id) + connID, remote, err := l.rpc.Accept(l.id) if err != nil { return nil, err } conn := &Conn{ - id: connID, - rpc: l.rpc, - local: l.addr, - // TODO: probably pass with response - remote: routing.Addr{}, + id: connID, + rpc: l.rpc, + local: l.addr, + remote: remote, } return conn, nil @@ -32,6 +32,8 @@ func (l *Listener) Accept() (net.Conn, error) { // TODO: should unblock all called `Accept`s with errors func (l *Listener) Close() error { + defer l.freePort() + return l.rpc.CloseListener(l.id) } diff --git a/pkg/app2/server_rpc.go b/pkg/app2/server_rpc.go index c269510864..2800db44bc 100644 --- a/pkg/app2/server_rpc.go +++ b/pkg/app2/server_rpc.go @@ -4,6 +4,9 @@ import ( "context" "fmt" + "github.com/skycoin/skycoin/src/util/logging" + + "github.com/pkg/errors" "github.com/skycoin/dmsg" "github.com/skycoin/skywire/pkg/routing" @@ -14,14 +17,16 @@ type ServerRPC struct { dmsgC *dmsg.Client lm *listenersManager cm *connsManager + log *logging.Logger } // newServerRPC constructs new server RPC interface. -func newServerRPC(dmsgC *dmsg.Client) *ServerRPC { +func newServerRPC(log *logging.Logger, dmsgC *dmsg.Client) *ServerRPC { return &ServerRPC{ dmsgC: dmsgC, lm: newListenersManager(), cm: newConnsManager(), + log: log, } } @@ -57,15 +62,24 @@ func (r *ServerRPC) Listen(local *routing.Addr, lisID *uint16) error { } if err := r.lm.set(*lisID, dmsgL); err != nil { - // TODO: close listener + if err := dmsgL.Close(); err != nil { + r.log.WithError(err).Error("error closing DMSG listener") + } + return err } return nil } +// AcceptResp contains response parameters for `Accept`. +type AcceptResp struct { + Remote routing.Addr + ConnID uint16 +} + // Accept accepts connection from the listener specified by `lisID`. -func (r *ServerRPC) Accept(lisID *uint16, connID *uint16) error { +func (r *ServerRPC) Accept(lisID *uint16, resp *AcceptResp) error { lis, ok := r.lm.get(*lisID) if !ok { return fmt.Errorf("not listener with id %d", *lisID) @@ -82,10 +96,26 @@ func (r *ServerRPC) Accept(lisID *uint16, connID *uint16) error { } if err := r.cm.set(*connID, tp); err != nil { - // TODO: close conn + if err := tp.Close(); err != nil { + r.log.WithError(err).Error("error closing DMSG transport") + } + return err } + remote, ok := tp.RemoteAddr().(dmsg.Addr) + if !ok { + return errors.New("wrong type for transport remote addr") + } + + resp = &AcceptResp{ + Remote: routing.Addr{ + PubKey: remote.PK, + Port: routing.Port(remote.Port), + }, + ConnID: *connID, + } + return nil } diff --git a/pkg/app2/server_rpc_client.go b/pkg/app2/server_rpc_client.go index bb7a5e842a..a1e825733c 100644 --- a/pkg/app2/server_rpc_client.go +++ b/pkg/app2/server_rpc_client.go @@ -10,7 +10,7 @@ import ( type ServerRPCClient interface { Dial(remote routing.Addr) (uint16, error) Listen(local routing.Addr) (uint16, error) - Accept(lisID uint16) (uint16, error) + Accept(lisID uint16) (uint16, routing.Addr, error) Write(connID uint16, b []byte) (int, error) Read(connID uint16, b []byte) (int, error) CloseConn(id uint16) error @@ -50,13 +50,13 @@ func (c *serverRPCCLient) Listen(local routing.Addr) (uint16, error) { } // Accept sends `Accept` command to the server. -func (c *serverRPCCLient) Accept(lisID uint16) (uint16, error) { - var connID uint16 - if err := c.rpc.Call("Accept", &lisID, &connID); err != nil { - return 0, err +func (c *serverRPCCLient) Accept(lisID uint16) (uint16, routing.Addr, error) { + var acceptResp AcceptResp + if err := c.rpc.Call("Accept", &lisID, &acceptResp); err != nil { + return 0, routing.Addr{}, err } - return connID, nil + return acceptResp.ConnID, acceptResp.Remote, nil } // Write sends `Write` command to the server. diff --git a/vendor/github.com/pkg/errors/.gitignore b/vendor/github.com/pkg/errors/.gitignore new file mode 100644 index 0000000000..daf913b1b3 --- /dev/null +++ b/vendor/github.com/pkg/errors/.gitignore @@ -0,0 +1,24 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof diff --git a/vendor/github.com/pkg/errors/.travis.yml b/vendor/github.com/pkg/errors/.travis.yml new file mode 100644 index 0000000000..588ceca183 --- /dev/null +++ b/vendor/github.com/pkg/errors/.travis.yml @@ -0,0 +1,11 @@ +language: go +go_import_path: github.com/pkg/errors +go: + - 1.4.3 + - 1.5.4 + - 1.6.2 + - 1.7.1 + - tip + +script: + - go test -v ./... diff --git a/vendor/github.com/pkg/errors/LICENSE b/vendor/github.com/pkg/errors/LICENSE new file mode 100644 index 0000000000..835ba3e755 --- /dev/null +++ b/vendor/github.com/pkg/errors/LICENSE @@ -0,0 +1,23 @@ +Copyright (c) 2015, Dave Cheney +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/pkg/errors/README.md b/vendor/github.com/pkg/errors/README.md new file mode 100644 index 0000000000..273db3c98a --- /dev/null +++ b/vendor/github.com/pkg/errors/README.md @@ -0,0 +1,52 @@ +# errors [![Travis-CI](https://travis-ci.org/pkg/errors.svg)](https://travis-ci.org/pkg/errors) [![AppVeyor](https://ci.appveyor.com/api/projects/status/b98mptawhudj53ep/branch/master?svg=true)](https://ci.appveyor.com/project/davecheney/errors/branch/master) [![GoDoc](https://godoc.org/github.com/pkg/errors?status.svg)](http://godoc.org/github.com/pkg/errors) [![Report card](https://goreportcard.com/badge/github.com/pkg/errors)](https://goreportcard.com/report/github.com/pkg/errors) + +Package errors provides simple error handling primitives. + +`go get github.com/pkg/errors` + +The traditional error handling idiom in Go is roughly akin to +```go +if err != nil { + return err +} +``` +which applied recursively up the call stack results in error reports without context or debugging information. The errors package allows programmers to add context to the failure path in their code in a way that does not destroy the original value of the error. + +## Adding context to an error + +The errors.Wrap function returns a new error that adds context to the original error. For example +```go +_, err := ioutil.ReadAll(r) +if err != nil { + return errors.Wrap(err, "read failed") +} +``` +## Retrieving the cause of an error + +Using `errors.Wrap` constructs a stack of errors, adding context to the preceding error. Depending on the nature of the error it may be necessary to reverse the operation of errors.Wrap to retrieve the original error for inspection. Any error value which implements this interface can be inspected by `errors.Cause`. +```go +type causer interface { + Cause() error +} +``` +`errors.Cause` will recursively retrieve the topmost error which does not implement `causer`, which is assumed to be the original cause. For example: +```go +switch err := errors.Cause(err).(type) { +case *MyError: + // handle specifically +default: + // unknown error +} +``` + +[Read the package documentation for more information](https://godoc.org/github.com/pkg/errors). + +## Contributing + +We welcome pull requests, bug fixes and issue reports. With that said, the bar for adding new symbols to this package is intentionally set high. + +Before proposing a change, please discuss your change by raising an issue. + +## Licence + +BSD-2-Clause diff --git a/vendor/github.com/pkg/errors/appveyor.yml b/vendor/github.com/pkg/errors/appveyor.yml new file mode 100644 index 0000000000..a932eade02 --- /dev/null +++ b/vendor/github.com/pkg/errors/appveyor.yml @@ -0,0 +1,32 @@ +version: build-{build}.{branch} + +clone_folder: C:\gopath\src\github.com\pkg\errors +shallow_clone: true # for startup speed + +environment: + GOPATH: C:\gopath + +platform: + - x64 + +# http://www.appveyor.com/docs/installed-software +install: + # some helpful output for debugging builds + - go version + - go env + # pre-installed MinGW at C:\MinGW is 32bit only + # but MSYS2 at C:\msys64 has mingw64 + - set PATH=C:\msys64\mingw64\bin;%PATH% + - gcc --version + - g++ --version + +build_script: + - go install -v ./... + +test_script: + - set PATH=C:\gopath\bin;%PATH% + - go test -v ./... + +#artifacts: +# - path: '%GOPATH%\bin\*.exe' +deploy: off diff --git a/vendor/github.com/pkg/errors/errors.go b/vendor/github.com/pkg/errors/errors.go new file mode 100644 index 0000000000..842ee80456 --- /dev/null +++ b/vendor/github.com/pkg/errors/errors.go @@ -0,0 +1,269 @@ +// Package errors provides simple error handling primitives. +// +// The traditional error handling idiom in Go is roughly akin to +// +// if err != nil { +// return err +// } +// +// which applied recursively up the call stack results in error reports +// without context or debugging information. The errors package allows +// programmers to add context to the failure path in their code in a way +// that does not destroy the original value of the error. +// +// Adding context to an error +// +// The errors.Wrap function returns a new error that adds context to the +// original error by recording a stack trace at the point Wrap is called, +// and the supplied message. For example +// +// _, err := ioutil.ReadAll(r) +// if err != nil { +// return errors.Wrap(err, "read failed") +// } +// +// If additional control is required the errors.WithStack and errors.WithMessage +// functions destructure errors.Wrap into its component operations of annotating +// an error with a stack trace and an a message, respectively. +// +// Retrieving the cause of an error +// +// Using errors.Wrap constructs a stack of errors, adding context to the +// preceding error. Depending on the nature of the error it may be necessary +// to reverse the operation of errors.Wrap to retrieve the original error +// for inspection. Any error value which implements this interface +// +// type causer interface { +// Cause() error +// } +// +// can be inspected by errors.Cause. errors.Cause will recursively retrieve +// the topmost error which does not implement causer, which is assumed to be +// the original cause. For example: +// +// switch err := errors.Cause(err).(type) { +// case *MyError: +// // handle specifically +// default: +// // unknown error +// } +// +// causer interface is not exported by this package, but is considered a part +// of stable public API. +// +// Formatted printing of errors +// +// All error values returned from this package implement fmt.Formatter and can +// be formatted by the fmt package. The following verbs are supported +// +// %s print the error. If the error has a Cause it will be +// printed recursively +// %v see %s +// %+v extended format. Each Frame of the error's StackTrace will +// be printed in detail. +// +// Retrieving the stack trace of an error or wrapper +// +// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are +// invoked. This information can be retrieved with the following interface. +// +// type stackTracer interface { +// StackTrace() errors.StackTrace +// } +// +// Where errors.StackTrace is defined as +// +// type StackTrace []Frame +// +// The Frame type represents a call site in the stack trace. Frame supports +// the fmt.Formatter interface that can be used for printing information about +// the stack trace of this error. For example: +// +// if err, ok := err.(stackTracer); ok { +// for _, f := range err.StackTrace() { +// fmt.Printf("%+s:%d", f) +// } +// } +// +// stackTracer interface is not exported by this package, but is considered a part +// of stable public API. +// +// See the documentation for Frame.Format for more details. +package errors + +import ( + "fmt" + "io" +) + +// New returns an error with the supplied message. +// New also records the stack trace at the point it was called. +func New(message string) error { + return &fundamental{ + msg: message, + stack: callers(), + } +} + +// Errorf formats according to a format specifier and returns the string +// as a value that satisfies error. +// Errorf also records the stack trace at the point it was called. +func Errorf(format string, args ...interface{}) error { + return &fundamental{ + msg: fmt.Sprintf(format, args...), + stack: callers(), + } +} + +// fundamental is an error that has a message and a stack, but no caller. +type fundamental struct { + msg string + *stack +} + +func (f *fundamental) Error() string { return f.msg } + +func (f *fundamental) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + io.WriteString(s, f.msg) + f.stack.Format(s, verb) + return + } + fallthrough + case 's': + io.WriteString(s, f.msg) + case 'q': + fmt.Fprintf(s, "%q", f.msg) + } +} + +// WithStack annotates err with a stack trace at the point WithStack was called. +// If err is nil, WithStack returns nil. +func WithStack(err error) error { + if err == nil { + return nil + } + return &withStack{ + err, + callers(), + } +} + +type withStack struct { + error + *stack +} + +func (w *withStack) Cause() error { return w.error } + +func (w *withStack) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + fmt.Fprintf(s, "%+v", w.Cause()) + w.stack.Format(s, verb) + return + } + fallthrough + case 's': + io.WriteString(s, w.Error()) + case 'q': + fmt.Fprintf(s, "%q", w.Error()) + } +} + +// Wrap returns an error annotating err with a stack trace +// at the point Wrap is called, and the supplied message. +// If err is nil, Wrap returns nil. +func Wrap(err error, message string) error { + if err == nil { + return nil + } + err = &withMessage{ + cause: err, + msg: message, + } + return &withStack{ + err, + callers(), + } +} + +// Wrapf returns an error annotating err with a stack trace +// at the point Wrapf is call, and the format specifier. +// If err is nil, Wrapf returns nil. +func Wrapf(err error, format string, args ...interface{}) error { + if err == nil { + return nil + } + err = &withMessage{ + cause: err, + msg: fmt.Sprintf(format, args...), + } + return &withStack{ + err, + callers(), + } +} + +// WithMessage annotates err with a new message. +// If err is nil, WithMessage returns nil. +func WithMessage(err error, message string) error { + if err == nil { + return nil + } + return &withMessage{ + cause: err, + msg: message, + } +} + +type withMessage struct { + cause error + msg string +} + +func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() } +func (w *withMessage) Cause() error { return w.cause } + +func (w *withMessage) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + fmt.Fprintf(s, "%+v\n", w.Cause()) + io.WriteString(s, w.msg) + return + } + fallthrough + case 's', 'q': + io.WriteString(s, w.Error()) + } +} + +// Cause returns the underlying cause of the error, if possible. +// An error value has a cause if it implements the following +// interface: +// +// type causer interface { +// Cause() error +// } +// +// If the error does not implement Cause, the original error will +// be returned. If the error is nil, nil will be returned without further +// investigation. +func Cause(err error) error { + type causer interface { + Cause() error + } + + for err != nil { + cause, ok := err.(causer) + if !ok { + break + } + err = cause.Cause() + } + return err +} diff --git a/vendor/github.com/pkg/errors/stack.go b/vendor/github.com/pkg/errors/stack.go new file mode 100644 index 0000000000..6b1f2891a5 --- /dev/null +++ b/vendor/github.com/pkg/errors/stack.go @@ -0,0 +1,178 @@ +package errors + +import ( + "fmt" + "io" + "path" + "runtime" + "strings" +) + +// Frame represents a program counter inside a stack frame. +type Frame uintptr + +// pc returns the program counter for this frame; +// multiple frames may have the same PC value. +func (f Frame) pc() uintptr { return uintptr(f) - 1 } + +// file returns the full path to the file that contains the +// function for this Frame's pc. +func (f Frame) file() string { + fn := runtime.FuncForPC(f.pc()) + if fn == nil { + return "unknown" + } + file, _ := fn.FileLine(f.pc()) + return file +} + +// line returns the line number of source code of the +// function for this Frame's pc. +func (f Frame) line() int { + fn := runtime.FuncForPC(f.pc()) + if fn == nil { + return 0 + } + _, line := fn.FileLine(f.pc()) + return line +} + +// Format formats the frame according to the fmt.Formatter interface. +// +// %s source file +// %d source line +// %n function name +// %v equivalent to %s:%d +// +// Format accepts flags that alter the printing of some verbs, as follows: +// +// %+s path of source file relative to the compile time GOPATH +// %+v equivalent to %+s:%d +func (f Frame) Format(s fmt.State, verb rune) { + switch verb { + case 's': + switch { + case s.Flag('+'): + pc := f.pc() + fn := runtime.FuncForPC(pc) + if fn == nil { + io.WriteString(s, "unknown") + } else { + file, _ := fn.FileLine(pc) + fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file) + } + default: + io.WriteString(s, path.Base(f.file())) + } + case 'd': + fmt.Fprintf(s, "%d", f.line()) + case 'n': + name := runtime.FuncForPC(f.pc()).Name() + io.WriteString(s, funcname(name)) + case 'v': + f.Format(s, 's') + io.WriteString(s, ":") + f.Format(s, 'd') + } +} + +// StackTrace is stack of Frames from innermost (newest) to outermost (oldest). +type StackTrace []Frame + +func (st StackTrace) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + switch { + case s.Flag('+'): + for _, f := range st { + fmt.Fprintf(s, "\n%+v", f) + } + case s.Flag('#'): + fmt.Fprintf(s, "%#v", []Frame(st)) + default: + fmt.Fprintf(s, "%v", []Frame(st)) + } + case 's': + fmt.Fprintf(s, "%s", []Frame(st)) + } +} + +// stack represents a stack of program counters. +type stack []uintptr + +func (s *stack) Format(st fmt.State, verb rune) { + switch verb { + case 'v': + switch { + case st.Flag('+'): + for _, pc := range *s { + f := Frame(pc) + fmt.Fprintf(st, "\n%+v", f) + } + } + } +} + +func (s *stack) StackTrace() StackTrace { + f := make([]Frame, len(*s)) + for i := 0; i < len(f); i++ { + f[i] = Frame((*s)[i]) + } + return f +} + +func callers() *stack { + const depth = 32 + var pcs [depth]uintptr + n := runtime.Callers(3, pcs[:]) + var st stack = pcs[0:n] + return &st +} + +// funcname removes the path prefix component of a function's name reported by func.Name(). +func funcname(name string) string { + i := strings.LastIndex(name, "/") + name = name[i+1:] + i = strings.Index(name, ".") + return name[i+1:] +} + +func trimGOPATH(name, file string) string { + // Here we want to get the source file path relative to the compile time + // GOPATH. As of Go 1.6.x there is no direct way to know the compiled + // GOPATH at runtime, but we can infer the number of path segments in the + // GOPATH. We note that fn.Name() returns the function name qualified by + // the import path, which does not include the GOPATH. Thus we can trim + // segments from the beginning of the file path until the number of path + // separators remaining is one more than the number of path separators in + // the function name. For example, given: + // + // GOPATH /home/user + // file /home/user/src/pkg/sub/file.go + // fn.Name() pkg/sub.Type.Method + // + // We want to produce: + // + // pkg/sub/file.go + // + // From this we can easily see that fn.Name() has one less path separator + // than our desired output. We count separators from the end of the file + // path until it finds two more than in the function name and then move + // one character forward to preserve the initial path segment without a + // leading separator. + const sep = "/" + goal := strings.Count(name, sep) + 2 + i := len(file) + for n := 0; n < goal; n++ { + i = strings.LastIndex(file[:i], sep) + if i == -1 { + // not enough separators found, set i so that the slice expression + // below leaves file unmodified + i = -len(sep) + break + } + } + // get back to 0 or trim the leading separator + file = file[i+len(sep):] + return file +} diff --git a/vendor/github.com/skycoin/dmsg/addr.go b/vendor/github.com/skycoin/dmsg/addr.go deleted file mode 100644 index 2be739b40b..0000000000 --- a/vendor/github.com/skycoin/dmsg/addr.go +++ /dev/null @@ -1,26 +0,0 @@ -package dmsg - -import ( - "fmt" - - "github.com/skycoin/dmsg/cipher" -) - -// Addr implements net.Addr for skywire addresses. -type Addr struct { - PK cipher.PubKey - Port uint16 -} - -// Network returns "dmsg" -func (Addr) Network() string { - return Type -} - -// String returns public key and port of node split by colon. -func (a Addr) String() string { - if a.Port == 0 { - return fmt.Sprintf("%s:~", a.PK) - } - return fmt.Sprintf("%s:%d", a.PK, a.Port) -} diff --git a/vendor/github.com/skycoin/dmsg/client.go b/vendor/github.com/skycoin/dmsg/client.go index 2e7d898a2b..f09ffefde9 100644 --- a/vendor/github.com/skycoin/dmsg/client.go +++ b/vendor/github.com/skycoin/dmsg/client.go @@ -57,7 +57,6 @@ type Client struct { pm *PortManager - // accept map[uint16]chan *transport done chan struct{} once sync.Once } @@ -70,10 +69,8 @@ func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc disc.APIClient, opts ...Cl sk: sk, dc: dc, conns: make(map[cipher.PubKey]*ClientConn), - pm: newPortManager(), - // accept: make(chan *transport, AcceptBufferSize), - // accept: make(map[uint16]chan *transport), - done: make(chan struct{}), + pm: newPortManager(pk), + done: make(chan struct{}), } for _, opt := range opts { if err := opt(c); err != nil { @@ -103,7 +100,7 @@ func (c *Client) updateDiscEntry(ctx context.Context) error { func (c *Client) setConn(ctx context.Context, conn *ClientConn) { c.mx.Lock() - c.conns[conn.remoteSrv] = conn + c.conns[conn.srvPK] = conn if err := c.updateDiscEntry(ctx); err != nil { c.log.WithError(err).Warn("updateEntry: failed") } @@ -142,7 +139,7 @@ func (c *Client) InitiateServerConnections(ctx context.Context, min int) error { if err != nil { return err } - c.log.Info("found dms_server entries:", entries) + c.log.Info("found dmsg.Server entries:", entries) if err := c.findOrConnectToServers(ctx, entries, min); err != nil { return err } @@ -158,7 +155,7 @@ func (c *Client) findServerEntries(ctx context.Context) ([]*disc.Entry, error) { return nil, fmt.Errorf("dms_servers are not available: %s", err) default: retry := time.Second - c.log.WithError(err).Warnf("no dms_servers found: trying again in %d second...", retry) + c.log.WithError(err).Warnf("no dms_servers found: trying again in %v...", retry) time.Sleep(retry) continue } @@ -213,7 +210,7 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey) return nil, err } - conn := NewClientConn(c.log, nc, c.pk, srvPK, c.pm) + conn := NewClientConn(c.log, c.pm, nc, c.pk, srvPK) if err := conn.readOK(); err != nil { return nil, err } @@ -244,7 +241,7 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey) // Listen creates a listener on a given port, adds it to port manager and returns the listener. func (c *Client) Listen(port uint16) (*Listener, error) { - l, ok := c.pm.NewListener(c.pk, port) + l, ok := c.pm.NewListener(port) if !ok { return nil, errors.New("port is busy") } @@ -288,7 +285,7 @@ func (c *Client) Type() string { // Close closes the dms_client and associated connections. // TODO(evaninjin): proper error handling. -func (c *Client) Close() error { +func (c *Client) Close() (err error) { if c == nil { return nil } @@ -305,13 +302,8 @@ func (c *Client) Close() error { c.conns = make(map[cipher.PubKey]*ClientConn) c.mx.Unlock() - c.pm.mu.Lock() - defer c.pm.mu.Unlock() - - for _, lis := range c.pm.listeners { - lis.close() - } + err = c.pm.Close() }) - return nil + return err } diff --git a/vendor/github.com/skycoin/dmsg/client_conn.go b/vendor/github.com/skycoin/dmsg/client_conn.go index 9ee1895af2..be48e6adbb 100644 --- a/vendor/github.com/skycoin/dmsg/client_conn.go +++ b/vendor/github.com/skycoin/dmsg/client_conn.go @@ -2,7 +2,6 @@ package dmsg import ( "context" - "encoding/json" "errors" "fmt" "net" @@ -18,9 +17,9 @@ import ( type ClientConn struct { log *logging.Logger - net.Conn // conn to dmsg server - local cipher.PubKey // local client's pk - remoteSrv cipher.PubKey // dmsg server's public key + net.Conn // conn to dmsg server + lPK cipher.PubKey // local client's pk + srvPK cipher.PubKey // dmsg server's public key // nextInitID keeps track of unused tp_ids to assign a future locally-initiated tp. // locally-initiated tps use an even tp_id between local and intermediary dms_server. @@ -38,12 +37,12 @@ type ClientConn struct { } // NewClientConn creates a new ClientConn. -func NewClientConn(log *logging.Logger, conn net.Conn, local, remote cipher.PubKey, pm *PortManager) *ClientConn { +func NewClientConn(log *logging.Logger, pm *PortManager, conn net.Conn, lPK, rPK cipher.PubKey) *ClientConn { cc := &ClientConn{ log: log, Conn: conn, - local: local, - remoteSrv: remote, + lPK: lPK, + srvPK: rPK, nextInitID: randID(true), tps: make(map[uint16]*Transport), pm: pm, @@ -54,7 +53,7 @@ func NewClientConn(log *logging.Logger, conn net.Conn, local, remote cipher.PubK } // RemotePK returns the remote Server's PK that the ClientConn is connected to. -func (c *ClientConn) RemotePK() cipher.PubKey { return c.remoteSrv } +func (c *ClientConn) RemotePK() cipher.PubKey { return c.srvPK } func (c *ClientConn) getNextInitID(ctx context.Context) (uint16, error) { for { @@ -76,7 +75,7 @@ func (c *ClientConn) getNextInitID(ctx context.Context) (uint16, error) { } } -func (c *ClientConn) addTp(ctx context.Context, rPK cipher.PubKey, lPort, rPort uint16) (*Transport, error) { +func (c *ClientConn) addTp(ctx context.Context, rPK cipher.PubKey, lPort, rPort uint16, closeCB func()) (*Transport, error) { c.mx.Lock() defer c.mx.Unlock() @@ -84,7 +83,10 @@ func (c *ClientConn) addTp(ctx context.Context, rPK cipher.PubKey, lPort, rPort if err != nil { return nil, err } - tp := NewTransport(c.Conn, c.log, Addr{c.local, lPort}, Addr{rPK, rPort}, id, c.delTp) + tp := NewTransport(c.Conn, c.log, Addr{c.lPK, lPort}, Addr{rPK, rPort}, id, func() { + c.delTp(id) + closeCB() + }) c.tps[id] = tp return tp, nil } @@ -116,72 +118,71 @@ func (c *ClientConn) setNextInitID(nextInitID uint16) { } func (c *ClientConn) readOK() error { - fr, err := readFrame(c.Conn) + _, df, err := readFrame(c.Conn) if err != nil { return errors.New("failed to get OK from server") } - - ft, _, _ := fr.Disassemble() - if ft != OkType { - return fmt.Errorf("wrong frame from server: %v", ft) + if df.Type != OkType { + return fmt.Errorf("wrong frame from server: %v", df.Type) } - return nil } -func (c *ClientConn) handleRequestFrame(id uint16, p []byte) (cipher.PubKey, error) { - // remotely-initiated tps should: - // - have a payload structured as HandshakePayload marshaled to JSON. - // - resp_pk should be of local client. - // - use an odd tp_id with the intermediary dmsg_server. - payload, err := unmarshalHandshakePayload(p) - if err != nil { - // TODO(nkryuchkov): When implementing reasons, send that payload format is incorrect. +// This handles 'REQUEST' frames which represent remotely-initiated tps. 'REQUEST' frames should: +// - have a HandshakePayload marshaled to JSON as payload. +// - have a resp_pk be of local client. +// - have an odd tp_id. +func (c *ClientConn) handleRequestFrame(log *logrus.Entry, id uint16, p []byte) (cipher.PubKey, error) { + + // The public key of the initiating client (or the client that sent the 'REQUEST' frame). + var initPK cipher.PubKey + + // Attempts to close tp due to given error. + // When we fail to close tp (a.k.a fail to send 'CLOSE' frame) or if the local client is closed, + // the connection to server should be closed. + // TODO(evanlinjin): derive close reason from error. + closeTp := func(origErr error) (cipher.PubKey, error) { if err := writeCloseFrame(c.Conn, id, PlaceholderReason); err != nil { - return cipher.PubKey{}, err + log.WithError(err).Warn("handleRequestFrame: failed to close transport: ending conn to server.") + log.WithError(c.Close()).Warn("handleRequestFrame: closing connection to server.") + return initPK, origErr + } + switch origErr { + case ErrClientClosed: + log.WithError(c.Close()).Warn("handleRequestFrame: closing connection to server.") } - return cipher.PubKey{}, ErrRequestCheckFailed + return initPK, origErr } - if payload.RespPK != c.local || isInitiatorID(id) { - // TODO(nkryuchkov): When implementing reasons, send that payload is malformed. - if err := writeCloseFrame(c.Conn, id, PlaceholderReason); err != nil { - return payload.InitPK, err - } - return payload.InitPK, ErrRequestCheckFailed + pay, err := unmarshalHandshakePayload(p) + if err != nil { + return closeTp(ErrRequestCheckFailed) // TODO(nkryuchkov): reason = payload format is incorrect. } + initPK = pay.InitAddr.PK - lis, ok := c.pm.Listener(payload.Port) + if pay.RespAddr.PK != c.lPK || isInitiatorID(id) { + return closeTp(ErrRequestCheckFailed) // TODO(nkryuchkov): reason = payload is malformed. + } + lis, ok := c.pm.Listener(pay.RespAddr.Port) if !ok { - // TODO(nkryuchkov): When implementing reasons, send that port is not listening - if err := writeCloseFrame(c.Conn, id, PlaceholderReason); err != nil { - return payload.InitPK, err - } - return payload.InitPK, ErrPortNotListening + return closeTp(ErrPortNotListening) // TODO(nkryuchkov): reason = port is not listening. + } + if c.isClosed() { + return closeTp(ErrClientClosed) // TODO(nkryuchkov): reason = client is closed. } - tp := NewTransport(c.Conn, c.log, Addr{c.local, payload.Port}, Addr{payload.InitPK, 0}, id, c.delTp) // TODO: Have proper remote port. - - select { - case <-c.done: - if err := tp.Close(); err != nil { - log.WithError(err).Warn("Failed to close transport") - } - return payload.InitPK, ErrClientClosed - - default: - err := lis.IntroduceTransport(tp) - if err == nil || err == ErrClientAcceptMaxed { - c.setTp(tp) - } - return payload.InitPK, err + tp := NewTransport(c.Conn, c.log, pay.RespAddr, pay.InitAddr, id, func() { c.delTp(id) }) + if err := lis.IntroduceTransport(tp); err != nil { + return initPK, err } + c.setTp(tp) + return initPK, nil } // Serve handles incoming frames. // Remote-initiated tps that are successfully created are pushing into 'accept' and exposed via 'Client.Accept()'. func (c *ClientConn) Serve(ctx context.Context) (err error) { - log := c.log.WithField("remoteServer", c.remoteSrv) + log := c.log.WithField("remoteServer", c.srvPK) log.WithField("connCount", incrementServeCount()).Infoln("ServingConn") defer func() { c.close() @@ -190,50 +191,40 @@ func (c *ClientConn) Serve(ctx context.Context) (err error) { }() for { - f, err := readFrame(c.Conn) + f, df, err := readFrame(c.Conn) if err != nil { return fmt.Errorf("read failed: %s", err) } log = log.WithField("received", f) - ft, id, p := f.Disassemble() - // If tp of tp_id exists, attempt to forward frame to tp. - // delete tp on any failure. - - if tp, ok := c.getTp(id); ok { + // Delete tp on any failure. + if tp, ok := c.getTp(df.TpID); ok { if err := tp.HandleFrame(f); err != nil { - log.WithError(err).Warnf("Rejected [%s]: Transport closed.", ft) + log.WithError(err).Warnf("Rejected [%s]: Transport closed.", df.Type) } continue } + c.delTp(df.TpID) // rm tp in case closed tp is not fully removed. // if tp does not exist, frame should be 'REQUEST'. // otherwise, handle any unexpected frames accordingly. - - c.delTp(id) // rm tp in case closed tp is not fully removed. - - switch ft { + switch df.Type { case RequestType: c.wg.Add(1) go func(log *logrus.Entry) { defer c.wg.Done() - initPK, err := c.handleRequestFrame(id, p) - if err != nil { - log.WithField("remoteClient", initPK).WithError(err).Infoln("Rejected [REQUEST]") - if isWriteError(err) || err == ErrClientClosed { - err := c.Close() - log.WithError(err).Warn("ClosingConnection") - } - return + if initPK, err := c.handleRequestFrame(log, df.TpID, df.Pay); err != nil { + log.WithField("remoteClient", initPK).WithError(err).Warn("Rejected [REQUEST]") + } else { + log.WithField("remoteClient", initPK).Info("Accepted [REQUEST]") } - log.WithField("remoteClient", initPK).Infoln("Accepted [REQUEST]") }(log) default: - log.Debugf("Ignored [%s]: No transport of given ID.", ft) - if ft != CloseType { - if err := writeCloseFrame(c.Conn, id, PlaceholderReason); err != nil { + log.Debugf("Ignored [%s]: No transport of given ID.", df.Type) + if df.Type != CloseType { + if err := writeCloseFrame(c.Conn, df.TpID, PlaceholderReason); err != nil { return err } } @@ -242,12 +233,16 @@ func (c *ClientConn) Serve(ctx context.Context) (err error) { } // DialTransport dials a transport to remote dms_client. -func (c *ClientConn) DialTransport(ctx context.Context, clientPK cipher.PubKey, port uint16) (*Transport, error) { - tp, err := c.addTp(ctx, clientPK, 0, port) // TODO: Have proper local port. +func (c *ClientConn) DialTransport(ctx context.Context, rPK cipher.PubKey, rPort uint16) (*Transport, error) { + lPort, closeCB, err := c.pm.ReserveEphemeral(ctx) if err != nil { return nil, err } - if err := tp.WriteRequest(port); err != nil { + tp, err := c.addTp(ctx, rPK, lPort, rPort, closeCB) // TODO: Have proper local port. + if err != nil { + return nil, err + } + if err := tp.WriteRequest(); err != nil { return nil, err } if err := tp.ReadAccept(ctx); err != nil { @@ -263,7 +258,7 @@ func (c *ClientConn) close() (closed bool) { } c.once.Do(func() { closed = true - c.log.WithField("remoteServer", c.remoteSrv).Infoln("ClosingConnection") + c.log.WithField("remoteServer", c.srvPK).Infoln("ClosingConnection") close(c.done) c.mx.Lock() for _, tp := range c.tps { @@ -290,12 +285,11 @@ func (c *ClientConn) Close() error { return nil } -func marshalHandshakePayload(p HandshakePayload) ([]byte, error) { - return json.Marshal(p) -} - -func unmarshalHandshakePayload(b []byte) (HandshakePayload, error) { - var p HandshakePayload - err := json.Unmarshal(b, &p) - return p, err +func (c *ClientConn) isClosed() bool { + select { + case <-c.done: + return true + default: + return false + } } diff --git a/vendor/github.com/skycoin/dmsg/listener.go b/vendor/github.com/skycoin/dmsg/listener.go index 2c685f8f1e..3fc6f48a46 100644 --- a/vendor/github.com/skycoin/dmsg/listener.go +++ b/vendor/github.com/skycoin/dmsg/listener.go @@ -1,36 +1,84 @@ package dmsg import ( + "fmt" "net" "sync" - - "github.com/skycoin/dmsg/cipher" ) // Listener listens for remote-initiated transports. type Listener struct { - pk cipher.PubKey - port uint16 - mx sync.Mutex // protects 'accept' + addr Addr // local listening address + accept chan *Transport - done chan struct{} - once sync.Once + mx sync.Mutex // protects 'accept' + + doneFunc func() // callback when done + done chan struct{} + once sync.Once } -func newListener(pk cipher.PubKey, port uint16) *Listener { +func newListener(addr Addr) *Listener { return &Listener{ - pk: pk, - port: port, + addr: addr, accept: make(chan *Transport, AcceptBufferSize), done: make(chan struct{}), } } +// AddCloseCallback adds a function that triggers when listener is closed. +// This should be called right after the listener is created and is not thread safe. +func (l *Listener) AddCloseCallback(cb func()) { l.doneFunc = cb } + +// IntroduceTransport handles a transport after receiving a REQUEST frame. +func (l *Listener) IntroduceTransport(tp *Transport) error { + if tp.LocalAddr() != l.addr { + return fmt.Errorf("failed to accept transport as local addresses does not match: we expected %s but got %s", + l.addr, tp.LocalAddr()) + } + + l.mx.Lock() + defer l.mx.Unlock() + + if l.isClosed() { + return ErrClientClosed + } + + select { + case <-l.done: + return ErrClientClosed + + case l.accept <- tp: + if err := tp.WriteAccept(); err != nil { + return err + } + go tp.Serve() + return nil + + default: + _ = tp.Close() //nolint:errcheck + return ErrClientAcceptMaxed + } +} + // Accept accepts a connection. func (l *Listener) Accept() (net.Conn, error) { return l.AcceptTransport() } +// AcceptTransport accepts a transport connection. +func (l *Listener) AcceptTransport() (*Transport, error) { + select { + case <-l.done: + return nil, ErrClientClosed + case tp, ok := <-l.accept: + if !ok { + return nil, ErrClientClosed + } + return tp, nil + } +} + // Close closes the listener. func (l *Listener) Close() error { if l.close() { @@ -42,6 +90,7 @@ func (l *Listener) Close() error { func (l *Listener) close() (closed bool) { l.once.Do(func() { closed = true + l.doneFunc() l.mx.Lock() defer l.mx.Unlock() @@ -69,55 +118,7 @@ func (l *Listener) isClosed() bool { } // Addr returns the listener's address. -func (l *Listener) Addr() net.Addr { - return Addr{ - PK: l.pk, - Port: l.port, - } -} - -// AcceptTransport accepts a transport connection. -func (l *Listener) AcceptTransport() (*Transport, error) { - select { - case <-l.done: - return nil, ErrClientClosed - case tp, ok := <-l.accept: - if !ok { - return nil, ErrClientClosed - } - return tp, nil - } -} +func (l *Listener) Addr() net.Addr { return l.addr } // Type returns the transport type. -func (l *Listener) Type() string { - return Type -} - -// IntroduceTransport handles a transport after receiving a REQUEST frame. -func (l *Listener) IntroduceTransport(tp *Transport) error { - l.mx.Lock() - defer l.mx.Unlock() - - if l.isClosed() { - return ErrClientClosed - } - - select { - case <-l.done: - return ErrClientClosed - - case l.accept <- tp: - if err := tp.WriteAccept(); err != nil { - return err - } - go tp.Serve() - return nil - - default: - if err := tp.Close(); err != nil { - log.WithError(err).Warn("Failed to close transport") - } - return ErrClientAcceptMaxed - } -} +func (l *Listener) Type() string { return Type } diff --git a/vendor/github.com/skycoin/dmsg/netutil/porter.go b/vendor/github.com/skycoin/dmsg/netutil/porter.go new file mode 100644 index 0000000000..fb0d2c1b26 --- /dev/null +++ b/vendor/github.com/skycoin/dmsg/netutil/porter.go @@ -0,0 +1,102 @@ +package netutil + +import ( + "context" + "sync" +) + +const ( + // PorterMinEphemeral is the default minimum ephemeral port. + PorterMinEphemeral = uint16(49152) +) + +// Porter reserves ports. +type Porter struct { + sync.RWMutex + eph uint16 // current ephemeral value + minEph uint16 // minimal ephemeral port value + ports map[uint16]interface{} +} + +// NewPorter creates a new Porter with a given minimum ephemeral port value. +func NewPorter(minEph uint16) *Porter { + ports := make(map[uint16]interface{}) + ports[0] = struct{}{} // port 0 is invalid + + return &Porter{ + eph: minEph, + minEph: minEph, + ports: ports, + } +} + +// Reserve a given port. +// It returns a boolean informing whether the port is reserved, and a function to clear the reservation. +func (p *Porter) Reserve(port uint16, v interface{}) (bool, func()) { + p.Lock() + defer p.Unlock() + + if _, ok := p.ports[port]; ok { + return false, nil + } + p.ports[port] = v + return true, p.makePortFreer(port) +} + +// ReserveEphemeral reserves a new ephemeral port. +// It returns the reserved ephemeral port, a function to clear the reservation and an error (if any). +func (p *Porter) ReserveEphemeral(ctx context.Context, v interface{}) (uint16, func(), error) { + p.Lock() + defer p.Unlock() + + for { + p.eph++ + if p.eph < p.minEph { + p.eph = p.minEph + } + if _, ok := p.ports[p.eph]; ok { + select { + case <-ctx.Done(): + return 0, nil, ctx.Err() + default: + continue + } + } + p.ports[p.eph] = v + return p.eph, p.makePortFreer(p.eph), nil + } +} + +// PortValue returns the value stored under a given port. +func (p *Porter) PortValue(port uint16) (interface{}, bool) { + p.RLock() + defer p.RUnlock() + + v, ok := p.ports[port] + return v, ok +} + +// RangePortValues ranges all ports that are currently reserved. +func (p *Porter) RangePortValues(fn func(port uint16, v interface{}) (next bool)) { + p.RLock() + defer p.RUnlock() + + for port, v := range p.ports { + if next := fn(port, v); !next { + return + } + } +} + +// This returns a function that frees a given port. +// It is ensured that the function's action is only performed once. +func (p *Porter) makePortFreer(port uint16) func() { + once := new(sync.Once) + return func() { + once.Do(func() { + p.Lock() + delete(p.ports, port) + p.Unlock() + }) + } +} diff --git a/vendor/github.com/skycoin/dmsg/port_manager.go b/vendor/github.com/skycoin/dmsg/port_manager.go index 63540c7017..0ab5a18e4a 100644 --- a/vendor/github.com/skycoin/dmsg/port_manager.go +++ b/vendor/github.com/skycoin/dmsg/port_manager.go @@ -1,72 +1,66 @@ package dmsg import ( - "math/rand" + "context" "sync" - "time" "github.com/skycoin/dmsg/cipher" -) - -const ( - firstEphemeralPort = 49152 - lastEphemeralPort = 65535 + "github.com/skycoin/dmsg/netutil" ) // PortManager manages ports of nodes. type PortManager struct { - mu sync.RWMutex - rand *rand.Rand - listeners map[uint16]*Listener + lPK cipher.PubKey + p *netutil.Porter } -func newPortManager() *PortManager { +func newPortManager(lPK cipher.PubKey) *PortManager { return &PortManager{ - rand: rand.New(rand.NewSource(time.Now().UnixNano())), - listeners: make(map[uint16]*Listener), + lPK: lPK, + p: netutil.NewPorter(netutil.PorterMinEphemeral), } } // Listener returns a listener assigned to a given port. func (pm *PortManager) Listener(port uint16) (*Listener, bool) { - pm.mu.RLock() - defer pm.mu.RUnlock() - - l, ok := pm.listeners[port] + v, ok := pm.p.PortValue(port) + if !ok { + return nil, false + } + l, ok := v.(*Listener) return l, ok } // NewListener assigns listener to port if port is available. -func (pm *PortManager) NewListener(pk cipher.PubKey, port uint16) (*Listener, bool) { - pm.mu.Lock() - defer pm.mu.Unlock() - if _, ok := pm.listeners[port]; ok { +func (pm *PortManager) NewListener(port uint16) (*Listener, bool) { + l := newListener(Addr{pm.lPK, port}) + ok, clear := pm.p.Reserve(port, l) + if !ok { return nil, false } - l := newListener(pk, port) - pm.listeners[port] = l + l.AddCloseCallback(clear) return l, true } -// RemoveListener removes listener assigned to port. -func (pm *PortManager) RemoveListener(port uint16) { - pm.mu.Lock() - defer pm.mu.Unlock() - - delete(pm.listeners, port) +// ReserveEphemeral reserves an ephemeral port. +func (pm *PortManager) ReserveEphemeral(ctx context.Context) (uint16, func(), error) { + return pm.p.ReserveEphemeral(ctx, nil) } -// NextEmptyEphemeralPort returns next random ephemeral port. -// It has a value between firstEphemeralPort and lastEphemeralPort. -func (pm *PortManager) NextEmptyEphemeralPort() uint16 { - for { - port := pm.randomEphemeralPort() - if _, ok := pm.Listener(port); !ok { - return port +// Close closes all listeners. +func (pm *PortManager) Close() error { + wg := new(sync.WaitGroup) + pm.p.RangePortValues(func(_ uint16, v interface{}) (next bool) { + l, ok := v.(*Listener) + if ok { + wg.Add(1) + go func() { + l.close() + wg.Done() + }() } - } -} - -func (pm *PortManager) randomEphemeralPort() uint16 { - return uint16(firstEphemeralPort + pm.rand.Intn(lastEphemeralPort-firstEphemeralPort)) + return true + }) + wg.Wait() + return nil } diff --git a/vendor/github.com/skycoin/dmsg/server.go b/vendor/github.com/skycoin/dmsg/server.go index ba0ee3dd37..a5bfa304c2 100644 --- a/vendor/github.com/skycoin/dmsg/server.go +++ b/vendor/github.com/skycoin/dmsg/server.go @@ -19,239 +19,6 @@ import ( // ErrListenerAlreadyWrappedToNoise occurs when the provided net.Listener is already wrapped with noise.Listener var ErrListenerAlreadyWrappedToNoise = errors.New("listener is already wrapped to *noise.Listener") -// NextConn provides information on the next connection. -type NextConn struct { - conn *ServerConn - id uint16 -} - -func (r *NextConn) writeFrame(ft FrameType, p []byte) error { - if err := writeFrame(r.conn.Conn, MakeFrame(ft, r.id, p)); err != nil { - go func() { - if err := r.conn.Close(); err != nil { - log.WithError(err).Warn("Failed to close connection") - } - }() - return err - } - return nil -} - -// ServerConn is a connection between a dmsg.Server and a dmsg.Client from a server's perspective. -type ServerConn struct { - log *logging.Logger - - net.Conn - remoteClient cipher.PubKey - - nextRespID uint16 - nextConns map[uint16]*NextConn - mx sync.RWMutex -} - -// NewServerConn creates a new connection from the perspective of a dms_server. -func NewServerConn(log *logging.Logger, conn net.Conn, remoteClient cipher.PubKey) *ServerConn { - return &ServerConn{ - log: log, - Conn: conn, - remoteClient: remoteClient, - nextRespID: randID(false), - nextConns: make(map[uint16]*NextConn), - } -} - -func (c *ServerConn) delNext(id uint16) { - c.mx.Lock() - delete(c.nextConns, id) - c.mx.Unlock() -} - -func (c *ServerConn) setNext(id uint16, r *NextConn) { - c.mx.Lock() - c.nextConns[id] = r - c.mx.Unlock() -} - -func (c *ServerConn) getNext(id uint16) (*NextConn, bool) { - c.mx.RLock() - r := c.nextConns[id] - c.mx.RUnlock() - return r, r != nil -} - -func (c *ServerConn) addNext(ctx context.Context, r *NextConn) (uint16, error) { - c.mx.Lock() - defer c.mx.Unlock() - - for { - if r := c.nextConns[c.nextRespID]; r == nil { - break - } - c.nextRespID += 2 - - select { - case <-ctx.Done(): - return 0, ctx.Err() - default: - } - } - - id := c.nextRespID - c.nextRespID = id + 2 - c.nextConns[id] = r - return id, nil -} - -// PK returns the remote dms_client's public key. -func (c *ServerConn) PK() cipher.PubKey { - return c.remoteClient -} - -type getConnFunc func(pk cipher.PubKey) (*ServerConn, bool) - -// Serve handles (and forwards when necessary) incoming frames. -func (c *ServerConn) Serve(ctx context.Context, getConn getConnFunc) (err error) { - log := c.log.WithField("srcClient", c.remoteClient) - - // Only manually close the underlying net.Conn when the done signal is context-initiated. - done := make(chan struct{}) - defer close(done) - go func() { - select { - case <-done: - case <-ctx.Done(): - if err := c.Conn.Close(); err != nil { - log.WithError(err).Warn("failed to close underlying connection") - } - } - }() - - defer func() { - // Send CLOSE frames to all transports which are established with this dmsg.Client - // This ensures that all parties are informed about the transport closing. - c.mx.Lock() - for _, conn := range c.nextConns { - why := byte(0) - if err := conn.writeFrame(CloseType, []byte{why}); err != nil { - log.WithError(err).Warnf("failed to write frame: %s", err) - } - } - c.mx.Unlock() - - log.WithError(err).WithField("connCount", decrementServeCount()).Infoln("ClosingConn") - if err := c.Conn.Close(); err != nil { - log.WithError(err).Warn("Failed to close connection") - } - }() - - log.WithField("connCount", incrementServeCount()).Infoln("ServingConn") - - err = c.writeOK() - if err != nil { - return fmt.Errorf("sending OK failed: %s", err) - } - - for { - f, err := readFrame(c.Conn) - if err != nil { - return fmt.Errorf("read failed: %s", err) - } - log := log.WithField("received", f) - - ft, id, p := f.Disassemble() - - switch ft { - case RequestType: - ctx, cancel := context.WithTimeout(ctx, TransportHandshakeTimeout) - _, why, ok := c.handleRequest(ctx, getConn, id, p) - cancel() - if !ok { - log.Debugln("FrameRejected: Erroneous request or unresponsive dstClient.") - if err := c.delChan(id, why); err != nil { - return err - } - } - log.Debugln("FrameForwarded") - - case AcceptType, FwdType, AckType, CloseType: - next, why, ok := c.forwardFrame(ft, id, p) - if !ok { - log.Debugln("FrameRejected: Failed to forward to dstClient.") - // Delete channel (and associations) on failure. - if err := c.delChan(id, why); err != nil { - return err - } - continue - } - log.Debugln("FrameForwarded") - - // On success, if Close frame, delete the associations. - if ft == CloseType { - c.delNext(id) - next.conn.delNext(next.id) - } - - default: - log.Debugln("FrameRejected: Unknown frame type.") - // Unknown frame type. - return errors.New("unknown frame of type received") - } - } -} - -func (c *ServerConn) delChan(id uint16, why byte) error { - c.delNext(id) - if err := writeCloseFrame(c.Conn, id, why); err != nil { - return fmt.Errorf("failed to write frame: %s", err) - } - return nil -} - -func (c *ServerConn) writeOK() error { - if err := writeFrame(c.Conn, MakeFrame(OkType, 0, nil)); err != nil { - return err - } - return nil -} - -// nolint:unparam -func (c *ServerConn) forwardFrame(ft FrameType, id uint16, p []byte) (*NextConn, byte, bool) { - next, ok := c.getNext(id) - if !ok { - return next, 0, false - } - if err := next.writeFrame(ft, p); err != nil { - return next, 0, false - } - return next, 0, true -} - -// nolint:unparam -func (c *ServerConn) handleRequest(ctx context.Context, getLink getConnFunc, id uint16, p []byte) (*NextConn, byte, bool) { - payload, err := unmarshalHandshakePayload(p) - if err != nil || payload.InitPK != c.PK() { - return nil, 0, false - } - respL, ok := getLink(payload.RespPK) - if !ok { - return nil, 0, false - } - - // set next relations. - respID, err := respL.addNext(ctx, &NextConn{conn: c, id: id}) - if err != nil { - return nil, 0, false - } - next := &NextConn{conn: respL, id: respID} - c.setNext(id, next) - - // forward to responding client. - if err := next.writeFrame(RequestType, p); err != nil { - return next, 0, false - } - return next, 0, true -} - // Server represents a dms_server. type Server struct { log *logging.Logger diff --git a/vendor/github.com/skycoin/dmsg/server_conn.go b/vendor/github.com/skycoin/dmsg/server_conn.go new file mode 100644 index 0000000000..a162b5102b --- /dev/null +++ b/vendor/github.com/skycoin/dmsg/server_conn.go @@ -0,0 +1,243 @@ +package dmsg + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + + "github.com/skycoin/skycoin/src/util/logging" + + "github.com/skycoin/dmsg/cipher" +) + +// NextConn provides information on the next connection. +type NextConn struct { + conn *ServerConn + id uint16 +} + +func (r *NextConn) writeFrame(ft FrameType, p []byte) error { + if err := writeFrame(r.conn.Conn, MakeFrame(ft, r.id, p)); err != nil { + go func() { + if err := r.conn.Close(); err != nil { + log.WithError(err).Warn("Failed to close connection") + } + }() + return err + } + return nil +} + +// ServerConn is a connection between a dmsg.Server and a dmsg.Client from a server's perspective. +type ServerConn struct { + log *logging.Logger + + net.Conn + remoteClient cipher.PubKey + + nextRespID uint16 + nextConns map[uint16]*NextConn + mx sync.RWMutex +} + +// NewServerConn creates a new connection from the perspective of a dms_server. +func NewServerConn(log *logging.Logger, conn net.Conn, remoteClient cipher.PubKey) *ServerConn { + return &ServerConn{ + log: log, + Conn: conn, + remoteClient: remoteClient, + nextRespID: randID(false), + nextConns: make(map[uint16]*NextConn), + } +} + +func (c *ServerConn) delNext(id uint16) { + c.mx.Lock() + delete(c.nextConns, id) + c.mx.Unlock() +} + +func (c *ServerConn) setNext(id uint16, r *NextConn) { + c.mx.Lock() + c.nextConns[id] = r + c.mx.Unlock() +} + +func (c *ServerConn) getNext(id uint16) (*NextConn, bool) { + c.mx.RLock() + r := c.nextConns[id] + c.mx.RUnlock() + return r, r != nil +} + +func (c *ServerConn) addNext(ctx context.Context, r *NextConn) (uint16, error) { + c.mx.Lock() + defer c.mx.Unlock() + + for { + if r := c.nextConns[c.nextRespID]; r == nil { + break + } + c.nextRespID += 2 + + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + } + } + + id := c.nextRespID + c.nextRespID = id + 2 + c.nextConns[id] = r + return id, nil +} + +// PK returns the remote dms_client's public key. +func (c *ServerConn) PK() cipher.PubKey { + return c.remoteClient +} + +type getConnFunc func(pk cipher.PubKey) (*ServerConn, bool) + +// Serve handles (and forwards when necessary) incoming frames. +func (c *ServerConn) Serve(ctx context.Context, getConn getConnFunc) (err error) { + log := c.log.WithField("srcClient", c.remoteClient) + + // Only manually close the underlying net.Conn when the done signal is context-initiated. + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-done: + case <-ctx.Done(): + if err := c.Conn.Close(); err != nil { + log.WithError(err).Warn("failed to close underlying connection") + } + } + }() + + defer func() { + // Send CLOSE frames to all transports which are established with this dmsg.Client + // This ensures that all parties are informed about the transport closing. + c.mx.Lock() + for _, conn := range c.nextConns { + why := byte(0) + if err := conn.writeFrame(CloseType, []byte{why}); err != nil { + log.WithError(err).Warnf("failed to write frame: %s", err) + } + } + c.mx.Unlock() + + log.WithError(err).WithField("connCount", decrementServeCount()).Infoln("ClosingConn") + if err := c.Conn.Close(); err != nil { + log.WithError(err).Warn("Failed to close connection") + } + }() + + log.WithField("connCount", incrementServeCount()).Infoln("ServingConn") + + err = c.writeOK() + if err != nil { + return fmt.Errorf("sending OK failed: %s", err) + } + + for { + f, df, err := readFrame(c.Conn) + if err != nil { + return fmt.Errorf("read failed: %s", err) + } + log := log.WithField("received", f) + + switch df.Type { + case RequestType: + ctx, cancel := context.WithTimeout(ctx, TransportHandshakeTimeout) + _, why, ok := c.handleRequest(ctx, getConn, df.TpID, df.Pay) + cancel() + if !ok { + log.Debugln("FrameRejected: Erroneous request or unresponsive dstClient.") + if err := c.delChan(df.TpID, why); err != nil { + return err + } + } + log.Debugln("FrameForwarded") + + case AcceptType, FwdType, AckType, CloseType: + next, why, ok := c.forwardFrame(df.Type, df.TpID, df.Pay) + if !ok { + log.Debugln("FrameRejected: Failed to forward to dstClient.") + // Delete channel (and associations) on failure. + if err := c.delChan(df.TpID, why); err != nil { + return err + } + continue + } + log.Debugln("FrameForwarded") + + // On success, if Close frame, delete the associations. + if df.Type == CloseType { + c.delNext(df.TpID) + next.conn.delNext(next.id) + } + + default: + log.Debugln("FrameRejected: Unknown frame type.") + return errors.New("unknown frame of type received") + } + } +} + +func (c *ServerConn) delChan(id uint16, why byte) error { + c.delNext(id) + if err := writeCloseFrame(c.Conn, id, why); err != nil { + return fmt.Errorf("failed to write frame: %s", err) + } + return nil +} + +func (c *ServerConn) writeOK() error { + if err := writeFrame(c.Conn, MakeFrame(OkType, 0, nil)); err != nil { + return err + } + return nil +} + +// nolint:unparam +func (c *ServerConn) forwardFrame(ft FrameType, id uint16, p []byte) (*NextConn, byte, bool) { + next, ok := c.getNext(id) + if !ok { + return next, 0, false + } + if err := next.writeFrame(ft, p); err != nil { + return next, 0, false + } + return next, 0, true +} + +// nolint:unparam +func (c *ServerConn) handleRequest(ctx context.Context, getLink getConnFunc, id uint16, p []byte) (*NextConn, byte, bool) { + payload, err := unmarshalHandshakePayload(p) + if err != nil || payload.InitAddr.PK != c.PK() { + return nil, 0, false + } + respL, ok := getLink(payload.RespAddr.PK) + if !ok { + return nil, 0, false + } + + // set next relations. + respID, err := respL.addNext(ctx, &NextConn{conn: c, id: id}) + if err != nil { + return nil, 0, false + } + next := &NextConn{conn: respL, id: respID} + c.setNext(id, next) + + // forward to responding client. + if err := next.writeFrame(RequestType, p); err != nil { + return next, 0, false + } + return next, 0, true +} diff --git a/vendor/github.com/skycoin/dmsg/transport.go b/vendor/github.com/skycoin/dmsg/transport.go index 2b1da95a74..5a7467172b 100644 --- a/vendor/github.com/skycoin/dmsg/transport.go +++ b/vendor/github.com/skycoin/dmsg/transport.go @@ -41,17 +41,16 @@ type Transport struct { bufCh chan struct{} // chan for indicating whether this is a new FWD frame bufSize int // keeps track of the total size of 'buf' bufMx sync.Mutex // protects fields responsible for handling FWD and ACK frames - rMx sync.Mutex // TODO: (WORKAROUND) concurrent reads seem problematic right now. - serving chan struct{} // chan which closes when serving begins - servingOnce sync.Once // ensures 'serving' only closes once - done chan struct{} // chan which closes when transport stops serving - doneOnce sync.Once // ensures 'done' only closes once - doneFunc func(id uint16) // contains a method to remove the transport from dmsg.Client + serving chan struct{} // chan which closes when serving begins + servingOnce sync.Once // ensures 'serving' only closes once + done chan struct{} // chan which closes when transport stops serving + doneOnce sync.Once // ensures 'done' only closes once + doneFunc func() // contains a method that triggers when dmsg.Client closes } // NewTransport creates a new dms_tp. -func NewTransport(conn net.Conn, log *logging.Logger, local, remote Addr, id uint16, doneFunc func(id uint16)) *Transport { +func NewTransport(conn net.Conn, log *logging.Logger, local, remote Addr, id uint16, doneFunc func()) *Transport { tp := &Transport{ Conn: conn, log: log, @@ -96,7 +95,7 @@ func (tp *Transport) close() (closed bool) { closed = true close(tp.done) - tp.doneFunc(tp.id) + tp.doneFunc() tp.bufMx.Lock() close(tp.bufCh) @@ -170,12 +169,11 @@ func (tp *Transport) HandleFrame(f Frame) error { } // WriteRequest writes a REQUEST frame to dmsg_server to be forwarded to associated client. -func (tp *Transport) WriteRequest(port uint16) error { +func (tp *Transport) WriteRequest() error { payload := HandshakePayload{ - Version: HandshakePayloadVersion, - InitPK: tp.local.PK, - RespPK: tp.remote.PK, - Port: port, + Version: HandshakePayloadVersion, + InitAddr: tp.local, + RespAddr: tp.remote, } payloadBytes, err := marshalHandshakePayload(payload) if err != nil { @@ -360,9 +358,6 @@ func (tp *Transport) Serve() { func (tp *Transport) Read(p []byte) (n int, err error) { <-tp.serving - tp.rMx.Lock() - defer tp.rMx.Unlock() - startRead: tp.bufMx.Lock() n, err = tp.buf.Read(p) diff --git a/vendor/github.com/skycoin/dmsg/frame.go b/vendor/github.com/skycoin/dmsg/types.go similarity index 71% rename from vendor/github.com/skycoin/dmsg/frame.go rename to vendor/github.com/skycoin/dmsg/types.go index 33b354ef95..dcaabe6db6 100644 --- a/vendor/github.com/skycoin/dmsg/frame.go +++ b/vendor/github.com/skycoin/dmsg/types.go @@ -2,6 +2,7 @@ package dmsg import ( "encoding/binary" + "encoding/json" "fmt" "io" "math" @@ -18,7 +19,7 @@ const ( Type = "dmsg" // HandshakePayloadVersion contains payload version to maintain compatibility with future versions // of HandshakePayload format. - HandshakePayloadVersion = "1" + HandshakePayloadVersion = "2.0" tpBufCap = math.MaxUint16 tpBufFrameCap = math.MaxUint8 @@ -34,15 +35,43 @@ var ( AcceptBufferSize = 20 ) +// Addr implements net.Addr for dmsg addresses. +type Addr struct { + PK cipher.PubKey `json:"public_key"` + Port uint16 `json:"port"` +} + +// Network returns "dmsg" +func (Addr) Network() string { + return Type +} + +// String returns public key and port of node split by colon. +func (a Addr) String() string { + if a.Port == 0 { + return fmt.Sprintf("%s:~", a.PK) + } + return fmt.Sprintf("%s:%d", a.PK, a.Port) +} + // HandshakePayload represents format of payload sent with REQUEST frames. -// TODO(evanlinjin): Use 'dmsg.Addr' for PK:Port pair. type HandshakePayload struct { - Version string `json:"version"` // just in case the struct changes. - InitPK cipher.PubKey `json:"init_pk"` - RespPK cipher.PubKey `json:"resp_pk"` - Port uint16 `json:"port"` + Version string `json:"version"` // just in case the struct changes. + InitAddr Addr `json:"init_address"` + RespAddr Addr `json:"resp_address"` +} + +func marshalHandshakePayload(p HandshakePayload) ([]byte, error) { + return json.Marshal(p) } +func unmarshalHandshakePayload(b []byte) (HandshakePayload, error) { + var p HandshakePayload + err := json.Unmarshal(b, &p) + return p, err +} + +// determines whether the transport ID is of an initiator or responder. func isInitiatorID(tpID uint16) bool { return tpID%2 == 0 } func randID(initiator bool) uint16 { @@ -55,6 +84,7 @@ func randID(initiator bool) uint16 { } } +// serveCount records the number of dmsg.Servers connected var serveCount int64 func incrementServeCount() int64 { return atomic.AddInt64(&serveCount, 1) } @@ -133,24 +163,36 @@ func (f Frame) String() string { return fmt.Sprintf("%s", f.Type(), f.TpID(), f.PayLen(), p) } -func readFrame(r io.Reader) (Frame, error) { - f := make(Frame, headerLen) - if _, err := io.ReadFull(r, f); err != nil { - return nil, err +type disassembledFrame struct { + Type FrameType + TpID uint16 + Pay []byte +} + +// read and disassembles frame from reader +func readFrame(r io.Reader) (f Frame, df disassembledFrame, err error) { + f = make(Frame, headerLen) + if _, err = io.ReadFull(r, f); err != nil { + return } f = append(f, make([]byte, f.PayLen())...) - _, err := io.ReadFull(r, f[headerLen:]) - return f, err + if _, err = io.ReadFull(r, f[headerLen:]); err != nil { + return + } + t, id, p := f.Disassemble() + df = disassembledFrame{Type: t, TpID: id, Pay: p} + return } type writeError struct{ error } func (e *writeError) Error() string { return "write error: " + e.error.Error() } -func isWriteError(err error) bool { - _, ok := err.(*writeError) - return ok -} +// TODO(evanlinjin): Determine if this is still needed, may be useful elsewhere. +//func isWriteError(err error) bool { +// _, ok := err.(*writeError) +// return ok +//} func writeFrame(w io.Writer, f Frame) error { _, err := w.Write(f) diff --git a/vendor/golang.org/x/sys/windows/syscall_windows.go b/vendor/golang.org/x/sys/windows/syscall_windows.go index b23050924f..452d44126d 100644 --- a/vendor/golang.org/x/sys/windows/syscall_windows.go +++ b/vendor/golang.org/x/sys/windows/syscall_windows.go @@ -296,6 +296,7 @@ func NewCallbackCDecl(fn interface{}) uintptr { //sys coCreateGuid(pguid *GUID) (ret error) = ole32.CoCreateGuid //sys CoTaskMemFree(address unsafe.Pointer) = ole32.CoTaskMemFree //sys rtlGetVersion(info *OsVersionInfoEx) (ret error) = ntdll.RtlGetVersion +//sys rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32) = ntdll.RtlGetNtVersionNumbers // syscall interface implementation for other packages @@ -1306,8 +1307,8 @@ func (t Token) KnownFolderPath(folderID *KNOWNFOLDERID, flags uint32) (string, e return UTF16ToString((*[(1 << 30) - 1]uint16)(unsafe.Pointer(p))[:]), nil } -// RtlGetVersion returns the true version of the underlying operating system, ignoring -// any manifesting or compatibility layers on top of the win32 layer. +// RtlGetVersion returns the version of the underlying operating system, ignoring +// manifest semantics but is affected by the application compatibility layer. func RtlGetVersion() *OsVersionInfoEx { info := &OsVersionInfoEx{} info.osVersionInfoSize = uint32(unsafe.Sizeof(*info)) @@ -1318,3 +1319,11 @@ func RtlGetVersion() *OsVersionInfoEx { _ = rtlGetVersion(info) return info } + +// RtlGetNtVersionNumbers returns the version of the underlying operating system, +// ignoring manifest semantics and the application compatibility layer. +func RtlGetNtVersionNumbers() (majorVersion, minorVersion, buildNumber uint32) { + rtlGetNtVersionNumbers(&majorVersion, &minorVersion, &buildNumber) + buildNumber &= 0xffff + return +} diff --git a/vendor/golang.org/x/sys/windows/zsyscall_windows.go b/vendor/golang.org/x/sys/windows/zsyscall_windows.go index d461bed98a..e5d62f3bf5 100644 --- a/vendor/golang.org/x/sys/windows/zsyscall_windows.go +++ b/vendor/golang.org/x/sys/windows/zsyscall_windows.go @@ -234,6 +234,7 @@ var ( procCoCreateGuid = modole32.NewProc("CoCreateGuid") procCoTaskMemFree = modole32.NewProc("CoTaskMemFree") procRtlGetVersion = modntdll.NewProc("RtlGetVersion") + procRtlGetNtVersionNumbers = modntdll.NewProc("RtlGetNtVersionNumbers") procWSAStartup = modws2_32.NewProc("WSAStartup") procWSACleanup = modws2_32.NewProc("WSACleanup") procWSAIoctl = modws2_32.NewProc("WSAIoctl") @@ -2530,6 +2531,11 @@ func rtlGetVersion(info *OsVersionInfoEx) (ret error) { return } +func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32) { + syscall.Syscall(procRtlGetNtVersionNumbers.Addr(), 3, uintptr(unsafe.Pointer(majorVersion)), uintptr(unsafe.Pointer(minorVersion)), uintptr(unsafe.Pointer(buildNumber))) + return +} + func WSAStartup(verreq uint32, data *WSAData) (sockerr error) { r0, _, _ := syscall.Syscall(procWSAStartup.Addr(), 2, uintptr(verreq), uintptr(unsafe.Pointer(data)), 0) if r0 != 0 { diff --git a/vendor/modules.txt b/vendor/modules.txt index 6163337b5b..7f20ceff5d 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -40,6 +40,8 @@ github.com/matttproud/golang_protobuf_extensions/pbutil github.com/mgutz/ansi # github.com/mitchellh/go-homedir v1.1.0 github.com/mitchellh/go-homedir +# github.com/pkg/errors v0.8.0 +github.com/pkg/errors # github.com/pkg/profile v1.3.0 github.com/pkg/profile # github.com/pmezard/go-difflib v1.0.0 @@ -62,12 +64,13 @@ github.com/prometheus/procfs/internal/fs # github.com/sirupsen/logrus v1.4.2 github.com/sirupsen/logrus github.com/sirupsen/logrus/hooks/syslog -# github.com/skycoin/dmsg v0.0.0-20190805065636-70f4c32a994f => ../dmsg +# github.com/skycoin/dmsg v0.0.0-20190816104216-d18ee6aa05cb => ../dmsg github.com/skycoin/dmsg/cipher github.com/skycoin/dmsg github.com/skycoin/dmsg/disc github.com/skycoin/dmsg/noise github.com/skycoin/dmsg/ioutil +github.com/skycoin/dmsg/netutil # github.com/skycoin/skycoin v0.26.0 github.com/skycoin/skycoin/src/util/logging github.com/skycoin/skycoin/src/cipher @@ -84,7 +87,7 @@ github.com/stretchr/testify/assert github.com/stretchr/testify/require # go.etcd.io/bbolt v1.3.3 go.etcd.io/bbolt -# golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4 +# golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 golang.org/x/crypto/ssh/terminal golang.org/x/crypto/blake2b golang.org/x/crypto/blake2s @@ -98,7 +101,7 @@ golang.org/x/net/nettest golang.org/x/net/context golang.org/x/net/proxy golang.org/x/net/internal/socks -# golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a +# golang.org/x/sys v0.0.0-20190825160603-fb81701db80f golang.org/x/sys/unix golang.org/x/sys/windows golang.org/x/sys/windows/svc/eventlog From a161a2ddb0071ebd474f518c7a9f43320f397514 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Mon, 16 Sep 2019 20:47:24 +0300 Subject: [PATCH 18/43] Remove `freeLocalPort` from `Conn` --- pkg/app2/conn.go | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/pkg/app2/conn.go b/pkg/app2/conn.go index 9bf669d3fa..f250ea6621 100644 --- a/pkg/app2/conn.go +++ b/pkg/app2/conn.go @@ -10,11 +10,10 @@ import ( // Conn is a connection from app client to the server. type Conn struct { - id uint16 - rpc ServerRPCClient - local routing.Addr - remote routing.Addr - freeLocalPort func() + id uint16 + rpc ServerRPCClient + local routing.Addr + remote routing.Addr } func (c *Conn) Read(b []byte) (int, error) { @@ -26,8 +25,6 @@ func (c *Conn) Write(b []byte) (int, error) { } func (c *Conn) Close() error { - defer c.freeLocalPort() - return c.rpc.CloseConn(c.id) } From a790f2c1b1c49b360441f3b21f319967e048681f Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Mon, 16 Sep 2019 21:26:12 +0300 Subject: [PATCH 19/43] Start to merge managers --- pkg/app2/listeners_manager.go | 2 +- pkg/app2/manager.go | 95 +++++++++++++++++++++++++++++++++++ pkg/app2/server_rpc.go | 74 ++++++++++++++++++++++++--- 3 files changed, 163 insertions(+), 8 deletions(-) create mode 100644 pkg/app2/manager.go diff --git a/pkg/app2/listeners_manager.go b/pkg/app2/listeners_manager.go index 48233da113..420ca5c357 100644 --- a/pkg/app2/listeners_manager.go +++ b/pkg/app2/listeners_manager.go @@ -8,7 +8,7 @@ import ( "github.com/skycoin/dmsg" ) -// connsManager manages listeners within the app server. +// listenersManager manages listeners within the app server. type listenersManager struct { listeners map[uint16]*dmsg.Listener mx sync.RWMutex diff --git a/pkg/app2/manager.go b/pkg/app2/manager.go new file mode 100644 index 0000000000..a1c6bdeeae --- /dev/null +++ b/pkg/app2/manager.go @@ -0,0 +1,95 @@ +package app2 + +import ( + "fmt" + "sync" + + "github.com/pkg/errors" +) + +var ( + errNoMoreAvailableValues = errors.New("no more available values") +) + +// manager manages allows to store and retrieve arbitrary values +// associated with the `uint16` key in a thread-safe manner. +// Provides method to generate key. +type manager struct { + values map[uint16]interface{} + mx sync.RWMutex + lstKey uint16 +} + +// newManager constructs new `manager`. +func newManager() *manager { + return &manager{ + values: make(map[uint16]interface{}), + } +} + +// `nextID` reserves next free slot for the value and returns the key for it. +func (m *manager) nextID() (*uint16, error) { + m.mx.Lock() + + nxtKey := m.lstKey + 1 + for ; nxtKey != m.lstKey; nxtKey++ { + if _, ok := m.values[nxtKey]; !ok { + break + } + } + + if nxtKey == m.lstKey { + m.mx.Unlock() + return nil, errNoMoreAvailableValues + } + + m.values[nxtKey] = nil + m.lstKey = nxtKey + + m.mx.Unlock() + return &nxtKey, nil +} + +// getAndRemove removes value specified by `key` from the manager instance and +// returns it. +func (m *manager) getAndRemove(key uint16) (interface{}, error) { + m.mx.Lock() + v, ok := m.values[key] + if !ok { + m.mx.Unlock() + return nil, fmt.Errorf("no value with key %d", key) + } + + if v == nil { + m.mx.Unlock() + return nil, fmt.Errorf("value with key %d is not set", key) + } + + delete(m.values, key) + + m.mx.Unlock() + return v, nil +} + +// set sets value `v` associated with `key`. +func (m *manager) set(key uint16, v interface{}) error { + m.mx.Lock() + + if l, ok := m.values[key]; ok && l != nil { + m.mx.Unlock() + return errors.New("value already exists") + } + + m.values[key] = v + + m.mx.Unlock() + return nil +} + +// get gets the value associated with the `key`. +func (m *manager) get(key uint16) (interface{}, bool) { + m.mx.RLock() + lis, ok := m.values[key] + m.mx.RUnlock() + return lis, ok +} diff --git a/pkg/app2/server_rpc.go b/pkg/app2/server_rpc.go index 2800db44bc..761152fc57 100644 --- a/pkg/app2/server_rpc.go +++ b/pkg/app2/server_rpc.go @@ -3,6 +3,7 @@ package app2 import ( "context" "fmt" + "net" "github.com/skycoin/skycoin/src/util/logging" @@ -15,8 +16,8 @@ import ( // ServerRPC is a RPC interface for the app server. type ServerRPC struct { dmsgC *dmsg.Client - lm *listenersManager - cm *connsManager + lm *manager + cm *manager log *logging.Logger } @@ -24,8 +25,8 @@ type ServerRPC struct { func newServerRPC(log *logging.Logger, dmsgC *dmsg.Client) *ServerRPC { return &ServerRPC{ dmsgC: dmsgC, - lm: newListenersManager(), - cm: newConnsManager(), + lm: newManager(), + cm: newManager(), log: log, } } @@ -80,7 +81,7 @@ type AcceptResp struct { // Accept accepts connection from the listener specified by `lisID`. func (r *ServerRPC) Accept(lisID *uint16, resp *AcceptResp) error { - lis, ok := r.lm.get(*lisID) + lis, ok := r.getListener(*lisID) if !ok { return fmt.Errorf("not listener with id %d", *lisID) } @@ -127,7 +128,7 @@ type WriteReq struct { // Write writes to the connection. func (r *ServerRPC) Write(req *WriteReq, n *int) error { - conn, ok := r.cm.get(req.ConnID) + conn, ok := r.getConn(req.ConnID) if !ok { return fmt.Errorf("no conn with id %d", req.ConnID) } @@ -149,7 +150,7 @@ type ReadResp struct { // Read reads data from connection specified by `connID`. func (r *ServerRPC) Read(connID *uint16, resp *ReadResp) error { - conn, ok := r.cm.get(*connID) + conn, ok := r.getConn(*connID) if !ok { return fmt.Errorf("no conn with id %d", *connID) } @@ -182,3 +183,62 @@ func (r *ServerRPC) CloseListener(lisID *uint16, _ *struct{}) error { return lis.Close() } + +func (r *ServerRPC) getAndRemoveListener(lisID uint16) (*dmsg.Listener, error) { + lisIfc, err := r.lm.getAndRemove(lisID) + if err != nil { + return nil, err + } + + return r.assertListener(lisIfc) +} + +func (r *ServerRPC) getAndRemoveConn(connID uint16) (net.Conn, error) { + connIfc, err := r.cm.getAndRemove(connID) + if err != nil { + return nil, err + } + + return r.assertConn(connIfc) +} + +func (r *ServerRPC) getListener(lisID uint16) (*dmsg.Listener, error) { + lisIfc, ok := r.lm.get(lisID) + if !ok { + return nil, false + } + + return r.assertListener(lisIfc) +} + +func (r *ServerRPC) getConn(connID uint16) (net.Conn, error) { + connIfc, ok := r.cm.get(connID) + if !ok { + return nil, false + } + + conn, ok := connIfc.(net.Conn) + if !ok { + r.log.Errorln("wrong type of value stored for conn") + return nil, false + } + + return conn, true +} +func (r *ServerRPC) assertListener(v interface{}) (*dmsg.Listener, error) { + lis, ok := v.(*dmsg.Listener) + if !ok { + return nil, errors.New("wrong type of value stored for listener") + } + + return lis, nil +} + +func (r *ServerRPC) assertConn(v interface{}) (net.Conn, error) { + conn, ok := v.(net.Conn) + if !ok { + return nil, errors.New("wrong type of value stored for conn") + } + + return conn, nil +} From 3e295fd5fe2e67cef481a0ad24338b1bf0bdf618 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Tue, 17 Sep 2019 15:29:29 +0300 Subject: [PATCH 20/43] Add more comments, finish manager --- pkg/app2/client.go | 3 +- pkg/app2/conn.go | 23 +++++---- pkg/app2/conns_manager.go | 90 ----------------------------------- pkg/app2/errors.go | 7 +++ pkg/app2/listener.go | 1 + pkg/app2/listeners_manager.go | 90 ----------------------------------- pkg/app2/server_rpc.go | 48 +++++++++---------- vendor/modules.txt | 2 +- 8 files changed, 49 insertions(+), 215 deletions(-) delete mode 100644 pkg/app2/conns_manager.go delete mode 100644 pkg/app2/listeners_manager.go diff --git a/pkg/app2/client.go b/pkg/app2/client.go index 32ea18169f..c21e78f81f 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -5,9 +5,8 @@ import ( "net" "net/rpc" - "github.com/skycoin/dmsg/netutil" - "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/dmsg/netutil" "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/pkg/routing" diff --git a/pkg/app2/conn.go b/pkg/app2/conn.go index f250ea6621..75f735ad98 100644 --- a/pkg/app2/conn.go +++ b/pkg/app2/conn.go @@ -4,16 +4,17 @@ import ( "net" "time" - "github.com/pkg/errors" "github.com/skycoin/skywire/pkg/routing" ) // Conn is a connection from app client to the server. +// Implements `net.Conn`. type Conn struct { - id uint16 - rpc ServerRPCClient - local routing.Addr - remote routing.Addr + id uint16 + rpc ServerRPCClient + local routing.Addr + remote routing.Addr + freeLocalPort func() } func (c *Conn) Read(b []byte) (int, error) { @@ -25,6 +26,12 @@ func (c *Conn) Write(b []byte) (int, error) { } func (c *Conn) Close() error { + defer func() { + if c.freeLocalPort != nil { + c.freeLocalPort() + } + }() + return c.rpc.CloseConn(c.id) } @@ -37,13 +44,13 @@ func (c *Conn) RemoteAddr() net.Addr { } func (c *Conn) SetDeadline(t time.Time) error { - return errors.New("method not implemented") + return errMethodNotImplemented } func (c *Conn) SetReadDeadline(t time.Time) error { - return errors.New("method not implemented") + return errMethodNotImplemented } func (c *Conn) SetWriteDeadline(t time.Time) error { - return errors.New("method not implemented") + return errMethodNotImplemented } diff --git a/pkg/app2/conns_manager.go b/pkg/app2/conns_manager.go deleted file mode 100644 index 101ad42425..0000000000 --- a/pkg/app2/conns_manager.go +++ /dev/null @@ -1,90 +0,0 @@ -package app2 - -import ( - "fmt" - "net" - "sync" - - "github.com/pkg/errors" -) - -// connsManager manages connections within the app server. -type connsManager struct { - conns map[uint16]net.Conn - mx sync.RWMutex - lstID uint16 -} - -// newConnsManager constructs new `connsManager`. -func newConnsManager() *connsManager { - return &connsManager{ - conns: make(map[uint16]net.Conn), - } -} - -// `nextID` reserves slot for the next connection and returns its id. -func (m *connsManager) nextID() (*uint16, error) { - m.mx.Lock() - - connID := m.lstID + 1 - for ; connID < m.lstID; connID++ { - if _, ok := m.conns[connID]; !ok { - break - } - } - - if connID == m.lstID { - m.mx.Unlock() - return nil, errors.New("no more available conns") - } - - m.conns[connID] = nil - m.lstID = connID - - m.mx.Unlock() - return &connID, nil -} - -// getAndRemove removes connection specified by `connID` from the manager instance and -// returns it. -func (m *connsManager) getAndRemove(connID uint16) (net.Conn, error) { - m.mx.Lock() - conn, ok := m.conns[connID] - if !ok { - m.mx.Unlock() - return nil, fmt.Errorf("no conn with id %d", connID) - } - - if conn == nil { - m.mx.Unlock() - return nil, fmt.Errorf("conn with id %d is not set", connID) - } - - delete(m.conns, connID) - - m.mx.Unlock() - return conn, nil -} - -// set sets `conn` associated with `connID`. -func (m *connsManager) set(connID uint16, conn net.Conn) error { - m.mx.Lock() - - if c, ok := m.conns[connID]; ok && c != nil { - m.mx.Unlock() - return errors.New("conn already exists") - } - - m.conns[connID] = conn - - m.mx.Unlock() - return nil -} - -// get gets the connection associated with the `connID`. -func (m *connsManager) get(connID uint16) (net.Conn, bool) { - m.mx.RLock() - conn, ok := m.conns[connID] - m.mx.RUnlock() - return conn, ok -} diff --git a/pkg/app2/errors.go b/pkg/app2/errors.go index f07b063015..88653a613d 100644 --- a/pkg/app2/errors.go +++ b/pkg/app2/errors.go @@ -3,5 +3,12 @@ package app2 import "github.com/pkg/errors" var ( + // ErrPortAlreadyBound is being returned when trying to bind to the port + // which is already bound to. ErrPortAlreadyBound = errors.New("port is already bound") ) + +var ( + // errMethodNotImplemented serves as a return value for non-implemented funcs (stubs). + errMethodNotImplemented = errors.New("method not implemented") +) diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go index cebc34935f..9b7bc33977 100644 --- a/pkg/app2/listener.go +++ b/pkg/app2/listener.go @@ -7,6 +7,7 @@ import ( ) // Listener is a listener for app server connections. +// Implements `net.Listener`. type Listener struct { id uint16 rpc ServerRPCClient diff --git a/pkg/app2/listeners_manager.go b/pkg/app2/listeners_manager.go deleted file mode 100644 index 420ca5c357..0000000000 --- a/pkg/app2/listeners_manager.go +++ /dev/null @@ -1,90 +0,0 @@ -package app2 - -import ( - "fmt" - "sync" - - "github.com/pkg/errors" - "github.com/skycoin/dmsg" -) - -// listenersManager manages listeners within the app server. -type listenersManager struct { - listeners map[uint16]*dmsg.Listener - mx sync.RWMutex - lstID uint16 -} - -// newListenersManager constructs new `listenersManager`. -func newListenersManager() *listenersManager { - return &listenersManager{ - listeners: make(map[uint16]*dmsg.Listener), - } -} - -// `nextID` reserves slot for the next listener and returns its id. -func (m *listenersManager) nextID() (*uint16, error) { - m.mx.Lock() - - lisID := m.lstID + 1 - for ; lisID < m.lstID; lisID++ { - if _, ok := m.listeners[lisID]; !ok { - break - } - } - - if lisID == m.lstID { - m.mx.Unlock() - return nil, errors.New("no more available listeners") - } - - m.listeners[lisID] = nil - m.lstID = lisID - - m.mx.Unlock() - return &lisID, nil -} - -// getAndRemove removes listener specified by `lisID` from the manager instance and -// returns it. -func (m *listenersManager) getAndRemove(lisID uint16) (*dmsg.Listener, error) { - m.mx.Lock() - lis, ok := m.listeners[lisID] - if !ok { - m.mx.Unlock() - return nil, fmt.Errorf("no listener with id %d", lisID) - } - - if lis == nil { - m.mx.Unlock() - return nil, fmt.Errorf("listener with id %d is not set", lisID) - } - - delete(m.listeners, lisID) - - m.mx.Unlock() - return lis, nil -} - -// set sets `lis` associated with `lisID`. -func (m *listenersManager) set(lisID uint16, lis *dmsg.Listener) error { - m.mx.Lock() - - if l, ok := m.listeners[lisID]; ok && l != nil { - m.mx.Unlock() - return errors.New("listener already exists") - } - - m.listeners[lisID] = lis - - m.mx.Unlock() - return nil -} - -// get gets the listener associated with the `lisID`. -func (m *listenersManager) get(lisID uint16) (*dmsg.Listener, bool) { - m.mx.RLock() - lis, ok := m.listeners[lisID] - m.mx.RUnlock() - return lis, ok -} diff --git a/pkg/app2/server_rpc.go b/pkg/app2/server_rpc.go index 761152fc57..ac56e848bd 100644 --- a/pkg/app2/server_rpc.go +++ b/pkg/app2/server_rpc.go @@ -5,10 +5,9 @@ import ( "fmt" "net" - "github.com/skycoin/skycoin/src/util/logging" - "github.com/pkg/errors" "github.com/skycoin/dmsg" + "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/pkg/routing" ) @@ -81,9 +80,9 @@ type AcceptResp struct { // Accept accepts connection from the listener specified by `lisID`. func (r *ServerRPC) Accept(lisID *uint16, resp *AcceptResp) error { - lis, ok := r.getListener(*lisID) - if !ok { - return fmt.Errorf("not listener with id %d", *lisID) + lis, err := r.getListener(*lisID) + if err != nil { + return err } connID, err := r.cm.nextID() @@ -128,12 +127,11 @@ type WriteReq struct { // Write writes to the connection. func (r *ServerRPC) Write(req *WriteReq, n *int) error { - conn, ok := r.getConn(req.ConnID) - if !ok { - return fmt.Errorf("no conn with id %d", req.ConnID) + conn, err := r.getConn(req.ConnID) + if err != nil { + return err } - var err error *n, err = conn.Write(req.B) if err != nil { return err @@ -150,12 +148,11 @@ type ReadResp struct { // Read reads data from connection specified by `connID`. func (r *ServerRPC) Read(connID *uint16, resp *ReadResp) error { - conn, ok := r.getConn(*connID) - if !ok { - return fmt.Errorf("no conn with id %d", *connID) + conn, err := r.getConn(*connID) + if err != nil { + return err } - var err error resp.N, err = conn.Read(resp.B) if err != nil { return err @@ -166,7 +163,7 @@ func (r *ServerRPC) Read(connID *uint16, resp *ReadResp) error { // CloseConn closes connection specified by `connID`. func (r *ServerRPC) CloseConn(connID *uint16, _ *struct{}) error { - conn, err := r.cm.getAndRemove(*connID) + conn, err := r.getAndRemoveConn(*connID) if err != nil { return err } @@ -176,7 +173,7 @@ func (r *ServerRPC) CloseConn(connID *uint16, _ *struct{}) error { // CloseListener closes listener specified by `lisID`. func (r *ServerRPC) CloseListener(lisID *uint16, _ *struct{}) error { - lis, err := r.lm.getAndRemove(*lisID) + lis, err := r.getAndRemoveListener(*lisID) if err != nil { return err } @@ -184,6 +181,8 @@ func (r *ServerRPC) CloseListener(lisID *uint16, _ *struct{}) error { return lis.Close() } +// getAndRemoveListener gets listener from the manager by `lisID` and removes it. +// Handles type assertion. func (r *ServerRPC) getAndRemoveListener(lisID uint16) (*dmsg.Listener, error) { lisIfc, err := r.lm.getAndRemove(lisID) if err != nil { @@ -193,6 +192,8 @@ func (r *ServerRPC) getAndRemoveListener(lisID uint16) (*dmsg.Listener, error) { return r.assertListener(lisIfc) } +// getAndRemoveConn gets conn from the manager by `connID` and removes it. +// Handles type assertion. func (r *ServerRPC) getAndRemoveConn(connID uint16) (net.Conn, error) { connIfc, err := r.cm.getAndRemove(connID) if err != nil { @@ -202,29 +203,27 @@ func (r *ServerRPC) getAndRemoveConn(connID uint16) (net.Conn, error) { return r.assertConn(connIfc) } +// getListener gets listener from the manager by `lisID`. Handles type assertion. func (r *ServerRPC) getListener(lisID uint16) (*dmsg.Listener, error) { lisIfc, ok := r.lm.get(lisID) if !ok { - return nil, false + return nil, fmt.Errorf("no listener with key %d", lisID) } return r.assertListener(lisIfc) } +// getConn gets conn from the manager by `connID`. Handles type assertion. func (r *ServerRPC) getConn(connID uint16) (net.Conn, error) { connIfc, ok := r.cm.get(connID) if !ok { - return nil, false - } - - conn, ok := connIfc.(net.Conn) - if !ok { - r.log.Errorln("wrong type of value stored for conn") - return nil, false + return nil, fmt.Errorf("no conn with key %d", connID) } - return conn, true + return r.assertConn(connIfc) } + +// assertListener asserts that `v` is of type `*dmsg.Listener`. func (r *ServerRPC) assertListener(v interface{}) (*dmsg.Listener, error) { lis, ok := v.(*dmsg.Listener) if !ok { @@ -234,6 +233,7 @@ func (r *ServerRPC) assertListener(v interface{}) (*dmsg.Listener, error) { return lis, nil } +// assertConn asserts that `v` is of type `net.Conn`. func (r *ServerRPC) assertConn(v interface{}) (net.Conn, error) { conn, ok := v.(net.Conn) if !ok { diff --git a/vendor/modules.txt b/vendor/modules.txt index 7f20ceff5d..8d9c7d5d63 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -68,9 +68,9 @@ github.com/sirupsen/logrus/hooks/syslog github.com/skycoin/dmsg/cipher github.com/skycoin/dmsg github.com/skycoin/dmsg/disc +github.com/skycoin/dmsg/netutil github.com/skycoin/dmsg/noise github.com/skycoin/dmsg/ioutil -github.com/skycoin/dmsg/netutil # github.com/skycoin/skycoin v0.26.0 github.com/skycoin/skycoin/src/util/logging github.com/skycoin/skycoin/src/cipher From ca7fcae805b225f65441687e09c19632cec6f423 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Tue, 17 Sep 2019 15:51:58 +0300 Subject: [PATCH 21/43] Start implementing `manager` tests --- pkg/app2/manager_test.go | 99 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 pkg/app2/manager_test.go diff --git a/pkg/app2/manager_test.go b/pkg/app2/manager_test.go new file mode 100644 index 0000000000..dd42bdc51b --- /dev/null +++ b/pkg/app2/manager_test.go @@ -0,0 +1,99 @@ +package app2 + +import ( + "math" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestManager_NextID(t *testing.T) { + t.Run("simple call", func(t *testing.T) { + m := newManager() + + nextKey, err := m.nextID() + require.NoError(t, err) + require.Equal(t, *nextKey, uint16(1)) + require.Equal(t, *nextKey, m.lstKey) + + nextKey, err = m.nextID() + require.NoError(t, err) + require.Equal(t, *nextKey, uint16(2)) + require.Equal(t, *nextKey, m.lstKey) + }) + + t.Run("call on full manager", func(t *testing.T) { + m := newManager() + for i := uint16(0); i < math.MaxUint16; i++ { + m.values[i] = nil + } + m.values[math.MaxUint16] = nil + + _, err := m.nextID() + require.Error(t, err) + }) + + t.Run("concurrent run", func(t *testing.T) { + m := newManager() + + valsToReserve := 10000 + + errs := make(chan error) + for i := 0; i < valsToReserve; i++ { + go func() { + _, err := m.nextID() + errs <- err + }() + } + + for i := 0; i < valsToReserve; i++ { + require.NoError(t, <-errs) + } + close(errs) + + require.Equal(t, m.lstKey, uint16(valsToReserve)) + for i := uint16(1); i < uint16(valsToReserve); i++ { + v, ok := m.values[i] + require.True(t, ok) + require.Nil(t, v) + } + }) +} + +func TestManager_GetAndRemove(t *testing.T) { + t.Run("simple call", func(t *testing.T) { + m := newManager() + + v := "value" + + m.values[1] = v + + gotV, err := m.getAndRemove(1) + require.NoError(t, err) + require.NotNil(t, gotV) + require.Equal(t, gotV, v) + + _, ok := m.values[1] + require.False(t, ok) + }) + + t.Run("no value", func(t *testing.T) { + m := newManager() + + _, err := m.getAndRemove(1) + require.Error(t, err) + }) + + t.Run("value not set", func(t *testing.T) { + m := newManager() + + m.values[1] = nil + + _, err := m.getAndRemove(1) + require.Error(t, err) + }) + + t.Run("concurrent run", func(t *testing.T) { + // TODO(Darkren): finish + }) +} From 1daf6c05ece0b486654e5758e521f2ecc830ca6f Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Tue, 17 Sep 2019 16:38:51 +0300 Subject: [PATCH 22/43] Almost finish `manager` tests --- pkg/app2/manager.go | 14 +++- pkg/app2/manager_test.go | 176 +++++++++++++++++++++++++++++++++++++-- pkg/app2/server_rpc.go | 6 +- 3 files changed, 184 insertions(+), 12 deletions(-) diff --git a/pkg/app2/manager.go b/pkg/app2/manager.go index a1c6bdeeae..5033b1e929 100644 --- a/pkg/app2/manager.go +++ b/pkg/app2/manager.go @@ -27,8 +27,8 @@ func newManager() *manager { } } -// `nextID` reserves next free slot for the value and returns the key for it. -func (m *manager) nextID() (*uint16, error) { +// `nextKey` reserves next free slot for the value and returns the key for it. +func (m *manager) nextKey() (*uint16, error) { m.mx.Lock() nxtKey := m.lstKey + 1 @@ -75,9 +75,15 @@ func (m *manager) getAndRemove(key uint16) (interface{}, error) { func (m *manager) set(key uint16, v interface{}) error { m.mx.Lock() - if l, ok := m.values[key]; ok && l != nil { + l, ok := m.values[key] + if !ok { m.mx.Unlock() - return errors.New("value already exists") + return errors.New("key is not reserved") + } else { + if l != nil { + m.mx.Unlock() + return errors.New("value already exists") + } } m.values[key] = v diff --git a/pkg/app2/manager_test.go b/pkg/app2/manager_test.go index dd42bdc51b..1032a9aca1 100644 --- a/pkg/app2/manager_test.go +++ b/pkg/app2/manager_test.go @@ -11,13 +11,19 @@ func TestManager_NextID(t *testing.T) { t.Run("simple call", func(t *testing.T) { m := newManager() - nextKey, err := m.nextID() + nextKey, err := m.nextKey() require.NoError(t, err) + v, ok := m.values[*nextKey] + require.True(t, ok) + require.Nil(t, v) require.Equal(t, *nextKey, uint16(1)) require.Equal(t, *nextKey, m.lstKey) - nextKey, err = m.nextID() + nextKey, err = m.nextKey() require.NoError(t, err) + v, ok = m.values[*nextKey] + require.True(t, ok) + require.Nil(t, v) require.Equal(t, *nextKey, uint16(2)) require.Equal(t, *nextKey, m.lstKey) }) @@ -29,7 +35,7 @@ func TestManager_NextID(t *testing.T) { } m.values[math.MaxUint16] = nil - _, err := m.nextID() + _, err := m.nextKey() require.Error(t, err) }) @@ -41,7 +47,7 @@ func TestManager_NextID(t *testing.T) { errs := make(chan error) for i := 0; i < valsToReserve; i++ { go func() { - _, err := m.nextID() + _, err := m.nextKey() errs <- err }() } @@ -94,6 +100,166 @@ func TestManager_GetAndRemove(t *testing.T) { }) t.Run("concurrent run", func(t *testing.T) { - // TODO(Darkren): finish + m := newManager() + + m.values[1] = "value" + + concurrency := 1000 + errs := make(chan error, concurrency) + for i := uint16(0); i < uint16(concurrency); i++ { + go func() { + _, err := m.getAndRemove(1) + errs <- err + }() + } + + errsCount := 0 + for i := 0; i < concurrency; i++ { + err := <-errs + if err != nil { + errsCount++ + } + } + close(errs) + require.Equal(t, errsCount, concurrency-1) + + _, ok := m.values[1] + require.False(t, ok) + }) +} + +func TestManager_Set(t *testing.T) { + t.Run("simple call", func(t *testing.T) { + m := newManager() + + nextKey, err := m.nextKey() + require.NoError(t, err) + + v := "value" + + err = m.set(*nextKey, v) + require.NoError(t, err) + gotV, ok := m.values[*nextKey] + require.True(t, ok) + require.Equal(t, gotV, v) + }) + + t.Run("key is not reserved", func(t *testing.T) { + m := newManager() + + err := m.set(1, "value") + require.Error(t, err) + + _, ok := m.values[1] + require.False(t, ok) + }) + + t.Run("value already exists", func(t *testing.T) { + m := newManager() + + v := "value" + + m.values[1] = v + + err := m.set(1, "value2") + require.Error(t, err) + gotV, ok := m.values[1] + require.True(t, ok) + require.Equal(t, gotV, v) + }) + + t.Run("concurrent run", func(t *testing.T) { + m := newManager() + + concurrency := 1000 + + nextKeyPtr, err := m.nextKey() + require.NoError(t, err) + + nextKey := *nextKeyPtr + + errs := make(chan error) + setV := make(chan int) + for i := 0; i < concurrency; i++ { + go func(v int) { + err := m.set(nextKey, v) + errs <- err + if err == nil { + setV <- v + } + }(i) + } + + errsCount := 0 + for i := 0; i < concurrency; i++ { + err := <-errs + if err != nil { + errsCount++ + } + } + close(errs) + + v := <-setV + close(setV) + + gotV, ok := m.values[nextKey] + require.True(t, ok) + require.Equal(t, gotV, v) + }) +} + +func TestManager_Get(t *testing.T) { + prepManagerWithVal := func(v interface{}) (*manager, uint16) { + m := newManager() + + nextKey, err := m.nextKey() + require.NoError(t, err) + + err = m.set(*nextKey, v) + require.NoError(t, err) + + return m, *nextKey + } + + t.Run("simple call", func(t *testing.T) { + v := "value" + + m, key := prepManagerWithVal(v) + + gotV, ok := m.get(key) + require.True(t, ok) + require.Equal(t, gotV, v) + + _, ok = m.get(100) + require.False(t, ok) + }) + + t.Run("concurrent run", func(t *testing.T) { + v := "value" + + m, key := prepManagerWithVal(v) + + concurrency := 1000 + type getRes struct { + v interface{} + ok bool + } + res := make(chan getRes) + for i := 0; i < concurrency; i++ { + go func() { + val, ok := m.get(key) + res <- getRes{ + v: val, + ok: ok, + } + }() + } + + for i := 0; i < concurrency; i++ { + r := <-res + require.True(t, r.ok) + require.Equal(t, r.v, v) + } + close(res) }) } diff --git a/pkg/app2/server_rpc.go b/pkg/app2/server_rpc.go index ac56e848bd..db4372ef1f 100644 --- a/pkg/app2/server_rpc.go +++ b/pkg/app2/server_rpc.go @@ -32,7 +32,7 @@ func newServerRPC(log *logging.Logger, dmsgC *dmsg.Client) *ServerRPC { // Dial dials to the remote. func (r *ServerRPC) Dial(remote *routing.Addr, connID *uint16) error { - connID, err := r.cm.nextID() + connID, err := r.cm.nextKey() if err != nil { return err } @@ -51,7 +51,7 @@ func (r *ServerRPC) Dial(remote *routing.Addr, connID *uint16) error { // Listen starts listening. func (r *ServerRPC) Listen(local *routing.Addr, lisID *uint16) error { - lisID, err := r.lm.nextID() + lisID, err := r.lm.nextKey() if err != nil { return err } @@ -85,7 +85,7 @@ func (r *ServerRPC) Accept(lisID *uint16, resp *AcceptResp) error { return err } - connID, err := r.cm.nextID() + connID, err := r.cm.nextKey() if err != nil { return err } From 8b8ad335a700e75695fffd29384eb8bc5e209d51 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Wed, 18 Sep 2019 09:46:06 +0300 Subject: [PATCH 23/43] Add client tests --- go.sum | 1 + pkg/app2/client.go | 15 +-- pkg/app2/client_test.go | 142 +++++++++++++++++++++++++++ pkg/app2/mock_server_rpc_client.go | 151 +++++++++++++++++++++++++++++ pkg/app2/server_rpc_client.go | 6 +- 5 files changed, 303 insertions(+), 12 deletions(-) create mode 100644 pkg/app2/client_test.go create mode 100644 pkg/app2/mock_server_rpc_client.go diff --git a/go.sum b/go.sum index f60c1121a3..20a615895e 100644 --- a/go.sum +++ b/go.sum @@ -105,6 +105,7 @@ github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= diff --git a/pkg/app2/client.go b/pkg/app2/client.go index c21e78f81f..3d302c3024 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -3,12 +3,9 @@ package app2 import ( "context" "net" - "net/rpc" "github.com/skycoin/dmsg/cipher" "github.com/skycoin/dmsg/netutil" - "github.com/skycoin/skycoin/src/util/logging" - "github.com/skycoin/skywire/pkg/routing" ) @@ -17,7 +14,6 @@ type Client struct { pk cipher.PubKey pid ProcID rpc ServerRPCClient - log *logging.Logger porter *netutil.Porter } @@ -26,19 +22,17 @@ type Client struct { // - localPK: The local public key of the parent skywire visor. // - pid: The procID assigned for the process that Client is being used by. // - rpc: RPC client to communicate with the server. -func NewClient(log *logging.Logger, localPK cipher.PubKey, pid ProcID, rpc *rpc.Client, - porter *netutil.Porter) *Client { +func NewClient(localPK cipher.PubKey, pid ProcID, rpc ServerRPCClient, porter *netutil.Porter) *Client { return &Client{ pk: localPK, pid: pid, - rpc: newServerRPCClient(rpc), - log: log, + rpc: rpc, porter: porter, } } // Dial dials the remote node using `remote`. -func (c *Client) Dial(remote routing.Addr) (*Conn, error) { +func (c *Client) Dial(remote routing.Addr) (net.Conn, error) { localPort, free, err := c.porter.ReserveEphemeral(context.TODO(), nil) if err != nil { return nil, err @@ -46,6 +40,7 @@ func (c *Client) Dial(remote routing.Addr) (*Conn, error) { connID, err := c.rpc.Dial(remote) if err != nil { + free() return nil, err } @@ -67,7 +62,6 @@ func (c *Client) Dial(remote routing.Addr) (*Conn, error) { func (c *Client) Listen(port routing.Port) (net.Listener, error) { ok, free := c.porter.Reserve(uint16(port), nil) if !ok { - free() return nil, ErrPortAlreadyBound } @@ -78,6 +72,7 @@ func (c *Client) Listen(port routing.Port) (net.Listener, error) { lisID, err := c.rpc.Listen(local) if err != nil { + free() return nil, err } diff --git a/pkg/app2/client_test.go b/pkg/app2/client_test.go new file mode 100644 index 0000000000..dcc4a2c35a --- /dev/null +++ b/pkg/app2/client_test.go @@ -0,0 +1,142 @@ +package app2 + +import ( + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + + "github.com/skycoin/skywire/pkg/routing" + + "github.com/skycoin/dmsg/cipher" + + "github.com/skycoin/dmsg/netutil" +) + +func TestClient_Dial(t *testing.T) { + localPK, _ := cipher.GenerateKeyPair() + pid := ProcID(1) + + remotePK, _ := cipher.GenerateKeyPair() + remotePort := routing.Port(120) + remote := routing.Addr{ + PubKey: remotePK, + Port: remotePort, + } + + t.Run("ok", func(t *testing.T) { + dialConnID := uint16(1) + var dialErr error + + rpc := &MockServerRPCClient{} + rpc.On("Dial", remote).Return(dialConnID, dialErr) + + cl := NewClient(localPK, pid, rpc, netutil.NewPorter(netutil.PorterMinEphemeral)) + + wantConn := &Conn{ + id: dialConnID, + rpc: rpc, + local: routing.Addr{ + PubKey: localPK, + }, + remote: remote, + } + + conn, err := cl.Dial(remote) + appConn, ok := conn.(*Conn) + require.True(t, ok) + + require.NoError(t, err) + require.Equal(t, wantConn.id, appConn.id) + require.Equal(t, wantConn.rpc, appConn.rpc) + require.Equal(t, wantConn.local.PubKey, appConn.local.PubKey) + require.Equal(t, wantConn.remote, appConn.remote) + require.NotNil(t, appConn.freeLocalPort) + portVal, ok := cl.porter.PortValue(uint16(appConn.local.Port)) + require.True(t, ok) + require.Nil(t, portVal) + }) + + t.Run("dial error", func(t *testing.T) { + dialErr := errors.New("dial error") + + rpc := &MockServerRPCClient{} + rpc.On("Dial", remote).Return(uint16(0), dialErr) + + cl := NewClient(localPK, pid, rpc, netutil.NewPorter(netutil.PorterMinEphemeral)) + + conn, err := cl.Dial(remote) + require.Equal(t, dialErr, err) + require.Nil(t, conn) + }) +} + +func TestClient_Listen(t *testing.T) { + localPK, _ := cipher.GenerateKeyPair() + pid := ProcID(1) + + port := routing.Port(1) + local := routing.Addr{ + PubKey: localPK, + Port: port, + } + + t.Run("ok", func(t *testing.T) { + listenLisID := uint16(1) + var listenErr error + + rpc := &MockServerRPCClient{} + rpc.On("Listen", local).Return(listenLisID, listenErr) + + cl := NewClient(localPK, pid, rpc, netutil.NewPorter(netutil.PorterMinEphemeral)) + + wantListener := &Listener{ + id: listenLisID, + rpc: rpc, + addr: local, + } + + listener, err := cl.Listen(port) + require.Nil(t, err) + appListener, ok := listener.(*Listener) + require.True(t, ok) + require.Equal(t, wantListener.id, appListener.id) + require.Equal(t, wantListener.rpc, appListener.rpc) + require.Equal(t, wantListener.addr, appListener.addr) + require.NotNil(t, appListener.freePort) + portVal, ok := cl.porter.PortValue(uint16(port)) + require.True(t, ok) + require.Nil(t, portVal) + }) + + t.Run("port is already bound", func(t *testing.T) { + porter := netutil.NewPorter(netutil.PorterMinEphemeral) + ok, _ := porter.Reserve(uint16(port), nil) + require.True(t, ok) + + rpc := &MockServerRPCClient{} + + cl := NewClient(localPK, pid, rpc, porter) + + wantErr := ErrPortAlreadyBound + + listener, err := cl.Listen(port) + require.Equal(t, wantErr, err) + require.Nil(t, listener) + }) + + t.Run("listen error", func(t *testing.T) { + listenErr := errors.New("listen error") + + rpc := &MockServerRPCClient{} + rpc.On("Listen", local).Return(uint16(0), listenErr) + + cl := NewClient(localPK, pid, rpc, netutil.NewPorter(netutil.PorterMinEphemeral)) + + listener, err := cl.Listen(port) + require.Equal(t, listenErr, err) + require.Nil(t, listener) + _, ok := cl.porter.PortValue(uint16(port)) + require.False(t, ok) + }) +} diff --git a/pkg/app2/mock_server_rpc_client.go b/pkg/app2/mock_server_rpc_client.go new file mode 100644 index 0000000000..91d3731f9c --- /dev/null +++ b/pkg/app2/mock_server_rpc_client.go @@ -0,0 +1,151 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package app2 + +import mock "github.com/stretchr/testify/mock" +import routing "github.com/skycoin/skywire/pkg/routing" + +// MockServerRPCClient is an autogenerated mock type for the ServerRPCClient type +type MockServerRPCClient struct { + mock.Mock +} + +// Accept provides a mock function with given fields: lisID +func (_m *MockServerRPCClient) Accept(lisID uint16) (uint16, routing.Addr, error) { + ret := _m.Called(lisID) + + var r0 uint16 + if rf, ok := ret.Get(0).(func(uint16) uint16); ok { + r0 = rf(lisID) + } else { + r0 = ret.Get(0).(uint16) + } + + var r1 routing.Addr + if rf, ok := ret.Get(1).(func(uint16) routing.Addr); ok { + r1 = rf(lisID) + } else { + r1 = ret.Get(1).(routing.Addr) + } + + var r2 error + if rf, ok := ret.Get(2).(func(uint16) error); ok { + r2 = rf(lisID) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// CloseConn provides a mock function with given fields: id +func (_m *MockServerRPCClient) CloseConn(id uint16) error { + ret := _m.Called(id) + + var r0 error + if rf, ok := ret.Get(0).(func(uint16) error); ok { + r0 = rf(id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// CloseListener provides a mock function with given fields: id +func (_m *MockServerRPCClient) CloseListener(id uint16) error { + ret := _m.Called(id) + + var r0 error + if rf, ok := ret.Get(0).(func(uint16) error); ok { + r0 = rf(id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Dial provides a mock function with given fields: remote +func (_m *MockServerRPCClient) Dial(remote routing.Addr) (uint16, error) { + ret := _m.Called(remote) + + var r0 uint16 + if rf, ok := ret.Get(0).(func(routing.Addr) uint16); ok { + r0 = rf(remote) + } else { + r0 = ret.Get(0).(uint16) + } + + var r1 error + if rf, ok := ret.Get(1).(func(routing.Addr) error); ok { + r1 = rf(remote) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Listen provides a mock function with given fields: local +func (_m *MockServerRPCClient) Listen(local routing.Addr) (uint16, error) { + ret := _m.Called(local) + + var r0 uint16 + if rf, ok := ret.Get(0).(func(routing.Addr) uint16); ok { + r0 = rf(local) + } else { + r0 = ret.Get(0).(uint16) + } + + var r1 error + if rf, ok := ret.Get(1).(func(routing.Addr) error); ok { + r1 = rf(local) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Read provides a mock function with given fields: connID, b +func (_m *MockServerRPCClient) Read(connID uint16, b []byte) (int, error) { + ret := _m.Called(connID, b) + + var r0 int + if rf, ok := ret.Get(0).(func(uint16, []byte) int); ok { + r0 = rf(connID, b) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func(uint16, []byte) error); ok { + r1 = rf(connID, b) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Write provides a mock function with given fields: connID, b +func (_m *MockServerRPCClient) Write(connID uint16, b []byte) (int, error) { + ret := _m.Called(connID, b) + + var r0 int + if rf, ok := ret.Get(0).(func(uint16, []byte) int); ok { + r0 = rf(connID, b) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func(uint16, []byte) error); ok { + r1 = rf(connID, b) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/app2/server_rpc_client.go b/pkg/app2/server_rpc_client.go index a1e825733c..85dc1e3b39 100644 --- a/pkg/app2/server_rpc_client.go +++ b/pkg/app2/server_rpc_client.go @@ -6,6 +6,8 @@ import ( "github.com/skycoin/skywire/pkg/routing" ) +//go:generate mockery -name ServerRPCClient -case underscore -inpkg + // ServerRPCClient describes RPC interface to communicate with the server. type ServerRPCClient interface { Dial(remote routing.Addr) (uint16, error) @@ -22,8 +24,8 @@ type serverRPCCLient struct { rpc *rpc.Client } -// newServerRPCClient constructs new `serverRPCClient`. -func newServerRPCClient(rpc *rpc.Client) ServerRPCClient { +// NewServerRPCClient constructs new `serverRPCClient`. +func NewServerRPCClient(rpc *rpc.Client) ServerRPCClient { return &serverRPCCLient{ rpc: rpc, } From 0e4a24a760ac79dacdf7db14540d928862bc1e88 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Wed, 18 Sep 2019 10:38:03 +0300 Subject: [PATCH 24/43] Add conn tests --- pkg/app2/conn.go | 10 ++- pkg/app2/conn_test.go | 128 +++++++++++++++++++++++++++++ pkg/app2/mock_server_rpc_client.go | 19 +++-- pkg/app2/server_rpc_client.go | 10 +-- 4 files changed, 155 insertions(+), 12 deletions(-) create mode 100644 pkg/app2/conn_test.go diff --git a/pkg/app2/conn.go b/pkg/app2/conn.go index 75f735ad98..07cda88b33 100644 --- a/pkg/app2/conn.go +++ b/pkg/app2/conn.go @@ -18,7 +18,15 @@ type Conn struct { } func (c *Conn) Read(b []byte) (int, error) { - return c.rpc.Read(c.id, b) + n, readBytes, err := c.rpc.Read(c.id, b) + if err != nil { + return 0, err + } + + // TODO: check for slice border + copy(b[:n], readBytes[:n]) + + return n, err } func (c *Conn) Write(b []byte) (int, error) { diff --git a/pkg/app2/conn_test.go b/pkg/app2/conn_test.go new file mode 100644 index 0000000000..9ced6c05d7 --- /dev/null +++ b/pkg/app2/conn_test.go @@ -0,0 +1,128 @@ +package app2 + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConn_Read(t *testing.T) { + connID := uint16(1) + + t.Run("ok", func(t *testing.T) { + readBuff := make([]byte, 100) + readN := 20 + readBytes := make([]byte, 100) + for i := 0; i < readN; i++ { + readBytes[i] = 2 + } + var readErr error + + rpc := &MockServerRPCClient{} + rpc.On("Read", connID, readBuff).Return(readN, readBytes, readErr) + + conn := &Conn{ + id: connID, + rpc: rpc, + } + + n, err := conn.Read(readBuff) + require.NoError(t, err) + require.Equal(t, n, readN) + require.Equal(t, readBuff[:n], readBytes[:n]) + }) + + t.Run("read error", func(t *testing.T) { + readBuff := make([]byte, 100) + readN := 0 + var readBytes []byte + readErr := errors.New("read error") + + rpc := &MockServerRPCClient{} + rpc.On("Read", connID, readBuff).Return(readN, readBytes, readErr) + + conn := &Conn{ + id: connID, + rpc: rpc, + } + + n, err := conn.Read(readBuff) + require.Equal(t, readErr, err) + require.Equal(t, readN, n) + }) +} + +func TestConn_Write(t *testing.T) { + connID := uint16(1) + + t.Run("ok", func(t *testing.T) { + writeBuff := make([]byte, 100) + writeN := 20 + var writeErr error + + rpc := &MockServerRPCClient{} + rpc.On("Write", connID, writeBuff).Return(writeN, writeErr) + + conn := &Conn{ + id: connID, + rpc: rpc, + } + + n, err := conn.Write(writeBuff) + require.NoError(t, err) + require.Equal(t, writeN, n) + }) + + t.Run("write error", func(t *testing.T) { + writeBuff := make([]byte, 100) + writeN := 0 + writeErr := errors.New("write error") + + rpc := &MockServerRPCClient{} + rpc.On("Write", connID, writeBuff).Return(writeN, writeErr) + + conn := &Conn{ + id: connID, + rpc: rpc, + } + + n, err := conn.Write(writeBuff) + require.Equal(t, writeErr, err) + require.Equal(t, writeN, n) + }) +} + +func TestConn_Close(t *testing.T) { + connID := uint16(1) + + t.Run("ok", func(t *testing.T) { + var closeErr error + + rpc := &MockServerRPCClient{} + rpc.On("CloseConn", connID).Return(closeErr) + + conn := &Conn{ + id: connID, + rpc: rpc, + } + + err := conn.Close() + require.NoError(t, err) + }) + + t.Run("close error", func(t *testing.T) { + closeErr := errors.New("close error") + + rpc := &MockServerRPCClient{} + rpc.On("CloseConn", connID).Return(closeErr) + + conn := &Conn{ + id: connID, + rpc: rpc, + } + + err := conn.Close() + require.Equal(t, closeErr, err) + }) +} diff --git a/pkg/app2/mock_server_rpc_client.go b/pkg/app2/mock_server_rpc_client.go index 91d3731f9c..26fb17717e 100644 --- a/pkg/app2/mock_server_rpc_client.go +++ b/pkg/app2/mock_server_rpc_client.go @@ -109,7 +109,7 @@ func (_m *MockServerRPCClient) Listen(local routing.Addr) (uint16, error) { } // Read provides a mock function with given fields: connID, b -func (_m *MockServerRPCClient) Read(connID uint16, b []byte) (int, error) { +func (_m *MockServerRPCClient) Read(connID uint16, b []byte) (int, []byte, error) { ret := _m.Called(connID, b) var r0 int @@ -119,14 +119,23 @@ func (_m *MockServerRPCClient) Read(connID uint16, b []byte) (int, error) { r0 = ret.Get(0).(int) } - var r1 error - if rf, ok := ret.Get(1).(func(uint16, []byte) error); ok { + var r1 []byte + if rf, ok := ret.Get(1).(func(uint16, []byte) []byte); ok { r1 = rf(connID, b) } else { - r1 = ret.Error(1) + if ret.Get(1) != nil { + r1 = ret.Get(1).([]byte) + } } - return r0, r1 + var r2 error + if rf, ok := ret.Get(2).(func(uint16, []byte) error); ok { + r2 = rf(connID, b) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 } // Write provides a mock function with given fields: connID, b diff --git a/pkg/app2/server_rpc_client.go b/pkg/app2/server_rpc_client.go index 85dc1e3b39..71c34b99e9 100644 --- a/pkg/app2/server_rpc_client.go +++ b/pkg/app2/server_rpc_client.go @@ -14,7 +14,7 @@ type ServerRPCClient interface { Listen(local routing.Addr) (uint16, error) Accept(lisID uint16) (uint16, routing.Addr, error) Write(connID uint16, b []byte) (int, error) - Read(connID uint16, b []byte) (int, error) + Read(connID uint16, b []byte) (int, []byte, error) CloseConn(id uint16) error CloseListener(id uint16) error } @@ -77,15 +77,13 @@ func (c *serverRPCCLient) Write(connID uint16, b []byte) (int, error) { } // Read sends `Read` command to the server. -func (c *serverRPCCLient) Read(connID uint16, b []byte) (int, error) { +func (c *serverRPCCLient) Read(connID uint16, b []byte) (int, []byte, error) { var resp ReadResp if err := c.rpc.Call("Read", &connID, &resp); err != nil { - return 0, err + return 0, nil, err } - copy(b[:resp.N], resp.B[:resp.N]) - - return resp.N, nil + return resp.N, resp.B, nil } // CloseConn sends `CloseConn` command to the server. From f43c6bf4f6ec55da0c390a9986940a98caf8f30b Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Wed, 18 Sep 2019 13:28:26 +0300 Subject: [PATCH 25/43] Add `Listener` tests --- pkg/app2/listener_test.go | 112 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 pkg/app2/listener_test.go diff --git a/pkg/app2/listener_test.go b/pkg/app2/listener_test.go new file mode 100644 index 0000000000..e9336d2788 --- /dev/null +++ b/pkg/app2/listener_test.go @@ -0,0 +1,112 @@ +package app2 + +import ( + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skywire/pkg/routing" +) + +func TestListener_Accept(t *testing.T) { + lisID := uint16(1) + localPK, _ := cipher.GenerateKeyPair() + local := routing.Addr{ + PubKey: localPK, + Port: routing.Port(100), + } + + t.Run("ok", func(t *testing.T) { + acceptConnID := uint16(1) + remotePK, _ := cipher.GenerateKeyPair() + acceptRemote := routing.Addr{ + PubKey: remotePK, + Port: routing.Port(100), + } + var acceptErr error + + rpc := &MockServerRPCClient{} + rpc.On("Accept", acceptConnID).Return(acceptConnID, acceptRemote, acceptErr) + + lis := &Listener{ + id: lisID, + rpc: rpc, + addr: local, + } + + wantConn := &Conn{ + id: acceptConnID, + rpc: rpc, + local: local, + remote: acceptRemote, + } + + conn, err := lis.Accept() + require.NoError(t, err) + require.Equal(t, conn, wantConn) + }) + + t.Run("accept error", func(t *testing.T) { + acceptConnID := uint16(0) + acceptRemote := routing.Addr{} + acceptErr := errors.New("accept error") + + rpc := &MockServerRPCClient{} + rpc.On("Accept", lisID).Return(acceptConnID, acceptRemote, acceptErr) + + lis := &Listener{ + id: lisID, + rpc: rpc, + addr: local, + } + + conn, err := lis.Accept() + require.Equal(t, acceptErr, err) + require.Nil(t, conn) + }) +} + +func TestListener_Close(t *testing.T) { + lisID := uint16(1) + localPK, _ := cipher.GenerateKeyPair() + local := routing.Addr{ + PubKey: localPK, + Port: routing.Port(100), + } + + t.Run("ok", func(t *testing.T) { + var closeErr error + + rpc := &MockServerRPCClient{} + rpc.On("CloseListener", lisID).Return(closeErr) + + lis := &Listener{ + id: lisID, + rpc: rpc, + addr: local, + freePort: func() {}, + } + + err := lis.Close() + require.NoError(t, err) + }) + + t.Run("close error", func(t *testing.T) { + closeErr := errors.New("close error") + + rpc := &MockServerRPCClient{} + rpc.On("CloseListener", lisID).Return(closeErr) + + lis := &Listener{ + id: lisID, + rpc: rpc, + addr: local, + freePort: func() {}, + } + + err := lis.Close() + require.Equal(t, closeErr, err) + }) +} From 7ef68617632c3a455b55b8d0cd2fa330b5652286 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Wed, 18 Sep 2019 14:01:00 +0300 Subject: [PATCH 26/43] Refactor tests a bit --- pkg/app2/conn_test.go | 204 ++++++++++++++++++-------------------- pkg/app2/listener_test.go | 60 ++++++----- 2 files changed, 127 insertions(+), 137 deletions(-) diff --git a/pkg/app2/conn_test.go b/pkg/app2/conn_test.go index 9ced6c05d7..d25de2716b 100644 --- a/pkg/app2/conn_test.go +++ b/pkg/app2/conn_test.go @@ -10,119 +10,113 @@ import ( func TestConn_Read(t *testing.T) { connID := uint16(1) - t.Run("ok", func(t *testing.T) { - readBuff := make([]byte, 100) - readN := 20 - readBytes := make([]byte, 100) - for i := 0; i < readN; i++ { - readBytes[i] = 2 - } - var readErr error - - rpc := &MockServerRPCClient{} - rpc.On("Read", connID, readBuff).Return(readN, readBytes, readErr) - - conn := &Conn{ - id: connID, - rpc: rpc, - } - - n, err := conn.Read(readBuff) - require.NoError(t, err) - require.Equal(t, n, readN) - require.Equal(t, readBuff[:n], readBytes[:n]) - }) - - t.Run("read error", func(t *testing.T) { - readBuff := make([]byte, 100) - readN := 0 - var readBytes []byte - readErr := errors.New("read error") - - rpc := &MockServerRPCClient{} - rpc.On("Read", connID, readBuff).Return(readN, readBytes, readErr) - - conn := &Conn{ - id: connID, - rpc: rpc, - } - - n, err := conn.Read(readBuff) - require.Equal(t, readErr, err) - require.Equal(t, readN, n) - }) + tt := []struct { + name string + readBuff []byte + readN int + readBytes []byte + readErr error + wantBuff []byte + }{ + { + name: "ok", + readBuff: make([]byte, 10), + readN: 2, + readBytes: []byte{1, 1, 0, 0, 0, 0, 0, 0, 0, 0}, + wantBuff: []byte{1, 1, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + { + name: "read error", + readBuff: make([]byte, 10), + readErr: errors.New("read error"), + wantBuff: make([]byte, 10), + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + rpc := &MockServerRPCClient{} + rpc.On("Read", connID, tc.readBuff).Return(tc.readN, tc.readBytes, tc.readErr) + + conn := &Conn{ + id: connID, + rpc: rpc, + } + + n, err := conn.Read(tc.readBuff) + require.Equal(t, tc.readErr, err) + require.Equal(t, tc.readN, n) + require.Equal(t, tc.wantBuff, tc.readBuff) + }) + } } func TestConn_Write(t *testing.T) { connID := uint16(1) - t.Run("ok", func(t *testing.T) { - writeBuff := make([]byte, 100) - writeN := 20 - var writeErr error - - rpc := &MockServerRPCClient{} - rpc.On("Write", connID, writeBuff).Return(writeN, writeErr) - - conn := &Conn{ - id: connID, - rpc: rpc, - } - - n, err := conn.Write(writeBuff) - require.NoError(t, err) - require.Equal(t, writeN, n) - }) - - t.Run("write error", func(t *testing.T) { - writeBuff := make([]byte, 100) - writeN := 0 - writeErr := errors.New("write error") - - rpc := &MockServerRPCClient{} - rpc.On("Write", connID, writeBuff).Return(writeN, writeErr) - - conn := &Conn{ - id: connID, - rpc: rpc, - } - - n, err := conn.Write(writeBuff) - require.Equal(t, writeErr, err) - require.Equal(t, writeN, n) - }) + tt := []struct { + name string + writeBuff []byte + writeN int + writeErr error + }{ + { + name: "ok", + writeBuff: make([]byte, 10), + writeN: 2, + }, + { + name: "write error", + writeBuff: make([]byte, 10), + writeErr: errors.New("write error"), + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + rpc := &MockServerRPCClient{} + rpc.On("Write", connID, tc.writeBuff).Return(tc.writeN, tc.writeErr) + + conn := &Conn{ + id: connID, + rpc: rpc, + } + + n, err := conn.Write(tc.writeBuff) + require.Equal(t, tc.writeErr, err) + require.Equal(t, tc.writeN, n) + }) + } } func TestConn_Close(t *testing.T) { connID := uint16(1) - t.Run("ok", func(t *testing.T) { - var closeErr error - - rpc := &MockServerRPCClient{} - rpc.On("CloseConn", connID).Return(closeErr) - - conn := &Conn{ - id: connID, - rpc: rpc, - } - - err := conn.Close() - require.NoError(t, err) - }) - - t.Run("close error", func(t *testing.T) { - closeErr := errors.New("close error") - - rpc := &MockServerRPCClient{} - rpc.On("CloseConn", connID).Return(closeErr) - - conn := &Conn{ - id: connID, - rpc: rpc, - } - - err := conn.Close() - require.Equal(t, closeErr, err) - }) + tt := []struct { + name string + closeErr error + }{ + { + name: "ok", + }, + { + name: "close error", + closeErr: errors.New("close error"), + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + rpc := &MockServerRPCClient{} + rpc.On("CloseConn", connID).Return(tc.closeErr) + + conn := &Conn{ + id: connID, + rpc: rpc, + } + + err := conn.Close() + require.Equal(t, tc.closeErr, err) + }) + } } diff --git a/pkg/app2/listener_test.go b/pkg/app2/listener_test.go index e9336d2788..e43b384f86 100644 --- a/pkg/app2/listener_test.go +++ b/pkg/app2/listener_test.go @@ -76,37 +76,33 @@ func TestListener_Close(t *testing.T) { Port: routing.Port(100), } - t.Run("ok", func(t *testing.T) { - var closeErr error - - rpc := &MockServerRPCClient{} - rpc.On("CloseListener", lisID).Return(closeErr) - - lis := &Listener{ - id: lisID, - rpc: rpc, - addr: local, - freePort: func() {}, - } - - err := lis.Close() - require.NoError(t, err) - }) - - t.Run("close error", func(t *testing.T) { - closeErr := errors.New("close error") - - rpc := &MockServerRPCClient{} - rpc.On("CloseListener", lisID).Return(closeErr) - - lis := &Listener{ - id: lisID, - rpc: rpc, - addr: local, - freePort: func() {}, - } + tt := []struct { + name string + closeErr error + }{ + { + name: "ok", + }, + { + name: "close error", + closeErr: errors.New("close error"), + }, + } - err := lis.Close() - require.Equal(t, closeErr, err) - }) + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + rpc := &MockServerRPCClient{} + rpc.On("CloseListener", lisID).Return(tc.closeErr) + + lis := &Listener{ + id: lisID, + rpc: rpc, + addr: local, + freePort: func() {}, + } + + err := lis.Close() + require.Equal(t, tc.closeErr, err) + }) + } } From f5986d5514798a9b26ccd681001f3a50f548b9fb Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Fri, 20 Sep 2019 17:26:09 +0300 Subject: [PATCH 27/43] Add networker stuff --- go.mod | 2 + go.sum | 10 + pkg/app2/client.go | 1 + pkg/app2/network/addr.go | 28 ++ pkg/app2/network/networker.go | 84 +++++ pkg/app2/network/type.go | 21 ++ vendor/github.com/skycoin/dmsg/client.go | 340 ++---------------- vendor/github.com/skycoin/dmsg/client_conn.go | 295 +++++++++++++++ vendor/github.com/skycoin/dmsg/listener.go | 124 +++++++ .../github.com/skycoin/dmsg/netutil/porter.go | 102 ++++++ .../github.com/skycoin/dmsg/port_manager.go | 66 ++++ vendor/github.com/skycoin/dmsg/server.go | 232 ------------ vendor/github.com/skycoin/dmsg/server_conn.go | 243 +++++++++++++ vendor/github.com/skycoin/dmsg/testing.go | 11 +- vendor/github.com/skycoin/dmsg/transport.go | 57 +-- .../skycoin/dmsg/{frame.go => types.go} | 79 +++- vendor/modules.txt | 3 +- 17 files changed, 1122 insertions(+), 576 deletions(-) create mode 100644 pkg/app2/network/addr.go create mode 100644 pkg/app2/network/networker.go create mode 100644 pkg/app2/network/type.go create mode 100644 vendor/github.com/skycoin/dmsg/client_conn.go create mode 100644 vendor/github.com/skycoin/dmsg/listener.go create mode 100644 vendor/github.com/skycoin/dmsg/netutil/porter.go create mode 100644 vendor/github.com/skycoin/dmsg/port_manager.go create mode 100644 vendor/github.com/skycoin/dmsg/server_conn.go rename vendor/github.com/skycoin/dmsg/{frame.go => types.go} (66%) diff --git a/go.mod b/go.mod index e7947d1c40..c0cfa4c314 100644 --- a/go.mod +++ b/go.mod @@ -26,3 +26,5 @@ require ( golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7 golang.org/x/net v0.0.0-20190916140828-c8589233b77d ) + +replace github.com/skycoin/dmsg => ../dmsg diff --git a/go.sum b/go.sum index 13c080a752..b9ca83793d 100644 --- a/go.sum +++ b/go.sum @@ -58,10 +58,13 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxv github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.2 h1:/bC9yWikZXAL9uJdulbSfyVNIR3n3trXl+v8+1sx8mU= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= @@ -101,14 +104,18 @@ github.com/prometheus/procfs v0.0.3 h1:CTwfnzjQ+8dS6MhHHu4YswVAD99sL2wjPqP+VkURm github.com/prometheus/procfs v0.0.3/go.mod h1:4A/X28fw3Fc593LaREMrKMqOKvUAntwMDaekg4FpcdQ= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/skycoin/dmsg v0.0.0-20190805065636-70f4c32a994f h1:WWjaxOXoj6oYelm67MNtJbg51HQALjKAyhs2WAHgpZs= github.com/skycoin/dmsg v0.0.0-20190805065636-70f4c32a994f/go.mod h1:obZYZp8eKR7Xqz+KNhJdUE6Gvp6rEXbDO8YTlW2YXgU= +github.com/skycoin/skycoin v0.25.1/go.mod h1:78nHjQzd8KG0jJJVL/j0xMmrihXi70ti63fh8vXScJw= github.com/skycoin/skycoin v0.26.0 h1:xDxe2r8AclMntZ550Y/vUQgwgLtwrf9Wu5UYiYcN5/o= github.com/skycoin/skycoin v0.26.0/go.mod h1:78nHjQzd8KG0jJJVL/j0xMmrihXi70ti63fh8vXScJw= +github.com/skycoin/skywire v0.1.1/go.mod h1:jDuUgTG20jhiBI6Trpayj0my6xhdS+ejEO9gTSM+C/E= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= +github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= github.com/spf13/cobra v0.0.5 h1:f0B+LkLX6DtmRH1isoNA9VTtNUK9K8xYd28JNNfOv/s= github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= @@ -124,16 +131,19 @@ github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJy github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= +go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.3 h1:MUGmc65QhB3pIlaQ5bB4LwqSj6GIonVJXpZiaKNyaKk= go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7 h1:0hQKqeLdqlt5iIwVOBErRisrHJAN57yOiPRQItI20fU= golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= diff --git a/pkg/app2/client.go b/pkg/app2/client.go index 3d302c3024..53fd6b0088 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -37,6 +37,7 @@ func (c *Client) Dial(remote routing.Addr) (net.Conn, error) { if err != nil { return nil, err } + net.Dial() connID, err := c.rpc.Dial(remote) if err != nil { diff --git a/pkg/app2/network/addr.go b/pkg/app2/network/addr.go new file mode 100644 index 0000000000..36fec8306d --- /dev/null +++ b/pkg/app2/network/addr.go @@ -0,0 +1,28 @@ +package network + +import ( + "fmt" + + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skywire/pkg/routing" +) + +// Addr implements net.Addr for network addresses. +type Addr struct { + Net Type + PubKey cipher.PubKey + Port routing.Port +} + +// Network returns "dmsg" +func (a Addr) Network() string { + return string(a.Net) +} + +// String returns public key and port of node split by colon. +func (a Addr) String() string { + if a.Port == 0 { + return fmt.Sprintf("%s:~", a.PubKey) + } + return fmt.Sprintf("%s:%d", a.PubKey, a.Port) +} diff --git a/pkg/app2/network/networker.go b/pkg/app2/network/networker.go new file mode 100644 index 0000000000..34b320f4b7 --- /dev/null +++ b/pkg/app2/network/networker.go @@ -0,0 +1,84 @@ +package network + +import ( + "context" + "errors" + "net" + "sync" +) + +var ( + // ErrNoSuchNetworker is being returned when there's no suitable networker. + ErrNoSuchNetworker = errors.New("no such networker") + // ErrNetworkerAlreadyExists is being returned when there's already one with such Network type. + ErrNetworkerAlreadyExists = errors.New("networker already exists") +) + +var ( + networkers = map[Type]Networker{} + networkersMx sync.RWMutex +) + +// AddNetworker associated Networker with the `network`. +func AddNetworker(t Type, n Networker) error { + networkersMx.Lock() + defer networkersMx.Unlock() + + if _, ok := networkers[t]; ok { + return ErrNetworkerAlreadyExists + } + + networkers[t] = n + + return nil +} + +// ResolveNetworker resolves Networker by `network`. +func ResolveNetworker(t Type) (Networker, error) { + networkersMx.RLock() + n, ok := networkers[t] + if !ok { + networkersMx.RUnlock() + return nil, ErrNoSuchNetworker + } + networkersMx.RUnlock() + return n, nil +} + +// Networker defines basic network operations, such as Dial/Listen. +type Networker interface { + Dial(addr Addr) (net.Conn, error) + DialContext(ctx context.Context, addr Addr) (net.Conn, error) + Listen(addr Addr) (net.Listener, error) + ListenContext(ctx context.Context, addr Addr) (net.Listener, error) +} + +// Dial dials the remote `addr` of the specified `network`. +func Dial(t Type, addr Addr) (net.Conn, error) { + return DialContext(context.Background(), t, addr) +} + +// DialContext dials the remote `Addr` of the specified `network` with the context. +func DialContext(ctx context.Context, t Type, addr Addr) (net.Conn, error) { + n, err := ResolveNetworker(t) + if err != nil { + return nil, err + } + + return n.DialContext(ctx, addr) +} + +// Listen starts listening on the local `addr` of the specified `network`. +func Listen(t Type, addr Addr) (net.Listener, error) { + return ListenContext(context.Background(), t, addr) +} + +// ListenContext starts listening on the local `addr` of the specified `network` with the context. +func ListenContext(ctx context.Context, t Type, addr Addr) (net.Listener, error) { + networker, err := ResolveNetworker(t) + if err != nil { + return nil, err + } + + return networker.ListenContext(ctx, addr) +} diff --git a/pkg/app2/network/type.go b/pkg/app2/network/type.go new file mode 100644 index 0000000000..d91c6a0dcc --- /dev/null +++ b/pkg/app2/network/type.go @@ -0,0 +1,21 @@ +package network + +// Type represents the network type. +type Type string + +const ( + // TypeDMSG is a network type for DMSG communication. + TypeDMSG = "dmsg" +) + +// IsValid checks whether the network contains valid value for the type. +func (n Type) IsValid() bool { + _, ok := validNetworks[n] + return ok +} + +var ( + validNetworks = map[Type]struct{}{ + TypeDMSG: {}, + } +) diff --git a/vendor/github.com/skycoin/dmsg/client.go b/vendor/github.com/skycoin/dmsg/client.go index 8e6a20fede..f09ffefde9 100644 --- a/vendor/github.com/skycoin/dmsg/client.go +++ b/vendor/github.com/skycoin/dmsg/client.go @@ -8,7 +8,6 @@ import ( "sync" "time" - "github.com/sirupsen/logrus" "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/dmsg/cipher" @@ -31,270 +30,6 @@ var ( ErrClientAcceptMaxed = errors.New("client accepts buffer maxed") ) -// ClientConn represents a connection between a dmsg.Client and dmsg.Server from a client's perspective. -type ClientConn struct { - log *logging.Logger - - net.Conn // conn to dmsg server - local cipher.PubKey // local client's pk - remoteSrv cipher.PubKey // dmsg server's public key - - // nextInitID keeps track of unused tp_ids to assign a future locally-initiated tp. - // locally-initiated tps use an even tp_id between local and intermediary dms_server. - nextInitID uint16 - - // Transports: map of transports to remote dms_clients (key: tp_id, val: transport). - tps map[uint16]*Transport - mx sync.RWMutex // to protect tps - - done chan struct{} - once sync.Once - wg sync.WaitGroup -} - -// NewClientConn creates a new ClientConn. -func NewClientConn(log *logging.Logger, conn net.Conn, local, remote cipher.PubKey) *ClientConn { - cc := &ClientConn{ - log: log, - Conn: conn, - local: local, - remoteSrv: remote, - nextInitID: randID(true), - tps: make(map[uint16]*Transport), - done: make(chan struct{}), - } - cc.wg.Add(1) - return cc -} - -// RemotePK returns the remote Server's PK that the ClientConn is connected to. -func (c *ClientConn) RemotePK() cipher.PubKey { return c.remoteSrv } - -func (c *ClientConn) getNextInitID(ctx context.Context) (uint16, error) { - for { - select { - case <-c.done: - return 0, ErrClientClosed - case <-ctx.Done(): - return 0, ctx.Err() - default: - if ch := c.tps[c.nextInitID]; ch != nil && !ch.IsClosed() { - c.nextInitID += 2 - continue - } - c.tps[c.nextInitID] = nil - id := c.nextInitID - c.nextInitID = id + 2 - return id, nil - } - } -} - -func (c *ClientConn) addTp(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) { - c.mx.Lock() - defer c.mx.Unlock() - - id, err := c.getNextInitID(ctx) - if err != nil { - return nil, err - } - tp := NewTransport(c.Conn, c.log, c.local, clientPK, id, c.delTp) - c.tps[id] = tp - return tp, nil -} - -func (c *ClientConn) setTp(tp *Transport) { - c.mx.Lock() - c.tps[tp.id] = tp - c.mx.Unlock() -} - -func (c *ClientConn) delTp(id uint16) { - c.mx.Lock() - c.tps[id] = nil - c.mx.Unlock() -} - -func (c *ClientConn) getTp(id uint16) (*Transport, bool) { - c.mx.RLock() - tp := c.tps[id] - c.mx.RUnlock() - ok := tp != nil && !tp.IsClosed() - return tp, ok -} - -func (c *ClientConn) setNextInitID(nextInitID uint16) { - c.mx.Lock() - c.nextInitID = nextInitID - c.mx.Unlock() -} - -func (c *ClientConn) readOK() error { - fr, err := readFrame(c.Conn) - if err != nil { - return errors.New("failed to get OK from server") - } - - ft, _, _ := fr.Disassemble() - if ft != OkType { - return fmt.Errorf("wrong frame from server: %v", ft) - } - - return nil -} - -func (c *ClientConn) handleRequestFrame(accept chan<- *Transport, id uint16, p []byte) (cipher.PubKey, error) { - // remotely-initiated tps should: - // - have a payload structured as 'init_pk:resp_pk'. - // - resp_pk should be of local client. - // - use an odd tp_id with the intermediary dmsg_server. - initPK, respPK, ok := splitPKs(p) - if !ok || respPK != c.local || isInitiatorID(id) { - if err := writeCloseFrame(c.Conn, id, 0); err != nil { - return initPK, err - } - return initPK, ErrRequestCheckFailed - } - - tp := NewTransport(c.Conn, c.log, c.local, initPK, id, c.delTp) - - select { - case <-c.done: - if err := tp.Close(); err != nil { - log.WithError(err).Warn("Failed to close transport") - } - return initPK, ErrClientClosed - default: - select { - case accept <- tp: - c.setTp(tp) - if err := tp.WriteAccept(); err != nil { - return initPK, err - } - go tp.Serve() - return initPK, nil - - default: - if err := tp.Close(); err != nil { - log.WithError(err).Warn("Failed to close transport") - } - return initPK, ErrClientAcceptMaxed - } - } -} - -// Serve handles incoming frames. -// Remote-initiated tps that are successfully created are pushing into 'accept' and exposed via 'Client.Accept()'. -func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err error) { - log := c.log.WithField("remoteServer", c.remoteSrv) - log.WithField("connCount", incrementServeCount()).Infoln("ServingConn") - defer func() { - c.close() - log.WithError(err).WithField("connCount", decrementServeCount()).Infoln("ConnectionClosed") - c.wg.Done() - }() - - for { - f, err := readFrame(c.Conn) - if err != nil { - return fmt.Errorf("read failed: %s", err) - } - log = log.WithField("received", f) - - ft, id, p := f.Disassemble() - - // If tp of tp_id exists, attempt to forward frame to tp. - // delete tp on any failure. - - if tp, ok := c.getTp(id); ok { - if err := tp.HandleFrame(f); err != nil { - log.WithError(err).Warnf("Rejected [%s]: Transport closed.", ft) - } - continue - } - - // if tp does not exist, frame should be 'REQUEST'. - // otherwise, handle any unexpected frames accordingly. - - c.delTp(id) // rm tp in case closed tp is not fully removed. - - switch ft { - case RequestType: - c.wg.Add(1) - go func(log *logrus.Entry) { - defer c.wg.Done() - initPK, err := c.handleRequestFrame(accept, id, p) - if err != nil { - log.WithField("remoteClient", initPK).WithError(err).Infoln("Rejected [REQUEST]") - if isWriteError(err) || err == ErrClientClosed { - err := c.Close() - log.WithError(err).Warn("ClosingConnection") - } - return - } - log.WithField("remoteClient", initPK).Infoln("Accepted [REQUEST]") - }(log) - - default: - log.Debugf("Ignored [%s]: No transport of given ID.", ft) - if ft != CloseType { - if err := writeCloseFrame(c.Conn, id, 0); err != nil { - return err - } - } - } - } -} - -// DialTransport dials a transport to remote dms_client. -func (c *ClientConn) DialTransport(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) { - tp, err := c.addTp(ctx, clientPK) - if err != nil { - return nil, err - } - if err := tp.WriteRequest(); err != nil { - return nil, err - } - if err := tp.ReadAccept(ctx); err != nil { - return nil, err - } - go tp.Serve() - return tp, nil -} - -func (c *ClientConn) close() (closed bool) { - if c == nil { - return false - } - c.once.Do(func() { - closed = true - c.log.WithField("remoteServer", c.remoteSrv).Infoln("ClosingConnection") - close(c.done) - c.mx.Lock() - for _, tp := range c.tps { - tp := tp - go func() { - if err := tp.Close(); err != nil { - log.WithError(err).Warn("Failed to close transport") - } - }() - } - if err := c.Conn.Close(); err != nil { - log.WithError(err).Warn("Failed to close connection") - } - c.mx.Unlock() - }) - return closed -} - -// Close closes the connection to dms_server. -func (c *ClientConn) Close() error { - if c.close() { - c.wg.Wait() - } - return nil -} - // ClientOption represents an optional argument for Client. type ClientOption func(c *Client) error @@ -320,21 +55,22 @@ type Client struct { conns map[cipher.PubKey]*ClientConn // conns with messaging servers. Key: pk of server mx sync.RWMutex - accept chan *Transport - done chan struct{} - once sync.Once + pm *PortManager + + done chan struct{} + once sync.Once } // NewClient creates a new Client. func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc disc.APIClient, opts ...ClientOption) *Client { c := &Client{ - log: logging.MustGetLogger("dmsg_client"), - pk: pk, - sk: sk, - dc: dc, - conns: make(map[cipher.PubKey]*ClientConn), - accept: make(chan *Transport, AcceptBufferSize), - done: make(chan struct{}), + log: logging.MustGetLogger("dmsg_client"), + pk: pk, + sk: sk, + dc: dc, + conns: make(map[cipher.PubKey]*ClientConn), + pm: newPortManager(pk), + done: make(chan struct{}), } for _, opt := range opts { if err := opt(c); err != nil { @@ -364,7 +100,7 @@ func (c *Client) updateDiscEntry(ctx context.Context) error { func (c *Client) setConn(ctx context.Context, conn *ClientConn) { c.mx.Lock() - c.conns[conn.remoteSrv] = conn + c.conns[conn.srvPK] = conn if err := c.updateDiscEntry(ctx); err != nil { c.log.WithError(err).Warn("updateEntry: failed") } @@ -403,7 +139,7 @@ func (c *Client) InitiateServerConnections(ctx context.Context, min int) error { if err != nil { return err } - c.log.Info("found dms_server entries:", entries) + c.log.Info("found dmsg.Server entries:", entries) if err := c.findOrConnectToServers(ctx, entries, min); err != nil { return err } @@ -419,7 +155,7 @@ func (c *Client) findServerEntries(ctx context.Context) ([]*disc.Entry, error) { return nil, fmt.Errorf("dms_servers are not available: %s", err) default: retry := time.Second - c.log.WithError(err).Warnf("no dms_servers found: trying again in %d second...", retry) + c.log.WithError(err).Warnf("no dms_servers found: trying again in %v...", retry) time.Sleep(retry) continue } @@ -474,7 +210,7 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey) return nil, err } - conn := NewClientConn(c.log, nc, c.pk, srvPK) + conn := NewClientConn(c.log, c.pm, nc, c.pk, srvPK) if err := conn.readOK(); err != nil { return nil, err } @@ -482,7 +218,7 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey) c.setConn(ctx, conn) go func() { - err := conn.Serve(ctx, c.accept) + err := conn.Serve(ctx) conn.log.WithError(err).WithField("remoteServer", srvPK).Warn("connected with server closed") c.delConn(ctx, srvPK) @@ -503,23 +239,17 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey) return conn, nil } -// Accept accepts remotely-initiated tps. -func (c *Client) Accept(ctx context.Context) (*Transport, error) { - select { - case tp, ok := <-c.accept: - if !ok { - return nil, ErrClientClosed - } - return tp, nil - case <-c.done: - return nil, ErrClientClosed - case <-ctx.Done(): - return nil, ctx.Err() +// Listen creates a listener on a given port, adds it to port manager and returns the listener. +func (c *Client) Listen(port uint16) (*Listener, error) { + l, ok := c.pm.NewListener(port) + if !ok { + return nil, errors.New("port is busy") } + return l, nil } // Dial dials a transport to remote dms_client. -func (c *Client) Dial(ctx context.Context, remote cipher.PubKey) (*Transport, error) { +func (c *Client) Dial(ctx context.Context, remote cipher.PubKey, port uint16) (*Transport, error) { entry, err := c.dc.Entry(ctx, remote) if err != nil { return nil, fmt.Errorf("get entry failure: %s", err) @@ -536,14 +266,16 @@ func (c *Client) Dial(ctx context.Context, remote cipher.PubKey) (*Transport, er c.log.WithError(err).Warn("failed to connect to server") continue } - return conn.DialTransport(ctx, remote) + return conn.DialTransport(ctx, remote, port) } return nil, errors.New("failed to find dms_servers for given client pk") } -// Local returns the local dms_client's public key. -func (c *Client) Local() cipher.PubKey { - return c.pk +// Addr returns the local dms_client's public key. +func (c *Client) Addr() net.Addr { + return Addr{ + PK: c.pk, + } } // Type returns the transport type. @@ -553,7 +285,7 @@ func (c *Client) Type() string { // Close closes the dms_client and associated connections. // TODO(evaninjin): proper error handling. -func (c *Client) Close() error { +func (c *Client) Close() (err error) { if c == nil { return nil } @@ -570,14 +302,8 @@ func (c *Client) Close() error { c.conns = make(map[cipher.PubKey]*ClientConn) c.mx.Unlock() - for { - select { - case <-c.accept: - default: - close(c.accept) - return - } - } + err = c.pm.Close() }) - return nil + + return err } diff --git a/vendor/github.com/skycoin/dmsg/client_conn.go b/vendor/github.com/skycoin/dmsg/client_conn.go new file mode 100644 index 0000000000..be48e6adbb --- /dev/null +++ b/vendor/github.com/skycoin/dmsg/client_conn.go @@ -0,0 +1,295 @@ +package dmsg + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + + "github.com/sirupsen/logrus" + "github.com/skycoin/skycoin/src/util/logging" + + "github.com/skycoin/dmsg/cipher" +) + +// ClientConn represents a connection between a dmsg.Client and dmsg.Server from a client's perspective. +type ClientConn struct { + log *logging.Logger + + net.Conn // conn to dmsg server + lPK cipher.PubKey // local client's pk + srvPK cipher.PubKey // dmsg server's public key + + // nextInitID keeps track of unused tp_ids to assign a future locally-initiated tp. + // locally-initiated tps use an even tp_id between local and intermediary dms_server. + nextInitID uint16 + + // Transports: map of transports to remote dms_clients (key: tp_id, val: transport). + tps map[uint16]*Transport + mx sync.RWMutex // to protect tps + + pm *PortManager + + done chan struct{} + once sync.Once + wg sync.WaitGroup +} + +// NewClientConn creates a new ClientConn. +func NewClientConn(log *logging.Logger, pm *PortManager, conn net.Conn, lPK, rPK cipher.PubKey) *ClientConn { + cc := &ClientConn{ + log: log, + Conn: conn, + lPK: lPK, + srvPK: rPK, + nextInitID: randID(true), + tps: make(map[uint16]*Transport), + pm: pm, + done: make(chan struct{}), + } + cc.wg.Add(1) + return cc +} + +// RemotePK returns the remote Server's PK that the ClientConn is connected to. +func (c *ClientConn) RemotePK() cipher.PubKey { return c.srvPK } + +func (c *ClientConn) getNextInitID(ctx context.Context) (uint16, error) { + for { + select { + case <-c.done: + return 0, ErrClientClosed + case <-ctx.Done(): + return 0, ctx.Err() + default: + if ch := c.tps[c.nextInitID]; ch != nil && !ch.IsClosed() { + c.nextInitID += 2 + continue + } + c.tps[c.nextInitID] = nil + id := c.nextInitID + c.nextInitID = id + 2 + return id, nil + } + } +} + +func (c *ClientConn) addTp(ctx context.Context, rPK cipher.PubKey, lPort, rPort uint16, closeCB func()) (*Transport, error) { + c.mx.Lock() + defer c.mx.Unlock() + + id, err := c.getNextInitID(ctx) + if err != nil { + return nil, err + } + tp := NewTransport(c.Conn, c.log, Addr{c.lPK, lPort}, Addr{rPK, rPort}, id, func() { + c.delTp(id) + closeCB() + }) + c.tps[id] = tp + return tp, nil +} + +func (c *ClientConn) setTp(tp *Transport) { + c.mx.Lock() + c.tps[tp.id] = tp + c.mx.Unlock() +} + +func (c *ClientConn) delTp(id uint16) { + c.mx.Lock() + c.tps[id] = nil + c.mx.Unlock() +} + +func (c *ClientConn) getTp(id uint16) (*Transport, bool) { + c.mx.RLock() + tp := c.tps[id] + c.mx.RUnlock() + ok := tp != nil && !tp.IsClosed() + return tp, ok +} + +func (c *ClientConn) setNextInitID(nextInitID uint16) { + c.mx.Lock() + c.nextInitID = nextInitID + c.mx.Unlock() +} + +func (c *ClientConn) readOK() error { + _, df, err := readFrame(c.Conn) + if err != nil { + return errors.New("failed to get OK from server") + } + if df.Type != OkType { + return fmt.Errorf("wrong frame from server: %v", df.Type) + } + return nil +} + +// This handles 'REQUEST' frames which represent remotely-initiated tps. 'REQUEST' frames should: +// - have a HandshakePayload marshaled to JSON as payload. +// - have a resp_pk be of local client. +// - have an odd tp_id. +func (c *ClientConn) handleRequestFrame(log *logrus.Entry, id uint16, p []byte) (cipher.PubKey, error) { + + // The public key of the initiating client (or the client that sent the 'REQUEST' frame). + var initPK cipher.PubKey + + // Attempts to close tp due to given error. + // When we fail to close tp (a.k.a fail to send 'CLOSE' frame) or if the local client is closed, + // the connection to server should be closed. + // TODO(evanlinjin): derive close reason from error. + closeTp := func(origErr error) (cipher.PubKey, error) { + if err := writeCloseFrame(c.Conn, id, PlaceholderReason); err != nil { + log.WithError(err).Warn("handleRequestFrame: failed to close transport: ending conn to server.") + log.WithError(c.Close()).Warn("handleRequestFrame: closing connection to server.") + return initPK, origErr + } + switch origErr { + case ErrClientClosed: + log.WithError(c.Close()).Warn("handleRequestFrame: closing connection to server.") + } + return initPK, origErr + } + + pay, err := unmarshalHandshakePayload(p) + if err != nil { + return closeTp(ErrRequestCheckFailed) // TODO(nkryuchkov): reason = payload format is incorrect. + } + initPK = pay.InitAddr.PK + + if pay.RespAddr.PK != c.lPK || isInitiatorID(id) { + return closeTp(ErrRequestCheckFailed) // TODO(nkryuchkov): reason = payload is malformed. + } + lis, ok := c.pm.Listener(pay.RespAddr.Port) + if !ok { + return closeTp(ErrPortNotListening) // TODO(nkryuchkov): reason = port is not listening. + } + if c.isClosed() { + return closeTp(ErrClientClosed) // TODO(nkryuchkov): reason = client is closed. + } + + tp := NewTransport(c.Conn, c.log, pay.RespAddr, pay.InitAddr, id, func() { c.delTp(id) }) + if err := lis.IntroduceTransport(tp); err != nil { + return initPK, err + } + c.setTp(tp) + return initPK, nil +} + +// Serve handles incoming frames. +// Remote-initiated tps that are successfully created are pushing into 'accept' and exposed via 'Client.Accept()'. +func (c *ClientConn) Serve(ctx context.Context) (err error) { + log := c.log.WithField("remoteServer", c.srvPK) + log.WithField("connCount", incrementServeCount()).Infoln("ServingConn") + defer func() { + c.close() + log.WithError(err).WithField("connCount", decrementServeCount()).Infoln("ConnectionClosed") + c.wg.Done() + }() + + for { + f, df, err := readFrame(c.Conn) + if err != nil { + return fmt.Errorf("read failed: %s", err) + } + log = log.WithField("received", f) + + // If tp of tp_id exists, attempt to forward frame to tp. + // Delete tp on any failure. + if tp, ok := c.getTp(df.TpID); ok { + if err := tp.HandleFrame(f); err != nil { + log.WithError(err).Warnf("Rejected [%s]: Transport closed.", df.Type) + } + continue + } + c.delTp(df.TpID) // rm tp in case closed tp is not fully removed. + + // if tp does not exist, frame should be 'REQUEST'. + // otherwise, handle any unexpected frames accordingly. + switch df.Type { + case RequestType: + c.wg.Add(1) + go func(log *logrus.Entry) { + defer c.wg.Done() + if initPK, err := c.handleRequestFrame(log, df.TpID, df.Pay); err != nil { + log.WithField("remoteClient", initPK).WithError(err).Warn("Rejected [REQUEST]") + } else { + log.WithField("remoteClient", initPK).Info("Accepted [REQUEST]") + } + }(log) + + default: + log.Debugf("Ignored [%s]: No transport of given ID.", df.Type) + if df.Type != CloseType { + if err := writeCloseFrame(c.Conn, df.TpID, PlaceholderReason); err != nil { + return err + } + } + } + } +} + +// DialTransport dials a transport to remote dms_client. +func (c *ClientConn) DialTransport(ctx context.Context, rPK cipher.PubKey, rPort uint16) (*Transport, error) { + lPort, closeCB, err := c.pm.ReserveEphemeral(ctx) + if err != nil { + return nil, err + } + tp, err := c.addTp(ctx, rPK, lPort, rPort, closeCB) // TODO: Have proper local port. + if err != nil { + return nil, err + } + if err := tp.WriteRequest(); err != nil { + return nil, err + } + if err := tp.ReadAccept(ctx); err != nil { + return nil, err + } + go tp.Serve() + return tp, nil +} + +func (c *ClientConn) close() (closed bool) { + if c == nil { + return false + } + c.once.Do(func() { + closed = true + c.log.WithField("remoteServer", c.srvPK).Infoln("ClosingConnection") + close(c.done) + c.mx.Lock() + for _, tp := range c.tps { + tp := tp + go func() { + if err := tp.Close(); err != nil { + log.WithError(err).Warn("Failed to close transport") + } + }() + } + if err := c.Conn.Close(); err != nil { + log.WithError(err).Warn("Failed to close connection") + } + c.mx.Unlock() + }) + return closed +} + +// Close closes the connection to dms_server. +func (c *ClientConn) Close() error { + if c.close() { + c.wg.Wait() + } + return nil +} + +func (c *ClientConn) isClosed() bool { + select { + case <-c.done: + return true + default: + return false + } +} diff --git a/vendor/github.com/skycoin/dmsg/listener.go b/vendor/github.com/skycoin/dmsg/listener.go new file mode 100644 index 0000000000..3fc6f48a46 --- /dev/null +++ b/vendor/github.com/skycoin/dmsg/listener.go @@ -0,0 +1,124 @@ +package dmsg + +import ( + "fmt" + "net" + "sync" +) + +// Listener listens for remote-initiated transports. +type Listener struct { + addr Addr // local listening address + + accept chan *Transport + mx sync.Mutex // protects 'accept' + + doneFunc func() // callback when done + done chan struct{} + once sync.Once +} + +func newListener(addr Addr) *Listener { + return &Listener{ + addr: addr, + accept: make(chan *Transport, AcceptBufferSize), + done: make(chan struct{}), + } +} + +// AddCloseCallback adds a function that triggers when listener is closed. +// This should be called right after the listener is created and is not thread safe. +func (l *Listener) AddCloseCallback(cb func()) { l.doneFunc = cb } + +// IntroduceTransport handles a transport after receiving a REQUEST frame. +func (l *Listener) IntroduceTransport(tp *Transport) error { + if tp.LocalAddr() != l.addr { + return fmt.Errorf("failed to accept transport as local addresses does not match: we expected %s but got %s", + l.addr, tp.LocalAddr()) + } + + l.mx.Lock() + defer l.mx.Unlock() + + if l.isClosed() { + return ErrClientClosed + } + + select { + case <-l.done: + return ErrClientClosed + + case l.accept <- tp: + if err := tp.WriteAccept(); err != nil { + return err + } + go tp.Serve() + return nil + + default: + _ = tp.Close() //nolint:errcheck + return ErrClientAcceptMaxed + } +} + +// Accept accepts a connection. +func (l *Listener) Accept() (net.Conn, error) { + return l.AcceptTransport() +} + +// AcceptTransport accepts a transport connection. +func (l *Listener) AcceptTransport() (*Transport, error) { + select { + case <-l.done: + return nil, ErrClientClosed + case tp, ok := <-l.accept: + if !ok { + return nil, ErrClientClosed + } + return tp, nil + } +} + +// Close closes the listener. +func (l *Listener) Close() error { + if l.close() { + return nil + } + return ErrClientClosed +} + +func (l *Listener) close() (closed bool) { + l.once.Do(func() { + closed = true + l.doneFunc() + + l.mx.Lock() + defer l.mx.Unlock() + + close(l.done) + for { + select { + case <-l.accept: + default: + close(l.accept) + return + } + } + }) + return closed +} + +func (l *Listener) isClosed() bool { + select { + case <-l.done: + return true + default: + return false + } +} + +// Addr returns the listener's address. +func (l *Listener) Addr() net.Addr { return l.addr } + +// Type returns the transport type. +func (l *Listener) Type() string { return Type } diff --git a/vendor/github.com/skycoin/dmsg/netutil/porter.go b/vendor/github.com/skycoin/dmsg/netutil/porter.go new file mode 100644 index 0000000000..fb0d2c1b26 --- /dev/null +++ b/vendor/github.com/skycoin/dmsg/netutil/porter.go @@ -0,0 +1,102 @@ +package netutil + +import ( + "context" + "sync" +) + +const ( + // PorterMinEphemeral is the default minimum ephemeral port. + PorterMinEphemeral = uint16(49152) +) + +// Porter reserves ports. +type Porter struct { + sync.RWMutex + eph uint16 // current ephemeral value + minEph uint16 // minimal ephemeral port value + ports map[uint16]interface{} +} + +// NewPorter creates a new Porter with a given minimum ephemeral port value. +func NewPorter(minEph uint16) *Porter { + ports := make(map[uint16]interface{}) + ports[0] = struct{}{} // port 0 is invalid + + return &Porter{ + eph: minEph, + minEph: minEph, + ports: ports, + } +} + +// Reserve a given port. +// It returns a boolean informing whether the port is reserved, and a function to clear the reservation. +func (p *Porter) Reserve(port uint16, v interface{}) (bool, func()) { + p.Lock() + defer p.Unlock() + + if _, ok := p.ports[port]; ok { + return false, nil + } + p.ports[port] = v + return true, p.makePortFreer(port) +} + +// ReserveEphemeral reserves a new ephemeral port. +// It returns the reserved ephemeral port, a function to clear the reservation and an error (if any). +func (p *Porter) ReserveEphemeral(ctx context.Context, v interface{}) (uint16, func(), error) { + p.Lock() + defer p.Unlock() + + for { + p.eph++ + if p.eph < p.minEph { + p.eph = p.minEph + } + if _, ok := p.ports[p.eph]; ok { + select { + case <-ctx.Done(): + return 0, nil, ctx.Err() + default: + continue + } + } + p.ports[p.eph] = v + return p.eph, p.makePortFreer(p.eph), nil + } +} + +// PortValue returns the value stored under a given port. +func (p *Porter) PortValue(port uint16) (interface{}, bool) { + p.RLock() + defer p.RUnlock() + + v, ok := p.ports[port] + return v, ok +} + +// RangePortValues ranges all ports that are currently reserved. +func (p *Porter) RangePortValues(fn func(port uint16, v interface{}) (next bool)) { + p.RLock() + defer p.RUnlock() + + for port, v := range p.ports { + if next := fn(port, v); !next { + return + } + } +} + +// This returns a function that frees a given port. +// It is ensured that the function's action is only performed once. +func (p *Porter) makePortFreer(port uint16) func() { + once := new(sync.Once) + return func() { + once.Do(func() { + p.Lock() + delete(p.ports, port) + p.Unlock() + }) + } +} diff --git a/vendor/github.com/skycoin/dmsg/port_manager.go b/vendor/github.com/skycoin/dmsg/port_manager.go new file mode 100644 index 0000000000..0ab5a18e4a --- /dev/null +++ b/vendor/github.com/skycoin/dmsg/port_manager.go @@ -0,0 +1,66 @@ +package dmsg + +import ( + "context" + "sync" + + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/dmsg/netutil" +) + +// PortManager manages ports of nodes. +type PortManager struct { + lPK cipher.PubKey + p *netutil.Porter +} + +func newPortManager(lPK cipher.PubKey) *PortManager { + return &PortManager{ + lPK: lPK, + p: netutil.NewPorter(netutil.PorterMinEphemeral), + } +} + +// Listener returns a listener assigned to a given port. +func (pm *PortManager) Listener(port uint16) (*Listener, bool) { + v, ok := pm.p.PortValue(port) + if !ok { + return nil, false + } + l, ok := v.(*Listener) + return l, ok +} + +// NewListener assigns listener to port if port is available. +func (pm *PortManager) NewListener(port uint16) (*Listener, bool) { + l := newListener(Addr{pm.lPK, port}) + ok, clear := pm.p.Reserve(port, l) + if !ok { + return nil, false + } + l.AddCloseCallback(clear) + return l, true +} + +// ReserveEphemeral reserves an ephemeral port. +func (pm *PortManager) ReserveEphemeral(ctx context.Context) (uint16, func(), error) { + return pm.p.ReserveEphemeral(ctx, nil) +} + +// Close closes all listeners. +func (pm *PortManager) Close() error { + wg := new(sync.WaitGroup) + pm.p.RangePortValues(func(_ uint16, v interface{}) (next bool) { + l, ok := v.(*Listener) + if ok { + wg.Add(1) + go func() { + l.close() + wg.Done() + }() + } + return true + }) + wg.Wait() + return nil +} diff --git a/vendor/github.com/skycoin/dmsg/server.go b/vendor/github.com/skycoin/dmsg/server.go index 4433b65cd8..a5bfa304c2 100644 --- a/vendor/github.com/skycoin/dmsg/server.go +++ b/vendor/github.com/skycoin/dmsg/server.go @@ -19,238 +19,6 @@ import ( // ErrListenerAlreadyWrappedToNoise occurs when the provided net.Listener is already wrapped with noise.Listener var ErrListenerAlreadyWrappedToNoise = errors.New("listener is already wrapped to *noise.Listener") -// NextConn provides information on the next connection. -type NextConn struct { - conn *ServerConn - id uint16 -} - -func (r *NextConn) writeFrame(ft FrameType, p []byte) error { - if err := writeFrame(r.conn.Conn, MakeFrame(ft, r.id, p)); err != nil { - go func() { - if err := r.conn.Close(); err != nil { - log.WithError(err).Warn("Failed to close connection") - } - }() - return err - } - return nil -} - -// ServerConn is a connection between a dmsg.Server and a dmsg.Client from a server's perspective. -type ServerConn struct { - log *logging.Logger - - net.Conn - remoteClient cipher.PubKey - - nextRespID uint16 - nextConns map[uint16]*NextConn - mx sync.RWMutex -} - -// NewServerConn creates a new connection from the perspective of a dms_server. -func NewServerConn(log *logging.Logger, conn net.Conn, remoteClient cipher.PubKey) *ServerConn { - return &ServerConn{ - log: log, - Conn: conn, - remoteClient: remoteClient, - nextRespID: randID(false), - nextConns: make(map[uint16]*NextConn), - } -} - -func (c *ServerConn) delNext(id uint16) { - c.mx.Lock() - delete(c.nextConns, id) - c.mx.Unlock() -} - -func (c *ServerConn) setNext(id uint16, r *NextConn) { - c.mx.Lock() - c.nextConns[id] = r - c.mx.Unlock() -} - -func (c *ServerConn) getNext(id uint16) (*NextConn, bool) { - c.mx.RLock() - r := c.nextConns[id] - c.mx.RUnlock() - return r, r != nil -} - -func (c *ServerConn) addNext(ctx context.Context, r *NextConn) (uint16, error) { - c.mx.Lock() - defer c.mx.Unlock() - - for { - if r := c.nextConns[c.nextRespID]; r == nil { - break - } - c.nextRespID += 2 - - select { - case <-ctx.Done(): - return 0, ctx.Err() - default: - } - } - - id := c.nextRespID - c.nextRespID = id + 2 - c.nextConns[id] = r - return id, nil -} - -// PK returns the remote dms_client's public key. -func (c *ServerConn) PK() cipher.PubKey { - return c.remoteClient -} - -type getConnFunc func(pk cipher.PubKey) (*ServerConn, bool) - -// Serve handles (and forwards when necessary) incoming frames. -func (c *ServerConn) Serve(ctx context.Context, getConn getConnFunc) (err error) { - log := c.log.WithField("srcClient", c.remoteClient) - - // Only manually close the underlying net.Conn when the done signal is context-initiated. - done := make(chan struct{}) - defer close(done) - go func() { - select { - case <-done: - case <-ctx.Done(): - if err := c.Conn.Close(); err != nil { - log.WithError(err).Warn("failed to close underlying connection") - } - } - }() - - defer func() { - // Send CLOSE frames to all transports which are established with this dmsg.Client - // This ensures that all parties are informed about the transport closing. - c.mx.Lock() - for _, conn := range c.nextConns { - why := byte(0) - if err := conn.writeFrame(CloseType, []byte{why}); err != nil { - log.WithError(err).Warnf("failed to write frame: %s", err) - } - } - c.mx.Unlock() - - log.WithError(err).WithField("connCount", decrementServeCount()).Infoln("ClosingConn") - if err := c.Conn.Close(); err != nil { - log.WithError(err).Warn("Failed to close connection") - } - }() - log.WithField("connCount", incrementServeCount()).Infoln("ServingConn") - - err = c.writeOK() - if err != nil { - return fmt.Errorf("sending OK failed: %s", err) - } - - for { - f, err := readFrame(c.Conn) - if err != nil { - return fmt.Errorf("read failed: %s", err) - } - log = log.WithField("received", f) - - ft, id, p := f.Disassemble() - - switch ft { - case RequestType: - ctx, cancel := context.WithTimeout(ctx, TransportHandshakeTimeout) - _, why, ok := c.handleRequest(ctx, getConn, id, p) - cancel() - if !ok { - log.Debugln("FrameRejected: Erroneous request or unresponsive dstClient.") - if err := c.delChan(id, why); err != nil { - return err - } - } - log.Debugln("FrameForwarded") - - case AcceptType, FwdType, AckType, CloseType: - next, why, ok := c.forwardFrame(ft, id, p) - if !ok { - log.Debugln("FrameRejected: Failed to forward to dstClient.") - // Delete channel (and associations) on failure. - if err := c.delChan(id, why); err != nil { - return err - } - continue - } - log.Debugln("FrameForwarded") - - // On success, if Close frame, delete the associations. - if ft == CloseType { - c.delNext(id) - next.conn.delNext(next.id) - } - - default: - log.Debugln("FrameRejected: Unknown frame type.") - // Unknown frame type. - return errors.New("unknown frame of type received") - } - } -} - -func (c *ServerConn) delChan(id uint16, why byte) error { - c.delNext(id) - if err := writeFrame(c.Conn, MakeFrame(CloseType, id, []byte{why})); err != nil { - return fmt.Errorf("failed to write frame: %s", err) - } - return nil -} - -func (c *ServerConn) writeOK() error { - if err := writeFrame(c.Conn, MakeFrame(OkType, 0, nil)); err != nil { - return err - } - return nil -} - -// nolint:unparam -func (c *ServerConn) forwardFrame(ft FrameType, id uint16, p []byte) (*NextConn, byte, bool) { - next, ok := c.getNext(id) - if !ok { - return next, 0, false - } - if err := next.writeFrame(ft, p); err != nil { - return next, 0, false - } - return next, 0, true -} - -// nolint:unparam -func (c *ServerConn) handleRequest(ctx context.Context, getLink getConnFunc, id uint16, p []byte) (*NextConn, byte, bool) { - initPK, respPK, ok := splitPKs(p) - if !ok || initPK != c.PK() { - return nil, 0, false - } - respL, ok := getLink(respPK) - if !ok { - return nil, 0, false - } - - // set next relations. - respID, err := respL.addNext(ctx, &NextConn{conn: c, id: id}) - if err != nil { - return nil, 0, false - } - next := &NextConn{conn: respL, id: respID} - c.setNext(id, next) - - // forward to responding client. - if err := next.writeFrame(RequestType, p); err != nil { - return next, 0, false - } - return next, 0, true -} - // Server represents a dms_server. type Server struct { log *logging.Logger diff --git a/vendor/github.com/skycoin/dmsg/server_conn.go b/vendor/github.com/skycoin/dmsg/server_conn.go new file mode 100644 index 0000000000..a162b5102b --- /dev/null +++ b/vendor/github.com/skycoin/dmsg/server_conn.go @@ -0,0 +1,243 @@ +package dmsg + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + + "github.com/skycoin/skycoin/src/util/logging" + + "github.com/skycoin/dmsg/cipher" +) + +// NextConn provides information on the next connection. +type NextConn struct { + conn *ServerConn + id uint16 +} + +func (r *NextConn) writeFrame(ft FrameType, p []byte) error { + if err := writeFrame(r.conn.Conn, MakeFrame(ft, r.id, p)); err != nil { + go func() { + if err := r.conn.Close(); err != nil { + log.WithError(err).Warn("Failed to close connection") + } + }() + return err + } + return nil +} + +// ServerConn is a connection between a dmsg.Server and a dmsg.Client from a server's perspective. +type ServerConn struct { + log *logging.Logger + + net.Conn + remoteClient cipher.PubKey + + nextRespID uint16 + nextConns map[uint16]*NextConn + mx sync.RWMutex +} + +// NewServerConn creates a new connection from the perspective of a dms_server. +func NewServerConn(log *logging.Logger, conn net.Conn, remoteClient cipher.PubKey) *ServerConn { + return &ServerConn{ + log: log, + Conn: conn, + remoteClient: remoteClient, + nextRespID: randID(false), + nextConns: make(map[uint16]*NextConn), + } +} + +func (c *ServerConn) delNext(id uint16) { + c.mx.Lock() + delete(c.nextConns, id) + c.mx.Unlock() +} + +func (c *ServerConn) setNext(id uint16, r *NextConn) { + c.mx.Lock() + c.nextConns[id] = r + c.mx.Unlock() +} + +func (c *ServerConn) getNext(id uint16) (*NextConn, bool) { + c.mx.RLock() + r := c.nextConns[id] + c.mx.RUnlock() + return r, r != nil +} + +func (c *ServerConn) addNext(ctx context.Context, r *NextConn) (uint16, error) { + c.mx.Lock() + defer c.mx.Unlock() + + for { + if r := c.nextConns[c.nextRespID]; r == nil { + break + } + c.nextRespID += 2 + + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + } + } + + id := c.nextRespID + c.nextRespID = id + 2 + c.nextConns[id] = r + return id, nil +} + +// PK returns the remote dms_client's public key. +func (c *ServerConn) PK() cipher.PubKey { + return c.remoteClient +} + +type getConnFunc func(pk cipher.PubKey) (*ServerConn, bool) + +// Serve handles (and forwards when necessary) incoming frames. +func (c *ServerConn) Serve(ctx context.Context, getConn getConnFunc) (err error) { + log := c.log.WithField("srcClient", c.remoteClient) + + // Only manually close the underlying net.Conn when the done signal is context-initiated. + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-done: + case <-ctx.Done(): + if err := c.Conn.Close(); err != nil { + log.WithError(err).Warn("failed to close underlying connection") + } + } + }() + + defer func() { + // Send CLOSE frames to all transports which are established with this dmsg.Client + // This ensures that all parties are informed about the transport closing. + c.mx.Lock() + for _, conn := range c.nextConns { + why := byte(0) + if err := conn.writeFrame(CloseType, []byte{why}); err != nil { + log.WithError(err).Warnf("failed to write frame: %s", err) + } + } + c.mx.Unlock() + + log.WithError(err).WithField("connCount", decrementServeCount()).Infoln("ClosingConn") + if err := c.Conn.Close(); err != nil { + log.WithError(err).Warn("Failed to close connection") + } + }() + + log.WithField("connCount", incrementServeCount()).Infoln("ServingConn") + + err = c.writeOK() + if err != nil { + return fmt.Errorf("sending OK failed: %s", err) + } + + for { + f, df, err := readFrame(c.Conn) + if err != nil { + return fmt.Errorf("read failed: %s", err) + } + log := log.WithField("received", f) + + switch df.Type { + case RequestType: + ctx, cancel := context.WithTimeout(ctx, TransportHandshakeTimeout) + _, why, ok := c.handleRequest(ctx, getConn, df.TpID, df.Pay) + cancel() + if !ok { + log.Debugln("FrameRejected: Erroneous request or unresponsive dstClient.") + if err := c.delChan(df.TpID, why); err != nil { + return err + } + } + log.Debugln("FrameForwarded") + + case AcceptType, FwdType, AckType, CloseType: + next, why, ok := c.forwardFrame(df.Type, df.TpID, df.Pay) + if !ok { + log.Debugln("FrameRejected: Failed to forward to dstClient.") + // Delete channel (and associations) on failure. + if err := c.delChan(df.TpID, why); err != nil { + return err + } + continue + } + log.Debugln("FrameForwarded") + + // On success, if Close frame, delete the associations. + if df.Type == CloseType { + c.delNext(df.TpID) + next.conn.delNext(next.id) + } + + default: + log.Debugln("FrameRejected: Unknown frame type.") + return errors.New("unknown frame of type received") + } + } +} + +func (c *ServerConn) delChan(id uint16, why byte) error { + c.delNext(id) + if err := writeCloseFrame(c.Conn, id, why); err != nil { + return fmt.Errorf("failed to write frame: %s", err) + } + return nil +} + +func (c *ServerConn) writeOK() error { + if err := writeFrame(c.Conn, MakeFrame(OkType, 0, nil)); err != nil { + return err + } + return nil +} + +// nolint:unparam +func (c *ServerConn) forwardFrame(ft FrameType, id uint16, p []byte) (*NextConn, byte, bool) { + next, ok := c.getNext(id) + if !ok { + return next, 0, false + } + if err := next.writeFrame(ft, p); err != nil { + return next, 0, false + } + return next, 0, true +} + +// nolint:unparam +func (c *ServerConn) handleRequest(ctx context.Context, getLink getConnFunc, id uint16, p []byte) (*NextConn, byte, bool) { + payload, err := unmarshalHandshakePayload(p) + if err != nil || payload.InitAddr.PK != c.PK() { + return nil, 0, false + } + respL, ok := getLink(payload.RespAddr.PK) + if !ok { + return nil, 0, false + } + + // set next relations. + respID, err := respL.addNext(ctx, &NextConn{conn: c, id: id}) + if err != nil { + return nil, 0, false + } + next := &NextConn{conn: respL, id: respID} + c.setNext(id, next) + + // forward to responding client. + if err := next.writeFrame(RequestType, p); err != nil { + return next, 0, false + } + return next, 0, true +} diff --git a/vendor/github.com/skycoin/dmsg/testing.go b/vendor/github.com/skycoin/dmsg/testing.go index ef9095b9f8..49a181b755 100644 --- a/vendor/github.com/skycoin/dmsg/testing.go +++ b/vendor/github.com/skycoin/dmsg/testing.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "io" + "net" "testing" "time" @@ -42,10 +43,12 @@ func checkConnCount(t *testing.T, delay time.Duration, count int, ccs ...connCou })) } -func checkTransportsClosed(t *testing.T, transports ...*Transport) { - for _, transport := range transports { - assert.False(t, isDoneChanOpen(transport.done)) - assert.False(t, isReadChanOpen(transport.inCh)) +func checkTransportsClosed(t *testing.T, transports ...net.Conn) { + for _, tr := range transports { + if tr, ok := tr.(*Transport); ok && tr != nil { + assert.False(t, isDoneChanOpen(tr.done)) + assert.False(t, isReadChanOpen(tr.inCh)) + } } } diff --git a/vendor/github.com/skycoin/dmsg/transport.go b/vendor/github.com/skycoin/dmsg/transport.go index 734983de93..5a7467172b 100644 --- a/vendor/github.com/skycoin/dmsg/transport.go +++ b/vendor/github.com/skycoin/dmsg/transport.go @@ -19,16 +19,18 @@ var ( ErrRequestRejected = errors.New("failed to create transport: request rejected") ErrRequestCheckFailed = errors.New("failed to create transport: request check failed") ErrAcceptCheckFailed = errors.New("failed to create transport: accept check failed") + ErrPortNotListening = errors.New("failed to create transport: port not listening") ) -// Transport represents a connection from dmsg.Client to remote dmsg.Client (via dmsg.Server intermediary). +// Transport represents communication between two nodes via a single hop: +// a connection from dmsg.Client to remote dmsg.Client (via dmsg.Server intermediary). type Transport struct { net.Conn // underlying connection to dmsg.Server log *logging.Logger - id uint16 // tp ID that identifies this dmsg.Transport - local cipher.PubKey // local PK - remote cipher.PubKey // remote PK + id uint16 // tp ID that identifies this dmsg.transport + local Addr // local PK + remote Addr // remote PK inCh chan Frame // handles incoming frames (from dmsg.Client) inMx sync.Mutex // protects 'inCh' @@ -39,17 +41,16 @@ type Transport struct { bufCh chan struct{} // chan for indicating whether this is a new FWD frame bufSize int // keeps track of the total size of 'buf' bufMx sync.Mutex // protects fields responsible for handling FWD and ACK frames - rMx sync.Mutex // TODO: (WORKAROUND) concurrent reads seem problematic right now. - serving chan struct{} // chan which closes when serving begins - servingOnce sync.Once // ensures 'serving' only closes once - done chan struct{} // chan which closes when transport stops serving - doneOnce sync.Once // ensures 'done' only closes once - doneFunc func(id uint16) // contains a method to remove the transport from dmsg.Client + serving chan struct{} // chan which closes when serving begins + servingOnce sync.Once // ensures 'serving' only closes once + done chan struct{} // chan which closes when transport stops serving + doneOnce sync.Once // ensures 'done' only closes once + doneFunc func() // contains a method that triggers when dmsg.Client closes } // NewTransport creates a new dms_tp. -func NewTransport(conn net.Conn, log *logging.Logger, local, remote cipher.PubKey, id uint16, doneFunc func(id uint16)) *Transport { +func NewTransport(conn net.Conn, log *logging.Logger, local, remote Addr, id uint16, doneFunc func()) *Transport { tp := &Transport{ Conn: conn, log: log, @@ -94,7 +95,7 @@ func (tp *Transport) close() (closed bool) { closed = true close(tp.done) - tp.doneFunc(tp.id) + tp.doneFunc() tp.bufMx.Lock() close(tp.bufCh) @@ -113,7 +114,7 @@ func (tp *Transport) close() (closed bool) { // Close closes the dmsg_tp. func (tp *Transport) Close() error { if tp.close() { - if err := writeFrame(tp.Conn, MakeFrame(CloseType, tp.id, []byte{0})); err != nil { + if err := writeCloseFrame(tp.Conn, tp.id, PlaceholderReason); err != nil { log.WithError(err).Warn("Failed to write frame") } } @@ -132,14 +133,20 @@ func (tp *Transport) IsClosed() bool { // LocalPK returns the local public key of the transport. func (tp *Transport) LocalPK() cipher.PubKey { - return tp.local + return tp.local.PK } // RemotePK returns the remote public key of the transport. func (tp *Transport) RemotePK() cipher.PubKey { - return tp.remote + return tp.remote.PK } +// LocalAddr returns local address in from : +func (tp *Transport) LocalAddr() net.Addr { return tp.local } + +// RemoteAddr returns remote address in form : +func (tp *Transport) RemoteAddr() net.Addr { return tp.remote } + // Type returns the transport type. func (tp *Transport) Type() string { return Type @@ -163,7 +170,16 @@ func (tp *Transport) HandleFrame(f Frame) error { // WriteRequest writes a REQUEST frame to dmsg_server to be forwarded to associated client. func (tp *Transport) WriteRequest() error { - f := MakeFrame(RequestType, tp.id, combinePKs(tp.local, tp.remote)) + payload := HandshakePayload{ + Version: HandshakePayloadVersion, + InitAddr: tp.local, + RespAddr: tp.remote, + } + payloadBytes, err := marshalHandshakePayload(payload) + if err != nil { + return err + } + f := MakeFrame(RequestType, tp.id, payloadBytes) if err := writeFrame(tp.Conn, f); err != nil { tp.log.WithError(err).Error("HandshakeFailed") tp.close() @@ -182,7 +198,7 @@ func (tp *Transport) WriteAccept() (err error) { } }() - f := MakeFrame(AcceptType, tp.id, combinePKs(tp.remote, tp.local)) + f := MakeFrame(AcceptType, tp.id, combinePKs(tp.remote.PK, tp.local.PK)) if err = writeFrame(tp.Conn, f); err != nil { tp.close() return err @@ -225,7 +241,7 @@ func (tp *Transport) ReadAccept(ctx context.Context) (err error) { // - resp_pk should be of remote client. // - use an even number with the intermediary dmsg_server. initPK, respPK, ok := splitPKs(p) - if !ok || initPK != tp.local || respPK != tp.remote || !isInitiatorID(id) { + if !ok || initPK != tp.local.PK || respPK != tp.remote.PK || !isInitiatorID(id) { if err := tp.Close(); err != nil { log.WithError(err).Warn("Failed to close transport") } @@ -257,7 +273,7 @@ func (tp *Transport) Serve() { // also write CLOSE frame if this is the first time 'close' is triggered defer func() { if tp.close() { - if err := writeCloseFrame(tp.Conn, tp.id, 0); err != nil { + if err := writeCloseFrame(tp.Conn, tp.id, PlaceholderReason); err != nil { log.WithError(err).Warn("Failed to write close frame") } } @@ -342,9 +358,6 @@ func (tp *Transport) Serve() { func (tp *Transport) Read(p []byte) (n int, err error) { <-tp.serving - tp.rMx.Lock() - defer tp.rMx.Unlock() - startRead: tp.bufMx.Lock() n, err = tp.buf.Read(p) diff --git a/vendor/github.com/skycoin/dmsg/frame.go b/vendor/github.com/skycoin/dmsg/types.go similarity index 66% rename from vendor/github.com/skycoin/dmsg/frame.go rename to vendor/github.com/skycoin/dmsg/types.go index 78e10edf5f..dcaabe6db6 100644 --- a/vendor/github.com/skycoin/dmsg/frame.go +++ b/vendor/github.com/skycoin/dmsg/types.go @@ -2,6 +2,7 @@ package dmsg import ( "encoding/binary" + "encoding/json" "fmt" "io" "math" @@ -16,6 +17,9 @@ import ( const ( // Type returns the transport type string. Type = "dmsg" + // HandshakePayloadVersion contains payload version to maintain compatibility with future versions + // of HandshakePayload format. + HandshakePayloadVersion = "2.0" tpBufCap = math.MaxUint16 tpBufFrameCap = math.MaxUint8 @@ -31,6 +35,43 @@ var ( AcceptBufferSize = 20 ) +// Addr implements net.Addr for dmsg addresses. +type Addr struct { + PK cipher.PubKey `json:"public_key"` + Port uint16 `json:"port"` +} + +// Network returns "dmsg" +func (Addr) Network() string { + return Type +} + +// String returns public key and port of node split by colon. +func (a Addr) String() string { + if a.Port == 0 { + return fmt.Sprintf("%s:~", a.PK) + } + return fmt.Sprintf("%s:%d", a.PK, a.Port) +} + +// HandshakePayload represents format of payload sent with REQUEST frames. +type HandshakePayload struct { + Version string `json:"version"` // just in case the struct changes. + InitAddr Addr `json:"init_address"` + RespAddr Addr `json:"resp_address"` +} + +func marshalHandshakePayload(p HandshakePayload) ([]byte, error) { + return json.Marshal(p) +} + +func unmarshalHandshakePayload(b []byte) (HandshakePayload, error) { + var p HandshakePayload + err := json.Unmarshal(b, &p) + return p, err +} + +// determines whether the transport ID is of an initiator or responder. func isInitiatorID(tpID uint16) bool { return tpID%2 == 0 } func randID(initiator bool) uint16 { @@ -43,6 +84,7 @@ func randID(initiator bool) uint16 { } } +// serveCount records the number of dmsg.Servers connected var serveCount int64 func incrementServeCount() int64 { return atomic.AddInt64(&serveCount, 1) } @@ -76,6 +118,11 @@ const ( AckType = FrameType(0xb) ) +// Reasons for closing frames +const ( + PlaceholderReason = iota +) + // Frame is the dmsg data unit. type Frame []byte @@ -116,24 +163,36 @@ func (f Frame) String() string { return fmt.Sprintf("%s", f.Type(), f.TpID(), f.PayLen(), p) } -func readFrame(r io.Reader) (Frame, error) { - f := make(Frame, headerLen) - if _, err := io.ReadFull(r, f); err != nil { - return nil, err +type disassembledFrame struct { + Type FrameType + TpID uint16 + Pay []byte +} + +// read and disassembles frame from reader +func readFrame(r io.Reader) (f Frame, df disassembledFrame, err error) { + f = make(Frame, headerLen) + if _, err = io.ReadFull(r, f); err != nil { + return } f = append(f, make([]byte, f.PayLen())...) - _, err := io.ReadFull(r, f[headerLen:]) - return f, err + if _, err = io.ReadFull(r, f[headerLen:]); err != nil { + return + } + t, id, p := f.Disassemble() + df = disassembledFrame{Type: t, TpID: id, Pay: p} + return } type writeError struct{ error } func (e *writeError) Error() string { return "write error: " + e.error.Error() } -func isWriteError(err error) bool { - _, ok := err.(*writeError) - return ok -} +// TODO(evanlinjin): Determine if this is still needed, may be useful elsewhere. +//func isWriteError(err error) bool { +// _, ok := err.(*writeError) +// return ok +//} func writeFrame(w io.Writer, f Frame) error { _, err := w.Write(f) diff --git a/vendor/modules.txt b/vendor/modules.txt index b859dcfc63..4ad29bf678 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -64,10 +64,11 @@ github.com/prometheus/procfs/internal/fs # github.com/sirupsen/logrus v1.4.2 github.com/sirupsen/logrus github.com/sirupsen/logrus/hooks/syslog -# github.com/skycoin/dmsg v0.0.0-20190805065636-70f4c32a994f +# github.com/skycoin/dmsg v0.0.0-20190805065636-70f4c32a994f => ../dmsg github.com/skycoin/dmsg/cipher github.com/skycoin/dmsg github.com/skycoin/dmsg/disc +github.com/skycoin/dmsg/netutil github.com/skycoin/dmsg/noise github.com/skycoin/dmsg/ioutil # github.com/skycoin/skycoin v0.26.0 From 25489ef6a68b52cc902f45eda09b9550f2669461 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Fri, 20 Sep 2019 18:09:39 +0300 Subject: [PATCH 28/43] Adjust code to the `Networker` usage --- pkg/app2/client.go | 13 +++-- pkg/app2/client_test.go | 8 ++-- pkg/app2/conn.go | 6 +-- pkg/app2/listener.go | 4 +- pkg/app2/manager.go | 2 +- pkg/app2/manager_test.go | 10 ++-- pkg/app2/mock_server_rpc_client.go | 22 ++++----- pkg/app2/network/dmsg_networker.go | 40 ++++++++++++++++ pkg/app2/network/networker.go | 16 +++---- pkg/app2/server_rpc.go | 76 +++++++++++++++--------------- pkg/app2/server_rpc_client.go | 16 +++---- 11 files changed, 128 insertions(+), 85 deletions(-) create mode 100644 pkg/app2/network/dmsg_networker.go diff --git a/pkg/app2/client.go b/pkg/app2/client.go index 53fd6b0088..743fadc433 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -4,6 +4,8 @@ import ( "context" "net" + "github.com/skycoin/skywire/pkg/app2/network" + "github.com/skycoin/dmsg/cipher" "github.com/skycoin/dmsg/netutil" "github.com/skycoin/skywire/pkg/routing" @@ -32,12 +34,11 @@ func NewClient(localPK cipher.PubKey, pid ProcID, rpc ServerRPCClient, porter *n } // Dial dials the remote node using `remote`. -func (c *Client) Dial(remote routing.Addr) (net.Conn, error) { +func (c *Client) Dial(remote network.Addr) (net.Conn, error) { localPort, free, err := c.porter.ReserveEphemeral(context.TODO(), nil) if err != nil { return nil, err } - net.Dial() connID, err := c.rpc.Dial(remote) if err != nil { @@ -48,7 +49,8 @@ func (c *Client) Dial(remote routing.Addr) (net.Conn, error) { conn := &Conn{ id: connID, rpc: c.rpc, - local: routing.Addr{ + local: network.Addr{ + Net: remote.Net, PubKey: c.pk, Port: routing.Port(localPort), }, @@ -60,13 +62,14 @@ func (c *Client) Dial(remote routing.Addr) (net.Conn, error) { } // Listen listens on the specified `port` for the incoming connections. -func (c *Client) Listen(port routing.Port) (net.Listener, error) { +func (c *Client) Listen(n network.Type, port routing.Port) (net.Listener, error) { ok, free := c.porter.Reserve(uint16(port), nil) if !ok { return nil, ErrPortAlreadyBound } - local := routing.Addr{ + local := network.Addr{ + Net: n, PubKey: c.pk, Port: port, } diff --git a/pkg/app2/client_test.go b/pkg/app2/client_test.go index dcc4a2c35a..1c49d5c14f 100644 --- a/pkg/app2/client_test.go +++ b/pkg/app2/client_test.go @@ -3,6 +3,8 @@ package app2 import ( "testing" + "github.com/skycoin/skywire/pkg/app2/network" + "github.com/pkg/errors" "github.com/stretchr/testify/require" @@ -19,7 +21,7 @@ func TestClient_Dial(t *testing.T) { remotePK, _ := cipher.GenerateKeyPair() remotePort := routing.Port(120) - remote := routing.Addr{ + remote := network.Addr{ PubKey: remotePK, Port: remotePort, } @@ -36,7 +38,7 @@ func TestClient_Dial(t *testing.T) { wantConn := &Conn{ id: dialConnID, rpc: rpc, - local: routing.Addr{ + local: network.Addr{ PubKey: localPK, }, remote: remote, @@ -76,7 +78,7 @@ func TestClient_Listen(t *testing.T) { pid := ProcID(1) port := routing.Port(1) - local := routing.Addr{ + local := network.Addr{ PubKey: localPK, Port: port, } diff --git a/pkg/app2/conn.go b/pkg/app2/conn.go index 07cda88b33..632fe0d464 100644 --- a/pkg/app2/conn.go +++ b/pkg/app2/conn.go @@ -4,7 +4,7 @@ import ( "net" "time" - "github.com/skycoin/skywire/pkg/routing" + "github.com/skycoin/skywire/pkg/app2/network" ) // Conn is a connection from app client to the server. @@ -12,8 +12,8 @@ import ( type Conn struct { id uint16 rpc ServerRPCClient - local routing.Addr - remote routing.Addr + local network.Addr + remote network.Addr freeLocalPort func() } diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go index 9b7bc33977..bb38364a7d 100644 --- a/pkg/app2/listener.go +++ b/pkg/app2/listener.go @@ -3,7 +3,7 @@ package app2 import ( "net" - "github.com/skycoin/skywire/pkg/routing" + "github.com/skycoin/skywire/pkg/app2/network" ) // Listener is a listener for app server connections. @@ -11,7 +11,7 @@ import ( type Listener struct { id uint16 rpc ServerRPCClient - addr routing.Addr + addr network.Addr freePort func() } diff --git a/pkg/app2/manager.go b/pkg/app2/manager.go index 5033b1e929..ce0e8f3f00 100644 --- a/pkg/app2/manager.go +++ b/pkg/app2/manager.go @@ -52,7 +52,7 @@ func (m *manager) nextKey() (*uint16, error) { // getAndRemove removes value specified by `key` from the manager instance and // returns it. -func (m *manager) getAndRemove(key uint16) (interface{}, error) { +func (m *manager) pop(key uint16) (interface{}, error) { m.mx.Lock() v, ok := m.values[key] if !ok { diff --git a/pkg/app2/manager_test.go b/pkg/app2/manager_test.go index 1032a9aca1..d67fe33179 100644 --- a/pkg/app2/manager_test.go +++ b/pkg/app2/manager_test.go @@ -66,7 +66,7 @@ func TestManager_NextID(t *testing.T) { }) } -func TestManager_GetAndRemove(t *testing.T) { +func TestManager_Pop(t *testing.T) { t.Run("simple call", func(t *testing.T) { m := newManager() @@ -74,7 +74,7 @@ func TestManager_GetAndRemove(t *testing.T) { m.values[1] = v - gotV, err := m.getAndRemove(1) + gotV, err := m.pop(1) require.NoError(t, err) require.NotNil(t, gotV) require.Equal(t, gotV, v) @@ -86,7 +86,7 @@ func TestManager_GetAndRemove(t *testing.T) { t.Run("no value", func(t *testing.T) { m := newManager() - _, err := m.getAndRemove(1) + _, err := m.pop(1) require.Error(t, err) }) @@ -95,7 +95,7 @@ func TestManager_GetAndRemove(t *testing.T) { m.values[1] = nil - _, err := m.getAndRemove(1) + _, err := m.pop(1) require.Error(t, err) }) @@ -108,7 +108,7 @@ func TestManager_GetAndRemove(t *testing.T) { errs := make(chan error, concurrency) for i := uint16(0); i < uint16(concurrency); i++ { go func() { - _, err := m.getAndRemove(1) + _, err := m.pop(1) errs <- err }() } diff --git a/pkg/app2/mock_server_rpc_client.go b/pkg/app2/mock_server_rpc_client.go index 26fb17717e..58e88cab76 100644 --- a/pkg/app2/mock_server_rpc_client.go +++ b/pkg/app2/mock_server_rpc_client.go @@ -3,7 +3,7 @@ package app2 import mock "github.com/stretchr/testify/mock" -import routing "github.com/skycoin/skywire/pkg/routing" +import network "github.com/skycoin/skywire/pkg/app2/network" // MockServerRPCClient is an autogenerated mock type for the ServerRPCClient type type MockServerRPCClient struct { @@ -11,7 +11,7 @@ type MockServerRPCClient struct { } // Accept provides a mock function with given fields: lisID -func (_m *MockServerRPCClient) Accept(lisID uint16) (uint16, routing.Addr, error) { +func (_m *MockServerRPCClient) Accept(lisID uint16) (uint16, network.Addr, error) { ret := _m.Called(lisID) var r0 uint16 @@ -21,11 +21,11 @@ func (_m *MockServerRPCClient) Accept(lisID uint16) (uint16, routing.Addr, error r0 = ret.Get(0).(uint16) } - var r1 routing.Addr - if rf, ok := ret.Get(1).(func(uint16) routing.Addr); ok { + var r1 network.Addr + if rf, ok := ret.Get(1).(func(uint16) network.Addr); ok { r1 = rf(lisID) } else { - r1 = ret.Get(1).(routing.Addr) + r1 = ret.Get(1).(network.Addr) } var r2 error @@ -67,18 +67,18 @@ func (_m *MockServerRPCClient) CloseListener(id uint16) error { } // Dial provides a mock function with given fields: remote -func (_m *MockServerRPCClient) Dial(remote routing.Addr) (uint16, error) { +func (_m *MockServerRPCClient) Dial(remote network.Addr) (uint16, error) { ret := _m.Called(remote) var r0 uint16 - if rf, ok := ret.Get(0).(func(routing.Addr) uint16); ok { + if rf, ok := ret.Get(0).(func(network.Addr) uint16); ok { r0 = rf(remote) } else { r0 = ret.Get(0).(uint16) } var r1 error - if rf, ok := ret.Get(1).(func(routing.Addr) error); ok { + if rf, ok := ret.Get(1).(func(network.Addr) error); ok { r1 = rf(remote) } else { r1 = ret.Error(1) @@ -88,18 +88,18 @@ func (_m *MockServerRPCClient) Dial(remote routing.Addr) (uint16, error) { } // Listen provides a mock function with given fields: local -func (_m *MockServerRPCClient) Listen(local routing.Addr) (uint16, error) { +func (_m *MockServerRPCClient) Listen(local network.Addr) (uint16, error) { ret := _m.Called(local) var r0 uint16 - if rf, ok := ret.Get(0).(func(routing.Addr) uint16); ok { + if rf, ok := ret.Get(0).(func(network.Addr) uint16); ok { r0 = rf(local) } else { r0 = ret.Get(0).(uint16) } var r1 error - if rf, ok := ret.Get(1).(func(routing.Addr) error); ok { + if rf, ok := ret.Get(1).(func(network.Addr) error); ok { r1 = rf(local) } else { r1 = ret.Error(1) diff --git a/pkg/app2/network/dmsg_networker.go b/pkg/app2/network/dmsg_networker.go new file mode 100644 index 0000000000..424b0df466 --- /dev/null +++ b/pkg/app2/network/dmsg_networker.go @@ -0,0 +1,40 @@ +package network + +import ( + "context" + "net" + + "github.com/skycoin/dmsg" +) + +// DMSGNetworker implements `Networker` for dmsg network. +type DMSGNetworker struct { + dmsgC *dmsg.Client +} + +// NewDMSGNetworker constructs new `DMSGNetworker`. +func NewDMSGNetworker(dmsgC *dmsg.Client) Networker { + return &DMSGNetworker{ + dmsgC: dmsgC, + } +} + +// Dial dials remote `addr` via dmsg network. +func (n *DMSGNetworker) Dial(addr Addr) (net.Conn, error) { + return n.DialContext(context.Background(), addr) +} + +// DialContext dials remote `addr` via dmsg network with context. +func (n *DMSGNetworker) DialContext(ctx context.Context, addr Addr) (net.Conn, error) { + return n.dmsgC.Dial(ctx, addr.PubKey, uint16(addr.Port)) +} + +// Listen starts listening on local `addr` in the dmsg network. +func (n *DMSGNetworker) Listen(addr Addr) (net.Listener, error) { + return n.ListenContext(context.Background(), addr) +} + +// ListenContext starts listening on local `addr` in the dmsg network with context. +func (n *DMSGNetworker) ListenContext(ctx context.Context, addr Addr) (net.Listener, error) { + return n.dmsgC.Listen(uint16(addr.Port)) +} diff --git a/pkg/app2/network/networker.go b/pkg/app2/network/networker.go index 34b320f4b7..ed6277c88e 100644 --- a/pkg/app2/network/networker.go +++ b/pkg/app2/network/networker.go @@ -54,13 +54,13 @@ type Networker interface { } // Dial dials the remote `addr` of the specified `network`. -func Dial(t Type, addr Addr) (net.Conn, error) { - return DialContext(context.Background(), t, addr) +func Dial(addr Addr) (net.Conn, error) { + return DialContext(context.Background(), addr) } // DialContext dials the remote `Addr` of the specified `network` with the context. -func DialContext(ctx context.Context, t Type, addr Addr) (net.Conn, error) { - n, err := ResolveNetworker(t) +func DialContext(ctx context.Context, addr Addr) (net.Conn, error) { + n, err := ResolveNetworker(addr.Net) if err != nil { return nil, err } @@ -69,13 +69,13 @@ func DialContext(ctx context.Context, t Type, addr Addr) (net.Conn, error) { } // Listen starts listening on the local `addr` of the specified `network`. -func Listen(t Type, addr Addr) (net.Listener, error) { - return ListenContext(context.Background(), t, addr) +func Listen(addr Addr) (net.Listener, error) { + return ListenContext(context.Background(), addr) } // ListenContext starts listening on the local `addr` of the specified `network` with the context. -func ListenContext(ctx context.Context, t Type, addr Addr) (net.Listener, error) { - networker, err := ResolveNetworker(t) +func ListenContext(ctx context.Context, addr Addr) (net.Listener, error) { + networker, err := ResolveNetworker(addr.Net) if err != nil { return nil, err } diff --git a/pkg/app2/server_rpc.go b/pkg/app2/server_rpc.go index db4372ef1f..bc716e681a 100644 --- a/pkg/app2/server_rpc.go +++ b/pkg/app2/server_rpc.go @@ -1,48 +1,49 @@ package app2 import ( - "context" "fmt" "net" + "github.com/skycoin/skywire/pkg/app2/network" + "github.com/pkg/errors" "github.com/skycoin/dmsg" "github.com/skycoin/skycoin/src/util/logging" - - "github.com/skycoin/skywire/pkg/routing" ) // ServerRPC is a RPC interface for the app server. type ServerRPC struct { - dmsgC *dmsg.Client - lm *manager - cm *manager - log *logging.Logger + lm *manager + cm *manager + log *logging.Logger } // newServerRPC constructs new server RPC interface. -func newServerRPC(log *logging.Logger, dmsgC *dmsg.Client) *ServerRPC { +func newServerRPC(log *logging.Logger) *ServerRPC { return &ServerRPC{ - dmsgC: dmsgC, - lm: newManager(), - cm: newManager(), - log: log, + lm: newManager(), + cm: newManager(), + log: log, } } // Dial dials to the remote. -func (r *ServerRPC) Dial(remote *routing.Addr, connID *uint16) error { +func (r *ServerRPC) Dial(remote *network.Addr, connID *uint16) error { connID, err := r.cm.nextKey() if err != nil { return err } - tp, err := r.dmsgC.Dial(context.TODO(), remote.PubKey, uint16(remote.Port)) + conn, err := network.Dial(*remote) if err != nil { return err } - if err := r.cm.set(*connID, tp); err != nil { + if err := r.cm.set(*connID, conn); err != nil { + if err := conn.Close(); err != nil { + r.log.WithError(err).Error("error closing conn") + } + return err } @@ -50,20 +51,20 @@ func (r *ServerRPC) Dial(remote *routing.Addr, connID *uint16) error { } // Listen starts listening. -func (r *ServerRPC) Listen(local *routing.Addr, lisID *uint16) error { +func (r *ServerRPC) Listen(local *network.Addr, lisID *uint16) error { lisID, err := r.lm.nextKey() if err != nil { return err } - dmsgL, err := r.dmsgC.Listen(uint16(local.Port)) + l, err := network.Listen(*local) if err != nil { return err } - if err := r.lm.set(*lisID, dmsgL); err != nil { - if err := dmsgL.Close(); err != nil { - r.log.WithError(err).Error("error closing DMSG listener") + if err := r.lm.set(*lisID, l); err != nil { + if err := l.Close(); err != nil { + r.log.WithError(err).Error("error closing listener") } return err @@ -74,7 +75,7 @@ func (r *ServerRPC) Listen(local *routing.Addr, lisID *uint16) error { // AcceptResp contains response parameters for `Accept`. type AcceptResp struct { - Remote routing.Addr + Remote network.Addr ConnID uint16 } @@ -90,29 +91,26 @@ func (r *ServerRPC) Accept(lisID *uint16, resp *AcceptResp) error { return err } - tp, err := lis.Accept() + conn, err := lis.Accept() if err != nil { return err } - if err := r.cm.set(*connID, tp); err != nil { - if err := tp.Close(); err != nil { + if err := r.cm.set(*connID, conn); err != nil { + if err := conn.Close(); err != nil { r.log.WithError(err).Error("error closing DMSG transport") } return err } - remote, ok := tp.RemoteAddr().(dmsg.Addr) + remote, ok := conn.RemoteAddr().(network.Addr) if !ok { - return errors.New("wrong type for transport remote addr") + return errors.New("wrong type for remote addr") } resp = &AcceptResp{ - Remote: routing.Addr{ - PubKey: remote.PK, - Port: routing.Port(remote.Port), - }, + Remote: remote, ConnID: *connID, } @@ -163,7 +161,7 @@ func (r *ServerRPC) Read(connID *uint16, resp *ReadResp) error { // CloseConn closes connection specified by `connID`. func (r *ServerRPC) CloseConn(connID *uint16, _ *struct{}) error { - conn, err := r.getAndRemoveConn(*connID) + conn, err := r.popConn(*connID) if err != nil { return err } @@ -173,7 +171,7 @@ func (r *ServerRPC) CloseConn(connID *uint16, _ *struct{}) error { // CloseListener closes listener specified by `lisID`. func (r *ServerRPC) CloseListener(lisID *uint16, _ *struct{}) error { - lis, err := r.getAndRemoveListener(*lisID) + lis, err := r.popListener(*lisID) if err != nil { return err } @@ -181,10 +179,10 @@ func (r *ServerRPC) CloseListener(lisID *uint16, _ *struct{}) error { return lis.Close() } -// getAndRemoveListener gets listener from the manager by `lisID` and removes it. +// popListener gets listener from the manager by `lisID` and removes it. // Handles type assertion. -func (r *ServerRPC) getAndRemoveListener(lisID uint16) (*dmsg.Listener, error) { - lisIfc, err := r.lm.getAndRemove(lisID) +func (r *ServerRPC) popListener(lisID uint16) (*dmsg.Listener, error) { + lisIfc, err := r.lm.pop(lisID) if err != nil { return nil, err } @@ -192,10 +190,10 @@ func (r *ServerRPC) getAndRemoveListener(lisID uint16) (*dmsg.Listener, error) { return r.assertListener(lisIfc) } -// getAndRemoveConn gets conn from the manager by `connID` and removes it. +// popConn gets conn from the manager by `connID` and removes it. // Handles type assertion. -func (r *ServerRPC) getAndRemoveConn(connID uint16) (net.Conn, error) { - connIfc, err := r.cm.getAndRemove(connID) +func (r *ServerRPC) popConn(connID uint16) (net.Conn, error) { + connIfc, err := r.cm.pop(connID) if err != nil { return nil, err } @@ -223,7 +221,7 @@ func (r *ServerRPC) getConn(connID uint16) (net.Conn, error) { return r.assertConn(connIfc) } -// assertListener asserts that `v` is of type `*dmsg.Listener`. +// assertListener asserts that `v` is of type `net.Listener`. func (r *ServerRPC) assertListener(v interface{}) (*dmsg.Listener, error) { lis, ok := v.(*dmsg.Listener) if !ok { diff --git a/pkg/app2/server_rpc_client.go b/pkg/app2/server_rpc_client.go index 71c34b99e9..d3dde3056b 100644 --- a/pkg/app2/server_rpc_client.go +++ b/pkg/app2/server_rpc_client.go @@ -3,16 +3,16 @@ package app2 import ( "net/rpc" - "github.com/skycoin/skywire/pkg/routing" + "github.com/skycoin/skywire/pkg/app2/network" ) //go:generate mockery -name ServerRPCClient -case underscore -inpkg // ServerRPCClient describes RPC interface to communicate with the server. type ServerRPCClient interface { - Dial(remote routing.Addr) (uint16, error) - Listen(local routing.Addr) (uint16, error) - Accept(lisID uint16) (uint16, routing.Addr, error) + Dial(remote network.Addr) (uint16, error) + Listen(local network.Addr) (uint16, error) + Accept(lisID uint16) (uint16, network.Addr, error) Write(connID uint16, b []byte) (int, error) Read(connID uint16, b []byte) (int, []byte, error) CloseConn(id uint16) error @@ -32,7 +32,7 @@ func NewServerRPCClient(rpc *rpc.Client) ServerRPCClient { } // Dial sends `Dial` command to the server. -func (c *serverRPCCLient) Dial(remote routing.Addr) (uint16, error) { +func (c *serverRPCCLient) Dial(remote network.Addr) (uint16, error) { var connID uint16 if err := c.rpc.Call("Dial", &remote, &connID); err != nil { return 0, err @@ -42,7 +42,7 @@ func (c *serverRPCCLient) Dial(remote routing.Addr) (uint16, error) { } // Listen sends `Listen` command to the server. -func (c *serverRPCCLient) Listen(local routing.Addr) (uint16, error) { +func (c *serverRPCCLient) Listen(local network.Addr) (uint16, error) { var lisID uint16 if err := c.rpc.Call("Listen", &local, &lisID); err != nil { return 0, err @@ -52,10 +52,10 @@ func (c *serverRPCCLient) Listen(local routing.Addr) (uint16, error) { } // Accept sends `Accept` command to the server. -func (c *serverRPCCLient) Accept(lisID uint16) (uint16, routing.Addr, error) { +func (c *serverRPCCLient) Accept(lisID uint16) (uint16, network.Addr, error) { var acceptResp AcceptResp if err := c.rpc.Call("Accept", &lisID, &acceptResp); err != nil { - return 0, routing.Addr{}, err + return 0, network.Addr{}, err } return acceptResp.ConnID, acceptResp.Remote, nil From eb42bca02d0a9485789d4a8e977a8d07059807e2 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Sat, 21 Sep 2019 20:36:35 +0300 Subject: [PATCH 29/43] Refactor a bit --- go.mod | 1 + go.sum | 3 ++ pkg/app2/client.go | 11 +++-- pkg/app2/client_test.go | 40 ++++++++--------- pkg/app2/conn.go | 2 +- pkg/app2/conn_test.go | 6 +-- pkg/app2/listener.go | 2 +- pkg/app2/listener_test.go | 20 +++++---- ...erver_rpc_client.go => mock_rpc_client.go} | 24 ++++++----- pkg/app2/network/addr.go | 3 +- pkg/app2/network/networker.go | 8 ++-- .../{server_rpc_client.go => rpc_client.go} | 30 ++++++------- pkg/app2/{server_rpc.go => rpc_gateway.go} | 43 +++++++++---------- vendor/github.com/skycoin/dmsg/go.mod | 4 +- vendor/github.com/skycoin/dmsg/go.sum | 6 --- 15 files changed, 101 insertions(+), 102 deletions(-) rename pkg/app2/{mock_server_rpc_client.go => mock_rpc_client.go} (77%) rename pkg/app2/{server_rpc_client.go => rpc_client.go} (64%) rename pkg/app2/{server_rpc.go => rpc_gateway.go} (77%) diff --git a/go.mod b/go.mod index c0cfa4c314..ec315f5700 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/skycoin/skycoin v0.26.0 github.com/spf13/cobra v0.0.5 github.com/stretchr/testify v1.4.0 + github.com/vektra/mockery v0.0.0-20181123154057-e78b021dcbb5 // indirect go.etcd.io/bbolt v1.3.3 golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7 golang.org/x/net v0.0.0-20190916140828-c8589233b77d diff --git a/go.sum b/go.sum index b9ca83793d..ba051af7ca 100644 --- a/go.sum +++ b/go.sum @@ -130,6 +130,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= +github.com/vektra/mockery v0.0.0-20181123154057-e78b021dcbb5 h1:Xim2mBRFdXzXmKRO8DJg/FJtn/8Fj9NOEpO6+WuMPmk= +github.com/vektra/mockery v0.0.0-20181123154057-e78b021dcbb5/go.mod h1:ppEjwdhyy7Y31EnHRDm1JkChoC7LXIJ7Ex0VYLWtZtQ= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.3 h1:MUGmc65QhB3pIlaQ5bB4LwqSj6GIonVJXpZiaKNyaKk= @@ -165,6 +167,7 @@ golang.org/x/sys v0.0.0-20190801041406-cbf593c0f2f3/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181112210238-4b1f3b6b1646/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190627182818-9947fec5c3ab/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= diff --git a/pkg/app2/client.go b/pkg/app2/client.go index 743fadc433..bd9e203905 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -4,10 +4,10 @@ import ( "context" "net" - "github.com/skycoin/skywire/pkg/app2/network" - "github.com/skycoin/dmsg/cipher" "github.com/skycoin/dmsg/netutil" + + "github.com/skycoin/skywire/pkg/app2/network" "github.com/skycoin/skywire/pkg/routing" ) @@ -15,21 +15,20 @@ import ( type Client struct { pk cipher.PubKey pid ProcID - rpc ServerRPCClient + rpc RPCClient porter *netutil.Porter } // NewClient creates a new `Client`. The `Client` needs to be provided with: -// - log: Logger instance. // - localPK: The local public key of the parent skywire visor. // - pid: The procID assigned for the process that Client is being used by. // - rpc: RPC client to communicate with the server. -func NewClient(localPK cipher.PubKey, pid ProcID, rpc ServerRPCClient, porter *netutil.Porter) *Client { +func NewClient(localPK cipher.PubKey, pid ProcID, rpc RPCClient) *Client { return &Client{ pk: localPK, pid: pid, rpc: rpc, - porter: porter, + porter: netutil.NewPorter(netutil.PorterMinEphemeral), } } diff --git a/pkg/app2/client_test.go b/pkg/app2/client_test.go index 1c49d5c14f..6bc19e2a8d 100644 --- a/pkg/app2/client_test.go +++ b/pkg/app2/client_test.go @@ -3,16 +3,12 @@ package app2 import ( "testing" - "github.com/skycoin/skywire/pkg/app2/network" - "github.com/pkg/errors" + "github.com/skycoin/dmsg/cipher" "github.com/stretchr/testify/require" + "github.com/skycoin/skywire/pkg/app2/network" "github.com/skycoin/skywire/pkg/routing" - - "github.com/skycoin/dmsg/cipher" - - "github.com/skycoin/dmsg/netutil" ) func TestClient_Dial(t *testing.T) { @@ -30,10 +26,10 @@ func TestClient_Dial(t *testing.T) { dialConnID := uint16(1) var dialErr error - rpc := &MockServerRPCClient{} + rpc := &MockRPCClient{} rpc.On("Dial", remote).Return(dialConnID, dialErr) - cl := NewClient(localPK, pid, rpc, netutil.NewPorter(netutil.PorterMinEphemeral)) + cl := NewClient(localPK, pid, rpc) wantConn := &Conn{ id: dialConnID, @@ -62,10 +58,10 @@ func TestClient_Dial(t *testing.T) { t.Run("dial error", func(t *testing.T) { dialErr := errors.New("dial error") - rpc := &MockServerRPCClient{} + rpc := &MockRPCClient{} rpc.On("Dial", remote).Return(uint16(0), dialErr) - cl := NewClient(localPK, pid, rpc, netutil.NewPorter(netutil.PorterMinEphemeral)) + cl := NewClient(localPK, pid, rpc) conn, err := cl.Dial(remote) require.Equal(t, dialErr, err) @@ -79,6 +75,7 @@ func TestClient_Listen(t *testing.T) { port := routing.Port(1) local := network.Addr{ + Net: network.TypeDMSG, PubKey: localPK, Port: port, } @@ -87,10 +84,10 @@ func TestClient_Listen(t *testing.T) { listenLisID := uint16(1) var listenErr error - rpc := &MockServerRPCClient{} + rpc := &MockRPCClient{} rpc.On("Listen", local).Return(listenLisID, listenErr) - cl := NewClient(localPK, pid, rpc, netutil.NewPorter(netutil.PorterMinEphemeral)) + cl := NewClient(localPK, pid, rpc) wantListener := &Listener{ id: listenLisID, @@ -98,7 +95,7 @@ func TestClient_Listen(t *testing.T) { addr: local, } - listener, err := cl.Listen(port) + listener, err := cl.Listen(network.TypeDMSG, port) require.Nil(t, err) appListener, ok := listener.(*Listener) require.True(t, ok) @@ -112,17 +109,16 @@ func TestClient_Listen(t *testing.T) { }) t.Run("port is already bound", func(t *testing.T) { - porter := netutil.NewPorter(netutil.PorterMinEphemeral) - ok, _ := porter.Reserve(uint16(port), nil) - require.True(t, ok) + rpc := &MockRPCClient{} - rpc := &MockServerRPCClient{} + cl := NewClient(localPK, pid, rpc) - cl := NewClient(localPK, pid, rpc, porter) + ok, _ := cl.porter.Reserve(uint16(port), nil) + require.True(t, ok) wantErr := ErrPortAlreadyBound - listener, err := cl.Listen(port) + listener, err := cl.Listen(network.TypeDMSG, port) require.Equal(t, wantErr, err) require.Nil(t, listener) }) @@ -130,12 +126,12 @@ func TestClient_Listen(t *testing.T) { t.Run("listen error", func(t *testing.T) { listenErr := errors.New("listen error") - rpc := &MockServerRPCClient{} + rpc := &MockRPCClient{} rpc.On("Listen", local).Return(uint16(0), listenErr) - cl := NewClient(localPK, pid, rpc, netutil.NewPorter(netutil.PorterMinEphemeral)) + cl := NewClient(localPK, pid, rpc) - listener, err := cl.Listen(port) + listener, err := cl.Listen(network.TypeDMSG, port) require.Equal(t, listenErr, err) require.Nil(t, listener) _, ok := cl.porter.PortValue(uint16(port)) diff --git a/pkg/app2/conn.go b/pkg/app2/conn.go index 632fe0d464..d201f22a62 100644 --- a/pkg/app2/conn.go +++ b/pkg/app2/conn.go @@ -11,7 +11,7 @@ import ( // Implements `net.Conn`. type Conn struct { id uint16 - rpc ServerRPCClient + rpc RPCClient local network.Addr remote network.Addr freeLocalPort func() diff --git a/pkg/app2/conn_test.go b/pkg/app2/conn_test.go index d25de2716b..ef860a6983 100644 --- a/pkg/app2/conn_test.go +++ b/pkg/app2/conn_test.go @@ -35,7 +35,7 @@ func TestConn_Read(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { - rpc := &MockServerRPCClient{} + rpc := &MockRPCClient{} rpc.On("Read", connID, tc.readBuff).Return(tc.readN, tc.readBytes, tc.readErr) conn := &Conn{ @@ -74,7 +74,7 @@ func TestConn_Write(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { - rpc := &MockServerRPCClient{} + rpc := &MockRPCClient{} rpc.On("Write", connID, tc.writeBuff).Return(tc.writeN, tc.writeErr) conn := &Conn{ @@ -107,7 +107,7 @@ func TestConn_Close(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { - rpc := &MockServerRPCClient{} + rpc := &MockRPCClient{} rpc.On("CloseConn", connID).Return(tc.closeErr) conn := &Conn{ diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go index bb38364a7d..56e52160e3 100644 --- a/pkg/app2/listener.go +++ b/pkg/app2/listener.go @@ -10,7 +10,7 @@ import ( // Implements `net.Listener`. type Listener struct { id uint16 - rpc ServerRPCClient + rpc RPCClient addr network.Addr freePort func() } diff --git a/pkg/app2/listener_test.go b/pkg/app2/listener_test.go index e43b384f86..058b070657 100644 --- a/pkg/app2/listener_test.go +++ b/pkg/app2/listener_test.go @@ -4,16 +4,18 @@ import ( "testing" "github.com/pkg/errors" + "github.com/skycoin/dmsg/cipher" "github.com/stretchr/testify/require" - "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skywire/pkg/app2/network" "github.com/skycoin/skywire/pkg/routing" ) func TestListener_Accept(t *testing.T) { lisID := uint16(1) localPK, _ := cipher.GenerateKeyPair() - local := routing.Addr{ + local := network.Addr{ + Net: network.TypeDMSG, PubKey: localPK, Port: routing.Port(100), } @@ -21,13 +23,14 @@ func TestListener_Accept(t *testing.T) { t.Run("ok", func(t *testing.T) { acceptConnID := uint16(1) remotePK, _ := cipher.GenerateKeyPair() - acceptRemote := routing.Addr{ + acceptRemote := network.Addr{ + Net: network.TypeDMSG, PubKey: remotePK, Port: routing.Port(100), } var acceptErr error - rpc := &MockServerRPCClient{} + rpc := &MockRPCClient{} rpc.On("Accept", acceptConnID).Return(acceptConnID, acceptRemote, acceptErr) lis := &Listener{ @@ -50,10 +53,10 @@ func TestListener_Accept(t *testing.T) { t.Run("accept error", func(t *testing.T) { acceptConnID := uint16(0) - acceptRemote := routing.Addr{} + acceptRemote := network.Addr{} acceptErr := errors.New("accept error") - rpc := &MockServerRPCClient{} + rpc := &MockRPCClient{} rpc.On("Accept", lisID).Return(acceptConnID, acceptRemote, acceptErr) lis := &Listener{ @@ -71,7 +74,8 @@ func TestListener_Accept(t *testing.T) { func TestListener_Close(t *testing.T) { lisID := uint16(1) localPK, _ := cipher.GenerateKeyPair() - local := routing.Addr{ + local := network.Addr{ + Net: network.TypeDMSG, PubKey: localPK, Port: routing.Port(100), } @@ -91,7 +95,7 @@ func TestListener_Close(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { - rpc := &MockServerRPCClient{} + rpc := &MockRPCClient{} rpc.On("CloseListener", lisID).Return(tc.closeErr) lis := &Listener{ diff --git a/pkg/app2/mock_server_rpc_client.go b/pkg/app2/mock_rpc_client.go similarity index 77% rename from pkg/app2/mock_server_rpc_client.go rename to pkg/app2/mock_rpc_client.go index 58e88cab76..af8f62fb6c 100644 --- a/pkg/app2/mock_server_rpc_client.go +++ b/pkg/app2/mock_rpc_client.go @@ -2,16 +2,18 @@ package app2 -import mock "github.com/stretchr/testify/mock" -import network "github.com/skycoin/skywire/pkg/app2/network" +import ( + network "github.com/skycoin/skywire/pkg/app2/network" + mock "github.com/stretchr/testify/mock" +) -// MockServerRPCClient is an autogenerated mock type for the ServerRPCClient type -type MockServerRPCClient struct { +// MockRPCClient is an autogenerated mock type for the RPCClient type +type MockRPCClient struct { mock.Mock } // Accept provides a mock function with given fields: lisID -func (_m *MockServerRPCClient) Accept(lisID uint16) (uint16, network.Addr, error) { +func (_m *MockRPCClient) Accept(lisID uint16) (uint16, network.Addr, error) { ret := _m.Called(lisID) var r0 uint16 @@ -39,7 +41,7 @@ func (_m *MockServerRPCClient) Accept(lisID uint16) (uint16, network.Addr, error } // CloseConn provides a mock function with given fields: id -func (_m *MockServerRPCClient) CloseConn(id uint16) error { +func (_m *MockRPCClient) CloseConn(id uint16) error { ret := _m.Called(id) var r0 error @@ -53,7 +55,7 @@ func (_m *MockServerRPCClient) CloseConn(id uint16) error { } // CloseListener provides a mock function with given fields: id -func (_m *MockServerRPCClient) CloseListener(id uint16) error { +func (_m *MockRPCClient) CloseListener(id uint16) error { ret := _m.Called(id) var r0 error @@ -67,7 +69,7 @@ func (_m *MockServerRPCClient) CloseListener(id uint16) error { } // Dial provides a mock function with given fields: remote -func (_m *MockServerRPCClient) Dial(remote network.Addr) (uint16, error) { +func (_m *MockRPCClient) Dial(remote network.Addr) (uint16, error) { ret := _m.Called(remote) var r0 uint16 @@ -88,7 +90,7 @@ func (_m *MockServerRPCClient) Dial(remote network.Addr) (uint16, error) { } // Listen provides a mock function with given fields: local -func (_m *MockServerRPCClient) Listen(local network.Addr) (uint16, error) { +func (_m *MockRPCClient) Listen(local network.Addr) (uint16, error) { ret := _m.Called(local) var r0 uint16 @@ -109,7 +111,7 @@ func (_m *MockServerRPCClient) Listen(local network.Addr) (uint16, error) { } // Read provides a mock function with given fields: connID, b -func (_m *MockServerRPCClient) Read(connID uint16, b []byte) (int, []byte, error) { +func (_m *MockRPCClient) Read(connID uint16, b []byte) (int, []byte, error) { ret := _m.Called(connID, b) var r0 int @@ -139,7 +141,7 @@ func (_m *MockServerRPCClient) Read(connID uint16, b []byte) (int, []byte, error } // Write provides a mock function with given fields: connID, b -func (_m *MockServerRPCClient) Write(connID uint16, b []byte) (int, error) { +func (_m *MockRPCClient) Write(connID uint16, b []byte) (int, error) { ret := _m.Called(connID, b) var r0 int diff --git a/pkg/app2/network/addr.go b/pkg/app2/network/addr.go index 36fec8306d..2db24213e5 100644 --- a/pkg/app2/network/addr.go +++ b/pkg/app2/network/addr.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skywire/pkg/routing" ) @@ -14,7 +15,7 @@ type Addr struct { Port routing.Port } -// Network returns "dmsg" +// Network returns network type. func (a Addr) Network() string { return string(a.Net) } diff --git a/pkg/app2/network/networker.go b/pkg/app2/network/networker.go index ed6277c88e..708b707e78 100644 --- a/pkg/app2/network/networker.go +++ b/pkg/app2/network/networker.go @@ -53,12 +53,12 @@ type Networker interface { ListenContext(ctx context.Context, addr Addr) (net.Listener, error) } -// Dial dials the remote `addr` of the specified `network`. +// Dial dials the remote `addr`. func Dial(addr Addr) (net.Conn, error) { return DialContext(context.Background(), addr) } -// DialContext dials the remote `Addr` of the specified `network` with the context. +// DialContext dials the remote `addr` with the context. func DialContext(ctx context.Context, addr Addr) (net.Conn, error) { n, err := ResolveNetworker(addr.Net) if err != nil { @@ -68,12 +68,12 @@ func DialContext(ctx context.Context, addr Addr) (net.Conn, error) { return n.DialContext(ctx, addr) } -// Listen starts listening on the local `addr` of the specified `network`. +// Listen starts listening on the local `addr`. func Listen(addr Addr) (net.Listener, error) { return ListenContext(context.Background(), addr) } -// ListenContext starts listening on the local `addr` of the specified `network` with the context. +// ListenContext starts listening on the local `addr` with the context. func ListenContext(ctx context.Context, addr Addr) (net.Listener, error) { networker, err := ResolveNetworker(addr.Net) if err != nil { diff --git a/pkg/app2/server_rpc_client.go b/pkg/app2/rpc_client.go similarity index 64% rename from pkg/app2/server_rpc_client.go rename to pkg/app2/rpc_client.go index d3dde3056b..1ed5b93b4d 100644 --- a/pkg/app2/server_rpc_client.go +++ b/pkg/app2/rpc_client.go @@ -6,10 +6,10 @@ import ( "github.com/skycoin/skywire/pkg/app2/network" ) -//go:generate mockery -name ServerRPCClient -case underscore -inpkg +//go:generate mockery -name RPCClient -case underscore -inpkg -// ServerRPCClient describes RPC interface to communicate with the server. -type ServerRPCClient interface { +// RPCClient describes RPC interface to communicate with the server. +type RPCClient interface { Dial(remote network.Addr) (uint16, error) Listen(local network.Addr) (uint16, error) Accept(lisID uint16) (uint16, network.Addr, error) @@ -19,20 +19,20 @@ type ServerRPCClient interface { CloseListener(id uint16) error } -// serverRPCClient implements `ServerRPCClient`. -type serverRPCCLient struct { +// rpcClient implements `RPCClient`. +type rpcCLient struct { rpc *rpc.Client } -// NewServerRPCClient constructs new `serverRPCClient`. -func NewServerRPCClient(rpc *rpc.Client) ServerRPCClient { - return &serverRPCCLient{ +// NewRPCClient constructs new `rpcClient`. +func NewRPCClient(rpc *rpc.Client) RPCClient { + return &rpcCLient{ rpc: rpc, } } // Dial sends `Dial` command to the server. -func (c *serverRPCCLient) Dial(remote network.Addr) (uint16, error) { +func (c *rpcCLient) Dial(remote network.Addr) (uint16, error) { var connID uint16 if err := c.rpc.Call("Dial", &remote, &connID); err != nil { return 0, err @@ -42,7 +42,7 @@ func (c *serverRPCCLient) Dial(remote network.Addr) (uint16, error) { } // Listen sends `Listen` command to the server. -func (c *serverRPCCLient) Listen(local network.Addr) (uint16, error) { +func (c *rpcCLient) Listen(local network.Addr) (uint16, error) { var lisID uint16 if err := c.rpc.Call("Listen", &local, &lisID); err != nil { return 0, err @@ -52,7 +52,7 @@ func (c *serverRPCCLient) Listen(local network.Addr) (uint16, error) { } // Accept sends `Accept` command to the server. -func (c *serverRPCCLient) Accept(lisID uint16) (uint16, network.Addr, error) { +func (c *rpcCLient) Accept(lisID uint16) (uint16, network.Addr, error) { var acceptResp AcceptResp if err := c.rpc.Call("Accept", &lisID, &acceptResp); err != nil { return 0, network.Addr{}, err @@ -62,7 +62,7 @@ func (c *serverRPCCLient) Accept(lisID uint16) (uint16, network.Addr, error) { } // Write sends `Write` command to the server. -func (c *serverRPCCLient) Write(connID uint16, b []byte) (int, error) { +func (c *rpcCLient) Write(connID uint16, b []byte) (int, error) { req := WriteReq{ ConnID: connID, B: b, @@ -77,7 +77,7 @@ func (c *serverRPCCLient) Write(connID uint16, b []byte) (int, error) { } // Read sends `Read` command to the server. -func (c *serverRPCCLient) Read(connID uint16, b []byte) (int, []byte, error) { +func (c *rpcCLient) Read(connID uint16, b []byte) (int, []byte, error) { var resp ReadResp if err := c.rpc.Call("Read", &connID, &resp); err != nil { return 0, nil, err @@ -87,11 +87,11 @@ func (c *serverRPCCLient) Read(connID uint16, b []byte) (int, []byte, error) { } // CloseConn sends `CloseConn` command to the server. -func (c *serverRPCCLient) CloseConn(id uint16) error { +func (c *rpcCLient) CloseConn(id uint16) error { return c.rpc.Call("CloseConn", &id, nil) } // CloseListener sends `CloseListener` command to the server. -func (c *serverRPCCLient) CloseListener(id uint16) error { +func (c *rpcCLient) CloseListener(id uint16) error { return c.rpc.Call("CloseListener", &id, nil) } diff --git a/pkg/app2/server_rpc.go b/pkg/app2/rpc_gateway.go similarity index 77% rename from pkg/app2/server_rpc.go rename to pkg/app2/rpc_gateway.go index bc716e681a..e5e9c7ec0b 100644 --- a/pkg/app2/server_rpc.go +++ b/pkg/app2/rpc_gateway.go @@ -4,23 +4,22 @@ import ( "fmt" "net" - "github.com/skycoin/skywire/pkg/app2/network" - "github.com/pkg/errors" - "github.com/skycoin/dmsg" "github.com/skycoin/skycoin/src/util/logging" + + "github.com/skycoin/skywire/pkg/app2/network" ) -// ServerRPC is a RPC interface for the app server. -type ServerRPC struct { +// RPCGateway is a RPC interface for the app server. +type RPCGateway struct { lm *manager cm *manager log *logging.Logger } -// newServerRPC constructs new server RPC interface. -func newServerRPC(log *logging.Logger) *ServerRPC { - return &ServerRPC{ +// newRPCGateway constructs new server RPC interface. +func newRPCGateway(log *logging.Logger) *RPCGateway { + return &RPCGateway{ lm: newManager(), cm: newManager(), log: log, @@ -28,7 +27,7 @@ func newServerRPC(log *logging.Logger) *ServerRPC { } // Dial dials to the remote. -func (r *ServerRPC) Dial(remote *network.Addr, connID *uint16) error { +func (r *RPCGateway) Dial(remote *network.Addr, connID *uint16) error { connID, err := r.cm.nextKey() if err != nil { return err @@ -51,7 +50,7 @@ func (r *ServerRPC) Dial(remote *network.Addr, connID *uint16) error { } // Listen starts listening. -func (r *ServerRPC) Listen(local *network.Addr, lisID *uint16) error { +func (r *RPCGateway) Listen(local *network.Addr, lisID *uint16) error { lisID, err := r.lm.nextKey() if err != nil { return err @@ -80,7 +79,7 @@ type AcceptResp struct { } // Accept accepts connection from the listener specified by `lisID`. -func (r *ServerRPC) Accept(lisID *uint16, resp *AcceptResp) error { +func (r *RPCGateway) Accept(lisID *uint16, resp *AcceptResp) error { lis, err := r.getListener(*lisID) if err != nil { return err @@ -124,7 +123,7 @@ type WriteReq struct { } // Write writes to the connection. -func (r *ServerRPC) Write(req *WriteReq, n *int) error { +func (r *RPCGateway) Write(req *WriteReq, n *int) error { conn, err := r.getConn(req.ConnID) if err != nil { return err @@ -145,7 +144,7 @@ type ReadResp struct { } // Read reads data from connection specified by `connID`. -func (r *ServerRPC) Read(connID *uint16, resp *ReadResp) error { +func (r *RPCGateway) Read(connID *uint16, resp *ReadResp) error { conn, err := r.getConn(*connID) if err != nil { return err @@ -160,7 +159,7 @@ func (r *ServerRPC) Read(connID *uint16, resp *ReadResp) error { } // CloseConn closes connection specified by `connID`. -func (r *ServerRPC) CloseConn(connID *uint16, _ *struct{}) error { +func (r *RPCGateway) CloseConn(connID *uint16, _ *struct{}) error { conn, err := r.popConn(*connID) if err != nil { return err @@ -170,7 +169,7 @@ func (r *ServerRPC) CloseConn(connID *uint16, _ *struct{}) error { } // CloseListener closes listener specified by `lisID`. -func (r *ServerRPC) CloseListener(lisID *uint16, _ *struct{}) error { +func (r *RPCGateway) CloseListener(lisID *uint16, _ *struct{}) error { lis, err := r.popListener(*lisID) if err != nil { return err @@ -181,7 +180,7 @@ func (r *ServerRPC) CloseListener(lisID *uint16, _ *struct{}) error { // popListener gets listener from the manager by `lisID` and removes it. // Handles type assertion. -func (r *ServerRPC) popListener(lisID uint16) (*dmsg.Listener, error) { +func (r *RPCGateway) popListener(lisID uint16) (net.Listener, error) { lisIfc, err := r.lm.pop(lisID) if err != nil { return nil, err @@ -192,7 +191,7 @@ func (r *ServerRPC) popListener(lisID uint16) (*dmsg.Listener, error) { // popConn gets conn from the manager by `connID` and removes it. // Handles type assertion. -func (r *ServerRPC) popConn(connID uint16) (net.Conn, error) { +func (r *RPCGateway) popConn(connID uint16) (net.Conn, error) { connIfc, err := r.cm.pop(connID) if err != nil { return nil, err @@ -202,7 +201,7 @@ func (r *ServerRPC) popConn(connID uint16) (net.Conn, error) { } // getListener gets listener from the manager by `lisID`. Handles type assertion. -func (r *ServerRPC) getListener(lisID uint16) (*dmsg.Listener, error) { +func (r *RPCGateway) getListener(lisID uint16) (net.Listener, error) { lisIfc, ok := r.lm.get(lisID) if !ok { return nil, fmt.Errorf("no listener with key %d", lisID) @@ -212,7 +211,7 @@ func (r *ServerRPC) getListener(lisID uint16) (*dmsg.Listener, error) { } // getConn gets conn from the manager by `connID`. Handles type assertion. -func (r *ServerRPC) getConn(connID uint16) (net.Conn, error) { +func (r *RPCGateway) getConn(connID uint16) (net.Conn, error) { connIfc, ok := r.cm.get(connID) if !ok { return nil, fmt.Errorf("no conn with key %d", connID) @@ -222,8 +221,8 @@ func (r *ServerRPC) getConn(connID uint16) (net.Conn, error) { } // assertListener asserts that `v` is of type `net.Listener`. -func (r *ServerRPC) assertListener(v interface{}) (*dmsg.Listener, error) { - lis, ok := v.(*dmsg.Listener) +func (r *RPCGateway) assertListener(v interface{}) (net.Listener, error) { + lis, ok := v.(net.Listener) if !ok { return nil, errors.New("wrong type of value stored for listener") } @@ -232,7 +231,7 @@ func (r *ServerRPC) assertListener(v interface{}) (*dmsg.Listener, error) { } // assertConn asserts that `v` is of type `net.Conn`. -func (r *ServerRPC) assertConn(v interface{}) (net.Conn, error) { +func (r *RPCGateway) assertConn(v interface{}) (net.Conn, error) { conn, ok := v.(net.Conn) if !ok { return nil, errors.New("wrong type of value stored for conn") diff --git a/vendor/github.com/skycoin/dmsg/go.mod b/vendor/github.com/skycoin/dmsg/go.mod index a24455c1f5..1ef2c47f66 100644 --- a/vendor/github.com/skycoin/dmsg/go.mod +++ b/vendor/github.com/skycoin/dmsg/go.mod @@ -14,7 +14,7 @@ require ( golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4 // indirect golang.org/x/net v0.0.0-20190620200207-3b0461eec859 golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb // indirect - golang.org/x/text v0.3.2 // indirect - golang.org/x/tools v0.0.0-20190627182818-9947fec5c3ab // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) + +replace github.com/skycoin/dmsg => ../dmsg diff --git a/vendor/github.com/skycoin/dmsg/go.sum b/vendor/github.com/skycoin/dmsg/go.sum index c6a730a9e6..624818fed7 100644 --- a/vendor/github.com/skycoin/dmsg/go.sum +++ b/vendor/github.com/skycoin/dmsg/go.sum @@ -32,11 +32,9 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4 h1:ydJNl0ENAG67pFbB+9tfhiL2pYqLhfoaZFw/cjLhY4A= golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859 h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -45,9 +43,5 @@ golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb h1:fgwFCsaw9buMuxNd6+DQfAuSF golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190627182818-9947fec5c3ab h1:uOzhX2fm3C4BmBwW2a7lnJQD7qel2+4uhmTc8czKBCU= -golang.org/x/tools v0.0.0-20190627182818-9947fec5c3ab/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From 0714e553a37123ba6780ced865e6fd9b9115945b Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Sat, 21 Sep 2019 22:12:30 +0300 Subject: [PATCH 30/43] Add more tests --- pkg/app2/network/mock_networker.go | 107 +++++++++++++++++++++++++ pkg/app2/network/networker.go | 6 +- pkg/app2/network/networker_test.go | 123 +++++++++++++++++++++++++++++ pkg/app2/network/type_test.go | 32 ++++++++ 4 files changed, 266 insertions(+), 2 deletions(-) create mode 100644 pkg/app2/network/mock_networker.go create mode 100644 pkg/app2/network/networker_test.go create mode 100644 pkg/app2/network/type_test.go diff --git a/pkg/app2/network/mock_networker.go b/pkg/app2/network/mock_networker.go new file mode 100644 index 0000000000..560de13d97 --- /dev/null +++ b/pkg/app2/network/mock_networker.go @@ -0,0 +1,107 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package network + +import ( + context "context" + net "net" + + mock "github.com/stretchr/testify/mock" +) + +// MockNetworker is an autogenerated mock type for the Networker type +type MockNetworker struct { + mock.Mock +} + +// Dial provides a mock function with given fields: addr +func (_m *MockNetworker) Dial(addr Addr) (net.Conn, error) { + ret := _m.Called(addr) + + var r0 net.Conn + if rf, ok := ret.Get(0).(func(Addr) net.Conn); ok { + r0 = rf(addr) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Conn) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(Addr) error); ok { + r1 = rf(addr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DialContext provides a mock function with given fields: ctx, addr +func (_m *MockNetworker) DialContext(ctx context.Context, addr Addr) (net.Conn, error) { + ret := _m.Called(ctx, addr) + + var r0 net.Conn + if rf, ok := ret.Get(0).(func(context.Context, Addr) net.Conn); ok { + r0 = rf(ctx, addr) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Conn) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, Addr) error); ok { + r1 = rf(ctx, addr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Listen provides a mock function with given fields: addr +func (_m *MockNetworker) Listen(addr Addr) (net.Listener, error) { + ret := _m.Called(addr) + + var r0 net.Listener + if rf, ok := ret.Get(0).(func(Addr) net.Listener); ok { + r0 = rf(addr) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Listener) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(Addr) error); ok { + r1 = rf(addr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ListenContext provides a mock function with given fields: ctx, addr +func (_m *MockNetworker) ListenContext(ctx context.Context, addr Addr) (net.Listener, error) { + ret := _m.Called(ctx, addr) + + var r0 net.Listener + if rf, ok := ret.Get(0).(func(context.Context, Addr) net.Listener); ok { + r0 = rf(ctx, addr) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Listener) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, Addr) error); ok { + r1 = rf(ctx, addr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/app2/network/networker.go b/pkg/app2/network/networker.go index 708b707e78..e25c05b137 100644 --- a/pkg/app2/network/networker.go +++ b/pkg/app2/network/networker.go @@ -7,6 +7,8 @@ import ( "sync" ) +//go:generate mockery -name Networker -case underscore -inpkg + var ( // ErrNoSuchNetworker is being returned when there's no suitable networker. ErrNoSuchNetworker = errors.New("no such networker") @@ -15,11 +17,11 @@ var ( ) var ( - networkers = map[Type]Networker{} + networkers = make(map[Type]Networker) networkersMx sync.RWMutex ) -// AddNetworker associated Networker with the `network`. +// AddNetworker associates Networker with the `network`. func AddNetworker(t Type, n Networker) error { networkersMx.Lock() defer networkersMx.Unlock() diff --git a/pkg/app2/network/networker_test.go b/pkg/app2/network/networker_test.go new file mode 100644 index 0000000000..526801daf9 --- /dev/null +++ b/pkg/app2/network/networker_test.go @@ -0,0 +1,123 @@ +package network + +import ( + "context" + "net" + "testing" + + "github.com/skycoin/skywire/pkg/routing" + + "github.com/skycoin/dmsg/cipher" + + "github.com/stretchr/testify/require" +) + +func TestAddNetworker(t *testing.T) { + clearNetworkers() + + nType := Type(TypeDMSG) + var n Networker + + err := AddNetworker(nType, n) + require.NoError(t, err) + + err = AddNetworker(nType, n) + require.Equal(t, err, ErrNetworkerAlreadyExists) +} + +func TestResolveNetworker(t *testing.T) { + clearNetworkers() + + nType := Type(TypeDMSG) + var n Networker + + n, err := ResolveNetworker(nType) + require.Equal(t, err, ErrNoSuchNetworker) + + err = AddNetworker(nType, n) + require.NoError(t, err) + + gotN, err := ResolveNetworker(nType) + require.NoError(t, err) + require.Equal(t, gotN, n) +} + +func TestDial(t *testing.T) { + addr := prepAddr() + + t.Run("no such networker", func(t *testing.T) { + clearNetworkers() + + _, err := Dial(addr) + require.Equal(t, err, ErrNoSuchNetworker) + }) + + t.Run("ok", func(t *testing.T) { + clearNetworkers() + + dialCtx := context.Background() + var ( + dialConn net.Conn + dialErr error + ) + + n := &MockNetworker{} + n.On("DialContext", dialCtx, addr).Return(dialConn, dialErr) + + err := AddNetworker(addr.Net, n) + require.NoError(t, err) + + conn, err := Dial(addr) + require.NoError(t, err) + require.Equal(t, conn, dialConn) + }) +} + +func TestListen(t *testing.T) { + addr := prepAddr() + + t.Run("no such networker", func(t *testing.T) { + clearNetworkers() + + _, err := Listen(addr) + require.Equal(t, err, ErrNoSuchNetworker) + }) + + t.Run("ok", func(t *testing.T) { + clearNetworkers() + + listenCtx := context.Background() + var ( + listenLis net.Listener + listenErr error + ) + + n := &MockNetworker{} + n.On("ListenContext", listenCtx, addr).Return(listenLis, listenErr) + + err := AddNetworker(addr.Net, n) + require.NoError(t, err) + + lis, err := Listen(addr) + require.NoError(t, err) + require.Equal(t, lis, listenLis) + }) +} + +func prepAddr() Addr { + addrPK, _ := cipher.GenerateKeyPair() + addrPort := routing.Port(100) + + return Addr{ + Net: TypeDMSG, + PubKey: addrPK, + Port: addrPort, + } +} + +func clearNetworkers() { + networkersMx.Lock() + defer networkersMx.Unlock() + + networkers = make(map[Type]Networker) +} diff --git a/pkg/app2/network/type_test.go b/pkg/app2/network/type_test.go new file mode 100644 index 0000000000..632ade9c7d --- /dev/null +++ b/pkg/app2/network/type_test.go @@ -0,0 +1,32 @@ +package network + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestType_IsValid(t *testing.T) { + tt := []struct { + name string + t Type + want bool + }{ + { + name: "valid", + t: TypeDMSG, + want: true, + }, + { + name: "not valid", + t: "not valid", + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + valid := tc.t.IsValid() + require.Equal(t, tc.want, valid) + }) + } +} From 34919949a2b73d7b05459958960c3577ea88d981 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Sat, 21 Sep 2019 22:35:06 +0300 Subject: [PATCH 31/43] Add even more tests --- pkg/app2/network/networker_test.go | 4 +-- pkg/app2/network/type.go | 2 +- pkg/app2/rpc_gateway.go | 6 ++-- pkg/app2/rpc_gateway_test.go | 49 ++++++++++++++++++++++++++++++ 4 files changed, 56 insertions(+), 5 deletions(-) create mode 100644 pkg/app2/rpc_gateway_test.go diff --git a/pkg/app2/network/networker_test.go b/pkg/app2/network/networker_test.go index 526801daf9..c03bc02bcf 100644 --- a/pkg/app2/network/networker_test.go +++ b/pkg/app2/network/networker_test.go @@ -15,7 +15,7 @@ import ( func TestAddNetworker(t *testing.T) { clearNetworkers() - nType := Type(TypeDMSG) + nType := TypeDMSG var n Networker err := AddNetworker(nType, n) @@ -28,7 +28,7 @@ func TestAddNetworker(t *testing.T) { func TestResolveNetworker(t *testing.T) { clearNetworkers() - nType := Type(TypeDMSG) + nType := TypeDMSG var n Networker n, err := ResolveNetworker(nType) diff --git a/pkg/app2/network/type.go b/pkg/app2/network/type.go index d91c6a0dcc..c91f9128d2 100644 --- a/pkg/app2/network/type.go +++ b/pkg/app2/network/type.go @@ -5,7 +5,7 @@ type Type string const ( // TypeDMSG is a network type for DMSG communication. - TypeDMSG = "dmsg" + TypeDMSG Type = "dmsg" ) // IsValid checks whether the network contains valid value for the type. diff --git a/pkg/app2/rpc_gateway.go b/pkg/app2/rpc_gateway.go index e5e9c7ec0b..d85cc50946 100644 --- a/pkg/app2/rpc_gateway.go +++ b/pkg/app2/rpc_gateway.go @@ -28,7 +28,7 @@ func newRPCGateway(log *logging.Logger) *RPCGateway { // Dial dials to the remote. func (r *RPCGateway) Dial(remote *network.Addr, connID *uint16) error { - connID, err := r.cm.nextKey() + reservedConnID, err := r.cm.nextKey() if err != nil { return err } @@ -38,7 +38,7 @@ func (r *RPCGateway) Dial(remote *network.Addr, connID *uint16) error { return err } - if err := r.cm.set(*connID, conn); err != nil { + if err := r.cm.set(*reservedConnID, conn); err != nil { if err := conn.Close(); err != nil { r.log.WithError(err).Error("error closing conn") } @@ -46,6 +46,8 @@ func (r *RPCGateway) Dial(remote *network.Addr, connID *uint16) error { return err } + *connID = *reservedConnID + return nil } diff --git a/pkg/app2/rpc_gateway_test.go b/pkg/app2/rpc_gateway_test.go new file mode 100644 index 0000000000..d6c89de2e6 --- /dev/null +++ b/pkg/app2/rpc_gateway_test.go @@ -0,0 +1,49 @@ +package app2 + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/skycoin/dmsg" + + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skywire/pkg/routing" + + "github.com/skycoin/skywire/pkg/app2/network" + + "github.com/skycoin/skycoin/src/util/logging" +) + +func TestRPCGateway_Dial(t *testing.T) { + l := logging.MustGetLogger("rpc_gateway") + nType := network.TypeDMSG + + dialCtx := context.Background() + dialAddrPK, _ := cipher.GenerateKeyPair() + dialAddrPort := routing.Port(100) + dialAddr := network.Addr{ + Net: nType, + PubKey: dialAddrPK, + Port: dialAddrPort, + } + dialConn := &dmsg.Transport{} + var dialErr error + + n := &network.MockNetworker{} + n.On("DialContext", dialCtx, dialAddr).Return(dialConn, dialErr) + + err := network.AddNetworker(nType, n) + require.NoError(t, err) + + rpc := newRPCGateway(l) + + t.Run("ok", func(t *testing.T) { + var connID uint16 + + err := rpc.Dial(&dialAddr, &connID) + require.NoError(t, err) + require.Equal(t, connID, uint16(1)) + }) +} From 07d5b80ffd2525bb4ad9335ae1b9809146a21e6c Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Sun, 22 Sep 2019 19:06:37 +0300 Subject: [PATCH 32/43] And more tests --- pkg/app2/mock_conn.go | 141 +++++++++++++++++++++++++ pkg/app2/mock_listener.go | 65 ++++++++++++ pkg/app2/network/dmsg_conn.go | 64 +++++++++++ pkg/app2/network/dmsg_networker.go | 7 +- pkg/app2/network/networker.go | 8 ++ pkg/app2/network/networker_test.go | 24 ++--- pkg/app2/rpc_gateway.go | 6 +- pkg/app2/rpc_gateway_test.go | 164 ++++++++++++++++++++++++++--- 8 files changed, 443 insertions(+), 36 deletions(-) create mode 100644 pkg/app2/mock_conn.go create mode 100644 pkg/app2/mock_listener.go create mode 100644 pkg/app2/network/dmsg_conn.go diff --git a/pkg/app2/mock_conn.go b/pkg/app2/mock_conn.go new file mode 100644 index 0000000000..981c8c3084 --- /dev/null +++ b/pkg/app2/mock_conn.go @@ -0,0 +1,141 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package app2 + +import ( + "net" + "time" + + "github.com/stretchr/testify/mock" +) + +// MockConn is an autogenerated mock type for the Conn type +type MockConn struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *MockConn) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// LocalAddr provides a mock function with given fields: +func (_m *MockConn) LocalAddr() net.Addr { + ret := _m.Called() + + var r0 net.Addr + if rf, ok := ret.Get(0).(func() net.Addr); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(net.Addr) + } + + return r0 +} + +// Read provides a mock function with given fields: b +func (_m *MockConn) Read(b []byte) (int, error) { + ret := _m.Called(b) + + var r0 int + if rf, ok := ret.Get(0).(func([]byte) int); ok { + r0 = rf(b) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(b) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RemoteAddr provides a mock function with given fields: +func (_m *MockConn) RemoteAddr() net.Addr { + ret := _m.Called() + + var r0 net.Addr + if rf, ok := ret.Get(0).(func() net.Addr); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(net.Addr) + } + + return r0 +} + +// SetDeadline provides a mock function with given fields: t +func (_m *MockConn) SetDeadline(t time.Time) error { + ret := _m.Called(t) + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(t) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetReadDeadline provides a mock function with given fields: t +func (_m *MockConn) SetReadDeadline(t time.Time) error { + ret := _m.Called(t) + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(t) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetWriteDeadline provides a mock function with given fields: t +func (_m *MockConn) SetWriteDeadline(t time.Time) error { + ret := _m.Called(t) + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(t) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Write provides a mock function with given fields: b +func (_m *MockConn) Write(b []byte) (int, error) { + ret := _m.Called(b) + + var r0 int + if rf, ok := ret.Get(0).(func([]byte) int); ok { + r0 = rf(b) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(b) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/pkg/app2/mock_listener.go b/pkg/app2/mock_listener.go new file mode 100644 index 0000000000..44fda81bd5 --- /dev/null +++ b/pkg/app2/mock_listener.go @@ -0,0 +1,65 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package app2 + +import ( + "net" + + "github.com/stretchr/testify/mock" +) + +// MockListener is an autogenerated mock type for the Listener type +type MockListener struct { + mock.Mock +} + +// Accept provides a mock function with given fields: +func (_m *MockListener) Accept() (net.Conn, error) { + ret := _m.Called() + + var r0 net.Conn + if rf, ok := ret.Get(0).(func() net.Conn); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Conn) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Addr provides a mock function with given fields: +func (_m *MockListener) Addr() net.Addr { + ret := _m.Called() + + var r0 net.Addr + if rf, ok := ret.Get(0).(func() net.Addr); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(net.Addr) + } + + return r0 +} + +// Close provides a mock function with given fields: +func (_m *MockListener) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/pkg/app2/network/dmsg_conn.go b/pkg/app2/network/dmsg_conn.go new file mode 100644 index 0000000000..9c6e8220bc --- /dev/null +++ b/pkg/app2/network/dmsg_conn.go @@ -0,0 +1,64 @@ +package network + +import ( + "net" + "time" + + "github.com/skycoin/skywire/pkg/routing" + + "github.com/skycoin/dmsg" +) + +type DMSGConn struct { + tp *dmsg.Transport +} + +func (c *DMSGConn) Read(b []byte) (n int, err error) { + return c.tp.Read(b) +} + +func (c *DMSGConn) Write(b []byte) (n int, err error) { + return c.tp.Write(b) +} + +func (c *DMSGConn) Close() error { + return c.tp.Close() +} + +func (c *DMSGConn) LocalAddr() net.Addr { + dmsgAddr, ok := c.tp.LocalAddr().(dmsg.Addr) + if !ok { + return c.tp.LocalAddr() + } + + return Addr{ + Net: TypeDMSG, + PubKey: dmsgAddr.PK, + Port: routing.Port(dmsgAddr.Port), + } +} + +func (c *DMSGConn) RemoteAddr() net.Addr { + dmsgAddr, ok := c.tp.RemoteAddr().(dmsg.Addr) + if !ok { + return c.tp.RemoteAddr() + } + + return Addr{ + Net: TypeDMSG, + PubKey: dmsgAddr.PK, + Port: routing.Port(dmsgAddr.Port), + } +} + +func (c *DMSGConn) SetDeadline(t time.Time) error { + return c.tp.SetDeadline(t) +} + +func (c *DMSGConn) SetReadDeadline(t time.Time) error { + return c.tp.SetReadDeadline(t) +} + +func (c *DMSGConn) SetWriteDeadline(t time.Time) error { + return c.tp.SetWriteDeadline(t) +} diff --git a/pkg/app2/network/dmsg_networker.go b/pkg/app2/network/dmsg_networker.go index 424b0df466..d772260df4 100644 --- a/pkg/app2/network/dmsg_networker.go +++ b/pkg/app2/network/dmsg_networker.go @@ -26,7 +26,12 @@ func (n *DMSGNetworker) Dial(addr Addr) (net.Conn, error) { // DialContext dials remote `addr` via dmsg network with context. func (n *DMSGNetworker) DialContext(ctx context.Context, addr Addr) (net.Conn, error) { - return n.dmsgC.Dial(ctx, addr.PubKey, uint16(addr.Port)) + tp, err := n.dmsgC.Dial(ctx, addr.PubKey, uint16(addr.Port)) + if err != nil { + return nil, err + } + + return &DMSGConn{tp: tp}, nil } // Listen starts listening on local `addr` in the dmsg network. diff --git a/pkg/app2/network/networker.go b/pkg/app2/network/networker.go index e25c05b137..41fb18b6cd 100644 --- a/pkg/app2/network/networker.go +++ b/pkg/app2/network/networker.go @@ -47,6 +47,14 @@ func ResolveNetworker(t Type) (Networker, error) { return n, nil } +// ClearNetworkers removes all the stored networkers. +func ClearNetworkers() { + networkersMx.Lock() + defer networkersMx.Unlock() + + networkers = make(map[Type]Networker) +} + // Networker defines basic network operations, such as Dial/Listen. type Networker interface { Dial(addr Addr) (net.Conn, error) diff --git a/pkg/app2/network/networker_test.go b/pkg/app2/network/networker_test.go index c03bc02bcf..f6692b838a 100644 --- a/pkg/app2/network/networker_test.go +++ b/pkg/app2/network/networker_test.go @@ -5,15 +5,14 @@ import ( "net" "testing" - "github.com/skycoin/skywire/pkg/routing" - "github.com/skycoin/dmsg/cipher" - "github.com/stretchr/testify/require" + + "github.com/skycoin/skywire/pkg/routing" ) func TestAddNetworker(t *testing.T) { - clearNetworkers() + ClearNetworkers() nType := TypeDMSG var n Networker @@ -26,7 +25,7 @@ func TestAddNetworker(t *testing.T) { } func TestResolveNetworker(t *testing.T) { - clearNetworkers() + ClearNetworkers() nType := TypeDMSG var n Networker @@ -46,14 +45,14 @@ func TestDial(t *testing.T) { addr := prepAddr() t.Run("no such networker", func(t *testing.T) { - clearNetworkers() + ClearNetworkers() _, err := Dial(addr) require.Equal(t, err, ErrNoSuchNetworker) }) t.Run("ok", func(t *testing.T) { - clearNetworkers() + ClearNetworkers() dialCtx := context.Background() var ( @@ -77,14 +76,14 @@ func TestListen(t *testing.T) { addr := prepAddr() t.Run("no such networker", func(t *testing.T) { - clearNetworkers() + ClearNetworkers() _, err := Listen(addr) require.Equal(t, err, ErrNoSuchNetworker) }) t.Run("ok", func(t *testing.T) { - clearNetworkers() + ClearNetworkers() listenCtx := context.Background() var ( @@ -114,10 +113,3 @@ func prepAddr() Addr { Port: addrPort, } } - -func clearNetworkers() { - networkersMx.Lock() - defer networkersMx.Unlock() - - networkers = make(map[Type]Networker) -} diff --git a/pkg/app2/rpc_gateway.go b/pkg/app2/rpc_gateway.go index d85cc50946..185eedc63f 100644 --- a/pkg/app2/rpc_gateway.go +++ b/pkg/app2/rpc_gateway.go @@ -53,7 +53,7 @@ func (r *RPCGateway) Dial(remote *network.Addr, connID *uint16) error { // Listen starts listening. func (r *RPCGateway) Listen(local *network.Addr, lisID *uint16) error { - lisID, err := r.lm.nextKey() + nextLisID, err := r.lm.nextKey() if err != nil { return err } @@ -63,7 +63,7 @@ func (r *RPCGateway) Listen(local *network.Addr, lisID *uint16) error { return err } - if err := r.lm.set(*lisID, l); err != nil { + if err := r.lm.set(*nextLisID, l); err != nil { if err := l.Close(); err != nil { r.log.WithError(err).Error("error closing listener") } @@ -71,6 +71,8 @@ func (r *RPCGateway) Listen(local *network.Addr, lisID *uint16) error { return err } + *lisID = *nextLisID + return nil } diff --git a/pkg/app2/rpc_gateway_test.go b/pkg/app2/rpc_gateway_test.go index d6c89de2e6..d230ed120a 100644 --- a/pkg/app2/rpc_gateway_test.go +++ b/pkg/app2/rpc_gateway_test.go @@ -2,8 +2,11 @@ package app2 import ( "context" + "math" + "net" "testing" + "github.com/pkg/errors" "github.com/stretchr/testify/require" "github.com/skycoin/dmsg" @@ -20,30 +23,157 @@ func TestRPCGateway_Dial(t *testing.T) { l := logging.MustGetLogger("rpc_gateway") nType := network.TypeDMSG - dialCtx := context.Background() - dialAddrPK, _ := cipher.GenerateKeyPair() - dialAddrPort := routing.Port(100) - dialAddr := network.Addr{ - Net: nType, - PubKey: dialAddrPK, - Port: dialAddrPort, - } - dialConn := &dmsg.Transport{} - var dialErr error + dialAddr := prepAddr(nType) - n := &network.MockNetworker{} - n.On("DialContext", dialCtx, dialAddr).Return(dialConn, dialErr) + t.Run("ok", func(t *testing.T) { + network.ClearNetworkers() - err := network.AddNetworker(nType, n) - require.NoError(t, err) + dialCtx := context.Background() + dialConn := &dmsg.Transport{} + var dialErr error - rpc := newRPCGateway(l) + n := &network.MockNetworker{} + n.On("DialContext", dialCtx, dialAddr).Return(dialConn, dialErr) + + err := network.AddNetworker(nType, n) + require.NoError(t, err) + + rpc := newRPCGateway(l) - t.Run("ok", func(t *testing.T) { var connID uint16 - err := rpc.Dial(&dialAddr, &connID) + err = rpc.Dial(&dialAddr, &connID) require.NoError(t, err) require.Equal(t, connID, uint16(1)) }) + + t.Run("no more slots for a new conn", func(t *testing.T) { + rpc := newRPCGateway(l) + for i := uint16(0); i < math.MaxUint16; i++ { + rpc.cm.values[i] = nil + } + rpc.cm.values[math.MaxUint16] = nil + + var connID uint16 + + err := rpc.Dial(&dialAddr, &connID) + require.Equal(t, err, errNoMoreAvailableValues) + }) + + t.Run("dial error", func(t *testing.T) { + network.ClearNetworkers() + + dialCtx := context.Background() + var dialConn net.Conn + dialErr := errors.New("dial error") + + n := &network.MockNetworker{} + n.On("DialContext", dialCtx, dialAddr).Return(dialConn, dialErr) + + err := network.AddNetworker(nType, n) + require.NoError(t, err) + + rpc := newRPCGateway(l) + + var connID uint16 + + err = rpc.Dial(&dialAddr, &connID) + require.Equal(t, err, dialErr) + }) +} + +func TestRPCGateway_Listen(t *testing.T) { + l := logging.MustGetLogger("rpc_gateway") + nType := network.TypeDMSG + + listenAddr := prepAddr(nType) + + t.Run("ok", func(t *testing.T) { + network.ClearNetworkers() + + listenCtx := context.Background() + listenLis := &dmsg.Listener{} + var listenErr error + + n := &network.MockNetworker{} + n.On("ListenContext", listenCtx, listenAddr).Return(listenLis, listenErr) + + err := network.AddNetworker(nType, n) + require.Equal(t, err, listenErr) + + rpc := newRPCGateway(l) + + var lisID uint16 + + err = rpc.Listen(&listenAddr, &lisID) + require.NoError(t, err) + require.Equal(t, lisID, uint16(1)) + }) + + t.Run("no more slots for a new listener", func(t *testing.T) { + rpc := newRPCGateway(l) + for i := uint16(0); i < math.MaxUint16; i++ { + rpc.lm.values[i] = nil + } + rpc.lm.values[math.MaxUint16] = nil + + var lisID uint16 + + err := rpc.Listen(&listenAddr, &lisID) + require.Equal(t, err, errNoMoreAvailableValues) + }) + + t.Run("listen error", func(t *testing.T) { + network.ClearNetworkers() + + listenCtx := context.Background() + var listenLis net.Listener + listenErr := errors.New("listen error") + + n := &network.MockNetworker{} + n.On("ListenContext", listenCtx, listenAddr).Return(listenLis, listenErr) + + err := network.AddNetworker(nType, n) + require.NoError(t, err) + + rpc := newRPCGateway(l) + + var lisID uint16 + + err = rpc.Listen(&listenAddr, &lisID) + require.Equal(t, err, listenErr) + }) +} + +func TestRPCGateway_Accept(t *testing.T) { + l := logging.MustGetLogger("rpc_gateway") + + rpc := newRPCGateway(l) + + lisID, err := rpc.lm.nextKey() + require.NoError(t, err) + + acceptConn := &dmsg.Transport{} + var acceptErr error + + lis := &MockListener{} + lis.On("Accept").Return(acceptConn, acceptErr) + + err = rpc.lm.set(*lisID, lis) + require.NoError(t, err) + + var resp AcceptResp + err = rpc.Accept(lisID, &resp) + require.NoError(t, err) +} + +func prepAddr(nType network.Type) network.Addr { + pk, _ := cipher.GenerateKeyPair() + port := routing.Port(100) + + return network.Addr{ + Net: nType, + PubKey: pk, + Port: port, + } } From b2dcbbb73fb7504301d6e5a560b68442f8a594d8 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Mon, 23 Sep 2019 14:22:12 +0300 Subject: [PATCH 33/43] Fix some queries --- pkg/app2/manager.go | 20 ++++++++++---------- pkg/app2/manager_test.go | 34 +++++++++++++++++----------------- pkg/app2/rpc_client.go | 13 ++++++++++--- pkg/app2/rpc_gateway.go | 24 +++++++++++++++++------- 4 files changed, 54 insertions(+), 37 deletions(-) diff --git a/pkg/app2/manager.go b/pkg/app2/manager.go index ce0e8f3f00..ffb6ebbfcf 100644 --- a/pkg/app2/manager.go +++ b/pkg/app2/manager.go @@ -11,24 +11,24 @@ var ( errNoMoreAvailableValues = errors.New("no more available values") ) -// manager manages allows to store and retrieve arbitrary values +// idManager manages allows to store and retrieve arbitrary values // associated with the `uint16` key in a thread-safe manner. // Provides method to generate key. -type manager struct { +type idManager struct { values map[uint16]interface{} mx sync.RWMutex lstKey uint16 } -// newManager constructs new `manager`. -func newManager() *manager { - return &manager{ +// newIDManager constructs new `idManager`. +func newIDManager() *idManager { + return &idManager{ values: make(map[uint16]interface{}), } } // `nextKey` reserves next free slot for the value and returns the key for it. -func (m *manager) nextKey() (*uint16, error) { +func (m *idManager) nextKey() (*uint16, error) { m.mx.Lock() nxtKey := m.lstKey + 1 @@ -50,9 +50,9 @@ func (m *manager) nextKey() (*uint16, error) { return &nxtKey, nil } -// getAndRemove removes value specified by `key` from the manager instance and +// pop removes value specified by `key` from the idManager instance and // returns it. -func (m *manager) pop(key uint16) (interface{}, error) { +func (m *idManager) pop(key uint16) (interface{}, error) { m.mx.Lock() v, ok := m.values[key] if !ok { @@ -72,7 +72,7 @@ func (m *manager) pop(key uint16) (interface{}, error) { } // set sets value `v` associated with `key`. -func (m *manager) set(key uint16, v interface{}) error { +func (m *idManager) set(key uint16, v interface{}) error { m.mx.Lock() l, ok := m.values[key] @@ -93,7 +93,7 @@ func (m *manager) set(key uint16, v interface{}) error { } // get gets the value associated with the `key`. -func (m *manager) get(key uint16) (interface{}, bool) { +func (m *idManager) get(key uint16) (interface{}, bool) { m.mx.RLock() lis, ok := m.values[key] m.mx.RUnlock() diff --git a/pkg/app2/manager_test.go b/pkg/app2/manager_test.go index d67fe33179..0b9de7bae4 100644 --- a/pkg/app2/manager_test.go +++ b/pkg/app2/manager_test.go @@ -7,9 +7,9 @@ import ( "github.com/stretchr/testify/require" ) -func TestManager_NextID(t *testing.T) { +func TestIDManager_NextID(t *testing.T) { t.Run("simple call", func(t *testing.T) { - m := newManager() + m := newIDManager() nextKey, err := m.nextKey() require.NoError(t, err) @@ -29,7 +29,7 @@ func TestManager_NextID(t *testing.T) { }) t.Run("call on full manager", func(t *testing.T) { - m := newManager() + m := newIDManager() for i := uint16(0); i < math.MaxUint16; i++ { m.values[i] = nil } @@ -40,7 +40,7 @@ func TestManager_NextID(t *testing.T) { }) t.Run("concurrent run", func(t *testing.T) { - m := newManager() + m := newIDManager() valsToReserve := 10000 @@ -66,9 +66,9 @@ func TestManager_NextID(t *testing.T) { }) } -func TestManager_Pop(t *testing.T) { +func TestIDManager_Pop(t *testing.T) { t.Run("simple call", func(t *testing.T) { - m := newManager() + m := newIDManager() v := "value" @@ -84,14 +84,14 @@ func TestManager_Pop(t *testing.T) { }) t.Run("no value", func(t *testing.T) { - m := newManager() + m := newIDManager() _, err := m.pop(1) require.Error(t, err) }) t.Run("value not set", func(t *testing.T) { - m := newManager() + m := newIDManager() m.values[1] = nil @@ -100,7 +100,7 @@ func TestManager_Pop(t *testing.T) { }) t.Run("concurrent run", func(t *testing.T) { - m := newManager() + m := newIDManager() m.values[1] = "value" @@ -128,9 +128,9 @@ func TestManager_Pop(t *testing.T) { }) } -func TestManager_Set(t *testing.T) { +func TestIDManager_Set(t *testing.T) { t.Run("simple call", func(t *testing.T) { - m := newManager() + m := newIDManager() nextKey, err := m.nextKey() require.NoError(t, err) @@ -145,7 +145,7 @@ func TestManager_Set(t *testing.T) { }) t.Run("key is not reserved", func(t *testing.T) { - m := newManager() + m := newIDManager() err := m.set(1, "value") require.Error(t, err) @@ -155,7 +155,7 @@ func TestManager_Set(t *testing.T) { }) t.Run("value already exists", func(t *testing.T) { - m := newManager() + m := newIDManager() v := "value" @@ -169,7 +169,7 @@ func TestManager_Set(t *testing.T) { }) t.Run("concurrent run", func(t *testing.T) { - m := newManager() + m := newIDManager() concurrency := 1000 @@ -208,9 +208,9 @@ func TestManager_Set(t *testing.T) { }) } -func TestManager_Get(t *testing.T) { - prepManagerWithVal := func(v interface{}) (*manager, uint16) { - m := newManager() +func TestIDManager_Get(t *testing.T) { + prepManagerWithVal := func(v interface{}) (*idManager, uint16) { + m := newIDManager() nextKey, err := m.nextKey() require.NoError(t, err) diff --git a/pkg/app2/rpc_client.go b/pkg/app2/rpc_client.go index 1ed5b93b4d..1e6b0377a5 100644 --- a/pkg/app2/rpc_client.go +++ b/pkg/app2/rpc_client.go @@ -12,7 +12,7 @@ import ( type RPCClient interface { Dial(remote network.Addr) (uint16, error) Listen(local network.Addr) (uint16, error) - Accept(lisID uint16) (uint16, network.Addr, error) + Accept(lisID uint16) (connID uint16, remote network.Addr, err error) Write(connID uint16, b []byte) (int, error) Read(connID uint16, b []byte) (int, []byte, error) CloseConn(id uint16) error @@ -52,7 +52,7 @@ func (c *rpcCLient) Listen(local network.Addr) (uint16, error) { } // Accept sends `Accept` command to the server. -func (c *rpcCLient) Accept(lisID uint16) (uint16, network.Addr, error) { +func (c *rpcCLient) Accept(lisID uint16) (connID uint16, remote network.Addr, err error) { var acceptResp AcceptResp if err := c.rpc.Call("Accept", &lisID, &acceptResp); err != nil { return 0, network.Addr{}, err @@ -78,11 +78,18 @@ func (c *rpcCLient) Write(connID uint16, b []byte) (int, error) { // Read sends `Read` command to the server. func (c *rpcCLient) Read(connID uint16, b []byte) (int, []byte, error) { + req := ReadReq{ + ConnID: connID, + BufLen: len(b), + } + var resp ReadResp - if err := c.rpc.Call("Read", &connID, &resp); err != nil { + if err := c.rpc.Call("Read", &req, &resp); err != nil { return 0, nil, err } + copy(b[:resp.N], resp.B[:resp.N]) + return resp.N, resp.B, nil } diff --git a/pkg/app2/rpc_gateway.go b/pkg/app2/rpc_gateway.go index 185eedc63f..a84652105c 100644 --- a/pkg/app2/rpc_gateway.go +++ b/pkg/app2/rpc_gateway.go @@ -12,16 +12,16 @@ import ( // RPCGateway is a RPC interface for the app server. type RPCGateway struct { - lm *manager - cm *manager + lm *idManager + cm *idManager log *logging.Logger } // newRPCGateway constructs new server RPC interface. func newRPCGateway(log *logging.Logger) *RPCGateway { return &RPCGateway{ - lm: newManager(), - cm: newManager(), + lm: newIDManager(), // contains listeners associated with their IDs + cm: newIDManager(), // contains connections associated with their IDs log: log, } } @@ -141,6 +141,12 @@ func (r *RPCGateway) Write(req *WriteReq, n *int) error { return nil } +// ReadReq contains arguments for `Read`. +type ReadReq struct { + ConnID uint16 + BufLen int +} + // ReadResp contains response parameters for `Read`. type ReadResp struct { B []byte @@ -148,17 +154,21 @@ type ReadResp struct { } // Read reads data from connection specified by `connID`. -func (r *RPCGateway) Read(connID *uint16, resp *ReadResp) error { - conn, err := r.getConn(*connID) +func (r *RPCGateway) Read(req *ReadReq, resp *ReadResp) error { + conn, err := r.getConn(req.ConnID) if err != nil { return err } - resp.N, err = conn.Read(resp.B) + buf := make([]byte, req.BufLen) + resp.N, err = conn.Read(buf) if err != nil { return err } + resp.B = make([]byte, resp.N) + copy(resp.B, buf[:resp.N]) + return nil } From e243192a1e1beccb0fcc3658d9ac6a4cb5243094 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Mon, 23 Sep 2019 16:01:44 +0300 Subject: [PATCH 34/43] Finish RPC gateway tests --- pkg/app2/{manager.go => id_manager.go} | 3 + .../{manager_test.go => id_manager_test.go} | 0 pkg/app2/network/dmsg_conn.go | 4 + pkg/app2/network/dmsg_networker.go | 2 +- pkg/app2/rpc_gateway.go | 10 +- pkg/app2/rpc_gateway_test.go | 389 +++++++++++++++++- vendor/github.com/skycoin/dmsg/go.mod | 4 +- vendor/github.com/skycoin/dmsg/go.sum | 6 + 8 files changed, 397 insertions(+), 21 deletions(-) rename pkg/app2/{manager.go => id_manager.go} (97%) rename pkg/app2/{manager_test.go => id_manager_test.go} (100%) diff --git a/pkg/app2/manager.go b/pkg/app2/id_manager.go similarity index 97% rename from pkg/app2/manager.go rename to pkg/app2/id_manager.go index ffb6ebbfcf..771d7f9ea7 100644 --- a/pkg/app2/manager.go +++ b/pkg/app2/id_manager.go @@ -97,5 +97,8 @@ func (m *idManager) get(key uint16) (interface{}, bool) { m.mx.RLock() lis, ok := m.values[key] m.mx.RUnlock() + if lis == nil { + return nil, false + } return lis, ok } diff --git a/pkg/app2/manager_test.go b/pkg/app2/id_manager_test.go similarity index 100% rename from pkg/app2/manager_test.go rename to pkg/app2/id_manager_test.go diff --git a/pkg/app2/network/dmsg_conn.go b/pkg/app2/network/dmsg_conn.go index 9c6e8220bc..c653f8d5d1 100644 --- a/pkg/app2/network/dmsg_conn.go +++ b/pkg/app2/network/dmsg_conn.go @@ -13,6 +13,10 @@ type DMSGConn struct { tp *dmsg.Transport } +func NewDMSGConn(tp *dmsg.Transport) *DMSGConn { + return &DMSGConn{tp: tp} +} + func (c *DMSGConn) Read(b []byte) (n int, err error) { return c.tp.Read(b) } diff --git a/pkg/app2/network/dmsg_networker.go b/pkg/app2/network/dmsg_networker.go index d772260df4..0cdffbb632 100644 --- a/pkg/app2/network/dmsg_networker.go +++ b/pkg/app2/network/dmsg_networker.go @@ -31,7 +31,7 @@ func (n *DMSGNetworker) DialContext(ctx context.Context, addr Addr) (net.Conn, e return nil, err } - return &DMSGConn{tp: tp}, nil + return NewDMSGConn(tp), nil } // Listen starts listening on local `addr` in the dmsg network. diff --git a/pkg/app2/rpc_gateway.go b/pkg/app2/rpc_gateway.go index a84652105c..bb868399e9 100644 --- a/pkg/app2/rpc_gateway.go +++ b/pkg/app2/rpc_gateway.go @@ -112,10 +112,8 @@ func (r *RPCGateway) Accept(lisID *uint16, resp *AcceptResp) error { return errors.New("wrong type for remote addr") } - resp = &AcceptResp{ - Remote: remote, - ConnID: *connID, - } + resp.Remote = remote + resp.ConnID = *connID return nil } @@ -197,7 +195,7 @@ func (r *RPCGateway) CloseListener(lisID *uint16, _ *struct{}) error { func (r *RPCGateway) popListener(lisID uint16) (net.Listener, error) { lisIfc, err := r.lm.pop(lisID) if err != nil { - return nil, err + return nil, errors.Wrap(err, "no listener") } return r.assertListener(lisIfc) @@ -208,7 +206,7 @@ func (r *RPCGateway) popListener(lisID uint16) (net.Listener, error) { func (r *RPCGateway) popConn(connID uint16) (net.Conn, error) { connIfc, err := r.cm.pop(connID) if err != nil { - return nil, err + return nil, errors.Wrap(err, "no conn") } return r.assertConn(connIfc) diff --git a/pkg/app2/rpc_gateway_test.go b/pkg/app2/rpc_gateway_test.go index d230ed120a..9e5301f231 100644 --- a/pkg/app2/rpc_gateway_test.go +++ b/pkg/app2/rpc_gateway_test.go @@ -4,6 +4,7 @@ import ( "context" "math" "net" + "strings" "testing" "github.com/pkg/errors" @@ -148,23 +149,367 @@ func TestRPCGateway_Listen(t *testing.T) { func TestRPCGateway_Accept(t *testing.T) { l := logging.MustGetLogger("rpc_gateway") - rpc := newRPCGateway(l) + t.Run("ok", func(t *testing.T) { + rpc := newRPCGateway(l) - lisID, err := rpc.lm.nextKey() - require.NoError(t, err) + acceptConn := network.NewDMSGConn(&dmsg.Transport{}) + var acceptErr error - acceptConn := &dmsg.Transport{} - var acceptErr error + lis := &MockListener{} + lis.On("Accept").Return(acceptConn, acceptErr) - lis := &MockListener{} - lis.On("Accept").Return(acceptConn, acceptErr) + lisID := addListener(t, rpc, lis) - err = rpc.lm.set(*lisID, lis) - require.NoError(t, err) + var resp AcceptResp + err := rpc.Accept(&lisID, &resp) + require.NoError(t, err) + require.Equal(t, resp.Remote, acceptConn.RemoteAddr()) + }) - var resp AcceptResp - err = rpc.Accept(lisID, &resp) - require.NoError(t, err) + t.Run("no such listener", func(t *testing.T) { + rpc := newRPCGateway(l) + + lisID := uint16(1) + + var resp AcceptResp + err := rpc.Accept(&lisID, &resp) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no listener")) + }) + + t.Run("listener is not set", func(t *testing.T) { + rpc := newRPCGateway(l) + + lisID := addListener(t, rpc, nil) + + var resp AcceptResp + err := rpc.Accept(&lisID, &resp) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no listener")) + }) + + t.Run("no more slots for a new conn", func(t *testing.T) { + rpc := newRPCGateway(l) + for i := uint16(0); i < math.MaxUint16; i++ { + rpc.cm.values[i] = nil + } + rpc.cm.values[math.MaxUint16] = nil + + lisID := addListener(t, rpc, &MockListener{}) + + var resp AcceptResp + err := rpc.Accept(&lisID, &resp) + require.Equal(t, err, errNoMoreAvailableValues) + }) + + t.Run("accept error", func(t *testing.T) { + rpc := newRPCGateway(l) + + var acceptConn net.Conn + acceptErr := errors.New("accept error") + + lis := &MockListener{} + lis.On("Accept").Return(acceptConn, acceptErr) + + lisID := addListener(t, rpc, lis) + + var resp AcceptResp + err := rpc.Accept(&lisID, &resp) + require.Equal(t, err, acceptErr) + }) + + t.Run("wrong type of remote addr", func(t *testing.T) { + rpc := newRPCGateway(l) + + acceptConn := &dmsg.Transport{} + var acceptErr error + + lis := &MockListener{} + lis.On("Accept").Return(acceptConn, acceptErr) + + lisID := addListener(t, rpc, lis) + + var resp AcceptResp + err := rpc.Accept(&lisID, &resp) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "wrong type")) + }) +} + +func TestRPCGateway_Write(t *testing.T) { + l := logging.MustGetLogger("rpc_gateway") + + writeBuff := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + writeN := 10 + + t.Run("ok", func(t *testing.T) { + rpc := newRPCGateway(l) + + var writeErr error + + conn := &MockConn{} + conn.On("Write", writeBuff).Return(writeN, writeErr) + + connID := addConn(t, rpc, conn) + + req := WriteReq{ + ConnID: connID, + B: writeBuff, + } + + var n int + err := rpc.Write(&req, &n) + require.NoError(t, err) + require.Equal(t, n, writeN) + }) + + t.Run("no such conn", func(t *testing.T) { + rpc := newRPCGateway(l) + + connID := uint16(1) + + req := WriteReq{ + ConnID: connID, + B: writeBuff, + } + + var n int + err := rpc.Write(&req, &n) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no conn")) + }) + + t.Run("conn is not set", func(t *testing.T) { + rpc := newRPCGateway(l) + + connID := addConn(t, rpc, nil) + + req := WriteReq{ + ConnID: connID, + B: writeBuff, + } + + var n int + err := rpc.Write(&req, &n) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no conn")) + }) + + t.Run("write error", func(t *testing.T) { + rpc := newRPCGateway(l) + + writeErr := errors.New("write error") + + conn := &MockConn{} + conn.On("Write", writeBuff).Return(writeN, writeErr) + + connID := addConn(t, rpc, conn) + + req := WriteReq{ + ConnID: connID, + B: writeBuff, + } + + var n int + err := rpc.Write(&req, &n) + require.Error(t, err) + require.Equal(t, err, writeErr) + }) +} + +func TestRPCGateway_Read(t *testing.T) { + l := logging.MustGetLogger("rpc_gateway") + + readBufLen := 10 + readBuf := make([]byte, readBufLen) + + t.Run("ok", func(t *testing.T) { + rpc := newRPCGateway(l) + + readN := 10 + var readErr error + + conn := &MockConn{} + conn.On("Read", readBuf).Return(readN, readErr) + + connID := addConn(t, rpc, conn) + + req := ReadReq{ + ConnID: connID, + BufLen: readBufLen, + } + + wantResp := ReadResp{ + B: readBuf, + N: readN, + } + + var resp ReadResp + err := rpc.Read(&req, &resp) + require.NoError(t, err) + require.Equal(t, resp, wantResp) + }) + + t.Run("no such conn", func(t *testing.T) { + rpc := newRPCGateway(l) + + connID := uint16(1) + + req := ReadReq{ + ConnID: connID, + BufLen: readBufLen, + } + + var resp ReadResp + err := rpc.Read(&req, &resp) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no conn")) + }) + + t.Run("conn is not set", func(t *testing.T) { + rpc := newRPCGateway(l) + + connID := addConn(t, rpc, nil) + + req := ReadReq{ + ConnID: connID, + BufLen: readBufLen, + } + + var resp ReadResp + err := rpc.Read(&req, &resp) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no conn")) + }) + + t.Run("read error", func(t *testing.T) { + rpc := newRPCGateway(l) + + readN := 0 + readErr := errors.New("read error") + + conn := &MockConn{} + conn.On("Read", readBuf).Return(readN, readErr) + + connID := addConn(t, rpc, conn) + + req := ReadReq{ + ConnID: connID, + BufLen: readBufLen, + } + + var resp ReadResp + err := rpc.Read(&req, &resp) + require.Equal(t, err, readErr) + }) +} + +func TestRPCGateway_CloseConn(t *testing.T) { + l := logging.MustGetLogger("rpc_gateway") + + t.Run("ok", func(t *testing.T) { + rpc := newRPCGateway(l) + + var closeErr error + + conn := &MockConn{} + conn.On("Close").Return(closeErr) + + connID := addConn(t, rpc, conn) + + err := rpc.CloseConn(&connID, nil) + require.NoError(t, err) + _, ok := rpc.cm.values[connID] + require.False(t, ok) + }) + + t.Run("no such conn", func(t *testing.T) { + rpc := newRPCGateway(l) + + connID := uint16(1) + + err := rpc.CloseConn(&connID, nil) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no conn")) + }) + + t.Run("conn is not set", func(t *testing.T) { + rpc := newRPCGateway(l) + + connID := addConn(t, rpc, nil) + + err := rpc.CloseConn(&connID, nil) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no conn")) + }) + + t.Run("close error", func(t *testing.T) { + rpc := newRPCGateway(l) + + closeErr := errors.New("close error") + + conn := &MockConn{} + conn.On("Close").Return(closeErr) + + connID := addConn(t, rpc, conn) + + err := rpc.CloseConn(&connID, nil) + require.Equal(t, err, closeErr) + }) +} + +func TestRPCGateway_CloseListener(t *testing.T) { + l := logging.MustGetLogger("rpc_gateway") + + t.Run("ok", func(t *testing.T) { + rpc := newRPCGateway(l) + + var closeErr error + + lis := &MockListener{} + lis.On("Close").Return(closeErr) + + lisID := addListener(t, rpc, lis) + + err := rpc.CloseListener(&lisID, nil) + require.NoError(t, err) + _, ok := rpc.lm.values[lisID] + require.False(t, ok) + }) + + t.Run("no such listener", func(t *testing.T) { + rpc := newRPCGateway(l) + + lisID := uint16(1) + + err := rpc.CloseListener(&lisID, nil) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no listener")) + }) + + t.Run("listener is not set", func(t *testing.T) { + rpc := newRPCGateway(l) + + lisID := addListener(t, rpc, nil) + + err := rpc.CloseListener(&lisID, nil) + require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no listener")) + }) + + t.Run("close error", func(t *testing.T) { + rpc := newRPCGateway(l) + + closeErr := errors.New("close error") + + lis := &MockListener{} + lis.On("Close").Return(closeErr) + + lisID := addListener(t, rpc, lis) + + err := rpc.CloseListener(&lisID, nil) + require.Equal(t, err, closeErr) + }) } func prepAddr(nType network.Type) network.Addr { @@ -177,3 +522,23 @@ func prepAddr(nType network.Type) network.Addr { Port: port, } } + +func addConn(t *testing.T, rpc *RPCGateway, conn net.Conn) uint16 { + connID, err := rpc.cm.nextKey() + require.NoError(t, err) + + err = rpc.cm.set(*connID, conn) + require.NoError(t, err) + + return *connID +} + +func addListener(t *testing.T, rpc *RPCGateway, lis net.Listener) uint16 { + lisID, err := rpc.lm.nextKey() + require.NoError(t, err) + + err = rpc.lm.set(*lisID, lis) + require.NoError(t, err) + + return *lisID +} diff --git a/vendor/github.com/skycoin/dmsg/go.mod b/vendor/github.com/skycoin/dmsg/go.mod index 1ef2c47f66..a24455c1f5 100644 --- a/vendor/github.com/skycoin/dmsg/go.mod +++ b/vendor/github.com/skycoin/dmsg/go.mod @@ -14,7 +14,7 @@ require ( golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4 // indirect golang.org/x/net v0.0.0-20190620200207-3b0461eec859 golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb // indirect + golang.org/x/text v0.3.2 // indirect + golang.org/x/tools v0.0.0-20190627182818-9947fec5c3ab // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) - -replace github.com/skycoin/dmsg => ../dmsg diff --git a/vendor/github.com/skycoin/dmsg/go.sum b/vendor/github.com/skycoin/dmsg/go.sum index 624818fed7..c6a730a9e6 100644 --- a/vendor/github.com/skycoin/dmsg/go.sum +++ b/vendor/github.com/skycoin/dmsg/go.sum @@ -32,9 +32,11 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4 h1:ydJNl0ENAG67pFbB+9tfhiL2pYqLhfoaZFw/cjLhY4A= golang.org/x/crypto v0.0.0-20190621222207-cc06ce4a13d4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859 h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -43,5 +45,9 @@ golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb h1:fgwFCsaw9buMuxNd6+DQfAuSF golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190627182818-9947fec5c3ab h1:uOzhX2fm3C4BmBwW2a7lnJQD7qel2+4uhmTc8czKBCU= +golang.org/x/tools v0.0.0-20190627182818-9947fec5c3ab/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From f9c933f4d2057f6c84de9cf97513eb38f941ecbc Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Tue, 24 Sep 2019 13:13:45 +0300 Subject: [PATCH 35/43] Add `WrappedConn` --- pkg/app2/listener.go | 1 - pkg/app2/network/addr.go | 22 ++++++++++ pkg/app2/network/dmsg_conn.go | 68 ------------------------------ pkg/app2/network/dmsg_networker.go | 2 +- pkg/app2/network/wrapped_conn.go | 42 ++++++++++++++++++ 5 files changed, 65 insertions(+), 70 deletions(-) delete mode 100644 pkg/app2/network/dmsg_conn.go create mode 100644 pkg/app2/network/wrapped_conn.go diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go index 56e52160e3..c78befd38d 100644 --- a/pkg/app2/listener.go +++ b/pkg/app2/listener.go @@ -31,7 +31,6 @@ func (l *Listener) Accept() (net.Conn, error) { return conn, nil } -// TODO: should unblock all called `Accept`s with errors func (l *Listener) Close() error { defer l.freePort() diff --git a/pkg/app2/network/addr.go b/pkg/app2/network/addr.go index 2db24213e5..abd4a11704 100644 --- a/pkg/app2/network/addr.go +++ b/pkg/app2/network/addr.go @@ -1,13 +1,20 @@ package network import ( + "errors" "fmt" + "net" + "github.com/skycoin/dmsg" "github.com/skycoin/dmsg/cipher" "github.com/skycoin/skywire/pkg/routing" ) +var ( + errUnknownAddrType = errors.New("addr type is unknown") +) + // Addr implements net.Addr for network addresses. type Addr struct { Net Type @@ -27,3 +34,18 @@ func (a Addr) String() string { } return fmt.Sprintf("%s:%d", a.PubKey, a.Port) } + +// wrapAddrs asserts type of the passed `net.Addr` and converts it +// to `Addr` if possible. +func wrapAddr(addr net.Addr) (Addr, error) { + switch a := addr.(type) { + case dmsg.Addr: + return Addr{ + Net: TypeDMSG, + PubKey: a.PK, + Port: routing.Port(a.Port), + }, nil + default: + return Addr{}, errUnknownAddrType + } +} diff --git a/pkg/app2/network/dmsg_conn.go b/pkg/app2/network/dmsg_conn.go deleted file mode 100644 index c653f8d5d1..0000000000 --- a/pkg/app2/network/dmsg_conn.go +++ /dev/null @@ -1,68 +0,0 @@ -package network - -import ( - "net" - "time" - - "github.com/skycoin/skywire/pkg/routing" - - "github.com/skycoin/dmsg" -) - -type DMSGConn struct { - tp *dmsg.Transport -} - -func NewDMSGConn(tp *dmsg.Transport) *DMSGConn { - return &DMSGConn{tp: tp} -} - -func (c *DMSGConn) Read(b []byte) (n int, err error) { - return c.tp.Read(b) -} - -func (c *DMSGConn) Write(b []byte) (n int, err error) { - return c.tp.Write(b) -} - -func (c *DMSGConn) Close() error { - return c.tp.Close() -} - -func (c *DMSGConn) LocalAddr() net.Addr { - dmsgAddr, ok := c.tp.LocalAddr().(dmsg.Addr) - if !ok { - return c.tp.LocalAddr() - } - - return Addr{ - Net: TypeDMSG, - PubKey: dmsgAddr.PK, - Port: routing.Port(dmsgAddr.Port), - } -} - -func (c *DMSGConn) RemoteAddr() net.Addr { - dmsgAddr, ok := c.tp.RemoteAddr().(dmsg.Addr) - if !ok { - return c.tp.RemoteAddr() - } - - return Addr{ - Net: TypeDMSG, - PubKey: dmsgAddr.PK, - Port: routing.Port(dmsgAddr.Port), - } -} - -func (c *DMSGConn) SetDeadline(t time.Time) error { - return c.tp.SetDeadline(t) -} - -func (c *DMSGConn) SetReadDeadline(t time.Time) error { - return c.tp.SetReadDeadline(t) -} - -func (c *DMSGConn) SetWriteDeadline(t time.Time) error { - return c.tp.SetWriteDeadline(t) -} diff --git a/pkg/app2/network/dmsg_networker.go b/pkg/app2/network/dmsg_networker.go index 0cdffbb632..e87ef55128 100644 --- a/pkg/app2/network/dmsg_networker.go +++ b/pkg/app2/network/dmsg_networker.go @@ -31,7 +31,7 @@ func (n *DMSGNetworker) DialContext(ctx context.Context, addr Addr) (net.Conn, e return nil, err } - return NewDMSGConn(tp), nil + return WrapConn(tp) } // Listen starts listening on local `addr` in the dmsg network. diff --git a/pkg/app2/network/wrapped_conn.go b/pkg/app2/network/wrapped_conn.go new file mode 100644 index 0000000000..d5915a58f2 --- /dev/null +++ b/pkg/app2/network/wrapped_conn.go @@ -0,0 +1,42 @@ +package network + +import ( + "net" +) + +// WrappedConn wraps `net.Conn` to support address conversion between +// specific `net.Addr` implementations and `Addr`. +type WrappedConn struct { + net.Conn + local Addr + remote Addr +} + +// WrapConn wraps passed `conn`. Handles `net.Addr` type assertion. +func WrapConn(conn net.Conn) (net.Conn, error) { + l, err := wrapAddr(conn.LocalAddr()) + if err != nil { + return nil, err + } + + r, err := wrapAddr(conn.RemoteAddr()) + if err != nil { + return nil, err + } + + return &WrappedConn{ + Conn: conn, + local: l, + remote: r, + }, nil +} + +// LocalAddr returns local address. +func (c *WrappedConn) LocalAddr() net.Addr { + return c.local +} + +// RemoteAddr returns remote address. +func (c *WrappedConn) RemoteAddr() net.Addr { + return c.remote +} From 34c6e29cbb6b955f0127536331cad0721dae7572 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Tue, 24 Sep 2019 13:32:19 +0300 Subject: [PATCH 36/43] Pass assigned local port from the server --- pkg/app2/client.go | 14 ++------ pkg/app2/conn.go | 15 +++----- pkg/app2/id_manager.go | 60 +++++++++++++++++++------------- pkg/app2/network/addr.go | 4 +-- pkg/app2/network/wrapped_conn.go | 4 +-- pkg/app2/rpc_client.go | 14 ++++---- pkg/app2/rpc_gateway.go | 32 ++++++++++++++--- 7 files changed, 81 insertions(+), 62 deletions(-) diff --git a/pkg/app2/client.go b/pkg/app2/client.go index bd9e203905..2f41a956c6 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -1,7 +1,6 @@ package app2 import ( - "context" "net" "github.com/skycoin/dmsg/cipher" @@ -34,27 +33,20 @@ func NewClient(localPK cipher.PubKey, pid ProcID, rpc RPCClient) *Client { // Dial dials the remote node using `remote`. func (c *Client) Dial(remote network.Addr) (net.Conn, error) { - localPort, free, err := c.porter.ReserveEphemeral(context.TODO(), nil) + connID, assignedPort, err := c.rpc.Dial(remote) if err != nil { return nil, err } - connID, err := c.rpc.Dial(remote) - if err != nil { - free() - return nil, err - } - conn := &Conn{ id: connID, rpc: c.rpc, local: network.Addr{ Net: remote.Net, PubKey: c.pk, - Port: routing.Port(localPort), + Port: assignedPort, }, - remote: remote, - freeLocalPort: free, + remote: remote, } return conn, nil diff --git a/pkg/app2/conn.go b/pkg/app2/conn.go index d201f22a62..67f99579d7 100644 --- a/pkg/app2/conn.go +++ b/pkg/app2/conn.go @@ -10,11 +10,10 @@ import ( // Conn is a connection from app client to the server. // Implements `net.Conn`. type Conn struct { - id uint16 - rpc RPCClient - local network.Addr - remote network.Addr - freeLocalPort func() + id uint16 + rpc RPCClient + local network.Addr + remote network.Addr } func (c *Conn) Read(b []byte) (int, error) { @@ -34,12 +33,6 @@ func (c *Conn) Write(b []byte) (int, error) { } func (c *Conn) Close() error { - defer func() { - if c.freeLocalPort != nil { - c.freeLocalPort() - } - }() - return c.rpc.CloseConn(c.id) } diff --git a/pkg/app2/id_manager.go b/pkg/app2/id_manager.go index 771d7f9ea7..561a5c7e37 100644 --- a/pkg/app2/id_manager.go +++ b/pkg/app2/id_manager.go @@ -17,7 +17,7 @@ var ( type idManager struct { values map[uint16]interface{} mx sync.RWMutex - lstKey uint16 + lstID uint16 } // newIDManager constructs new `idManager`. @@ -27,58 +27,58 @@ func newIDManager() *idManager { } } -// `nextKey` reserves next free slot for the value and returns the key for it. -func (m *idManager) nextKey() (*uint16, error) { +// `reserveNextID` reserves next free slot for the value and returns the id for it. +func (m *idManager) reserveNextID() (id *uint16, free func(), err error) { m.mx.Lock() - nxtKey := m.lstKey + 1 - for ; nxtKey != m.lstKey; nxtKey++ { - if _, ok := m.values[nxtKey]; !ok { + nxtID := m.lstID + 1 + for ; nxtID != m.lstID; nxtID++ { + if _, ok := m.values[nxtID]; !ok { break } } - if nxtKey == m.lstKey { + if nxtID == m.lstID { m.mx.Unlock() - return nil, errNoMoreAvailableValues + return nil, nil, errNoMoreAvailableValues } - m.values[nxtKey] = nil - m.lstKey = nxtKey + m.values[nxtID] = nil + m.lstID = nxtID m.mx.Unlock() - return &nxtKey, nil + return &nxtID, m.constructFreeFunc(nxtID), nil } -// pop removes value specified by `key` from the idManager instance and +// pop removes value specified by `id` from the idManager instance and // returns it. -func (m *idManager) pop(key uint16) (interface{}, error) { +func (m *idManager) pop(id uint16) (interface{}, error) { m.mx.Lock() - v, ok := m.values[key] + v, ok := m.values[id] if !ok { m.mx.Unlock() - return nil, fmt.Errorf("no value with key %d", key) + return nil, fmt.Errorf("no value with id %d", id) } if v == nil { m.mx.Unlock() - return nil, fmt.Errorf("value with key %d is not set", key) + return nil, fmt.Errorf("value with id %d is not set", id) } - delete(m.values, key) + delete(m.values, id) m.mx.Unlock() return v, nil } -// set sets value `v` associated with `key`. -func (m *idManager) set(key uint16, v interface{}) error { +// set sets value `v` associated with `id`. +func (m *idManager) set(id uint16, v interface{}) error { m.mx.Lock() - l, ok := m.values[key] + l, ok := m.values[id] if !ok { m.mx.Unlock() - return errors.New("key is not reserved") + return errors.New("id is not reserved") } else { if l != nil { m.mx.Unlock() @@ -86,19 +86,29 @@ func (m *idManager) set(key uint16, v interface{}) error { } } - m.values[key] = v + m.values[id] = v m.mx.Unlock() return nil } -// get gets the value associated with the `key`. -func (m *idManager) get(key uint16) (interface{}, bool) { +// get gets the value associated with the `id`. +func (m *idManager) get(id uint16) (interface{}, bool) { m.mx.RLock() - lis, ok := m.values[key] + lis, ok := m.values[id] m.mx.RUnlock() if lis == nil { return nil, false } return lis, ok } + +// constructFreeFunc constructs new func responsible for clearing +// a slot with the specified `id`. +func (m *idManager) constructFreeFunc(id uint16) func() { + return func() { + m.mx.Lock() + delete(m.values, id) + m.mx.Unlock() + } +} diff --git a/pkg/app2/network/addr.go b/pkg/app2/network/addr.go index abd4a11704..c3fa28020e 100644 --- a/pkg/app2/network/addr.go +++ b/pkg/app2/network/addr.go @@ -35,9 +35,9 @@ func (a Addr) String() string { return fmt.Sprintf("%s:%d", a.PubKey, a.Port) } -// wrapAddrs asserts type of the passed `net.Addr` and converts it +// WrapAddr asserts type of the passed `net.Addr` and converts it // to `Addr` if possible. -func wrapAddr(addr net.Addr) (Addr, error) { +func WrapAddr(addr net.Addr) (Addr, error) { switch a := addr.(type) { case dmsg.Addr: return Addr{ diff --git a/pkg/app2/network/wrapped_conn.go b/pkg/app2/network/wrapped_conn.go index d5915a58f2..cdc1595f2c 100644 --- a/pkg/app2/network/wrapped_conn.go +++ b/pkg/app2/network/wrapped_conn.go @@ -14,12 +14,12 @@ type WrappedConn struct { // WrapConn wraps passed `conn`. Handles `net.Addr` type assertion. func WrapConn(conn net.Conn) (net.Conn, error) { - l, err := wrapAddr(conn.LocalAddr()) + l, err := WrapAddr(conn.LocalAddr()) if err != nil { return nil, err } - r, err := wrapAddr(conn.RemoteAddr()) + r, err := WrapAddr(conn.RemoteAddr()) if err != nil { return nil, err } diff --git a/pkg/app2/rpc_client.go b/pkg/app2/rpc_client.go index 1e6b0377a5..4de0dabea2 100644 --- a/pkg/app2/rpc_client.go +++ b/pkg/app2/rpc_client.go @@ -3,6 +3,8 @@ package app2 import ( "net/rpc" + "github.com/skycoin/skywire/pkg/routing" + "github.com/skycoin/skywire/pkg/app2/network" ) @@ -10,7 +12,7 @@ import ( // RPCClient describes RPC interface to communicate with the server. type RPCClient interface { - Dial(remote network.Addr) (uint16, error) + Dial(remote network.Addr) (connID uint16, assignedPort routing.Port, err error) Listen(local network.Addr) (uint16, error) Accept(lisID uint16) (connID uint16, remote network.Addr, err error) Write(connID uint16, b []byte) (int, error) @@ -32,13 +34,13 @@ func NewRPCClient(rpc *rpc.Client) RPCClient { } // Dial sends `Dial` command to the server. -func (c *rpcCLient) Dial(remote network.Addr) (uint16, error) { - var connID uint16 - if err := c.rpc.Call("Dial", &remote, &connID); err != nil { - return 0, err +func (c *rpcCLient) Dial(remote network.Addr) (connID uint16, assignedPort routing.Port, err error) { + var resp DialResp + if err := c.rpc.Call("Dial", &remote, &resp); err != nil { + return 0, 0, err } - return connID, nil + return resp.ConnID, resp.AssignedPort, nil } // Listen sends `Listen` command to the server. diff --git a/pkg/app2/rpc_gateway.go b/pkg/app2/rpc_gateway.go index bb868399e9..176b187f8f 100644 --- a/pkg/app2/rpc_gateway.go +++ b/pkg/app2/rpc_gateway.go @@ -4,6 +4,8 @@ import ( "fmt" "net" + "github.com/skycoin/skywire/pkg/routing" + "github.com/pkg/errors" "github.com/skycoin/skycoin/src/util/logging" @@ -26,15 +28,28 @@ func newRPCGateway(log *logging.Logger) *RPCGateway { } } +// DialResp contains response parameters for `Dial`. +type DialResp struct { + ConnID uint16 + AssignedPort routing.Port +} + // Dial dials to the remote. -func (r *RPCGateway) Dial(remote *network.Addr, connID *uint16) error { - reservedConnID, err := r.cm.nextKey() +func (r *RPCGateway) Dial(remote *network.Addr, resp *DialResp) error { + reservedConnID, free, err := r.cm.reserveNextID() if err != nil { return err } conn, err := network.Dial(*remote) if err != nil { + free() + return err + } + + localAddr, err := network.WrapAddr(conn.LocalAddr()) + if err != nil { + free() return err } @@ -43,23 +58,26 @@ func (r *RPCGateway) Dial(remote *network.Addr, connID *uint16) error { r.log.WithError(err).Error("error closing conn") } + free() return err } - *connID = *reservedConnID + resp.ConnID = *reservedConnID + resp.AssignedPort = localAddr.Port return nil } // Listen starts listening. func (r *RPCGateway) Listen(local *network.Addr, lisID *uint16) error { - nextLisID, err := r.lm.nextKey() + nextLisID, free, err := r.lm.reserveNextID() if err != nil { return err } l, err := network.Listen(*local) if err != nil { + free() return err } @@ -68,6 +86,7 @@ func (r *RPCGateway) Listen(local *network.Addr, lisID *uint16) error { r.log.WithError(err).Error("error closing listener") } + free() return err } @@ -89,13 +108,14 @@ func (r *RPCGateway) Accept(lisID *uint16, resp *AcceptResp) error { return err } - connID, err := r.cm.nextKey() + connID, free, err := r.cm.reserveNextID() if err != nil { return err } conn, err := lis.Accept() if err != nil { + free() return err } @@ -104,11 +124,13 @@ func (r *RPCGateway) Accept(lisID *uint16, resp *AcceptResp) error { r.log.WithError(err).Error("error closing DMSG transport") } + free() return err } remote, ok := conn.RemoteAddr().(network.Addr) if !ok { + free() return errors.New("wrong type for remote addr") } From 677fb24e7a580af2d196978b934622aff0fedcc3 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Tue, 24 Sep 2019 14:51:48 +0300 Subject: [PATCH 37/43] Add conn/listener tracking --- pkg/app2/client.go | 81 +++++++++++++++++++++++++++++++++++-- pkg/app2/conn.go | 11 +++-- pkg/app2/id_manager.go | 30 +++++++++++++- pkg/app2/id_manager_util.go | 27 +++++++++++++ pkg/app2/listener.go | 39 +++++++++++++++++- pkg/app2/rpc_client.go | 3 +- pkg/app2/rpc_gateway.go | 39 +++++------------- 7 files changed, 188 insertions(+), 42 deletions(-) create mode 100644 pkg/app2/id_manager_util.go diff --git a/pkg/app2/client.go b/pkg/app2/client.go index 2f41a956c6..5f6e9317d1 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -5,6 +5,7 @@ import ( "github.com/skycoin/dmsg/cipher" "github.com/skycoin/dmsg/netutil" + "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/pkg/app2/network" "github.com/skycoin/skywire/pkg/routing" @@ -12,21 +13,28 @@ import ( // Client is used by skywire apps. type Client struct { + log *logging.Logger pk cipher.PubKey pid ProcID rpc RPCClient + lm *idManager // contains listeners associated with their IDs + cm *idManager // contains connections associated with their IDs porter *netutil.Porter } // NewClient creates a new `Client`. The `Client` needs to be provided with: +// - log: logger instance // - localPK: The local public key of the parent skywire visor. // - pid: The procID assigned for the process that Client is being used by. // - rpc: RPC client to communicate with the server. -func NewClient(localPK cipher.PubKey, pid ProcID, rpc RPCClient) *Client { +func NewClient(log *logging.Logger, localPK cipher.PubKey, pid ProcID, rpc RPCClient) *Client { return &Client{ + log: log, pk: localPK, pid: pid, rpc: rpc, + lm: newIDManager(), + cm: newIDManager(), porter: netutil.NewPorter(netutil.PorterMinEphemeral), } } @@ -49,12 +57,23 @@ func (c *Client) Dial(remote network.Addr) (net.Conn, error) { remote: remote, } + free, err := c.cm.add(connID, conn) + if err != nil { + if err := conn.Close(); err != nil { + c.log.WithError(err).Error("error closing conn") + } + + return nil, err + } + + conn.freeConn = free + return conn, nil } // Listen listens on the specified `port` for the incoming connections. func (c *Client) Listen(n network.Type, port routing.Port) (net.Listener, error) { - ok, free := c.porter.Reserve(uint16(port), nil) + ok, freePort := c.porter.Reserve(uint16(port), nil) if !ok { return nil, ErrPortAlreadyBound } @@ -67,16 +86,70 @@ func (c *Client) Listen(n network.Type, port routing.Port) (net.Listener, error) lisID, err := c.rpc.Listen(local) if err != nil { - free() + freePort() return nil, err } listener := &Listener{ + log: c.log, id: lisID, rpc: c.rpc, addr: local, - freePort: free, + cm: newIDManager(), + freePort: freePort, + } + + freeLis, err := c.lm.add(lisID, listener) + if err != nil { + if err := listener.Close(); err != nil { + c.log.WithError(err).Error("error closing listener") + } + + freePort() + return nil, err } + listener.freeLis = freeLis + return listener, nil } + +// Close closes client/server communication entirely. It closes all open +// listeners and connections. +func (c *Client) Close() { + var listeners []net.Listener + c.lm.doRange(func(_ uint16, v interface{}) bool { + lis, err := assertListener(v) + if err != nil { + c.log.Error(err) + return true + } + + listeners = append(listeners, lis) + return true + }) + + var conns []net.Conn + c.cm.doRange(func(_ uint16, v interface{}) bool { + conn, err := assertConn(v) + if err != nil { + c.log.Error(err) + return true + } + + conns = append(conns, conn) + return true + }) + + for _, lis := range listeners { + if err := lis.Close(); err != nil { + c.log.WithError(err).Error("error closing listener") + } + } + + for _, conn := range conns { + if err := conn.Close(); err != nil { + c.log.WithError(err).Error("error closing conn") + } + } +} diff --git a/pkg/app2/conn.go b/pkg/app2/conn.go index 67f99579d7..09847da724 100644 --- a/pkg/app2/conn.go +++ b/pkg/app2/conn.go @@ -10,10 +10,11 @@ import ( // Conn is a connection from app client to the server. // Implements `net.Conn`. type Conn struct { - id uint16 - rpc RPCClient - local network.Addr - remote network.Addr + id uint16 + rpc RPCClient + local network.Addr + remote network.Addr + freeConn func() } func (c *Conn) Read(b []byte) (int, error) { @@ -33,6 +34,8 @@ func (c *Conn) Write(b []byte) (int, error) { } func (c *Conn) Close() error { + defer c.freeConn() + return c.rpc.CloseConn(c.id) } diff --git a/pkg/app2/id_manager.go b/pkg/app2/id_manager.go index 561a5c7e37..6f087f1b3a 100644 --- a/pkg/app2/id_manager.go +++ b/pkg/app2/id_manager.go @@ -9,6 +9,7 @@ import ( var ( errNoMoreAvailableValues = errors.New("no more available values") + errValueAlreadyExists = errors.New("value already exists") ) // idManager manages allows to store and retrieve arbitrary values @@ -71,6 +72,21 @@ func (m *idManager) pop(id uint16) (interface{}, error) { return v, nil } +// add adds the new value `v` associated with `id`. +func (m *idManager) add(id uint16, v interface{}) (free func(), err error) { + m.mx.Lock() + + if _, ok := m.values[id]; ok { + m.mx.Unlock() + return nil, errValueAlreadyExists + } + + m.values[id] = v + + m.mx.Unlock() + return m.constructFreeFunc(id), nil +} + // set sets value `v` associated with `id`. func (m *idManager) set(id uint16, v interface{}) error { m.mx.Lock() @@ -82,7 +98,7 @@ func (m *idManager) set(id uint16, v interface{}) error { } else { if l != nil { m.mx.Unlock() - return errors.New("value already exists") + return errValueAlreadyExists } } @@ -103,6 +119,18 @@ func (m *idManager) get(id uint16) (interface{}, bool) { return lis, ok } +// doRange performs range over the manager contents. Loop stops when +// `next` returns false. +func (m *idManager) doRange(next func(id uint16, v interface{}) bool) { + m.mx.RLock() + for id, v := range m.values { + if !next(id, v) { + break + } + } + m.mx.RUnlock() +} + // constructFreeFunc constructs new func responsible for clearing // a slot with the specified `id`. func (m *idManager) constructFreeFunc(id uint16) func() { diff --git a/pkg/app2/id_manager_util.go b/pkg/app2/id_manager_util.go new file mode 100644 index 0000000000..174b293300 --- /dev/null +++ b/pkg/app2/id_manager_util.go @@ -0,0 +1,27 @@ +package app2 + +import ( + "net" + + "github.com/pkg/errors" +) + +// assertListener asserts that `v` is of type `net.Listener`. +func assertListener(v interface{}) (net.Listener, error) { + lis, ok := v.(net.Listener) + if !ok { + return nil, errors.New("wrong type of value stored for listener") + } + + return lis, nil +} + +// assertConn asserts that `v` is of type `net.Conn`. +func assertConn(v interface{}) (net.Conn, error) { + conn, ok := v.(net.Conn) + if !ok { + return nil, errors.New("wrong type of value stored for conn") + } + + return conn, nil +} diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go index c78befd38d..41bc198e74 100644 --- a/pkg/app2/listener.go +++ b/pkg/app2/listener.go @@ -3,16 +3,21 @@ package app2 import ( "net" + "github.com/skycoin/skycoin/src/util/logging" + "github.com/skycoin/skywire/pkg/app2/network" ) // Listener is a listener for app server connections. // Implements `net.Listener`. type Listener struct { + log *logging.Logger id uint16 rpc RPCClient addr network.Addr + cm *idManager // contains conns associated with their IDs freePort func() + freeLis func() } func (l *Listener) Accept() (net.Conn, error) { @@ -28,11 +33,43 @@ func (l *Listener) Accept() (net.Conn, error) { remote: remote, } + free, err := l.cm.add(connID, conn) + if err != nil { + if err := conn.Close(); err != nil { + l.log.WithError(err).Error("error closing listener") + } + + return nil, err + } + + conn.freeConn = free + return conn, nil } func (l *Listener) Close() error { - defer l.freePort() + defer func() { + l.freePort() + l.freeLis() + + var conns []net.Conn + l.cm.doRange(func(_ uint16, v interface{}) bool { + conn, err := assertConn(v) + if err != nil { + l.log.Error(err) + return true + } + + conns = append(conns, conn) + return true + }) + + for _, conn := range conns { + if err := conn.Close(); err != nil { + l.log.WithError(err).Error("error closing listener") + } + } + }() return l.rpc.CloseListener(l.id) } diff --git a/pkg/app2/rpc_client.go b/pkg/app2/rpc_client.go index 4de0dabea2..c0440c134f 100644 --- a/pkg/app2/rpc_client.go +++ b/pkg/app2/rpc_client.go @@ -3,9 +3,8 @@ package app2 import ( "net/rpc" - "github.com/skycoin/skywire/pkg/routing" - "github.com/skycoin/skywire/pkg/app2/network" + "github.com/skycoin/skywire/pkg/routing" ) //go:generate mockery -name RPCClient -case underscore -inpkg diff --git a/pkg/app2/rpc_gateway.go b/pkg/app2/rpc_gateway.go index 176b187f8f..fab2245937 100644 --- a/pkg/app2/rpc_gateway.go +++ b/pkg/app2/rpc_gateway.go @@ -4,26 +4,25 @@ import ( "fmt" "net" - "github.com/skycoin/skywire/pkg/routing" - "github.com/pkg/errors" "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/pkg/app2/network" + "github.com/skycoin/skywire/pkg/routing" ) // RPCGateway is a RPC interface for the app server. type RPCGateway struct { - lm *idManager - cm *idManager + lm *idManager // contains listeners associated with their IDs + cm *idManager // contains connections associated with their IDs log *logging.Logger } // newRPCGateway constructs new server RPC interface. func newRPCGateway(log *logging.Logger) *RPCGateway { return &RPCGateway{ - lm: newIDManager(), // contains listeners associated with their IDs - cm: newIDManager(), // contains connections associated with their IDs + lm: newIDManager(), + cm: newIDManager(), log: log, } } @@ -220,7 +219,7 @@ func (r *RPCGateway) popListener(lisID uint16) (net.Listener, error) { return nil, errors.Wrap(err, "no listener") } - return r.assertListener(lisIfc) + return assertListener(lisIfc) } // popConn gets conn from the manager by `connID` and removes it. @@ -231,7 +230,7 @@ func (r *RPCGateway) popConn(connID uint16) (net.Conn, error) { return nil, errors.Wrap(err, "no conn") } - return r.assertConn(connIfc) + return assertConn(connIfc) } // getListener gets listener from the manager by `lisID`. Handles type assertion. @@ -241,7 +240,7 @@ func (r *RPCGateway) getListener(lisID uint16) (net.Listener, error) { return nil, fmt.Errorf("no listener with key %d", lisID) } - return r.assertListener(lisIfc) + return assertListener(lisIfc) } // getConn gets conn from the manager by `connID`. Handles type assertion. @@ -251,25 +250,5 @@ func (r *RPCGateway) getConn(connID uint16) (net.Conn, error) { return nil, fmt.Errorf("no conn with key %d", connID) } - return r.assertConn(connIfc) -} - -// assertListener asserts that `v` is of type `net.Listener`. -func (r *RPCGateway) assertListener(v interface{}) (net.Listener, error) { - lis, ok := v.(net.Listener) - if !ok { - return nil, errors.New("wrong type of value stored for listener") - } - - return lis, nil -} - -// assertConn asserts that `v` is of type `net.Conn`. -func (r *RPCGateway) assertConn(v interface{}) (net.Conn, error) { - conn, ok := v.(net.Conn) - if !ok { - return nil, errors.New("wrong type of value stored for conn") - } - - return conn, nil + return assertConn(connIfc) } From 3a9fdba9c79e694bc22d076097bccf42e6cb597b Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Tue, 24 Sep 2019 15:38:24 +0300 Subject: [PATCH 38/43] Fix client tests --- pkg/app2/client_test.go | 117 +++++++++++++++++++++++++---- pkg/app2/conn.go | 6 +- pkg/app2/id_manager_test.go | 29 +++---- pkg/app2/listener.go | 4 +- pkg/app2/mock_rpc_client.go | 24 +++--- pkg/app2/network/mock_networker.go | 9 +-- pkg/app2/rpc_gateway_test.go | 23 +----- 7 files changed, 147 insertions(+), 65 deletions(-) diff --git a/pkg/app2/client_test.go b/pkg/app2/client_test.go index 6bc19e2a8d..6724cb4d27 100644 --- a/pkg/app2/client_test.go +++ b/pkg/app2/client_test.go @@ -3,6 +3,8 @@ package app2 import ( "testing" + "github.com/skycoin/skycoin/src/util/logging" + "github.com/pkg/errors" "github.com/skycoin/dmsg/cipher" "github.com/stretchr/testify/require" @@ -12,56 +14,101 @@ import ( ) func TestClient_Dial(t *testing.T) { + l := logging.MustGetLogger("app2_client") localPK, _ := cipher.GenerateKeyPair() pid := ProcID(1) remotePK, _ := cipher.GenerateKeyPair() remotePort := routing.Port(120) remote := network.Addr{ + Net: network.TypeDMSG, PubKey: remotePK, Port: remotePort, } t.Run("ok", func(t *testing.T) { dialConnID := uint16(1) + dialAssignedPort := routing.Port(1) var dialErr error rpc := &MockRPCClient{} - rpc.On("Dial", remote).Return(dialConnID, dialErr) + rpc.On("Dial", remote).Return(dialConnID, dialAssignedPort, dialErr) - cl := NewClient(localPK, pid, rpc) + cl := NewClient(l, localPK, pid, rpc) wantConn := &Conn{ id: dialConnID, rpc: rpc, local: network.Addr{ + Net: remote.Net, PubKey: localPK, + Port: dialAssignedPort, }, remote: remote, } conn, err := cl.Dial(remote) + require.NoError(t, err) + appConn, ok := conn.(*Conn) require.True(t, ok) - require.NoError(t, err) require.Equal(t, wantConn.id, appConn.id) require.Equal(t, wantConn.rpc, appConn.rpc) - require.Equal(t, wantConn.local.PubKey, appConn.local.PubKey) + require.Equal(t, wantConn.local, appConn.local) require.Equal(t, wantConn.remote, appConn.remote) - require.NotNil(t, appConn.freeLocalPort) - portVal, ok := cl.porter.PortValue(uint16(appConn.local.Port)) - require.True(t, ok) - require.Nil(t, portVal) + require.NotNil(t, appConn.freeConn) + }) + + t.Run("conn already exists", func(t *testing.T) { + dialConnID := uint16(1) + dialAssignedPort := routing.Port(1) + var dialErr error + + var closeErr error + + rpc := &MockRPCClient{} + rpc.On("Dial", remote).Return(dialConnID, dialAssignedPort, dialErr) + rpc.On("CloseConn", dialConnID).Return(closeErr) + + cl := NewClient(l, localPK, pid, rpc) + + _, err := cl.cm.add(dialConnID, nil) + require.NoError(t, err) + + conn, err := cl.Dial(remote) + require.Equal(t, err, errValueAlreadyExists) + require.Nil(t, conn) + }) + + t.Run("conn already exists, conn closed with error", func(t *testing.T) { + dialConnID := uint16(1) + dialAssignedPort := routing.Port(1) + var dialErr error + + closeErr := errors.New("close error") + + rpc := &MockRPCClient{} + rpc.On("Dial", remote).Return(dialConnID, dialAssignedPort, dialErr) + rpc.On("CloseConn", dialConnID).Return(closeErr) + + cl := NewClient(l, localPK, pid, rpc) + + _, err := cl.cm.add(dialConnID, nil) + require.NoError(t, err) + + conn, err := cl.Dial(remote) + require.Equal(t, err, errValueAlreadyExists) + require.Nil(t, conn) }) t.Run("dial error", func(t *testing.T) { dialErr := errors.New("dial error") rpc := &MockRPCClient{} - rpc.On("Dial", remote).Return(uint16(0), dialErr) + rpc.On("Dial", remote).Return(uint16(0), uint16(0), dialErr) - cl := NewClient(localPK, pid, rpc) + cl := NewClient(l, localPK, pid, rpc) conn, err := cl.Dial(remote) require.Equal(t, dialErr, err) @@ -70,6 +117,7 @@ func TestClient_Dial(t *testing.T) { } func TestClient_Listen(t *testing.T) { + l := logging.MustGetLogger("app2_client") localPK, _ := cipher.GenerateKeyPair() pid := ProcID(1) @@ -87,7 +135,7 @@ func TestClient_Listen(t *testing.T) { rpc := &MockRPCClient{} rpc.On("Listen", local).Return(listenLisID, listenErr) - cl := NewClient(localPK, pid, rpc) + cl := NewClient(l, localPK, pid, rpc) wantListener := &Listener{ id: listenLisID, @@ -97,12 +145,15 @@ func TestClient_Listen(t *testing.T) { listener, err := cl.Listen(network.TypeDMSG, port) require.Nil(t, err) + appListener, ok := listener.(*Listener) require.True(t, ok) + require.Equal(t, wantListener.id, appListener.id) require.Equal(t, wantListener.rpc, appListener.rpc) require.Equal(t, wantListener.addr, appListener.addr) require.NotNil(t, appListener.freePort) + require.NotNil(t, appListener.freeLis) portVal, ok := cl.porter.PortValue(uint16(port)) require.True(t, ok) require.Nil(t, portVal) @@ -111,7 +162,7 @@ func TestClient_Listen(t *testing.T) { t.Run("port is already bound", func(t *testing.T) { rpc := &MockRPCClient{} - cl := NewClient(localPK, pid, rpc) + cl := NewClient(l, localPK, pid, rpc) ok, _ := cl.porter.Reserve(uint16(port), nil) require.True(t, ok) @@ -123,13 +174,53 @@ func TestClient_Listen(t *testing.T) { require.Nil(t, listener) }) + t.Run("listener already exists", func(t *testing.T) { + listenLisID := uint16(1) + var listenErr error + + var closeErr error + + rpc := &MockRPCClient{} + rpc.On("Listen", local).Return(listenLisID, listenErr) + rpc.On("CloseListener", listenLisID).Return(closeErr) + + cl := NewClient(l, localPK, pid, rpc) + + _, err := cl.lm.add(listenLisID, nil) + require.NoError(t, err) + + listener, err := cl.Listen(network.TypeDMSG, port) + require.Equal(t, err, errValueAlreadyExists) + require.Nil(t, listener) + }) + + t.Run("listener already exists, listener closed with error", func(t *testing.T) { + listenLisID := uint16(1) + var listenErr error + + closeErr := errors.New("close error") + + rpc := &MockRPCClient{} + rpc.On("Listen", local).Return(listenLisID, listenErr) + rpc.On("CloseListener", listenLisID).Return(closeErr) + + cl := NewClient(l, localPK, pid, rpc) + + _, err := cl.lm.add(listenLisID, nil) + require.NoError(t, err) + + listener, err := cl.Listen(network.TypeDMSG, port) + require.Equal(t, err, errValueAlreadyExists) + require.Nil(t, listener) + }) + t.Run("listen error", func(t *testing.T) { listenErr := errors.New("listen error") rpc := &MockRPCClient{} rpc.On("Listen", local).Return(uint16(0), listenErr) - cl := NewClient(localPK, pid, rpc) + cl := NewClient(l, localPK, pid, rpc) listener, err := cl.Listen(network.TypeDMSG, port) require.Equal(t, listenErr, err) diff --git a/pkg/app2/conn.go b/pkg/app2/conn.go index 09847da724..7c0c8d2a93 100644 --- a/pkg/app2/conn.go +++ b/pkg/app2/conn.go @@ -34,7 +34,11 @@ func (c *Conn) Write(b []byte) (int, error) { } func (c *Conn) Close() error { - defer c.freeConn() + defer func() { + if c.freeConn != nil { + c.freeConn() + } + }() return c.rpc.CloseConn(c.id) } diff --git a/pkg/app2/id_manager_test.go b/pkg/app2/id_manager_test.go index 0b9de7bae4..2390cf88aa 100644 --- a/pkg/app2/id_manager_test.go +++ b/pkg/app2/id_manager_test.go @@ -11,21 +11,22 @@ func TestIDManager_NextID(t *testing.T) { t.Run("simple call", func(t *testing.T) { m := newIDManager() - nextKey, err := m.nextKey() + // TODO: use free + nextID, _, err := m.reserveNextID() require.NoError(t, err) - v, ok := m.values[*nextKey] + v, ok := m.values[*nextID] require.True(t, ok) require.Nil(t, v) - require.Equal(t, *nextKey, uint16(1)) - require.Equal(t, *nextKey, m.lstKey) + require.Equal(t, *nextID, uint16(1)) + require.Equal(t, *nextID, m.lstID) - nextKey, err = m.nextKey() + nextID, _, err = m.reserveNextID() require.NoError(t, err) - v, ok = m.values[*nextKey] + v, ok = m.values[*nextID] require.True(t, ok) require.Nil(t, v) - require.Equal(t, *nextKey, uint16(2)) - require.Equal(t, *nextKey, m.lstKey) + require.Equal(t, *nextID, uint16(2)) + require.Equal(t, *nextID, m.lstID) }) t.Run("call on full manager", func(t *testing.T) { @@ -35,7 +36,7 @@ func TestIDManager_NextID(t *testing.T) { } m.values[math.MaxUint16] = nil - _, err := m.nextKey() + _, _, err := m.reserveNextID() require.Error(t, err) }) @@ -47,7 +48,7 @@ func TestIDManager_NextID(t *testing.T) { errs := make(chan error) for i := 0; i < valsToReserve; i++ { go func() { - _, err := m.nextKey() + _, _, err := m.reserveNextID() errs <- err }() } @@ -57,7 +58,7 @@ func TestIDManager_NextID(t *testing.T) { } close(errs) - require.Equal(t, m.lstKey, uint16(valsToReserve)) + require.Equal(t, m.lstID, uint16(valsToReserve)) for i := uint16(1); i < uint16(valsToReserve); i++ { v, ok := m.values[i] require.True(t, ok) @@ -132,7 +133,7 @@ func TestIDManager_Set(t *testing.T) { t.Run("simple call", func(t *testing.T) { m := newIDManager() - nextKey, err := m.nextKey() + nextKey, _, err := m.reserveNextID() require.NoError(t, err) v := "value" @@ -173,7 +174,7 @@ func TestIDManager_Set(t *testing.T) { concurrency := 1000 - nextKeyPtr, err := m.nextKey() + nextKeyPtr, _, err := m.reserveNextID() require.NoError(t, err) nextKey := *nextKeyPtr @@ -212,7 +213,7 @@ func TestIDManager_Get(t *testing.T) { prepManagerWithVal := func(v interface{}) (*idManager, uint16) { m := newIDManager() - nextKey, err := m.nextKey() + nextKey, _, err := m.reserveNextID() require.NoError(t, err) err = m.set(*nextKey, v) diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go index 41bc198e74..5f4497660b 100644 --- a/pkg/app2/listener.go +++ b/pkg/app2/listener.go @@ -50,7 +50,9 @@ func (l *Listener) Accept() (net.Conn, error) { func (l *Listener) Close() error { defer func() { l.freePort() - l.freeLis() + if l.freeLis != nil { + l.freeLis() + } var conns []net.Conn l.cm.doRange(func(_ uint16, v interface{}) bool { diff --git a/pkg/app2/mock_rpc_client.go b/pkg/app2/mock_rpc_client.go index af8f62fb6c..7caecc6aa8 100644 --- a/pkg/app2/mock_rpc_client.go +++ b/pkg/app2/mock_rpc_client.go @@ -2,10 +2,9 @@ package app2 -import ( - network "github.com/skycoin/skywire/pkg/app2/network" - mock "github.com/stretchr/testify/mock" -) +import mock "github.com/stretchr/testify/mock" +import network "github.com/skycoin/skywire/pkg/app2/network" +import routing "github.com/skycoin/skywire/pkg/routing" // MockRPCClient is an autogenerated mock type for the RPCClient type type MockRPCClient struct { @@ -69,7 +68,7 @@ func (_m *MockRPCClient) CloseListener(id uint16) error { } // Dial provides a mock function with given fields: remote -func (_m *MockRPCClient) Dial(remote network.Addr) (uint16, error) { +func (_m *MockRPCClient) Dial(remote network.Addr) (uint16, routing.Port, error) { ret := _m.Called(remote) var r0 uint16 @@ -79,14 +78,21 @@ func (_m *MockRPCClient) Dial(remote network.Addr) (uint16, error) { r0 = ret.Get(0).(uint16) } - var r1 error - if rf, ok := ret.Get(1).(func(network.Addr) error); ok { + var r1 routing.Port + if rf, ok := ret.Get(1).(func(network.Addr) routing.Port); ok { r1 = rf(remote) } else { - r1 = ret.Error(1) + r1 = ret.Get(1).(routing.Port) } - return r0, r1 + var r2 error + if rf, ok := ret.Get(2).(func(network.Addr) error); ok { + r2 = rf(remote) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 } // Listen provides a mock function with given fields: local diff --git a/pkg/app2/network/mock_networker.go b/pkg/app2/network/mock_networker.go index 560de13d97..fd4b304aea 100644 --- a/pkg/app2/network/mock_networker.go +++ b/pkg/app2/network/mock_networker.go @@ -2,12 +2,9 @@ package network -import ( - context "context" - net "net" - - mock "github.com/stretchr/testify/mock" -) +import context "context" +import mock "github.com/stretchr/testify/mock" +import net "net" // MockNetworker is an autogenerated mock type for the Networker type type MockNetworker struct { diff --git a/pkg/app2/rpc_gateway_test.go b/pkg/app2/rpc_gateway_test.go index 9e5301f231..8a007a80d1 100644 --- a/pkg/app2/rpc_gateway_test.go +++ b/pkg/app2/rpc_gateway_test.go @@ -1,26 +1,6 @@ package app2 -import ( - "context" - "math" - "net" - "strings" - "testing" - - "github.com/pkg/errors" - "github.com/stretchr/testify/require" - - "github.com/skycoin/dmsg" - - "github.com/skycoin/dmsg/cipher" - "github.com/skycoin/skywire/pkg/routing" - - "github.com/skycoin/skywire/pkg/app2/network" - - "github.com/skycoin/skycoin/src/util/logging" -) - -func TestRPCGateway_Dial(t *testing.T) { +/*func TestRPCGateway_Dial(t *testing.T) { l := logging.MustGetLogger("rpc_gateway") nType := network.TypeDMSG @@ -542,3 +522,4 @@ func addListener(t *testing.T, rpc *RPCGateway, lis net.Listener) uint16 { return *lisID } +*/ From 8c02d73e26007223f547adc93b36a792de3115a4 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Tue, 24 Sep 2019 15:49:41 +0300 Subject: [PATCH 39/43] Partially fix `idManager` tests --- pkg/app2/client_test.go | 3 +-- pkg/app2/id_manager_test.go | 46 ++++++++++++++++++++++--------------- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/pkg/app2/client_test.go b/pkg/app2/client_test.go index 6724cb4d27..a90dae35f8 100644 --- a/pkg/app2/client_test.go +++ b/pkg/app2/client_test.go @@ -3,10 +3,9 @@ package app2 import ( "testing" - "github.com/skycoin/skycoin/src/util/logging" - "github.com/pkg/errors" "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skycoin/src/util/logging" "github.com/stretchr/testify/require" "github.com/skycoin/skywire/pkg/app2/network" diff --git a/pkg/app2/id_manager_test.go b/pkg/app2/id_manager_test.go index 2390cf88aa..42ea38b8ae 100644 --- a/pkg/app2/id_manager_test.go +++ b/pkg/app2/id_manager_test.go @@ -2,6 +2,7 @@ package app2 import ( "math" + "strings" "testing" "github.com/stretchr/testify/require" @@ -11,17 +12,18 @@ func TestIDManager_NextID(t *testing.T) { t.Run("simple call", func(t *testing.T) { m := newIDManager() - // TODO: use free - nextID, _, err := m.reserveNextID() + nextID, free, err := m.reserveNextID() require.NoError(t, err) + require.NotNil(t, free) v, ok := m.values[*nextID] require.True(t, ok) require.Nil(t, v) require.Equal(t, *nextID, uint16(1)) require.Equal(t, *nextID, m.lstID) - nextID, _, err = m.reserveNextID() + nextID, free, err = m.reserveNextID() require.NoError(t, err) + require.NotNil(t, free) v, ok = m.values[*nextID] require.True(t, ok) require.Nil(t, v) @@ -89,6 +91,7 @@ func TestIDManager_Pop(t *testing.T) { _, err := m.pop(1) require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "no value")) }) t.Run("value not set", func(t *testing.T) { @@ -98,6 +101,7 @@ func TestIDManager_Pop(t *testing.T) { _, err := m.pop(1) require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "is not set")) }) t.Run("concurrent run", func(t *testing.T) { @@ -133,23 +137,24 @@ func TestIDManager_Set(t *testing.T) { t.Run("simple call", func(t *testing.T) { m := newIDManager() - nextKey, _, err := m.reserveNextID() + nextID, _, err := m.reserveNextID() require.NoError(t, err) v := "value" - err = m.set(*nextKey, v) + err = m.set(*nextID, v) require.NoError(t, err) - gotV, ok := m.values[*nextKey] + gotV, ok := m.values[*nextID] require.True(t, ok) require.Equal(t, gotV, v) }) - t.Run("key is not reserved", func(t *testing.T) { + t.Run("id is not reserved", func(t *testing.T) { m := newIDManager() err := m.set(1, "value") require.Error(t, err) + require.True(t, strings.Contains(err.Error(), "not reserved")) _, ok := m.values[1] require.False(t, ok) @@ -174,16 +179,16 @@ func TestIDManager_Set(t *testing.T) { concurrency := 1000 - nextKeyPtr, _, err := m.reserveNextID() + nextIDPtr, _, err := m.reserveNextID() require.NoError(t, err) - nextKey := *nextKeyPtr + nextID := *nextIDPtr errs := make(chan error) setV := make(chan int) for i := 0; i < concurrency; i++ { go func(v int) { - err := m.set(nextKey, v) + err := m.set(nextID, v) errs <- err if err == nil { setV <- v @@ -203,7 +208,7 @@ func TestIDManager_Set(t *testing.T) { v := <-setV close(setV) - gotV, ok := m.values[nextKey] + gotV, ok := m.values[nextID] require.True(t, ok) require.Equal(t, gotV, v) }) @@ -213,32 +218,37 @@ func TestIDManager_Get(t *testing.T) { prepManagerWithVal := func(v interface{}) (*idManager, uint16) { m := newIDManager() - nextKey, _, err := m.reserveNextID() + nextID, _, err := m.reserveNextID() require.NoError(t, err) - err = m.set(*nextKey, v) + err = m.set(*nextID, v) require.NoError(t, err) - return m, *nextKey + return m, *nextID } t.Run("simple call", func(t *testing.T) { v := "value" - m, key := prepManagerWithVal(v) + m, id := prepManagerWithVal(v) - gotV, ok := m.get(key) + gotV, ok := m.get(id) require.True(t, ok) require.Equal(t, gotV, v) _, ok = m.get(100) require.False(t, ok) + + m.values[2] = nil + gotV, ok = m.get(2) + require.False(t, ok) + require.Nil(t, gotV) }) t.Run("concurrent run", func(t *testing.T) { v := "value" - m, key := prepManagerWithVal(v) + m, id := prepManagerWithVal(v) concurrency := 1000 type getRes struct { @@ -248,7 +258,7 @@ func TestIDManager_Get(t *testing.T) { res := make(chan getRes) for i := 0; i < concurrency; i++ { go func() { - val, ok := m.get(key) + val, ok := m.get(id) res <- getRes{ v: val, ok: ok, From d7e140bc5f3caffaf1de5cfa93a37d18e412d106 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Wed, 25 Sep 2019 13:02:40 +0300 Subject: [PATCH 40/43] Get rid of porter --- pkg/app2/client.go | 45 +++++++++++++++++--------------------------- pkg/app2/listener.go | 14 ++++++-------- 2 files changed, 23 insertions(+), 36 deletions(-) diff --git a/pkg/app2/client.go b/pkg/app2/client.go index 5f6e9317d1..686b42afd3 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -4,7 +4,6 @@ import ( "net" "github.com/skycoin/dmsg/cipher" - "github.com/skycoin/dmsg/netutil" "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/pkg/app2/network" @@ -13,13 +12,12 @@ import ( // Client is used by skywire apps. type Client struct { - log *logging.Logger - pk cipher.PubKey - pid ProcID - rpc RPCClient - lm *idManager // contains listeners associated with their IDs - cm *idManager // contains connections associated with their IDs - porter *netutil.Porter + log *logging.Logger + pk cipher.PubKey + pid ProcID + rpc RPCClient + lm *idManager // contains listeners associated with their IDs + cm *idManager // contains connections associated with their IDs } // NewClient creates a new `Client`. The `Client` needs to be provided with: @@ -29,13 +27,12 @@ type Client struct { // - rpc: RPC client to communicate with the server. func NewClient(log *logging.Logger, localPK cipher.PubKey, pid ProcID, rpc RPCClient) *Client { return &Client{ - log: log, - pk: localPK, - pid: pid, - rpc: rpc, - lm: newIDManager(), - cm: newIDManager(), - porter: netutil.NewPorter(netutil.PorterMinEphemeral), + log: log, + pk: localPK, + pid: pid, + rpc: rpc, + lm: newIDManager(), + cm: newIDManager(), } } @@ -73,11 +70,6 @@ func (c *Client) Dial(remote network.Addr) (net.Conn, error) { // Listen listens on the specified `port` for the incoming connections. func (c *Client) Listen(n network.Type, port routing.Port) (net.Listener, error) { - ok, freePort := c.porter.Reserve(uint16(port), nil) - if !ok { - return nil, ErrPortAlreadyBound - } - local := network.Addr{ Net: n, PubKey: c.pk, @@ -86,17 +78,15 @@ func (c *Client) Listen(n network.Type, port routing.Port) (net.Listener, error) lisID, err := c.rpc.Listen(local) if err != nil { - freePort() return nil, err } listener := &Listener{ - log: c.log, - id: lisID, - rpc: c.rpc, - addr: local, - cm: newIDManager(), - freePort: freePort, + log: c.log, + id: lisID, + rpc: c.rpc, + addr: local, + cm: newIDManager(), } freeLis, err := c.lm.add(lisID, listener) @@ -105,7 +95,6 @@ func (c *Client) Listen(n network.Type, port routing.Port) (net.Listener, error) c.log.WithError(err).Error("error closing listener") } - freePort() return nil, err } diff --git a/pkg/app2/listener.go b/pkg/app2/listener.go index 5f4497660b..9b20399d44 100644 --- a/pkg/app2/listener.go +++ b/pkg/app2/listener.go @@ -11,13 +11,12 @@ import ( // Listener is a listener for app server connections. // Implements `net.Listener`. type Listener struct { - log *logging.Logger - id uint16 - rpc RPCClient - addr network.Addr - cm *idManager // contains conns associated with their IDs - freePort func() - freeLis func() + log *logging.Logger + id uint16 + rpc RPCClient + addr network.Addr + cm *idManager // contains conns associated with their IDs + freeLis func() } func (l *Listener) Accept() (net.Conn, error) { @@ -49,7 +48,6 @@ func (l *Listener) Accept() (net.Conn, error) { func (l *Listener) Close() error { defer func() { - l.freePort() if l.freeLis != nil { l.freeLis() } From 17ba14623afb4a1b56e9da1b7cd17286bfbdf8c8 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Wed, 25 Sep 2019 13:16:20 +0300 Subject: [PATCH 41/43] Fix PR queries --- pkg/app2/client.go | 4 ++-- pkg/app2/client_test.go | 14 +++++++------- pkg/app2/network/addr.go | 4 ++-- pkg/app2/network/wrapped_conn.go | 4 ++-- pkg/app2/rpc_client.go | 4 ++-- pkg/app2/rpc_gateway.go | 8 ++++---- 6 files changed, 19 insertions(+), 19 deletions(-) diff --git a/pkg/app2/client.go b/pkg/app2/client.go index 686b42afd3..3d8a5f595d 100644 --- a/pkg/app2/client.go +++ b/pkg/app2/client.go @@ -38,7 +38,7 @@ func NewClient(log *logging.Logger, localPK cipher.PubKey, pid ProcID, rpc RPCCl // Dial dials the remote node using `remote`. func (c *Client) Dial(remote network.Addr) (net.Conn, error) { - connID, assignedPort, err := c.rpc.Dial(remote) + connID, localPort, err := c.rpc.Dial(remote) if err != nil { return nil, err } @@ -49,7 +49,7 @@ func (c *Client) Dial(remote network.Addr) (net.Conn, error) { local: network.Addr{ Net: remote.Net, PubKey: c.pk, - Port: assignedPort, + Port: localPort, }, remote: remote, } diff --git a/pkg/app2/client_test.go b/pkg/app2/client_test.go index a90dae35f8..407ea90aae 100644 --- a/pkg/app2/client_test.go +++ b/pkg/app2/client_test.go @@ -27,11 +27,11 @@ func TestClient_Dial(t *testing.T) { t.Run("ok", func(t *testing.T) { dialConnID := uint16(1) - dialAssignedPort := routing.Port(1) + dialLocalPort := routing.Port(1) var dialErr error rpc := &MockRPCClient{} - rpc.On("Dial", remote).Return(dialConnID, dialAssignedPort, dialErr) + rpc.On("Dial", remote).Return(dialConnID, dialLocalPort, dialErr) cl := NewClient(l, localPK, pid, rpc) @@ -41,7 +41,7 @@ func TestClient_Dial(t *testing.T) { local: network.Addr{ Net: remote.Net, PubKey: localPK, - Port: dialAssignedPort, + Port: dialLocalPort, }, remote: remote, } @@ -61,13 +61,13 @@ func TestClient_Dial(t *testing.T) { t.Run("conn already exists", func(t *testing.T) { dialConnID := uint16(1) - dialAssignedPort := routing.Port(1) + dialLocalPort := routing.Port(1) var dialErr error var closeErr error rpc := &MockRPCClient{} - rpc.On("Dial", remote).Return(dialConnID, dialAssignedPort, dialErr) + rpc.On("Dial", remote).Return(dialConnID, dialLocalPort, dialErr) rpc.On("CloseConn", dialConnID).Return(closeErr) cl := NewClient(l, localPK, pid, rpc) @@ -82,13 +82,13 @@ func TestClient_Dial(t *testing.T) { t.Run("conn already exists, conn closed with error", func(t *testing.T) { dialConnID := uint16(1) - dialAssignedPort := routing.Port(1) + dialLocalPort := routing.Port(1) var dialErr error closeErr := errors.New("close error") rpc := &MockRPCClient{} - rpc.On("Dial", remote).Return(dialConnID, dialAssignedPort, dialErr) + rpc.On("Dial", remote).Return(dialConnID, dialLocalPort, dialErr) rpc.On("CloseConn", dialConnID).Return(closeErr) cl := NewClient(l, localPK, pid, rpc) diff --git a/pkg/app2/network/addr.go b/pkg/app2/network/addr.go index c3fa28020e..fcef27b92e 100644 --- a/pkg/app2/network/addr.go +++ b/pkg/app2/network/addr.go @@ -35,9 +35,9 @@ func (a Addr) String() string { return fmt.Sprintf("%s:%d", a.PubKey, a.Port) } -// WrapAddr asserts type of the passed `net.Addr` and converts it +// ConvertAddr asserts type of the passed `net.Addr` and converts it // to `Addr` if possible. -func WrapAddr(addr net.Addr) (Addr, error) { +func ConvertAddr(addr net.Addr) (Addr, error) { switch a := addr.(type) { case dmsg.Addr: return Addr{ diff --git a/pkg/app2/network/wrapped_conn.go b/pkg/app2/network/wrapped_conn.go index cdc1595f2c..d8e1f4df95 100644 --- a/pkg/app2/network/wrapped_conn.go +++ b/pkg/app2/network/wrapped_conn.go @@ -14,12 +14,12 @@ type WrappedConn struct { // WrapConn wraps passed `conn`. Handles `net.Addr` type assertion. func WrapConn(conn net.Conn) (net.Conn, error) { - l, err := WrapAddr(conn.LocalAddr()) + l, err := ConvertAddr(conn.LocalAddr()) if err != nil { return nil, err } - r, err := WrapAddr(conn.RemoteAddr()) + r, err := ConvertAddr(conn.RemoteAddr()) if err != nil { return nil, err } diff --git a/pkg/app2/rpc_client.go b/pkg/app2/rpc_client.go index c0440c134f..767cf304ac 100644 --- a/pkg/app2/rpc_client.go +++ b/pkg/app2/rpc_client.go @@ -11,7 +11,7 @@ import ( // RPCClient describes RPC interface to communicate with the server. type RPCClient interface { - Dial(remote network.Addr) (connID uint16, assignedPort routing.Port, err error) + Dial(remote network.Addr) (connID uint16, localPort routing.Port, err error) Listen(local network.Addr) (uint16, error) Accept(lisID uint16) (connID uint16, remote network.Addr, err error) Write(connID uint16, b []byte) (int, error) @@ -33,7 +33,7 @@ func NewRPCClient(rpc *rpc.Client) RPCClient { } // Dial sends `Dial` command to the server. -func (c *rpcCLient) Dial(remote network.Addr) (connID uint16, assignedPort routing.Port, err error) { +func (c *rpcCLient) Dial(remote network.Addr) (connID uint16, localPort routing.Port, err error) { var resp DialResp if err := c.rpc.Call("Dial", &remote, &resp); err != nil { return 0, 0, err diff --git a/pkg/app2/rpc_gateway.go b/pkg/app2/rpc_gateway.go index fab2245937..26795d564f 100644 --- a/pkg/app2/rpc_gateway.go +++ b/pkg/app2/rpc_gateway.go @@ -46,7 +46,7 @@ func (r *RPCGateway) Dial(remote *network.Addr, resp *DialResp) error { return err } - localAddr, err := network.WrapAddr(conn.LocalAddr()) + localAddr, err := network.ConvertAddr(conn.LocalAddr()) if err != nil { free() return err @@ -127,10 +127,10 @@ func (r *RPCGateway) Accept(lisID *uint16, resp *AcceptResp) error { return err } - remote, ok := conn.RemoteAddr().(network.Addr) - if !ok { + remote, err := network.ConvertAddr(conn.RemoteAddr()) + if err != nil { free() - return errors.New("wrong type for remote addr") + return err } resp.Remote = remote From d56035181740ca164d1ec46499267a792a950dc5 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Wed, 25 Sep 2019 15:20:38 +0300 Subject: [PATCH 42/43] Fix tests --- pkg/app2/client_test.go | 94 +++++++++++++----- pkg/app2/id_manager_test.go | 141 +++++++++++++++++++++++++- pkg/app2/listener_test.go | 186 +++++++++++++++++++++++++++++------ pkg/app2/network/addr.go | 4 +- pkg/app2/rpc_client.go | 2 +- pkg/app2/rpc_gateway.go | 30 +++--- pkg/app2/rpc_gateway_test.go | 91 ++++++++++++----- 7 files changed, 452 insertions(+), 96 deletions(-) diff --git a/pkg/app2/client_test.go b/pkg/app2/client_test.go index 407ea90aae..56e571cbfd 100644 --- a/pkg/app2/client_test.go +++ b/pkg/app2/client_test.go @@ -57,6 +57,14 @@ func TestClient_Dial(t *testing.T) { require.Equal(t, wantConn.local, appConn.local) require.Equal(t, wantConn.remote, appConn.remote) require.NotNil(t, appConn.freeConn) + + cmConnIfc, ok := cl.cm.values[appConn.id] + require.True(t, ok) + require.NotNil(t, cmConnIfc) + + cmConn, ok := cmConnIfc.(*Conn) + require.True(t, ok) + require.NotNil(t, cmConn.freeConn) }) t.Run("conn already exists", func(t *testing.T) { @@ -105,7 +113,7 @@ func TestClient_Dial(t *testing.T) { dialErr := errors.New("dial error") rpc := &MockRPCClient{} - rpc.On("Dial", remote).Return(uint16(0), uint16(0), dialErr) + rpc.On("Dial", remote).Return(uint16(0), routing.Port(0), dialErr) cl := NewClient(l, localPK, pid, rpc) @@ -151,26 +159,7 @@ func TestClient_Listen(t *testing.T) { require.Equal(t, wantListener.id, appListener.id) require.Equal(t, wantListener.rpc, appListener.rpc) require.Equal(t, wantListener.addr, appListener.addr) - require.NotNil(t, appListener.freePort) require.NotNil(t, appListener.freeLis) - portVal, ok := cl.porter.PortValue(uint16(port)) - require.True(t, ok) - require.Nil(t, portVal) - }) - - t.Run("port is already bound", func(t *testing.T) { - rpc := &MockRPCClient{} - - cl := NewClient(l, localPK, pid, rpc) - - ok, _ := cl.porter.Reserve(uint16(port), nil) - require.True(t, ok) - - wantErr := ErrPortAlreadyBound - - listener, err := cl.Listen(network.TypeDMSG, port) - require.Equal(t, wantErr, err) - require.Nil(t, listener) }) t.Run("listener already exists", func(t *testing.T) { @@ -224,7 +213,68 @@ func TestClient_Listen(t *testing.T) { listener, err := cl.Listen(network.TypeDMSG, port) require.Equal(t, listenErr, err) require.Nil(t, listener) - _, ok := cl.porter.PortValue(uint16(port)) - require.False(t, ok) }) } + +func TestClient_Close(t *testing.T) { + l := logging.MustGetLogger("app2_client") + localPK, _ := cipher.GenerateKeyPair() + pid := ProcID(1) + + var closeNoErr error + closeErr := errors.New("close error") + + rpc := &MockRPCClient{} + + lisID1 := uint16(1) + lisID2 := uint16(2) + + rpc.On("CloseListener", lisID1).Return(closeNoErr) + rpc.On("CloseListener", lisID2).Return(closeErr) + + lm := newIDManager() + + lis1 := &Listener{id: lisID1, rpc: rpc, cm: newIDManager()} + freeLis1, err := lm.add(lisID1, lis1) + require.NoError(t, err) + lis1.freeLis = freeLis1 + + lis2 := &Listener{id: lisID2, rpc: rpc, cm: newIDManager()} + freeLis2, err := lm.add(lisID2, lis2) + require.NoError(t, err) + lis2.freeLis = freeLis2 + + connID1 := uint16(1) + connID2 := uint16(2) + + rpc.On("CloseConn", connID1).Return(closeNoErr) + rpc.On("CloseConn", connID2).Return(closeErr) + + cm := newIDManager() + + conn1 := &Conn{id: connID1, rpc: rpc} + freeConn1, err := cm.add(connID1, conn1) + require.NoError(t, err) + conn1.freeConn = freeConn1 + + conn2 := &Conn{id: connID2, rpc: rpc} + freeConn2, err := cm.add(connID2, conn2) + require.NoError(t, err) + conn2.freeConn = freeConn2 + + cl := NewClient(l, localPK, pid, rpc) + cl.cm = cm + cl.lm = lm + + cl.Close() + + _, ok := lm.values[lisID1] + require.False(t, ok) + _, ok = lm.values[lisID2] + require.False(t, ok) + + _, ok = cm.values[connID1] + require.False(t, ok) + _, ok = cm.values[connID2] + require.False(t, ok) +} diff --git a/pkg/app2/id_manager_test.go b/pkg/app2/id_manager_test.go index 42ea38b8ae..20513ea456 100644 --- a/pkg/app2/id_manager_test.go +++ b/pkg/app2/id_manager_test.go @@ -2,13 +2,14 @@ package app2 import ( "math" + "sort" "strings" "testing" "github.com/stretchr/testify/require" ) -func TestIDManager_NextID(t *testing.T) { +func TestIDManager_ReserveNextID(t *testing.T) { t.Run("simple call", func(t *testing.T) { m := newIDManager() @@ -133,6 +134,70 @@ func TestIDManager_Pop(t *testing.T) { }) } +func TestIDManager_Add(t *testing.T) { + t.Run("simple call", func(t *testing.T) { + m := newIDManager() + + id := uint16(1) + v := "value" + + free, err := m.add(id, v) + require.Nil(t, err) + require.NotNil(t, free) + + gotV, ok := m.values[id] + require.True(t, ok) + require.Equal(t, gotV, v) + + v2 := "value2" + + free, err = m.add(id, v2) + require.Equal(t, err, errValueAlreadyExists) + require.Nil(t, free) + + gotV, ok = m.values[id] + require.True(t, ok) + require.Equal(t, gotV, v) + }) + + t.Run("concurrent run", func(t *testing.T) { + m := newIDManager() + + id := uint16(1) + + concurrency := 1000 + + addV := make(chan int) + errs := make(chan error) + for i := 0; i < concurrency; i++ { + go func(v int) { + _, err := m.add(id, v) + errs <- err + if err == nil { + addV <- v + } + }(i) + } + + errsCount := 0 + for i := 0; i < concurrency; i++ { + if err := <-errs; err != nil { + errsCount++ + } + } + close(errs) + + v := <-addV + close(addV) + + require.Equal(t, concurrency-1, errsCount) + + gotV, ok := m.values[id] + require.True(t, ok) + require.Equal(t, gotV, v) + }) +} + func TestIDManager_Set(t *testing.T) { t.Run("simple call", func(t *testing.T) { m := newIDManager() @@ -208,6 +273,8 @@ func TestIDManager_Set(t *testing.T) { v := <-setV close(setV) + require.Equal(t, concurrency-1, errsCount) + gotV, ok := m.values[nextID] require.True(t, ok) require.Equal(t, gotV, v) @@ -274,3 +341,75 @@ func TestIDManager_Get(t *testing.T) { close(res) }) } + +func TestIDManager_DoRange(t *testing.T) { + m := newIDManager() + + valsCount := 5 + + vals := make([]int, 0, valsCount) + for i := 0; i < valsCount; i++ { + vals = append(vals, i) + } + + for i, v := range vals { + _, err := m.add(uint16(i), v) + require.NoError(t, err) + } + + // run full range + gotVals := make([]int, 0, valsCount) + m.doRange(func(_ uint16, v interface{}) bool { + val, ok := v.(int) + require.True(t, ok) + + gotVals = append(gotVals, val) + + return true + }) + sort.Ints(gotVals) + require.Equal(t, gotVals, vals) + + // run part range + var gotVal int + gotValsCount := 0 + m.doRange(func(_ uint16, v interface{}) bool { + if gotValsCount == 1 { + return false + } + + val, ok := v.(int) + require.True(t, ok) + + gotVal = val + + gotValsCount++ + + return true + }) + + found := false + for _, v := range vals { + if v == gotVal { + found = true + } + } + require.True(t, found) +} + +func TestIDManager_ConstructFreeFunc(t *testing.T) { + m := newIDManager() + + id := uint16(1) + v := "value" + + free, err := m.add(id, v) + require.NoError(t, err) + require.NotNil(t, free) + + free() + + gotV, ok := m.values[id] + require.False(t, ok) + require.Nil(t, gotV) +} diff --git a/pkg/app2/listener_test.go b/pkg/app2/listener_test.go index 058b070657..9ff53cf7b4 100644 --- a/pkg/app2/listener_test.go +++ b/pkg/app2/listener_test.go @@ -5,6 +5,7 @@ import ( "github.com/pkg/errors" "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skycoin/src/util/logging" "github.com/stretchr/testify/require" "github.com/skycoin/skywire/pkg/app2/network" @@ -12,6 +13,8 @@ import ( ) func TestListener_Accept(t *testing.T) { + l := logging.MustGetLogger("app2_listener") + lisID := uint16(1) localPK, _ := cipher.GenerateKeyPair() local := network.Addr{ @@ -22,10 +25,10 @@ func TestListener_Accept(t *testing.T) { t.Run("ok", func(t *testing.T) { acceptConnID := uint16(1) - remotePK, _ := cipher.GenerateKeyPair() + acceptRemotePK, _ := cipher.GenerateKeyPair() acceptRemote := network.Addr{ Net: network.TypeDMSG, - PubKey: remotePK, + PubKey: acceptRemotePK, Port: routing.Port(100), } var acceptErr error @@ -37,6 +40,7 @@ func TestListener_Accept(t *testing.T) { id: lisID, rpc: rpc, addr: local, + cm: newIDManager(), } wantConn := &Conn{ @@ -48,7 +52,82 @@ func TestListener_Accept(t *testing.T) { conn, err := lis.Accept() require.NoError(t, err) - require.Equal(t, conn, wantConn) + + appConn, ok := conn.(*Conn) + require.True(t, ok) + require.Equal(t, wantConn.id, appConn.id) + require.Equal(t, wantConn.rpc, appConn.rpc) + require.Equal(t, wantConn.local, appConn.local) + require.Equal(t, wantConn.remote, appConn.remote) + require.NotNil(t, appConn.freeConn) + + connIfc, ok := lis.cm.values[acceptConnID] + require.True(t, ok) + + appConn, ok = connIfc.(*Conn) + require.True(t, ok) + require.NotNil(t, appConn.freeConn) + }) + + t.Run("conn already exists", func(t *testing.T) { + acceptConnID := uint16(1) + acceptRemotePK, _ := cipher.GenerateKeyPair() + acceptRemote := network.Addr{ + Net: network.TypeDMSG, + PubKey: acceptRemotePK, + Port: routing.Port(100), + } + var acceptErr error + + var closeErr error + + rpc := &MockRPCClient{} + rpc.On("Accept", acceptConnID).Return(acceptConnID, acceptRemote, acceptErr) + rpc.On("CloseConn", acceptConnID).Return(closeErr) + + lis := &Listener{ + id: lisID, + rpc: rpc, + addr: local, + cm: newIDManager(), + } + + lis.cm.values[acceptConnID] = nil + + conn, err := lis.Accept() + require.Equal(t, err, errValueAlreadyExists) + require.Nil(t, conn) + }) + + t.Run("conn already exists, conn closed with error", func(t *testing.T) { + acceptConnID := uint16(1) + acceptRemotePK, _ := cipher.GenerateKeyPair() + acceptRemote := network.Addr{ + Net: network.TypeDMSG, + PubKey: acceptRemotePK, + Port: routing.Port(100), + } + var acceptErr error + + closeErr := errors.New("close error") + + rpc := &MockRPCClient{} + rpc.On("Accept", acceptConnID).Return(acceptConnID, acceptRemote, acceptErr) + rpc.On("CloseConn", acceptConnID).Return(closeErr) + + lis := &Listener{ + log: l, + id: lisID, + rpc: rpc, + addr: local, + cm: newIDManager(), + } + + lis.cm.values[acceptConnID] = nil + + conn, err := lis.Accept() + require.Equal(t, err, errValueAlreadyExists) + require.Nil(t, conn) }) t.Run("accept error", func(t *testing.T) { @@ -63,6 +142,7 @@ func TestListener_Accept(t *testing.T) { id: lisID, rpc: rpc, addr: local, + cm: newIDManager(), } conn, err := lis.Accept() @@ -72,6 +152,8 @@ func TestListener_Accept(t *testing.T) { } func TestListener_Close(t *testing.T) { + l := logging.MustGetLogger("app2_listener") + lisID := uint16(1) localPK, _ := cipher.GenerateKeyPair() local := network.Addr{ @@ -80,33 +162,75 @@ func TestListener_Close(t *testing.T) { Port: routing.Port(100), } - tt := []struct { - name string - closeErr error - }{ - { - name: "ok", - }, - { - name: "close error", - closeErr: errors.New("close error"), - }, - } + t.Run("ok", func(t *testing.T) { + var closeNoErr error + closeErr := errors.New("close error") - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - rpc := &MockRPCClient{} - rpc.On("CloseListener", lisID).Return(tc.closeErr) - - lis := &Listener{ - id: lisID, - rpc: rpc, - addr: local, - freePort: func() {}, - } - - err := lis.Close() - require.Equal(t, tc.closeErr, err) - }) - } + rpc := &MockRPCClient{} + rpc.On("CloseListener", lisID).Return(closeNoErr) + + cm := newIDManager() + + connID1 := uint16(1) + connID2 := uint16(2) + connID3 := uint16(3) + + rpc.On("CloseConn", connID1).Return(closeNoErr) + rpc.On("CloseConn", connID2).Return(closeErr) + rpc.On("CloseConn", connID3).Return(closeNoErr) + + conn1 := &Conn{id: connID1, rpc: rpc} + free1, err := cm.add(connID1, conn1) + require.NoError(t, err) + conn1.freeConn = free1 + + conn2 := &Conn{id: connID2, rpc: rpc} + free2, err := cm.add(connID2, conn2) + require.NoError(t, err) + conn2.freeConn = free2 + + conn3 := &Conn{id: connID3, rpc: rpc} + free3, err := cm.add(connID3, conn3) + require.NoError(t, err) + conn3.freeConn = free3 + + lis := &Listener{ + log: l, + id: lisID, + rpc: rpc, + addr: local, + cm: cm, + freeLis: func() {}, + } + + err = lis.Close() + require.NoError(t, err) + + _, ok := lis.cm.values[connID1] + require.False(t, ok) + + _, ok = lis.cm.values[connID2] + require.False(t, ok) + + _, ok = lis.cm.values[connID3] + require.False(t, ok) + }) + + t.Run("close error", func(t *testing.T) { + lisCloseErr := errors.New("close error") + + rpc := &MockRPCClient{} + rpc.On("CloseListener", lisID).Return(lisCloseErr) + + lis := &Listener{ + log: l, + id: lisID, + rpc: rpc, + addr: local, + cm: newIDManager(), + } + + err := lis.Close() + require.Equal(t, err, lisCloseErr) + }) } diff --git a/pkg/app2/network/addr.go b/pkg/app2/network/addr.go index fcef27b92e..d96aabc2a2 100644 --- a/pkg/app2/network/addr.go +++ b/pkg/app2/network/addr.go @@ -12,7 +12,7 @@ import ( ) var ( - errUnknownAddrType = errors.New("addr type is unknown") + ErrUnknownAddrType = errors.New("addr type is unknown") ) // Addr implements net.Addr for network addresses. @@ -46,6 +46,6 @@ func ConvertAddr(addr net.Addr) (Addr, error) { Port: routing.Port(a.Port), }, nil default: - return Addr{}, errUnknownAddrType + return Addr{}, ErrUnknownAddrType } } diff --git a/pkg/app2/rpc_client.go b/pkg/app2/rpc_client.go index 767cf304ac..6fc7b240b6 100644 --- a/pkg/app2/rpc_client.go +++ b/pkg/app2/rpc_client.go @@ -39,7 +39,7 @@ func (c *rpcCLient) Dial(remote network.Addr) (connID uint16, localPort routing. return 0, 0, err } - return resp.ConnID, resp.AssignedPort, nil + return resp.ConnID, resp.LocalPort, nil } // Listen sends `Listen` command to the server. diff --git a/pkg/app2/rpc_gateway.go b/pkg/app2/rpc_gateway.go index 26795d564f..dd1387131c 100644 --- a/pkg/app2/rpc_gateway.go +++ b/pkg/app2/rpc_gateway.go @@ -29,8 +29,8 @@ func newRPCGateway(log *logging.Logger) *RPCGateway { // DialResp contains response parameters for `Dial`. type DialResp struct { - ConnID uint16 - AssignedPort routing.Port + ConnID uint16 + LocalPort routing.Port } // Dial dials to the remote. @@ -46,14 +46,14 @@ func (r *RPCGateway) Dial(remote *network.Addr, resp *DialResp) error { return err } - localAddr, err := network.ConvertAddr(conn.LocalAddr()) + wrappedConn, err := network.WrapConn(conn) if err != nil { free() return err } - if err := r.cm.set(*reservedConnID, conn); err != nil { - if err := conn.Close(); err != nil { + if err := r.cm.set(*reservedConnID, wrappedConn); err != nil { + if err := wrappedConn.Close(); err != nil { r.log.WithError(err).Error("error closing conn") } @@ -61,8 +61,10 @@ func (r *RPCGateway) Dial(remote *network.Addr, resp *DialResp) error { return err } + localAddr := wrappedConn.LocalAddr().(network.Addr) + resp.ConnID = *reservedConnID - resp.AssignedPort = localAddr.Port + resp.LocalPort = localAddr.Port return nil } @@ -118,21 +120,23 @@ func (r *RPCGateway) Accept(lisID *uint16, resp *AcceptResp) error { return err } - if err := r.cm.set(*connID, conn); err != nil { - if err := conn.Close(); err != nil { - r.log.WithError(err).Error("error closing DMSG transport") - } - + wrappedConn, err := network.WrapConn(conn) + if err != nil { free() return err } - remote, err := network.ConvertAddr(conn.RemoteAddr()) - if err != nil { + if err := r.cm.set(*connID, wrappedConn); err != nil { + if err := wrappedConn.Close(); err != nil { + r.log.WithError(err).Error("error closing DMSG transport") + } + free() return err } + remote := wrappedConn.RemoteAddr().(network.Addr) + resp.Remote = remote resp.ConnID = *connID diff --git a/pkg/app2/rpc_gateway_test.go b/pkg/app2/rpc_gateway_test.go index 8a007a80d1..637100a740 100644 --- a/pkg/app2/rpc_gateway_test.go +++ b/pkg/app2/rpc_gateway_test.go @@ -1,6 +1,23 @@ package app2 -/*func TestRPCGateway_Dial(t *testing.T) { +import ( + "context" + "math" + "net" + "strings" + "testing" + + "github.com/pkg/errors" + "github.com/skycoin/dmsg" + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skycoin/src/util/logging" + "github.com/stretchr/testify/require" + + "github.com/skycoin/skywire/pkg/app2/network" + "github.com/skycoin/skywire/pkg/routing" +) + +func TestRPCGateway_Dial(t *testing.T) { l := logging.MustGetLogger("rpc_gateway") nType := network.TypeDMSG @@ -9,8 +26,10 @@ package app2 t.Run("ok", func(t *testing.T) { network.ClearNetworkers() + localPort := routing.Port(100) + dialCtx := context.Background() - dialConn := &dmsg.Transport{} + dialConn := dmsg.NewTransport(nil, nil, dmsg.Addr{Port: uint16(localPort)}, dmsg.Addr{}, 0, func() {}) var dialErr error n := &network.MockNetworker{} @@ -21,11 +40,11 @@ package app2 rpc := newRPCGateway(l) - var connID uint16 - - err = rpc.Dial(&dialAddr, &connID) + var resp DialResp + err = rpc.Dial(&dialAddr, &resp) require.NoError(t, err) - require.Equal(t, connID, uint16(1)) + require.Equal(t, resp.ConnID, uint16(1)) + require.Equal(t, resp.LocalPort, localPort) }) t.Run("no more slots for a new conn", func(t *testing.T) { @@ -35,9 +54,8 @@ package app2 } rpc.cm.values[math.MaxUint16] = nil - var connID uint16 - - err := rpc.Dial(&dialAddr, &connID) + var resp DialResp + err := rpc.Dial(&dialAddr, &resp) require.Equal(t, err, errNoMoreAvailableValues) }) @@ -56,11 +74,32 @@ package app2 rpc := newRPCGateway(l) - var connID uint16 - - err = rpc.Dial(&dialAddr, &connID) + var resp DialResp + err = rpc.Dial(&dialAddr, &resp) require.Equal(t, err, dialErr) }) + + t.Run("error wrapping conn", func(t *testing.T) { + network.ClearNetworkers() + + dialCtx := context.Background() + dialConn := &MockConn{} + dialConn.On("LocalAddr").Return(routing.Addr{}) + dialConn.On("RemoteAddr").Return(routing.Addr{}) + var dialErr error + + n := &network.MockNetworker{} + n.On("DialContext", dialCtx, dialAddr).Return(dialConn, dialErr) + + err := network.AddNetworker(nType, n) + require.NoError(t, err) + + rpc := newRPCGateway(l) + + var resp DialResp + err = rpc.Dial(&dialAddr, &resp) + require.Equal(t, err, network.ErrUnknownAddrType) + }) } func TestRPCGateway_Listen(t *testing.T) { @@ -132,7 +171,7 @@ func TestRPCGateway_Accept(t *testing.T) { t.Run("ok", func(t *testing.T) { rpc := newRPCGateway(l) - acceptConn := network.NewDMSGConn(&dmsg.Transport{}) + acceptConn := &dmsg.Transport{} var acceptErr error lis := &MockListener{} @@ -143,7 +182,7 @@ func TestRPCGateway_Accept(t *testing.T) { var resp AcceptResp err := rpc.Accept(&lisID, &resp) require.NoError(t, err) - require.Equal(t, resp.Remote, acceptConn.RemoteAddr()) + require.Equal(t, resp.Remote, network.Addr{Net: network.TypeDMSG}) }) t.Run("no such listener", func(t *testing.T) { @@ -182,11 +221,13 @@ func TestRPCGateway_Accept(t *testing.T) { require.Equal(t, err, errNoMoreAvailableValues) }) - t.Run("accept error", func(t *testing.T) { + t.Run("error wrapping conn", func(t *testing.T) { rpc := newRPCGateway(l) - var acceptConn net.Conn - acceptErr := errors.New("accept error") + acceptConn := &MockConn{} + acceptConn.On("LocalAddr").Return(routing.Addr{}) + acceptConn.On("RemoteAddr").Return(routing.Addr{}) + var acceptErr error lis := &MockListener{} lis.On("Accept").Return(acceptConn, acceptErr) @@ -195,14 +236,14 @@ func TestRPCGateway_Accept(t *testing.T) { var resp AcceptResp err := rpc.Accept(&lisID, &resp) - require.Equal(t, err, acceptErr) + require.Equal(t, err, network.ErrUnknownAddrType) }) - t.Run("wrong type of remote addr", func(t *testing.T) { + t.Run("accept error", func(t *testing.T) { rpc := newRPCGateway(l) - acceptConn := &dmsg.Transport{} - var acceptErr error + var acceptConn net.Conn + acceptErr := errors.New("accept error") lis := &MockListener{} lis.On("Accept").Return(acceptConn, acceptErr) @@ -211,8 +252,7 @@ func TestRPCGateway_Accept(t *testing.T) { var resp AcceptResp err := rpc.Accept(&lisID, &resp) - require.Error(t, err) - require.True(t, strings.Contains(err.Error(), "wrong type")) + require.Equal(t, err, acceptErr) }) } @@ -504,7 +544,7 @@ func prepAddr(nType network.Type) network.Addr { } func addConn(t *testing.T, rpc *RPCGateway, conn net.Conn) uint16 { - connID, err := rpc.cm.nextKey() + connID, _, err := rpc.cm.reserveNextID() require.NoError(t, err) err = rpc.cm.set(*connID, conn) @@ -514,7 +554,7 @@ func addConn(t *testing.T, rpc *RPCGateway, conn net.Conn) uint16 { } func addListener(t *testing.T, rpc *RPCGateway, lis net.Listener) uint16 { - lisID, err := rpc.lm.nextKey() + lisID, _, err := rpc.lm.reserveNextID() require.NoError(t, err) err = rpc.lm.set(*lisID, lis) @@ -522,4 +562,3 @@ func addListener(t *testing.T, rpc *RPCGateway, lis net.Listener) uint16 { return *lisID } -*/ From 662c229faaa47f926e04a06a48500555b8033085 Mon Sep 17 00:00:00 2001 From: Sir Darkrengarius Date: Thu, 26 Sep 2019 15:12:53 +0300 Subject: [PATCH 43/43] Add rpcClient tests --- pkg/app2/conn.go | 5 +- pkg/app2/conn_test.go | 22 +- pkg/app2/mock_rpc_client.go | 19 +- pkg/app2/network/dmsg_networker.go | 7 +- pkg/app2/rpc_client.go | 22 +- pkg/app2/rpc_client_test.go | 494 +++++++++++++++++++++++++++++ 6 files changed, 520 insertions(+), 49 deletions(-) create mode 100644 pkg/app2/rpc_client_test.go diff --git a/pkg/app2/conn.go b/pkg/app2/conn.go index 7c0c8d2a93..e6473a9eaf 100644 --- a/pkg/app2/conn.go +++ b/pkg/app2/conn.go @@ -18,14 +18,11 @@ type Conn struct { } func (c *Conn) Read(b []byte) (int, error) { - n, readBytes, err := c.rpc.Read(c.id, b) + n, err := c.rpc.Read(c.id, b) if err != nil { return 0, err } - // TODO: check for slice border - copy(b[:n], readBytes[:n]) - return n, err } diff --git a/pkg/app2/conn_test.go b/pkg/app2/conn_test.go index ef860a6983..2b185065c7 100644 --- a/pkg/app2/conn_test.go +++ b/pkg/app2/conn_test.go @@ -11,32 +11,27 @@ func TestConn_Read(t *testing.T) { connID := uint16(1) tt := []struct { - name string - readBuff []byte - readN int - readBytes []byte - readErr error - wantBuff []byte + name string + readBuff []byte + readN int + readErr error }{ { - name: "ok", - readBuff: make([]byte, 10), - readN: 2, - readBytes: []byte{1, 1, 0, 0, 0, 0, 0, 0, 0, 0}, - wantBuff: []byte{1, 1, 0, 0, 0, 0, 0, 0, 0, 0}, + name: "ok", + readBuff: make([]byte, 10), + readN: 2, }, { name: "read error", readBuff: make([]byte, 10), readErr: errors.New("read error"), - wantBuff: make([]byte, 10), }, } for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { rpc := &MockRPCClient{} - rpc.On("Read", connID, tc.readBuff).Return(tc.readN, tc.readBytes, tc.readErr) + rpc.On("Read", connID, tc.readBuff).Return(tc.readN, tc.readErr) conn := &Conn{ id: connID, @@ -46,7 +41,6 @@ func TestConn_Read(t *testing.T) { n, err := conn.Read(tc.readBuff) require.Equal(t, tc.readErr, err) require.Equal(t, tc.readN, n) - require.Equal(t, tc.wantBuff, tc.readBuff) }) } } diff --git a/pkg/app2/mock_rpc_client.go b/pkg/app2/mock_rpc_client.go index 7caecc6aa8..bbf373f937 100644 --- a/pkg/app2/mock_rpc_client.go +++ b/pkg/app2/mock_rpc_client.go @@ -117,7 +117,7 @@ func (_m *MockRPCClient) Listen(local network.Addr) (uint16, error) { } // Read provides a mock function with given fields: connID, b -func (_m *MockRPCClient) Read(connID uint16, b []byte) (int, []byte, error) { +func (_m *MockRPCClient) Read(connID uint16, b []byte) (int, error) { ret := _m.Called(connID, b) var r0 int @@ -127,23 +127,14 @@ func (_m *MockRPCClient) Read(connID uint16, b []byte) (int, []byte, error) { r0 = ret.Get(0).(int) } - var r1 []byte - if rf, ok := ret.Get(1).(func(uint16, []byte) []byte); ok { + var r1 error + if rf, ok := ret.Get(1).(func(uint16, []byte) error); ok { r1 = rf(connID, b) } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).([]byte) - } - } - - var r2 error - if rf, ok := ret.Get(2).(func(uint16, []byte) error); ok { - r2 = rf(connID, b) - } else { - r2 = ret.Error(2) + r1 = ret.Error(1) } - return r0, r1, r2 + return r0, r1 } // Write provides a mock function with given fields: connID, b diff --git a/pkg/app2/network/dmsg_networker.go b/pkg/app2/network/dmsg_networker.go index e87ef55128..424b0df466 100644 --- a/pkg/app2/network/dmsg_networker.go +++ b/pkg/app2/network/dmsg_networker.go @@ -26,12 +26,7 @@ func (n *DMSGNetworker) Dial(addr Addr) (net.Conn, error) { // DialContext dials remote `addr` via dmsg network with context. func (n *DMSGNetworker) DialContext(ctx context.Context, addr Addr) (net.Conn, error) { - tp, err := n.dmsgC.Dial(ctx, addr.PubKey, uint16(addr.Port)) - if err != nil { - return nil, err - } - - return WrapConn(tp) + return n.dmsgC.Dial(ctx, addr.PubKey, uint16(addr.Port)) } // Listen starts listening on local `addr` in the dmsg network. diff --git a/pkg/app2/rpc_client.go b/pkg/app2/rpc_client.go index 6fc7b240b6..a97773f279 100644 --- a/pkg/app2/rpc_client.go +++ b/pkg/app2/rpc_client.go @@ -15,7 +15,7 @@ type RPCClient interface { Listen(local network.Addr) (uint16, error) Accept(lisID uint16) (connID uint16, remote network.Addr, err error) Write(connID uint16, b []byte) (int, error) - Read(connID uint16, b []byte) (int, []byte, error) + Read(connID uint16, b []byte) (int, error) CloseConn(id uint16) error CloseListener(id uint16) error } @@ -35,7 +35,7 @@ func NewRPCClient(rpc *rpc.Client) RPCClient { // Dial sends `Dial` command to the server. func (c *rpcCLient) Dial(remote network.Addr) (connID uint16, localPort routing.Port, err error) { var resp DialResp - if err := c.rpc.Call("Dial", &remote, &resp); err != nil { + if err := c.rpc.Call("RPCGateway.Dial", &remote, &resp); err != nil { return 0, 0, err } @@ -45,7 +45,7 @@ func (c *rpcCLient) Dial(remote network.Addr) (connID uint16, localPort routing. // Listen sends `Listen` command to the server. func (c *rpcCLient) Listen(local network.Addr) (uint16, error) { var lisID uint16 - if err := c.rpc.Call("Listen", &local, &lisID); err != nil { + if err := c.rpc.Call("RPCGateway.Listen", &local, &lisID); err != nil { return 0, err } @@ -55,7 +55,7 @@ func (c *rpcCLient) Listen(local network.Addr) (uint16, error) { // Accept sends `Accept` command to the server. func (c *rpcCLient) Accept(lisID uint16) (connID uint16, remote network.Addr, err error) { var acceptResp AcceptResp - if err := c.rpc.Call("Accept", &lisID, &acceptResp); err != nil { + if err := c.rpc.Call("RPCGateway.Accept", &lisID, &acceptResp); err != nil { return 0, network.Addr{}, err } @@ -70,7 +70,7 @@ func (c *rpcCLient) Write(connID uint16, b []byte) (int, error) { } var n int - if err := c.rpc.Call("Write", &req, &n); err != nil { + if err := c.rpc.Call("RPCGateway.Write", &req, &n); err != nil { return n, err } @@ -78,28 +78,28 @@ func (c *rpcCLient) Write(connID uint16, b []byte) (int, error) { } // Read sends `Read` command to the server. -func (c *rpcCLient) Read(connID uint16, b []byte) (int, []byte, error) { +func (c *rpcCLient) Read(connID uint16, b []byte) (int, error) { req := ReadReq{ ConnID: connID, BufLen: len(b), } var resp ReadResp - if err := c.rpc.Call("Read", &req, &resp); err != nil { - return 0, nil, err + if err := c.rpc.Call("RPCGateway.Read", &req, &resp); err != nil { + return 0, err } copy(b[:resp.N], resp.B[:resp.N]) - return resp.N, resp.B, nil + return resp.N, nil } // CloseConn sends `CloseConn` command to the server. func (c *rpcCLient) CloseConn(id uint16) error { - return c.rpc.Call("CloseConn", &id, nil) + return c.rpc.Call("RPCGateway.CloseConn", &id, nil) } // CloseListener sends `CloseListener` command to the server. func (c *rpcCLient) CloseListener(id uint16) error { - return c.rpc.Call("CloseListener", &id, nil) + return c.rpc.Call("RPCGateway.CloseListener", &id, nil) } diff --git a/pkg/app2/rpc_client_test.go b/pkg/app2/rpc_client_test.go new file mode 100644 index 0000000000..2d3476780e --- /dev/null +++ b/pkg/app2/rpc_client_test.go @@ -0,0 +1,494 @@ +package app2 + +import ( + "context" + "net" + "net/rpc" + "testing" + + "github.com/pkg/errors" + "github.com/skycoin/dmsg" + "github.com/skycoin/dmsg/cipher" + "github.com/skycoin/skycoin/src/util/logging" + "github.com/stretchr/testify/require" + "golang.org/x/net/nettest" + + "github.com/skycoin/skywire/pkg/app2/network" + "github.com/skycoin/skywire/pkg/routing" +) + +func TestRPCClient_Dial(t *testing.T) { + t.Run("ok", func(t *testing.T) { + s := prepRPCServer(t, prepGateway()) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + remoteNet := network.TypeDMSG + remotePK, _ := cipher.GenerateKeyPair() + remotePort := routing.Port(100) + remote := network.Addr{ + Net: remoteNet, + PubKey: remotePK, + Port: remotePort, + } + + localPK, _ := cipher.GenerateKeyPair() + dmsgLocal := dmsg.Addr{ + PK: localPK, + Port: 101, + } + dmsgRemote := dmsg.Addr{ + PK: remotePK, + Port: uint16(remotePort), + } + + dialCtx := context.Background() + dialConn := dmsg.NewTransport(&MockConn{}, logging.MustGetLogger("dmsg_tp"), + dmsgLocal, dmsgRemote, 0, func() {}) + var noErr error + + n := &network.MockNetworker{} + n.On("DialContext", dialCtx, remote).Return(dialConn, noErr) + + network.ClearNetworkers() + err := network.AddNetworker(remoteNet, n) + require.NoError(t, err) + + connID, localPort, err := cl.Dial(remote) + require.NoError(t, err) + require.Equal(t, connID, uint16(1)) + require.Equal(t, localPort, routing.Port(dmsgLocal.Port)) + + }) + + t.Run("dial error", func(t *testing.T) { + s := prepRPCServer(t, prepGateway()) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + remoteNet := network.TypeDMSG + remotePK, _ := cipher.GenerateKeyPair() + remotePort := routing.Port(100) + remote := network.Addr{ + Net: remoteNet, + PubKey: remotePK, + Port: remotePort, + } + + dialCtx := context.Background() + var dialConn net.Conn + dialErr := errors.New("dial error") + + n := &network.MockNetworker{} + n.On("DialContext", dialCtx, remote).Return(dialConn, dialErr) + + network.ClearNetworkers() + err := network.AddNetworker(remoteNet, n) + require.NoError(t, err) + + connID, localPort, err := cl.Dial(remote) + require.Error(t, err) + require.Equal(t, err.Error(), dialErr.Error()) + require.Equal(t, connID, uint16(0)) + require.Equal(t, localPort, routing.Port(0)) + }) +} + +func TestRPCClient_Listen(t *testing.T) { + t.Run("ok", func(t *testing.T) { + s := prepRPCServer(t, prepGateway()) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + localNet := network.TypeDMSG + localPK, _ := cipher.GenerateKeyPair() + localPort := routing.Port(100) + local := network.Addr{ + Net: localNet, + PubKey: localPK, + Port: localPort, + } + + listenCtx := context.Background() + var listenLis net.Listener + var noErr error + + n := &network.MockNetworker{} + n.On("ListenContext", listenCtx, local).Return(listenLis, noErr) + + network.ClearNetworkers() + err := network.AddNetworker(localNet, n) + require.NoError(t, err) + + lisID, err := cl.Listen(local) + require.NoError(t, err) + require.Equal(t, lisID, uint16(1)) + }) + + t.Run("listen error", func(t *testing.T) { + s := prepRPCServer(t, prepGateway()) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + localNet := network.TypeDMSG + localPK, _ := cipher.GenerateKeyPair() + localPort := routing.Port(100) + local := network.Addr{ + Net: localNet, + PubKey: localPK, + Port: localPort, + } + + listenCtx := context.Background() + var listenLis net.Listener + listenErr := errors.New("listen error") + + n := &network.MockNetworker{} + n.On("ListenContext", listenCtx, local).Return(listenLis, listenErr) + + network.ClearNetworkers() + err := network.AddNetworker(localNet, n) + require.NoError(t, err) + + lisID, err := cl.Listen(local) + require.Error(t, err) + require.Equal(t, err.Error(), listenErr.Error()) + require.Equal(t, lisID, uint16(0)) + }) +} + +func TestRPCClient_Accept(t *testing.T) { + t.Run("ok", func(t *testing.T) { + gateway := prepGateway() + + localPK, _ := cipher.GenerateKeyPair() + localPort := uint16(100) + dmsgLocal := dmsg.Addr{ + PK: localPK, + Port: localPort, + } + remotePK, _ := cipher.GenerateKeyPair() + remotePort := uint16(101) + dmsgRemote := dmsg.Addr{ + PK: remotePK, + Port: remotePort, + } + lisConn := dmsg.NewTransport(&MockConn{}, logging.MustGetLogger("dmsg_tp"), + dmsgLocal, dmsgRemote, 0, func() {}) + var noErr error + + lis := &MockListener{} + lis.On("Accept").Return(lisConn, noErr) + + lisID := uint16(1) + + _, err := gateway.lm.add(lisID, lis) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + wantRemote := network.Addr{ + Net: network.TypeDMSG, + PubKey: remotePK, + Port: routing.Port(remotePort), + } + + connID, remote, err := cl.Accept(lisID) + require.NoError(t, err) + require.Equal(t, connID, uint16(1)) + require.Equal(t, remote, wantRemote) + }) + + t.Run("accept error", func(t *testing.T) { + gateway := prepGateway() + + var lisConn net.Conn + listenErr := errors.New("accept error") + + lis := &MockListener{} + lis.On("Accept").Return(lisConn, listenErr) + + lisID := uint16(1) + + _, err := gateway.lm.add(lisID, lis) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + connID, remote, err := cl.Accept(lisID) + require.Error(t, err) + require.Equal(t, err.Error(), listenErr.Error()) + require.Equal(t, connID, uint16(0)) + require.Equal(t, remote, network.Addr{}) + }) +} + +func TestRPCClient_Write(t *testing.T) { + t.Run("ok", func(t *testing.T) { + gateway := prepGateway() + + writeBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + writeN := 10 + var noErr error + + conn := &MockConn{} + conn.On("Write", writeBuf).Return(writeN, noErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + n, err := cl.Write(connID, writeBuf) + require.NoError(t, err) + require.Equal(t, n, writeN) + }) + + t.Run("write error", func(t *testing.T) { + gateway := prepGateway() + + writeBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + writeN := 0 + writeErr := errors.New("write error") + + conn := &MockConn{} + conn.On("Write", writeBuf).Return(writeN, writeErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + n, err := cl.Write(connID, writeBuf) + require.Error(t, err) + require.Equal(t, err.Error(), writeErr.Error()) + require.Equal(t, n, 0) + }) +} + +func TestRPCClient_Read(t *testing.T) { + t.Run("ok", func(t *testing.T) { + gateway := prepGateway() + + readBufLen := 10 + readBuf := make([]byte, readBufLen) + readN := 5 + var noErr error + + conn := &MockConn{} + conn.On("Read", readBuf).Return(readN, noErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + n, err := cl.Read(connID, readBuf) + require.NoError(t, err) + require.Equal(t, n, readN) + }) + + t.Run("read error", func(t *testing.T) { + gateway := prepGateway() + + readBufLen := 10 + readBuf := make([]byte, readBufLen) + readN := 0 + readErr := errors.New("read error") + + conn := &MockConn{} + conn.On("Read", readBuf).Return(readN, readErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + n, err := cl.Read(connID, readBuf) + require.Error(t, err) + require.Equal(t, err.Error(), readErr.Error()) + require.Equal(t, n, readN) + }) +} + +func TestRPCClient_CloseConn(t *testing.T) { + t.Run("ok", func(t *testing.T) { + gateway := prepGateway() + + var noErr error + + conn := &MockConn{} + conn.On("Close").Return(noErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + err = cl.CloseConn(connID) + require.NoError(t, err) + }) + + t.Run("close error", func(t *testing.T) { + gateway := prepGateway() + + closeErr := errors.New("close error") + + conn := &MockConn{} + conn.On("Close").Return(closeErr) + + connID := uint16(1) + + _, err := gateway.cm.add(connID, conn) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + err = cl.CloseConn(connID) + require.Error(t, err) + require.Equal(t, err.Error(), closeErr.Error()) + }) +} + +func TestRPCClient_CloseListener(t *testing.T) { + t.Run("ok", func(t *testing.T) { + gateway := prepGateway() + + var noErr error + + lis := &MockListener{} + lis.On("Close").Return(noErr) + + lisID := uint16(1) + + _, err := gateway.lm.add(lisID, lis) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + err = cl.CloseListener(lisID) + require.NoError(t, err) + }) + + t.Run("close error", func(t *testing.T) { + gateway := prepGateway() + + closeErr := errors.New("close error") + + lis := &MockListener{} + lis.On("Close").Return(closeErr) + + lisID := uint16(1) + + _, err := gateway.lm.add(lisID, lis) + require.NoError(t, err) + + s := prepRPCServer(t, gateway) + rpcL, lisCleanup := prepListener(t) + defer lisCleanup() + go s.Accept(rpcL) + + cl := prepClient(t, rpcL.Addr().Network(), rpcL.Addr().String()) + + err = cl.CloseListener(lisID) + require.Error(t, err) + require.Equal(t, err.Error(), closeErr.Error()) + }) +} + +func prepGateway() *RPCGateway { + l := logging.MustGetLogger("rpc_gateway") + return newRPCGateway(l) +} + +func prepRPCServer(t *testing.T, gateway *RPCGateway) *rpc.Server { + s := rpc.NewServer() + err := s.Register(gateway) + require.NoError(t, err) + + return s +} + +func prepListener(t *testing.T) (lis net.Listener, cleanup func()) { + lis, err := nettest.NewLocalListener("tcp") + require.NoError(t, err) + + return lis, func() { + err := lis.Close() + require.NoError(t, err) + } +} + +func prepClient(t *testing.T, network, addr string) RPCClient { + rpcCl, err := rpc.Dial(network, addr) + require.NoError(t, err) + + return NewRPCClient(rpcCl) +}