diff --git a/client.go b/client.go index 0588021..7fe2171 100644 --- a/client.go +++ b/client.go @@ -535,8 +535,19 @@ func getHTTPResp(cli HTTP, r *http.Request, resp *Response) error { } defer w.Body.Close() - decoder := json.NewDecoder(w.Body) - return decoder.Decode(resp) + prevRespData := resp.Data + err = json.NewDecoder(w.Body).Decode(resp) + if err != nil { + return err + } + + // NOTE: in case of error, the server may return the data is nil; we must not accept this value but keep + // the other Response values. This is because if it is set to nil, our outpointer writing would fail and do nothing on retry + if resp.Data == nil { + resp.Data = prevRespData + } + + return nil } // MockClient builds a client that ignores certs and talks to the given host. diff --git a/client_test.go b/client_test.go index 8d84adc..125da35 100644 --- a/client_test.go +++ b/client_test.go @@ -75,9 +75,9 @@ func buildServer(code int, body []byte, a func(r *http.Request)) *httptest.Serve })) } -func buildConcurrentServer(code int, t *testing.T, a func(r *http.Request) []byte) *httptest.Server { +func buildConcurrentServer(a func(r *http.Request) (resp []byte, code int)) *httptest.Server { return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resp := a(r) + resp, code := a(r) w.WriteHeader(code) w.Header().Set("Content-Type", "application/json") w.Write(resp) @@ -357,7 +357,7 @@ func TestPutAccess(t *testing.T) { func TestConcurrentDeletes(t *testing.T) { var ops uint64 - srv := buildConcurrentServer(200, t, func(r *http.Request) []byte { + srv := buildConcurrentServer(func(r *http.Request) (resp []byte, code int) { if r.Method != "DELETE" { t.Fatalf("%s is not DELETE", r.Method) } @@ -366,7 +366,6 @@ func TestConcurrentDeletes(t *testing.T) { t.Fatalf("%s is not the path for testkey1 or testkey2", r.URL.Path) } atomic.AddUint64(&ops, 1) - var resp []byte var err error if ops%2 == 0 { resp, err = buildGoodResponse("") @@ -379,7 +378,7 @@ func TestConcurrentDeletes(t *testing.T) { t.Fatalf("%s is not nil", err) } } - return resp + return resp, 200 }) defer srv.Close() @@ -401,6 +400,62 @@ func TestConcurrentDeletes(t *testing.T) { } } +func TestConcurrentAddVersion(t *testing.T) { + var ops uint64 + expected := uint64(123) + expected2 := uint64(124) + srv := buildConcurrentServer(func(r *http.Request) (resp []byte, code int) { + if r.Method != "POST" { + t.Fatalf("%s is not POST", r.Method) + } + if r.URL.Path != "/v0/keys/testkey1/versions/" { + t.Fatalf("%s is not the path for testkey1", r.URL.Path) + } + atomic.AddUint64(&ops, 1) + var err error + switch ops { + case 1: + resp, err = buildGoodResponse(expected) + code = 200 + case 2: + resp, err = buildInternalServerErrorResponse(nil) + code = 500 + case 3: + resp, err = buildGoodResponse(expected2) + code = 200 + default: + } + if err != nil { + t.Fatalf("%s is not nil", err) + } + return resp, code + }) + defer srv.Close() + + cli := MockClient(srv.Listener.Addr().String(), "") + + // Put a new version of the same key in succession + respData, err := cli.AddVersion("testkey1", []byte("data")) + if err != nil { + t.Fatalf("%s is not nil", err) + } + if respData != expected { + t.Fatalf("expected %d but got %d", expected, respData) + } + respData, err = cli.AddVersion("testkey1", []byte("data")) + if err != nil { + t.Fatalf("%s is not nil", err) + } + if respData != expected2 { + t.Fatalf("expected %d but got %d", expected2, respData) + } + + // Verify that our atomic counter was incremented 3 times + if ops != 3 { + t.Fatalf("%d total client attempts is not 3", ops) + } +} + func TestGetKeyWithStatus(t *testing.T) { expected := Key{ ID: "testkey", diff --git a/server/routes.go b/server/routes.go index 032d5b5..7de21ba 100644 --- a/server/routes.go +++ b/server/routes.go @@ -341,7 +341,7 @@ func putAccessHandler(m KeyManager, principal knox.Principal, parameters map[str // postVersionHandler creates a new key version. This version is immediately // added as an Active key. -// The route for this handler is PUT /v0/keys//versions/ +// The route for this handler is POST /v0/keys//versions/ // The principal needs Write access. func postVersionHandler(m KeyManager, principal knox.Principal, parameters map[string]string) (interface{}, *HTTPError) {