Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[management] Remove redundant get account calls in GetAccountFromToken #2615

Merged
merged 29 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
258b30c
refactor access control middleware and user access by JWT groups
bcmmbaga Sep 16, 2024
3cf1b02
refactor jwt groups extractor
bcmmbaga Sep 17, 2024
e5d55d3
refactor handlers to get account when necessary
bcmmbaga Sep 17, 2024
ccab3b4
refactor getAccountFromToken
bcmmbaga Sep 18, 2024
720d36a
refactor getAccountWithAuthorizationClaims
bcmmbaga Sep 18, 2024
a4c4158
Merge branch 'main' into refactor-get-account-by-token
bcmmbaga Sep 18, 2024
021fc8f
fix merge
bcmmbaga Sep 18, 2024
f60a423
revert handles change
bcmmbaga Sep 18, 2024
8f9c54f
remove GetUserByID from account manager
bcmmbaga Sep 18, 2024
9631cb4
fix tests
bcmmbaga Sep 18, 2024
4d9bb7e
refactor getAccountWithAuthorizationClaims to return account id
bcmmbaga Sep 20, 2024
26dd045
Merge branch 'main' into refactor-get-account-by-token
bcmmbaga Sep 20, 2024
8f98add
refactor handlers to use GetAccountIDFromToken
bcmmbaga Sep 22, 2024
7601a17
fix tests
bcmmbaga Sep 22, 2024
d9f612d
remove locks
bcmmbaga Sep 23, 2024
2884038
refactor
bcmmbaga Sep 24, 2024
1ffe89d
add GetGroupByName from store
bcmmbaga Sep 24, 2024
7561706
add GetGroupByID from store and refactor
bcmmbaga Sep 24, 2024
eab8564
Refactor retrieval of policy and posture checks
bcmmbaga Sep 24, 2024
d14b855
Refactor user permissions and retrieves PAT
bcmmbaga Sep 24, 2024
16174f0
Refactor route, setupkey, nameserver and dns to get record(s) from store
bcmmbaga Sep 25, 2024
41b212f
Refactor store
bcmmbaga Sep 25, 2024
b815393
fix lint
bcmmbaga Sep 25, 2024
c384874
fix tests
bcmmbaga Sep 25, 2024
dc82c2d
fix add missing policy source posture checks
bcmmbaga Sep 26, 2024
871595d
Merge branch 'main' into refactor-get-account-by-token
bcmmbaga Sep 26, 2024
4575ae2
add store lock
bcmmbaga Sep 26, 2024
b1b2b0a
fix tests
bcmmbaga Sep 26, 2024
e90d9ce
add get account
bcmmbaga Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions management/server/file_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,24 @@ func (s *FileStore) UpdateAccount(_ context.Context, _ LockingStrength, _ *Accou
func (s *FileStore) GetGroupByID(_ context.Context, _, _ string) (*nbgroup.Group, error) {
return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented")
}

func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
}

func (s *FileStore) GetAccountPolicies(_ context.Context, _ string) ([]*Policy, error) {
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
}

func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) {
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")

}

func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ string) ([]*posture.Checks, error) {
return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
}

func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) {
return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented")
}
9 changes: 1 addition & 8 deletions management/server/http/policies_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package http
import (
"encoding/json"
"net/http"
"slices"
"strconv"

"github.com/gorilla/mux"
Expand Down Expand Up @@ -84,18 +83,12 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
return
}

account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
_, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}

policyIdx := slices.IndexFunc(account.Policies, func(policy *server.Policy) bool { return policy.ID == policyID })
if policyIdx < 0 {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
return
}

h.savePolicy(w, r, accountID, userID, policyID)
}

Expand Down
36 changes: 7 additions & 29 deletions management/server/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,30 +315,16 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,

// GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}

user, err := account.FindUser(userID)
if err != nil {
return nil, err
}

if !(user.HasAdminPower() || user.IsServiceUser) {
if (!user.HasAdminPower() && !user.IsServiceUser) || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
}

