Skip to content

Commit

Permalink
add progress callback
Browse files Browse the repository at this point in the history
  • Loading branch information
Chaho12 authored and losipiuk committed Aug 26, 2022
1 parent 33ced97 commit eaed16c
Show file tree
Hide file tree
Showing 2 changed files with 241 additions and 22 deletions.
167 changes: 145 additions & 22 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ var (

// ErrInvalidResponseType indicates that the server returned an invalid type definition.
ErrInvalidResponseType = errors.New("trino: server response contains an invalid type")

// ErrInvalidProgressCallbackHeader indicates that server did not get valid headers for progress callback
ErrInvalidProgressCallbackHeader = errors.New("trino: both " + trinoProgressCallbackParam + " and " + trinoProgressCallbackPeriodParam + " must be set when using progress callback")
)

const (
Expand All @@ -121,6 +124,9 @@ const (
trinoSetRoleHeader = trinoHeaderPrefix + `Set-Role`
trinoExtraCredentialHeader = trinoHeaderPrefix + `Extra-Credential`

trinoProgressCallbackParam = trinoHeaderPrefix + `Progress-Callback`
trinoProgressCallbackPeriodParam = trinoHeaderPrefix + `Progress-Callback-Period`

trinoAddedPrepareHeader = trinoHeaderPrefix + `Added-Prepare`
trinoDeallocatedPrepareHeader = trinoHeaderPrefix + `Deallocated-Prepare`

Expand Down Expand Up @@ -232,12 +238,14 @@ func (c *Config) FormatDSN() (string, error) {

// Conn is a Trino connection.
type Conn struct {
baseURL string
auth *url.Userinfo
httpClient http.Client
httpHeaders http.Header
kerberosClient client.Client
kerberosEnabled bool
baseURL string
auth *url.Userinfo
httpClient http.Client
httpHeaders http.Header
kerberosClient client.Client
kerberosEnabled bool
progressUpdater ProgressUpdater
progressUpdaterPeriod queryProgressCallbackPeriod
}

var (
Expand Down Expand Up @@ -541,9 +549,11 @@ func newErrQueryFailedFromResponse(resp *http.Response) *ErrQueryFailed {
}

type driverStmt struct {
conn *Conn
query string
user string
conn *Conn
query string
user string
statsCh chan QueryProgressInfo
doneCh chan struct{}
}

var (
Expand All @@ -553,7 +563,14 @@ var (
_ driver.NamedValueChecker = &driverStmt{}
)

// Close closes statement just before releasing connection
func (st *driverStmt) Close() error {
if st.doneCh != nil {
close(st.doneCh)
}
if st.statsCh != nil {
<-st.statsCh
}
return nil
}

Expand All @@ -576,11 +593,14 @@ func (st *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue)
queryID: sr.ID,
nextURI: sr.NextURI,
rowsAffected: sr.UpdateCount,
statsCh: st.statsCh,
doneCh: st.doneCh,
}
// consume all results, if there are any
for err == nil {
err = rows.fetch(true)
}

if err != nil && err != io.EOF {
return nil, err
}
Expand All @@ -595,6 +615,13 @@ func (st *driverStmt) CheckNamedValue(arg *driver.NamedValue) error {
if reflect.TypeOf(arg.Value).Kind() == reflect.Slice {
return nil
}

if arg.Name == trinoProgressCallbackParam {
return nil
}
if arg.Name == trinoProgressCallbackPeriodParam {
return nil
}
return driver.ErrSkip
}

Expand All @@ -609,19 +636,20 @@ type stmtResponse struct {
}

type stmtStats struct {
State string `json:"state"`
Scheduled bool `json:"scheduled"`
Nodes int `json:"nodes"`
TotalSplits int `json:"totalSplits"`
QueuesSplits int `json:"queuedSplits"`
RunningSplits int `json:"runningSplits"`
CompletedSplits int `json:"completedSplits"`
UserTimeMillis int `json:"userTimeMillis"`
CPUTimeMillis int `json:"cpuTimeMillis"`
WallTimeMillis int `json:"wallTimeMillis"`
ProcessedRows int `json:"processedRows"`
ProcessedBytes int `json:"processedBytes"`
RootStage stmtStage `json:"rootStage"`
State string `json:"state"`
Scheduled bool `json:"scheduled"`
Nodes int `json:"nodes"`
TotalSplits int `json:"totalSplits"`
QueuesSplits int `json:"queuedSplits"`
RunningSplits int `json:"runningSplits"`
CompletedSplits int `json:"completedSplits"`
UserTimeMillis int `json:"userTimeMillis"`
CPUTimeMillis int `json:"cpuTimeMillis"`
WallTimeMillis int `json:"wallTimeMillis"`
ProcessedRows int `json:"processedRows"`
ProcessedBytes int `json:"processedBytes"`
RootStage stmtStage `json:"rootStage"`
ProgressPercentage float32 `json:"progressPercentage"`
}

type stmtError struct {
Expand Down Expand Up @@ -678,6 +706,8 @@ func (st *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue
stmt: st,
queryID: sr.ID,
nextURI: sr.NextURI,
statsCh: st.statsCh,
doneCh: st.doneCh,
}
if err = rows.fetch(false); err != nil {
return nil, err
Expand All @@ -693,6 +723,15 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
hs = make(http.Header)
var ss []string
for _, arg := range args {
if arg.Name == trinoProgressCallbackParam {
st.conn.progressUpdater = arg.Value.(ProgressUpdater)
continue
}
if arg.Name == trinoProgressCallbackPeriodParam {
st.conn.progressUpdaterPeriod.Period = arg.Value.(time.Duration)
continue
}

s, err := Serial(arg.Value)
if err != nil {
return nil, err
Expand All @@ -716,6 +755,9 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
ss = append(ss, s)
}
}
if (st.conn.progressUpdater != nil && st.conn.progressUpdaterPeriod.Period == 0) || (st.conn.progressUpdater == nil && st.conn.progressUpdaterPeriod.Period > 0) {
return nil, ErrInvalidProgressCallbackHeader
}
if len(ss) > 0 {
query = "EXECUTE " + preparedStatementName + " USING " + strings.Join(ss, ", ")
}
Expand All @@ -739,6 +781,38 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
if err != nil {
return nil, fmt.Errorf("trino: %v", err)
}

if st.conn.progressUpdater != nil {
st.statsCh = make(chan QueryProgressInfo)
st.doneCh = make(chan struct{})

// progress updater go func
go func() {
for {
select {
case stats := <-st.statsCh:
st.conn.progressUpdater.Update(stats)
case <-st.doneCh:
close(st.statsCh)
return
}
}
}()

// initial progress callback call
srStats := QueryProgressInfo{
QueryId: sr.ID,
QueryStats: sr.Stats,
}
select {
case st.statsCh <- srStats:
default:
// ignore when can't send stats
}
st.conn.progressUpdaterPeriod.LastCallbackTime = time.Now()
st.conn.progressUpdaterPeriod.LastQueryState = sr.Stats.State
}

return &sr, handleResponseError(resp.StatusCode, sr.Error)
}

Expand All @@ -754,6 +828,9 @@ type driverRows struct {
coltype []*typeConverter
data []queryData
rowsAffected int64

statsCh chan QueryProgressInfo
doneCh chan struct{}
}

var _ driver.Rows = &driverRows{}
Expand Down Expand Up @@ -976,7 +1053,9 @@ func (qr *driverRows) fetch(allowEOF bool) error {
qr.rowindex = 0
qr.data = qresp.Data
qr.nextURI = qresp.NextURI

if len(qr.data) == 0 {
qr.scheduleProgressUpdate(qresp.ID, qresp.Stats)
if qr.nextURI != "" {
return qr.fetch(allowEOF)
}
Expand All @@ -998,6 +1077,7 @@ func (qr *driverRows) fetch(allowEOF bool) error {
}
}
qr.rowsAffected = qresp.UpdateCount
qr.scheduleProgressUpdate(qresp.ID, qresp.Stats)
return nil
}

Expand Down Expand Up @@ -1043,6 +1123,33 @@ func (qr *driverRows) initColumns(qresp *queryResponse) error {
return nil
}

func (qr *driverRows) scheduleProgressUpdate(id string, stats stmtStats) {
if qr.stmt.conn.progressUpdater == nil {
return
}

qrStats := QueryProgressInfo{
QueryId: id,
QueryStats: stats,
}
currentTime := time.Now()
diff := currentTime.Sub(qr.stmt.conn.progressUpdaterPeriod.LastCallbackTime)
period := qr.stmt.conn.progressUpdaterPeriod.Period

// Check if period has not passed yet AND if query state did not change
if diff < period && qr.stmt.conn.progressUpdaterPeriod.LastQueryState == qrStats.QueryStats.State {
return
}

select {
case qr.statsCh <- qrStats:
default:
// ignore when can't send stats
}
qr.stmt.conn.progressUpdaterPeriod.LastCallbackTime = currentTime
qr.stmt.conn.progressUpdaterPeriod.LastQueryState = qrStats.QueryStats.State
}

type typeConverter struct {
typeName string
parsedType []string
Expand Down Expand Up @@ -1943,3 +2050,19 @@ func (s *NullSlice3Map) Scan(value interface{}) error {
s.Valid = true
return nil
}

type QueryProgressInfo struct {
QueryId string
QueryStats stmtStats
}

type queryProgressCallbackPeriod struct {
Period time.Duration
LastCallbackTime time.Time
LastQueryState string
}

type ProgressUpdater interface {
// Update the query progress, immediately when the query starts, when receiving data, and once when the query is finished.
Update(QueryProgressInfo)
}
96 changes: 96 additions & 0 deletions trino/trino_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"sort"
"testing"
"time"

Expand Down Expand Up @@ -249,6 +250,101 @@ func TestQueryForUsername(t *testing.T) {
}
}

type TestQueryProgressCallback struct {
statusMap map[time.Time]string
}

func (qpc *TestQueryProgressCallback) Update(qpi QueryProgressInfo) {
qpc.statusMap[time.Now()] = qpi.QueryStats.State
}

func TestQueryProgressWithCallback(t *testing.T) {
if testing.Short() {
t.Skip("Skipping test in short mode.")
}
c := &Config{
ServerURI: *integrationServerFlag,
SessionProperties: map[string]string{"query_priority": "1"},
}

dsn, err := c.FormatDSN()
require.NoError(t, err)

db, err := sql.Open("trino", dsn)
require.NoError(t, err)

t.Cleanup(func() {
assert.NoError(t, db.Close())
})

callback := &TestQueryProgressCallback{}

_, err = db.Query("SELECT 2", sql.Named("X-Trino-Progress-Callback", callback))
assert.EqualError(t, err, ErrInvalidProgressCallbackHeader.Error(), "unexpected error")
}

func TestQueryProgressWithCallbackPeriod(t *testing.T) {
if testing.Short() {
t.Skip("Skipping test in short mode.")
}
c := &Config{
ServerURI: *integrationServerFlag,
SessionProperties: map[string]string{"query_priority": "1"},
}

dsn, err := c.FormatDSN()
require.NoError(t, err)

db, err := sql.Open("trino", dsn)
require.NoError(t, err)

t.Cleanup(func() {
assert.NoError(t, db.Close())
})

statusMap := make(map[time.Time]string)
progressUpdater := &TestQueryProgressCallback{
statusMap: statusMap,
}
progressUpdaterPeriod, err := time.ParseDuration("1ms")

rows, err := db.Query("SELECT 2",
sql.Named("X-Trino-Progress-Callback", progressUpdater),
sql.Named("X-Trino-Progress-Callback-Period", progressUpdaterPeriod),
)
require.NoError(t, err, "Failed executing query")
assert.NotNil(t, rows)

for rows.Next() {
var ts string
require.NoError(t, rows.Scan(&ts), "Failed scanning query result")

assert.Equal(t, "2", ts, "Expected value does not equal result value")
}

if err = rows.Err(); err != nil {
t.Fatal(err)
}
if err = rows.Close(); err != nil {
t.Fatal(err)
}

// sort time in order to calculate interval
var keys []time.Time
for k := range statusMap {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
return keys[i].Before(keys[j])
})

for i, k := range keys {
if i > 0 {
assert.GreaterOrEqual(t, k.Sub(keys[i-1]), progressUpdaterPeriod)
}
}
}

func TestQueryColumns(t *testing.T) {
c := &Config{
ServerURI: *integrationServerFlag,
Expand Down

0 comments on commit eaed16c

Please sign in to comment.