-
Notifications
You must be signed in to change notification settings - Fork 674
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use grpc client interceptors to properly check for auth requirement (#…
…315) * Use grpc client interceptors to properly check for auth requirement Signed-off-by: Haytham Abuelfutuh <[email protected]> * Some refactor and add unit tests Signed-off-by: Haytham Abuelfutuh <[email protected]> * PR Comments Signed-off-by: Haytham Abuelfutuh <[email protected]> * lint Signed-off-by: Haytham Abuelfutuh <[email protected]> * unit tests Signed-off-by: Haytham Abuelfutuh <[email protected]> * Attempt a random port Signed-off-by: Haytham Abuelfutuh <[email protected]> * Listen to localhost only Signed-off-by: Haytham Abuelfutuh <[email protected]> * PR Comments Signed-off-by: Haytham Abuelfutuh <[email protected]> * use chain unary interceptor instead Signed-off-by: Haytham Abuelfutuh <[email protected]> * only log on errors Signed-off-by: Haytham Abuelfutuh <[email protected]> * Attempt to disable error check Signed-off-by: Haytham Abuelfutuh <[email protected]> Signed-off-by: Haytham Abuelfutuh <[email protected]>
- Loading branch information
Showing
7 changed files
with
487 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
package admin | ||
|
||
import ( | ||
"context" | ||
"sync/atomic" | ||
|
||
"google.golang.org/grpc/credentials" | ||
|
||
stdlibAtomic "github.com/flyteorg/flytestdlib/atomic" | ||
) | ||
|
||
// atomicPerRPCCredentials provides a convenience on top of atomic.Value and credentials.PerRPCCredentials to be thread-safe. | ||
type atomicPerRPCCredentials struct { | ||
atomic.Value | ||
} | ||
|
||
func (t *atomicPerRPCCredentials) Store(properties credentials.PerRPCCredentials) { | ||
t.Value.Store(properties) | ||
} | ||
|
||
func (t *atomicPerRPCCredentials) Load() credentials.PerRPCCredentials { | ||
val := t.Value.Load() | ||
if val == nil { | ||
return CustomHeaderTokenSource{} | ||
} | ||
|
||
return val.(credentials.PerRPCCredentials) | ||
} | ||
|
||
func newAtomicPerPRCCredentials() *atomicPerRPCCredentials { | ||
return &atomicPerRPCCredentials{ | ||
Value: atomic.Value{}, | ||
} | ||
} | ||
|
||
// PerRPCCredentialsFuture is a future wrapper for credentials.PerRPCCredentials that can act as one and also be | ||
// materialized later. | ||
type PerRPCCredentialsFuture struct { | ||
perRPCCredentials *atomicPerRPCCredentials | ||
initialized stdlibAtomic.Bool | ||
} | ||
|
||
// GetRequestMetadata gets the authorization metadata as a map using a TokenSource to generate a token | ||
func (ts *PerRPCCredentialsFuture) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { | ||
if ts.initialized.Load() { | ||
tp := ts.perRPCCredentials.Load() | ||
return tp.GetRequestMetadata(ctx, uri...) | ||
} | ||
|
||
return map[string]string{}, nil | ||
} | ||
|
||
// RequireTransportSecurity returns whether this credentials class requires TLS/SSL. OAuth uses Bearer tokens that are | ||
// susceptible to MITM (Man-In-The-Middle) attacks that are mitigated by TLS/SSL. We may return false here to make it | ||
// easier to setup auth. However, in a production environment, TLS for OAuth2 is a requirement. | ||
// see also: https://tools.ietf.org/html/rfc6749#section-3.1 | ||
func (ts *PerRPCCredentialsFuture) RequireTransportSecurity() bool { | ||
if ts.initialized.Load() { | ||
return ts.perRPCCredentials.Load().RequireTransportSecurity() | ||
} | ||
|
||
return false | ||
} | ||
|
||
func (ts *PerRPCCredentialsFuture) Store(tokenSource credentials.PerRPCCredentials) { | ||
ts.perRPCCredentials.Store(tokenSource) | ||
ts.initialized.Store(true) | ||
} | ||
|
||
func (ts *PerRPCCredentialsFuture) Get() credentials.PerRPCCredentials { | ||
return ts.perRPCCredentials.Load() | ||
} | ||
|
||
func (ts *PerRPCCredentialsFuture) IsInitialized() bool { | ||
return ts.initialized.Load() | ||
} | ||
|
||
// NewPerRPCCredentialsFuture initializes a new PerRPCCredentialsFuture that can act as a credentials.PerRPCCredentials | ||
// and can also be resolved in the future. Users of the future can check if it has been initialized before by calling | ||
// PerRPCCredentialsFuture.IsInitialized(). Calling PerRPCCredentialsFuture.Get() multiple times will return | ||
// the same stored object (unless it changed in between calls). Calling PerRPCCredentialsFuture.Store() multiple | ||
// times is supported and will result in overriding the old value atomically. | ||
func NewPerRPCCredentialsFuture() *PerRPCCredentialsFuture { | ||
tokenSource := PerRPCCredentialsFuture{ | ||
perRPCCredentials: newAtomicPerPRCCredentials(), | ||
} | ||
|
||
return &tokenSource | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
package admin | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestAtomicPerRPCCredentials(t *testing.T) { | ||
a := atomicPerRPCCredentials{} | ||
assert.True(t, a.Load().RequireTransportSecurity()) | ||
|
||
tokenSource := DummyTestTokenSource{} | ||
chTokenSource := NewCustomHeaderTokenSource(tokenSource, true, "my_custom_header") | ||
a.Store(chTokenSource) | ||
|
||
assert.False(t, a.Load().RequireTransportSecurity()) | ||
} | ||
|
||
func TestNewPerRPCCredentialsFuture(t *testing.T) { | ||
f := NewPerRPCCredentialsFuture() | ||
assert.False(t, f.RequireTransportSecurity()) | ||
assert.Equal(t, CustomHeaderTokenSource{}, f.Get()) | ||
|
||
tokenSource := DummyTestTokenSource{} | ||
chTokenSource := NewCustomHeaderTokenSource(tokenSource, false, "my_custom_header") | ||
f.Store(chTokenSource) | ||
|
||
assert.True(t, f.Get().RequireTransportSecurity()) | ||
assert.True(t, f.RequireTransportSecurity()) | ||
} | ||
|
||
func ExampleNewPerRPCCredentialsFuture() { | ||
f := NewPerRPCCredentialsFuture() | ||
|
||
// Starts uninitialized | ||
fmt.Println("Initialized:", f.IsInitialized()) | ||
|
||
// Implements credentials.PerRPCCredentials so can be used as one | ||
m, err := f.GetRequestMetadata(context.TODO(), "") | ||
fmt.Println("GetRequestMetadata:", m, "Error:", err) | ||
|
||
// Materialize the value later and populate | ||
tokenSource := DummyTestTokenSource{} | ||
f.Store(NewCustomHeaderTokenSource(tokenSource, false, "my_custom_header")) | ||
|
||
// Future calls to credentials.PerRPCCredentials methods will use the new instance | ||
m, err = f.GetRequestMetadata(context.TODO(), "") | ||
fmt.Println("GetRequestMetadata:", m, "Error:", err) | ||
|
||
// Output: | ||
// Initialized: false | ||
// GetRequestMetadata: map[] Error: <nil> | ||
// GetRequestMetadata: map[my_custom_header:Bearer abc] Error: <nil> | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
package admin | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
|
||
"github.com/flyteorg/flyteidl/clients/go/admin/cache" | ||
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" | ||
"github.com/flyteorg/flytestdlib/logger" | ||
|
||
"google.golang.org/grpc/codes" | ||
"google.golang.org/grpc/status" | ||
|
||
"google.golang.org/grpc" | ||
) | ||
|
||
// MaterializeCredentials will attempt to build a TokenSource given the anonymously available information exposed by the server. | ||
// Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values. | ||
func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture) error { | ||
authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg) | ||
if err != nil { | ||
return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err) | ||
} | ||
|
||
tokenSourceProvider, err := NewTokenSourceProvider(ctx, cfg, tokenCache, authMetadataClient) | ||
if err != nil { | ||
return fmt.Errorf("failed to initialized token source provider. Err: %w", err) | ||
} | ||
|
||
clientMetadata, err := authMetadataClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{}) | ||
if err != nil { | ||
return fmt.Errorf("failed to fetch client metadata. Error: %v", err) | ||
} | ||
|
||
tokenSource, err := tokenSourceProvider.GetTokenSource(ctx) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, clientMetadata.AuthorizationMetadataKey) | ||
perRPCCredentials.Store(wrappedTokenSource) | ||
return nil | ||
} | ||
|
||
func shouldAttemptToAuthenticate(errorCode codes.Code) bool { | ||
return errorCode == codes.Unauthenticated | ||
} | ||
|
||
// newAuthInterceptor creates a new grpc.UnaryClientInterceptor that forwards the grpc call and inspects the error. | ||
// It will first invoke the grpc pipeline (to proceed with the request) with no modifications. It's expected for the grpc | ||
// pipeline to already have a grpc.WithPerRPCCredentials() DialOption. If the perRPCCredentials has already been initialized, | ||
// it'll take care of refreshing when tokens expire... etc. | ||
// If the first invocation succeeds (either due to grpc.PerRPCCredentials setting the right tokens or the server not | ||
// requiring authentication), the interceptor will be no-op. | ||
// If the first invocation fails with an auth error, this interceptor will then attempt to establish a token source once | ||
// more. It'll fail hard if it couldn't do so (i.e. it will no longer attempt to send an unauthenticated request). Once | ||
// a token source has been created, it'll invoke the grpc pipeline again, this time the grpc.PerRPCCredentials should | ||
// be able to find and acquire a valid AccessToken to annotate the request with. | ||
func newAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor { | ||
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { | ||
err := invoker(ctx, method, req, reply, cc, opts...) | ||
if err != nil { | ||
logger.Debugf(ctx, "Request failed due to [%v]. If it's an unauthenticated error, we will attempt to establish an authenticated context.", err) | ||
|
||
if st, ok := status.FromError(err); ok { | ||
// If the error we receive from executing the request expects | ||
if shouldAttemptToAuthenticate(st.Code()) { | ||
logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code()) | ||
newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture) | ||
if newErr != nil { | ||
return fmt.Errorf("authentication error! Original Error: %v, Auth Error: %w", err, newErr) | ||
} | ||
|
||
return invoker(ctx, method, req, reply, cc, opts...) | ||
} | ||
} | ||
} | ||
|
||
return err | ||
} | ||
} |
Oops, something went wrong.