Skip to content

Commit

Permalink
fix: Support string type for the script field in Route (#1289)
Browse files Browse the repository at this point in the history
* Support string type for script field in Route

Signed-off-by: imjoey <[email protected]>

* Add validating lua code when create/update routes

also improve the test case in unittest and e2e

Signed-off-by: imjoey <[email protected]>

* typo fix and style format

Signed-off-by: imjoey <[email protected]>

* Improve testcases

Signed-off-by: imjoey <[email protected]>

* Addtional check the Script via log in APISIX

Signed-off-by: imjoey <[email protected]>

* ngx.log print log in error.log, instead of access.log

Signed-off-by: imjoey <[email protected]>

* Use ngx.WARN instead of INFO to enable output

Signed-off-by: imjoey <[email protected]>
  • Loading branch information
imjoey authored Jan 15, 2021
1 parent 5f22326 commit b9a0227
Show file tree
Hide file tree
Showing 6 changed files with 389 additions and 16 deletions.
50 changes: 36 additions & 14 deletions api/internal/handler/route/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ import (
"github.com/shiningrush/droplet"
"github.com/shiningrush/droplet/data"
"github.com/shiningrush/droplet/wrapper"
"github.com/yuin/gopher-lua"
wgin "github.com/shiningrush/droplet/wrapper/gin"
lua "github.com/yuin/gopher-lua"

"github.com/apisix/manager-api/internal/conf"
"github.com/apisix/manager-api/internal/core/entity"
"github.com/apisix/manager-api/internal/core/store"
Expand Down Expand Up @@ -327,12 +328,25 @@ func (h *Handler) Create(c droplet.Context) (interface{}, error) {
script := &entity.Script{}
script.ID = utils.InterfaceToString(input.ID)
script.Script = input.Script
//to lua

var err error
input.Script, err = generateLuaCode(input.Script.(map[string]interface{}))
if err != nil {
return nil, err
// Explicitly to lua if input script is of the map type, otherwise
// it will always represent a piece of lua code of the string type.
if scriptConf, ok := input.Script.(map[string]interface{}); ok {
// For lua code of map type, syntax validation is done by
// the generateLuaCode function
input.Script, err = generateLuaCode(scriptConf)
if err != nil {
return &data.SpecCodeResponse{StatusCode: http.StatusBadRequest}, err
}
} else {
// For lua code of string type, use utility func to syntax validation
err = utils.ValidateLuaCode(input.Script.(string))
if err != nil {
return &data.SpecCodeResponse{StatusCode: http.StatusBadRequest}, err
}
}

//save original conf
if err = h.scriptStore.Create(c.Context(), script); err != nil {
return nil, err
Expand Down Expand Up @@ -392,17 +406,25 @@ func (h *Handler) Update(c droplet.Context) (interface{}, error) {
script := &entity.Script{}
script.ID = input.ID
script.Script = input.Script
//to lua

var err error
scriptConf, ok := input.Script.(map[string]interface{})
if !ok {
return &data.SpecCodeResponse{StatusCode: http.StatusBadRequest},
fmt.Errorf("invalid `script`")
}
input.Route.Script, err = generateLuaCode(scriptConf)
if err != nil {
return &data.SpecCodeResponse{StatusCode: http.StatusInternalServerError}, err
// Explicitly to lua if input script is of the map type, otherwise
// it will always represent a piece of lua code of the string type.
if scriptConf, ok := input.Script.(map[string]interface{}); ok {
// For lua code of map type, syntax validation is done by
// the generateLuaCode function
input.Route.Script, err = generateLuaCode(scriptConf)
if err != nil {
return &data.SpecCodeResponse{StatusCode: http.StatusBadRequest}, err
}
} else {
// For lua code of string type, use utility func to syntax validation
err = utils.ValidateLuaCode(input.Script.(string))
if err != nil {
return &data.SpecCodeResponse{StatusCode: http.StatusBadRequest}, err
}
}

//save original conf
if err = h.scriptStore.Update(c.Context(), script, true); err != nil {
//if not exists, create
Expand Down
149 changes: 147 additions & 2 deletions api/internal/handler/route/route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ func TestRoute(t *testing.T) {
dataPage = retPage.(*store.ListOutput)
assert.Equal(t, len(dataPage.Rows), 1)

//sleep
//sleep
time.Sleep(time.Duration(100) * time.Millisecond)

// list search and status not match
Expand Down Expand Up @@ -1197,7 +1197,7 @@ func TestRoute(t *testing.T) {
assert.Nil(t, err)
}

func Test_Route_With_Script(t *testing.T) {
func Test_Route_With_Script_Dag2lua(t *testing.T) {
// init
err := storage.InitETCDClient(conf.ETCDConfig)
assert.Nil(t, err)
Expand Down Expand Up @@ -1349,3 +1349,148 @@ func Test_Route_With_Script(t *testing.T) {
_, err = handler.BatchDelete(ctx)
assert.Nil(t, err)
}

func Test_Route_With_Script_Luacode(t *testing.T) {
// init
err := storage.InitETCDClient(conf.ETCDConfig)
assert.Nil(t, err)
err = store.InitStores()
assert.Nil(t, err)

handler := &Handler{
routeStore: store.GetStore(store.HubKeyRoute),
svcStore: store.GetStore(store.HubKeyService),
upstreamStore: store.GetStore(store.HubKeyUpstream),
scriptStore: store.GetStore(store.HubKeyScript),
}
assert.NotNil(t, handler)

// create with script of valid lua syntax
ctx := droplet.NewContext()
route := &entity.Route{}
reqBody := `{
"id": "1",
"uri": "/index.html",
"upstream": {
"type": "roundrobin",
"nodes": [{
"host": "www.a.com",
"port": 80,
"weight": 1
}]
},
"script": "local _M = {} \n function _M.access(api_ctx) \n ngx.log(ngx.WARN,\"hit access phase\") \n end \nreturn _M"
}`
err = json.Unmarshal([]byte(reqBody), route)
assert.Nil(t, err)
ctx.SetInput(route)
_, err = handler.Create(ctx)
assert.Nil(t, err)

// sleep
time.Sleep(time.Duration(20) * time.Millisecond)

// get
input := &GetInput{}
input.ID = "1"
ctx.SetInput(input)
ret, err := handler.Get(ctx)
stored := ret.(*entity.Route)
assert.Nil(t, err)
assert.Equal(t, stored.ID, route.ID)
assert.Equal(t, "local _M = {} \n function _M.access(api_ctx) \n ngx.log(ngx.WARN,\"hit access phase\") \n end \nreturn _M", stored.Script)

// update via empty script
route2 := &UpdateInput{}
route2.ID = "1"
reqBody = `{
"id": "1",
"uri": "/index.html",
"enable_websocket": true,
"upstream": {
"type": "roundrobin",
"nodes": [{
"host": "www.a.com",
"port": 80,
"weight": 1
}]
}
}`

err = json.Unmarshal([]byte(reqBody), route2)
assert.Nil(t, err)
ctx.SetInput(route2)
_, err = handler.Update(ctx)
assert.Nil(t, err)

//sleep
time.Sleep(time.Duration(100) * time.Millisecond)

//get, script should be nil
input = &GetInput{}
input.ID = "1"
ctx.SetInput(input)
ret, err = handler.Get(ctx)
stored = ret.(*entity.Route)
assert.Nil(t, err)
assert.Equal(t, stored.ID, route.ID)
assert.Nil(t, stored.Script)

// 2nd update via invalid script
input3 := &UpdateInput{}
input3.ID = "1"
reqBody = `{
"id": "1",
"uri": "/index.html",
"enable_websocket": true,
"upstream": {
"type": "roundrobin",
"nodes": [{
"host": "www.a.com",
"port": 80,
"weight": 1
}]
},
"script": "local _M = {} \n function _M.access(api_ctx) \n ngx.log(ngx.WARN,\"hit access phase\")"
}`

err = json.Unmarshal([]byte(reqBody), input3)
assert.Nil(t, err)
ctx.SetInput(input3)
_, err = handler.Update(ctx)
// err should NOT be nil
assert.NotNil(t, err)

// delete test data
inputDel := &BatchDelete{}
reqBody = `{"ids": "1"}`
err = json.Unmarshal([]byte(reqBody), inputDel)
assert.Nil(t, err)
ctx.SetInput(inputDel)
_, err = handler.BatchDelete(ctx)
assert.Nil(t, err)

// 2nd create with script of invalid lua syntax
ctx = droplet.NewContext()
route = &entity.Route{}
reqBody = `{
"id": "1",
"uri": "/index.html",
"upstream": {
"type": "roundrobin",
"nodes": [{
"host": "www.a.com",
"port": 80,
"weight": 1
}]
},
"script": "local _M = {} \n function _M.access(api_ctx) \n ngx.log(ngx.WARN,\"hit access phase\")"
}`
err = json.Unmarshal([]byte(reqBody), route)
assert.Nil(t, err)
ctx.SetInput(route)
ret, err = handler.Create(ctx)
assert.NotNil(t, err)
assert.EqualError(t, err, "<string> at EOF: syntax error\n")
assert.Equal(t, http.StatusBadRequest, ret.(*data.SpecCodeResponse).StatusCode)
}
8 changes: 8 additions & 0 deletions api/internal/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"strings"

"github.com/sony/sonyflake"
"github.com/yuin/gopher-lua/parse"
)

var _sf *sonyflake.Sonyflake
Expand Down Expand Up @@ -161,3 +162,10 @@ func LabelContains(labels map[string]string, reqLabels map[string]struct{}) bool

return false
}

// ValidateLuaCode validates lua syntax for input code, return nil
// if passed, otherwise a non-nil error will be returned
func ValidateLuaCode(code string) error {
_, err := parse.Parse(strings.NewReader(code), "<string>")
return err
}
11 changes: 11 additions & 0 deletions api/internal/utils/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,14 @@ func TestLabelContains(t *testing.T) {
}
assert.True(t, LabelContains(mp, reqMap))
}

func TestValidateLuaCode(t *testing.T) {
validLuaCode := "local _M = {} \n function _M.access(api_ctx) \n ngx.log(ngx.WARN,\"hit access phase\") \n end \nreturn _M"
err := ValidateLuaCode(validLuaCode)
assert.Nil(t, err)

invalidLuaCode := "local _M = {} \n function _M.access(api_ctx) \n ngx.log(ngx.WARN,\"hit access phase\")"
err = ValidateLuaCode(invalidLuaCode)
assert.NotNil(t, err)
assert.Equal(t, "<string> at EOF: syntax error\n", err.Error())
}
31 changes: 31 additions & 0 deletions api/test/e2e/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,34 @@ func CleanAPISIXErrorLog(t *testing.T) {
}
assert.Nil(t, err)
}

// ReadAPISIXAccessLog reads the access log of APISIX.
func ReadAPISIXAccessLog(t *testing.T) string {
cmd := exec.Command("pwd")
pwdByte, err := cmd.CombinedOutput()
pwd := string(pwdByte)
pwd = strings.Replace(pwd, "\n", "", 1)
pwd = pwd[:strings.Index(pwd, "/e2e")]
bytes, err := ioutil.ReadFile(pwd + "/docker/apisix_logs/access.log")
assert.Nil(t, err)
logContent := string(bytes)

return logContent
}

// CleanAPISIXAccessLog cleans the access log of APISIX.
// It's always recommended to call this function before checking
// its content.
func CleanAPISIXAccessLog(t *testing.T) {
cmd := exec.Command("pwd")
pwdByte, err := cmd.CombinedOutput()
pwd := string(pwdByte)
pwd = strings.Replace(pwd, "\n", "", 1)
pwd = pwd[:strings.Index(pwd, "/e2e")]
cmd = exec.Command("sudo", "echo", " > ", pwd+"/docker/apisix_logs/access.log")
_, err = cmd.CombinedOutput()
if err != nil {
fmt.Println("cmd error:", err.Error())
}
assert.Nil(t, err)
}
Loading

0 comments on commit b9a0227

Please sign in to comment.