Skip to content

Commit

Permalink
mimic override parse and merge within golang codegen package
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmbenton committed Oct 26, 2023
1 parent 41343ab commit 491bfe0
Show file tree
Hide file tree
Showing 8 changed files with 490 additions and 118 deletions.
2 changes: 1 addition & 1 deletion internal/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, re
case sql.Gen.Go != nil:
out = combo.Go.Out
handler = ext.HandleFunc(golang.Generate)
opts, err := json.Marshal(pluginGoOpts(sql.Gen.Go, combo, result))
opts, err := json.Marshal(sql.Gen.Go)
if err != nil {
return "", nil, fmt.Errorf("opts marshal failed: %w", err)
}
Expand Down
71 changes: 21 additions & 50 deletions internal/cmd/shim.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"encoding/json"
"strings"

goopts "github.com/sqlc-dev/sqlc/internal/codegen/golang/opts"
Expand All @@ -12,7 +13,7 @@ import (
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
)

func pluginOverride(r *compiler.Result, o config.Override) goopts.Override {
func pluginOverride(r *compiler.Result, o config.Override) *plugin.Override {
var column string
var table plugin.Identifier

Expand All @@ -34,26 +35,36 @@ func pluginOverride(r *compiler.Result, o config.Override) goopts.Override {
column = colParts[3]
}
}
return goopts.Override{
CodeType: "", // FIXME

goTypeJSON, err := json.Marshal(pluginGoType(o))
if err != nil {
panic(err)
}

return &plugin.Override{
CodeType: goTypeJSON,
DbType: o.DBType,
Nullable: o.Nullable,
Unsigned: o.Unsigned,
Column: o.Column,
ColumnName: column,
Table: &table,
GoType: pluginGoType(o),
}
}

func pluginSettings(r *compiler.Result, cs config.CombinedSettings) *plugin.Settings {
var overrides []*plugin.Override
for _, o := range cs.Overrides {
overrides = append(overrides, pluginOverride(r, o))
}
return &plugin.Settings{
Version: cs.Global.Version,
Engine: string(cs.Package.Engine),
Schema: []string(cs.Package.Schema),
Queries: []string(cs.Package.Queries),
Rename: cs.Rename,
Codegen: pluginCodegen(cs, cs.Codegen),
Version: cs.Global.Version,
Engine: string(cs.Package.Engine),
Schema: []string(cs.Package.Schema),
Queries: []string(cs.Package.Queries),
Overrides: overrides,
Rename: cs.Rename,
Codegen: pluginCodegen(cs, cs.Codegen),
}
}

Expand Down Expand Up @@ -111,46 +122,6 @@ func pluginGoType(o config.Override) *goopts.ParsedGoType {
}
}

func pluginGoOpts(sqlGo *config.SQLGo, cs config.CombinedSettings, r *compiler.Result) *goopts.Options {
var overrides []goopts.Override
for _, o := range cs.Overrides {
overrides = append(overrides, pluginOverride(r, o))
}
return &goopts.Options{
EmitInterface: sqlGo.EmitInterface,
EmitJsonTags: sqlGo.EmitJSONTags,
JsonTagsIdUppercase: sqlGo.JsonTagsIDUppercase,
EmitDbTags: sqlGo.EmitDBTags,
EmitPreparedQueries: sqlGo.EmitPreparedQueries,
EmitExactTableNames: sqlGo.EmitExactTableNames,
EmitEmptySlices: sqlGo.EmitEmptySlices,
EmitExportedQueries: sqlGo.EmitExportedQueries,
EmitResultStructPointers: sqlGo.EmitResultStructPointers,
EmitParamsStructPointers: sqlGo.EmitParamsStructPointers,
EmitMethodsWithDbArgument: sqlGo.EmitMethodsWithDBArgument,
EmitPointersForNullTypes: sqlGo.EmitPointersForNullTypes,
EmitEnumValidMethod: sqlGo.EmitEnumValidMethod,
EmitAllEnumValues: sqlGo.EmitAllEnumValues,
JsonTagsCaseStyle: sqlGo.JSONTagsCaseStyle,
Package: sqlGo.Package,
Out: sqlGo.Out,
Overrides: overrides,
// Rename intentionally omitted
SqlPackage: sqlGo.SQLPackage,
SqlDriver: sqlGo.SQLDriver,
OutputBatchFileName: sqlGo.OutputBatchFileName,
OutputDbFileName: sqlGo.OutputDBFileName,
OutputModelsFileName: sqlGo.OutputModelsFileName,
OutputQuerierFileName: sqlGo.OutputQuerierFileName,
OutputCopyfromFileName: sqlGo.OutputCopyFromFileName,
OutputFilesSuffix: sqlGo.OutputFilesSuffix,
InflectionExcludeTableNames: sqlGo.InflectionExcludeTableNames,
QueryParameterLimit: sqlGo.QueryParameterLimit,
OmitUnusedStructs: sqlGo.OmitUnusedStructs,
BuildTags: sqlGo.BuildTags,
}
}

