-
Notifications
You must be signed in to change notification settings - Fork 0
/
configured_oper.go
196 lines (173 loc) · 4.83 KB
/
configured_oper.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
package main
import (
"errors"
"fmt"
"io"
"os"
"strings"
"sync"
"time"
"github.com/baalimago/repeater/internal/output"
)
const incrementPlaceholder = "INC"
type configuredOper struct {
am int
workers int
args []string
progress output.Mode
progressFormat string
output output.Mode
outputFile *os.File
outputFileMu *sync.Mutex
outputFileMode string
increment bool
runtime time.Duration
results []Result
resultFile *os.File
workerWg *sync.WaitGroup
amIdleWorkers int
amSuccess int
workPlanMu *sync.Mutex
retryOnFail bool
}
type userQuitError string
func (uqe userQuitError) Error() string {
return string(uqe)
}
const UserQuitError userQuitError = "user quit"
type incrementConfigError struct {
args []string
}
func (ice incrementConfigError) Error() string {
return fmt.Sprintf("increment is true, but args: %v, does not contain string '%s'", ice.args, incrementPlaceholder)
}
func New(am, workers int,
args []string,
pMode output.Mode,
progressFormat string,
oMode output.Mode,
outputFile string,
outputFileMode string,
increment bool,
resultFlag string,
retryOnFail bool,
) (configuredOper, error) {
shouldHaveReportFile := pMode == output.BOTH || pMode == output.FILE ||
oMode == output.BOTH || oMode == output.FILE
if shouldHaveReportFile && outputFile == "" {
return configuredOper{}, fmt.Errorf("progress mode '%v', or output mode '%v', requires a report file but none is set. Use flag --file <file_name>", pMode, oMode)
}
if increment && !containsIncrementPlaceholder(args) {
return configuredOper{}, incrementConfigError{args: args}
}
if workers > am {
return configuredOper{}, fmt.Errorf("please use less workers than repetitions. Am workers: %v, am repetitions: %v", workers, am)
}
c := configuredOper{
am: am,
workers: workers,
args: args,
progress: pMode,
progressFormat: progressFormat,
output: oMode,
outputFileMu: &sync.Mutex{},
increment: increment,
amSuccess: 0,
workerWg: &sync.WaitGroup{},
amIdleWorkers: workers,
workPlanMu: &sync.Mutex{},
retryOnFail: retryOnFail,
}
c.workerWg.Add(workers)
file, err := c.getFile(outputFile, outputFileMode)
if err != nil {
if errors.Is(err, UserQuitError) {
return c, err
}
return c, fmt.Errorf("failed to get file: %w", err)
}
c.outputFile = file
file, err = c.getFile(resultFlag, "")
if err != nil {
if errors.Is(err, UserQuitError) {
return c, err
}
return c, fmt.Errorf("failed to get file: %w", err)
}
c.resultFile = file
return c, nil
}
// getFile a file. if one already exists, either consult the fileMode string, or query
// user how the file should be treated
func (c *configuredOper) getFile(s, fileMode string) (*os.File, error) {
if s == "" {
return nil, nil
}
if _, err := os.Stat(s); !errors.Is(err, os.ErrNotExist) {
userResp := fileMode
if fileMode == "" {
printWarn(fmt.Sprintf("file: \"%v\", already exists. Would you like to [t]runcate, [a]ppend or [q]uit? [t/a/q]: ", s))
fmt.Scanln(&userResp)
}
cleanedUserResp := strings.ToLower(strings.TrimSpace(userResp))
c.outputFileMode = cleanedUserResp
switch cleanedUserResp {
case "t":
// NOOP, fallthrough to os.Create below
case "a":
return os.OpenFile(s, os.O_APPEND|os.O_RDWR, 0o644)
case "q":
return nil, UserQuitError
default:
return nil, fmt.Errorf("unrecognized reply: \"%v\", valid options are [tT], [aA] or [qQ]", userResp)
}
}
return os.Create(s)
}
func (c *configuredOper) String() string {
reportFileName := "HIDDEN"
if c.outputFile != nil {
reportFileName = c.outputFile.Name()
}
return fmt.Sprintf(`am: %v
command: %v
increment: %v
workers: %v
progress: %s
progress format: %q
output: %s
report file: %v
report file mode: %v`, c.am, c.args, c.increment, c.workers, c.progress, c.progressFormat, c.output, reportFileName, c.outputFileMode)
}
func (c *configuredOper) writeOutput(res *Result) {
switch c.output {
case output.STDOUT:
fmt.Fprintf(os.Stdout, "%v", res.Output)
case output.FILE:
fmt.Fprintf(c.outputFile, "%v", res.Output)
case output.BOTH:
fmt.Fprintf(os.Stdout, "%v", res.Output)
fmt.Fprintf(c.outputFile, "%v", res.Output)
}
}
func (c *configuredOper) setupProgressStreams() []io.Writer {
progressStreams := make([]io.Writer, 0, 2)
switch c.progress {
case output.STDOUT:
progressStreams = append(progressStreams, os.Stdout)
case output.FILE:
progressStreams = append(progressStreams, c.outputFile)
case output.BOTH:
progressStreams = append(progressStreams, os.Stdout)
progressStreams = append(progressStreams, c.outputFile)
}
return progressStreams
}
func containsIncrementPlaceholder(args []string) bool {
for _, arg := range args {
if strings.Contains(arg, incrementPlaceholder) {
return true
}
}
return false
}