Skip to content

Commit

Permalink
address review comments, rip out go type parsing from config
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmbenton committed Oct 27, 2023
1 parent 6400a47 commit 9c7d17d
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 329 deletions.
23 changes: 1 addition & 22 deletions internal/cmd/shim.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package cmd

import (
"encoding/json"
"strings"

goopts "github.com/sqlc-dev/sqlc/internal/codegen/golang/opts"
"github.com/sqlc-dev/sqlc/internal/compiler"
"github.com/sqlc-dev/sqlc/internal/config"
"github.com/sqlc-dev/sqlc/internal/config/convert"
Expand Down Expand Up @@ -36,13 +34,8 @@ func pluginOverride(r *compiler.Result, o config.Override) *plugin.Override {
}
}

goTypeJSON, err := json.Marshal(pluginGoType(o))
if err != nil {
panic(err)
}

return &plugin.Override{
CodeType: goTypeJSON,
CodeType: o.CodeType,
DbType: o.DBType,
Nullable: o.Nullable,
Unsigned: o.Unsigned,
Expand Down Expand Up @@ -108,20 +101,6 @@ func pluginWASM(p config.Plugin) *plugin.Codegen_WASM {
return nil
}

func pluginGoType(o config.Override) *goopts.ParsedGoType {
// Note that there is a slight mismatch between this and the
// proto api. The GoType on the override is the unparsed type,
// which could be a qualified path or an object, as per
// https://docs.sqlc.dev/en/v1.18.0/reference/config.html#type-overriding
return &goopts.ParsedGoType{
ImportPath: o.GoImportPath,
Package: o.GoPackage,
TypeName: o.GoTypeName,
BasicType: o.GoBasicType,
StructTags: o.GoStructTags,
}
}

func pluginCatalog(c *catalog.Catalog) *plugin.Catalog {
var schemas []*plugin.Schema
for _, s := range c.Schemas {
Expand Down
13 changes: 13 additions & 0 deletions internal/codegen/golang/opts/go_override.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,16 @@ func (o *GoOverride) Convert() *plugin.Override {
func (o *GoOverride) Matches(n *plugin.Identifier, defaultSchema string) bool {
return sdk.Matches(o.Convert(), n, defaultSchema)
}

func NewGoOverride(po *plugin.Override, o Override) GoOverride {
return GoOverride{
po,
&ParsedGoType{
ImportPath: o.GoImportPath,
Package: o.GoPackage,
TypeName: o.GoTypeName,
BasicType: o.GoBasicType,
StructTags: o.GoStructTags,
},
}
}
35 changes: 19 additions & 16 deletions internal/codegen/golang/opts/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,31 @@ func ParseOpts(req *plugin.CodeGenRequest) (*Options, error) {
return options, fmt.Errorf("unmarshalling options: %w", err)
}

for i := range options.QuerySetOverrides {
if err := options.QuerySetOverrides[i].Parse(); err != nil {
for _, override := range req.Settings.Overrides {
var actualOverride Override
if err := json.Unmarshal(override.CodeType, &actualOverride); err != nil {
return options, err
}

// construct a "plugin"-style override to make the next loop simpler
override := pluginOverride(req.Catalog.DefaultSchema, options.QuerySetOverrides[i])

// in sqlc config.Combine() the "package"-level overrides were appended to
// global overrides, so we mimic that behavior here
req.Settings.Overrides = append(req.Settings.Overrides, override)
if err := actualOverride.Parse(); err != nil {
return options, err
}
options.Overrides = append(options.Overrides, NewGoOverride(
override,
actualOverride,
))
}

for _, override := range req.Settings.Overrides {
var goType ParsedGoType
if err := json.Unmarshal(override.CodeType, &goType); err != nil {
// in sqlc config.Combine() the "package"-level overrides were appended to
// global overrides, so we mimic that behavior here
for i := range options.QuerySetOverrides {
if err := options.QuerySetOverrides[i].Parse(); err != nil {
return options, err
}
options.Overrides = append(options.Overrides, GoOverride{
override,
&goType,
})

options.Overrides = append(options.Overrides, NewGoOverride(
pluginOverride(req.Catalog.DefaultSchema, options.QuerySetOverrides[i]),
options.QuerySetOverrides[i],
))
}

if options.QueryParameterLimit == nil {
Expand Down
5 changes: 4 additions & 1 deletion internal/codegen/golang/opts/override.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copied from github.com/sqlc-dev/sqlc/internal/config/override.go and removed Engine field from Override
// Copied from github.com/sqlc-dev/sqlc/internal/config/override.go
package opts

import (
Expand All @@ -21,6 +21,9 @@ type Override struct {
DBType string `json:"db_type" yaml:"db_type"`
Deprecated_PostgresType string `json:"postgres_type" yaml:"postgres_type"`

// for global overrides only when two different engines are in use
Engine string `json:"engine,omitempty" yaml:"engine"`

// True if the GoType should override if the matching type is nullable
Nullable bool `json:"nullable" yaml:"nullable"`

Expand Down
117 changes: 117 additions & 0 deletions internal/codegen/golang/opts/override_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package opts

import (
"testing"

"github.com/google/go-cmp/cmp"
)

func TestTypeOverrides(t *testing.T) {
for _, test := range []struct {
override Override
pkg string
typeName string
basic bool
}{
{
Override{
DBType: "uuid",
GoType: GoType{Spec: "github.com/segmentio/ksuid.KSUID"},
},
"github.com/segmentio/ksuid",
"ksuid.KSUID",
false,
},
// TODO: Add test for struct pointers
//
// {
// Override{
// DBType: "uuid",
// GoType: "github.com/segmentio/*ksuid.KSUID",
// },
// "github.com/segmentio/ksuid",
// "*ksuid.KSUID",
// false,
// },
{
Override{
DBType: "citext",
GoType: GoType{Spec: "string"},
},
"",
"string",
true,
},
{
Override{
DBType: "timestamp",
GoType: GoType{Spec: "time.Time"},
},
"time",
"time.Time",
false,
},
} {
tt := test
t.Run(tt.override.GoType.Spec, func(t *testing.T) {
if err := tt.override.Parse(); err != nil {
t.Fatalf("override parsing failed; %s", err)
}
if diff := cmp.Diff(tt.pkg, tt.override.GoImportPath); diff != "" {
t.Errorf("package mismatch;\n%s", diff)
}
if diff := cmp.Diff(tt.typeName, tt.override.GoTypeName); diff != "" {
t.Errorf("type name mismatch;\n%s", diff)
}
if diff := cmp.Diff(tt.basic, tt.override.GoBasicType); diff != "" {
t.Errorf("basic mismatch;\n%s", diff)
}
})
}
for _, test := range []struct {
override Override
err string
}{
{
Override{
DBType: "uuid",
GoType: GoType{Spec: "Pointer"},
},
"Package override `go_type` specifier \"Pointer\" is not a Go basic type e.g. 'string'",
},
{
Override{
DBType: "uuid",
GoType: GoType{Spec: "untyped rune"},
},
"Package override `go_type` specifier \"untyped rune\" is not a Go basic type e.g. 'string'",
},
} {
tt := test
t.Run(tt.override.GoType.Spec, func(t *testing.T) {
err := tt.override.Parse()
if err == nil {
t.Fatalf("expected parse to fail; got nil")
}
if diff := cmp.Diff(tt.err, err.Error()); diff != "" {
t.Errorf("error mismatch;\n%s", diff)
}
})
}
}

func FuzzOverride(f *testing.F) {
for _, spec := range []string{
"string",
"github.com/gofrs/uuid.UUID",
"github.com/segmentio/ksuid.KSUID",
} {
f.Add(spec)
}
f.Fuzz(func(t *testing.T, s string) {
o := Override{
GoType: GoType{Spec: s},
}
o.Parse()
})
}
21 changes: 0 additions & 21 deletions internal/codegen/golang/opts/plugin_override.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
package opts

import (
"encoding/json"
"strings"

"github.com/sqlc-dev/sqlc/internal/plugin"
Expand Down Expand Up @@ -31,13 +30,7 @@ func pluginOverride(defaultSchema string, o Override) *plugin.Override {
}
}

goTypeJSON, err := json.Marshal(pluginGoType(o))
if err != nil {
panic(err)
}

return &plugin.Override{
CodeType: goTypeJSON,
DbType: o.DBType,
Nullable: o.Nullable,
Unsigned: o.Unsigned,
Expand All @@ -46,17 +39,3 @@ func pluginOverride(defaultSchema string, o Override) *plugin.Override {
Table: &table,
}
}

func pluginGoType(o Override) *ParsedGoType {
// Note that there is a slight mismatch between this and the
// proto api. The GoType on the override is the unparsed type,
// which could be a qualified path or an object, as per
// https://docs.sqlc.dev/en/v1.18.0/reference/config.html#type-overriding
return &ParsedGoType{
ImportPath: o.GoImportPath,
Package: o.GoPackage,
TypeName: o.GoTypeName,
BasicType: o.GoBasicType,
StructTags: o.GoStructTags,
}
}
110 changes: 0 additions & 110 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,113 +89,3 @@ func TestInvalidConfig(t *testing.T) {
t.Errorf("expected err; got nil")
}
}

func TestTypeOverrides(t *testing.T) {
for _, test := range []struct {
override Override
pkg string
typeName string
basic bool
}{
{
Override{
DBType: "uuid",
GoType: GoType{Spec: "github.com/segmentio/ksuid.KSUID"},
},
"github.com/segmentio/ksuid",
"ksuid.KSUID",
false,
},
// TODO: Add test for struct pointers
//
// {
// Override{
// DBType: "uuid",
// GoType: "github.com/segmentio/*ksuid.KSUID",
// },
// "github.com/segmentio/ksuid",
// "*ksuid.KSUID",
// false,
// },
{
Override{
DBType: "citext",
GoType: GoType{Spec: "string"},
},
"",
"string",
true,
},
{
Override{
DBType: "timestamp",
GoType: GoType{Spec: "time.Time"},
},
"time",
"time.Time",
false,
},
} {
tt := test
t.Run(tt.override.GoType.Spec, func(t *testing.T) {
if err := tt.override.Parse(); err != nil {
t.Fatalf("override parsing failed; %s", err)
}
if diff := cmp.Diff(tt.pkg, tt.override.GoImportPath); diff != "" {
t.Errorf("package mismatch;\n%s", diff)
}
if diff := cmp.Diff(tt.typeName, tt.override.GoTypeName); diff != "" {
t.Errorf("type name mismatch;\n%s", diff)
}
if diff := cmp.Diff(tt.basic, tt.override.GoBasicType); diff != "" {
t.Errorf("basic mismatch;\n%s", diff)
}
})
}
for _, test := range []struct {
override Override
err string
}{
{
Override{
DBType: "uuid",
GoType: GoType{Spec: "Pointer"},
},
"Package override `go_type` specifier \"Pointer\" is not a Go basic type e.g. 'string'",
},
{
Override{
DBType: "uuid",
GoType: GoType{Spec: "untyped rune"},
},
"Package override `go_type` specifier \"untyped rune\" is not a Go basic type e.g. 'string'",
},
} {
tt := test
t.Run(tt.override.GoType.Spec, func(t *testing.T) {
err := tt.override.Parse()
if err == nil {
t.Fatalf("expected parse to fail; got nil")
}
if diff := cmp.Diff(tt.err, err.Error()); diff != "" {
t.Errorf("error mismatch;\n%s", diff)
}
})
}
}

func FuzzOverride(f *testing.F) {
for _, spec := range []string{
"string",
"github.com/gofrs/uuid.UUID",
"github.com/segmentio/ksuid.KSUID",
} {
f.Add(spec)
}
f.Fuzz(func(t *testing.T, s string) {
o := Override{
GoType: GoType{Spec: s},
}
o.Parse()
})
}
Loading

0 comments on commit 9c7d17d

Please sign in to comment.