diff --git a/cmd/apps/chat/chat.go b/cmd/apps/chat/chat.go index 6f1f63ee3..a1dbee6d4 100644 --- a/cmd/apps/chat/chat.go +++ b/cmd/apps/chat/chat.go @@ -131,10 +131,10 @@ func sseHandler(w http.ResponseWriter, req *http.Request) { } clientChan = make(chan string) + go func() { <-req.Context().Done() close(clientChan) - clientChan = nil log.Println("SSE connection were closed.") }() diff --git a/internal/ioutil/ack.go b/internal/ioutil/ack.go deleted file mode 100644 index 05f049f73..000000000 --- a/internal/ioutil/ack.go +++ /dev/null @@ -1,269 +0,0 @@ -package ioutil - -import ( - "bytes" - "errors" - "io" - "math" - "sync" - "time" - - "github.com/skycoin/skywire/pkg/cipher" -) - -// DataPacketType defines types of data packets. -type DataPacketType byte - -const ( - // DataPacketPayload represents Payload data packet. - DataPacketPayload DataPacketType = iota - // DataPacketAck represents Ack data packet. - DataPacketAck -) - -// ErrClosed is the error used for read or write operations on a closed ReadWriter. -var ErrClosed = errors.New("read/write: closed") - -// AckReadWriter is an io.ReadWriter wrapper that implements ack logic -// for writes. Writes are blocked till Ack packets are received, CRC -// check is performed using SHA256. Ack packets are either sent along -// with subsequent writes or flushed each ackInterval. -type AckReadWriter struct { - rw io.ReadWriteCloser - - sndAcks *ackList - rcvAcks *ackList - - readChan chan []byte - errChan chan error - buf *bytes.Buffer - doneChan chan struct{} - ackInterval time.Duration -} - -// NewAckReadWriter constructs a new AckReadWriter. -func NewAckReadWriter(rw io.ReadWriteCloser, ackInterval time.Duration) *AckReadWriter { - arw := &AckReadWriter{ - rw: rw, - sndAcks: newAckList(), - rcvAcks: newAckList(), - doneChan: make(chan struct{}), - readChan: make(chan []byte), - errChan: make(chan error), - ackInterval: ackInterval, - buf: new(bytes.Buffer), - } - go arw.serveLoop() - return arw -} - -func (arw *AckReadWriter) Write(p []byte) (n int, err error) { - errCh := make(chan error) - seq := arw.sndAcks.push(&ack{errCh, cipher.SumSHA256(p)}) - packet := append([]byte{byte(DataPacketPayload), seq}, p...) - - _, _, buf := arw.ackPacket() - buf = append(buf, packet...) - n, err = arw.rw.Write(buf) - if err != nil { - return - } - - if n != len(buf) { - err = io.ErrShortWrite - return - } - - select { - case <-arw.doneChan: - return 0, ErrClosed - case err = <-errCh: - return len(p), err - } -} - -func (arw *AckReadWriter) Read(p []byte) (n int, err error) { - if arw.buf.Len() != 0 { - return arw.buf.Read(p) - } - - select { - case <-arw.doneChan: - return 0, io.EOF - case err := <-arw.errChan: - return 0, err - case data, more := <-arw.readChan: - if !more { - return 0, io.EOF - } - - time.AfterFunc(arw.ackInterval, arw.flush) - - if len(data) > len(p) { - if _, err := arw.buf.Write(data[len(p):]); err != nil { - return 0, io.ErrShortBuffer - } - - return copy(p, data[:len(p)]), nil - } - - return copy(p, data), nil - } -} - -// Close implements io.Closer for AckReadWriter. -func (arw *AckReadWriter) Close() error { - select { - case <-arw.doneChan: - default: - arw.flush() - close(arw.doneChan) - close(arw.readChan) - } - - return arw.rw.Close() -} - -func (arw *AckReadWriter) serveLoop() { - buf := make([]byte, 100*1024) - for { - n, err := arw.rw.Read(buf) - if err != nil { - select { - case <-arw.doneChan: - case arw.errChan <- err: - } - return - } - - data := buf[:n] - for { - if len(data) == 0 || DataPacketType(data[0]) == DataPacketPayload { - break - } - - arw.confirm(data[1], data[2:34]) - data = data[34:] - } - - if len(data) == 0 { - continue - } - - arw.rcvAcks.set(data[1], &ack{nil, cipher.SumSHA256(data[2:])}) - go func() { - select { - case <-arw.doneChan: - case arw.readChan <- data[2:]: - } - }() - } -} - -func (arw *AckReadWriter) ackPacket() ([]byte, []*ack, []byte) { - buf := make([]byte, 0) - acks := make([]*ack, 0) - seqs := make([]byte, 0) - for { - seq, ack := arw.rcvAcks.pull() - if ack == nil { - break - } - - buf = append([]byte{byte(DataPacketAck), seq}, ack.hash[:]...) - acks = append(acks, ack) - seqs = append(seqs, seq) - } - - return seqs, acks, buf -} - -func (arw *AckReadWriter) confirm(seq byte, hash []byte) { - ack := arw.sndAcks.remove(seq) - if ack == nil { - return - } - - rcvHash, err := cipher.SHA256FromBytes(hash) - if err != nil { - ack.errChan <- err - return - } - - if ack.hash != rcvHash { - ack.errChan <- errors.New("invalid CRC") - return - } - - ack.errChan <- nil -} - -func (arw *AckReadWriter) flush() { - seqs, acks, p := arw.ackPacket() - if len(p) == 0 { - return - } - - if _, err := arw.rw.Write(p); err != nil { - for idx, ack := range acks { - arw.rcvAcks.set(seqs[idx], ack) - } - } -} - -type ack struct { - errChan chan error - hash cipher.SHA256 -} - -type ackList struct { - sync.Mutex - - acks []*ack -} - -func newAckList() *ackList { - return &ackList{acks: make([]*ack, math.MaxUint8)} -} - -func (al *ackList) push(a *ack) byte { - al.Lock() - defer al.Unlock() - - for i := byte(0); i < math.MaxUint8; i++ { - if al.acks[i] == nil { - al.acks[i] = a - return i - } - } - - panic("too many packets in flight") -} - -func (al *ackList) pull() (byte, *ack) { - al.Lock() - defer al.Unlock() - - for seq, ack := range al.acks { - if ack != nil { - al.acks[seq] = nil - return byte(seq), ack - } - } - - return 0, nil -} - -func (al *ackList) set(seq byte, a *ack) { - al.Lock() - al.acks[seq] = a - al.Unlock() -} - -func (al *ackList) remove(seq byte) *ack { - al.Lock() - a := al.acks[seq] - al.acks[seq] = nil - al.Unlock() - return a -} diff --git a/internal/ioutil/ack_test.go b/internal/ioutil/ack_test.go deleted file mode 100644 index 3c2a21205..000000000 --- a/internal/ioutil/ack_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package ioutil - -import ( - "io" - "net" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/skycoin/skywire/pkg/cipher" -) - -func TestAckReadWriter(t *testing.T) { - in, out := net.Pipe() - rw1 := NewAckReadWriter(in, 100*time.Millisecond) - rw2 := NewAckReadWriter(out, 100*time.Millisecond) - - errCh := make(chan error) - go func() { - _, err := rw1.Write([]byte("foo")) - errCh <- err - }() - - buf := make([]byte, 3) - n, err := rw2.Read(buf) - require.NoError(t, err) - assert.Equal(t, 3, n) - assert.Equal(t, []byte("foo"), buf) - - errCh2 := make(chan error) - go func() { - _, err = rw2.Write([]byte("bar")) - errCh2 <- err - }() - - require.NoError(t, <-errCh) - - buf = make([]byte, 3) - n, err = rw1.Read(buf) - require.NoError(t, err) - assert.Equal(t, 3, n) - assert.Equal(t, []byte("bar"), buf) - - require.NoError(t, rw1.Close()) - require.NoError(t, <-errCh2) - require.NoError(t, rw2.Close()) -} - -func TestAckReadWriterCRCFailure(t *testing.T) { - in, out := net.Pipe() - rw1 := NewAckReadWriter(in, 100*time.Millisecond) - rw2 := NewAckReadWriter(out, 100*time.Millisecond) - - errCh := make(chan error) - go func() { - _, err := rw1.Write([]byte("foo")) - errCh <- err - }() - - buf := make([]byte, 3) - n, err := rw2.Read(buf) - require.NoError(t, err) - assert.Equal(t, 3, n) - assert.Equal(t, []byte("foo"), buf) - - rw2.rcvAcks.set(0, &ack{nil, cipher.SumSHA256([]byte("bar"))}) - - go rw2.Write([]byte("bar")) // nolint: errcheck - - err = <-errCh - require.Error(t, err) - assert.Equal(t, "invalid CRC", err.Error()) - - buf = make([]byte, 3) - n, err = rw1.Read(buf) - require.NoError(t, err) - assert.Equal(t, 3, n) - assert.Equal(t, []byte("bar"), buf) - - require.NoError(t, rw1.Close()) - require.NoError(t, rw2.Close()) -} - -func TestAckReadWriterFlushOnClose(t *testing.T) { - in, out := net.Pipe() - rw1 := NewAckReadWriter(in, 100*time.Millisecond) - rw2 := NewAckReadWriter(out, 100*time.Millisecond) - - errCh := make(chan error) - go func() { - _, err := rw1.Write([]byte("foo")) - errCh <- err - }() - - buf := make([]byte, 3) - n, err := rw2.Read(buf) - require.NoError(t, err) - assert.Equal(t, 3, n) - assert.Equal(t, []byte("foo"), buf) - - require.NoError(t, rw2.Close()) - require.NoError(t, <-errCh) - - require.NoError(t, rw1.Close()) -} - -func TestAckReadWriterPartialRead(t *testing.T) { - in, out := net.Pipe() - rw1 := NewAckReadWriter(in, 100*time.Millisecond) - rw2 := NewAckReadWriter(out, 100*time.Millisecond) - - errCh := make(chan error) - go func() { - _, err := rw1.Write([]byte("foo")) - errCh <- err - }() - - buf := make([]byte, 2) - n, err := rw2.Read(buf) - require.NoError(t, err) - assert.Equal(t, 2, n) - assert.Equal(t, []byte("fo"), buf) - - n, err = rw2.Read(buf) - require.NoError(t, err) - assert.Equal(t, 1, n) - assert.Equal(t, []byte("o"), buf[:n]) - - require.NoError(t, rw2.Close()) - require.NoError(t, rw1.Close()) -} - -func TestAckReadWriterReadError(t *testing.T) { - in, out := net.Pipe() - rw := NewAckReadWriter(in, 100*time.Millisecond) - - errCh := make(chan error) - go func() { - _, err := rw.Read([]byte{}) - errCh <- err - }() - - require.NoError(t, out.Close()) - - err := <-errCh - require.Error(t, err) - assert.Equal(t, io.EOF, err) - - require.NoError(t, rw.Close()) -} diff --git a/pkg/app/app.go b/pkg/app/app.go index 5b0d70d98..a67441378 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -13,9 +13,6 @@ import ( "os/exec" "path/filepath" "sync" - "time" - - "github.com/skycoin/skywire/internal/ioutil" ) const ( @@ -255,13 +252,13 @@ func (app *App) confirmLoop(data []byte) error { type appConn struct { net.Conn - rw *ioutil.AckReadWriter + rw io.ReadWriteCloser laddr *Addr raddr *Addr } func newAppConn(conn net.Conn, laddr, raddr *Addr) *appConn { - return &appConn{conn, ioutil.NewAckReadWriter(conn, 100*time.Millisecond), laddr, raddr} + return &appConn{conn, conn, laddr, raddr} } func (conn *appConn) LocalAddr() net.Addr { diff --git a/pkg/messaging/channel.go b/pkg/messaging/channel.go index a14a67583..0cc95aa7e 100644 --- a/pkg/messaging/channel.go +++ b/pkg/messaging/channel.go @@ -7,7 +7,6 @@ import ( "io" "time" - "github.com/skycoin/skywire/internal/ioutil" "github.com/skycoin/skywire/internal/noise" "github.com/skycoin/skywire/pkg/cipher" "github.com/skycoin/skywire/pkg/transport" @@ -197,24 +196,3 @@ func (c *channel) readEncrypted(ctx context.Context, p []byte) (n int, err error return copy(p, data), nil } - -type ackedChannel struct { - *channel - rw *ioutil.AckReadWriter -} - -func newAckedChannel(c *channel) *ackedChannel { - return &ackedChannel{c, ioutil.NewAckReadWriter(c, 100*time.Millisecond)} -} - -func (c *ackedChannel) Write(p []byte) (n int, err error) { - return c.rw.Write(p) -} - -func (c *ackedChannel) Read(p []byte) (n int, err error) { - return c.rw.Read(p) -} - -func (c *ackedChannel) Close() error { - return c.rw.Close() -} diff --git a/pkg/messaging/client.go b/pkg/messaging/client.go index d82900217..51473829d 100644 --- a/pkg/messaging/client.go +++ b/pkg/messaging/client.go @@ -133,7 +133,7 @@ func (c *Client) Accept(ctx context.Context) (transport.Transport, error) { if !more { return nil, ErrClientClosed } - return newAckedChannel(ch), nil + return ch, nil case <-ctx.Done(): return nil, ctx.Err() } @@ -180,7 +180,7 @@ func (c *Client) Dial(ctx context.Context, remote cipher.PubKey) (transport.Tran } c.Logger.Infof("Opened new channel local ID %d, remote ID %d with %s", localID, channel.ID, remote) - return newAckedChannel(channel), nil + return channel, nil } // Local returns the local public key. diff --git a/pkg/messaging/pool_test.go b/pkg/messaging/pool_test.go index 00c5cc835..26292fdd3 100644 --- a/pkg/messaging/pool_test.go +++ b/pkg/messaging/pool_test.go @@ -1,3 +1,5 @@ +// +build !no_ci + package messaging import (