diff --git a/v2/appengine.go b/v2/appengine.go index 02e1f24e..1941d36d 100644 --- a/v2/appengine.go +++ b/v2/appengine.go @@ -54,6 +54,9 @@ func Main() { internal.Main() } +// Middleware wraps an http handler so that it can make GAE API calls +var Middleware func(http.Handler) http.Handler = internal.Middleware + // IsDevAppServer reports whether the App Engine app is running in the // development App Server. func IsDevAppServer() bool { diff --git a/v2/internal/api.go b/v2/internal/api.go index 9bf67ad6..41b8e25c 100644 --- a/v2/internal/api.go +++ b/v2/internal/api.go @@ -84,53 +84,63 @@ func apiURL(ctx netcontext.Context) *url.URL { } } -func handleHTTP(w http.ResponseWriter, r *http.Request) { - c := &context{ - req: r, - outHeader: w.Header(), - } - r = r.WithContext(withContext(r.Context(), c)) - c.req = r - - // Patch up RemoteAddr so it looks reasonable. - if addr := r.Header.Get(userIPHeader); addr != "" { - r.RemoteAddr = addr - } else if addr = r.Header.Get(remoteAddrHeader); addr != "" { - r.RemoteAddr = addr - } else { - // Should not normally reach here, but pick a sensible default anyway. - r.RemoteAddr = "127.0.0.1" - } - // The address in the headers will most likely be of these forms: - // 123.123.123.123 - // 2001:db8::1 - // net/http.Request.RemoteAddr is specified to be in "IP:port" form. - if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil { - // Assume the remote address is only a host; add a default port. - r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80") - } - - executeRequestSafely(c, r) - c.outHeader = nil // make sure header changes aren't respected any more - - // Avoid nil Write call if c.Write is never called. - if c.outCode != 0 { - w.WriteHeader(c.outCode) - } - if c.outBody != nil { - w.Write(c.outBody) - } +// Middleware wraps an http handler so that it can make GAE API calls +func Middleware(next http.Handler) http.Handler { + return handleHTTPMiddleware(executeRequestSafelyMiddleware(next)) } -func executeRequestSafely(c *context, r *http.Request) { - defer func() { - if x := recover(); x != nil { - logf(c, 4, "%s", renderPanic(x)) // 4 == critical - c.outCode = 500 +func handleHTTPMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c := &context{ + req: r, + outHeader: w.Header(), + } + r = r.WithContext(withContext(r.Context(), c)) + c.req = r + + // Patch up RemoteAddr so it looks reasonable. + if addr := r.Header.Get(userIPHeader); addr != "" { + r.RemoteAddr = addr + } else if addr = r.Header.Get(remoteAddrHeader); addr != "" { + r.RemoteAddr = addr + } else { + // Should not normally reach here, but pick a sensible default anyway. + r.RemoteAddr = "127.0.0.1" + } + // The address in the headers will most likely be of these forms: + // 123.123.123.123 + // 2001:db8::1 + // net/http.Request.RemoteAddr is specified to be in "IP:port" form. + if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil { + // Assume the remote address is only a host; add a default port. + r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80") } - }() - http.DefaultServeMux.ServeHTTP(c, r) + next.ServeHTTP(c, r) + c.outHeader = nil // make sure header changes aren't respected any more + + // Avoid nil Write call if c.Write is never called. + if c.outCode != 0 { + w.WriteHeader(c.outCode) + } + if c.outBody != nil { + w.Write(c.outBody) + } + }) +} + +func executeRequestSafelyMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if x := recover(); x != nil { + c := w.(*context) + logf(c, 4, "%s", renderPanic(x)) // 4 == critical + c.outCode = 500 + } + }() + + next.ServeHTTP(w, r) + }) } func renderPanic(x interface{}) string { diff --git a/v2/internal/api_test.go b/v2/internal/api_test.go index 67785b35..e073c9b7 100644 --- a/v2/internal/api_test.go +++ b/v2/internal/api_test.go @@ -292,7 +292,7 @@ func TestRemoteAddr(t *testing.T) { Header: tc.headers, Body: ioutil.NopCloser(bytes.NewReader(nil)), } - handleHTTP(httptest.NewRecorder(), r) + Middleware(http.DefaultServeMux).ServeHTTP(httptest.NewRecorder(), r) if addr != tc.addr { t.Errorf("Header %v, got %q, want %q", tc.headers, addr, tc.addr) } @@ -309,7 +309,7 @@ func TestPanickingHandler(t *testing.T) { Body: ioutil.NopCloser(bytes.NewReader(nil)), } rec := httptest.NewRecorder() - handleHTTP(rec, r) + Middleware(http.DefaultServeMux).ServeHTTP(rec, r) if rec.Code != 500 { t.Errorf("Panicking handler returned HTTP %d, want HTTP %d", rec.Code, 500) } diff --git a/v2/internal/main.go b/v2/internal/main.go index 7adb4490..0abb8c62 100644 --- a/v2/internal/main.go +++ b/v2/internal/main.go @@ -30,7 +30,7 @@ func Main() { if IsDevAppServer() { host = "127.0.0.1" } - if err := http.ListenAndServe(host+":"+port, http.HandlerFunc(handleHTTP)); err != nil { + if err := http.ListenAndServe(host+":"+port, Middleware(http.DefaultServeMux)); err != nil { log.Fatalf("http.ListenAndServe: %v", err) } }