Skip to content

Commit

Permalink
Some fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
林志宇 committed Jun 1, 2019
1 parent e8d87f0 commit e6c5dab
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 41 deletions.
85 changes: 45 additions & 40 deletions pkg/dms/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,36 +96,45 @@ func (c *ClientConn) getTp(id uint16) (*Transport, bool) {
return tp, ok
}

func (c *ClientConn) handleRequestFrame(ctx context.Context, id uint16, p []byte) (*Transport, error) {
func (c *ClientConn) handleRequestFrame(ctx context.Context, accept chan<- *Transport, id uint16, p []byte) (cipher.PubKey, error) {
// remote-initiated tps should:
// - have a payload structured as 'init_pk:resp_pk'.
// - resp_pk should be of local client.
// - use an odd tp_id with the intermediary dms_server.
initPK, respPK, ok := splitPKs(p)
if !ok || respPK != c.local || isInitiatorID(id) {
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return nil, err
c.Close()
return initPK, err
}
return nil, ErrRequestCheckFailed
return initPK, ErrRequestCheckFailed
}

tp := NewTransport(c.Conn, c.local, initPK, id)
if err := tp.Handshake(ctx); err != nil {
// return err here as response handshake is send via ClientConn and that shouldn't fail.
return nil, err
c.Close()
return initPK, err
}
c.setTp(tp)

return tp, nil
select {
case accept <- tp:
case <-ctx.Done():
return initPK, ctx.Err()
}
return initPK, nil
}

// Serve handles incoming frames.
// Remote-initiated tps that are successfully created are pushing into 'accept' and exposed via 'Client.Accept()'.
func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) error {

c.wg.Add(1)
defer c.wg.Done()

log := c.log.WithField("remoteSrv", c.remoteSrv)
log := c.log.WithField("srv", c.remoteSrv)
defer log.Warnf("serveLoopClosed")

for {
f, err := readFrame(c.Conn)
Expand All @@ -136,45 +145,41 @@ func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) error
tp, ok := c.getTp(id)
log.Infof("readFrame: frameType(%v) channelID(%v) payloadLen(%v)", ft, id, f.PayLen())

// if tp does not exist, frame should be 'REQUEST'.
// otherwise, handle any unexpected frames accordingly.
if !ok {
c.delTp(id) // rm tp in case closed tp is not fully removed.
switch ft {
case RequestType:
tp, err := c.handleRequestFrame(ctx, id, p)
if err != nil {
log.WithError(err).Infof("transportRejected: remoteClient(%v) channelID(%v)", tp.remoteClient, tp.id)
if err == ErrRequestCheckFailed {
continue
}
return err
}
log.Infof("transportAccepted: remoteClient(%v) channelID(%v)", tp.remoteClient, tp.id)
select {
case accept <- tp:
case <-ctx.Done():
return ctx.Err()
}
case CloseType:
log.Infof("closeFrameIgnored: transport untracked locally.")
default:
log.Infof("unexpectedFrameReceived: transport untracked locally.")
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return err
}
if ok {
// If tp of tp_id exists, attempt to forward frame to tp.
// delete tp on any failure.
if !tp.AwaitRead(f) {
log.Infof("failed to injest to local_tp: id(%d) dstClient(%s)", id, tp.remoteClient)
c.delTp(id)
}
log.Infof("successfully injested to local_tp: id(%d) dstClient(%s)", id, tp.remoteClient)
continue
}

// If tp of tp_id exists, attempt to forward frame to tp.
// delete tp on any failure.

if !tp.AwaitRead(f) {
log.Infof("failed to injest to local_tp: id(%d) dstClient(%s)", id, tp.remoteClient)
c.delTp(id)
// if tp does not exist, frame should be 'REQUEST'.
// otherwise, handle any unexpected frames accordingly.
c.delTp(id) // rm tp in case closed tp is not fully removed.
switch ft {
case RequestType:
// TODO(evanlinjin): Allow for REQUEST frame handling to be done in goroutine.
// Currently this causes issues (probably because we need ACK frames).
initPK, err := c.handleRequestFrame(ctx, accept, id, p)
if err != nil {
log.WithError(err).Infof("transportRejected: channelID(%v) client(%s)", id, initPK)
if err == ErrRequestCheckFailed {
continue
}
return err
}
log.Infof("transportAccepted: channelID(%v) client(%s)", id, initPK)
case CloseType:
log.Infof("closeFrameIgnored: transport untracked locally.")
default:
log.Infof("unexpectedFrameReceived: transport untracked locally.")
if err := writeCloseFrame(c.Conn, id, 0); err != nil {
return err
}
}
log.Infof("successfully injested to local_tp: id(%d) dstClient(%s)", id, tp.remoteClient)
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/dms/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (c *Transport) close() (closed bool) {
// Handshake performs a tp handshake (before tp is considered valid).
func (c *Transport) Handshake(ctx context.Context) error {
// if channel ID is even, client is initiator.
if init := isInitiatorID(c.id); init {
if isInitiatorID(c.id) {

pks := combinePKs(c.local, c.remoteClient)
f := MakeFrame(RequestType, c.id, pks)
Expand Down

0 comments on commit e6c5dab

Please sign in to comment.