Skip to content

Commit

Permalink
Auth/prevent lookup per call (flyteorg#5686)
Browse files Browse the repository at this point in the history
* save values

Signed-off-by: Yee Hing Tong <[email protected]>

* move things up

Signed-off-by: Yee Hing Tong <[email protected]>

* tests

Signed-off-by: Yee Hing Tong <[email protected]>

* unit test

Signed-off-by: Yee Hing Tong <[email protected]>

* imports for client test

Signed-off-by: Yee Hing Tong <[email protected]>

* more test

Signed-off-by: Yee Hing Tong <[email protected]>

* don't test admin connection

Signed-off-by: Yee Hing Tong <[email protected]>

* disable client for config

Signed-off-by: Yee Hing Tong <[email protected]>

* make generate

Signed-off-by: Yee Hing Tong <[email protected]>

* hide behind a once

Signed-off-by: Yee Hing Tong <[email protected]>

* typo

Signed-off-by: Yee Hing Tong <[email protected]>

* reset client builder test

Signed-off-by: Yee Hing Tong <[email protected]>

* reset client test

Signed-off-by: Yee Hing Tong <[email protected]>

* revert propeller

Signed-off-by: Yee Hing Tong <[email protected]>

* delay invocation even further

Signed-off-by: Yee Hing Tong <[email protected]>

---------

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: Bugra Gedik <[email protected]>
  • Loading branch information
wild-endeavor authored and bgedik committed Sep 12, 2024
1 parent f02a8c7 commit 86859d1
Show file tree
Hide file tree
Showing 13 changed files with 238 additions and 105 deletions.
10 changes: 7 additions & 3 deletions flytectl/cmd/configuration/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,13 @@ func CreateConfigCommand() *cobra.Command {
configCmd := viper.GetConfigCommand()

getResourcesFuncs := map[string]cmdcore.CommandEntry{
"init": {CmdFunc: configInitFunc, Aliases: []string{""}, ProjectDomainNotRequired: true,
Short: initCmdShort,
Long: initCmdLong, PFlagProvider: initConfig.DefaultConfig},
"init": {
CmdFunc: configInitFunc,
Aliases: []string{""},
ProjectDomainNotRequired: true,
DisableFlyteClient: true,
Short: initCmdShort,
Long: initCmdLong, PFlagProvider: initConfig.DefaultConfig},
}

configCmd.Flags().BoolVar(&initConfig.DefaultConfig.Force, "force", false, "Force to overwrite the default config file without confirmation")
Expand Down
4 changes: 2 additions & 2 deletions flytectl/cmd/core/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestGenerateCommandFunc(t *testing.T) {
adminCfg.Endpoint = config.URL{URL: url.URL{Host: "dummyHost"}}
adminCfg.AuthType = admin.AuthTypePkce
rootCmd := &cobra.Command{}
cmdEntry := CommandEntry{CmdFunc: testCommandFunc, ProjectDomainNotRequired: true}
cmdEntry := CommandEntry{CmdFunc: testCommandFunc, ProjectDomainNotRequired: true, DisableFlyteClient: true}
fn := generateCommandFunc(cmdEntry)
assert.Nil(t, fn(rootCmd, []string{}))
})
Expand All @@ -30,7 +30,7 @@ func TestGenerateCommandFunc(t *testing.T) {
adminCfg := admin.GetConfig(context.Background())
adminCfg.Endpoint = config.URL{URL: url.URL{Host: ""}}
rootCmd := &cobra.Command{}
cmdEntry := CommandEntry{CmdFunc: testCommandFunc, ProjectDomainNotRequired: true}
cmdEntry := CommandEntry{CmdFunc: testCommandFunc, ProjectDomainNotRequired: true, DisableFlyteClient: true}
fn := generateCommandFunc(cmdEntry)
assert.Nil(t, fn(rootCmd, []string{}))
})
Expand Down
108 changes: 80 additions & 28 deletions flyteidl/clients/go/admin/auth_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/http"
"sync"

