diff --git a/pkg/messaging/channel.go b/pkg/messaging/channel.go index 67007cd14..fe38281a8 100644 --- a/pkg/messaging/channel.go +++ b/pkg/messaging/channel.go @@ -237,7 +237,7 @@ func (mCh *msgChannel) readEncrypted(ctx context.Context, p []byte) (n int, err } if len(data) > len(p) { - if _, err := mCh.buf.Write(data[len(p):]); err != nil { + if _, err := mCh.buf.Write(data[len(p):]); err != nil { // TODO: data race. return 0, io.ErrShortBuffer } diff --git a/pkg/router/router.go b/pkg/router/router.go index 9390cbe10..4f4ab65ea 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -80,37 +80,29 @@ func New(config *Config) *Router { func (r *Router) Serve(ctx context.Context) error { go func() { - for tr := range r.tm.TrChan { - if tr.Accepted { - go func(t *transport.ManagedTransport) { - for { - var err error - if r.IsSetupTransport(t) { - err = r.rm.Serve(t) - } else { - err = r.serveTransport(t) - } + for tp := range r.tm.TrChan { + isAccepted, isSetup := tp.Accepted, r.IsSetupTransport(tp) + + var serve func(io.ReadWriter) error + switch { + case isAccepted && isSetup: + serve = r.rm.Serve + case !isSetup: + serve = r.serveTransport + default: + continue + } - if err != nil { - if err != io.EOF { - r.Logger.Warnf("Stopped serving Transport: %s", err) - } - return + go func(tp transport.Transport) { + for { + if err := serve(tp); err != nil { + if err != io.EOF { + r.Logger.Warnf("Stopped serving Transport: %s", err) } + return } - }(tr) - } else { - go func(t *transport.ManagedTransport) { - for { - if err := r.serveTransport(t); err != nil { - if err != io.EOF { - r.Logger.Warnf("Stopped serving Transport: %s", err) - } - return - } - } - }(tr) - } + } + }(tp) } }() @@ -178,14 +170,14 @@ func (r *Router) Close() error { return r.tm.Close() } -func (r *Router) serveTransport(tr *transport.ManagedTransport) error { +func (r *Router) serveTransport(rw io.ReadWriter) error { packet := make(routing.Packet, 6) - if _, err := io.ReadFull(tr, packet); err != nil { + if _, err := io.ReadFull(rw, packet); err != nil { return err } payload := make([]byte, packet.Size()) - if _, err := io.ReadFull(tr, payload); err != nil { + if _, err := io.ReadFull(rw, payload); err != nil { return err } diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index d07d6d0f9..d4f9c7aa0 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -487,8 +487,6 @@ func TestRouterSetupLoop(t *testing.T) { r := New(conf) errCh := make(chan error) go func() { - // acceptCh, _ := m2.Observe() - // tr := <-acceptCh var tr *transport.ManagedTransport for tr = range m2.TrChan { if tr.Accepted { diff --git a/pkg/setup/node.go b/pkg/setup/node.go index 0a4b48f6e..38b7e4e21 100644 --- a/pkg/setup/node.go +++ b/pkg/setup/node.go @@ -92,7 +92,6 @@ func (sn *Node) Serve(ctx context.Context) error { go func() { for tr := range sn.tm.TrChan { - if tr.Accepted { go func(t transport.Transport) { for { diff --git a/pkg/setup/protocol.go b/pkg/setup/protocol.go index 436ce258b..59fbaed18 100644 --- a/pkg/setup/protocol.go +++ b/pkg/setup/protocol.go @@ -76,7 +76,7 @@ func NewSetupProtocol(rw io.ReadWriter) *Protocol { // ReadPacket reads a single setup packet. func (p *Protocol) ReadPacket() (PacketType, []byte, error) { rawLen := make([]byte, 2) - if _, err := io.ReadFull(p.rw, rawLen); err != nil { + if _, err := io.ReadFull(p.rw, rawLen); err != nil { // TODO: data race. return 0, nil, err } rawBody := make([]byte, binary.BigEndian.Uint16(rawLen)) @@ -146,7 +146,7 @@ func CreateLoop(p *Protocol, l *routing.Loop) error { if err := p.WritePacket(PacketCreateLoop, l); err != nil { return err } - if err := readAndDecodePacket(p, nil); err != nil { + if err := readAndDecodePacket(p, nil); err != nil { // TODO: data race. return err } return nil @@ -187,7 +187,7 @@ func LoopClosed(p *Protocol, l *LoopData) error { } func readAndDecodePacket(p *Protocol, v interface{}) error { - t, raw, err := p.ReadPacket() + t, raw, err := p.ReadPacket() // TODO: data race. if err != nil { return err } diff --git a/pkg/transport/managed_transport.go b/pkg/transport/managed_transport.go index c801b2993..0af29c7bb 100644 --- a/pkg/transport/managed_transport.go +++ b/pkg/transport/managed_transport.go @@ -41,7 +41,7 @@ func newManagedTransport(id uuid.UUID, tr Transport, public bool, accepted bool) // Read reads using underlying func (tr *ManagedTransport) Read(p []byte) (n int, err error) { tr.mu.RLock() - n, err = tr.Transport.Read(p) + n, err = tr.Transport.Read(p) // TODO: data race. tr.mu.RUnlock() if err == nil { select {