From d453ee408fd8e3a7e7859e6ad1f940f5ebc6fd18 Mon Sep 17 00:00:00 2001 From: Mark Bates Date: Sat, 25 Aug 2018 09:41:22 -0400 Subject: [PATCH] uses a top level middleware to more easily catch application errors (#1248) * uses a top level middleware to more easily catch application errors * fixes broken tests * custom error handlers don't have access to context variables set in middleware fixes #1250 * improved the test a lot * fixes gofmt issues * uses a top level middleware to more easily catch application errors * fixes broken tests * custom error handlers don't have access to context variables set in middleware fixes #1250 * improved the test a lot * fixed fmt * fix fmt --- app.go | 26 ++++++++++++++++---------- errors.go | 30 ++++++++++++++++++++++++++++++ errors_test.go | 28 ++++++++++++++++++++++++++++ route_info.go | 25 ++++--------------------- router_test.go | 10 +++++----- 5 files changed, 83 insertions(+), 36 deletions(-) diff --git a/app.go b/app.go index 2f87dd12c..0366dc290 100644 --- a/app.go +++ b/app.go @@ -15,14 +15,15 @@ import ( type App struct { Options // Middleware returns the current MiddlewareStack for the App/Group. - Middleware *MiddlewareStack - ErrorHandlers ErrorHandlers - router *mux.Router - moot *sync.Mutex - routes RouteList - root *App - children []*App - filepaths []string + Middleware *MiddlewareStack + ErrorHandlers ErrorHandlers + ErrorMiddleware MiddlewareFunc + router *mux.Router + moot *sync.Mutex + routes RouteList + root *App + children []*App + filepaths []string } // New returns a new instance of App and adds some sane, and useful, defaults. @@ -31,8 +32,7 @@ func New(opts Options) *App { opts = optionsWithDefaults(opts) a := &App{ - Options: opts, - Middleware: newMiddlewareStack(), + Options: opts, ErrorHandlers: ErrorHandlers{ 404: defaultErrorHandler, 500: defaultErrorHandler, @@ -43,6 +43,12 @@ func New(opts Options) *App { children: []*App{}, } + dem := a.defaultErrorMiddleware + if a.ErrorMiddleware != nil { + dem = a.ErrorMiddleware + } + a.Middleware = newMiddlewareStack(dem) + notFoundHandler := func(errorf string, code int) http.HandlerFunc { return func(res http.ResponseWriter, req *http.Request) { c := a.newContext(RouteInfo{}, res, req) diff --git a/errors.go b/errors.go index a5efb85cc..3f97e6a3d 100644 --- a/errors.go +++ b/errors.go @@ -1,6 +1,7 @@ package buffalo import ( + "database/sql" "encoding/json" "fmt" "net/http" @@ -74,6 +75,35 @@ func (a *App) PanicHandler(next Handler) Handler { } } +func (a *App) defaultErrorMiddleware(next Handler) Handler { + return func(c Context) error { + err := next(c) + if err == nil { + return nil + } + status := 500 + // unpack root cause and check for HTTPError + cause := errors.Cause(err) + switch cause { + case sql.ErrNoRows: + status = 404 + default: + if h, ok := cause.(HTTPError); ok { + status = h.Status + } + } + eh := a.ErrorHandlers.Get(status) + err = eh(status, err, c) + if err != nil { + // things have really hit the fan if we're here!! + a.Logger.Error(err) + c.Response().WriteHeader(500) + c.Response().Write([]byte(err.Error())) + } + return nil + } +} + func productionErrorResponseFor(status int) []byte { if status == http.StatusNotFound { return []byte(prodNotFoundTmpl) diff --git a/errors_test.go b/errors_test.go index 658be47cb..cd9f9c07e 100644 --- a/errors_test.go +++ b/errors_test.go @@ -56,3 +56,31 @@ func Test_PanicHandler(t *testing.T) { }) } } + +func Test_defaultErrorMiddleware(t *testing.T) { + r := require.New(t) + app := New(Options{}) + var x string + var ok bool + app.ErrorHandlers[422] = func(code int, err error, c Context) error { + x, ok = c.Value("T").(string) + c.Response().WriteHeader(code) + c.Response().Write([]byte(err.Error())) + return nil + } + app.Use(func(next Handler) Handler { + return func(c Context) error { + c.Set("T", "t") + return c.Error(422, errors.New("boom")) + } + }) + app.GET("/", func(c Context) error { + return nil + }) + + w := httptest.New(app) + res := w.HTML("/").Get() + r.Equal(422, res.Code) + r.True(ok) + r.Equal("t", x) +} diff --git a/route_info.go b/route_info.go index e728724dc..e7e27cafe 100644 --- a/route_info.go +++ b/route_info.go @@ -1,11 +1,9 @@ package buffalo import ( - "database/sql" "net/http" gcontext "github.com/gorilla/context" - "github.com/pkg/errors" ) func (info RouteInfo) ServeHTTP(res http.ResponseWriter, req *http.Request) { @@ -19,24 +17,9 @@ func (info RouteInfo) ServeHTTP(res http.ResponseWriter, req *http.Request) { err := a.Middleware.handler(info)(c) if err != nil { - status := 500 - // unpack root cause and check for HTTPError - cause := errors.Cause(err) - switch cause { - case sql.ErrNoRows: - status = 404 - default: - if h, ok := cause.(HTTPError); ok { - status = h.Status - } - } - eh := a.ErrorHandlers.Get(status) - err = eh(status, err, c) - if err != nil { - // things have really hit the fan if we're here!! - a.Logger.Error(err) - c.Response().WriteHeader(500) - c.Response().Write([]byte(err.Error())) - } + // things have really hit the fan if we're here!! + a.Logger.Error(err) + c.Response().WriteHeader(500) + c.Response().Write([]byte(err.Error())) } } diff --git a/router_test.go b/router_test.go index f4b0d9484..fd11480bc 100644 --- a/router_test.go +++ b/router_test.go @@ -299,15 +299,15 @@ func Test_Router_Group_Middleware(t *testing.T) { a := testApp() a.Use(func(h Handler) Handler { return h }) - r.Len(a.Middleware.stack, 4) + r.Len(a.Middleware.stack, 5) g := a.Group("/api/v1") - r.Len(a.Middleware.stack, 4) - r.Len(g.Middleware.stack, 4) + r.Len(a.Middleware.stack, 5) + r.Len(g.Middleware.stack, 5) g.Use(func(h Handler) Handler { return h }) - r.Len(a.Middleware.stack, 4) - r.Len(g.Middleware.stack, 5) + r.Len(a.Middleware.stack, 5) + r.Len(g.Middleware.stack, 6) } func Test_Router_Redirect(t *testing.T) {