Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement ExecContext #16

Merged
merged 4 commits into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions trino/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,71 @@ func TestIntegrationQueryParametersSelect(t *testing.T) {
})
}
}

func TestIntegrationExec(t *testing.T) {
db := integrationOpen(t)
defer db.Close()

_, err := db.Query(`SELECT count(*) FROM nation`)
expected := "Schema must be specified when session schema is not set"
if err == nil || !strings.Contains(err.Error(), expected) {
t.Fatalf("Expected to fail to execute query with error: %v, got: %v", expected, err)
}

result, err := db.Exec("USE tpch.sf100")
if err != nil {
t.Fatal("Failed executing query:", err.Error())
}
if result == nil {
t.Fatal("Expected exec result to be not nil")
}

a, err := result.RowsAffected()
if err != nil {
t.Fatal("Expected RowsAffected not to return any error, got:", err)
}
if a != 0 {
t.Fatal("Expected RowsAffected to be zero, got:", a)
}
rows, err := db.Query(`SELECT count(*) FROM nation`)
if err != nil {
t.Fatal("Failed executing query:", err.Error())
}
if rows == nil || !rows.Next() {
t.Fatal("Failed fetching results")
}
}

func TestIntegrationUnsupportedHeader(t *testing.T) {
dsn := integrationServerDSN(t)
dsn += "?catalog=tpch&schema=sf10"
db := integrationOpen(t, dsn)
defer db.Close()
cases := []struct {
query string
err error
}{
{
query: "SET SESSION grouped_execution=true",
err: ErrUnsupportedHeader,
},
{
query: "SET ROLE dummy",
err: ErrUnsupportedHeader,
},
{
query: "SET PATH dummy",
err: errors.New(`trino: query failed (200 OK): "io.prestosql.spi.PrestoException: SET PATH not supported by client"`),
},
{
query: "RESET SESSION grouped_execution",
err: ErrUnsupportedHeader,
},
}
for _, c := range cases {
_, err := db.Query(c.query)
if err == nil || err.Error() != c.err.Error() {
t.Fatal("unexpected error:", err)
}
}
}
156 changes: 122 additions & 34 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ var (

// ErrQueryCancelled indicates that a query has been cancelled.
ErrQueryCancelled = errors.New("trino: query cancelled")

// ErrUnsupportedHeader indicates that the server response contains an unsupported header.
ErrUnsupportedHeader = errors.New("trino: server response contains an unsupported header")
)