func pluginCatalog(c *catalog.Catalog) *plugin.Catalog {
var schemas []*plugin.Schema
for _, s := range c.Schemas {
Expand Down
27 changes: 27 additions & 0 deletions internal/codegen/golang/opts/global_override.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package opts

import (
"github.com/sqlc-dev/sqlc/internal/codegen/sdk"
"github.com/sqlc-dev/sqlc/internal/plugin"
)

type GlobalOverride struct {
*plugin.Override

GoType *ParsedGoType
}

func (o *GlobalOverride) Convert() *plugin.Override {
return &plugin.Override{
DbType: o.DbType,
Nullable: o.Nullable,
Column: o.Column,
Table: o.Table,
ColumnName: o.ColumnName,
Unsigned: o.Unsigned,
}
}

func (o *GlobalOverride) Matches(n *plugin.Identifier, defaultSchema string) bool {
return sdk.Matches(o.Convert(), n, defaultSchema)
}
184 changes: 184 additions & 0 deletions internal/codegen/golang/opts/go_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
package opts

import (
"encoding/json"
"fmt"
"go/types"
"regexp"
"strings"

"github.com/fatih/structtag"
)

type GoType struct {
Path string `json:"import" yaml:"import"`
Package string `json:"package" yaml:"package"`
Name string `json:"type" yaml:"type"`
Pointer bool `json:"pointer" yaml:"pointer"`
Slice bool `json:"slice" yaml:"slice"`
Spec string
BuiltIn bool
}

type ParsedGoType struct {
ImportPath string `json:"import_path"`
Package string `json:"package"`
TypeName string `json:"type_name"`
BasicType bool `json:"basic_type"`
StructTags map[string]string `json:"struct_tags"`
}

func (o *GoType) UnmarshalJSON(data []byte) error {
var spec string
if err := json.Unmarshal(data, &spec); err == nil {
*o = GoType{Spec: spec}
return nil
}
type alias GoType
var a alias
if err := json.Unmarshal(data, &a); err != nil {
return err
}
*o = GoType(a)
return nil
}

func (o *GoType) UnmarshalYAML(unmarshal func(interface{}) error) error {
var spec string
if err := unmarshal(&spec); err == nil {
*o = GoType{Spec: spec}
return nil
}
type alias GoType
var a alias
if err := unmarshal(&a); err != nil {
return err
}
*o = GoType(a)
return nil
}

var validIdentifier = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
var versionNumber = regexp.MustCompile(`^v[0-9]+$`)
var invalidIdentifier = regexp.MustCompile(`[^a-zA-Z0-9_]`)

func generatePackageID(importPath string) (string, bool) {
parts := strings.Split(importPath, "/")
name := parts[len(parts)-1]
// If the last part of the import path is a valid identifier, assume that's the package name
if versionNumber.MatchString(name) && len(parts) >= 2 {
name = parts[len(parts)-2]
return invalidIdentifier.ReplaceAllString(strings.ToLower(name), "_"), true
}
if validIdentifier.MatchString(name) {
return name, false
}
return invalidIdentifier.ReplaceAllString(strings.ToLower(name), "_"), true
}

// validate GoType
func (gt GoType) Parse() (*ParsedGoType, error) {
var o ParsedGoType

if gt.Spec == "" {
// TODO: Validation
if gt.Path == "" && gt.Package != "" {
return nil, fmt.Errorf("Package override `go_type`: package name requires an import path")
}
var pkg string
var pkgNeedsAlias bool

if gt.Package == "" && gt.Path != "" {
pkg, pkgNeedsAlias = generatePackageID(gt.Path)
if pkgNeedsAlias {
o.Package = pkg
}
} else {
pkg = gt.Package
o.Package = gt.Package
}

o.ImportPath = gt.Path
o.TypeName = gt.Name
o.BasicType = gt.Path == "" && gt.Package == ""
if pkg != "" {
o.TypeName = pkg + "." + o.TypeName
}
if gt.Pointer {
o.TypeName = "*" + o.TypeName
}
if gt.Slice {
o.TypeName = "[]" + o.TypeName
}
return &o, nil
}

input := gt.Spec
lastDot := strings.LastIndex(input, ".")
lastSlash := strings.LastIndex(input, "/")
typename := input
if lastDot == -1 && lastSlash == -1 {
// if the type name has no slash and no dot, validate that the type is a basic Go type
var found bool
for _, typ := range types.Typ {
info := typ.Info()
if info == 0 {
continue
}
if info&types.IsUntyped != 0 {
continue
}
if typename == typ.Name() {
found = true
}
}
if !found {
return nil, fmt.Errorf("Package override `go_type` specifier %q is not a Go basic type e.g. 'string'", input)
}
o.BasicType = true
} else {
// assume the type lives in a Go package
if lastDot == -1 {
return nil, fmt.Errorf("Package override `go_type` specifier %q is not the proper format, expected 'package.type', e.g. 'github.com/segmentio/ksuid.KSUID'", input)
}
typename = input[lastSlash+1:]
// a package name beginning with "go-" will give syntax errors in
// generated code. We should do the right thing and get the actual
// import name, but in lieu of that, stripping the leading "go-" may get
// us what we want.
typename = strings.TrimPrefix(typename, "go-")
typename = strings.TrimSuffix(typename, "-go")
o.ImportPath = input[:lastDot]
}
o.TypeName = typename
isPointer := input[0] == '*'
if isPointer {
o.ImportPath = o.ImportPath[1:]
o.TypeName = "*" + o.TypeName
}
return &o, nil
}

// GoStructTag is a raw Go struct tag.
type GoStructTag string

// Parse parses and validates a GoStructTag.
// The output is in a form convenient for codegen.
//
// Sample valid inputs/outputs:
//
// In Out
// empty string {}
// `a:"b"` {"a": "b"}
// `a:"b" x:"y,z"` {"a": "b", "x": "y,z"}
func (s GoStructTag) Parse() (map[string]string, error) {
m := make(map[string]string)
tags, err := structtag.Parse(string(s))
if err != nil {
return nil, err
}
for _, tag := range tags.Tags() {
m[tag.Key] = tag.Value()
}
return m, nil
}
Loading

0 comments on commit 491bfe0

Please sign in to comment.