for _, policy := range account.Policies {
if policy.ID == policyID {
return policy, nil
}
}

return nil, status.Errorf(status.NotFound, "policy with ID %s not found", policyID)
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
}

// SavePolicy in the store
Expand Down Expand Up @@ -400,24 +386,16 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po

// ListPolicies from the store
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}

user, err := account.FindUser(userID)
if err != nil {
return nil, err
}

if !(user.HasAdminPower() || user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view policies")
if (!user.HasAdminPower() && !user.IsServiceUser) || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
}

return account.Policies, nil
return am.Store.GetAccountPolicies(ctx, accountID)
}

func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
Expand Down
34 changes: 6 additions & 28 deletions management/server/posture_checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,16 @@ const (
)

func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}

user, err := account.FindUser(userID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}

if !user.HasAdminPower() {
if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}

for _, postureChecks := range account.PostureChecks {
if postureChecks.ID == postureChecksID {
return postureChecks, nil
}
}

return nil, status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID)
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
}

func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
Expand Down Expand Up @@ -121,24 +107,16 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
}

func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}

user, err := account.FindUser(userID)
if err != nil {
return nil, err
}

if !user.HasAdminPower() {
if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}

return account.PostureChecks, nil
return am.Store.GetAccountPostureChecks(ctx, accountID)
}

func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
Expand Down
52 changes: 49 additions & 3 deletions management/server/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
var user User
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&user, idQueryCondition, userID)
Preload(clause.Associations).First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
Expand Down Expand Up @@ -1095,7 +1095,8 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength

func (s *SqlStore) GetGroupByID(ctx context.Context, groupID, accountID string) (*nbgroup.Group, error) {
var group nbgroup.Group
result := s.db.WithContext(ctx).Model(&nbgroup.Group{}).Where(accountAndIDQueryCondition, accountID, groupID).First(&group)
result := s.db.WithContext(ctx).Model(&nbgroup.Group{}).Preload(clause.Associations).
Where(accountAndIDQueryCondition, accountID, groupID).First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found")
Expand All @@ -1109,7 +1110,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, groupID, accountID string)
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
var group nbgroup.Group
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbgroup.Group{}).
Where("name = ? and account_id = ?", groupName, accountID).Order("json_array_length(peers) DESC").First(&group)
Preload(clause.Associations).Where("name = ? and account_id = ?", groupName, accountID).Order("json_array_length(peers) DESC").First(&group)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found")
Expand All @@ -1118,3 +1119,48 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
}
return &group, nil
}

func (s *SqlStore) GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) {
pascal-fischer marked this conversation as resolved.
Show resolved Hide resolved
var policies []*Policy
result := s.db.WithContext(ctx).Model(&Policy{}).Where(accountIDCondition, accountID).
Preload(clause.Associations).Find(&policies)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get account posture checks: %v", result.Error)
}
return policies, nil
}

func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
var policy *Policy
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Policy{}).
Preload(clause.Associations).Where(accountAndIDQueryCondition, accountID, policyID).First(&policy)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "posture checks not found")
}
return nil, status.Errorf(status.Internal, "failed to get posture checks from store: %s", result.Error)
}
return policy, nil
}

func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
pascal-fischer marked this conversation as resolved.
Show resolved Hide resolved
var postureChecks []*posture.Checks
result := s.db.WithContext(ctx).Model(&posture.Checks{}).Where(accountIDCondition, accountID).Find(&postureChecks)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get account posture checks: %v", result.Error)
}
return postureChecks, nil
}

func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
var postureCheck *posture.Checks
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&posture.Checks{}).
Where(accountAndIDQueryCondition, accountID, postureCheckID).First(&postureCheck)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "posture checks not found")
}
return nil, status.Errorf(status.Internal, "failed to get posture checks from store: %s", result.Error)
}
return postureCheck, nil
}
5 changes: 5 additions & 0 deletions management/server/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,12 @@ type Store interface {
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error

GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)

GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)

GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
Expand Down