Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
Signed-off-by: bcmmbaga <[email protected]>
  • Loading branch information
bcmmbaga committed Sep 24, 2024
1 parent d9f612d commit f903fce
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 55 deletions.
39 changes: 23 additions & 16 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ import (
cacheStore "github.com/eko/gocache/v3/store"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"

"github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
Expand All @@ -41,6 +36,10 @@ import (
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
)

const (
Expand Down Expand Up @@ -1255,30 +1254,37 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return nil
}

// GetAccountIDByUserOrAccountID looks for an account by user or accountID, if no account is provided and
// userID doesn't have an account associated with it, one account is created
// domain is used to create a new account if no account is found
// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided.
// If an accountID is provided, it checks if the account exists and returns it.
// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID.
// If the user doesn't have an account, it creates one using the provided domain.
// Returns the account ID or an error if none is found or created.
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
if accountID != "" {
_, _, err := am.Store.GetAccountDomainAndCategory(ctx, accountID)
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
if err != nil {
return "", err
}
if !exists {
return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
}
return accountID, nil
} else if userID != "" {
}

if userID != "" {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
if err != nil {
return "", status.Errorf(status.NotFound, "account not found using user id: %s", userID)
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
}

err = am.addAccountIDToIDPAppMeta(ctx, userID, account)
if err != nil {
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
return "", err
}

return account.Id, nil
}

return "", status.Errorf(status.NotFound, "no valid user or account Id provided")
return "", status.Errorf(status.NotFound, "no valid userID or accountID provided")
}

func isNil(i idp.Manager) bool {
Expand Down Expand Up @@ -1808,6 +1814,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return nil
}

// TODO: Remove GetAccount after refactoring account peer's update
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()

Expand Down Expand Up @@ -1907,7 +1914,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
}

domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, claims.AccountId)
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
if err != nil {
return "", err
}
Expand All @@ -1923,7 +1930,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)

// We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, claims.Domain)
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
if err != nil {
// if NotFound we are good to continue, otherwise return error
e, ok := status.FromError(err)
Expand Down
19 changes: 14 additions & 5 deletions management/server/file_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@ import (
"sync"
"time"

"github.com/rs/xid"
log "github.com/sirupsen/logrus"

nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/util"
)
Expand Down Expand Up @@ -958,11 +957,11 @@ func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented")
}

func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ string) (string, error) {
func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) {
return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
}

func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, accountID string) (string, string, error) {
func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) {
s.mux.Lock()
defer s.mux.Unlock()

Expand All @@ -973,3 +972,13 @@ func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, accountID str

return account.Domain, account.DomainCategory, nil
}

// AccountExists checks whether an account exists by the given ID.
func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) {
_, exists := s.Accounts[id]
return exists, nil
}

func (s *FileStore) UpdateAccount(_ context.Context, _ LockingStrength, _ *Account) error {
return nil
}
66 changes: 55 additions & 11 deletions management/server/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
}

func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
accountID, err := s.GetAccountIDByPrivateDomain(ctx, domain)
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
if err != nil {
return nil, err
}
Expand All @@ -409,11 +409,12 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
return s.GetAccount(ctx, accountID)
}

func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error) {
var account Account

result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory)
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory,
).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
Expand All @@ -422,7 +423,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, domain strin
return "", status.Errorf(status.Internal, "issue getting account from store")
}

return account.Id, nil
return accountID, nil
}

func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
Expand Down Expand Up @@ -671,9 +672,8 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
}

func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
var user User
var accountID string
result := s.db.Model(&user).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
Expand Down Expand Up @@ -1035,10 +1035,54 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
}
}

// UpdateAccount updates an existing account's domain, DNS settings, and settings fields.
func (s *SqlStore) UpdateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error {
updates := make(map[string]interface{})

if account.Domain != "" {
updates["domain"] = account.Domain
}

if account.DNSSettings.DisabledManagementGroups != nil {
updates["dns_settings"] = account.DNSSettings
}

if account.Settings != nil {
updates["settings"] = account.Settings
}

if len(updates) == 0 {
return nil
}

result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Where("id = ?", account.Id).Updates(updates)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to update account: %v", result.Error)
}

if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "account not found")
}

return nil
}

// AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
var count int64
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Where(idQueryCondition, id).Count(&count)
if result.Error != nil {
return false, result.Error
}
return count > 0, nil
}

// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, accountID string) (string, string, error) {
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
var account Account
result := s.db.WithContext(ctx).Model(&Account{}).Select("domain", "domain_category").
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
Where(idQueryCondition, accountID).First(&account)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
Expand Down
57 changes: 34 additions & 23 deletions management/server/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ const (
type Store interface {
GetAllAccounts(ctx context.Context) []*Account
GetAccount(ctx context.Context, accountID string) (*Account, error)
GetAccountDomainAndCategory(ctx context.Context, accountID string) (string, string, error)
DeleteAccount(ctx context.Context, account *Account) error
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
Expand All @@ -49,45 +49,56 @@ type Store interface {
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error)
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
UpdateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error
SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error

GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
SaveAccount(ctx context.Context, account *Account) error
SaveUsers(accountID string, users map[string]*User) error
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error

GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error

GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)

GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error

GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error

GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)

GetInstallationID() string
SaveInstallationID(ctx context.Context, ID string) error

// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock(ctx context.Context) func()
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error

// Close should close the store persisting all unsaved data.
Close(ctx context.Context) error
// GetStoreEngine should return StoreEngine of the current store implementation.
// This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
IncrementNetworkSerial(ctx context.Context, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
}

Expand Down

0 comments on commit f903fce

Please sign in to comment.