Skip to content

Commit

Permalink
Better retry for when the cached access token has been invalidated
Browse files Browse the repository at this point in the history
outside of okta-aws-cli's control.

Closes #207
Closes #198
  • Loading branch information
monde committed Jul 8, 2024
1 parent d3f3618 commit 771a778
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 96 deletions.
26 changes: 22 additions & 4 deletions cmd/root/web/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/okta/okta-aws-cli/internal/config"
cliFlag "github.com/okta/okta-aws-cli/internal/flag"
"github.com/okta/okta-aws-cli/internal/okta"
"github.com/okta/okta-aws-cli/internal/webssoauth"
)

Expand Down Expand Up @@ -82,16 +83,33 @@ func NewWebCommand() *cobra.Command {
if err != nil {
return err
}

err = cliFlag.CheckRequiredFlags(requiredFlags)
if err != nil {
return err
}

wsa, err := webssoauth.NewWebSSOAuthentication(config)
if err != nil {
return err
for attempt := 1; attempt <= 2; attempt++ {
wsa, err := webssoauth.NewWebSSOAuthentication(config)
if err != nil {
break
}

err = wsa.EstablishIAMCredentials()
if err == nil {
break
}

if apiErr, ok := err.(*okta.APIError); ok {
if apiErr.ErrorType == "invalid_grant" && webssoauth.RemoveCachedAccessToken() {
webssoauth.ConsolePrint(config, "\nCached access token appears to be stale, removing token and retrying device authorization ...\n\n")
continue
}
break
}
}
return wsa.EstablishIAMCredentials()

return err
},
}

Expand Down
83 changes: 81 additions & 2 deletions internal/okta/apierror.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,87 @@

package okta

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

"github.com/BurntSushi/toml"
)

const (
// APIErrorMessageBase base API error message
APIErrorMessageBase = "the API returned an unknown error"
// APIErrorMessageWithErrorDescription API error message with description
APIErrorMessageWithErrorDescription = "the API returned an error: %s"
// APIErrorMessageWithErrorSummary API error message with summary
APIErrorMessageWithErrorSummary = "the API returned an error: %s"
// HTTPHeaderWwwAuthenticate Www-Authenticate header
HTTPHeaderWwwAuthenticate = "Www-Authenticate"
)

// APIError Wrapper for Okta API error
type APIError struct {
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ErrorType string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorCode string `json:"errorCode,omitempty"`
ErrorSummary string `json:"errorSummary,omitempty" toml:"error_description"`
ErrorLink string `json:"errorLink,omitempty"`
ErrorID string `json:"errorId,omitempty"`
ErrorCauses []map[string]interface{} `json:"errorCauses,omitempty"`
}

// Error String-ify the Error
func (e *APIError) Error() string {
formattedErr := APIErrorMessageBase
if e.ErrorDescription != "" {
formattedErr = fmt.Sprintf(APIErrorMessageWithErrorDescription, e.ErrorDescription)
} else if e.ErrorSummary != "" {
formattedErr = fmt.Sprintf(APIErrorMessageWithErrorSummary, e.ErrorSummary)
}
if len(e.ErrorCauses) > 0 {
var causes []string
for _, cause := range e.ErrorCauses {
for key, val := range cause {
causes = append(causes, fmt.Sprintf("%s: %v", key, val))
}
}
formattedErr = fmt.Sprintf("%s. Causes: %s", formattedErr, strings.Join(causes, ", "))
}
return formattedErr
}

