Skip to content

Commit

Permalink
Use TokenCache in ClientCredentialsTokenSourceProvider (#377)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewwdye authored Mar 7, 2023
1 parent a86c783 commit 60ba864
Show file tree
Hide file tree
Showing 3 changed files with 296 additions and 25 deletions.
54 changes: 54 additions & 0 deletions flyteidl/clients/go/admin/mocks/TokenSource.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

85 changes: 60 additions & 25 deletions flyteidl/clients/go/admin/token_source_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ import (
"github.com/flyteorg/flytestdlib/logger"
)

//go:generate mockery -name TokenSource
type TokenSource interface {
Token() (*oauth2.Token, error)
}

const (
audienceKey = "audience"
)
Expand Down Expand Up @@ -68,7 +73,7 @@ func NewTokenSourceProvider(ctx context.Context, cfg *Config, tokenCache cache.T
}
}

tokenProvider, err = NewClientCredentialsTokenSourceProvider(ctx, cfg, scopes, tokenURL, audienceValue)
tokenProvider, err = NewClientCredentialsTokenSourceProvider(ctx, cfg, scopes, tokenURL, tokenCache, audienceValue)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -163,10 +168,12 @@ func GetPKCEAuthTokenSource(ctx context.Context, pkceTokenOrchestrator pkce.Toke

type ClientCredentialsTokenSourceProvider struct {
ccConfig clientcredentials.Config
TokenRefreshWindow time.Duration
tokenRefreshWindow time.Duration
tokenCache cache.TokenCache
}

func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, scopes []string, tokenURL string, audience string) (TokenSourceProvider, error) {
func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, scopes []string, tokenURL string,
tokenCache cache.TokenCache, audience string) (TokenSourceProvider, error) {
var secret string
if len(cfg.ClientSecretEnvVar) > 0 {
secret = os.Getenv(cfg.ClientSecretEnvVar)
Expand All @@ -183,6 +190,9 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s
endpointParams = url.Values{audienceKey: {audience}}
}
secret = strings.TrimSpace(secret)
if tokenCache == nil {
tokenCache = &cache.TokenCacheInMemoryProvider{}
}
return ClientCredentialsTokenSourceProvider{
ccConfig: clientcredentials.Config{
ClientID: cfg.ClientID,
Expand All @@ -191,56 +201,81 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s
Scopes: scopes,
EndpointParams: endpointParams,
},
TokenRefreshWindow: cfg.TokenRefreshWindow.Duration}, nil
tokenRefreshWindow: cfg.TokenRefreshWindow.Duration,
tokenCache: tokenCache}, nil
}

func (p ClientCredentialsTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
if p.TokenRefreshWindow > 0 {
if p.tokenRefreshWindow > 0 {
source := p.ccConfig.TokenSource(ctx)
return &customTokenSource{
ctx: ctx,
new: source,
mu: sync.Mutex{},
t: nil,
tokenRefreshWindow: p.TokenRefreshWindow,
tokenRefreshWindow: p.tokenRefreshWindow,
tokenCache: p.tokenCache,
}, nil
}
return p.ccConfig.TokenSource(ctx), nil
}

type customTokenSource struct {
ctx context.Context
new oauth2.TokenSource
mu sync.Mutex // guards everything else
t *oauth2.Token
refreshTime time.Time
failedToRefresh bool
tokenRefreshWindow time.Duration
tokenCache cache.TokenCache
}

// fetchTokenFromCache returns the cached token if available, and a bool indicating if we should try to refresh it.
// This function is not thread safe and should be called with the lock held.
func (s *customTokenSource) fetchTokenFromCache() (*oauth2.Token, bool) {
token, err := s.tokenCache.GetToken()
if err != nil {
logger.Infof(s.ctx, "no token found in cache")
return nil, false
}
if !token.Valid() {
logger.Infof(s.ctx, "cached token invalid")
return nil, false
}
if time.Now().After(s.refreshTime) && !s.failedToRefresh {
logger.Infof(s.ctx, "cached token refresh window exceeded")
return token, true
}
logger.Infof(s.ctx, "using cached token")
return token, false
}

func (s *customTokenSource) Token() (*oauth2.Token, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.t.Valid() {
if time.Now().After(s.refreshTime) && !s.failedToRefresh {
t, err := s.new.Token()
if err != nil {
s.failedToRefresh = true // don't try to refresh again before expiry
return s.t, nil
}
s.t = t
s.refreshTime = s.t.Expiry.Add(-getRandomDuration(s.tokenRefreshWindow))
s.failedToRefresh = false
return s.t, nil
}
return s.t, nil

cachedToken, needsRefresh := s.fetchTokenFromCache()
if cachedToken != nil && !needsRefresh {
return cachedToken, nil
}
t, err := s.new.Token()

token, err := s.new.Token()
if err != nil {
if needsRefresh {
logger.Warnf(s.ctx, "failed to refresh token, using last cached token until expired")
s.failedToRefresh = true
return cachedToken, nil
}
logger.Errorf(s.ctx, "failed to refresh token")
return nil, err
}
s.t = t
logger.Infof(s.ctx, "refreshed token")
err = s.tokenCache.SaveToken(token)
if err != nil {
logger.Warnf(s.ctx, "failed to cache token, using anyway")
}
s.failedToRefresh = false
s.refreshTime = s.t.Expiry.Add(-getRandomDuration(s.tokenRefreshWindow))
return t, nil
s.refreshTime = token.Expiry.Add(-getRandomDuration(s.tokenRefreshWindow))
return token, nil
}

// Get random duration between 0 and maxDuration
Expand Down
Loading

0 comments on commit 60ba864

Please sign in to comment.