-
Notifications
You must be signed in to change notification settings - Fork 3
/
expect.go
162 lines (125 loc) · 3.76 KB
/
expect.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
package gormtestutil
import (
"fmt"
"reflect"
"sync"
testingi "github.com/mitchellh/go-testing-interface"
"gorm.io/gorm"
)
const (
defaultTimesCalled = 1
defaultStrict = true
)
// ExpectOption allows various options to be supplied to Expect* functions
type ExpectOption func(*expectConfig)
// WithCalls is used to expect an invocation an X amount of times
func WithCalls(times int) ExpectOption {
return func(config *expectConfig) {
config.Times = times
}
}
// WithExpectation allows you to chain wait groups and expectations together
func WithExpectation(expectation *sync.WaitGroup) ExpectOption {
return func(config *expectConfig) {
config.Expectation = expectation
}
}
// WithoutMaximum instructs the expectation to ignore an excess amount of calls. By default, any more calls
// than the expected 'times' cause an error.
func WithoutMaximum() ExpectOption {
return func(config *expectConfig) {
config.Strict = false
}
}
// ExpectCreated asserts that an insert statement has at least been executed on the model.
func ExpectCreated(t testingi.T, database *gorm.DB, model any, options ...ExpectOption) *sync.WaitGroup {
t.Helper()
return expectHook(t, database, model, "create", options...)
}
// ExpectDeleted asserts that a delete statement has at least been executed on the model.
func ExpectDeleted(t testingi.T, database *gorm.DB, model any, options ...ExpectOption) *sync.WaitGroup {
t.Helper()
return expectHook(t, database, model, "delete", options...)
}
// ExpectUpdated asserts that an update statement has at least been executed on the model.
func ExpectUpdated(t testingi.T, database *gorm.DB, model any, options ...ExpectOption) *sync.WaitGroup {
t.Helper()
return expectHook(t, database, model, "update", options...)
}
type expectConfig struct {
Times int
Strict bool
Expectation *sync.WaitGroup
}
// expectHook asserts that a hook has at least been executed on the model.
//
//nolint:cyclop // Allowing it
func expectHook(t testingi.T, database *gorm.DB, model any, hook string, options ...ExpectOption) *sync.WaitGroup {
t.Helper()
if database == nil {
t.Error("database cannot be nil")
return nil
}
kind := reflect.ValueOf(model).Kind()
if kind != reflect.Struct {
t.Error("model must be a struct")
return nil
}
// Default values
config := &expectConfig{
Times: defaultTimesCalled,
Strict: defaultStrict,
Expectation: &sync.WaitGroup{},
}
for _, option := range options {
option(config)
}
// Set waitgroup for amount of times
config.Expectation.Add(config.Times)
// Get table name of model to use in register hook
stmt := &gorm.Statement{DB: database}
if err := stmt.Parse(model); err != nil {
t.Error(err)
return nil
}
var timesCalled int
assertHook := func(tx *gorm.DB) {
t.Helper()
if tx.Statement.Table != stmt.Table {
return
}
timesCalled++
if timesCalled <= config.Times {
config.Expectation.Done()
return
}
message := fmt.Sprintf("%s hook asserts called %d times but called at least %d times\n", stmt.Table, config.Times, timesCalled)
if config.Strict {
t.Errorf(message)
return
}
t.Log(message)
}
hookName := fmt.Sprintf("assert_%s_%v", hook, stmt.Table)
switch hook {
case "create":
gormHook := "gorm:after_create"
if err := database.Callback().Create().After(gormHook).Register(hookName, assertHook); err != nil {
t.Error(err)
return nil
}
case "delete":
gormHook := "gorm:after_delete"
if err := database.Callback().Delete().After(gormHook).Register(hookName, assertHook); err != nil {
t.Error(err)
return nil
}
case "update":
gormHook := "gorm:after_update"
if err := database.Callback().Update().After(gormHook).Register(hookName, assertHook); err != nil {
t.Error(err)
return nil
}
}
return config.Expectation
}