Skip to content

Commit

Permalink
Merge pull request #349 from nkryuchkov/bug/retry-with-correct-nonce-345
Browse files Browse the repository at this point in the history
BUG: Ensure modules using internal/httpauth retries with correct nonce after 401.
  • Loading branch information
志宇 authored May 16, 2019
2 parents 1c1c6ca + 051119c commit 627e4c3
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 29 deletions.
81 changes: 68 additions & 13 deletions internal/httpauth/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ import (
"github.com/skycoin/skywire/pkg/cipher"
)

const (
invalidNonceErrorMessage = "SW-Nonce does not match"
)

// NextNonceResponse represents a ServeHTTP response for json encoding
type NextNonceResponse struct {
Edge cipher.PubKey `json:"edge"`
Expand All @@ -40,8 +44,8 @@ type Client struct {
client http.Client
key cipher.PubKey
sec cipher.SecKey
Addr string // sanitized address of the client, which may differ from addr used in NewClient
nonce uint64
addr string // sanitized address of the client, which may differ from addr used in NewClient
nonce uint64 // has to be handled with the atomic package at all time
}

// NewClient creates a new client setting a public key to the client to be used for Auth.
Expand All @@ -55,7 +59,7 @@ func NewClient(ctx context.Context, addr string, key cipher.PubKey, sec cipher.S
client: http.Client{},
key: key,
sec: sec,
Addr: sanitizedAddr(addr),
addr: sanitizedAddr(addr),
}

// request server for a nonce
Expand All @@ -71,11 +75,6 @@ func NewClient(ctx context.Context, addr string, key cipher.PubKey, sec cipher.S
// Do performs a new authenticated Request and returns the response. Internally, if the request was
// successful nonce is incremented
func (c *Client) Do(req *http.Request) (*http.Response, error) {
req.Header.Add("SW-Public", c.key.Hex())

// use nonce, later, if no err from req update such nonce
req.Header.Add("SW-Nonce", strconv.FormatUint(uint64(c.getCurrentNonce()), 10))

body := make([]byte, 0)
if req.ContentLength != 0 {
auxBody, err := ioutil.ReadAll(req.Body)
Expand All @@ -86,18 +85,31 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
req.Body = ioutil.NopCloser(bytes.NewBuffer(auxBody))
body = auxBody
}
sign, err := Sign(body, c.getCurrentNonce(), c.sec)

res, err := c.doRequest(req, body)
if err != nil {
return nil, err
}

req.Header.Add("SW-Sig", sign.Hex())

res, err := c.client.Do(req)
isNonceValid, err := isNonceValid(res)
if err != nil {
return nil, err
}

if !isNonceValid {
nonce, err := c.Nonce(context.Background(), c.key)
if err != nil {
return nil, err
}
c.SetNonce(nonce)

res.Body.Close()
res, err = c.doRequest(req, body)
if err != nil {
return nil, err
}
}

if res.StatusCode == http.StatusOK {
c.incrementNonce()
}
Expand All @@ -107,7 +119,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {

// Nonce calls the remote API to retrieve the next expected nonce
func (c *Client) Nonce(ctx context.Context, key cipher.PubKey) (Nonce, error) {
req, err := http.NewRequest(http.MethodGet, c.Addr+"/security/nonces/"+key.Hex(), nil)
req, err := http.NewRequest(http.MethodGet, c.addr+"/security/nonces/"+key.Hex(), nil)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -136,6 +148,26 @@ func (c *Client) SetNonce(n Nonce) {
atomic.StoreUint64(&c.nonce, uint64(n))
}

// Addr returns sanitized address of the client
func (c *Client) Addr() string {
return c.addr
}

func (c *Client) doRequest(req *http.Request, body []byte) (*http.Response, error) {
nonce := c.getCurrentNonce()
sign, err := Sign(body, nonce, c.sec)
if err != nil {
return nil, err
}

// use nonce, later, if no err from req update such nonce
req.Header.Set("SW-Nonce", strconv.FormatUint(uint64(nonce), 10))
req.Header.Set("SW-Sig", sign.Hex())
req.Header.Set("SW-Public", c.key.Hex())

return c.client.Do(req)
}

func (c *Client) getCurrentNonce() Nonce {
return Nonce(atomic.LoadUint64(&c.nonce))
}
Expand All @@ -144,6 +176,29 @@ func (c *Client) incrementNonce() {
atomic.AddUint64(&c.nonce, 1)
}

// isNonceValid checks if `res` contains an invalid nonce error.
// The error is occurred if status code equals to `http.StatusUnauthorized`
// and body contains `invalidNonceErrorMessage`.
func isNonceValid(res *http.Response) (bool, error) {
var serverResponse HTTPResponse

auxRespBody, err := ioutil.ReadAll(res.Body)
if err != nil {
return false, err
}
res.Body.Close()
res.Body = ioutil.NopCloser(bytes.NewBuffer(auxRespBody))

if err := json.Unmarshal(auxRespBody, &serverResponse); err != nil || serverResponse.Error == nil {
return true, nil
}

isAuthorized := serverResponse.Error.Code != http.StatusUnauthorized
hasValidNonce := serverResponse.Error.Message != invalidNonceErrorMessage

return isAuthorized && hasValidNonce, nil
}

func sanitizedAddr(addr string) string {
if addr == "" {
return "http://localhost"
Expand Down
86 changes: 72 additions & 14 deletions internal/httpauth/client_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package httpauth

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"strconv"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -15,34 +17,90 @@ import (
"github.com/skycoin/skywire/pkg/cipher"
)

const (
payload = "Hello, client\n"
errorMessage = `{"error":{"message":"SW-Nonce does not match","code":401}}`
)

func TestClient(t *testing.T) {
pk, sk := cipher.GenerateKeyPair()

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == fmt.Sprintf("/security/nonces/%s", pk) {
json.NewEncoder(w).Encode(&NextNonceResponse{pk, 1}) // nolint: errcheck
} else {
require.Equal(t, "1", r.Header.Get("Sw-Nonce"))
require.Equal(t, pk.Hex(), r.Header.Get("Sw-Public"))
sig := cipher.Sig{}
require.NoError(t, sig.UnmarshalText([]byte(r.Header.Get("Sw-Sig"))))
require.NoError(t, cipher.VerifyPubKeySignedPayload(pk, sig, PayloadWithNonce([]byte{}, 1)))
fmt.Fprintln(w, "Hello, client")
}
}))
headerCh := make(chan http.Header, 1)
ts := newTestServer(pk, headerCh)
defer ts.Close()

c, err := NewClient(context.TODO(), ts.URL, pk, sk)
require.NoError(t, err)

req, err := http.NewRequest("GET", ts.URL+"/foo", bytes.NewBufferString(payload))
require.NoError(t, err)
res, err := c.Do(req)
require.NoError(t, err)

b, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
res.Body.Close()
assert.Equal(t, []byte(payload), b)
assert.Equal(t, uint64(2), c.nonce)

headers := <-headerCh
checkResp(t, headers, b, pk, 1)
}

// TestClient_BadNonce tests if `Client` retries the request if an invalid nonce is set.
func TestClient_BadNonce(t *testing.T) {
pk, sk := cipher.GenerateKeyPair()

headerCh := make(chan http.Header, 1)
ts := newTestServer(pk, headerCh)
defer ts.Close()

c, err := NewClient(context.TODO(), ts.URL, pk, sk)
require.NoError(t, err)

req, err := http.NewRequest("GET", ts.URL+"/foo", nil)
c.nonce = 999

req, err := http.NewRequest("GET", ts.URL+"/foo", bytes.NewBufferString(payload))
require.NoError(t, err)
res, err := c.Do(req)
require.NoError(t, err)

b, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
res.Body.Close()
assert.Equal(t, []byte("Hello, client\n"), b)
assert.Equal(t, uint64(2), c.nonce)

headers := <-headerCh
checkResp(t, headers, b, pk, 1)
}

func checkResp(t *testing.T, headers http.Header, body []byte, pk cipher.PubKey, nonce int) {
require.Equal(t, strconv.Itoa(nonce), headers.Get("Sw-Nonce"))
require.Equal(t, pk.Hex(), headers.Get("Sw-Public"))
sig := cipher.Sig{}
require.NoError(t, sig.UnmarshalText([]byte(headers.Get("Sw-Sig"))))
require.NoError(t, cipher.VerifyPubKeySignedPayload(pk, sig, PayloadWithNonce(body, Nonce(nonce))))
}

func newTestServer(pk cipher.PubKey, headerCh chan<- http.Header) *httptest.Server {
nonce := 1
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == fmt.Sprintf("/security/nonces/%s", pk) {
json.NewEncoder(w).Encode(&NextNonceResponse{pk, Nonce(nonce)}) // nolint: errcheck
} else {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
defer r.Body.Close()
respMessage := string(body)
if r.Header.Get("Sw-Nonce") != strconv.Itoa(nonce) {
respMessage = errorMessage
} else {
headerCh <- r.Header
nonce++
}
fmt.Fprint(w, respMessage)
}
}))
}
4 changes: 2 additions & 2 deletions pkg/transport-discovery/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (c *apiClient) Post(ctx context.Context, path string, payload interface{})
return nil, err
}

req, err := http.NewRequest("POST", c.client.Addr+path, body)
req, err := http.NewRequest("POST", c.client.Addr()+path, body)
if err != nil {
return nil, err
}
Expand All @@ -63,7 +63,7 @@ func (c *apiClient) Post(ctx context.Context, path string, payload interface{})

// Get performs a new GET request.
func (c *apiClient) Get(ctx context.Context, path string) (*http.Response, error) {
req, err := http.NewRequest("GET", c.client.Addr+path, new(bytes.Buffer))
req, err := http.NewRequest("GET", c.client.Addr()+path, new(bytes.Buffer))
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 627e4c3

Please sign in to comment.