Skip to content
This repository has been archived by the owner on May 18, 2021. It is now read-only.

Commit

Permalink
Remove recursive traversal, export Profiles, and change to `GetValu…
Browse files Browse the repository at this point in the history
…e()`
  • Loading branch information
lsowen committed Jan 18, 2019
1 parent 7f49240 commit f80bc5e
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 32 deletions.
4 changes: 2 additions & 2 deletions cmd/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ func loadDurationFlagFromEnv(cmd *cobra.Command, flagName string, envVar string,
return nil
}

func updateDurationFromConfigProfile(profiles map[string]map[string]string, profile string, val *time.Duration) error {
fromProfile, _, err := lib.GetValueFromProfile(profile, profiles, "assume_role_ttl")
func updateDurationFromConfigProfile(profiles lib.Profiles, profile string, val *time.Duration) error {
fromProfile, _, err := profiles.GetValue(profile, "assume_role_ttl")
if err != nil {
return nil
}
Expand Down
14 changes: 7 additions & 7 deletions lib/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ import (
"github.com/vaughan0/go-ini"
)

type profiles map[string]map[string]string
type Profiles map[string]map[string]string

type config interface {
Parse() (profiles, error)
Parse() (Profiles, error)
}

type fileConfig struct {
Expand All @@ -37,7 +37,7 @@ func NewConfigFromEnv() (config, error) {
return &fileConfig{file: file}, nil
}

func (c *fileConfig) Parse() (profiles, error) {
func (c *fileConfig) Parse() (Profiles, error) {
if c.file == "" {
return nil, nil
}
Expand All @@ -48,7 +48,7 @@ func (c *fileConfig) Parse() (profiles, error) {
return nil, fmt.Errorf("Error parsing config file %q: %v", c.file, err)
}

profiles := profiles{"okta": map[string]string{}}
profiles := Profiles{"okta": map[string]string{}}
for sectionName, section := range f {
profiles[strings.TrimPrefix(sectionName, "profile ")] = section
}
Expand All @@ -57,7 +57,7 @@ func (c *fileConfig) Parse() (profiles, error) {
}

// sourceProfile returns either the defined source_profile or p if none exists
func sourceProfile(p string, from profiles) string {
func sourceProfile(p string, from Profiles) string {
if conf, ok := from[p]; ok {
if source := conf["source_profile"]; source != "" {
return source
Expand All @@ -66,8 +66,8 @@ func sourceProfile(p string, from profiles) string {
return p
}

func (p profiles) GetValueFromProfile(profile string, config_key string) (string, string, error) {
for {
func (p Profiles) GetValue(profile string, config_key string) (string, string, error) {
for i := 0; i < 2; i++ {
config_value, ok := p[profile][config_key]
if ok {
return config_value, profile, nil
Expand Down
26 changes: 9 additions & 17 deletions lib/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package lib
import "testing"

func TestGetConfigValue(t *testing.T) {
config_profiles := make(profiles)
config_profiles := make(Profiles)

t.Run("empty profile", func(t *testing.T) {
_, _, found_error := config_profiles.GetValueFromProfile("profile_a", "config_key")
_, _, found_error := config_profiles.GetValue("profile_a", "config_key")
if found_error == nil {
t.Error("Searching an empty profile set should return an error")
}
Expand Down Expand Up @@ -35,14 +35,14 @@ func TestGetConfigValue(t *testing.T) {
}

t.Run("missing key", func(t *testing.T) {
_, _, found_error := config_profiles.GetValueFromProfile("profile_a", "config_key")
_, _, found_error := config_profiles.GetValue("profile_a", "config_key")
if found_error == nil {
t.Error("Searching for a missing key should return an error")
}
})

t.Run("fallback to okta", func(t *testing.T) {
found_value, found_profile, found_error := config_profiles.GetValueFromProfile("profile_a", "key_a")
found_value, found_profile, found_error := config_profiles.GetValue("profile_a", "key_a")
if found_error != nil {
t.Error("Error when searching for key_a")
}
Expand All @@ -57,7 +57,7 @@ func TestGetConfigValue(t *testing.T) {
})

t.Run("found in current profile", func(t *testing.T) {
found_value, found_profile, found_error := config_profiles.GetValueFromProfile("profile_b", "key_d")
found_value, found_profile, found_error := config_profiles.GetValue("profile_b", "key_d")
if found_error != nil {
t.Error("Error when searching for key_d")
}
Expand All @@ -72,7 +72,7 @@ func TestGetConfigValue(t *testing.T) {
})

t.Run("traversing from child profile", func(t *testing.T) {
found_value, found_profile, found_error := config_profiles.GetValueFromProfile("profile_b", "key_a")
found_value, found_profile, found_error := config_profiles.GetValue("profile_b", "key_a")
if found_error != nil {
t.Error("Error when searching for key_a")
}
Expand All @@ -87,17 +87,9 @@ func TestGetConfigValue(t *testing.T) {
})

t.Run("recursive traversing from child profile", func(t *testing.T) {
found_value, found_profile, found_error := config_profiles.GetValueFromProfile("profile_c", "key_c")
if found_error != nil {
t.Error("Error when searching for key_c")
}

if found_profile != "profile_a" {
t.Error("key_c should have come from `profile_a`")
}

if found_value != "c-a" {
t.Error("The proper value for `key_c` should be `c-a`")
_, _, found_error := config_profiles.GetValue("profile_c", "key_c")
if found_error == nil {
t.Error("Recursive searching should not work")
}
})
}
8 changes: 4 additions & 4 deletions lib/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type ProviderOptions struct {
SessionDuration time.Duration
AssumeRoleDuration time.Duration
ExpiryWindow time.Duration
Profiles profiles
Profiles Profiles
}

func (o ProviderOptions) Validate() error {
Expand Down Expand Up @@ -65,7 +65,7 @@ type Provider struct {
expires time.Time
keyring keyring.Keyring
sessions *KeyringSessions
profiles profiles
profiles Profiles
defaultRoleSessionName string
}

Expand Down Expand Up @@ -135,7 +135,7 @@ func (p *Provider) Retrieve() (credentials.Value, error) {
}

func (p *Provider) getSamlURL() (string, error) {
oktaAwsSAMLUrl, profile, err := p.profiles.GetValueFromProfile(p.profile, "aws_saml_url")
oktaAwsSAMLUrl, profile, err := p.profiles.GetValue(p.profile, "aws_saml_url")
if err != nil {
log.Debugf("Using aws_saml_url from profile: %s", profile)
return oktaAwsSAMLUrl, nil
Expand All @@ -144,7 +144,7 @@ func (p *Provider) getSamlURL() (string, error) {
}

func (p *Provider) getOktaSessionCookieKey() string {
oktaSessionCookieKey, profile, err := p.profiles.GetValueFromProfile(p.profile, "okta_session_cookie_key")
oktaSessionCookieKey, profile, err := p.profiles.GetValue(p.profile, "okta_session_cookie_key")
if err != nil {
log.Debugf("Using okta_session_cookie_key from profile: %s", profile)
return oktaSessionCookieKey
Expand Down
4 changes: 2 additions & 2 deletions lib/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ type awsSession struct {

type KeyringSessions struct {
Keyring keyring.Keyring
Profiles profiles
Profiles Profiles
}

func NewKeyringSessions(k keyring.Keyring, p profiles) (*KeyringSessions, error) {
func NewKeyringSessions(k keyring.Keyring, p Profiles) (*KeyringSessions, error) {
return &KeyringSessions{
Keyring: k,
Profiles: p,
Expand Down

0 comments on commit f80bc5e

Please sign in to comment.