"golang.org/x/oauth2"
"google.golang.org/grpc"
Expand All @@ -20,33 +21,10 @@ const ProxyAuthorizationHeader = "proxy-authorization"

// MaterializeCredentials will attempt to build a TokenSource given the anonymously available information exposed by the server.
// Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values.
func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache,
perRPCCredentials *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) error {
authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg, proxyCredentialsFuture)
if err != nil {
return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err)
}

tokenSourceProvider, err := NewTokenSourceProvider(ctx, cfg, tokenCache, authMetadataClient)
if err != nil {
return fmt.Errorf("failed to initialized token source provider. Err: %w", err)
}

authorizationMetadataKey := cfg.AuthorizationHeader
if len(authorizationMetadataKey) == 0 {
clientMetadata, err := authMetadataClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{})
if err != nil {
return fmt.Errorf("failed to fetch client metadata. Error: %v", err)
}
authorizationMetadataKey = clientMetadata.AuthorizationMetadataKey
}

tokenSource, err := tokenSourceProvider.GetTokenSource(ctx)
if err != nil {
return fmt.Errorf("failed to get token source. Error: %w", err)
}
func MaterializeCredentials(tokenSource oauth2.TokenSource, cfg *Config, authorizationMetadataKey string,
perRPCCredentials *PerRPCCredentialsFuture) error {

_, err = tokenSource.Token()
_, err := tokenSource.Token()
if err != nil {
return fmt.Errorf("failed to issue token. Error: %w", err)
}
Expand Down Expand Up @@ -127,6 +105,60 @@ func setHTTPClientContext(ctx context.Context, cfg *Config, proxyCredentialsFutu
return context.WithValue(ctx, oauth2.HTTPClient, httpClient)
}

type OauthMetadataProvider struct {
authorizationMetadataKey string
tokenSource oauth2.TokenSource
once sync.Once
}

func (o *OauthMetadataProvider) getTokenSourceAndMetadata(cfg *Config, tokenCache cache.TokenCache, proxyCredentialsFuture *PerRPCCredentialsFuture) error {
ctx := context.Background()

authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg, proxyCredentialsFuture)
if err != nil {
return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err)
}

tokenSourceProvider, err := NewTokenSourceProvider(ctx, cfg, tokenCache, authMetadataClient)
if err != nil {
return fmt.Errorf("failed to initialize token source provider. Err: %w", err)
}

authorizationMetadataKey := cfg.AuthorizationHeader
if len(authorizationMetadataKey) == 0 {
clientMetadata, err := authMetadataClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{})
if err != nil {
return fmt.Errorf("failed to fetch client metadata. Error: %v", err)
}
authorizationMetadataKey = clientMetadata.AuthorizationMetadataKey
}

tokenSource, err := tokenSourceProvider.GetTokenSource(ctx)
if err != nil {
return fmt.Errorf("failed to get token source. Error: %w", err)
}

o.authorizationMetadataKey = authorizationMetadataKey
o.tokenSource = tokenSource

return nil
}

func (o *OauthMetadataProvider) GetOauthMetadata(cfg *Config, tokenCache cache.TokenCache, proxyCredentialsFuture *PerRPCCredentialsFuture) error {
// Ensure loadTokenRelated() is only executed once
var err error
o.once.Do(func() {
err = o.getTokenSourceAndMetadata(cfg, tokenCache, proxyCredentialsFuture)
if err != nil {
logger.Errorf(context.Background(), "Failed to load token related config. Error: %v", err)
}
})
if err != nil {
return err
}
return nil
}

