diff --git a/README.md b/README.md index 6f692bc..eb66750 100644 --- a/README.md +++ b/README.md @@ -262,11 +262,11 @@ types: passed to Trino as a time with a time zone * the result of `trino.Timestamp(year, month, day, hour, minute, second, nanosecond)` - passed to Trino as a timestamp without a time zone +* `time.Duration` - passed to Trino as an interval day to second. Because Trino does not support nanosecond precision for intervals, if the nanosecond part of the value is not zero, an error will be returned. It's not yet possible to pass: * `float32` or `float64` * `byte` -* `time.Duration` * `json.RawMessage` * maps diff --git a/trino/integration_test.go b/trino/integration_test.go index ff80be7..828ce2a 100644 --- a/trino/integration_test.go +++ b/trino/integration_test.go @@ -30,6 +30,7 @@ import ( "fmt" "io" "log" + "math" "math/big" "net/http" "os" @@ -987,3 +988,127 @@ func contextSleep(ctx context.Context, d time.Duration) error { return ctx.Err() } } + +func TestIntegrationDayToHourIntervalMilliPrecision(t *testing.T) { + db := integrationOpen(t) + defer db.Close() + tests := []struct { + name string + arg time.Duration + wantErr bool + }{ + { + name: "valid 1234567891s", + arg: time.Duration(1234567891) * time.Second, + wantErr: false, + }, + { + name: "valid 123456789.1s", + arg: time.Duration(123456789100) * time.Millisecond, + wantErr: false, + }, + { + name: "valid 12345678.91s", + arg: time.Duration(12345678910) * time.Millisecond, + wantErr: false, + }, + { + name: "valid 1234567.891s", + arg: time.Duration(1234567891) * time.Millisecond, + wantErr: false, + }, + { + name: "valid -1234567891s", + arg: time.Duration(-1234567891) * time.Second, + wantErr: false, + }, + { + name: "valid -123456789.1s", + arg: time.Duration(-123456789100) * time.Millisecond, + wantErr: false, + }, + { + name: "valid -12345678.91s", + arg: time.Duration(-12345678910) * time.Millisecond, + wantErr: false, + }, + { + name: "valid -1234567.891s", + arg: time.Duration(-1234567891) * time.Millisecond, + wantErr: false, + }, + { + name: "invalid 1234567891.2s", + arg: time.Duration(1234567891200) * time.Millisecond, + wantErr: true, + }, + { + name: "invalid 123456789.12s", + arg: time.Duration(123456789120) * time.Millisecond, + wantErr: true, + }, + { + name: "invalid 12345678.912s", + arg: time.Duration(12345678912) * time.Millisecond, + wantErr: true, + }, + { + name: "invalid -1234567891.2s", + arg: time.Duration(-1234567891200) * time.Millisecond, + wantErr: true, + }, + { + name: "invalid -123456789.12s", + arg: time.Duration(-123456789120) * time.Millisecond, + wantErr: true, + }, + { + name: "invalid -12345678.912s", + arg: time.Duration(-12345678912) * time.Millisecond, + wantErr: true, + }, + { + name: "invalid max seconds (9223372036)", + arg: time.Duration(math.MaxInt64) / time.Second * time.Second, + wantErr: true, + }, + { + name: "invalid min seconds (-9223372036)", + arg: time.Duration(math.MinInt64) / time.Second * time.Second, + wantErr: true, + }, + { + name: "valid max seconds (2147483647)", + arg: math.MaxInt32 * time.Second, + }, + { + name: "valid min seconds (-2147483647)", + arg: -math.MaxInt32 * time.Second, + }, + { + name: "valid max minutes (153722867)", + arg: time.Duration(math.MaxInt64) / time.Minute * time.Minute, + }, + { + name: "valid min minutes (-153722867)", + arg: time.Duration(math.MinInt64) / time.Minute * time.Minute, + }, + { + name: "valid max hours (2562047)", + arg: time.Duration(math.MaxInt64) / time.Hour * time.Hour, + }, + { + name: "valid min hours (-2562047)", + arg: time.Duration(math.MinInt64) / time.Hour * time.Hour, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, err := db.Exec("SELECT ?", test.arg) + if (err != nil) != test.wantErr { + t.Errorf("Exec() error = %v, wantErr %v", err, test.wantErr) + return + } + }) + } +} diff --git a/trino/serial.go b/trino/serial.go index 88e9a9d..5a778a5 100644 --- a/trino/serial.go +++ b/trino/serial.go @@ -17,6 +17,7 @@ package trino import ( "encoding/json" "fmt" + "math" "reflect" "strconv" "strings" @@ -163,7 +164,7 @@ func Serial(v interface{}) (string, error) { return "TIMESTAMP " + time.Time(x).Format("'2006-01-02 15:04:05.999999999 Z07:00'"), nil case time.Duration: - return "", UnsupportedArgError{"time.Duration"} + return serialDuration(x) // TODO - json.RawMesssage should probably be matched to 'JSON' in Trino case json.RawMessage: @@ -208,3 +209,51 @@ func serialSlice(v []interface{}) (string, error) { return "ARRAY[" + strings.Join(ss, ", ") + "]", nil } + +const ( + // For seconds with milliseconds there is a maximum length of 10 digits + // or 11 characters with the dot and 12 characters with the minus sign and dot + maxIntervalStrLenWithDot = 11 // 123456789.1 and 12345678.91 are valid +) + +func serialDuration(dur time.Duration) (string, error) { + switch { + case dur%time.Hour == 0: + return serialHoursInterval(dur), nil + case dur%time.Minute == 0: + return serialMinutesInterval(dur), nil + case dur%time.Second == 0: + return serialSecondsInterval(dur) + case dur%time.Millisecond == 0: + return serialMillisecondsInterval(dur) + default: + return "", fmt.Errorf("trino: duration %v is not a multiple of hours, minutes, seconds or milliseconds", dur) + } +} + +func serialHoursInterval(dur time.Duration) string { + return "INTERVAL '" + strconv.Itoa(int(dur/time.Hour)) + "' HOUR" +} + +func serialMinutesInterval(dur time.Duration) string { + return "INTERVAL '" + strconv.Itoa(int(dur/time.Minute)) + "' MINUTE" +} + +func serialSecondsInterval(dur time.Duration) (string, error) { + seconds := int64(dur / time.Second) + if seconds <= math.MinInt32 || seconds > math.MaxInt32 { + return "", fmt.Errorf("trino: duration %v is out of range for interval of seconds type", dur) + } + return "INTERVAL '" + strconv.FormatInt(seconds, 10) + "' SECOND", nil +} + +func serialMillisecondsInterval(dur time.Duration) (string, error) { + seconds := int64(dur / time.Second) + millisInSecond := dur.Abs().Milliseconds() % 1000 + intervalNr := strings.TrimRight(fmt.Sprintf("%d.%03d", seconds, millisInSecond), "0") + if seconds > 0 && len(intervalNr) > maxIntervalStrLenWithDot || + seconds < 0 && len(intervalNr) > maxIntervalStrLenWithDot+1 { // +1 for the minus sign + return "", fmt.Errorf("trino: duration %v is out of range for interval of seconds with millis type", dur) + } + return "INTERVAL '" + intervalNr + "' SECOND", nil +} diff --git a/trino/serial_test.go b/trino/serial_test.go index fc0aa2f..aa91145 100644 --- a/trino/serial_test.go +++ b/trino/serial_test.go @@ -15,6 +15,7 @@ package trino import ( + "math" "testing" "time" @@ -160,6 +161,86 @@ func TestSerial(t *testing.T) { value: time.Date(2017, 7, 10, 11, 34, 25, 123456, time.UTC), expectedSerial: "TIMESTAMP '2017-07-10 11:34:25.000123456 Z'", }, + { + name: "duration", + value: 10*time.Second + 5*time.Millisecond, + expectedSerial: "INTERVAL '10.005' SECOND", + }, + { + name: "duration with negative value", + value: -(10*time.Second + 5*time.Millisecond), + expectedSerial: "INTERVAL '-10.005' SECOND", + }, + { + name: "minute duration", + value: 10 * time.Minute, + expectedSerial: "INTERVAL '10' MINUTE", + }, + { + name: "hour duration", + value: 23 * time.Hour, + expectedSerial: "INTERVAL '23' HOUR", + }, + { + name: "max hour duration", + value: (math.MaxInt64 / time.Hour) * time.Hour, + expectedSerial: "INTERVAL '2562047' HOUR", + }, + { + name: "min hour duration", + value: (math.MinInt64 / time.Hour) * time.Hour, + expectedSerial: "INTERVAL '-2562047' HOUR", + }, + { + name: "max minute duration", + value: (math.MaxInt64 / time.Minute) * time.Minute, + expectedSerial: "INTERVAL '153722867' MINUTE", + }, + { + name: "min minute duration", + value: (math.MinInt64 / time.Minute) * time.Minute, + expectedSerial: "INTERVAL '-153722867' MINUTE", + }, + { + name: "too big second duration", + value: (math.MaxInt64 / time.Second) * time.Second, + expectedError: true, + }, + { + name: "too small second duration", + value: (math.MinInt64 / time.Second) * time.Second, + expectedError: true, + }, + { + name: "too big millisecond duration", + value: time.Millisecond*912 + time.Second*12345678, + expectedError: true, + }, + { + name: "too small millisecond duration", + value: -(time.Millisecond*910 + time.Second*123456789), + expectedError: true, + }, + { + name: "max allowed second duration", + value: math.MaxInt32 * time.Second, + expectedSerial: "INTERVAL '2147483647' SECOND", + }, + { + name: "min allowed second duration", + value: -math.MaxInt32 * time.Second, + expectedSerial: "INTERVAL '-2147483647' SECOND", + }, + { + name: "max allowed second with milliseconds duration", + value: 999999999*time.Second + 900*time.Millisecond, + expectedSerial: "INTERVAL '999999999.9' SECOND", + }, + { + name: "min allowed second with milliseconds duration", + value: -999999999*time.Second - 900*time.Millisecond, + expectedSerial: "INTERVAL '-999999999.9' SECOND", + }, { name: "nil", value: nil, diff --git a/trino/trino.go b/trino/trino.go index d8361c7..a7b9d0c 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -682,7 +682,7 @@ func (st *driverStmt) CheckNamedValue(arg *driver.NamedValue) error { switch arg.Value.(type) { case nil: return nil - case Numeric, trinoDate, trinoTime, trinoTimeTz, trinoTimestamp: + case Numeric, trinoDate, trinoTime, trinoTimeTz, trinoTimestamp, time.Duration: return nil default: {