Skip to content

Commit

Permalink
Convert recursion in fetch into loop
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet authored and losipiuk committed Aug 31, 2022
1 parent 9b9640d commit 1de9316
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 23 deletions.
48 changes: 25 additions & 23 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -1024,34 +1024,36 @@ func (qr *driverRows) fetch(allowEOF bool) error {
return nil
}

var qresp queryResponse
err := qr.executeFetchRequest(&qresp)
if err != nil {
return err
}

qr.rowindex = 0
qr.data = qresp.Data
qr.nextURI = qresp.NextURI
for qr.nextURI != "" {
var qresp queryResponse
err := qr.executeFetchRequest(&qresp)
if err != nil {
return err
}

if len(qr.data) == 0 {
qr.rowindex = 0
qr.data = qresp.Data
qr.nextURI = qresp.NextURI
qr.rowsAffected = qresp.UpdateCount
qr.scheduleProgressUpdate(qresp.ID, qresp.Stats)
if qr.nextURI != "" {
return qr.fetch(allowEOF)
}
if allowEOF {
qr.err = io.EOF
return qr.err

if len(qr.data) == 0 {
if qr.nextURI != "" {
continue
}
if allowEOF {
qr.err = io.EOF
return qr.err
}
}
}
if qr.columns == nil && len(qresp.Columns) > 0 {
err = qr.initColumns(&qresp)
if err != nil {
return err
if qr.columns == nil && len(qresp.Columns) > 0 {
err = qr.initColumns(&qresp)
if err != nil {
return err
}
}
return nil
}
qr.rowsAffected = qresp.UpdateCount
qr.scheduleProgressUpdate(qresp.ID, qresp.Stats)
return nil
}

Expand Down
46 changes: 46 additions & 0 deletions trino/trino_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package trino

import (
"bytes"
"context"
"database/sql"
"encoding/json"
Expand All @@ -23,6 +24,7 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"runtime/debug"
"sort"
"testing"
"time"
Expand Down Expand Up @@ -782,6 +784,50 @@ func TestQueryFailure(t *testing.T) {
assert.IsTypef(t, new(ErrQueryFailed), err, "unexpected error: %w", err)
}

// This test ensures that the fetch method is not generating stack overflow errors.
// === RUN TestFetchNoStackOverflow
// runtime: goroutine stack exceeds 1000000000-byte limit
// runtime: sp=0x14037b00390 stack=[0x14037b00000, 0x14057b00000]
// fatal error: stack overflow
func TestFetchNoStackOverflow(t *testing.T) {
previousSetting := debug.SetMaxStack(50 * 1024)
defer debug.SetMaxStack(previousSetting)
count := 0
var buf *bytes.Buffer
var ts *httptest.Server
ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if count <= 50 {
if buf == nil {
buf = new(bytes.Buffer)
json.NewEncoder(buf).Encode(&stmtResponse{
NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1",
})
}
w.WriteHeader(http.StatusOK)
w.Write(buf.Bytes())
count++
return
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(&stmtResponse{
Error: stmtError{
ErrorName: "TEST",
},
})
}))

db, err := sql.Open("trino", ts.URL)
require.NoError(t, err)

t.Cleanup(func() {
assert.NoError(t, db.Close())
})

_, err = db.Query("SELECT 1")
assert.IsTypef(t, new(ErrQueryFailed), err, "unexpected error: %w", err)

}

func TestSession(t *testing.T) {
if testing.Short() {
t.Skip("Skipping test in short mode.")
Expand Down

0 comments on commit 1de9316

Please sign in to comment.