diff --git a/authentication/user_accounts.go b/authentication/auth_handler.go similarity index 69% rename from authentication/user_accounts.go rename to authentication/auth_handler.go index d0ae07a..f9096d1 100644 --- a/authentication/user_accounts.go +++ b/authentication/auth_handler.go @@ -11,13 +11,11 @@ import ( "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "slugquest.com/backend/crud" ) const FRONTEND_HOST string = "localhost:5185" -// TODO: make this more elegant with Gin sessions or something -var Curr_user_id string = "hi" - // Checks if user is authenticated before redirecting to next page func IsAuthenticated(c *gin.Context) { // Auth token: for direct calls to this endpoint @@ -46,6 +44,8 @@ func LoginHandler(auth *Authenticator) gin.HandlerFunc { // Save the state inside the session. session := sessions.Default(c) + session.Clear() + session.Set("state", state) if err := session.Save(); err != nil { c.String(http.StatusInternalServerError, err.Error()) @@ -130,20 +130,79 @@ func CallbackHandler(auth *Authenticator) gin.HandlerFunc { return } - // Extract Auth0's provided user vid - if profile["sub"] == nil { + var userInfoStruct *crud.User = getUserInfo(c) + if userInfoStruct == nil { + c.String(http.StatusInternalServerError, "Couldn't retrieve user profile.") + return + } + + session.Set("user_profile", userInfoStruct) + if err := session.Save(); err != nil { c.String(http.StatusInternalServerError, err.Error()) return } - user_id := profile["sub"].(string)[len("auth0|"):] - Curr_user_id = user_id // Redirect to logged in page. c.Redirect(http.StatusTemporaryRedirect, "http://"+FRONTEND_HOST+"/loggedin") } } -// Displays user profile from the current session +func getUserInfo(c *gin.Context) *crud.User { + session := sessions.Default(c) + + profile, ok := session.Get("profile").(map[string]interface{}) + if !ok || profile == nil { + c.String(http.StatusInternalServerError, "Couldn't retrieve user profile.") + return nil + } + + // No user id? No SlugQuest. + sesUID, ok := profile["sub"].(string) + if !ok { + c.String(http.StatusInternalServerError, "Couldn't resolve user id.") + return nil + } + + sesUsername, ok := profile["name"].(string) + if !ok { + c.String(http.StatusInternalServerError, "Couldn't resolve username.") + return nil + } + + sesPFP, ok := profile["picture"].(string) + if !ok { + c.String(http.StatusInternalServerError, "Couldn't resolve profile picture URL.") + return nil + } + + // Check if user exists in our DB + user, found, err := crud.GetUser(sesUID) + if err != nil { + c.String(http.StatusInternalServerError, "Couldn't fetch user.") + return nil + } + + if !found { + // Need to populate and add a new user + user = crud.User{ + UserID: sesUID, + Username: sesUsername, + Picture: sesPFP, + Points: 0, + BossId: 0, + } + + added, err := crud.AddUser(user) + if err != nil || !added { + c.String(http.StatusInternalServerError, "Couldn't register user into our records.") + return nil + } + } + + return &user +} + +// Sends user profile from the current session as JSON // func UserProfileHandler(c *gin.Context) { // session := sessions.Default(c) // profile := session.Get("profile") diff --git a/authentication/authenticator.go b/authentication/authenticator.go index 5d18c8b..c44b213 100644 --- a/authentication/authenticator.go +++ b/authentication/authenticator.go @@ -17,8 +17,8 @@ type Authenticator struct { oauth2.Config } -// New instantiates the *Authenticator. -func New() (*Authenticator, error) { +// NewAuthenticator instantiates the *Authenticator. +func NewAuthenticator() (*Authenticator, error) { provider, err := oidc.NewProvider( context.Background(), "https://"+os.Getenv("AUTH0_DOMAIN")+"/", diff --git a/crud/db_handler.go b/crud/db_handler.go index fc9ab28..ca9c409 100644 --- a/crud/db_handler.go +++ b/crud/db_handler.go @@ -38,6 +38,14 @@ type TaskPreview struct { IsAllDay bool } +type User struct { + UserID string // Not known to user, do not expose + Username string // Set by user, can be exposed + Picture string // A0 stores their profile pics as URLs + Points int + BossId int +} + var DB *sqlx.DB func LoadDumbData() error { @@ -290,3 +298,103 @@ func GetTaskId(Tid int) (Task, bool, error) { print(counter) return taskit, counter == 1, err } + +// Find user by UserID +func GetUser(uid string) (User, bool, error) { + rows, err := DB.Query("SELECT * FROM UserTable WHERE UserID=?;", uid) + var user User + if err != nil { + fmt.Println(err) + return user, false, err + } + + counter := 0 + for rows.Next() { + counter += 1 + rows.Scan(&user.UserID, &user.Points, &user.BossId) + } + rows.Close() + + return user, counter == 1, err +} + +// Add user into DB +func AddUser(u User) (bool, error) { + tx, err := DB.Beginx() + if err != nil { + fmt.Printf("AddUser(): breaky 1: %v", err) + return false, err + } + defer tx.Rollback() // aborrt transaction if error + + stmt, err := tx.Preparex("INSERT INTO UserTable (UserID, Points, Bossid) VALUES (?, ?, ?)") + if err != nil { + fmt.Printf("AddUser(): breaky 2: %v", err) + return false, err + } + + defer stmt.Close() //defer the closing of SQL statement to ensure it Closes once the function completes + _, err = stmt.Exec(u.UserID, u.Points, u.BossId) + if err != nil { + fmt.Printf("AddUser(): breaky 3: %v", err) + return false, err + } + + tx.Commit() //commit transaction to database + + return true, nil +} + +// Edit a user by supplying new values +func EditUser(u User, uid string) (bool, error) { + tx, err := DB.Beginx() + if err != nil { + return false, err + } + defer tx.Rollback() // aborrt transaction if error + + stmt, err := tx.Preparex(` + UPDATE UserTable + SET UserID = ?, Points = ?, Bossid = ? + WHERE UserID = ? + `) + if err != nil { + return false, err + } + + defer stmt.Close() + + _, err = stmt.Exec(u.UserID, u.Points, u.BossId, uid) + if err != nil { + return false, err + } + + tx.Commit() + + return true, nil +} + +func DeleteUser(uid string) (bool, error) { + tx, err := DB.Beginx() + if err != nil { + return false, err + } + defer tx.Rollback() // aborrt transaction if error + + stmt, err := tx.Preparex("DELETE FROM UserTable WHERE UserID = ?") + if err != nil { + fmt.Println("DeleteUser: breaky 1") + return false, err + } + defer stmt.Close() + + _, err = stmt.Exec(uid) + if err != nil { + fmt.Println("DeleteUser: breaky 2") + return false, err + } + + tx.Commit() + + return true, nil +} diff --git a/main.go b/main.go index a5fb6df..c8f43e8 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "fmt" "log" envfuncs "github.com/joho/godotenv" @@ -16,12 +15,14 @@ func main() { // Load .env if env_err := envfuncs.Load(); env_err != nil { log.Fatalf("Error loading the .env file: %v", env_err) + return } // Create new authenticator to pass to the router - auth, auth_err := authentication.New() + auth, auth_err := authentication.NewAuthenticator() if auth_err != nil { log.Fatalf("Failed to initialize the authenticator: %v", auth_err) + return } router := CreateRouter(auth) @@ -34,10 +35,11 @@ func main() { dummy_err := crud.LoadDumbData() if dummy_err != nil { log.Fatalf("error loaduing dumb data: %v", dummy_err) + return } utest := testing.RunAllTests() if !utest { - fmt.Println("unit test failure") + log.Fatal("unit test failure") return } diff --git a/router.go b/router.go index ff0e1b6..3f280f5 100644 --- a/router.go +++ b/router.go @@ -35,6 +35,7 @@ func CreateRouter(auth *authentication.Authenticator) *gin.Engine { // To store custom types in our cookies, // we must first register them using gob.Register gob.Register(map[string]interface{}{}) + gob.Register(crud.User{}) // Set up cookie store for the user session store := cookie.NewStore([]byte("secret")) @@ -134,16 +135,14 @@ func deleteTask(c *gin.Context) { // Returns a list of all tasks of the current user func getAllUserTasks(c *gin.Context) { - // TODO: ill be fixing this - // user_id stored as a variable within the session - // uid := c.GetString("user_id") - // log.Printf("found userid = %v", uid) - // if uid == "" { - // log.Println("getAllUserTasks(): couldn't get user_id") - // c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retreive user id"}) - // return - // } - uid := authentication.Curr_user_id + // Retrieve the user_id through the struct stored in the session + session := sessions.Default(c) + userProfile, ok := session.Get("user_profile").(crud.User) + if !ok { + c.String(http.StatusInternalServerError, "Couldn't retreive user's id to display tasks.") + return + } + uid := userProfile.UserID arr, err := crud.GetUserTask(uid) if err != nil { diff --git a/schema.sql b/schema.sql index c9cfeeb..1fddd4d 100644 --- a/schema.sql +++ b/schema.sql @@ -1,5 +1,8 @@ -CREATE TABLE UserTable ( - UserID VARCHAR(255) PRIMARY KEY NOT NULL-- Assuming Auth0 provides a string-based user ID +CREATE TABLE IF NOT EXISTS UserTable ( + UserID VARCHAR(255) PRIMARY KEY NOT NULL, + Points INTEGER NOT NULL, + BossId INTEGER NOT NULL, + FOREIGN KEY (BossId) REFERENCES BossTable(BossID) ); CREATE TABLE TaskTable ( diff --git a/testing/db_unittests.go b/testing/db_unittests.go index 38d71d6..3f40500 100644 --- a/testing/db_unittests.go +++ b/testing/db_unittests.go @@ -8,15 +8,16 @@ import ( . "slugquest.com/backend/crud" ) -var testUserId string = "1111" +var dummyUserID string = "1111" // testing task functions +var testUserID string = "2222" // testing user functions func RunAllTests() bool { - return TestGetUserTask() && TestDeleteTask() && TestEditTask() && TestGetTaskId() + return TestGetUserTask() && TestDeleteTask() && TestEditTask() && TestGetTaskId() && TestAddUser() && TestEditUser() && TestDeleteUser() } func TestDeleteTask() bool { newTask := Task{ - UserID: testUserId, + UserID: dummyUserID, Category: "example", TaskName: "New Task", Description: "Description of the new task", @@ -55,7 +56,7 @@ func TestDeleteTask() bool { } func TestEditTask() bool { newTask := Task{ - UserID: testUserId, + UserID: dummyUserID, TaskID: 3, Category: "example", TaskName: "New Task", @@ -75,7 +76,7 @@ func TestEditTask() bool { editedTask := Task{ TaskID: int(taskID), - UserID: testUserId, + UserID: dummyUserID, Category: "asdf", TaskName: "edited name", Description: "edited description", @@ -105,7 +106,7 @@ func TestEditTask() bool { return true } func TestGetUserTask() bool { - taskl, err := GetUserTask(testUserId) + taskl, err := GetUserTask(dummyUserID) if err != nil { log.Printf("TestGetUserTask(): %v", err) return false @@ -144,3 +145,68 @@ func TestGetTaskId() bool { } return true } + +func TestAddUser() bool { + newUser := User{ + UserID: testUserID, + Username: "sluggo", + Picture: "lol.jpg", + Points: 1, + BossId: 1, + } + + addSuccess, addErr := AddUser(newUser) + if addErr != nil || !addSuccess { + log.Printf("TestAddUser(): couldn't add user") + return false + } + + _, found, _ := GetUser(newUser.UserID) + if !found { + log.Println("TestAddUser(): add failed") + return false + } + + return true +} + +func TestEditUser() bool { + // Original is one inserted in TestAddUser() + editedUser := User{ + UserID: testUserID, + Username: "not in DB, not tested", + Picture: "not in DB, not tested", + Points: 5, + BossId: 10, + } + + editSuccess, editErr := EditUser(editedUser, editedUser.UserID) + if editErr != nil || !editSuccess { + log.Printf("TestEditUser(): error editing user: %v", editErr) + return false + } + + checkE, _, _ := GetUser(editedUser.UserID) + if checkE.Points != 5 || checkE.BossId != 10 { + log.Println("TestEditUser(): edit verfication failed") + return false + } + + return true +} + +func TestDeleteUser() bool { + deleteSuccess, deleteErr := DeleteUser(testUserID) + if deleteErr != nil || !deleteSuccess { + log.Printf("TestDeleteUser(): couldn't delete user") + return false + } + + _, found, _ := GetUser(testUserID) + if found { + log.Println("TestDeleteUser(): delete failed") + return false + } + + return true +}