diff --git a/errors.go b/errors.go index 4bfb667d6..f98595419 100644 --- a/errors.go +++ b/errors.go @@ -112,14 +112,12 @@ func (a *App) defaultErrorMiddleware(next Handler) Handler { } status := http.StatusInternalServerError // unpack root err and check for HTTPError - switch { - case errors.Is(err, sql.ErrNoRows): + if errors.Is(err, sql.ErrNoRows) { status = http.StatusNotFound - default: - var h HTTPError - if errors.As(err, &h) { - status = h.Status - } + } + var h HTTPError + if errors.As(err, &h) { + status = h.Status } payload := events.Payload{ "context": c, diff --git a/route_info_test.go b/route_info_test.go index 4475c3840..b24039cd1 100644 --- a/route_info_test.go +++ b/route_info_test.go @@ -2,6 +2,7 @@ package buffalo import ( "database/sql" + "fmt" "net/http" "testing" @@ -22,6 +23,18 @@ func Test_RouteInfo_ServeHTTP_SQL_Error(t *testing.T) { return sql.ErrNoRows }) + app.GET("/bad-2", func(c Context) error { + return sql.ErrTxDone + }) + + app.GET("/gone-unwrap", func(c Context) error { + return c.Error(http.StatusGone, sql.ErrTxDone) + }) + + app.GET("/gone-wrap", func(c Context) error { + return c.Error(http.StatusGone, fmt.Errorf("some error wrapping here: %w", sql.ErrNoRows)) + }) + w := httptest.New(app) res := w.HTML("/good").Get() @@ -29,4 +42,10 @@ func Test_RouteInfo_ServeHTTP_SQL_Error(t *testing.T) { res = w.HTML("/bad").Get() r.Equal(http.StatusNotFound, res.Code) + + res = w.HTML("/gone-wrap").Get() + r.Equal(http.StatusGone, res.Code) + + res = w.HTML("/gone-unwrap").Get() + r.Equal(http.StatusGone, res.Code) }