forked from arnehormann/sqlinternals
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsqlinternals_test.go
147 lines (127 loc) · 3.94 KB
/
sqlinternals_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
// sqlinternals - retrieve driver.Rows from sql.*Row / sql.*Rows
//
// Copyright 2013 Arne Hormann. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package sqlinternals
import (
"database/sql"
"database/sql/driver"
"io"
"testing"
)
type omnithing struct {
numInputs int
columns []string
rows [][]interface{}
}
func (t *omnithing) Close() error { return nil }
// driver.Driver
func (t *omnithing) Open(name string) (driver.Conn, error) { return t, nil }
// driver.Conn
func (t *omnithing) Prepare(query string) (driver.Stmt, error) { return t, nil }
func (t *omnithing) Begin() (driver.Tx, error) { return t, nil }
// driver.Tx
func (t *omnithing) Commit() error { return nil }
func (t *omnithing) Rollback() error { return nil }
// driver.Stmt
func (t *omnithing) NumInput() int { return t.numInputs }
func (t *omnithing) Exec(args []driver.Value) (driver.Result, error) { return t, nil }
func (t *omnithing) Query(args []driver.Value) (driver.Rows, error) { return t, nil }
// driver.Result
func (t *omnithing) LastInsertId() (int64, error) { return 0, nil }
func (t *omnithing) RowsAffected() (int64, error) { return 0, nil }
// driver.Rows
func (t *omnithing) Columns() []string { return t.columns }
func (t *omnithing) Next(dest []driver.Value) error {
if len(t.rows) == 0 {
return io.EOF
}
var row []interface{}
row, t.rows = t.rows[0], t.rows[1:]
for i, v := range row {
dest[i] = v
}
return nil
}
func (o *omnithing) setDB(numInputs int, columns []string, cells ...interface{}) *omnithing {
o.numInputs = numInputs
o.columns = columns
numCols, numCells := len(columns), len(cells)
numRows := numCells / numCols
if numCols*numRows != numCells {
panic("wrong number of cells")
}
rows := [][]interface{}{}
for r := 0; r < numRows; r++ {
cols := []interface{}{}
for c := 0; c < numCols; c++ {
cols = append(cols, cells[r*numCols+c])
}
rows = append(rows, cols)
}
o.rows = rows
return o
}
type querier func(conn *sql.DB) (interface{}, error)
var (
testdriver = &omnithing{}
// make sure the test type implements the interfaces
_ driver.Driver = testdriver
_ driver.Conn = testdriver
_ driver.Tx = testdriver
_ driver.Stmt = testdriver
_ driver.Result = testdriver
_ driver.Rows = testdriver
)
const driverType = "test"
func init() {
sql.Register(driverType, testdriver)
}
func runRowsTest(t *testing.T, query querier, numInputs int, columns []string, cells ...interface{}) {
// set intial state before usage
testdriver.setDB(numInputs, columns, cells...)
// run a query, retrieve *sql.Rows
conn, err := sql.Open(driverType, "")
defer conn.Close()
rowOrRows, err := query(conn)
if closer, ok := rowOrRows.(io.Closer); ok {
defer closer.Close()
}
// check that it is accessible and matches the one in testdriver.rows
unwrapped, err := Inspect(rowOrRows)
if err != nil {
t.Error(err)
return
}
myrows, ok := unwrapped.(*omnithing)
if !ok || myrows != testdriver {
t.Errorf("returned driver.Rows must match those passed in.")
}
}
func TestRowWithoutArgs(t *testing.T) {
query := func(conn *sql.DB) (interface{}, error) {
return conn.QueryRow(`SELECT "test"`), nil
}
runRowsTest(t, query, 0, []string{"header"}, "test")
}
func TestRowWithArgs(t *testing.T) {
query := func(conn *sql.DB) (interface{}, error) {
return conn.QueryRow(`SELECT ?`, "test"), nil
}
runRowsTest(t, query, 1, []string{"header"}, "test")
}
func TestRowsWithoutArgs(t *testing.T) {
query := func(conn *sql.DB) (interface{}, error) {
return conn.Query(`SELECT "test"`)
}
runRowsTest(t, query, 0, []string{"header"}, "test")
}
func TestRowsWithArgs(t *testing.T) {
query := func(conn *sql.DB) (interface{}, error) {
return conn.Query(`SELECT ?`, "test")
}
runRowsTest(t, query, 1, []string{"header"}, "test")
}