Skip to content

Commit

Permalink
Adding device authorization oauth2 flow (#313)
Browse files Browse the repository at this point in the history
* Added config skip opening browser for pkce auth

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* added docs

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* increased the default browser session timeout to 2min

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Adding device flow idl changes

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* Adding device flow orchestration

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* lint fixes

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* nit

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* fixes

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* refactor and feedback

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* nit

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* test fixes

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* more test fixes

Signed-off-by: Prafulla Mahindrakar <[email protected]>

* feedback

Signed-off-by: Prafulla Mahindrakar <[email protected]>

Signed-off-by: Prafulla Mahindrakar <[email protected]>
  • Loading branch information
pmahindrakar-oss authored Sep 2, 2022
1 parent 7da6209 commit 997b290
Show file tree
Hide file tree
Showing 39 changed files with 1,303 additions and 353 deletions.
7 changes: 4 additions & 3 deletions flyteidl/clients/go/admin/authtype_enumer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package pkce
package cache

import "golang.org/x/oauth2"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package pkce
package cache

import (
"fmt"
Expand Down
14 changes: 7 additions & 7 deletions flyteidl/clients/go/admin/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@ import (
"errors"
"fmt"

"github.com/flyteorg/flyteidl/clients/go/admin/mocks"
"github.com/flyteorg/flyteidl/clients/go/admin/pkce"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flytestdlib/logger"

grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpcRetry "github.com/grpc-ecosystem/go-grpc-middleware/retry"
grpcPrometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/health/grpc_health_v1"

"github.com/flyteorg/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyteidl/clients/go/admin/mocks"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flytestdlib/logger"
)

// IDE "Go Generate File". This will create a mocks/AdminServiceClient.go file
Expand Down Expand Up @@ -172,7 +172,7 @@ func InitializeAdminClient(ctx context.Context, cfg *Config, opts ...grpc.DialOp

// initializeClients creates an AdminClient, AuthServiceClient and IdentityServiceClient with a shared Admin connection
// for the process. Note that if called with different cfg/dialoptions, it will not refresh the connection.
func initializeClients(ctx context.Context, cfg *Config, tokenCache pkce.TokenCache, opts ...grpc.DialOption) (*Clientset, error) {
func initializeClients(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, opts ...grpc.DialOption) (*Clientset, error) {
authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg)
if err != nil {
logger.Panicf(ctx, "failed to initialize Auth Metadata Client. Error: %v", err)
Expand Down Expand Up @@ -215,7 +215,7 @@ func initializeClients(ctx context.Context, cfg *Config, tokenCache pkce.TokenCa
}

// Deprecated: Please use NewClientsetBuilder() instead.
func InitializeAdminClientFromConfig(ctx context.Context, tokenCache pkce.TokenCache, opts ...grpc.DialOption) (service.AdminServiceClient, error) {
func InitializeAdminClientFromConfig(ctx context.Context, tokenCache cache.TokenCache, opts ...grpc.DialOption) (service.AdminServiceClient, error) {
clientSet, err := initializeClients(ctx, GetConfig(ctx), tokenCache, opts...)
if err != nil {
return nil, err
Expand Down
8 changes: 4 additions & 4 deletions flyteidl/clients/go/admin/client_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ import (

"google.golang.org/grpc"

"github.com/flyteorg/flyteidl/clients/go/admin/pkce"
"github.com/flyteorg/flyteidl/clients/go/admin/cache"
)

// ClientsetBuilder is used to build the clientset. This allows custom token cache implementations to be plugged in.
type ClientsetBuilder struct {
config *Config
tokenCache pkce.TokenCache
tokenCache cache.TokenCache
opts []grpc.DialOption
}

Expand All @@ -27,7 +27,7 @@ func (cb *ClientsetBuilder) WithConfig(config *Config) *ClientsetBuilder {
}

// WithTokenCache allows pluggable token cache implemetations. eg; flytectl uses keyring as tokenCache
func (cb *ClientsetBuilder) WithTokenCache(tokenCache pkce.TokenCache) *ClientsetBuilder {
func (cb *ClientsetBuilder) WithTokenCache(tokenCache cache.TokenCache) *ClientsetBuilder {
cb.tokenCache = tokenCache
return cb
}
Expand All @@ -40,7 +40,7 @@ func (cb *ClientsetBuilder) WithDialOptions(opts ...grpc.DialOption) *ClientsetB
// Build the clientset using the current state of the ClientsetBuilder
func (cb *ClientsetBuilder) Build(ctx context.Context) (*Clientset, error) {
if cb.tokenCache == nil {
cb.tokenCache = &pkce.TokenCacheInMemoryProvider{}
cb.tokenCache = &cache.TokenCacheInMemoryProvider{}
}

if cb.config == nil {
Expand Down
7 changes: 4 additions & 3 deletions flyteidl/clients/go/admin/client_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ import (
"reflect"
"testing"

"github.com/flyteorg/flyteidl/clients/go/admin/pkce"
"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyteidl/clients/go/admin/cache"
)

func TestClientsetBuilder_Build(t *testing.T) {
cb := NewClientsetBuilder().WithConfig(&Config{
UseInsecureConnection: true,
}).WithTokenCache(&pkce.TokenCacheInMemoryProvider{})
}).WithTokenCache(&cache.TokenCacheInMemoryProvider{})
_, err := cb.Build(context.Background())
assert.NoError(t, err)
assert.True(t, reflect.TypeOf(cb.tokenCache) == reflect.TypeOf(&pkce.TokenCacheInMemoryProvider{}))
assert.True(t, reflect.TypeOf(cb.tokenCache) == reflect.TypeOf(&cache.TokenCacheInMemoryProvider{}))
}
38 changes: 26 additions & 12 deletions flyteidl/clients/go/admin/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,24 @@ import (
"io/ioutil"
"net/http"
"net/url"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/oauth2"
_ "google.golang.org/grpc/balancer/roundrobin" //nolint

"github.com/flyteorg/flyteidl/clients/go/admin/cache"
cachemocks "github.com/flyteorg/flyteidl/clients/go/admin/cache/mocks"
"github.com/flyteorg/flyteidl/clients/go/admin/mocks"
"github.com/flyteorg/flyteidl/clients/go/admin/oauth"
"github.com/flyteorg/flyteidl/clients/go/admin/pkce"
pkcemocks "github.com/flyteorg/flyteidl/clients/go/admin/pkce/mocks"
"github.com/flyteorg/flyteidl/clients/go/admin/tokenorchestrator"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flytestdlib/config"
"github.com/flyteorg/flytestdlib/logger"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/oauth2"
_ "google.golang.org/grpc/balancer/roundrobin" //nolint
)

func TestInitializeAndGetAdminClient(t *testing.T) {
Expand Down Expand Up @@ -193,13 +197,13 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) {
RedirectUri: "http://localhost:54545/callback",
}
http.DefaultServeMux = http.NewServeMux()
plan, _ := ioutil.ReadFile("pkce/testdata/token.json")
plan, _ := os.ReadFile("tokenorchestrator/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
tokenData.Expiry = time.Now().Add(time.Minute)
t.Run("cache hit", func(t *testing.T) {
mockTokenCache := new(pkcemocks.TokenCache)
mockTokenCache := new(cachemocks.TokenCache)
mockAuthClient := new(mocks.AuthMetadataServiceClient)
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
Expand All @@ -213,7 +217,7 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) {
})
tokenData.Expiry = time.Now().Add(-time.Minute)
t.Run("cache miss auth failure", func(t *testing.T) {
mockTokenCache := new(pkcemocks.TokenCache)
mockTokenCache := new(cachemocks.TokenCache)
mockAuthClient := new(mocks.AuthMetadataServiceClient)
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
Expand Down Expand Up @@ -244,16 +248,26 @@ func Test_getPkceAuthTokenSource(t *testing.T) {
mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil)

t.Run("cached token expired", func(t *testing.T) {
plan, _ := ioutil.ReadFile("pkce/testdata/token.json")
plan, _ := ioutil.ReadFile("tokenorchestrator/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)

// populate the cache
tokenCache := &pkce.TokenCacheInMemoryProvider{}
tokenCache := &cache.TokenCacheInMemoryProvider{}
assert.NoError(t, tokenCache.SaveToken(&tokenData))

orchestrator, err := pkce.NewTokenOrchestrator(ctx, pkce.Config{}, tokenCache, mockAuthClient)
baseOrchestrator := tokenorchestrator.BaseTokenOrchestrator{
ClientConfig: &oauth.Config{
Config: &oauth2.Config{
RedirectURL: "http://localhost:8089/redirect",
Scopes: []string{"code", "all"},
},
},
TokenCache: tokenCache,
}

orchestrator, err := pkce.NewTokenOrchestrator(baseOrchestrator, pkce.Config{})
assert.NoError(t, err)

http.DefaultServeMux = http.NewServeMux()
Expand Down
19 changes: 14 additions & 5 deletions flyteidl/clients/go/admin/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (
"path/filepath"
"time"

"github.com/flyteorg/flyteidl/clients/go/admin/deviceflow"
"github.com/flyteorg/flyteidl/clients/go/admin/pkce"

"github.com/flyteorg/flytestdlib/config"
"github.com/flyteorg/flytestdlib/logger"
)
Expand All @@ -27,12 +27,14 @@ var DefaultClientSecretLocation = filepath.Join(string(filepath.Separator), "etc
type AuthType uint8

const (
// Chooses Client Secret OAuth2 protocol (ref: https://tools.ietf.org/html/rfc6749#section-4.4)
// AuthTypeClientSecret Chooses Client Secret OAuth2 protocol (ref: https://tools.ietf.org/html/rfc6749#section-4.4)
AuthTypeClientSecret AuthType = iota
// Chooses Proof Key Code Exchange OAuth2 extension protocol (ref: https://tools.ietf.org/html/rfc7636)
// AuthTypePkce Chooses Proof Key Code Exchange OAuth2 extension protocol (ref: https://tools.ietf.org/html/rfc7636)
AuthTypePkce
// Chooses an external authentication process
// AuthTypeExternalCommand Chooses an external authentication process
AuthTypeExternalCommand
// AuthTypeDeviceFlow Uses device flow to authenticate in a constrained environment with no access to browser
AuthTypeDeviceFlow
)

type Config struct {
Expand Down Expand Up @@ -67,6 +69,8 @@ type Config struct {

PkceConfig pkce.Config `json:"pkceConfig" pflag:",Config for Pkce authentication flow."`

DeviceFlowConfig deviceflow.Config `json:"deviceFlowConfig" pflag:",Config for Device authentication flow."`

Command []string `json:"command" pflag:",Command for external authentication token generation"`

// Set the gRPC service config formatted as a json string https://github.com/grpc/grpc/blob/master/doc/service_config.md
Expand All @@ -86,7 +90,12 @@ var (
ClientSecretLocation: DefaultClientSecretLocation,
PkceConfig: pkce.Config{
TokenRefreshGracePeriod: config.Duration{Duration: 5 * time.Minute},
BrowserSessionTimeout: config.Duration{Duration: 15 * time.Second},
BrowserSessionTimeout: config.Duration{Duration: 2 * time.Minute},
},
DeviceFlowConfig: deviceflow.Config{
TokenRefreshGracePeriod: config.Duration{Duration: 5 * time.Minute},
Timeout: config.Duration{Duration: 10 * time.Minute},
PollInterval: config.Duration{Duration: 5 * time.Second},
},
TokenRefreshWindow: config.Duration{Duration: 0},
}
Expand Down
7 changes: 5 additions & 2 deletions flyteidl/clients/go/admin/config_flags.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 42 additions & 0 deletions flyteidl/clients/go/admin/config_flags_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions flyteidl/clients/go/admin/deviceflow/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package deviceflow

import "github.com/flyteorg/flytestdlib/config"

// Config defines settings used for Device orchestration flow.
type Config struct {
TokenRefreshGracePeriod config.Duration `json:"refreshTime" pflag:",grace period from the token expiry after which it would refresh the token."`
Timeout config.Duration `json:"timeout" pflag:",amount of time the device flow should complete or else it will be cancelled."`
PollInterval config.Duration `json:"pollInterval" pflag:",amount of time the device flow would poll the token endpoint if auth server doesn't return a polling interval. Okta and google IDP do return an interval'"`
}
42 changes: 42 additions & 0 deletions flyteidl/clients/go/admin/deviceflow/payload.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package deviceflow

import "golang.org/x/oauth2"

// DeviceAuthorizationRequest sent to authorization server directly from the client app
type DeviceAuthorizationRequest struct {
// ClientID is the client identifier issued to the client during the registration process of OAuth app with the authorization server
ClientID string `json:"client_id"`
// Scope is the scope parameter of the access request
Scope string `json:"scope"`
}

// DeviceAuthorizationResponse contains the information that the end user would use to authorize the app requesting the
// resource access.
type DeviceAuthorizationResponse struct {
// DeviceCode unique device code generated by the authorization server.
DeviceCode string `json:"device_code"`
// UserCode unique code generated for the user to enter on another device
UserCode string `json:"user_code"`
// VerificationURI url endpoint of the authorization server which host the device and app verification
VerificationURI string `json:"verification_uri"`
// VerificationURIComplete url endpoint of the authorization server which host the device and app verification along with user code
VerificationURIComplete string `json:"verification_uri_complete"`
// ExpiresIn lifetime in seconds of the "device_code" and "user_code"
ExpiresIn int64 `json:"expires_in"`
// Interval minimum amount of time in secs the client app should wait between polling requests to the token endpoint.
Interval int64 `json:"interval"`
}

type DeviceAccessTokenRequest struct {
// ClientID is the client identifier issued to the client during the registration process of OAuth app with the authorization server
ClientID string `json:"client_id"`
// DeviceCode unique device code generated by the authorization server.
DeviceCode string `json:"device_code"`
// Value MUST be set to "urn:ietf:params:oauth:grant-type:device_code"
GrantType string `json:"grant_type"`
}

type DeviceAccessTokenResponse struct {
oauth2.Token
Error string `json:"error"`
}
Loading

0 comments on commit 997b290

Please sign in to comment.