diff --git a/README.md b/README.md index a107f6fe..7ebaad01 100644 --- a/README.md +++ b/README.md @@ -141,6 +141,14 @@ role_arn = arn:aws:iam:::role/ assume_role_ttl = 12h ``` +#### Multi-factor Authentication (MFA) configuration + +If you have a single MFA factor configured, that factor will be automatically selected. By default, if you have multiple available MFA factors, then you will be prompted to select which one to use. However, if you have multiple factors and want to specify which factor to use, you can do one of the following: + +* Specify on the command line with `--mfa-provider` and `--mfa-factor-type` +* Specify with environment variables `AWS_OKTA_MFA_PROVIDER` and `AWS_OKTA_MFA_FACTOR_TYPE` +* Specify in your aws config with `mfa_provider` and `mfa_factor_type` + ## Backends We use 99design's keyring package that they use in `aws-vault`. Because of this, you can choose between different pluggable secret storage backends just like in `aws-vault`. You can either set your backend from the command line as a flag, or set the `AWS_OKTA_BACKEND` environment variable. diff --git a/cmd/add.go b/cmd/add.go index b40c29d2..039768b8 100644 --- a/cmd/add.go +++ b/cmd/add.go @@ -82,7 +82,12 @@ func add(cmd *cobra.Command, args []string) error { Domain: oktaDomain, } - if err := creds.Validate(mfaDevice); err != nil { + // Profiles aren't parsed during `add`, but still want + // to centralize the MFA config logic + var dummyProfiles lib.Profiles + updateMfaConfig(cmd, dummyProfiles, "", &mfaConfig) + + if err := creds.Validate(mfaConfig); err != nil { log.Debugf("Failed to validate credentials: %s", err) return ErrFailedToValidateCredentials } diff --git a/cmd/exec.go b/cmd/exec.go index 64f84f52..f3365180 100644 --- a/cmd/exec.go +++ b/cmd/exec.go @@ -116,6 +116,8 @@ func execRun(cmd *cobra.Command, args []string) error { return fmt.Errorf("Profile '%s' not found in your aws config. Use list command to see configured profiles.", profile) } + updateMfaConfig(cmd, profiles, profile, &mfaConfig) + // check for an assume_role_ttl in the profile if we don't have a more explicit one if !cmd.Flags().Lookup("assume-role-ttl").Changed { if err := updateDurationFromConfigProfile(profiles, profile, &assumeRoleTTL); err != nil { @@ -124,7 +126,7 @@ func execRun(cmd *cobra.Command, args []string) error { } opts := lib.ProviderOptions{ - MFADevice: mfaDevice, + MFAConfig: mfaConfig, Profiles: profiles, SessionDuration: sessionTTL, AssumeRoleDuration: assumeRoleTTL, diff --git a/cmd/login.go b/cmd/login.go index fda41a18..04a53836 100644 --- a/cmd/login.go +++ b/cmd/login.go @@ -69,6 +69,8 @@ func loginRun(cmd *cobra.Command, args []string) error { return fmt.Errorf("Profile '%s' not found in your aws config", profile) } + updateMfaConfig(cmd, profiles, profile, &mfaConfig) + // check for an assume_role_ttl in the profile if we don't have a more explicit one if !cmd.Flags().Lookup("assume-role-ttl").Changed { if err := updateDurationFromConfigProfile(profiles, profile, &assumeRoleTTL); err != nil { @@ -77,7 +79,7 @@ func loginRun(cmd *cobra.Command, args []string) error { } opts := lib.ProviderOptions{ - MFADevice: mfaDevice, + MFAConfig: mfaConfig, Profiles: profiles, SessionDuration: sessionTTL, AssumeRoleDuration: assumeRoleTTL, diff --git a/cmd/root.go b/cmd/root.go index 11b9757f..cd374624 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -7,6 +7,7 @@ import ( "github.com/99designs/keyring" analytics "github.com/segmentio/analytics-go" + "github.com/segmentio/aws-okta/lib" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" ) @@ -22,13 +23,13 @@ var ( const ( // keep expected behavior pre-u2f with duo push - DefaultMFADevice = "phone1" + DefaultMFADuoDevice = "phone1" ) // global flags var ( backend string - mfaDevice string + mfaConfig lib.MFAConfig debug bool version string analyticsWriteKey string @@ -72,15 +73,6 @@ func prerun(cmd *cobra.Command, args []string) { } } - if !cmd.Flags().Lookup("mfa-device").Changed { - mfaDeviceFromEnv, ok := os.LookupEnv("AWS_OKTA_MFA_DEVICE") - if ok { - mfaDevice = mfaDeviceFromEnv - } else { - mfaDevice = DefaultMFADevice - } - } - if debug { log.SetLevel(log.DebugLevel) } @@ -111,7 +103,44 @@ func init() { for _, backendType := range keyring.AvailableBackends() { backendsAvailable = append(backendsAvailable, string(backendType)) } - RootCmd.PersistentFlags().StringVarP(&mfaDevice, "mfa-device", "m", "phone1", "Device to use phone1, phone2, u2f or token") + RootCmd.PersistentFlags().StringVarP(&mfaConfig.Provider, "mfa-provider", "", "", "MFA Provider to use (eg DUO, OKTA, GOOGLE)") + RootCmd.PersistentFlags().StringVarP(&mfaConfig.FactorType, "mfa-factor-type", "", "", "MFA Factor Type to use (eg push, token:software:totp)") + RootCmd.PersistentFlags().StringVarP(&mfaConfig.DuoDevice, "mfa-duo-device", "", "phone1", "Device to use phone1, phone2, u2f or token") RootCmd.PersistentFlags().StringVarP(&backend, "backend", "b", "", fmt.Sprintf("Secret backend to use %s", backendsAvailable)) RootCmd.PersistentFlags().BoolVarP(&debug, "debug", "d", false, "Enable debug logging") } + +func updateMfaConfig(cmd *cobra.Command, profiles lib.Profiles, profile string, config *lib.MFAConfig) { + if !cmd.Flags().Lookup("mfa-duo-device").Changed { + mfaDeviceFromEnv, ok := os.LookupEnv("AWS_OKTA_MFA_DUO_DEVICE") + if ok { + config.DuoDevice = mfaDeviceFromEnv + } else { + config.DuoDevice = DefaultMFADuoDevice + } + } + + if !cmd.Flags().Lookup("mfa-provider").Changed { + mfaProvider, ok := os.LookupEnv("AWS_OKTA_MFA_PROVIDER") + if ok { + config.Provider = mfaProvider + } else { + mfaProvider, _, err := profiles.GetValue(profile, "mfa_provider") + if err == nil { + config.Provider = mfaProvider + } + } + } + + if !cmd.Flags().Lookup("mfa-factor-type").Changed { + mfaFactorType, ok := os.LookupEnv("AWS_OKTA_MFA_FACTOR_TYPE") + if ok { + config.FactorType = mfaFactorType + } else { + mfaFactorType, _, err := profiles.GetValue(profile, "mfa_factor_type") + if err == nil { + config.FactorType = mfaFactorType + } + } + } +} diff --git a/lib/okta.go b/lib/okta.go index b7e6cf33..7fcf204f 100644 --- a/lib/okta.go +++ b/lib/okta.go @@ -42,7 +42,6 @@ type OktaClient struct { Password string UserAuth *OktaUserAuthn DuoClient *DuoClient - MFADevice string AccessKeyId string SecretAccessKey string SessionToken string @@ -51,6 +50,13 @@ type OktaClient struct { CookieJar http.CookieJar BaseURL *url.URL Domain string + MFAConfig MFAConfig +} + +type MFAConfig struct { + Provider string // Which MFA provider to use when presented with an MFA challenge + FactorType string // Which of the factor types of the MFA provider to use + DuoDevice string // Which DUO device to use for DUO MFA } type SAMLAssertion struct { @@ -66,9 +72,9 @@ type OktaCreds struct { Domain string } -func (c *OktaCreds) Validate(mfaDevice string) error { +func (c *OktaCreds) Validate(mfaConfig MFAConfig) error { // OktaClient assumes we're doing some AWS SAML calls, but Validate doesn't - o, err := NewOktaClient(*c, "", "", mfaDevice) + o, err := NewOktaClient(*c, "", "", mfaConfig) if err != nil { return err } @@ -92,7 +98,7 @@ func getOktaDomain(region string) (string, error) { return "", fmt.Errorf("invalid region %s", region) } -func NewOktaClient(creds OktaCreds, oktaAwsSAMLUrl string, sessionCookie string, mfaDevice string) (*OktaClient, error) { +func NewOktaClient(creds OktaCreds, oktaAwsSAMLUrl string, sessionCookie string, mfaConfig MFAConfig) (*OktaClient, error) { var domain string // maintain compatibility for deprecated creds.Organization @@ -134,8 +140,8 @@ func NewOktaClient(creds OktaCreds, oktaAwsSAMLUrl string, sessionCookie string, OktaAwsSAMLUrl: oktaAwsSAMLUrl, CookieJar: jar, BaseURL: base, - MFADevice: mfaDevice, Domain: domain, + MFAConfig: mfaConfig, }, nil } @@ -232,25 +238,53 @@ func (o *OktaClient) AuthenticateProfile(profileARN string, duration time.Durati return *samlResp.Credentials, sessionCookie, nil } -func selectMFADevice(factors []OktaUserAuthnFactor) (*OktaUserAuthnFactor, error) { - if len(factors) > 1 { - log.Info("Select a MFA from the following list") - for i, f := range factors { - log.Infof("%d: %s (%s)", i, f.Provider, f.FactorType) - } - i, err := Prompt("Select MFA method", false) - if err != nil { - return nil, err - } - factor, err := strconv.Atoi(i) - if err != nil { - return nil, err +func selectMFADeviceFromConfig(o *OktaClient) (*OktaUserAuthnFactor, error) { + log.Debugf("MFAConfig: %v\n", o.MFAConfig) + if o.MFAConfig.Provider == "" || o.MFAConfig.FactorType == "" { + return nil, nil + } + + for _, f := range o.UserAuth.Embedded.Factors { + log.Debugf("%v\n", f) + if strings.EqualFold(f.Provider, o.MFAConfig.Provider) && strings.EqualFold(f.FactorType, o.MFAConfig.FactorType) { + log.Debugf("Using matching factor \"%v %v\" from config\n", f.Provider, f.FactorType) + return &f, nil } - return &factors[factor], nil + } + + return nil, fmt.Errorf("Failed to select MFA device with Provider = \"%s\", FactorType = \"%s\"", o.MFAConfig.Provider, o.MFAConfig.FactorType) +} + +func (o *OktaClient) selectMFADevice() (*OktaUserAuthnFactor, error) { + factors := o.UserAuth.Embedded.Factors + if len(factors) == 0 { + return nil, errors.New("No available MFA Factors") } else if len(factors) == 1 { return &factors[0], nil } - return nil, errors.New("Failed to select MFA device") + + factor, err := selectMFADeviceFromConfig(o) + if err != nil { + return nil, err + } + + if factor != nil { + return factor, nil + } + + log.Info("Select a MFA from the following list") + for i, f := range factors { + log.Infof("%d: %s (%s)", i, f.Provider, f.FactorType) + } + i, err := Prompt("Select MFA method", false) + if err != nil { + return nil, err + } + factorIdx, err := strconv.Atoi(i) + if err != nil { + return nil, err + } + return &factors[factorIdx], nil } func (o *OktaClient) preChallenge(oktaFactorId, oktaFactorType string) ([]byte, error) { @@ -307,7 +341,7 @@ func (o *OktaClient) postChallenge(payload []byte, oktaFactorProvider string, ok Host: f.Embedded.Verification.Host, Signature: f.Embedded.Verification.Signature, Callback: f.Embedded.Verification.Links.Complete.Href, - Device: o.MFADevice, + Device: o.MFAConfig.DuoDevice, StateToken: o.UserAuth.StateToken, } @@ -355,7 +389,7 @@ func (o *OktaClient) challengeMFA() (err error) { var oktaFactorType string log.Debugf("%s", o.UserAuth.StateToken) - factor, err := selectMFADevice(o.UserAuth.Embedded.Factors) + factor, err := o.selectMFADevice() if err != nil { log.Debug("Failed to select MFA device") return @@ -489,10 +523,10 @@ type OktaProvider struct { ProfileARN string SessionDuration time.Duration OktaAwsSAMLUrl string - MFADevice string // OktaSessionCookieKey represents the name of the session cookie // to be stored in the keyring. OktaSessionCookieKey string + MFAConfig MFAConfig } func (p *OktaProvider) Retrieve() (sts.Credentials, string, error) { @@ -515,7 +549,7 @@ func (p *OktaProvider) Retrieve() (sts.Credentials, string, error) { sessionCookie = string(cookieItem.Data) } - oktaClient, err := NewOktaClient(oktaCreds, p.OktaAwsSAMLUrl, sessionCookie, p.MFADevice) + oktaClient, err := NewOktaClient(oktaCreds, p.OktaAwsSAMLUrl, sessionCookie, p.MFAConfig) if err != nil { return sts.Credentials{}, "", err } @@ -526,9 +560,9 @@ func (p *OktaProvider) Retrieve() (sts.Credentials, string, error) { } newCookieItem := keyring.Item{ - Key: p.OktaSessionCookieKey, - Data: []byte(newSessionCookie), - Label: "okta session cookie", + Key: p.OktaSessionCookieKey, + Data: []byte(newSessionCookie), + Label: "okta session cookie", KeychainNotTrustApplication: false, } diff --git a/lib/provider.go b/lib/provider.go index 217616d7..01d4eb95 100644 --- a/lib/provider.go +++ b/lib/provider.go @@ -26,11 +26,11 @@ const ( ) type ProviderOptions struct { - MFADevice string SessionDuration time.Duration AssumeRoleDuration time.Duration ExpiryWindow time.Duration Profiles Profiles + MFAConfig MFAConfig } func (o ProviderOptions) Validate() error { @@ -167,7 +167,7 @@ func (p *Provider) getSamlSessionCreds() (sts.Credentials, error) { } provider := OktaProvider{ - MFADevice: p.ProviderOptions.MFADevice, + MFAConfig: p.ProviderOptions.MFAConfig, Keyring: p.keyring, ProfileARN: profileARN, SessionDuration: p.SessionDuration, @@ -198,7 +198,7 @@ func (p *Provider) GetSAMLLoginURL() (*url.URL, error) { } provider := OktaProvider{ - MFADevice: p.ProviderOptions.MFADevice, + MFAConfig: p.ProviderOptions.MFAConfig, Keyring: p.keyring, ProfileARN: profileARN, SessionDuration: p.SessionDuration,