From ddbd37061cafde67a35d5eee5469418af5d567a6 Mon Sep 17 00:00:00 2001 From: Yonghwan SO Date: Thu, 21 Apr 2022 23:48:37 +0900 Subject: [PATCH] hardening DefaultContext to make it panic-free --- default_context.go | 16 +++++++- default_context_test.go | 81 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 95 insertions(+), 2 deletions(-) diff --git a/default_context.go b/default_context.go index 3f2eec99d..e88acbf4b 100644 --- a/default_context.go +++ b/default_context.go @@ -23,6 +23,9 @@ import ( var _ Context = &DefaultContext{} var _ context.Context = &DefaultContext{} +// TODO(sio4): #road-to-v1 - make DefaultContext private +// and only allow it to be generated by App.newContext() or any similar. + // DefaultContext is, as its name implies, a default // implementation of the Context interface. type DefaultContext struct { @@ -67,16 +70,22 @@ func (d *DefaultContext) Param(key string) string { // Set a value onto the Context. Any value set onto the Context // will be automatically available in templates. func (d *DefaultContext) Set(key string, value interface{}) { + if d.data == nil { + d.data = &sync.Map{} + } d.data.Store(key, value) } // Value that has previously stored on the context. func (d *DefaultContext) Value(key interface{}) interface{} { - if k, ok := key.(string); ok { + if k, ok := key.(string); ok && d.data != nil { if v, ok := d.data.Load(k); ok { return v } } + if d.Context == nil { + return nil + } return d.Context.Value(key) } @@ -235,6 +244,11 @@ func (d *DefaultContext) Redirect(status int, url string, args ...interface{}) e // Data contains all the values set through Get/Set. func (d *DefaultContext) Data() map[string]interface{} { m := map[string]interface{}{} + + if d.data == nil { + return m + } + d.data.Range(func(k, v interface{}) bool { s, ok := k.(string) if !ok { diff --git a/default_context_test.go b/default_context_test.go index 2823828cc..203781641 100644 --- a/default_context_test.go +++ b/default_context_test.go @@ -143,6 +143,15 @@ func Test_DefaultContext_GetSet(t *testing.T) { r.Equal("Mark", c.Value("name").(string)) } +func Test_DefaultContext_Set_Unconfigured(t *testing.T) { + r := require.New(t) + c := DefaultContext{} + + c.Set("name", "Yonghwan") + r.NotNil(c.Value("name")) + r.Equal("Yonghwan", c.Value("name").(string)) +} + func Test_DefaultContext_Value(t *testing.T) { r := require.New(t) c := basicContext() @@ -151,7 +160,12 @@ func Test_DefaultContext_Value(t *testing.T) { c.Set("name", "Mark") r.NotNil(c.Value("name")) r.Equal("Mark", c.Value("name").(string)) - r.Equal("Mark", c.Value("name").(string)) +} + +func Test_DefaultContext_Value_Unconfigured(t *testing.T) { + r := require.New(t) + c := DefaultContext{} + r.Nil(c.Value("name")) } func Test_DefaultContext_Render(t *testing.T) { @@ -301,3 +315,68 @@ func Test_DefaultContext_Bind_JSON(t *testing.T) { r.Equal("Mark", user.FirstName) } + +func Test_DefaultContext_Data(t *testing.T) { + r := require.New(t) + c := basicContext() + + r.EqualValues(map[string]interface{}{}, c.Data()) +} + +func Test_DefaultContext_Data_Unconfigured(t *testing.T) { + r := require.New(t) + c := DefaultContext{} + + r.EqualValues(map[string]interface{}{}, c.Data()) +} + +func Test_DefaultContext_String(t *testing.T) { + r := require.New(t) + c := basicContext() + c.Set("name", "Buffalo") + c.Set("language", "go") + + r.EqualValues("language: go\n\nname: Buffalo", c.String()) +} + +func Test_DefaultContext_String_EmptyData(t *testing.T) { + r := require.New(t) + c := basicContext() + r.EqualValues("", c.String()) +} + +func Test_DefaultContext_String_EmptyData_Unconfigured(t *testing.T) { + r := require.New(t) + c := DefaultContext{} + + r.EqualValues("", c.String()) +} + +func Test_DefaultContext_MarshalJSON(t *testing.T) { + r := require.New(t) + c := basicContext() + c.Set("name", "Buffalo") + c.Set("language", "go") + + jb, err := c.MarshalJSON() + r.NoError(err) + r.EqualValues(`{"language":"go","name":"Buffalo"}`, string(jb)) +} + +func Test_DefaultContext_MarshalJSON_EmptyData(t *testing.T) { + r := require.New(t) + c := basicContext() + + jb, err := c.MarshalJSON() + r.NoError(err) + r.EqualValues(`{}`, string(jb)) +} + +func Test_DefaultContext_MarshalJSON_EmptyData_Unconfigured(t *testing.T) { + r := require.New(t) + c := DefaultContext{} + + jb, err := c.MarshalJSON() + r.NoError(err) + r.EqualValues(`{}`, string(jb)) +}