Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Pass along raw claims in identity context #447

Merged
merged 10 commits into from
Jun 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion auth/authzserver/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func verifyClaims(expectedAudience sets.String, claimsRaw map[string]interface{}
scopes.Insert(auth.ScopeAll)
}

return auth.NewIdentityContext(claims.Audience[foundAudIndex], claims.Subject, clientID, claims.IssuedAt, scopes, userInfo), nil
return auth.NewIdentityContext(claims.Audience[foundAudIndex], claims.Subject, clientID, claims.IssuedAt, scopes, userInfo, claimsRaw), nil
}

// NewProvider creates a new OAuth2 Provider that is able to do OAuth 2-legged and 3-legged flows. It'll lookup
Expand Down
14 changes: 13 additions & 1 deletion auth/identity_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ var (
emptyIdentityContext = IdentityContext{}
)

type claimsType = map[string]interface{}

// IdentityContext is an abstract entity to enclose the authenticated identity of the user/app. Both gRPC and HTTP
// servers have interceptors to set the IdentityContext on the context.Context.
// To retrieve the current IdentityContext call auth.IdentityContextFromContext(ctx).
Expand All @@ -25,6 +27,8 @@ type IdentityContext struct {
userInfo *service.UserInfoResponse
// Set to pointer just to keep this struct go-simple to support equal operator
scopes *sets.String
// Raw JWT token from the IDP. Set to a pointer to support the equal operator for this struct.
claims *claimsType
}

func (c IdentityContext) Audience() string {
Expand Down Expand Up @@ -59,6 +63,13 @@ func (c IdentityContext) Scopes() sets.String {
return sets.NewString()
}

func (c IdentityContext) Claims() map[string]interface{} {
if c.claims != nil {
return *c.claims
}
return make(map[string]interface{})
}

func (c IdentityContext) WithContext(ctx context.Context) context.Context {
return context.WithValue(ctx, ContextKeyIdentityContext, c)
}
Expand All @@ -68,7 +79,7 @@ func (c IdentityContext) AuthenticatedAt() time.Time {
}

// NewIdentityContext creates a new IdentityContext.
func NewIdentityContext(audience, userID, appID string, authenticatedAt time.Time, scopes sets.String, userInfo *service.UserInfoResponse) IdentityContext {
func NewIdentityContext(audience, userID, appID string, authenticatedAt time.Time, scopes sets.String, userInfo *service.UserInfoResponse, claims map[string]interface{}) IdentityContext {
// For some reason, google IdP returns a subject in the ID Token but an empty subject in the /user_info endpoint
if userInfo == nil {
userInfo = &service.UserInfoResponse{}
Expand All @@ -85,6 +96,7 @@ func NewIdentityContext(audience, userID, appID string, authenticatedAt time.Tim
userInfo: userInfo,
authenticatedAt: authenticatedAt,
scopes: &scopes,
claims: &claims,
}
}

Expand Down
21 changes: 21 additions & 0 deletions auth/identity_context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package auth

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestGetClaims(t *testing.T) {
noClaims := map[string]interface{}(nil)
noClaimsCtx := NewIdentityContext("", "", "", time.Now(), nil, nil, nil)
assert.EqualValues(t, noClaims, noClaimsCtx.Claims())

claims := map[string]interface{}{
"groups": []string{"g1", "g2"},
"something": "else",
}
withClaimsCtx := NewIdentityContext("", "", "", time.Now(), nil, nil, claims)
assert.EqualValues(t, claims, withClaimsCtx.Claims())
}
2 changes: 2 additions & 0 deletions auth/interfaces/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ type IdentityContext interface {
UserInfo() *service.UserInfoResponse
AuthenticatedAt() time.Time
Scopes() sets.String
// Returns the full set of claims in the JWT token provided by the IDP.
Claims() map[string]interface{}

IsEmpty() bool
WithContext(ctx context.Context) context.Context
Expand Down
34 changes: 34 additions & 0 deletions auth/interfaces/mocks/identity_context.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions auth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ func ParseIDTokenAndValidate(ctx context.Context, clientID, rawIDToken string, p

return idToken, flyteErr
}

return idToken, nil
}

Expand Down Expand Up @@ -130,8 +129,12 @@ func IdentityContextFromIDTokenToken(ctx context.Context, tokenStr, clientID str
if err != nil {
return nil, err
}
var claims map[string]interface{}
if err := idToken.Claims(&claims); err != nil {
logger.Infof(ctx, "Failed to unmarshal claims from id token, err: %v", err)
}

// TODO: Document why automatically specify "all" scope
return NewIdentityContext(idToken.Audience[0], idToken.Subject, "", idToken.IssuedAt,
sets.NewString(ScopeAll), userInfo), nil
sets.NewString(ScopeAll), userInfo, claims), nil
}
4 changes: 2 additions & 2 deletions pkg/manager/impl/execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ func TestCreateExecution(t *testing.T) {
request.Spec.RawOutputDataConfig = &admin.RawOutputDataConfig{OutputLocationPrefix: rawOutput}
request.Spec.ClusterAssignment = &clusterAssignment

identity := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil)
identity := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil, nil)
ctx := identity.WithContext(context.Background())
response, err := execManager.CreateExecution(ctx, request, requestedAt)
assert.Nil(t, err)
Expand Down Expand Up @@ -2834,7 +2834,7 @@ func TestTerminateExecution(t *testing.T) {
r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &mockExecutor)
execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{})

identity := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil)
identity := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil, nil)
ctx := identity.WithContext(context.Background())
resp, err := execManager.TerminateExecution(ctx, admin.ExecutionTerminateRequest{
Id: &core.WorkflowExecutionIdentifier{
Expand Down