Skip to content

Commit

Permalink
introduce runTestsParallel
Browse files Browse the repository at this point in the history
  • Loading branch information
shogo82148 committed Oct 7, 2023
1 parent 3798012 commit 8afa6c5
Showing 1 changed file with 108 additions and 39 deletions.
147 changes: 108 additions & 39 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package mysql
import (
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"database/sql"
"database/sql/driver"
Expand Down Expand Up @@ -149,16 +150,84 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
}

for _, test := range tests {
test := test
t.Run("default", func(t *testing.T) {
dbt := &DBTest{t, db}
t.Cleanup(func() {
dbt.db.Exec("DROP TABLE IF EXISTS test")
})
test(dbt)
dbt.db.Exec("DROP TABLE IF EXISTS test")
})
if db2 != nil {
t.Run("interpolateParams", func(t *testing.T) {
dbt2 := &DBTest{t, db2}
t.Cleanup(func() {
dbt2.db.Exec("DROP TABLE IF EXISTS test")
})
test(dbt2)
dbt2.db.Exec("DROP TABLE IF EXISTS test")
})
}
}
}

func runTestsParallel(t *testing.T, dsn string, tests ...func(dbt *DBTest, tableName string)) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
}

t.Parallel()

for _, test := range tests {
test := test
t.Run("default", func(t *testing.T) {
t.Parallel()

var buf [8]byte
if _, err := rand.Read(buf[:]); err != nil {
t.Fatal(err)
}
tableName := fmt.Sprintf("test_%x", buf[:])

db, err := sql.Open("mysql", dsn)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
t.Cleanup(func() {
db.Exec("DROP TABLE IF EXISTS " + tableName)
db.Close()
})

db.Exec("DROP TABLE IF EXISTS " + tableName)
dbt := &DBTest{t, db}
test(dbt, tableName)
dbt.db.Exec("DROP TABLE IF EXISTS test")
})

dsn2 := dsn + "&interpolateParams=true"
if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation {
t.Run("interpolateParams", func(t *testing.T) {
t.Parallel()

var buf [8]byte
if _, err := rand.Read(buf[:]); err != nil {
t.Fatal(err)
}
tableName := fmt.Sprintf("test_%x", buf[:])

db, err := sql.Open("mysql", dsn2)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
t.Cleanup(func() {
db.Exec("DROP TABLE IF EXISTS " + tableName)
db.Close()
})

db.Exec("DROP TABLE IF EXISTS " + tableName)
dbt := &DBTest{t, db}
test(dbt, tableName)
dbt.db.Exec("DROP TABLE IF EXISTS test")
})
}
}
Expand Down Expand Up @@ -199,7 +268,7 @@ func maybeSkip(t *testing.T, err error, skipErrno uint16) {
}