// NewAuthInterceptor creates a new grpc.UnaryClientInterceptor that forwards the grpc call and inspects the error.
// It will first invoke the grpc pipeline (to proceed with the request) with no modifications. It's expected for the grpc
// pipeline to already have a grpc.WithPerRPCCredentials() DialOption. If the perRPCCredentials has already been initialized,
Expand All @@ -138,13 +170,26 @@ func setHTTPClientContext(ctx context.Context, cfg *Config, proxyCredentialsFutu
// a token source has been created, it'll invoke the grpc pipeline again, this time the grpc.PerRPCCredentials should
// be able to find and acquire a valid AccessToken to annotate the request with.
func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor {

oauthMetadataProvider := OauthMetadataProvider{
once: sync.Once{},
}

return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {

ctx = setHTTPClientContext(ctx, cfg, proxyCredentialsFuture)

// If there is already a token in the cache (e.g. key-ring), we should use it immediately...
t, _ := tokenCache.GetToken()
if t != nil {
err := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)
err := oauthMetadataProvider.GetOauthMetadata(cfg, tokenCache, proxyCredentialsFuture)
if err != nil {
return err
}
authorizationMetadataKey := oauthMetadataProvider.authorizationMetadataKey
tokenSource := oauthMetadataProvider.tokenSource

err = MaterializeCredentials(tokenSource, cfg, authorizationMetadataKey, credentialsFuture)
if err != nil {
return fmt.Errorf("failed to materialize credentials. Error: %v", err)
}
Expand All @@ -157,6 +202,13 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut
if st, ok := status.FromError(err); ok {
// If the error we receive from executing the request expects
if shouldAttemptToAuthenticate(st.Code()) {
err := oauthMetadataProvider.GetOauthMetadata(cfg, tokenCache, proxyCredentialsFuture)
if err != nil {
return err
}
authorizationMetadataKey := oauthMetadataProvider.authorizationMetadataKey
tokenSource := oauthMetadataProvider.tokenSource

err = func() error {
if !tokenCache.TryLock() {
tokenCache.CondWait()
Expand All @@ -171,7 +223,7 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut
}

logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code())
newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)
newErr := MaterializeCredentials(tokenSource, cfg, authorizationMetadataKey, credentialsFuture)
if newErr != nil {
errString := fmt.Sprintf("authentication error! Original Error: %v, Auth Error: %v", err, newErr)
logger.Errorf(ctx, errString)
Expand Down
88 changes: 78 additions & 10 deletions flyteidl/clients/go/admin/auth_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache/mocks"
adminMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
Expand Down Expand Up @@ -141,11 +142,34 @@ func Test_newAuthInterceptor(t *testing.T) {
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
t.Run("Other Error", func(t *testing.T) {
ctx := context.Background()
httpPort := rand.IntnRange(10000, 60000)
grpcPort := rand.IntnRange(10000, 60000)
m := &adminMocks.AuthMetadataServiceServer{}
m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{
AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", httpPort),
TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", httpPort),
JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", httpPort),
}, nil)

m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service.PublicClientAuthConfigResponse{
Scopes: []string{"all"},
}, nil)

s := newAuthMetadataServer(t, grpcPort, httpPort, m)
assert.NoError(t, s.Start(ctx))
defer s.Close()
u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort))
assert.NoError(t, err)
f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()
mockTokenCache := &mocks.TokenCache{}
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
interceptor := NewAuthInterceptor(&Config{}, mockTokenCache, f, p)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
interceptor := NewAuthInterceptor(&Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
}, mockTokenCache, f, p)
otherError := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return status.New(codes.Canceled, "").Err()
}
Expand Down Expand Up @@ -209,6 +233,14 @@ func Test_newAuthInterceptor(t *testing.T) {
httpPort := rand.IntnRange(10000, 60000)
grpcPort := rand.IntnRange(10000, 60000)
m := &adminMocks.AuthMetadataServiceServer{}
m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{
AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", httpPort),
TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", httpPort),
JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", httpPort),
}, nil)
m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service.PublicClientAuthConfigResponse{
Scopes: []string{"all"},
}, nil)
s := newAuthMetadataServer(t, grpcPort, httpPort, m)
ctx := context.Background()
assert.NoError(t, s.Start(ctx))
Expand Down Expand Up @@ -283,12 +315,13 @@ func Test_newAuthInterceptor(t *testing.T) {
})
}

