Skip to content

Commit

Permalink
fix: auto fix handler (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
notJoon authored Sep 13, 2024
1 parent 4f77c4e commit 977e9f5
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 98 deletions.
87 changes: 38 additions & 49 deletions cmd/tlin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func main() {
logger, _ := zap.NewProduction()
defer logger.Sync()

config := parseFlags()
config := parseFlags(os.Args[1:])

ctx, cancel := context.WithTimeout(context.Background(), config.Timeout)
defer cancel()
Expand All @@ -74,41 +74,38 @@ func main() {
runWithTimeout(ctx, func() {
runCyclomaticComplexityAnalysis(ctx, logger, config.Paths, config.CyclomaticThreshold)
})
} else {
} else if config.AutoFix {
runWithTimeout(ctx, func() {
runNormalLintProcess(ctx, logger, engine, config.Paths)
runAutoFix(ctx, logger, engine, config.Paths, config.DryRun, config.ConfidenceThreshold)
})
}

if config.AutoFix {
if config.ConfidenceThreshold < 0 || config.ConfidenceThreshold > 1 {
fmt.Println("error: confidence threshold must be between 0 and 1")
os.Exit(1)
}
runAutoFix(ctx, logger, engine, config.Paths, config.DryRun, config.ConfidenceThreshold)
} else {
runWithTimeout(ctx, func() {
runNormalLintProcess(ctx, logger, engine, config.Paths)
})
}
}

func parseFlags() Config {
func parseFlags(args []string) Config {
flagSet := flag.NewFlagSet("tlin", flag.ExitOnError)
config := Config{}
flag.DurationVar(&config.Timeout, "timeout", defaultTimeout, "Set a timeout for the linter. example: 1s, 1m, 1h")
flag.BoolVar(&config.CyclomaticComplexity, "cyclo", false, "Run cyclomatic complexity analysis")
flag.IntVar(&config.CyclomaticThreshold, "threshold", 10, "Cyclomatic complexity threshold")
flag.StringVar(&config.IgnoreRules, "ignore", "", "Comma-separated list of lint rules to ignore")
flag.BoolVar(&config.CFGAnalysis, "cfg", false, "Run control flow graph analysis")
flag.StringVar(&config.FuncName, "func", "", "Function name for CFG analysis")

flag.BoolVar(&config.AutoFix, "fix", false, "Automatically fix issues")
flag.BoolVar(&config.DryRun, "dry-run", false, "Show what would be fixed without actually fixing")
flag.Float64Var(&config.ConfidenceThreshold, "confidence", defaultConfidenceThreshold, "Minimum confidence threshold for fixing issues")

flag.Parse()
flagSet.DurationVar(&config.Timeout, "timeout", defaultTimeout, "Set a timeout for the linter. example: 1s, 1m, 1h")
flagSet.BoolVar(&config.CyclomaticComplexity, "cyclo", false, "Run cyclomatic complexity analysis")
flagSet.IntVar(&config.CyclomaticThreshold, "threshold", 10, "Cyclomatic complexity threshold")
flagSet.StringVar(&config.IgnoreRules, "ignore", "", "Comma-separated list of lint rules to ignore")
flagSet.BoolVar(&config.CFGAnalysis, "cfg", false, "Run control flow graph analysis")
flagSet.StringVar(&config.FuncName, "func", "", "Function name for CFG analysis")
flagSet.BoolVar(&config.AutoFix, "fix", false, "Automatically fix issues")
flagSet.BoolVar(&config.DryRun, "dry-run", false, "Run in dry-run mode (show fixes without applying them)")
flagSet.Float64Var(&config.ConfidenceThreshold, "confidence", defaultConfidenceThreshold, "Confidence threshold for auto-fixing (0.0 to 1.0)")

err := flagSet.Parse(args)
if err != nil {
fmt.Println("Error parsing flags:", err)
os.Exit(1)
}

config.Paths = flag.Args()
config.Paths = flagSet.Args()
if len(config.Paths) == 0 {
fmt.Println("error: Please provide file or directory paths")
os.Exit(1)
Expand Down Expand Up @@ -147,31 +144,6 @@ func runNormalLintProcess(ctx context.Context, logger *zap.Logger, engine LintEn
}
}

func runAutoFix(ctx context.Context, logger *zap.Logger, engine LintEngine, paths []string, dryRun bool, confidenceThreshold float64) {
fix := fixer.New(dryRun, confidenceThreshold)

for _, path := range paths {
issues, err := processPath(ctx, logger, engine, path, processFile)
if err != nil {
logger.Error(
"error processing path",
zap.String("path", path),
zap.Error(err),
)
continue
}

err = fix.Fix(path, issues)
if err != nil {
logger.Error(
"error fixing issues",
zap.String("path", path),
zap.Error(err),
)
}
}
}

func runCyclomaticComplexityAnalysis(ctx context.Context, logger *zap.Logger, paths []string, threshold int) {
issues, err := processFiles(ctx, logger, nil, paths, func(_ LintEngine, path string) ([]tt.Issue, error) {
return processCyclomaticComplexity(path, threshold)
Expand Down Expand Up @@ -217,6 +189,23 @@ func runCFGAnalysis(_ context.Context, logger *zap.Logger, paths []string, funcN
}
}

func runAutoFix(ctx context.Context, logger *zap.Logger, engine LintEngine, paths []string, dryRun bool, confidenceThreshold float64) {
fix := fixer.New(dryRun, confidenceThreshold)

for _, path := range paths {
issues, err := processPath(ctx, logger, engine, path, processFile)
if err != nil {
logger.Error("error processing path", zap.String("path", path), zap.Error(err))
continue
}

err = fix.Fix(path, issues)
if err != nil {
logger.Error("error fixing issues", zap.String("path", path), zap.Error(err))
}
}
}

func processFiles(ctx context.Context, logger *zap.Logger, engine LintEngine, paths []string, processor func(LintEngine, string) ([]tt.Issue, error)) ([]tt.Issue, error) {
var allIssues []tt.Issue
for _, path := range paths {
Expand Down
140 changes: 130 additions & 10 deletions cmd/tlin/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,55 @@ func (m *MockLintEngine) IgnoreRule(rule string) {
}

func TestParseFlags(t *testing.T) {
t.Parallel()
oldArgs := os.Args
defer func() { os.Args = oldArgs }()
tests := []struct {
name string
args []string
expected Config
}{
{
name: "AutoFix",
args: []string{"-fix", "file.go"},
expected: Config{
AutoFix: true,
Paths: []string{"file.go"},
ConfidenceThreshold: defaultConfidenceThreshold,
},
},
{
name: "AutoFix with DryRun",
args: []string{"-fix", "-dry-run", "file.go"},
expected: Config{
AutoFix: true,
DryRun: true,
Paths: []string{"file.go"},
ConfidenceThreshold: defaultConfidenceThreshold,
},
},
{
name: "AutoFix with custom confidence",
args: []string{"-fix", "-confidence", "0.9", "file.go"},
expected: Config{
AutoFix: true,
Paths: []string{"file.go"},
ConfidenceThreshold: 0.9,
},
},
}

os.Args = []string{"cmd", "-timeout", "10s", "-cyclo", "-threshold", "15", "-ignore", "rule1,rule2", "file1.go", "file2.go"}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
oldArgs := os.Args
defer func() { os.Args = oldArgs }()

config := parseFlags()
os.Args = append([]string{"cmd"}, tt.args...)
config := parseFlags(tt.args)

assert.Equal(t, 10*time.Second, config.Timeout)
assert.True(t, config.CyclomaticComplexity)
assert.Equal(t, 15, config.CyclomaticThreshold)
assert.Equal(t, "rule1,rule2", config.IgnoreRules)
assert.Equal(t, []string{"file1.go", "file2.go"}, config.Paths)
assert.Equal(t, tt.expected.AutoFix, config.AutoFix)
assert.Equal(t, tt.expected.DryRun, config.DryRun)
assert.Equal(t, tt.expected.ConfidenceThreshold, config.ConfidenceThreshold)
assert.Equal(t, tt.expected.Paths, config.Paths)
})
}
}

func TestProcessFile(t *testing.T) {
Expand Down Expand Up @@ -261,3 +297,87 @@ func ignoredFunc() { // 19

assert.Contains(t, output, "Function not found: nonExistentFunc")
}

func TestRunAutoFix(t *testing.T) {
logger, _ := zap.NewProduction()
mockEngine := new(MockLintEngine)
ctx := context.Background()

tempDir, err := os.MkdirTemp("", "autofix-test")
assert.NoError(t, err)
defer os.RemoveAll(tempDir)

testFile := filepath.Join(tempDir, "test.go")
initialContent := `package main
func main() {
slice := []int{1, 2, 3}
_ = slice[:len(slice)]
}
`
err = os.WriteFile(testFile, []byte(initialContent), 0644)
assert.NoError(t, err)

expectedIssues := []types.Issue{
{
Rule: "simplify-slice-range",
Filename: testFile,
Message: "unnecessary use of len() in slice expression, can be simplified",
Start: token.Position{Line: 5, Column: 5},
End: token.Position{Line: 5, Column: 24},
Suggestion: "_ = slice[:]",
Confidence: 0.9,
},
}

mockEngine.On("Run", testFile).Return(expectedIssues, nil)

// Capture stdout
oldStdout := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w

// Run auto-fix
runAutoFix(ctx, logger, mockEngine, []string{testFile}, false, 0.8)

// Restore stdout
w.Close()
os.Stdout = oldStdout
var buf bytes.Buffer
io.Copy(&buf, r)
output := buf.String()

// Check if the fix was applied
content, err := os.ReadFile(testFile)
assert.NoError(t, err)

expectedContent := `package main
func main() {
slice := []int{1, 2, 3}
_ = slice[:]
}
`
assert.Equal(t, expectedContent, string(content))
assert.Contains(t, output, "Fixed issues in")

// Test dry-run mode
err = os.WriteFile(testFile, []byte(initialContent), 0644)
assert.NoError(t, err)

r, w, _ = os.Pipe()
os.Stdout = w

runAutoFix(ctx, logger, mockEngine, []string{testFile}, true, 0.8)

w.Close()
os.Stdout = oldStdout
buf.Reset()
io.Copy(&buf, r)
output = buf.String()

content, err = os.ReadFile(testFile)
assert.NoError(t, err)
assert.Equal(t, initialContent, string(content))
assert.Contains(t, output, "Would fix issue in")
}
17 changes: 9 additions & 8 deletions internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type Engine struct {
SymbolTable *SymbolTable
rules []LintRule
ignoredRules map[string]bool
defaultRules []LintRule
}

// NewEngine creates a new lint engine.
Expand All @@ -26,14 +27,18 @@ func NewEngine(rootDir string) (*Engine, error) {
}

engine := &Engine{SymbolTable: st}
engine.registerDefaultRules()
engine.initDefaultRules()

return engine, nil
}

// registerDefaultRules adds the default set of lint rules to the engine.
func (e *Engine) registerDefaultRules() {
e.rules = append(e.rules,
e.rules = append(e.rules, e.defaultRules...)
}

func (e *Engine) initDefaultRules() {
e.defaultRules = []LintRule{
&GolangciLintRule{},
&EarlyReturnOpportunityRule{},
&SimplifySliceExprRule{},
Expand All @@ -46,12 +51,8 @@ func (e *Engine) registerDefaultRules() {
&UselessBreakRule{},
&DeferRule{},
&MissingModPackageRule{},
)
}

// AddRule allows adding custom lint rules to the engine.
func (e *Engine) AddRule(rule LintRule) {
e.rules = append(e.rules, rule)
}
e.registerDefaultRules()
}

// Run applies all lint rules to the given file and returns a slice of Issues.
Expand Down
25 changes: 0 additions & 25 deletions internal/fixer/fixer.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package fixer

import (
"bufio"
"bytes"
"fmt"
"go/format"
Expand All @@ -16,14 +15,12 @@ import (

type Fixer struct {
DryRun bool
autoConfirm bool // testing purposes
MinConfidence float64 // threshold for fixing issues
}

func New(dryRun bool, threshold float64) *Fixer {
return &Fixer{
DryRun: dryRun,
autoConfirm: false,
MinConfidence: threshold,
}
}
Expand Down Expand Up @@ -51,10 +48,6 @@ func (f *Fixer) Fix(filename string, issues []tt.Issue) error {
continue
}

if !f.confirmFix(issue) && !f.autoConfirm {
continue
}

startLine := issue.Start.Line - 1
endLine := issue.End.Line - 1

Expand Down Expand Up @@ -89,24 +82,6 @@ func (f *Fixer) Fix(filename string, issues []tt.Issue) error {
return nil
}

func (f *Fixer) confirmFix(issue tt.Issue) bool {
if f.autoConfirm {
return true
}

fmt.Printf(
"Fix issue in %s at line %d? (confidence: %.2f)\n",
issue.Filename, issue.Start.Line, issue.Confidence,
)
fmt.Printf("Message: %s\n", issue.Message)
fmt.Printf("Suggestion:\n%s\n", issue.Suggestion)
fmt.Print("Apply this fix? (y/N): ")

reader := bufio.NewReader(os.Stdin)
resp, _ := reader.ReadString('\n')
return strings.ToLower(strings.TrimSpace(resp)) == "y"
}

func (c *Fixer) extractIndent(line string) string {
return line[:len(line)-len(strings.TrimLeft(line, " \t"))]
}
Expand Down
Loading

0 comments on commit 977e9f5

Please sign in to comment.