diff --git a/cmd/manager-node/commands/root.go b/cmd/manager-node/commands/root.go index 0ead506a5b..bdf5d9118a 100644 --- a/cmd/manager-node/commands/root.go +++ b/cmd/manager-node/commands/root.go @@ -92,7 +92,7 @@ var rootCmd = &cobra.Command{ }() if mock { - err := m.AddMockData(&manager.MockConfig{ + err := m.AddMockData(manager.MockConfig{ Nodes: mockNodes, MaxTpsPerNode: mockMaxTps, MaxRoutesPerNode: mockMaxRoutes, diff --git a/pkg/manager/config.go b/pkg/manager/config.go index 440a161133..1999b06042 100644 --- a/pkg/manager/config.go +++ b/pkg/manager/config.go @@ -33,8 +33,6 @@ type Config struct { PK cipher.PubKey `json:"public_key"` SK cipher.SecKey `json:"secret_key"` DBPath string `json:"db_path"` - NameRegexp string `json:"username_regexp"` // regular expression for usernames (no check if empty). TODO - PassRegexp string `json:"password_regexp"` // regular expression for passwords (no check of empty). TODO PassSaltLen int `json:"password_salt_len"` // Salt Len for password verification data. Cookies CookieConfig `json:"cookies"` Interfaces InterfaceConfig `json:"interfaces"` @@ -64,8 +62,6 @@ func GenerateLocalConfig() Config { } func (c *Config) FillDefaults() { - c.NameRegexp = `^(admin)$` - c.PassRegexp = `((?=.*\d)(?=.*[a-z])(?=.*[A-Z]).{6,20})` c.PassSaltLen = 16 c.Cookies.FillDefaults() c.Interfaces.FillDefaults() @@ -98,8 +94,8 @@ type CookieConfig struct { } func (c *CookieConfig) FillDefaults() { - c.Path = "/" c.ExpiresDuration = time.Hour * 12 + c.Path = "/" c.Secure = true c.HttpOnly = true c.SameSite = http.SameSiteDefaultMode diff --git a/pkg/manager/node.go b/pkg/manager/node.go index 87e2d55d65..956192b4bc 100644 --- a/pkg/manager/node.go +++ b/pkg/manager/node.go @@ -76,7 +76,7 @@ type MockConfig struct { } // AddMockData adds mock data to Manager Node. -func (m *Node) AddMockData(config *MockConfig) error { +func (m *Node) AddMockData(config MockConfig) error { r := rand.New(rand.NewSource(time.Now().UnixNano())) for i := 0; i < config.Nodes; i++ { pk, client := node.NewMockRPCClient(r, config.MaxTpsPerNode, config.MaxRoutesPerNode) @@ -94,37 +94,40 @@ func (m *Node) ServeHTTP(w http.ResponseWriter, r *http.Request) { mux.Use(middleware.Timeout(time.Second * 30)) mux.Use(middleware.Logger) - mux.Route("/auth", func(r chi.Router) { - r.Post("/create-account", m.users.CreateAccount(m.c.PassSaltLen, m.c.PassRegexp, m.c.NameRegexp)) - r.Post("/login", m.users.Login()) - r.Post("/logout", m.users.Logout()) - }) - mux.Route("/api", func(r chi.Router) { - r.Use(m.users.Authorize) - r.Get("/user/info", m.users.UserInfo()) - r.Post("/user/change-password", m.users.ChangePassword(m.c.PassSaltLen, m.c.PassRegexp)) + r.Group(func(r chi.Router) { + r.Post("/create-account", m.users.CreateAccount(m.c.PassSaltLen)) + r.Post("/login", m.users.Login()) + r.Post("/logout", m.users.Logout()) + }) + + r.Group(func(r chi.Router) { + r.Use(m.users.Authorize) + + r.Get("/user", m.users.UserInfo()) + r.Post("/change-password", m.users.ChangePassword(m.c.PassSaltLen)) - r.Get("/nodes", m.getNodes()) - r.Get("/nodes/{pk}", m.getNode()) + r.Get("/nodes", m.getNodes()) + r.Get("/nodes/{pk}", m.getNode()) - r.Get("/nodes/{pk}/apps", m.getApps()) - r.Get("/nodes/{pk}/apps/{app}", m.getApp()) - r.Put("/nodes/{pk}/apps/{app}", m.putApp()) + r.Get("/nodes/{pk}/apps", m.getApps()) + r.Get("/nodes/{pk}/apps/{app}", m.getApp()) + r.Put("/nodes/{pk}/apps/{app}", m.putApp()) - r.Get("/nodes/{pk}/transport-types", m.getTransportTypes()) + r.Get("/nodes/{pk}/transport-types", m.getTransportTypes()) - r.Get("/nodes/{pk}/transports", m.getTransports()) - r.Post("/nodes/{pk}/transports", m.postTransport()) - r.Get("/nodes/{pk}/transports/{tid}", m.getTransport()) - r.Delete("/nodes/{pk}/transports/{tid}", m.deleteTransport()) + r.Get("/nodes/{pk}/transports", m.getTransports()) + r.Post("/nodes/{pk}/transports", m.postTransport()) + r.Get("/nodes/{pk}/transports/{tid}", m.getTransport()) + r.Delete("/nodes/{pk}/transports/{tid}", m.deleteTransport()) - r.Get("/nodes/{pk}/routes", m.getRoutes()) - r.Post("/nodes/{pk}/routes", m.postRoute()) - r.Get("/nodes/{pk}/routes/{rid}", m.getRoute()) - r.Put("/nodes/{pk}/routes/{rid}", m.putRoute()) - r.Delete("/nodes/{pk}/routes/{rid}", m.deleteRoute()) + r.Get("/nodes/{pk}/routes", m.getRoutes()) + r.Post("/nodes/{pk}/routes", m.postRoute()) + r.Get("/nodes/{pk}/routes/{rid}", m.getRoute()) + r.Put("/nodes/{pk}/routes/{rid}", m.putRoute()) + r.Delete("/nodes/{pk}/routes/{rid}", m.deleteRoute()) + }) }) mux.ServeHTTP(w, r) diff --git a/pkg/manager/node_test.go b/pkg/manager/node_test.go new file mode 100644 index 0000000000..9e43fe6553 --- /dev/null +++ b/pkg/manager/node_test.go @@ -0,0 +1,293 @@ +package manager + +import ( + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TODO(evanlinjin): tests to write: +// - no access to any endpoint without login / signup. + +func TestNewNode(t *testing.T) { + config := makeConfig() + + confDir, err := ioutil.TempDir(os.TempDir(), "SWM") + require.NoError(t, err) + config.DBPath = filepath.Join(confDir, "users.db") + + startNode := func(mock MockConfig) (string, *http.Client, func()) { + node, err := NewNode(config) + require.NoError(t, err) + require.NoError(t, node.AddMockData(mock)) + + srv := httptest.NewTLSServer(node) + node.c.Cookies.Domain = srv.Listener.Addr().String() + + client := srv.Client() + jar, err := cookiejar.New(&cookiejar.Options{}) + require.NoError(t, err) + client.Jar = jar + + return srv.Listener.Addr().String(), client, func() { + srv.Close() + require.NoError(t, os.Remove(config.DBPath)) + } + } + + type TestCase struct { + Method string + URI string + Body io.Reader + RespStatus int + RespBody func(t *testing.T, resp *http.Response) + } + + testCases := func(t *testing.T, addr string, client *http.Client, cases []TestCase) { + for i, tc := range cases { + testTag := fmt.Sprintf("[%d] %s", i, tc.URI) + + req, err := http.NewRequest(tc.Method, "https://"+addr+tc.URI, tc.Body) + require.NoError(t, err, testTag) + + resp, err := client.Do(req) + require.NoError(t, err, testTag) + + assert.Equal(t, tc.RespStatus, resp.StatusCode, testTag) + if tc.RespBody != nil { + tc.RespBody(t, resp) + } + } + } + + t.Run("no_access_without_login", func(t *testing.T) { + addr, client, stop := startNode(MockConfig{5, 10, 10}) + defer stop() + + makeCase := func(method string, uri string, body io.Reader) TestCase { + return TestCase{ + Method: method, + URI: uri, + Body: body, + RespStatus: http.StatusUnauthorized, + RespBody: func(t *testing.T, r *http.Response) { + body, err := decodeErrorBody(r.Body) + assert.NoError(t, err) + assert.Equal(t, ErrBadSession.Error(), body.Error) + }, + } + } + + testCases(t, addr, client, []TestCase{ + makeCase(http.MethodGet, "/api/user", nil), + makeCase(http.MethodPost, "/api/change-password", strings.NewReader(`{"old_password":"old","new_password":"new"}`)), + makeCase(http.MethodGet, "/api/nodes", nil), + }) + }) + + t.Run("only_admin_account_allowed", func(t *testing.T) { + addr, client, stop := startNode(MockConfig{5, 10, 10}) + defer stop() + + testCases(t, addr, client, []TestCase{ + { + Method: http.MethodPost, + URI: "/api/create-account", + Body: strings.NewReader(`{"username":"invalid_user","password":"Secure1234"}`), + RespStatus: http.StatusForbidden, + RespBody: func(t *testing.T, r *http.Response) { + body, err := decodeErrorBody(r.Body) + assert.NoError(t, err) + assert.Equal(t, ErrUserNotCreated.Error(), body.Error) + }, + }, + { + Method: http.MethodPost, + URI: "/api/create-account", + Body: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) + }, + }, + }) + }) + + t.Run("cannot_login_twice", func(t *testing.T) { + addr, client, stop := startNode(MockConfig{5, 10, 10}) + defer stop() + + testCases(t, addr, client, []TestCase{ + { + Method: http.MethodPost, + URI: "/api/create-account", + Body: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) + }, + }, + { + Method: http.MethodPost, + URI: "/api/login", + Body: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) + }, + }, + { + Method: http.MethodPost, + URI: "/api/login", + Body: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), + RespStatus: http.StatusForbidden, + RespBody: func(t *testing.T, r *http.Response) { + body, err := decodeErrorBody(r.Body) + assert.NoError(t, err) + assert.Equal(t, ErrNotLoggedOut.Error(), body.Error) + }, + }, + }) + }) + + t.Run("access_after_login", func(t *testing.T) { + addr, client, stop := startNode(MockConfig{5, 10, 10}) + defer stop() + + testCases(t, addr, client, []TestCase{ + { + Method: http.MethodPost, + URI: "/api/create-account", + Body: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) + }, + }, + { + Method: http.MethodPost, + URI: "/api/login", + Body: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + fmt.Println(r.Cookies()) + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) + }, + }, + { + Method: http.MethodGet, + URI: "/api/user", + RespStatus: http.StatusOK, + }, + { + Method: http.MethodGet, + URI: "/api/nodes", + RespStatus: http.StatusOK, + }, + }) + }) + + t.Run("no_access_after_logout", func(t *testing.T) { + addr, client, stop := startNode(MockConfig{5, 10, 10}) + defer stop() + + testCases(t, addr, client, []TestCase{ + { + Method: http.MethodPost, + URI: "/api/create-account", + Body: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) + }, + }, + { + Method: http.MethodPost, + URI: "/api/login", + Body: strings.NewReader(`{"username":"admin","password":"Secure1234"}`), + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + fmt.Println(r.Cookies()) + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) + }, + }, + { + Method: http.MethodPost, + URI: "/api/logout", + RespStatus: http.StatusOK, + RespBody: func(t *testing.T, r *http.Response) { + fmt.Println(r.Cookies()) + var ok bool + assert.NoError(t, json.NewDecoder(r.Body).Decode(&ok)) + assert.True(t, ok) + }, + }, + { + Method: http.MethodGet, + URI: "/api/user", + RespStatus: http.StatusUnauthorized, + RespBody: func(t *testing.T, r *http.Response) { + body, err := decodeErrorBody(r.Body) + assert.NoError(t, err) + assert.Equal(t, ErrBadSession.Error(), body.Error) + }, + }, + { + Method: http.MethodGet, + URI: "/api/nodes", + RespStatus: http.StatusUnauthorized, + RespBody: func(t *testing.T, r *http.Response) { + body, err := decodeErrorBody(r.Body) + assert.NoError(t, err) + assert.Equal(t, ErrBadSession.Error(), body.Error) + }, + }, + }) + }) + + t.Run("change_password", func(t *testing.T) { + // TODO: + // - Create account. + // - Login. + // - Change Password. + // - Logout. + // - Login with old password (should fail). + // - Login with new password (should succeed). + }) +} + +type ErrorBody struct { + Error string `json:"error"` +} + +func decodeErrorBody(rb io.Reader) (*ErrorBody, error) { + b := new(ErrorBody) + dec := json.NewDecoder(rb) + dec.DisallowUnknownFields() + return b, dec.Decode(b) +} diff --git a/pkg/manager/user.go b/pkg/manager/user.go index 7fe686b5da..cf4d1f37dc 100644 --- a/pkg/manager/user.go +++ b/pkg/manager/user.go @@ -28,27 +28,17 @@ type User struct { PwHash cipher.SHA256 } -func (u *User) SetName(pattern, name string) bool { - if pattern != "" { - ok, err := regexp.MatchString(pattern, name) - catch(err, "invalid username regex:") - if !ok { - return false - } +func (u *User) SetName(name string) bool { + if !UsernameFormatOkay(name) { + return false } u.Name = name return true } -func (u *User) SetPassword(saltLen int, pattern, password string) bool { - if pattern != "" { - ok, err := regexp.MatchString(pattern, password) - if err != nil { - catch(err, "invalid password regex:") - } - if !ok { - return false - } +func (u *User) SetPassword(saltLen int, password string) bool { + if !PasswordFormatOkay(password) { + return false } u.PwSalt = cipher.RandByte(saltLen) u.PwHash = cipher.SumSHA256(append([]byte(password), u.PwSalt...)) @@ -190,3 +180,15 @@ func (s *SingleUserStore) RemoveUser(name string) { func (s *SingleUserStore) allowName(name string) bool { return name == s.username } + +func UsernameFormatOkay(name string) bool { + return regexp.MustCompile(`^[a-z0-9_-]{4,21}$`).MatchString(name) +} + +func PasswordFormatOkay(pass string) bool { + if len(pass) < 6 || len(pass) > 64 { + return false + } + // TODO: implement more advanced password checking. + return true +} diff --git a/pkg/manager/user_manager.go b/pkg/manager/user_manager.go index 4df1fa5363..8c658b0d4e 100644 --- a/pkg/manager/user_manager.go +++ b/pkg/manager/user_manager.go @@ -3,7 +3,6 @@ package manager import ( "context" "errors" - "fmt" "net/http" "sync" "time" @@ -15,7 +14,17 @@ import ( ) const ( - sessionCookieName = "swm_session" + sessionCookieName = "swm-session" +) + +var ( + ErrBadBody = errors.New("ill-formatted request body") + ErrNotLoggedOut = errors.New("not logged out") + ErrBadLogin = errors.New("incorrect username or password") + ErrBadSession = errors.New("session cookie is either non-existent, expired, or ill-formatted") + ErrBadUsernameFormat = errors.New("format of 'username' is not accepted") + ErrBadPasswordFormat = errors.New("format of 'password' is not accepted") + ErrUserNotCreated = errors.New("failed to create new user: username is either already taken, or unaccepted") ) type Session struct { @@ -44,8 +53,8 @@ func NewUserManager(users UserStore, config CookieConfig) *UserManager { func (s *UserManager) Login() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - if _, err := r.Cookie(sessionCookieName); err != http.ErrNoCookie { - httputil.WriteJSON(w, r, http.StatusForbidden, errors.New("not logged out")) + if _, _, ok := s.session(r); ok { + httputil.WriteJSON(w, r, http.StatusForbidden, ErrNotLoggedOut) return } var rb struct { @@ -53,18 +62,19 @@ func (s *UserManager) Login() http.HandlerFunc { Password string `json:"password"` } if err := httputil.ReadJSON(r, &rb); err != nil { - httputil.WriteJSON(w, r, http.StatusBadRequest, errors.New("cannot read request body")) + httputil.WriteJSON(w, r, http.StatusBadRequest, ErrBadBody) return } user, ok := s.db.User(rb.Username) if !ok || !user.VerifyPassword(rb.Password) { - httputil.WriteJSON(w, r, http.StatusUnauthorized, errors.New("incorrect username or password")) + httputil.WriteJSON(w, r, http.StatusUnauthorized, ErrBadLogin) return } s.newSession(w, Session{ User: rb.Username, Expiry: time.Now().Add(time.Hour * 12), // TODO: Set default expiry. }) + //http.SetCookie() httputil.WriteJSON(w, r, http.StatusOK, ok) } } @@ -81,9 +91,9 @@ func (s *UserManager) Logout() http.HandlerFunc { func (s *UserManager) Authorize(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user, session, err := s.session(r) - if err != nil { - httputil.WriteJSON(w, r, http.StatusUnauthorized, err) + user, session, ok := s.session(r) + if !ok { + httputil.WriteJSON(w, r, http.StatusUnauthorized, ErrBadSession) return } ctx := r.Context() @@ -93,7 +103,7 @@ func (s *UserManager) Authorize(next http.Handler) http.Handler { }) } -func (s *UserManager) ChangePassword(pwSaltLen int, pwPattern string) http.HandlerFunc { +func (s *UserManager) ChangePassword(pwSaltLen int) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var ( user = r.Context().Value("user").(User) @@ -110,7 +120,7 @@ func (s *UserManager) ChangePassword(pwSaltLen int, pwPattern string) http.Handl httputil.WriteJSON(w, r, http.StatusUnauthorized, errors.New("unauthorised")) return } - if ok := user.SetPassword(pwSaltLen, pwPattern, rb.NewPassword); !ok { + if ok := user.SetPassword(pwSaltLen, rb.NewPassword); !ok { httputil.WriteJSON(w, r, http.StatusBadRequest, errors.New("format of 'new_password' is not accepted")) return } @@ -118,7 +128,7 @@ func (s *UserManager) ChangePassword(pwSaltLen int, pwPattern string) http.Handl } } -func (s *UserManager) CreateAccount(pwSaltLen int, pwPattern, unPattern string) http.HandlerFunc { +func (s *UserManager) CreateAccount(pwSaltLen int) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var rb struct { Username string `json:"username"` @@ -129,16 +139,16 @@ func (s *UserManager) CreateAccount(pwSaltLen int, pwPattern, unPattern string) return } var user User - if ok := user.SetName(unPattern, rb.Username); !ok { - httputil.WriteJSON(w, r, http.StatusBadRequest, errors.New("format of 'username' is not accepted")) + if ok := user.SetName(rb.Username); !ok { + httputil.WriteJSON(w, r, http.StatusBadRequest, ErrBadUsernameFormat) return } - if ok := user.SetPassword(pwSaltLen, pwPattern, rb.Password); !ok { - httputil.WriteJSON(w, r, http.StatusBadRequest, errors.New("format of 'password' is not accepted")) + if ok := user.SetPassword(pwSaltLen, rb.Password); !ok { + httputil.WriteJSON(w, r, http.StatusBadRequest, ErrBadPasswordFormat) return } if ok := s.db.AddUser(user); !ok { - httputil.WriteJSON(w, r, http.StatusForbidden, fmt.Errorf("failed to create user of username '%s'", user.Name)) + httputil.WriteJSON(w, r, http.StatusForbidden, ErrUserNotCreated) return } httputil.WriteJSON(w, r, http.StatusOK, true) @@ -212,32 +222,33 @@ func (s *UserManager) delSession(w http.ResponseWriter, r *http.Request) error { return nil } -func (s *UserManager) session(r *http.Request) (User, Session, error) { +func (s *UserManager) session(r *http.Request) (User, Session, bool) { cookie, err := r.Cookie(sessionCookieName) if err != nil { - return User{}, Session{}, err + return User{}, Session{}, false } var sid uuid.UUID if err := s.crypto.Decode(sessionCookieName, cookie.Value, &sid); err != nil { - return User{}, Session{}, err + log.WithError(err).Warn("failed to decode session cookie value") + return User{}, Session{}, false } s.mu.RLock() session, ok := s.sessions[sid] s.mu.RUnlock() if !ok { - return User{}, Session{}, errors.New("invalid session") + return User{}, Session{}, false } user, ok := s.db.User(session.User) if !ok { - return User{}, Session{}, errors.New("invalid session") + return User{}, Session{}, false } if time.Now().After(session.Expiry) { s.mu.Lock() delete(s.sessions, sid) s.mu.Unlock() - return User{}, Session{}, errors.New("invalid session") + return User{}, Session{}, false } - return user, session, nil + return user, session, true } // TODO: getSessionCookie function.