Skip to content

Commit

Permalink
fix(login): fix login based on refresh token logic (#7637)
Browse files Browse the repository at this point in the history
We were fetching the namespace from the api.LoginRequest.Namespace even if the user has sent the refresh token. We should not do that and extract the namespace from the refresh token itself.
  • Loading branch information
NamanJain8 authored and aman-bansal committed Mar 25, 2021
1 parent f091ce5 commit 21509c9
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 22 deletions.
52 changes: 31 additions & 21 deletions edgraph/access_ee.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ func (s *Server) Login(ctx context.Context,
glog.Infof("%s logged in successfully", user.UserID)

resp := &api.Response{}
accessJwt, err := getAccessJwt(user.UserID, user.Groups, request.Namespace)
accessJwt, err := getAccessJwt(user.UserID, user.Groups, user.Namespace)
if err != nil {
errMsg := fmt.Sprintf("unable to get access jwt (userid=%s,addr=%s):%v",
user.UserID, addr, err)
glog.Errorf(errMsg)
return nil, errors.Errorf(errMsg)
}

refreshJwt, err := getRefreshJwt(user.UserID, request.Namespace)
refreshJwt, err := getRefreshJwt(user.UserID, user.Namespace)
if err != nil {
errMsg := fmt.Sprintf("unable to get refresh jwt (userid=%s,addr=%s):%v",
user.UserID, addr, err)
Expand Down Expand Up @@ -122,9 +122,6 @@ func (s *Server) authenticateLogin(ctx context.Context, request *api.LoginReques
return nil, errors.Wrapf(err, "invalid login request")
}

// In case of login, we can't extract namespace from JWT because we have not yet given JWT
// to the user, so the login request should contain the namespace, which is then set to ctx.
ctx = x.AttachNamespace(ctx, request.Namespace)
var user *acl.User
if len(request.RefreshToken) > 0 {
userData, err := validateToken(request.RefreshToken)
Expand All @@ -133,7 +130,8 @@ func (s *Server) authenticateLogin(ctx context.Context, request *api.LoginReques
request.RefreshToken)
}

userId := userData[0]
userId := userData.userId
ctx = x.AttachNamespace(ctx, userData.namespace)
user, err = authorizeUser(ctx, userId, "")
if err != nil {
return nil, errors.Wrapf(err, "while querying user with id %v", userId)
Expand All @@ -143,10 +141,15 @@ func (s *Server) authenticateLogin(ctx context.Context, request *api.LoginReques
return nil, errors.Errorf("unable to authenticate: invalid credentials")
}

user.Namespace = userData.namespace
glog.Infof("Authenticated user %s through refresh token", userId)
return user, nil
}

// In case of login, we can't extract namespace from JWT because we have not yet given JWT
// to the user, so the login request should contain the namespace, which is then set to ctx.
ctx = x.AttachNamespace(ctx, request.Namespace)

// authorize the user using password
var err error
user, err = authorizeUser(ctx, request.Userid, request.Password)
Expand All @@ -161,13 +164,20 @@ func (s *Server) authenticateLogin(ctx context.Context, request *api.LoginReques
if !user.PasswordMatch {
return nil, x.ErrorInvalidLogin
}
user.Namespace = request.Namespace
return user, nil
}

type userData struct {
namespace uint64
userId string
groupIds []string
}

// validateToken verifies the signature and expiration of the jwt, and if validation passes,
// returns a slice of strings, where the first element is the extracted userId
// and the rest are groupIds encoded in the jwt.
func validateToken(jwtStr string) ([]string, error) {
func validateToken(jwtStr string) (*userData, error) {
claims, err := x.ParseJWT(jwtStr)
if err != nil {
return nil, err
Expand Down Expand Up @@ -203,7 +213,7 @@ func validateToken(jwtStr string) ([]string, error) {
groupIds = append(groupIds, groupId)
}
}
return append([]string{userId}, groupIds...), nil
return &userData{namespace: uint64(namespace), userId: userId, groupIds: groupIds}, nil
}

// validateLoginRequest validates that the login request has either the refresh token or the
Expand Down Expand Up @@ -574,7 +584,7 @@ func upsertGroot(ctx context.Context, passwd string) error {
}

// extract the userId, groupIds from the accessJwt in the context
func extractUserAndGroups(ctx context.Context) ([]string, error) {
func extractUserAndGroups(ctx context.Context) (*userData, error) {
accessJwt, err := x.ExtractJwt(ctx)
if err != nil {
return nil, err
Expand All @@ -587,15 +597,15 @@ type authPredResult struct {
blocked map[string]struct{}
}

func authorizePreds(ctx context.Context, userData, preds []string,
func authorizePreds(ctx context.Context, userData *userData, preds []string,
aclOp *acl.Operation) (*authPredResult, error) {

ns, err := x.ExtractNamespace(ctx)
if err != nil {
return nil, errors.Wrapf(err, "While authorizing preds")
}
userId := userData[0]
groupIds := userData[1:]
userId := userData.userId
groupIds := userData.groupIds
blockedPreds := make(map[string]struct{})
for _, pred := range preds {
nsPred := x.NamespaceAttr(ns, pred)
Expand Down Expand Up @@ -662,8 +672,8 @@ func authorizeAlter(ctx context.Context, op *api.Operation) error {
return status.Error(codes.Unauthenticated, err.Error())
}

userId = userData[0]
groupIds = userData[1:]
userId = userData.userId
groupIds = userData.groupIds

if x.IsGuardian(groupIds) {
// Members of guardian group are allowed to alter anything.
Expand Down Expand Up @@ -776,8 +786,8 @@ func authorizeMutation(ctx context.Context, gmu *gql.Mutation) error {
return status.Error(codes.Unauthenticated, err.Error())
}

userId = userData[0]
groupIds = userData[1:]
userId = userData.userId
groupIds = userData.groupIds

if x.IsGuardian(groupIds) {
// Members of guardians group are allowed to mutate anything
Expand Down Expand Up @@ -929,8 +939,8 @@ func authorizeQuery(ctx context.Context, parsedReq *gql.Result, graphql bool) er
return nil, nil, status.Error(codes.Unauthenticated, err.Error())
}

userId = userData[0]
groupIds = userData[1:]
userId = userData.userId
groupIds = userData.groupIds

if x.IsGuardian(groupIds) {
// Members of guardian groups are allowed to query anything.
Expand Down Expand Up @@ -1018,7 +1028,7 @@ func authorizeSchemaQuery(ctx context.Context, er *query.ExecutionResult) error
return nil, status.Error(codes.Unauthenticated, err.Error())
}

groupIds := userData[1:]
groupIds := userData.groupIds
if x.IsGuardian(groupIds) {
// Members of guardian groups are allowed to query anything.
return nil, nil
Expand Down Expand Up @@ -1094,8 +1104,8 @@ func AuthorizeGuardians(ctx context.Context) error {
case err != nil:
return status.Error(codes.Unauthenticated, err.Error())
default:
userId := userData[0]
groupIds := userData[1:]
userId := userData.userId
groupIds := userData.groupIds

if !x.IsGuardian(groupIds) {
// Deny access for members of non-guardian groups
Expand Down
1 change: 1 addition & 0 deletions ee/acl/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ type User struct {
Uid string `json:"uid"`
UserID string `json:"dgraph.xid"`
Password string `json:"dgraph.password"`
Namespace uint64 `json:"namespace"`
PasswordMatch bool `json:"password_match"`
Groups []Group `json:"dgraph.user.group"`
}
Expand Down
21 changes: 20 additions & 1 deletion systest/multi-tenancy/basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,26 @@ func TestPersistentQuery(t *testing.T) {
require.Contains(t, resp.Errors[0].Message, "no accessJwt available")
}

func TestTokenExpired(t *testing.T) {
prepare(t)
galaxyToken := testutil.Login(t,
&testutil.LoginParams{UserID: "groot", Passwd: "password", Namespace: x.GalaxyNamespace})

// Create a new namespace
ns, err := testutil.CreateNamespaceWithRetry(t, galaxyToken)
require.NoError(t, err)
token := testutil.Login(t,
&testutil.LoginParams{UserID: "groot", Passwd: "password", Namespace: ns})

// Relogin using refresh JWT.
token = testutil.Login(t,
&testutil.LoginParams{RefreshJwt: token.RefreshToken})
_, err = testutil.CreateNamespaceWithRetry(t, token)
require.Error(t, err)
require.Contains(t, err.Error(), "Only guardian of galaxy is allowed to do this operation")
}

func TestMain(m *testing.M) {
fmt.Printf("Using adminEndpoint : %s for multy-tenancy test.\n", testutil.AdminUrl())
fmt.Printf("Using adminEndpoint : %s for multi-tenancy test.\n", testutil.AdminUrl())
os.Exit(m.Run())
}

0 comments on commit 21509c9

Please sign in to comment.