func TestEmptyQuery(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
runTestsParallel(t, dsn, func(dbt *DBTest, _ string) {
// just a comment, no query
rows := dbt.mustQuery("--")
defer rows.Close()
Expand All @@ -211,20 +280,20 @@ func TestEmptyQuery(t *testing.T) {
}

func TestCRUD(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) {
// Create Table
dbt.mustExec("CREATE TABLE test (value BOOL)")
dbt.mustExec("CREATE TABLE " + tbl + " (value BOOL)")

// Test for unexpected data
var out bool
rows := dbt.mustQuery("SELECT * FROM test")
rows := dbt.mustQuery("SELECT * FROM " + tbl)
if rows.Next() {
dbt.Error("unexpected data in empty table")
}
rows.Close()

// Create Data
res := dbt.mustExec("INSERT INTO test VALUES (1)")
res := dbt.mustExec("INSERT INTO " + tbl + " VALUES (1)")
count, err := res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
Expand All @@ -242,7 +311,7 @@ func TestCRUD(t *testing.T) {
}

// Read
rows = dbt.mustQuery("SELECT value FROM test")
rows = dbt.mustQuery("SELECT value FROM " + tbl)
if rows.Next() {
rows.Scan(&out)
if true != out {
Expand All @@ -258,7 +327,7 @@ func TestCRUD(t *testing.T) {
rows.Close()

// Update
res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true)
res = dbt.mustExec("UPDATE "+tbl+" SET value = ? WHERE value = ?", false, true)
count, err = res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
Expand All @@ -268,7 +337,7 @@ func TestCRUD(t *testing.T) {
}

// Check Update
rows = dbt.mustQuery("SELECT value FROM test")
rows = dbt.mustQuery("SELECT value FROM " + tbl)
if rows.Next() {
rows.Scan(&out)
if false != out {
Expand All @@ -284,7 +353,7 @@ func TestCRUD(t *testing.T) {
rows.Close()

// Delete
res = dbt.mustExec("DELETE FROM test WHERE value = ?", false)
res = dbt.mustExec("DELETE FROM "+tbl+" WHERE value = ?", false)
count, err = res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
Expand All @@ -294,7 +363,7 @@ func TestCRUD(t *testing.T) {
}

// Check for unexpected rows
res = dbt.mustExec("DELETE FROM test")
res = dbt.mustExec("DELETE FROM " + tbl)
count, err = res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
Expand All @@ -308,13 +377,13 @@ func TestCRUD(t *testing.T) {
// TestNumbers test that selecting numeric columns.
// Both of textRows and binaryRows should return same type and value.
func TestNumbersToAny(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE `test` (id INT PRIMARY KEY, b BOOL, i8 TINYINT, " +
runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) {
dbt.mustExec("CREATE TABLE " + tbl + " (id INT PRIMARY KEY, b BOOL, i8 TINYINT, " +
"i16 SMALLINT, i32 INT, i64 BIGINT, f32 FLOAT, f64 DOUBLE)")
dbt.mustExec("INSERT INTO `test` VALUES (1, true, 127, 32767, 2147483647, 9223372036854775807, 1.25, 2.5)")
dbt.mustExec("INSERT INTO " + tbl + " VALUES (1, true, 127, 32767, 2147483647, 9223372036854775807, 1.25, 2.5)")

// Use binaryRows for intarpolateParams=false and textRows for intarpolateParams=true.
rows := dbt.mustQuery("SELECT b, i8, i16, i32, i64, f32, f64 FROM `test` WHERE id=?", 1)
rows := dbt.mustQuery("SELECT b, i8, i16, i32, i64, f32, f64 FROM "+tbl+" WHERE id=?", 1)
if !rows.Next() {
dbt.Fatal("no data")
}
Expand Down Expand Up @@ -393,19 +462,19 @@ func TestMultiQuery(t *testing.T) {
}

func TestInt(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) {
types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"}
in := int64(42)
var out int64
var rows *sql.Rows

// SIGNED
for _, v := range types {
dbt.mustExec("CREATE TABLE test (value " + v + ")")
dbt.mustExec("CREATE TABLE " + tbl + " (value " + v + ")")

dbt.mustExec("INSERT INTO test VALUES (?)", in)
dbt.mustExec("INSERT INTO "+tbl+" VALUES (?)", in)

rows = dbt.mustQuery("SELECT value FROM test")
rows = dbt.mustQuery("SELECT value FROM " + tbl)
if rows.Next() {
rows.Scan(&out)
if in != out {
Expand All @@ -416,16 +485,16 @@ func TestInt(t *testing.T) {
}
rows.Close()

dbt.mustExec("DROP TABLE IF EXISTS test")
dbt.mustExec("DROP TABLE IF EXISTS " + tbl)
}

// UNSIGNED ZEROFILL
for _, v := range types {
dbt.mustExec("CREATE TABLE test (value " + v + " ZEROFILL)")
dbt.mustExec("CREATE TABLE " + tbl + " (value " + v + " ZEROFILL)")

dbt.mustExec("INSERT INTO test VALUES (?)", in)
dbt.mustExec("INSERT INTO "+tbl+" VALUES (?)", in)

rows = dbt.mustQuery("SELECT value FROM test")
rows = dbt.mustQuery("SELECT value FROM " + tbl)
if rows.Next() {
rows.Scan(&out)
if in != out {
Expand All @@ -436,21 +505,21 @@ func TestInt(t *testing.T) {
}
rows.Close()

dbt.mustExec("DROP TABLE IF EXISTS test")
dbt.mustExec("DROP TABLE IF EXISTS " + tbl)
}
})
}

func TestFloat32(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) {
types := [2]string{"FLOAT", "DOUBLE"}
in := float32(42.23)
var out float32
var rows *sql.Rows
for _, v := range types {
dbt.mustExec("CREATE TABLE test (value " + v + ")")
dbt.mustExec("INSERT INTO test VALUES (?)", in)
rows = dbt.mustQuery("SELECT value FROM test")
dbt.mustExec("CREATE TABLE " + tbl + " (value " + v + ")")
dbt.mustExec("INSERT INTO "+tbl+" VALUES (?)", in)
rows = dbt.mustQuery("SELECT value FROM " + tbl)
if rows.Next() {
rows.Scan(&out)
if in != out {
Expand All @@ -460,21 +529,21 @@ func TestFloat32(t *testing.T) {
dbt.Errorf("%s: no data", v)
}
rows.Close()
dbt.mustExec("DROP TABLE IF EXISTS test")
dbt.mustExec("DROP TABLE IF EXISTS " + tbl)
}
})
}

func TestFloat64(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) {
types := [2]string{"FLOAT", "DOUBLE"}
var expected float64 = 42.23
var out float64
var rows *sql.Rows
for _, v := range types {
dbt.mustExec("CREATE TABLE test (value " + v + ")")
dbt.mustExec("INSERT INTO test VALUES (42.23)")
rows = dbt.mustQuery("SELECT value FROM test")
dbt.mustExec("CREATE TABLE " + tbl + " (value " + v + ")")
dbt.mustExec("INSERT INTO " + tbl + " VALUES (42.23)")
rows = dbt.mustQuery("SELECT value FROM " + tbl)
if rows.Next() {
rows.Scan(&out)
if expected != out {
Expand All @@ -484,21 +553,21 @@ func TestFloat64(t *testing.T) {
dbt.Errorf("%s: no data", v)
}
rows.Close()
dbt.mustExec("DROP TABLE IF EXISTS test")
dbt.mustExec("DROP TABLE IF EXISTS " + tbl)
}
})
}

func TestFloat64Placeholder(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
runTestsParallel(t, dsn, func(dbt *DBTest, tbl string) {
types := [2]string{"FLOAT", "DOUBLE"}
var expected float64 = 42.23
var out float64
var rows *sql.Rows
for _, v := range types {
dbt.mustExec("CREATE TABLE test (id int, value " + v + ")")
dbt.mustExec("INSERT INTO test VALUES (1, 42.23)")
rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1)
dbt.mustExec("CREATE TABLE " + tbl + " (id int, value " + v + ")")
dbt.mustExec("INSERT INTO " + tbl + " VALUES (1, 42.23)")
rows = dbt.mustQuery("SELECT value FROM "+tbl+" WHERE id = ?", 1)
if rows.Next() {
rows.Scan(&out)
if expected != out {
Expand All @@ -508,7 +577,7 @@ func TestFloat64Placeholder(t *testing.T) {
dbt.Errorf("%s: no data", v)
}
rows.Close()
dbt.mustExec("DROP TABLE IF EXISTS test")
dbt.mustExec("DROP TABLE IF EXISTS " + tbl)
}
})
}
Expand Down

0 comments on commit 8afa6c5

Please sign in to comment.