Skip to content

Commit

Permalink
Merge pull request #91 from huandu/feature-driver-valuer
Browse files Browse the repository at this point in the history
fix #90 Support driver.Valuer type in Struct and interpolate methods
  • Loading branch information
huandu authored Nov 24, 2022
2 parents 167e2ad + 7507922 commit f299327
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 81 deletions.
173 changes: 103 additions & 70 deletions interpolate.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
package sqlbuilder

import (
"database/sql/driver"
"fmt"
"reflect"
"strconv"
"time"
"unicode"
Expand Down Expand Up @@ -389,78 +391,13 @@ func encodeValue(buf []byte, arg interface{}, flavor Flavor) ([]byte, error) {
case nil:
buf = append(buf, "NULL"...)

case bool:
if v {
buf = append(buf, "TRUE"...)
case driver.Valuer:
if val, err := v.Value(); err != nil {
return nil, err
} else {
buf = append(buf, "FALSE"...)
return encodeValue(buf, val, flavor)
}

case int:
buf = strconv.AppendInt(buf, int64(v), 10)

case int8:
buf = strconv.AppendInt(buf, int64(v), 10)

case int16:
buf = strconv.AppendInt(buf, int64(v), 10)

case int32:
buf = strconv.AppendInt(buf, int64(v), 10)

case int64:
buf = strconv.AppendInt(buf, v, 10)

case uint:
buf = strconv.AppendUint(buf, uint64(v), 10)

case uint8:
buf = strconv.AppendUint(buf, uint64(v), 10)

case uint16:
buf = strconv.AppendUint(buf, uint64(v), 10)

case uint32:
buf = strconv.AppendUint(buf, uint64(v), 10)

case uint64:
buf = strconv.AppendUint(buf, v, 10)

case float32:
buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 32)

case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)

case []byte:
if v == nil {
buf = append(buf, "NULL"...)
break
}

switch flavor {
case MySQL:
buf = append(buf, "_binary"...)
buf = quoteStringValue(buf, *(*string)(unsafe.Pointer(&v)), flavor)

case PostgreSQL:
buf = append(buf, "E'\\\\x"...)
buf = appendHex(buf, v)
buf = append(buf, "'::bytea"...)

case SQLite:
buf = append(buf, "X'"...)
buf = appendHex(buf, v)
buf = append(buf, '\'')

case SQLServer:
buf = append(buf, "0x"...)
buf = appendHex(buf, v)
}

case string:
buf = quoteStringValue(buf, v, flavor)

case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
Expand Down Expand Up @@ -492,7 +429,103 @@ func encodeValue(buf []byte, arg interface{}, flavor Flavor) ([]byte, error) {
buf = quoteStringValue(buf, v.String(), flavor)

default:
return nil, ErrInterpolateUnsupportedArgs
primative := reflect.ValueOf(arg)

switch k := primative.Kind(); k {
case reflect.Bool:
if primative.Bool() {
buf = append(buf, "TRUE"...)
} else {
buf = append(buf, "FALSE"...)
}

case reflect.Int:
buf = strconv.AppendInt(buf, primative.Int(), 10)

case reflect.Int8:
buf = strconv.AppendInt(buf, primative.Int(), 10)

case reflect.Int16:
buf = strconv.AppendInt(buf, primative.Int(), 10)

case reflect.Int32:
buf = strconv.AppendInt(buf, primative.Int(), 10)

case reflect.Int64:
buf = strconv.AppendInt(buf, primative.Int(), 10)

case reflect.Uint:
buf = strconv.AppendUint(buf, primative.Uint(), 10)

case reflect.Uint8:
buf = strconv.AppendUint(buf, primative.Uint(), 10)

case reflect.Uint16:
buf = strconv.AppendUint(buf, primative.Uint(), 10)

case reflect.Uint32:
buf = strconv.AppendUint(buf, primative.Uint(), 10)

case reflect.Uint64:
buf = strconv.AppendUint(buf, primative.Uint(), 10)

case reflect.Float32:
buf = strconv.AppendFloat(buf, primative.Float(), 'g', -1, 32)

case reflect.Float64:
buf = strconv.AppendFloat(buf, primative.Float(), 'g', -1, 64)

case reflect.String:
buf = quoteStringValue(buf, primative.String(), flavor)

case reflect.Slice, reflect.Array:
if k == reflect.Slice && primative.IsNil() {
buf = append(buf, "NULL"...)
break
}

if elem := primative.Type().Elem(); elem.Kind() != reflect.Uint8 {
return nil, ErrInterpolateUnsupportedArgs
}

var data []byte

// Bytes() will panic if primative is an array and cannot be addressed.
// Copy all bytes to data as a fallback.
if k == reflect.Array && !primative.CanAddr() {
l := primative.Len()
data = make([]byte, l)

for i := 0; i < l; i++ {
data[i] = byte(primative.Index(i).Uint())
}
} else {
data = primative.Bytes()
}

switch flavor {
case MySQL:
buf = append(buf, "_binary"...)
buf = quoteStringValue(buf, *(*string)(unsafe.Pointer(&data)), flavor)

case PostgreSQL:
buf = append(buf, "E'\\\\x"...)
buf = appendHex(buf, data)
buf = append(buf, "'::bytea"...)

case SQLite:
buf = append(buf, "X'"...)
buf = appendHex(buf, data)
buf = append(buf, '\'')

case SQLServer:
buf = append(buf, "0x"...)
buf = appendHex(buf, data)
}

default:
return nil, ErrInterpolateUnsupportedArgs
}
}

return buf, nil
Expand Down
42 changes: 34 additions & 8 deletions interpolate_test.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,34 @@
package sqlbuilder

import (
"database/sql/driver"
"errors"
"strconv"
"testing"
"time"

"github.com/huandu/go-assert"
)

type errorValuer int

var ErrErrorValuer = errors.New("error valuer")

func (v errorValuer) Value() (driver.Value, error) {
return 0, ErrErrorValuer
}

func TestFlavorInterpolate(t *testing.T) {
a := assert.New(t)
dt := time.Date(2019, 4, 24, 12, 23, 34, 123456789, time.FixedZone("CST", 8*60*60)) // 2019-04-24 12:23:34.987654321 CST
_, errOutOfRange := strconv.ParseInt("12345678901234567890", 10, 32)
byteArr := [...]byte{'f', 'o', 'o'}
cases := []struct {
flavor Flavor
sql string
args []interface{}
query string
err error
Flavor Flavor
SQL string
Args []interface{}
Query string
Err error
}{
{
MySQL,
Expand All @@ -39,6 +50,11 @@ func TestFlavorInterpolate(t *testing.T) {
"SELECT '\\'?', \"\\\"?\", `\\`?`, \\?", []interface{}{MySQL},
"SELECT '\\'?', \"\\\"?\", `\\`?`, \\'MySQL'", nil,
},
{
MySQL,
"SELECT ?", []interface{}{byteArr},
"SELECT _binary'foo'", nil,
},
{
MySQL,
"SELECT ?", nil,
Expand All @@ -49,6 +65,16 @@ func TestFlavorInterpolate(t *testing.T) {
"SELECT ?", []interface{}{complex(1, 2)},
"", ErrInterpolateUnsupportedArgs,
},
{
MySQL,
"SELECT ?", []interface{}{[]complex128{complex(1, 2)}},
"", ErrInterpolateUnsupportedArgs,
},
{
MySQL,
"SELECT ?", []interface{}{errorValuer(1)},
"", ErrErrorValuer,
},

{
PostgreSQL,
Expand Down Expand Up @@ -141,9 +167,9 @@ func TestFlavorInterpolate(t *testing.T) {

for idx, c := range cases {
a.Use(&idx, &c)
query, err := c.flavor.Interpolate(c.sql, c.args)
query, err := c.Flavor.Interpolate(c.SQL, c.Args)

a.Equal(query, c.query)
a.Assert(err == c.err || err.Error() == c.err.Error())
a.Equal(query, c.Query)
a.Assert(err == c.Err || err.Error() == c.Err.Error())
}
}
21 changes: 18 additions & 3 deletions struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package sqlbuilder

import (
"bytes"
"database/sql/driver"
"math"
"reflect"
"regexp"
Expand Down Expand Up @@ -36,6 +37,8 @@ const (

var optRegex = regexp.MustCompile(`(?P<` + optName + `>\w+)(\((?P<` + optParams + `>.*)\))?`)

var typeOfSQLDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem()

// Struct represents a struct type.
//
// All methods in Struct are thread-safe.
Expand Down Expand Up @@ -179,7 +182,7 @@ func (s *Struct) UpdateForTag(table string, tag string, value interface{}) *Upda
continue
}
} else {
val = dereferencedValue(val)
val = dereferencedFieldValue(val)
}

data := val.Interface()
Expand Down Expand Up @@ -237,7 +240,7 @@ func (s *Struct) buildColsAndValuesForTag(ib *InsertBuilder, tag string, value .

for _, item := range value {
v := reflect.ValueOf(item)
v = dereferencedValue(v)
v = dereferencedFieldValue(v)

if v.Type() == s.structType {
vs = append(vs, v)
Expand Down Expand Up @@ -265,7 +268,7 @@ func (s *Struct) buildColsAndValuesForTag(ib *InsertBuilder, tag string, value .
nilCnt++
}

val = dereferencedValue(val)
val = dereferencedFieldValue(val)

if val.IsValid() {
values[i] = append(values[i], val.Interface())
Expand Down Expand Up @@ -485,6 +488,18 @@ func dereferencedValue(v reflect.Value) reflect.Value {
return v
}

func dereferencedFieldValue(v reflect.Value) reflect.Value {
for k := v.Kind(); k == reflect.Ptr || k == reflect.Interface; k = v.Kind() {
if v.Type().Implements(typeOfSQLDriverValuer) {
break
}

v = v.Elem()
}

return v
}

// isEmptyValue checks if v is zero.
// Following code is borrowed from `IsZero` method in `reflect.Value` since Go 1.13.
func isEmptyValue(v reflect.Value) bool {
Expand Down
32 changes: 32 additions & 0 deletions struct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package sqlbuilder

import (
"database/sql/driver"
"fmt"
"testing"
"time"
Expand Down Expand Up @@ -811,6 +812,37 @@ func TestStructFieldAs(t *testing.T) {
a.Equal(sql, `UPDATE t SET t1 = ?, t2 = ?, t4 = ?`)
}

type structImplValuer int

func (v *structImplValuer) Value() (driver.Value, error) {
return *v * 2, nil
}

type structContainsValuer struct {
F1 string
F2 *structImplValuer
}

func TestStructFieldsImplValuer(t *testing.T) {
a := assert.New(t)
st := NewStruct(new(structContainsValuer))
f1 := "foo"
f2 := structImplValuer(100)

sql, args := st.Update("t", structContainsValuer{
F1: f1,
F2: &f2,
}).BuildWithFlavor(MySQL)

a.Equal(sql, "UPDATE t SET F1 = ?, F2 = ?")
a.Equal(args[0], f1)
a.Equal(args[1], &f2)

result, err := MySQL.Interpolate(sql, args)
a.NilError(err)
a.Equal(result, "UPDATE t SET F1 = 'foo', F2 = 200")
}

func SomeOtherMapper(string) string {
return ""
}
Expand Down

0 comments on commit f299327

Please sign in to comment.