diff --git a/cache.go b/cache.go index 65d3eca..ecabd97 100644 --- a/cache.go +++ b/cache.go @@ -3,10 +3,12 @@ package httpcache import ( "bytes" "context" + "crypto/sha1" "encoding/json" "fmt" "hash/crc32" "io" + "io/ioutil" "math" "net/http" "net/url" @@ -431,9 +433,38 @@ func (h *HTTPCache) getBucketIndexForKey(key string) uint32 { // In caddy2, it is automatically add the map by addHTTPVarsToReplacer func getKey(cacheKeyTemplate string, r *http.Request) string { repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) + + // Add contentlength and bodyhash when not added before + if _, ok := repl.Get("http.request.contentlength"); !ok { + repl.Set("http.request.contentlength", r.ContentLength) + repl.Map(func(key string) (interface{}, bool) { + if key == "http.request.bodyhash" { + return bodyHash(r), true + } + return nil, false + }) + } + return repl.ReplaceKnown(cacheKeyTemplate, "") } +// bodyHash calculates a hash value of the request body +func bodyHash(r *http.Request) string { + body, err := ioutil.ReadAll(r.Body) + if err != nil { + return "" + } + + h := sha1.New() + h.Write(body) + bs := h.Sum(nil) + result := fmt.Sprintf("%x", bs) + + r.Body = ioutil.NopCloser(bytes.NewBuffer(body)) + + return result +} + // Get returns the cached response func (h *HTTPCache) Get(key string, request *http.Request, includeStale bool) (*Entry, bool) { b := h.getBucketIndexForKey(key) diff --git a/cache_test.go b/cache_test.go index 5463211..bece14c 100644 --- a/cache_test.go +++ b/cache_test.go @@ -1,6 +1,8 @@ package httpcache import ( + "bytes" + "context" "io/ioutil" "net/http" "net/http/httptest" @@ -8,6 +10,8 @@ import ( "testing" "time" + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/modules/caddyhttp" "github.com/sillygod/cdp-cache/backends" "github.com/stretchr/testify/suite" ) @@ -271,9 +275,32 @@ func (suite *HTTPCacheTestSuite) TearDownSuite() { suite.Nil(err) } +type KeyTestSuite struct { + suite.Suite +} + +func (suite *KeyTestSuite) TestContentLengthInKey() { + body := []byte(`{"search":"my search string"}`) + req := httptest.NewRequest("POST", "/", bytes.NewBuffer(body)) + ctx := context.WithValue(req.Context(), caddy.ReplacerCtxKey, caddyhttp.NewTestReplacer(req)) + req = req.WithContext(ctx) + key := getKey("{http.request.contentlength}", req) + suite.Equal("29", key) +} + +func (suite *KeyTestSuite) TestBodyHashInKey() { + body := []byte(`{"search":"my search string"}`) + req := httptest.NewRequest("POST", "/", bytes.NewBuffer(body)) + ctx := context.WithValue(req.Context(), caddy.ReplacerCtxKey, caddyhttp.NewTestReplacer(req)) + req = req.WithContext(ctx) + key := getKey("{http.request.bodyhash}", req) + suite.Equal("5edeb27ddae03685d04df2ab56ebf11fb9c8a711", key) +} + func TestCacheStatusTestSuite(t *testing.T) { suite.Run(t, new(CacheStatusTestSuite)) suite.Run(t, new(HTTPCacheTestSuite)) suite.Run(t, new(RuleMatcherTestSuite)) suite.Run(t, new(EntryTestSuite)) + suite.Run(t, new(KeyTestSuite)) } diff --git a/caddyfile.go b/caddyfile.go index d6f11b3..7ea3ab6 100644 --- a/caddyfile.go +++ b/caddyfile.go @@ -39,6 +39,7 @@ var ( defaultLockTimeout = time.Duration(5) * time.Minute defaultMaxAge = time.Duration(5) * time.Minute defaultPath = "/tmp/caddy_cache" + defaultMatchMethods = []string{"GET", "HEAD"} defaultCacheType = file defaultcacheBucketsNum = 256 defaultCacheMaxMemorySize = GB // default is 1 GB @@ -56,6 +57,7 @@ const ( keyPath = "path" keyMatchHeader = "match_header" keyMatchPath = "match_path" + keyMatchMethod = "match_methods" keyCacheKey = "cache_key" keyCacheBucketsNum = "cache_bucket_num" keyCacheMaxMemorySize = "cache_max_memory_size" @@ -82,6 +84,7 @@ type Config struct { LockTimeout time.Duration `json:"lock_timeout,omitempty"` RuleMatchersRaws []RuleMatcherRawWithType `json:"rule_matcher_raws,omitempty"` RuleMatchers []RuleMatcher `json:"-"` + MatchMethods []string `json:"match_methods,omitempty"` CacheBucketsNum int `json:"cache_buckets_num,omitempty"` CacheMaxMemorySize int `json:"cache_max_memory_size,omitempty"` Path string `json:"path,omitempty"` @@ -97,6 +100,7 @@ func getDefaultConfig() *Config { LockTimeout: defaultLockTimeout, RuleMatchersRaws: []RuleMatcherRawWithType{}, RuleMatchers: []RuleMatcher{}, + MatchMethods: defaultMatchMethods, CacheBucketsNum: defaultcacheBucketsNum, CacheMaxMemorySize: defaultCacheMaxMemorySize, Path: defaultPath, @@ -215,6 +219,12 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { Data: data, }) + case keyMatchMethod: + if len(args) < 2 { + return d.Err("Invalid usage of match_method in cache config.") + } + config.MatchMethods = append(config.MatchMethods, args...) + case keyCacheKey: if len(args) != 1 { return d.Err(fmt.Sprintf("Invalid usage of %s in cache config.", keyCacheKey)) diff --git a/handler.go b/handler.go index 76a3942..ddc7c4e 100644 --- a/handler.go +++ b/handler.go @@ -93,7 +93,7 @@ func popOrNil(h *Handler, errChan chan error) (err error) { } -func (h *Handler) fetchUpstream(req *http.Request, next caddyhttp.Handler) (*Entry, error) { +func (h *Handler) fetchUpstream(req *http.Request, next caddyhttp.Handler, key string) (*Entry, error) { // Create a new empty response response := NewResponse() @@ -131,7 +131,7 @@ func (h *Handler) fetchUpstream(req *http.Request, next caddyhttp.Handler) (*Ent response.WaitHeaders() // Create a new CacheEntry - return NewEntry(getKey(h.Config.CacheKeyTemplate, req), req, response, h.Config), popOrNil(h, errChan) + return NewEntry(key, req, response, h.Config), popOrNil(h, errChan) } // CaddyModule returns the Caddy module information @@ -329,7 +329,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht }(h, start) - if !shouldUseCache(r) { + if !shouldUseCache(r, h.Config) { h.addStatusHeaderIfConfigured(w, cacheBypass) return next.ServeHTTP(w, r) } @@ -359,7 +359,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht if h.Distributed != nil { // new an entry without fetching the upstream response := NewResponse() - entry := NewEntry(getKey(h.Config.CacheKeyTemplate, r), r, response, h.Config) + entry := NewEntry(key, r, response, h.Config) err := entry.setBackend(r.Context(), h.Config) if err != nil { return caddyhttp.Error(http.StatusInternalServerError, err) @@ -393,7 +393,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht // It should be fetched from upstream and save it in cache t := time.Now() - entry, err := h.fetchUpstream(r, next) + entry, err := h.fetchUpstream(r, next, key) upstreamDuration = time.Since(t) if entry.Response.Code >= 500 { diff --git a/handler_test.go b/handler_test.go index 2c8f388..1a12c48 100644 --- a/handler_test.go +++ b/handler_test.go @@ -96,6 +96,13 @@ func (suite *HandlerProvisionTestSuite) TestProvisionRedisBackend() { type DetermineShouldCacheTestSuite struct { suite.Suite + Config *Config +} + +func (suite *DetermineShouldCacheTestSuite) SetupSuite() { + if suite.Config == nil { + suite.Config = getDefaultConfig() + } } func (suite *DetermineShouldCacheTestSuite) TestWebsocketConnection() { @@ -119,7 +126,7 @@ func (suite *DetermineShouldCacheTestSuite) TestWebsocketConnection() { for _, test := range tests { req := makeRequest("/", test.header) - shouldBeCached := shouldUseCache(req) + shouldBeCached := shouldUseCache(req, suite.Config) suite.Equal(test.shouldBeCached, shouldBeCached) } @@ -127,12 +134,37 @@ func (suite *DetermineShouldCacheTestSuite) TestWebsocketConnection() { func (suite *DetermineShouldCacheTestSuite) TestNonGETOrHeadMethod() { r := httptest.NewRequest("POST", "/", nil) - shouldBeCached := shouldUseCache(r) + shouldBeCached := shouldUseCache(r, suite.Config) + suite.False(shouldBeCached) +} + +type DetermineShouldCachePOSTOnlyTestSuite struct { + suite.Suite + Config *Config +} + +func (suite *DetermineShouldCachePOSTOnlyTestSuite) SetupSuite() { + if suite.Config == nil { + suite.Config = getDefaultConfig() + suite.Config.MatchMethods = []string{"POST"} + } +} + +func (suite *DetermineShouldCachePOSTOnlyTestSuite) TestPOSTMethod() { + r := httptest.NewRequest("POST", "/", nil) + shouldBeCached := shouldUseCache(r, suite.Config) + suite.True(shouldBeCached) +} + +func (suite *DetermineShouldCachePOSTOnlyTestSuite) TestGETMethod() { + r := httptest.NewRequest("GET", "/", nil) + shouldBeCached := shouldUseCache(r, suite.Config) suite.False(shouldBeCached) } func TestCacheKeyTemplatingTestSuite(t *testing.T) { suite.Run(t, new(CacheKeyTemplatingTestSuite)) suite.Run(t, new(DetermineShouldCacheTestSuite)) + suite.Run(t, new(DetermineShouldCachePOSTOnlyTestSuite)) suite.Run(t, new(HandlerProvisionTestSuite)) } diff --git a/readme.org b/readme.org index ce33d49..c5eed6a 100644 --- a/readme.org +++ b/readme.org @@ -119,6 +119,11 @@ *** match_path Only the request's path match the condition will be cached. Ex. =/= means all request need to be cached because all request's path must start with =/= +*** match_methods + By default, only =GET= and =POST= methods are cached. If you would like to cache other methods as well you can configure here which methods should be cached, e.g.: =GET HEAD POST=. + + To be able to distinguish different POST requests, it is advisable to include the body hash in the cache key, e.g.: ={http.request.method} {http.request.host}{http.request.uri.path}?{http.request.uri.query} {http.request.contentlength} {http.request.bodyhash}= + *** default_max_age The cache's expiration time. diff --git a/response.go b/response.go index 88da87c..34c081d 100644 --- a/response.go +++ b/response.go @@ -144,10 +144,16 @@ func (r *Response) WriteHeader(code int) { r.headersChan <- struct{}{} } -func shouldUseCache(req *http.Request) bool { +func shouldUseCache(req *http.Request, config *Config) bool { - if req.Method != "GET" && req.Method != "HEAD" { - // Only cache Get and head request + matchMethod := false + for _, method := range config.MatchMethods { + if method == req.Method { + matchMethod = true + break + } + } + if !matchMethod { return false }