Skip to content

Commit

Permalink
perf(spanner): better error handling (#10734)
Browse files Browse the repository at this point in the history
- perf(spanner): avoid using fmt.Errorf unnecessarily
- perf(spanner): avoid duplicated errors.New in UnmarshalJSON
- fix(spanner): error strings should not be capitalized

Updates #9749
  • Loading branch information
egonelbre authored Sep 25, 2024
1 parent bbe7b9c commit c342f65
Show file tree
Hide file tree
Showing 12 changed files with 43 additions and 35 deletions.
2 changes: 1 addition & 1 deletion spanner/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func validDatabaseName(db string) error {
func parseDatabaseName(db string) (project, instance, database string, err error) {
matches := validDBPattern.FindStringSubmatch(db)
if len(matches) == 0 {
return "", "", "", fmt.Errorf("Failed to parse database name from %q according to pattern %q",
return "", "", "", fmt.Errorf("failed to parse database name from %q according to pattern %q",
db, validDBPattern.String())
}
return matches[1], matches[2], matches[3], nil
Expand Down
4 changes: 2 additions & 2 deletions spanner/client_benchmarks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package spanner

import (
"context"
"fmt"
"errors"
"math/rand"
"reflect"
"sync"
Expand Down Expand Up @@ -87,7 +87,7 @@ func createBenchmarkServer(incStep uint64) (server *MockedSpannerInMemTestServer
if uint64(client.idleSessions.idleList.Len()) == client.idleSessions.MinOpened {
return nil
}
return fmt.Errorf("not yet initialized")
return errors.New("not yet initialized")
})
return
}
Expand Down
2 changes: 1 addition & 1 deletion spanner/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1930,7 +1930,7 @@ func TestClient_ReadWriteTransaction_BufferedWriteBeforeSqlStatementWithError(t
// We ignore the error and proceed to commit the transaction.
_, err := tx.Update(ctx, NewStatement(UpdateBarSetFoo))
if err == nil {
return fmt.Errorf("missing expected InvalidArgument error")
return errors.New("missing expected InvalidArgument error")
}
return nil
})
Expand Down
2 changes: 1 addition & 1 deletion spanner/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2840,7 +2840,7 @@ func TestIntegration_StructTypes(t *testing.T) {
return fmt.Errorf("len(rows) = %d; want 1", len(rows))
}
if !rows[0].Valid {
return fmt.Errorf("rows[0] is NULL")
return errors.New("rows[0] is NULL")
}
var i, j int64
if err := rows[0].Row.Columns(&i, &j); err != nil {
Expand Down
4 changes: 2 additions & 2 deletions spanner/row.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,10 +418,10 @@ func (r *Row) ToStructLenient(p interface{}) error {
// err := spanner.SelectAll(row, &singersByPtr, spanner.WithLenient())
func SelectAll(rows rowIterator, destination interface{}, options ...DecodeOptions) error {
if rows == nil {
return fmt.Errorf("rows is nil")
return errors.New("rows is nil")
}
if destination == nil {
return fmt.Errorf("destination is nil")
return errors.New("destination is nil")
}
dstVal := reflect.ValueOf(destination)
if !dstVal.IsValid() || (dstVal.Kind() == reflect.Ptr && dstVal.IsNil()) {
Expand Down
7 changes: 4 additions & 3 deletions spanner/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"bytes"
"container/heap"
"context"
"errors"
"fmt"
"math/rand"
"reflect"
Expand Down Expand Up @@ -279,7 +280,7 @@ func TestTakeFromIdleListChecked(t *testing.T) {
numOpened := uint64(sp.idleList.Len())
sp.mu.Unlock()
if numOpened < sp.SessionPoolConfig.incStep-1 {
return fmt.Errorf("creation not yet finished")
return errors.New("creation not yet finished")
}
return nil
})
Expand Down Expand Up @@ -1900,7 +1901,7 @@ func TestMaintainer_DeletesSessions(t *testing.T) {
sp.mu.Lock()
defer sp.mu.Unlock()
if sp.numOpened > 0 {
return fmt.Errorf("session pool still contains more than 0 sessions")
return errors.New("session pool still contains more than 0 sessions")
}
return nil
})
Expand Down Expand Up @@ -2023,7 +2024,7 @@ func TestSessionCreationIsDistributedOverChannels(t *testing.T) {
numOpened := uint64(sp.idleList.Len())
sp.mu.Unlock()
if numOpened < spc.MinOpened {
return fmt.Errorf("not yet initialized")
return errors.New("not yet initialized")
}
return nil
})
Expand Down
5 changes: 3 additions & 2 deletions spanner/sessionclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package spanner

import (
"context"
"errors"
"fmt"
"sync"
"testing"
Expand Down Expand Up @@ -260,7 +261,7 @@ func TestBatchCreateAndCloseSession(t *testing.T) {
client.idleSessions.mu.Lock()
defer client.idleSessions.mu.Unlock()
if client.idleSessions.multiplexedSession == nil {
return fmt.Errorf("multiplexed session not created yet")
return errors.New("multiplexed session not created yet")
}
return nil
})
Expand Down Expand Up @@ -475,7 +476,7 @@ func TestBatchCreateSessions_ServerExhausted(t *testing.T) {
if isMultiplexEnabled {
waitFor(t, func() error {
if client.idleSessions.multiplexedSession == nil {
return fmt.Errorf("multiplexed session not created yet")
return errors.New("multiplexed session not created yet")
}
return nil
})
Expand Down
7 changes: 4 additions & 3 deletions spanner/spannertest/db_eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package spannertest

import (
"bytes"
"errors"
"fmt"
"regexp"
"strconv"
Expand Down Expand Up @@ -284,7 +285,7 @@ func (ec evalContext) evalArithOp(e spansql.ArithOp) (interface{}, error) {
}
if rhs == 0 {
// TODO: Does real Spanner use a specific error code here?
return nil, fmt.Errorf("divide by zero")
return nil, errors.New("divide by zero")
}
return lhs / rhs, nil
case spansql.Add, spansql.Sub, spansql.Mul:
Expand Down Expand Up @@ -714,7 +715,7 @@ func (ec evalContext) evalExtractExpr(expr spansql.ExtractExpr) (result interfac
return int64(v.Year), nil
}
}
return nil, fmt.Errorf("Extract with part %v not supported", expr.Part)
return nil, fmt.Errorf("extract with part %v not supported", expr.Part)
}

func (ec evalContext) evalAtTimeZoneExpr(expr spansql.AtTimeZoneExpr) (result interface{}, err error) {
Expand Down Expand Up @@ -916,7 +917,7 @@ func (ec evalContext) colInfo(e spansql.Expr) (colInfo, error) {
return colInfo{}, err
}
if ci.Type.Array {
return colInfo{}, fmt.Errorf("can't nest array literals")
return colInfo{}, errors.New("can't nest array literals")
}
ci.Type.Array = true
return ci, nil
Expand Down
7 changes: 4 additions & 3 deletions spanner/spannertest/db_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package spannertest

import (
"errors"
"fmt"
"io"
"sort"
Expand Down Expand Up @@ -445,7 +446,7 @@ func (d *database) evalSelect(sel spansql.Select, qc *queryContext) (si *selIter
// First stage is to identify the data source.
// If there's a FROM then that names a table to use.
if len(sel.From) > 1 {
return nil, fmt.Errorf("selecting with more than one FROM clause not yet supported")
return nil, errors.New("selecting with more than one FROM clause not yet supported")
}
if len(sel.From) == 1 {
var err error
Expand Down Expand Up @@ -751,11 +752,11 @@ func (d *database) evalSelectFrom(qc *queryContext, ec evalContext, sf spansql.S

func newJoinIter(lhs, rhs *rawIter, lhsEC, rhsEC evalContext, sfj spansql.SelectFromJoin) (*joinIter, evalContext, error) {
if sfj.On != nil && len(sfj.Using) > 0 {
return nil, evalContext{}, fmt.Errorf("JOIN may not have both ON and USING clauses")
return nil, evalContext{}, errors.New("JOIN may not have both ON and USING clauses")
}
if sfj.On == nil && len(sfj.Using) == 0 && sfj.Type != spansql.CrossJoin {
// TODO: This isn't correct for joining against a non-table.
return nil, evalContext{}, fmt.Errorf("non-CROSS JOIN must have ON or USING clause")
return nil, evalContext{}, errors.New("non-CROSS JOIN must have ON or USING clause")
}

// Start with the context from the LHS (aliases and params should be the same on both sides).
Expand Down
5 changes: 3 additions & 2 deletions spanner/spannertest/inmem.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ package spannertest
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -629,10 +630,10 @@ func (s *server) StreamingRead(req *spannerpb.ReadRequest, stream spannerpb.Span
}
if len(req.ResumeToken) > 0 {
// This should only happen if we send resume_token ourselves.
return fmt.Errorf("read resumption not supported")
return errors.New("read resumption not supported")
}
if len(req.PartitionToken) > 0 {
return fmt.Errorf("partition restrictions not supported")
return errors.New("partition restrictions not supported")
}

var ri rowIter
Expand Down
30 changes: 16 additions & 14 deletions spanner/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ var (

protoMsgReflectType = reflect.TypeOf((*proto.Message)(nil)).Elem()
protoEnumReflectType = reflect.TypeOf((*protoreflect.Enum)(nil)).Elem()

errPayloadNil = errors.New("payload should not be nil")
)

// UseNumberWithJSONDecoderEncoder specifies whether Cloud Spanner JSON numbers are decoded
Expand Down Expand Up @@ -222,7 +224,7 @@ func (n NullInt64) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullInt64.
func (n *NullInt64) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
return errPayloadNil
}
if jsonIsNull(payload) {
n.Int64 = int64(0)
Expand Down Expand Up @@ -302,7 +304,7 @@ func (n NullString) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullString.
func (n *NullString) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
return errPayloadNil
}
if jsonIsNull(payload) {
n.StringVal = ""
Expand Down Expand Up @@ -387,7 +389,7 @@ func (n NullFloat64) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullFloat64.
func (n *NullFloat64) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
return errPayloadNil
}
if jsonIsNull(payload) {
n.Float64 = float64(0)
Expand Down Expand Up @@ -467,7 +469,7 @@ func (n NullFloat32) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullFloat32.
func (n *NullFloat32) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
return errPayloadNil
}
if jsonIsNull(payload) {
n.Float32 = float32(0)
Expand Down Expand Up @@ -547,7 +549,7 @@ func (n NullBool) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullBool.
func (n *NullBool) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
return errPayloadNil
}
if jsonIsNull(payload) {
n.Bool = false
Expand Down Expand Up @@ -627,7 +629,7 @@ func (n NullTime) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullTime.
func (n *NullTime) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
return errPayloadNil
}
if jsonIsNull(payload) {
n.Time = time.Time{}
Expand Down Expand Up @@ -712,7 +714,7 @@ func (n NullDate) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullDate.
func (n *NullDate) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
return errPayloadNil
}
if jsonIsNull(payload) {
n.Date = civil.Date{}
Expand Down Expand Up @@ -797,7 +799,7 @@ func (n NullNumeric) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullNumeric.
func (n *NullNumeric) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
return errPayloadNil
}
if jsonIsNull(payload) {
n.Numeric = big.Rat{}
Expand Down Expand Up @@ -894,7 +896,7 @@ func (n NullJSON) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullJSON.
func (n *NullJSON) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
return errPayloadNil
}
if jsonIsNull(payload) {
n.Valid = false
Expand Down Expand Up @@ -942,7 +944,7 @@ func (n PGNumeric) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for PGNumeric.
func (n *PGNumeric) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
return errPayloadNil
}
if jsonIsNull(payload) {
n.Numeric = ""
Expand Down Expand Up @@ -989,7 +991,7 @@ func (n NullProtoMessage) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullProtoMessage.
func (n *NullProtoMessage) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
return errPayloadNil
}
if jsonIsNull(payload) {
n.ProtoMessageVal = nil
Expand Down Expand Up @@ -1035,7 +1037,7 @@ func (n NullProtoEnum) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullProtoEnum.
func (n *NullProtoEnum) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
return errPayloadNil
}
if jsonIsNull(payload) {
n.ProtoEnumVal = nil
Expand Down Expand Up @@ -1096,7 +1098,7 @@ func (n PGJsonB) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for PGJsonB.
func (n *PGJsonB) UnmarshalJSON(payload []byte) error {
if payload == nil {
return fmt.Errorf("payload should not be nil")
return errPayloadNil
}
if jsonIsNull(payload) {
n.Valid = false
Expand Down Expand Up @@ -4541,7 +4543,7 @@ func convertCustomTypeValue(sourceType decodableSpannerType, v interface{}) (int
var destination reflect.Value
switch sourceType {
case spannerTypeInvalid:
return nil, fmt.Errorf("cannot encode a value to type spannerTypeInvalid")
return nil, errors.New("cannot encode a value to type spannerTypeInvalid")
case spannerTypeNonNullString:
destination = reflect.Indirect(reflect.New(reflect.TypeOf("")))
case spannerTypeNullString:
Expand Down
3 changes: 2 additions & 1 deletion spanner/value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package spanner
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"math"
"math/big"
Expand Down Expand Up @@ -211,7 +212,7 @@ func (c *customArray) DecodeSpanner(val interface{}) error {
}
asSlice := listVal.AsSlice()
if len(asSlice) != 4 {
return fmt.Errorf("failed to decode customArray: expected array of length 4")
return errors.New("failed to decode customArray: expected array of length 4")
}
for i, vI := range asSlice {
vStr, ok := vI.(string)
Expand Down

0 comments on commit c342f65

Please sign in to comment.