diff --git a/clients/go/admin/token_source_provider.go b/clients/go/admin/token_source_provider.go index 937b6d901..41db678f6 100644 --- a/clients/go/admin/token_source_provider.go +++ b/clients/go/admin/token_source_provider.go @@ -208,12 +208,17 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s func (p ClientCredentialsTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) { if p.tokenRefreshWindow > 0 { source := p.ccConfig.TokenSource(ctx) + refreshTime := time.Time{} + if token, err := p.tokenCache.GetToken(); err == nil { + refreshTime = token.Expiry.Add(-getRandomDuration(p.tokenRefreshWindow)) + } return &customTokenSource{ ctx: ctx, new: source, mu: sync.Mutex{}, tokenRefreshWindow: p.tokenRefreshWindow, tokenCache: p.tokenCache, + refreshTime: refreshTime, }, nil } return p.ccConfig.TokenSource(ctx), nil @@ -222,10 +227,10 @@ func (p ClientCredentialsTokenSourceProvider) GetTokenSource(ctx context.Context type customTokenSource struct { ctx context.Context new oauth2.TokenSource + tokenRefreshWindow time.Duration mu sync.Mutex // guards everything else refreshTime time.Time failedToRefresh bool - tokenRefreshWindow time.Duration tokenCache cache.TokenCache } @@ -245,7 +250,6 @@ func (s *customTokenSource) fetchTokenFromCache() (*oauth2.Token, bool) { logger.Infof(s.ctx, "cached token refresh window exceeded") return token, true } - logger.Infof(s.ctx, "using cached token") return token, false } diff --git a/clients/go/admin/token_source_provider_test.go b/clients/go/admin/token_source_provider_test.go new file mode 100644 index 000000000..b1858b46c --- /dev/null +++ b/clients/go/admin/token_source_provider_test.go @@ -0,0 +1,311 @@ +package admin + +import ( + "context" + "fmt" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/oauth2" + + tokenCacheMocks "github.com/flyteorg/flyteidl/clients/go/admin/cache/mocks" + adminMocks "github.com/flyteorg/flyteidl/clients/go/admin/mocks" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flytestdlib/config" +) + +func TestNewTokenSourceProvider(t *testing.T) { + ctx := context.Background() + tests := []struct { + name string + audienceCfg string + scopesCfg []string + useAudienceFromAdmin bool + clientConfigResponse service.PublicClientAuthConfigResponse + expectedAudience string + expectedScopes []string + expectedCallsPubEndpoint int + }{ + { + name: "audience from client config", + audienceCfg: "clientConfiguredAud", + scopesCfg: []string{"all"}, + clientConfigResponse: service.PublicClientAuthConfigResponse{}, + expectedAudience: "clientConfiguredAud", + expectedScopes: []string{"all"}, + expectedCallsPubEndpoint: 0, + }, + { + name: "audience from public client response", + audienceCfg: "clientConfiguredAud", + useAudienceFromAdmin: true, + scopesCfg: []string{"all"}, + clientConfigResponse: service.PublicClientAuthConfigResponse{Audience: "AdminConfiguredAud", Scopes: []string{}}, + expectedAudience: "AdminConfiguredAud", + expectedScopes: []string{"all"}, + expectedCallsPubEndpoint: 1, + }, + + { + name: "audience from client with useAudience from admin false", + audienceCfg: "clientConfiguredAud", + useAudienceFromAdmin: false, + scopesCfg: []string{"all"}, + clientConfigResponse: service.PublicClientAuthConfigResponse{Audience: "AdminConfiguredAud", Scopes: []string{}}, + expectedAudience: "clientConfiguredAud", + expectedScopes: []string{"all"}, + expectedCallsPubEndpoint: 0, + }, + } + for _, test := range tests { + cfg := GetConfig(ctx) + tokenCache := &tokenCacheMocks.TokenCache{} + metadataClient := &adminMocks.AuthMetadataServiceClient{} + metadataClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{}, nil) + metadataClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&test.clientConfigResponse, nil) + cfg.AuthType = AuthTypeClientSecret + cfg.Audience = test.audienceCfg + cfg.Scopes = test.scopesCfg + cfg.UseAudienceFromAdmin = test.useAudienceFromAdmin + flyteTokenSource, err := NewTokenSourceProvider(ctx, cfg, tokenCache, metadataClient) + assert.True(t, metadataClient.AssertNumberOfCalls(t, "GetPublicClientConfig", test.expectedCallsPubEndpoint)) + assert.NoError(t, err) + assert.NotNil(t, flyteTokenSource) + clientCredSourceProvider, ok := flyteTokenSource.(ClientCredentialsTokenSourceProvider) + assert.True(t, ok) + assert.Equal(t, test.expectedScopes, clientCredSourceProvider.ccConfig.Scopes) + assert.Equal(t, url.Values{audienceKey: {test.expectedAudience}}, clientCredSourceProvider.ccConfig.EndpointParams) + } +} + +func TestCustomTokenSource_GetTokenSource(t *testing.T) { + ctx := context.Background() + cfg := GetConfig(ctx) + cfg.TokenRefreshWindow = config.Duration{Duration: time.Minute} + cfg.ClientSecretLocation = "" + + hourAhead := time.Now().Add(time.Hour) + validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead} + + tests := []struct { + name string + token *oauth2.Token + }{ + { + name: "no token", + token: nil, + }, + { + + name: "valid token", + token: &validToken, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tokenCache := &tokenCacheMocks.TokenCache{} + var tokenErr error = nil + if test.token == nil { + tokenErr = fmt.Errorf("no token") + } + tokenCache.OnGetToken().Return(test.token, tokenErr).Once() + provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "") + assert.NoError(t, err) + + source, err := provider.GetTokenSource(ctx) + assert.NoError(t, err) + customSource, ok := source.(*customTokenSource) + assert.True(t, ok) + + if test.token == nil { + assert.Equal(t, time.Time{}, customSource.refreshTime) + } else { + assert.LessOrEqual(t, customSource.refreshTime.Unix(), test.token.Expiry.Unix()) + assert.GreaterOrEqual(t, customSource.refreshTime.Unix(), test.token.Expiry.Add(-cfg.TokenRefreshWindow.Duration).Unix()) + } + }) + } +} + +func TestCustomTokenSource_fetchTokenFromCache(t *testing.T) { + ctx := context.Background() + cfg := GetConfig(ctx) + cfg.TokenRefreshWindow = config.Duration{Duration: time.Minute} + cfg.ClientSecretLocation = "" + + minuteAgo := time.Now().Add(-time.Minute) + hourAhead := time.Now().Add(time.Hour) + invalidToken := oauth2.Token{AccessToken: "foo", Expiry: minuteAgo} + validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead} + + tests := []struct { + name string + refreshTime time.Time + failedToRefresh bool + token *oauth2.Token + expectToken bool + expectNeedsRefresh bool + }{ + { + name: "no token", + refreshTime: hourAhead, + failedToRefresh: false, + token: nil, + expectToken: false, + expectNeedsRefresh: false, + }, + { + name: "invalid token", + refreshTime: hourAhead, + failedToRefresh: false, + token: &invalidToken, + expectToken: false, + expectNeedsRefresh: false, + }, + { + name: "refresh exceeded", + refreshTime: minuteAgo, + failedToRefresh: false, + token: &validToken, + expectToken: true, + expectNeedsRefresh: true, + }, + { + name: "refresh exceeded failed", + refreshTime: minuteAgo, + failedToRefresh: true, + token: &validToken, + expectToken: true, + expectNeedsRefresh: false, + }, + { + name: "valid token", + refreshTime: hourAhead, + failedToRefresh: false, + token: &validToken, + expectToken: true, + expectNeedsRefresh: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tokenCache := &tokenCacheMocks.TokenCache{} + var tokenErr error = nil + if test.token == nil { + tokenErr = fmt.Errorf("no token") + } + tokenCache.OnGetToken().Return(test.token, tokenErr).Twice() + provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "") + assert.NoError(t, err) + source, err := provider.GetTokenSource(ctx) + assert.NoError(t, err) + customSource, ok := source.(*customTokenSource) + assert.True(t, ok) + + customSource.refreshTime = test.refreshTime + customSource.failedToRefresh = test.failedToRefresh + token, needsRefresh := customSource.fetchTokenFromCache() + if test.expectToken { + assert.NotNil(t, token) + } else { + assert.Nil(t, token) + } + assert.Equal(t, test.expectNeedsRefresh, needsRefresh) + }) + } +} + +func TestCustomTokenSource_Token(t *testing.T) { + ctx := context.Background() + cfg := GetConfig(ctx) + cfg.TokenRefreshWindow = config.Duration{Duration: time.Minute} + cfg.ClientSecretLocation = "" + + minuteAgo := time.Now().Add(-time.Minute) + hourAhead := time.Now().Add(time.Hour) + twoHourAhead := time.Now().Add(2 * time.Hour) + invalidToken := oauth2.Token{AccessToken: "foo", Expiry: minuteAgo} + validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead} + newToken := oauth2.Token{AccessToken: "foo", Expiry: twoHourAhead} + + tests := []struct { + name string + refreshTime time.Time + failedToRefresh bool + token *oauth2.Token + newToken *oauth2.Token + expectedToken *oauth2.Token + }{ + { + name: "cached token", + refreshTime: hourAhead, + failedToRefresh: false, + token: &validToken, + newToken: nil, + expectedToken: &validToken, + }, + { + name: "failed refresh still valid", + refreshTime: minuteAgo, + failedToRefresh: false, + token: &validToken, + newToken: nil, + expectedToken: &validToken, + }, + { + name: "failed refresh invalid", + refreshTime: minuteAgo, + failedToRefresh: false, + token: &invalidToken, + newToken: nil, + expectedToken: nil, + }, + { + name: "refresh", + refreshTime: minuteAgo, + failedToRefresh: false, + token: &invalidToken, + newToken: &newToken, + expectedToken: &newToken, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tokenCache := &tokenCacheMocks.TokenCache{} + tokenCache.OnGetToken().Return(test.token, nil).Twice() + provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "") + assert.NoError(t, err) + source, err := provider.GetTokenSource(ctx) + assert.NoError(t, err) + customSource, ok := source.(*customTokenSource) + assert.True(t, ok) + + mockSource := &adminMocks.TokenSource{} + if test.newToken != nil { + mockSource.OnToken().Return(test.newToken, nil) + } else { + mockSource.OnToken().Return(nil, fmt.Errorf("refresh token failed")) + } + customSource.new = mockSource + customSource.refreshTime = test.refreshTime + customSource.failedToRefresh = test.failedToRefresh + if test.newToken != nil { + tokenCache.OnSaveToken(test.newToken).Return(nil).Once() + } + token, err := source.Token() + if test.expectedToken != nil { + assert.Equal(t, test.expectedToken, token) + assert.NoError(t, err) + } else { + assert.Nil(t, token) + assert.Error(t, err) + } + }) + } +} diff --git a/clients/go/admin/token_source_test.go b/clients/go/admin/token_source_test.go index 745237759..0e247bfe8 100644 --- a/clients/go/admin/token_source_test.go +++ b/clients/go/admin/token_source_test.go @@ -2,19 +2,10 @@ package admin import ( "context" - "fmt" - "net/url" "testing" - "time" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "golang.org/x/oauth2" - - tokenCacheMocks "github.com/flyteorg/flyteidl/clients/go/admin/cache/mocks" - adminMocks "github.com/flyteorg/flyteidl/clients/go/admin/mocks" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" - "github.com/flyteorg/flytestdlib/config" ) type DummyTestTokenSource struct { @@ -34,246 +25,3 @@ func TestNewTokenSource(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "Bearer abc", metadata["test"]) } - -func TestNewTokenSourceProvider(t *testing.T) { - ctx := context.Background() - tests := []struct { - name string - audienceCfg string - scopesCfg []string - useAudienceFromAdmin bool - clientConfigResponse service.PublicClientAuthConfigResponse - expectedAudience string - expectedScopes []string - expectedCallsPubEndpoint int - }{ - { - name: "audience from client config", - audienceCfg: "clientConfiguredAud", - scopesCfg: []string{"all"}, - clientConfigResponse: service.PublicClientAuthConfigResponse{}, - expectedAudience: "clientConfiguredAud", - expectedScopes: []string{"all"}, - expectedCallsPubEndpoint: 0, - }, - { - name: "audience from public client response", - audienceCfg: "clientConfiguredAud", - useAudienceFromAdmin: true, - scopesCfg: []string{"all"}, - clientConfigResponse: service.PublicClientAuthConfigResponse{Audience: "AdminConfiguredAud", Scopes: []string{}}, - expectedAudience: "AdminConfiguredAud", - expectedScopes: []string{"all"}, - expectedCallsPubEndpoint: 1, - }, - - { - name: "audience from client with useAudience from admin false", - audienceCfg: "clientConfiguredAud", - useAudienceFromAdmin: false, - scopesCfg: []string{"all"}, - clientConfigResponse: service.PublicClientAuthConfigResponse{Audience: "AdminConfiguredAud", Scopes: []string{}}, - expectedAudience: "clientConfiguredAud", - expectedScopes: []string{"all"}, - expectedCallsPubEndpoint: 0, - }, - } - for _, test := range tests { - cfg := GetConfig(ctx) - tokenCache := &tokenCacheMocks.TokenCache{} - metadataClient := &adminMocks.AuthMetadataServiceClient{} - metadataClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{}, nil) - metadataClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&test.clientConfigResponse, nil) - cfg.AuthType = AuthTypeClientSecret - cfg.Audience = test.audienceCfg - cfg.Scopes = test.scopesCfg - cfg.UseAudienceFromAdmin = test.useAudienceFromAdmin - flyteTokenSource, err := NewTokenSourceProvider(ctx, cfg, tokenCache, metadataClient) - assert.True(t, metadataClient.AssertNumberOfCalls(t, "GetPublicClientConfig", test.expectedCallsPubEndpoint)) - assert.NoError(t, err) - assert.NotNil(t, flyteTokenSource) - clientCredSourceProvider, ok := flyteTokenSource.(ClientCredentialsTokenSourceProvider) - assert.True(t, ok) - assert.Equal(t, test.expectedScopes, clientCredSourceProvider.ccConfig.Scopes) - assert.Equal(t, url.Values{audienceKey: {test.expectedAudience}}, clientCredSourceProvider.ccConfig.EndpointParams) - } -} - -func TestCustomTokenSource_fetchTokenFromCache(t *testing.T) { - ctx := context.Background() - cfg := GetConfig(ctx) - cfg.TokenRefreshWindow = config.Duration{Duration: time.Minute} - cfg.ClientSecretLocation = "" - - minuteAgo := time.Now().Add(-time.Minute) - hourAhead := time.Now().Add(time.Hour) - invalidToken := oauth2.Token{AccessToken: "foo", Expiry: minuteAgo} - validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead} - - tests := []struct { - name string - refreshTime time.Time - failedToRefresh bool - token *oauth2.Token - expectToken bool - expectNeedsRefresh bool - }{ - { - name: "no token", - refreshTime: hourAhead, - failedToRefresh: false, - token: nil, - expectToken: false, - expectNeedsRefresh: false, - }, - { - name: "invalid token", - refreshTime: hourAhead, - failedToRefresh: false, - token: &invalidToken, - expectToken: false, - expectNeedsRefresh: false, - }, - { - name: "refresh exceeded", - refreshTime: minuteAgo, - failedToRefresh: false, - token: &validToken, - expectToken: false, - expectNeedsRefresh: false, - }, - { - name: "refresh exceeded failed", - refreshTime: minuteAgo, - failedToRefresh: true, - token: &validToken, - expectToken: false, - expectNeedsRefresh: false, - }, - { - name: "valid token", - refreshTime: hourAhead, - failedToRefresh: false, - token: &validToken, - expectToken: false, - expectNeedsRefresh: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - tokenCache := &tokenCacheMocks.TokenCache{} - provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "") - assert.NoError(t, err) - source, err := provider.GetTokenSource(ctx) - assert.NoError(t, err) - customSource, ok := source.(*customTokenSource) - assert.True(t, ok) - - customSource.refreshTime = test.refreshTime - customSource.failedToRefresh = test.failedToRefresh - var tokenErr error = nil - if test.token == nil { - tokenErr = fmt.Errorf("no token") - } - tokenCache.OnGetToken().Return(nil, tokenErr).Once() - token, needsRefresh := customSource.fetchTokenFromCache() - if test.expectToken { - assert.NotNil(t, token) - } else { - assert.Nil(t, token) - } - assert.Equal(t, test.expectNeedsRefresh, needsRefresh) - }) - } -} - -func TestCustomTokenSource_Token(t *testing.T) { - ctx := context.Background() - cfg := GetConfig(ctx) - cfg.TokenRefreshWindow = config.Duration{Duration: time.Minute} - cfg.ClientSecretLocation = "" - - minuteAgo := time.Now().Add(-time.Minute) - hourAhead := time.Now().Add(time.Hour) - twoHourAhead := time.Now().Add(2 * time.Hour) - invalidToken := oauth2.Token{AccessToken: "foo", Expiry: minuteAgo} - validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead} - newToken := oauth2.Token{AccessToken: "foo", Expiry: twoHourAhead} - - tests := []struct { - name string - refreshTime time.Time - failedToRefresh bool - token *oauth2.Token - newToken *oauth2.Token - expectedToken *oauth2.Token - }{ - { - name: "cached token", - refreshTime: hourAhead, - failedToRefresh: false, - token: &validToken, - newToken: nil, - expectedToken: &validToken, - }, - { - name: "failed refresh still valid", - refreshTime: minuteAgo, - failedToRefresh: false, - token: &validToken, - newToken: nil, - expectedToken: &validToken, - }, - { - name: "failed refresh invalid", - refreshTime: minuteAgo, - failedToRefresh: false, - token: &invalidToken, - newToken: nil, - expectedToken: nil, - }, - { - name: "refresh", - refreshTime: minuteAgo, - failedToRefresh: false, - token: &invalidToken, - newToken: &newToken, - expectedToken: &newToken, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - tokenCache := &tokenCacheMocks.TokenCache{} - provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "") - assert.NoError(t, err) - source, err := provider.GetTokenSource(ctx) - assert.NoError(t, err) - customSource, ok := source.(*customTokenSource) - assert.True(t, ok) - - mockSource := &adminMocks.TokenSource{} - if test.newToken != nil { - mockSource.OnToken().Return(test.newToken, nil) - } else { - mockSource.OnToken().Return(nil, fmt.Errorf("refresh token failed")) - } - customSource.new = mockSource - customSource.refreshTime = test.refreshTime - customSource.failedToRefresh = test.failedToRefresh - tokenCache.OnGetToken().Return(test.token, nil).Once() - if test.newToken != nil { - tokenCache.OnSaveToken(test.newToken).Return(nil).Once() - } - token, err := source.Token() - if test.expectedToken != nil { - assert.Equal(t, test.expectedToken, token) - assert.NoError(t, err) - } else { - assert.Nil(t, token) - assert.Error(t, err) - } - }) - } -}