From 8860264478e61f1e5e3ddda67d87386a76516967 Mon Sep 17 00:00:00 2001 From: Nikita Kryuchkov Date: Mon, 13 May 2019 00:26:55 +0300 Subject: [PATCH] Request a new nonce and retry HTTP request if a nonce validation error is returned --- internal/httpauth/client.go | 73 ++++++++++++++++++++++----- internal/httpauth/client_test.go | 85 ++++++++++++++++++++++++++------ 2 files changed, 131 insertions(+), 27 deletions(-) diff --git a/internal/httpauth/client.go b/internal/httpauth/client.go index 614034f99b..044b250cd0 100644 --- a/internal/httpauth/client.go +++ b/internal/httpauth/client.go @@ -13,10 +13,16 @@ import ( "strconv" "strings" "sync/atomic" + "time" "github.com/skycoin/skywire/pkg/cipher" ) +const ( + retryInterval = 100 * time.Millisecond + invalidNonceErrorMessage = "SW-Nonce does not match" +) + // NextNonceResponse represents a ServeHTTP response for json encoding type NextNonceResponse struct { Edge cipher.PubKey `json:"edge"` @@ -71,11 +77,6 @@ func NewClient(ctx context.Context, addr string, key cipher.PubKey, sec cipher.S // Do performs a new authenticated Request and returns the response. Internally, if the request was // successful nonce is incremented func (c *Client) Do(req *http.Request) (*http.Response, error) { - req.Header.Add("SW-Public", c.key.Hex()) - - // use nonce, later, if no err from req update such nonce - req.Header.Add("SW-Nonce", strconv.FormatUint(uint64(c.getCurrentNonce()), 10)) - body := make([]byte, 0) if req.ContentLength != 0 { auxBody, err := ioutil.ReadAll(req.Body) @@ -86,16 +87,42 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { req.Body = ioutil.NopCloser(bytes.NewBuffer(auxBody)) body = auxBody } - sign, err := Sign(body, c.getCurrentNonce(), c.sec) - if err != nil { - return nil, err - } - req.Header.Add("SW-Sig", sign.Hex()) + var res *http.Response + for { + nonce := c.getCurrentNonce() + sign, err := Sign(body, nonce, c.sec) + if err != nil { + return nil, err + } + + // use nonce, later, if no err from req update such nonce + req.Header.Set("SW-Nonce", strconv.FormatUint(uint64(nonce), 10)) + req.Header.Set("SW-Sig", sign.Hex()) + req.Header.Set("SW-Public", c.key.Hex()) - res, err := c.client.Do(req) - if err != nil { - return nil, err + res, err = c.client.Do(req) + if err != nil { + return nil, err + } + + isNonceValid, err := c.isNonceValid(res) + if err != nil { + return nil, err + } + + if isNonceValid { + break + } + + nonce, err = c.Nonce(context.Background(), c.key) + if err != nil { + return nil, err + } + c.SetNonce(nonce) + + res.Body.Close() + time.Sleep(retryInterval) } if res.StatusCode == http.StatusOK { @@ -144,6 +171,26 @@ func (c *Client) incrementNonce() { atomic.AddUint64(&c.nonce, 1) } +func (c *Client) isNonceValid(res *http.Response) (bool, error) { + var serverResponse HTTPResponse + + auxRespBody, err := ioutil.ReadAll(res.Body) + if err != nil { + return false, err + } + res.Body.Close() + res.Body = ioutil.NopCloser(bytes.NewBuffer(auxRespBody)) + + if err := json.Unmarshal(auxRespBody, &serverResponse); err != nil || serverResponse.Error == nil { + return true, nil + } + + isAuthorized := serverResponse.Error.Code != http.StatusUnauthorized + hasValidNonce := serverResponse.Error.Message != invalidNonceErrorMessage + + return isAuthorized && hasValidNonce, nil +} + func sanitizedAddr(addr string) string { if addr == "" { return "http://localhost" diff --git a/internal/httpauth/client_test.go b/internal/httpauth/client_test.go index 70a95eb64d..cd9f7fad31 100644 --- a/internal/httpauth/client_test.go +++ b/internal/httpauth/client_test.go @@ -1,12 +1,14 @@ package httpauth import ( + "bytes" "context" "encoding/json" "fmt" "io/ioutil" "net/http" "net/http/httptest" + "strconv" "testing" "github.com/stretchr/testify/assert" @@ -15,27 +17,49 @@ import ( "github.com/skycoin/skywire/pkg/cipher" ) +const ( + payload = "Hello, client\n" + errorMessage = `{"error":{"message":"SW-Nonce does not match","code":401}}` +) + func TestClient(t *testing.T) { pk, sk := cipher.GenerateKeyPair() - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.String() == fmt.Sprintf("/security/nonces/%s", pk) { - json.NewEncoder(w).Encode(&NextNonceResponse{pk, 1}) // nolint: errcheck - } else { - require.Equal(t, "1", r.Header.Get("Sw-Nonce")) - require.Equal(t, pk.Hex(), r.Header.Get("Sw-Public")) - sig := cipher.Sig{} - require.NoError(t, sig.UnmarshalText([]byte(r.Header.Get("Sw-Sig")))) - require.NoError(t, cipher.VerifyPubKeySignedPayload(pk, sig, PayloadWithNonce([]byte{}, 1))) - fmt.Fprintln(w, "Hello, client") - } - })) + headerCh := make(chan http.Header, 1) + ts := newTestServer(pk, headerCh) + defer ts.Close() + + c, err := NewClient(context.TODO(), ts.URL, pk, sk) + require.NoError(t, err) + + req, err := http.NewRequest("GET", ts.URL+"/foo", bytes.NewBufferString(payload)) + require.NoError(t, err) + res, err := c.Do(req) + require.NoError(t, err) + + b, err := ioutil.ReadAll(res.Body) + require.NoError(t, err) + res.Body.Close() + assert.Equal(t, []byte(payload), b) + assert.Equal(t, uint64(2), c.nonce) + + headers := <-headerCh + checkResp(t, headers, b, pk, 1) +} + +func TestBadNonce(t *testing.T) { + pk, sk := cipher.GenerateKeyPair() + + headerCh := make(chan http.Header, 1) + ts := newTestServer(pk, headerCh) defer ts.Close() c, err := NewClient(context.TODO(), ts.URL, pk, sk) require.NoError(t, err) - req, err := http.NewRequest("GET", ts.URL+"/foo", nil) + c.nonce = 999 + + req, err := http.NewRequest("GET", ts.URL+"/foo", bytes.NewBufferString(payload)) require.NoError(t, err) res, err := c.Do(req) require.NoError(t, err) @@ -43,6 +67,39 @@ func TestClient(t *testing.T) { b, err := ioutil.ReadAll(res.Body) require.NoError(t, err) res.Body.Close() - assert.Equal(t, []byte("Hello, client\n"), b) assert.Equal(t, uint64(2), c.nonce) + + headers := <-headerCh + checkResp(t, headers, b, pk, 1) +} + +func checkResp(t *testing.T, headers http.Header, body []byte, pk cipher.PubKey, nonce int) { + require.Equal(t, strconv.Itoa(nonce), headers.Get("Sw-Nonce")) + require.Equal(t, pk.Hex(), headers.Get("Sw-Public")) + sig := cipher.Sig{} + require.NoError(t, sig.UnmarshalText([]byte(headers.Get("Sw-Sig")))) + require.NoError(t, cipher.VerifyPubKeySignedPayload(pk, sig, PayloadWithNonce(body, Nonce(nonce)))) +} + +func newTestServer(pk cipher.PubKey, headerCh chan<- http.Header) *httptest.Server { + nonce := 1 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.String() == fmt.Sprintf("/security/nonces/%s", pk) { + json.NewEncoder(w).Encode(&NextNonceResponse{pk, Nonce(nonce)}) // nolint: errcheck + } else { + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return + } + defer r.Body.Close() + respMessage := string(body) + if r.Header.Get("Sw-Nonce") != strconv.Itoa(nonce) { + respMessage = errorMessage + } else { + headerCh <- r.Header + nonce++ + } + fmt.Fprint(w, respMessage) + } + })) }