diff --git a/cmd/exec.go b/cmd/exec.go index 2b24cd48..79970dd4 100644 --- a/cmd/exec.go +++ b/cmd/exec.go @@ -22,6 +22,7 @@ import ( var ( sessionTTL time.Duration assumeRoleTTL time.Duration + assumeRoleARN string ) func mustListProfiles() lib.Profiles { @@ -45,6 +46,7 @@ func init() { RootCmd.AddCommand(execCmd) execCmd.Flags().DurationVarP(&sessionTTL, "session-ttl", "t", time.Hour, "Expiration time for okta role session") execCmd.Flags().DurationVarP(&assumeRoleTTL, "assume-role-ttl", "a", time.Hour, "Expiration time for assumed role") + execCmd.Flags().StringVarP(&assumeRoleARN, "assume-role-arn", "r", "", "Role arn to assume, overrides arn in profile") } func loadDurationFlagFromEnv(cmd *cobra.Command, flagName string, envVar string, val *time.Duration) error { @@ -67,6 +69,21 @@ func loadDurationFlagFromEnv(cmd *cobra.Command, flagName string, envVar string, return nil } +func loadStringFlagFromEnv(cmd *cobra.Command, flagName string, envVar string, val *string) error { + if cmd.Flags().Lookup(flagName).Changed { + return nil + } + + fromEnv, ok := os.LookupEnv(envVar) + if !ok { + return nil + } + + cmd.Flags().Lookup(flagName).Changed = true + *val = fromEnv + return nil +} + func updateDurationFromConfigProfile(profiles lib.Profiles, profile string, val *time.Duration) error { fromProfile, _, err := profiles.GetValue(profile, "assume_role_ttl") if err != nil { @@ -90,6 +107,9 @@ func execPre(cmd *cobra.Command, args []string) { if err := loadDurationFlagFromEnv(cmd, "assume-role-ttl", "AWS_ASSUME_ROLE_TTL", &assumeRoleTTL); err != nil { fmt.Fprintln(os.Stderr, "warning: failed to parse duration from AWS_ASSUME_ROLE_TTL") } + if err := loadStringFlagFromEnv(cmd, "assume-role-arn", "AWS_ASSUME_ROLE_ARN", &assumeRoleARN); err != nil { + fmt.Fprintln(os.Stderr, "warning: failed to parse duration from AWS_ASSUME_ROLE_ARN") + } } func execRun(cmd *cobra.Command, args []string) error { @@ -143,6 +163,7 @@ func execRun(cmd *cobra.Command, args []string) error { Profiles: profiles, SessionDuration: sessionTTL, AssumeRoleDuration: assumeRoleTTL, + AssumeRoleArn: assumeRoleARN, } var allowedBackends []keyring.BackendType diff --git a/lib/provider.go b/lib/provider.go index a5cb2a51..d726176f 100644 --- a/lib/provider.go +++ b/lib/provider.go @@ -36,7 +36,7 @@ type ProviderOptions struct { ExpiryWindow time.Duration Profiles Profiles MFAConfig MFAConfig - + AssumeRoleArn string // if true, use store_singlekritem SessionCache (new) // if false, use store_kritempersession SessionCache (old) SessionCacheSingleItem bool @@ -201,6 +201,8 @@ func (p *Provider) getOktaSessionCookieKey() string { } func (p *Provider) getSamlSessionCreds() (sts.Credentials, error) { + var profileARN string + var ok bool source := sourceProfile(p.profile, p.profiles) oktaAwsSAMLUrl, err := p.getSamlURL() if err != nil { @@ -208,11 +210,16 @@ func (p *Provider) getSamlSessionCreds() (sts.Credentials, error) { } oktaSessionCookieKey := p.getOktaSessionCookieKey() - profileARN, ok := p.profiles[source]["role_arn"] - if !ok { - return sts.Credentials{}, errors.New("Source profile must provide `role_arn`") + // if the assumable role is passed it have it override what is in the profile + if p.AssumeRoleArn != "" { + profileARN = p.AssumeRoleArn + log.Debug("Overriding Assumable role with: ", profileARN) + } else { + profileARN, ok = p.profiles[source]["role_arn"] + if !ok { + return sts.Credentials{}, errors.New("Source profile must provide `role_arn`") + } } - provider := OktaProvider{ MFAConfig: p.ProviderOptions.MFAConfig, Keyring: p.keyring,