Skip to content

Commit

Permalink
Support HTTP POST requests (#44)
Browse files Browse the repository at this point in the history
* allow caching of other HTTP methods like POST

* add contentlength and bodyhash key template vars

For caching POST requests it's important to distinguish requests between different body contents. This commit adds `http.request.contentlength` and `http.request.bodyhash`. For the body hash it's important that the cache key is calculated before the body has been read (before it's empty). Therefore the key is passed to `fetchUpstream`.
  • Loading branch information
corneliusludmann authored Aug 1, 2021
1 parent b904cb7 commit a00fa59
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 10 deletions.
31 changes: 31 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package httpcache
import (
"bytes"
"context"
"crypto/sha1"
"encoding/json"
"fmt"
"hash/crc32"
"io"
"io/ioutil"
"math"
"net/http"
"net/url"
Expand Down Expand Up @@ -431,9 +433,38 @@ func (h *HTTPCache) getBucketIndexForKey(key string) uint32 {
// In caddy2, it is automatically add the map by addHTTPVarsToReplacer
func getKey(cacheKeyTemplate string, r *http.Request) string {
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)

// Add contentlength and bodyhash when not added before
if _, ok := repl.Get("http.request.contentlength"); !ok {
repl.Set("http.request.contentlength", r.ContentLength)
repl.Map(func(key string) (interface{}, bool) {
if key == "http.request.bodyhash" {
return bodyHash(r), true
}
return nil, false
})
}

return repl.ReplaceKnown(cacheKeyTemplate, "")
}

// bodyHash calculates a hash value of the request body
func bodyHash(r *http.Request) string {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return ""
}

h := sha1.New()
h.Write(body)
bs := h.Sum(nil)
result := fmt.Sprintf("%x", bs)

r.Body = ioutil.NopCloser(bytes.NewBuffer(body))

return result
}

// Get returns the cached response
func (h *HTTPCache) Get(key string, request *http.Request, includeStale bool) (*Entry, bool) {
b := h.getBucketIndexForKey(key)
Expand Down
27 changes: 27 additions & 0 deletions cache_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
package httpcache

import (
"bytes"
"context"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"

"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/sillygod/cdp-cache/backends"
"github.com/stretchr/testify/suite"
)
Expand Down Expand Up @@ -271,9 +275,32 @@ func (suite *HTTPCacheTestSuite) TearDownSuite() {
suite.Nil(err)
}

type KeyTestSuite struct {
suite.Suite
}

func (suite *KeyTestSuite) TestContentLengthInKey() {
body := []byte(`{"search":"my search string"}`)
req := httptest.NewRequest("POST", "/", bytes.NewBuffer(body))
ctx := context.WithValue(req.Context(), caddy.ReplacerCtxKey, caddyhttp.NewTestReplacer(req))
req = req.WithContext(ctx)
key := getKey("{http.request.contentlength}", req)
suite.Equal("29", key)
}

func (suite *KeyTestSuite) TestBodyHashInKey() {
body := []byte(`{"search":"my search string"}`)
req := httptest.NewRequest("POST", "/", bytes.NewBuffer(body))
ctx := context.WithValue(req.Context(), caddy.ReplacerCtxKey, caddyhttp.NewTestReplacer(req))
req = req.WithContext(ctx)
key := getKey("{http.request.bodyhash}", req)
suite.Equal("5edeb27ddae03685d04df2ab56ebf11fb9c8a711", key)
}

func TestCacheStatusTestSuite(t *testing.T) {
suite.Run(t, new(CacheStatusTestSuite))
suite.Run(t, new(HTTPCacheTestSuite))
suite.Run(t, new(RuleMatcherTestSuite))
suite.Run(t, new(EntryTestSuite))
suite.Run(t, new(KeyTestSuite))
}
10 changes: 10 additions & 0 deletions caddyfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ var (
defaultLockTimeout = time.Duration(5) * time.Minute
defaultMaxAge = time.Duration(5) * time.Minute
defaultPath = "/tmp/caddy_cache"
defaultMatchMethods = []string{"GET", "HEAD"}
defaultCacheType = file
defaultcacheBucketsNum = 256
defaultCacheMaxMemorySize = GB // default is 1 GB
Expand All @@ -56,6 +57,7 @@ const (
keyPath = "path"
keyMatchHeader = "match_header"
keyMatchPath = "match_path"
keyMatchMethod = "match_methods"
keyCacheKey = "cache_key"
keyCacheBucketsNum = "cache_bucket_num"
keyCacheMaxMemorySize = "cache_max_memory_size"
Expand All @@ -82,6 +84,7 @@ type Config struct {
LockTimeout time.Duration `json:"lock_timeout,omitempty"`
RuleMatchersRaws []RuleMatcherRawWithType `json:"rule_matcher_raws,omitempty"`
RuleMatchers []RuleMatcher `json:"-"`
MatchMethods []string `json:"match_methods,omitempty"`
CacheBucketsNum int `json:"cache_buckets_num,omitempty"`
CacheMaxMemorySize int `json:"cache_max_memory_size,omitempty"`
Path string `json:"path,omitempty"`
Expand All @@ -97,6 +100,7 @@ func getDefaultConfig() *Config {
LockTimeout: defaultLockTimeout,
RuleMatchersRaws: []RuleMatcherRawWithType{},
RuleMatchers: []RuleMatcher{},
MatchMethods: defaultMatchMethods,
CacheBucketsNum: defaultcacheBucketsNum,
CacheMaxMemorySize: defaultCacheMaxMemorySize,
Path: defaultPath,
Expand Down Expand Up @@ -215,6 +219,12 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
Data: data,
})

case keyMatchMethod:
if len(args) < 2 {
return d.Err("Invalid usage of match_method in cache config.")
}
config.MatchMethods = append(config.MatchMethods, args...)

case keyCacheKey:
if len(args) != 1 {
return d.Err(fmt.Sprintf("Invalid usage of %s in cache config.", keyCacheKey))
Expand Down
10 changes: 5 additions & 5 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func popOrNil(h *Handler, errChan chan error) (err error) {

}

func (h *Handler) fetchUpstream(req *http.Request, next caddyhttp.Handler) (*Entry, error) {
func (h *Handler) fetchUpstream(req *http.Request, next caddyhttp.Handler, key string) (*Entry, error) {
// Create a new empty response
response := NewResponse()

Expand Down Expand Up @@ -131,7 +131,7 @@ func (h *Handler) fetchUpstream(req *http.Request, next caddyhttp.Handler) (*Ent
response.WaitHeaders()

// Create a new CacheEntry
return NewEntry(getKey(h.Config.CacheKeyTemplate, req), req, response, h.Config), popOrNil(h, errChan)
return NewEntry(key, req, response, h.Config), popOrNil(h, errChan)
}

// CaddyModule returns the Caddy module information
Expand Down Expand Up @@ -329,7 +329,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht

}(h, start)

if !shouldUseCache(r) {
if !shouldUseCache(r, h.Config) {
h.addStatusHeaderIfConfigured(w, cacheBypass)
return next.ServeHTTP(w, r)
}
Expand Down Expand Up @@ -359,7 +359,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
if h.Distributed != nil {
// new an entry without fetching the upstream
response := NewResponse()
entry := NewEntry(getKey(h.Config.CacheKeyTemplate, r), r, response, h.Config)
entry := NewEntry(key, r, response, h.Config)
err := entry.setBackend(r.Context(), h.Config)
if err != nil {
return caddyhttp.Error(http.StatusInternalServerError, err)
Expand Down Expand Up @@ -393,7 +393,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
// It should be fetched from upstream and save it in cache

t := time.Now()
entry, err := h.fetchUpstream(r, next)
entry, err := h.fetchUpstream(r, next, key)
upstreamDuration = time.Since(t)

if entry.Response.Code >= 500 {
Expand Down
36 changes: 34 additions & 2 deletions handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ func (suite *HandlerProvisionTestSuite) TestProvisionRedisBackend() {

type DetermineShouldCacheTestSuite struct {
suite.Suite
Config *Config
}

func (suite *DetermineShouldCacheTestSuite) SetupSuite() {
if suite.Config == nil {
suite.Config = getDefaultConfig()
}
}

func (suite *DetermineShouldCacheTestSuite) TestWebsocketConnection() {
Expand All @@ -119,20 +126,45 @@ func (suite *DetermineShouldCacheTestSuite) TestWebsocketConnection() {

for _, test := range tests {
req := makeRequest("/", test.header)
shouldBeCached := shouldUseCache(req)
shouldBeCached := shouldUseCache(req, suite.Config)
suite.Equal(test.shouldBeCached, shouldBeCached)
}

}

func (suite *DetermineShouldCacheTestSuite) TestNonGETOrHeadMethod() {
r := httptest.NewRequest("POST", "/", nil)
shouldBeCached := shouldUseCache(r)
shouldBeCached := shouldUseCache(r, suite.Config)
suite.False(shouldBeCached)
}

type DetermineShouldCachePOSTOnlyTestSuite struct {
suite.Suite
Config *Config
}

func (suite *DetermineShouldCachePOSTOnlyTestSuite) SetupSuite() {
if suite.Config == nil {
suite.Config = getDefaultConfig()
suite.Config.MatchMethods = []string{"POST"}
}
}

func (suite *DetermineShouldCachePOSTOnlyTestSuite) TestPOSTMethod() {
r := httptest.NewRequest("POST", "/", nil)
shouldBeCached := shouldUseCache(r, suite.Config)
suite.True(shouldBeCached)
}

func (suite *DetermineShouldCachePOSTOnlyTestSuite) TestGETMethod() {
r := httptest.NewRequest("GET", "/", nil)
shouldBeCached := shouldUseCache(r, suite.Config)
suite.False(shouldBeCached)
}

func TestCacheKeyTemplatingTestSuite(t *testing.T) {
suite.Run(t, new(CacheKeyTemplatingTestSuite))
suite.Run(t, new(DetermineShouldCacheTestSuite))
suite.Run(t, new(DetermineShouldCachePOSTOnlyTestSuite))
suite.Run(t, new(HandlerProvisionTestSuite))
}
5 changes: 5 additions & 0 deletions readme.org
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@
*** match_path
Only the request's path match the condition will be cached. Ex. =/= means all request need to be cached because all request's path must start with =/=

*** match_methods
By default, only =GET= and =POST= methods are cached. If you would like to cache other methods as well you can configure here which methods should be cached, e.g.: =GET HEAD POST=.

To be able to distinguish different POST requests, it is advisable to include the body hash in the cache key, e.g.: ={http.request.method} {http.request.host}{http.request.uri.path}?{http.request.uri.query} {http.request.contentlength} {http.request.bodyhash}=

*** default_max_age
The cache's expiration time.

Expand Down
12 changes: 9 additions & 3 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,16 @@ func (r *Response) WriteHeader(code int) {
r.headersChan <- struct{}{}
}

func shouldUseCache(req *http.Request) bool {
func shouldUseCache(req *http.Request, config *Config) bool {

if req.Method != "GET" && req.Method != "HEAD" {
// Only cache Get and head request
matchMethod := false
for _, method := range config.MatchMethods {
if method == req.Method {
matchMethod = true
break
}
}
if !matchMethod {
return false
}

Expand Down

0 comments on commit a00fa59

Please sign in to comment.