func TestMaterializeCredentials(t *testing.T) {
func TestNewAuthInterceptorAndMaterialize(t *testing.T) {
t.Run("No oauth2 metadata endpoint or Public client config lookup", func(t *testing.T) {
httpPort := rand.IntnRange(10000, 60000)
grpcPort := rand.IntnRange(10000, 60000)
fakeToken := &oauth2.Token{}
c := &mocks.TokenCache{}
c.OnGetTokenMatch().Return(nil, nil)
c.OnGetTokenMatch().Return(fakeToken, nil)
c.OnSaveTokenMatch(mock.Anything).Return(nil)
m := &adminMocks.AuthMetadataServiceServer{}
m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata"))
Expand All @@ -304,22 +337,30 @@ func TestMaterializeCredentials(t *testing.T) {
f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()

err = MaterializeCredentials(ctx, &Config{
cfg := &Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
TokenURL: fmt.Sprintf("http://localhost:%d/oauth2/token", httpPort),
Scopes: []string{"all"},
Audience: "http://localhost:30081",
AuthorizationHeader: "authorization",
}, c, f, p)
}

intercept := NewAuthInterceptor(cfg, c, f, p)
// Invoke Materialize inside the intercept
err = intercept(ctx, "GET", nil, nil, nil, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return nil
})
assert.NoError(t, err)
})

t.Run("Failed to fetch client metadata", func(t *testing.T) {
httpPort := rand.IntnRange(10000, 60000)
grpcPort := rand.IntnRange(10000, 60000)
c := &mocks.TokenCache{}
c.OnGetTokenMatch().Return(nil, nil)
fakeToken := &oauth2.Token{}
c.OnGetTokenMatch().Return(fakeToken, nil)
c.OnSaveTokenMatch(mock.Anything).Return(nil)
m := &adminMocks.AuthMetadataServiceServer{}
m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata"))
Expand All @@ -333,17 +374,44 @@ func TestMaterializeCredentials(t *testing.T) {
u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort))
assert.NoError(t, err)

cfg := &Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", httpPort),
Scopes: []string{"all"},
}
f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()
intercept := NewAuthInterceptor(cfg, c, f, p)
err = intercept(ctx, "GET", nil, nil, nil, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return nil
})
assert.EqualError(t, err, "failed to fetch client metadata. Error: rpc error: code = Unknown desc = expected err")
})
}

func TestSimpleMaterializeCredentials(t *testing.T) {
t.Run("simple materialize", func(t *testing.T) {
httpPort := rand.IntnRange(10000, 60000)
grpcPort := rand.IntnRange(10000, 60000)
u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", grpcPort))
assert.NoError(t, err)

f := NewPerRPCCredentialsFuture()

err = MaterializeCredentials(ctx, &Config{
dummySource := DummyTestTokenSource{}

err = MaterializeCredentials(dummySource, &Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", httpPort),
TokenURL: fmt.Sprintf("http://localhost:%d/oauth2/token", httpPort),
Scopes: []string{"all"},
}, c, f, p)
assert.EqualError(t, err, "failed to fetch client metadata. Error: rpc error: code = Unknown desc = expected err")
Audience: "http://localhost:30081",
AuthorizationHeader: "authorization",
}, "authorization", f)
assert.NoError(t, err)
})
}

Expand Down
3 changes: 2 additions & 1 deletion flyteidl/clients/go/admin/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,9 @@ func initializeClients(ctx context.Context, cfg *Config, tokenCache cache.TokenC
credentialsFuture := NewPerRPCCredentialsFuture()
proxyCredentialsFuture := NewPerRPCCredentialsFuture()

authInterceptor := NewAuthInterceptor(cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)
opts = append(opts,
grpc.WithChainUnaryInterceptor(NewAuthInterceptor(cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)),
grpc.WithChainUnaryInterceptor(authInterceptor),
grpc.WithPerRPCCredentials(credentialsFuture))

if cfg.DefaultServiceConfig != "" {
Expand Down
Loading

0 comments on commit 86859d1

Please sign in to comment.