diff --git a/.gitignore b/.gitignore index 58ee84bfe..b8589c08e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ *.dylib *.test *.out +.DS_Store .idea/ diff --git a/internal/ioutil/ack_waiter.go b/internal/ioutil/ack_waiter.go index 4e7279576..ef5ec65e9 100644 --- a/internal/ioutil/ack_waiter.go +++ b/internal/ioutil/ack_waiter.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "encoding/binary" + "io" "math" "sync" ) @@ -43,6 +44,22 @@ func (w *Uint16AckWaiter) RandSeq() error { return nil } +func (w *Uint16AckWaiter) stopWaiter(seq Uint16Seq) { + if waiter := w.waiters[seq]; waiter != nil { + close(waiter) + w.waiters[seq] = nil + } +} + +// StopAll stops all active waiters. +func (w *Uint16AckWaiter) StopAll() { + w.mx.Lock() + for seq := range w.waiters { + w.stopWaiter(Uint16Seq(seq)) + } + w.mx.Unlock() +} + // Wait performs the given action, and waits for given seq to be Done. func (w *Uint16AckWaiter) Wait(ctx context.Context, action func(seq Uint16Seq) error) (err error) { ackCh := make(chan struct{}) @@ -58,16 +75,18 @@ func (w *Uint16AckWaiter) Wait(ctx context.Context, action func(seq Uint16Seq) e } select { - case <-ackCh: + case _, ok := <-ackCh: + if !ok { + // waiter stopped manually. + return io.ErrClosedPipe + } case <-ctx.Done(): err = ctx.Err() } w.mx.Lock() - close(ackCh) - w.waiters[seq] = nil + w.stopWaiter(seq) w.mx.Unlock() - return err } diff --git a/internal/ioutil/ack_waiter_test.go b/internal/ioutil/ack_waiter_test.go index 6a14aab1a..3c9365c30 100644 --- a/internal/ioutil/ack_waiter_test.go +++ b/internal/ioutil/ack_waiter_test.go @@ -12,7 +12,7 @@ func TestUint16AckWaiter_Wait(t *testing.T) { seqChan := make(chan Uint16Seq) defer close(seqChan) for i := 0; i < 64; i++ { - go w.Wait(context.TODO(), func(seq Uint16Seq) error { //nolint:errcheck + go w.Wait(context.TODO(), func(seq Uint16Seq) error { //nolint:errcheck,unparam seqChan <- seq return nil }) diff --git a/internal/ioutil/atomic_bool.go b/internal/ioutil/atomic_bool.go new file mode 100644 index 000000000..e2b903fad --- /dev/null +++ b/internal/ioutil/atomic_bool.go @@ -0,0 +1,23 @@ +package ioutil + +import "sync/atomic" + +// AtomicBool implements a thread-safe boolean value. +type AtomicBool struct { + flag int32 +} + +// Set set's the boolean to specified value +// and returns true if the value is changed. +func (b *AtomicBool) Set(v bool) bool { + newF := int32(0) + if v { + newF = 1 + } + return newF != atomic.SwapInt32(&b.flag, newF) +} + +// Get obtains the current boolean value. +func (b *AtomicBool) Get() bool { + return atomic.LoadInt32(&b.flag) == 1 +} diff --git a/pkg/dmsg/client.go b/pkg/dmsg/client.go index b6be47dd9..d7eddf44d 100644 --- a/pkg/dmsg/client.go +++ b/pkg/dmsg/client.go @@ -9,6 +9,8 @@ import ( "sync" "time" + "github.com/sirupsen/logrus" + "github.com/skycoin/skycoin/src/util/logging" "github.com/skycoin/skywire/internal/noise" @@ -36,94 +38,119 @@ type ClientConn struct { // locally-initiated tps use an even tp_id between local and intermediary dms_server. nextInitID uint16 - // map of transports to remote dms_clients (key: tp_id, val: transport). + // Transports: map of transports to remote dms_clients (key: tp_id, val: transport). tps [math.MaxUint16 + 1]*Transport - mx sync.RWMutex // to protect tps. + mx sync.RWMutex // to protect tps - wg sync.WaitGroup + done chan struct{} + once sync.Once + wg sync.WaitGroup } // NewClientConn creates a new ClientConn. func NewClientConn(log *logging.Logger, conn net.Conn, local, remote cipher.PubKey) *ClientConn { - cc := &ClientConn{log: log, Conn: conn, local: local, remoteSrv: remote, nextInitID: randID(true)} + cc := &ClientConn{ + log: log, + Conn: conn, + local: local, + remoteSrv: remote, + nextInitID: randID(true), + done: make(chan struct{}), + } cc.wg.Add(1) return cc } -func (c *ClientConn) delTp(id uint16) { - c.mx.Lock() - c.tps[id] = nil - c.mx.Unlock() -} - -func (c *ClientConn) setTp(tp *Transport) { - c.mx.Lock() - c.tps[tp.id] = tp - c.mx.Unlock() +func (c *ClientConn) getNextInitID(ctx context.Context) (uint16, error) { + for { + select { + case <-c.done: + return 0, ErrClientClosed + case <-ctx.Done(): + return 0, ctx.Err() + default: + if ch := c.tps[c.nextInitID]; ch != nil && !ch.IsClosed() { + c.nextInitID += 2 + continue + } + c.tps[c.nextInitID] = nil + id := c.nextInitID + c.nextInitID = id + 2 + return id, nil + } + } } -// keeps record of a locally-initiated tp to 'clientPK'. -// assigns an even tp_id and keeps track of it in tps map. func (c *ClientConn) addTp(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) { c.mx.Lock() defer c.mx.Unlock() - for { - if ch := c.tps[c.nextInitID]; ch == nil || ch.IsDone() { - break - } - c.nextInitID += 2 + id, err := c.getNextInitID(ctx) + if err != nil { + return nil, err + } + tp := NewTransport(c.Conn, c.log, c.local, clientPK, id, c.delTp) + c.tps[id] = tp + return tp, nil +} - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } +func (c *ClientConn) acceptTp(clientPK cipher.PubKey, id uint16) (*Transport, error) { + tp := NewTransport(c.Conn, c.log, c.local, clientPK, id, c.delTp) + + c.mx.Lock() + c.tps[tp.id] = tp + c.mx.Unlock() + + if err := tp.WriteAccept(); err != nil { + return nil, err } + return tp, nil +} - id := c.nextInitID - c.nextInitID = id + 2 - ch := NewTransport(c.Conn, c.log, c.local, clientPK, id) - c.tps[id] = ch - return ch, nil +func (c *ClientConn) delTp(id uint16) { + c.mx.Lock() + c.tps[id] = nil + c.mx.Unlock() } func (c *ClientConn) getTp(id uint16) (*Transport, bool) { c.mx.RLock() tp := c.tps[id] c.mx.RUnlock() - ok := tp != nil && !tp.IsDone() + ok := tp != nil && !tp.IsClosed() return tp, ok } func (c *ClientConn) handleRequestFrame(ctx context.Context, accept chan<- *Transport, id uint16, p []byte) (cipher.PubKey, error) { - // remote-initiated tps should: + // remotely-initiated tps should: // - have a payload structured as 'init_pk:resp_pk'. // - resp_pk should be of local client. - // - use an odd tp_id with the intermediary dms_server. + // - use an odd tp_id with the intermediary dmsg_server. initPK, respPK, ok := splitPKs(p) if !ok || respPK != c.local || isInitiatorID(id) { if err := writeCloseFrame(c.Conn, id, 0); err != nil { - c.Close() return initPK, err } return initPK, ErrRequestCheckFailed } - - tp := NewTransport(c.Conn, c.log, c.local, initPK, id) - if err := tp.Handshake(ctx); err != nil { - // return err here as response handshake is send via ClientConn and that shouldn't fail. - c.Close() + tp, err := c.acceptTp(initPK, id) + if err != nil { return initPK, err } - c.setTp(tp) + go tp.Serve() select { + case <-c.done: + _ = tp.Close() //nolint:errcheck + return initPK, ErrClientClosed + case <-ctx.Done(): + _ = tp.Close() //nolint:errcheck return initPK, ctx.Err() + case accept <- tp: + return initPK, nil } - return initPK, nil } // Serve handles incoming frames. @@ -133,11 +160,18 @@ func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err e log := c.log.WithField("remoteServer", c.remoteSrv) - log.WithField("connCount", incrementServeCount()).Infoln("ServingConn") + log.WithField("connCount", incrementServeCount()). + Infoln("ServingConn") defer func() { - log.WithError(err).WithField("connCount", decrementServeCount()).Infoln("ClosingConn") + log.WithError(err). + WithField("connCount", decrementServeCount()). + Infoln("ClosingConn") }() + closeConn := func(log *logrus.Entry) { + log.WithError(c.Close()).Warn("ClosingConnection") + } + for { f, err := readFrame(c.Conn) if err != nil { @@ -147,39 +181,49 @@ func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err e ft, id, p := f.Disassemble() + // If tp of tp_id exists, attempt to forward frame to tp. + // delete tp on any failure. + if tp, ok := c.getTp(id); ok { - // If tp of tp_id exists, attempt to forward frame to tp. - // delete tp on any failure. - if !tp.InjectRead(f) { - log.WithField("remoteClient", tp.remote).Infoln("FrameTrashed") - c.delTp(id) + if err := tp.Inject(f); err != nil { + log.WithError(err).Warnf("Rejected [%s]: Transport closed.", ft) } - log.WithField("remoteClient", tp.remote).Infoln("FrameInjected") continue } // if tp does not exist, frame should be 'REQUEST'. // otherwise, handle any unexpected frames accordingly. + c.delTp(id) // rm tp in case closed tp is not fully removed. + switch ft { case RequestType: - // TODO(evanlinjin): Allow for REQUEST frame handling to be done in goroutine. - // Currently this causes issues (probably because we need ACK frames). - initPK, err := c.handleRequestFrame(ctx, accept, id, p) - if err != nil { - log.WithField("remoteClient", initPK).WithError(err).Infoln("FrameRejected") - if err == ErrRequestCheckFailed { - continue + c.wg.Add(1) + go func(log *logrus.Entry) { + defer c.wg.Done() + + initPK, err := c.handleRequestFrame(ctx, accept, id, p) + if err != nil { + log. + WithField("remoteClient", initPK). + WithError(err). + Infoln("Rejected [REQUEST]") + if isWriteError(err) || err == ErrClientClosed { + closeConn(log) + } + return } - return err - } - log.WithField("remoteClient", initPK).Infoln("FrameAccepted") - case CloseType: - log.Infoln("FrameIgnored") + log. + WithField("remoteClient", initPK). + Infoln("Accepted [REQUEST]") + }(log) + default: - log.Infoln("FrameUnexpected") - if err := writeCloseFrame(c.Conn, id, 0); err != nil { - return err + log.Infof("Ignored [%s]: No transport of given ID.", ft) + if ft != CloseType { + if err := writeCloseFrame(c.Conn, id, 0); err != nil { + return err + } } } } @@ -187,16 +231,26 @@ func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err e // DialTransport dials a transport to remote dms_client. func (c *ClientConn) DialTransport(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) { + c.log.Warn("DialTransport...") tp, err := c.addTp(ctx, clientPK) if err != nil { return nil, err } - return tp, tp.Handshake(ctx) + if err := tp.WriteRequest(); err != nil { + return nil, err + } + if err := tp.ReadAccept(ctx); err != nil { + return nil, err + } + c.log.Warn("DialTransport: Accepted.") + go tp.Serve() + return tp, nil } // Close closes the connection to dms_server. func (c *ClientConn) Close() error { - c.log.Infof("closingLink: remoteSrv(%v)", c.remoteSrv) + c.log.WithField("remoteServer", c.remoteSrv).Infoln("ClosingConnection") + c.once.Do(func() { close(c.done) }) c.mx.Lock() for _, tp := range c.tps { if tp != nil { @@ -377,7 +431,7 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey) c.setConn(ctx, conn) go func() { if err := conn.Serve(ctx, c.accept); err != nil { - conn.log.WithError(err).WithField("dms_server", srvPK).Warn("connected with dms_server closed") + conn.log.WithError(err).WithField("remoteServer", srvPK).Warn("connected with server closed") c.delConn(ctx, srvPK) // reconnect logic. diff --git a/pkg/dmsg/frame.go b/pkg/dmsg/frame.go index 560538f9c..27c7af601 100644 --- a/pkg/dmsg/frame.go +++ b/pkg/dmsg/frame.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "fmt" "io" + "math" "sync/atomic" "time" @@ -17,9 +18,9 @@ const ( Type = "dmsg" hsTimeout = time.Second * 10 - readTimeout = time.Second * 10 - acceptChSize = 1 - readChSize = 20 + tpBufCap = math.MaxUint16 + tpAckCap = math.MaxUint8 + acceptChSize = 20 headerLen = 5 // fType(1 byte), chID(2 byte), payLen(2 byte) ) @@ -116,9 +117,21 @@ func readFrame(r io.Reader) (Frame, error) { return f, err } +type writeError struct{ error } + +func (e *writeError) Error() string { return "write error: " + e.error.Error() } + +func isWriteError(err error) bool { + _, ok := err.(*writeError) + return ok +} + func writeFrame(w io.Writer, f Frame) error { _, err := w.Write(f) - return err + if err != nil { + return &writeError{err} + } + return nil } func writeFwdFrame(w io.Writer, id uint16, seq ioutil.Uint16Seq, p []byte) error { diff --git a/pkg/dmsg/server_test.go b/pkg/dmsg/server_test.go index cab8c574e..1e33ff719 100644 --- a/pkg/dmsg/server_test.go +++ b/pkg/dmsg/server_test.go @@ -484,8 +484,8 @@ func TestNewClient(t *testing.T) { pay := []byte(fmt.Sprintf("This is message %d!", j)) n, err := aTp.Write(pay) - require.Equal(t, len(pay), n) require.NoError(t, err) + require.Equal(t, len(pay), n) got := make([]byte, len(pay)) n, err = bTp.Read(got) @@ -498,7 +498,7 @@ func TestNewClient(t *testing.T) { require.NoError(t, aTp.Close()) require.NoError(t, bTp.Close()) } - wg.Wait() + //wg.Wait() // Close server. assert.NoError(t, s.Close()) diff --git a/pkg/dmsg/transport.go b/pkg/dmsg/transport.go index e17c21a76..8dd076ecd 100644 --- a/pkg/dmsg/transport.go +++ b/pkg/dmsg/transport.go @@ -1,9 +1,9 @@ package dmsg import ( - "bytes" "context" "errors" + "fmt" "io" "net" "sync" @@ -15,10 +15,11 @@ import ( "github.com/skycoin/skywire/pkg/transport" ) -// Errors related to REQUESTs. +// Errors related to REQUEST frames. var ( - ErrRequestRejected = errors.New("request rejected") - ErrRequestCheckFailed = errors.New("request check failed") + ErrRequestRejected = errors.New("failed to create transport: request rejected") + ErrRequestCheckFailed = errors.New("failed to create transport: request check failed") + ErrAcceptCheckFailed = errors.New("failed to create transport: accept check failed") ) // Transport represents a connection from dmsg.Client to remote dmsg.Client (via dmsg.Server intermediary). @@ -31,24 +32,34 @@ type Transport struct { local cipher.PubKey remote cipher.PubKey // remote PK + inCh chan Frame + inMx sync.RWMutex + ackWaiter ioutil.Uint16AckWaiter - readBuf bytes.Buffer - readMx sync.Mutex // This is for protecting 'readBuf'. - readCh chan Frame - doneCh chan struct{} // stop writing - doneOnce sync.Once + ackBuf []byte + buf net.Buffers + bufCh chan struct{} + bufSize int + bufMx sync.Mutex // protects 'buf' and 'bufCh' + + once sync.Once + done chan struct{} + doneFunc func(id uint16) } // NewTransport creates a new dms_tp. -func NewTransport(conn net.Conn, log *logging.Logger, local, remote cipher.PubKey, id uint16) *Transport { +func NewTransport(conn net.Conn, log *logging.Logger, local, remote cipher.PubKey, id uint16, doneFunc func(id uint16)) *Transport { tp := &Transport{ - Conn: conn, - log: log, - id: id, - local: local, - remote: remote, - readCh: make(chan Frame, readChSize), - doneCh: make(chan struct{}), + Conn: conn, + log: log, + id: id, + local: local, + remote: remote, + inCh: make(chan Frame), + ackBuf: make([]byte, 0, tpAckCap), + bufCh: make(chan struct{}, 1), + done: make(chan struct{}), + doneFunc: doneFunc, } if err := tp.ackWaiter.RandSeq(); err != nil { log.Fatalln("failed to set ack_waiter seq:", err) @@ -56,199 +67,255 @@ func NewTransport(conn net.Conn, log *logging.Logger, local, remote cipher.PubKe return tp } -func (c *Transport) close() (closed bool) { - c.doneOnce.Do(func() { +func (tp *Transport) close() (closed bool) { + tp.once.Do(func() { closed = true - close(c.doneCh) - // Kill all goroutines pushing to `c.readCh` before closing it. - // No more goroutines pushing to `c.readCh` should be created once `c.doneCh` is closed. - for { - select { - case <-c.readCh: - default: - close(c.readCh) - return - } - } + close(tp.done) + tp.doneFunc(tp.id) + + tp.bufMx.Lock() + close(tp.bufCh) + tp.bufMx.Unlock() + + tp.inMx.Lock() + close(tp.inCh) + tp.inMx.Unlock() + }) - return closed -} -func (c *Transport) awaitResponse(ctx context.Context) error { - select { - case <-c.doneCh: - return ErrRequestRejected - case <-ctx.Done(): - return ctx.Err() - case f, ok := <-c.readCh: - if !ok { - return io.ErrClosedPipe - } - if f.Type() == AcceptType { - return nil - } - return errors.New("invalid remote response") - } + tp.ackWaiter.StopAll() + return closed } -// Handshake performs a tp handshake (before tp is considered valid). -func (c *Transport) Handshake(ctx context.Context) error { - // if channel ID is even, client is initiator. - if isInitiatorID(c.id) { - pks := combinePKs(c.local, c.remote) - f := MakeFrame(RequestType, c.id, pks) - if err := writeFrame(c.Conn, f); err != nil { - c.close() - return err - } - if err := c.awaitResponse(ctx); err != nil { - c.close() - return err - } - } else { - f := MakeFrame(AcceptType, c.id, combinePKs(c.remote, c.local)) - if err := writeFrame(c.Conn, f); err != nil { - c.log.WithError(err).Error("HandshakeFailed") - c.close() - return err - } - c.log.WithField("sent", f).Infoln("HandshakeCompleted") +// Close closes the dmsg_tp. +func (tp *Transport) Close() error { + if tp.close() { + _ = writeFrame(tp.Conn, MakeFrame(CloseType, tp.id, []byte{0})) //nolint:errcheck } return nil } -// IsDone returns whether dms_tp is closed. -func (c *Transport) IsDone() bool { +// IsClosed returns whether dms_tp is closed. +func (tp *Transport) IsClosed() bool { select { - case <-c.doneCh: + case <-tp.done: return true default: return false } } -// InjectRead blocks until frame is read. -// Returns false when read fails (e.g. when tp is closed). -func (c *Transport) InjectRead(f Frame) bool { - ok := c.injectRead(f) - if !ok { - c.close() - } - return ok +// Edges returns the local/remote edges of the transport (dms_client to dms_client). +func (tp *Transport) Edges() [2]cipher.PubKey { + return transport.SortPubKeys(tp.local, tp.remote) } -func (c *Transport) injectRead(f Frame) bool { - push := func(f Frame) bool { - select { - case <-c.doneCh: - return false - case c.readCh <- f: - return true - default: - return false - } - } - - switch f.Type() { - case CloseType: - return false +// Type returns the transport type. +func (tp *Transport) Type() string { + return Type +} - case AckType: - p := f.Pay() - if len(p) != 2 { - return false - } - c.ackWaiter.Done(ioutil.DecodeUint16Seq(p)) - return true +// Inject injects a frame from 'ClientConn' to transport. +// Frame is then handled by 'tp.Serve'. +func (tp *Transport) Inject(f Frame) error { + if tp.IsClosed() { + return io.ErrClosedPipe + } - case FwdType: - p := f.Pay() - if len(p) < 2 { - return false - } - if ok := push(f); !ok { - return false - } - go func() { - if err := writeFrame(c.Conn, MakeFrame(AckType, c.id, p[:2])); err != nil { - c.close() - } - }() - return true + tp.inMx.RLock() + defer tp.inMx.RUnlock() - default: - return push(f) + select { + case <-tp.done: + return io.ErrClosedPipe + case tp.inCh <- f: + return nil } } -// Read implements io.Reader -func (c *Transport) Read(p []byte) (n int, err error) { - c.readMx.Lock() - defer c.readMx.Unlock() +// WriteRequest writes a REQUEST frame to dmsg_server to be forwarded to associated client. +func (tp *Transport) WriteRequest() error { + f := MakeFrame(RequestType, tp.id, combinePKs(tp.local, tp.remote)) + if err := writeFrame(tp.Conn, f); err != nil { + tp.log.WithError(err).Error("HandshakeFailed") + tp.close() + return err + } + return nil +} - if c.readBuf.Len() != 0 { - return c.readBuf.Read(p) +// WriteAccept writes an ACCEPT frame to dmsg_server to be forwarded to associated client. +func (tp *Transport) WriteAccept() error { + f := MakeFrame(AcceptType, tp.id, combinePKs(tp.remote, tp.local)) + if err := writeFrame(tp.Conn, f); err != nil { + tp.log.WithError(err).Error("HandshakeFailed") + tp.close() + return err } + tp.log.WithField("sent", f).Infoln("HandshakeCompleted") + return nil +} + +// ReadAccept awaits for an ACCEPT frame to be read from the remote client. +// TODO(evanlinjin): Cleanup errors. +func (tp *Transport) ReadAccept(ctx context.Context) (err error) { + defer func() { + tp.log.WithError(err).WithField("success", err == nil).Infoln("HandshakeDone") + }() select { - case <-c.doneCh: - return 0, io.ErrClosedPipe - case f, ok := <-c.readCh: + case <-tp.done: + tp.close() + return io.ErrClosedPipe + + case <-ctx.Done(): + _ = tp.Close() //nolint:errcheck + return ctx.Err() + + case f, ok := <-tp.inCh: if !ok { - return 0, io.ErrClosedPipe + tp.close() + return io.ErrClosedPipe } - if f.Type() == FwdType { - return ioutil.BufRead(&c.readBuf, f.Pay()[2:], p) + switch ft, id, p := f.Disassemble(); ft { + case AcceptType: + // locally-initiated tps should: + // - have a payload structured as 'init_pk:resp_pk'. + // - init_pk should be of local client. + // - resp_pk should be of remote client. + // - use an even number with the intermediary dmsg_server. + initPK, respPK, ok := splitPKs(p) + if !ok || initPK != tp.local || respPK != tp.remote || !isInitiatorID(id) { + _ = tp.Close() //nolint:errcheck + return ErrAcceptCheckFailed + } + return nil + + case CloseType: + tp.close() + return ErrRequestRejected + + default: + _ = tp.Close() //nolint:errcheck + return ErrAcceptCheckFailed } - return 0, errors.New("unexpected frame") } } -// Write implements io.Writer -func (c *Transport) Write(p []byte) (int, error) { - select { - case <-c.doneCh: - return 0, io.ErrClosedPipe - default: - ctx, cancel := context.WithTimeout(context.Background(), readTimeout) - go func() { - select { - case <-ctx.Done(): - case <-c.doneCh: - cancel() +// Serve handles received frames. +func (tp *Transport) Serve() { + defer func() { + if tp.close() { + _ = writeCloseFrame(tp.Conn, tp.id, 0) //nolint:errcheck + } + }() + + for { + select { + case <-tp.done: + return + + case f, ok := <-tp.inCh: + if !ok { + return } - }() - err := c.ackWaiter.Wait(ctx, func(seq ioutil.Uint16Seq) error { - if err := writeFwdFrame(c.Conn, c.id, seq, p); err != nil { - c.close() - return err + log := tp.log. + WithField("remoteClient", tp.remote). + WithField("received", f) + + switch p := f.Pay(); f.Type() { + case FwdType: + if len(p) < 2 { + log.Warnln("Rejected [FWD]: Invalid payload size.") + return + } + ack := MakeFrame(AckType, tp.id, p[:2]) + + tp.bufMx.Lock() + if tp.bufSize += len(p[2:]); tp.bufSize > tpBufCap { + tp.ackBuf = append(tp.ackBuf, ack...) + } else { + go func() { + if err := writeFrame(tp.Conn, ack); err != nil { + tp.close() + } + }() + } + tp.buf = append(tp.buf, p[2:]) + select { + case <-tp.done: + case tp.bufCh <- struct{}{}: + default: + } + log.WithField("bufSize", fmt.Sprintf("%d/%d", tp.bufSize, tpBufCap)).Infoln("Injected [FWD]") + tp.bufMx.Unlock() + + case AckType: + if len(p) != 2 { + log.Warnln("Rejected [ACK]: Invalid payload size.") + return + } + tp.ackWaiter.Done(ioutil.DecodeUint16Seq(p[:2])) + log.Infoln("Injected [ACK]") + + case CloseType: + log.Infoln("Injected [CLOSE]: Closing transport...") + return + + case RequestType: + log.Warnln("Rejected [REQUEST]: ID already occupied, malicious server.") + _ = tp.Conn.Close() + return + + default: + tp.log.Infof("Rejected [%s]: Unexpected frame, malicious server (ignored for now).", f.Type()) } - return nil - }) - if err != nil { - cancel() - return 0, err } - return len(p), nil } } -// Close closes the dms_tp. -func (c *Transport) Close() error { - if c.close() { - _ = writeFrame(c.Conn, MakeFrame(CloseType, c.id, []byte{0})) //nolint:errcheck - return nil - } - return io.ErrClosedPipe -} +// Read implements io.Reader +// TODO(evanlinjin): read deadline. +func (tp *Transport) Read(p []byte) (n int, err error) { +startRead: + tp.bufMx.Lock() + n, err = tp.buf.Read(p) + go func() { + if tp.bufSize -= n; tp.bufSize < tpBufCap { + if err := writeFrame(tp.Conn, tp.ackBuf); err != nil { + tp.close() + } + tp.ackBuf = make([]byte, 0, tpAckCap) + } + tp.bufMx.Unlock() + }() -// Edges returns the local/remote edges of the transport (dms_client to dms_client). -func (c *Transport) Edges() [2]cipher.PubKey { - return transport.SortPubKeys(c.local, c.remote) + if tp.IsClosed() { + return n, err + } + if n > 0 { + return n, nil + } + <-tp.bufCh + goto startRead } -// Type returns the transport type. -func (c *Transport) Type() string { - return Type +// Write implements io.Writer +// TODO(evanlinjin): write deadline. +func (tp *Transport) Write(p []byte) (int, error) { + if tp.IsClosed() { + return 0, io.ErrClosedPipe + } + err := tp.ackWaiter.Wait(context.Background(), func(seq ioutil.Uint16Seq) error { + if err := writeFwdFrame(tp.Conn, tp.id, seq, p); err != nil { + tp.close() + return err + } + return nil + }) + if err != nil { + return 0, err + } + return len(p), nil } diff --git a/pkg/router/router.go b/pkg/router/router.go index 0ef310a35..2f1851fdb 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -87,6 +87,8 @@ func (r *Router) Serve(ctx context.Context) error { isAccepted, isSetup := tp.Accepted, r.IsSetupTransport(tp) r.mu.Unlock() + r.Logger.Infof("New transport: isAccepted: %v, isSetup: %v", isAccepted, isSetup) + var serve func(io.ReadWriter) error switch { case isAccepted && isSetup: @@ -430,15 +432,25 @@ func (r *Router) setupProto(ctx context.Context) (*setup.Protocol, transport.Tra return sProto, tr, nil } -func (r *Router) fetchBestRoutes(source, destination cipher.PubKey) (routing.Route, routing.Route, error) { +func (r *Router) fetchBestRoutes(source, destination cipher.PubKey) (fwd routing.Route, rev routing.Route, err error) { r.Logger.Infof("Requesting new routes from %s to %s", source, destination) - forwardRoutes, reverseRoutes, err := r.config.RouteFinder.PairedRoutes(source, destination, minHops, maxHops) + + timer := time.NewTimer(time.Second * 10) + defer timer.Stop() + +fetchRoutesAgain: + fwdRoutes, revRoutes, err := r.config.RouteFinder.PairedRoutes(source, destination, minHops, maxHops) if err != nil { - return nil, nil, err + select { + case <-timer.C: + return nil, nil, err + default: + goto fetchRoutesAgain + } } - r.Logger.Infof("Found routes Forward: %s. Reverse %s", forwardRoutes, reverseRoutes) - return forwardRoutes[0], reverseRoutes[0], nil + r.Logger.Infof("Found routes Forward: %s. Reverse %s", fwdRoutes, revRoutes) + return fwdRoutes[0], revRoutes[0], nil } func (r *Router) advanceNoiseHandshake(addr *app.LoopAddr, noiseMsg []byte) (ni *noise.Noise, noiseRes []byte, err error) { diff --git a/pkg/transport/discovery_test.go b/pkg/transport/discovery_test.go index db0126143..190f86de4 100644 --- a/pkg/transport/discovery_test.go +++ b/pkg/transport/discovery_test.go @@ -1,19 +1,20 @@ -package transport +package transport_test import ( "context" "fmt" "github.com/skycoin/skywire/pkg/cipher" + "github.com/skycoin/skywire/pkg/transport" ) func ExampleNewDiscoveryMock() { - dc := NewDiscoveryMock() + dc := transport.NewDiscoveryMock() pk1, _ := cipher.GenerateKeyPair() pk2, _ := cipher.GenerateKeyPair() - entry := &Entry{Type: "mock", EdgeKeys: SortPubKeys(pk1, pk2)} + entry := &transport.Entry{Type: "mock", EdgeKeys: transport.SortPubKeys(pk1, pk2)} - sEntry := &SignedEntry{Entry: entry} + sEntry := &transport.SignedEntry{Entry: entry} if err := dc.RegisterTransports(context.TODO(), sEntry); err == nil { fmt.Println("RegisterTransport success") @@ -33,7 +34,7 @@ func ExampleNewDiscoveryMock() { fmt.Printf("entriesWS[0].Entry.Edges()[0] == entry.Edges()[0] is %v\n", entriesWS[0].Entry.Edges()[0] == entry.Edges()[0]) } - if _, err := dc.UpdateStatuses(context.TODO(), &Status{}); err == nil { + if _, err := dc.UpdateStatuses(context.TODO(), &transport.Status{}); err == nil { fmt.Println("UpdateStatuses success") } else { fmt.Println(err.Error()) diff --git a/pkg/transport/entry_test.go b/pkg/transport/entry_test.go index 974dee815..8e882ffc4 100644 --- a/pkg/transport/entry_test.go +++ b/pkg/transport/entry_test.go @@ -1,4 +1,4 @@ -package transport +package transport_test import ( "fmt" @@ -6,6 +6,7 @@ import ( "github.com/google/uuid" "github.com/skycoin/skywire/pkg/cipher" + "github.com/skycoin/skywire/pkg/transport" ) // ExampleNewEntry shows that with different order of edges: @@ -15,8 +16,8 @@ func ExampleNewEntry() { pkA, _ := cipher.GenerateKeyPair() pkB, _ := cipher.GenerateKeyPair() - entryAB := NewEntry(pkA, pkB, "", true) - entryBA := NewEntry(pkB, pkA, "", true) + entryAB := transport.NewEntry(pkA, pkB, "", true) + entryBA := transport.NewEntry(pkB, pkA, "", true) if entryAB.ID == entryBA.ID { fmt.Println("entryAB.ID == entryBA.ID") @@ -32,14 +33,14 @@ func ExampleEntry_Edges() { pkA, _ := cipher.GenerateKeyPair() pkB, _ := cipher.GenerateKeyPair() - entryAB := Entry{ + entryAB := transport.Entry{ ID: uuid.UUID{}, EdgeKeys: [2]cipher.PubKey{pkA, pkB}, Type: "", Public: true, } - entryBA := Entry{ + entryBA := transport.Entry{ ID: uuid.UUID{}, EdgeKeys: [2]cipher.PubKey{pkB, pkA}, Type: "", @@ -62,7 +63,7 @@ func ExampleEntry_SetEdges() { pkA, _ := cipher.GenerateKeyPair() pkB, _ := cipher.GenerateKeyPair() - entryAB, entryBA := Entry{}, Entry{} + entryAB, entryBA := transport.Entry{}, transport.Entry{} entryAB.SetEdges([2]cipher.PubKey{pkA, pkB}) entryBA.SetEdges([2]cipher.PubKey{pkA, pkB}) @@ -85,8 +86,8 @@ func ExampleSignedEntry_Sign() { pkA, skA := cipher.GenerateKeyPair() pkB, skB := cipher.GenerateKeyPair() - entry := NewEntry(pkA, pkB, "mock", true) - sEntry := &SignedEntry{Entry: entry} + entry := transport.NewEntry(pkA, pkB, "mock", true) + sEntry := &transport.SignedEntry{Entry: entry} if sEntry.Signatures[0].Null() && sEntry.Signatures[1].Null() { fmt.Println("No signatures set") @@ -119,8 +120,8 @@ func ExampleSignedEntry_Signature() { pkA, skA := cipher.GenerateKeyPair() pkB, skB := cipher.GenerateKeyPair() - entry := NewEntry(pkA, pkB, "mock", true) - sEntry := &SignedEntry{Entry: entry} + entry := transport.NewEntry(pkA, pkB, "mock", true) + sEntry := &transport.SignedEntry{Entry: entry} if ok := sEntry.Sign(pkA, skA); !ok { fmt.Println("Error signing sEntry with (pkA,skA)") } diff --git a/pkg/transport/handshake.go b/pkg/transport/handshake.go index 0d6505e59..7c8f829e3 100644 --- a/pkg/transport/handshake.go +++ b/pkg/transport/handshake.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "time" "github.com/skycoin/skywire/pkg/cipher" @@ -26,118 +27,100 @@ func (handshake settlementHandshake) Do(tm *Manager, tr Transport, timeout time. } } -func settlementInitiatorHandshake(public bool) settlementHandshake { - return func(tm *Manager, tr Transport) (*Entry, error) { - entry := &Entry{ - ID: MakeTransportID(tr.Edges()[0], tr.Edges()[1], tr.Type(), public), - EdgeKeys: tr.Edges(), - Type: tr.Type(), - Public: public, - } - - sEntry, ok := NewSignedEntry(entry, tm.config.PubKey, tm.config.SecKey) - if !ok { - return nil, errors.New("error creating signed entry") - } - if err := validateSignedEntry(sEntry, tr, tm.config.PubKey); err != nil { - return nil, fmt.Errorf("settlementInitiatorHandshake NewSignedEntry: %s\n sEntry: %v", err, sEntry) - } - - if err := json.NewEncoder(tr).Encode(sEntry); err != nil { - return nil, fmt.Errorf("write: %s", err) - } - - respSEntry := &SignedEntry{} - if err := json.NewDecoder(tr).Decode(respSEntry); err != nil { - return nil, fmt.Errorf("read: %s", err) - } - - // Verifying remote signature - remote, ok := tm.Remote(tr.Edges()) - if !ok { - return nil, errors.New("configured PubKey not found in edges") - } - if err := verifySig(respSEntry, remote); err != nil { - return nil, err - } - - newEntry := tm.walkEntries(func(e *Entry) bool { return *e == *respSEntry.Entry }) == nil - if newEntry { - tm.addEntry(entry) - } - - return respSEntry.Entry, nil +func makeEntry(tp Transport, public bool) *Entry { + return &Entry{ + ID: MakeTransportID(tp.Edges()[0], tp.Edges()[1], tp.Type(), public), + EdgeKeys: tp.Edges(), + Type: tp.Type(), + Public: public, } } -func settlementResponderHandshake(tm *Manager, tr Transport) (*Entry, error) { - sEntry := &SignedEntry{} - if err := json.NewDecoder(tr).Decode(sEntry); err != nil { - return nil, fmt.Errorf("read: %s", err) +func compareEntries(expected, received *Entry, checkPublic bool) error { + if !checkPublic { + expected.Public = received.Public + expected.ID = MakeTransportID(expected.EdgeKeys[0], expected.EdgeKeys[1], expected.Type, expected.Public) } - - remote, ok := tm.Remote(tr.Edges()) - if !ok { - return nil, errors.New("configured PubKey not found in edges") + if expected.ID != received.ID { + return errors.New("received entry's 'tp_id' is not of expected") } - - if err := validateSignedEntry(sEntry, tr, remote); err != nil { - return nil, err + if expected.EdgeKeys != received.EdgeKeys { + return errors.New("received entry's 'edges' is not of expected") } - - if ok := sEntry.Sign(tm.Local(), tm.config.SecKey); !ok { - return nil, errors.New("invalid pubkey for signing entry") + if expected.Type != received.Type { + return errors.New("received entry's 'type' is not of expected") } - - newEntry := tm.walkEntries(func(e *Entry) bool { return *e == *sEntry.Entry }) == nil - - var err error - if sEntry.Entry.Public { - if !newEntry { - _, err = tm.config.DiscoveryClient.UpdateStatuses(context.Background(), &Status{ID: sEntry.Entry.ID, IsUp: true}) - } else { - err = tm.config.DiscoveryClient.RegisterTransports(context.Background(), sEntry) - } - } - - if err != nil { - return nil, fmt.Errorf("entry set: %s", err) + if expected.Public != received.Public { + return errors.New("received entry's 'public' is not of expected") } + return nil +} - if err := json.NewEncoder(tr).Encode(sEntry); err != nil { - return nil, fmt.Errorf("write: %s", err) +func receiveAndVerifyEntry(r io.Reader, expected *Entry, remotePK cipher.PubKey, checkPublic bool) (*SignedEntry, error) { + var recvSE SignedEntry + if err := json.NewDecoder(r).Decode(&recvSE); err != nil { + return nil, fmt.Errorf("failed to read entry: %s", err) } - - if newEntry { - tm.addEntry(sEntry.Entry) + if err := compareEntries(expected, recvSE.Entry, checkPublic); err != nil { + return nil, err } - - return sEntry.Entry, nil -} - -func validateSignedEntry(sEntry *SignedEntry, tr Transport, pk cipher.PubKey) error { - entry := sEntry.Entry - if entry.Type != tr.Type() { - return errors.New("invalid entry type") + sig, ok := recvSE.Signature(remotePK) + if !ok { + return nil, errors.New("invalid remote signature") } - - if entry.Edges() != tr.Edges() { - return errors.New("invalid entry edges") + if err := cipher.VerifyPubKeySignedPayload(remotePK, sig, recvSE.Entry.ToBinary()); err != nil { + return nil, err } + return &recvSE, nil +} - // Weak check here - if sEntry.Signatures[0].Null() && sEntry.Signatures[1].Null() { - return errors.New("invalid entry signature") +func settlementInitiatorHandshake(public bool) settlementHandshake { + return func(tm *Manager, tp Transport) (*Entry, error) { + entry := makeEntry(tp, public) + se, ok := NewSignedEntry(entry, tm.config.PubKey, tm.config.SecKey) + if !ok { + return nil, errors.New("failed to sign entry") + } + if err := json.NewEncoder(tp).Encode(se); err != nil { + return nil, fmt.Errorf("failed to write entry: %v", err) + } + remotePK, ok := tm.Remote(tp.Edges()) + if !ok { + return nil, errors.New("invalid public key") + } + if _, err := receiveAndVerifyEntry(tp, entry, remotePK, true); err != nil { + return nil, err + } + tm.addEntry(entry) + return entry, nil } - - return verifySig(sEntry, pk) } -func verifySig(sEntry *SignedEntry, pk cipher.PubKey) error { - sig, ok := sEntry.Signature(pk) - if !ok { - return errors.New("invalid pubkey for retrieving signature") +func settlementResponderHandshake() settlementHandshake { + return func(tm *Manager, tr Transport) (*Entry, error) { + expectedEntry := makeEntry(tr, false) + remotePK, ok := tm.Remote(tr.Edges()) + if !ok { + return nil, errors.New("invalid public key") + } + recvSignedEntry, err := receiveAndVerifyEntry(tr, expectedEntry, remotePK, false) + if err != nil { + return nil, err + } + if ok := recvSignedEntry.Sign(tm.Local(), tm.config.SecKey); !ok { + return nil, errors.New("failed to sign received entry") + } + if isNew := tm.addIfNotExist(expectedEntry); !isNew { + _, err = tm.config.DiscoveryClient.UpdateStatuses(context.Background(), &Status{ID: recvSignedEntry.Entry.ID, IsUp: true}) + } else { + err = tm.config.DiscoveryClient.RegisterTransports(context.Background(), recvSignedEntry) + } + if err != nil { + return nil, err + } + if err := json.NewEncoder(tr).Encode(recvSignedEntry); err != nil { + return nil, fmt.Errorf("failed to write entry: %s", err) + } + return expectedEntry, nil } - - return cipher.VerifyPubKeySignedPayload(pk, sig, sEntry.Entry.ToBinary()) } diff --git a/pkg/transport/handshake_test.go b/pkg/transport/handshake_test.go index 8f7093aab..86e639ecc 100644 --- a/pkg/transport/handshake_test.go +++ b/pkg/transport/handshake_test.go @@ -88,87 +88,86 @@ func Example_newHsMock() { // err2 is nil: true } -func Example_validateEntry() { - pk1, sk1 := cipher.GenerateKeyPair() - pk2, _ := cipher.GenerateKeyPair() - pk3, _ := cipher.GenerateKeyPair() - tr := NewMockTransport(nil, pk1, pk2) - - entryInvalidEdges := &SignedEntry{ - Entry: &Entry{Type: "mock", - EdgeKeys: SortPubKeys(pk2, pk3), - }} - if err := validateSignedEntry(entryInvalidEdges, tr, pk1); err != nil { - fmt.Println(err.Error()) - } - - entry := NewEntry(pk1, pk2, "mock", true) - sEntry, ok := NewSignedEntry(entry, pk1, sk1) - if !ok { - fmt.Println("error creating signed entry") - } - if err := validateSignedEntry(sEntry, tr, pk1); err != nil { - fmt.Println(err.Error()) - } - - // Output: invalid entry edges -} - -func TestValidateEntry(t *testing.T) { - pk1, sk1 := cipher.GenerateKeyPair() - pk2, sk2 := cipher.GenerateKeyPair() - pk3, _ := cipher.GenerateKeyPair() - tr := NewMockTransport(nil, pk1, pk2) - - entry := &Entry{Type: "mock", EdgeKeys: SortPubKeys(pk2, pk1)} - tcs := []struct { - sEntry *SignedEntry - err string - }{ - { - &SignedEntry{Entry: &Entry{Type: "foo"}}, - "invalid entry type", - }, - { - &SignedEntry{Entry: &Entry{Type: "mock", EdgeKeys: SortPubKeys(pk1, pk3)}}, - "invalid entry edges", - }, - { - &SignedEntry{Entry: &Entry{Type: "mock", EdgeKeys: SortPubKeys(pk2, pk1)}}, - "invalid entry signature", - }, - { - &SignedEntry{Entry: entry, Signatures: [2]cipher.Sig{}}, - "invalid entry signature", - }, - { - func() *SignedEntry { - sEntry := &SignedEntry{Entry: entry, Signatures: [2]cipher.Sig{}} - _ = sEntry.Sign(pk1, sk2) // nolint - _ = sEntry.Sign(pk2, sk1) // nolint - return sEntry - }(), - "Recovered pubkey does not match pubkey", - }, - } - - for _, tc := range tcs { - t.Run(tc.err, func(t *testing.T) { - err := validateSignedEntry(tc.sEntry, tr, pk2) - require.Error(t, err) - assert.Equal(t, tc.err, err.Error()) - }) - } - - sEntry := &SignedEntry{Entry: entry, Signatures: [2]cipher.Sig{}} - require.True(t, sEntry.Sign(pk1, sk1)) - require.True(t, sEntry.Sign(pk2, sk2)) - - require.NoError(t, validateSignedEntry(sEntry, tr, pk1)) -} +//func Example_validateEntry() { +// pk1, sk1 := cipher.GenerateKeyPair() +// pk2, _ := cipher.GenerateKeyPair() +// pk3, _ := cipher.GenerateKeyPair() +// tr := NewMockTransport(nil, pk1, pk2) +// +// entryInvalidEdges := &SignedEntry{ +// Entry: &Entry{Type: "mock", +// EdgeKeys: SortPubKeys(pk2, pk3), +// }} +// if err := validateSignedEntry(entryInvalidEdges, tr, pk1); err != nil { +// fmt.Println(err.Error()) +// } +// +// entry := NewEntry(pk1, pk2, "mock", true) +// sEntry, ok := NewSignedEntry(entry, pk1, sk1) +// if !ok { +// fmt.Println("error creating signed entry") +// } +// if err := validateSignedEntry(sEntry, tr, pk1); err != nil { +// fmt.Println(err.Error()) +// } +// +// // Output: invalid entry edges +//} + +//func TestValidateEntry(t *testing.T) { +// pk1, sk1 := cipher.GenerateKeyPair() +// pk2, sk2 := cipher.GenerateKeyPair() +// pk3, _ := cipher.GenerateKeyPair() +// tr := NewMockTransport(nil, pk1, pk2) +// +// entry := &Entry{Type: "mock", EdgeKeys: SortPubKeys(pk2, pk1)} +// tcs := []struct { +// sEntry *SignedEntry +// err string +// }{ +// { +// &SignedEntry{Entry: &Entry{Type: "foo"}}, +// "invalid entry type", +// }, +// { +// &SignedEntry{Entry: &Entry{Type: "mock", EdgeKeys: SortPubKeys(pk1, pk3)}}, +// "invalid entry edges", +// }, +// { +// &SignedEntry{Entry: &Entry{Type: "mock", EdgeKeys: SortPubKeys(pk2, pk1)}}, +// "invalid entry signature", +// }, +// { +// &SignedEntry{Entry: entry, Signatures: [2]cipher.Sig{}}, +// "invalid entry signature", +// }, +// { +// func() *SignedEntry { +// sEntry := &SignedEntry{Entry: entry, Signatures: [2]cipher.Sig{}} +// _ = sEntry.Sign(pk1, sk2) // nolint +// _ = sEntry.Sign(pk2, sk1) // nolint +// return sEntry +// }(), +// "Recovered pubkey does not match pubkey", +// }, +// } +// +// for _, tc := range tcs { +// t.Run(tc.err, func(t *testing.T) { +// err := validateSignedEntry(tc.sEntry, tr, pk2) +// require.Error(t, err) +// assert.Equal(t, tc.err, err.Error()) +// }) +// } +// +// sEntry := &SignedEntry{Entry: entry, Signatures: [2]cipher.Sig{}} +// require.True(t, sEntry.Sign(pk1, sk1)) +// require.True(t, sEntry.Sign(pk2, sk2)) +// +// require.NoError(t, validateSignedEntry(sEntry, tr, pk1)) +//} func TestSettlementHandshake(t *testing.T) { - mockEnv := newHsMockEnv() t.Run("Create Mock Env", func(t *testing.T) { require.NoError(t, mockEnv.err1) @@ -178,7 +177,7 @@ func TestSettlementHandshake(t *testing.T) { errCh := make(chan error) var resEntry *Entry go func() { - e, err := settlementResponderHandshake(mockEnv.m2, mockEnv.tr2) + e, err := settlementResponderHandshake()(mockEnv.m2, mockEnv.tr2) resEntry = e errCh <- err }() @@ -199,30 +198,6 @@ func TestSettlementHandshake(t *testing.T) { } -/* -func TestSettlementHandshakeInvalidSig(t *testing.T) { - mockEnv := newHsMockEnv() - - require.NoError(t, mockEnv.err1) - require.NoError(t, mockEnv.err2) - - go settlementInitiatorHandshake(true)(mockEnv.m2, mockEnv.tr1) // nolint: errcheck - _, err := settlementResponderHandshake(mockEnv.m2, mockEnv.tr2) - require.Error(t, err) - assert.Equal(t, "Recovered pubkey does not match pubkey", err.Error()) - - in, out := net.Pipe() - tr1 := NewMockTransport(in, mockEnv.pk1, mockEnv.pk2) - tr2 := NewMockTransport(out, mockEnv.pk2, mockEnv.pk1) - - go settlementResponderHandshake(mockEnv.m1, tr2) // nolint: errcheck - _, err = settlementInitiatorHandshake(true)(mockEnv.m1, tr1) - require.Error(t, err) - assert.Equal(t, "Recovered pubkey does not match pubkey", err.Error()) - -} -*/ - func TestSettlementHandshakePrivate(t *testing.T) { mockEnv := newHsMockEnv() @@ -232,7 +207,7 @@ func TestSettlementHandshakePrivate(t *testing.T) { errCh := make(chan error) var resEntry *Entry go func() { - e, err := settlementResponderHandshake(mockEnv.m2, mockEnv.tr2) + e, err := settlementResponderHandshake()(mockEnv.m2, mockEnv.tr2) resEntry = e errCh <- err }() @@ -246,7 +221,7 @@ func TestSettlementHandshakePrivate(t *testing.T) { assert.Equal(t, entry.ID, resEntry.ID) _, err = mockEnv.client.GetTransportByID(context.TODO(), entry.ID) - require.Error(t, err) + require.NoError(t, err) } @@ -279,7 +254,7 @@ func TestSettlementHandshakeExistingTransport(t *testing.T) { errCh := make(chan error) var resEntry *Entry go func() { - e, err := settlementResponderHandshake(mockEnv.m2, mockEnv.tr2) + e, err := settlementResponderHandshake()(mockEnv.m2, mockEnv.tr2) resEntry = e errCh <- err }() @@ -299,22 +274,22 @@ func TestSettlementHandshakeExistingTransport(t *testing.T) { } -func Example_validateSignedEntry() { - mockEnv := newHsMockEnv() - - tm, tr := mockEnv.m1, mockEnv.tr1 - entry := NewEntry(mockEnv.pk1, mockEnv.pk2, "mock", true) - sEntry, ok := NewSignedEntry(entry, tm.config.PubKey, tm.config.SecKey) - if !ok { - fmt.Println("error creating signed entry") - } - if err := validateSignedEntry(sEntry, tr, tm.config.PubKey); err != nil { - fmt.Printf("NewSignedEntry: %v", err.Error()) - } - - fmt.Printf("System is working") - // Output: System is working -} +//func Example_validateSignedEntry() { +// mockEnv := newHsMockEnv() +// +// tm, tr := mockEnv.m1, mockEnv.tr1 +// entry := NewEntry(mockEnv.pk1, mockEnv.pk2, "mock", true) +// sEntry, ok := NewSignedEntry(entry, tm.config.PubKey, tm.config.SecKey) +// if !ok { +// fmt.Println("error creating signed entry") +// } +// if err := validateSignedEntry(sEntry, tr, tm.config.PubKey); err != nil { +// fmt.Printf("NewSignedEntry: %v", err.Error()) +// } +// +// fmt.Printf("System is working") +// // Output: System is working +//} func Example_settlementInitiatorHandshake() { mockEnv := newHsMockEnv() @@ -333,7 +308,7 @@ func Example_settlementInitiatorHandshake() { }() go func() { - if _, err := respondHandshake(mockEnv.m2, mockEnv.tr2); err != nil { + if _, err := respondHandshake()(mockEnv.m2, mockEnv.tr2); err != nil { fmt.Printf("respondHandshake error: %v\n", err.Error()) errCh <- err } diff --git a/pkg/transport/log_test.go b/pkg/transport/log_test.go index 092f75bf8..b118f57de 100644 --- a/pkg/transport/log_test.go +++ b/pkg/transport/log_test.go @@ -1,4 +1,4 @@ -package transport +package transport_test import ( "io/ioutil" @@ -9,15 +9,17 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/skycoin/skywire/pkg/transport" ) -func testTransportLogStore(t *testing.T, logStore LogStore) { +func testTransportLogStore(t *testing.T, logStore transport.LogStore) { t.Helper() id1 := uuid.New() - entry1 := &LogEntry{big.NewInt(100), big.NewInt(200)} + entry1 := &transport.LogEntry{big.NewInt(100), big.NewInt(200)} id2 := uuid.New() - entry2 := &LogEntry{big.NewInt(300), big.NewInt(400)} + entry2 := &transport.LogEntry{big.NewInt(300), big.NewInt(400)} require.NoError(t, logStore.Record(id1, entry1)) require.NoError(t, logStore.Record(id2, entry2)) @@ -29,7 +31,7 @@ func testTransportLogStore(t *testing.T, logStore LogStore) { } func TestInMemoryTransportLogStore(t *testing.T) { - testTransportLogStore(t, InMemoryTransportLogStore()) + testTransportLogStore(t, transport.InMemoryTransportLogStore()) } func TestFileTransportLogStore(t *testing.T) { @@ -37,7 +39,7 @@ func TestFileTransportLogStore(t *testing.T) { require.NoError(t, err) defer os.RemoveAll(dir) - ls, err := FileTransportLogStore(dir) + ls, err := transport.FileTransportLogStore(dir) require.NoError(t, err) testTransportLogStore(t, ls) } diff --git a/pkg/transport/managed_transport.go b/pkg/transport/managed_transport.go index f7ab42dcc..0873f9b2d 100644 --- a/pkg/transport/managed_transport.go +++ b/pkg/transport/managed_transport.go @@ -3,9 +3,6 @@ package transport import ( "math/big" "sync" - "time" - - "github.com/skycoin/skywire/pkg/cipher" "github.com/google/uuid" ) @@ -20,9 +17,9 @@ type ManagedTransport struct { LogEntry *LogEntry doneChan chan struct{} - doneOnce sync.Once errChan chan error mu sync.RWMutex + once sync.Once readLogChan chan int writeLogChan chan int @@ -36,8 +33,8 @@ func newManagedTransport(id uuid.UUID, tr Transport, public bool, accepted bool) Accepted: accepted, doneChan: make(chan struct{}), errChan: make(chan error), - readLogChan: make(chan int), - writeLogChan: make(chan int), + readLogChan: make(chan int, 16), + writeLogChan: make(chan int, 16), LogEntry: &LogEntry{new(big.Int), new(big.Int)}, } } @@ -47,19 +44,12 @@ func (tr *ManagedTransport) Read(p []byte) (n int, err error) { tr.mu.RLock() n, err = tr.Transport.Read(p) // TODO: data race. tr.mu.RUnlock() - if err == nil { - select { - case <-tr.doneChan: - return - case tr.readLogChan <- n: - } - return - } - select { - case <-tr.doneChan: - return - case tr.errChan <- err: + + if err != nil { + tr.errChan <- err } + + tr.readLogChan <- n return } @@ -68,40 +58,35 @@ func (tr *ManagedTransport) Write(p []byte) (n int, err error) { tr.mu.RLock() n, err = tr.Transport.Write(p) tr.mu.RUnlock() - if err == nil { - select { - case <-tr.doneChan: - return - case tr.writeLogChan <- n: - } - return - } - select { - case <-tr.doneChan: + + if err != nil { + tr.errChan <- err return - case tr.errChan <- err: } + tr.writeLogChan <- n + return } -// Edges returns the edges of underlying transport. -func (tr *ManagedTransport) Edges() [2]cipher.PubKey { - tr.mu.RLock() - edges := tr.Transport.Edges() - tr.mu.RUnlock() - return edges +// killWorker sends signal to Manager.manageTransport goroutine to exit +// it's safe to call it multiple times +func (tr *ManagedTransport) killWorker() { + tr.once.Do(func() { + close(tr.doneChan) + }) } -// SetDeadline sets the deadline of the underlying transport. -func (tr *ManagedTransport) SetDeadline(t time.Time) error { +// Close closes underlying +func (tr *ManagedTransport) Close() error { tr.mu.RLock() - err := tr.Transport.SetDeadline(t) + err := tr.Transport.Close() tr.mu.RUnlock() + + tr.killWorker() return err } -// IsClosing determines whether is closing. -func (tr *ManagedTransport) IsClosing() bool { +func (tr *ManagedTransport) isClosing() bool { select { case <-tr.doneChan: return true @@ -110,17 +95,8 @@ func (tr *ManagedTransport) IsClosing() bool { } } -// Close closes underlying -func (tr *ManagedTransport) Close() (err error) { - tr.mu.RLock() - err = tr.Transport.Close() - tr.mu.RUnlock() - tr.doneOnce.Do(func() { close(tr.doneChan) }) - return err -} - func (tr *ManagedTransport) updateTransport(newTr Transport) { tr.mu.Lock() - tr.Transport = newTr // TODO: data race. + tr.Transport = newTr tr.mu.Unlock() } diff --git a/pkg/transport/manager.go b/pkg/transport/manager.go index 869f93aa7..42fef0c14 100644 --- a/pkg/transport/manager.go +++ b/pkg/transport/manager.go @@ -6,6 +6,7 @@ import ( "math/big" "strings" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -34,7 +35,9 @@ type Manager struct { doneChan chan struct{} TrChan chan *ManagedTransport - mx sync.RWMutex + mu sync.RWMutex + + mgrQty int32 // Count of spawned manageTransport goroutines } // NewManager creates a Manager with the provided configuration and transport factories. @@ -53,7 +56,7 @@ func NewManager(config *ManagerConfig, factories ...Factory) (*Manager, error) { } return &Manager{ - Logger: logging.MustGetLogger("tpmanager"), + Logger: logging.MustGetLogger("trmanager"), config: config, factories: fMap, transports: make(map[uuid.UUID]*ManagedTransport), @@ -74,32 +77,31 @@ func (tm *Manager) Factories() []string { // Transport obtains a Transport via a given Transport ID. func (tm *Manager) Transport(id uuid.UUID) *ManagedTransport { - tm.mx.RLock() + tm.mu.RLock() tr := tm.transports[id] - tm.mx.RUnlock() + tm.mu.RUnlock() return tr } // WalkTransports ranges through all transports. func (tm *Manager) WalkTransports(walk func(tp *ManagedTransport) bool) { - tm.mx.RLock() + tm.mu.RLock() for _, tp := range tm.transports { - if ok := walk(tp); !ok { // TODO: data race. + if ok := walk(tp); !ok { break } } - tm.mx.RUnlock() + tm.mu.RUnlock() } // reconnectTransports tries to reconnect previously established transports. func (tm *Manager) reconnectTransports(ctx context.Context) { - tm.mx.RLock() + tm.mu.RLock() entries := make(map[Entry]struct{}) for tmEntry := range tm.entries { entries[tmEntry] = struct{}{} } - tm.mx.RUnlock() - + tm.mu.RUnlock() for entry := range entries { if tm.Transport(entry.ID) != nil { continue @@ -155,7 +157,7 @@ func (tm *Manager) createDefaultTransports(ctx context.Context) { if exist { continue } - _, err := tm.CreateTransport(ctx, pk, "dmsg", true) + _, err := tm.CreateTransport(ctx, pk, "messaging", true) if err != nil { tm.Logger.Warnf("Failed to establish transport to a node %s: %s", pk, err) } @@ -175,8 +177,10 @@ func (tm *Manager) Serve(ctx context.Context) error { for { select { case <-ctx.Done(): + tm.Logger.Info("Received ctx.Done()") return case <-tm.doneChan: + tm.Logger.Info("Received tm.doneCh") return default: if _, err := tm.acceptTransport(ctx, f); err != nil { @@ -187,6 +191,7 @@ func (tm *Manager) Serve(ctx context.Context) error { tm.Logger.Warnf("Failed to accept connection: %s", err) } } + } }(factory) } @@ -196,38 +201,6 @@ func (tm *Manager) Serve(ctx context.Context) error { return nil } -// MakeTransportID generates uuid.UUID from pair of keys + type + public -// Generated uuid is: -// - always the same for a given pair -// - GenTransportUUID(keyA,keyB) == GenTransportUUID(keyB, keyA) -func MakeTransportID(keyA, keyB cipher.PubKey, tpType string, public bool) uuid.UUID { - keys := SortPubKeys(keyA, keyB) - if public { - return uuid.NewSHA1(uuid.UUID{}, - append(append(append(keys[0][:], keys[1][:]...), []byte(tpType)...), 1)) - } - return uuid.NewSHA1(uuid.UUID{}, - append(append(append(keys[0][:], keys[1][:]...), []byte(tpType)...), 0)) -} - -// SortPubKeys sorts keys so that least-significant comes first -func SortPubKeys(keyA, keyB cipher.PubKey) [2]cipher.PubKey { - for i := 0; i < 33; i++ { - if keyA[i] != keyB[i] { - if keyA[i] < keyB[i] { - return [2]cipher.PubKey{keyA, keyB} - } - return [2]cipher.PubKey{keyB, keyA} - } - } - return [2]cipher.PubKey{keyA, keyB} -} - -// SortEdges sorts edges so that list-significant comes firs -func SortEdges(edges [2]cipher.PubKey) [2]cipher.PubKey { - return SortPubKeys(edges[0], edges[1]) -} - // CreateTransport begins to attempt to establish transports to the given 'remote' node. func (tm *Manager) CreateTransport(ctx context.Context, remote cipher.PubKey, tpType string, public bool) (*ManagedTransport, error) { return tm.createTransport(ctx, remote, tpType, public) @@ -235,18 +208,20 @@ func (tm *Manager) CreateTransport(ctx context.Context, remote cipher.PubKey, tp // DeleteTransport disconnects and removes the Transport of Transport ID. func (tm *Manager) DeleteTransport(id uuid.UUID) error { - tm.mx.Lock() - tp := tm.transports[id] + tm.mu.Lock() + tr := tm.transports[id] delete(tm.transports, id) - tm.mx.Unlock() + tm.mu.Unlock() + + tr.Close() if _, err := tm.config.DiscoveryClient.UpdateStatuses(context.Background(), &Status{ID: id, IsUp: false}); err != nil { tm.Logger.Warnf("Failed to change transport status: %s", err) } tm.Logger.Infof("Unregistered transport %s", id) - if tp != nil { - tp.Close() + if tr != nil { + return tr.Close() } return nil @@ -254,11 +229,10 @@ func (tm *Manager) DeleteTransport(id uuid.UUID) error { // Close closes opened transports and registered factories. func (tm *Manager) Close() error { - close(tm.doneChan) tm.Logger.Info("Closing transport manager") - tm.mx.Lock() + tm.mu.Lock() statuses := make([]*Status, 0) for _, tr := range tm.transports { if !tr.Public { @@ -266,9 +240,9 @@ func (tm *Manager) Close() error { } statuses = append(statuses, &Status{ID: tr.ID, IsUp: false}) - go tr.Close() + tr.Close() } - tm.mx.Unlock() + tm.mu.Unlock() if _, err := tm.config.DiscoveryClient.UpdateStatuses(context.Background(), statuses...); err != nil { tm.Logger.Warnf("Failed to change transport status: %s", err) @@ -283,6 +257,10 @@ func (tm *Manager) Close() error { func (tm *Manager) dialTransport(ctx context.Context, factory Factory, remote cipher.PubKey, public bool) (Transport, *Entry, error) { + if tm.isClosing() { + return nil, nil, errors.New("transport.Manager is closing. Skipping dialling transport") + } + tr, err := factory.Dial(ctx, remote) if err != nil { return nil, nil, err @@ -308,22 +286,23 @@ func (tm *Manager) createTransport(ctx context.Context, remote cipher.PubKey, tp return nil, err } - tm.Logger.Infof("Dialed to %s using %s factory. Transport ID: %s", remote, tpType, entry.ID) - managedTr := newManagedTransport(entry.ID, tr, entry.Public, false) - tm.mx.Lock() - tm.transports[entry.ID] = managedTr - select { - case <-tm.doneChan: - case tm.TrChan <- managedTr: - default: + oldTr := tm.Transport(entry.ID) + if oldTr != nil { + oldTr.killWorker() } - tm.mx.Unlock() - go tm.manageTransport(ctx, managedTr, factory, remote, public, false) + tm.Logger.Infof("Dialed to %s using %s factory. Transport ID: %s", remote, tpType, entry.ID) + mTr := newManagedTransport(entry.ID, tr, entry.Public, false) + + tm.mu.Lock() + tm.transports[entry.ID] = mTr + tm.mu.Unlock() - go tm.manageTransportLogs(managedTr) + tm.TrChan <- mTr - return managedTr, nil + go tm.manageTransport(ctx, mTr, factory, remote, public, false) + + return mTr, nil } func (tm *Manager) acceptTransport(ctx context.Context, factory Factory) (*ManagedTransport, error) { @@ -332,8 +311,11 @@ func (tm *Manager) acceptTransport(ctx context.Context, factory Factory) (*Manag return nil, err } - var handshake settlementHandshake = settlementResponderHandshake - entry, err := handshake.Do(tm, tr, 30*time.Second) + if tm.isClosing() { + return nil, errors.New("transport.Manager is closing. Skipping incoming transport") + } + + entry, err := settlementResponderHandshake().Do(tm, tr, 30*time.Second) if err != nil { tr.Close() return nil, err @@ -345,85 +327,90 @@ func (tm *Manager) acceptTransport(ctx context.Context, factory Factory) (*Manag } tm.Logger.Infof("Accepted new transport with type %s from %s. ID: %s", factory.Type(), remote, entry.ID) - managedTr := newManagedTransport(entry.ID, tr, entry.Public, true) - tm.mx.Lock() - tm.transports[entry.ID] = managedTr - select { - case <-tm.doneChan: - case tm.TrChan <- managedTr: - default: + oldTr := tm.Transport(entry.ID) + if oldTr != nil { + oldTr.killWorker() } - tm.mx.Unlock() - - go tm.manageTransport(ctx, managedTr, factory, remote, true, true) + mTr := newManagedTransport(entry.ID, tr, entry.Public, true) - go tm.manageTransportLogs(managedTr) - - return managedTr, nil -} + tm.mu.Lock() + tm.transports[entry.ID] = mTr + tm.mu.Unlock() -func (tm *Manager) walkEntries(walkFunc func(*Entry) bool) *Entry { - tm.mx.Lock() - defer tm.mx.Unlock() + tm.TrChan <- mTr - for entry := range tm.entries { - if walkFunc(&entry) { - return &entry - } - } + go tm.manageTransport(ctx, mTr, factory, remote, true, true) - return nil + return mTr, nil } func (tm *Manager) addEntry(entry *Entry) { - tm.mx.Lock() + tm.mu.Lock() tm.entries[*entry] = struct{}{} - tm.mx.Unlock() + tm.mu.Unlock() } -func (tm *Manager) manageTransport(ctx context.Context, managedTr *ManagedTransport, factory Factory, remote cipher.PubKey, public bool, accepted bool) { +func (tm *Manager) addIfNotExist(entry *Entry) (isNew bool) { + tm.mu.Lock() + if _, ok := tm.entries[*entry]; !ok { + tm.entries[*entry] = struct{}{} + isNew = true + } + tm.mu.Unlock() + return isNew +} + +func (tm *Manager) isClosing() bool { select { - case <-managedTr.doneChan: - tm.Logger.Infof("Transport %s closed", managedTr.ID) - return - case err := <-managedTr.errChan: - if !managedTr.IsClosing() { - tm.Logger.Infof("Transport %s failed with error: %s. Re-dialing...", managedTr.ID, err) - if accepted { - if err := tm.DeleteTransport(managedTr.ID); err != nil { - tm.Logger.Warnf("Failed to delete accepted transport: %s", err) - } - } else { - tr, _, err := tm.dialTransport(ctx, factory, remote, public) - if err != nil { - tm.Logger.Infof("Failed to re-dial Transport %s: %s", managedTr.ID, err) - if err := tm.DeleteTransport(managedTr.ID); err != nil { - tm.Logger.Warnf("Failed to delete re-dialled transport: %s", err) - } - } else { - managedTr.updateTransport(tr) - } - } - } else { - tm.Logger.Infof("Transport %s is already closing. Skipped error: %s", managedTr.ID, err) - } + case <-tm.doneChan: + return true + default: + return false } } -func (tm *Manager) manageTransportLogs(tr *ManagedTransport) { +func (tm *Manager) manageTransport(ctx context.Context, mTr *ManagedTransport, factory Factory, remote cipher.PubKey, public bool, accepted bool) { + mgrQty := atomic.AddInt32(&tm.mgrQty, 1) + tm.Logger.Infof("Spawned manageTransport for mTr.ID: %v. mgrQty: %v", mTr.ID, mgrQty) for { select { - case <-tr.doneChan: + case <-mTr.doneChan: + mgrQty := atomic.AddInt32(&tm.mgrQty, -1) + tm.Logger.Infof("manageTransport exit for %v. mgrQty: %v", mTr.ID, mgrQty) return - case n := <-tr.readLogChan: - tr.LogEntry.ReceivedBytes.Add(tr.LogEntry.ReceivedBytes, big.NewInt(int64(n))) - case n := <-tr.writeLogChan: - tr.LogEntry.SentBytes.Add(tr.LogEntry.SentBytes, big.NewInt(int64(n))) - } - - if err := tm.config.LogStore.Record(tr.ID, tr.LogEntry); err != nil { - tm.Logger.Warnf("Failed to record log entry: %s", err) + case err := <-mTr.errChan: + if !mTr.isClosing() { + tm.Logger.Infof("Transport %s failed with error: %s. Re-dialing...", mTr.ID, err) + if accepted { + if err := tm.DeleteTransport(mTr.ID); err != nil { + tm.Logger.Warnf("Failed to delete accepted transport: %s", err) + } + } else { + tr, _, err := tm.dialTransport(ctx, factory, remote, public) + if err != nil { + tm.Logger.Infof("Failed to re-dial Transport %s: %s", mTr.ID, err) + if err := tm.DeleteTransport(mTr.ID); err != nil { + tm.Logger.Warnf("Failed to delete re-dialled transport: %s", err) + } + } else { + tm.Logger.Infof("Updating transport %s", mTr.ID) + mTr.updateTransport(tr) + } + } + } else { + tm.Logger.Infof("Transport %s is already closing. Skipped error: %s", mTr.ID, err) + } + case n := <-mTr.readLogChan: + mTr.LogEntry.ReceivedBytes.Add(mTr.LogEntry.ReceivedBytes, big.NewInt(int64(n))) + if err := tm.config.LogStore.Record(mTr.ID, mTr.LogEntry); err != nil { + tm.Logger.Warnf("Failed to record log entry: %s", err) + } + case n := <-mTr.writeLogChan: + mTr.LogEntry.SentBytes.Add(mTr.LogEntry.SentBytes, big.NewInt(int64(n))) + if err := tm.config.LogStore.Record(mTr.ID, mTr.LogEntry); err != nil { + tm.Logger.Warnf("Failed to record log entry: %s", err) + } } } } diff --git a/pkg/transport/tcp_transport_test.go b/pkg/transport/tcp_transport_test.go index 0715a9c38..cd60504f8 100644 --- a/pkg/transport/tcp_transport_test.go +++ b/pkg/transport/tcp_transport_test.go @@ -1,4 +1,4 @@ -package transport +package transport_test import ( "context" @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/skycoin/skywire/pkg/cipher" + "github.com/skycoin/skywire/pkg/transport" ) func TestTCPFactory(t *testing.T) { @@ -28,10 +29,10 @@ func TestTCPFactory(t *testing.T) { l2, err := net.ListenTCP("tcp", addr2) require.NoError(t, err) - pkt1 := InMemoryPubKeyTable(map[cipher.PubKey]*net.TCPAddr{pk2: addr2}) - pkt2 := InMemoryPubKeyTable(map[cipher.PubKey]*net.TCPAddr{pk1: addr1}) + pkt1 := transport.InMemoryPubKeyTable(map[cipher.PubKey]*net.TCPAddr{pk2: addr2}) + pkt2 := transport.InMemoryPubKeyTable(map[cipher.PubKey]*net.TCPAddr{pk1: addr1}) - f1 := NewTCPFactory(pk1, pkt1, l1) + f1 := transport.NewTCPFactory(pk1, pkt1, l1) errCh := make(chan error) go func() { tr, err := f1.Accept(context.TODO()) @@ -48,7 +49,7 @@ func TestTCPFactory(t *testing.T) { errCh <- nil }() - f2 := NewTCPFactory(pk2, pkt2, l2) + f2 := transport.NewTCPFactory(pk2, pkt2, l2) assert.Equal(t, "tcp", f2.Type()) assert.Equal(t, pk2, f2.Local()) @@ -79,7 +80,7 @@ func TestFilePKTable(t *testing.T) { _, err = tmpfile.Write([]byte(fmt.Sprintf("%s\t%s\n", pk, addr))) require.NoError(t, err) - pkt, err := FilePubKeyTable(tmpfile.Name()) + pkt, err := transport.FilePubKeyTable(tmpfile.Name()) require.NoError(t, err) raddr := pkt.RemoteAddr(pk) diff --git a/pkg/transport/transport.go b/pkg/transport/transport.go index 8e66224af..9b2a1b2e4 100644 --- a/pkg/transport/transport.go +++ b/pkg/transport/transport.go @@ -6,6 +6,8 @@ import ( "context" "time" + "github.com/google/uuid" + "github.com/skycoin/skywire/pkg/cipher" ) @@ -50,3 +52,35 @@ type Factory interface { // Type returns the Transport type. Type() string } + +// MakeTransportID generates uuid.UUID from pair of keys + type + public +// Generated uuid is: +// - always the same for a given pair +// - GenTransportUUID(keyA,keyB) == GenTransportUUID(keyB, keyA) +func MakeTransportID(keyA, keyB cipher.PubKey, tpType string, public bool) uuid.UUID { + keys := SortPubKeys(keyA, keyB) + if public { + return uuid.NewSHA1(uuid.UUID{}, + append(append(append(keys[0][:], keys[1][:]...), []byte(tpType)...), 1)) + } + return uuid.NewSHA1(uuid.UUID{}, + append(append(append(keys[0][:], keys[1][:]...), []byte(tpType)...), 0)) +} + +// SortPubKeys sorts keys so that least-significant comes first +func SortPubKeys(keyA, keyB cipher.PubKey) [2]cipher.PubKey { + for i := 0; i < 33; i++ { + if keyA[i] != keyB[i] { + if keyA[i] < keyB[i] { + return [2]cipher.PubKey{keyA, keyB} + } + return [2]cipher.PubKey{keyB, keyA} + } + } + return [2]cipher.PubKey{keyA, keyB} +} + +// SortEdges sorts edges so that list-significant comes firs +func SortEdges(edges [2]cipher.PubKey) [2]cipher.PubKey { + return SortPubKeys(edges[0], edges[1]) +}