diff --git a/ast.go b/ast.go index c682480..6b8809f 100644 --- a/ast.go +++ b/ast.go @@ -9,28 +9,29 @@ import ( "github.com/sonalys/fake/internal/packages" ) -func (file *ParsedFile) importConflictResolution(importUsedName string, importPath string) string { - info, ok := file.OriginalImports[importUsedName] - // If the original import is found, use either name or alias. - if ok { - pkgInfo := file.ImportsPathMap[importPath] - if pkgInfo.Alias != "" { - file.UsedImports[pkgInfo.Alias] = struct{}{} - return pkgInfo.Alias - } - file.UsedImports[pkgInfo.Name] = struct{}{} - return pkgInfo.Name +func (file *ParsedFile) importConflictResolution() string { + if file.importResolved { + return file.importAlias + } + var alias string = file.PkgName + info, ok := file.Imports[file.PkgName] + // Conflict detected for a package with different path. + if ok && info.Path != file.PkgPath { + alias = fmt.Sprintf("%s1", alias) } info = &imports.ImportEntry{ - PackageInfo: packages.PackageInfo{ + PackageInfo: &packages.PackageInfo{ Path: file.PkgPath, Name: file.PkgName, }, + Alias: alias, } - file.Imports[file.PkgName] = info + file.Imports[alias] = info file.ImportsPathMap[file.PkgPath] = info - file.UsedImports[file.PkgName] = struct{}{} - return file.PkgName + file.UsedImports[alias] = struct{}{} + file.importAlias = alias + file.importResolved = true + return alias } func (f *ParsedInterface) printAstExpr(expr ast.Expr) string { @@ -52,7 +53,7 @@ func (f *ParsedInterface) printAstExpr(expr ast.Expr) string { return fieldType.Name } } - return fmt.Sprintf("%s.%s", file.importConflictResolution(file.PkgName, file.PkgPath), fieldType.Name) + return fmt.Sprintf("%s.%s", file.importConflictResolution(), fieldType.Name) case *ast.SelectorExpr: // Type from another package. pkgName := fmt.Sprint(fieldType.X) diff --git a/file.go b/file.go index 78d1b43..c5348e6 100644 --- a/file.go +++ b/file.go @@ -7,6 +7,7 @@ import ( "io" "slices" + "github.com/rs/zerolog/log" "github.com/sonalys/fake/internal/imports" ) @@ -20,6 +21,9 @@ type ParsedFile struct { OriginalImports map[string]*imports.ImportEntry ImportsPathMap map[string]*imports.ImportEntry UsedImports map[string]struct{} + + importResolved bool + importAlias string } func (f *ParsedFile) ListInterfaces(names ...string) []*ParsedInterface { @@ -59,9 +63,12 @@ func (f *ParsedFile) writeImports(w io.Writer) { fmt.Fprintf(w, "\t\"testing\"\n") fmt.Fprintf(w, "\tmockSetup \"github.com/sonalys/fake/boilerplate\"\n") for name := range f.UsedImports { - info := f.Imports[name] + info, ok := f.Imports[name] + if !ok { + log.Fatal().Msg("inconsistency between usedImports and imports state") + } fmt.Fprintf(w, "\t") - if info.Alias != "" { + if info.Alias != "" && info.Alias != info.PackageInfo.Name { fmt.Fprintf(w, "%s ", info.Alias) } fmt.Fprintf(w, "\"%s\"\n", info.Path) diff --git a/generator.go b/generator.go index fd2a157..257ab18 100644 --- a/generator.go +++ b/generator.go @@ -1,19 +1,48 @@ package fake import ( + "go/ast" "go/token" + "os" + "path" + + "github.com/sonalys/fake/internal/files" + "github.com/sonalys/fake/internal/imports" + "golang.org/x/mod/modfile" ) // Generator is the controller for the whole module, caching files and holding metadata. type Generator struct { FileSet *token.FileSet MockPackageName string + + cachedPackageInfo func(f *ast.File) (nameMap, pathMap map[string]*imports.ImportEntry) + goModFilename string + goMod *modfile.File } // NewGenerator will create a new mock generator for the specified module. -func NewGenerator(n string) *Generator { - return &Generator{ - FileSet: token.NewFileSet(), - MockPackageName: n, +func NewGenerator(pkgName, baseDir string) (*Generator, error) { + goModPath, err := files.FindFile(baseDir, "go.mod") + if err != nil { + return nil, err + } + // Read the contents of the go.mod file + modFileContent, err := os.ReadFile(goModPath) + if err != nil { + return nil, err } + // Parse the go.mod file + modFile, err := modfile.Parse(goModPath, modFileContent, nil) + if err != nil { + return nil, err + } + + return &Generator{ + FileSet: token.NewFileSet(), + goModFilename: goModPath, + goMod: modFile, + MockPackageName: pkgName, + cachedPackageInfo: imports.CachedImportInformation(path.Dir(goModPath)), + }, nil } diff --git a/generator_test.go b/generator_test.go index 71cd5d5..55c08bd 100644 --- a/generator_test.go +++ b/generator_test.go @@ -8,10 +8,12 @@ import ( ) func Test_Generate(t *testing.T) { - // output := t.TempDir() - output := "out" + output := t.TempDir() + // output := "out" + // os.RemoveAll(output) // no caching Run([]string{"testdata"}, output, nil) - g := NewGenerator("mocks") - _, err := g.ParseFile(path.Join(output, "testdata", "stub.gen.go")) + g, err := NewGenerator("mocks", "testdata") + require.NoError(t, err) + _, err = g.ParseFile(path.Join(output, "testdata", "stub.gen.go")) require.NoError(t, err) } diff --git a/interface.go b/interface.go index b2da682..1a10cc4 100644 --- a/interface.go +++ b/interface.go @@ -24,10 +24,16 @@ type ParsedInterface struct { // type B[J any] interface{ Method() J } // it should have method Method() T when implementing A mock. TranslateGenericNames []string + + fieldsCache []*ParsedField } func (i *ParsedInterface) ListFields() []*ParsedField { - return i.ParsedFile.Generator.listInterfaceFields(i, i.ParsedFile.Imports) + if i.fieldsCache != nil { + return i.fieldsCache + } + i.fieldsCache = i.ParsedFile.Generator.listInterfaceFields(i, i.ParsedFile.Imports) + return i.fieldsCache } // ListInterfaceFields receives an interface to translate fields into fields. @@ -37,6 +43,9 @@ func (g *Generator) listInterfaceFields(i *ParsedInterface, imports map[string]* if i == nil || i.Ref.Methods == nil { return nil } + if i.fieldsCache != nil { + return i.fieldsCache + } var resp []*ParsedField for _, field := range i.Ref.Methods.List { switch t := field.Type.(type) { @@ -90,6 +99,7 @@ func (g *Generator) listInterfaceFields(i *ParsedInterface, imports map[string]* } } resp = deduplicatedResp + i.fieldsCache = resp return resp } @@ -124,7 +134,7 @@ func (g *Generator) parseInterface(ident *ast.SelectorExpr, f *ParsedFile) *Pars if !ok { return nil } - pkg, ok := pkgs.Parse(pkgInfo.Path) + pkg, ok := pkgs.Parse(g.goModFilename, pkgInfo.Path) if !ok { return nil } diff --git a/internal/caching/hash.go b/internal/caching/hash.go index a59ff41..ef1fa25 100644 --- a/internal/caching/hash.go +++ b/internal/caching/hash.go @@ -10,6 +10,7 @@ import ( "path" "sort" "strings" + "time" "github.com/rs/zerolog/log" "github.com/sonalys/fake/internal/files" @@ -29,8 +30,11 @@ func getImportsHash(filePath string, dependencies map[string]string) (string, er sort.Strings(imports) var b strings.Builder for _, importPath := range imports { - if hash, ok := dependencies[importPath]; ok { - b.WriteString(hash) + for path, hash := range dependencies { + if strings.Contains(importPath, path) { + b.WriteString(hash) + break + } } } return b.String(), nil @@ -38,17 +42,13 @@ func getImportsHash(filePath string, dependencies map[string]string) (string, er func GetUncachedFiles(inputs, ignore []string, outputDir string) (map[string]LockfileHandler, error) { lockFilePath := path.Join(outputDir, lockFilename) - lockFilePath = strings.ReplaceAll(lockFilePath, "internal", "internal_") groupLockFiles, err := readLockFile(lockFilePath) if err != nil { return nil, fmt.Errorf("reading %s file: %w", lockFilename, err) } - var dependencies map[string]string - if len(groupLockFiles) > 0 { - dependencies, err = gosum.Parse(inputs[0]) - if err != nil { - return nil, fmt.Errorf("parsing go.sum file: %w", err) - } + dependencies, err := gosum.Parse(inputs[0]) + if err != nil { + return nil, fmt.Errorf("parsing go.sum file: %w", err) } goFiles, err := files.ListGoFiles(inputs, append(ignore, outputDir)) if err != nil { @@ -57,52 +57,69 @@ func GetUncachedFiles(inputs, ignore []string, outputDir string) (map[string]Loc out := make(map[string]LockfileHandler, len(groupLockFiles)) cachedHasher := getFileHasher(len(goFiles)) - // TODO: split into a function. - for _, filePathList := range files.GroupByDirectory(goFiles) { - for _, filePath := range filePathList { - entry, ok := groupLockFiles[filePath] - // If file is not in lock file hashes, then we delay hash calculation for after the mock generation. - // this makes it faster by avoiding calculation of useless files. - if !ok { - out[filePath] = &UnhashedLockFile{ - Filepath: filePath, - Dependencies: dependencies, - } - continue - } - stat, _ := os.Stat(filePath) - if !entry.ModifiedAt.IsZero() && !stat.ModTime().IsZero() && stat.ModTime().Equal(entry.ModifiedAt) { - entry.exists = true - out[filePath] = &entry - continue - } - importsHash, err := getImportsHash(filePath, dependencies) - if err != nil { - return nil, err + + gomod, err := files.FindFile(inputs[0], "go.mod") + if err != nil { + return nil, fmt.Errorf("input is not part of a go module") + } + + for _, absPath := range goFiles { + relPath, err := files.GetRelativePath(gomod, absPath) + if err != nil { + return nil, err + } + entry, ok := groupLockFiles[relPath] + // If file is not in lock file hashes, then we delay hash calculation for after the mock generation. + // this makes it faster by avoiding calculation of useless files. + if !ok { + out[relPath] = &UnhashedLockFile{ + Filepath: absPath, + Dependencies: dependencies, } - hash, err := cachedHasher(filePath) + continue + } + var modAt time.Time + if !entry.ModifiedAt.IsZero() { + stat, err := os.Stat(absPath) if err != nil { - return nil, fmt.Errorf("hashing file: %w", err) + return nil, fmt.Errorf("could not get file stats: %w", err) } - if entry.Hash == hash && entry.Dependencies == importsHash { - // Mark file as processed, to further delete unused entries. + modAt := stat.ModTime() + if !stat.ModTime().IsZero() && modAt.Equal(entry.ModifiedAt) { entry.exists = true - out[filePath] = &entry + entry.filepath = absPath + out[relPath] = &entry continue } - out[filePath] = &HashedLockFile{ - changed: true, - exists: true, - Hash: hash, - Dependencies: importsHash, - ModifiedAt: stat.ModTime(), - } + } + importsHash, err := getImportsHash(absPath, dependencies) + if err != nil { + return nil, err + } + hash, err := cachedHasher(absPath) + if err != nil { + return nil, fmt.Errorf("hashing file: %w", err) + } + if entry.Hash == hash && entry.Dependencies == importsHash { + // Mark file as processed, to further delete unused entries. + entry.exists = true + entry.filepath = absPath + out[relPath] = &entry + continue + } + out[relPath] = &HashedLockFile{ + changed: true, + exists: true, + Hash: hash, + filepath: absPath, + Dependencies: importsHash, + ModifiedAt: modAt, } } - for filePath := range groupLockFiles { - if _, ok := out[filePath]; !ok { + for relPath := range groupLockFiles { + if _, ok := out[relPath]; !ok { // Remove empty files from our new lock file. - rmFileName := files.GenerateOutputFileName(filePath, outputDir) + rmFileName := files.GenerateOutputFileName(relPath, outputDir) os.Remove(rmFileName) log.Info().Msgf("removing legacy mock from %s", rmFileName) } @@ -119,7 +136,7 @@ func loadPackageImports(file string) ([]string, error) { if err != nil { return nil, err } - var imports []string + imports := make([]string, 0, 30) for _, pkg := range pkgs { for imp := range pkg.Imports { imports = append(imports, imp) diff --git a/internal/caching/lockfile.go b/internal/caching/lockfile.go index ddda5bd..0d1eeb9 100644 --- a/internal/caching/lockfile.go +++ b/internal/caching/lockfile.go @@ -7,6 +7,7 @@ import ( "path/filepath" "time" + "github.com/rs/zerolog/log" "github.com/sonalys/fake/internal/files" ) @@ -21,8 +22,9 @@ type ( Hash string `json:"hash"` Dependencies string `json:"dependencies,omitempty"` // Changed is used as an in-memory flag to say that a file lock changed. - changed bool `json:"-"` - exists bool `json:"-"` + filepath string `json:"-"` + changed bool `json:"-"` + exists bool `json:"-"` } LockFilePackage map[string]HashedLockFile @@ -30,22 +32,37 @@ type ( type LockfileHandler interface { Changed() bool + AbsolutePath() string Exists() bool - Compute() HashedLockFile + Compute() *HashedLockFile } func (f *UnhashedLockFile) Changed() bool { return true } +func (f *UnhashedLockFile) AbsolutePath() string { + return f.Filepath +} + +func (f *HashedLockFile) AbsolutePath() string { + return f.filepath +} + func (f *UnhashedLockFile) Exists() bool { return true } -func (f *UnhashedLockFile) Compute() HashedLockFile { - hash, _ := hashFiles(f.Filepath) - dep, _ := getImportsHash(f.Filepath, f.Dependencies) - return HashedLockFile{ +func (f *UnhashedLockFile) Compute() *HashedLockFile { + hash, err := hashFiles(f.Filepath) + if err != nil { + log.Error().Err(err).Msg("could not compute file hash") + } + dep, err := getImportsHash(f.Filepath, f.Dependencies) + if err != nil { + log.Error().Err(err).Msg("could not compute file imports hash") + } + return &HashedLockFile{ Hash: hash, Dependencies: dep, } @@ -59,8 +76,8 @@ func (f *HashedLockFile) Exists() bool { return f.exists } -func (f *HashedLockFile) Compute() HashedLockFile { - return *f +func (f *HashedLockFile) Compute() *HashedLockFile { + return f } // readLockFile reads and parses the json model from the fake.lock.json file @@ -84,7 +101,7 @@ and the target directory (dir), as well as a hash map (hash). It saves file at path output/{dir}/fake.lock.json */ func WriteLockFile(output string, hash map[string]LockfileHandler) error { - var out = make(map[string]HashedLockFile, len(hash)) + var out = make(map[string]*HashedLockFile, len(hash)) for file, entry := range hash { if entry.Exists() { out[file] = entry.Compute() diff --git a/internal/files/walk.go b/internal/files/walk.go index a89f4c7..1b2183e 100644 --- a/internal/files/walk.go +++ b/internal/files/walk.go @@ -80,31 +80,20 @@ func FindFile(childDir, fileName string) (string, error) { // GetPackagePath returns the absolute path of the file package, including the module path. // This function is not considering packages with different names from their respective folders, // the reason is that this software is not made for psychopaths. -func GetPackagePath(filename string) (string, error) { - goModPath, err := FindFile(filepath.Dir(filename), "go.mod") - if err != nil { - return "", err - } - // Read the contents of the go.mod file - modFileContent, err := os.ReadFile(goModPath) - if err != nil { - return "", err - } - // Parse the go.mod file - modFile, err := modfile.Parse(goModPath, modFileContent, nil) +func GetPackagePath(goModPath string, modFile *modfile.File, filename string) (string, error) { + // Retrieve the module path + modulePath := modFile.Module.Mod.Path + pkgPath, err := GetRelativePath(goModPath, path.Dir(filename)) if err != nil { return "", err } - // Retrieve the module path - modulePath := modFile.Module.Mod.Path - pkgPath, _ := getRelativePath(goModPath, path.Dir(filename)) return path.Join(modulePath, pkgPath), nil } -// getRelativePath returns the shared path between two paths. +// GetRelativePath returns the shared path between two paths. // if they are in the same folder, they will return the same folder path. // Example: /path1/folder1/file and /path1/folder2/file2 should return /path1. -func getRelativePath(path1, path2 string) (string, error) { +func GetRelativePath(path1, path2 string) (string, error) { // Make the paths absolute to ensure accurate relative path calculation absPath1, err := filepath.Abs(path1) if err != nil { @@ -121,40 +110,3 @@ func getRelativePath(path1, path2 string) (string, error) { } return relativePath, nil } - -/* -GroupByDirectory groups files by their directory -Example: - -Input: - - files := []string{ - "/home/user/documents/file1.txt", - "/home/user/documents/file2.txt", - "/home/user/images/image1.png", - "/home/user/images/image2.png", - "/home/user/images/image3.png", - } - -Output: - - { - "/home/user/documents": []string{ - "/home/user/documents/file1.txt", - "/home/user/documents/file2.txt", - }, - "/home/user/images": []string{ - "/home/user/images/image1.png", - "/home/user/images/image2.png", - "/home/user/images/image3.png", - }, - } -*/ -func GroupByDirectory(files []string) map[string][]string { - groups := make(map[string][]string) - for _, file := range files { - dir := filepath.Dir(file) - groups[dir] = append(groups[dir], file) - } - return groups -} diff --git a/internal/gosum/parser.go b/internal/gosum/parser.go index aedf3d5..3b62a00 100644 --- a/internal/gosum/parser.go +++ b/internal/gosum/parser.go @@ -17,6 +17,7 @@ func readGoSum(path string, goMod map[string]string) (map[string]string, error) if err != nil { return nil, err } + defer file.Close() dependencies := make(map[string]string) scanner := bufio.NewScanner(file) for scanner.Scan() { diff --git a/internal/imports/parser.go b/internal/imports/parser.go index fe47445..810e750 100644 --- a/internal/imports/parser.go +++ b/internal/imports/parser.go @@ -9,63 +9,40 @@ import ( type ( ImportEntry struct { - packages.PackageInfo + *packages.PackageInfo Alias string } ) -func FileListUsedImports(f *ast.File) (nameMap, pathMap map[string]*ImportEntry) { - importNamePathMap := make(map[string]*ImportEntry, len(f.Imports)) - for _, i := range f.Imports { - trimmedPath := strings.Trim(i.Path.Value, "\"") - info, ok := packages.Parse(trimmedPath) - if !ok { - continue - } - var importEntry = &ImportEntry{ - PackageInfo: *info, - Alias: "", - } - usedName := info.Name - if i.Name != nil { - usedName = i.Name.Name - importEntry.Alias = usedName - } - importNamePathMap[usedName] = importEntry - } - // We want all imports used by interfaces. - importChecker := getUsedInterfacePackages(f) - nameMap = make(map[string]*ImportEntry, len(importChecker)) - pathMap = make(map[string]*ImportEntry, len(importChecker)) - for name := range importChecker { - nameMap[name] = importNamePathMap[name] - pathMap[importNamePathMap[name].Path] = importNamePathMap[name] - } - return nameMap, pathMap -} +func CachedImportInformation(dir string) func(f *ast.File) (nameMap, pathMap map[string]*ImportEntry) { + cache := make(map[string]*packages.PackageInfo, 100) + return func(f *ast.File) (nameMap map[string]*ImportEntry, pathMap map[string]*ImportEntry) { + nameMap = make(map[string]*ImportEntry, len(f.Imports)) + pathMap = make(map[string]*ImportEntry, len(f.Imports)) -// getUsedInterfacePackages returns all packages imported by interfaces. -func getUsedInterfacePackages(f *ast.File) map[string]struct{} { - resp := make(map[string]struct{}, len(f.Imports)) - ast.Inspect(f, func(n ast.Node) bool { - switch x := n.(type) { - case *ast.TypeSpec: - if _, ok := x.Type.(*ast.InterfaceType); !ok { - return true - } - ast.Inspect(x.Type, func(n ast.Node) bool { - // We don't consider imports from interfaces, as they will also be traversed. - sel, ok := n.(*ast.SelectorExpr) + for _, i := range f.Imports { + trimmedPath := strings.Trim(i.Path.Value, "\"") + var info *packages.PackageInfo + if cachedInfo, ok := cache[trimmedPath]; ok { + info = cachedInfo + } else { + info, ok = packages.Parse(dir, trimmedPath) if !ok { - return true - } - if ident, ok := sel.X.(*ast.Ident); ok { - resp[ident.Name] = struct{}{} + continue } - return true - }) + } + var importEntry = &ImportEntry{ + PackageInfo: info, + Alias: "", + } + usedName := info.Name + if i.Name != nil && i.Name.Name != "" && info.Name != i.Name.Name { + usedName = i.Name.Name + importEntry.Alias = usedName + } + nameMap[usedName] = importEntry + pathMap[info.Path] = importEntry } - return true - }) - return resp + return nameMap, pathMap + } } diff --git a/internal/imports/parser_test.go b/internal/imports/parser_test.go index d0eb831..c5b4357 100644 --- a/internal/imports/parser_test.go +++ b/internal/imports/parser_test.go @@ -17,13 +17,13 @@ func Test(t *testing.T) { f, err := parser.ParseFile(fset, "../../testdata/stub.go", nil, 0) require.NoError(t, err) - got, _ := FileListUsedImports(f) + got, _ := CachedImportInformation("")(f) exp := []ImportEntry{ - {PackageInfo: packages.PackageInfo{Name: "anotherpkg", Path: "github.com/sonalys/fake/testdata/anotherpkg"}}, - {PackageInfo: packages.PackageInfo{Name: "time", Path: "time"}}, - {PackageInfo: packages.PackageInfo{Name: "testing", Path: "testing"}}, - {PackageInfo: packages.PackageInfo{Name: "require", Path: "github.com/stretchr/testify/require"}, Alias: "stub"}, + {PackageInfo: &packages.PackageInfo{Name: "anotherpkg", Path: "github.com/sonalys/fake/testdata/anotherpkg"}}, + {PackageInfo: &packages.PackageInfo{Name: "time", Path: "time"}}, + {PackageInfo: &packages.PackageInfo{Name: "testing", Path: "testing"}}, + {PackageInfo: &packages.PackageInfo{Name: "require", Path: "github.com/stretchr/testify/require"}, Alias: "stub"}, } require.ElementsMatch(t, exp, got) } diff --git a/internal/packages/scanner.go b/internal/packages/scanner.go index 3da5008..1a1392a 100644 --- a/internal/packages/scanner.go +++ b/internal/packages/scanner.go @@ -9,9 +9,10 @@ type PackageInfo struct { } // Parse parses the specified package and returns its package name and import path. -func Parse(importPath string) (*PackageInfo, bool) { +func Parse(dir, importPath string) (*PackageInfo, bool) { cfg := &packages.Config{ Mode: packages.NeedName | packages.NeedFiles, + Dir: dir, } pkgs, err := packages.Load(cfg, importPath) if err != nil { diff --git a/parser.go b/parser.go index ee33542..a2854d7 100644 --- a/parser.go +++ b/parser.go @@ -3,18 +3,21 @@ package fake import ( "go/parser" + "github.com/rs/zerolog/log" "github.com/sonalys/fake/internal/files" - "github.com/sonalys/fake/internal/imports" ) func (g *Generator) ParseFile(input string) (*ParsedFile, error) { - file, err := parser.ParseFile(g.FileSet, input, nil, parser.Mode(0)) + file, err := parser.ParseFile(g.FileSet, input, nil, parser.SkipObjectResolution) if err != nil { return nil, err } - packagePath, _ := files.GetPackagePath(input) - - imports, importsPathMap := imports.FileListUsedImports(file) + packagePath, err := files.GetPackagePath(g.goModFilename, g.goMod, input) + if err != nil { + log.Error().Err(err).Msg("failed to get package info") + return nil, err + } + imports, importsPathMap := g.cachedPackageInfo(file) return &ParsedFile{ Generator: g, Ref: file, diff --git a/run.go b/run.go index 95bfaec..0c06baa 100644 --- a/run.go +++ b/run.go @@ -23,7 +23,10 @@ func GenerateInterface(c GenerateInterfaceConfig) { log.Fatal().Err(err).Msg("error comparing file hashes") } log.Info().Msgf("scanning %d files for interface %s", len(fileHashes), c.InterfaceName) - gen := NewGenerator(c.PackageName) + gen, err := NewGenerator(c.PackageName, c.Inputs[0]) + if err != nil { + log.Fatal().Err(err).Msg("error creating mock generator") + } for curFilePath := range fileHashes { b := gen.GenerateFile(curFilePath, c.InterfaceName) if b == nil { @@ -40,29 +43,34 @@ func GenerateInterface(c GenerateInterfaceConfig) { outputFile.Write(b) outputFile.Close() } + if err := caching.WriteLockFile(path.Dir(gen.goModFilename), fileHashes); err != nil { + log.Error().Err(err).Msg("error saving lock file") + } } func Run(inputs []string, output string, ignore []string, interfaces ...string) { - gen := NewGenerator("mocks") + gen, err := NewGenerator("mocks", inputs[0]) + if err != nil { + log.Fatal().Err(err).Msg("error creating mock generator") + } fileHashes, err := caching.GetUncachedFiles(inputs, append(ignore, output), output) if err != nil { log.Fatal().Err(err).Msg("error comparing file hashes") } - log.Info().Msgf("scanning %d files", len(fileHashes)) var counter int - for curFilePath, lockFile := range fileHashes { + for relPath, lockFile := range fileHashes { if !lockFile.Changed() { continue } - if b := gen.GenerateFile(curFilePath); len(b) > 0 { - log.Info().Msgf("generating mock for %s", curFilePath) + if b := gen.GenerateFile(lockFile.AbsolutePath()); len(b) > 0 { + log.Info().Msgf("generating mock for %s", relPath) counter++ - outputFile := openOutputFile(curFilePath, output) + outputFile := openOutputFile(relPath, output) outputFile.Write(b) outputFile.Close() } else { // Remove empty files from our new lock file. - delete(fileHashes, curFilePath) + // delete(fileHashes, relPath) } } if len(fileHashes) == 0 || counter == 0 { diff --git a/writer.go b/writer.go index fe5c6d1..74d0b0b 100644 --- a/writer.go +++ b/writer.go @@ -6,22 +6,34 @@ import ( "go/format" "io" "os" + "sync" "github.com/rs/zerolog/log" "github.com/sonalys/fake/internal/files" ) +var pool = sync.Pool{ + New: func() any { + newBuf := make([]byte, 0, 1024*10) + return &newBuf + }, +} + func (g *Generator) GenerateFile(input string, interfaceNames ...string) []byte { parsedFile, err := g.ParseFile(input) if err != nil { - log.Panic().Msgf("failed to parse file: %s", input) + log.Panic().Err(err).Msgf("failed to parse file: %s", input) } interfaces := parsedFile.ListInterfaces(interfaceNames...) if len(interfaces) == 0 { return nil } - header := bytes.NewBuffer(make([]byte, 0, parsedFile.Size)) - body := bytes.NewBuffer(make([]byte, 0, parsedFile.Size)) + buf1 := pool.Get().(*[]byte) + buf2 := pool.Get().(*[]byte) + defer pool.Put(buf1) + defer pool.Put(buf2) + header := bytes.NewBuffer(*buf1) + body := bytes.NewBuffer(*buf2) if g.MockPackageName == "" { g.MockPackageName = parsedFile.PkgName } @@ -32,7 +44,7 @@ func (g *Generator) GenerateFile(input string, interfaceNames ...string) []byte } // writeImports comes after interfaces because we only add external dependencies after generating interfaces. parsedFile.writeImports(header) - header.Write(body.Bytes()) + header.ReadFrom(body) return formatCode(header.Bytes()) }