Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor import-loading to simplify the type-generation code #101

Merged
merged 2 commits into from
Sep 17, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/genqlient.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ bindings:
# time.Time
# map[string]interface{}
# github.com/you/yourpkg/subpkg.MyType
# Specifically, this can be any of the following expressions:
# - any named type (qualified by the full package path)
# - any predeclared basic type (string, int, etc.)
# - interface{}
# - for any allowed type T, *T, []T, [N]T, and map[string]T
# but can't be, for example:
# - an inline (unnamed) struct or interface type
# - a map whose key-type is not string
# - a nonstandard way of spelling those, (interface {/* hi */},
# map[ string ]T)
type: time.Time

# To bind an object type:
Expand Down
4 changes: 2 additions & 2 deletions generate/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (g *generator) convertType(
// bind GraphQL named types, at least for now.
localBinding := options.Bind
if localBinding != "" && localBinding != "-" {
goRef, err := g.addRef(localBinding)
goRef, err := g.ref(localBinding)
return &goOpaqueType{goRef, typ.Name()}, err
}

Expand Down Expand Up @@ -217,7 +217,7 @@ func (g *generator) convertDefinition(
return nil, err
}
}
goRef, err := g.addRef(globalBinding.Type)
goRef, err := g.ref(globalBinding.Type)
return &goOpaqueType{goRef, def.Name}, err
}
goBuiltinName, ok := builtinTypes[def.Name]
Expand Down
108 changes: 65 additions & 43 deletions generate/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ package generate
import (
"bytes"
"encoding/json"
"fmt"
"go/format"
"io"
"sort"
"strings"
"text/template"
Expand All @@ -25,12 +25,15 @@ type generator struct {
// The config for which we are generating code.
Config *Config
// The list of operations for which to generate code.
Operations []operation
Operations []*operation
// The types needed for these operations.
typeMap map[string]goType
// Imports needed for these operations, path -> alias and alias -> true
imports map[string]string
usedAliases map[string]bool
// True if we've already written out the imports (in which case they can't
// be modified).
importsLocked bool
// Cache of loaded templates.
templateCache map[string]*template.Template
// Schema we are generating code against
Expand Down Expand Up @@ -58,10 +61,12 @@ type operation struct {
ResponseName string `json:"-"`
// The original filename from which we got this query.
SourceFilename string `json:"sourceLocation"`
// The config within which we are generating code.
Config *Config `json:"-"`
}

type exportedOperations struct {
Operations []operation `json:"operations"`
Operations []*operation `json:"operations"`
}

type argument struct {
Expand All @@ -76,7 +81,7 @@ func newGenerator(
config *Config,
schema *ast.Schema,
fragments ast.FragmentDefinitionList,
) (*generator, error) {
) *generator {
g := generator{
Config: config,
typeMap: map[string]goType{},
Expand All @@ -91,29 +96,10 @@ func newGenerator(
g.fragments[fragment.Name] = fragment
}

_, err := g.addRef("github.com/Khan/genqlient/graphql.Client")
if err != nil {
return nil, err
}

if g.Config.ClientGetter != "" {
_, err := g.addRef(g.Config.ClientGetter)
if err != nil {
return nil, fmt.Errorf("invalid client_getter: %w", err)
}
}

if g.Config.ContextType != "-" {
_, err := g.addRef(g.Config.ContextType)
if err != nil {
return nil, fmt.Errorf("invalid context_type: %w", err)
}
}

return &g, nil
return &g
}

func (g *generator) Types() (string, error) {
func (g *generator) WriteTypes(w io.Writer) error {
names := make([]string, 0, len(g.typeMap))
for name := range g.typeMap {
names = append(names, name)
Expand All @@ -124,17 +110,19 @@ func (g *generator) Types() (string, error) {
// vaguely aligned to the structure of the queries.
sort.Strings(names)

defs := make([]string, 0, len(g.typeMap))
var builder strings.Builder
for _, name := range names {
builder.Reset()
err := g.typeMap[name].WriteDefinition(&builder, g)
err := g.typeMap[name].WriteDefinition(w, g)
if err != nil {
return err
}
// Make sure we have blank lines between types (and between the last
// type and the first operation)
_, err = io.WriteString(w, "\n\n")
if err != nil {
return "", err
return err
}
defs = append(defs, builder.String())
}
return strings.Join(defs, "\n\n"), nil
return nil
}

func (g *generator) getArgument(
Expand Down Expand Up @@ -317,15 +305,18 @@ func (g *generator) addOperation(op *ast.OperationDefinition) error {
sourceFilename = sourceFilename[:i]
}

g.Operations = append(g.Operations, operation{
g.Operations = append(g.Operations, &operation{
Type: op.Operation,
Name: op.Name,
Doc: docComment,
// The newline just makes it format a little nicer.
// The newline just makes it format a little nicer. We add it here
// rather than in the template so exported operations will match
// *exactly* what we send to the server.
Body: "\n" + builder.String(),
Args: args,
ResponseName: responseType.Reference(),
SourceFilename: sourceFilename,
Config: g.Config, // for the convenience of the template
})

return nil
Expand Down Expand Up @@ -362,24 +353,55 @@ func Generate(config *Config) (map[string][]byte, error) {

// Step 2: For each operation and fragment, convert it into data structures
// representing Go types (defined in types.go). The bulk of this logic is
// in convert.go.
g, err := newGenerator(config, schema, document.Fragments)
// in convert.go, and it additionally updates g.typeMap to include all the
// types it needs.
g := newGenerator(config, schema, document.Fragments)
for _, op := range document.Operations {
if err = g.addOperation(op); err != nil {
return nil, err
}
}

// Step 3: Glue it all together!
//
// First, write the types (from g.typeMap) and operations to a temporary
// buffer, since they affect what imports we'll put in the header.
var bodyBuf bytes.Buffer
err = g.WriteTypes(&bodyBuf)
if err != nil {
return nil, err
}
for _, op := range document.Operations {
if err = g.addOperation(op); err != nil {
for _, operation := range g.Operations {
err = g.execute("operation.go.tmpl", &bodyBuf, operation)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know "execute" is based on text/template terminology, but that isn't super clear in the context of the generator, IMO. Perhaps this function could be named renderTemplate or something similar.

if err != nil {
return nil, err
}
}

// The header also needs to reference some context types, which it does
// after it writes the imports, so we need to preregister those imports.
if g.Config.ContextType != "-" {
_, err = g.ref("context.Context")
if err != nil {
return nil, err
}
if g.Config.ContextType != "context.Context" {
_, err = g.ref(g.Config.ContextType)
if err != nil {
return nil, err
}
}
}

// Step 3: Glue it all together! Most of this is done inline in the
// template, but the call to g.Types() in the template calls out to
// types.go to actually generate the code for each type.
// Now really glue it all together, and format.
var buf bytes.Buffer
err = g.execute("operation.go.tmpl", &buf, g)
err = g.execute("header.go.tmpl", &buf, g)
if err != nil {
return nil, errorf(nil, "could not render template: %v", err)
return nil, err
}
_, err = io.Copy(&buf, &bodyBuf)
if err != nil {
return nil, err
}

unformatted := buf.Bytes()
Expand Down
10 changes: 10 additions & 0 deletions generate/header.go.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package {{.Config.Package}}

// Code generated by github.com/Khan/genqlient, DO NOT EDIT.

{{.Imports}}

{{if and (ne .Config.ContextType "-") (ne .Config.ContextType "context.Context") }}
// Check that context_type from genqlient.yaml implements context.Context.
var _ {{ref "context.Context"}} = ({{ref .Config.ContextType}})(nil)
{{end}}
74 changes: 29 additions & 45 deletions generate/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,60 +9,46 @@ import (
)

func (g *generator) addImportFor(pkgPath string) (alias string) {
if existingAlias, ok := g.imports[pkgPath]; ok {
return existingAlias
}

pkgName := pkgPath[strings.LastIndex(pkgPath, "/")+1:]
alias = pkgName
suffix := 2
for g.usedAliases[alias] {
alias = pkgName + strconv.Itoa(suffix)
suffix++
}

g.imports[pkgPath] = alias
g.usedAliases[alias] = true
return alias
}

// addRef adds any imports necessary to refer to the given name, and returns a
// reference alias.Name for it.
func (g *generator) addRef(fullyQualifiedName string) (qualifiedName string, err error) {
return g.getRef(fullyQualifiedName, true)
}

// ref returns a reference alias.Name for the given import, if its package was
// already added (e.g. via addRef), and an error if not.
func (g *generator) ref(fullyQualifiedName string) (qualifiedName string, err error) {
return g.getRef(fullyQualifiedName, false)
}

var _sliceOrMapPrefixRegexp = regexp.MustCompile(`^(\*|\[\d*\]|map\[string\])*`)

func (g *generator) getRef(fullyQualifiedName string, addImport bool) (qualifiedName string, err error) {
// Ideally, we want to allow a reference to basically an arbitrary symbol.
// But that's very hard, because it might be quite complicated, like
// struct{ F []map[mypkg.K]otherpkg.V }
// Now in practice, using an unnamed struct is not a great idea, but we do
// want to allow as much as we can that encoding/json knows how to work
// with, since you would reasonably expect us to accept, say,
// map[string][]interface{}. So we allow:
// - any named type (mypkg.T)
// - any predeclared basic type (string, int, etc.)
// - interface{}
// - for any allowed type T, *T, []T, [N]T, and map[string]T
// which effectively excludes:
// - unnamed struct types
// - map[K]V where K is a named type wrapping string
// - any nonstandard spelling of those (interface {/* hi */},
// map[ string ]T)
// TODO: document that somewhere visible

// ref takes a Go fully-qualified name, ensures that any necessary symbols are
// imported, and returns an appropriate reference.
//
// Ideally, we want to allow a reference to basically an arbitrary symbol.
// But that's very hard, because it might be quite complicated, like
// struct{ F []map[mypkg.K]otherpkg.V }
// Now in practice, using an unnamed struct is not a great idea, but we do
// want to allow as much as we can that encoding/json knows how to work
// with, since you would reasonably expect us to accept, say,
// map[string][]interface{}. So we allow:
// - any named type (mypkg.T)
// - any predeclared basic type (string, int, etc.)
// - interface{}
// - for any allowed type T, *T, []T, [N]T, and map[string]T
// which effectively excludes:
// - unnamed struct types
// - map[K]V where K is a named type wrapping string
// - any nonstandard spelling of those (interface {/* hi */},
// map[ string ]T)
// (This is documented in docs/genqlient.yaml)
func (g *generator) ref(fullyQualifiedName string) (qualifiedName string, err error) {
errorMsg := `invalid type-name "%v" (%v); expected a builtin, ` +
`path/to/package.Name, interface{}, or a slice, map, or pointer of those`

if strings.Contains(fullyQualifiedName, " ") {
// TODO: pass in pos here and below
return "", errorf(nil, errorMsg, fullyQualifiedName, "contains spaces")
}

Expand All @@ -80,22 +66,20 @@ func (g *generator) getRef(fullyQualifiedName string, addImport bool) (qualified

pkgPath := nameToImport[:i]
localName := nameToImport[i+1:]
var alias string
if addImport {
alias = g.addImportFor(pkgPath)
} else {
var ok bool
alias, ok = g.imports[pkgPath]
if !ok {
// This is an internal error, not a user error.
return "", errorf(nil, `no alias defined for package "%v"`, pkgPath)
alias, ok := g.imports[pkgPath]
if !ok {
if g.importsLocked {
return "", errorf(nil,
`genqlient internal error: imports locked but no alias defined for package "%v"`, pkgPath)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps: "imports locked but package "%v" has not been imported"?

}
alias = g.addImportFor(pkgPath)
}
return prefix + alias + "." + localName, nil
}

// Returns the import-clause to use in the generated code.
func (g *generator) Imports() string {
g.importsLocked = true
if len(g.imports) == 0 {
return ""
}
Expand Down
Loading