diff --git a/util/dex/dex.go b/util/dex/dex.go index dd239d34abcd5..735b7cbb72976 100644 --- a/util/dex/dex.go +++ b/util/dex/dex.go @@ -17,6 +17,13 @@ import ( var messageRe = regexp.MustCompile(`
(.*)([\s\S]*?)<\/p>`) +func decorateDirector(director func(req *http.Request), target *url.URL) func(req *http.Request) { + return func(req *http.Request) { + director(req) + req.Host = target.Host + } +} + // NewDexHTTPReverseProxy returns a reverse proxy to the Dex server. Dex is assumed to be configured // with the external issuer URL muxed to the same path configured in server.go. In other words, if // Argo CD API server wants to proxy requests at /api/dex, then the dex config yaml issuer URL should @@ -25,6 +32,7 @@ func NewDexHTTPReverseProxy(serverAddr string, baseHRef string) func(writer http target, err := url.Parse(serverAddr) errors.CheckError(err) target.Path = baseHRef + proxy := httputil.NewSingleHostReverseProxy(target) proxy.ModifyResponse = func(resp *http.Response) error { if resp.StatusCode == 500 { @@ -52,6 +60,7 @@ func NewDexHTTPReverseProxy(serverAddr string, baseHRef string) func(writer http } return nil } + proxy.Director = decorateDirector(proxy.Director, target) return func(w http.ResponseWriter, r *http.Request) { proxy.ServeHTTP(w, r) } diff --git a/util/dex/dex_test.go b/util/dex/dex_test.go index 114903902bd59..ffe1fa7862878 100644 --- a/util/dex/dex_test.go +++ b/util/dex/dex_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "strings" "testing" @@ -270,7 +271,9 @@ func Test_GenerateDexConfig(t *testing.T) { func Test_DexReverseProxy(t *testing.T) { t.Run("Good case", func(t *testing.T) { + var host string fakeDex := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + host = req.Host rw.WriteHeader(http.StatusOK) })) defer fakeDex.Close() @@ -278,10 +281,12 @@ func Test_DexReverseProxy(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(NewDexHTTPReverseProxy(fakeDex.URL, "/"))) fmt.Printf("Fake API Server listening on %s\n", server.URL) defer server.Close() + target, _ := url.Parse(fakeDex.URL) resp, err := http.Get(server.URL) assert.NotNil(t, resp) assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, host, target.Host) fmt.Printf("%s\n", resp.Status) })