diff --git a/relay/endpoints_php_test.go b/relay/endpoints_php_test.go index 21322dfe..5f1fcd27 100644 --- a/relay/endpoints_php_test.go +++ b/relay/endpoints_php_test.go @@ -68,6 +68,7 @@ func TestEndpointsPHPPolling(t *testing.T) { result, _ := st.DoRequest(r, p.relay) assert.Equal(t, http.StatusNotModified, result.StatusCode) + assert.NotEmpty(t, result.Header.Get("Expires")) }) t.Run("query with different ETag is cached", func(t *testing.T) { @@ -76,6 +77,7 @@ func TestEndpointsPHPPolling(t *testing.T) { result, _ := st.DoRequest(r, p.relay) assert.Equal(t, http.StatusOK, result.StatusCode) + assert.NotEmpty(t, result.Header.Get("Expires")) }) } } diff --git a/relay/relay_endpoints.go b/relay/relay_endpoints.go index 29edb3e9..3eb8a8bc 100644 --- a/relay/relay_endpoints.go +++ b/relay/relay_endpoints.go @@ -297,15 +297,6 @@ func pollFlagOrSegment(clientContext relayenv.EnvContext, kind ldstoretypes.Data func writeCacheableJSONResponse(w http.ResponseWriter, req *http.Request, clientContext relayenv.EnvContext, bytes []byte, etagValue string) { - etag := fmt.Sprintf("relay-%s", etagValue) // just to make it extra clear that these are relay-specific etags - if cachedEtag := req.Header.Get("If-None-Match"); cachedEtag != "" { - if cachedEtag == etag { - w.WriteHeader(http.StatusNotModified) - return - } - } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Etag", etag) ttl := clientContext.GetTTL() if ttl > 0 { w.Header().Set("Vary", "Authorization") @@ -315,7 +306,17 @@ func writeCacheableJSONResponse(w http.ResponseWriter, req *http.Request, client // HTTP cache in front of ld-relay, multiple clients hitting the cache at different times // will all see the same expiration time. } + + etag := fmt.Sprintf("relay-%s", etagValue) // just to make it extra clear that these are relay-specific etags + if cachedEtag := req.Header.Get("If-None-Match"); cachedEtag == etag { + w.WriteHeader(http.StatusNotModified) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Etag", etag) w.WriteHeader(http.StatusOK) + _, _ = w.Write(bytes) }