Skip to content
This repository has been archived by the owner on May 18, 2021. It is now read-only.

Commit

Permalink
Add a general method for retrieving profile configuration values
Browse files Browse the repository at this point in the history
  • Loading branch information
lsowen committed Jan 21, 2019
1 parent 9874247 commit ea9a898
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 35 deletions.
8 changes: 4 additions & 4 deletions cmd/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ func loadDurationFlagFromEnv(cmd *cobra.Command, flagName string, envVar string,
return nil
}

func updateDurationFromConfigProfile(profiles map[string]map[string]string, profile string, durationName string, val *time.Duration) error {
fromProfile, ok := profiles[profile]["assume_role_ttl"]
if !ok {
func updateDurationFromConfigProfile(profiles lib.Profiles, profile string, val *time.Duration) error {
fromProfile, _, err := profiles.GetValue(profile, "assume_role_ttl")
if err != nil {
return nil
}

Expand Down Expand Up @@ -118,7 +118,7 @@ func execRun(cmd *cobra.Command, args []string) error {

// check for an assume_role_ttl in the profile if we don't have a more explicit one
if !cmd.Flags().Lookup("assume-role-ttl").Changed {
if err := updateDurationFromConfigProfile(profiles, profile, "assume_role_ttl", &assumeRoleTTL); err != nil {
if err := updateDurationFromConfigProfile(profiles, profile, &assumeRoleTTL); err != nil {
fmt.Fprintln(os.Stderr, "warning: could not parse duration from profile config")
}
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func loginRun(cmd *cobra.Command, args []string) error {

// check for an assume_role_ttl in the profile if we don't have a more explicit one
if !cmd.Flags().Lookup("assume-role-ttl").Changed {
if err := updateDurationFromConfigProfile(profiles, profile, "assume_role_ttl", &assumeRoleTTL); err != nil {
if err := updateDurationFromConfigProfile(profiles, profile, &assumeRoleTTL); err != nil {
fmt.Fprintln(os.Stderr, "warning: could not parse duration from profile config")
}
}
Expand Down
36 changes: 31 additions & 5 deletions lib/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ import (
"github.com/vaughan0/go-ini"
)

type profiles map[string]map[string]string
type Profiles map[string]map[string]string

type config interface {
Parse() (profiles, error)
Parse() (Profiles, error)
}

type fileConfig struct {
Expand All @@ -37,7 +37,7 @@ func NewConfigFromEnv() (config, error) {
return &fileConfig{file: file}, nil
}

func (c *fileConfig) Parse() (profiles, error) {
func (c *fileConfig) Parse() (Profiles, error) {
if c.file == "" {
return nil, nil
}
Expand All @@ -48,7 +48,7 @@ func (c *fileConfig) Parse() (profiles, error) {
return nil, fmt.Errorf("Error parsing config file %q: %v", c.file, err)
}

profiles := profiles{"okta": map[string]string{}}
profiles := Profiles{"okta": map[string]string{}}
for sectionName, section := range f {
profiles[strings.TrimPrefix(sectionName, "profile ")] = section
}
Expand All @@ -57,11 +57,37 @@ func (c *fileConfig) Parse() (profiles, error) {
}

// sourceProfile returns either the defined source_profile or p if none exists
func sourceProfile(p string, from profiles) string {
func sourceProfile(p string, from Profiles) string {
if conf, ok := from[p]; ok {
if source := conf["source_profile"]; source != "" {
return source
}
}
return p
}

func (p Profiles) GetValue(profile string, config_key string) (string, string, error) {
config_value, ok := p[profile][config_key]
if ok {
return config_value, profile, nil
}

// Lookup from the `source_profile`, if it exists
profile, ok = p[profile]["source_profile"]
if ok {
config_value, ok := p[profile][config_key]
if ok {
return config_value, profile, nil
}

}

// Fallback to `okta` if no profile supplies the value
profile = "okta"
config_value, ok = p[profile][config_key]
if ok {
return config_value, profile, nil
}

return "", "", fmt.Errorf("Could not find %s in %s, source profile, or okta", config_key, profile)
}
95 changes: 95 additions & 0 deletions lib/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package lib

import "testing"

func TestGetConfigValue(t *testing.T) {
config_profiles := make(Profiles)

t.Run("empty profile", func(t *testing.T) {
_, _, found_error := config_profiles.GetValue("profile_a", "config_key")
if found_error == nil {
t.Error("Searching an empty profile set should return an error")
}
})

config_profiles["okta"] = map[string]string{
"key_a": "a",
"key_b": "b",
}

config_profiles["profile_a"] = map[string]string{
"key_b": "b-a",
"key_c": "c-a",
"key_d": "d-a",
}

config_profiles["profile_b"] = map[string]string{
"source_profile": "profile_a",
"key_d": "d-b",
"key_e": "e-b",
}

config_profiles["profile_c"] = map[string]string{
"source_profile": "profile_b",
"key_f": "f-c",
}

t.Run("missing key", func(t *testing.T) {
_, _, found_error := config_profiles.GetValue("profile_a", "config_key")
if found_error == nil {
t.Error("Searching for a missing key should return an error")
}
})

t.Run("fallback to okta", func(t *testing.T) {
found_value, found_profile, found_error := config_profiles.GetValue("profile_a", "key_a")
if found_error != nil {
t.Error("Error when searching for key_a")
}

if found_profile != "okta" {
t.Error("key_a should have come from `okta`")
}

if found_value != "a" {
t.Error("The proper value for `key_a` should be `a`")
}
})

t.Run("found in current profile", func(t *testing.T) {
found_value, found_profile, found_error := config_profiles.GetValue("profile_b", "key_d")
if found_error != nil {
t.Error("Error when searching for key_d")
}

if found_profile != "profile_b" {
t.Error("key_d should have come from `profile_b`")
}

if found_value != "d-b" {
t.Error("The proper value for `key_d` should be `d-b`")
}
})

t.Run("traversing from child profile", func(t *testing.T) {
found_value, found_profile, found_error := config_profiles.GetValue("profile_b", "key_a")
if found_error != nil {
t.Error("Error when searching for key_a")
}

if found_profile != "okta" {
t.Error("key_a should have come from `okta`")
}

if found_value != "a" {
t.Error("The proper value for `key_a` should be `a`")
}
})

t.Run("recursive traversing from child profile", func(t *testing.T) {
_, _, found_error := config_profiles.GetValue("profile_c", "key_c")
if found_error == nil {
t.Error("Recursive searching should not work")
}
})
}
6 changes: 3 additions & 3 deletions lib/okta.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,9 +526,9 @@ func (p *OktaProvider) Retrieve() (sts.Credentials, string, error) {
}

newCookieItem := keyring.Item{
Key: p.OktaSessionCookieKey,
Data: []byte(newSessionCookie),
Label: "okta session cookie",
Key: p.OktaSessionCookieKey,
Data: []byte(newSessionCookie),
Label: "okta session cookie",
KeychainNotTrustApplication: false,
}

Expand Down
34 changes: 14 additions & 20 deletions lib/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type ProviderOptions struct {
SessionDuration time.Duration
AssumeRoleDuration time.Duration
ExpiryWindow time.Duration
Profiles profiles
Profiles Profiles
}

func (o ProviderOptions) Validate() error {
Expand Down Expand Up @@ -65,7 +65,7 @@ type Provider struct {
expires time.Time
keyring keyring.Keyring
sessions *KeyringSessions
profiles profiles
profiles Profiles
defaultRoleSessionName string
}

Expand Down Expand Up @@ -134,37 +134,31 @@ func (p *Provider) Retrieve() (credentials.Value, error) {
return value, nil
}

func (p *Provider) getSamlURL(source string) (string, error) {
haystack := []string{p.profile, source, "okta"}
for _, profile := range haystack {
oktaAwsSAMLUrl, ok := p.profiles[profile]["aws_saml_url"]
if ok {
log.Debugf("Using aws_saml_url from profile: %s", profile)
return oktaAwsSAMLUrl, nil
}
func (p *Provider) getSamlURL() (string, error) {
oktaAwsSAMLUrl, profile, err := p.profiles.GetValue(p.profile, "aws_saml_url")
if err != nil {
log.Debugf("Using aws_saml_url from profile: %s", profile)
return oktaAwsSAMLUrl, nil
}
return "", errors.New("aws_saml_url missing from ~/.aws/config")
}

func (p *Provider) getOktaSessionCookieKey(source string) string {
haystack := []string{p.profile, source, "okta"}
for _, profile := range haystack {
oktaSessionCookieKey, ok := p.profiles[profile]["okta_session_cookie_key"]
if ok {
log.Debugf("Using okta_session_cookie_key from profile: %s", profile)
return oktaSessionCookieKey
}
func (p *Provider) getOktaSessionCookieKey() string {
oktaSessionCookieKey, profile, err := p.profiles.GetValue(p.profile, "okta_session_cookie_key")
if err != nil {
log.Debugf("Using okta_session_cookie_key from profile: %s", profile)
return oktaSessionCookieKey
}
return "okta-session-cookie"
}

func (p *Provider) getSamlSessionCreds() (sts.Credentials, error) {
source := sourceProfile(p.profile, p.profiles)
oktaAwsSAMLUrl, err := p.getSamlURL(source)
oktaAwsSAMLUrl, err := p.getSamlURL()
if err != nil {
return sts.Credentials{}, err
}
oktaSessionCookieKey := p.getOktaSessionCookieKey(source)
oktaSessionCookieKey := p.getOktaSessionCookieKey()

profileARN, ok := p.profiles[source]["role_arn"]
if !ok {
Expand Down
4 changes: 2 additions & 2 deletions lib/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ type awsSession struct {

type KeyringSessions struct {
Keyring keyring.Keyring
Profiles profiles
Profiles Profiles
}

func NewKeyringSessions(k keyring.Keyring, p profiles) (*KeyringSessions, error) {
func NewKeyringSessions(k keyring.Keyring, p Profiles) (*KeyringSessions, error) {
return &KeyringSessions{
Keyring: k,
Profiles: p,
Expand Down

0 comments on commit ea9a898

Please sign in to comment.