Skip to content

Commit

Permalink
Reduce allocations, chapter III (#7222)
Browse files Browse the repository at this point in the history
My last PR for a while in the ongoing "reduce allocations in eval" quest.
Motivated initially mostly to speed up `regal lint`, but most of the changes
here positively impacts evaluation performance for most policies.

The changes with the highest impact in this PR:

* Use `sync.Pool`s to avoid the most costly allocations, includuing heavy `*eval`
  pointers created each time a child or closure scope is evaluated.
* When tracing is disabled, avoid variable escaping to heap in `evalStep` function
  whose value is only read when tracing is enabled.
* Save one allocation per iteration in `walkNoPath` by reusing an AST array instead
  of creating a new one for each call.

Also a few minor fixes here and there which either fixed some correctness issue, or
had a measurable (although minor) positive impact on performance.

**regal lint bundle (main)**
```
BenchmarkRegalLintingItself-10   1	2015560750 ns/op	4335625360 B/op	83728460 allocs/op
```

**regal lint bundle (now)**
```
BenchmarkRegalLintingItself-10   1	1828754125 ns/op	3541027496 B/op	70080568 allocs/op
```

About 10% faster eval, with almost a gigabyte less memory allocated, and 13 million+ allocations
less performed.

Another topic discussed recently has been the cost of calling custom functions in hot paths.
While this PR doesn't address that problem fully, the benefits of the change is still quite
noticeable. A benchmark for that case specifically is also included in the PR, and the change
compared to main as noted below:

**main**
```
BenchmarkCustomFunctionInHotPath-10    	55  18543908 ns/op  20821043 B/op  284611 allocs/op
```

**pr**
```
BenchmarkCustomFunctionInHotPath-10    	73  16247587 ns/op  13048108 B/op  228406 allocs/op
```

It's worth noting however that this benchmark benefits "unfairly" by the improvements made
in the `walkNoPath` function, and perhaps more so than custom function evaluation getting
that much more efficient.

Signed-off-by: Anders Eknert <[email protected]>
  • Loading branch information
anderseknert authored Dec 18, 2024
1 parent 2ddcade commit 50b5ee5
Show file tree
Hide file tree
Showing 17 changed files with 8,124 additions and 189 deletions.
28 changes: 18 additions & 10 deletions v1/ast/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ type typeChecker struct {
// newTypeChecker returns a new typeChecker object that has no errors.
func newTypeChecker() *typeChecker {
return &typeChecker{
builtins: make(map[string]*Builtin),
schemaTypes: make(map[string]types.Type),
exprCheckers: map[string]exprChecker{
"eq": checkExprEq,
},
Expand All @@ -62,6 +60,7 @@ func (tc *typeChecker) copy() *typeChecker {
return newTypeChecker().
WithVarRewriter(tc.varRewriter).
WithSchemaSet(tc.ss).
WithSchemaTypes(tc.schemaTypes).
WithAllowNet(tc.allowNet).
WithInputType(tc.input).
WithAllowUndefinedFunctionCalls(tc.allowUndefinedFuncs).
Expand All @@ -84,6 +83,11 @@ func (tc *typeChecker) WithSchemaSet(ss *SchemaSet) *typeChecker {
return tc
}

func (tc *typeChecker) WithSchemaTypes(schemaTypes map[string]types.Type) *typeChecker {
tc.schemaTypes = schemaTypes
return tc
}

func (tc *typeChecker) WithAllowNet(hosts []string) *typeChecker {
tc.allowNet = hosts
return tc
Expand Down Expand Up @@ -124,6 +128,7 @@ func (tc *typeChecker) CheckBody(env *TypeEnv, body Body) (*TypeEnv, Errors) {

errors := []*Error{}
env = tc.newEnv(env)
vis := newRefChecker(env, tc.varRewriter)

WalkExprs(body, func(expr *Expr) bool {

Expand All @@ -134,7 +139,8 @@ func (tc *typeChecker) CheckBody(env *TypeEnv, body Body) (*TypeEnv, Errors) {

hasClosureErrors := len(closureErrs) > 0

vis := newRefChecker(env, tc.varRewriter)
// reset errors from previous iteration
vis.errs = nil
NewGenericVisitor(vis.Visit).Walk(expr)
for _, err := range vis.errs {
errors = append(errors, err)
Expand Down Expand Up @@ -200,6 +206,10 @@ func (tc *typeChecker) checkClosures(env *TypeEnv, expr *Expr) Errors {
}

func (tc *typeChecker) getSchemaType(schemaAnnot *SchemaAnnotation, rule *Rule) (types.Type, *Error) {
if tc.schemaTypes == nil {
tc.schemaTypes = make(map[string]types.Type)
}

if refType, exists := tc.schemaTypes[schemaAnnot.Schema.String()]; exists {
return refType, nil
}
Expand Down Expand Up @@ -353,7 +363,7 @@ func (tc *typeChecker) checkExpr(env *TypeEnv, expr *Expr) *Error {
// If the type checker wasn't provided with a required capabilities
// structure then just skip. In some cases, type checking might be run
// without the need to record what builtins are required.
if tc.required != nil {
if tc.required != nil && tc.builtins != nil {
if bi, ok := tc.builtins[operator]; ok {
tc.required.addBuiltinSorted(bi)
}
Expand Down Expand Up @@ -433,14 +443,13 @@ func (tc *typeChecker) checkExprBuiltin(env *TypeEnv, expr *Expr) *Error {
func checkExprEq(env *TypeEnv, expr *Expr) *Error {

pre := getArgTypes(env, expr.Operands())
exp := Equality.Decl.FuncArgs()

if len(pre) < len(exp.Args) {
return newArgError(expr.Location, expr.Operator(), "too few arguments", pre, exp)
if len(pre) < Equality.Decl.Arity() {
return newArgError(expr.Location, expr.Operator(), "too few arguments", pre, Equality.Decl.FuncArgs())
}

if len(exp.Args) < len(pre) {
return newArgError(expr.Location, expr.Operator(), "too many arguments", pre, exp)
if Equality.Decl.Arity() < len(pre) {
return newArgError(expr.Location, expr.Operator(), "too many arguments", pre, Equality.Decl.FuncArgs())
}

a, b := expr.Operand(0), expr.Operand(1)
Expand Down Expand Up @@ -684,7 +693,6 @@ func rewriteVarsNop(node Ref) Ref {
}

func newRefChecker(env *TypeEnv, f varRewriter) *refChecker {

if f == nil {
f = rewriteVarsNop
}
Expand Down
18 changes: 15 additions & 3 deletions v1/ast/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -1351,6 +1351,16 @@ func (expr *Expr) Complement() *Expr {
return &cpy
}

// ComplementNoWith returns a copy of this expression with the negation flag flipped
// and the with modifier removed. This is the same as calling .Complement().NoWith()
// but without making an intermediate copy.
func (expr *Expr) ComplementNoWith() *Expr {
cpy := *expr
cpy.Negated = !cpy.Negated
cpy.With = nil
return &cpy
}

// Equal returns true if this Expr equals the other Expr.
func (expr *Expr) Equal(other *Expr) bool {
return expr.Compare(other) == 0
Expand Down Expand Up @@ -1441,9 +1451,11 @@ func (expr *Expr) sortOrder() int {
func (expr *Expr) CopyWithoutTerms() *Expr {
cpy := *expr

cpy.With = make([]*With, len(expr.With))
for i := range expr.With {
cpy.With[i] = expr.With[i].Copy()
if expr.With != nil {
cpy.With = make([]*With, len(expr.With))
for i := range expr.With {
cpy.With[i] = expr.With[i].Copy()
}
}

return &cpy
Expand Down
21 changes: 14 additions & 7 deletions v1/ast/term.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,30 +71,37 @@ func InterfaceToValue(x interface{}) (Value, error) {
return intNumber(x), nil
case string:
return String(x), nil
case []interface{}:
r := make([]*Term, len(x))
case []any:
r := util.NewPtrSlice[Term](len(x))
for i, e := range x {
e, err := InterfaceToValue(e)
if err != nil {
return nil, err
}
r[i] = &Term{Value: e}
r[i].Value = e
}
return NewArray(r...), nil
case map[string]interface{}:
r := newobject(len(x))
case map[string]any:
kvs := util.NewPtrSlice[Term](len(x) * 2)
idx := 0
for k, v := range x {
k, err := InterfaceToValue(k)
if err != nil {
return nil, err
}
kvs[idx].Value = k
v, err := InterfaceToValue(v)
if err != nil {
return nil, err
}
r.Insert(NewTerm(k), NewTerm(v))
kvs[idx+1].Value = v
idx += 2
}
return r, nil
tuples := make([][2]*Term, len(kvs)/2)
for i := 0; i < len(kvs); i += 2 {
tuples[i/2] = *(*[2]*Term)(kvs[i : i+2])
}
return NewObject(tuples...), nil
case map[string]string:
r := newobject(len(x))
for k, v := range x {
Expand Down
10 changes: 5 additions & 5 deletions v1/ast/visit.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,14 +357,14 @@ func (vis *GenericVisitor) Walk(x interface{}) {
vis.Walk(x.Get(k))
})
case Object:
x.Foreach(func(k, _ *Term) {
for _, k := range x.Keys() {
vis.Walk(k)
vis.Walk(x.Get(k))
})
}
case *Array:
x.Foreach(func(t *Term) {
vis.Walk(t)
})
for i := 0; i < x.Len(); i++ {
vis.Walk(x.Elem(i))
}
case Set:
xSlice := x.Slice()
for i := range xSlice {
Expand Down
3 changes: 0 additions & 3 deletions v1/rego/rego.go
Original file line number Diff line number Diff line change
Expand Up @@ -1735,9 +1735,6 @@ func (r *Rego) PrepareForEval(ctx context.Context, opts ...PrepareOption) (Prepa
}

txnErr := txnClose(ctx, err) // Always call closer
if err != nil {
return PreparedEvalQuery{}, err
}
if txnErr != nil {
return PreparedEvalQuery{}, txnErr
}
Expand Down
51 changes: 51 additions & 0 deletions v1/rego/rego_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package rego

import (
"context"
"encoding/json"
"fmt"
"os"
"testing"

"github.com/open-policy-agent/opa/internal/runtime"
Expand Down Expand Up @@ -68,3 +70,52 @@ func BenchmarkPartialObjectRuleCrossModule(b *testing.B) {
})
}
}

func BenchmarkCustomFunctionInHotPath(b *testing.B) {
ctx := context.Background()

bs, err := os.ReadFile("testdata/ast.json")
if err != nil {
b.Fatal(err)
}

input := ast.MustParseTerm(string(bs))
module := ast.MustParseModule(`package test
import rego.v1
r := count(refs)
refs contains value if {
walk(input, [_, value])
is_ref(value)
}
is_ref(value) if value.type == "ref"
is_ref(value) if value[0].type == "ref"`)

r := New(Query("data.test.r = x"), ParsedModule(module))

pq, err := r.PrepareForEval(ctx)
if err != nil {
b.Fatal(err)
}

b.ResetTimer()
b.ReportAllocs()

for i := 0; i < b.N; i++ {
res, err := pq.Eval(ctx, EvalParsedInput(input.Value))
if err != nil {
b.Fatal(err)
}

if res == nil {
b.Fatal("expected result")
}

if res[0].Bindings["x"].(json.Number) != "402" {
b.Fatalf("expected 402, got %v", res[0].Bindings["x"])
}
}
}
Loading

0 comments on commit 50b5ee5

Please sign in to comment.