Skip to content

Commit

Permalink
reverseproxy: Rewrite requests and responses for websocket over http2 (
Browse files Browse the repository at this point in the history
…#6567)

* reverse proxy: rewrite requests and responses for websocket over http2

* delete protocol pseudo-header

* modify cloned requests

* set request variable to track if it's a h2 websocket

* use request bodu

* rewrite request body

* use WebSocket instead of Websocket in the headers

* use logger check for zap loggers

* fix lint
  • Loading branch information
WeidiDeng authored Dec 6, 2024
1 parent a1751ad commit 9c0c71e
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 13 deletions.
19 changes: 19 additions & 0 deletions modules/caddyhttp/reverseproxy/reverseproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package reverseproxy
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -394,6 +396,23 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
return caddyhttp.Error(http.StatusInternalServerError,
fmt.Errorf("preparing request for upstream round-trip: %v", err))
}
// websocket over http2, assuming backend doesn't support this, the request will be modified to http1.1 upgrade
// TODO: once we can reliably detect backend support this, it can be removed for those backends
if r.ProtoMajor == 2 && r.Method == http.MethodConnect && r.Header.Get(":protocol") != "" {
clonedReq.Header.Del(":protocol")
// keep the body for later use. http1.1 upgrade uses http.NoBody
caddyhttp.SetVar(clonedReq.Context(), "h2_websocket_body", clonedReq.Body)
clonedReq.Body = http.NoBody
clonedReq.Method = http.MethodGet
clonedReq.Header.Set("Upgrade", r.Header.Get(":protocol"))
clonedReq.Header.Set("Connection", "Upgrade")
key := make([]byte, 16)
_, randErr := rand.Read(key)
if randErr != nil {
return randErr
}
clonedReq.Header["Sec-WebSocket-Key"] = []string{base64.StdEncoding.EncodeToString(key)}
}

// we will need the original headers and Host value if
// header operations are configured; this is so that each
Expand Down
82 changes: 69 additions & 13 deletions modules/caddyhttp/reverseproxy/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package reverseproxy

import (
"bufio"
"context"
"errors"
"fmt"
Expand All @@ -33,8 +34,29 @@ import (
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"golang.org/x/net/http/httpguts"

"github.com/caddyserver/caddy/v2/modules/caddyhttp"
)

type h2ReadWriteCloser struct {
io.ReadCloser
http.ResponseWriter
}

func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) {
n, err = rwc.ResponseWriter.Write(p)
if err != nil {
return 0, err
}

//nolint:bodyclose
err = http.NewResponseController(rwc.ResponseWriter).Flush()
if err != nil {
return 0, err
}
return n, nil
}

func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response) {
reqUpType := upgradeType(req.Header)
resUpType := upgradeType(res.Header)
Expand Down Expand Up @@ -67,24 +89,58 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup,
// like the rest of handler chain.
copyHeader(rw.Header(), res.Header)
normalizeWebsocketHeaders(rw.Header())
rw.WriteHeader(res.StatusCode)

logger.Debug("upgrading connection")
var (
conn io.ReadWriteCloser
brw *bufio.ReadWriter
)
// websocket over http2, assuming backend doesn't support this, the request will be modified to http1.1 upgrade
// TODO: once we can reliably detect backend support this, it can be removed for those backends
if body, ok := caddyhttp.GetVar(req.Context(), "h2_websocket_body").(io.ReadCloser); ok {
req.Body = body
rw.Header().Del("Upgrade")
rw.Header().Del("Connection")
delete(rw.Header(), "Sec-WebSocket-Accept")
rw.WriteHeader(http.StatusOK)

if c := logger.Check(zap.DebugLevel, "upgrading connection"); c != nil {
c.Write(zap.Int("http_version", 2))
}

//nolint:bodyclose
conn, brw, hijackErr := http.NewResponseController(rw).Hijack()
if errors.Is(hijackErr, http.ErrNotSupported) {
if c := logger.Check(zapcore.ErrorLevel, "can't switch protocols using non-Hijacker ResponseWriter"); c != nil {
c.Write(zap.String("type", fmt.Sprintf("%T", rw)))
//nolint:bodyclose
flushErr := http.NewResponseController(rw).Flush()
if flushErr != nil {
if c := h.logger.Check(zap.ErrorLevel, "failed to flush http2 websocket response"); c != nil {
c.Write(zap.Error(flushErr))
}
return
}
return
}
conn = h2ReadWriteCloser{req.Body, rw}
// bufio is not needed, use minimal buffer
brw = bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1))
} else {
rw.WriteHeader(res.StatusCode)

if hijackErr != nil {
if c := logger.Check(zapcore.ErrorLevel, "hijack failed on protocol switch"); c != nil {
c.Write(zap.Error(hijackErr))
if c := logger.Check(zap.DebugLevel, "upgrading connection"); c != nil {
c.Write(zap.Int("http_version", req.ProtoMajor))
}

var hijackErr error
//nolint:bodyclose
conn, brw, hijackErr = http.NewResponseController(rw).Hijack()
if errors.Is(hijackErr, http.ErrNotSupported) {
if c := h.logger.Check(zap.ErrorLevel, "can't switch protocols using non-Hijacker ResponseWriter"); c != nil {
c.Write(zap.String("type", fmt.Sprintf("%T", rw)))
}
return
}

if hijackErr != nil {
if c := h.logger.Check(zap.ErrorLevel, "hijack failed on protocol switch"); c != nil {
c.Write(zap.Error(hijackErr))
}
return
}
return
}

// adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5
Expand Down

0 comments on commit 9c0c71e

Please sign in to comment.