diff --git a/api/internal/filter/ip_filter.go b/api/internal/filter/ip_filter.go index de62cf5688..2d07dea552 100644 --- a/api/internal/filter/ip_filter.go +++ b/api/internal/filter/ip_filter.go @@ -19,6 +19,7 @@ package filter import ( "net" "net/http" + "strings" "github.com/gin-gonic/gin" @@ -81,7 +82,10 @@ func checkIP(ipStr string, ips map[string]bool, subnets []*subnet) bool { func IPFilter() gin.HandlerFunc { ips, subnets := generateIPSet(conf.AllowList) return func(c *gin.Context) { - ipStr := c.ClientIP() + var ipStr string + if ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr)); err == nil { + ipStr = ip + } if len(conf.AllowList) < 1 { c.Next() diff --git a/api/internal/filter/ip_filter_test.go b/api/internal/filter/ip_filter_test.go index f9de0482b1..dad4da6871 100644 --- a/api/internal/filter/ip_filter_test.go +++ b/api/internal/filter/ip_filter_test.go @@ -17,6 +17,7 @@ package filter import ( + "net/http/httptest" "testing" "github.com/gin-gonic/gin" @@ -55,4 +56,22 @@ func TestIPFilter_Handle(t *testing.T) { }) w = performRequest(r, "GET", "/test") assert.Equal(t, 200, w.Code) + + // should forbidden + conf.AllowList = []string{"127.0.0.1"} + r = gin.New() + r.Use(IPFilter()) + r.GET("/test", func(c *gin.Context) {}) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Forwarded-For", "127.0.0.1") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, 403, w.Code) + + req = httptest.NewRequest("GET", "/test", nil) + req.Header.Set("X-Real-Ip", "127.0.0.1") + w = httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, 403, w.Code) }