From 9991cd7db2abc92f84af27043ac4dc3cecfdd95a Mon Sep 17 00:00:00 2001 From: Haytham Abuelfutuh Date: Fri, 9 Sep 2022 12:44:47 -0700 Subject: [PATCH] Use grpc client interceptors to properly check for auth requirement (#315) * Use grpc client interceptors to properly check for auth requirement Signed-off-by: Haytham Abuelfutuh * Some refactor and add unit tests Signed-off-by: Haytham Abuelfutuh * PR Comments Signed-off-by: Haytham Abuelfutuh * lint Signed-off-by: Haytham Abuelfutuh * unit tests Signed-off-by: Haytham Abuelfutuh * Attempt a random port Signed-off-by: Haytham Abuelfutuh * Listen to localhost only Signed-off-by: Haytham Abuelfutuh * PR Comments Signed-off-by: Haytham Abuelfutuh * use chain unary interceptor instead Signed-off-by: Haytham Abuelfutuh * only log on errors Signed-off-by: Haytham Abuelfutuh * Attempt to disable error check Signed-off-by: Haytham Abuelfutuh Signed-off-by: Haytham Abuelfutuh --- clients/go/admin/atomic_credentials.go | 89 ++++++++ clients/go/admin/atomic_credentials_test.go | 57 +++++ clients/go/admin/auth_interceptor.go | 81 +++++++ clients/go/admin/auth_interceptor_test.go | 235 ++++++++++++++++++++ clients/go/admin/client.go | 50 +---- clients/go/admin/client_test.go | 2 +- clients/go/admin/token_source.go | 25 ++- 7 files changed, 487 insertions(+), 52 deletions(-) create mode 100644 clients/go/admin/atomic_credentials.go create mode 100644 clients/go/admin/atomic_credentials_test.go create mode 100644 clients/go/admin/auth_interceptor.go create mode 100644 clients/go/admin/auth_interceptor_test.go diff --git a/clients/go/admin/atomic_credentials.go b/clients/go/admin/atomic_credentials.go new file mode 100644 index 0000000000..ca2a4f1776 --- /dev/null +++ b/clients/go/admin/atomic_credentials.go @@ -0,0 +1,89 @@ +package admin + +import ( + "context" + "sync/atomic" + + "google.golang.org/grpc/credentials" + + stdlibAtomic "github.com/flyteorg/flytestdlib/atomic" +) + +// atomicPerRPCCredentials provides a convenience on top of atomic.Value and credentials.PerRPCCredentials to be thread-safe. +type atomicPerRPCCredentials struct { + atomic.Value +} + +func (t *atomicPerRPCCredentials) Store(properties credentials.PerRPCCredentials) { + t.Value.Store(properties) +} + +func (t *atomicPerRPCCredentials) Load() credentials.PerRPCCredentials { + val := t.Value.Load() + if val == nil { + return CustomHeaderTokenSource{} + } + + return val.(credentials.PerRPCCredentials) +} + +func newAtomicPerPRCCredentials() *atomicPerRPCCredentials { + return &atomicPerRPCCredentials{ + Value: atomic.Value{}, + } +} + +// PerRPCCredentialsFuture is a future wrapper for credentials.PerRPCCredentials that can act as one and also be +// materialized later. +type PerRPCCredentialsFuture struct { + perRPCCredentials *atomicPerRPCCredentials + initialized stdlibAtomic.Bool +} + +// GetRequestMetadata gets the authorization metadata as a map using a TokenSource to generate a token +func (ts *PerRPCCredentialsFuture) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + if ts.initialized.Load() { + tp := ts.perRPCCredentials.Load() + return tp.GetRequestMetadata(ctx, uri...) + } + + return map[string]string{}, nil +} + +// RequireTransportSecurity returns whether this credentials class requires TLS/SSL. OAuth uses Bearer tokens that are +// susceptible to MITM (Man-In-The-Middle) attacks that are mitigated by TLS/SSL. We may return false here to make it +// easier to setup auth. However, in a production environment, TLS for OAuth2 is a requirement. +// see also: https://tools.ietf.org/html/rfc6749#section-3.1 +func (ts *PerRPCCredentialsFuture) RequireTransportSecurity() bool { + if ts.initialized.Load() { + return ts.perRPCCredentials.Load().RequireTransportSecurity() + } + + return false +} + +func (ts *PerRPCCredentialsFuture) Store(tokenSource credentials.PerRPCCredentials) { + ts.perRPCCredentials.Store(tokenSource) + ts.initialized.Store(true) +} + +func (ts *PerRPCCredentialsFuture) Get() credentials.PerRPCCredentials { + return ts.perRPCCredentials.Load() +} + +func (ts *PerRPCCredentialsFuture) IsInitialized() bool { + return ts.initialized.Load() +} + +// NewPerRPCCredentialsFuture initializes a new PerRPCCredentialsFuture that can act as a credentials.PerRPCCredentials +// and can also be resolved in the future. Users of the future can check if it has been initialized before by calling +// PerRPCCredentialsFuture.IsInitialized(). Calling PerRPCCredentialsFuture.Get() multiple times will return +// the same stored object (unless it changed in between calls). Calling PerRPCCredentialsFuture.Store() multiple +// times is supported and will result in overriding the old value atomically. +func NewPerRPCCredentialsFuture() *PerRPCCredentialsFuture { + tokenSource := PerRPCCredentialsFuture{ + perRPCCredentials: newAtomicPerPRCCredentials(), + } + + return &tokenSource +} diff --git a/clients/go/admin/atomic_credentials_test.go b/clients/go/admin/atomic_credentials_test.go new file mode 100644 index 0000000000..426835d163 --- /dev/null +++ b/clients/go/admin/atomic_credentials_test.go @@ -0,0 +1,57 @@ +package admin + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAtomicPerRPCCredentials(t *testing.T) { + a := atomicPerRPCCredentials{} + assert.True(t, a.Load().RequireTransportSecurity()) + + tokenSource := DummyTestTokenSource{} + chTokenSource := NewCustomHeaderTokenSource(tokenSource, true, "my_custom_header") + a.Store(chTokenSource) + + assert.False(t, a.Load().RequireTransportSecurity()) +} + +func TestNewPerRPCCredentialsFuture(t *testing.T) { + f := NewPerRPCCredentialsFuture() + assert.False(t, f.RequireTransportSecurity()) + assert.Equal(t, CustomHeaderTokenSource{}, f.Get()) + + tokenSource := DummyTestTokenSource{} + chTokenSource := NewCustomHeaderTokenSource(tokenSource, false, "my_custom_header") + f.Store(chTokenSource) + + assert.True(t, f.Get().RequireTransportSecurity()) + assert.True(t, f.RequireTransportSecurity()) +} + +func ExampleNewPerRPCCredentialsFuture() { + f := NewPerRPCCredentialsFuture() + + // Starts uninitialized + fmt.Println("Initialized:", f.IsInitialized()) + + // Implements credentials.PerRPCCredentials so can be used as one + m, err := f.GetRequestMetadata(context.TODO(), "") + fmt.Println("GetRequestMetadata:", m, "Error:", err) + + // Materialize the value later and populate + tokenSource := DummyTestTokenSource{} + f.Store(NewCustomHeaderTokenSource(tokenSource, false, "my_custom_header")) + + // Future calls to credentials.PerRPCCredentials methods will use the new instance + m, err = f.GetRequestMetadata(context.TODO(), "") + fmt.Println("GetRequestMetadata:", m, "Error:", err) + + // Output: + // Initialized: false + // GetRequestMetadata: map[] Error: + // GetRequestMetadata: map[my_custom_header:Bearer abc] Error: +} diff --git a/clients/go/admin/auth_interceptor.go b/clients/go/admin/auth_interceptor.go new file mode 100644 index 0000000000..69c522850b --- /dev/null +++ b/clients/go/admin/auth_interceptor.go @@ -0,0 +1,81 @@ +package admin + +import ( + "context" + "fmt" + + "github.com/flyteorg/flyteidl/clients/go/admin/cache" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flytestdlib/logger" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "google.golang.org/grpc" +) + +// MaterializeCredentials will attempt to build a TokenSource given the anonymously available information exposed by the server. +// Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values. +func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture) error { + authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg) + if err != nil { + return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err) + } + + tokenSourceProvider, err := NewTokenSourceProvider(ctx, cfg, tokenCache, authMetadataClient) + if err != nil { + return fmt.Errorf("failed to initialized token source provider. Err: %w", err) + } + + clientMetadata, err := authMetadataClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{}) + if err != nil { + return fmt.Errorf("failed to fetch client metadata. Error: %v", err) + } + + tokenSource, err := tokenSourceProvider.GetTokenSource(ctx) + if err != nil { + return err + } + + wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, clientMetadata.AuthorizationMetadataKey) + perRPCCredentials.Store(wrappedTokenSource) + return nil +} + +func shouldAttemptToAuthenticate(errorCode codes.Code) bool { + return errorCode == codes.Unauthenticated +} + +// newAuthInterceptor creates a new grpc.UnaryClientInterceptor that forwards the grpc call and inspects the error. +// It will first invoke the grpc pipeline (to proceed with the request) with no modifications. It's expected for the grpc +// pipeline to already have a grpc.WithPerRPCCredentials() DialOption. If the perRPCCredentials has already been initialized, +// it'll take care of refreshing when tokens expire... etc. +// If the first invocation succeeds (either due to grpc.PerRPCCredentials setting the right tokens or the server not +// requiring authentication), the interceptor will be no-op. +// If the first invocation fails with an auth error, this interceptor will then attempt to establish a token source once +// more. It'll fail hard if it couldn't do so (i.e. it will no longer attempt to send an unauthenticated request). Once +// a token source has been created, it'll invoke the grpc pipeline again, this time the grpc.PerRPCCredentials should +// be able to find and acquire a valid AccessToken to annotate the request with. +func newAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + err := invoker(ctx, method, req, reply, cc, opts...) + if err != nil { + logger.Debugf(ctx, "Request failed due to [%v]. If it's an unauthenticated error, we will attempt to establish an authenticated context.", err) + + if st, ok := status.FromError(err); ok { + // If the error we receive from executing the request expects + if shouldAttemptToAuthenticate(st.Code()) { + logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code()) + newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture) + if newErr != nil { + return fmt.Errorf("authentication error! Original Error: %v, Auth Error: %w", err, newErr) + } + + return invoker(ctx, method, req, reply, cc, opts...) + } + } + } + + return err + } +} diff --git a/clients/go/admin/auth_interceptor_test.go b/clients/go/admin/auth_interceptor_test.go new file mode 100644 index 0000000000..b877793b5e --- /dev/null +++ b/clients/go/admin/auth_interceptor_test.go @@ -0,0 +1,235 @@ +package admin + +import ( + "context" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + + "github.com/flyteorg/flytestdlib/logger" + + "k8s.io/apimachinery/pkg/util/rand" + + mocks2 "github.com/flyteorg/flyteidl/clients/go/admin/mocks" + "github.com/stretchr/testify/mock" + + service2 "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flytestdlib/config" + + "github.com/stretchr/testify/assert" + + "github.com/flyteorg/flyteidl/clients/go/admin/cache/mocks" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// authMetadataServer is a fake AuthMetadataServer that takes in an AuthMetadataServer implementation (usually one +// initialized through mockery) and starts a local server that uses it to respond to grpc requests. +type authMetadataServer struct { + s *httptest.Server + t testing.TB + port int + grpcServer *grpc.Server + netListener net.Listener + impl service2.AuthMetadataServiceServer + lck *sync.RWMutex +} + +func (s authMetadataServer) GetOAuth2Metadata(ctx context.Context, in *service2.OAuth2MetadataRequest) (*service2.OAuth2MetadataResponse, error) { + return s.impl.GetOAuth2Metadata(ctx, in) +} + +func (s authMetadataServer) GetPublicClientConfig(ctx context.Context, in *service2.PublicClientAuthConfigRequest) (*service2.PublicClientAuthConfigResponse, error) { + return s.impl.GetPublicClientConfig(ctx, in) +} + +func (s authMetadataServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var issuer string + switch r.URL.Path { + case "/.well-known/oauth-authorization-server": + w.Header().Set("Content-Type", "application/json") + _, err := io.WriteString(w, strings.ReplaceAll(`{ + "issuer": "https://dev-14186422.okta.com", + "authorization_endpoint": "https://example.com/auth", + "token_endpoint": "https://example.com/token", + "jwks_uri": "https://example.com/keys", + "id_token_signing_alg_values_supported": ["RS256"] + }`, "ISSUER", issuer)) + if !assert.NoError(s.t, err) { + s.t.FailNow() + } + + return + } + + http.NotFound(w, r) +} + +func (s *authMetadataServer) Start(_ context.Context) error { + s.lck.Lock() + defer s.lck.Unlock() + + /***** Set up the server serving channelz service. *****/ + lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", s.port)) + if err != nil { + return fmt.Errorf("failed to listen on port [%v]: %w", s.port, err) + } + + grpcS := grpc.NewServer() + service2.RegisterAuthMetadataServiceServer(grpcS, s) + go func() { + _ = grpcS.Serve(lis) + //assert.NoError(s.t, err) + }() + + s.grpcServer = grpcS + s.netListener = lis + + s.s = httptest.NewServer(s) + + return nil +} + +func (s *authMetadataServer) Close() { + s.lck.RLock() + defer s.lck.RUnlock() + + s.grpcServer.Stop() + s.s.Close() +} + +func newAuthMetadataServer(t testing.TB, port int, impl service2.AuthMetadataServiceServer) *authMetadataServer { + return &authMetadataServer{ + port: port, + t: t, + impl: impl, + lck: &sync.RWMutex{}, + } +} + +func Test_newAuthInterceptor(t *testing.T) { + t.Run("Other Error", func(t *testing.T) { + f := NewPerRPCCredentialsFuture() + interceptor := newAuthInterceptor(&Config{}, &mocks.TokenCache{}, f) + otherError := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + return status.New(codes.Canceled, "").Err() + } + + assert.Error(t, interceptor(context.Background(), "POST", nil, nil, nil, otherError)) + }) + + t.Run("Unauthenticated first time, succeed the second time", func(t *testing.T) { + assert.NoError(t, logger.SetConfig(&logger.Config{ + Level: logger.DebugLevel, + })) + + port := rand.IntnRange(10000, 60000) + m := &mocks2.AuthMetadataServiceServer{} + m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service2.OAuth2MetadataResponse{ + AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", port), + TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", port), + JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", port), + }, nil) + m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service2.PublicClientAuthConfigResponse{ + Scopes: []string{"all"}, + }, nil) + s := newAuthMetadataServer(t, port, m) + ctx := context.Background() + assert.NoError(t, s.Start(ctx)) + defer s.Close() + + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) + assert.NoError(t, err) + + f := NewPerRPCCredentialsFuture() + interceptor := newAuthInterceptor(&Config{ + Endpoint: config.URL{URL: *u}, + UseInsecureConnection: true, + AuthType: AuthTypeClientSecret, + }, &mocks.TokenCache{}, f) + unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + return status.New(codes.Unauthenticated, "").Err() + } + + err = interceptor(ctx, "POST", nil, nil, nil, unauthenticated) + assert.Error(t, err) + assert.Truef(t, f.IsInitialized(), "PerRPCCredentialFuture should be initialized") + assert.False(t, f.Get().RequireTransportSecurity(), "Insecure should be true leading to RequireTLS false") + }) + + t.Run("Already authenticated", func(t *testing.T) { + assert.NoError(t, logger.SetConfig(&logger.Config{ + Level: logger.DebugLevel, + })) + + port := rand.IntnRange(10000, 60000) + m := &mocks2.AuthMetadataServiceServer{} + s := newAuthMetadataServer(t, port, m) + ctx := context.Background() + assert.NoError(t, s.Start(ctx)) + defer s.Close() + + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) + assert.NoError(t, err) + + f := NewPerRPCCredentialsFuture() + interceptor := newAuthInterceptor(&Config{ + Endpoint: config.URL{URL: *u}, + UseInsecureConnection: true, + AuthType: AuthTypeClientSecret, + }, &mocks.TokenCache{}, f) + authenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + return nil + } + + err = interceptor(ctx, "POST", nil, nil, nil, authenticated) + assert.NoError(t, err) + assert.Falsef(t, f.IsInitialized(), "PerRPCCredentialFuture should not need to be initialized") + }) + + t.Run("Other error, doesn't authenticate", func(t *testing.T) { + assert.NoError(t, logger.SetConfig(&logger.Config{ + Level: logger.DebugLevel, + })) + + port := rand.IntnRange(10000, 60000) + m := &mocks2.AuthMetadataServiceServer{} + m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service2.OAuth2MetadataResponse{ + AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", port), + TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", port), + JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", port), + }, nil) + m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service2.PublicClientAuthConfigResponse{ + Scopes: []string{"all"}, + }, nil) + + s := newAuthMetadataServer(t, port, m) + ctx := context.Background() + assert.NoError(t, s.Start(ctx)) + defer s.Close() + + u, err := url.Parse(fmt.Sprintf("dns:///localhost:%d", port)) + assert.NoError(t, err) + + f := NewPerRPCCredentialsFuture() + interceptor := newAuthInterceptor(&Config{ + Endpoint: config.URL{URL: *u}, + UseInsecureConnection: true, + AuthType: AuthTypeClientSecret, + }, &mocks.TokenCache{}, f) + unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + return status.New(codes.Aborted, "").Err() + } + + err = interceptor(ctx, "POST", nil, nil, nil, unauthenticated) + assert.Error(t, err) + assert.Falsef(t, f.IsInitialized(), "PerRPCCredentialFuture should not be initialized") + }) +} diff --git a/clients/go/admin/client.go b/clients/go/admin/client.go index 383c4c0699..fb0bfc642f 100644 --- a/clients/go/admin/client.go +++ b/clients/go/admin/client.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" - grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware" grpcRetry "github.com/grpc-ecosystem/go-grpc-middleware/retry" grpcPrometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "google.golang.org/grpc" @@ -30,11 +29,6 @@ type Clientset struct { healthServiceClient grpc_health_v1.HealthClient identityServiceClient service.IdentityServiceClient dataProxyServiceClient service.DataProxyServiceClient - authOpt grpc.DialOption -} - -func (c Clientset) AuthOpt() grpc.DialOption { - return c.authOpt } // AdminClient retrieves the AdminServiceClient @@ -75,16 +69,11 @@ func GetAdditionalAdminClientConfigOptions(cfg *Config) []grpc.DialOption { timeoutDialOption := grpcRetry.WithPerRetryTimeout(cfg.PerRetryTimeout.Duration) maxRetriesOption := grpcRetry.WithMax(uint(cfg.MaxRetries)) - retryInterceptor := grpcRetry.UnaryClientInterceptor(timeoutDialOption, maxRetriesOption) - finalUnaryInterceptor := grpcMiddleware.ChainUnaryClient( - grpcPrometheus.UnaryClientInterceptor, - retryInterceptor, - ) // We only make unary calls in this client, no streaming calls. We can add a streaming interceptor if admin // ever has those endpoints - opts = append(opts, grpc.WithUnaryInterceptor(finalUnaryInterceptor)) + opts = append(opts, grpc.WithChainUnaryInterceptor(grpcPrometheus.UnaryClientInterceptor, retryInterceptor)) return opts } @@ -102,13 +91,13 @@ func getAuthenticationDialOption(ctx context.Context, cfg *Config, tokenSourcePr return nil, fmt.Errorf("failed to fetch client metadata. Error: %v", err) } - tSource, err := tokenSourceProvider.GetTokenSource(ctx) + tokenSource, err := tokenSourceProvider.GetTokenSource(ctx) if err != nil { return nil, err } - oauthTokenSource := NewCustomHeaderTokenSource(tSource, cfg.UseInsecureConnection, clientMetadata.AuthorizationMetadataKey) - return grpc.WithPerRPCCredentials(oauthTokenSource), nil + wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, clientMetadata.AuthorizationMetadataKey) + return grpc.WithPerRPCCredentials(wrappedTokenSource), nil } // InitializeAuthMetadataClient creates a new anonymously Auth Metadata Service client. @@ -116,7 +105,7 @@ func InitializeAuthMetadataClient(ctx context.Context, cfg *Config) (client serv // Create an unauthenticated connection to fetch AuthMetadata authMetadataConnection, err := NewAdminConnection(ctx, cfg) if err != nil { - return nil, fmt.Errorf("failed to initialize admin connection. Error: %w", err) + return nil, fmt.Errorf("failed to initialized admin connection. Error: %w", err) } return service.NewAuthMetadataServiceClient(authMetadataConnection), nil @@ -163,7 +152,7 @@ func NewAdminConnection(ctx context.Context, cfg *Config, opts ...grpc.DialOptio func InitializeAdminClient(ctx context.Context, cfg *Config, opts ...grpc.DialOption) service.AdminServiceClient { set, err := initializeClients(ctx, cfg, nil, opts...) if err != nil { - logger.Panicf(ctx, "Failed to initialize client. Error: %v", err) + logger.Panicf(ctx, "Failed to initialized client. Error: %v", err) return nil } @@ -173,24 +162,10 @@ func InitializeAdminClient(ctx context.Context, cfg *Config, opts ...grpc.DialOp // initializeClients creates an AdminClient, AuthServiceClient and IdentityServiceClient with a shared Admin connection // for the process. Note that if called with different cfg/dialoptions, it will not refresh the connection. func initializeClients(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, opts ...grpc.DialOption) (*Clientset, error) { - authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg) - if err != nil { - logger.Panicf(ctx, "failed to initialize Auth Metadata Client. Error: %v", err) - } - - tokenSourceProvider, err := NewTokenSourceProvider(ctx, cfg, tokenCache, authMetadataClient) - if err != nil { - logger.Errorf(ctx, "failed to initialize token source provider. Err: %s", err.Error()) - } - - authOpt, err := getAuthenticationDialOption(ctx, cfg, tokenSourceProvider, authMetadataClient) - if err != nil { - logger.Warnf(ctx, "Starting an unauthenticated client because: %v", err) - } - - if authOpt != nil { - opts = append(opts, authOpt) - } + credentialsFuture := NewPerRPCCredentialsFuture() + opts = append(opts, + grpc.WithChainUnaryInterceptor(newAuthInterceptor(cfg, tokenCache, credentialsFuture)), + grpc.WithPerRPCCredentials(credentialsFuture)) if cfg.DefaultServiceConfig != "" { opts = append(opts, grpc.WithDefaultServiceConfig(cfg.DefaultServiceConfig)) @@ -198,7 +173,7 @@ func initializeClients(ctx context.Context, cfg *Config, tokenCache cache.TokenC adminConnection, err := NewAdminConnection(ctx, cfg, opts...) if err != nil { - logger.Panicf(ctx, "failed to initialize Admin connection. Err: %s", err.Error()) + logger.Panicf(ctx, "failed to initialized Admin connection. Err: %s", err.Error()) } var cs Clientset @@ -207,9 +182,6 @@ func initializeClients(ctx context.Context, cfg *Config, tokenCache cache.TokenC cs.identityServiceClient = service.NewIdentityServiceClient(adminConnection) cs.healthServiceClient = grpc_health_v1.NewHealthClient(adminConnection) cs.dataProxyServiceClient = service.NewDataProxyServiceClient(adminConnection) - if authOpt != nil { - cs.authOpt = authOpt - } return &cs, nil } diff --git a/clients/go/admin/client_test.go b/clients/go/admin/client_test.go index 8f41fc248f..469ebaea35 100644 --- a/clients/go/admin/client_test.go +++ b/clients/go/admin/client_test.go @@ -319,7 +319,7 @@ func ExampleClientSetBuilder() { // See AuthType for a list of supported authentication types. clientSet, err := NewClientsetBuilder().WithConfig(GetConfig(ctx)).Build(ctx) if err != nil { - logger.Fatalf(ctx, "failed to initialize clientSet from config. Error: %v", err) + logger.Fatalf(ctx, "failed to initialized clientSet from config. Error: %v", err) } // Access and use the desired client: diff --git a/clients/go/admin/token_source.go b/clients/go/admin/token_source.go index 604824836b..33610e2877 100644 --- a/clients/go/admin/token_source.go +++ b/clients/go/admin/token_source.go @@ -6,36 +6,37 @@ import ( "golang.org/x/oauth2" ) -// This class is here because we cannot use the normal "github.com/grpc/grpc-go/credentials/oauth" package to satisfy +// CustomHeaderTokenSource class is here because we cannot use the normal "github.com/grpc/grpc-go/credentials/oauth" package to satisfy // the credentials.PerRPCCredentials interface. This is because we want to be able to support a different 'header' // when passing the token in the gRPC call's metadata. The default is filled in in the constructor if none is supplied. type CustomHeaderTokenSource struct { - oauth2.TokenSource + tokenSource oauth2.TokenSource customHeader string insecure bool } const DefaultAuthorizationHeader = "authorization" +// RequireTransportSecurity returns whether this credentials class requires TLS/SSL. OAuth uses Bearer tokens that are +// susceptible to MITM (Man-In-The-Middle) attacks that are mitigated by TLS/SSL. We may return false here to make it +// easier to setup auth. However, in a production environment, TLS for OAuth2 is a requirement. +// see also: https://tools.ietf.org/html/rfc6749#section-3.1 +func (ts CustomHeaderTokenSource) RequireTransportSecurity() bool { + return !ts.insecure +} + // GetRequestMetadata gets the authorization metadata as a map using a TokenSource to generate a token func (ts CustomHeaderTokenSource) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - token, err := ts.Token() + token, err := ts.tokenSource.Token() if err != nil { return nil, err } + return map[string]string{ ts.customHeader: token.Type() + " " + token.AccessToken, }, nil } -// RequireTransportSecurity returns whether this credentials class requires TLS/SSL. OAuth uses Bearer tokens that are -// susceptible to MITM (Man-In-The-Middle) attacks that are mitigated by TLS/SSL. We may return false here to make it -// easier to setup auth. However, in a production environment, TLS for OAuth2 is a requirement. -// see also: https://tools.ietf.org/html/rfc6749#section-3.1 -func (ts CustomHeaderTokenSource) RequireTransportSecurity() bool { - return !ts.insecure -} - func NewCustomHeaderTokenSource(source oauth2.TokenSource, insecure bool, customHeader string) CustomHeaderTokenSource { header := DefaultAuthorizationHeader if customHeader != "" { @@ -43,7 +44,7 @@ func NewCustomHeaderTokenSource(source oauth2.TokenSource, insecure bool, custom } return CustomHeaderTokenSource{ - TokenSource: source, + tokenSource: source, customHeader: header, insecure: insecure, }