diff --git a/lib/okta.go b/lib/okta.go index 18c1e72d..8661d233 100644 --- a/lib/okta.go +++ b/lib/okta.go @@ -185,7 +185,7 @@ func (o *OktaClient) AuthenticateUser() error { return nil } -func (o *OktaClient) AuthenticateProfile(profileARN string, duration time.Duration) (sts.Credentials, string, error) { +func (o *OktaClient) AuthenticateProfileWithRegion(profileARN string, duration time.Duration, region string) (sts.Credentials, string, error) { // Attempt to reuse session cookie var assertion SAMLAssertion @@ -211,14 +211,18 @@ func (o *OktaClient) AuthenticateProfile(profileARN string, duration time.Durati } // Step 4 : Assume Role with SAML - samlSess := session.Must(session.NewSession()) - var svc *sts.STS - if assertion.Resp.Destination == "https://signin.amazonaws-us-gov.com/saml" { - svc = sts.New(samlSess, aws.NewConfig().WithRegion("us-gov-west-1")) + log.Debug("Step 4: Assume Role with SAML") + var samlSess *session.Session + if region != "" { + log.Debugf("Using region: %s\n", region) + conf := &aws.Config{ + Region: aws.String(region), + } + samlSess = session.Must(session.NewSession(conf)) } else { - svc = sts.New(samlSess) + samlSess = session.Must(session.NewSession()) } - log.Debugf("SAML assertion has destination %s, STS client is configured with endpoint %s\n", assertion.Resp.Destination, svc.Client.ClientInfo.Endpoint) + svc := sts.New(samlSess) samlParams := &sts.AssumeRoleWithSAMLInput{ PrincipalArn: aws.String(principal), @@ -245,6 +249,11 @@ func (o *OktaClient) AuthenticateProfile(profileARN string, duration time.Durati return *samlResp.Credentials, sessionCookie, nil } + +func (o *OktaClient) AuthenticateProfile(profileARN string, duration time.Duration) (sts.Credentials, string, error) { + return o.AuthenticateProfileWithRegion(profileARN, duration, "") +} + func selectMFADeviceFromConfig(o *OktaClient) (*OktaUserAuthnFactor, error) { log.Debugf("MFAConfig: %v\n", o.MFAConfig) if o.MFAConfig.Provider == "" || o.MFAConfig.FactorType == "" { @@ -551,6 +560,7 @@ type OktaProvider struct { // to be stored in the keyring. OktaSessionCookieKey string MFAConfig MFAConfig + AwsRegion string } func (p *OktaProvider) Retrieve() (sts.Credentials, string, error) { @@ -578,7 +588,7 @@ func (p *OktaProvider) Retrieve() (sts.Credentials, string, error) { return sts.Credentials{}, "", err } - creds, newSessionCookie, err := oktaClient.AuthenticateProfile(p.ProfileARN, p.SessionDuration) + creds, newSessionCookie, err := oktaClient.AuthenticateProfileWithRegion(p.ProfileARN, p.SessionDuration, p.AwsRegion) if err != nil { return sts.Credentials{}, "", err } diff --git a/lib/provider.go b/lib/provider.go index 0117e2a6..054232cd 100644 --- a/lib/provider.go +++ b/lib/provider.go @@ -221,6 +221,10 @@ func (p *Provider) getSamlSessionCreds() (sts.Credentials, error) { OktaAwsSAMLUrl: oktaAwsSAMLUrl, OktaSessionCookieKey: oktaSessionCookieKey, } + + if region := p.profiles[source]["region"]; region != "" { + provider.AwsRegion = region + } creds, oktaUsername, err := provider.Retrieve() if err != nil { @@ -253,6 +257,10 @@ func (p *Provider) GetSAMLLoginURL() (*url.URL, error) { OktaSessionCookieKey: oktaSessionCookieKey, } + if region := p.profiles[source]["region"]; region != "" { + provider.AwsRegion = region + } + loginURL, err := provider.GetSAMLLoginURL() if err != nil { return &url.URL{}, err