Skip to content

Commit

Permalink
Request a new nonce and retry HTTP request if a nonce validation erro…
Browse files Browse the repository at this point in the history
…r is returned
  • Loading branch information
nkryuchkov committed May 12, 2019
1 parent 4443360 commit 8860264
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 27 deletions.
73 changes: 60 additions & 13 deletions internal/httpauth/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,16 @@ import (
"strconv"
"strings"
"sync/atomic"
"time"

"github.com/skycoin/skywire/pkg/cipher"
)

const (
retryInterval = 100 * time.Millisecond
invalidNonceErrorMessage = "SW-Nonce does not match"
)

// NextNonceResponse represents a ServeHTTP response for json encoding
type NextNonceResponse struct {
Edge cipher.PubKey `json:"edge"`
Expand Down Expand Up @@ -71,11 +77,6 @@ func NewClient(ctx context.Context, addr string, key cipher.PubKey, sec cipher.S
// Do performs a new authenticated Request and returns the response. Internally, if the request was
// successful nonce is incremented
func (c *Client) Do(req *http.Request) (*http.Response, error) {
req.Header.Add("SW-Public", c.key.Hex())

// use nonce, later, if no err from req update such nonce
req.Header.Add("SW-Nonce", strconv.FormatUint(uint64(c.getCurrentNonce()), 10))

body := make([]byte, 0)
if req.ContentLength != 0 {
auxBody, err := ioutil.ReadAll(req.Body)
Expand All @@ -86,16 +87,42 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
req.Body = ioutil.NopCloser(bytes.NewBuffer(auxBody))
body = auxBody
}
sign, err := Sign(body, c.getCurrentNonce(), c.sec)
if err != nil {
return nil, err
}

req.Header.Add("SW-Sig", sign.Hex())
var res *http.Response
for {
nonce := c.getCurrentNonce()
sign, err := Sign(body, nonce, c.sec)
if err != nil {
return nil, err
}

// use nonce, later, if no err from req update such nonce
req.Header.Set("SW-Nonce", strconv.FormatUint(uint64(nonce), 10))
req.Header.Set("SW-Sig", sign.Hex())
req.Header.Set("SW-Public", c.key.Hex())

res, err := c.client.Do(req)
if err != nil {
return nil, err
res, err = c.client.Do(req)
if err != nil {
return nil, err
}

isNonceValid, err := c.isNonceValid(res)
if err != nil {
return nil, err
}

if isNonceValid {
break
}

nonce, err = c.Nonce(context.Background(), c.key)
if err != nil {
return nil, err
}
c.SetNonce(nonce)

res.Body.Close()
time.Sleep(retryInterval)
}

if res.StatusCode == http.StatusOK {
Expand Down Expand Up @@ -144,6 +171,26 @@ func (c *Client) incrementNonce() {
atomic.AddUint64(&c.nonce, 1)
}

func (c *Client) isNonceValid(res *http.Response) (bool, error) {
var serverResponse HTTPResponse

auxRespBody, err := ioutil.ReadAll(res.Body)
if err != nil {
return false, err
}
res.Body.Close()
res.Body = ioutil.NopCloser(bytes.NewBuffer(auxRespBody))

if err := json.Unmarshal(auxRespBody, &serverResponse); err != nil || serverResponse.Error == nil {
return true, nil
}

isAuthorized := serverResponse.Error.Code != http.StatusUnauthorized
hasValidNonce := serverResponse.Error.Message != invalidNonceErrorMessage

return isAuthorized && hasValidNonce, nil
}

func sanitizedAddr(addr string) string {
if addr == "" {
return "http://localhost"
Expand Down
85 changes: 71 additions & 14 deletions internal/httpauth/client_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package httpauth

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"strconv"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -15,34 +17,89 @@ import (
"github.com/skycoin/skywire/pkg/cipher"
)

const (
payload = "Hello, client\n"
errorMessage = `{"error":{"message":"SW-Nonce does not match","code":401}}`
)

func TestClient(t *testing.T) {
pk, sk := cipher.GenerateKeyPair()

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == fmt.Sprintf("/security/nonces/%s", pk) {
json.NewEncoder(w).Encode(&NextNonceResponse{pk, 1}) // nolint: errcheck
} else {
require.Equal(t, "1", r.Header.Get("Sw-Nonce"))
require.Equal(t, pk.Hex(), r.Header.Get("Sw-Public"))
sig := cipher.Sig{}
require.NoError(t, sig.UnmarshalText([]byte(r.Header.Get("Sw-Sig"))))
require.NoError(t, cipher.VerifyPubKeySignedPayload(pk, sig, PayloadWithNonce([]byte{}, 1)))
fmt.Fprintln(w, "Hello, client")
}
}))
headerCh := make(chan http.Header, 1)
ts := newTestServer(pk, headerCh)
defer ts.Close()

c, err := NewClient(context.TODO(), ts.URL, pk, sk)
require.NoError(t, err)

req, err := http.NewRequest("GET", ts.URL+"/foo", bytes.NewBufferString(payload))
require.NoError(t, err)
res, err := c.Do(req)
require.NoError(t, err)

b, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
res.Body.Close()
assert.Equal(t, []byte(payload), b)
assert.Equal(t, uint64(2), c.nonce)

headers := <-headerCh
checkResp(t, headers, b, pk, 1)
}

func TestBadNonce(t *testing.T) {
pk, sk := cipher.GenerateKeyPair()

headerCh := make(chan http.Header, 1)
ts := newTestServer(pk, headerCh)
defer ts.Close()

c, err := NewClient(context.TODO(), ts.URL, pk, sk)
require.NoError(t, err)

req, err := http.NewRequest("GET", ts.URL+"/foo", nil)
c.nonce = 999

req, err := http.NewRequest("GET", ts.URL+"/foo", bytes.NewBufferString(payload))
require.NoError(t, err)
res, err := c.Do(req)
require.NoError(t, err)

b, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
res.Body.Close()
assert.Equal(t, []byte("Hello, client\n"), b)
assert.Equal(t, uint64(2), c.nonce)

headers := <-headerCh
checkResp(t, headers, b, pk, 1)
}

func checkResp(t *testing.T, headers http.Header, body []byte, pk cipher.PubKey, nonce int) {
require.Equal(t, strconv.Itoa(nonce), headers.Get("Sw-Nonce"))
require.Equal(t, pk.Hex(), headers.Get("Sw-Public"))
sig := cipher.Sig{}
require.NoError(t, sig.UnmarshalText([]byte(headers.Get("Sw-Sig"))))
require.NoError(t, cipher.VerifyPubKeySignedPayload(pk, sig, PayloadWithNonce(body, Nonce(nonce))))
}

func newTestServer(pk cipher.PubKey, headerCh chan<- http.Header) *httptest.Server {
nonce := 1
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == fmt.Sprintf("/security/nonces/%s", pk) {
json.NewEncoder(w).Encode(&NextNonceResponse{pk, Nonce(nonce)}) // nolint: errcheck
} else {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return
}
defer r.Body.Close()
respMessage := string(body)
if r.Header.Get("Sw-Nonce") != strconv.Itoa(nonce) {
respMessage = errorMessage
} else {
headerCh <- r.Header
nonce++
}
fmt.Fprint(w, respMessage)
}
}))
}

0 comments on commit 8860264

Please sign in to comment.