diff --git a/dgraph/cmd/alpha/run.go b/dgraph/cmd/alpha/run.go index 0807ce52dca..4bab03f44fc 100644 --- a/dgraph/cmd/alpha/run.go +++ b/dgraph/cmd/alpha/run.go @@ -513,6 +513,23 @@ func setupServer(closer *z.Closer) { adminSchemaHandler(w, r, adminServer) }))) + http.Handle("/admin/schema/validate", http.HandlerFunc(func(w http.ResponseWriter, + r *http.Request) { + schema := readRequest(w, r) + w.Header().Set("Content-Type", "application/json") + + err := admin.SchemaValidate(string(schema)) + if err == nil { + w.WriteHeader(http.StatusOK) + x.SetStatus(w, "success", "Schema is valid") + return + } + + w.WriteHeader(http.StatusBadRequest) + errs := strings.Split(strings.TrimSpace(err.Error()), "\n") + x.SetStatusWithErrors(w, x.ErrorInvalidRequest, errs) + })) + http.Handle("/admin/shutdown", allowedMethodsHandler(allowedMethods{http.MethodGet: true}, adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { shutDownHandler(w, r, adminServer) diff --git a/graphql/admin/admin.go b/graphql/admin/admin.go index 9a98724b489..9bcd0e03aa2 100644 --- a/graphql/admin/admin.go +++ b/graphql/admin/admin.go @@ -327,6 +327,16 @@ var ( mainHealthStore = &GraphQLHealthStore{} ) +func SchemaValidate(sch string) error { + schHandler, err := schema.NewHandler(sch, true) + if err != nil { + return err + } + + _, err = schema.FromString(schHandler.GQLSchema()) + return err +} + // GraphQLHealth is used to report the health status of a GraphQL server. // It is required for kubernetes probing. type GraphQLHealth struct { @@ -557,7 +567,7 @@ func getCurrentGraphQLSchema() (*gqlSchema, error) { } func generateGQLSchema(sch *gqlSchema) (schema.Schema, error) { - schHandler, err := schema.NewHandler(sch.Schema) + schHandler, err := schema.NewHandler(sch.Schema, false) if err != nil { return nil, err } diff --git a/graphql/admin/schema.go b/graphql/admin/schema.go index 7e1c23448f6..b38d7da897d 100644 --- a/graphql/admin/schema.go +++ b/graphql/admin/schema.go @@ -49,7 +49,9 @@ func resolveUpdateGQLSchema(ctx context.Context, m schema.Mutation) (*resolve.Re return resolve.EmptyResult(m, err), false } - schHandler, err := schema.NewHandler(input.Set.Schema) + // We just need to validate the schema. Schema is later set in `resetSchema()` when the schema + // is returned from badger. + schHandler, err := schema.NewHandler(input.Set.Schema, true) if err != nil { return resolve.EmptyResult(m, err), false } diff --git a/graphql/authorization/auth.go b/graphql/authorization/auth.go index 849b54b45d1..c6fb3a9f803 100644 --- a/graphql/authorization/auth.go +++ b/graphql/authorization/auth.go @@ -26,6 +26,7 @@ import ( "net/http" "regexp" "strings" + "sync" "github.com/golang/glog" "github.com/vektah/gqlparser/v2/gqlerror" @@ -47,7 +48,7 @@ const ( ) var ( - metainfo AuthMeta + authMeta = &AuthMeta{} ) type AuthMeta struct { @@ -57,6 +58,7 @@ type AuthMeta struct { Namespace string Algo string Audience []string + sync.RWMutex } // Validate required fields. @@ -84,17 +86,17 @@ func (a *AuthMeta) validate() error { return nil } -func Parse(schema string) (AuthMeta, error) { +func Parse(schema string) (*AuthMeta, error) { var meta AuthMeta authInfoIdx := strings.LastIndex(schema, AuthMetaHeader) if authInfoIdx == -1 { - return meta, nil + return nil, nil } authInfo := schema[authInfoIdx:] err := json.Unmarshal([]byte(authInfo[len(AuthMetaHeader):]), &meta) if err == nil { - return meta, meta.validate() + return &meta, meta.validate() } glog.Warningln("Falling back to parsing `Dgraph.Authorization` in old format." + @@ -113,13 +115,13 @@ func Parse(schema string) (AuthMeta, error) { authMetaRegex, err := regexp.Compile(`^#[\s]([^\s]+)[\s]+([^\s]+)[\s]+([^\s]+)[\s]+([^\s]+)[\s]+"([^\"]+)"`) if err != nil { - return meta, gqlerror.Errorf("JWT parsing failed: %v", err) + return nil, gqlerror.Errorf("JWT parsing failed: %v", err) } idx := authMetaRegex.FindAllStringSubmatchIndex(authInfo, -1) if len(idx) != 1 || len(idx[0]) != 12 || !strings.HasPrefix(authInfo, authInfo[idx[0][0]:idx[0][1]]) { - return meta, gqlerror.Errorf("Invalid `Dgraph.Authorization` format: %s", authInfo) + return nil, gqlerror.Errorf("Invalid `Dgraph.Authorization` format: %s", authInfo) } meta.Header = authInfo[idx[0][4]:idx[0][5]] @@ -127,46 +129,63 @@ func Parse(schema string) (AuthMeta, error) { meta.Algo = authInfo[idx[0][8]:idx[0][9]] meta.VerificationKey = authInfo[idx[0][10]:idx[0][11]] if meta.Algo == HMAC256 { - return meta, nil + return &meta, nil } if meta.Algo != RSA256 { - return meta, errors.Errorf( + return nil, errors.Errorf( "invalid jwt algorithm: found %s, but supported options are HS256 or RS256", meta.Algo) } - return meta, nil + return &meta, nil } -func ParseAuthMeta(schema string) error { - var err error - metainfo, err = Parse(schema) +func ParseAuthMeta(schema string) (*AuthMeta, error) { + metaInfo, err := Parse(schema) if err != nil { - return err + return nil, err } - if metainfo.Algo != RSA256 { - return err + if metaInfo.Algo != RSA256 { + return metaInfo, nil } // The jwt library internally uses `bytes.IndexByte(data, '\n')` to fetch new line and fails // if we have newline "\n" as ASCII value {92,110} instead of the actual ASCII value of 10. // To fix this we replace "\n" with new line's ASCII value. - bytekey := bytes.ReplaceAll([]byte(metainfo.VerificationKey), []byte{92, 110}, []byte{10}) + bytekey := bytes.ReplaceAll([]byte(metaInfo.VerificationKey), []byte{92, 110}, []byte{10}) - metainfo.RSAPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(bytekey) - return err + if metaInfo.RSAPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(bytekey); err != nil { + return nil, err + } + return metaInfo, nil } func GetHeader() string { - return metainfo.Header + authMeta.RLock() + defer authMeta.RUnlock() + return authMeta.Header } -func GetAuthMeta() AuthMeta { - return metainfo +func GetAuthMeta() *AuthMeta { + authMeta.RLock() + defer authMeta.RUnlock() + return authMeta +} + +func SetAuthMeta(m *AuthMeta) { + authMeta.Lock() + defer authMeta.Unlock() + + authMeta.VerificationKey = m.VerificationKey + authMeta.RSAPublicKey = m.RSAPublicKey + authMeta.Header = m.Header + authMeta.Namespace = m.Namespace + authMeta.Algo = m.Algo + authMeta.Audience = m.Audience } // AttachAuthorizationJwt adds any incoming JWT authorization data into the grpc context metadata. func AttachAuthorizationJwt(ctx context.Context, r *http.Request) context.Context { - authorizationJwt := r.Header.Get(metainfo.Header) + authorizationJwt := r.Header.Get(authMeta.Header) if authorizationJwt == "" { return ctx } @@ -198,7 +217,7 @@ func (c *CustomClaims) UnmarshalJSON(data []byte) error { } // Unmarshal the auth variables for a particular namespace. - if authValue, ok := result[metainfo.Namespace]; ok { + if authValue, ok := result[authMeta.Namespace]; ok { if authJson, ok := authValue.(string); ok { if err := json.Unmarshal([]byte(authJson), &c.AuthVariables); err != nil { return err @@ -217,13 +236,13 @@ func (c *CustomClaims) validateAudience() error { } // If there is an audience claim, but no value provided, fail - if metainfo.Audience == nil { + if authMeta.Audience == nil { return fmt.Errorf("audience value was expected but not provided") } var match = false for _, audStr := range c.Audience { - for _, expectedAudStr := range metainfo.Audience { + for _, expectedAudStr := range authMeta.Audience { if subtle.ConstantTimeCompare([]byte(audStr), []byte(expectedAudStr)) == 1 { match = true break @@ -253,7 +272,10 @@ func ExtractCustomClaims(ctx context.Context) (*CustomClaims, error) { } func validateJWTCustomClaims(jwtStr string) (*CustomClaims, error) { - if metainfo.Algo == "" { + authMeta.RLock() + defer authMeta.RUnlock() + + if authMeta.Algo == "" { return nil, fmt.Errorf( "jwt token cannot be validated because verification algorithm is not set") } @@ -264,17 +286,17 @@ func validateJWTCustomClaims(jwtStr string) (*CustomClaims, error) { token, err := jwt.ParseWithClaims(jwtStr, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) { algo, _ := token.Header["alg"].(string) - if algo != metainfo.Algo { + if algo != authMeta.Algo { return nil, errors.Errorf("unexpected signing method: Expected %s Found %s", - metainfo.Algo, algo) + authMeta.Algo, algo) } if algo == HMAC256 { if _, ok := token.Method.(*jwt.SigningMethodHMAC); ok { - return []byte(metainfo.VerificationKey), nil + return []byte(authMeta.VerificationKey), nil } } else if algo == RSA256 { if _, ok := token.Method.(*jwt.SigningMethodRSA); ok { - return metainfo.RSAPublicKey, nil + return authMeta.RSAPublicKey, nil } } return nil, errors.Errorf("couldn't parse signing method from token header: %s", algo) diff --git a/graphql/e2e/auth/auth_test.go b/graphql/e2e/auth/auth_test.go index cc21dcde262..104597e5cd7 100644 --- a/graphql/e2e/auth/auth_test.go +++ b/graphql/e2e/auth/auth_test.go @@ -1093,11 +1093,7 @@ func TestMain(m *testing.M) { panic(err) } - authMeta, err := authorization.Parse(string(authSchema)) - if err != nil { - panic(err) - } - + authMeta := testutil.SetAuthMeta(string(authSchema)) metaInfo = &testutil.AuthMeta{ PublicKey: authMeta.VerificationKey, Namespace: authMeta.Namespace, diff --git a/graphql/e2e/common/fragment.go b/graphql/e2e/common/fragment.go index de91dc515ae..3b1b8ff2549 100644 --- a/graphql/e2e/common/fragment.go +++ b/graphql/e2e/common/fragment.go @@ -50,7 +50,8 @@ func fragmentInMutation(t *testing.T) { gqlResponse := addStarshipParams.ExecuteAsPost(t, graphqlURL) RequireNoGQLErrors(t, gqlResponse) - addStarshipExpected := `{"addStarship":{ + addStarshipExpected := ` + {"addStarship":{ "starship":[{ "name":"Millennium Falcon", "length":2 diff --git a/graphql/e2e/schema/schema_test.go b/graphql/e2e/schema/schema_test.go index c2f5bc2f66b..184454eb4cf 100644 --- a/graphql/e2e/schema/schema_test.go +++ b/graphql/e2e/schema/schema_test.go @@ -17,10 +17,12 @@ package schema import ( + "bytes" "context" "encoding/json" "fmt" "io/ioutil" + "net/http" "sync" "testing" "time" @@ -30,6 +32,7 @@ import ( "github.com/dgraph-io/dgraph/graphql/e2e/common" "github.com/dgraph-io/dgraph/testutil" "github.com/dgraph-io/dgraph/worker" + "github.com/dgraph-io/dgraph/x" "github.com/stretchr/testify/require" ) @@ -366,6 +369,84 @@ func TestUpdateGQLSchemaFields(t *testing.T) { require.Equal(t, string(generatedSchema), updateResp.UpdateGQLSchema.GQLSchema.GeneratedSchema) } +// verifyEmptySchema verifies that the schema is not set in the GraphQL server. +func verifyEmptySchema(t *testing.T) { + schema := getGQLSchema(t, groupOneAdminServer) + require.Empty(t, schema.Schema) +} + +func TestGQLSchemaValidate(t *testing.T) { + testCases := []struct { + schema string + errors x.GqlErrorList + valid bool + }{ + { + schema: ` + type Task @auth( + query: { rule: "{$USERROLE: { eq: \"USER\"}}" } + ) { + id: ID! + name: String! + occurrences: [TaskOccurrence] @hasInverse(field: task) + } + + type TaskOccurrence @auth( + query: { rule: "query { queryTaskOccurrence { task { id } } }" } + ) { + id: ID! + due: DateTime + comp: DateTime + task: Task @hasInverse(field: occurrences) + } + `, + valid: true, + }, + { + schema: ` + type X { + id: ID @dgraph(pred: "X.id") + name: String + } + type Y { + f1: String! @dgraph(pred:"~movie") + } + `, + errors: x.GqlErrorList{{Message: "input:3: Type X; Field id: has the @dgraph directive but fields of type ID can't have the @dgraph directive."}, {Message: "input:7: Type Y; Field f1 is of type String, but reverse predicate in @dgraph directive only applies to fields with object types."}}, + valid: false, + }, + } + + dg, err := testutil.DgraphClient(groupOnegRPC) + require.NoError(t, err) + testutil.DropAll(t, dg) + + validateUrl := groupOneAdminServer + "/schema/validate" + var response x.QueryResWithData + for _, tcase := range testCases { + resp, err := http.Post(validateUrl, "text/plain", bytes.NewBuffer([]byte(tcase.schema))) + require.NoError(t, err) + + decoder := json.NewDecoder(resp.Body) + err = decoder.Decode(&response) + require.NoError(t, err) + + // Verify that we only validate the schema and not set it. + verifyEmptySchema(t) + + if tcase.valid { + require.Equal(t, resp.StatusCode, http.StatusOK) + continue + } + require.Equal(t, resp.StatusCode, http.StatusBadRequest) + require.NotNil(t, response.Errors) + require.Equal(t, len(response.Errors), len(tcase.errors)) + for idx, err := range response.Errors { + require.Equal(t, err.Message, tcase.errors[idx].Message) + } + } +} + func updateGQLSchema(t *testing.T, schema, url string) *common.GraphQLResponse { req := &common.GraphQLParams{ Query: `mutation updateGQLSchema($sch: String!) { diff --git a/graphql/resolve/auth_test.go b/graphql/resolve/auth_test.go index fa2c79bc017..c303295fc72 100644 --- a/graphql/resolve/auth_test.go +++ b/graphql/resolve/auth_test.go @@ -162,6 +162,7 @@ func TestStringCustomClaim(t *testing.T) { require.NoError(t, err) test.LoadSchemaFromString(t, string(authSchema)) + testutil.SetAuthMeta(string(authSchema)) // Token with string custom claim // "https://xyz.io/jwt/claims": "{\"USER\": \"50950b40-262f-4b26-88a7-cbbb780b2176\", \"ROLE\": \"ADMIN\"}", @@ -187,6 +188,7 @@ func TestAudienceClaim(t *testing.T) { require.NoError(t, err) test.LoadSchemaFromString(t, string(authSchema)) + testutil.SetAuthMeta(string(authSchema)) // Verify that authorization information is set correctly. metainfo := authorization.GetAuthMeta() @@ -250,6 +252,7 @@ func TestJWTExpiry(t *testing.T) { require.NoError(t, err) test.LoadSchemaFromString(t, string(authSchema)) + testutil.SetAuthMeta(string(authSchema)) // Verify that authorization information is set correctly. metainfo := authorization.GetAuthMeta() @@ -707,6 +710,8 @@ func TestAuthQueryRewriting(t *testing.T) { strSchema := string(result) authMeta, err := authorization.Parse(strSchema) + authorization.SetAuthMeta(authMeta) + metaInfo := &testutil.AuthMeta{ PublicKey: authMeta.VerificationKey, Namespace: authMeta.Namespace, diff --git a/graphql/schema/schemagen.go b/graphql/schema/schemagen.go index c6bbe7ebb4b..d0f576a6c23 100644 --- a/graphql/schema/schemagen.go +++ b/graphql/schema/schemagen.go @@ -72,7 +72,7 @@ func (s *handler) DGSchema() string { return s.dgraphSchema } -func parseSecrets(sch string) (map[string]string, error) { +func parseSecrets(sch string) (map[string]string, *authorization.AuthMeta, error) { m := make(map[string]string) scanner := bufio.NewScanner(strings.NewReader(sch)) authSecret := "" @@ -81,7 +81,7 @@ func parseSecrets(sch string) (map[string]string, error) { if strings.HasPrefix(text, "# Dgraph.Authorization") { if authSecret != "" { - return nil, errors.Errorf("Dgraph.Authorization should be only be specified once in "+ + return nil, nil, errors.Errorf("Dgraph.Authorization should be only be specified once in "+ "a schema, found second mention: %v", text) } authSecret = text @@ -94,12 +94,12 @@ func parseSecrets(sch string) (map[string]string, error) { const doubleQuotesCode = 34 if len(parts) < 4 { - return nil, errors.Errorf("incorrect format for specifying Dgraph secret found for "+ + return nil, nil, errors.Errorf("incorrect format for specifying Dgraph secret found for "+ "comment: `%s`, it should be `# Dgraph.Secret key value`", text) } val := strings.Join(parts[3:], " ") if strings.Count(val, `"`) != 2 || val[0] != doubleQuotesCode || val[len(val)-1] != doubleQuotesCode { - return nil, errors.Errorf("incorrect format for specifying Dgraph secret found for "+ + return nil, nil, errors.Errorf("incorrect format for specifying Dgraph secret found for "+ "comment: `%s`, it should be `# Dgraph.Secret key value`", text) } @@ -109,23 +109,28 @@ func parseSecrets(sch string) (map[string]string, error) { } if err := scanner.Err(); err != nil { - return nil, errors.Wrapf(err, "while trying to parse secrets from schema file") + return nil, nil, errors.Wrapf(err, "while trying to parse secrets from schema file") } + if authSecret == "" { - return m, nil + return m, nil, nil } - err := authorization.ParseAuthMeta(authSecret) - return m, err + + metaInfo, err := authorization.ParseAuthMeta(authSecret) + if err != nil { + return nil, nil, err + } + return m, metaInfo, nil } // NewHandler processes the input schema. If there are no errors, it returns // a valid Handler, otherwise it returns nil and an error. -func NewHandler(input string) (Handler, error) { +func NewHandler(input string, validateOnly bool) (Handler, error) { if input == "" { return nil, gqlerror.Errorf("No schema specified") } - secrets, err := parseSecrets(input) + secrets, metaInfo, err := parseSecrets(input) if err != nil { return nil, err } @@ -201,7 +206,12 @@ func NewHandler(input string) (Handler, error) { return nil, gqlErrList } - headers := getAllowedHeaders(sch, defns) + var authHeader string + if metaInfo != nil { + authHeader = metaInfo.Header + } + + headers := getAllowedHeaders(sch, defns, authHeader) dgSchema := genDgSchema(sch, typesToComplete) completeSchema(sch, typesToComplete) @@ -209,17 +219,27 @@ func NewHandler(input string) (Handler, error) { return nil, gqlerror.Errorf("No query or mutation found in the generated schema") } + handler := &handler{ + input: input, + dgraphSchema: dgSchema, + completeSchema: sch, + originalDefs: defns, + } + + // Return early since we are only validating the schema. + if validateOnly { + return handler, nil + } + hc.Lock() hc.allowed = headers hc.secrets = schemaSecrets hc.Unlock() - return &handler{ - input: input, - dgraphSchema: dgSchema, - completeSchema: sch, - originalDefs: defns, - }, nil + if metaInfo != nil { + authorization.SetAuthMeta(metaInfo) + } + return handler, nil } type headersConfig struct { @@ -237,7 +257,7 @@ var hc = headersConfig{ allowed: x.AccessControlAllowedHeaders, } -func getAllowedHeaders(sch *ast.Schema, definitions []string) string { +func getAllowedHeaders(sch *ast.Schema, definitions []string, authHeader string) string { headers := make(map[string]struct{}) setHeaders := func(dir *ast.Directive) { @@ -278,8 +298,8 @@ func getAllowedHeaders(sch *ast.Schema, definitions []string) string { } // Add Auth Header to allowed headers list - if authorization.GetHeader() != "" { - finalHeaders = append(finalHeaders, authorization.GetHeader()) + if authHeader != "" { + finalHeaders = append(finalHeaders, authHeader) } allowed := x.AccessControlAllowedHeaders diff --git a/graphql/schema/schemagen_test.go b/graphql/schema/schemagen_test.go index 8c6dd0e0ece..dbcbebc2ec6 100644 --- a/graphql/schema/schemagen_test.go +++ b/graphql/schema/schemagen_test.go @@ -52,7 +52,7 @@ func TestDGSchemaGen(t *testing.T) { for _, sch := range schemas { t.Run(sch.Name, func(t *testing.T) { - schHandler, errs := NewHandler(sch.Input) + schHandler, errs := NewHandler(sch.Input, false) require.NoError(t, errs) dgSchema := schHandler.DGSchema() @@ -80,7 +80,7 @@ func TestSchemaString(t *testing.T) { str1, err := ioutil.ReadFile(inputFileName) require.NoError(t, err) - schHandler, errs := NewHandler(string(str1)) + schHandler, errs := NewHandler(string(str1), false) require.NoError(t, errs) newSchemaStr := schHandler.GQLSchema() @@ -111,7 +111,7 @@ func TestSchemas(t *testing.T) { t.Run("Valid Schemas", func(t *testing.T) { for _, sch := range tests["valid_schemas"] { t.Run(sch.Name, func(t *testing.T) { - schHandler, errlist := NewHandler(sch.Input) + schHandler, errlist := NewHandler(sch.Input, false) require.NoError(t, errlist, sch.Name) newSchemaStr := schHandler.GQLSchema() @@ -125,7 +125,7 @@ func TestSchemas(t *testing.T) { t.Run("Invalid Schemas", func(t *testing.T) { for _, sch := range tests["invalid_schemas"] { t.Run(sch.Name, func(t *testing.T) { - _, errlist := NewHandler(sch.Input) + _, errlist := NewHandler(sch.Input, false) if diff := cmp.Diff(sch.Errlist, errlist); diff != "" { t.Errorf("error mismatch (-want +got):\n%s", diff) } @@ -151,7 +151,7 @@ func TestAuthSchemas(t *testing.T) { t.Run("Valid Schemas", func(t *testing.T) { for _, sch := range tests["valid_schemas"] { t.Run(sch.Name, func(t *testing.T) { - schHandler, errlist := NewHandler(sch.Input) + schHandler, errlist := NewHandler(sch.Input, false) require.NoError(t, errlist, sch.Name) _, authError := FromString(schHandler.GQLSchema()) @@ -163,7 +163,7 @@ func TestAuthSchemas(t *testing.T) { t.Run("Invalid Schemas", func(t *testing.T) { for _, sch := range tests["invalid_schemas"] { t.Run(sch.Name, func(t *testing.T) { - schHandler, errlist := NewHandler(sch.Input) + schHandler, errlist := NewHandler(sch.Input, false) require.NoError(t, errlist, sch.Name) _, authError := FromString(schHandler.GQLSchema()) @@ -302,7 +302,7 @@ func TestOnlyCorrectSearchArgsWork(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - _, errlist := NewHandler(test.schema) + _, errlist := NewHandler(test.schema, false) require.Len(t, errlist, test.expectedErrors, "every field in this test applies @search wrongly and should raise an error") }) diff --git a/graphql/schema/wrappers_test.go b/graphql/schema/wrappers_test.go index bffeb40b707..5537480ed12 100644 --- a/graphql/schema/wrappers_test.go +++ b/graphql/schema/wrappers_test.go @@ -22,7 +22,6 @@ import ( "strings" "testing" - "github.com/dgraph-io/dgraph/graphql/authorization" "github.com/google/go-cmp/cmp" "github.com/pkg/errors" "github.com/stretchr/testify/require" @@ -84,7 +83,7 @@ type Starship { length: Float }` - schHandler, errs := NewHandler(schemaStr) + schHandler, errs := NewHandler(schemaStr, false) require.NoError(t, errs) sch, err := FromString(schHandler.GQLSchema()) require.NoError(t, err) @@ -206,7 +205,7 @@ func TestDgraphMapping_WithDirectives(t *testing.T) { length: Float }` - schHandler, errs := NewHandler(schemaStr) + schHandler, errs := NewHandler(schemaStr, false) require.NoError(t, errs) sch, err := FromString(schHandler.GQLSchema()) require.NoError(t, err) @@ -730,7 +729,7 @@ func TestGraphQLQueryInCustomHTTPConfig(t *testing.T) { for _, tcase := range tests { t.Run(tcase.Name, func(t *testing.T) { - schHandler, errs := NewHandler(tcase.GQLSchema) + schHandler, errs := NewHandler(tcase.GQLSchema, false) require.NoError(t, errs) sch, err := FromString(schHandler.GQLSchema()) require.NoError(t, err) @@ -770,7 +769,7 @@ func TestGraphQLQueryInCustomHTTPConfig(t *testing.T) { c, err := field.CustomHTTPConfig() require.NoError(t, err) - remoteSchemaHandler, errs := NewHandler(tcase.RemoteSchema) + remoteSchemaHandler, errs := NewHandler(tcase.RemoteSchema, false) require.NoError(t, errs) remoteSchema, err := FromString(remoteSchemaHandler.GQLSchema()) require.NoError(t, err) @@ -830,7 +829,7 @@ func TestAllowedHeadersList(t *testing.T) { } for _, test := range tcases { t.Run(test.name, func(t *testing.T) { - schHandler, errs := NewHandler(test.schemaStr) + schHandler, errs := NewHandler(test.schemaStr, false) require.NoError(t, errs) _, err := FromString(schHandler.GQLSchema()) require.NoError(t, err) @@ -913,7 +912,7 @@ func TestCustomLogicHeaders(t *testing.T) { } for _, test := range tcases { t.Run(test.name, func(t *testing.T) { - _, err := NewHandler(test.schemaStr) + _, err := NewHandler(test.schemaStr, false) require.EqualError(t, err, test.err.Error()) }) } @@ -1069,15 +1068,15 @@ func TestParseSecrets(t *testing.T) { } for _, test := range tcases { t.Run(test.name, func(t *testing.T) { - s, err := parseSecrets(test.schemaStr) + s, authMeta, err := parseSecrets(test.schemaStr) if test.err != nil || err != nil { require.EqualError(t, err, test.err.Error()) return } - require.Equal(t, test.expectedSecrets, s) if test.expectedAuthHeader != "" { - require.Equal(t, test.expectedAuthHeader, authorization.GetHeader()) + require.NotNil(t, authMeta) + require.Equal(t, test.expectedAuthHeader, authMeta.Header) } }) } diff --git a/graphql/test/test.go b/graphql/test/test.go index 62210bb4d51..b1ece7a01ca 100644 --- a/graphql/test/test.go +++ b/graphql/test/test.go @@ -56,7 +56,7 @@ func LoadSchemaFromFile(t *testing.T, gqlFile string) schema.Schema { } func LoadSchemaFromString(t *testing.T, sch string) schema.Schema { - handler, err := schema.NewHandler(string(sch)) + handler, err := schema.NewHandler(string(sch), false) requireNoGQLErrors(t, err) return LoadSchema(t, handler.GQLSchema()) diff --git a/testutil/graphql.go b/testutil/graphql.go index 8e68f13790d..2582ad53460 100644 --- a/testutil/graphql.go +++ b/testutil/graphql.go @@ -20,6 +20,7 @@ import ( "bytes" "context" "encoding/json" + "github.com/dgraph-io/dgraph/graphql/authorization" "io/ioutil" "net/http" "testing" @@ -222,3 +223,12 @@ func AppendAuthInfo(schema []byte, algo, publicKeyFile string) ([]byte, error) { authInfo := `# Dgraph.Authorization {"VerificationKey":"` + string(keyData) + `","Header":"X-Test-Auth","Namespace":"https://xyz.io/jwt/claims","Algo":"RS256","Audience":["aud1","63do0q16n6ebjgkumu05kkeian","aud5"]}` return append(schema, []byte(authInfo)...), nil } + +func SetAuthMeta(strSchema string) *authorization.AuthMeta { + authMeta, err := authorization.Parse(strSchema) + if err != nil { + panic(err) + } + authorization.SetAuthMeta(authMeta) + return authMeta +} diff --git a/x/x.go b/x/x.go index d8d0ab28450..7412975be92 100644 --- a/x/x.go +++ b/x/x.go @@ -302,6 +302,22 @@ func SetStatus(w http.ResponseWriter, code, msg string) { } } +func SetStatusWithErrors(w http.ResponseWriter, code string, errs []string) { + var qr queryRes + ext := make(map[string]interface{}) + ext["code"] = code + for _, err := range errs { + qr.Errors = append(qr.Errors, &GqlError{Message: err, Extensions: ext}) + } + if js, err := json.Marshal(qr); err == nil { + if _, err := w.Write(js); err != nil { + glog.Errorf("Error while writing: %+v", err) + } + } else { + Panic(errors.Errorf("Unable to marshal: %+v", qr)) + } +} + // SetHttpStatus is similar to SetStatus but sets a proper HTTP status code // in the response instead of always returning HTTP 200 (OK). func SetHttpStatus(w http.ResponseWriter, code int, msg string) {