const (
Expand All @@ -103,6 +106,12 @@ const (
trinoCatalogHeader = "X-Presto-Catalog"
trinoSchemaHeader = "X-Presto-Schema"
trinoSessionHeader = "X-Presto-Session"
trinoSetCatalogHeader = "X-Presto-Set-Catalog"
trinoSetSchemaHeader = "X-Presto-Set-Schema"
trinoSetPathHeader = "X-Presto-Set-Path"
trinoSetSessionHeader = "X-Presto-Set-Session"
trinoClearSessionHeader = "X-Presto-Clear-Session"
trinoSetRoleHeader = "X-Presto-Set-Role"

KerberosEnabledConfig = "KerberosEnabled"
kerberosKeytabPathConfig = "KerberosKeytabPath"
Expand All @@ -112,6 +121,19 @@ const (
SSLCertPathConfig = "SSLCertPath"
)

var (
responseToRequestHeaderMap = map[string]string{
trinoSetSchemaHeader: trinoSchemaHeader,
trinoSetCatalogHeader: trinoCatalogHeader,
}
unsupportedResponseHeaders = []string{
trinoSetPathHeader,
trinoSetSessionHeader,
trinoClearSessionHeader,
trinoSetRoleHeader,
}
)

type sqldriver struct{}

func (d *sqldriver) Open(name string) (driver.Conn, error) {
Expand Down Expand Up @@ -406,7 +428,7 @@ func (c *Conn) roundTrip(ctx context.Context, req *http.Request) (*http.Response
case <-timer.C:
timeout := DefaultQueryTimeout
if deadline, ok := ctx.Deadline(); ok {
timeout = deadline.Sub(time.Now())
timeout = time.Until(deadline)
}
client := c.httpClient
client.Timeout = timeout
Expand All @@ -416,6 +438,16 @@ func (c *Conn) roundTrip(ctx context.Context, req *http.Request) (*http.Response
}
switch resp.StatusCode {
case http.StatusOK:
for src, dst := range responseToRequestHeaderMap {
if v := resp.Header.Get(src); v != "" {
c.httpHeaders.Set(dst, v)
}
}
for _, name := range unsupportedResponseHeaders {
if v := resp.Header.Get(name); v != "" {
return nil, ErrUnsupportedHeader
}
}
return resp, nil
case http.StatusServiceUnavailable:
resp.Body.Close()
Expand Down Expand Up @@ -470,6 +502,7 @@ type driverStmt struct {
var (
_ driver.Stmt = &driverStmt{}
_ driver.StmtQueryContext = &driverStmt{}
_ driver.StmtExecContext = &driverStmt{}
)

func (st *driverStmt) Close() error {
Expand All @@ -481,15 +514,38 @@ func (st *driverStmt) NumInput() int {
}

func (st *driverStmt) Exec(args []driver.Value) (driver.Result, error) {
return nil, ErrOperationNotSupported
return nil, driver.ErrSkip
}

func (st *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
sr, err := st.exec(ctx, args)
if err != nil {
return nil, err
}
rows := &driverRows{
ctx: ctx,
stmt: st,
nextURI: sr.NextURI,
rowsAffected: sr.UpdateCount,
}
// consume all results, if there are any
for err == nil {
err = rows.fetch(true)
}
if err != nil && err != io.EOF {
return nil, err
}
return rows, nil
}

type stmtResponse struct {
ID string `json:"id"`
InfoURI string `json:"infoUri"`
NextURI string `json:"nextUri"`
Stats stmtStats `json:"stats"`
Error stmtError `json:"error"`
ID string `json:"id"`
InfoURI string `json:"infoUri"`
NextURI string `json:"nextUri"`
Stats stmtStats `json:"stats"`
Error stmtError `json:"error"`
UpdateType string `json:"updateType"`
UpdateCount int64 `json:"updateCount"`
}

type stmtStats struct {
Expand Down Expand Up @@ -553,6 +609,22 @@ func (st *driverStmt) Query(args []driver.Value) (driver.Rows, error) {
}

func (st *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
sr, err := st.exec(ctx, args)
if err != nil {
return nil, err
}
rows := &driverRows{
ctx: ctx,
stmt: st,
nextURI: sr.NextURI,
}
if err = rows.fetch(false); err != nil {
return nil, err
}
return rows, nil
}

func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmtResponse, error) {
query := st.query
var hs http.Header

Expand Down Expand Up @@ -588,6 +660,7 @@ func (st *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue
if err != nil {
return nil, err
}

defer resp.Body.Close()
var sr stmtResponse
d := json.NewDecoder(resp.Body)
Expand All @@ -596,35 +669,26 @@ func (st *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue
if err != nil {
return nil, fmt.Errorf("trino: %v", err)
}
err = handleResponseError(resp.StatusCode, sr.Error)
if err != nil {
return nil, err
}
rows := &driverRows{
ctx: ctx,
stmt: st,
nextURI: sr.NextURI,
}
if err = rows.fetch(false); err != nil {
return nil, err
}
return rows, nil
return &sr, handleResponseError(resp.StatusCode, sr.Error)
}

type driverRows struct {
ctx context.Context
stmt *driverStmt
nextURI string

err error
rowindex int
columns []string
coltype []*typeConverter
data []queryData
err error
rowindex int
columns []string
coltype []*typeConverter
data []queryData
rowsAffected int64
}

var _ driver.Rows = &driverRows{}
var _ driver.Result = &driverRows{}

// Close closes the rows iterator.
func (qr *driverRows) Close() error {
if qr.nextURI != "" {
hs := make(http.Header)
Expand Down Expand Up @@ -652,6 +716,7 @@ func (qr *driverRows) Close() error {
return qr.err
}

// Columns returns the names of the columns.
func (qr *driverRows) Columns() []string {
if qr.err != nil {
return []string{}
Expand All @@ -675,6 +740,11 @@ func (qr *driverRows) ColumnTypeDatabaseTypeName(index int) string {
return name
}

// Next is called to populate the next row of data into
// the provided slice. The provided slice will be the same
// size as the Columns() are wide.
//
// Next should return io.EOF when there are no more rows.
func (qr *driverRows) Next(dest []driver.Value) error {
if qr.err != nil {
return qr.err
Expand Down Expand Up @@ -705,6 +775,18 @@ func (qr *driverRows) Next(dest []driver.Value) error {
return nil
}

// LastInsertId returns the database's auto-generated ID
// after, for example, an INSERT into a table with primary
// key.
func (qr driverRows) LastInsertId() (int64, error) {
return 0, ErrOperationNotSupported
}

// RowsAffected returns the number of rows affected by the query.
func (qr driverRows) RowsAffected() (int64, error) {
return qr.rowsAffected, qr.err
}

type queryResponse struct {
ID string `json:"id"`
InfoURI string `json:"infoUri"`
Expand All @@ -714,6 +796,8 @@ type queryResponse struct {
Data []queryData `json:"data"`
Stats stmtStats `json:"stats"`
Error stmtError `json:"error"`
UpdateType string `json:"updateType"`
UpdateCount int64 `json:"updateCount"`
}

type queryColumn struct {
Expand All @@ -730,11 +814,6 @@ type typeSignature struct {
LiteralArguments []interface{} `json:"literalArguments"`
}

type infoResponse struct {
QueryID string `json:"queryId"`
State string `json:"state"`
}

func handleResponseError(status int, respErr stmtError) error {
switch respErr.ErrorName {
case "":
Expand All @@ -750,6 +829,12 @@ func handleResponseError(status int, respErr stmtError) error {
}

func (qr *driverRows) fetch(allowEOF bool) error {
if qr.nextURI == "" {
if allowEOF {
return io.EOF
}
return nil
}
hs := make(http.Header)
hs.Add(trinoUserHeader, qr.stmt.user)
req, err := qr.stmt.conn.newRequest("GET", qr.nextURI, nil, hs)
Expand All @@ -772,6 +857,7 @@ func (qr *driverRows) fetch(allowEOF bool) error {
if err != nil {
return err
}

qr.rowindex = 0
qr.data = qresp.Data
qr.nextURI = qresp.NextURI
Expand All @@ -786,6 +872,7 @@ func (qr *driverRows) fetch(allowEOF bool) error {
if qr.columns == nil && len(qresp.Columns) > 0 {
qr.initColumns(&qresp)
}
qr.rowsAffected = qresp.UpdateCount
return nil
}

Expand Down Expand Up @@ -1359,11 +1446,11 @@ type NullTime struct {

// Scan implements the sql.Scanner interface.
func (s *NullTime) Scan(value interface{}) error {
switch value.(type) {
switch t := value.(type) {
case time.Time:
s.Time, s.Valid = value.(time.Time)
s.Time, s.Valid = t, true
case NullTime:
*s = value.(NullTime)
*s = t
}
return nil
}
Expand Down Expand Up @@ -1488,7 +1575,8 @@ func (s *NullSliceMap) Scan(value interface{}) error {
return fmt.Errorf("cannot convert %v (%T) to []NullMap", value, value)
}
m := NullMap{}
m.Scan(vs[i])
// this scan can never fail
_ = m.Scan(vs[i])
slice[i] = m
}
s.SliceMap = slice
Expand Down
Loading