From 39eea9c4b9aac8dde701b833d1e41bce1ad9e3cd Mon Sep 17 00:00:00 2001 From: Ben Bieker Date: Tue, 13 Oct 2015 10:14:30 +0200 Subject: [PATCH] sparse fieldsets filtering support if the appropriate query parameters are available. The api goes through all the attributes and only includes the ones from the query parameter. For example: `?fields[posts]=title,content` If there are invalid fields in the query, errors with all of them will be returned --- api.go | 121 +++++++++++++++++++++++++++++- api_test.go | 206 +++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 324 insertions(+), 3 deletions(-) diff --git a/api.go b/api.go index 6d02761..f98cdd3 100644 --- a/api.go +++ b/api.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" "reflect" + "regexp" "strconv" "strings" @@ -16,7 +17,12 @@ import ( "github.com/manyminds/api2go/routing" ) -const defaultContentTypHeader = "application/vnd.api+json" +const ( + codeInvalidQueryFields = "API2GO_INVALID_FIELD_QUERY_PARAM" + defaultContentTypHeader = "application/vnd.api+json" +) + +var queryFieldsRegex = regexp.MustCompile(`^fields\[(\w+)\]$`) type response struct { Meta map[string]interface{} @@ -994,7 +1000,11 @@ func unmarshalRequest(r *http.Request, marshalers map[string]ContentMarshaler) ( func marshalResponse(resp interface{}, w http.ResponseWriter, status int, r *http.Request, marshalers map[string]ContentMarshaler) error { marshaler, contentType := selectContentMarshaler(r, marshalers) - result, err := marshaler.Marshal(resp) + filtered, err := filterSparseFields(resp, r) + if err != nil { + return err + } + result, err := marshaler.Marshal(filtered) if err != nil { return err } @@ -1002,6 +1012,113 @@ func marshalResponse(resp interface{}, w http.ResponseWriter, status int, r *htt return nil } +func filterSparseFields(resp interface{}, r *http.Request) (interface{}, error) { + query := r.URL.Query() + queryParams := parseQueryFields(&query) + if len(queryParams) < 1 { + return resp, nil + } + + if content, ok := resp.(map[string]interface{}); ok { + wrongFields := map[string][]string{} + + // single entry in data + if data, ok := content["data"].(map[string]interface{}); ok { + errors := replaceAttributes(&queryParams, &data) + for t, v := range errors { + wrongFields[t] = v + } + } + + // data can be a slice too + if datas, ok := content["data"].([]map[string]interface{}); ok { + for index, data := range datas { + errors := replaceAttributes(&queryParams, &data) + for t, v := range errors { + wrongFields[t] = v + } + datas[index] = data + } + } + + // included slice + if included, ok := content["included"].([]map[string]interface{}); ok { + for index, include := range included { + errors := replaceAttributes(&queryParams, &include) + for t, v := range errors { + wrongFields[t] = v + } + included[index] = include + } + } + + if len(wrongFields) > 0 { + httpError := NewHTTPError(nil, "Some requested fields were invalid", http.StatusBadRequest) + for k, v := range wrongFields { + for _, field := range v { + httpError.Errors = append(httpError.Errors, Error{ + Status: "Bad Request", + Code: codeInvalidQueryFields, + Title: fmt.Sprintf(`Field "%s" does not exist for type "%s"`, field, k), + Detail: "Please make sure you do only request existing fields", + Source: &ErrorSource{ + Parameter: fmt.Sprintf("fields[%s]", k), + }, + }) + } + } + return nil, httpError + } + } + return resp, nil +} + +func parseQueryFields(query *url.Values) (result map[string][]string) { + result = map[string][]string{} + for name, param := range *query { + matches := queryFieldsRegex.FindStringSubmatch(name) + if len(matches) > 1 { + match := matches[1] + result[match] = strings.Split(param[0], ",") + } + } + + return +} + +func filterAttributes(attributes map[string]interface{}, fields []string) (filteredAttributes map[string]interface{}, wrongFields []string) { + wrongFields = []string{} + filteredAttributes = map[string]interface{}{} + + for _, field := range fields { + if attribute, ok := attributes[field]; ok { + filteredAttributes[field] = attribute + } else { + wrongFields = append(wrongFields, field) + } + } + + return +} + +func replaceAttributes(query *map[string][]string, entry *map[string]interface{}) map[string][]string { + fieldType := (*entry)["type"].(string) + fields := (*query)[fieldType] + if len(fields) > 0 { + if attributes, ok := (*entry)["attributes"]; ok { + var wrongFields []string + (*entry)["attributes"], wrongFields = filterAttributes(attributes.(map[string]interface{}), fields) + if len(wrongFields) > 0 { + return map[string][]string{ + fieldType: wrongFields, + } + } + } + } + + return nil +} + func selectContentMarshaler(r *http.Request, marshalers map[string]ContentMarshaler) (marshaler ContentMarshaler, contentType string) { if _, found := r.Header["Accept"]; found { var contentTypes []string diff --git a/api_test.go b/api_test.go index fca286e..49d4cc9 100644 --- a/api_test.go +++ b/api_test.go @@ -199,6 +199,7 @@ func (b Banana) GetID() string { type User struct { ID string `jsonapi:"-"` Name string + Info string } func (u User) GetID() string { @@ -524,6 +525,7 @@ var _ = Describe("RestHandler", func() { "type": "users", "attributes": map[string]interface{}{ "name": "Dieter", + "info": "", }, }, { @@ -650,7 +652,8 @@ var _ = Describe("RestHandler", func() { "id": "1", "type": "users", "attributes": { - "name": "Dieter" + "name": "Dieter", + "info": "" } }}`)) }) @@ -1587,4 +1590,205 @@ var _ = Describe("RestHandler", func() { Expect(rec.Body.Bytes()).To(ContainSubstring(expected)) }) }) + + Context("Sparse Fieldsets", func() { + var ( + source *fixtureSource + api *API + rec *httptest.ResponseRecorder + ) + + BeforeEach(func() { + author := User{ID: "666", Name: "Tester", Info: "Is curious about testing"} + source = &fixtureSource{map[string]*Post{ + "1": {ID: "1", Title: "Nice Post", Value: null.FloatFrom(13.37), Author: &author}, + }, false} + api = NewAPI("") + api.AddResource(Post{}, source) + rec = httptest.NewRecorder() + }) + + It("only returns requested post fields for single post", func() { + req, err := http.NewRequest("GET", "/posts/1?fields[posts]=title,value", nil) + Expect(err).ToNot(HaveOccurred()) + api.Handler().ServeHTTP(rec, req) + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(rec.Body.Bytes()).To(MatchJSON(` + {"data": { + "id": "1", + "type": "posts", + "attributes": { + "title": "Nice Post", + "value": 13.37 + }, + "relationships": { + "author": { + "data": { + "id": "666", + "type": "users" + }, + "links": { + "related": "/posts/1/author", + "self": "/posts/1/relationships/author" + } + }, + "bananas": { + "data": [], + "links": { + "related": "/posts/1/bananas", + "self": "/posts/1/relationships/bananas" + } + }, + "comments": { + "data": [], + "links": { + "related": "/posts/1/comments", + "self": "/posts/1/relationships/comments" + } + } + } + }, + "included": [ + { + "attributes": { + "info": "Is curious about testing", + "name": "Tester" + }, + "id": "666", + "type": "users" + } + ] + }`)) + }) + + It("FindOne: only returns requested post field for single post and includes", func() { + req, err := http.NewRequest("GET", "/posts/1?fields[posts]=title&fields[users]=name", nil) + Expect(err).ToNot(HaveOccurred()) + api.Handler().ServeHTTP(rec, req) + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(rec.Body.Bytes()).To(MatchJSON(` + {"data": { + "id": "1", + "type": "posts", + "attributes": { + "title": "Nice Post" + }, + "relationships": { + "author": { + "data": { + "id": "666", + "type": "users" + }, + "links": { + "related": "/posts/1/author", + "self": "/posts/1/relationships/author" + } + }, + "bananas": { + "data": [], + "links": { + "related": "/posts/1/bananas", + "self": "/posts/1/relationships/bananas" + } + }, + "comments": { + "data": [], + "links": { + "related": "/posts/1/comments", + "self": "/posts/1/relationships/comments" + } + } + } + }, + "included": [ + { + "attributes": { + "name": "Tester" + }, + "id": "666", + "type": "users" + } + ] + }`)) + }) + + It("FindAll: only returns requested post field for single post and includes", func() { + req, err := http.NewRequest("GET", "/posts?fields[posts]=title&fields[users]=name", nil) + Expect(err).ToNot(HaveOccurred()) + api.Handler().ServeHTTP(rec, req) + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(rec.Body.Bytes()).To(MatchJSON(` + {"data": [{ + "id": "1", + "type": "posts", + "attributes": { + "title": "Nice Post" + }, + "relationships": { + "author": { + "data": { + "id": "666", + "type": "users" + }, + "links": { + "related": "/posts/1/author", + "self": "/posts/1/relationships/author" + } + }, + "bananas": { + "data": [], + "links": { + "related": "/posts/1/bananas", + "self": "/posts/1/relationships/bananas" + } + }, + "comments": { + "data": [], + "links": { + "related": "/posts/1/comments", + "self": "/posts/1/relationships/comments" + } + } + } + }], + "included": [ + { + "attributes": { + "name": "Tester" + }, + "id": "666", + "type": "users" + } + ] + }`)) + }) + + It("Summarize all invalid field query parameters as error", func() { + req, err := http.NewRequest("GET", "/posts?fields[posts]=title,nonexistent&fields[users]=name,title,fluffy,pink", nil) + Expect(err).ToNot(HaveOccurred()) + api.Handler().ServeHTTP(rec, req) + Expect(rec.Code).To(Equal(http.StatusBadRequest)) + error := HTTPError{} + err = json.Unmarshal(rec.Body.Bytes(), &error) + Expect(err).ToNot(HaveOccurred()) + + expectedError := func(field, objType string) Error { + return Error{ + Status: "Bad Request", + Code: codeInvalidQueryFields, + Title: fmt.Sprintf(`Field "%s" does not exist for type "%s"`, field, objType), + Detail: "Please make sure you do only request existing fields", + Source: &ErrorSource{ + Parameter: fmt.Sprintf("fields[%s]", objType), + }, + } + } + + Expect(error.Errors).To(HaveLen(4)) + Expect(error.Errors).To(ContainElement(expectedError("nonexistent", "posts"))) + Expect(error.Errors).To(ContainElement(expectedError("title", "users"))) + Expect(error.Errors).To(ContainElement(expectedError("fluffy", "users"))) + Expect(error.Errors).To(ContainElement(expectedError("pink", "users"))) + }) + }) })