diff --git a/cmd/hypervisor/commands/gen-config.go b/cmd/hypervisor/commands/gen-config.go index 4956e04e1..bc69e48b1 100644 --- a/cmd/hypervisor/commands/gen-config.go +++ b/cmd/hypervisor/commands/gen-config.go @@ -10,19 +10,26 @@ import ( "github.com/SkycoinProject/skywire-mainnet/pkg/util/pathutil" ) +// nolint:gochecknoglobals var ( output string replace bool configLocType = pathutil.WorkingDirLoc ) +// nolint:gochecknoinits func init() { + outputUsage := "path of output config file. Uses default of 'type' flag if unspecified." + replaceUsage := "whether to allow rewrite of a file that already exists." + configLocTypeUsage := fmt.Sprintf("config generation mode. Valid values: %v", pathutil.AllConfigLocationTypes()) + rootCmd.AddCommand(genConfigCmd) - genConfigCmd.Flags().StringVarP(&output, "output", "o", "", "path of output config file. Uses default of 'type' flag if unspecified.") - genConfigCmd.Flags().BoolVarP(&replace, "replace", "r", false, "whether to allow rewrite of a file that already exists.") - genConfigCmd.Flags().VarP(&configLocType, "type", "m", fmt.Sprintf("config generation mode. Valid values: %v", pathutil.AllConfigLocationTypes())) + genConfigCmd.Flags().StringVarP(&output, "output", "o", "", outputUsage) + genConfigCmd.Flags().BoolVarP(&replace, "replace", "r", false, replaceUsage) + genConfigCmd.Flags().VarP(&configLocType, "type", "m", configLocTypeUsage) } +// nolint:gochecknoglobals var genConfigCmd = &cobra.Command{ Use: "gen-config", Short: "generates a configuration file", diff --git a/cmd/hypervisor/commands/root.go b/cmd/hypervisor/commands/root.go index 0b3046531..01d65f2df 100644 --- a/cmd/hypervisor/commands/root.go +++ b/cmd/hypervisor/commands/root.go @@ -4,7 +4,6 @@ import ( "fmt" "net" "net/http" - "os" "github.com/SkycoinProject/skycoin/src/util/logging" "github.com/spf13/cobra" @@ -15,6 +14,7 @@ import ( const configEnv = "SW_HYPERVISOR_CONFIG" +// nolint:gochecknoglobals var ( log = logging.MustGetLogger("hypervisor") @@ -26,6 +26,7 @@ var ( mockMaxRoutes int ) +// nolint:gochecknoinits func init() { rootCmd.Flags().StringVarP(&configPath, "config", "c", "./hypervisor-config.json", "hypervisor config path") rootCmd.Flags().BoolVarP(&mock, "mock", "m", false, "whether to run hypervisor with mock data") @@ -35,6 +36,7 @@ func init() { rootCmd.Flags().IntVar(&mockMaxRoutes, "mock-max-routes", 30, "max number of routes per node") } +// nolint:gochecknoglobals var rootCmd = &cobra.Command{ Use: "hypervisor", Short: "Manages Skywire App Nodes", @@ -48,7 +50,8 @@ var rootCmd = &cobra.Command{ if err := config.Parse(configPath); err != nil { log.WithError(err).Fatalln("failed to parse config file") } - fmt.Println(config) + + fmt.Println("Config: \n", config) var ( httpAddr = config.Interfaces.HTTPAddr @@ -95,7 +98,6 @@ var rootCmd = &cobra.Command{ // Execute executes root CLI command. func Execute() { if err := rootCmd.Execute(); err != nil { - fmt.Println(err) - os.Exit(1) + log.Fatal(err) } } diff --git a/pkg/hypervisor/config.go b/pkg/hypervisor/config.go index b49dad3b5..92eaaa2c0 100644 --- a/pkg/hypervisor/config.go +++ b/pkg/hypervisor/config.go @@ -13,6 +13,12 @@ import ( "github.com/SkycoinProject/skywire-mainnet/pkg/util/pathutil" ) +const ( + defaultCookieExpiration = 12 * time.Hour + hashKeyLen = 64 + blockKeyLen = 32 +) + // Key allows a byte slice to be marshaled or unmarshaled from a hex string. type Key []byte @@ -30,28 +36,32 @@ func (hk Key) MarshalText() ([]byte, error) { func (hk *Key) UnmarshalText(text []byte) error { *hk = make([]byte, hex.DecodedLen(len(text))) _, err := hex.Decode(*hk, text) + return err } // Config configures the hypervisor. type Config struct { - PK cipher.PubKey `json:"public_key"` - SK cipher.SecKey `json:"secret_key"` - DBPath string `json:"db_path"` // Path to store database file. - EnableAuth bool `json:"enable_auth"` // Whether to enable user management. Cookies CookieConfig `json:"cookies"` // Configures cookies (for session management). Interfaces InterfaceConfig `json:"interfaces"` // Configures exposed interfaces. + DBPath string `json:"db_path"` // Path to store database file. + EnableAuth bool `json:"enable_auth"` // Whether to enable user management. + PK cipher.PubKey `json:"public_key"` + SK cipher.SecKey `json:"secret_key"` } func makeConfig() Config { var c Config + pk, sk := cipher.GenerateKeyPair() c.PK = pk c.SK = sk c.EnableAuth = true - c.Cookies.HashKey = cipher.RandByte(64) - c.Cookies.BlockKey = cipher.RandByte(32) + c.Cookies.HashKey = cipher.RandByte(hashKeyLen) + c.Cookies.BlockKey = cipher.RandByte(blockKeyLen) + c.FillDefaults() + return c } @@ -61,8 +71,10 @@ func GenerateWorkDirConfig() Config { if err != nil { log.Fatalf("failed to generate WD config: %s", dir) } + c := makeConfig() c.DBPath = filepath.Join(dir, "users.db") + return c } @@ -70,6 +82,7 @@ func GenerateWorkDirConfig() Config { func GenerateHomeConfig() Config { c := makeConfig() c.DBPath = filepath.Join(pathutil.HomeDir(), ".skycoin/hypervisor/users.db") + return c } @@ -77,6 +90,7 @@ func GenerateHomeConfig() Config { func GenerateLocalConfig() Config { c := makeConfig() c.DBPath = "/usr/local/SkycoinProject/hypervisor/users.db" + return c } @@ -92,11 +106,18 @@ func (c *Config) Parse(path string) error { if path, err = filepath.Abs(path); err != nil { return err } + f, err := os.Open(filepath.Clean(path)) if err != nil { return err } - defer func() { catch(f.Close()) }() + + defer func() { + if err := f.Close(); err != nil { + log.Fatalf("Failed to close file %s: %v", f.Name(), err) + } + }() + return json.NewDecoder(f).Decode(c) } @@ -116,7 +137,7 @@ type CookieConfig struct { // FillDefaults fills config with default values. func (c *CookieConfig) FillDefaults() { - c.ExpiresDuration = time.Hour * 12 + c.ExpiresDuration = defaultCookieExpiration c.Path = "/" c.Secure = true c.HTTPOnly = true diff --git a/pkg/hypervisor/hypervisor.go b/pkg/hypervisor/hypervisor.go index 63161dc3f..2f0614c75 100644 --- a/pkg/hypervisor/hypervisor.go +++ b/pkg/hypervisor/hypervisor.go @@ -26,9 +26,18 @@ import ( "github.com/SkycoinProject/skywire-mainnet/pkg/visor" ) -var ( - log = logging.MustGetLogger("hypervisor") +const ( healthTimeout = 5 * time.Second + httpTimeout = 30 * time.Second +) + +const ( + statusStop = iota + statusStart +) + +var ( + log = logging.MustGetLogger("hypervisor") // nolint: gochecknoglobals ) type appNodeConn struct { @@ -50,6 +59,7 @@ func NewNode(config Config) (*Node, error) { if err != nil { return nil, err } + singleUserDB := NewSingleUserStore("admin", boltUserDB) return &Node{ @@ -67,7 +77,9 @@ func (m *Node) ServeRPC(lis net.Listener) error { if err != nil { return err } + addr := conn.RemoteAddr().(*noise.Addr) + m.mu.Lock() m.nodes[addr.PK] = appNodeConn{ Addr: addr, @@ -93,11 +105,13 @@ type MockConfig struct { // AddMockData adds mock data to Node. func (m *Node) AddMockData(config MockConfig) error { r := rand.New(rand.NewSource(time.Now().UnixNano())) + for i := 0; i < config.Nodes; i++ { pk, client, err := visor.NewMockRPCClient(r, config.MaxTpsPerNode, config.MaxRoutesPerNode) if err != nil { return err } + m.mu.Lock() m.nodes[pk] = appNodeConn{ Addr: &noise.Addr{ @@ -108,15 +122,19 @@ func (m *Node) AddMockData(config MockConfig) error { } m.mu.Unlock() } + m.c.EnableAuth = config.EnableAuth + return nil } // ServeHTTP implements http.Handler func (m *Node) ServeHTTP(w http.ResponseWriter, req *http.Request) { r := chi.NewRouter() - r.Use(middleware.Timeout(time.Second * 30)) + + r.Use(middleware.Timeout(httpTimeout)) r.Use(middleware.Logger) + r.Route("/api", func(r chi.Router) { if m.c.EnableAuth { r.Group(func(r chi.Router) { @@ -125,10 +143,12 @@ func (m *Node) ServeHTTP(w http.ResponseWriter, req *http.Request) { r.Post("/logout", m.users.Logout()) }) } + r.Group(func(r chi.Router) { if m.c.EnableAuth { r.Use(m.users.Authorize) } + r.Get("/user", m.users.UserInfo()) r.Post("/change-password", m.users.ChangePassword()) r.Post("/exec/{pk}", m.exec()) @@ -154,6 +174,7 @@ func (m *Node) ServeHTTP(w http.ResponseWriter, req *http.Request) { r.Get("/nodes/{pk}/restart", m.restart()) }) }) + r.ServeHTTP(w, req) } @@ -175,10 +196,12 @@ func (m *Node) getHealth() http.HandlerFunc { resCh := make(chan healthRes) tCh := time.After(healthTimeout) + go func() { hi, err := ctx.RPC.Health() resCh <- healthRes{hi, err} }() + select { case res := <-resCh: if res.err != nil { @@ -187,6 +210,7 @@ func (m *Node) getHealth() http.HandlerFunc { vh.HealthInfo = res.h vh.Status = http.StatusOK } + httputil.WriteJSON(w, r, http.StatusOK, vh) case <-tCh: httputil.WriteJSON(w, r, http.StatusRequestTimeout, &VisorHealth{Status: http.StatusRequestTimeout}) @@ -202,6 +226,7 @@ func (m *Node) getUptime() http.HandlerFunc { httputil.WriteJSON(w, r, http.StatusInternalServerError, err) return } + httputil.WriteJSON(w, r, http.StatusOK, u) }) } @@ -212,6 +237,7 @@ func (m *Node) exec() http.HandlerFunc { var reqBody struct { Command string `json:"command"` } + if err := httputil.ReadJSON(r, &reqBody); err != nil { httputil.WriteJSON(w, r, http.StatusBadRequest, err) return @@ -226,6 +252,7 @@ func (m *Node) exec() http.HandlerFunc { output := struct { Output string `json:"output"` }{string(out)} + httputil.WriteJSON(w, r, http.StatusOK, output) }) } @@ -240,13 +267,16 @@ type summaryResp struct { func (m *Node) getNodes() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var summaries []summaryResp + m.mu.RLock() for pk, c := range m.nodes { summary, err := c.Client.Summary() if err != nil { - log.Printf("failed to obtain summary from AppNode with pk %s. Error: %v", pk, err) + log.Errorf("failed to obtain summary from AppNode with pk %s. Error: %v", pk, err) + summary = &visor.Summary{PubKey: pk} } + summaries = append(summaries, summaryResp{ TCPAddr: c.Addr.Addr.String(), Online: err == nil, @@ -254,6 +284,7 @@ func (m *Node) getNodes() http.HandlerFunc { }) } m.mu.RUnlock() + httputil.WriteJSON(w, r, http.StatusOK, summaries) } } @@ -266,6 +297,7 @@ func (m *Node) getNode() http.HandlerFunc { httputil.WriteJSON(w, r, http.StatusInternalServerError, err) return } + httputil.WriteJSON(w, r, http.StatusOK, summaryResp{ TCPAddr: ctx.Addr.Addr.String(), Summary: summary, @@ -281,6 +313,7 @@ func (m *Node) getApps() http.HandlerFunc { httputil.WriteJSON(w, r, http.StatusInternalServerError, err) return } + httputil.WriteJSON(w, r, http.StatusOK, apps) }) } @@ -292,10 +325,12 @@ func (m *Node) getApp() http.HandlerFunc { }) } +// TODO: simplify +// nolint: funlen,gocognit,godox func (m *Node) putApp() http.HandlerFunc { return m.withCtx(m.appCtx, func(w http.ResponseWriter, r *http.Request, ctx *httpCtx) { var reqBody struct { - Autostart *bool `json:"autostart,omitempty"` + AutoStart *bool `json:"autostart,omitempty"` Status *int `json:"status,omitempty"` Passcode *string `json:"passcode,omitempty"` PK *cipher.PubKey `json:"pk,omitempty"` @@ -306,9 +341,9 @@ func (m *Node) putApp() http.HandlerFunc { return } - if reqBody.Autostart != nil { - if *reqBody.Autostart != ctx.App.AutoStart { - if err := ctx.RPC.SetAutoStart(ctx.App.Name, *reqBody.Autostart); err != nil { + if reqBody.AutoStart != nil { + if *reqBody.AutoStart != ctx.App.AutoStart { + if err := ctx.RPC.SetAutoStart(ctx.App.Name, *reqBody.AutoStart); err != nil { httputil.WriteJSON(w, r, http.StatusInternalServerError, err) return } @@ -317,19 +352,19 @@ func (m *Node) putApp() http.HandlerFunc { if reqBody.Status != nil { switch *reqBody.Status { - case 0: + case statusStop: if err := ctx.RPC.StopApp(ctx.App.Name); err != nil { httputil.WriteJSON(w, r, http.StatusInternalServerError, err) return } - case 1: + case statusStart: if err := ctx.RPC.StartApp(ctx.App.Name); err != nil { httputil.WriteJSON(w, r, http.StatusInternalServerError, err) return } default: - httputil.WriteJSON(w, r, http.StatusBadRequest, - fmt.Errorf("value of 'status' field is %d when expecting 0 or 1", *reqBody.Status)) + errMsg := fmt.Errorf("value of 'status' field is %d when expecting 0 or 1", *reqBody.Status) + httputil.WriteJSON(w, r, http.StatusBadRequest, errMsg) return } } @@ -367,11 +402,12 @@ func (m *Node) appLogsSince() http.HandlerFunc { return m.withCtx(m.appCtx, func(w http.ResponseWriter, r *http.Request, ctx *httpCtx) { since := r.URL.Query().Get("since") - // if time is not parseable or empty default to return all logs + // if time is not parsable or empty default to return all logs t, err := time.Parse(time.RFC3339Nano, since) if err != nil { t = time.Unix(0, 0) } + logs, err := ctx.RPC.LogsSince(t, ctx.App.Name) if err != nil { httputil.WriteJSON(w, r, http.StatusInternalServerError, err) @@ -397,27 +433,27 @@ func (m *Node) getTransportTypes() http.HandlerFunc { httputil.WriteJSON(w, r, http.StatusInternalServerError, err) return } + httputil.WriteJSON(w, r, http.StatusOK, types) }) } func (m *Node) getTransports() http.HandlerFunc { return m.withCtx(m.nodeCtx, func(w http.ResponseWriter, r *http.Request, ctx *httpCtx) { - var ( - qTypes []string - qPKs []cipher.PubKey - qLogs bool - ) - var err error - qTypes = strSliceFromQuery(r, "type", nil) - if qPKs, err = pkSliceFromQuery(r, "pk", nil); err != nil { + qTypes := strSliceFromQuery(r, "type", nil) + + qPKs, err := pkSliceFromQuery(r, "pk", nil) + if err != nil { httputil.WriteJSON(w, r, http.StatusBadRequest, err) return } - if qLogs, err = httputil.BoolFromQuery(r, "logs", true); err != nil { + + qLogs, err := httputil.BoolFromQuery(r, "logs", true) + if err != nil { httputil.WriteJSON(w, r, http.StatusBadRequest, err) return } + transports, err := ctx.RPC.Transports(qTypes, qPKs, qLogs) if err != nil { httputil.WriteJSON(w, r, http.StatusInternalServerError, err) @@ -430,19 +466,23 @@ func (m *Node) getTransports() http.HandlerFunc { func (m *Node) postTransport() http.HandlerFunc { return m.withCtx(m.nodeCtx, func(w http.ResponseWriter, r *http.Request, ctx *httpCtx) { var reqBody struct { - Remote cipher.PubKey `json:"remote_pk"` TpType string `json:"transport_type"` + Remote cipher.PubKey `json:"remote_pk"` Public bool `json:"public"` } + if err := httputil.ReadJSON(r, &reqBody); err != nil { httputil.WriteJSON(w, r, http.StatusBadRequest, err) return } - summary, err := ctx.RPC.AddTransport(reqBody.Remote, reqBody.TpType, reqBody.Public, 30*time.Second) + + const timeout = 30 * time.Second + summary, err := ctx.RPC.AddTransport(reqBody.Remote, reqBody.TpType, reqBody.Public, timeout) if err != nil { httputil.WriteJSON(w, r, http.StatusInternalServerError, err) return } + httputil.WriteJSON(w, r, http.StatusOK, summary) }) } @@ -459,6 +499,7 @@ func (m *Node) deleteTransport() http.HandlerFunc { httputil.WriteJSON(w, r, http.StatusInternalServerError, err) return } + httputil.WriteJSON(w, r, http.StatusOK, true) }) } @@ -474,9 +515,11 @@ func makeRoutingRuleResp(key routing.RouteID, rule routing.Rule, summary bool) r Key: key, Rule: hex.EncodeToString(rule), } + if summary { resp.Summary = rule.Summary() } + return resp } @@ -487,15 +530,18 @@ func (m *Node) getRoutes() http.HandlerFunc { httputil.WriteJSON(w, r, http.StatusBadRequest, err) return } + rules, err := ctx.RPC.RoutingRules() if err != nil { httputil.WriteJSON(w, r, http.StatusInternalServerError, err) return } + resp := make([]routingRuleResp, len(rules)) for i, rule := range rules { resp[i] = makeRoutingRuleResp(rule.KeyRouteID(), rule, qSummary) } + httputil.WriteJSON(w, r, http.StatusOK, resp) }) } @@ -507,6 +553,7 @@ func (m *Node) postRoute() http.HandlerFunc { httputil.WriteJSON(w, r, http.StatusBadRequest, err) return } + rule, err := summary.ToRule() if err != nil { httputil.WriteJSON(w, r, http.StatusBadRequest, err) @@ -517,6 +564,7 @@ func (m *Node) postRoute() http.HandlerFunc { httputil.WriteJSON(w, r, http.StatusInternalServerError, err) return } + httputil.WriteJSON(w, r, http.StatusOK, makeRoutingRuleResp(rule.KeyRouteID(), rule, true)) }) } @@ -528,11 +576,13 @@ func (m *Node) getRoute() http.HandlerFunc { httputil.WriteJSON(w, r, http.StatusBadRequest, err) return } + rule, err := ctx.RPC.RoutingRule(ctx.RtKey) if err != nil { httputil.WriteJSON(w, r, http.StatusNotFound, err) return } + httputil.WriteJSON(w, r, http.StatusOK, makeRoutingRuleResp(ctx.RtKey, rule, qSummary)) }) } @@ -544,15 +594,18 @@ func (m *Node) putRoute() http.HandlerFunc { httputil.WriteJSON(w, r, http.StatusBadRequest, err) return } + rule, err := summary.ToRule() if err != nil { httputil.WriteJSON(w, r, http.StatusBadRequest, err) return } + if err := ctx.RPC.SaveRoutingRule(rule); err != nil { httputil.WriteJSON(w, r, http.StatusInternalServerError, err) return } + httputil.WriteJSON(w, r, http.StatusOK, makeRoutingRuleResp(ctx.RtKey, rule, true)) }) } @@ -563,6 +616,7 @@ func (m *Node) deleteRoute() http.HandlerFunc { httputil.WriteJSON(w, r, http.StatusNotFound, err) return } + httputil.WriteJSON(w, r, http.StatusOK, true) }) } @@ -576,6 +630,7 @@ func makeLoopResp(info visor.LoopInfo) loopResp { if len(info.FwdRule) == 0 || len(info.ConsumeRule) == 0 { return loopResp{} } + return loopResp{ RuleConsumeFields: *info.ConsumeRule.Summary().ConsumeFields, FwdRule: *info.FwdRule.Summary().ForwardFields, @@ -589,10 +644,12 @@ func (m *Node) getLoops() http.HandlerFunc { httputil.WriteJSON(w, r, http.StatusInternalServerError, err) return } + resp := make([]loopResp, len(loops)) for i, l := range loops { resp[i] = makeLoopResp(l) } + httputil.WriteJSON(w, r, http.StatusOK, resp) }) } @@ -617,15 +674,11 @@ func (m *Node) client(pk cipher.PubKey) (*noise.Addr, visor.RPCClient, bool) { m.mu.RLock() conn, ok := m.nodes[pk] m.mu.RUnlock() + return conn.Addr, conn.Client, ok } type httpCtx struct { - // Node - PK cipher.PubKey - Addr *noise.Addr - RPC visor.RPCClient - // App App *visor.AppState @@ -634,6 +687,11 @@ type httpCtx struct { // Route RtKey routing.RouteID + + // Node + PK cipher.PubKey + Addr *noise.Addr + RPC visor.RPCClient } type ( @@ -655,11 +713,13 @@ func (m *Node) nodeCtx(w http.ResponseWriter, r *http.Request) (*httpCtx, bool) httputil.WriteJSON(w, r, http.StatusBadRequest, err) return nil, false } + addr, client, ok := m.client(pk) if !ok { httputil.WriteJSON(w, r, http.StatusNotFound, fmt.Errorf("node of pk '%s' not found", pk)) return nil, false } + return &httpCtx{ PK: pk, Addr: addr, @@ -672,20 +732,25 @@ func (m *Node) appCtx(w http.ResponseWriter, r *http.Request) (*httpCtx, bool) { if !ok { return nil, false } + appName := chi.URLParam(r, "app") + apps, err := ctx.RPC.Apps() if err != nil { httputil.WriteJSON(w, r, http.StatusInternalServerError, err) return nil, false } - for _, app := range apps { - if app.Name == appName { - ctx.App = app + + for _, a := range apps { + if a.Name == appName { + ctx.App = a return ctx, true } } - httputil.WriteJSON(w, r, http.StatusNotFound, - fmt.Errorf("can not find app of name %s from node %s", appName, ctx.PK)) + + errMsg := fmt.Errorf("can not find app of name %s from node %s", appName, ctx.PK) + httputil.WriteJSON(w, r, http.StatusNotFound, errMsg) + return nil, false } @@ -694,41 +759,53 @@ func (m *Node) tpCtx(w http.ResponseWriter, r *http.Request) (*httpCtx, bool) { if !ok { return nil, false } + tid, err := uuidFromParam(r, "tid") if err != nil { httputil.WriteJSON(w, r, http.StatusBadRequest, err) return nil, false } + tp, err := ctx.RPC.Transport(tid) if err != nil { if err.Error() == visor.ErrNotFound.Error() { - httputil.WriteJSON(w, r, http.StatusNotFound, - fmt.Errorf("transport of ID %s is not found", tid)) + errMsg := fmt.Errorf("transport of ID %s is not found", tid) + httputil.WriteJSON(w, r, http.StatusNotFound, errMsg) + return nil, false } + httputil.WriteJSON(w, r, http.StatusInternalServerError, err) + return nil, false } + ctx.Tp = tp + return ctx, true } func (m *Node) routeCtx(w http.ResponseWriter, r *http.Request) (*httpCtx, bool) { - ctx, ok := m.tpCtx(w, r) + ctx, ok := m.nodeCtx(w, r) if !ok { return nil, false } - rid, err := ridFromParam(r, "key") + + rid, err := ridFromParam(r, "rid") if err != nil { httputil.WriteJSON(w, r, http.StatusBadRequest, err) + return nil, false } + ctx.RtKey = rid + return ctx, true } func pkFromParam(r *http.Request, key string) (cipher.PubKey, error) { pk := cipher.PubKey{} err := pk.UnmarshalText([]byte(chi.URLParam(r, key))) + return pk, err } @@ -741,6 +818,7 @@ func ridFromParam(r *http.Request, key string) (routing.RouteID, error) { if err != nil { return 0, errors.New("invalid route ID provided") } + return routing.RouteID(rid), nil } @@ -749,6 +827,7 @@ func strSliceFromQuery(r *http.Request, key string, defaultVal []string) []strin if !ok { return defaultVal } + return slice } @@ -757,23 +836,17 @@ func pkSliceFromQuery(r *http.Request, key string, defaultVal []cipher.PubKey) ( if !ok { return defaultVal, nil } + pks := make([]cipher.PubKey, len(qPKs)) + for i, qPK := range qPKs { pk := cipher.PubKey{} if err := pk.UnmarshalText([]byte(qPK)); err != nil { return nil, err } + pks[i] = pk } - return pks, nil -} -func catch(err error, msgs ...string) { - if err != nil { - if len(msgs) > 0 { - log.Fatalln(append(msgs, err.Error())) - } else { - log.Fatalln(err) - } - } + return pks, nil } diff --git a/pkg/hypervisor/hypervisor_test.go b/pkg/hypervisor/hypervisor_test.go index 1797a9c70..7a1c6ceaa 100644 --- a/pkg/hypervisor/hypervisor_test.go +++ b/pkg/hypervisor/hypervisor_test.go @@ -25,6 +25,7 @@ func TestMain(m *testing.M) { if err != nil { log.Fatal(err) } + logging.SetLevel(lvl) } else { logging.Disable() @@ -33,360 +34,403 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } +// nolint: gosec +const ( + goodPayload = `{"username":"admin","password":"Secure1234!"}` + changePasswordPayload = `{"old_password":"Secure1234!","new_password":"NewSecure1234!"}` + changedPasswordPayload = `{"username":"admin","password":"NewSecure1234!"}` + badCreateAccountPayload = `{"username":"invalid_user","password":"Secure1234!"}` +) + func TestNewNode(t *testing.T) { config := makeConfig() confDir, err := ioutil.TempDir(os.TempDir(), "SWHV") require.NoError(t, err) + config.DBPath = filepath.Join(confDir, "users.db") - defaultMockConfig := func() MockConfig { - return MockConfig{ - Nodes: 5, - MaxTpsPerNode: 10, - MaxRoutesPerNode: 10, - EnableAuth: true, - } + t.Run("no_access_without_login", func(t *testing.T) { + testNodeNoAccessWithoutLogin(t, config) + }) + + t.Run("only_admin_account_allowed", func(t *testing.T) { + testNodeOnlyAdminAccountAllowed(t, config) + }) + + t.Run("cannot_login_twice", func(t *testing.T) { + testNodeCannotLoginTwice(t, config) + }) + + t.Run("access_after_login", func(t *testing.T) { + testNodeAccessAfterLogin(t, config) + }) + + t.Run("no_access_after_logout", func(t *testing.T) { + testNodeNoAccessAfterLogout(t, config) + }) + + t.Run("change_password", func(t *testing.T) { + testNodeChangePassword(t, config) + }) +} + +func makeStartNode(t *testing.T, config Config) (string, *http.Client, func()) { + // nolint: gomnd + defaultMockConfig := MockConfig{ + Nodes: 5, + MaxTpsPerNode: 10, + MaxRoutesPerNode: 10, + EnableAuth: true, } - startNode := func(mock MockConfig) (string, *http.Client, func()) { - node, err := NewNode(config) - require.NoError(t, err) - require.NoError(t, node.AddMockData(mock)) + node, err := NewNode(config) + require.NoError(t, err) + require.NoError(t, node.AddMockData(defaultMockConfig)) - srv := httptest.NewTLSServer(node) - node.c.Cookies.Domain = srv.Listener.Addr().String() + srv := httptest.NewTLSServer(node) + node.c.Cookies.Domain = srv.Listener.Addr().String() - client := srv.Client() - jar, err := cookiejar.New(&cookiejar.Options{}) - require.NoError(t, err) - client.Jar = jar + client := srv.Client() + jar, err := cookiejar.New(&cookiejar.Options{}) + require.NoError(t, err) - return srv.Listener.Addr().String(), client, func() { - srv.Close() - require.NoError(t, os.Remove(config.DBPath)) - } + client.Jar = jar + + return srv.Listener.Addr().String(), client, func() { + srv.Close() + require.NoError(t, os.Remove(config.DBPath)) } +} + +type TestCase struct { + ReqMethod string + ReqURI string + ReqBody io.Reader + ReqMod func(req *http.Request) + RespStatus int + RespBody func(t *testing.T, resp *http.Response) +} - type TestCase struct { - ReqMethod string - ReqURI string - ReqBody io.Reader - ReqMod func(req *http.Request) - RespStatus int - RespBody func(t *testing.T, resp *http.Response) +func testCases(t *testing.T, addr string, client *http.Client, cases []TestCase) { + for i, tc := range cases { + testTag := fmt.Sprintf("[%d] %s", i, tc.ReqURI) + + testCase(t, addr, client, tc, testTag) } +} - testCases := func(t *testing.T, addr string, client *http.Client, cases []TestCase) { - for i, tc := range cases { - testTag := fmt.Sprintf("[%d] %s", i, tc.ReqURI) +func testCase(t *testing.T, addr string, client *http.Client, tc TestCase, testTag string) { + req, err := http.NewRequest(tc.ReqMethod, "https://"+addr+tc.ReqURI, tc.ReqBody) + require.NoError(t, err, testTag) - req, err := http.NewRequest(tc.ReqMethod, "https://"+addr+tc.ReqURI, tc.ReqBody) - require.NoError(t, err, testTag) + if tc.ReqMod != nil { + tc.ReqMod(req) + } - if tc.ReqMod != nil { - tc.ReqMod(req) - } + resp, err := client.Do(req) + if resp != nil { + defer func() { + assert.NoError(t, resp.Body.Close()) + }() + } - resp, err := client.Do(req) - require.NoError(t, err, testTag) + require.NoError(t, err, testTag) + assert.Equal(t, tc.RespStatus, resp.StatusCode, testTag) - assert.Equal(t, tc.RespStatus, resp.StatusCode, testTag) - if tc.RespBody != nil { - tc.RespBody(t, resp) - } - } + if tc.RespBody != nil { + tc.RespBody(t, resp) } +} - t.Run("no_access_without_login", func(t *testing.T) { - addr, client, stop := startNode(defaultMockConfig()) - defer stop() - - makeCase := func(method string, uri string, body io.Reader) TestCase { - return TestCase{ - ReqMethod: method, - ReqURI: uri, - ReqBody: body, - RespStatus: http.StatusUnauthorized, - RespBody: func(t *testing.T, r *http.Response) { - body, err := decodeErrorBody(r.Body) - assert.NoError(t, err) - assert.Equal(t, ErrBadSession.Error(), body.Error) - }, - } +func testNodeNoAccessWithoutLogin(t *testing.T, config Config) { + addr, client, stop := makeStartNode(t, config) + defer stop() + + makeCase := func(method string, uri string, body io.Reader) TestCase { + return TestCase{ + ReqMethod: method, + ReqURI: uri, + ReqBody: body, + RespStatus: http.StatusUnauthorized, + RespBody: func(t *testing.T, r *http.Response) { + body, err := decodeErrorBody(r.Body) + assert.NoError(t, err) + assert.Equal(t, ErrBadSession.Error(), body.Error) + }, } + } - testCases(t, addr, client, []TestCase{ - makeCase(http.MethodGet, "/api/user", nil), - makeCase(http.MethodPost, "/api/change-password", strings.NewReader(`{"old_password":"old","new_password":"new"}`)), - makeCase(http.MethodGet, "/api/nodes", nil), - }) + testCases(t, addr, client, []TestCase{ + makeCase(http.MethodGet, "/api/user", nil), + makeCase(http.MethodPost, "/api/change-password", strings.NewReader(`{"old_password":"old","new_password":"new"}`)), + makeCase(http.MethodGet, "/api/nodes", nil), }) +} - t.Run("only_admin_account_allowed", func(t *testing.T) { - addr, client, stop := startNode(defaultMockConfig()) - defer stop() - - testCases(t, addr, client, []TestCase{ - { - ReqMethod: http.MethodPost, - ReqURI: "/api/create-account", - ReqBody: strings.NewReader(`{"username":"invalid_user","password":"Secure1234"}`), - RespStatus: http.StatusForbidden, - RespBody: func(t *testing.T, r *http.Response) { - body, err := decodeErrorBody(r.Body) - assert.NoError(t, err) - assert.Equal(t, ErrUserNotCreated.Error(), body.Error) - }, +func testNodeOnlyAdminAccountAllowed(t *testing.T, config Config) { + addr, client, stop := makeStartNode(t, config) + defer stop() + + testCases(t, addr, client, []TestCase{ + { + ReqMethod: http.MethodPost, + ReqURI: "/api/create-account", + ReqBody: strings.NewReader(badCreateAccountPayload), + RespStatus: http.StatusForbidden, + RespBody: func(t *testing.T, r *http.Response) { + body, err := decodeErrorBody(r.Body) + assert.NoError(t, err) + assert.Equal(t, ErrNameNotAllowed.Error(), body.Error) }, - { - ReqMethod: http.MethodPost, - ReqURI: "/api/create-account", - ReqBody: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), - RespStatus: http.StatusOK, - RespBody: func(t *testing.T, r *http.Response) { - var ok bool - assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) - assert.True(t, ok) - }, + }, + { + ReqMethod: http.MethodPost, + ReqURI: "/api/create-account", + ReqBody: strings.NewReader(goodPayload), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) }, - }) + }, }) +} - t.Run("cannot_login_twice", func(t *testing.T) { - addr, client, stop := startNode(defaultMockConfig()) - defer stop() - - testCases(t, addr, client, []TestCase{ - { - ReqMethod: http.MethodPost, - ReqURI: "/api/create-account", - ReqBody: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), - RespStatus: http.StatusOK, - RespBody: func(t *testing.T, r *http.Response) { - var ok bool - assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) - assert.True(t, ok) - }, +func testNodeCannotLoginTwice(t *testing.T, config Config) { + addr, client, stop := makeStartNode(t, config) + defer stop() + + testCases(t, addr, client, []TestCase{ + { + ReqMethod: http.MethodPost, + ReqURI: "/api/create-account", + ReqBody: strings.NewReader(goodPayload), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) }, - { - ReqMethod: http.MethodPost, - ReqURI: "/api/login", - ReqBody: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), - RespStatus: http.StatusOK, - RespBody: func(t *testing.T, r *http.Response) { - var ok bool - assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) - assert.True(t, ok) - }, + }, + { + ReqMethod: http.MethodPost, + ReqURI: "/api/login", + ReqBody: strings.NewReader(goodPayload), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) }, - { - ReqMethod: http.MethodPost, - ReqURI: "/api/login", - ReqBody: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), - RespStatus: http.StatusForbidden, - RespBody: func(t *testing.T, r *http.Response) { - body, err := decodeErrorBody(r.Body) - assert.NoError(t, err) - assert.Equal(t, ErrNotLoggedOut.Error(), body.Error) - }, + }, + { + ReqMethod: http.MethodPost, + ReqURI: "/api/login", + ReqBody: strings.NewReader(goodPayload), + RespStatus: http.StatusForbidden, + RespBody: func(t *testing.T, r *http.Response) { + body, err := decodeErrorBody(r.Body) + assert.NoError(t, err) + assert.Equal(t, ErrNotLoggedOut.Error(), body.Error) }, - }) + }, }) +} - t.Run("access_after_login", func(t *testing.T) { - addr, client, stop := startNode(defaultMockConfig()) - defer stop() - - testCases(t, addr, client, []TestCase{ - { - ReqMethod: http.MethodPost, - ReqURI: "/api/create-account", - ReqBody: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), - RespStatus: http.StatusOK, - RespBody: func(t *testing.T, r *http.Response) { - var ok bool - assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) - assert.True(t, ok) - }, - }, - { - ReqMethod: http.MethodPost, - ReqURI: "/api/login", - ReqBody: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), - RespStatus: http.StatusOK, - RespBody: func(t *testing.T, r *http.Response) { - var ok bool - assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) - assert.True(t, ok) - }, +func testNodeAccessAfterLogin(t *testing.T, config Config) { + addr, client, stop := makeStartNode(t, config) + defer stop() + + testCases(t, addr, client, []TestCase{ + { + ReqMethod: http.MethodPost, + ReqURI: "/api/create-account", + ReqBody: strings.NewReader(goodPayload), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) }, - { - ReqMethod: http.MethodGet, - ReqURI: "/api/user", - RespStatus: http.StatusOK, + }, + { + ReqMethod: http.MethodPost, + ReqURI: "/api/login", + ReqBody: strings.NewReader(goodPayload), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) }, - { - ReqMethod: http.MethodGet, - ReqURI: "/api/nodes", - RespStatus: http.StatusOK, - }, - }) + }, + { + ReqMethod: http.MethodGet, + ReqURI: "/api/user", + RespStatus: http.StatusOK, + }, + { + ReqMethod: http.MethodGet, + ReqURI: "/api/nodes", + RespStatus: http.StatusOK, + }, }) +} - t.Run("no_access_after_logout", func(t *testing.T) { - addr, client, stop := startNode(defaultMockConfig()) - defer stop() - - testCases(t, addr, client, []TestCase{ - { - ReqMethod: http.MethodPost, - ReqURI: "/api/create-account", - ReqBody: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), - RespStatus: http.StatusOK, - RespBody: func(t *testing.T, r *http.Response) { - var ok bool - assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) - assert.True(t, ok) - }, +func testNodeNoAccessAfterLogout(t *testing.T, config Config) { + addr, client, stop := makeStartNode(t, config) + defer stop() + + testCases(t, addr, client, []TestCase{ + { + ReqMethod: http.MethodPost, + ReqURI: "/api/create-account", + ReqBody: strings.NewReader(goodPayload), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) }, - { - ReqMethod: http.MethodPost, - ReqURI: "/api/login", - ReqBody: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), - RespStatus: http.StatusOK, - RespBody: func(t *testing.T, r *http.Response) { - var ok bool - assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) - assert.True(t, ok) - }, + }, + { + ReqMethod: http.MethodPost, + ReqURI: "/api/login", + ReqBody: strings.NewReader(goodPayload), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) }, - { - ReqMethod: http.MethodPost, - ReqURI: "/api/logout", - RespStatus: http.StatusOK, - RespBody: func(t *testing.T, r *http.Response) { - var ok bool - assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) - assert.True(t, ok) - }, + }, + { + ReqMethod: http.MethodPost, + ReqURI: "/api/logout", + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) }, - { - ReqMethod: http.MethodGet, - ReqURI: "/api/user", - RespStatus: http.StatusUnauthorized, - RespBody: func(t *testing.T, r *http.Response) { - body, err := decodeErrorBody(r.Body) - assert.NoError(t, err) - assert.Equal(t, ErrBadSession.Error(), body.Error) - }, + }, + { + ReqMethod: http.MethodGet, + ReqURI: "/api/user", + RespStatus: http.StatusUnauthorized, + RespBody: func(t *testing.T, r *http.Response) { + body, err := decodeErrorBody(r.Body) + assert.NoError(t, err) + assert.Equal(t, ErrBadSession.Error(), body.Error) }, - { - ReqMethod: http.MethodGet, - ReqURI: "/api/nodes", - RespStatus: http.StatusUnauthorized, - RespBody: func(t *testing.T, r *http.Response) { - body, err := decodeErrorBody(r.Body) - assert.NoError(t, err) - assert.Equal(t, ErrBadSession.Error(), body.Error) - }, + }, + { + ReqMethod: http.MethodGet, + ReqURI: "/api/nodes", + RespStatus: http.StatusUnauthorized, + RespBody: func(t *testing.T, r *http.Response) { + body, err := decodeErrorBody(r.Body) + assert.NoError(t, err) + assert.Equal(t, ErrBadSession.Error(), body.Error) }, - }) + }, }) +} - t.Run("change_password", func(t *testing.T) { - // - Create account. - // - Login. - // - Change Password. - // - Attempt action (should fail). - // - Logout. - // - Login with old password (should fail). - // - Login with new password (should succeed). - - addr, client, stop := startNode(defaultMockConfig()) - defer stop() - - // To emulate an active session. - var cookies []*http.Cookie - - testCases(t, addr, client, []TestCase{ - { - ReqMethod: http.MethodPost, - ReqURI: "/api/create-account", - ReqBody: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), - RespStatus: http.StatusOK, - RespBody: func(t *testing.T, r *http.Response) { - var ok bool - assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) - assert.True(t, ok) - }, +// - Create account. +// - Login. +// - Change Password. +// - Attempt action (should fail). +// - Logout. +// - Login with old password (should fail). +// - Login with new password (should succeed). +// nolint: funlen +func testNodeChangePassword(t *testing.T, config Config) { + addr, client, stop := makeStartNode(t, config) + defer stop() + + // To emulate an active session. + var cookies []*http.Cookie + + testCases(t, addr, client, []TestCase{ + { + ReqMethod: http.MethodPost, + ReqURI: "/api/create-account", + ReqBody: strings.NewReader(goodPayload), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) }, - { - ReqMethod: http.MethodPost, - ReqURI: "/api/login", - ReqBody: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), - RespStatus: http.StatusOK, - RespBody: func(t *testing.T, r *http.Response) { - cookies = r.Cookies() - var ok bool - assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) - assert.True(t, ok) - }, + }, + { + ReqMethod: http.MethodPost, + ReqURI: "/api/login", + ReqBody: strings.NewReader(goodPayload), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + cookies = r.Cookies() + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) }, - { - ReqMethod: http.MethodPost, - ReqURI: "/api/change-password", - ReqBody: strings.NewReader(`{"old_password":"Secure1234","new_password":"NewSecure1234"}`), - RespStatus: http.StatusOK, - RespBody: func(t *testing.T, r *http.Response) { - var ok bool - assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) - assert.True(t, ok) - }, + }, + { + ReqMethod: http.MethodPost, + ReqURI: "/api/change-password", + ReqBody: strings.NewReader(changePasswordPayload), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) }, - { - ReqMethod: http.MethodGet, - ReqURI: "/api/nodes", - ReqMod: func(req *http.Request) { - for _, cookie := range cookies { - req.AddCookie(cookie) - } - }, - RespStatus: http.StatusUnauthorized, + }, + { + ReqMethod: http.MethodGet, + ReqURI: "/api/nodes", + ReqMod: func(req *http.Request) { + for _, cookie := range cookies { + req.AddCookie(cookie) + } }, - { - ReqMethod: http.MethodPost, - ReqURI: "/api/logout", - RespStatus: http.StatusOK, - RespBody: func(t *testing.T, r *http.Response) { - var ok bool - assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) - assert.True(t, ok) - }, + RespStatus: http.StatusUnauthorized, + }, + { + ReqMethod: http.MethodPost, + ReqURI: "/api/logout", + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) }, - { - ReqMethod: http.MethodPost, - ReqURI: "/api/login", - ReqBody: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), - RespStatus: http.StatusUnauthorized, - RespBody: func(t *testing.T, r *http.Response) { - b, err := decodeErrorBody(r.Body) - assert.NoError(t, err) - require.Equal(t, ErrBadLogin.Error(), b.Error) - }, + }, + { + ReqMethod: http.MethodPost, + ReqURI: "/api/login", + ReqBody: strings.NewReader(goodPayload), + RespStatus: http.StatusUnauthorized, + RespBody: func(t *testing.T, r *http.Response) { + b, err := decodeErrorBody(r.Body) + assert.NoError(t, err) + require.Equal(t, ErrBadLogin.Error(), b.Error) }, - { - ReqMethod: http.MethodPost, - ReqURI: "/api/login", - ReqBody: strings.NewReader(`{"username":"admin","password":"NewSecure1234"}`), - RespStatus: http.StatusOK, - RespBody: func(t *testing.T, r *http.Response) { - var ok bool - assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) - assert.True(t, ok) - }, + }, + { + ReqMethod: http.MethodPost, + ReqURI: "/api/login", + ReqBody: strings.NewReader(changedPasswordPayload), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) }, - }) + }, }) } @@ -398,5 +442,6 @@ func decodeErrorBody(rb io.Reader) (*ErrorBody, error) { b := new(ErrorBody) dec := json.NewDecoder(rb) dec.DisallowUnknownFields() + return b, dec.Decode(b) } diff --git a/pkg/hypervisor/user.go b/pkg/hypervisor/user.go index 898eb48ea..063549095 100644 --- a/pkg/hypervisor/user.go +++ b/pkg/hypervisor/user.go @@ -3,10 +3,12 @@ package hypervisor import ( "bytes" "encoding/gob" + "fmt" "os" "path/filepath" "regexp" "time" + "unicode" "github.com/SkycoinProject/dmsg/cipher" "go.etcd.io/bbolt" @@ -16,8 +18,21 @@ const ( boltTimeout = 10 * time.Second boltUserBucketName = "users" passwordSaltLen = 16 + minPasswordLen = 6 + maxPasswordLen = 64 + ownerRW = 0600 + ownerRWX = 0700 ) +// Errors returned by UserStore. +var ( + ErrBadPasswordLen = fmt.Errorf("password length should be between %d and %d chars", minPasswordLen, maxPasswordLen) + ErrSimplePassword = fmt.Errorf("password must have at least one upper, lower, digit and special character") + ErrUserExists = fmt.Errorf("username already exists") + ErrNameNotAllowed = fmt.Errorf("name not allowed") +) + +// nolint: gochecknoinits func init() { gob.Register(User{}) } @@ -31,21 +46,25 @@ type User struct { // SetName checks the provided name, and sets the name if format is valid. func (u *User) SetName(name string) bool { - if !UsernameFormatOkay(name) { + if !checkUsernameFormat(name) { return false } + u.Name = name + return true } // SetPassword checks the provided password, and sets the password if format is valid. -func (u *User) SetPassword(password string) bool { - if !PasswordFormatOkay(password) { - return false +func (u *User) SetPassword(password string) error { + if err := checkPasswordFormat(password); err != nil { + return err } + u.PwSalt = cipher.RandByte(passwordSaltLen) u.PwHash = cipher.SumSHA256(append([]byte(password), u.PwSalt...)) - return true + + return nil } // VerifyPassword verifies the password input with hash and salt. @@ -54,29 +73,31 @@ func (u *User) VerifyPassword(password string) bool { } // Encode encodes the user to bytes. -func (u *User) Encode() []byte { +func (u *User) Encode() ([]byte, error) { var buf bytes.Buffer if err := gob.NewEncoder(&buf).Encode(u); err != nil { - catch(err, "unexpected user encode error:") + return nil, fmt.Errorf("unexpected user encode error: %w", err) } - return buf.Bytes() + + return buf.Bytes(), nil } // DecodeUser decodes the user from bytes. -func DecodeUser(raw []byte) User { +func DecodeUser(raw []byte) (*User, error) { var user User if err := gob.NewDecoder(bytes.NewReader(raw)).Decode(&user); err != nil { - catch(err, "unexpected decode user error:") + return nil, fmt.Errorf("unexpected decode user error: %w", err) } - return user + + return &user, nil } // UserStore stores users. type UserStore interface { - User(name string) (User, bool) - AddUser(user User) bool - SetUser(user User) bool - RemoveUser(name string) + User(name string) (*User, error) + AddUser(user User) error + SetUser(user User) error + RemoveUser(name string) error } // BoltUserStore implements UserStore, storing users in a bbolt database file. @@ -86,130 +107,173 @@ type BoltUserStore struct { // NewBoltUserStore creates a new BoltUserStore. func NewBoltUserStore(path string) (*BoltUserStore, error) { - if err := os.MkdirAll(filepath.Dir(path), os.FileMode(0700)); err != nil { + if err := os.MkdirAll(filepath.Dir(path), os.FileMode(ownerRWX)); err != nil { return nil, err } - db, err := bbolt.Open(path, os.FileMode(0600), &bbolt.Options{Timeout: boltTimeout}) + + db, err := bbolt.Open(path, os.FileMode(ownerRW), &bbolt.Options{Timeout: boltTimeout}) if err != nil { return nil, err } + err = db.Update(func(tx *bbolt.Tx) error { _, err := tx.CreateBucketIfNotExists([]byte(boltUserBucketName)) return err }) + return &BoltUserStore{DB: db}, err } -// User obtains a single user. Returns true if user exists. -func (s *BoltUserStore) User(name string) (user User, ok bool) { - catch(s.View(func(tx *bbolt.Tx) error { +// User obtains a single user. Returns nil if user does not exist. +func (s *BoltUserStore) User(name string) (user *User, err error) { + err = s.View(func(tx *bbolt.Tx) error { users := tx.Bucket([]byte(boltUserBucketName)) rawUser := users.Get([]byte(name)) if rawUser == nil { - ok = false return nil } - user = DecodeUser(rawUser) - ok = true - return nil - })) - return user, ok + + user, err = DecodeUser(rawUser) + return err + }) + + return user, err } -// AddUser adds a new user; ok is true when successful. -func (s *BoltUserStore) AddUser(user User) (ok bool) { - catch(s.Update(func(tx *bbolt.Tx) error { +// AddUser adds a new user. +func (s *BoltUserStore) AddUser(user User) error { + return s.Update(func(tx *bbolt.Tx) error { users := tx.Bucket([]byte(boltUserBucketName)) if users.Get([]byte(user.Name)) != nil { - ok = false - return nil + return ErrUserExists } - ok = true - return users.Put([]byte(user.Name), user.Encode()) - })) - return ok + + encoded, err := user.Encode() + if err != nil { + return err + } + + return users.Put([]byte(user.Name), encoded) + }) } -// SetUser changes an existing user. Returns true on success. -func (s *BoltUserStore) SetUser(user User) (ok bool) { - catch(s.Update(func(tx *bbolt.Tx) error { +// SetUser changes an existing user. +func (s *BoltUserStore) SetUser(user User) error { + return s.Update(func(tx *bbolt.Tx) error { users := tx.Bucket([]byte(boltUserBucketName)) if users.Get([]byte(user.Name)) == nil { - ok = false - return nil + return ErrUserNotFound + } + + encoded, err := user.Encode() + if err != nil { + return err } - ok = true - return users.Put([]byte(user.Name), user.Encode()) - })) - return ok + + return users.Put([]byte(user.Name), encoded) + }) } // RemoveUser removes a user of given username. -func (s *BoltUserStore) RemoveUser(name string) { - catch(s.Update(func(tx *bbolt.Tx) error { +func (s *BoltUserStore) RemoveUser(name string) error { + return s.Update(func(tx *bbolt.Tx) error { return tx.Bucket([]byte(boltUserBucketName)).Delete([]byte(name)) - })) + }) } // SingleUserStore implements UserStore while enforcing only having a single user. type SingleUserStore struct { - username string UserStore + username string } // NewSingleUserStore creates a new SingleUserStore with provided username and UserStore. func NewSingleUserStore(username string, users UserStore) *SingleUserStore { return &SingleUserStore{ - username: username, UserStore: users, + username: username, } } // User gets a user. -func (s *SingleUserStore) User(name string) (User, bool) { - if s.allowName(name) { - return s.UserStore.User(name) +func (s *SingleUserStore) User(name string) (*User, error) { + if !s.allowName(name) { + return nil, ErrNameNotAllowed } - return User{}, false + + return s.UserStore.User(name) } // AddUser adds a new user. -func (s *SingleUserStore) AddUser(user User) bool { - if s.allowName(user.Name) { - return s.UserStore.AddUser(user) +func (s *SingleUserStore) AddUser(user User) error { + if !s.allowName(user.Name) { + return ErrNameNotAllowed } - return false + + return s.UserStore.AddUser(user) } // SetUser sets an existing user. -func (s *SingleUserStore) SetUser(user User) bool { - if s.allowName(user.Name) { - return s.UserStore.SetUser(user) +func (s *SingleUserStore) SetUser(user User) error { + if !s.allowName(user.Name) { + return ErrNameNotAllowed } - return false + + return s.UserStore.SetUser(user) } // RemoveUser removes a user. -func (s *SingleUserStore) RemoveUser(name string) { - if s.allowName(name) { - s.UserStore.RemoveUser(name) +func (s *SingleUserStore) RemoveUser(name string) error { + if !s.allowName(name) { + return ErrNameNotAllowed } + + return s.UserStore.RemoveUser(name) } func (s *SingleUserStore) allowName(name string) bool { return name == s.username } -// UsernameFormatOkay checks if the username format is valid. -func UsernameFormatOkay(name string) bool { +func checkUsernameFormat(name string) bool { return regexp.MustCompile(`^[a-z0-9_-]{4,21}$`).MatchString(name) } -// PasswordFormatOkay checks if the password format is valid. -func PasswordFormatOkay(pass string) bool { - if len(pass) < 6 || len(pass) > 64 { - return false +func checkPasswordFormat(password string) error { + if len(password) < minPasswordLen || len(password) > maxPasswordLen { + return ErrBadPasswordLen } - // TODO: implement more advanced password checking. - return true + + return checkPasswordStrength(password) +} + +func checkPasswordStrength(password string) error { + if len(password) == 0 { + return ErrSimplePassword + } + + passwordClasses := [][]*unicode.RangeTable{ + {unicode.Upper, unicode.Title}, + {unicode.Lower}, + {unicode.Number, unicode.Digit}, + {unicode.Space, unicode.Symbol, unicode.Punct, unicode.Mark}, + } + + seen := make([]bool, len(passwordClasses)) + + for _, r := range password { + for i, class := range passwordClasses { + if unicode.IsOneOf(class, r) { + seen[i] = true + } + } + } + + for _, v := range seen { + if !v { + return ErrSimplePassword + } + } + + return nil } diff --git a/pkg/hypervisor/user_manager.go b/pkg/hypervisor/user_manager.go index f037914a5..eca69a742 100644 --- a/pkg/hypervisor/user_manager.go +++ b/pkg/hypervisor/user_manager.go @@ -3,6 +3,7 @@ package hypervisor import ( "context" "errors" + "fmt" "net/http" "sync" "time" @@ -24,8 +25,6 @@ var ( ErrBadLogin = errors.New("incorrect username or password") ErrBadSession = errors.New("session cookie is either non-existent, expired, or ill-formatted") ErrBadUsernameFormat = errors.New("format of 'username' is not accepted") - ErrBadPasswordFormat = errors.New("format of 'password' is not accepted") - ErrUserNotCreated = errors.New("failed to create new user: username is either already taken, or unaccepted") ErrUserNotFound = errors.New("user is either deleted or not found") ) @@ -71,25 +70,44 @@ func (s *UserManager) Login() http.HandlerFunc { httputil.WriteJSON(w, r, http.StatusForbidden, ErrNotLoggedOut) return } + var rb struct { Username string `json:"username"` Password string `json:"password"` } + if err := httputil.ReadJSON(r, &rb); err != nil { httputil.WriteJSON(w, r, http.StatusBadRequest, ErrBadBody) return } - user, ok := s.db.User(rb.Username) - if !ok || !user.VerifyPassword(rb.Password) { + + user, err := s.db.User(rb.Username) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + log.WithError(err).Errorf("Failed to get user %q", rb.Username) + + return + } + + if user == nil || !user.VerifyPassword(rb.Password) { httputil.WriteJSON(w, r, http.StatusUnauthorized, ErrBadLogin) return } - s.newSession(w, Session{ + + session := Session{ User: rb.Username, Expiry: time.Now().Add(s.c.ExpiresDuration), - }) + } + + if err := s.newSession(w, session); err != nil { + log.WithError(err).Errorf("Failed to create a new session") + w.WriteHeader(http.StatusInternalServerError) + + return + } + // http.SetCookie() - httputil.WriteJSON(w, r, http.StatusOK, ok) + httputil.WriteJSON(w, r, http.StatusOK, true) } } @@ -100,6 +118,7 @@ func (s *UserManager) Logout() http.HandlerFunc { httputil.WriteJSON(w, r, http.StatusBadRequest, errors.New("not logged in")) return } + httputil.WriteJSON(w, r, http.StatusOK, true) } } @@ -112,9 +131,11 @@ func (s *UserManager) Authorize(next http.Handler) http.Handler { httputil.WriteJSON(w, r, http.StatusUnauthorized, ErrBadSession) return } + ctx := r.Context() ctx = context.WithValue(ctx, userKey, user) ctx = context.WithValue(ctx, sessionKey, session) + next.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -122,29 +143,35 @@ func (s *UserManager) Authorize(next http.Handler) http.Handler { // ChangePassword returns a HandlerFunc for changing the user's password. func (s *UserManager) ChangePassword() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - var ( - user = r.Context().Value(userKey).(User) - ) + var user = r.Context().Value(userKey).(User) + var rb struct { OldPassword string `json:"old_password"` NewPassword string `json:"new_password"` } + if err := httputil.ReadJSON(r, &rb); err != nil { httputil.WriteJSON(w, r, http.StatusBadRequest, err) return } + if ok := user.VerifyPassword(rb.OldPassword); !ok { httputil.WriteJSON(w, r, http.StatusUnauthorized, ErrBadLogin) return } - if ok := user.SetPassword(rb.NewPassword); !ok { - httputil.WriteJSON(w, r, http.StatusBadRequest, ErrBadPasswordFormat) + + if err := user.SetPassword(rb.NewPassword); err != nil { + httputil.WriteJSON(w, r, http.StatusBadRequest, err) return } - if ok := s.db.SetUser(user); !ok { - httputil.WriteJSON(w, r, http.StatusForbidden, ErrUserNotFound) + + if err := s.db.SetUser(user); err != nil { + log.WithError(err).Errorf("Failed to update user %q data", user.Name) + w.WriteHeader(http.StatusInternalServerError) + return } + s.delAllSessionsOfUser(user.Name) httputil.WriteJSON(w, r, http.StatusOK, true) } @@ -157,23 +184,35 @@ func (s *UserManager) CreateAccount() http.HandlerFunc { Username string `json:"username"` Password string `json:"password"` } + if err := httputil.ReadJSON(r, &rb); err != nil { httputil.WriteJSON(w, r, http.StatusBadRequest, err) return } + var user User if ok := user.SetName(rb.Username); !ok { httputil.WriteJSON(w, r, http.StatusBadRequest, ErrBadUsernameFormat) return } - if ok := user.SetPassword(rb.Password); !ok { - httputil.WriteJSON(w, r, http.StatusBadRequest, ErrBadPasswordFormat) + + if err := user.SetPassword(rb.Password); err != nil { + httputil.WriteJSON(w, r, http.StatusBadRequest, err) return } - if ok := s.db.AddUser(user); !ok { - httputil.WriteJSON(w, r, http.StatusForbidden, ErrUserNotCreated) + + if err := s.db.AddUser(user); err != nil { + if err == ErrNameNotAllowed { + httputil.WriteJSON(w, r, http.StatusForbidden, ErrNameNotAllowed) + return + } + + log.WithError(err).Errorf("Failed to create user %q account", user.Name) + w.WriteHeader(http.StatusInternalServerError) + return } + httputil.WriteJSON(w, r, http.StatusOK, true) } } @@ -181,19 +220,22 @@ func (s *UserManager) CreateAccount() http.HandlerFunc { // UserInfo returns a HandlerFunc for obtaining user info. func (s *UserManager) UserInfo() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - var ( - user = r.Context().Value(userKey).(User) - session = r.Context().Value(sessionKey).(Session) - ) + user := r.Context().Value(userKey).(User) + session := r.Context().Value(sessionKey).(Session) + var otherSessions []Session + s.mu.RLock() + for _, s := range s.sessions { if s.User == user.Name && s.SID != session.SID { otherSessions = append(otherSessions, s) } } + s.mu.RUnlock() - httputil.WriteJSON(w, r, http.StatusOK, struct { + + resp := struct { Username string `json:"username"` Current Session `json:"current_session"` Sessions []Session `json:"other_sessions"` @@ -201,17 +243,24 @@ func (s *UserManager) UserInfo() http.HandlerFunc { Username: user.Name, Current: session, Sessions: otherSessions, - }) + } + + httputil.WriteJSON(w, r, http.StatusOK, resp) } } -func (s *UserManager) newSession(w http.ResponseWriter, session Session) { +func (s *UserManager) newSession(w http.ResponseWriter, session Session) error { session.SID = uuid.New() + s.mu.Lock() s.sessions[session.SID] = session s.mu.Unlock() + value, err := s.crypto.Encode(sessionCookieName, session.SID) - catch(err) + if err != nil { + return fmt.Errorf("encode SID cookie: %w", err) + } + http.SetCookie(w, &http.Cookie{ Name: sessionCookieName, Value: value, @@ -221,6 +270,8 @@ func (s *UserManager) newSession(w http.ResponseWriter, session Session) { HttpOnly: s.c.HTTPOnly, SameSite: s.c.SameSite, }) + + return nil } func (s *UserManager) delSession(w http.ResponseWriter, r *http.Request) error { @@ -228,13 +279,16 @@ func (s *UserManager) delSession(w http.ResponseWriter, r *http.Request) error { if err != nil { return err } + var sid uuid.UUID if err := s.crypto.Decode(sessionCookieName, cookie.Value, &sid); err != nil { return err } + s.mu.Lock() delete(s.sessions, sid) s.mu.Unlock() + http.SetCookie(w, &http.Cookie{ Name: sessionCookieName, Domain: s.c.Domain, @@ -243,16 +297,19 @@ func (s *UserManager) delSession(w http.ResponseWriter, r *http.Request) error { HttpOnly: s.c.HTTPOnly, SameSite: s.c.SameSite, }) + return nil } func (s *UserManager) delAllSessionsOfUser(userName string) { s.mu.Lock() + for sid, session := range s.sessions { if session.User == userName { delete(s.sessions, sid) } } + s.mu.Unlock() } @@ -261,26 +318,38 @@ func (s *UserManager) session(r *http.Request) (User, Session, bool) { if err != nil { return User{}, Session{}, false } + var sid uuid.UUID if err := s.crypto.Decode(sessionCookieName, cookie.Value, &sid); err != nil { - log.WithError(err).Warn("failed to decode session cookie value") + log.WithError(err).Warn("Failed to decode session cookie value") return User{}, Session{}, false } + s.mu.RLock() session, ok := s.sessions[sid] s.mu.RUnlock() + if !ok { return User{}, Session{}, false } - user, ok := s.db.User(session.User) - if !ok { + + user, err := s.db.User(session.User) + if err != nil { + log.WithError(err).Errorf("Failed to fetch user %q data", user.Name) + return User{}, Session{}, false + } + + if user == nil { return User{}, Session{}, false } + if time.Now().After(session.Expiry) { s.mu.Lock() delete(s.sessions, sid) s.mu.Unlock() + return User{}, Session{}, false } - return user, session, true + + return *user, session, true } diff --git a/pkg/hypervisor/user_test.go b/pkg/hypervisor/user_test.go new file mode 100644 index 000000000..1cd5d0bd3 --- /dev/null +++ b/pkg/hypervisor/user_test.go @@ -0,0 +1,75 @@ +package hypervisor + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +// nolint: funlen +func Test_checkPasswordFormat(t *testing.T) { + tests := []struct { + name string + password string + err error + }{ + { + name: "Too short", + password: "1", + err: ErrBadPasswordLen, + }, + { + name: "Too Long", + password: strings.Repeat("1", 100), + err: ErrBadPasswordLen, + }, + { + name: "Only digit", + password: strings.Repeat("1", 10), + err: ErrSimplePassword, + }, + { + name: "Only lower", + password: strings.Repeat("a", 10), + err: ErrSimplePassword, + }, + { + name: "Only upper", + password: strings.Repeat("A", 10), + err: ErrSimplePassword, + }, + { + name: "Only special", + password: strings.Repeat("!", 10), + err: ErrSimplePassword, + }, + { + name: "Missing digit", + password: strings.Repeat("Aa!", 4), + err: ErrSimplePassword, + }, + { + name: "Missing lower", + password: strings.Repeat("A1!", 4), + err: ErrSimplePassword, + }, + { + name: "Missing upper", + password: strings.Repeat("a1!", 4), + err: ErrSimplePassword, + }, + { + name: "Missing special", + password: strings.Repeat("Aa1", 4), + err: ErrSimplePassword, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.err, checkPasswordFormat(tt.password)) + }) + } +}