Skip to content

Commit

Permalink
Merge branch 'feature/skymsg' of https://github.com/evanlinjin/skywire
Browse files Browse the repository at this point in the history
…into feature/skymsg-tests
  • Loading branch information
Darkren committed Jun 4, 2019
2 parents b30b3a2 + 92c28d4 commit f45e7eb
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 65 deletions.
24 changes: 12 additions & 12 deletions internal/ioutil/ack_waiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"crypto/rand"
"encoding/binary"
"io"
"math"
"sync"
)
Expand Down Expand Up @@ -45,38 +44,39 @@ func (w *Uint16AckWaiter) RandSeq() error {
}

// Wait performs the given action, and waits for given seq to be Done.
func (w *Uint16AckWaiter) Wait(ctx context.Context, done <-chan struct{}, action func(seq Uint16Seq) error) error {
func (w *Uint16AckWaiter) Wait(ctx context.Context, action func(seq Uint16Seq) error) (err error) {
ackCh := make(chan struct{})
defer close(ackCh)

w.mx.Lock()
seq := w.nextSeq
w.nextSeq++
w.waiters[seq] = ackCh
w.mx.Unlock()

if err := action(seq); err != nil {
if err = action(seq); err != nil {
return err
}

select {
case <-ackCh:
return nil
case <-done:
return io.ErrClosedPipe
case <-ctx.Done():
return ctx.Err()
err = ctx.Err()
}

w.mx.Lock()
close(ackCh)
w.waiters[seq] = nil
w.mx.Unlock()

return err
}

// Done finishes given sequence.
func (w *Uint16AckWaiter) Done(seq Uint16Seq) {
w.mx.RLock()
ackCh := w.waiters[seq]
w.mx.RUnlock()

select {
case ackCh <- struct{}{}:
case w.waiters[seq] <- struct{}{}:
default:
}
w.mx.RUnlock()
}
24 changes: 24 additions & 0 deletions internal/ioutil/ack_waiter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package ioutil

import (
"context"
"testing"
)

// Ensure that no race conditions occurs.
func TestUint16AckWaiter_Wait(t *testing.T) {
w := new(Uint16AckWaiter)

seqChan := make(chan Uint16Seq)
defer close(seqChan)
for i := 0; i < 64; i++ {
go w.Wait(context.TODO(), func(seq Uint16Seq) error { //nolint:errcheck
seqChan <- seq
return nil
})
seq := <-seqChan
for j := 0; j < i; j++ {
go w.Done(seq)
}
}
}
38 changes: 22 additions & 16 deletions pkg/dmsg/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,15 @@ type ClientConn struct {
// map of transports to remote dms_clients (key: tp_id, val: transport).
tps [math.MaxUint16 + 1]*Transport
mx sync.RWMutex // to protect tps.

wg sync.WaitGroup
}

// NewClientConn creates a new ClientConn.
func NewClientConn(log *logging.Logger, conn net.Conn, local, remote cipher.PubKey) *ClientConn {
return &ClientConn{log: log, Conn: conn, local: local, remoteSrv: remote, nextInitID: randID(true)}
cc := &ClientConn{log: log, Conn: conn, local: local, remoteSrv: remote, nextInitID: randID(true)}
cc.wg.Add(1)
return cc
}

func (c *ClientConn) delTp(id uint16) {
Expand Down Expand Up @@ -92,7 +96,7 @@ func (c *ClientConn) getTp(id uint16) (*Transport, bool) {
return tp, ok
}

func (c *ClientConn) handleRequestFrame(ctx context.Context, done <-chan struct{}, accept chan<- *Transport, id uint16, p []byte) (cipher.PubKey, error) {
func (c *ClientConn) handleRequestFrame(ctx context.Context, accept chan<- *Transport, id uint16, p []byte) (cipher.PubKey, error) {
// remote-initiated tps should:
// - have a payload structured as 'init_pk:resp_pk'.
// - resp_pk should be of local client.
Expand All @@ -115,8 +119,6 @@ func (c *ClientConn) handleRequestFrame(ctx context.Context, done <-chan struct{
c.setTp(tp)

select {
case <-done:
return initPK, ErrClientClosed
case <-ctx.Done():
return initPK, ctx.Err()
case accept <- tp:
Expand All @@ -126,19 +128,22 @@ func (c *ClientConn) handleRequestFrame(ctx context.Context, done <-chan struct{

// Serve handles incoming frames.
// Remote-initiated tps that are successfully created are pushing into 'accept' and exposed via 'Client.Accept()'.
func (c *ClientConn) Serve(ctx context.Context, done <-chan struct{}, accept chan<- *Transport) (err error) {
func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) (err error) {
defer c.wg.Done()

log := c.log.WithField("remoteServer", c.remoteSrv)

log.WithField("connCount", incrementServeCount()).Infoln("ServingConn")
defer func() {
log.WithError(err).WithField("connCount", decrementServeCount()).Infoln("ClosingConn")
}()
log.WithField("connCount", incrementServeCount()).Infoln("ServingConn")

for {
f, err := readFrame(c.Conn)
if err != nil {
return fmt.Errorf("read failed: %s", err)
}
log = log.WithField("frame", f)
log = log.WithField("received", f)

ft, id, p := f.Disassemble()

Expand All @@ -160,7 +165,7 @@ func (c *ClientConn) Serve(ctx context.Context, done <-chan struct{}, accept cha
case RequestType:
// TODO(evanlinjin): Allow for REQUEST frame handling to be done in goroutine.
// Currently this causes issues (probably because we need ACK frames).
initPK, err := c.handleRequestFrame(ctx, done, accept, id, p)
initPK, err := c.handleRequestFrame(ctx, accept, id, p)
if err != nil {
log.WithField("remoteClient", initPK).WithError(err).Infoln("FrameRejected")
if err == ErrRequestCheckFailed {
Expand Down Expand Up @@ -200,7 +205,7 @@ func (c *ClientConn) Close() error {
}
err := c.Conn.Close()
c.mx.Unlock()
//c.wg.Wait()
c.wg.Wait()
return err
}

Expand All @@ -223,7 +228,7 @@ type Client struct {
// NewClient creates a new Client.
func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc client.APIClient) *Client {
return &Client{
log: logging.MustGetLogger("dms_client"),
log: logging.MustGetLogger("dmsg_client"),
pk: pk,
sk: sk,
dc: dc,
Expand Down Expand Up @@ -260,7 +265,7 @@ func (c *Client) setConn(ctx context.Context, l *ClientConn) {
c.mx.Lock()
c.conns[l.remoteSrv] = l
if err := c.updateDiscEntry(ctx); err != nil {
c.log.WithError(err).Warn("failed to update dms_client entry")
c.log.WithError(err).Warn("updateEntry: failed")
}
c.mx.Unlock()
}
Expand All @@ -269,7 +274,7 @@ func (c *Client) delConn(ctx context.Context, pk cipher.PubKey) {
c.mx.Lock()
delete(c.conns, pk)
if err := c.updateDiscEntry(ctx); err != nil {
c.log.WithError(err).Warn("failed to update dms_client entry")
c.log.WithError(err).Warn("updateEntry: failed")
}
c.mx.Unlock()
}
Expand Down Expand Up @@ -312,8 +317,9 @@ func (c *Client) findServerEntries(ctx context.Context) ([]*client.Entry, error)
case <-ctx.Done():
return nil, fmt.Errorf("dms_servers are not available: %s", err)
default:
c.log.WithError(err).Warnf("no dms_servers found: trying again is 1 second...")
time.Sleep(time.Second)
retry := time.Second
c.log.WithError(err).Warnf("no dms_servers found: trying again in %d second...", retry)
time.Sleep(retry)
continue
}
}
Expand Down Expand Up @@ -370,7 +376,7 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey)
conn := NewClientConn(c.log, nc, c.pk, srvPK)
c.setConn(ctx, conn)
go func() {
if err := conn.Serve(ctx, c.done, c.accept); err != nil {
if err := conn.Serve(ctx, c.accept); err != nil {
conn.log.WithError(err).WithField("dms_server", srvPK).Warn("connected with dms_server closed")
c.delConn(ctx, srvPK)

Expand All @@ -380,7 +386,7 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey)
case <-c.done:
case <-ctx.Done():
case <-t.C:
conn.log.WithField("dms_server", srvPK).Warn("reconnecting to dms_server")
conn.log.WithField("remoteServer", srvPK).Warn("Reconnecting")
_, _ = c.findOrConnectToServer(ctx, srvPK) //nolint:errcheck
}
return
Expand Down
27 changes: 24 additions & 3 deletions pkg/dmsg/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"math"
"net"
"sync"
"time"

"github.com/skycoin/skycoin/src/util/logging"

Expand Down Expand Up @@ -117,7 +118,7 @@ func (c *ServerConn) Serve(ctx context.Context, getConn getConnFunc) (err error)
if err != nil {
return fmt.Errorf("read failed: %s", err)
}
log = log.WithField("frame", f)
log = log.WithField("received", f)

ft, id, p := f.Disassemble()

Expand Down Expand Up @@ -292,11 +293,11 @@ func (s *Server) Serve() error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

if err := s.updateDiscEntry(ctx); err != nil {
if err := s.retryUpdateEntry(ctx, hsTimeout); err != nil {
return fmt.Errorf("updating server's discovery entry failed with: %s", err)
}

s.log.Infof("serving: pk(%s) addr(%s)", s.pk, s.lis.Addr())
s.log.Infof("serving: pk(%s) addr(%s)", s.pk, s.Addr())

for {
rawConn, err := s.lis.Accept()
Expand Down Expand Up @@ -335,3 +336,23 @@ func (s *Server) updateDiscEntry(ctx context.Context) error {

return s.dc.UpdateEntry(ctx, s.sk, entry)
}

func (s *Server) retryUpdateEntry(ctx context.Context, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

for {
if err := s.updateDiscEntry(ctx); err != nil {
select {
case <-ctx.Done():
return ctx.Err()
default:
retry := time.Second
s.log.WithError(err).Warnf("updateEntry failed: trying again in %d second...", retry)
time.Sleep(retry)
continue
}
}
return nil
}
}
19 changes: 12 additions & 7 deletions pkg/dmsg/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func NewTransport(conn net.Conn, log *logging.Logger, local, remote cipher.PubKe
doneCh: make(chan struct{}),
}
if err := tp.ackWaiter.RandSeq(); err != nil {
log.Fatalln("failed to set ack_water seq:", err)
log.Fatalln("failed to set ack_waiter seq:", err)
}
return tp
}
Expand Down Expand Up @@ -107,14 +107,13 @@ func (c *Transport) Handshake(ctx context.Context) error {
return err
}
} else {
c.log.Infof("tp_hs responding...")
f := MakeFrame(AcceptType, c.id, combinePKs(c.remote, c.local))
if err := writeFrame(c.Conn, f); err != nil {
c.log.WithError(err).Error("tp_hs responded with error.")
c.log.WithError(err).Error("HandshakeFailed")
c.close()
return err
}
c.log.Infoln("tp_hs responded:", f)
c.log.WithField("sent", f).Infoln("HandshakeCompleted")
}
return nil
}
Expand Down Expand Up @@ -213,16 +212,22 @@ func (c *Transport) Write(p []byte) (int, error) {
return 0, io.ErrClosedPipe
default:
ctx, cancel := context.WithTimeout(context.Background(), readTimeout)
defer cancel()

err := c.ackWaiter.Wait(ctx, c.doneCh, func(seq ioutil.Uint16Seq) error {
go func() {
select {
case <-ctx.Done():
case <-c.doneCh:
cancel()
}
}()
err := c.ackWaiter.Wait(ctx, func(seq ioutil.Uint16Seq) error {
if err := writeFwdFrame(c.Conn, c.id, seq, p); err != nil {
c.close()
return err
}
return nil
})
if err != nil {
cancel()
return 0, err
}
return len(p), nil
Expand Down
21 changes: 20 additions & 1 deletion pkg/transport/managed_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package transport
import (
"math/big"
"sync"
"time"

"github.com/skycoin/skywire/pkg/cipher"

"github.com/google/uuid"
)
Expand Down Expand Up @@ -81,6 +84,22 @@ func (tr *ManagedTransport) Write(p []byte) (n int, err error) {
return
}

// Edges returns the edges of underlying transport.
func (tr *ManagedTransport) Edges() [2]cipher.PubKey {
tr.mu.RLock()
edges := tr.Transport.Edges()
tr.mu.RUnlock()
return edges
}

// SetDeadline sets the deadline of the underlying transport.
func (tr *ManagedTransport) SetDeadline(t time.Time) error {
tr.mu.RLock()
err := tr.Transport.SetDeadline(t)
tr.mu.RUnlock()
return err
}

// IsClosing determines whether is closing.
func (tr *ManagedTransport) IsClosing() bool {
select {
Expand All @@ -102,6 +121,6 @@ func (tr *ManagedTransport) Close() (err error) {

func (tr *ManagedTransport) updateTransport(newTr Transport) {
tr.mu.Lock()
tr.Transport = newTr
tr.Transport = newTr // TODO: data race.
tr.mu.Unlock()
}
Loading

0 comments on commit f45e7eb

Please sign in to comment.