Skip to content

Commit

Permalink
feat(api): refactor authentication drivers + manage github merge comm…
Browse files Browse the repository at this point in the history
…it (#6439)
  • Loading branch information
sguiheux authored Feb 13, 2023
1 parent 84655f8 commit 9e332b3
Show file tree
Hide file tree
Showing 66 changed files with 2,410 additions and 1,038 deletions.
176 changes: 103 additions & 73 deletions engine/api/api.go

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions engine/api/api_routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,9 @@ func (api *API) InitRouter() {
r.Handle("/config/cdn", ScopeNone(), r.GET(api.configCDNHandler))
r.Handle("/config/api", ScopeNone(), r.GET(api.configAPIHandler))

r.Handle("/link/driver", ScopeNone(), r.GET(api.getLinkDriversHandler))
r.Handle("/link/{consumerType}/ask", Scope(sdk.AuthConsumerScopeUser), r.POST(api.postAskLinkExternalUserWithCDSHandler))
r.Handle("/link/{consumerType}", Scope(sdk.AuthConsumerScopeUser), r.POST(api.postLinkExternalUserWithCDSHandler))
// Users
r.Handle("/user", Scope(sdk.AuthConsumerScopeUser), r.GET(api.getUsersHandler))
r.Handle("/user/favorite", Scope(sdk.AuthConsumerScopeUser), r.POST(api.postUserFavoriteHandler))
Expand All @@ -380,6 +383,7 @@ func (api *API) InitRouter() {
r.Handle("/user/{permUsernamePublic}", Scope(sdk.AuthConsumerScopeUser), r.GET(api.getUserHandler), r.PUT(api.putUserHandler), r.DELETE(api.deleteUserHandler))
r.Handle("/user/{permUsernamePublic}/group", Scope(sdk.AuthConsumerScopeUser), r.GET(api.getUserGroupsHandler))
r.Handle("/user/{permUsername}/contact", Scope(sdk.AuthConsumerScopeUser), r.GET(api.getUserContactsHandler))
r.Handle("/user/{permUsername}/link", Scope(sdk.AuthConsumerScopeUser), r.GET(api.getUserLinksHandler))
r.Handle("/user/{permUsername}/auth/consumer", Scope(sdk.AuthConsumerScopeAccessToken), r.GET(api.getConsumersByUserHandler), r.POST(api.postConsumerByUserHandler))
r.Handle("/user/{permUsername}/auth/consumer/{permConsumerID}", Scope(sdk.AuthConsumerScopeAccessToken), r.DELETE(api.deleteConsumerByUserHandler))
r.Handle("/user/{permUsername}/auth/consumer/{permConsumerID}/regen", Scope(sdk.AuthConsumerScopeAccessToken), r.POST(api.postConsumerRegenByUserHandler))
Expand Down
4 changes: 4 additions & 0 deletions engine/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"context"
"github.com/ovh/cds/engine/api/link"
"github.com/ovh/cds/engine/api/organization"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -43,6 +44,9 @@ func newTestAPI(t *testing.T, bootstrapFunc ...test.Bootstrapf) (*API, *test.Fak
api.AuthenticationDrivers[sdk.ConsumerBuiltin] = builtin.NewDriver()
api.AuthenticationDrivers[sdk.ConsumerTest] = authdrivertest.NewDriver(t)
api.AuthenticationDrivers[sdk.ConsumerTest2] = authdrivertest.NewDriver(t)

api.LinkDrivers = make(map[sdk.AuthConsumerType]link.LinkDriver)
api.LinkDrivers[sdk.ConsumerGithub] = authdrivertest.NewDriver(t)
api.GoRoutines = sdk.NewGoRoutines(context.TODO())

// Reset organization
Expand Down
19 changes: 10 additions & 9 deletions engine/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ func (api *API) getAuthAskSigninHandler() service.Handler {
if !consumerType.IsValid() {
return sdk.WithStack(sdk.ErrNotFound)
}
driver, ok := api.AuthenticationDrivers[consumerType]
authDriver, ok := api.AuthenticationDrivers[consumerType]
if !ok {
return sdk.WithStack(sdk.ErrNotFound)
}

driverRedirect, ok := driver.(sdk.AuthDriverWithRedirect)
driverRedirect, ok := authDriver.GetDriver().(sdk.DriverWithRedirect)
if !ok {
return nil
}
Expand Down Expand Up @@ -99,30 +99,31 @@ func (api *API) postAuthSigninHandler() service.Handler {
if !consumerType.IsValid() {
return sdk.WithStack(sdk.ErrNotFound)
}
driver, ok := api.AuthenticationDrivers[consumerType]
authDriver, ok := api.AuthenticationDrivers[consumerType]
if !ok {
return sdk.WithStack(sdk.ErrNotFound)
}
signInDriver := authDriver.GetDriver().(sdk.DriverWithSignInRequest)

// Extract and validate signin request
var req sdk.AuthConsumerSigninRequest
if err := service.UnmarshalBody(r, &req); err != nil {
return err
}
if err := driver.CheckSigninRequest(req); err != nil {
if err := signInDriver.CheckSigninRequest(req); err != nil {
return err
}

// Extract and validate signin state
switch x := driver.(type) {
case sdk.AuthDriverWithSigninStateToken:
switch x := authDriver.GetDriver().(type) {
case sdk.DriverWithSigninStateToken:
if err := x.CheckSigninStateToken(req); err != nil {
return err
}
}

// Convert code to external user info
userInfo, err := driver.GetUserInfo(ctx, req)
userInfo, err := authDriver.GetUserInfo(ctx, req)
if err != nil {
return err
}
Expand Down Expand Up @@ -206,7 +207,7 @@ func (api *API) postAuthSigninHandler() service.Handler {
} else {
// We can't find any user with the same email address
// So we will do signup for a new user from the data got from the auth driver
if driver.GetManifest().SignupDisabled {
if authDriver.GetManifest().SignupDisabled {
return sdk.WithStack(sdk.ErrSignupDisabled)
}

Expand Down Expand Up @@ -273,7 +274,7 @@ func (api *API) postAuthSigninHandler() service.Handler {
}

// Generate a new session for consumer
sessionDuration := driver.GetSessionDuration()
sessionDuration := authDriver.GetSessionDuration()
var session *sdk.AuthSession
if userInfo.MFA {
trackSudo(ctx, w)
Expand Down
13 changes: 8 additions & 5 deletions engine/api/auth_builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"time"

"github.com/ovh/cds/engine/api/authentication"
"github.com/ovh/cds/engine/api/authentication/builtin"
"github.com/ovh/cds/engine/api/driver/builtin"
"github.com/ovh/cds/engine/api/user"
"github.com/ovh/cds/engine/service"
"github.com/ovh/cds/sdk"
Expand All @@ -19,7 +19,7 @@ import (
func (api *API) postAuthBuiltinSigninHandler() service.Handler {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
// Get the consumer builtin driver
driver, ok := api.AuthenticationDrivers[sdk.ConsumerBuiltin]
authDriver, ok := api.AuthenticationDrivers[sdk.ConsumerBuiltin]
if !ok {
return sdk.WithStack(sdk.ErrForbidden)
}
Expand All @@ -29,12 +29,15 @@ func (api *API) postAuthBuiltinSigninHandler() service.Handler {
if err := service.UnmarshalBody(r, &req); err != nil {
return sdk.NewError(sdk.ErrForbidden, err)
}
if err := driver.CheckSigninRequest(req); err != nil {

driv := authDriver.GetDriver().(sdk.DriverWithSignInRequest)

if err := driv.CheckSigninRequest(req); err != nil {
return sdk.NewError(sdk.ErrForbidden, err)
}

// Convert code to external user info
userInfo, err := driver.GetUserInfo(ctx, req)
userInfo, err := authDriver.GetUserInfo(ctx, req)
if err != nil {
return sdk.NewError(sdk.ErrForbidden, err)
}
Expand Down Expand Up @@ -101,7 +104,7 @@ func (api *API) postAuthBuiltinSigninHandler() service.Handler {
}

// Generate a new session for consumer
session, err := authentication.NewSession(ctx, tx, &consumer.AuthConsumer, driver.GetSessionDuration())
session, err := authentication.NewSession(ctx, tx, &consumer.AuthConsumer, authDriver.GetSessionDuration())
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion engine/api/auth_consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"context"
builtindriver "github.com/ovh/cds/engine/api/driver/builtin"
"net/http"
"time"

Expand Down Expand Up @@ -197,7 +198,7 @@ func (api *API) postConsumerRegenByUserHandler() service.Handler {
return err
}

jws, err := builtin.NewSigninConsumerToken(consumer) // Regen a new jws (signin token)
jws, err := builtindriver.NewSigninConsumerToken(consumer) // Regen a new jws (signin token)
if err != nil {
return err
}
Expand Down
34 changes: 18 additions & 16 deletions engine/api/auth_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"context"
localdriver "github.com/ovh/cds/engine/api/driver/local"
"net/http"
"time"

Expand All @@ -20,12 +21,12 @@ import (
// postAuthLocalSignupHandler creates a new registration that need to be verified to create a new user.
func (api *API) postAuthLocalSignupHandler() service.Handler {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
driver, okDriver := api.AuthenticationDrivers[sdk.ConsumerLocal]
if !okDriver || driver.GetManifest().SignupDisabled {
authDriver, okDriver := api.AuthenticationDrivers[sdk.ConsumerLocal]
if !okDriver || authDriver.GetManifest().SignupDisabled {
return sdk.WithStack(sdk.ErrSignupDisabled)
}

localDriver := driver.(*local.AuthDriver)
localDriver := authDriver.GetDriver().(*localdriver.LocalDriver)

// Extract and validate signup request
var reqData sdk.AuthConsumerSigninRequest
Expand Down Expand Up @@ -187,17 +188,18 @@ func initBuiltinConsumersFromStartupConfig(ctx context.Context, tx gorpmapper.Sq
// postAuthLocalSigninHandler returns a new session for an existing local consumer.
func (api *API) postAuthLocalSigninHandler() service.Handler {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
driver, okDriver := api.AuthenticationDrivers[sdk.ConsumerLocal]
authDriver, okDriver := api.AuthenticationDrivers[sdk.ConsumerLocal]
if !okDriver {
return sdk.WithStack(sdk.ErrForbidden)
}
localDriver := authDriver.GetDriver().(*localdriver.LocalDriver)

// Extract and validate signup request
var reqData sdk.AuthConsumerSigninRequest
if err := service.UnmarshalBody(r, &reqData); err != nil {
return err
}
if err := driver.CheckSigninRequest(reqData); err != nil {
if err := localDriver.CheckSigninRequest(reqData); err != nil {
return err
}

Expand All @@ -217,7 +219,7 @@ func (api *API) postAuthLocalSigninHandler() service.Handler {
return sdk.NewErrorWithStack(err, sdk.ErrUnauthorized)
}

userInfo, err := driver.GetUserInfo(ctx, reqData)
userInfo, err := authDriver.GetUserInfo(ctx, reqData)
if err != nil {
return err
}
Expand Down Expand Up @@ -245,7 +247,7 @@ func (api *API) postAuthLocalSigninHandler() service.Handler {
}

// Generate a new session for consumer
session, err := authentication.NewSession(ctx, tx, &consumer.AuthConsumer, driver.GetSessionDuration())
session, err := authentication.NewSession(ctx, tx, &consumer.AuthConsumer, authDriver.GetSessionDuration())
if err != nil {
return err
}
Expand Down Expand Up @@ -283,11 +285,11 @@ func (api *API) postAuthLocalSigninHandler() service.Handler {

func (api *API) postAuthLocalVerifyHandler() service.Handler {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
driver, okDriver := api.AuthenticationDrivers[sdk.ConsumerLocal]
authDriver, okDriver := api.AuthenticationDrivers[sdk.ConsumerLocal]
if !okDriver {
return sdk.WithStack(sdk.ErrForbidden)
}
localDriver := driver.(*local.AuthDriver)
localDriver := authDriver.GetDriver().(*localdriver.LocalDriver)

var reqData sdk.AuthConsumerSigninRequest
var tokenInQueryString = QueryString(r, "token")
Expand Down Expand Up @@ -366,7 +368,7 @@ func (api *API) postAuthLocalVerifyHandler() service.Handler {
}
}

userInfo, err := driver.GetUserInfo(ctx, reqData)
userInfo, err := authDriver.GetUserInfo(ctx, reqData)
if err != nil {
return err
}
Expand Down Expand Up @@ -395,7 +397,7 @@ func (api *API) postAuthLocalVerifyHandler() service.Handler {
}

// Generate a new session for consumer
session, err := authentication.NewSession(ctx, tx, &consumer.AuthConsumer, driver.GetSessionDuration())
session, err := authentication.NewSession(ctx, tx, &consumer.AuthConsumer, authDriver.GetSessionDuration())
if err != nil {
return err
}
Expand Down Expand Up @@ -433,12 +435,12 @@ func (api *API) postAuthLocalVerifyHandler() service.Handler {

func (api *API) postAuthLocalAskResetHandler() service.Handler {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
driver, okDriver := api.AuthenticationDrivers[sdk.ConsumerLocal]
authDriver, okDriver := api.AuthenticationDrivers[sdk.ConsumerLocal]
if !okDriver {
return sdk.WithStack(sdk.ErrForbidden)
}

localDriver := driver.(*local.AuthDriver)
localDriver := authDriver.GetDriver().(*localdriver.LocalDriver)

var email string

Expand Down Expand Up @@ -511,12 +513,12 @@ func (api *API) postAuthLocalAskResetHandler() service.Handler {

func (api *API) postAuthLocalResetHandler() service.Handler {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
driver, okDriver := api.AuthenticationDrivers[sdk.ConsumerLocal]
authDriver, okDriver := api.AuthenticationDrivers[sdk.ConsumerLocal]
if !okDriver {
return sdk.WithStack(sdk.ErrForbidden)
}

localDriver := driver.(*local.AuthDriver)
localDriver := authDriver.GetDriver().(*localdriver.LocalDriver)

var reqData sdk.AuthConsumerSigninRequest
if err := service.UnmarshalBody(r, &reqData); err != nil {
Expand Down Expand Up @@ -572,7 +574,7 @@ func (api *API) postAuthLocalResetHandler() service.Handler {
}

// Generate a new session for consumer
session, err := authentication.NewSession(ctx, tx, &consumer.AuthConsumer, driver.GetSessionDuration())
session, err := authentication.NewSession(ctx, tx, &consumer.AuthConsumer, authDriver.GetSessionDuration())
if err != nil {
return err
}
Expand Down
7 changes: 4 additions & 3 deletions engine/api/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"context"
"encoding/json"
corpssodriver "github.com/ovh/cds/engine/api/driver/corpsso"
"github.com/ovh/cds/engine/api/organization"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -319,7 +320,7 @@ func Test_postAuthSigninHandler_WithCorporateSSO(t *testing.T) {
api, db, _ := newTestAPI(t)
api.Config.Auth.AllowedOrganizations = []string{"planet-express"}

var cfg corpsso.Config
var cfg corpssodriver.SSOConfig
cfg.Request.Keys.RequestSigningKey = AuthKey
cfg.Request.RedirectMethod = "POST"
cfg.Request.RedirectURL = "https://lolcat.local/sso/jwt"
Expand Down Expand Up @@ -356,7 +357,7 @@ func Test_postAuthSigninHandler_WithCorporateSSO(t *testing.T) {
requestedJWS = redirectInfo.Body["request"]
var data = sdk.AuthConsumerSigninRequest{}
data["state"] = requestedJWS
err := api.AuthenticationDrivers[sdk.ConsumerCorporateSSO].(sdk.AuthDriverWithSigninStateToken).CheckSigninStateToken(data)
err := api.AuthenticationDrivers[sdk.ConsumerCorporateSSO].GetDriver().(sdk.DriverWithSigninStateToken).CheckSigninStateToken(data)
require.NoError(t, err)
})

Expand Down Expand Up @@ -405,7 +406,7 @@ func Test_postAuthSigninHandler_WithCorporateSSO(t *testing.T) {
}

func generateToken(t *testing.T, username string) string {
ssoToken := corpsso.IssuedToken{
ssoToken := corpssodriver.IssuedToken{
RemoteUser: username,
RemoteUsername: strings.Title(username),
Email: username + "@planet-express.futurama",
Expand Down
Loading

0 comments on commit 9e332b3

Please sign in to comment.