Skip to content

Commit

Permalink
Additional nil-safety checks with corresponding test updates (#1073)
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones authored Nov 19, 2024
1 parent 72e0977 commit ba74bf6
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 15 deletions.
21 changes: 19 additions & 2 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,14 @@ func TestResidualAstMacros(t *testing.T) {
}
}

func TestResidualAstNil(t *testing.T) {
env := testEnv(t)
ast, err := env.ResidualAst(nil, nil)
if err == nil || !strings.Contains(err.Error(), "unsupported expr") {
t.Errorf("env.ResidualAst() got (%v, %v) wanted unsupported expr error", ast, err)
}
}

func BenchmarkEvalOptions(b *testing.B) {
env := testEnv(b,
Variable("ai", IntType),
Expand Down Expand Up @@ -1323,7 +1331,7 @@ func TestEnvExtensionIsolation(t *testing.T) {
func TestVariadicLogicalOperators(t *testing.T) {
env := testEnv(t, variadicLogicalOperatorASTs())
ast, iss := env.Compile(
`(false || false || false || false || true) &&
`(false || false || false || false || true) &&
(true && true && true && true && false)`)
if iss.Err() != nil {
t.Fatalf("Compile() failed: %v", iss.Err())
Expand Down Expand Up @@ -2293,7 +2301,7 @@ func TestOptionalValuesCompile(t *testing.T) {
if iss.Err() != nil {
t.Fatalf("%v failed: %v", tc.expr, iss.Err())
}
for id, reference := range ast.impl.ReferenceMap() {
for id, reference := range ast.NativeRep().ReferenceMap() {
other, found := tc.references[id]
if !found {
t.Errorf("Compile(%v) expected reference %d: %v", tc.expr, id, reference)
Expand Down Expand Up @@ -2955,6 +2963,15 @@ func BenchmarkDynamicDispatch(b *testing.B) {
})
}

func TestAstProgramNilValue(t *testing.T) {
var ast *Ast = nil
env := testEnv(t)
prg, err := env.Program(ast)
if err == nil || !strings.Contains(err.Error(), "unsupported expr") {
t.Errorf("env.Program() got (%v,%v) wanted unsupported expr error", prg, err)
}
}

// TODO: ideally testCostEstimator and testRuntimeCostEstimator would be shared in a test fixtures package
type testCostEstimator struct {
hints map[string]uint64
Expand Down
5 changes: 3 additions & 2 deletions cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,8 @@ func (e *Env) PartialVars(vars any) (interpreter.PartialActivation, error) {
// TODO: Consider adding an option to generate a Program.Residual to avoid round-tripping to an
// Ast format and then Program again.
func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
pruned := interpreter.PruneAst(a.impl.Expr(), a.impl.SourceInfo().MacroCalls(), details.State())
ast := a.NativeRep()
pruned := interpreter.PruneAst(ast.Expr(), ast.SourceInfo().MacroCalls(), details.State())
newAST := &Ast{source: a.Source(), impl: pruned}
expr, err := AstToString(newAST)
if err != nil {
Expand All @@ -582,7 +583,7 @@ func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...ch
extendedOpts := make([]checker.CostOption, 0, len(e.costOptions))
extendedOpts = append(extendedOpts, opts...)
extendedOpts = append(extendedOpts, e.costOptions...)
return checker.Cost(ast.impl, estimator, extendedOpts...)
return checker.Cost(ast.NativeRep(), estimator, extendedOpts...)
}

// configure applies a series of EnvOptions to the current environment.
Expand Down
8 changes: 4 additions & 4 deletions cel/folding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -650,11 +650,11 @@ func TestConstantFoldingNormalizeIDs(t *testing.T) {
t.Fatalf("Compile() failed: %v", iss.Err())
}
preOpt := newIDCollector()
ast.PostOrderVisit(checked.impl.Expr(), preOpt)
ast.PostOrderVisit(checked.NativeRep().Expr(), preOpt)
if !reflect.DeepEqual(preOpt.IDs(), tc.ids) {
t.Errorf("Compile() got ids %v, expected %v", preOpt.IDs(), tc.ids)
}
for id, call := range checked.impl.SourceInfo().MacroCalls() {
for id, call := range checked.NativeRep().SourceInfo().MacroCalls() {
macroText, found := tc.macros[id]
if !found {
t.Fatalf("Compile() did not find macro %d", id)
Expand Down Expand Up @@ -682,11 +682,11 @@ func TestConstantFoldingNormalizeIDs(t *testing.T) {
t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err())
}
postOpt := newIDCollector()
ast.PostOrderVisit(optimized.impl.Expr(), postOpt)
ast.PostOrderVisit(optimized.NativeRep().Expr(), postOpt)
if !reflect.DeepEqual(postOpt.IDs(), tc.normalizedIDs) {
t.Errorf("Optimize() got ids %v, expected %v", postOpt.IDs(), tc.normalizedIDs)
}
for id, call := range optimized.impl.SourceInfo().MacroCalls() {
for id, call := range optimized.NativeRep().SourceInfo().MacroCalls() {
macroText, found := tc.normalizedMacros[id]
if !found {
t.Fatalf("Optimize() did not find macro %d", id)
Expand Down
2 changes: 1 addition & 1 deletion cel/inlining.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func NewInlineVariable(name string, definition *Ast) *InlineVariable {
// If the variable occurs more than once, the provided alias will be used to replace the expressions
// where the variable name occurs.
func NewInlineVariableWithAlias(name, alias string, definition *Ast) *InlineVariable {
return &InlineVariable{name: name, alias: alias, def: definition.impl}
return &InlineVariable{name: name, alias: alias, def: definition.NativeRep()}
}

// NewInliningOptimizer creates and optimizer which replaces variables with expression definitions.
Expand Down
4 changes: 2 additions & 2 deletions cel/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func AstToCheckedExpr(a *Ast) (*exprpb.CheckedExpr, error) {
if !a.IsChecked() {
return nil, fmt.Errorf("cannot convert unchecked ast")
}
return ast.ToProto(a.impl)
return ast.ToProto(a.NativeRep())
}

// ParsedExprToAst converts a parsed expression proto message to an Ast.
Expand Down Expand Up @@ -99,7 +99,7 @@ func AstToParsedExpr(a *Ast) (*exprpb.ParsedExpr, error) {
// Note, the conversion may not be an exact replica of the original expression, but will produce
// a string that is semantically equivalent and whose textual representation is stable.
func AstToString(a *Ast) (string, error) {
return parser.Unparse(a.impl.Expr(), a.impl.SourceInfo())
return parser.Unparse(a.NativeRep().Expr(), a.NativeRep().SourceInfo())
}

// RefValueToValue converts between ref.Val and google.api.expr.v1alpha1.Value.
Expand Down
22 changes: 22 additions & 0 deletions cel/io_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package cel

import (
"fmt"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -144,6 +145,27 @@ func TestAstToString(t *testing.T) {
}
}

func TestAstToStringNil(t *testing.T) {
expr, err := AstToString(nil)
if err == nil || !strings.Contains(err.Error(), "unsupported expr") {
t.Errorf("env.AstToString() got (%v, %v) wanted unsupported expr error", expr, err)
}
}

func TestAstToCheckedExprNil(t *testing.T) {
expr, err := AstToCheckedExpr(nil)
if err == nil || !strings.Contains(err.Error(), "cannot convert unchecked ast") {
t.Errorf("env.AstToCheckedExpr() got (%v, %v) wanted conversion error", expr, err)
}
}

func TestAstToParsedExprNil(t *testing.T) {
expr, err := AstToParsedExpr(nil)
if err != nil {
t.Errorf("env.AstToParsedExpr() got (%v, %v) wanted conversion error", expr, err)
}
}

func TestCheckedExprToAstConstantExpr(t *testing.T) {
stdEnv, err := NewEnv()
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions cel/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer {
// If issues are encountered, the Issues.Err() return value will be non-nil.
func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
// Make a copy of the AST to be optimized.
optimized := ast.Copy(a.impl)
ids := newIDGenerator(ast.MaxID(a.impl))
optimized := ast.Copy(a.NativeRep())
ids := newIDGenerator(ast.MaxID(a.NativeRep()))

// Create the optimizer context, could be pooled in the future.
issues := NewIssues(common.NewErrors(a.Source()))
Expand Down Expand Up @@ -86,7 +86,7 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
if iss.Err() != nil {
return nil, iss
}
optimized = checked.impl
optimized = checked.NativeRep()
}

// Return the optimized result.
Expand Down
10 changes: 10 additions & 0 deletions cel/optimizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package cel_test

import (
"sort"
"strings"
"testing"

"github.com/google/cel-go/cel"
Expand Down Expand Up @@ -201,6 +202,15 @@ func TestStaticOptimizerNewAST(t *testing.T) {
}
}

func TestStaticOptimizerNilAST(t *testing.T) {
env := optimizerEnv(t)
opt := cel.NewStaticOptimizer(&identityOptimizer{t: t})
optAST, iss := opt.Optimize(env, nil)
if iss.Err() == nil || !strings.Contains(iss.Err().Error(), "unexpected unspecified type") {
t.Errorf("opt.Optimize(env, nil) got (%v, %v), wanted unexpected unspecified type", optAST, iss)
}
}

type identityOptimizer struct {
t *testing.T
}
Expand Down
3 changes: 3 additions & 0 deletions cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ type EvalDetails struct {
// State of the evaluation, non-nil if the OptTrackState or OptExhaustiveEval is specified
// within EvalOptions.
func (ed *EvalDetails) State() interpreter.EvalState {
if ed == nil {
return interpreter.NewEvalState()
}
return ed.state
}

Expand Down
6 changes: 5 additions & 1 deletion common/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,13 @@ type Errors struct {

// NewErrors creates a new instance of the Errors type.
func NewErrors(source Source) *Errors {
src := source
if src == nil {
src = NewTextSource("")
}
return &Errors{
errors: []*Error{},
source: source,
source: src,
maxErrorsToReport: 100,
}
}
Expand Down

0 comments on commit ba74bf6

Please sign in to comment.