diff --git a/pkg/dms/client.go b/pkg/dms/client.go index d3ebd7d16e..4fbf79a81b 100644 --- a/pkg/dms/client.go +++ b/pkg/dms/client.go @@ -96,7 +96,7 @@ func (c *ClientConn) getTp(id uint16) (*Transport, bool) { return tp, ok } -func (c *ClientConn) handleRequestFrame(ctx context.Context, id uint16, p []byte) (*Transport, error) { +func (c *ClientConn) handleRequestFrame(ctx context.Context, accept chan<- *Transport, id uint16, p []byte) (cipher.PubKey, error) { // remote-initiated tps should: // - have a payload structured as 'init_pk:resp_pk'. // - resp_pk should be of local client. @@ -104,28 +104,37 @@ func (c *ClientConn) handleRequestFrame(ctx context.Context, id uint16, p []byte initPK, respPK, ok := splitPKs(p) if !ok || respPK != c.local || isInitiatorID(id) { if err := writeCloseFrame(c.Conn, id, 0); err != nil { - return nil, err + c.Close() + return initPK, err } - return nil, ErrRequestCheckFailed + return initPK, ErrRequestCheckFailed } tp := NewTransport(c.Conn, c.local, initPK, id) if err := tp.Handshake(ctx); err != nil { // return err here as response handshake is send via ClientConn and that shouldn't fail. - return nil, err + c.Close() + return initPK, err } c.setTp(tp) - return tp, nil + select { + case accept <- tp: + case <-ctx.Done(): + return initPK, ctx.Err() + } + return initPK, nil } // Serve handles incoming frames. // Remote-initiated tps that are successfully created are pushing into 'accept' and exposed via 'Client.Accept()'. func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) error { + c.wg.Add(1) defer c.wg.Done() - log := c.log.WithField("remoteSrv", c.remoteSrv) + log := c.log.WithField("srv", c.remoteSrv) + defer log.Warnf("serveLoopClosed") for { f, err := readFrame(c.Conn) @@ -136,45 +145,41 @@ func (c *ClientConn) Serve(ctx context.Context, accept chan<- *Transport) error tp, ok := c.getTp(id) log.Infof("readFrame: frameType(%v) channelID(%v) payloadLen(%v)", ft, id, f.PayLen()) - // if tp does not exist, frame should be 'REQUEST'. - // otherwise, handle any unexpected frames accordingly. - if !ok { - c.delTp(id) // rm tp in case closed tp is not fully removed. - switch ft { - case RequestType: - tp, err := c.handleRequestFrame(ctx, id, p) - if err != nil { - log.WithError(err).Infof("transportRejected: remoteClient(%v) channelID(%v)", tp.remoteClient, tp.id) - if err == ErrRequestCheckFailed { - continue - } - return err - } - log.Infof("transportAccepted: remoteClient(%v) channelID(%v)", tp.remoteClient, tp.id) - select { - case accept <- tp: - case <-ctx.Done(): - return ctx.Err() - } - case CloseType: - log.Infof("closeFrameIgnored: transport untracked locally.") - default: - log.Infof("unexpectedFrameReceived: transport untracked locally.") - if err := writeCloseFrame(c.Conn, id, 0); err != nil { - return err - } + if ok { + // If tp of tp_id exists, attempt to forward frame to tp. + // delete tp on any failure. + if !tp.AwaitRead(f) { + log.Infof("failed to injest to local_tp: id(%d) dstClient(%s)", id, tp.remoteClient) + c.delTp(id) } + log.Infof("successfully injested to local_tp: id(%d) dstClient(%s)", id, tp.remoteClient) continue } - // If tp of tp_id exists, attempt to forward frame to tp. - // delete tp on any failure. - - if !tp.AwaitRead(f) { - log.Infof("failed to injest to local_tp: id(%d) dstClient(%s)", id, tp.remoteClient) - c.delTp(id) + // if tp does not exist, frame should be 'REQUEST'. + // otherwise, handle any unexpected frames accordingly. + c.delTp(id) // rm tp in case closed tp is not fully removed. + switch ft { + case RequestType: + // TODO(evanlinjin): Allow for REQUEST frame handling to be done in goroutine. + // Currently this causes issues (probably because we need ACK frames). + initPK, err := c.handleRequestFrame(ctx, accept, id, p) + if err != nil { + log.WithError(err).Infof("transportRejected: channelID(%v) client(%s)", id, initPK) + if err == ErrRequestCheckFailed { + continue + } + return err + } + log.Infof("transportAccepted: channelID(%v) client(%s)", id, initPK) + case CloseType: + log.Infof("closeFrameIgnored: transport untracked locally.") + default: + log.Infof("unexpectedFrameReceived: transport untracked locally.") + if err := writeCloseFrame(c.Conn, id, 0); err != nil { + return err + } } - log.Infof("successfully injested to local_tp: id(%d) dstClient(%s)", id, tp.remoteClient) } } diff --git a/pkg/dms/transport.go b/pkg/dms/transport.go index e36e7fcf84..b8cb514313 100644 --- a/pkg/dms/transport.go +++ b/pkg/dms/transport.go @@ -75,7 +75,7 @@ func (c *Transport) close() (closed bool) { // Handshake performs a tp handshake (before tp is considered valid). func (c *Transport) Handshake(ctx context.Context) error { // if channel ID is even, client is initiator. - if init := isInitiatorID(c.id); init { + if isInitiatorID(c.id) { pks := combinePKs(c.local, c.remoteClient) f := MakeFrame(RequestType, c.id, pks)