diff --git a/auth/authzserver/provider.go b/auth/authzserver/provider.go index 1aa88ac1e..2695897bf 100644 --- a/auth/authzserver/provider.go +++ b/auth/authzserver/provider.go @@ -175,7 +175,7 @@ func verifyClaims(expectedAudience sets.String, claimsRaw map[string]interface{} scopes.Insert(auth.ScopeAll) } - return auth.NewIdentityContext(claims.Audience[foundAudIndex], claims.Subject, clientID, claims.IssuedAt, scopes, userInfo), nil + return auth.NewIdentityContext(claims.Audience[foundAudIndex], claims.Subject, clientID, claims.IssuedAt, scopes, userInfo, claimsRaw), nil } // NewProvider creates a new OAuth2 Provider that is able to do OAuth 2-legged and 3-legged flows. It'll lookup diff --git a/auth/identity_context.go b/auth/identity_context.go index c045a3619..eafa0dcf0 100644 --- a/auth/identity_context.go +++ b/auth/identity_context.go @@ -13,6 +13,8 @@ var ( emptyIdentityContext = IdentityContext{} ) +type claimsType = map[string]interface{} + // IdentityContext is an abstract entity to enclose the authenticated identity of the user/app. Both gRPC and HTTP // servers have interceptors to set the IdentityContext on the context.Context. // To retrieve the current IdentityContext call auth.IdentityContextFromContext(ctx). @@ -25,6 +27,8 @@ type IdentityContext struct { userInfo *service.UserInfoResponse // Set to pointer just to keep this struct go-simple to support equal operator scopes *sets.String + // Raw JWT token from the IDP. Set to a pointer to support the equal operator for this struct. + claims *claimsType } func (c IdentityContext) Audience() string { @@ -59,6 +63,13 @@ func (c IdentityContext) Scopes() sets.String { return sets.NewString() } +func (c IdentityContext) Claims() map[string]interface{} { + if c.claims != nil { + return *c.claims + } + return make(map[string]interface{}) +} + func (c IdentityContext) WithContext(ctx context.Context) context.Context { return context.WithValue(ctx, ContextKeyIdentityContext, c) } @@ -68,7 +79,7 @@ func (c IdentityContext) AuthenticatedAt() time.Time { } // NewIdentityContext creates a new IdentityContext. -func NewIdentityContext(audience, userID, appID string, authenticatedAt time.Time, scopes sets.String, userInfo *service.UserInfoResponse) IdentityContext { +func NewIdentityContext(audience, userID, appID string, authenticatedAt time.Time, scopes sets.String, userInfo *service.UserInfoResponse, claims map[string]interface{}) IdentityContext { // For some reason, google IdP returns a subject in the ID Token but an empty subject in the /user_info endpoint if userInfo == nil { userInfo = &service.UserInfoResponse{} @@ -85,6 +96,7 @@ func NewIdentityContext(audience, userID, appID string, authenticatedAt time.Tim userInfo: userInfo, authenticatedAt: authenticatedAt, scopes: &scopes, + claims: &claims, } } diff --git a/auth/identity_context_test.go b/auth/identity_context_test.go new file mode 100644 index 000000000..0cda160f3 --- /dev/null +++ b/auth/identity_context_test.go @@ -0,0 +1,21 @@ +package auth + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestGetClaims(t *testing.T) { + noClaims := map[string]interface{}(nil) + noClaimsCtx := NewIdentityContext("", "", "", time.Now(), nil, nil, nil) + assert.EqualValues(t, noClaims, noClaimsCtx.Claims()) + + claims := map[string]interface{}{ + "groups": []string{"g1", "g2"}, + "something": "else", + } + withClaimsCtx := NewIdentityContext("", "", "", time.Now(), nil, nil, claims) + assert.EqualValues(t, claims, withClaimsCtx.Claims()) +} diff --git a/auth/interfaces/context.go b/auth/interfaces/context.go index 35d0aed4b..e872342bb 100644 --- a/auth/interfaces/context.go +++ b/auth/interfaces/context.go @@ -64,6 +64,8 @@ type IdentityContext interface { UserInfo() *service.UserInfoResponse AuthenticatedAt() time.Time Scopes() sets.String + // Returns the full set of claims in the JWT token provided by the IDP. + Claims() map[string]interface{} IsEmpty() bool WithContext(ctx context.Context) context.Context diff --git a/auth/interfaces/mocks/identity_context.go b/auth/interfaces/mocks/identity_context.go index e1302901d..0458ea841 100644 --- a/auth/interfaces/mocks/identity_context.go +++ b/auth/interfaces/mocks/identity_context.go @@ -115,6 +115,40 @@ func (_m *IdentityContext) AuthenticatedAt() time.Time { return r0 } +type IdentityContext_Claims struct { + *mock.Call +} + +func (_m IdentityContext_Claims) Return(_a0 map[string]interface{}) *IdentityContext_Claims { + return &IdentityContext_Claims{Call: _m.Call.Return(_a0)} +} + +func (_m *IdentityContext) OnClaims() *IdentityContext_Claims { + c_call := _m.On("Claims") + return &IdentityContext_Claims{Call: c_call} +} + +func (_m *IdentityContext) OnClaimsMatch(matchers ...interface{}) *IdentityContext_Claims { + c_call := _m.On("Claims", matchers...) + return &IdentityContext_Claims{Call: c_call} +} + +// Claims provides a mock function with given fields: +func (_m *IdentityContext) Claims() map[string]interface{} { + ret := _m.Called() + + var r0 map[string]interface{} + if rf, ok := ret.Get(0).(func() map[string]interface{}); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]interface{}) + } + } + + return r0 +} + type IdentityContext_IsEmpty struct { *mock.Call } diff --git a/auth/token.go b/auth/token.go index 9433ea6ff..5acdd3203 100644 --- a/auth/token.go +++ b/auth/token.go @@ -70,7 +70,6 @@ func ParseIDTokenAndValidate(ctx context.Context, clientID, rawIDToken string, p return idToken, flyteErr } - return idToken, nil } @@ -130,8 +129,12 @@ func IdentityContextFromIDTokenToken(ctx context.Context, tokenStr, clientID str if err != nil { return nil, err } + var claims map[string]interface{} + if err := idToken.Claims(&claims); err != nil { + logger.Infof(ctx, "Failed to unmarshal claims from id token, err: %v", err) + } // TODO: Document why automatically specify "all" scope return NewIdentityContext(idToken.Audience[0], idToken.Subject, "", idToken.IssuedAt, - sets.NewString(ScopeAll), userInfo), nil + sets.NewString(ScopeAll), userInfo, claims), nil } diff --git a/pkg/manager/impl/execution_manager_test.go b/pkg/manager/impl/execution_manager_test.go index 1d2ee0427..21098b663 100644 --- a/pkg/manager/impl/execution_manager_test.go +++ b/pkg/manager/impl/execution_manager_test.go @@ -367,7 +367,7 @@ func TestCreateExecution(t *testing.T) { request.Spec.RawOutputDataConfig = &admin.RawOutputDataConfig{OutputLocationPrefix: rawOutput} request.Spec.ClusterAssignment = &clusterAssignment - identity := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil) + identity := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil, nil) ctx := identity.WithContext(context.Background()) response, err := execManager.CreateExecution(ctx, request, requestedAt) assert.Nil(t, err) @@ -2834,7 +2834,7 @@ func TestTerminateExecution(t *testing.T) { r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &mockExecutor) execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) - identity := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil) + identity := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil, nil) ctx := identity.WithContext(context.Background()) resp, err := execManager.TerminateExecution(ctx, admin.ExecutionTerminateRequest{ Id: &core.WorkflowExecutionIdentifier{