Skip to content

Commit

Permalink
feat(debug): Add databases=managed debug option (#2898)
Browse files Browse the repository at this point in the history
* feat(debug): Add databases=managed debug option

- Remove the --no-database flag
- Refactor the shfmt package

---------

Co-authored-by: Andrew Benton <[email protected]>
  • Loading branch information
kyleconroy and andrewmbenton authored Oct 23, 2023
1 parent 5f87be3 commit df4c05b
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 62 deletions.
20 changes: 8 additions & 12 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int
rootCmd.PersistentFlags().BoolP("experimental", "x", false, "DEPRECATED: enable experimental features (default: false)")
rootCmd.PersistentFlags().Bool("no-remote", false, "disable remote execution (default: false)")
rootCmd.PersistentFlags().Bool("remote", false, "enable remote execution (default: false)")
rootCmd.PersistentFlags().Bool("no-database", false, "disable database connections (default: false)")

rootCmd.AddCommand(checkCmd)
rootCmd.AddCommand(createDBCmd)
Expand Down Expand Up @@ -137,24 +136,21 @@ var initCmd = &cobra.Command{
}

type Env struct {
DryRun bool
Debug opts.Debug
Remote bool
NoRemote bool
NoDatabase bool
DryRun bool
Debug opts.Debug
Remote bool
NoRemote bool
}

func ParseEnv(c *cobra.Command) Env {
dr := c.Flag("dry-run")
r := c.Flag("remote")
nr := c.Flag("no-remote")
nodb := c.Flag("no-database")
return Env{
DryRun: dr != nil && dr.Changed,
Debug: opts.DebugFromEnv(),
Remote: r != nil && nr.Value.String() == "true",
NoRemote: nr != nil && nr.Value.String() == "true",
NoDatabase: nodb != nil && nodb.Value.String() == "true",
DryRun: dr != nil && dr.Changed,
Debug: opts.DebugFromEnv(),
Remote: r != nil && nr.Value.String() == "true",
NoRemote: nr != nil && nr.Value.String() == "true",
}
}

Expand Down
43 changes: 18 additions & 25 deletions internal/cmd/vet.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,13 @@ func Vet(ctx context.Context, dir, filename string, opts *Options) error {
}

c := checker{
Rules: rules,
Conf: conf,
Dir: dir,
Env: env,
Envmap: map[string]string{},
Stderr: stderr,
NoDatabase: e.NoDatabase,
Rules: rules,
Conf: conf,
Dir: dir,
Env: env,
Stderr: stderr,
OnlyManagedDB: e.Debug.OnlyManagedDatabases,
Replacer: shfmt.NewReplacer(nil),
}
errored := false
for _, sql := range conf.SQL {
Expand Down Expand Up @@ -379,14 +379,14 @@ type rule struct {
}

type checker struct {
Rules map[string]rule
Conf *config.Config
Dir string
Env *cel.Env
Envmap map[string]string
Stderr io.Writer
NoDatabase bool
Client pb.QuickClient
Rules map[string]rule
Conf *config.Config
Dir string
Env *cel.Env
Stderr io.Writer
OnlyManagedDB bool
Client pb.QuickClient
Replacer *shfmt.Replacer
}

func (c *checker) fetchDatabaseUri(ctx context.Context, s config.SQL) (string, func() error, error) {
Expand Down Expand Up @@ -448,14 +448,7 @@ func (c *checker) fetchDatabaseUri(ctx context.Context, s config.SQL) (string, f
}

func (c *checker) DSN(dsn string) (string, error) {
// Populate the environment variable map if it is empty
if len(c.Envmap) == 0 {
for _, e := range os.Environ() {
k, v, _ := strings.Cut(e, "=")
c.Envmap[k] = v
}
}
return shfmt.Replace(dsn, c.Envmap), nil
return c.Replacer.Replace(dsn), nil
}

func (c *checker) checkSQL(ctx context.Context, s config.SQL) error {
Expand Down Expand Up @@ -488,8 +481,8 @@ func (c *checker) checkSQL(ctx context.Context, s config.SQL) error {
var prep preparer
var expl explainer
if s.Database != nil { // TODO only set up a database connection if a rule evaluation requires it
if c.NoDatabase {
return fmt.Errorf("database: connections disabled via command line flag")
if s.Database.URI != "" && c.OnlyManagedDB {
return fmt.Errorf("database: connections disabled via SQLCDEBUG=databases=managed")
}
dburl, cleanup, err := c.fetchDatabaseUri(ctx, s)
if err != nil {
Expand Down
27 changes: 17 additions & 10 deletions internal/engine/postgresql/analyzer/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,31 @@ import (

core "github.com/sqlc-dev/sqlc/internal/analysis"
"github.com/sqlc-dev/sqlc/internal/config"
"github.com/sqlc-dev/sqlc/internal/opts"
pb "github.com/sqlc-dev/sqlc/internal/quickdb/v1"
"github.com/sqlc-dev/sqlc/internal/shfmt"
"github.com/sqlc-dev/sqlc/internal/sql/ast"
"github.com/sqlc-dev/sqlc/internal/sql/named"
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
)

type Analyzer struct {
db config.Database
client pb.QuickClient
pool *pgxpool.Pool

formats sync.Map
columns sync.Map
tables sync.Map
db config.Database
client pb.QuickClient
pool *pgxpool.Pool
dbg opts.Debug
replacer *shfmt.Replacer
formats sync.Map
columns sync.Map
tables sync.Map
}

func New(client pb.QuickClient, db config.Database) *Analyzer {
return &Analyzer{
db: db,
client: client,
db: db,
dbg: opts.DebugFromEnv(),
client: client,
replacer: shfmt.NewReplacer(nil),
}
}

Expand Down Expand Up @@ -204,8 +209,10 @@ func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrat
return nil, err
}
uri = edb.Uri
} else if a.dbg.OnlyManagedDatabases {
return nil, fmt.Errorf("database: connections disabled via SQLCDEBUG=databases=managed")
} else {
uri = a.db.URI
uri = a.replacer.Replace(a.db.URI)
}
conf, err := pgxpool.ParseConfig(uri)
if err != nil {
Expand Down
16 changes: 10 additions & 6 deletions internal/opts/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@ import (
// dumpcatalog: setting dumpcatalog=1 will print the parsed database schema
// trace: setting trace=<path> will output a trace
// processplugins: setting processplugins=0 will disable process-based plugins
// databases: setting databases=managed will disable connections to databases via URI
// dumpvetenv: setting dumpvetenv=1 will print the variables available to
// a vet rule during evaluation
// dumpexplain: setting dumpexplain=1 will print the JSON-formatted output
// from executing EXPLAIN ... on a query during vet rule evaluation

type Debug struct {
DumpAST bool
DumpCatalog bool
Trace string
ProcessPlugins bool
DumpVetEnv bool
DumpExplain bool
DumpAST bool
DumpCatalog bool
Trace string
ProcessPlugins bool
OnlyManagedDatabases bool
DumpVetEnv bool
DumpExplain bool
}

func DebugFromEnv() Debug {
Expand Down Expand Up @@ -53,6 +55,8 @@ func DebugFromString(val string) Debug {
}
case pair == "processplugins=0":
d.ProcessPlugins = false
case pair == "databases=managed":
d.OnlyManagedDatabases = true
case pair == "dumpvetenv=1":
d.DumpVetEnv = true
case pair == "dumpexplain=1":
Expand Down
26 changes: 24 additions & 2 deletions internal/shfmt/shfmt.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,38 @@
package shfmt

import (
"os"
"regexp"
"strings"
)

var pat = regexp.MustCompile(`\$\{[A-Z_]+\}`)

func Replace(f string, vars map[string]string) string {
type Replacer struct {
envmap map[string]string
}

func (r *Replacer) Replace(f string) string {
return pat.ReplaceAllStringFunc(f, func(s string) string {
s = strings.TrimPrefix(s, "${")
s = strings.TrimSuffix(s, "}")
return vars[s]
return r.envmap[s]
})
}

func NewReplacer(env []string) *Replacer {
r := Replacer{
envmap: map[string]string{},
}
if env == nil {
env = os.Environ()
}
for _, e := range env {
k, v, _ := strings.Cut(e, "=")
if k == "SQLC_AUTH_TOKEN" {
continue
}
r.envmap[k] = v
}
return &r
}
14 changes: 7 additions & 7 deletions internal/shfmt/shfmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ import "testing"

func TestReplace(t *testing.T) {
s := "POSTGRES_SQL://${PG_USER}:${PG_PASSWORD}@${PG_HOST}:${PG_PORT}/AUTHORS"
env := map[string]string{
"PG_USER": "user",
"PG_PASSWORD": "password",
"PG_HOST": "host",
"PG_PORT": "port",
}
r := NewReplacer([]string{
"PG_USER=user",
"PG_PASSWORD=password",
"PG_HOST=host",
"PG_PORT=port",
})
e := "POSTGRES_SQL://user:password@host:port/AUTHORS"
if v := Replace(s, env); v != e {
if v := r.Replace(s); v != e {
t.Errorf("%s != %s", v, e)
}
}

0 comments on commit df4c05b

Please sign in to comment.