From 27054e07bec784f809cd5ddaa8bdacd839a38134 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wa=C5=9B?= Date: Wed, 24 Aug 2022 23:44:24 +0200 Subject: [PATCH] Fetch and decode query results concurrently --- trino/trino.go | 212 +++++++++++++++++++++++++++++--------------- trino/trino_test.go | 1 + 2 files changed, 142 insertions(+), 71 deletions(-) diff --git a/trino/trino.go b/trino/trino.go index 8f8a7e3..549231a 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -547,11 +547,15 @@ func newErrQueryFailedFromResponse(resp *http.Response) *ErrQueryFailed { } type driverStmt struct { - conn *Conn - query string - user string - statsCh chan QueryProgressInfo - doneCh chan struct{} + conn *Conn + query string + user string + nextURIs chan string + httpResponses chan *http.Response + queryResponses chan queryResponse + statsCh chan QueryProgressInfo + errors chan error + doneCh chan struct{} } var ( @@ -563,12 +567,26 @@ var ( // Close closes statement just before releasing connection func (st *driverStmt) Close() error { - if st.doneCh != nil { - close(st.doneCh) + if st.doneCh == nil { + return nil } + close(st.doneCh) if st.statsCh != nil { <-st.statsCh + st.statsCh = nil + } + go func() { + // drain errors chan to allow goroutines to write to it + for range st.errors { + } + }() + for range st.queryResponses { + } + for range st.httpResponses { } + close(st.nextURIs) + close(st.errors) + st.doneCh = nil return nil } @@ -596,7 +614,7 @@ func (st *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue) } // consume all results, if there are any for err == nil { - err = rows.fetch(true) + err = rows.fetch() } if err != nil && err != io.EOF { @@ -707,7 +725,7 @@ func (st *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue statsCh: st.statsCh, doneCh: st.doneCh, } - if err = rows.fetch(false); err != nil { + if err = rows.fetch(); err != nil && err != io.EOF { return nil, err } return rows, nil @@ -780,9 +798,89 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt return nil, fmt.Errorf("trino: %v", err) } + st.doneCh = make(chan struct{}) + st.nextURIs = make(chan string) + st.httpResponses = make(chan *http.Response) + st.queryResponses = make(chan queryResponse) + st.errors = make(chan error) + go func() { + defer close(st.httpResponses) + for { + select { + case nextURI := <-st.nextURIs: + if nextURI == "" { + return + } + hs := make(http.Header) + hs.Add(trinoUserHeader, st.user) + req, err := st.conn.newRequest("GET", nextURI, nil, hs) + if err != nil { + st.errors <- err + return + } + resp, err := st.conn.roundTrip(ctx, req) + if err != nil { + if ctx.Err() == context.Canceled { + st.errors <- context.Canceled + return + } + st.errors <- err + return + } + select { + case st.httpResponses <- resp: + case <-st.doneCh: + return + } + case <-st.doneCh: + return + } + } + }() + go func() { + defer close(st.queryResponses) + for { + select { + case resp := <-st.httpResponses: + if resp == nil { + return + } + var qresp queryResponse + d := json.NewDecoder(resp.Body) + d.UseNumber() + err = d.Decode(&qresp) + if err != nil { + st.errors <- fmt.Errorf("trino: %v", err) + return + } + err = resp.Body.Close() + if err != nil { + st.errors <- err + return + } + err = handleResponseError(resp.StatusCode, qresp.Error) + if err != nil { + st.errors <- err + return + } + select { + case st.nextURIs <- qresp.NextURI: + case <-st.doneCh: + return + } + select { + case st.queryResponses <- qresp: + case <-st.doneCh: + return + } + case <-st.doneCh: + return + } + } + }() + st.nextURIs <- sr.NextURI if st.conn.progressUpdater != nil { st.statsCh = make(chan QueryProgressInfo) - st.doneCh = make(chan struct{}) // progress updater go func go func() { @@ -810,7 +908,6 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt st.conn.progressUpdaterPeriod.LastCallbackTime = time.Now() st.conn.progressUpdaterPeriod.LastQueryState = sr.Stats.State } - return &sr, handleResponseError(resp.StatusCode, sr.Error) } @@ -873,7 +970,7 @@ func (qr *driverRows) Columns() []string { return []string{} } if qr.columns == nil { - if err := qr.fetch(false); err != nil { + if err := qr.fetch(); err != nil && err != io.EOF { qr.err = err return []string{} } @@ -915,7 +1012,7 @@ func (qr *driverRows) Next(dest []driver.Value) error { qr.err = io.EOF return qr.err } - if err := qr.fetch(true); err != nil { + if err := qr.fetch(); err != nil { qr.err = err return err } @@ -925,6 +1022,9 @@ func (qr *driverRows) Next(dest []driver.Value) error { return qr.err } for i, v := range qr.coltype { + if i > len(dest)-1 { + break + } vv, err := v.ConvertValue(qr.data[qr.rowindex][i]) if err != nil { qr.err = err @@ -945,7 +1045,7 @@ func (qr driverRows) LastInsertId() (int64, error) { // RowsAffected returns the number of rows affected by the query. func (qr driverRows) RowsAffected() (int64, error) { - return qr.rowsAffected, qr.err + return qr.rowsAffected, nil } type queryResponse struct { @@ -1014,71 +1114,34 @@ func handleResponseError(status int, respErr stmtError) error { } } -func (qr *driverRows) fetch(allowEOF bool) error { - if qr.nextURI == "" { - if allowEOF { - return io.EOF - } - return nil - } - - for qr.nextURI != "" { - var qresp queryResponse - err := qr.executeFetchRequest(&qresp) - if err != nil { - return err - } - - qr.rowindex = 0 - qr.data = qresp.Data - qr.nextURI = qresp.NextURI - qr.rowsAffected = qresp.UpdateCount - qr.scheduleProgressUpdate(qresp.ID, qresp.Stats) - - if len(qr.data) == 0 { - if qr.nextURI != "" { - continue - } - if allowEOF { - qr.err = io.EOF - return qr.err +func (qr *driverRows) fetch() error { + var qresp queryResponse + var err error + for { + select { + case qresp = <-qr.stmt.queryResponses: + if qresp.ID == "" { + return io.EOF } - } - if qr.columns == nil && len(qresp.Columns) > 0 { err = qr.initColumns(&qresp) if err != nil { return err } - } - return nil - } - return nil -} - -func (qr *driverRows) executeFetchRequest(qresp *queryResponse) error { - hs := make(http.Header) - hs.Add(trinoUserHeader, qr.stmt.user) - req, err := qr.stmt.conn.newRequest("GET", qr.nextURI, nil, hs) - if err != nil { - return err - } - resp, err := qr.stmt.conn.roundTrip(qr.ctx, req) - if err != nil { - if qr.ctx.Err() == context.Canceled { - qr.Close() + qr.rowindex = 0 + qr.data = qresp.Data + qr.rowsAffected = qresp.UpdateCount + qr.scheduleProgressUpdate(qresp.ID, qresp.Stats) + if len(qr.data) != 0 { + return nil + } + case err = <-qr.stmt.errors: + if err == context.Canceled { + qr.Close() + } + qr.err = err return err } - return err - } - defer resp.Body.Close() - - d := json.NewDecoder(resp.Body) - d.UseNumber() - err = d.Decode(&qresp) - if err != nil { - return fmt.Errorf("trino: %v", err) } - return handleResponseError(resp.StatusCode, qresp.Error) } func unmarshalArguments(signature *typeSignature) error { @@ -1110,6 +1173,9 @@ func unmarshalArguments(signature *typeSignature) error { } func (qr *driverRows) initColumns(qresp *queryResponse) error { + if qr.columns != nil || len(qresp.Columns) == 0 { + return nil + } var err error for i := range qresp.Columns { err = unmarshalArguments(&(qresp.Columns[i].TypeSignature)) @@ -1120,6 +1186,10 @@ func (qr *driverRows) initColumns(qresp *queryResponse) error { qr.columns = make([]string, len(qresp.Columns)) qr.coltype = make([]*typeConverter, len(qresp.Columns)) for i, col := range qresp.Columns { + err = unmarshalArguments(&(qresp.Columns[i].TypeSignature)) + if err != nil { + return fmt.Errorf("error decoding column type signature: %w", err) + } qr.columns[i] = col.Name qr.coltype[i], err = newTypeConverter(col.Type, col.TypeSignature) if err != nil { diff --git a/trino/trino_test.go b/trino/trino_test.go index 25b78a3..f3077d3 100644 --- a/trino/trino_test.go +++ b/trino/trino_test.go @@ -800,6 +800,7 @@ func TestFetchNoStackOverflow(t *testing.T) { if buf == nil { buf = new(bytes.Buffer) json.NewEncoder(buf).Encode(&stmtResponse{ + ID: "fake-query", NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1", }) }