Skip to content

Commit

Permalink
Return scan types in rows.ColumnTypes()
Browse files Browse the repository at this point in the history
  • Loading branch information
nineinchnick authored and losipiuk committed Jul 21, 2022
1 parent cc992c5 commit 7410412
Show file tree
Hide file tree
Showing 3 changed files with 743 additions and 42 deletions.
13 changes: 10 additions & 3 deletions trino/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,12 @@ func TestIntegrationSessionProperties(t *testing.T) {
}

func TestIntegrationTypeConversion(t *testing.T) {
err := RegisterCustomClient("uncompressed", &http.Client{Transport: &http.Transport{DisableCompression: true}})
if err != nil {
t.Fatal(err)
}
dsn := *integrationServerFlag
dsn += "?session_properties=parse_decimal_literals_as_double=true"
dsn += "?session_properties=parse_decimal_literals_as_double=true&custom_client=uncompressed"
db := integrationOpen(t, dsn)
var (
goTime time.Time
Expand All @@ -351,8 +355,9 @@ func TestIntegrationTypeConversion(t *testing.T) {
nullFloat64Slice3 NullSlice3Float64
goMap map[string]interface{}
nullMap NullMap
goRow []interface{}
)
err := db.QueryRow(`
err = db.QueryRow(`
SELECT
TIMESTAMP '2017-07-10 01:02:03.004 UTC',
CAST(NULL AS TIMESTAMP),
Expand All @@ -368,7 +373,8 @@ func TestIntegrationTypeConversion(t *testing.T) {
ARRAY[ARRAY[1.1, 1.1, 1.1], NULL],
ARRAY[ARRAY[ARRAY[1.1, 1.1, 1.1], NULL], NULL],
MAP(ARRAY['a', 'b'], ARRAY['c', 'd']),
CAST(NULL AS MAP(ARRAY(INTEGER), ARRAY(INTEGER)))
CAST(NULL AS MAP(ARRAY(INTEGER), ARRAY(INTEGER))),
ROW(1, 'a', CAST('2017-07-10 01:02:03.004 UTC' AS TIMESTAMP(6) WITH TIME ZONE), ARRAY['c'])
`).Scan(
&goTime,
&nullTime,
Expand All @@ -385,6 +391,7 @@ func TestIntegrationTypeConversion(t *testing.T) {
&nullFloat64Slice3,
&goMap,
&nullMap,
&goRow,
)
if err != nil {
t.Fatal(err)
Expand Down
259 changes: 230 additions & 29 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ import (
"net/http"
"net/url"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -98,6 +97,9 @@ var (

// ErrUnsupportedHeader indicates that the server response contains an unsupported header.
ErrUnsupportedHeader = errors.New("trino: server response contains an unsupported header")

// ErrInvalidResponseType indicates that the server returned an invalid type definition.
ErrInvalidResponseType = errors.New("trino: server response contains an invalid type")
)

const (
Expand Down Expand Up @@ -752,6 +754,10 @@ type driverRows struct {

var _ driver.Rows = &driverRows{}
var _ driver.Result = &driverRows{}
var _ driver.RowsColumnTypeScanType = &driverRows{}
var _ driver.RowsColumnTypeDatabaseTypeName = &driverRows{}
var _ driver.RowsColumnTypeLength = &driverRows{}
var _ driver.RowsColumnTypePrecisionScale = &driverRows{}

// Close closes the rows iterator.
func (qr *driverRows) Close() error {
Expand Down Expand Up @@ -796,14 +802,24 @@ func (qr *driverRows) Columns() []string {
return qr.columns
}

var coltypeLengthSuffix = regexp.MustCompile(`\(\d+\)$`)

func (qr *driverRows) ColumnTypeDatabaseTypeName(index int) string {
name := qr.coltype[index].typeName
if m := coltypeLengthSuffix.FindStringSubmatch(name); m != nil {
name = name[0 : len(name)-len(m[0])]
typeName := qr.coltype[index].parsedType[0]
if typeName == "map" || typeName == "array" || typeName == "row" {
typeName = qr.coltype[index].typeName
}
return name
return strings.ToUpper(typeName)
}

func (qr *driverRows) ColumnTypeScanType(index int) reflect.Type {
return qr.coltype[index].scanType
}

func (qr *driverRows) ColumnTypeLength(index int) (int64, bool) {
return qr.coltype[index].size.value, qr.coltype[index].size.hasValue
}

func (qr *driverRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
return qr.coltype[index].precision.value, qr.coltype[index].scale.value, qr.coltype[index].precision.hasValue
}

// Next is called to populate the next row of data into
Expand Down Expand Up @@ -874,10 +890,35 @@ type queryColumn struct {

type queryData []interface{}

type namedTypeSignature struct {
FieldName string `json:"fieldName"`
TypeSignature typeSignature `json:"typeSignature"`
}

type typeSignature struct {
RawType string `json:"rawType"`
TypeArguments []interface{} `json:"typeArguments"`
LiteralArguments []interface{} `json:"literalArguments"`
RawType string `json:"rawType"`
Arguments []typeArgument `json:"arguments"`
}

type typeKind string

const (
KIND_TYPE = typeKind("TYPE")
KIND_NAMED_TYPE = typeKind("NAMED_TYPE")
KIND_LONG = typeKind("LONG")
KIND_VARIABLE = typeKind("VARIABLE")
)

type typeArgument struct {
// Kind determines if the typeSignature, namedTypeSignature, or long field has a value
Kind typeKind `json:"kind"`
Value json.RawMessage `json:"value"`
// typeSignature decoded from Value when Kind is TYPE
typeSignature typeSignature
// namedTypeSignature decoded from Value when Kind is NAMED_TYPE
namedTypeSignature namedTypeSignature
// long decoded from Value when Kind is LONG
long int64
}

func handleResponseError(status int, respErr stmtError) error {
Expand Down Expand Up @@ -941,48 +982,208 @@ func (qr *driverRows) fetch(allowEOF bool) error {
}
}
if qr.columns == nil && len(qresp.Columns) > 0 {
qr.initColumns(&qresp)
for i := range qresp.Columns {
err = unmarshalArguments(&(qresp.Columns[i].TypeSignature))
if err != nil {
return fmt.Errorf("error decoding column type signature: %w", err)
}
}
err = qr.initColumns(&qresp)
if err != nil {
return err
}
}
qr.rowsAffected = qresp.UpdateCount
return nil
}

func (qr *driverRows) initColumns(qresp *queryResponse) {
func unmarshalArguments(signature *typeSignature) error {
for i, argument := range signature.Arguments {
var payload interface{}
switch argument.Kind {
case KIND_TYPE:
payload = &(signature.Arguments[i].typeSignature)
case KIND_NAMED_TYPE:
payload = &(signature.Arguments[i].namedTypeSignature)
case KIND_LONG:
payload = &(signature.Arguments[i].long)
}
err := json.Unmarshal(argument.Value, payload)
if err != nil {
return err
}
switch argument.Kind {
case KIND_TYPE:
err = unmarshalArguments(&(signature.Arguments[i].typeSignature))
case KIND_NAMED_TYPE:
err = unmarshalArguments(&(signature.Arguments[i].namedTypeSignature.TypeSignature))
}
if err != nil {
return err
}
}
return nil
}

func (qr *driverRows) initColumns(qresp *queryResponse) error {
qr.columns = make([]string, len(qresp.Columns))
qr.coltype = make([]*typeConverter, len(qresp.Columns))
var err error
for i, col := range qresp.Columns {
qr.columns[i] = col.Name
qr.coltype[i] = newTypeConverter(col.Type)
qr.coltype[i], err = newTypeConverter(col.Type, col.TypeSignature)
if err != nil {
return err
}
}
return nil
}

type typeConverter struct {
typeName string
parsedType []string // e.g. array, array, varchar, for [][]string
parsedType []string
scanType reflect.Type
precision optionalInt64
scale optionalInt64
size optionalInt64
}

func newTypeConverter(typeName string) *typeConverter {
return &typeConverter{
type optionalInt64 struct {
value int64
hasValue bool
}

func newOptionalInt64(value int64) optionalInt64 {
return optionalInt64{value: value, hasValue: true}
}

func newTypeConverter(typeName string, signature typeSignature) (*typeConverter, error) {
result := &typeConverter{
typeName: typeName,
parsedType: parseType(typeName),
parsedType: getNestedTypes([]string{}, signature),
}
var err error
result.scanType, err = getScanType(result.parsedType)
if err != nil {
return nil, err
}
switch signature.RawType {
case "char", "varchar":
if len(signature.Arguments) > 0 {
if signature.Arguments[0].Kind != KIND_LONG {
return nil, ErrInvalidResponseType
}
result.size = newOptionalInt64(signature.Arguments[0].long)
}
case "decimal":
if len(signature.Arguments) > 0 {
if signature.Arguments[0].Kind != KIND_LONG {
return nil, ErrInvalidResponseType
}
result.precision = newOptionalInt64(signature.Arguments[0].long)
}
if len(signature.Arguments) > 1 {
if signature.Arguments[1].Kind != KIND_LONG {
return nil, ErrInvalidResponseType
}
result.scale = newOptionalInt64(signature.Arguments[1].long)
}
}
return result, nil
}

// parses Trino types, e.g. array(varchar(10)) to "array", "varchar"
// TODO: Use queryColumn.TypeSignature instead.
func parseType(name string) []string {
parts := strings.Split(strings.ToLower(name), "(")
if len(parts) == 1 {
return parts
func getNestedTypes(types []string, signature typeSignature) []string {
types = append(types, signature.RawType)
if len(signature.Arguments) == 1 {
switch signature.Arguments[0].Kind {
case KIND_TYPE:
types = getNestedTypes(types, signature.Arguments[0].typeSignature)
case KIND_NAMED_TYPE:
types = getNestedTypes(types, signature.Arguments[0].namedTypeSignature.TypeSignature)
}
}
last := len(parts) - 1
parts[last] = strings.TrimRight(parts[last], ")")
if len(parts[last]) > 0 {
if _, err := strconv.Atoi(parts[last]); err == nil {
parts = parts[:last]
return types
}

func getScanType(typeNames []string) (reflect.Type, error) {
var v interface{}
switch typeNames[0] {
case "boolean":
v = sql.NullBool{}
case "json", "char", "varchar", "varbinary", "interval year to month", "interval day to second", "decimal", "ipaddress", "unknown":
v = sql.NullString{}
case "tinyint", "smallint":
v = sql.NullInt32{}
case "integer":
v = sql.NullInt32{}
case "bigint":
v = sql.NullInt64{}
case "real", "double":
v = sql.NullFloat64{}
case "date", "time", "time with time zone", "timestamp", "timestamp with time zone":
v = sql.NullTime{}
case "map":
v = NullMap{}
case "array":
if len(typeNames) <= 1 {
return nil, ErrInvalidResponseType
}
switch typeNames[1] {
case "boolean":
v = NullSliceBool{}
case "json", "char", "varchar", "varbinary", "interval year to month", "interval day to second", "decimal", "ipaddress", "unknown":
v = NullSliceString{}
case "tinyint", "smallint", "integer", "bigint":
v = NullSliceInt64{}
case "real", "double":
v = NullSliceFloat64{}
case "date", "time", "time with time zone", "timestamp", "timestamp with time zone":
v = NullSliceTime{}
case "map":
v = NullSliceMap{}
case "array":
if len(typeNames) <= 2 {
return nil, ErrInvalidResponseType
}
switch typeNames[2] {
case "boolean":
v = NullSlice2Bool{}
case "json", "char", "varchar", "varbinary", "interval year to month", "interval day to second", "decimal", "ipaddress", "unknown":
v = NullSlice2String{}
case "tinyint", "smallint", "integer", "bigint":
v = NullSlice2Int64{}
case "real", "double":
v = NullSlice2Float64{}
case "date", "time", "time with time zone", "timestamp", "timestamp with time zone":
v = NullSlice2Time{}
case "map":
v = NullSlice2Map{}
case "array":
if len(typeNames) <= 3 {
return nil, ErrInvalidResponseType
}
switch typeNames[3] {
case "boolean":
v = NullSlice3Bool{}
case "json", "char", "varchar", "varbinary", "interval year to month", "interval day to second", "decimal", "ipaddress", "unknown":
v = NullSlice3String{}
case "tinyint", "smallint", "integer", "bigint":
v = NullSlice3Int64{}
case "real", "double":
v = NullSlice3Float64{}
case "date", "time", "time with time zone", "timestamp", "timestamp with time zone":
v = NullSlice3Time{}
case "map":
v = NullSlice3Map{}
}
// if this is a 4 or more dimensional array, scan type will be an empty interface
}
}
}
return parts
if v == nil {
return reflect.TypeOf(new(interface{})).Elem(), nil
}
return reflect.TypeOf(v), nil
}

// ConvertValue implements the driver.ValueConverter interface.
Expand Down
Loading

0 comments on commit 7410412

Please sign in to comment.