Skip to content

Commit

Permalink
comments and various tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
林志宇 committed May 31, 2019
1 parent 5a7d4bd commit 183decd
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 87 deletions.
158 changes: 98 additions & 60 deletions pkg/dms/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,31 @@ import (
var (
// ErrNoSrv indicate that remote client does not have DelegatedServers in entry.
ErrNoSrv = errors.New("remote has no DelegatedServers")
// ErrRejected indicates that ChannelOpen frame was rejected by remote server.
ErrRejected = errors.New("rejected")
// ErrChannelClosed indicates that underlying channel is being closed and writes are prohibited.
ErrChannelClosed = errors.New("channel closed")
// ErrDeadlineExceeded indicates that read/write operation failed due to timeout.
ErrDeadlineExceeded = errors.New("deadline exceeded")
// ErrClientClosed indicates that client is closed and not accepting new connections.
ErrClientClosed = errors.New("client closed")
)

// Conn represents a connection between a dms.Client and dms.Server from a client's perspective.
type Conn struct {
log *logging.Logger
net.Conn // conn to dms server
local cipher.PubKey // local client's pk
remoteSrv cipher.PubKey // dms server's public key
nextID uint16 // next unused channel ID
tps [math.MaxUint16]*Transport // channels to dms clients
mx sync.RWMutex
wg sync.WaitGroup
log *logging.Logger

net.Conn // conn to dms server
local cipher.PubKey // local client's pk
remoteSrv cipher.PubKey // dms server's public key

// nextID keeps track of unused tp_ids to assign a future locally-initiated tp.
// locally-initiated tps use an even tp_id between local and intermediary dms_server.
nextID uint16

// map of transports to remote dms_clients (key: tp_id, val: transport).
tps [math.MaxUint16]*Transport
mx sync.RWMutex // to protect tps.

// awaits .Serve() to end before considering properly closed.
wg sync.WaitGroup
}

// NewConn creates a new Conn.
func NewConn(log *logging.Logger, conn net.Conn, local, remote cipher.PubKey) *Conn {
return &Conn{log: log, Conn: conn, local: local, remoteSrv: remote, nextID: 0}
}
Expand All @@ -53,12 +56,14 @@ func (c *Conn) delTp(id uint16) {
c.mx.Unlock()
}

func (c *Conn) setTp(ch *Transport) {
func (c *Conn) setTp(tp *Transport) {
c.mx.Lock()
c.tps[ch.id] = ch
c.tps[tp.id] = tp
c.mx.Unlock()
}

// keeps record of a locally-initiated tp to 'clientPK'.
// assigns an even tp_id and keeps track of it in tps map.
func (c *Conn) addTp(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) {
c.mx.Lock()
defer c.mx.Unlock()
Expand Down Expand Up @@ -91,23 +96,30 @@ func (c *Conn) getTp(id uint16) (*Transport, bool) {
return tp, ok
}

// local: local client pk (also responding client).
// remote: remote client pk (also initiating client).
func checkRequest(local cipher.PubKey, id uint16, p []byte) (remote cipher.PubKey, ok bool) {
// server-initiated channels should have odd channel ID
if isEven(id) {
return cipher.PubKey{}, false
}

// check expected request payload
func (c *Conn) handleRequestFrame(ctx context.Context, id uint16, p []byte) (*Transport, error) {
// remote-initiated tps should:
// - have a payload structured as 'init_pk:resp_pk'.
// - resp_pk should be of local client.
// - use an odd tp_id with the intermediary dms_server.
initPK, respPK, ok := splitPKs(p)
if !ok || respPK != local {
return cipher.PubKey{}, false
if !ok || respPK != c.local || isEven(id) {
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return nil, err
}
return nil, ErrRequestCheckFailed
}

return initPK, true
tp := NewTransport(c.Conn, c.local, initPK, id)
if err := tp.Handshake(ctx); err != nil {
// return err here as response handshake is send via Conn and that shouldn't fail.
return nil, err
}
c.setTp(tp)
return tp, nil
}

// Serve handles incoming frames.
// Remote-initiated tps that are successfully created are pushing into 'accept' and exposed via 'Client.Accept()'.
func (c *Conn) Serve(ctx context.Context, accept chan<- *Transport) error {
c.wg.Add(1)
defer c.wg.Done()
Expand All @@ -123,36 +135,46 @@ func (c *Conn) Serve(ctx context.Context, accept chan<- *Transport) error {
tp, ok := c.getTp(id)
log.Infof("readFrame: frameType(%v) channelID(%v) payloadLen(%v)", ft, id, f.PayLen())

// if tp does not exist, frame should be 'REQUEST'.
// otherwise, handle any unexpected frames accordingly.
if !ok {
c.delTp(id)
c.delTp(id) // rm tp in case closed tp is not fully removed.
switch ft {
case RequestType:
remote, ok := checkRequest(c.local, id, p)
if !ok {
_ = writeFrame(c.Conn, MakeFrame(CloseType, id, []byte{0}))
} else {
tp = NewTransport(c.Conn, c.local, remote, id)
c.setTp(tp)
if err := tp.Handshake(ctx); err != nil {
return err
}
select {
case accept <- tp:
log.Infof("channelAccepted: remoteClient(%v) channelID(%v)", tp.remoteClient, tp.id)
case <-ctx.Done():
return ctx.Err()
tp, err := c.handleRequestFrame(ctx, id, p)
if err != nil {
log.WithError(err).Infof("transportRejected: remoteClient(%v) channelID(%v)", tp.remoteClient, tp.id)
if err == ErrRequestCheckFailed {
continue
}
return err
}
log.Infof("transportAccepted: remoteClient(%v) channelID(%v)", tp.remoteClient, tp.id)
select {
case accept <- tp:
case <-ctx.Done():
return ctx.Err()
}
case CloseType:
log.Infof("closeFrameIgnored: transport untracked locally.")
default:
_ = writeFrame(c.Conn, MakeFrame(CloseType, id, []byte{0}))
log.Infof("unexpectedFrameReceived: transport untracked locally.")
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return err
}
}
} else if !tp.AwaitRead(f) {
continue
}

// If tp of tp_id exists, attempt to forward frame to tp.
// delete tp on any failure.
if !tp.AwaitRead(f) {
c.delTp(id)
}
}
}

// DialTransport dials a transport to remote dms_client.
func (c *Conn) DialTransport(ctx context.Context, clientPK cipher.PubKey) (*Transport, error) {
tp, err := c.addTp(ctx, clientPK)
if err != nil {
Expand All @@ -161,12 +183,13 @@ func (c *Conn) DialTransport(ctx context.Context, clientPK cipher.PubKey) (*Tran
return tp, tp.Handshake(ctx)
}

// Close closes the connection to dms_server.
func (c *Conn) Close() error {
c.log.Infof("closingLink: remoteSrv(%v)", c.remoteSrv)
c.mx.Lock()
for _, ch := range c.tps {
if ch != nil {
_ = ch.Close()
for _, tp := range c.tps {
if tp != nil {
_ = tp.Close()
}
}
err := c.Conn.Close()
Expand All @@ -175,6 +198,7 @@ func (c *Conn) Close() error {
return err
}

// Client implements transport.Factory
type Client struct {
log *logging.Logger

Expand All @@ -189,6 +213,7 @@ type Client struct {
once sync.Once
}

// NewClient creates a new Client.
func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc client.APIClient) *Client {
return &Client{
log: logging.MustGetLogger("dms_client"),
Expand All @@ -200,28 +225,31 @@ func NewClient(pk cipher.PubKey, sk cipher.SecKey, dc client.APIClient) *Client
}
}

// SetLogger sets the dms_client's logger.
func (c *Client) SetLogger(log *logging.Logger) {
c.log = log
}

func (c *Client) setConn(l *Conn) {
c.mx.Lock()
c.conns[l.remoteSrv] = l
c.mx.Unlock()
}
// TODO: re-connect logic.
//func (c *Client) setConn(l *Conn) {
// c.mx.Lock()
// c.conns[l.remoteSrv] = l
// c.mx.Unlock()
//}

func (c *Client) delConn(pk cipher.PubKey) {
c.mx.Lock()
delete(c.conns, pk)
c.mx.Unlock()
}

func (c *Client) getConn(pk cipher.PubKey) (*Conn, bool) {
c.mx.RLock()
l, ok := c.conns[pk]
c.mx.RUnlock()
return l, ok
}
// TODO: re-connect logic.
//func (c *Client) getConn(pk cipher.PubKey) (*Conn, bool) {
// c.mx.RLock()
// l, ok := c.conns[pk]
// c.mx.RUnlock()
// return l, ok
//}

func (c *Client) newConn(ctx context.Context, srvPK cipher.PubKey, addr string) (*Conn, error) {
conn, err := net.Dial("tcp", addr)
Expand All @@ -245,12 +273,16 @@ func (c *Client) newConn(ctx context.Context, srvPK cipher.PubKey, addr string)
go func() {
if err := l.Serve(ctx, c.accept); err != nil {
l.log.WithError(err).WithField("srv_pk", l.remoteSrv).Warn("link with server closed")
if err := c.updateDiscEntry(ctx); err != nil {
c.log.WithError(err).Error("failed to update entry after server close.")
}
c.delConn(l.remoteSrv)
}
}()
return l, nil
}

// InitiateServers initiates connections with dms_servers.
func (c *Client) InitiateServers(ctx context.Context, n int) error {
if n == 0 {
return nil
Expand All @@ -262,7 +294,8 @@ func (c *Client) InitiateServers(ctx context.Context, n int) error {
select {
case <-ctx.Done():
return fmt.Errorf("messaging servers are not available: %s", err)
case <-time.Tick(time.Second):
default:
time.Sleep(time.Second)
continue
}
}
Expand Down Expand Up @@ -332,6 +365,7 @@ func (c *Client) updateDiscEntry(ctx context.Context) error {
return c.dc.UpdateEntry(ctx, c.sk, entry)
}

// Accept accepts remotely-initiated tps.
func (c *Client) Accept(ctx context.Context) (transport.Transport, error) {
select {
case tp, ok := <-c.accept:
Expand All @@ -344,6 +378,7 @@ func (c *Client) Accept(ctx context.Context) (transport.Transport, error) {
}
}

// Dial dials a transport to remote dms_client.
func (c *Client) Dial(ctx context.Context, remote cipher.PubKey) (transport.Transport, error) {
c.mx.Lock()
defer c.mx.Unlock()
Expand All @@ -362,14 +397,17 @@ func (c *Client) Dial(ctx context.Context, remote cipher.PubKey) (transport.Tran
return conn.DialTransport(ctx, remote)
}

// Local returns the local dms_client's public key.
func (c *Client) Local() cipher.PubKey {
return c.pk
}

// Type returns the transport type.
func (c *Client) Type() string {
return Type
}

// Close closes the dms_client and associated connections.
// TODO(evaninjin): proper error handling.
func (c *Client) Close() error {
c.mx.Lock()
Expand Down
25 changes: 21 additions & 4 deletions pkg/dms/frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
)

const (
// Type returns the transport type string.
Type = "dms"

hsTimeout = time.Second * 10
Expand All @@ -19,6 +20,7 @@ const (

func isEven(chID uint16) bool { return chID%2 == 0 }

// FrameType represents the frame type.
type FrameType byte

func (ft FrameType) String() string {
Expand All @@ -34,15 +36,18 @@ func (ft FrameType) String() string {
return names[ft]
}

// Frame types.
const (
RequestType = FrameType(1)
AcceptType = FrameType(2)
CloseType = FrameType(3)
SendType = FrameType(10)
)

// Frame is the dms data unit.
type Frame []byte

// MakeFrame creates a new Frame.
func MakeFrame(ft FrameType, chID uint16, pay []byte) Frame {
f := make(Frame, headerLen+len(pay))
f[0] = byte(ft)
Expand All @@ -52,13 +57,21 @@ func MakeFrame(ft FrameType, chID uint16, pay []byte) Frame {
return f
}

// Type returns the frame's type.
func (f Frame) Type() FrameType { return FrameType(f[0]) }
func (f Frame) ChID() uint16 { return binary.BigEndian.Uint16(f[1:3]) }
func (f Frame) PayLen() int { return int(binary.BigEndian.Uint16(f[3:5])) }
func (f Frame) Pay() []byte { return f[headerLen:] }

// TpID returns the frame's tp_id.
func (f Frame) TpID() uint16 { return binary.BigEndian.Uint16(f[1:3]) }

// PayLen returns the expected payload len.
func (f Frame) PayLen() int { return int(binary.BigEndian.Uint16(f[3:5])) }

// Pay returns the payload.
func (f Frame) Pay() []byte { return f[headerLen:] }

// Disassemble splits the frame into fields.
func (f Frame) Disassemble() (ft FrameType, id uint16, p []byte) {
return f.Type(), f.ChID(), f.Pay()
return f.Type(), f.TpID(), f.Pay()
}

func readFrame(r io.Reader) (Frame, error) {
Expand All @@ -76,6 +89,10 @@ func writeFrame(w io.Writer, f Frame) error {
return err
}

func writeCloseFrame(w io.Writer, id uint16, reason byte) error {
return writeFrame(w, MakeFrame(CloseType, id, []byte{reason}))
}

func combinePKs(initPK, respPK cipher.PubKey) []byte {
return append(initPK[:], respPK[:]...)
}
Expand Down
Loading

0 comments on commit 183decd

Please sign in to comment.