diff --git a/trino/trino.go b/trino/trino.go index 1a5f02c..d8361c7 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -713,27 +713,27 @@ type stmtResponse struct { } type stmtStats struct { - State string `json:"state"` - Scheduled bool `json:"scheduled"` - Nodes int `json:"nodes"` - TotalSplits int `json:"totalSplits"` - QueuesSplits int `json:"queuedSplits"` - RunningSplits int `json:"runningSplits"` - CompletedSplits int `json:"completedSplits"` - UserTimeMillis int `json:"userTimeMillis"` - CPUTimeMillis int64 `json:"cpuTimeMillis"` - WallTimeMillis int64 `json:"wallTimeMillis"` - QueuedTimeMillis int64 `json:"queuedTimeMillis"` - ElapsedTimeMillis int64 `json:"elapsedTimeMillis"` - ProcessedRows int64 `json:"processedRows"` - ProcessedBytes int64 `json:"processedBytes"` - PhysicalInputBytes int64 `json:"physicalInputBytes"` - PhysicalWrittenBytes int64 `json:"physicalWrittenBytes"` - PeakMemoryBytes int64 `json:"peakMemoryBytes"` - SpilledBytes int64 `json:"spilledBytes"` - RootStage stmtStage `json:"rootStage"` - ProgressPercentage float32 `json:"progressPercentage"` - RunningPercentage float32 `json:"runningPercentage"` + State string `json:"state"` + Scheduled bool `json:"scheduled"` + Nodes int `json:"nodes"` + TotalSplits int `json:"totalSplits"` + QueuesSplits int `json:"queuedSplits"` + RunningSplits int `json:"runningSplits"` + CompletedSplits int `json:"completedSplits"` + UserTimeMillis int `json:"userTimeMillis"` + CPUTimeMillis int64 `json:"cpuTimeMillis"` + WallTimeMillis int64 `json:"wallTimeMillis"` + QueuedTimeMillis int64 `json:"queuedTimeMillis"` + ElapsedTimeMillis int64 `json:"elapsedTimeMillis"` + ProcessedRows int64 `json:"processedRows"` + ProcessedBytes int64 `json:"processedBytes"` + PhysicalInputBytes int64 `json:"physicalInputBytes"` + PhysicalWrittenBytes int64 `json:"physicalWrittenBytes"` + PeakMemoryBytes int64 `json:"peakMemoryBytes"` + SpilledBytes int64 `json:"spilledBytes"` + RootStage stmtStage `json:"rootStage"` + ProgressPercentage jsonFloat64 `json:"progressPercentage"` + RunningPercentage jsonFloat64 `json:"runningPercentage"` } type ErrTrino struct { @@ -792,6 +792,28 @@ type stmtStage struct { SubStages []stmtStage `json:"subStages"` } +type jsonFloat64 float64 + +func (f *jsonFloat64) UnmarshalJSON(data []byte) error { + var v float64 + err := json.Unmarshal(data, &v) + if err != nil { + var jsonErr *json.UnmarshalTypeError + if errors.As(err, &jsonErr) { + if f != nil { + *f = 0 + } + return nil + } + return err + } + p := (*float64)(f) + *p = v + return nil +} + +var _ json.Unmarshaler = new(jsonFloat64) + func (st *driverStmt) Query(args []driver.Value) (driver.Rows, error) { return nil, driver.ErrSkip } diff --git a/trino/trino_test.go b/trino/trino_test.go index 0f591c5..502a749 100644 --- a/trino/trino_test.go +++ b/trino/trino_test.go @@ -250,6 +250,34 @@ func TestRoundTripRetryQueryError(t *testing.T) { assert.IsTypef(t, new(ErrQueryFailed), err, "unexpected error: %w", err) } +func TestRoundTripBogusData(t *testing.T) { + count := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if count == 0 { + count++ + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + // some invalid JSON + w.Write([]byte(`{"stats": {"progressPercentage": ""}}`)) + })) + + t.Cleanup(ts.Close) + + db, err := sql.Open("trino", ts.URL) + require.NoError(t, err) + + t.Cleanup(func() { + assert.NoError(t, db.Close()) + }) + + rows, err := db.Query("SELECT 1") + require.NoError(t, err) + assert.False(t, rows.Next()) + require.NoError(t, rows.Err()) +} + func TestRoundTripCancellation(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusServiceUnavailable) @@ -336,10 +364,12 @@ func TestQueryForUsername(t *testing.T) { } type TestQueryProgressCallback struct { - statusMap map[time.Time]string + progressMap map[time.Time]float64 + statusMap map[time.Time]string } func (qpc *TestQueryProgressCallback) Update(qpi QueryProgressInfo) { + qpc.progressMap[time.Now()] = float64(qpi.QueryStats.ProgressPercentage) qpc.statusMap[time.Now()] = qpi.QueryStats.State } @@ -387,9 +417,11 @@ func TestQueryProgressWithCallbackPeriod(t *testing.T) { assert.NoError(t, db.Close()) }) + progressMap := make(map[time.Time]float64) statusMap := make(map[time.Time]string) progressUpdater := &TestQueryProgressCallback{ - statusMap: statusMap, + progressMap: progressMap, + statusMap: statusMap, } progressUpdaterPeriod, err := time.ParseDuration("1ms") require.NoError(t, err) @@ -416,6 +448,8 @@ func TestQueryProgressWithCallbackPeriod(t *testing.T) { } // sort time in order to calculate interval + assert.NotEmpty(t, progressMap) + assert.NotEmpty(t, statusMap) var keys []time.Time for k := range statusMap { keys = append(keys, k) @@ -428,6 +462,7 @@ func TestQueryProgressWithCallbackPeriod(t *testing.T) { if i > 0 { assert.GreaterOrEqual(t, k.Sub(keys[i-1]), progressUpdaterPeriod) } + assert.GreaterOrEqual(t, progressMap[k], 0.0) } }