Skip to content

Commit

Permalink
PK-4: Add userID in r.Context after successful login. close #4
Browse files Browse the repository at this point in the history
  • Loading branch information
egregors committed Jul 26, 2024
1 parent 1d04070 commit 39a8d1f
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 32 deletions.
1 change: 1 addition & 0 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ jobs:
uses: golangci/golangci-lint-action@v3
with:
version: latest
config: .golangci.yml

- name: install goveralls
run: |
Expand Down
37 changes: 22 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Used in project:
![Static Badge](https://img.shields.io/badge/Go_WebAuthn-v0.10.2-green)
![Static Badge](https://img.shields.io/badge/TS%5CJS%20SimpleWebAuthn-v10.0.0-green)

Actual versions:
Actual versions:
![GitHub Release](https://img.shields.io/github/v/release/go-webauthn/webauthn?label=Go%20WebAuthn)
![GitHub Release](https://img.shields.io/github/v/release/MasterKale/SimpleWebAuthn?label=TS%5CJS%20SimpleWebAuthn)

Expand All @@ -65,13 +65,17 @@ To add a passkey service to your application, you need to do two things:
#### Implement the `UserStore` and `SessionStore` interfaces

```go
package passkey

import "github.com/go-webauthn/webauthn/webauthn"

type User interface {
webauthn.User
PutCredential(webauthn.Credential)
}

type UserStore interface {
GetOrCreateUser(userName string) User
GetOrCreateUser(UserID string) User
SaveUser(User)
}

Expand Down Expand Up @@ -169,40 +173,43 @@ This will start the example application on http://localhost:8080.
The library provides a middleware function that can be used to protect routes that require authentication.

```go
func Auth(sessionStore SessionStore, onSuccess, onFail http.HandlerFunc) func (next http.Handler) http.Handler {
Auth(sessionStore SessionStore, userIDKey string, onSuccess, onFail http.HandlerFunc) func (next http.Handler) http.Handler
```

It takes two callback functions that are called when the user is authenticated or not.
It takes key for context and two callback functions that are called when the user is authenticated or not.
You can use the context key to retrieve the authenticated userID from the request context
with `passkey.UserFromContext`.

`passkey` contains a helper function:

| Helper | Description |
|------------------------------|-------------------------------------------------------------------------|
| Unauthorized | Returns a 401 Unauthorized response when the user is not authenticated. |
| RedirectUnauthorized(target) | Redirects the user to a given URL when they are not authenticated. |
| UserFromContext | Get userID from context |

You can use it to protect routes that require authentication:

```go
package main

import (
"net/url"
"net/url"

"github.com/egregors/passkey"
"github.com/egregors/passkey"
)

func main() {
// ...
withAuth := passkey.Auth(
storage,
nil,
passkey.RedirectUnauthorized(url.URL{Path: "/"}),
)

mux.Handle("/private", withAuth(privateMux))
}
// ...
withAuth := passkey.Auth(
storage,
"pkUser",
nil,
passkey.RedirectUnauthorized(url.URL{Path: "/"}),
)

mux.Handle("/private", withAuth(privateMux))
}
```

## Development
Expand Down
30 changes: 28 additions & 2 deletions _example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"fmt"
"html/template"
"net/http"
"net/url"
"time"
Expand Down Expand Up @@ -43,12 +44,37 @@ func main() {

privateMux := http.NewServeMux()
privateMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// render html from web/private.html
http.ServeFile(w, r, "./_example/web/private.html")
// get the userID from the request context
userID, ok := passkey.UserFromContext(r.Context(), "pkUser")
if !ok {
http.Error(w, "No user found", http.StatusUnauthorized)

return
}

pageData := struct {
UserID string
}{
UserID: userID,
}

tmpl, err := template.ParseFiles("./_example/web/private.html")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)

return
}

if err := tmpl.Execute(w, pageData); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)

return
}
})

withAuth := passkey.Auth(
storage,
"pkUser",
nil,
passkey.RedirectUnauthorized(url.URL{Path: "/"}),
)
Expand Down
2 changes: 1 addition & 1 deletion _example/web/private.html
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
</nav>

<div class="container d-flex justify-content-center align-items-center vh-100">
<h1>Private</h1>
<h1>Hi, {{ .UserID }}!</h1>
</div>

<script src="script.js"></script>
Expand Down
4 changes: 3 additions & 1 deletion handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,10 @@ func (p *Passkey) finishLogin(w http.ResponseWriter, r *http.Request) {
return
}

// FIXME: we reuse the webauthn.SessionData struct, but it's not a good idea probably
p.sessionStore.SaveSession(t, &webauthn.SessionData{
Expires: time.Now().Add(time.Hour),
UserID: session.UserID,
Expires: time.Now().Add(p.cfg.SessionMaxAge),
})
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Expand Down
2 changes: 1 addition & 1 deletion ifaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type User interface {
}

type UserStore interface {
GetOrCreateUser(userName string) User
GetOrCreateUser(userID string) User
SaveUser(User)
}

Expand Down
28 changes: 25 additions & 3 deletions middleware.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
package passkey

import (
"context"
"net/http"
"net/url"
"time"
)

// Auth implements a middleware handler for adding passkey http auth to a route.
// It checks if the request has a valid session cookie and if the session is still valid.
// If the session is valid, the onSuccess handler is called and the next handler is executed.
// If the session is invalid, the onFail handler is called and the next handler is not executed.
func Auth(sessionStore SessionStore, onSuccess, onFail http.HandlerFunc) func(next http.Handler) http.Handler {
// If the session is valid:
// - `UserID` will be added to the request context;
// - `onSuccess` handler is called and the next handler is executed.
//
// Otherwise:
// - `onFail` handler is called and the next handler is not executed.
func Auth(sessionStore SessionStore, userIDKey string, onSuccess, onFail http.HandlerFunc) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sid, err := r.Cookie(sessionCookieName)
Expand All @@ -33,6 +38,10 @@ func Auth(sessionStore SessionStore, onSuccess, onFail http.HandlerFunc) func(ne
return
}

ctx := r.Context()
ctx = context.WithValue(ctx, userIDKey, string(session.UserID))
r = r.WithContext(ctx)

exec(onSuccess, w, r)
next.ServeHTTP(w, r)
})
Expand All @@ -56,3 +65,16 @@ func RedirectUnauthorized(target url.URL) http.HandlerFunc {
http.Redirect(w, r, target.String(), http.StatusSeeOther)
}
}

// UserFromContext returns the user ID from the request context. If the userID is not found, it returns an empty string.
func UserFromContext(ctx context.Context, pkUserKey string) (string, bool) {
if ctx.Value(pkUserKey) == nil {
return "", false
}

if id, ok := ctx.Value(pkUserKey).(string); ok && id != "" {
return id, true
}

return "", false
}
55 changes: 54 additions & 1 deletion middleware_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package passkey

import (
"context"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -169,7 +171,12 @@ func TestAuth(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sessionStore := tt.args.sessionStore()
handler := Auth(sessionStore, tt.args.onSuccess, tt.args.onFail)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler := Auth(
sessionStore,
"pkUserKey",
tt.args.onSuccess,
tt.args.onFail,
)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
resp := httptest.NewRecorder()
Expand All @@ -185,3 +192,49 @@ func TestAuth(t *testing.T) {
})
}
}

func TestUserFromContext(t *testing.T) {
tests := []struct {
name string
ctx context.Context
pkUserKey string
wantVal string
wantOk bool
}{
{
name: "empty context",
ctx: context.Background(),
pkUserKey: "pkUserKey",
wantVal: "",
wantOk: false,
},
{
name: "missing key",
ctx: context.WithValue(context.Background(), "otherKey", "value"),
pkUserKey: "pkUserKey",
wantVal: "",
wantOk: false,
},
{
name: "empty value",
ctx: context.WithValue(context.Background(), "pkUserKey", ""),
pkUserKey: "pkUserKey",
wantVal: "",
wantOk: false,
},
{
name: "valid value",
ctx: context.WithValue(context.Background(), "pkUserKey", "value"),
pkUserKey: "pkUserKey",
wantVal: "value",
wantOk: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotVal, gotOk := UserFromContext(tt.ctx, tt.pkUserKey)
assert.Equalf(t, tt.wantVal, gotVal, "UserFromContext(%v, %v)", tt.ctx, tt.pkUserKey)
assert.Equalf(t, tt.wantOk, gotOk, "UserFromContext(%v, %v)", tt.ctx, tt.pkUserKey)
})
}
}
16 changes: 8 additions & 8 deletions mock_UserStore.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 39a8d1f

Please sign in to comment.