diff --git a/internal/vpn/client.go b/internal/vpn/client.go index 675bbac269..1669e5d636 100644 --- a/internal/vpn/client.go +++ b/internal/vpn/client.go @@ -638,14 +638,16 @@ func (c *Client) shakeHands(conn net.Conn) (TUNIP, TUNGateway net.IP, err error) Passcode: c.cfg.Passcode, } + const handshakeTimeout = 5 * time.Second + fmt.Printf("Sending client hello: %v\n", cHello) - if err := WriteJSON(conn, &cHello); err != nil { + if err := WriteJSONWithTimeout(conn, &cHello, handshakeTimeout); err != nil { return nil, nil, fmt.Errorf("error sending client hello: %w", err) } var sHello ServerHello - if err := ReadJSON(conn, &sHello); err != nil { + if err := ReadJSONWithTimeout(conn, &sHello, handshakeTimeout); err != nil { return nil, nil, fmt.Errorf("error reading server hello: %w", err) } diff --git a/internal/vpn/net.go b/internal/vpn/net.go index 1a992ff7a9..8dd9b89c36 100644 --- a/internal/vpn/net.go +++ b/internal/vpn/net.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "time" "github.com/skycoin/dmsg/cipher" "github.com/skycoin/dmsg/noise" @@ -12,6 +13,23 @@ import ( "github.com/skycoin/skywire/pkg/app/appnet" ) +// WriteJSONWithTimeout marshals `data` and sends it over the `conn` with the specified write `timeout`. +func WriteJSONWithTimeout(conn net.Conn, data interface{}, timeout time.Duration) error { + if err := conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil { + return fmt.Errorf("failed to set write deadline: %w", err) + } + + if err := WriteJSON(conn, data); err != nil { + return err + } + + if err := conn.SetWriteDeadline(time.Time{}); err != nil { + return fmt.Errorf("failed to remove write deadline: %w", err) + } + + return nil +} + // WriteJSON marshals `data` and sends it over the `conn`. func WriteJSON(conn net.Conn, data interface{}) error { dataBytes, err := json.Marshal(data) @@ -31,6 +49,24 @@ func WriteJSON(conn net.Conn, data interface{}) error { return nil } +// ReadJSONWithTimeout reads portion of data from the `conn` and unmarshals it into `data` with the +// specified read `timeout`. +func ReadJSONWithTimeout(conn net.Conn, data interface{}, timeout time.Duration) error { + if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil { + return fmt.Errorf("failed to set read deadline: %w", err) + } + + if err := ReadJSON(conn, data); err != nil { + return err + } + + if err := conn.SetReadDeadline(time.Time{}); err != nil { + return fmt.Errorf("failed to remove read deadline: %w", err) + } + + return nil +} + // ReadJSON reads portion of data from the `conn` and unmarshals it into `data`. func ReadJSON(conn net.Conn, data interface{}) error { const bufSize = 1024