diff --git a/ws_test.go b/ws_test.go index c2b78d3ac..4dec114c6 100644 --- a/ws_test.go +++ b/ws_test.go @@ -1113,8 +1113,7 @@ func TestWSNoDeadlockOnAuthFailure(t *testing.T) { } func TestWSProxyPath(t *testing.T) { - const proxyPath = "/proxy1" - var proxyCalled bool + const proxyPath = "proxy1" // Listen to a random port l, err := net.Listen("tcp", ":0") @@ -1125,22 +1124,37 @@ func TestWSProxyPath(t *testing.T) { proxyPort := l.Addr().(*net.TCPAddr).Port + ch := make(chan struct{}, 1) proxySrv := &http.Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - proxyCalled = r.URL.Path == proxyPath + if r.URL.Path == "/"+proxyPath { + ch <- struct{}{} + } }), } defer proxySrv.Shutdown(context.Background()) go proxySrv.Serve(l) - opt := testWSGetDefaultOptions(t, false) - s := RunServerWithOptions(opt) - defer s.Shutdown() - - url := fmt.Sprintf("ws://127.0.0.1:%d", proxyPort) - Connect(url, ProxyPath(proxyPath)) - - if !proxyCalled { - t.Fatal("Proxy haven't been called") + for _, test := range []struct { + name string + path string + }{ + {"without slash", proxyPath}, + {"with slash", "/" + proxyPath}, + } { + t.Run(test.name, func(t *testing.T) { + url := fmt.Sprintf("ws://127.0.0.1:%d", proxyPort) + nc, err := Connect(url, ProxyPath(test.path)) + if err == nil { + nc.Close() + t.Fatal("Did not expect to connect") + } + select { + case <-ch: + // OK: + case <-time.After(time.Second): + t.Fatal("Proxy was not reached") + } + }) } }