Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reset cached access token on invalid_grant #220

Merged
merged 6 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
# Changelog

## 2.3.0 (July 12, 2024)

### ENHANCEMENTS

* New command `okta-aws-cli list-profiles` helper to inspect profiles in okta.yaml [#222](https://github.com/okta/okta-aws-cli/pull/222), thanks [@pmgalea](https://github.com/pmgalea)!
* GH releases publish Windows artifact to Chocolatey [#215](https://github.com/okta/okta-aws-cli/pull/215), thanks [@monde](https://github.com/monde)!
* Better retry for when the cached access token has been invalidated outside of okta-aws-cli's control. [#220](https://github.com/okta/okta-aws-cli/pull/220), thanks [@monde](https://github.com/monde)!
* Print a warning at first run if otka.yaml is malformed. [#220](https://github.com/okta/okta-aws-cli/pull/220), thanks [@monde](https://github.com/monde)!

### BUG FIXES

* Correct "default" profile flaw introduced in 2.2.0 release [#220](https://github.com/okta/okta-aws-cli/pull/220), thanks [@monde](https://github.com/monde)!
* Continue polling instead of exit on a 400 "slow_down" API error [#220](https://github.com/okta/okta-aws-cli/pull/220), thanks [@monde](https://github.com/monde)!

## 2.2.0 (July 3, 2024)

### ENHANCEMENTS
Expand Down
6 changes: 2 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ build: fmtcheck
go build -o $(GOBIN)/okta-aws-cli cmd/okta-aws-cli/main.go

clean:
go clean -cache -testcache ./...

clean-all:
go clean -cache -testcache -modcache ./...
rm -fr dist/
go clean -testcache

fmt: tools # Format the code
@$(GOFMT) -l -w .
Expand Down
2 changes: 1 addition & 1 deletion cmd/root/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func init() {
{
Name: config.ProfileFlag,
Short: "p",
Value: "default",
Value: "",
Usage: "AWS Profile",
EnvVar: config.ProfileEnvVar,
},
Expand Down
34 changes: 29 additions & 5 deletions cmd/root/web/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/okta/okta-aws-cli/internal/config"
cliFlag "github.com/okta/okta-aws-cli/internal/flag"
"github.com/okta/okta-aws-cli/internal/okta"
"github.com/okta/okta-aws-cli/internal/webssoauth"
)

Expand Down Expand Up @@ -78,20 +79,43 @@ func NewWebCommand() *cobra.Command {
Use: "web",
Short: "Human oriented authentication and device authorization",
RunE: func(cmd *cobra.Command, args []string) error {
config, err := config.EvaluateSettings()
cfg, err := config.EvaluateSettings()
if err != nil {
return err
}
err = cliFlag.CheckRequiredFlags(requiredFlags)

// Warn if there is an issue with okta.yaml
_, err = config.OktaConfig()
if err != nil {
return err
webssoauth.ConsolePrint(cfg, "WARNING: issue with %s file. Run `okta-aws-cli debug` command for additional diagnosis.\nError: %+v\n", config.OktaYaml, err)
}

wsa, err := webssoauth.NewWebSSOAuthentication(config)
err = cliFlag.CheckRequiredFlags(requiredFlags)
if err != nil {
return err
}
return wsa.EstablishIAMCredentials()

for attempt := 1; attempt <= 2; attempt++ {
wsa, err := webssoauth.NewWebSSOAuthentication(cfg)
if err != nil {
break
}

err = wsa.EstablishIAMCredentials()
if err == nil {
break
}

if apiErr, ok := err.(*okta.APIError); ok {
if apiErr.ErrorType == "invalid_grant" && webssoauth.RemoveCachedAccessToken() {
webssoauth.ConsolePrint(cfg, "\nCached access token appears to be stale, removing token and retrying device authorization ...\n\n")
continue
}
break
}
}

return err
},
}

Expand Down
21 changes: 11 additions & 10 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func init() {

const (
// Version app version
Version = "2.2.0"
Version = "2.3.0"

////////////////////////////////////////////////////////////
// FORMATS
Expand Down Expand Up @@ -420,7 +420,17 @@ func readConfig() (Attributes, error) {
}
}

// config loading order
// 1) command line flags 2) environment variables, 3) .env file
awsProfile := viper.GetString(ProfileFlag)
// mimic AWS CLI behavior, if profile value is not set by flag check
// the ENV VAR, else set to "default"
if awsProfile == "" {
awsProfile = viper.GetString(downCase(ProfileEnvVar))
}
if awsProfile == "" {
awsProfile = "default"
}

attrs := Attributes{
AllProfiles: viper.GetBool(getFlagNameFromProfile(awsProfile, AllProfilesFlag)),
Expand Down Expand Up @@ -454,15 +464,6 @@ func readConfig() (Attributes, error) {
attrs.Format = EnvVarFormat
}

// mimic AWS CLI behavior, if profile value is not set by flag check
// the ENV VAR, else set to "default"
if attrs.Profile == "" {
attrs.Profile = viper.GetString(downCase(ProfileEnvVar))
}
if attrs.Profile == "" {
attrs.Profile = "default"
}

// Viper binds ENV VARs to a lower snake version, set the configs with them
// if they haven't already been set by cli flag binding.
if attrs.OrgDomain == "" {
Expand Down
83 changes: 81 additions & 2 deletions internal/okta/apierror.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,87 @@

package okta

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

"github.com/BurntSushi/toml"
)

const (
// APIErrorMessageBase base API error message
APIErrorMessageBase = "the API returned an unknown error"
// APIErrorMessageWithErrorDescription API error message with description
APIErrorMessageWithErrorDescription = "the API returned an error: %s"
// APIErrorMessageWithErrorSummary API error message with summary
APIErrorMessageWithErrorSummary = "the API returned an error: %s"
// HTTPHeaderWwwAuthenticate Www-Authenticate header
HTTPHeaderWwwAuthenticate = "Www-Authenticate"
)

// APIError Wrapper for Okta API error
type APIError struct {
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ErrorType string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorCode string `json:"errorCode,omitempty"`
ErrorSummary string `json:"errorSummary,omitempty" toml:"error_description"`
ErrorLink string `json:"errorLink,omitempty"`
ErrorID string `json:"errorId,omitempty"`
ErrorCauses []map[string]interface{} `json:"errorCauses,omitempty"`
}

// Error String-ify the Error
func (e *APIError) Error() string {
formattedErr := APIErrorMessageBase
if e.ErrorDescription != "" {
formattedErr = fmt.Sprintf(APIErrorMessageWithErrorDescription, e.ErrorDescription)
} else if e.ErrorSummary != "" {
formattedErr = fmt.Sprintf(APIErrorMessageWithErrorSummary, e.ErrorSummary)
}
if len(e.ErrorCauses) > 0 {
var causes []string
for _, cause := range e.ErrorCauses {
for key, val := range cause {
causes = append(causes, fmt.Sprintf("%s: %v", key, val))
}
}
formattedErr = fmt.Sprintf("%s. Causes: %s", formattedErr, strings.Join(causes, ", "))
}
return formattedErr
}

// NewAPIError Constructor for Okta API error, will return nil if the response
// is not an error.
func NewAPIError(resp *http.Response) error {
statusCode := resp.StatusCode
if statusCode >= http.StatusOK && statusCode < http.StatusBadRequest {
return nil
}
e := APIError{}
if (statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden) &&
strings.Contains(resp.Header.Get(HTTPHeaderWwwAuthenticate), "Bearer") {
for _, v := range strings.Split(resp.Header.Get(HTTPHeaderWwwAuthenticate), ", ") {
if strings.Contains(v, "error_description") {
_, err := toml.Decode(v, &e)
if err != nil {
e.ErrorSummary = "unauthorized"
}
return &e
}
}
}
bodyBytes, _ := io.ReadAll(resp.Body)
copyBodyBytes := make([]byte, len(bodyBytes))
copy(copyBodyBytes, bodyBytes)
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
_ = json.NewDecoder(bytes.NewReader(copyBodyBytes)).Decode(&e)
if statusCode == http.StatusInternalServerError {
e.ErrorSummary += fmt.Sprintf(", x-okta-request-id=%s", resp.Header.Get("x-okta-request-id"))
}
return &e
}
84 changes: 18 additions & 66 deletions internal/paginator/paginator.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,37 @@
/*
* Copyright (c) 2024-Present, Okta, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package paginator

import (
"bytes"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"

"github.com/BurntSushi/toml"
"github.com/okta/okta-aws-cli/internal/okta"
)

const (
// HTTPHeaderWwwAuthenticate Www-Authenticate header
HTTPHeaderWwwAuthenticate = "Www-Authenticate"
// APIErrorMessageBase base API error message
APIErrorMessageBase = "the API returned an unknown error"
// APIErrorMessageWithErrorDescription API error message with description
APIErrorMessageWithErrorDescription = "the API returned an error: %s"
// APIErrorMessageWithErrorSummary API error message with summary
Expand Down Expand Up @@ -136,7 +149,7 @@ func newPaginateResponse(r *http.Response, pgntr *Paginator) *PaginateResponse {
func buildPaginateResponse(resp *http.Response, pgntr *Paginator, v interface{}) (*PaginateResponse, error) {
ct := resp.Header.Get("Content-Type")
response := newPaginateResponse(resp, pgntr)
err := checkResponseForError(resp)
err := okta.NewAPIError(resp)
if err != nil {
return response, err
}
Expand Down Expand Up @@ -167,64 +180,3 @@ func buildPaginateResponse(resp *http.Response, pgntr *Paginator, v interface{})
}
return response, nil
}

func checkResponseForError(resp *http.Response) error {
statusCode := resp.StatusCode
if statusCode >= http.StatusOK && statusCode < http.StatusBadRequest {
return nil
}
e := Error{}
if (statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden) &&
strings.Contains(resp.Header.Get(HTTPHeaderWwwAuthenticate), "Bearer") {
for _, v := range strings.Split(resp.Header.Get(HTTPHeaderWwwAuthenticate), ", ") {
if strings.Contains(v, "error_description") {
_, err := toml.Decode(v, &e)
if err != nil {
e.ErrorSummary = "unauthorized"
}
return &e
}
}
}
bodyBytes, _ := io.ReadAll(resp.Body)
copyBodyBytes := make([]byte, len(bodyBytes))
copy(copyBodyBytes, bodyBytes)
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
_ = json.NewDecoder(bytes.NewReader(copyBodyBytes)).Decode(&e)
if statusCode == http.StatusInternalServerError {
e.ErrorSummary += fmt.Sprintf(", x-okta-request-id=%s", resp.Header.Get("x-okta-request-id"))
}
return &e
}

// Error A struct for marshalling Okta's API error response bodies
type Error struct {
ErrorMessage string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorCode string `json:"errorCode,omitempty"`
ErrorSummary string `json:"errorSummary,omitempty" toml:"error_description"`
ErrorLink string `json:"errorLink,omitempty"`
ErrorID string `json:"errorId,omitempty"`
ErrorCauses []map[string]interface{} `json:"errorCauses,omitempty"`
}

// Error String-ify the Error
func (e *Error) Error() string {
formattedErr := APIErrorMessageBase
if e.ErrorDescription != "" {
formattedErr = fmt.Sprintf(APIErrorMessageWithErrorDescription, e.ErrorDescription)
} else if e.ErrorSummary != "" {
formattedErr = fmt.Sprintf(APIErrorMessageWithErrorSummary, e.ErrorSummary)
}
if len(e.ErrorCauses) > 0 {
var causes []string
for _, cause := range e.ErrorCauses {
for key, val := range cause {
causes = append(causes, fmt.Sprintf("%s: %v", key, val))
}
}
formattedErr = fmt.Sprintf("%s. Causes: %s", formattedErr, strings.Join(causes, ", "))
}
return formattedErr
}
Loading
Loading