From cf832ba6399ca533fad2a2abeaca0863ac731a7c Mon Sep 17 00:00:00 2001 From: pmahindrakar-oss Date: Thu, 12 Jan 2023 08:06:10 -0800 Subject: [PATCH] Added UseAudienceFromAdmin property to force pull audience from admin config. Default is false and expects clients to pass it Signed-off-by: pmahindrakar-oss --- clients/go/admin/config.go | 1 + clients/go/admin/config_flags.go | 1 + clients/go/admin/config_flags_test.go | 14 +++++++ clients/go/admin/token_source_provider.go | 6 +-- clients/go/admin/token_source_test.go | 47 ++++++++++++----------- 5 files changed, 42 insertions(+), 27 deletions(-) diff --git a/clients/go/admin/config.go b/clients/go/admin/config.go index d3bd31359..dd0652606 100644 --- a/clients/go/admin/config.go +++ b/clients/go/admin/config.go @@ -53,6 +53,7 @@ type Config struct { ClientSecretLocation string `json:"clientSecretLocation" pflag:",File containing the client secret"` ClientSecretEnvVar string `json:"clientSecretEnvVar" pflag:",Environment variable containing the client secret"` Scopes []string `json:"scopes" pflag:",List of scopes to request"` + UseAudienceFromAdmin bool `json:"useAudienceFromAdmin" pflag:",Use Audience configured from admins public endpoint config."` Audience string `json:"audience" pflag:",Audience to use when initiating OAuth2 authorization requests."` // There are two ways to get the token URL. If the authorization server url is provided, the client will try to use RFC 8414 to diff --git a/clients/go/admin/config_flags.go b/clients/go/admin/config_flags.go index dbe234d48..53a6a4421 100755 --- a/clients/go/admin/config_flags.go +++ b/clients/go/admin/config_flags.go @@ -64,6 +64,7 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "clientSecretLocation"), defaultConfig.ClientSecretLocation, "File containing the client secret") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "clientSecretEnvVar"), defaultConfig.ClientSecretEnvVar, "Environment variable containing the client secret") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "scopes"), defaultConfig.Scopes, "List of scopes to request") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "useAudienceFromAdmin"), defaultConfig.UseAudienceFromAdmin, "Use Audience configured from admins public endpoint config.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "audience"), defaultConfig.Audience, "Audience to use when initiating OAuth2 authorization requests.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authorizationServerUrl"), defaultConfig.DeprecatedAuthorizationServerURL, "This is the URL to your IdP's authorization server. It'll default to Endpoint") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "tokenUrl"), defaultConfig.TokenURL, "OPTIONAL: Your IdP's token endpoint. It'll be discovered from flyte admin's OAuth Metadata endpoint if not provided.") diff --git a/clients/go/admin/config_flags_test.go b/clients/go/admin/config_flags_test.go index 0c37a3aeb..bdcec55f6 100755 --- a/clients/go/admin/config_flags_test.go +++ b/clients/go/admin/config_flags_test.go @@ -295,6 +295,20 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_useAudienceFromAdmin", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("useAudienceFromAdmin", testValue) + if vBool, err := cmdFlags.GetBool("useAudienceFromAdmin"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.UseAudienceFromAdmin) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_audience", func(t *testing.T) { t.Run("Override", func(t *testing.T) { diff --git a/clients/go/admin/token_source_provider.go b/clients/go/admin/token_source_provider.go index 4eac8e01a..19ed10be9 100644 --- a/clients/go/admin/token_source_provider.go +++ b/clients/go/admin/token_source_provider.go @@ -53,7 +53,7 @@ func NewTokenSourceProvider(ctx context.Context, cfg *Config, tokenCache cache.T scopes := cfg.Scopes audienceValue := cfg.Audience - if len(scopes) == 0 || len(audienceValue) == 0 { + if len(scopes) == 0 || cfg.UseAudienceFromAdmin { publicClientConfig, err := authClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{}) if err != nil { return nil, fmt.Errorf("failed to fetch client metadata. Error: %v", err) @@ -63,9 +63,7 @@ func NewTokenSourceProvider(ctx context.Context, cfg *Config, tokenCache cache.T scopes = publicClientConfig.Scopes } // Update audience from publicClientConfig - if len(audienceValue) == 0 { - audienceValue = publicClientConfig.Audience - } + audienceValue = publicClientConfig.Audience } tokenProvider, err = NewClientCredentialsTokenSourceProvider(ctx, cfg, scopes, tokenURL, audienceValue) diff --git a/clients/go/admin/token_source_test.go b/clients/go/admin/token_source_test.go index 9318e8103..2e860da15 100644 --- a/clients/go/admin/token_source_test.go +++ b/clients/go/admin/token_source_test.go @@ -38,44 +38,45 @@ func TestNewTokenSourceProvider(t *testing.T) { name string audienceCfg string scopesCfg []string + useAudienceFromAdmin bool clientConfigResponse service.PublicClientAuthConfigResponse expectedAudience string expectedScopes []string }{ { name: "audience from client config", - audienceCfg: "aud", + audienceCfg: "clientConfiguredAud", scopesCfg: []string{"all"}, clientConfigResponse: service.PublicClientAuthConfigResponse{}, - expectedAudience: "aud", + expectedAudience: "clientConfiguredAud", expectedScopes: []string{"all"}, }, { name: "audience from public client response", - audienceCfg: "", - scopesCfg: []string{}, - clientConfigResponse: service.PublicClientAuthConfigResponse{Audience: "aud", Scopes: []string{"all"}}, - expectedAudience: "aud", + audienceCfg: "clientConfiguredAud", + useAudienceFromAdmin: true, + scopesCfg: []string{"all"}, + clientConfigResponse: service.PublicClientAuthConfigResponse{Audience: "AdminConfiguredAud", Scopes: []string{}}, + expectedAudience: "AdminConfiguredAud", expectedScopes: []string{"all"}, }, } for _, test := range tests { - t.Run("audience from client config", func(t *testing.T) { - cfg := GetConfig(ctx) - tokenCache := &tokenCacheMocks.TokenCache{} - metadataClient := &adminMocks.AuthMetadataServiceClient{} - metadataClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{}, nil) - metadataClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&test.clientConfigResponse, nil) - cfg.AuthType = AuthTypeClientSecret - cfg.Audience = test.audienceCfg - cfg.Scopes = test.scopesCfg - flyteTokenSource, err := NewTokenSourceProvider(ctx, cfg, tokenCache, metadataClient) - assert.NoError(t, err) - assert.NotNil(t, flyteTokenSource) - clientCredSourceProvider, ok := flyteTokenSource.(ClientCredentialsTokenSourceProvider) - assert.True(t, ok) - assert.Equal(t, test.expectedScopes, clientCredSourceProvider.ccConfig.Scopes) - assert.Equal(t, url.Values{audienceKey: {test.expectedAudience}}, clientCredSourceProvider.ccConfig.EndpointParams) - }) + cfg := GetConfig(ctx) + tokenCache := &tokenCacheMocks.TokenCache{} + metadataClient := &adminMocks.AuthMetadataServiceClient{} + metadataClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{}, nil) + metadataClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&test.clientConfigResponse, nil) + cfg.AuthType = AuthTypeClientSecret + cfg.Audience = test.audienceCfg + cfg.Scopes = test.scopesCfg + cfg.UseAudienceFromAdmin = test.useAudienceFromAdmin + flyteTokenSource, err := NewTokenSourceProvider(ctx, cfg, tokenCache, metadataClient) + assert.NoError(t, err) + assert.NotNil(t, flyteTokenSource) + clientCredSourceProvider, ok := flyteTokenSource.(ClientCredentialsTokenSourceProvider) + assert.True(t, ok) + assert.Equal(t, test.expectedScopes, clientCredSourceProvider.ccConfig.Scopes) + assert.Equal(t, url.Values{audienceKey: {test.expectedAudience}}, clientCredSourceProvider.ccConfig.EndpointParams) } }