Skip to content

Commit

Permalink
Refactor dialect commons (#358)
Browse files Browse the repository at this point in the history
* Refactor dialect commons

Add Quote method to solve names escaping in the future

* Extract AfterOpen to its own interface
  • Loading branch information
stanislas-m authored Mar 3, 2019
1 parent 0a6f94b commit f9fc3a3
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 232 deletions.
8 changes: 5 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,11 @@ func (c *Connection) Open() error {
db.SetMaxIdleConns(details.IdlePool)
c.Store = &dB{db}

err = c.Dialect.afterOpen(c)
if err != nil {
c.Store = nil
if d, ok := c.Dialect.(afterOpenable); ok {
err = d.AfterOpen(c)
if err != nil {
c.Store = nil
}
}
return errors.Wrap(err, "could not open database connection")
}
Expand Down
168 changes: 3 additions & 165 deletions dialect.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,12 @@
package pop

import (
"bytes"
"encoding/gob"
"fmt"
"io"
"io/ioutil"
"os"
"os/exec"
"strings"

"github.com/gobuffalo/fizz"
"github.com/gobuffalo/pop/columns"
"github.com/gobuffalo/pop/logging"
"github.com/gofrs/uuid"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
)

func init() {
gob.Register(uuid.UUID{})
}

type dialect interface {
Name() string
URL() string
Expand All @@ -40,156 +25,9 @@ type dialect interface {
FizzTranslator() fizz.Translator
Lock(func() error) error
TruncateAll(*Connection) error
afterOpen(*Connection) error
}

func genericCreate(s store, model *Model, cols columns.Columns) error {
keyType := model.PrimaryKeyType()
switch keyType {
case "int", "int64":
var id int64
w := cols.Writeable()
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", model.TableName(), w.String(), w.SymbolizedString())
log(logging.SQL, query)
res, err := s.NamedExec(query, model.Value)
if err != nil {
return errors.WithStack(err)
}
id, err = res.LastInsertId()
if err == nil {
model.setID(id)
}
if err != nil {
return errors.WithStack(err)
}
return nil
case "UUID", "string":
if keyType == "UUID" {
if model.ID() == emptyUUID {
u, err := uuid.NewV4()
if err != nil {
return errors.WithStack(err)
}
model.setID(u)
}
} else if model.ID() == "" {
return fmt.Errorf("missing ID value")
}
w := cols.Writeable()
w.Add("id")
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", model.TableName(), w.String(), w.SymbolizedString())
log(logging.SQL, query)
stmt, err := s.PrepareNamed(query)
if err != nil {
return errors.WithStack(err)
}
_, err = stmt.Exec(model.Value)
if err != nil {
if err := stmt.Close(); err != nil {
return errors.WithMessage(err, "failed to close statement")
}
return errors.WithStack(err)
}
return errors.WithMessage(stmt.Close(), "failed to close statement")
}
return errors.Errorf("can not use %s as a primary key type!", keyType)
Quote(key string) string
}

func genericUpdate(s store, model *Model, cols columns.Columns) error {
stmt := fmt.Sprintf("UPDATE %s SET %s WHERE %s", model.TableName(), cols.Writeable().UpdateString(), model.whereNamedID())
log(logging.SQL, stmt, model.ID())
_, err := s.NamedExec(stmt, model.Value)
if err != nil {
return errors.WithStack(err)
}
return nil
}

func genericDestroy(s store, model *Model) error {
stmt := fmt.Sprintf("DELETE FROM %s WHERE %s", model.TableName(), model.whereID())
err := genericExec(s, stmt, model.ID())
if err != nil {
return errors.WithStack(err)
}
return nil
}

func genericExec(s store, stmt string, args ...interface{}) error {
log(logging.SQL, stmt, args...)
_, err := s.Exec(stmt, args...)
if err != nil {
return errors.WithStack(err)
}
return nil
}

func genericSelectOne(s store, model *Model, query Query) error {
sql, args := query.ToSQL(model)
log(logging.SQL, sql, args...)
err := s.Get(model.Value, sql, args...)
if err != nil {
return errors.WithStack(err)
}
return nil
}

func genericSelectMany(s store, models *Model, query Query) error {
sql, args := query.ToSQL(models)
log(logging.SQL, sql, args...)
err := s.Select(models.Value, sql, args...)
if err != nil {
return errors.WithStack(err)
}
return nil
}

func genericLoadSchema(deets *ConnectionDetails, migrationURL string, r io.Reader) error {
// Open DB connection on the target DB
db, err := sqlx.Open(deets.Dialect, migrationURL)
if err != nil {
return errors.WithMessage(err, fmt.Sprintf("unable to load schema for %s", deets.Database))
}
defer db.Close()

// Get reader contents
contents, err := ioutil.ReadAll(r)
if err != nil {
return err
}

if len(contents) == 0 {
log(logging.Info, "schema is empty for %s, skipping", deets.Database)
return nil
}

_, err = db.Exec(string(contents))
if err != nil {
return errors.WithMessage(err, fmt.Sprintf("unable to load schema for %s", deets.Database))
}

log(logging.Info, "loaded schema for %s", deets.Database)
return nil
}

func genericDumpSchema(deets *ConnectionDetails, cmd *exec.Cmd, w io.Writer) error {
log(logging.SQL, strings.Join(cmd.Args, " "))

bb := &bytes.Buffer{}
mw := io.MultiWriter(w, bb)

cmd.Stdout = mw
cmd.Stderr = os.Stderr

err := cmd.Run()
if err != nil {
return err
}

x := bytes.TrimSpace(bb.Bytes())
if len(x) == 0 {
return errors.Errorf("unable to dump schema for %s", deets.Database)
}

log(logging.Info, "dumped schema for %s", deets.Database)
return nil
type afterOpenable interface {
AfterOpen(*Connection) error
}
20 changes: 8 additions & 12 deletions dialect_cockroach.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ type cockroachInfo struct {
}

type cockroach struct {
translateCache map[string]string
mu sync.Mutex
ConnectionDetails *ConnectionDetails
info cockroachInfo
commonDialect
translateCache map[string]string
mu sync.Mutex
info cockroachInfo
}

func (p *cockroach) Name() string {
Expand Down Expand Up @@ -187,10 +187,6 @@ func (p *cockroach) FizzTranslator() fizz.Translator {
return translators.NewCockroach(p.URL(), p.Details().Database)
}

func (p *cockroach) Lock(fn func() error) error {
return fn()
}

func (p *cockroach) DumpSchema(w io.Writer) error {
cmd := exec.Command("cockroach", "dump", p.Details().Database, "--dump-mode=schema")

Expand Down Expand Up @@ -239,7 +235,7 @@ func (p *cockroach) TruncateAll(tx *Connection) error {
// return tx3.RawQuery(fmt.Sprintf("truncate %s cascade;", strings.Join(tableNames, ", "))).Exec()
}

func (p *cockroach) afterOpen(c *Connection) error {
func (p *cockroach) AfterOpen(c *Connection) error {
if err := c.RawQuery(`select version() AS "version"`).First(&p.info); err != nil {
return err
}
Expand All @@ -257,9 +253,9 @@ func (p *cockroach) afterOpen(c *Connection) error {
func newCockroach(deets *ConnectionDetails) (dialect, error) {
deets.Dialect = "postgres"
d := &cockroach{
ConnectionDetails: deets,
translateCache: map[string]string{},
mu: sync.Mutex{},
commonDialect: commonDialect{ConnectionDetails: deets},
translateCache: map[string]string{},
mu: sync.Mutex{},
}
d.info.client = deets.Options["application_name"]
return d, nil
Expand Down
6 changes: 3 additions & 3 deletions dialect_cockroach_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func Test_Cockroach_URL_Raw(t *testing.T) {
}
err := cd.Finalize()
r.NoError(err)
m := &cockroach{ConnectionDetails: cd}
m := &cockroach{commonDialect: commonDialect{ConnectionDetails: cd}}
r.Equal("scheme://user:pass@host:port/database?option1=value1", m.URL())
r.Equal("postgres://user:pass@host:port/?option1=value1", m.urlWithoutDb())
}
Expand All @@ -36,7 +36,7 @@ func Test_Cockroach_URL_Build(t *testing.T) {
}
err := cd.Finalize()
r.NoError(err)
m := &cockroach{ConnectionDetails: cd}
m := &cockroach{commonDialect: commonDialect{ConnectionDetails: cd}}
r.True(strings.HasPrefix(m.URL(), "postgres://user:pass@host:port/database?"), "URL() returns %v", m.URL())
r.Contains(m.URL(), "option1=value1")
r.Contains(m.URL(), "application_name=pop.test")
Expand All @@ -58,7 +58,7 @@ func Test_Cockroach_URL_UserDefinedAppName(t *testing.T) {
}
err := cd.Finalize()
r.NoError(err)
m := &cockroach{ConnectionDetails: cd}
m := &cockroach{commonDialect: commonDialect{ConnectionDetails: cd}}
r.Contains(m.URL(), "database?application_name=myapp")
r.Contains(m.urlWithoutDb(), "/?application_name=myapp")
}
Loading

0 comments on commit f9fc3a3

Please sign in to comment.