diff --git a/cmd/exec.go b/cmd/exec.go index 9bbcce7e..64f84f52 100644 --- a/cmd/exec.go +++ b/cmd/exec.go @@ -54,9 +54,9 @@ func loadDurationFlagFromEnv(cmd *cobra.Command, flagName string, envVar string, return nil } -func updateDurationFromConfigProfile(profiles map[string]map[string]string, profile string, durationName string, val *time.Duration) error { - fromProfile, ok := profiles[profile]["assume_role_ttl"] - if !ok { +func updateDurationFromConfigProfile(profiles lib.Profiles, profile string, val *time.Duration) error { + fromProfile, _, err := profiles.GetValue(profile, "assume_role_ttl") + if err != nil { return nil } @@ -118,7 +118,7 @@ func execRun(cmd *cobra.Command, args []string) error { // check for an assume_role_ttl in the profile if we don't have a more explicit one if !cmd.Flags().Lookup("assume-role-ttl").Changed { - if err := updateDurationFromConfigProfile(profiles, profile, "assume_role_ttl", &assumeRoleTTL); err != nil { + if err := updateDurationFromConfigProfile(profiles, profile, &assumeRoleTTL); err != nil { fmt.Fprintln(os.Stderr, "warning: could not parse duration from profile config") } } diff --git a/cmd/login.go b/cmd/login.go index 9a997a69..31a1607a 100644 --- a/cmd/login.go +++ b/cmd/login.go @@ -70,7 +70,7 @@ func loginRun(cmd *cobra.Command, args []string) error { // check for an assume_role_ttl in the profile if we don't have a more explicit one if !cmd.Flags().Lookup("assume-role-ttl").Changed { - if err := updateDurationFromConfigProfile(profiles, profile, "assume_role_ttl", &assumeRoleTTL); err != nil { + if err := updateDurationFromConfigProfile(profiles, profile, &assumeRoleTTL); err != nil { fmt.Fprintln(os.Stderr, "warning: could not parse duration from profile config") } } diff --git a/lib/config.go b/lib/config.go index 1899ca1e..c60795eb 100644 --- a/lib/config.go +++ b/lib/config.go @@ -12,10 +12,10 @@ import ( "github.com/vaughan0/go-ini" ) -type profiles map[string]map[string]string +type Profiles map[string]map[string]string type config interface { - Parse() (profiles, error) + Parse() (Profiles, error) } type fileConfig struct { @@ -37,7 +37,7 @@ func NewConfigFromEnv() (config, error) { return &fileConfig{file: file}, nil } -func (c *fileConfig) Parse() (profiles, error) { +func (c *fileConfig) Parse() (Profiles, error) { if c.file == "" { return nil, nil } @@ -48,7 +48,7 @@ func (c *fileConfig) Parse() (profiles, error) { return nil, fmt.Errorf("Error parsing config file %q: %v", c.file, err) } - profiles := profiles{"okta": map[string]string{}} + profiles := Profiles{"okta": map[string]string{}} for sectionName, section := range f { profiles[strings.TrimPrefix(sectionName, "profile ")] = section } @@ -57,7 +57,7 @@ func (c *fileConfig) Parse() (profiles, error) { } // sourceProfile returns either the defined source_profile or p if none exists -func sourceProfile(p string, from profiles) string { +func sourceProfile(p string, from Profiles) string { if conf, ok := from[p]; ok { if source := conf["source_profile"]; source != "" { return source @@ -65,3 +65,29 @@ func sourceProfile(p string, from profiles) string { } return p } + +func (p Profiles) GetValue(profile string, config_key string) (string, string, error) { + config_value, ok := p[profile][config_key] + if ok { + return config_value, profile, nil + } + + // Lookup from the `source_profile`, if it exists + profile, ok = p[profile]["source_profile"] + if ok { + config_value, ok := p[profile][config_key] + if ok { + return config_value, profile, nil + } + + } + + // Fallback to `okta` if no profile supplies the value + profile = "okta" + config_value, ok = p[profile][config_key] + if ok { + return config_value, profile, nil + } + + return "", "", fmt.Errorf("Could not find %s in %s, source profile, or okta", config_key, profile) +} diff --git a/lib/config_test.go b/lib/config_test.go new file mode 100644 index 00000000..88cf4738 --- /dev/null +++ b/lib/config_test.go @@ -0,0 +1,95 @@ +package lib + +import "testing" + +func TestGetConfigValue(t *testing.T) { + config_profiles := make(Profiles) + + t.Run("empty profile", func(t *testing.T) { + _, _, found_error := config_profiles.GetValue("profile_a", "config_key") + if found_error == nil { + t.Error("Searching an empty profile set should return an error") + } + }) + + config_profiles["okta"] = map[string]string{ + "key_a": "a", + "key_b": "b", + } + + config_profiles["profile_a"] = map[string]string{ + "key_b": "b-a", + "key_c": "c-a", + "key_d": "d-a", + } + + config_profiles["profile_b"] = map[string]string{ + "source_profile": "profile_a", + "key_d": "d-b", + "key_e": "e-b", + } + + config_profiles["profile_c"] = map[string]string{ + "source_profile": "profile_b", + "key_f": "f-c", + } + + t.Run("missing key", func(t *testing.T) { + _, _, found_error := config_profiles.GetValue("profile_a", "config_key") + if found_error == nil { + t.Error("Searching for a missing key should return an error") + } + }) + + t.Run("fallback to okta", func(t *testing.T) { + found_value, found_profile, found_error := config_profiles.GetValue("profile_a", "key_a") + if found_error != nil { + t.Error("Error when searching for key_a") + } + + if found_profile != "okta" { + t.Error("key_a should have come from `okta`") + } + + if found_value != "a" { + t.Error("The proper value for `key_a` should be `a`") + } + }) + + t.Run("found in current profile", func(t *testing.T) { + found_value, found_profile, found_error := config_profiles.GetValue("profile_b", "key_d") + if found_error != nil { + t.Error("Error when searching for key_d") + } + + if found_profile != "profile_b" { + t.Error("key_d should have come from `profile_b`") + } + + if found_value != "d-b" { + t.Error("The proper value for `key_d` should be `d-b`") + } + }) + + t.Run("traversing from child profile", func(t *testing.T) { + found_value, found_profile, found_error := config_profiles.GetValue("profile_b", "key_a") + if found_error != nil { + t.Error("Error when searching for key_a") + } + + if found_profile != "okta" { + t.Error("key_a should have come from `okta`") + } + + if found_value != "a" { + t.Error("The proper value for `key_a` should be `a`") + } + }) + + t.Run("recursive traversing from child profile", func(t *testing.T) { + _, _, found_error := config_profiles.GetValue("profile_c", "key_c") + if found_error == nil { + t.Error("Recursive searching should not work") + } + }) +} diff --git a/lib/okta.go b/lib/okta.go index bf4bb0aa..a2f0cb32 100644 --- a/lib/okta.go +++ b/lib/okta.go @@ -526,9 +526,9 @@ func (p *OktaProvider) Retrieve() (sts.Credentials, string, error) { } newCookieItem := keyring.Item{ - Key: p.OktaSessionCookieKey, - Data: []byte(newSessionCookie), - Label: "okta session cookie", + Key: p.OktaSessionCookieKey, + Data: []byte(newSessionCookie), + Label: "okta session cookie", KeychainNotTrustApplication: false, } diff --git a/lib/provider.go b/lib/provider.go index fda760c5..7ea98c56 100644 --- a/lib/provider.go +++ b/lib/provider.go @@ -29,7 +29,7 @@ type ProviderOptions struct { SessionDuration time.Duration AssumeRoleDuration time.Duration ExpiryWindow time.Duration - Profiles profiles + Profiles Profiles } func (o ProviderOptions) Validate() error { @@ -65,7 +65,7 @@ type Provider struct { expires time.Time keyring keyring.Keyring sessions *KeyringSessions - profiles profiles + profiles Profiles defaultRoleSessionName string } @@ -134,37 +134,31 @@ func (p *Provider) Retrieve() (credentials.Value, error) { return value, nil } -func (p *Provider) getSamlURL(source string) (string, error) { - haystack := []string{p.profile, source, "okta"} - for _, profile := range haystack { - oktaAwsSAMLUrl, ok := p.profiles[profile]["aws_saml_url"] - if ok { - log.Debugf("Using aws_saml_url from profile: %s", profile) - return oktaAwsSAMLUrl, nil - } +func (p *Provider) getSamlURL() (string, error) { + oktaAwsSAMLUrl, profile, err := p.profiles.GetValue(p.profile, "aws_saml_url") + if err != nil { + log.Debugf("Using aws_saml_url from profile: %s", profile) + return oktaAwsSAMLUrl, nil } return "", errors.New("aws_saml_url missing from ~/.aws/config") } -func (p *Provider) getOktaSessionCookieKey(source string) string { - haystack := []string{p.profile, source, "okta"} - for _, profile := range haystack { - oktaSessionCookieKey, ok := p.profiles[profile]["okta_session_cookie_key"] - if ok { - log.Debugf("Using okta_session_cookie_key from profile: %s", profile) - return oktaSessionCookieKey - } +func (p *Provider) getOktaSessionCookieKey() string { + oktaSessionCookieKey, profile, err := p.profiles.GetValue(p.profile, "okta_session_cookie_key") + if err != nil { + log.Debugf("Using okta_session_cookie_key from profile: %s", profile) + return oktaSessionCookieKey } return "okta-session-cookie" } func (p *Provider) getSamlSessionCreds() (sts.Credentials, error) { source := sourceProfile(p.profile, p.profiles) - oktaAwsSAMLUrl, err := p.getSamlURL(source) + oktaAwsSAMLUrl, err := p.getSamlURL() if err != nil { return sts.Credentials{}, err } - oktaSessionCookieKey := p.getOktaSessionCookieKey(source) + oktaSessionCookieKey := p.getOktaSessionCookieKey() profileARN, ok := p.profiles[source]["role_arn"] if !ok { diff --git a/lib/sessions.go b/lib/sessions.go index 568a8e0a..75106e8c 100644 --- a/lib/sessions.go +++ b/lib/sessions.go @@ -22,10 +22,10 @@ type awsSession struct { type KeyringSessions struct { Keyring keyring.Keyring - Profiles profiles + Profiles Profiles } -func NewKeyringSessions(k keyring.Keyring, p profiles) (*KeyringSessions, error) { +func NewKeyringSessions(k keyring.Keyring, p Profiles) (*KeyringSessions, error) { return &KeyringSessions{ Keyring: k, Profiles: p,