Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transport: Discard the buffer when empty after http connect handshake #7424

Merged
merged 4 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions internal/transport/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,14 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri
}
return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump)
}

return &bufConn{Conn: conn, r: r}, nil
// The buffer could contain extra bytes from the target server, so we can't
// discard it. However, in many cases where the server waits for the client
// to send the first message (e.g. when TLS is being used), the buffer will
// be empty, so we can avoid the overhead of reading through this buffer.
if r.Buffered() != 0 {
return &bufConn{Conn: conn, r: r}, nil
}
return conn, nil
}

// proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy
Expand Down
90 changes: 73 additions & 17 deletions internal/transport/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package transport

import (
"bufio"
"bytes"
"context"
"encoding/base64"
"fmt"
Expand Down Expand Up @@ -58,7 +59,7 @@ type proxyServer struct {
requestCheck func(*http.Request) error
}

func (p *proxyServer) run() {
func (p *proxyServer) run(waitForServerHello bool) {
in, err := p.lis.Accept()
if err != nil {
return
Expand All @@ -83,8 +84,26 @@ func (p *proxyServer) run() {
p.t.Errorf("failed to dial to server: %v", err)
return
}
out.SetDeadline(time.Now().Add(defaultTestTimeout))
resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"}
resp.Write(p.in)
var buf bytes.Buffer
resp.Write(&buf)
if waitForServerHello {
// Batch the first message from the server with the http connect
// response. This is done to test the cases in which the grpc client has
// the response to the connect request and proxied packets from the
// destination server when it reads the transport.
b := make([]byte, 50)
bytesRead, err := out.Read(b)
if err != nil {
p.t.Errorf("Got error while reading server hello: %v", err)
in.Close()
out.Close()
return
}
buf.Write(b[0:bytesRead])
}
p.in.Write(buf.Bytes())
p.out = out
go io.Copy(p.in, p.out)
go io.Copy(p.out, p.in)
Expand All @@ -100,17 +119,23 @@ func (p *proxyServer) stop() {
}
}

func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxyReqCheck func(*http.Request) error) {
type testArgs struct {
proxyURLModify func(*url.URL) *url.URL
proxyReqCheck func(*http.Request) error
serverMessage []byte
}

func testHTTPConnect(t *testing.T, args testArgs) {
plis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
p := &proxyServer{
t: t,
lis: plis,
requestCheck: proxyReqCheck,
requestCheck: args.proxyReqCheck,
}
go p.run()
go p.run(len(args.serverMessage) > 0)
defer p.stop()

blis, err := net.Listen("tcp", "localhost:0")
Expand All @@ -128,13 +153,14 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy
return
}
defer in.Close()
in.Write(args.serverMessage)
in.Read(recvBuf)
done <- nil
}()

// Overwrite the function in the test and restore them in defer.
hpfe := func(req *http.Request) (*url.URL, error) {
return proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil
return args.proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil
}
defer overwrite(hpfe)()

Expand All @@ -143,47 +169,76 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy
defer cancel()
c, err := proxyDial(ctx, blis.Addr().String(), "test")
if err != nil {
t.Fatalf("http connect Dial failed: %v", err)
t.Fatalf("HTTP connect Dial failed: %v", err)
}
defer c.Close()
c.SetDeadline(time.Now().Add(defaultTestTimeout))

// Send msg on the connection.
c.Write(msg)
if err := <-done; err != nil {
t.Fatalf("failed to accept: %v", err)
t.Fatalf("Failed to accept: %v", err)
}

// Check received msg.
if string(recvBuf) != string(msg) {
t.Fatalf("received msg: %v, want %v", recvBuf, msg)
t.Fatalf("Received msg: %v, want %v", recvBuf, msg)
}

if len(args.serverMessage) > 0 {
gotServerMessage := make([]byte, len(args.serverMessage))
if _, err := c.Read(gotServerMessage); err != nil {
t.Errorf("Got error while reading message from server: %v", err)
return
}
if string(gotServerMessage) != string(args.serverMessage) {
t.Errorf("Message from server: %v, want %v", gotServerMessage, args.serverMessage)
}
}
}

func (s) TestHTTPConnect(t *testing.T) {
testHTTPConnect(t,
func(in *url.URL) *url.URL {
args := testArgs{
proxyURLModify: func(in *url.URL) *url.URL {
return in
},
func(req *http.Request) error {
proxyReqCheck: func(req *http.Request) error {
if req.Method != http.MethodConnect {
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
return nil
},
)
}
testHTTPConnect(t, args)
}

func (s) TestHTTPConnectWithServerHello(t *testing.T) {
args := testArgs{
proxyURLModify: func(in *url.URL) *url.URL {
return in
},
proxyReqCheck: func(req *http.Request) error {
if req.Method != http.MethodConnect {
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
return nil
},
serverMessage: []byte("server-hello"),
}
testHTTPConnect(t, args)
}

func (s) TestHTTPConnectBasicAuth(t *testing.T) {
const (
user = "notAUser"
password = "notAPassword"
)
testHTTPConnect(t,
func(in *url.URL) *url.URL {
args := testArgs{
proxyURLModify: func(in *url.URL) *url.URL {
in.User = url.UserPassword(user, password)
return in
},
func(req *http.Request) error {
proxyReqCheck: func(req *http.Request) error {
if req.Method != http.MethodConnect {
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
Expand All @@ -195,7 +250,8 @@ func (s) TestHTTPConnectBasicAuth(t *testing.T) {
}
return nil
},
)
}
testHTTPConnect(t, args)
}

func (s) TestMapAddressEnv(t *testing.T) {
Expand Down