From b21cc62342ea2551c94b76faaba9b536cb8dcb38 Mon Sep 17 00:00:00 2001 From: Dmitry Doroginin Date: Thu, 21 Jun 2018 16:42:18 +0300 Subject: [PATCH] Add impl gen, optional swagger gen, support std http.ServeMux, swagger json options, go_package supporting --- cmd/protoc-gen-goclay/genhandler/handler.go | 237 ++++++++++--- cmd/protoc-gen-goclay/genhandler/options.go | 48 +++ cmd/protoc-gen-goclay/genhandler/template.go | 334 +++++++++++-------- cmd/protoc-gen-goclay/main.go | 77 +++-- cmd/protoc-gen-goclay/swagger.go | 11 +- transport/handlers.go | 14 +- transport/swagopts.go | 46 +++ 7 files changed, 541 insertions(+), 226 deletions(-) create mode 100644 cmd/protoc-gen-goclay/genhandler/options.go create mode 100644 transport/swagopts.go diff --git a/cmd/protoc-gen-goclay/genhandler/handler.go b/cmd/protoc-gen-goclay/genhandler/handler.go index 6b03c80..662be03 100644 --- a/cmd/protoc-gen-goclay/genhandler/handler.go +++ b/cmd/protoc-gen-goclay/genhandler/handler.go @@ -1,107 +1,156 @@ package genhandler import ( + "bytes" + "encoding/json" "fmt" + "go/build" "go/format" + "os" "path" "path/filepath" "strings" + "github.com/go-openapi/spec" "github.com/golang/glog" "github.com/golang/protobuf/proto" plugin "github.com/golang/protobuf/protoc-gen-go/plugin" "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor" - options "google.golang.org/genproto/googleapis/api/annotations" + "google.golang.org/genproto/googleapis/api/annotations" ) type Generator struct { - reg *descriptor.Registry - baseImports []descriptor.GoPackage + options options + reg *descriptor.Registry + imports []descriptor.GoPackage // common imports } // New returns a new generator which generates handler wrappers. -func New(reg *descriptor.Registry) *Generator { - var imports []descriptor.GoPackage - for _, pkgpath := range []string{ - "net/http", - - "github.com/utrack/clay/transport", - "github.com/utrack/clay/transport/httpruntime", - - "github.com/grpc-ecosystem/grpc-gateway/runtime", +func New(reg *descriptor.Registry, opts ...Option) *Generator { + o := options{} + for _, opt := range opts { + opt(&o) + } + g := &Generator{ + options: o, + reg: reg, + } + g.imports = append(g.imports, + g.newGoPackage("github.com/pkg/errors"), + g.newGoPackage("github.com/utrack/clay/transport"), + ) + return g +} - "google.golang.org/grpc", - "github.com/go-chi/chi", - "github.com/pkg/errors", - } { - pkg := descriptor.GoPackage{ - Path: pkgpath, - Name: path.Base(pkgpath), - } - if err := reg.ReserveGoPackageAlias(pkg.Name, pkg.Path); err != nil { - for i := 0; ; i++ { - alias := fmt.Sprintf("%s_%d", pkg.Name, i) - if err := reg.ReserveGoPackageAlias(alias, pkg.Path); err != nil { - continue - } - pkg.Alias = alias - break +func (g *Generator) newGoPackage(pkgPath string) descriptor.GoPackage { + gopkg := descriptor.GoPackage{ + Path: pkgPath, + Name: path.Base(pkgPath), + } + if err := g.reg.ReserveGoPackageAlias(gopkg.Name, gopkg.Path); err != nil { + for i := 0; ; i++ { + alias := fmt.Sprintf("%s_%d", gopkg.Name, i) + if err := g.reg.ReserveGoPackageAlias(alias, gopkg.Path); err != nil { + continue } + gopkg.Alias = alias + break } - imports = append(imports, pkg) } - return &Generator{reg: reg, baseImports: imports} + return gopkg } -func (g *Generator) Generate(targets []*descriptor.File, fileToSwagger map[string][]byte) ([]*plugin.CodeGeneratorResponse_File, error) { +func (g *Generator) Generate(targets []*descriptor.File) ([]*plugin.CodeGeneratorResponse_File, error) { var files []*plugin.CodeGeneratorResponse_File for _, file := range targets { glog.V(1).Infof("Processing %s", file.GetName()) - code, err := g.getTemplate(fileToSwagger[file.GetName()], file) - if err == errNoTargetService { - glog.V(1).Infof("%s: %v", file.GetName(), err) + if len(file.Services) == 0 { + glog.V(0).Infof("%s: %v", file.GetName(), errNoTargetService) continue } + descCode, err := g.getDescTemplate(g.options.SwaggerDef[file.GetName()], file) + if err != nil { return nil, err } - formatted, err := format.Source([]byte(code)) + formatted, err := format.Source([]byte(descCode)) if err != nil { - - glog.Errorf("%v: %s", err, annotateString(code)) + glog.Errorf("%v: %s", err, annotateString(descCode)) return nil, err } name := file.GetName() ext := filepath.Ext(name) base := strings.TrimSuffix(name, ext) - output := fmt.Sprintf("%s.pb.goclay.go", base) + + goPkg := "" + if file.GoPkg.Path != "." { + goPkg = file.GoPkg.Path + } + output := fmt.Sprintf(filepath.Join(goPkg, g.options.DescPath, "%s.pb.goclay.go"), base) + output = filepath.Clean(output) + files = append(files, &plugin.CodeGeneratorResponse_File{ Name: proto.String(output), Content: proto.String(string(formatted)), }) glog.V(1).Infof("Will emit %s", output) + + if g.options.Impl { + output := fmt.Sprintf(filepath.Join(goPkg, g.options.ImplPath, "%s.pb.impl.go"), base) + output = filepath.Clean(output) + + if !g.options.Force && fileExists(output) { + glog.V(0).Infof("Implementation will not be emitted: file '%s' already exists", output) + continue + } + + implCode, err := g.getImplTemplate(file) + if err != nil { + return nil, err + } + formatted, err := format.Source([]byte(implCode)) + if err != nil { + glog.Errorf("%v: %s", err, annotateString(implCode)) + return nil, err + } + + files = append(files, &plugin.CodeGeneratorResponse_File{ + Name: proto.String(output), + Content: proto.String(string(formatted)), + }) + glog.V(1).Infof("Will emit %s", output) + } } return files, nil } -func (g *Generator) getTemplate(swagBuffer []byte, f *descriptor.File) (string, error) { - if len(f.Services) == 0 { - return "", errNoTargetService - } +func (g *Generator) getDescTemplate(swagger *spec.Swagger, f *descriptor.File) (string, error) { pkgSeen := make(map[string]bool) var imports []descriptor.GoPackage - for _, pkg := range g.baseImports { + for _, pkg := range g.imports { pkgSeen[pkg.Path] = true imports = append(imports, pkg) } + for _, pkg := range []string{ + "net/http", + "github.com/utrack/clay/transport/httpruntime", + "github.com/grpc-ecosystem/grpc-gateway/runtime", + "google.golang.org/grpc", + "github.com/go-chi/chi", + "github.com/go-openapi/spec", + } { + pkgSeen[pkg] = true + imports = append(imports, g.newGoPackage(pkg)) + } + for _, svc := range f.Services { for _, m := range svc.Methods { pkg := m.RequestType.File.GoPkg // Add request type package to imports if needed - if m.Options == nil || !proto.HasExtension(m.Options, options.E_Http) || + if m.Options == nil || !proto.HasExtension(m.Options, annotations.E_Http) || pkg == f.GoPkg || pkgSeen[pkg.Path] { continue } @@ -109,8 +158,61 @@ func (g *Generator) getTemplate(swagBuffer []byte, f *descriptor.File) (string, imports = append(imports, pkg) } } + p := param{File: f, Imports: imports} + if swagger != nil { + b, err := swagger.MarshalJSON() + var buf bytes.Buffer + err = json.Indent(&buf, b, "", " ") + if err != nil { + return "", err + } + p.SwaggerBuffer = buf.Bytes() + } + return applyDescTemplate(p) +} - return applyTemplate(param{SwagBuffer: swagBuffer, File: f, Imports: imports}) +func (g *Generator) getImplTemplate(f *descriptor.File) (string, error) { + pkgSeen := make(map[string]bool) + var imports []descriptor.GoPackage + for _, pkg := range g.imports { + pkgSeen[pkg.Path] = true + imports = append(imports, pkg) + } + for _, pkg := range []string{ + "context", + } { + pkgSeen[pkg] = true + imports = append(imports, g.newGoPackage(pkg)) + } + p := param{ + File: f, + } + if g.options.ImplPath != g.options.DescPath { + pkg := filepath.Join(getRootImportPath(f), g.options.DescPath) + pkgSeen[pkg] = true + gopkg := g.newGoPackage(pkg) + imports = append(imports, gopkg) + if gopkg.Alias != "" { + p.DescPrefix = gopkg.Alias + "." + } else { + p.DescPrefix = gopkg.Name + "." + } + } + for _, svc := range f.Services { + for _, m := range svc.Methods { + pkg := m.RequestType.File.GoPkg + // Add request type package to imports if needed + if m.Options == nil || !proto.HasExtension(m.Options, annotations.E_Http) || + pkg == f.GoPkg || pkgSeen[pkg.Path] { + continue + } + pkgSeen[pkg.Path] = true + imports = append(imports, pkg) + } + } + p.Imports = imports + + return applyImplTemplate(p) } func annotateString(str string) string { @@ -120,3 +222,48 @@ func annotateString(str string) string { } return strings.Join(strs, "\n") } + +func fileExists(path string) bool { + dir, err := filepath.Abs(filepath.Dir(os.Args[0])) + if err != nil { + glog.V(0).Info(err) + } + dir, err = filepath.EvalSymlinks(dir) + if err != nil { + glog.V(0).Info(err) + } + if _, err := os.Stat(filepath.Join(dir, path)); err == nil { + return true + } + return false +} + +func getRootImportPath(file *descriptor.File) string { + goImportPath := "" + if file.GoPkg.Path != "." { + goImportPath = file.GoPkg.Path + } + dir, err := filepath.Abs(filepath.Dir(os.Args[0])) + if err != nil { + glog.V(0).Info(err) + } + dir, err = filepath.EvalSymlinks(dir) + if err != nil { + glog.V(0).Info(err) + } + for _, gp := range strings.Split(build.Default.GOPATH, ":") { + agp, _ := filepath.Abs(gp) + agp, _ = filepath.EvalSymlinks(agp) + if strings.HasPrefix(dir, agp) { + currentPath := strings.TrimPrefix(dir, agp+"/src/") + if strings.HasPrefix(goImportPath, currentPath) { + return goImportPath + } else if goImportPath != "" { + return filepath.Join(currentPath, goImportPath) + } else { + return currentPath + } + } + } + return "" +} diff --git a/cmd/protoc-gen-goclay/genhandler/options.go b/cmd/protoc-gen-goclay/genhandler/options.go new file mode 100644 index 0000000..f403494 --- /dev/null +++ b/cmd/protoc-gen-goclay/genhandler/options.go @@ -0,0 +1,48 @@ +package genhandler + +import "github.com/go-openapi/spec" + +type options struct { + ImplPath string + DescPath string + SwaggerDef map[string]*spec.Swagger + Impl bool + Force bool +} + +type Option func(*options) + +// SwaggerDef sets map of spec.Swagger per proto file +func SwaggerDef(swaggerDef map[string]*spec.Swagger) Option { + return func(o *options) { + o.SwaggerDef = swaggerDef + } +} + +// Impl sets Impl flag option (if true implementation will be generated) +func Impl(impl bool) Option { + return func(o *options) { + o.Impl = impl + } +} + +// ImplPath sets path for implementation file +func ImplPath(path string) Option { + return func(o *options) { + o.ImplPath = path + } +} + +// DescPath sets path for description and swagger file +func DescPath(path string) Option { + return func(o *options) { + o.DescPath = path + } +} + +// Force sets force mode for generation implementation +func Force(force bool) Option { + return func(o *options) { + o.Force = force + } +} diff --git a/cmd/protoc-gen-goclay/genhandler/template.go b/cmd/protoc-gen-goclay/genhandler/template.go index 400552e..4bd0334 100644 --- a/cmd/protoc-gen-goclay/genhandler/template.go +++ b/cmd/protoc-gen-goclay/genhandler/template.go @@ -1,75 +1,83 @@ package genhandler import ( - "bytes" - "strings" - "text/template" + "bytes" + "strings" + "text/template" - "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor" - "github.com/pkg/errors" + "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor" + "github.com/pkg/errors" ) var ( - errNoTargetService = errors.New("no target service defined in the file") + errNoTargetService = errors.New("no target service defined in the file") ) type param struct { - *descriptor.File - Imports []descriptor.GoPackage - SwagBuffer []byte + *descriptor.File + Imports []descriptor.GoPackage + SwaggerBuffer []byte + DescPrefix string } -func applyTemplate(p param) (string, error) { - // r := &http.Request{} - // r.URL.Query() - w := bytes.NewBuffer(nil) - if err := headerTemplate.Execute(w, p); err != nil { - return "", err - } - - if err := regTemplate.ExecuteTemplate(w, "base", p); err != nil { - return "", err - } - - type swaggerTmpl struct { - FileName string - Swagger string - } - - if err := footerTemplate.Execute(w, p); err != nil { - return "", err - } - - if err := patternsTemplate.ExecuteTemplate(w, "base", p); err != nil { - return "", err - } - //spew.Dump(p.Services[0].Methods[0].Bindings) - - return w.String(), nil +func applyImplTemplate(p param) (string, error) { + w := bytes.NewBuffer(nil) + + if err := implTemplate.Execute(w, p); err != nil { + return "", err + } + + return w.String(), nil +} + +func applyDescTemplate(p param) (string, error) { + // r := &http.Request{} + // r.URL.Query() + w := bytes.NewBuffer(nil) + if err := headerTemplate.Execute(w, p); err != nil { + return "", err + } + + if err := regTemplate.ExecuteTemplate(w, "base", p); err != nil { + return "", err + } + + if len(p.SwaggerBuffer) > 0 { + if err := footerTemplate.Execute(w, p); err != nil { + return "", err + } + } + + if err := patternsTemplate.ExecuteTemplate(w, "base", p); err != nil { + return "", err + } + //spew.Dump(p.Services[0].Methods[0].Bindings) + + return w.String(), nil } var ( - funcMap = template.FuncMap{ - "dotToUnderscore": func(s string) string { return strings.Replace(strings.Replace(s, ".", "_", -1), "/", "_", -1) }, - "byteStr": func(b []byte) string { return string(b) }, - "escapeBackTicks": func(s string) string { return strings.Replace(s, "`", "` + \"``\" + `", -1) }, - } + funcMap = template.FuncMap{ + "dotToUnderscore": func(s string) string { return strings.Replace(strings.Replace(s, ".", "_", -1), "/", "_", -1) }, + "byteStr": func(b []byte) string { return string(b) }, + "escapeBackTicks": func(s string) string { return strings.Replace(s, "`", "` + \"``\" + `", -1) }, + } - headerTemplate = template.Must(template.New("header").Parse(` + headerTemplate = template.Must(template.New("header").Parse(` // Code generated by protoc-gen-goclay -// source: {{.GetName}} +// source: {{ .GetName }} // DO NOT EDIT! /* -Package {{.GoPkg.Name}} is a self-registering gRPC and JSON+Swagger service definition. +Package {{ .GoPkg.Name }} is a self-registering gRPC and JSON+Swagger service definition. It conforms to the github.com/utrack/clay Service interface. */ -package {{.GoPkg.Name}} +package {{ .GoPkg.Name }} import ( - {{range $i := .Imports}}{{if $i.Standard}}{{$i | printf "%s\n"}}{{end}}{{end}} + {{ range $i := .Imports }}{{ if $i.Standard }}{{ $i | printf "%s\n" }}{{ end }}{{ end }} - {{range $i := .Imports}}{{if not $i.Standard}}{{$i | printf "%s\n"}}{{end}}{{end}} + {{ range $i := .Imports }}{{ if not $i.Standard }}{{ $i | printf "%s\n" }}{{ end }}{{ end }} ) // Update your shared lib or downgrade generator to v1 if there's an error @@ -78,109 +86,175 @@ var _ = transport.IsVersion2 var _ chi.Router var _ runtime.Marshaler `)) - regTemplate = template.Must(template.New("svc-reg").Funcs(funcMap).Parse(` -{{define "base"}} -{{range $svc := .Services}} -// {{$svc.GetName}}Desc is a descriptor/registrator for the {{$svc.GetName}}Server. -type {{$svc.GetName}}Desc struct { - svc {{$svc.GetName}}Server + regTemplate = template.Must(template.New("svc-reg").Funcs(funcMap).Parse(` +{{ define "base" }} +{{ range $svc := .Services }} +// {{ $svc.GetName }}Desc is a descriptor/registrator for the {{ $svc.GetName }}Server. +type {{ $svc.GetName }}Desc struct { + svc {{ $svc.GetName }}Server } -// New{{$svc.GetName}}ServiceDesc creates new registrator for the {{$svc.GetName}}Server. -func New{{$svc.GetName}}ServiceDesc(svc {{$svc.GetName}}Server) *{{$svc.GetName}}Desc { - return &{{$svc.GetName}}Desc{svc:svc} +// New{{ $svc.GetName }}ServiceDesc creates new registrator for the {{ $svc.GetName }}Server. +func New{{ $svc.GetName }}ServiceDesc(svc {{ $svc.GetName }}Server) *{{ $svc.GetName }}Desc { + return &{{ $svc.GetName }}Desc{svc:svc} } // RegisterGRPC implements service registrator interface. -func (d *{{$svc.GetName}}Desc) RegisterGRPC(s *grpc.Server) { - Register{{$svc.GetName}}Server(s,d.svc) +func (d *{{ $svc.GetName }}Desc) RegisterGRPC(s *grpc.Server) { + Register{{ $svc.GetName }}Server(s,d.svc) } // SwaggerDef returns this file's Swagger definition. -func (d *{{$svc.GetName}}Desc) SwaggerDef() []byte { - return _swaggerDef_{{dotToUnderscore $.GetName}} +func (d *{{ $svc.GetName }}Desc) SwaggerDef(options ...transport.SwaggerOption) (result []byte) { + {{ if $.SwaggerBuffer }}if len(options) > 0 { + var err error + var swagger = &spec.Swagger{} + if err = swagger.UnmarshalJSON(_swaggerDef_{{ dotToUnderscore $.GetName }}); err != nil { + panic("Bad swagger definition: " + err.Error()) + } + for _, o := range options { + o(swagger) + } + if result, err = swagger.MarshalJSON(); err != nil { + panic("Failed marshal spec.Swagger definition: " + err.Error()) + } + } else { + result = _swaggerDef_{{ dotToUnderscore $.GetName }} + }{{ end }} + return result } // RegisterHTTP registers this service's HTTP handlers/bindings. -func (d *{{$svc.GetName}}Desc) RegisterHTTP(mux transport.Router) { - {{range $m := $svc.Methods}} - // Handlers for {{$m.GetName}} - {{range $b := $m.Bindings}} - mux.MethodFunc("{{$b.HTTPMethod}}",pattern_goclay_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, r *http.Request) { - defer r.Body.Close() - - var req {{$m.RequestType.GetName}} - err := unmarshaler_goclay_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(r,&req) - if err != nil { - httpruntime.SetError(r.Context(),r,w,errors.Wrap(err,"couldn't parse request")) - return - } - - ret,err := d.svc.{{$m.GetName}}(r.Context(),&req) - if err != nil { - httpruntime.SetError(r.Context(),r,w,errors.Wrap(err,"returned from handler")) - return - } - - _,outbound := httpruntime.MarshalerForRequest(r) - w.Header().Set("Content-Type", outbound.ContentType()) - err = outbound.Marshal(w, ret) - if err != nil { - httpruntime.SetError(r.Context(),r,w,errors.Wrap(err,"couldn't write response")) - return - } - }) - {{end}} - {{end}} +func (d *{{ $svc.GetName }}Desc) RegisterHTTP(mux transport.Router) error { + chiMux, isChi := mux.(chi.Router) + var h http.HandlerFunc + {{ range $m := $svc.Methods }} + {{ range $b := $m.Bindings -}} + // Handler for {{ $m.GetName }}, binding: {{ $b.HTTPMethod }} {{ $b.PathTmpl.Template }} + h = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + var req {{ $m.RequestType.GetName }} + err := unmarshaler_goclay_{{ $svc.GetName }}_{{ $m.GetName }}_{{ $b.Index }}(r,&req) + if err != nil { + httpruntime.SetError(r.Context(),r,w,errors.Wrap(err,"couldn't parse request")) + return + } + + ret,err := d.svc.{{ $m.GetName }}(r.Context(),&req) + if err != nil { + httpruntime.SetError(r.Context(),r,w,errors.Wrap(err,"returned from handler")) + return + } + + _,outbound := httpruntime.MarshalerForRequest(r) + w.Header().Set("Content-Type", outbound.ContentType()) + err = outbound.Marshal(w, ret) + if err != nil { + httpruntime.SetError(r.Context(),r,w,errors.Wrap(err,"couldn't write response")) + return + } + }) + if isChi { + chiMux.Method("{{ $b.HTTPMethod }}",pattern_goclay_{{ $svc.GetName }}_{{ $m.GetName }}_{{ $b.Index }}, h) + } else { + {{if $b.PathParams -}} + return errors.New("query URI params supported only for chi.Router") + {{- else -}} + mux.Handle(pattern_goclay_{{ $svc.GetName }}_{{ $m.GetName }}_{{ $b.Index }}, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "{{ $b.HTTPMethod }}" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + h(w, r) + })) + {{- end }} + } + {{ end }} + {{ end }} + return nil } -{{end}} -{{end}} // base service handler ended +{{ end }} +{{ end }} // base service handler ended `)) - footerTemplate = template.Must(template.New("footer").Funcs(funcMap).Parse(` -var _swaggerDef_{{dotToUnderscore .GetName}} = []byte(` + "`" + `{{escapeBackTicks (byteStr .SwagBuffer)}}` + ` + footerTemplate = template.Must(template.New("footer").Funcs(funcMap).Parse(` + var _swaggerDef_{{ dotToUnderscore .GetName }} = []byte(` + "`" + `{{ escapeBackTicks (byteStr .SwaggerBuffer) }}` + ` ` + "`)" + ` `)) - patternsTemplate = template.Must(template.New("patterns").Parse(` -{{define "base"}} + patternsTemplate = template.Must(template.New("patterns").Parse(` +{{ define "base" }} var ( -{{range $svc := .Services}} -{{range $m := $svc.Methods}} -{{range $b := $m.Bindings}} - pattern_goclay_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = "{{$b.PathTmpl.Template}}" - unmarshaler_goclay_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} = func(r *http.Request,req *{{$m.RequestType.GetName}}) error { - - var err error - {{if $b.Body}} - {{template "unmbody" .}} - {{end}} - {{if $b.PathParams}} - {{template "unmpath" .}} - {{end}} - +{{ range $svc := .Services }} +{{ range $m := $svc.Methods }} +{{ range $b := $m.Bindings }} + pattern_goclay_{{ $svc.GetName }}_{{ $m.GetName }}_{{ $b.Index }} = "{{ $b.PathTmpl.Template }}" + unmarshaler_goclay_{{ $svc.GetName }}_{{ $m.GetName }}_{{ $b.Index }} = func(r *http.Request,req *{{ $m.RequestType.GetName }}) error { + {{- if $b.Body -}} + {{- template "unmbody" . -}} + {{- end -}} + {{- if $b.PathParams -}} + {{- template "unmpath" . -}} + {{- end -}} + return nil + } +{{ end }} +{{ end }} +{{ end }} +) +{{ end }} +{{ define "unmbody" }} + inbound,_ := httpruntime.MarshalerForRequest(r) + if err := errors.Wrap(inbound.Unmarshal(r.Body,req),"couldn't read request JSON"); err != nil { return err + } +{{ end }} +{{ define "unmpath" }} + rctx := chi.RouteContext(r.Context()) + if rctx == nil { + panic("Only chi router is supported for GETs atm") + } + for pos,k := range rctx.URLParams.Keys { + if err := errors.Wrap(runtime.PopulateFieldFromPath(req, k, rctx.URLParams.Values[pos]), "couldn't populate field from URL"); err != nil { + return err } -{{end}} -{{end}} -{{end}} + } +{{ end }} +`)) + + implTemplate = template.Must(template.New("impl").Funcs(funcMap).Parse(` +// Code generated by protoc-gen-goclay, but your can (must) modify it. +// source: {{ .GetName }} + +package {{ .GoPkg.Name }} + +import ( + {{ range $i, $import := .Imports -}} + {{ $import -}}; + {{ end }} ) -{{end}} -{{define "unmbody"}} - inbound,_ := httpruntime.MarshalerForRequest(r) - err = errors.Wrap(inbound.Unmarshal(r.Body,req),"couldn't read request JSON") - if err != nil { - return err - } -{{end}} -{{define "unmpath"}} - rctx := chi.RouteContext(r.Context()) - if rctx == nil { - panic("Only chi router is supported for GETs atm") - } - for pos,k := range rctx.URLParams.Keys { - runtime.PopulateFieldFromPath(req, k, rctx.URLParams.Values[pos]) - } -{{end}} + +{{ range $service := .Services }} + +type {{ $service.GetName }}Implementation struct {} + +func New{{ $service.GetName }}() *{{ $service.GetName }}Implementation { + return &{{ $service.GetName }}Implementation{} +} + +{{ range $method := $service.Methods }} +func (i *{{ $service.GetName }}Implementation) {{ $method.Name }}(ctx context.Context, req *{{ $.DescPrefix }}{{ $method.RequestType.GetName }}) (*{{ $.DescPrefix }}{{ $method.ResponseType.GetName }}, error) { + return nil, errors.New("not implemented") +} +{{ end }} + +// GetDescription is a simple alias to the ServiceDesc constructor. +// It makes it possible to register the service implementation @ the server. +func (i *{{ $service.GetName }}Implementation) GetDescription() transport.ServiceDesc { + return {{ $.DescPrefix }}New{{ $service.GetName }}ServiceDesc(i) +} + +{{ end }} `)) ) diff --git a/cmd/protoc-gen-goclay/main.go b/cmd/protoc-gen-goclay/main.go index bbb1e8f..e11516d 100644 --- a/cmd/protoc-gen-goclay/main.go +++ b/cmd/protoc-gen-goclay/main.go @@ -4,17 +4,15 @@ import ( "flag" "fmt" "io" - "io/ioutil" "os" "strings" - "github.com/utrack/clay/cmd/protoc-gen-goclay/genhandler" - "github.com/golang/glog" "github.com/golang/protobuf/proto" plugin "github.com/golang/protobuf/protoc-gen-go/plugin" "github.com/grpc-ecosystem/grpc-gateway/codegenerator" "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor" + "github.com/utrack/clay/cmd/protoc-gen-goclay/genhandler" ) var ( @@ -22,31 +20,21 @@ var ( file = flag.String("file", "-", "where to load data from") allowDeleteBody = flag.Bool("allow_delete_body", false, "unless set, HTTP DELETE methods may not have a body") grpcAPIConfiguration = flag.String("grpc_api_configuration", "", "path to gRPC API Configuration in YAML format") + withImpl = flag.Bool("impl", false, "generate simple implementations for proto Services. Implementation will not be generated if it already exists. See also `force` option") + withSwagger = flag.Bool("swagger", true, "generate swagger.json") + descPath = flag.String("desc_path", "", "path where the http description is generated") + implPath = flag.String("impl_path", "", "path where the implementation is generated (for impl = true)") + forceImpl = flag.Bool("force", false, "force regenerate implementation if it already exists (for impl = true)") ) -func parseReq(r io.Reader) (*plugin.CodeGeneratorRequest, error) { - glog.V(1).Info("Parsing code generator request") - input, err := ioutil.ReadAll(r) - if err != nil { - glog.Errorf("Failed to read code generator request: %v", err) - return nil, err - } - req := new(plugin.CodeGeneratorRequest) - if err = proto.Unmarshal(input, req); err != nil { - glog.Errorf("Failed to unmarshal code generator request: %v", err) - return nil, err - } - glog.V(1).Info("Parsed code generator request") - return req, nil -} - func main() { flag.Parse() + flag.Lookup("logtostderr").Value.Set("true") defer glog.Flush() reg := descriptor.NewRegistry() - glog.V(1).Info("Processing code generator request") + glog.V(2).Info("Processing code generator request") fs := os.Stdin if *file != "-" { var err error @@ -55,7 +43,8 @@ func main() { glog.Fatal(err) } } - glog.V(1).Info("Parsing code generator request") + + glog.V(2).Info("Parsing code generator request") req, err := codegenerator.ParseRequest(fs) if err != nil { glog.Fatal(err) @@ -81,13 +70,29 @@ func main() { } } - g := genhandler.New(reg) - if err = reg.Load(req); err != nil { emitError(err) return } + opts := []genhandler.Option{ + genhandler.Impl(*withImpl), + genhandler.ImplPath(*implPath), + genhandler.DescPath(*descPath), + genhandler.Force(*forceImpl), + } + + if *withSwagger { + swagBuf, err := genSwaggerDef(req, pkgMap) + if err != nil { + emitError(err) + return + } + opts = append(opts, genhandler.SwaggerDef(swagBuf)) + } + + g := genhandler.New(reg, opts...) + var targets []*descriptor.File for _, target := range req.FileToGenerate { var f *descriptor.File @@ -98,14 +103,8 @@ func main() { targets = append(targets, f) } - swagBuf, err := genSwaggerDef(req, pkgMap) - if err != nil { - emitError(err) - return - } - - out, err := g.Generate(targets, swagBuf) - glog.V(1).Info("Processed code generator request") + out, err := g.Generate(targets) + glog.V(2).Info("Processed code generator request") if err != nil { emitError(err) return @@ -113,14 +112,6 @@ func main() { emitFiles(os.Stdout, out) } -func ptrOfInt32(i int32) *int32 { - return &i -} - -func ptrOfString(s string) *string { - return &s -} - // parseReqParam parses a CodeGeneratorRequest parameter and adds the // extracted values to the given FlagSet and pkgMap. Returns a non-nil // error if setting a flag failed. @@ -153,6 +144,14 @@ func parseReqParam(param string, f *flag.FlagSet, pkgMap map[string]string) erro return fmt.Errorf("Cannot set flag %s: %v", p, err) } } + *descPath = strings.Trim(*descPath, "/") + if *descPath == "." { + *descPath = "" + } + *implPath = strings.Trim(*implPath, "/") + if *implPath == "." { + *implPath = "" + } return nil } diff --git a/cmd/protoc-gen-goclay/swagger.go b/cmd/protoc-gen-goclay/swagger.go index f97c7cb..7e7ce03 100644 --- a/cmd/protoc-gen-goclay/swagger.go +++ b/cmd/protoc-gen-goclay/swagger.go @@ -1,13 +1,14 @@ package main import ( + "github.com/go-openapi/spec" "github.com/golang/glog" plugin "github.com/golang/protobuf/protoc-gen-go/plugin" "github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor" "github.com/utrack/grpc-gateway/protoc-gen-swagger/genswagger" ) -func genSwaggerDef(req *plugin.CodeGeneratorRequest, pkgMap map[string]string) (map[string][]byte, error) { +func genSwaggerDef(req *plugin.CodeGeneratorRequest, pkgMap map[string]string) (map[string]*spec.Swagger, error) { reg := descriptor.NewRegistry() reg.SetPrefix(*importPrefix) reg.SetAllowDeleteBody(*allowDeleteBody) @@ -41,9 +42,13 @@ func genSwaggerDef(req *plugin.CodeGeneratorRequest, pkgMap map[string]string) ( if err != nil { return nil, err } - ret := make(map[string][]byte, len(outSwag)) + ret := make(map[string]*spec.Swagger, len(outSwag)) for pos := range outSwag { - ret[req.FileToGenerate[pos]] = []byte(outSwag[pos].GetContent()) + s := &spec.Swagger{} + if err := s.UnmarshalJSON([]byte(outSwag[pos].GetContent())); err != nil { + return nil, err + } + ret[req.FileToGenerate[pos]] = s } return ret, nil } diff --git a/transport/handlers.go b/transport/handlers.go index ea9b291..067b1b5 100644 --- a/transport/handlers.go +++ b/transport/handlers.go @@ -7,24 +7,20 @@ import ( ) // Service is a registerable collection of endpoints. -// These functions should be autogenerated by LZD Protobuf codegenerator. +// These functions should be autogenerated by protoc-gen-goclay. type Service interface { GetDescription() ServiceDesc } // ServiceDesc is a description of an endpoints' collection. -// These functions should be autogenerated by LZD Protobuf codegenerator. +// These functions should be autogenerated by protoc-gen-goclay. type ServiceDesc interface { RegisterGRPC(*grpc.Server) - RegisterHTTP(Router) - SwaggerDef() []byte + RegisterHTTP(Router) error + SwaggerDef(options ...SwaggerOption) []byte } // Router routes HTTP requests around. type Router interface { - http.Handler - Method(method string, pattern string, h http.Handler) - MethodFunc(method string, pattern string, h http.HandlerFunc) - // Use makes this router use middlewares passed. - Use(middlewares ...func(http.Handler) http.Handler) + Handle(pattern string, h http.Handler) } diff --git a/transport/swagopts.go b/transport/swagopts.go new file mode 100644 index 0000000..f7a8a5e --- /dev/null +++ b/transport/swagopts.go @@ -0,0 +1,46 @@ +package transport + +import ( + "github.com/go-openapi/spec" +) + +type SwaggerOption func(swagger *spec.Swagger) + +func WithHost(host string) SwaggerOption { + return func(swagger *spec.Swagger) { + swagger.Host = host + } +} + +func WithVersion(version string) SwaggerOption { + return func(swagger *spec.Swagger) { + if swagger.Info == nil { + swagger.Info = &spec.Info{} + } + swagger.Info.Version = version + } +} + +func WithTitle(title string) SwaggerOption { + return func(swagger *spec.Swagger) { + if swagger.Info == nil { + swagger.Info = &spec.Info{} + } + swagger.Info.Title = title + } +} + +func WithDescription(desc string) SwaggerOption { + return func(swagger *spec.Swagger) { + if swagger.Info == nil { + swagger.Info = &spec.Info{} + } + swagger.Info.Description = desc + } +} + +func WithSecurityDefinitions(secDef spec.SecurityDefinitions) SwaggerOption { + return func(swagger *spec.Swagger) { + swagger.SecurityDefinitions = secDef + } +} \ No newline at end of file