Skip to content

Commit

Permalink
[FLYTE-486] Support selecting IDP based on the query parameter (#4838) (
Browse files Browse the repository at this point in the history
#61)

* Added config option for IDPQuery parameter

---------

Signed-off-by: pmahindrakar-oss <[email protected]>
  • Loading branch information
pmahindrakar-oss authored Feb 6, 2024
1 parent 071939a commit 651ac73
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 10 deletions.
1 change: 1 addition & 0 deletions flyteadmin/auth/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ type UserAuthConfig struct {
CookieHashKeySecretName string `json:"cookieHashKeySecretName" pflag:",OPTIONAL: Secret name to use for cookie hash key."`
CookieBlockKeySecretName string `json:"cookieBlockKeySecretName" pflag:",OPTIONAL: Secret name to use for cookie block key."`
CookieSetting CookieSettings `json:"cookieSetting" pflag:", settings used by cookies created for user auth"`
IDPQueryParameter string `json:"idpQueryParameter" pflag:", idp query parameter used for selecting a particular IDP for doing user authentication. Eg: for Okta passing idp=<IDP-ID> forces the authentication to happen with IDP-ID"`
}

//go:generate enumer --type=SameSite --trimprefix=SameSite -json
Expand Down
1 change: 1 addition & 0 deletions flyteadmin/auth/config/config_flags.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions flyteadmin/auth/config/config_flags_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 19 additions & 2 deletions flyteadmin/auth/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"

Expand Down Expand Up @@ -139,7 +140,7 @@ func GetLoginHandler(ctx context.Context, authCtx interfaces.AuthenticationConte

state := HashCsrfState(csrfToken)
logger.Debugf(ctx, "Setting CSRF state cookie to %s and state to %s\n", csrfToken, state)
url := authCtx.OAuth2ClientConfig(GetPublicURL(ctx, request, authCtx.Options())).AuthCodeURL(state)
urlString := authCtx.OAuth2ClientConfig(GetPublicURL(ctx, request, authCtx.Options())).AuthCodeURL(state)
queryParams := request.URL.Query()
if !GetRedirectURLAllowed(ctx, queryParams.Get(RedirectURLParameter), authCtx.Options()) {
logger.Infof(ctx, "unauthorized redirect URI")
Expand All @@ -154,7 +155,23 @@ func GetLoginHandler(ctx context.Context, authCtx interfaces.AuthenticationConte
logger.Errorf(ctx, "Was not able to create a redirect cookie")
}
}
http.Redirect(writer, request, url, http.StatusTemporaryRedirect)

idpURL, err := url.Parse(urlString)
if err != nil {
logger.Errorf(ctx, "failed to parse url %q: %v", urlString, err)
writer.WriteHeader(http.StatusInternalServerError)
}

// Add the IDPQueryParameter to the URL if it is present in the request
idpQueryParam := authCtx.Options().UserAuth.IDPQueryParameter
if len(idpQueryParam) > 0 && queryParams.Get(idpQueryParam) != "" {
logger.Infof(ctx, "Adding IDP Query Parameter to the URL")
query := idpURL.Query() // Gets a copy of query parameters
query.Add(idpQueryParam, queryParams.Get(idpQueryParam))
// Updates the rawquery with the new query parameters
idpURL.RawQuery = query.Encode()
}
http.Redirect(writer, request, idpURL.String(), http.StatusTemporaryRedirect)
}
}

Expand Down
49 changes: 41 additions & 8 deletions flyteadmin/auth/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,49 @@ func TestGetLoginHandler(t *testing.T) {
Scopes: []string{"openid", "other"},
}
mockAuthCtx := mocks.AuthenticationContext{}
mockAuthCtx.OnOptions().Return(&config.Config{})
mockAuthCtx.OnOptions().Return(&config.Config{
UserAuth: config.UserAuthConfig{
IDPQueryParameter: "idp",
},
})
mockAuthCtx.OnOAuth2ClientConfigMatch(mock.Anything).Return(&dummyOAuth2Config)
handler := GetLoginHandler(ctx, &mockAuthCtx)
req, err := http.NewRequest("GET", "/login", nil)
assert.NoError(t, err)
w := httptest.NewRecorder()
handler(w, req)
assert.Equal(t, 307, w.Code)
assert.True(t, strings.Contains(w.Header().Get("Location"), "response_type=code&scope=openid+other"))
assert.True(t, strings.Contains(w.Header().Get("Set-Cookie"), "flyte_csrf_state="))

type test struct {
name string
url string
expectedStatusCode int
expectedLocation string
expectedSetCookie string
}
tests := []test{
{
name: "no idp parameter",
url: "/login",
expectedStatusCode: 307,
expectedLocation: "response_type=code&scope=openid+other",
expectedSetCookie: "flyte_csrf_state=",
},
{
name: "with idp parameter config",
url: "/login?idp=dummyIDP",
expectedStatusCode: 307,
expectedLocation: "dummyIDP",
expectedSetCookie: "flyte_csrf_state=",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequest("GET", tt.url, nil)
assert.NoError(t, err)
w := httptest.NewRecorder()
handler(w, req)
assert.Equal(t, tt.expectedStatusCode, w.Code)
assert.True(t, strings.Contains(w.Header().Get("Location"), tt.expectedLocation))
assert.True(t, strings.Contains(w.Header().Get("Set-Cookie"), tt.expectedSetCookie))
})
}
}

func TestGetLogoutHandler(t *testing.T) {
Expand Down

0 comments on commit 651ac73

Please sign in to comment.