// NewAPIError Constructor for Okta API error, will return nil if the response
// is not an error.
func NewAPIError(resp *http.Response) error {
statusCode := resp.StatusCode
if statusCode >= http.StatusOK && statusCode < http.StatusBadRequest {
return nil
}
e := APIError{}
if (statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden) &&
strings.Contains(resp.Header.Get(HTTPHeaderWwwAuthenticate), "Bearer") {
for _, v := range strings.Split(resp.Header.Get(HTTPHeaderWwwAuthenticate), ", ") {
if strings.Contains(v, "error_description") {
_, err := toml.Decode(v, &e)
if err != nil {
e.ErrorSummary = "unauthorized"
}
return &e
}
}
}
bodyBytes, _ := io.ReadAll(resp.Body)
copyBodyBytes := make([]byte, len(bodyBytes))
copy(copyBodyBytes, bodyBytes)
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
_ = json.NewDecoder(bytes.NewReader(copyBodyBytes)).Decode(&e)
if statusCode == http.StatusInternalServerError {
e.ErrorSummary += fmt.Sprintf(", x-okta-request-id=%s", resp.Header.Get("x-okta-request-id"))
}
return &e
}
84 changes: 18 additions & 66 deletions internal/paginator/paginator.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,37 @@
/*
* Copyright (c) 2024-Present, Okta, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package paginator

import (
"bytes"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"

"github.com/BurntSushi/toml"
"github.com/okta/okta-aws-cli/internal/okta"
)

const (
// HTTPHeaderWwwAuthenticate Www-Authenticate header
HTTPHeaderWwwAuthenticate = "Www-Authenticate"
// APIErrorMessageBase base API error message
APIErrorMessageBase = "the API returned an unknown error"
// APIErrorMessageWithErrorDescription API error message with description
APIErrorMessageWithErrorDescription = "the API returned an error: %s"
// APIErrorMessageWithErrorSummary API error message with summary
Expand Down Expand Up @@ -136,7 +149,7 @@ func newPaginateResponse(r *http.Response, pgntr *Paginator) *PaginateResponse {
func buildPaginateResponse(resp *http.Response, pgntr *Paginator, v interface{}) (*PaginateResponse, error) {
ct := resp.Header.Get("Content-Type")
response := newPaginateResponse(resp, pgntr)
err := checkResponseForError(resp)
err := okta.NewAPIError(resp)
if err != nil {
return response, err
}
Expand Down Expand Up @@ -167,64 +180,3 @@ func buildPaginateResponse(resp *http.Response, pgntr *Paginator, v interface{})
}
return response, nil
}

func checkResponseForError(resp *http.Response) error {
statusCode := resp.StatusCode
if statusCode >= http.StatusOK && statusCode < http.StatusBadRequest {
return nil
}
e := Error{}
if (statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden) &&
strings.Contains(resp.Header.Get(HTTPHeaderWwwAuthenticate), "Bearer") {
for _, v := range strings.Split(resp.Header.Get(HTTPHeaderWwwAuthenticate), ", ") {
if strings.Contains(v, "error_description") {
_, err := toml.Decode(v, &e)
if err != nil {
e.ErrorSummary = "unauthorized"
}
return &e
}
}
}
bodyBytes, _ := io.ReadAll(resp.Body)
copyBodyBytes := make([]byte, len(bodyBytes))
copy(copyBodyBytes, bodyBytes)
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
_ = json.NewDecoder(bytes.NewReader(copyBodyBytes)).Decode(&e)
if statusCode == http.StatusInternalServerError {
e.ErrorSummary += fmt.Sprintf(", x-okta-request-id=%s", resp.Header.Get("x-okta-request-id"))
}
return &e
}

// Error A struct for marshalling Okta's API error response bodies
type Error struct {
ErrorMessage string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorCode string `json:"errorCode,omitempty"`
ErrorSummary string `json:"errorSummary,omitempty" toml:"error_description"`
ErrorLink string `json:"errorLink,omitempty"`
ErrorID string `json:"errorId,omitempty"`
ErrorCauses []map[string]interface{} `json:"errorCauses,omitempty"`
}

// Error String-ify the Error
func (e *Error) Error() string {
formattedErr := APIErrorMessageBase
if e.ErrorDescription != "" {
formattedErr = fmt.Sprintf(APIErrorMessageWithErrorDescription, e.ErrorDescription)
} else if e.ErrorSummary != "" {
formattedErr = fmt.Sprintf(APIErrorMessageWithErrorSummary, e.ErrorSummary)
}
if len(e.ErrorCauses) > 0 {
var causes []string
for _, cause := range e.ErrorCauses {
for key, val := range cause {
causes = append(causes, fmt.Sprintf("%s: %v", key, val))
}
}
formattedErr = fmt.Sprintf("%s. Causes: %s", formattedErr, strings.Join(causes, ", "))
}
return formattedErr
}
64 changes: 40 additions & 24 deletions internal/webssoauth/webssoauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (
"net/url"
"os"
osexec "os/exec"
"os/user"
"path/filepath"
"regexp"
"strings"
Expand Down Expand Up @@ -723,16 +722,9 @@ func (w *WebSSOAuthentication) fetchSSOWebToken(clientID, awsFedAppID string, at
return nil, err
}

if resp.StatusCode != http.StatusOK {
baseErrStr := "fetching SSO web token received API response %q"

var apiErr okta.APIError
err = json.NewDecoder(resp.Body).Decode(&apiErr)
if err != nil {
return nil, fmt.Errorf(baseErrStr, resp.Status)
}

return nil, fmt.Errorf(baseErrStr+okta.AccessTokenErrorFormat, resp.Status, apiErr.Error, apiErr.ErrorDescription)
err = okta.NewAPIError(resp)
if err != nil {
return nil, err
}

token = &okta.AccessToken{}
Expand Down Expand Up @@ -956,8 +948,8 @@ func (w *WebSSOAuthentication) accessToken(deviceAuth *okta.DeviceAuthorization)
if err != nil {
return backoff.Permanent(fmt.Errorf("fetching access token polling received unexpected API error body %q", string(bodyBytes)))
}
if apiErr.Error != "authorization_pending" {
return backoff.Permanent(fmt.Errorf("fetching access token polling received unexpected API polling error %q - %q", apiErr.Error, apiErr.ErrorDescription))
if apiErr.ErrorType != "authorization_pending" {
return backoff.Permanent(fmt.Errorf("fetching access token polling received unexpected API polling error %q - %q", apiErr.ErrorType, apiErr.ErrorDescription))
}

return errors.New("continue polling")
Expand Down Expand Up @@ -1120,15 +1112,37 @@ func (w *WebSSOAuthentication) isClassicOrg() bool {
return false
}

// cachedAccessTokenPath Path to the cached access token in $HOME/.okta/awscli-access-token.json
func cachedAccessTokenPath() (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(homeDir, dotOktaDir, tokenFileName), nil
}

// RemoveCachedAccessToken Remove cached access token if it exists. Returns true
// if the file exists was reremoved, swallows errors otherwise.
func RemoveCachedAccessToken() bool {
accessTokenPath, err := cachedAccessTokenPath()
if err != nil {
return false
}
if os.Remove(accessTokenPath) != nil {
return false
}

return true
}

// cachedAccessToken will returned the cached access token if it exists and is
// not expired.
func (w *WebSSOAuthentication) cachedAccessToken() (at *okta.AccessToken) {
homeDir, err := os.UserHomeDir()
accessTokenPath, err := cachedAccessTokenPath()
if err != nil {
return
}
configPath := filepath.Join(homeDir, dotOktaDir, tokenFileName)
atJSON, err := os.ReadFile(configPath)
atJSON, err := os.ReadFile(accessTokenPath)
if err != nil {
return
}
Expand Down Expand Up @@ -1158,15 +1172,12 @@ func (w *WebSSOAuthentication) cacheAccessToken(at *okta.AccessToken) {
return
}

cUser, err := user.Current()
homeDir, err := os.UserHomeDir()
if err != nil {
return
}
if cUser.HomeDir == "" {
return
}

oktaDir := filepath.Join(cUser.HomeDir, dotOktaDir)
oktaDir := filepath.Join(homeDir, dotOktaDir)
// noop if dir exists
err = os.MkdirAll(oktaDir, 0o700)
if err != nil {
Expand All @@ -1178,18 +1189,23 @@ func (w *WebSSOAuthentication) cacheAccessToken(at *okta.AccessToken) {
return
}

configPath := filepath.Join(cUser.HomeDir, dotOktaDir, tokenFileName)
configPath := filepath.Join(homeDir, dotOktaDir, tokenFileName)
_ = os.WriteFile(configPath, atJSON, 0o600)
}

func (w *WebSSOAuthentication) consolePrint(format string, a ...any) {
if w.config.IsProcessCredentialsFormat() {
// ConsolePrint printf formatted warning messages.
func ConsolePrint(config *config.Config, format string, a ...any) {
if config.IsProcessCredentialsFormat() {
return
}

fmt.Fprintf(os.Stderr, format, a...)
}

func (w *WebSSOAuthentication) consolePrint(format string, a ...any) {
ConsolePrint(w.config, format, a...)
}

// fetchAllAWSCredentialsWithSAMLRole Gets all AWS Credentials with an STS Assume Role with SAML AWS API call.
func (w *WebSSOAuthentication) fetchAllAWSCredentialsWithSAMLRole(idpRolesMap map[string][]string, assertion, region string) <-chan *oaws.CredentialContainer {
ccch := make(chan *oaws.CredentialContainer)
Expand Down

0 comments on commit 771a778

Please sign in to comment.