Skip to content

Commit

Permalink
Merge branch 'main' into groups-get-account-refactoring
Browse files Browse the repository at this point in the history
# Conflicts:
#	management/server/group.go
  • Loading branch information
bcmmbaga committed Nov 15, 2024
2 parents a4d905f + 44e799c commit 92b9e11
Show file tree
Hide file tree
Showing 14 changed files with 188 additions and 72 deletions.
8 changes: 5 additions & 3 deletions client/firewall/iptables/manager_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}

// persist early to ensure cleanup of chains
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()

return nil
}
Expand Down
8 changes: 5 additions & 3 deletions client/firewall/nftables/manager_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}

// persist early
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()

return nil
}
Expand Down
6 changes: 3 additions & 3 deletions client/internal/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil {
return nil, err
}
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
return cfg, err
}

Expand All @@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {

// WriteOutConfig write put the prepared config to the given path
func WriteOutConfig(path string, config *Config) error {
return util.WriteJson(path, config)
return util.WriteJson(context.Background(), path, config)
}

// createNewConfig creates a new config generating a new Wireguard key and saving to file
Expand Down Expand Up @@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) {
}

if updated {
if err := util.WriteJson(input.ConfigPath, config); err != nil {
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
return nil, err
}
}
Expand Down
10 changes: 7 additions & 3 deletions client/internal/dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}

// don't block
go func() {
if err := s.stateManager.PersistState(ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}
}()

if s.searchDomainNotifier != nil {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
Expand Down
7 changes: 5 additions & 2 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"


nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
Expand Down Expand Up @@ -171,7 +170,7 @@ type Engine struct {

relayManager *relayClient.Manager
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
srWatcher *guard.SRWatcher
}

// Peer is an instance of the Connection Peer
Expand Down Expand Up @@ -641,6 +640,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
}

func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
if e.wgInterface == nil {
return errors.New("wireguard interface is not initialized")
}

if e.wgInterface.Address().String() != conf.Address {
oldAddr := e.wgInterface.Address().String()
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
Expand Down
98 changes: 98 additions & 0 deletions client/internal/routemanager/client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package routemanager

import (
"fmt"
"net/netip"
"testing"
"time"
Expand Down Expand Up @@ -227,6 +228,64 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "relayed routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: true,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
relayed: true,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "p2p routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "current route with bad score should be changed to route with better score",
statuses: map[route.ID]routerPeerStatus{
Expand Down Expand Up @@ -287,6 +346,45 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
},
}

// fill the test data with random routes
for _, tc := range testCases {
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}

for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p1_%d", i))
dummyStatus := routerPeerStatus{
connected: false,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p2_%d", i))
dummyStatus := routerPeerStatus{
connected: false,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
currentRoute := &route.Route{
Expand Down
20 changes: 5 additions & 15 deletions client/internal/statemanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"golang.org/x/exp/maps"

nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/util"
)

// State interface defines the methods that all state types must implement
Expand Down Expand Up @@ -178,25 +179,14 @@ func (m *Manager) PersistState(ctx context.Context) error {
return nil
}

ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()

done := make(chan error, 1)

start := time.Now()
go func() {
data, err := json.MarshalIndent(m.states, "", " ")
if err != nil {
done <- fmt.Errorf("marshal states: %w", err)
return
}

// nolint:gosec
if err := os.WriteFile(m.filePath, data, 0640); err != nil {
done <- fmt.Errorf("write state file: %w", err)
return
}

done <- nil
done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states)
}()

select {
Expand All @@ -208,7 +198,7 @@ func (m *Manager) PersistState(ctx context.Context) error {
}
}

log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty))
log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start))

clear(m.dirty)

Expand Down
22 changes: 5 additions & 17 deletions client/internal/statemanager/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,20 @@ import (
"os"
"path/filepath"
"runtime"

log "github.com/sirupsen/logrus"
)

// GetDefaultStatePath returns the path to the state file based on the operating system
// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist.
// It returns an empty string if the path cannot be determined.
func GetDefaultStatePath() string {
var path string

switch runtime.GOOS {
case "windows":
path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
case "darwin", "linux":
path = "/var/lib/netbird/state.json"
return "/var/lib/netbird/state.json"
case "freebsd", "openbsd", "netbsd", "dragonfly":
path = "/var/db/netbird/state.json"
// ios/android don't need state
default:
return ""
return "/var/db/netbird/state.json"
}

dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
return ""
}
return ""

return path
}
2 changes: 1 addition & 1 deletion management/server/file_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
// It is recommended to call it with locking FileStore.mux
func (s *FileStore) persist(ctx context.Context, file string) error {
start := time.Now()
err := util.WriteJson(file, s)
err := util.WriteJson(context.Background(), file, s)
if err != nil {
return err
}
Expand Down
12 changes: 4 additions & 8 deletions management/server/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ import (
"fmt"
"slices"

nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"

nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"

"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status"
Expand All @@ -27,11 +28,6 @@ func (e *GroupLinkError) Error() string {

// CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}

user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
Expand All @@ -41,7 +37,7 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco
return status.NewUserNotPartOfAccountError()
}

if user.IsRegularUser() && settings.RegularUsersViewBlocked {
if user.IsRegularUser() {
return status.NewAdminPermissionError()
}

Expand Down
20 changes: 16 additions & 4 deletions management/server/http/peers_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,26 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {

dnsDomain := h.accountManager.GetDNSDomain()

respBody := make([]*api.PeerBatch, 0, len(account.Peers))
for _, peer := range account.Peers {
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}

groupsMap := map[string]*nbgroup.Group{}
groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
for _, group := range groups {
groupsMap[group.ID] = group
}

respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID)

respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
}
Expand Down Expand Up @@ -304,7 +316,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
}

func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum
groupsInfo := []api.GroupMinimum{}
groupsChecked := make(map[string]struct{})
for _, group := range groups {
_, ok := groupsChecked[group.ID]
Expand Down
Loading

0 comments on commit 92b9e11

Please sign in to comment.