-
-
Notifications
You must be signed in to change notification settings - Fork 80
/
sqlserver.go
145 lines (124 loc) · 3.46 KB
/
sqlserver.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package testfixtures
import (
"database/sql"
"fmt"
"strings"
)
type sqlserver struct {
baseHelper
paramTypeCache int
tables []string
}
func (h *sqlserver) init(db *sql.DB) error {
var err error
// NOTE(@andreynering): The SQL Server lib (github.com/denisenkom/go-mssqldb)
// supports both the "?" style (when using the deprecated "mssql" driver)
// and the "@p1" style (when using the new "sqlserver" driver).
//
// Since we don't have a way to know which driver it's been used,
// this is a small hack to detect the allowed param style.
var v int
if err := db.QueryRow("SELECT ?", 1).Scan(&v); err == nil && v == 1 {
h.paramTypeCache = paramTypeQuestion
} else {
h.paramTypeCache = paramTypeAtSign
}
h.tables, err = h.tableNames(db)
if err != nil {
return err
}
return nil
}
func (h *sqlserver) paramType() int {
return h.paramTypeCache
}
func (*sqlserver) quoteKeyword(s string) string {
parts := strings.Split(s, ".")
for i, p := range parts {
parts[i] = fmt.Sprintf(`[%s]`, p)
}
return strings.Join(parts, ".")
}
func (*sqlserver) databaseName(q queryable) (string, error) {
var dbName string
err := q.QueryRow("SELECT DB_NAME()").Scan(&dbName)
return dbName, err
}
func (*sqlserver) tableNames(q queryable) ([]string, error) {
rows, err := q.Query("SELECT table_schema + '.' + table_name FROM INFORMATION_SCHEMA.TABLES WHERE table_name <> 'spt_values' AND table_type = 'BASE TABLE'")
if err != nil {
return nil, err
}
defer rows.Close()
var tables []string
for rows.Next() {
var table string
if err = rows.Scan(&table); err != nil {
return nil, err
}
tables = append(tables, table)
}
if err = rows.Err(); err != nil {
return nil, err
}
return tables, nil
}
func (h *sqlserver) tableHasIdentityColumn(q queryable, tableName string) (bool, error) {
sql := fmt.Sprintf(`
SELECT COUNT(*)
FROM sys.identity_columns
WHERE OBJECT_ID = OBJECT_ID('%s')
`, tableName)
var count int
if err := q.QueryRow(sql).Scan(&count); err != nil {
return false, err
}
return count > 0, nil
}
func (h *sqlserver) whileInsertOnTable(tx *sql.Tx, tableName string, fn func() error) (err error) {
hasIdentityColumn, err := h.tableHasIdentityColumn(tx, tableName)
if err != nil {
return err
}
if hasIdentityColumn {
defer func() {
_, err2 := tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", h.quoteKeyword(tableName)))
if err2 != nil && err == nil {
err = fmt.Errorf("testfixtures: could not disable identity insert: %w", err2)
}
}()
_, err := tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s ON", h.quoteKeyword(tableName)))
if err != nil {
return fmt.Errorf("testfixtures: could not enable identity insert: %w", err)
}
}
return fn()
}
func (h *sqlserver) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) (err error) {
// ensure the triggers are re-enable after all
defer func() {
var b strings.Builder
for _, table := range h.tables {
b.WriteString(fmt.Sprintf("ALTER TABLE %s WITH CHECK CHECK CONSTRAINT ALL;", h.quoteKeyword(table)))
}
if _, err2 := db.Exec(b.String()); err2 != nil && err == nil {
err = err2
}
}()
var b strings.Builder
for _, table := range h.tables {
b.WriteString(fmt.Sprintf("ALTER TABLE %s NOCHECK CONSTRAINT ALL;", h.quoteKeyword(table)))
}
if _, err := db.Exec(b.String()); err != nil {
return err
}
tx, err := db.Begin()
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
if err = loadFn(tx); err != nil {
return err
}
return tx.Commit()
}