diff --git a/app.go b/app.go index aec94b3a6..804061e43 100644 --- a/app.go +++ b/app.go @@ -75,7 +75,6 @@ func New(opts Options) *App { } a.Use(a.PanicHandler) a.Use(RequestLogger) - a.Use(sessionSaver) return a } diff --git a/default_context.go b/default_context.go index e648948a0..0aad92404 100644 --- a/default_context.go +++ b/default_context.go @@ -139,6 +139,9 @@ func (d *DefaultContext) Render(status int, rr render.Renderer) error { if d.Session() != nil { d.Flash().Clear() d.Flash().persist(d.Session()) + if err := d.Session().Save(); err != nil { + return HTTPError{Status: http.StatusInternalServerError, Cause: err} + } } d.Response().Header().Set("Content-Type", rr.ContentType()) @@ -191,7 +194,13 @@ var mapType = reflect.ValueOf(map[string]interface{}{}).Type() // Redirect a request with the given status to the given URL. func (d *DefaultContext) Redirect(status int, url string, args ...interface{}) error { - d.Flash().persist(d.Session()) + if d.Session() != nil { + d.Flash().Clear() + d.Flash().persist(d.Session()) + if err := d.Session().Save(); err != nil { + return HTTPError{Status: http.StatusInternalServerError, Cause: err} + } + } if strings.HasSuffix(url, "Path()") { if len(args) > 1 { diff --git a/flash.go b/flash.go index 7e5a773c9..ecf08d176 100644 --- a/flash.go +++ b/flash.go @@ -39,7 +39,6 @@ func (f Flash) Add(key, value string) { func (f Flash) persist(session *Session) { b, _ := json.Marshal(f.data) session.Set(flashKey, b) - session.Save() } //newFlash creates a new Flash and loads the session data inside its data. diff --git a/request_logger.go b/request_logger.go index 0ceaf786e..17f2be188 100644 --- a/request_logger.go +++ b/request_logger.go @@ -42,7 +42,6 @@ func RequestLoggerFunc(h Handler) Handler { } irid = rs c.Session().Set("requestor_id", irid) - c.Session().Save() } rid := irid.(string) + "-" + rs diff --git a/route_info.go b/route_info.go index 8bb849ce0..b5e20d512 100644 --- a/route_info.go +++ b/route_info.go @@ -103,7 +103,6 @@ func (ri RouteInfo) ServeHTTP(res http.ResponseWriter, req *http.Request) { events.EmitPayload(EvtRouteStarted, payload) err := a.Middleware.handler(ri)(c) - c.Flash().persist(c.Session()) if err != nil { status := http.StatusInternalServerError diff --git a/session.go b/session.go index 9bfec880a..f49ea60bf 100644 --- a/session.go +++ b/session.go @@ -63,13 +63,3 @@ func (a *App) getSession(r *http.Request, w http.ResponseWriter) *Session { res: w, } } - -func sessionSaver(next Handler) Handler { - return func(c Context) error { - err := next(c) - if err != nil { - return err - } - return c.Session().Save() - } -} diff --git a/session_test.go b/session_test.go new file mode 100644 index 000000000..b69d241a4 --- /dev/null +++ b/session_test.go @@ -0,0 +1,64 @@ +package buffalo + +import ( + "fmt" + "net/http" + "strings" + "testing" + + "github.com/gobuffalo/buffalo/render" + "github.com/gobuffalo/httptest" + + "github.com/stretchr/testify/require" +) + +func Test_Session_SingleCookie(t *testing.T) { + r := require.New(t) + + sessionName := "_test_session" + a := New(Options{SessionName: sessionName}) + rr := render.New(render.Options{}) + + a.GET("/", func(c Context) error { + return c.Render(http.StatusCreated, rr.String("")) + }) + + w := httptest.New(a) + res := w.HTML("/").Get() + + var sessionCookies []string + for _, c := range res.Header().Values("Set-Cookie") { + if strings.HasPrefix(c, sessionName) { + sessionCookies = append(sessionCookies, c) + } + } + + r.Equal(1, len(sessionCookies)) +} + +func Test_Session_CustomValue(t *testing.T) { + r := require.New(t) + + a := New(Options{}) + rr := render.New(render.Options{}) + + // Root path sets a custom session value + a.GET("/", func(c Context) error { + c.Session().Set("example", "test") + return c.Render(http.StatusCreated, rr.String("")) + }) + // /session path prints custom session value as response + a.GET("/session", func(c Context) error { + sessionValue := c.Session().Get("example") + return c.Render(http.StatusCreated, rr.String(fmt.Sprintf("%s", sessionValue))) + }) + + w := httptest.New(a) + _ = w.HTML("/").Get() + + // Create second request that should contain the cookie from the first response + reqGetSession := w.HTML("/session") + resGetSession := reqGetSession.Get() + + r.Equal(resGetSession.Body.String(), "test") +}