Skip to content

Commit

Permalink
feat(GraphQL): Add GraphQL schema validation Endpoint. (#6250)
Browse files Browse the repository at this point in the history
* Add GraphQL schema validation Endpoint.

(cherry picked from commit df1c7c9)
  • Loading branch information
Arijit Das authored and arijitAD committed Sep 17, 2020
1 parent 9c8c994 commit 7d6d881
Show file tree
Hide file tree
Showing 14 changed files with 255 additions and 76 deletions.
17 changes: 17 additions & 0 deletions dgraph/cmd/alpha/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,23 @@ func setupServer(closer *z.Closer) {
adminSchemaHandler(w, r, adminServer)
})))

http.Handle("/admin/schema/validate", http.HandlerFunc(func(w http.ResponseWriter,
r *http.Request) {
schema := readRequest(w, r)
w.Header().Set("Content-Type", "application/json")

err := admin.SchemaValidate(string(schema))
if err == nil {
w.WriteHeader(http.StatusOK)
x.SetStatus(w, "success", "Schema is valid")
return
}

w.WriteHeader(http.StatusBadRequest)
errs := strings.Split(strings.TrimSpace(err.Error()), "\n")
x.SetStatusWithErrors(w, x.ErrorInvalidRequest, errs)
}))

http.Handle("/admin/shutdown", allowedMethodsHandler(allowedMethods{http.MethodGet: true},
adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
shutDownHandler(w, r, adminServer)
Expand Down
12 changes: 11 additions & 1 deletion graphql/admin/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,16 @@ var (
mainHealthStore = &GraphQLHealthStore{}
)

func SchemaValidate(sch string) error {
schHandler, err := schema.NewHandler(sch, true)
if err != nil {
return err
}

_, err = schema.FromString(schHandler.GQLSchema())
return err
}

// GraphQLHealth is used to report the health status of a GraphQL server.
// It is required for kubernetes probing.
type GraphQLHealth struct {
Expand Down Expand Up @@ -557,7 +567,7 @@ func getCurrentGraphQLSchema() (*gqlSchema, error) {
}

func generateGQLSchema(sch *gqlSchema) (schema.Schema, error) {
schHandler, err := schema.NewHandler(sch.Schema)
schHandler, err := schema.NewHandler(sch.Schema, false)
if err != nil {
return nil, err
}
Expand Down
4 changes: 3 additions & 1 deletion graphql/admin/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ func resolveUpdateGQLSchema(ctx context.Context, m schema.Mutation) (*resolve.Re
return resolve.EmptyResult(m, err), false
}

schHandler, err := schema.NewHandler(input.Set.Schema)
// We just need to validate the schema. Schema is later set in `resetSchema()` when the schema
// is returned from badger.
schHandler, err := schema.NewHandler(input.Set.Schema, true)
if err != nil {
return resolve.EmptyResult(m, err), false
}
Expand Down
82 changes: 52 additions & 30 deletions graphql/authorization/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"net/http"
"regexp"
"strings"
"sync"

"github.com/golang/glog"
"github.com/vektah/gqlparser/v2/gqlerror"
Expand All @@ -47,7 +48,7 @@ const (
)

var (
metainfo AuthMeta
authMeta = &AuthMeta{}
)

type AuthMeta struct {
Expand All @@ -57,6 +58,7 @@ type AuthMeta struct {
Namespace string
Algo string
Audience []string
sync.RWMutex
}

// Validate required fields.
Expand Down Expand Up @@ -84,17 +86,17 @@ func (a *AuthMeta) validate() error {
return nil
}

func Parse(schema string) (AuthMeta, error) {
func Parse(schema string) (*AuthMeta, error) {
var meta AuthMeta
authInfoIdx := strings.LastIndex(schema, AuthMetaHeader)
if authInfoIdx == -1 {
return meta, nil
return nil, nil
}
authInfo := schema[authInfoIdx:]

err := json.Unmarshal([]byte(authInfo[len(AuthMetaHeader):]), &meta)
if err == nil {
return meta, meta.validate()
return &meta, meta.validate()
}

glog.Warningln("Falling back to parsing `Dgraph.Authorization` in old format." +
Expand All @@ -113,60 +115,77 @@ func Parse(schema string) (AuthMeta, error) {
authMetaRegex, err :=
regexp.Compile(`^#[\s]([^\s]+)[\s]+([^\s]+)[\s]+([^\s]+)[\s]+([^\s]+)[\s]+"([^\"]+)"`)
if err != nil {
return meta, gqlerror.Errorf("JWT parsing failed: %v", err)
return nil, gqlerror.Errorf("JWT parsing failed: %v", err)
}

idx := authMetaRegex.FindAllStringSubmatchIndex(authInfo, -1)
if len(idx) != 1 || len(idx[0]) != 12 ||
!strings.HasPrefix(authInfo, authInfo[idx[0][0]:idx[0][1]]) {
return meta, gqlerror.Errorf("Invalid `Dgraph.Authorization` format: %s", authInfo)
return nil, gqlerror.Errorf("Invalid `Dgraph.Authorization` format: %s", authInfo)
}

meta.Header = authInfo[idx[0][4]:idx[0][5]]
meta.Namespace = authInfo[idx[0][6]:idx[0][7]]
meta.Algo = authInfo[idx[0][8]:idx[0][9]]
meta.VerificationKey = authInfo[idx[0][10]:idx[0][11]]
if meta.Algo == HMAC256 {
return meta, nil
return &meta, nil
}
if meta.Algo != RSA256 {
return meta, errors.Errorf(
return nil, errors.Errorf(
"invalid jwt algorithm: found %s, but supported options are HS256 or RS256", meta.Algo)
}
return meta, nil
return &meta, nil
}

func ParseAuthMeta(schema string) error {
var err error
metainfo, err = Parse(schema)
func ParseAuthMeta(schema string) (*AuthMeta, error) {
metaInfo, err := Parse(schema)
if err != nil {
return err
return nil, err
}

if metainfo.Algo != RSA256 {
return err
if metaInfo.Algo != RSA256 {
return metaInfo, nil
}

// The jwt library internally uses `bytes.IndexByte(data, '\n')` to fetch new line and fails
// if we have newline "\n" as ASCII value {92,110} instead of the actual ASCII value of 10.
// To fix this we replace "\n" with new line's ASCII value.
bytekey := bytes.ReplaceAll([]byte(metainfo.VerificationKey), []byte{92, 110}, []byte{10})
bytekey := bytes.ReplaceAll([]byte(metaInfo.VerificationKey), []byte{92, 110}, []byte{10})

metainfo.RSAPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(bytekey)
return err
if metaInfo.RSAPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(bytekey); err != nil {
return nil, err
}
return metaInfo, nil
}

func GetHeader() string {
return metainfo.Header
authMeta.RLock()
defer authMeta.RUnlock()
return authMeta.Header
}

func GetAuthMeta() AuthMeta {
return metainfo
func GetAuthMeta() *AuthMeta {
authMeta.RLock()
defer authMeta.RUnlock()
return authMeta
}

func SetAuthMeta(m *AuthMeta) {
authMeta.Lock()
defer authMeta.Unlock()

authMeta.VerificationKey = m.VerificationKey
authMeta.RSAPublicKey = m.RSAPublicKey
authMeta.Header = m.Header
authMeta.Namespace = m.Namespace
authMeta.Algo = m.Algo
authMeta.Audience = m.Audience
}

// AttachAuthorizationJwt adds any incoming JWT authorization data into the grpc context metadata.
func AttachAuthorizationJwt(ctx context.Context, r *http.Request) context.Context {
authorizationJwt := r.Header.Get(metainfo.Header)
authorizationJwt := r.Header.Get(authMeta.Header)
if authorizationJwt == "" {
return ctx
}
Expand Down Expand Up @@ -198,7 +217,7 @@ func (c *CustomClaims) UnmarshalJSON(data []byte) error {
}

// Unmarshal the auth variables for a particular namespace.
if authValue, ok := result[metainfo.Namespace]; ok {
if authValue, ok := result[authMeta.Namespace]; ok {
if authJson, ok := authValue.(string); ok {
if err := json.Unmarshal([]byte(authJson), &c.AuthVariables); err != nil {
return err
Expand All @@ -217,13 +236,13 @@ func (c *CustomClaims) validateAudience() error {
}

// If there is an audience claim, but no value provided, fail
if metainfo.Audience == nil {
if authMeta.Audience == nil {
return fmt.Errorf("audience value was expected but not provided")
}

var match = false
for _, audStr := range c.Audience {
for _, expectedAudStr := range metainfo.Audience {
for _, expectedAudStr := range authMeta.Audience {
if subtle.ConstantTimeCompare([]byte(audStr), []byte(expectedAudStr)) == 1 {
match = true
break
Expand Down Expand Up @@ -253,7 +272,10 @@ func ExtractCustomClaims(ctx context.Context) (*CustomClaims, error) {
}

func validateJWTCustomClaims(jwtStr string) (*CustomClaims, error) {
if metainfo.Algo == "" {
authMeta.RLock()
defer authMeta.RUnlock()

if authMeta.Algo == "" {
return nil, fmt.Errorf(
"jwt token cannot be validated because verification algorithm is not set")
}
Expand All @@ -264,17 +286,17 @@ func validateJWTCustomClaims(jwtStr string) (*CustomClaims, error) {
token, err :=
jwt.ParseWithClaims(jwtStr, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) {
algo, _ := token.Header["alg"].(string)
if algo != metainfo.Algo {
if algo != authMeta.Algo {
return nil, errors.Errorf("unexpected signing method: Expected %s Found %s",
metainfo.Algo, algo)
authMeta.Algo, algo)
}
if algo == HMAC256 {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); ok {
return []byte(metainfo.VerificationKey), nil
return []byte(authMeta.VerificationKey), nil
}
} else if algo == RSA256 {
if _, ok := token.Method.(*jwt.SigningMethodRSA); ok {
return metainfo.RSAPublicKey, nil
return authMeta.RSAPublicKey, nil
}
}
return nil, errors.Errorf("couldn't parse signing method from token header: %s", algo)
Expand Down
6 changes: 1 addition & 5 deletions graphql/e2e/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1093,11 +1093,7 @@ func TestMain(m *testing.M) {
panic(err)
}

authMeta, err := authorization.Parse(string(authSchema))
if err != nil {
panic(err)
}

authMeta := testutil.SetAuthMeta(string(authSchema))
metaInfo = &testutil.AuthMeta{
PublicKey: authMeta.VerificationKey,
Namespace: authMeta.Namespace,
Expand Down
3 changes: 2 additions & 1 deletion graphql/e2e/common/fragment.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ func fragmentInMutation(t *testing.T) {
gqlResponse := addStarshipParams.ExecuteAsPost(t, graphqlURL)
RequireNoGQLErrors(t, gqlResponse)

addStarshipExpected := `{"addStarship":{
addStarshipExpected := `
{"addStarship":{
"starship":[{
"name":"Millennium Falcon",
"length":2
Expand Down
81 changes: 81 additions & 0 deletions graphql/e2e/schema/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
package schema

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"sync"
"testing"
"time"
Expand All @@ -30,6 +32,7 @@ import (
"github.com/dgraph-io/dgraph/graphql/e2e/common"
"github.com/dgraph-io/dgraph/testutil"
"github.com/dgraph-io/dgraph/worker"
"github.com/dgraph-io/dgraph/x"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -366,6 +369,84 @@ func TestUpdateGQLSchemaFields(t *testing.T) {
require.Equal(t, string(generatedSchema), updateResp.UpdateGQLSchema.GQLSchema.GeneratedSchema)
}

// verifyEmptySchema verifies that the schema is not set in the GraphQL server.
func verifyEmptySchema(t *testing.T) {
schema := getGQLSchema(t, groupOneAdminServer)
require.Empty(t, schema.Schema)
}

func TestGQLSchemaValidate(t *testing.T) {
testCases := []struct {
schema string
errors x.GqlErrorList
valid bool
}{
{
schema: `
type Task @auth(
query: { rule: "{$USERROLE: { eq: \"USER\"}}" }
) {
id: ID!
name: String!
occurrences: [TaskOccurrence] @hasInverse(field: task)
}
type TaskOccurrence @auth(
query: { rule: "query { queryTaskOccurrence { task { id } } }" }
) {
id: ID!
due: DateTime
comp: DateTime
task: Task @hasInverse(field: occurrences)
}
`,
valid: true,
},
{
schema: `
type X {
id: ID @dgraph(pred: "X.id")
name: String
}
type Y {
f1: String! @dgraph(pred:"~movie")
}
`,
errors: x.GqlErrorList{{Message: "input:3: Type X; Field id: has the @dgraph directive but fields of type ID can't have the @dgraph directive."}, {Message: "input:7: Type Y; Field f1 is of type String, but reverse predicate in @dgraph directive only applies to fields with object types."}},
valid: false,
},
}

dg, err := testutil.DgraphClient(groupOnegRPC)
require.NoError(t, err)
testutil.DropAll(t, dg)

validateUrl := groupOneAdminServer + "/schema/validate"
var response x.QueryResWithData
for _, tcase := range testCases {
resp, err := http.Post(validateUrl, "text/plain", bytes.NewBuffer([]byte(tcase.schema)))
require.NoError(t, err)

decoder := json.NewDecoder(resp.Body)
err = decoder.Decode(&response)
require.NoError(t, err)

// Verify that we only validate the schema and not set it.
verifyEmptySchema(t)

if tcase.valid {
require.Equal(t, resp.StatusCode, http.StatusOK)
continue
}
require.Equal(t, resp.StatusCode, http.StatusBadRequest)
require.NotNil(t, response.Errors)
require.Equal(t, len(response.Errors), len(tcase.errors))
for idx, err := range response.Errors {
require.Equal(t, err.Message, tcase.errors[idx].Message)
}
}
}

func updateGQLSchema(t *testing.T, schema, url string) *common.GraphQLResponse {
req := &common.GraphQLParams{
Query: `mutation updateGQLSchema($sch: String!) {
Expand Down
Loading

0 comments on commit 7d6d881

Please sign in to comment.