From 9199320dbe03632ff743e0611176bec1dd6549a1 Mon Sep 17 00:00:00 2001 From: Becky Huang Date: Mon, 10 Jan 2022 10:29:56 -0800 Subject: [PATCH] refactor local.driver (#171) export local.Driver set Externs for local.Driver within client.init move local.Driver PutModules and DeleteModules to private, deleted from drivers interface remove rego validation in client.CreateCRD Signed-off-by: Becky Huang --- constraint/pkg/client/client.go | 16 ++-- constraint/pkg/client/client_test.go | 43 --------- constraint/pkg/client/drivers/interface.go | 6 -- constraint/pkg/client/drivers/local/args.go | 14 +-- constraint/pkg/client/drivers/local/local.go | 95 +++++++++---------- .../pkg/client/drivers/local/local_test.go | 20 ++-- .../client/drivers/local/local_unit_test.go | 41 ++++---- .../pkg/client/drivers/remote/remote.go | 10 -- 8 files changed, 87 insertions(+), 158 deletions(-) diff --git a/constraint/pkg/client/client.go b/constraint/pkg/client/client.go index 931740f4d..3276c0420 100644 --- a/constraint/pkg/client/client.go +++ b/constraint/pkg/client/client.go @@ -251,9 +251,6 @@ func (c *Client) CreateCRD(templ *templates.ConstraintTemplate) (*apiextensions. if err != nil { return nil, err } - if _, _, err = local.MapModules(templ, c.allowedDataFields); err != nil { - return nil, err - } return artifacts.crd, nil } @@ -277,9 +274,6 @@ func (c *Client) AddTemplate(templ *templates.ConstraintTemplate) (*types.Respon c.mtx.Lock() defer c.mtx.Unlock() - if d, ok := c.backend.driver.(interface{ AddExterns([]string) }); ok { - d.AddExterns(c.allowedDataFields) - } if err = c.backend.driver.AddTemplate(templ); err != nil { return resp, err } @@ -663,7 +657,15 @@ func (c *Client) init() error { ErrCreatingClient, err, src) } } - + if d, ok := c.backend.driver.(*local.Driver); ok { + var externs []string + for _, field := range c.allowedDataFields { + externs = append(externs, fmt.Sprintf("data.%s", field)) + } + d.SetExterns(externs) + } else { + return fmt.Errorf("%w: driver %T is not supported", ErrCreatingClient, c.backend.driver) + } return nil } diff --git a/constraint/pkg/client/client_test.go b/constraint/pkg/client/client_test.go index 485ae2613..6cdb6faaa 100644 --- a/constraint/pkg/client/client_test.go +++ b/constraint/pkg/client/client_test.go @@ -1224,49 +1224,6 @@ violation[msg] {msg := "always"}`, want: nil, wantErr: ErrInvalidConstraintTemplate, }, - { - name: "no rego", - targets: []TargetHandler{&badHandler{Name: "handler", HasLib: true}}, - template: &templates.ConstraintTemplate{ - ObjectMeta: v1.ObjectMeta{Name: "foo"}, - Spec: templates.ConstraintTemplateSpec{ - CRD: templates.CRD{ - Spec: templates.CRDSpec{ - Names: templates.Names{ - Kind: "Foo", - }, - }, - }, - Targets: []templates.Target{{ - Target: "handler", - }}, - }, - }, - want: nil, - wantErr: local.ErrInvalidConstraintTemplate, - }, - { - name: "empty rego package", - targets: []TargetHandler{&badHandler{Name: "handler", HasLib: true}}, - template: &templates.ConstraintTemplate{ - ObjectMeta: v1.ObjectMeta{Name: "foo"}, - Spec: templates.ConstraintTemplateSpec{ - CRD: templates.CRD{ - Spec: templates.CRDSpec{ - Names: templates.Names{ - Kind: "Foo", - }, - }, - }, - Targets: []templates.Target{{ - Target: "handler", - Rego: `package foo`, - }}, - }, - }, - want: nil, - wantErr: local.ErrInvalidConstraintTemplate, - }, { name: "multiple targets", targets: []TargetHandler{ diff --git a/constraint/pkg/client/drivers/interface.go b/constraint/pkg/client/drivers/interface.go index 314f7c637..5663ac20b 100644 --- a/constraint/pkg/client/drivers/interface.go +++ b/constraint/pkg/client/drivers/interface.go @@ -23,13 +23,7 @@ func Tracing(enabled bool) QueryOpt { type Driver interface { Init() error PutModule(name string, src string) error - // PutModules upserts a number of modules under a given prefix. - PutModules(namePrefix string, srcs []string) error - // DeleteModules deletes all modules under a given prefix and returns the - // count of modules deleted. Deletion of non-existing prefix will - // result in 0, nil being returned. - DeleteModules(namePrefix string) (int, error) // AddTemplate adds the template source code to OPA AddTemplate(ct *templates.ConstraintTemplate) error // RemoveTemplate removes the template source code from OPA diff --git a/constraint/pkg/client/drivers/local/args.go b/constraint/pkg/client/drivers/local/args.go index 85b0d80b1..29f5d7f8e 100644 --- a/constraint/pkg/client/drivers/local/args.go +++ b/constraint/pkg/client/drivers/local/args.go @@ -8,10 +8,10 @@ import ( opatypes "github.com/open-policy-agent/opa/types" ) -type Arg func(*driver) +type Arg func(*Driver) func Defaults() Arg { - return func(d *driver) { + return func(d *Driver) { if d.compiler == nil { d.compiler = ast.NewCompiler() } @@ -40,31 +40,31 @@ func Defaults() Arg { } func Tracing(enabled bool) Arg { - return func(d *driver) { + return func(d *Driver) { d.traceEnabled = enabled } } func Modules(modules map[string]*ast.Module) Arg { - return func(d *driver) { + return func(d *Driver) { d.modules = modules } } func Storage(s storage.Store) Arg { - return func(d *driver) { + return func(d *Driver) { d.storage = s } } func AddExternalDataProviderCache(providerCache *externaldata.ProviderCache) Arg { - return func(d *driver) { + return func(d *Driver) { d.providerCache = providerCache } } func DisableBuiltins(builtins ...string) Arg { - return func(d *driver) { + return func(d *Driver) { if d.capabilities == nil { d.capabilities = ast.CapabilitiesForThisVersion() } diff --git a/constraint/pkg/client/drivers/local/local.go b/constraint/pkg/client/drivers/local/local.go index 0c3b4baaa..3c65879e1 100644 --- a/constraint/pkg/client/drivers/local/local.go +++ b/constraint/pkg/client/drivers/local/local.go @@ -30,6 +30,8 @@ import ( const ( moduleSetPrefix = "__modset_" moduleSetSep = "_idx_" + libPrefix = "data.lib" + violation = "violation" ) type module struct { @@ -50,7 +52,7 @@ func (i insertParam) add(name string, src string) error { } func New(args ...Arg) drivers.Driver { - d := &driver{} + d := &Driver{} for _, arg := range args { arg(d) } @@ -62,9 +64,9 @@ func New(args ...Arg) drivers.Driver { return d } -var _ drivers.Driver = &driver{} +var _ drivers.Driver = &Driver{} -type driver struct { +type Driver struct { modulesMux sync.RWMutex compiler *ast.Compiler modules map[string]*ast.Module @@ -75,7 +77,7 @@ type driver struct { externs []string } -func (d *driver) Init() error { +func (d *Driver) Init() error { if d.providerCache != nil { rego.RegisterBuiltin1( ®o.Function{ @@ -139,7 +141,7 @@ func copyModules(modules map[string]*ast.Module) map[string]*ast.Module { return m } -func (d *driver) checkModuleName(name string) error { +func (d *Driver) checkModuleName(name string) error { if name == "" { return fmt.Errorf("%w: module %q has no name", ErrModuleName, name) @@ -153,7 +155,7 @@ func (d *driver) checkModuleName(name string) error { return nil } -func (d *driver) checkModuleSetName(name string) error { +func (d *Driver) checkModuleSetName(name string) error { if name == "" { return fmt.Errorf("%w: modules name prefix cannot be empty", ErrModulePrefix) } @@ -173,7 +175,7 @@ func toModuleSetName(prefix string, idx int) string { return fmt.Sprintf("%s%d", toModuleSetPrefix(prefix), idx) } -func (d *driver) PutModule(name string, src string) error { +func (d *Driver) PutModule(name string, src string) error { if err := d.checkModuleName(name); err != nil { return err } @@ -190,8 +192,8 @@ func (d *driver) PutModule(name string, src string) error { return err } -// PutModules implements drivers.Driver. -func (d *driver) PutModules(namePrefix string, srcs []string) error { +// PutModules upserts a number of modules under a given prefix. +func (d *Driver) putModules(namePrefix string, srcs []string) error { if err := d.checkModuleSetName(namePrefix); err != nil { return err } @@ -222,7 +224,7 @@ func (d *driver) PutModules(namePrefix string, srcs []string) error { // alterModules alters the modules in the driver by inserting and removing // the provided modules then returns the count of modules removed. // alterModules expects that the caller is holding the modulesMux lock. -func (d *driver) alterModules(insert insertParam, remove []string) (int, error) { +func (d *Driver) alterModules(insert insertParam, remove []string) (int, error) { // TODO(davis-haba): Remove this Context once it is no longer necessary. ctx := context.TODO() @@ -272,8 +274,10 @@ func (d *driver) alterModules(insert insertParam, remove []string) (int, error) return len(remove), nil } -// DeleteModules implements drivers.Driver. -func (d *driver) DeleteModules(namePrefix string) (int, error) { +// deleteModules deletes all modules under a given prefix and returns the +// count of modules deleted. Deletion of non-existing prefix will +// result in 0, nil being returned. +func (d *Driver) deleteModules(namePrefix string) (int, error) { if err := d.checkModuleSetName(namePrefix); err != nil { return 0, err } @@ -286,7 +290,7 @@ func (d *driver) DeleteModules(namePrefix string) (int, error) { // listModuleSet returns the list of names corresponding to a given module // prefix. -func (d *driver) listModuleSet(namePrefix string) []string { +func (d *Driver) listModuleSet(namePrefix string) []string { prefix := toModuleSetPrefix(namePrefix) var names []string @@ -311,7 +315,7 @@ func parsePath(path string) ([]string, error) { return p, nil } -func (d *driver) PutData(ctx context.Context, path string, data interface{}) error { +func (d *Driver) PutData(ctx context.Context, path string, data interface{}) error { d.modulesMux.RLock() defer d.modulesMux.RUnlock() @@ -358,7 +362,7 @@ func (d *driver) PutData(ctx context.Context, path string, data interface{}) err // DeleteData deletes data from OPA and returns true if data was found and deleted, false // if data was not found, and any errors. -func (d *driver) DeleteData(ctx context.Context, path string) (bool, error) { +func (d *Driver) DeleteData(ctx context.Context, path string) (bool, error) { d.modulesMux.RLock() defer d.modulesMux.RUnlock() @@ -387,7 +391,7 @@ func (d *driver) DeleteData(ctx context.Context, path string) (bool, error) { return true, nil } -func (d *driver) eval(ctx context.Context, path string, input interface{}, cfg *drivers.QueryCfg) (rego.ResultSet, *string, error) { +func (d *Driver) eval(ctx context.Context, path string, input interface{}, cfg *drivers.QueryCfg) (rego.ResultSet, *string, error) { d.modulesMux.RLock() defer d.modulesMux.RUnlock() @@ -416,7 +420,7 @@ func (d *driver) eval(ctx context.Context, path string, input interface{}, cfg * return res, t, err } -func (d *driver) Query(ctx context.Context, path string, input interface{}, opts ...drivers.QueryOpt) (*types.Response, error) { +func (d *Driver) Query(ctx context.Context, path string, input interface{}, opts ...drivers.QueryOpt) (*types.Response, error) { cfg := &drivers.QueryCfg{} for _, opt := range opts { opt(cfg) @@ -455,7 +459,7 @@ func (d *driver) Query(ctx context.Context, path string, input interface{}, opts }, nil } -func (d *driver) Dump(ctx context.Context) (string, error) { +func (d *Driver) Dump(ctx context.Context) (string, error) { d.modulesMux.RLock() defer d.modulesMux.RUnlock() @@ -498,57 +502,58 @@ func (d *driver) Dump(ctx context.Context) (string, error) { return string(b), nil } -func MapModules(templ *templates.ConstraintTemplate, extern []string) (string, []string, error) { +// AddTemplate implements drivers.Driver. +func (d *Driver) AddTemplate(templ *templates.ConstraintTemplate) error { if err := validateTargets(templ); err != nil { - return "", nil, err + return err } targetSpec := templ.Spec.Targets[0] targetHandler := targetSpec.Target kind := templ.Spec.CRD.Spec.Names.Kind - libPrefix := templateLibPrefix(targetHandler, kind) + pkgPrefix := templateLibPrefix(targetHandler, kind) rr, err := regorewriter.New( - regorewriter.NewPackagePrefixer(libPrefix), - []string{"data.lib"}, - allowedDataFields(extern)) + regorewriter.NewPackagePrefixer(pkgPrefix), + []string{libPrefix}, + d.externs) if err != nil { - return "", nil, fmt.Errorf("creating rego rewriter: %w", err) + return fmt.Errorf("creating rego rewriter: %w", err) } namePrefix := createTemplatePath(targetHandler, kind) entryPoint, err := parseModule(namePrefix, templ.Spec.Targets[0].Rego) if err != nil { - return "", nil, fmt.Errorf("%w: %v", ErrInvalidConstraintTemplate, err) + return fmt.Errorf("%w: %v", ErrInvalidConstraintTemplate, err) } if entryPoint == nil { - return "", nil, fmt.Errorf("%w: failed to parse module for unknown reason", + return fmt.Errorf("%w: failed to parse module for unknown reason", ErrInvalidConstraintTemplate) } if err = rewriteModulePackage(namePrefix, entryPoint); err != nil { - return "", nil, err + return err } - req := map[string]struct{}{"violation": {}} + req := map[string]struct{}{violation: {}} if err = requireModuleRules(entryPoint, req); err != nil { - return "", nil, fmt.Errorf("%w: invalid rego: %v", + return fmt.Errorf("%w: invalid rego: %v", ErrInvalidConstraintTemplate, err) } rr.AddEntryPointModule(namePrefix, entryPoint) for idx, libSrc := range targetSpec.Libs { - libPath := fmt.Sprintf(`%s["lib_%d"]`, libPrefix, idx) + libPath := fmt.Sprintf(`%s["lib_%d"]`, pkgPrefix, idx) if err = rr.AddLib(libPath, libSrc); err != nil { - return "", nil, fmt.Errorf("%w: %v", + return fmt.Errorf("%w: %v", ErrInvalidConstraintTemplate, err) } } sources, err := rr.Rewrite() if err != nil { - return "", nil, fmt.Errorf("%w: %v", + return fmt.Errorf("%w: %v", ErrInvalidConstraintTemplate, err) } @@ -562,33 +567,27 @@ func MapModules(templ *templates.ConstraintTemplate, extern []string) (string, [ return nil }) if err != nil { - return "", nil, fmt.Errorf("%w: %v", + return fmt.Errorf("%w: %v", ErrInvalidConstraintTemplate, err) } - return namePrefix, mods, nil -} - -// AddTemplate implements drivers.Driver. -func (d *driver) AddTemplate(templ *templates.ConstraintTemplate) error { - namePrefix, mods, err := MapModules(templ, d.externs) if err != nil { return err } - if err = d.PutModules(namePrefix, mods); err != nil { + if err = d.putModules(namePrefix, mods); err != nil { return fmt.Errorf("%w: %v", ErrCompile, err) } return nil } // RemoveTemplate implements driver.Driver. -func (d *driver) RemoveTemplate(ctx context.Context, templ *templates.ConstraintTemplate) error { +func (d *Driver) RemoveTemplate(ctx context.Context, templ *templates.ConstraintTemplate) error { if err := validateTargets(templ); err != nil { return nil } targetHandler := templ.Spec.Targets[0].Target kind := templ.Spec.CRD.Spec.Names.Kind namePrefix := createTemplatePath(targetHandler, kind) - _, err := d.DeleteModules(namePrefix) + _, err := d.deleteModules(namePrefix) return err } @@ -675,14 +674,6 @@ func validateTargets(templ *templates.ConstraintTemplate) error { } } -func (d *driver) AddExterns(fields []string) { +func (d *Driver) SetExterns(fields []string) { d.externs = fields } - -func allowedDataFields(fields []string) []string { - var externs []string - for _, field := range fields { - externs = append(externs, fmt.Sprintf("data.%s", field)) - } - return externs -} diff --git a/constraint/pkg/client/drivers/local/local_test.go b/constraint/pkg/client/drivers/local/local_test.go index f81f766af..3b0848546 100644 --- a/constraint/pkg/client/drivers/local/local_test.go +++ b/constraint/pkg/client/drivers/local/local_test.go @@ -75,9 +75,9 @@ type action struct { func (tt *compositeTestCase) run(t *testing.T) { dr := New(tt.driverArg...) - d, ok := dr.(*driver) + d, ok := dr.(*Driver) if !ok { - t.Fatalf("got driver %T, want %T", dr, &driver{}) + t.Fatalf("got driver %T, want %T", dr, &Driver{}) } for idx, a := range tt.Actions { @@ -97,7 +97,7 @@ func (tt *compositeTestCase) run(t *testing.T) { } case putModules: - err := d.PutModules(a.RuleNamePrefix, a.Rules.srcs()) + err := d.putModules(a.RuleNamePrefix, a.Rules.srcs()) if (err == nil) && a.ErrorExpected { t.Fatalf("PutModules err = nil; want non-nil") } @@ -106,7 +106,7 @@ func (tt *compositeTestCase) run(t *testing.T) { } case deleteModules: - count, err := d.DeleteModules(a.RuleNamePrefix) + count, err := d.deleteModules(a.RuleNamePrefix) if (err == nil) && a.ErrorExpected { t.Fatalf("DeleteModules err = nil; want non-nil") } @@ -378,9 +378,9 @@ func TestPutModule(t *testing.T) { ctx := context.Background() dr := New() - d, ok := dr.(*driver) + d, ok := dr.(*Driver) if !ok { - t.Fatalf("got driver %T, want %T", dr, &driver{}) + t.Fatalf("got driver %T, want %T", dr, &Driver{}) } for _, r := range tt.Rules { @@ -436,9 +436,9 @@ func TestPutData(t *testing.T) { ctx := context.Background() dr := New() - d, ok := dr.(*driver) + d, ok := dr.(*Driver) if !ok { - t.Fatalf("got driver %T, want %T", dr, &driver{}) + t.Fatalf("got driver %T, want %T", dr, &Driver{}) } for _, data := range tt.Data { @@ -514,9 +514,9 @@ func TestDeleteData(t *testing.T) { ctx := context.Background() dr := New() - d, ok := dr.(*driver) + d, ok := dr.(*Driver) if !ok { - t.Fatalf("got driver %T, want %T", dr, &driver{}) + t.Fatalf("got driver %T, want %T", dr, &Driver{}) } for _, a := range tt.Actions { diff --git a/constraint/pkg/client/drivers/local/local_unit_test.go b/constraint/pkg/client/drivers/local/local_unit_test.go index def7e2b9e..0847f4fad 100644 --- a/constraint/pkg/client/drivers/local/local_unit_test.go +++ b/constraint/pkg/client/drivers/local/local_unit_test.go @@ -116,10 +116,10 @@ func TestDriver_PutModule(t *testing.T) { t.Run(tc.name, func(t *testing.T) { d := New(Modules(tc.beforeModules)) - dr, ok := d.(*driver) + dr, ok := d.(*Driver) if !ok { t.Fatalf("got New() type = %T, want %T", - d, &driver{}) + d, &Driver{}) } gotErr := d.PutModule(tc.moduleName, tc.moduleSrc) @@ -242,21 +242,18 @@ func TestDriver_PutModules(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { d := New() - + dr, ok := d.(*Driver) + if !ok { + t.Fatalf("got New() type = %T, want %T", dr, &Driver{}) + } for prefix, src := range tc.beforeModules { - err := d.PutModules(prefix, src) + err := dr.putModules(prefix, src) if err != nil { t.Fatal(err) } } - dr, ok := d.(*driver) - if !ok { - t.Fatalf("got New() type = %T, want %T", - d, &driver{}) - } - - gotErr := d.PutModules(tc.prefix, tc.srcs) + gotErr := dr.putModules(tc.prefix, tc.srcs) if !errors.Is(gotErr, tc.wantErr) { t.Fatalf("got PutModules() error = %v, want %v", gotErr, tc.wantErr) } @@ -315,10 +312,10 @@ func TestDriver_PutModules_StorageErrors(t *testing.T) { t.Fatalf("got PutModule() err %v, want %v", err, nil) } - dr, ok := d.(*driver) + dr, ok := d.(*Driver) if !ok { t.Fatalf("got New() type = %T, want %T", - d, &driver{}) + d, &Driver{}) } gotModules := getModules(dr) @@ -399,25 +396,23 @@ func TestDriver_DeleteModules(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { d := New() - + dr, ok := d.(*Driver) + if !ok { + t.Fatalf("got New() type = %T, want %T", + d, &Driver{}) + } for prefix, count := range tc.beforeModules { modules := make([]string, count) for i := 0; i < count; i++ { modules[i] = Module } - err := d.PutModules(prefix, modules) + err := dr.putModules(prefix, modules) if err != nil { t.Fatal(err) } } - dr, ok := d.(*driver) - if !ok { - t.Fatalf("got New() type = %T, want %T", - d, &driver{}) - } - - gotDeleted, gotErr := d.DeleteModules(tc.prefix) + gotDeleted, gotErr := dr.deleteModules(tc.prefix) if gotDeleted != tc.wantDeleted { t.Errorf("got DeleteModules() = %v, want %v", gotDeleted, tc.wantDeleted) } @@ -730,7 +725,7 @@ func TestDriver_DeleteData_StorageErrors(t *testing.T) { } } -func getModules(dr *driver) []string { +func getModules(dr *Driver) []string { result := make([]string, len(dr.modules)) idx := 0 diff --git a/constraint/pkg/client/drivers/remote/remote.go b/constraint/pkg/client/drivers/remote/remote.go index 642ba6aba..05cbfe6fe 100644 --- a/constraint/pkg/client/drivers/remote/remote.go +++ b/constraint/pkg/client/drivers/remote/remote.go @@ -78,11 +78,6 @@ func (d *driver) PutModule(name string, src string) error { return d.opa.InsertPolicy(name, []byte(src)) } -// PutModules implements drivers.Driver. -func (d *driver) PutModules(namePrefix string, srcs []string) error { - panic("not implemented") -} - // DeleteModule deletes a rule from OPA and returns true if a rule was found and deleted, false // if a rule was not found, and any errors. func (d *driver) DeleteModule(name string) (bool, error) { @@ -98,11 +93,6 @@ func (d *driver) DeleteModule(name string) (bool, error) { return err == nil, err } -// DeleteModules implements drivers.Driver. -func (d *driver) DeleteModules(namePrefix string) (int, error) { - panic("not implemented") -} - // AddTemplate implements drivers.Driver. func (d *driver) AddTemplate(ct *templates.ConstraintTemplate) error { panic("not implemented")