From 6d14c2cc7b8981bf541fce8371123dbacc88b764 Mon Sep 17 00:00:00 2001 From: Vivian <2908189+vivianho@users.noreply.github.com> Date: Tue, 2 Apr 2019 10:37:28 -0700 Subject: [PATCH] feat: Add --domain and --username for aws-okta add (#137) --- cmd/add.go | 68 +++++++++++++++++++++++++++++++++++------------------ lib/okta.go | 9 +++---- 2 files changed, 50 insertions(+), 27 deletions(-) diff --git a/cmd/add.go b/cmd/add.go index 039768b8..f9935faf 100644 --- a/cmd/add.go +++ b/cmd/add.go @@ -12,6 +12,12 @@ import ( "github.com/spf13/cobra" ) +var ( + organization string + oktaDomain string + oktaRegion string +) + // addCmd represents the add command var addCmd = &cobra.Command{ Use: "add", @@ -21,6 +27,8 @@ var addCmd = &cobra.Command{ func init() { RootCmd.AddCommand(addCmd) + addCmd.Flags().StringVarP(&oktaDomain, "domain", "", "", "Okta domain (e.g. .okta.com)") + addCmd.Flags().StringVarP(&username, "username", "", "", "Okta username") } func add(cmd *cobra.Command, args []string) error { @@ -45,30 +53,44 @@ func add(cmd *cobra.Command, args []string) error { }) } - // Ask username password from prompt - organization, err := lib.Prompt("Okta organization", false) - if err != nil { - return err + // Ask Okta organization details if not given in command line argument + if oktaDomain == "" { + organization, err = lib.Prompt("Okta organization", false) + if err != nil { + return err + } + + oktaRegion, err = lib.Prompt("Okta region ([us], emea, preview)", false) + if err != nil { + return err + } + if oktaRegion == "" { + oktaRegion = "us" + } + + tld, err := lib.GetOktaDomain(oktaRegion) + if err != nil { + return err + } + defaultOktaDomain := fmt.Sprintf("%s.%s", organization, tld) + + oktaDomain, err = lib.Prompt("Okta domain ["+defaultOktaDomain+"]", false) + if err != nil { + return err + } + if oktaDomain == "" { + oktaDomain = defaultOktaDomain + } } - oktaRegion, err := lib.Prompt("Okta region ([us], emea, preview)", false) - if err != nil { - return err - } - if oktaRegion == "" { - oktaRegion = "us" - } - - oktaDomain, err := lib.Prompt("Okta domain ["+oktaRegion+".okta.com]", false) - if err != nil { - return err - } - - username, err := lib.Prompt("Okta username", false) - if err != nil { - return err + if username == "" { + username, err = lib.Prompt("Okta username", false) + if err != nil { + return err + } } + // Ask for password from prompt password, err := lib.Prompt("Okta password", true) if err != nil { return err @@ -98,9 +120,9 @@ func add(cmd *cobra.Command, args []string) error { } item := keyring.Item{ - Key: "okta-creds", - Data: encoded, - Label: "okta credentials", + Key: "okta-creds", + Data: encoded, + Label: "okta credentials", KeychainNotTrustApplication: false, } diff --git a/lib/okta.go b/lib/okta.go index 943a4902..6485150c 100644 --- a/lib/okta.go +++ b/lib/okta.go @@ -86,7 +86,7 @@ func (c *OktaCreds) Validate(mfaConfig MFAConfig) error { return nil } -func getOktaDomain(region string) (string, error) { +func GetOktaDomain(region string) (string, error) { switch region { case "us": return OktaServerUs, nil @@ -131,6 +131,7 @@ func NewOktaClient(creds OktaCreds, oktaAwsSAMLUrl string, sessionCookie string, }, }) } + log.Debug("domain: " + domain) return &OktaClient{ // Setting Organization for backwards compatibility @@ -561,9 +562,9 @@ func (p *OktaProvider) Retrieve() (sts.Credentials, string, error) { } newCookieItem := keyring.Item{ - Key: p.OktaSessionCookieKey, - Data: []byte(newSessionCookie), - Label: "okta session cookie", + Key: p.OktaSessionCookieKey, + Data: []byte(newSessionCookie), + Label: "okta session cookie", KeychainNotTrustApplication: false, }