Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(/commit): protect the commit endpoint via acl #7608

Merged
merged 9 commits into from
Mar 24, 2021
Merged
23 changes: 11 additions & 12 deletions dgraph/cmd/alpha/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,11 @@ import (

"github.com/dgraph-io/dgraph/graphql/admin"

"github.com/dgraph-io/dgo/v200"
"github.com/dgraph-io/dgo/v200/protos/api"
"github.com/dgraph-io/dgraph/edgraph"
"github.com/dgraph-io/dgraph/gql"
"github.com/dgraph-io/dgraph/graphql/schema"
"github.com/dgraph-io/dgraph/query"
"github.com/dgraph-io/dgraph/worker"
"github.com/dgraph-io/dgraph/x"
"github.com/gogo/protobuf/jsonpb"
"github.com/golang/glog"
Expand Down Expand Up @@ -471,17 +469,18 @@ func commitHandler(w http.ResponseWriter, r *http.Request) {
return
}

ctx := x.AttachAccessJwt(context.Background(), r)
var response map[string]interface{}
if abort {
response, err = handleAbort(startTs)
response, err = handleAbort(ctx, startTs)
} else {
// Keys are sent as an array in the body.
reqText := readRequest(w, r)
if reqText == nil {
return
}

response, err = handleCommit(startTs, reqText)
response, err = handleCommit(ctx, startTs, reqText)
}
if err != nil {
x.SetStatus(w, x.ErrorInvalidRequest, err.Error())
Expand All @@ -497,27 +496,28 @@ func commitHandler(w http.ResponseWriter, r *http.Request) {
_, _ = x.WriteResponse(w, r, js)
}

func handleAbort(startTs uint64) (map[string]interface{}, error) {
func handleAbort(ctx context.Context, startTs uint64) (map[string]interface{}, error) {
tc := &api.TxnContext{
StartTs: startTs,
Aborted: true,
}

_, err := worker.CommitOverNetwork(context.Background(), tc)
switch err {
case dgo.ErrAborted:
tctx, err := (&edgraph.Server{}).CommitOrAbort(ctx, tc)
switch {
case tctx.Aborted:
return map[string]interface{}{
"code": x.Success,
"message": "Done",
}, nil
case nil:
case err == nil:
return nil, errors.Errorf("transaction could not be aborted")
default:
return nil, err
}
}

func handleCommit(startTs uint64, reqText []byte) (map[string]interface{}, error) {
func handleCommit(ctx context.Context, startTs uint64, reqText []byte) (map[string]interface{},
error) {
tc := &api.TxnContext{
StartTs: startTs,
}
Expand All @@ -540,14 +540,13 @@ func handleCommit(startTs uint64, reqText []byte) (map[string]interface{}, error
tc.Preds = reqMap["preds"]
}

cts, err := worker.CommitOverNetwork(context.Background(), tc)
tc, err := (&edgraph.Server{}).CommitOrAbort(ctx, tc)
if err != nil {
return nil, err
}

resp := &api.Response{}
resp.Txn = tc
resp.Txn.CommitTs = cts
e := query.Extensions{
Txn: resp.Txn,
}
Expand Down
35 changes: 34 additions & 1 deletion edgraph/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,33 @@ func authorizeRequest(ctx context.Context, qc *queryContext) error {
return nil
}

func validateNamespace(ctx context.Context, preds []string) error {
ns, err := x.ExtractJWTNamespace(ctx)
if err != nil {
return err
}

// Do a basic validation that all the predicates passed in transaction context matches the
// claimed namespace and user is not accidently commiting a transaction that it did not create.
for _, pred := range preds {
// Format for Preds in TxnContext is gid-<namespace><pred> (see fillPreds in posting pkg)
splits := strings.Split(pred, "-")
if len(splits) < 2 {
return errors.Errorf("Unable to find group id in %s", pred)
}
pred = strings.Join(splits[1:], "-")
if len(pred) < 8 {
return errors.Errorf("found invalid pred %s of length < 8 in transaction context", pred)
}
if parsedNs := x.ParseNamespace(pred); parsedNs != ns {
return errors.Errorf("Please login into correct namespace. "+
"Currently logged in namespace %#x", ns)
}
}

return nil
}

// CommitOrAbort commits or aborts a transaction.
func (s *Server) CommitOrAbort(ctx context.Context, tc *api.TxnContext) (*api.TxnContext, error) {
ctx, span := otrace.StartSpan(ctx, "Server.CommitOrAbort")
Expand All @@ -1480,6 +1507,12 @@ func (s *Server) CommitOrAbort(ctx context.Context, tc *api.TxnContext) (*api.Tx
return &api.TxnContext{}, err
}

if x.WorkerConfig.AclEnabled {
if err := validateNamespace(ctx, tc.Preds); err != nil {
return &api.TxnContext{}, err
}
}

tctx := &api.TxnContext{}
if tc.StartTs == 0 {
return &api.TxnContext{}, errors.Errorf(
Expand All @@ -1492,11 +1525,11 @@ func (s *Server) CommitOrAbort(ctx context.Context, tc *api.TxnContext) (*api.Tx
if err == dgo.ErrAborted {
// If err returned is dgo.ErrAborted and tc.Aborted was set, that means the client has
// aborted the transaction by calling txn.Discard(). Hence return a nil error.
tctx.Aborted = true
if tc.Aborted {
return tctx, nil
}

tctx.Aborted = true
return tctx, status.Errorf(codes.Aborted, err.Error())
}
tctx.StartTs = tc.StartTs
Expand Down