diff --git a/pkg/dmsg/client.go b/pkg/dmsg/client.go index 1dc4ab210..016994659 100644 --- a/pkg/dmsg/client.go +++ b/pkg/dmsg/client.go @@ -125,6 +125,20 @@ func (c *ClientConn) setNextInitID(nextInitID uint16) { c.mx.Unlock() } +func (c *ClientConn) readOK() error { + fr, err := readFrame(c.Conn) + if err != nil { + return errors.New("failed to get OK from server") + } + + ft, _, _ := fr.Disassemble() + if ft != OkType { + return fmt.Errorf("wrong frame from server: %v", ft) + } + + return nil +} + func (c *ClientConn) handleRequestFrame(accept chan<- *Transport, id uint16, p []byte) (cipher.PubKey, error) { // remotely-initiated tps should: // - have a payload structured as 'init_pk:resp_pk'. @@ -429,7 +443,12 @@ func (c *Client) findOrConnectToServer(ctx context.Context, srvPK cipher.PubKey) } conn := NewClientConn(c.log, nc, c.pk, srvPK) + if err := conn.waitOKFrame(); err != nil { + return nil, err + } + c.setConn(ctx, conn) + go func() { err := conn.Serve(ctx, c.accept) conn.log.WithError(err).WithField("remoteServer", srvPK).Warn("connected with server closed") diff --git a/pkg/dmsg/frame.go b/pkg/dmsg/frame.go index d3c2f31ac..43a1c53b9 100644 --- a/pkg/dmsg/frame.go +++ b/pkg/dmsg/frame.go @@ -52,6 +52,7 @@ func (ft FrameType) String() string { CloseType: "CLOSE", FwdType: "FWD", AckType: "ACK", + OkType: "OK", } if int(ft) >= len(names) { return fmt.Sprintf("UNKNOWN:%d", ft) @@ -61,6 +62,7 @@ func (ft FrameType) String() string { // Frame types. const ( + OkType = FrameType(0x0) RequestType = FrameType(0x1) AcceptType = FrameType(0x2) CloseType = FrameType(0x3) diff --git a/pkg/dmsg/server.go b/pkg/dmsg/server.go index 5c67eb2cd..8b1e66b44 100644 --- a/pkg/dmsg/server.go +++ b/pkg/dmsg/server.go @@ -114,6 +114,11 @@ func (c *ServerConn) Serve(ctx context.Context, getConn getConnFunc) (err error) }() log.WithField("connCount", incrementServeCount()).Infoln("ServingConn") + err = c.sendOK() + if err != nil { + return fmt.Errorf("sending OK failed: %s", err) + } + for { f, err := readFrame(c.Conn) if err != nil { @@ -170,6 +175,13 @@ func (c *ServerConn) delChan(id uint16, why byte) error { return nil } +func (c *ServerConn) writeOK() error { + if err := writeFrame(c.Conn, MakeFrame(OkType, 0, nil)); err != nil { + return err + } + return nil +} + func (c *ServerConn) forwardFrame(ft FrameType, id uint16, p []byte) (*NextConn, byte, bool) { //nolint:unparam next, ok := c.getNext(id) if !ok {