diff --git a/trino/trino.go b/trino/trino.go index e87d91a..3fe4a28 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -100,6 +100,9 @@ var ( // ErrInvalidResponseType indicates that the server returned an invalid type definition. ErrInvalidResponseType = errors.New("trino: server response contains an invalid type") + + // ErrInvalidProgressCallbackHeader indicates that server did not get valid headers for progress callback + ErrInvalidProgressCallbackHeader = errors.New("trino: both " + trinoProgressCallbackParam + " and " + trinoProgressCallbackPeriodParam + " must be set when using progress callback") ) const ( @@ -121,6 +124,9 @@ const ( trinoSetRoleHeader = trinoHeaderPrefix + `Set-Role` trinoExtraCredentialHeader = trinoHeaderPrefix + `Extra-Credential` + trinoProgressCallbackParam = trinoHeaderPrefix + `Progress-Callback` + trinoProgressCallbackPeriodParam = trinoHeaderPrefix + `Progress-Callback-Period` + trinoAddedPrepareHeader = trinoHeaderPrefix + `Added-Prepare` trinoDeallocatedPrepareHeader = trinoHeaderPrefix + `Deallocated-Prepare` @@ -232,12 +238,14 @@ func (c *Config) FormatDSN() (string, error) { // Conn is a Trino connection. type Conn struct { - baseURL string - auth *url.Userinfo - httpClient http.Client - httpHeaders http.Header - kerberosClient client.Client - kerberosEnabled bool + baseURL string + auth *url.Userinfo + httpClient http.Client + httpHeaders http.Header + kerberosClient client.Client + kerberosEnabled bool + progressUpdater ProgressUpdater + progressUpdaterPeriod queryProgressCallbackPeriod } var ( @@ -541,9 +549,11 @@ func newErrQueryFailedFromResponse(resp *http.Response) *ErrQueryFailed { } type driverStmt struct { - conn *Conn - query string - user string + conn *Conn + query string + user string + statsCh chan QueryProgressInfo + doneCh chan struct{} } var ( @@ -553,7 +563,14 @@ var ( _ driver.NamedValueChecker = &driverStmt{} ) +// Close closes statement just before releasing connection func (st *driverStmt) Close() error { + if st.doneCh != nil { + close(st.doneCh) + } + if st.statsCh != nil { + <-st.statsCh + } return nil } @@ -576,11 +593,14 @@ func (st *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue) queryID: sr.ID, nextURI: sr.NextURI, rowsAffected: sr.UpdateCount, + statsCh: st.statsCh, + doneCh: st.doneCh, } // consume all results, if there are any for err == nil { err = rows.fetch(true) } + if err != nil && err != io.EOF { return nil, err } @@ -595,6 +615,13 @@ func (st *driverStmt) CheckNamedValue(arg *driver.NamedValue) error { if reflect.TypeOf(arg.Value).Kind() == reflect.Slice { return nil } + + if arg.Name == trinoProgressCallbackParam { + return nil + } + if arg.Name == trinoProgressCallbackPeriodParam { + return nil + } return driver.ErrSkip } @@ -609,19 +636,20 @@ type stmtResponse struct { } type stmtStats struct { - State string `json:"state"` - Scheduled bool `json:"scheduled"` - Nodes int `json:"nodes"` - TotalSplits int `json:"totalSplits"` - QueuesSplits int `json:"queuedSplits"` - RunningSplits int `json:"runningSplits"` - CompletedSplits int `json:"completedSplits"` - UserTimeMillis int `json:"userTimeMillis"` - CPUTimeMillis int `json:"cpuTimeMillis"` - WallTimeMillis int `json:"wallTimeMillis"` - ProcessedRows int `json:"processedRows"` - ProcessedBytes int `json:"processedBytes"` - RootStage stmtStage `json:"rootStage"` + State string `json:"state"` + Scheduled bool `json:"scheduled"` + Nodes int `json:"nodes"` + TotalSplits int `json:"totalSplits"` + QueuesSplits int `json:"queuedSplits"` + RunningSplits int `json:"runningSplits"` + CompletedSplits int `json:"completedSplits"` + UserTimeMillis int `json:"userTimeMillis"` + CPUTimeMillis int `json:"cpuTimeMillis"` + WallTimeMillis int `json:"wallTimeMillis"` + ProcessedRows int `json:"processedRows"` + ProcessedBytes int `json:"processedBytes"` + RootStage stmtStage `json:"rootStage"` + ProgressPercentage float32 `json:"progressPercentage"` } type stmtError struct { @@ -678,6 +706,8 @@ func (st *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue stmt: st, queryID: sr.ID, nextURI: sr.NextURI, + statsCh: st.statsCh, + doneCh: st.doneCh, } if err = rows.fetch(false); err != nil { return nil, err @@ -693,6 +723,15 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt hs = make(http.Header) var ss []string for _, arg := range args { + if arg.Name == trinoProgressCallbackParam { + st.conn.progressUpdater = arg.Value.(ProgressUpdater) + continue + } + if arg.Name == trinoProgressCallbackPeriodParam { + st.conn.progressUpdaterPeriod.Period = arg.Value.(time.Duration) + continue + } + s, err := Serial(arg.Value) if err != nil { return nil, err @@ -716,6 +755,9 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt ss = append(ss, s) } } + if (st.conn.progressUpdater != nil && st.conn.progressUpdaterPeriod.Period == 0) || (st.conn.progressUpdater == nil && st.conn.progressUpdaterPeriod.Period > 0) { + return nil, ErrInvalidProgressCallbackHeader + } if len(ss) > 0 { query = "EXECUTE " + preparedStatementName + " USING " + strings.Join(ss, ", ") } @@ -739,6 +781,38 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt if err != nil { return nil, fmt.Errorf("trino: %v", err) } + + if st.conn.progressUpdater != nil { + st.statsCh = make(chan QueryProgressInfo) + st.doneCh = make(chan struct{}) + + // progress updater go func + go func() { + for { + select { + case stats := <-st.statsCh: + st.conn.progressUpdater.Update(stats) + case <-st.doneCh: + close(st.statsCh) + return + } + } + }() + + // initial progress callback call + srStats := QueryProgressInfo{ + QueryId: sr.ID, + QueryStats: sr.Stats, + } + select { + case st.statsCh <- srStats: + default: + // ignore when can't send stats + } + st.conn.progressUpdaterPeriod.LastCallbackTime = time.Now() + st.conn.progressUpdaterPeriod.LastQueryState = sr.Stats.State + } + return &sr, handleResponseError(resp.StatusCode, sr.Error) } @@ -754,6 +828,9 @@ type driverRows struct { coltype []*typeConverter data []queryData rowsAffected int64 + + statsCh chan QueryProgressInfo + doneCh chan struct{} } var _ driver.Rows = &driverRows{} @@ -976,7 +1053,9 @@ func (qr *driverRows) fetch(allowEOF bool) error { qr.rowindex = 0 qr.data = qresp.Data qr.nextURI = qresp.NextURI + if len(qr.data) == 0 { + qr.scheduleProgressUpdate(qresp.ID, qresp.Stats) if qr.nextURI != "" { return qr.fetch(allowEOF) } @@ -998,6 +1077,7 @@ func (qr *driverRows) fetch(allowEOF bool) error { } } qr.rowsAffected = qresp.UpdateCount + qr.scheduleProgressUpdate(qresp.ID, qresp.Stats) return nil } @@ -1043,6 +1123,33 @@ func (qr *driverRows) initColumns(qresp *queryResponse) error { return nil } +func (qr *driverRows) scheduleProgressUpdate(id string, stats stmtStats) { + if qr.stmt.conn.progressUpdater == nil { + return + } + + qrStats := QueryProgressInfo{ + QueryId: id, + QueryStats: stats, + } + currentTime := time.Now() + diff := currentTime.Sub(qr.stmt.conn.progressUpdaterPeriod.LastCallbackTime) + period := qr.stmt.conn.progressUpdaterPeriod.Period + + // Check if period has not passed yet AND if query state did not change + if diff < period && qr.stmt.conn.progressUpdaterPeriod.LastQueryState == qrStats.QueryStats.State { + return + } + + select { + case qr.statsCh <- qrStats: + default: + // ignore when can't send stats + } + qr.stmt.conn.progressUpdaterPeriod.LastCallbackTime = currentTime + qr.stmt.conn.progressUpdaterPeriod.LastQueryState = qrStats.QueryStats.State +} + type typeConverter struct { typeName string parsedType []string @@ -1943,3 +2050,19 @@ func (s *NullSlice3Map) Scan(value interface{}) error { s.Valid = true return nil } + +type QueryProgressInfo struct { + QueryId string + QueryStats stmtStats +} + +type queryProgressCallbackPeriod struct { + Period time.Duration + LastCallbackTime time.Time + LastQueryState string +} + +type ProgressUpdater interface { + // Update the query progress, immediately when the query starts, when receiving data, and once when the query is finished. + Update(QueryProgressInfo) +} diff --git a/trino/trino_test.go b/trino/trino_test.go index 5dcb9fe..7856da1 100644 --- a/trino/trino_test.go +++ b/trino/trino_test.go @@ -23,6 +23,7 @@ import ( "net/http" "net/http/httptest" "reflect" + "sort" "testing" "time" @@ -249,6 +250,101 @@ func TestQueryForUsername(t *testing.T) { } } +type TestQueryProgressCallback struct { + statusMap map[time.Time]string +} + +func (qpc *TestQueryProgressCallback) Update(qpi QueryProgressInfo) { + qpc.statusMap[time.Now()] = qpi.QueryStats.State +} + +func TestQueryProgressWithCallback(t *testing.T) { + if testing.Short() { + t.Skip("Skipping test in short mode.") + } + c := &Config{ + ServerURI: *integrationServerFlag, + SessionProperties: map[string]string{"query_priority": "1"}, + } + + dsn, err := c.FormatDSN() + require.NoError(t, err) + + db, err := sql.Open("trino", dsn) + require.NoError(t, err) + + t.Cleanup(func() { + assert.NoError(t, db.Close()) + }) + + callback := &TestQueryProgressCallback{} + + _, err = db.Query("SELECT 2", sql.Named("X-Trino-Progress-Callback", callback)) + assert.EqualError(t, err, ErrInvalidProgressCallbackHeader.Error(), "unexpected error") +} + +func TestQueryProgressWithCallbackPeriod(t *testing.T) { + if testing.Short() { + t.Skip("Skipping test in short mode.") + } + c := &Config{ + ServerURI: *integrationServerFlag, + SessionProperties: map[string]string{"query_priority": "1"}, + } + + dsn, err := c.FormatDSN() + require.NoError(t, err) + + db, err := sql.Open("trino", dsn) + require.NoError(t, err) + + t.Cleanup(func() { + assert.NoError(t, db.Close()) + }) + + statusMap := make(map[time.Time]string) + progressUpdater := &TestQueryProgressCallback{ + statusMap: statusMap, + } + progressUpdaterPeriod, err := time.ParseDuration("1ms") + + rows, err := db.Query("SELECT 2", + sql.Named("X-Trino-Progress-Callback", progressUpdater), + sql.Named("X-Trino-Progress-Callback-Period", progressUpdaterPeriod), + ) + require.NoError(t, err, "Failed executing query") + assert.NotNil(t, rows) + + for rows.Next() { + var ts string + require.NoError(t, rows.Scan(&ts), "Failed scanning query result") + + assert.Equal(t, "2", ts, "Expected value does not equal result value") + } + + if err = rows.Err(); err != nil { + t.Fatal(err) + } + if err = rows.Close(); err != nil { + t.Fatal(err) + } + + // sort time in order to calculate interval + var keys []time.Time + for k := range statusMap { + keys = append(keys, k) + } + sort.Slice(keys, func(i, j int) bool { + return keys[i].Before(keys[j]) + }) + + for i, k := range keys { + if i > 0 { + assert.GreaterOrEqual(t, k.Sub(keys[i-1]), progressUpdaterPeriod) + } + } +} + func TestQueryColumns(t *testing.T) { c := &Config{ ServerURI: *integrationServerFlag,