diff --git a/trino/trino.go b/trino/trino.go index 3fe4a28..ef97e44 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -1023,29 +1023,9 @@ func (qr *driverRows) fetch(allowEOF bool) error { } return nil } - hs := make(http.Header) - hs.Add(trinoUserHeader, qr.stmt.user) - req, err := qr.stmt.conn.newRequest("GET", qr.nextURI, nil, hs) - if err != nil { - return err - } - resp, err := qr.stmt.conn.roundTrip(qr.ctx, req) - if err != nil { - if qr.ctx.Err() == context.Canceled { - qr.Close() - return err - } - return err - } - defer resp.Body.Close() + var qresp queryResponse - d := json.NewDecoder(resp.Body) - d.UseNumber() - err = d.Decode(&qresp) - if err != nil { - return fmt.Errorf("trino: %v", err) - } - err = handleResponseError(resp.StatusCode, qresp.Error) + err := qr.executeFetchRequest(&qresp) if err != nil { return err } @@ -1081,6 +1061,32 @@ func (qr *driverRows) fetch(allowEOF bool) error { return nil } +func (qr *driverRows) executeFetchRequest(qresp *queryResponse) error { + hs := make(http.Header) + hs.Add(trinoUserHeader, qr.stmt.user) + req, err := qr.stmt.conn.newRequest("GET", qr.nextURI, nil, hs) + if err != nil { + return err + } + resp, err := qr.stmt.conn.roundTrip(qr.ctx, req) + if err != nil { + if qr.ctx.Err() == context.Canceled { + qr.Close() + return err + } + return err + } + defer resp.Body.Close() + + d := json.NewDecoder(resp.Body) + d.UseNumber() + err = d.Decode(&qresp) + if err != nil { + return fmt.Errorf("trino: %v", err) + } + return handleResponseError(resp.StatusCode, qresp.Error) +} + func unmarshalArguments(signature *typeSignature) error { for i, argument := range signature.Arguments { var payload interface{}