Skip to content

Commit

Permalink
Load OAuth credential on startup instead of request processing (#4442)
Browse files Browse the repository at this point in the history
  • Loading branch information
iamrodrigo authored Sep 7, 2021
1 parent 7ca1886 commit 7aca829
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 31 deletions.
2 changes: 1 addition & 1 deletion common/authorization/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
"github.com/uber/cadence/common/log"
)

func NewAuthorizer(authorization config.Authorization, logger log.Logger, domainCache cache.DomainCache) Authorizer {
func NewAuthorizer(authorization config.Authorization, logger log.Logger, domainCache cache.DomainCache) (Authorizer, error) {
switch true {
case authorization.OAuthAuthorizer.Enable:
return NewOAuthAuthorizer(authorization.OAuthAuthorizer, logger, domainCache)
Expand Down
12 changes: 8 additions & 4 deletions common/authorization/factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/cristalhq/jwt/v3"
"github.com/stretchr/testify/suite"

"github.com/uber/cadence/common"
"github.com/uber/cadence/common/config"
"github.com/uber/cadence/common/log"
"github.com/uber/cadence/common/log/loggerimpl"
Expand Down Expand Up @@ -63,7 +64,7 @@ func cfgOAuth() config.Authorization {
Enable: true,
JwtCredentials: config.JwtCredentials{
Algorithm: jwt.RS256.String(),
PublicKey: "public",
PublicKey: "../../config/credentials/keytest.pub",
},
MaxJwtTTL: 12345,
},
Expand All @@ -72,16 +73,19 @@ func cfgOAuth() config.Authorization {

func (s *factorySuite) TestFactoryNoopAuthorizer() {
cfgOAuthVar := cfgOAuth()
publicKey, _ := common.LoadRSAPublicKey(cfgOAuthVar.OAuthAuthorizer.JwtCredentials.PublicKey)
var tests = []struct {
cfg config.Authorization
expected Authorizer
err error
}{
{cfgNoop(), &nopAuthority{}},
{cfgOAuthVar, &oauthAuthority{authorizationCfg: cfgOAuthVar.OAuthAuthorizer, log: s.logger}},
{cfgNoop(), &nopAuthority{}, nil},
{cfgOAuthVar, &oauthAuthority{authorizationCfg: cfgOAuthVar.OAuthAuthorizer, log: s.logger, publicKey: publicKey}, nil},
}

for _, test := range tests {
authorizer := NewAuthorizer(test.cfg, s.logger, nil)
authorizer, err := NewAuthorizer(test.cfg, s.logger, nil)
s.Equal(authorizer, test.expected)
s.Equal(err, test.err)
}
}
4 changes: 2 additions & 2 deletions common/authorization/nopAuthorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import "context"
type nopAuthority struct{}

// NewNopAuthorizer creates a no-op authority
func NewNopAuthorizer() Authorizer {
return &nopAuthority{}
func NewNopAuthorizer() (Authorizer, error) {
return &nopAuthority{}, nil
}

func (a *nopAuthority) Authorize(
Expand Down
18 changes: 11 additions & 7 deletions common/authorization/oauthAuthorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package authorization

import (
"context"
"crypto/rsa"
"encoding/json"
"fmt"
"strings"
Expand All @@ -41,6 +42,7 @@ type oauthAuthority struct {
authorizationCfg config.OAuthAuthorizer
domainCache cache.DomainCache
log log.Logger
publicKey *rsa.PublicKey
}

type JWTClaims struct {
Expand All @@ -59,12 +61,17 @@ func NewOAuthAuthorizer(
authorizationCfg config.OAuthAuthorizer,
log log.Logger,
domainCache cache.DomainCache,
) Authorizer {
) (Authorizer, error) {
publicKey, err := common.LoadRSAPublicKey(authorizationCfg.JwtCredentials.PublicKey)
if err != nil {
return nil, err
}
return &oauthAuthority{
authorizationCfg: authorizationCfg,
domainCache: domainCache,
log: log,
}
publicKey: publicKey,
}, nil
}

// Authorize defines the logic to verify get claims from token
Expand Down Expand Up @@ -109,12 +116,9 @@ func (a *oauthAuthority) Authorize(
}

func (a *oauthAuthority) getVerifier() (jwt.Verifier, error) {
publicKey, err := common.LoadRSAPublicKey(a.authorizationCfg.JwtCredentials.PublicKey)
if err != nil {
return nil, err
}

algorithm := jwt.Algorithm(a.authorizationCfg.JwtCredentials.Algorithm)
verifier, err := jwt.NewVerifierRS(algorithm, publicKey)
verifier, err := jwt.NewVerifierRS(algorithm, a.publicKey)
if err != nil {
return nil, err
}
Expand Down
35 changes: 22 additions & 13 deletions common/authorization/oauthAutorizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ func (s *oauthSuite) TearDownTest() {

func (s *oauthSuite) TestCorrectPayload() {
s.domainCache.EXPECT().GetDomain(s.att.DomainName).Return(s.domainEntry, nil).Times(1)
authorizer := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
authorizer, err := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
s.NoError(err)
result, err := authorizer.Authorize(s.ctx, &s.att)
s.NoError(err)
s.Equal(result.Decision, DecisionAllow)
Expand All @@ -132,7 +133,8 @@ func (s *oauthSuite) TestItIsAdmin() {
Headers: transport.NewHeaders().With(common.AuthorizationTokenHeaderName, token),
})
s.NoError(err)
authorizer := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
authorizer, err := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
s.NoError(err)
result, err := authorizer.Authorize(ctx, &s.att)
s.NoError(err)
s.Equal(result.Decision, DecisionAllow)
Expand All @@ -145,7 +147,8 @@ func (s *oauthSuite) TestEmptyToken() {
Headers: transport.NewHeaders().With(common.AuthorizationTokenHeaderName, ""),
})
s.NoError(err)
authorizer := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
authorizer, err := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
s.NoError(err)
s.logger.On("Debug", "request is not authorized", mock.MatchedBy(func(t []tag.Tag) bool {
return fmt.Sprintf("%v", t[0].Field().Interface) == "token is not set in header"
}))
Expand All @@ -155,31 +158,33 @@ func (s *oauthSuite) TestEmptyToken() {

func (s *oauthSuite) TestGetDomainError() {
s.domainCache.EXPECT().GetDomain(s.att.DomainName).Return(nil, fmt.Errorf("error")).Times(1)
authorizer := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
authorizer, err := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
s.NoError(err)
result, err := authorizer.Authorize(s.ctx, &s.att)
s.Equal(result.Decision, DecisionDeny)
s.EqualError(err, "error")
}

func (s *oauthSuite) TestIncorrectPublicKey() {
s.cfg.JwtCredentials.PublicKey = "incorrectPublicKey"
authorizer := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
result, err := authorizer.Authorize(s.ctx, &s.att)
authorizer, err := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
s.Equal(authorizer, nil)
s.EqualError(err, "invalid public key path incorrectPublicKey")
s.Equal(result.Decision, DecisionDeny)
}

func (s *oauthSuite) TestIncorrectAlgorithm() {
s.cfg.JwtCredentials.Algorithm = "SHA256"
authorizer := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
authorizer, err := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
s.NoError(err)
result, err := authorizer.Authorize(s.ctx, &s.att)
s.EqualError(err, "jwt: algorithm is not supported")
s.Equal(result.Decision, DecisionDeny)
}

func (s *oauthSuite) TestMaxTTLLargerInToken() {
s.cfg.MaxJwtTTL = 1
authorizer := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
authorizer, err := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
s.NoError(err)
s.logger.On("Debug", "request is not authorized", mock.MatchedBy(func(t []tag.Tag) bool {
return fmt.Sprintf("%v", t[0].Field().Interface) == "TTL in token is larger than MaxTTL allowed"
}))
Expand All @@ -194,7 +199,8 @@ func (s *oauthSuite) TestIncorrectToken() {
Headers: transport.NewHeaders().With(common.AuthorizationTokenHeaderName, "test"),
})
s.NoError(err)
authorizer := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
authorizer, err := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
s.NoError(err)
s.logger.On("Debug", "request is not authorized", mock.MatchedBy(func(t []tag.Tag) bool {
return fmt.Sprintf("%v", t[0].Field().Interface) == "jwt: token format is not valid"
}))
Expand All @@ -211,7 +217,8 @@ func (s *oauthSuite) TestIatExpiredToken() {
Headers: transport.NewHeaders().With(common.AuthorizationTokenHeaderName, token),
})
s.NoError(err)
authorizer := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
authorizer, err := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
s.NoError(err)
s.logger.On("Debug", "request is not authorized", mock.MatchedBy(func(t []tag.Tag) bool {
return fmt.Sprintf("%v", t[0].Field().Interface) == "JWT has expired"
}))
Expand All @@ -223,7 +230,8 @@ func (s *oauthSuite) TestDifferentGroup() {
s.domainEntry.GetInfo().Data[common.DomainDataKeyForReadGroups] = "AdifferentGroup"
s.domainCache.EXPECT().GetDomain(s.att.DomainName).Return(s.domainEntry, nil).Times(1)
s.att.Permission = PermissionWrite
authorizer := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
authorizer, err := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
s.NoError(err)
s.logger.On("Debug", "request is not authorized", mock.MatchedBy(func(t []tag.Tag) bool {
return fmt.Sprintf("%v", t[0].Field().Interface) == "token doesn't have the right permission, jwt groups: [a b c], allowed groups: []"
}))
Expand All @@ -234,7 +242,8 @@ func (s *oauthSuite) TestDifferentGroup() {
func (s *oauthSuite) TestIncorrectPermission() {
s.domainCache.EXPECT().GetDomain(s.att.DomainName).Return(s.domainEntry, nil).Times(1)
s.att.Permission = Permission(15)
authorizer := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
authorizer, err := NewOAuthAuthorizer(s.cfg, s.logger, s.domainCache)
s.NoError(err)
s.logger.On("Debug", "request is not authorized", mock.MatchedBy(func(t []tag.Tag) bool {
return fmt.Sprintf("%v", t[0].Field().Interface) == "token doesn't have permission for 15 API"
}))
Expand Down
7 changes: 5 additions & 2 deletions host/onebox.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,12 @@ func (c *cadenceImpl) startFrontend(hosts map[string][]string, startWG *sync.Wai
params.ArchiverProvider = c.archiverProvider
params.ESConfig = c.esConfig
params.ESClient = c.esClient
params.Authorizer = authorization.NewAuthorizer(c.authorizationConfig, params.Logger, nil)

var err error
authorizer, err := authorization.NewAuthorizer(c.authorizationConfig, params.Logger, nil)
if err != nil {
c.logger.Fatal("Unable to create authorizer", tag.Error(err))
}
params.Authorizer = authorizer
params.PersistenceConfig, err = copyPersistenceConfig(c.persistenceConfig)
if err != nil {
c.logger.Fatal("Failed to copy persistence config for frontend", tag.Error(err))
Expand Down
7 changes: 6 additions & 1 deletion service/frontend/accessControlledAdminHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/uber/cadence/common/authorization"
"github.com/uber/cadence/common/config"
"github.com/uber/cadence/common/log/tag"
"github.com/uber/cadence/common/resource"
"github.com/uber/cadence/common/types"
)
Expand All @@ -42,7 +43,11 @@ var _ AdminHandler = (*AccessControlledWorkflowAdminHandler)(nil)
// NewAccessControlledHandlerImpl creates frontend handler with authentication support
func NewAccessControlledAdminHandlerImpl(adminHandler AdminHandler, resource resource.Resource, authorizer authorization.Authorizer, cfg config.Authorization) *AccessControlledWorkflowAdminHandler {
if authorizer == nil {
authorizer = authorization.NewAuthorizer(cfg, resource.GetLogger(), resource.GetDomainCache())
var err error
authorizer, err = authorization.NewAuthorizer(cfg, resource.GetLogger(), resource.GetDomainCache())
if err != nil {
resource.GetLogger().Fatal("Error when initiating the Authorizer", tag.Error(err))
}
}
return &AccessControlledWorkflowAdminHandler{
Resource: resource,
Expand Down
7 changes: 6 additions & 1 deletion service/frontend/accessControlledHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/uber/cadence/common/authorization"
"github.com/uber/cadence/common/config"
"github.com/uber/cadence/common/log/tag"
"github.com/uber/cadence/common/metrics"
"github.com/uber/cadence/common/resource"
"github.com/uber/cadence/common/types"
Expand All @@ -45,7 +46,11 @@ var _ Handler = (*AccessControlledWorkflowHandler)(nil)
// NewAccessControlledHandlerImpl creates frontend handler with authentication support
func NewAccessControlledHandlerImpl(wfHandler Handler, resource resource.Resource, authorizer authorization.Authorizer, cfg config.Authorization) *AccessControlledWorkflowHandler {
if authorizer == nil {
authorizer = authorization.NewAuthorizer(cfg, resource.GetLogger(), resource.GetDomainCache())
var err error
authorizer, err = authorization.NewAuthorizer(cfg, resource.GetLogger(), resource.GetDomainCache())
if err != nil {
resource.GetLogger().Fatal("Error when initiating the Authorizer", tag.Error(err))
}
}
return &AccessControlledWorkflowHandler{
Resource: resource,
Expand Down

0 comments on commit 7aca829

Please sign in to comment.