Skip to content

Commit

Permalink
fix(/commit): protect the commit endpoint via acl (#7608)
Browse files Browse the repository at this point in the history
/commit endpoint was not ACL protected. In a multi-tenant system, it could be disastrous where a malicious user can commit or abort the transactions of any namespace. This PR partially fixes the issue.
  • Loading branch information
NamanJain8 authored Mar 24, 2021
1 parent d5299b9 commit 6c0e3aa
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 13 deletions.
23 changes: 11 additions & 12 deletions dgraph/cmd/alpha/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,11 @@ import (

"github.com/dgraph-io/dgraph/graphql/admin"

"github.com/dgraph-io/dgo/v200"
"github.com/dgraph-io/dgo/v200/protos/api"
"github.com/dgraph-io/dgraph/edgraph"
"github.com/dgraph-io/dgraph/gql"
"github.com/dgraph-io/dgraph/graphql/schema"
"github.com/dgraph-io/dgraph/query"
"github.com/dgraph-io/dgraph/worker"
"github.com/dgraph-io/dgraph/x"
"github.com/gogo/protobuf/jsonpb"
"github.com/golang/glog"
Expand Down Expand Up @@ -471,17 +469,18 @@ func commitHandler(w http.ResponseWriter, r *http.Request) {
return
}

ctx := x.AttachAccessJwt(context.Background(), r)
var response map[string]interface{}
if abort {
response, err = handleAbort(startTs)
response, err = handleAbort(ctx, startTs)
} else {
// Keys are sent as an array in the body.
reqText := readRequest(w, r)
if reqText == nil {
return
}

response, err = handleCommit(startTs, reqText)
response, err = handleCommit(ctx, startTs, reqText)
}
if err != nil {
x.SetStatus(w, x.ErrorInvalidRequest, err.Error())
Expand All @@ -497,27 +496,28 @@ func commitHandler(w http.ResponseWriter, r *http.Request) {
_, _ = x.WriteResponse(w, r, js)
}

func handleAbort(startTs uint64) (map[string]interface{}, error) {
func handleAbort(ctx context.Context, startTs uint64) (map[string]interface{}, error) {
tc := &api.TxnContext{
StartTs: startTs,
Aborted: true,
}

_, err := worker.CommitOverNetwork(context.Background(), tc)
switch err {
case dgo.ErrAborted:
tctx, err := (&edgraph.Server{}).CommitOrAbort(ctx, tc)
switch {
case tctx.Aborted:
return map[string]interface{}{
"code": x.Success,
"message": "Done",
}, nil
case nil:
case err == nil:
return nil, errors.Errorf("transaction could not be aborted")
default:
return nil, err
}
}

func handleCommit(startTs uint64, reqText []byte) (map[string]interface{}, error) {
func handleCommit(ctx context.Context, startTs uint64, reqText []byte) (map[string]interface{},
error) {
tc := &api.TxnContext{
StartTs: startTs,
}
Expand All @@ -540,14 +540,13 @@ func handleCommit(startTs uint64, reqText []byte) (map[string]interface{}, error
tc.Preds = reqMap["preds"]
}

cts, err := worker.CommitOverNetwork(context.Background(), tc)
tc, err := (&edgraph.Server{}).CommitOrAbort(ctx, tc)
if err != nil {
return nil, err
}

resp := &api.Response{}
resp.Txn = tc
resp.Txn.CommitTs = cts
e := query.Extensions{
Txn: resp.Txn,
}
Expand Down
35 changes: 34 additions & 1 deletion edgraph/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,33 @@ func authorizeRequest(ctx context.Context, qc *queryContext) error {
return nil
}

func validateNamespace(ctx context.Context, preds []string) error {
ns, err := x.ExtractJWTNamespace(ctx)
if err != nil {
return err
}

// Do a basic validation that all the predicates passed in transaction context matches the
// claimed namespace and user is not accidently commiting a transaction that it did not create.
for _, pred := range preds {
// Format for Preds in TxnContext is gid-<namespace><pred> (see fillPreds in posting pkg)
splits := strings.Split(pred, "-")
if len(splits) < 2 {
return errors.Errorf("Unable to find group id in %s", pred)
}
pred = strings.Join(splits[1:], "-")
if len(pred) < 8 {
return errors.Errorf("found invalid pred %s of length < 8 in transaction context", pred)
}
if parsedNs := x.ParseNamespace(pred); parsedNs != ns {
return errors.Errorf("Please login into correct namespace. "+
"Currently logged in namespace %#x", ns)
}
}

return nil
}

// CommitOrAbort commits or aborts a transaction.
func (s *Server) CommitOrAbort(ctx context.Context, tc *api.TxnContext) (*api.TxnContext, error) {
ctx, span := otrace.StartSpan(ctx, "Server.CommitOrAbort")
Expand All @@ -1480,6 +1507,12 @@ func (s *Server) CommitOrAbort(ctx context.Context, tc *api.TxnContext) (*api.Tx
return &api.TxnContext{}, err
}

if x.WorkerConfig.AclEnabled {
if err := validateNamespace(ctx, tc.Preds); err != nil {
return &api.TxnContext{}, err
}
}

tctx := &api.TxnContext{}
if tc.StartTs == 0 {
return &api.TxnContext{}, errors.Errorf(
Expand All @@ -1492,11 +1525,11 @@ func (s *Server) CommitOrAbort(ctx context.Context, tc *api.TxnContext) (*api.Tx
if err == dgo.ErrAborted {
// If err returned is dgo.ErrAborted and tc.Aborted was set, that means the client has
// aborted the transaction by calling txn.Discard(). Hence return a nil error.
tctx.Aborted = true
if tc.Aborted {
return tctx, nil
}

tctx.Aborted = true
return tctx, status.Errorf(codes.Aborted, err.Error())
}
tctx.StartTs = tc.StartTs
Expand Down

0 comments on commit 6c0e3aa

Please sign in to comment.