Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow types defined as instantiated generic interfaces to generate mocks #790

Merged
merged 1 commit into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions cmd/mockery.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,6 @@ func (r *RootApp) Run() error {
log.Error().Err(err).Msg("unable to parse packages")
return err
}
log.Info().Msg("done parsing, loading")
if err := parser.Load(); err != nil {
log.Err(err).Msgf("failed to load parser")
return nil
}
log.Info().Msg("done loading, visiting interface nodes")
for _, iface := range parser.Interfaces() {
ifaceLog := log.
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions pkg/fixtures/instantiated_generic_interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package test

type GenericInterface[M any] interface {
Func(arg *M) int
}

type InstantiatedGenericInterface GenericInterface[float32]
1 change: 0 additions & 1 deletion pkg/fixtures/variadic.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ type VariadicFunction = func(args1 string, args2 ...interface{}) interface{}
type Variadic interface {
VariadicFunction(str string, vFunc VariadicFunction) error
}

2 changes: 1 addition & 1 deletion pkg/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (s *GeneratorSuite) getInterfaceFromFile(interfacePath, interfaceName strin
)

s.Require().NoError(
s.parser.Load(),
s.parser.Load(context.Background()),
)

iface, err := s.parser.Find(interfaceName)
Expand Down
2 changes: 1 addition & 1 deletion pkg/outputter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ packages:
m.config.Config = confPath.String()

require.NoError(t, parser.ParsePackages(ctx, []string{tt.packagePath}))
require.NoError(t, parser.Load())
require.NoError(t, parser.Load(context.Background()))
for _, intf := range parser.Interfaces() {
t.Logf("generating interface: %s %s", intf.QualifiedName, intf.Name)
require.NoError(t, m.Generate(ctx, intf))
Expand Down
58 changes: 40 additions & 18 deletions pkg/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,27 @@ import (
"golang.org/x/tools/go/packages"
)

type parserEntry struct {
type fileEntry struct {
fileName string
pkg *packages.Package
syntax *ast.File
interfaces []string
}

func (f *fileEntry) ParseInterfaces(ctx context.Context) {
nv := NewNodeVisitor(ctx)
ast.Walk(nv, f.syntax)
f.interfaces = nv.DeclaredInterfaces()
}

type packageLoadEntry struct {
pkgs []*packages.Package
err error
}

type Parser struct {
entries []*parserEntry
entriesByFileName map[string]*parserEntry
files []*fileEntry
entriesByFileName map[string]*fileEntry
parserPackages []*types.Package
conf packages.Config
packageLoadCache map[string]packageLoadEntry
Expand All @@ -52,7 +58,7 @@ func NewParser(buildTags []string) *Parser {
}
return &Parser{
parserPackages: make([]*types.Package, 0),
entriesByFileName: map[string]*parserEntry{},
entriesByFileName: map[string]*fileEntry{},
conf: conf,
packageLoadCache: map[string]packageLoadEntry{},
}
Expand Down Expand Up @@ -86,18 +92,21 @@ func (p *Parser) ParsePackages(ctx context.Context, packageNames []string) error
Str("package", pkg.PkgPath).
Str("file", file).
Msgf("found file")
entry := parserEntry{
entry := fileEntry{
fileName: file,
pkg: pkg,
syntax: pkg.Syntax[fileIdx],
}
p.entries = append(p.entries, &entry)
entry.ParseInterfaces(ctx)
p.files = append(p.files, &entry)
p.entriesByFileName[file] = &entry
}
}
return nil
}

// DEPRECATED: Parse is part of the deprecated, legacy mockery behavior. This is not
// used when the packages feature is enabled.
func (p *Parser) Parse(ctx context.Context, path string) error {
// To support relative paths to mock targets w/ vendor deps, we need to provide eventual
// calls to build.Context.Import with an absolute path. It needs to be absolute because
Expand Down Expand Up @@ -164,30 +173,28 @@ func (p *Parser) Parse(ctx context.Context, path string) error {
if _, ok := p.entriesByFileName[f]; ok {
continue
}
entry := parserEntry{
entry := fileEntry{
fileName: f,
pkg: pkg,
syntax: pkg.Syntax[idx],
}
p.entries = append(p.entries, &entry)
p.files = append(p.files, &entry)
p.entriesByFileName[f] = &entry
}
}

return nil
}

func (p *Parser) Load() error {
for _, entry := range p.entries {
nv := NewNodeVisitor()
ast.Walk(nv, entry.syntax)
entry.interfaces = nv.DeclaredInterfaces()
func (p *Parser) Load(ctx context.Context) error {
for _, entry := range p.files {
entry.ParseInterfaces(ctx)
}
return nil
}

func (p *Parser) Find(name string) (*Interface, error) {
for _, entry := range p.entries {
for _, entry := range p.files {
for _, iface := range entry.interfaces {
if iface == name {
list := p.packageInterfaces(entry.pkg.Types, entry.fileName, []string{name}, nil)
Expand All @@ -202,7 +209,7 @@ func (p *Parser) Find(name string) (*Interface, error) {

func (p *Parser) Interfaces() []*Interface {
ifaces := make(sortableIFaceList, 0)
for _, entry := range p.entries {
for _, entry := range p.files {
declaredIfaces := entry.interfaces
ifaces = p.packageInterfaces(entry.pkg.Types, entry.fileName, declaredIfaces, ifaces)
}
Expand Down Expand Up @@ -314,12 +321,15 @@ func (s sortableIFaceList) Less(i, j int) bool {
}

type NodeVisitor struct {
declaredInterfaces []string
declaredInterfaces []string
genericInstantiationInterface map[string]any
ctx context.Context
}

func NewNodeVisitor() *NodeVisitor {
func NewNodeVisitor(ctx context.Context) *NodeVisitor {
return &NodeVisitor{
declaredInterfaces: make([]string, 0),
ctx: ctx,
}
}

Expand All @@ -328,11 +338,23 @@ func (nv *NodeVisitor) DeclaredInterfaces() []string {
}

func (nv *NodeVisitor) Visit(node ast.Node) ast.Visitor {
log := zerolog.Ctx(nv.ctx)

switch n := node.(type) {
case *ast.TypeSpec:
log := log.With().
Str("node-name", n.Name.Name).
Str("node-type", fmt.Sprintf("%T", n.Type)).
Logger()

switch n.Type.(type) {
case *ast.InterfaceType, *ast.FuncType:
case *ast.InterfaceType, *ast.FuncType, *ast.IndexExpr:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main fix.

log.Debug().
Str("node-type", fmt.Sprintf("%T", n.Type)).
Msg("found node with acceptable type for mocking")
nv.declaredInterfaces = append(nv.declaredInterfaces, n.Name.Name)
default:
log.Debug().Msg("Found node with unacceptable type for mocking. Rejecting.")
}
}
return nv
Expand Down
10 changes: 5 additions & 5 deletions pkg/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func TestFileParse(t *testing.T) {
err := parser.Parse(ctx, testFile)
assert.NoError(t, err)

err = parser.Load()
err = parser.Load(context.Background())
assert.NoError(t, err)

node, err := parser.Find("Requester")
Expand All @@ -38,7 +38,7 @@ func TestBuildTagInFilename(t *testing.T) {
err = parser.Parse(ctx, getFixturePath("buildtag", "filename", "iface_freebsd.go"))
assert.NoError(t, err)

err = parser.Load()
err = parser.Load(context.Background())
assert.NoError(t, err) // Expect "redeclared in this block" if tags aren't respected

nodes := parser.Interfaces()
Expand All @@ -60,7 +60,7 @@ func TestBuildTagInComment(t *testing.T) {
err = parser.Parse(ctx, getFixturePath("buildtag", "comment", "freebsd_iface.go"))
assert.NoError(t, err)

err = parser.Load()
err = parser.Load(context.Background())
assert.NoError(t, err) // Expect "redeclared in this block" if tags aren't respected

nodes := parser.Interfaces()
Expand All @@ -78,7 +78,7 @@ func TestCustomBuildTag(t *testing.T) {
err = parser.Parse(ctx, getFixturePath("buildtag", "comment", "custom2_iface.go"))
assert.NoError(t, err)

err = parser.Load()
err = parser.Load(context.Background())
assert.NoError(t, err) // Expect "redeclared in this block" if tags aren't respected

found := false
Expand All @@ -94,6 +94,6 @@ func TestCustomBuildTag(t *testing.T) {
func TestParsePackages(t *testing.T) {
parser := NewParser([]string{})
require.NoError(t, parser.ParsePackages(context.Background(), []string{"github.com/vektra/mockery/v2/pkg/fixtures"}))
assert.NotEqual(t, 0, len(parser.entries))
assert.NotEqual(t, 0, len(parser.files))

}
Loading
Loading