diff --git a/docs/genqlient.yaml b/docs/genqlient.yaml index d1a12bfe..8dc540b2 100644 --- a/docs/genqlient.yaml +++ b/docs/genqlient.yaml @@ -93,6 +93,16 @@ bindings: # time.Time # map[string]interface{} # github.com/you/yourpkg/subpkg.MyType + # Specifically, this can be any of the following expressions: + # - any named type (qualified by the full package path) + # - any predeclared basic type (string, int, etc.) + # - interface{} + # - for any allowed type T, *T, []T, [N]T, and map[string]T + # but can't be, for example: + # - an inline (unnamed) struct or interface type + # - a map whose key-type is not string + # - a nonstandard way of spelling those, (interface {/* hi */}, + # map[ string ]T) type: time.Time # To bind an object type: diff --git a/generate/convert.go b/generate/convert.go index 609c9932..06271e8e 100644 --- a/generate/convert.go +++ b/generate/convert.go @@ -165,7 +165,7 @@ func (g *generator) convertType( // bind GraphQL named types, at least for now. localBinding := options.Bind if localBinding != "" && localBinding != "-" { - goRef, err := g.addRef(localBinding) + goRef, err := g.ref(localBinding) return &goOpaqueType{goRef, typ.Name()}, err } @@ -217,7 +217,7 @@ func (g *generator) convertDefinition( return nil, err } } - goRef, err := g.addRef(globalBinding.Type) + goRef, err := g.ref(globalBinding.Type) return &goOpaqueType{goRef, def.Name}, err } goBuiltinName, ok := builtinTypes[def.Name] diff --git a/generate/generate.go b/generate/generate.go index 2adbf3ce..08f2ce5e 100644 --- a/generate/generate.go +++ b/generate/generate.go @@ -7,8 +7,8 @@ package generate import ( "bytes" "encoding/json" - "fmt" "go/format" + "io" "sort" "strings" "text/template" @@ -25,12 +25,15 @@ type generator struct { // The config for which we are generating code. Config *Config // The list of operations for which to generate code. - Operations []operation + Operations []*operation // The types needed for these operations. typeMap map[string]goType // Imports needed for these operations, path -> alias and alias -> true imports map[string]string usedAliases map[string]bool + // True if we've already written out the imports (in which case they can't + // be modified). + importsLocked bool // Cache of loaded templates. templateCache map[string]*template.Template // Schema we are generating code against @@ -58,10 +61,12 @@ type operation struct { ResponseName string `json:"-"` // The original filename from which we got this query. SourceFilename string `json:"sourceLocation"` + // The config within which we are generating code. + Config *Config `json:"-"` } type exportedOperations struct { - Operations []operation `json:"operations"` + Operations []*operation `json:"operations"` } type argument struct { @@ -76,7 +81,7 @@ func newGenerator( config *Config, schema *ast.Schema, fragments ast.FragmentDefinitionList, -) (*generator, error) { +) *generator { g := generator{ Config: config, typeMap: map[string]goType{}, @@ -91,29 +96,10 @@ func newGenerator( g.fragments[fragment.Name] = fragment } - _, err := g.addRef("github.com/Khan/genqlient/graphql.Client") - if err != nil { - return nil, err - } - - if g.Config.ClientGetter != "" { - _, err := g.addRef(g.Config.ClientGetter) - if err != nil { - return nil, fmt.Errorf("invalid client_getter: %w", err) - } - } - - if g.Config.ContextType != "-" { - _, err := g.addRef(g.Config.ContextType) - if err != nil { - return nil, fmt.Errorf("invalid context_type: %w", err) - } - } - - return &g, nil + return &g } -func (g *generator) Types() (string, error) { +func (g *generator) WriteTypes(w io.Writer) error { names := make([]string, 0, len(g.typeMap)) for name := range g.typeMap { names = append(names, name) @@ -124,17 +110,19 @@ func (g *generator) Types() (string, error) { // vaguely aligned to the structure of the queries. sort.Strings(names) - defs := make([]string, 0, len(g.typeMap)) - var builder strings.Builder for _, name := range names { - builder.Reset() - err := g.typeMap[name].WriteDefinition(&builder, g) + err := g.typeMap[name].WriteDefinition(w, g) + if err != nil { + return err + } + // Make sure we have blank lines between types (and between the last + // type and the first operation) + _, err = io.WriteString(w, "\n\n") if err != nil { - return "", err + return err } - defs = append(defs, builder.String()) } - return strings.Join(defs, "\n\n"), nil + return nil } func (g *generator) getArgument( @@ -317,15 +305,18 @@ func (g *generator) addOperation(op *ast.OperationDefinition) error { sourceFilename = sourceFilename[:i] } - g.Operations = append(g.Operations, operation{ + g.Operations = append(g.Operations, &operation{ Type: op.Operation, Name: op.Name, Doc: docComment, - // The newline just makes it format a little nicer. + // The newline just makes it format a little nicer. We add it here + // rather than in the template so exported operations will match + // *exactly* what we send to the server. Body: "\n" + builder.String(), Args: args, ResponseName: responseType.Reference(), SourceFilename: sourceFilename, + Config: g.Config, // for the convenience of the template }) return nil @@ -362,24 +353,55 @@ func Generate(config *Config) (map[string][]byte, error) { // Step 2: For each operation and fragment, convert it into data structures // representing Go types (defined in types.go). The bulk of this logic is - // in convert.go. - g, err := newGenerator(config, schema, document.Fragments) + // in convert.go, and it additionally updates g.typeMap to include all the + // types it needs. + g := newGenerator(config, schema, document.Fragments) + for _, op := range document.Operations { + if err = g.addOperation(op); err != nil { + return nil, err + } + } + + // Step 3: Glue it all together! + // + // First, write the types (from g.typeMap) and operations to a temporary + // buffer, since they affect what imports we'll put in the header. + var bodyBuf bytes.Buffer + err = g.WriteTypes(&bodyBuf) if err != nil { return nil, err } - for _, op := range document.Operations { - if err = g.addOperation(op); err != nil { + for _, operation := range g.Operations { + err = g.render("operation.go.tmpl", &bodyBuf, operation) + if err != nil { + return nil, err + } + } + + // The header also needs to reference some context types, which it does + // after it writes the imports, so we need to preregister those imports. + if g.Config.ContextType != "-" { + _, err = g.ref("context.Context") + if err != nil { return nil, err } + if g.Config.ContextType != "context.Context" { + _, err = g.ref(g.Config.ContextType) + if err != nil { + return nil, err + } + } } - // Step 3: Glue it all together! Most of this is done inline in the - // template, but the call to g.Types() in the template calls out to - // types.go to actually generate the code for each type. + // Now really glue it all together, and format. var buf bytes.Buffer - err = g.execute("operation.go.tmpl", &buf, g) + err = g.render("header.go.tmpl", &buf, g) if err != nil { - return nil, errorf(nil, "could not render template: %v", err) + return nil, err + } + _, err = io.Copy(&buf, &bodyBuf) + if err != nil { + return nil, err } unformatted := buf.Bytes() diff --git a/generate/header.go.tmpl b/generate/header.go.tmpl new file mode 100644 index 00000000..0055504b --- /dev/null +++ b/generate/header.go.tmpl @@ -0,0 +1,10 @@ +package {{.Config.Package}} + +// Code generated by github.com/Khan/genqlient, DO NOT EDIT. + +{{.Imports}} + +{{if and (ne .Config.ContextType "-") (ne .Config.ContextType "context.Context") }} +// Check that context_type from genqlient.yaml implements context.Context. +var _ {{ref "context.Context"}} = ({{ref .Config.ContextType}})(nil) +{{end}} diff --git a/generate/imports.go b/generate/imports.go index e7a60079..89e5182c 100644 --- a/generate/imports.go +++ b/generate/imports.go @@ -9,15 +9,12 @@ import ( ) func (g *generator) addImportFor(pkgPath string) (alias string) { - if existingAlias, ok := g.imports[pkgPath]; ok { - return existingAlias - } - pkgName := pkgPath[strings.LastIndex(pkgPath, "/")+1:] alias = pkgName suffix := 2 for g.usedAliases[alias] { alias = pkgName + strconv.Itoa(suffix) + suffix++ } g.imports[pkgPath] = alias @@ -25,44 +22,33 @@ func (g *generator) addImportFor(pkgPath string) (alias string) { return alias } -// addRef adds any imports necessary to refer to the given name, and returns a -// reference alias.Name for it. -func (g *generator) addRef(fullyQualifiedName string) (qualifiedName string, err error) { - return g.getRef(fullyQualifiedName, true) -} - -// ref returns a reference alias.Name for the given import, if its package was -// already added (e.g. via addRef), and an error if not. -func (g *generator) ref(fullyQualifiedName string) (qualifiedName string, err error) { - return g.getRef(fullyQualifiedName, false) -} - var _sliceOrMapPrefixRegexp = regexp.MustCompile(`^(\*|\[\d*\]|map\[string\])*`) -func (g *generator) getRef(fullyQualifiedName string, addImport bool) (qualifiedName string, err error) { - // Ideally, we want to allow a reference to basically an arbitrary symbol. - // But that's very hard, because it might be quite complicated, like - // struct{ F []map[mypkg.K]otherpkg.V } - // Now in practice, using an unnamed struct is not a great idea, but we do - // want to allow as much as we can that encoding/json knows how to work - // with, since you would reasonably expect us to accept, say, - // map[string][]interface{}. So we allow: - // - any named type (mypkg.T) - // - any predeclared basic type (string, int, etc.) - // - interface{} - // - for any allowed type T, *T, []T, [N]T, and map[string]T - // which effectively excludes: - // - unnamed struct types - // - map[K]V where K is a named type wrapping string - // - any nonstandard spelling of those (interface {/* hi */}, - // map[ string ]T) - // TODO: document that somewhere visible - +// ref takes a Go fully-qualified name, ensures that any necessary symbols are +// imported, and returns an appropriate reference. +// +// Ideally, we want to allow a reference to basically an arbitrary symbol. +// But that's very hard, because it might be quite complicated, like +// struct{ F []map[mypkg.K]otherpkg.V } +// Now in practice, using an unnamed struct is not a great idea, but we do +// want to allow as much as we can that encoding/json knows how to work +// with, since you would reasonably expect us to accept, say, +// map[string][]interface{}. So we allow: +// - any named type (mypkg.T) +// - any predeclared basic type (string, int, etc.) +// - interface{} +// - for any allowed type T, *T, []T, [N]T, and map[string]T +// which effectively excludes: +// - unnamed struct types +// - map[K]V where K is a named type wrapping string +// - any nonstandard spelling of those (interface {/* hi */}, +// map[ string ]T) +// (This is documented in docs/genqlient.yaml) +func (g *generator) ref(fullyQualifiedName string) (qualifiedName string, err error) { errorMsg := `invalid type-name "%v" (%v); expected a builtin, ` + `path/to/package.Name, interface{}, or a slice, map, or pointer of those` if strings.Contains(fullyQualifiedName, " ") { - // TODO: pass in pos here and below return "", errorf(nil, errorMsg, fullyQualifiedName, "contains spaces") } @@ -80,22 +66,20 @@ func (g *generator) getRef(fullyQualifiedName string, addImport bool) (qualified pkgPath := nameToImport[:i] localName := nameToImport[i+1:] - var alias string - if addImport { - alias = g.addImportFor(pkgPath) - } else { - var ok bool - alias, ok = g.imports[pkgPath] - if !ok { - // This is an internal error, not a user error. - return "", errorf(nil, `no alias defined for package "%v"`, pkgPath) + alias, ok := g.imports[pkgPath] + if !ok { + if g.importsLocked { + return "", errorf(nil, + `genqlient internal error: imports locked but package "%v" has not been imported`, pkgPath) } + alias = g.addImportFor(pkgPath) } return prefix + alias + "." + localName, nil } // Returns the import-clause to use in the generated code. func (g *generator) Imports() string { + g.importsLocked = true if len(g.imports) == 0 { return "" } diff --git a/generate/operation.go.tmpl b/generate/operation.go.tmpl index c5614b51..e5cadf68 100644 --- a/generate/operation.go.tmpl +++ b/generate/operation.go.tmpl @@ -1,20 +1,9 @@ -package {{.Config.Package}} - -// Code generated by github.com/Khan/genqlient, DO NOT EDIT. - -{{.Imports}} - -{{/* TODO: type-assert that your ctx type implements context.Context */}} - -{{.Types}} - -{{range .Operations}} {{.Doc}} func {{.Name}}( - {{if ne $.Config.ContextType "-" -}} - ctx {{ref $.Config.ContextType}}, + {{if ne .Config.ContextType "-" -}} + ctx {{ref .Config.ContextType}}, {{end}} - {{- if not $.Config.ClientGetter -}} + {{- if not .Config.ClientGetter -}} client {{ref "github.com/Khan/genqlient/graphql.Client"}}, {{end}} {{- range .Args -}} @@ -46,8 +35,8 @@ func {{.Name}}( {{end -}} var err error - {{if $.Config.ClientGetter -}} - client, err := {{ref $.Config.ClientGetter}}({{if ne $.Config.ContextType "-"}}ctx{{else}}{{end}}) + {{if .Config.ClientGetter -}} + client, err := {{ref .Config.ClientGetter}}({{if ne .Config.ContextType "-"}}ctx{{else}}{{end}}) if err != nil { return nil, err } @@ -55,7 +44,7 @@ func {{.Name}}( var retval {{.ResponseName}} err = client.MakeRequest( - {{if ne $.Config.ContextType "-"}}ctx{{else}}nil{{end}}, + {{if ne .Config.ContextType "-"}}ctx{{else}}nil{{end}}, "{{.Name}}", `{{.Body}}`, &retval, @@ -63,4 +52,3 @@ func {{.Name}}( ) return &retval, err } -{{end}} diff --git a/generate/template.go b/generate/template.go index 984fdf61..6e3ab438 100644 --- a/generate/template.go +++ b/generate/template.go @@ -33,8 +33,8 @@ func intRange(n int) []int { func sub(x, y int) int { return x - y } -// execute executes the given template with the funcs from this generator. -func (g *generator) execute(tmplRelFilename string, w io.Writer, data interface{}) error { +// render executes the given template with the funcs from this generator. +func (g *generator) render(tmplRelFilename string, w io.Writer, data interface{}) error { tmpl := g.templateCache[tmplRelFilename] if tmpl == nil { absFilename := filepath.Join(thisDir, tmplRelFilename) diff --git a/generate/testdata/snapshots/TestGenerateWithConfig-ClientGetterCustomContext-testdata-queries-generated.go b/generate/testdata/snapshots/TestGenerateWithConfig-ClientGetterCustomContext-testdata-queries-generated.go index 107279c1..098bc980 100644 --- a/generate/testdata/snapshots/TestGenerateWithConfig-ClientGetterCustomContext-testdata-queries-generated.go +++ b/generate/testdata/snapshots/TestGenerateWithConfig-ClientGetterCustomContext-testdata-queries-generated.go @@ -3,9 +3,14 @@ package queries // Code generated by github.com/Khan/genqlient, DO NOT EDIT. import ( + "context" + "github.com/Khan/genqlient/internal/testutil" ) +// Check that context_type from genqlient.yaml implements context.Context. +var _ context.Context = (testutil.MyContext)(nil) + // SimpleQueryResponse is returned by SimpleQuery on success. type SimpleQueryResponse struct { // user looks up a user by some stuff. diff --git a/generate/testdata/snapshots/TestGenerateWithConfig-CustomContext-testdata-queries-generated.go b/generate/testdata/snapshots/TestGenerateWithConfig-CustomContext-testdata-queries-generated.go index ff9aa398..3c7a7ef0 100644 --- a/generate/testdata/snapshots/TestGenerateWithConfig-CustomContext-testdata-queries-generated.go +++ b/generate/testdata/snapshots/TestGenerateWithConfig-CustomContext-testdata-queries-generated.go @@ -3,10 +3,15 @@ package queries // Code generated by github.com/Khan/genqlient, DO NOT EDIT. import ( + "context" + "github.com/Khan/genqlient/graphql" "github.com/Khan/genqlient/internal/testutil" ) +// Check that context_type from genqlient.yaml implements context.Context. +var _ context.Context = (testutil.MyContext)(nil) + // SimpleQueryResponse is returned by SimpleQuery on success. type SimpleQueryResponse struct { // user looks up a user by some stuff. diff --git a/generate/types.go b/generate/types.go index 653c0be5..e29f32a4 100644 --- a/generate/types.go +++ b/generate/types.go @@ -202,14 +202,7 @@ func (typ *goStructType) WriteDefinition(w io.Writer, g *generator) error { return nil } - // TODO(benkraft): Avoid having to enumerate these in advance; just let the - // template add them directly. - _, err := g.addRef("encoding/json.Unmarshal") - if err != nil { - return err - } - - return g.execute("unmarshal.go.tmpl", w, typ) + return g.render("unmarshal.go.tmpl", w, typ) } func (typ *goStructType) Reference() string { return typ.GoName } @@ -287,19 +280,7 @@ func (typ *goInterfaceType) WriteDefinition(w io.Writer, g *generator) error { // Finally, write the unmarshal-helper, which will be called by struct // fields referencing this type (see goStructType.WriteDefinition). - // - // TODO(benkraft): Avoid having to enumerate these refs in advance; just - // let the template add them directly. - _, err := g.addRef("encoding/json.Unmarshal") - if err != nil { - return err - } - _, err = g.addRef("fmt.Errorf") - if err != nil { - return err - } - - return g.execute("unmarshal_helper.go.tmpl", w, typ) + return g.render("unmarshal_helper.go.tmpl", w, typ) } func (typ *goInterfaceType) Reference() string { return typ.GoName }