diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 0000000..9e1e72f --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,42 @@ +# Contributing + +Before making any changes to this project, please initiate discussions for the +proposed changes that do not yet have an issue associated with them. Your +collaboration is greatly appreciated! + +## Labels + +Please make use of the available labels when creating issues or pull requests: + +- `enhancement`: New feature or request +- `bug`: Something isn't working +- `documentation`: Improvements or additions to documentation +- `help wanted`: Extra attention is needed +- `question`: Further information is requested + +As we work through issues or pull requests, they may be additionally labeled +with: + +- `duplicate`: This issue or pull request already exists +- `good first issue`: Great for newcomers +- `invalid`: This doesn't seem right +- `wontfix`: This will not be worked on + +## Pull Requests + +Pull requests should be made against the `main` branch. All pull requests that +contain a feature or fix are mandatory to have unit tests. Your PR is only to be +merged if you adhere to this flow. + +## Security Vulnerabilities + +If you discovery a security vulnerability within this project, please send an +email to `syntaqx@gmail.com`. All security vulnerabilities will be promptly +addressed. + +## Contribute + +If you want to say thank you and/or support the active development of the +project: + +- Add a [GitHub Star](https://github.com/syntaqx/cookie/stargazers) to the project. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..70168d8 --- /dev/null +++ b/Makefile @@ -0,0 +1,3 @@ +cover: + go test -coverprofile=coverage.out ./... + go tool cover -html=coverage.out diff --git a/README.md b/README.md index d558e47..679c398 100644 --- a/README.md +++ b/README.md @@ -9,201 +9,151 @@ Cookies, but with structs, for happiness. -## Usage +## Overview -```go -import ( - "github.com/syntaqx/cookie" -) +`cookie` is a Go package designed to make handling HTTP cookies simple and +robust, and simplifying the process of parsing them into your structs. It +supports standard data types, custom data types, and signed cookies to ensure +data integrity. -... - -type MyCookies struct { - Debug bool `cookie:"DEBUG"` -} +## Features -... +- **Easy to use**: Simple API for managing cookies in your web applications. +- **Struct-based cookie values**: Easily get cookies into your structs. +- **Custom type support**: Extend cookie parsing with your own data types. +- **Signed cookies**: Ensure the integrity of your cookies with HMAC signatures. +- **No external dependencies**: Just pure standard library goodness. -var cookies Cookies -err := cookie.PopulateFromCookies(r, &cookies) -if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return -} +## Installation -fmt.Println(cookies.Debug) +```bash +go get github.com/syntaqx/cookie ``` -## Helper Methods - -### Get +## Basic Usage -For when you just want the value of the cookie: +The `cookie` package provides a `DefaultManager` that can be used to plug and +play into your existing applications: ```go -debug, err := cookie.Get(r, "DEBUG") -if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return -} +cookie.Get(r, "DEBUG") +cookie.GetSigned(r, "Access-Token") +cookie.Set(w, "DEBUG", "true", cookie.Options{}) +cookie.Set(w, "Access-Token", "token_value", cookie.Options{Signed: true}) +cookie.SetSigned(w, "Access-Token", "token_value") ``` -### Set - -While it's very easy to set Cookies in Go, often times you'll be setting -multiple cookies with the same options: +Or Populate a struct: ```go -options := &cookie.Options{ - Domain: "example.com", - Expires: time.Now().Add(24 * time.Hour), - MaxAge: 86400, - Secure: true, - HttpOnly: true, - SameSite: http.SameSiteStrictMode, +type RequestCookies struct { + Theme string `cookie:"THEME"` + Debug bool `cookie:"DEBUG,unsigned"` + AccessToken string `cookie:"Access-Token,signed"` } -cookie.Set(w, "debug", "true", options) -cookie.Set(w, "theme", "default", options) +var c RequestCookies +cookie.PopulateFromCookies(r, &c) ``` -### Remove +In order to sign cookies however, you must provide a signing key: ```go -cookie.Remove(w, "debug") +signingKey := []byte("super-secret-key") +cookie.DefaultManager = cookie.NewManager( + cookie.WithSigningKey(signingKey), +) ``` -## Signed Cookies +> [!TIP] +> Cookies are stored in plaintext by default (unsigned). A signed cookie is used +> to ensure the cookie value has not been tampered with. This is done by +> creating a [HMAC][] signature of the cookie value using a secret key. Then, +> when the cookie is read, the signature is verified to ensure the cookie value +> has not been modified. +> +> It is still recommended that sensitive data not be stored in cookies, and that +> HTTPS be used to prevent cookie [replay attacks][]. -By default, cookies are stored in plaintext. +## Advanced Usage: Manager -Cookies can be signed to ensure their value has not been tampered with. This -works by creating a [HMAC](https://en.wikipedia.org/wiki/HMAC) of the value -(current cookie), and base64 encoding it. When the cookie gets read, it -recalculates the signature and makes sure that it matches the signature attached -to it. - -It is still recommended that sensitive data not be stored in cookies, and that -HTTPS be used to prevent cookie -[replay attacks](https://en.wikipedia.org/wiki/Replay_attack). - -If you want to sign your cookies, this can be accomplished by: - -### `SetSigned` - -If you want to set a signed cookie, you can use the `SetSigned` helper method: +For more advanced usage, you can create a `Manager` to handle your cookies, +rather than relying on the `DefaultManager`: ```go -cookie.SetSigned(w, "user_id", "123") +manager := cookie.NewManager() ``` -Alternatively, you can pass `Signed` to the options when setting a cookie: +You can optionally provide a signing key for signed cookies: ```go -cookie.Set(w, "user_id", "123", &cookie.Options{ - Signed: true, -}) -``` - -These are functionally identical. - -### `GetSigned` - -If you want to get a signed cookie, you can use the `GetSigned` helper method: - -```go -userID, err := cookie.GetSigned(r, "user_id") -if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return -} +signingKey := []byte("super-secret-key") +manager := cookie.NewManager( + cookie.WithSigningKey(signingKey), +) ``` -### Reading Signed Cookies - -To read signed cookies into your struct, you can use the `signed` tag: - -```go -type User struct { - ID uuid.UUID `cookie:"user_id,signed"` -} -``` +[HMAC]: https://en.wikipedia.org/wiki/HMAC +[replay attacks]: https://en.wikipedia.org/wiki/Replay_attack -### Signing Key +### Setting Cookies -By default, the signing key is set to `[]byte(cookie.DefaultSigningKey)`. You -should change this signing key for your application by assigning the -`cookie.SigningKey` variable to a secret value of your own: +Use the `Set` method to set cookies. You can specify options such as path, +domain, expiration, and whether the cookie should be signed. ```go -cookie.SigningKey = []byte("my-secret-key") +err := manager.Set(w, "DEBUG", "true", cookie.Options{}) +err := manager.Set(w, "Access-Token", "token_value", cookie.Options{Signed: true}) ``` -## Default Options +### Getting Cookies -You can set default options for all cookies by assigning the -`cookie.DefaultOptions` variable: +Use the Get method to retrieve unsigned cookies and GetSigned for signed cookies. ```go -cookie.DefaultOptions = &cookie.Options{ - Domain: "example.com", - Expires: time.Now().Add(24 * time.Hour), - MaxAge: 86400, - Secure: true, - HttpOnly: true, - SameSite: http.SameSiteStrictMode, -} +value, err := manager.Get(r, "DEBUG") +value, err := manager.GetSigned(r, "Access-Token") ``` -These options will be used as the defaults for cookies that do not strictly -override them, allowing you to only set the values you care about. - -### Signed by Default +### Populating Structs from Cookies -If you want all cookies to be signed by default, you can set the `Signed` field -in the `cookie.DefaultOptions`: +Use `PopulateFromCookies` to populate a struct with cookie values. The struct +fields should be tagged with the cookie names. ```go -cookie.DefaultOptions = &cookie.Options{ - Signed: true, +type RequestCookies struct { + Theme string `cookie:"THEME"` + Debug bool `cookie:"DEBUG,unsigned"` + AccessToken string `cookie:"Access-Token,signed"` + NotRequired string `cookie:"NOT_REQUIRED,omitempty"` } -``` -Which will now sign all cookies by default when using the `Set` method. You can -still override this by passing `Signed: false` to the options when setting a -cookie. - -```go -cookie.Set(w, "debug", "true", &cookie.Options{ - Signed: false, -}) +var c RequestCookies +err := manager.PopulateFromCookies(r, &c) ``` -This will require the use of the `GetSigned` method to retrieve cookie values. +> [!TIP] +> By default, the `PopulateFromCookies` method will return an error if a +> required cookie is missing. You can use the `omitempty` tag to make a field +> optional. -```go -debug, err := cookie.GetSigned(r, "debug") -if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return -} -``` +### Supporting Custom Types -When defaulting to signed cookies, unsigned cookies can still be populated by -using the `unsigned` tag in the struct field: +To support custom types, register a custom handler with the Manager. ```go -type MyCookies struct { - Debug bool `cookie:"debug,unsigned"` -} -``` +import ( + "reflect" + "github.com/gofrs/uuid/v5" + "github.com/syntaqx/manager" +) -Or retrieved using the `Get` method, which always retrieves the plaintext value: +... -```go -debug, err := cookie.Get(r, "debug") -if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return -} +manager := cookie.NewManager( + cookie.WithSigningKey(signingKey), + cookie.WithCustomHandler(reflect.TypeOf(uuid.UUID{}), func(value string) (interface{}, error) { + return uuid.FromString(value) + }), +) ``` diff --git a/_example/main.go b/_example/main.go index 215caba..0bece38 100644 --- a/_example/main.go +++ b/_example/main.go @@ -5,70 +5,72 @@ import ( "net/http" "time" - "github.com/gofrs/uuid/v5" "github.com/syntaqx/cookie" ) -type RequestCookies struct { - ApplicationID uuid.UUID `cookie:"Application-ID"` - Theme string `cookie:"THEME"` - Debug bool `cookie:"DEBUG"` - AccessToken string `cookie:"Access-Token,signed"` - UserID int `cookie:"User-ID,signed"` - IsAdmin bool `cookie:"Is-Admin,signed"` - Permissions []string `cookie:"Permissions,signed"` - ExpiresAt time.Time `cookie:"Expires-At,signed"` +var defaultCookieOptions = cookie.Options{ + HttpOnly: true, +} + +var signedCookieOptions = cookie.Options{ + HttpOnly: true, + Secure: true, + Signed: true, } func handler(w http.ResponseWriter, r *http.Request) { - // If none of the cookies are set, we'll set them and refresh the page - // so the rest of the demo functions. - _, err := cookie.Get(r, "Application-ID") + _, err := cookie.Get(r, "DEBUG") if err != nil { setDemoCookies(w) http.Redirect(w, r, "/", http.StatusSeeOther) return } - // Populate struct from cookies - var req RequestCookies - err = cookie.PopulateFromCookies(r, &req) - if err != nil { + type RequestCookies struct { + Theme string `cookie:"THEME"` + Debug bool `cookie:"DEBUG,unsigned"` + AccessToken string `cookie:"Access-Token,signed"` + UserID int `cookie:"User-ID,signed"` + IsAdmin bool `cookie:"Is-Admin,signed"` + Permissions []string `cookie:"Permissions,signed"` + Friends []int `cookie:"Friends,unsigned"` + ExpiresAt time.Time `cookie:"Expires-At,signed"` + NotExists string `cookie:"Does-Not-Exist,omitempty"` + } + + var c RequestCookies + if err := cookie.PopulateFromCookies(r, &c); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - // Dump the struct as a response - fmt.Fprintf(w, "RequestCookies: %+v", req) + fmt.Fprintf(w, "RequestCookies: %v\n", c) } func setDemoCookies(w http.ResponseWriter) { - // Set cookies - cookie.Set(w, "Application-ID", uuid.Must(uuid.NewV7()).String(), nil) - cookie.Set(w, "THEME", "default", nil) - cookie.Set(w, "DEBUG", "true", nil) - - secureOptions := &cookie.Options{ - Path: "/", - Expires: time.Now().Add(24 * time.Hour), - HttpOnly: true, - Secure: true, - } - - // Set signed cookies - cookie.SetSigned(w, "Access-Token", "some-access-token", secureOptions) - cookie.SetSigned(w, "User-ID", "123", secureOptions) - cookie.SetSigned(w, "Is-Admin", "true", secureOptions) - cookie.SetSigned(w, "Permissions", "read,write,execute", secureOptions) - cookie.SetSigned(w, "Expires-At", time.Now().Add(24*time.Hour).Format(time.RFC3339), secureOptions) + cookie.Set(w, "DEBUG", "true", defaultCookieOptions) + cookie.Set(w, "THEME", "dark", defaultCookieOptions) + cookie.Set(w, "Access-Token", "token_value", signedCookieOptions) + cookie.Set(w, "User-ID", "12345", signedCookieOptions) + cookie.Set(w, "Is-Admin", "true", signedCookieOptions) + cookie.Set(w, "Permissions", "read,write,execute", signedCookieOptions) + cookie.Set(w, "Friends", "1,2,3,4,5", defaultCookieOptions) + cookie.Set(w, "Expires-At", time.Now().Add(24*time.Hour).Format(time.RFC3339), signedCookieOptions) } func main() { - cookie.DefaultOptions = &cookie.Options{ - Path: "/", - Expires: time.Now().Add(24 * time.Hour), - HttpOnly: true, - } + // Create a new cookie manager with a signing key. + manager := cookie.NewManager( + cookie.WithSigningKey([]byte("super-secret-key")), + ) + + // Set the default manager to the one we just created. This allows us to use + // the default package functions without having to pass the manager. + // + // This is optional, as you can create a new manager and pass it through to + // the functions that require it, potentially allowing you to have different + // managers with different options. + cookie.DefaultManager = manager http.HandleFunc("/", handler) diff --git a/cookie.go b/cookie.go index 00cd6b7..b6b3918 100644 --- a/cookie.go +++ b/cookie.go @@ -1,25 +1,15 @@ package cookie import ( - "crypto/hmac" - "crypto/sha256" "encoding/base64" - "errors" "net/http" "reflect" - "strconv" "strings" "time" - - "github.com/gofrs/uuid/v5" -) - -const ( - CookieTag = "cookie" - DefaultSigningKey = "default-signing-key" ) -// Options contains the options for a cookie. +// Options represent the options for an HTTP cookie as sent in the Set-Cookie +// header of an HTTP response or the Cookie header of an HTTP request. type Options struct { Path string Domain string @@ -31,66 +21,42 @@ type Options struct { Signed bool } -var ( - // SigningKey is the key used to sign cookies. - SigningKey = []byte(DefaultSigningKey) - - // DefaultOptions are the default options for cookies. - DefaultOptions = &Options{ - Path: "/", - Domain: "", - Expires: time.Time{}, - MaxAge: 0, - Secure: false, - HttpOnly: false, - SameSite: http.SameSiteDefaultMode, - Signed: false, - } -) - -var ( - // ErrInvalidSignedFormat is returned when a signed cookie is not in the correct format. - ErrInvalidSignedFormat = errors.New("cookie: invalid signed cookie format") +// Manager handles cookie operations. +type Manager struct { + signingKey []byte + customHandlers map[reflect.Type]CustomTypeHandler +} - // ErrInvalidSignature is returned when a signed cookie has an invalid signature. - ErrInvalidSignature = errors.New("cookie: invalid cookie signature") -) +// Option is a function type for configuring the Manager. +type Option func(*Manager) -// ErrUnsupportedType is returned when a field type is not supported. -type ErrUnsupportedType struct { - Type reflect.Type +// WithSigningKey sets the signing key for the Manager. +func WithSigningKey(key []byte) Option { + return func(m *Manager) { + m.signingKey = key + } } -// Error returns the error message. -func (e *ErrUnsupportedType) Error() string { - return "cookie: unsupported type: " + e.Type.String() +// WithCustomHandler registers a custom type handler for the Manager. +func WithCustomHandler(typ reflect.Type, handler CustomTypeHandler) Option { + return func(m *Manager) { + m.customHandlers[typ] = handler + } } -// Set sets a cookie with the given name, value, and options. -func Set(w http.ResponseWriter, name, value string, options *Options) { - mergedOptions := mergeOptions(options, DefaultOptions) - cookie := &http.Cookie{ - Name: name, - Value: value, - Path: mergedOptions.Path, - Domain: mergedOptions.Domain, - Expires: mergedOptions.Expires, - MaxAge: mergedOptions.MaxAge, - Secure: mergedOptions.Secure, - HttpOnly: mergedOptions.HttpOnly, - SameSite: mergedOptions.SameSite, +// NewManager creates a new Manager with the given options. +func NewManager(opts ...Option) *Manager { + m := &Manager{ + customHandlers: make(map[reflect.Type]CustomTypeHandler), } - - if mergedOptions.Signed { - signature := generateHMAC(value) - cookie.Value = base64.URLEncoding.EncodeToString([]byte(value)) + "|" + signature + for _, opt := range opts { + opt(m) } - - http.SetCookie(w, cookie) + return m } -// Get retrieves the plaintext value of a cookie with the given name. -func Get(r *http.Request, name string) (string, error) { +// Get retrieves an unsigned cooke value. +func (m *Manager) Get(r *http.Request, name string) (string, error) { cookie, err := r.Cookie(name) if err != nil { return "", err @@ -98,190 +64,88 @@ func Get(r *http.Request, name string) (string, error) { return cookie.Value, nil } -// SetSigned sets a signed cookie with the given name, value, and options. -func SetSigned(w http.ResponseWriter, name, value string, options *Options) { - if options == nil { - options = &Options{} - } - - options.Signed = true - Set(w, name, value, options) -} - -// GetSigned retrieves the value of a signed cookie with the given name. -func GetSigned(r *http.Request, name string) (string, error) { - signedValue, err := Get(r, name) +// GetSigned retrieves a signed cookie value. +func (m *Manager) GetSigned(r *http.Request, name string) (string, error) { + value, err := m.Get(r, name) if err != nil { return "", err } - parts := strings.SplitN(signedValue, "|", 2) + parts := strings.Split(value, "|") if len(parts) != 2 { - return "", ErrInvalidSignedFormat + return "", ErrInvalidSignedCookieFormat } - value, err := base64.URLEncoding.DecodeString(parts[0]) + data, signature := parts[0], parts[1] + dataBytes, err := base64.URLEncoding.DecodeString(data) if err != nil { return "", err } - - signature, err := base64.URLEncoding.DecodeString(parts[1]) + signatureBytes, err := base64.URLEncoding.DecodeString(signature) if err != nil { return "", err } - h := hmac.New(sha256.New, SigningKey) - h.Write(value) - expectedSignature := h.Sum(nil) - - if !hmac.Equal(signature, expectedSignature) { - return "", ErrInvalidSignature + if verify([]byte(data), signatureBytes, m.signingKey) { + return string(dataBytes), nil } - - return string(value), nil + return "", ErrInvalidCookieSignature } -// Remove removes a cookie by setting its MaxAge to -1. -func Remove(w http.ResponseWriter, name string) { - cookie := &http.Cookie{ - Name: name, - Value: "", - Path: "/", - MaxAge: -1, +// Set sets the value of a cookie. +func (m *Manager) Set(w http.ResponseWriter, name, value string, opts ...Options) error { + var o Options + if len(opts) > 0 { + o = opts[0] } - http.SetCookie(w, cookie) -} -// PopulateFromCookies populates the fields of a struct based on cookie tags. -func PopulateFromCookies(r *http.Request, dest interface{}) error { - val := reflect.ValueOf(dest).Elem() - typ := val.Type() - - for i := 0; i < val.NumField(); i++ { - field := val.Field(i) - fieldType := typ.Field(i) - tag := fieldType.Tag.Get(CookieTag) - tagParts := strings.Split(tag, ",") - - if tagParts[0] == "" { - continue - } - - var cookie string - var err error - isSigned := DefaultOptions.Signed - - for _, part := range tagParts[1:] { - if part == "signed" { - isSigned = true - } else if part == "unsigned" { - isSigned = false - } - } - - if isSigned { - cookie, err = GetSigned(r, tagParts[0]) - } else { - cookie, err = Get(r, tagParts[0]) - } - - if err != nil { - return err - } + if o.Signed && m.signingKey != nil { + value = signCookieValue(value, m.signingKey) + } - switch field.Kind() { - case reflect.String: - field.SetString(cookie) - case reflect.Int: - intVal, err := strconv.Atoi(cookie) - if err != nil { - return err - } - field.SetInt(int64(intVal)) - case reflect.Bool: - boolVal, err := strconv.ParseBool(cookie) - if err != nil { - return err - } - field.SetBool(boolVal) - case reflect.Slice: - switch fieldType.Type.Elem().Kind() { - case reflect.String: - field.Set(reflect.ValueOf(strings.Split(cookie, ","))) - case reflect.Int: - intStrings := strings.Split(cookie, ",") - intSlice := make([]int, len(intStrings)) - for i, s := range intStrings { - intVal, err := strconv.Atoi(s) - if err != nil { - return err - } - intSlice[i] = intVal - } - field.Set(reflect.ValueOf(intSlice)) - default: - return &ErrUnsupportedType{fieldType.Type} - } - case reflect.Array: - if fieldType.Type == reflect.TypeOf(uuid.UUID{}) { - uid, err := uuid.FromString(cookie) - if err != nil { - return err - } - field.Set(reflect.ValueOf(uid)) - } - case reflect.Struct: - if fieldType.Type == reflect.TypeOf(time.Time{}) { - timeVal, err := time.Parse(time.RFC3339, cookie) - if err != nil { - return err - } - field.Set(reflect.ValueOf(timeVal)) - } - default: - return &ErrUnsupportedType{fieldType.Type} - } + cookie := &http.Cookie{ + Name: name, + Value: value, + Path: o.Path, + Domain: o.Domain, + Expires: o.Expires, + MaxAge: o.MaxAge, + Secure: o.Secure, + HttpOnly: o.HttpOnly, + SameSite: o.SameSite, } - return nil -} -func generateHMAC(value string) string { - h := hmac.New(sha256.New, SigningKey) - h.Write([]byte(value)) - return base64.URLEncoding.EncodeToString(h.Sum(nil)) + http.SetCookie(w, cookie) + return nil } -func mergeOptions(provided, defaults *Options) *Options { - if provided == nil { - return defaults +// SetSigned sets a signed value of a cookie. +func (m *Manager) SetSigned(w http.ResponseWriter, name, value string, opts ...Options) error { + var o Options + if len(opts) > 0 { + o = opts[0] } + o.Signed = true + return m.Set(w, name, value, o) +} - merged := *defaults - - if provided.Path != "" { - merged.Path = provided.Path - } - if provided.Domain != "" { - merged.Domain = provided.Domain - } - if !provided.Expires.IsZero() { - merged.Expires = provided.Expires - } - if provided.MaxAge != 0 { - merged.MaxAge = provided.MaxAge - } - if provided.Secure { - merged.Secure = provided.Secure - } - if provided.HttpOnly { - merged.HttpOnly = provided.HttpOnly +// Remove removes a cookie from the response. +func (m *Manager) Remove(w http.ResponseWriter, name string, opts ...Options) error { + var o Options + if len(opts) > 0 { + o = opts[0] } - if provided.SameSite != http.SameSiteDefaultMode { - merged.SameSite = provided.SameSite - } - if provided.Signed { - merged.Signed = provided.Signed + cookie := &http.Cookie{ + Name: name, + Value: "", + Path: o.Path, + Domain: o.Domain, + Expires: time.Unix(0, 0), + MaxAge: -1, + Secure: o.Secure, + HttpOnly: o.HttpOnly, + SameSite: o.SameSite, } - - return &merged + http.SetCookie(w, cookie) + return nil } diff --git a/cookie_test.go b/cookie_test.go index 89f66eb..08bc2c8 100644 --- a/cookie_test.go +++ b/cookie_test.go @@ -4,536 +4,295 @@ import ( "encoding/base64" "net/http" "net/http/httptest" - "reflect" - "strconv" - "strings" "testing" - "time" - - "github.com/gofrs/uuid/v5" ) -func TestSet(t *testing.T) { - _, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatal(err) - } - w := httptest.NewRecorder() - - name := "myCookie" - value := "myValue" - - options := &Options{ - Path: "/", - Domain: "example.com", - Expires: time.Now().Add(24 * time.Hour), - MaxAge: 3600, - Secure: true, - HttpOnly: true, - SameSite: http.SameSiteStrictMode, - } - - Set(w, name, value, options) - - // Get the response cookies - cookies := w.Result().Cookies() - - // Check if the cookie was set correctly - if len(cookies) != 1 { - t.Errorf("Expected 1 cookie, got %d", len(cookies)) - } - cookie := cookies[0] - if cookie.Name != name { - t.Errorf("Expected cookie name %s, got %s", name, cookie.Name) - } - if cookie.Value != value { - t.Errorf("Expected cookie value %s, got %s", value, cookie.Value) - } - if cookie.Path != options.Path { - t.Errorf("Expected cookie path %s, got %s", options.Path, cookie.Path) - } - if cookie.Domain != options.Domain { - t.Errorf("Expected cookie domain %s, got %s", options.Domain, cookie.Domain) - } - if cookie.MaxAge != options.MaxAge { - t.Errorf("Expected cookie max age %d, got %d", options.MaxAge, cookie.MaxAge) - } - if cookie.Secure != options.Secure { - t.Errorf("Expected cookie secure %t, got %t", options.Secure, cookie.Secure) - } - if cookie.HttpOnly != options.HttpOnly { - t.Errorf("Expected cookie HttpOnly %t, got %t", options.HttpOnly, cookie.HttpOnly) - } - if cookie.SameSite != options.SameSite { - t.Errorf("Expected cookie SameSite %d, got %d", options.SameSite, cookie.SameSite) - } -} - -func TestSet_WithoutOptions(t *testing.T) { - _, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatal(err) - } - w := httptest.NewRecorder() - - name := "myCookie" - value := "myValue" - - Set(w, name, value, nil) - - // Get the response cookies - cookies := w.Result().Cookies() - - // Check if the cookie was set correctly - if len(cookies) != 1 { - t.Errorf("Expected 1 cookie, got %d", len(cookies)) - } -} - -func TestSetSigned(t *testing.T) { - _, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatal(err) - } - w := httptest.NewRecorder() - - name := "myCookie" - value := "myValue" +var unsignedManager = NewManager() +var signedManager = NewManager(WithSigningKey([]byte("super-secret-key"))) - SetSigned(w, name, value, nil) +func TestManager_Get(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/", nil) - // Get the response cookies - cookies := w.Result().Cookies() - - // Check if the cookie was set correctly - if len(cookies) != 1 { - t.Errorf("Expected 1 cookie, got %d", len(cookies)) - } - cookie := cookies[0] - if cookie.Name != name { - t.Errorf("Expected cookie name %s, got %s", name, cookie.Name) - } - - // Check if the cookie value is signed - parts := strings.Split(cookie.Value, "|") - if len(parts) != 2 { - t.Errorf("Expected signed cookie value, got %s", cookie.Value) - } -} - -func TestGet(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) cookieName := "myCookie" cookieValue := "myValue" - cookie := &http.Cookie{ - Name: cookieName, - Value: cookieValue, - } - r.AddCookie(cookie) - value, err := Get(r, cookieName) + r.AddCookie(&http.Cookie{Name: cookieName, Value: cookieValue}) + + value, err := unsignedManager.Get(r, cookieName) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Errorf("Unexpected error: %v", err) } if value != cookieValue { - t.Errorf("Expected cookie value %s, got %s", cookieValue, value) + t.Errorf("Expected value '%s', but got '%s'", cookieValue, value) } } -func TestGetNonexistentCookie(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) - cookieName := "nonexistentCookie" +func TestManager_GetError(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/", nil) - _, err := Get(r, cookieName) + cookieName := "myCookie" + + _, err := unsignedManager.Get(r, cookieName) if err == nil { - t.Error("Expected error, got nil") + t.Error("Expected error, but got nil") } } -func TestGetSigned(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) +func TestManager_GetSigned(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/", nil) + cookieName := "myCookie" - cookieValue := "myValue" - signature := generateHMAC(cookieValue) - signedValue := base64.URLEncoding.EncodeToString([]byte(cookieValue)) + "|" + signature + expectedValue := "myValue" - cookie := &http.Cookie{ - Name: cookieName, - Value: signedValue, - } + cookieValue := signCookieValue(expectedValue, signedManager.signingKey) - r.AddCookie(cookie) + r.AddCookie(&http.Cookie{Name: cookieName, Value: cookieValue}) - value, err := GetSigned(r, cookieName) + value, err := signedManager.GetSigned(r, cookieName) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Errorf("Unexpected error: %v", err) } - if value != cookieValue { - t.Errorf("Expected cookie value %s, got %s", cookieValue, value) + if value != expectedValue { + t.Errorf("Expected value '%s', but got '%s'", expectedValue, value) } } -func TestGetSignedNonexistentCookie(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) - cookieName := "nonexistentCookie" +func TestManager_GetSignedError(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/", nil) + + cookieName := "myCookie" - _, err := GetSigned(r, cookieName) + _, err := signedManager.GetSigned(r, cookieName) if err == nil { - t.Error("Expected error, got nil") + t.Error("Expected error, but got nil") } } -func TestGetSignedInvalidSignedFormat(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) - cookieName := "myCookie" - cookieValue := "myValue" +func TestManager_GetSignedInvalidFormat(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/", nil) - cookie := &http.Cookie{ - Name: cookieName, - Value: cookieValue, - } + cookieName := "myCookie" + cookieValue := "invalidFormat" - r.AddCookie(cookie) + r.AddCookie(&http.Cookie{Name: cookieName, Value: cookieValue}) - _, err := GetSigned(r, cookieName) + _, err := signedManager.GetSigned(r, cookieName) if err == nil { - t.Error("Expected error, got nil") + t.Error("Expected error, but got nil") } - if err != ErrInvalidSignedFormat { - t.Errorf("Expected error ErrInvalidSignedFormat, got %v", err) + if err != ErrInvalidSignedCookieFormat { + t.Errorf("Expected error '%v', but got '%v'", ErrInvalidSignedCookieFormat, err) } } -func TestGetSignedInvalidValue(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) +func TestManager_GetSignedInvalidSignature(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/", nil) + cookieName := "myCookie" - cookieValue := "myValue" - signedValue := cookieValue + "|invalid" + expectedValue := "myValue" - cookie := &http.Cookie{ - Name: cookieName, - Value: signedValue, - } + data := base64.URLEncoding.EncodeToString([]byte(expectedValue)) + signature := base64.URLEncoding.EncodeToString(sign([]byte("invalidData"), signedManager.signingKey)) + cookieValue := data + "|" + signature - r.AddCookie(cookie) + r.AddCookie(&http.Cookie{Name: cookieName, Value: cookieValue}) - _, err := GetSigned(r, cookieName) + _, err := signedManager.GetSigned(r, cookieName) if err == nil { - t.Error("Expected error, got nil") + t.Error("Expected error, but got nil") + } + + if err != ErrInvalidCookieSignature { + t.Errorf("Expected error '%v', but got '%v'", ErrInvalidCookieSignature, err) } } -func TestGetSignedInvalidBase64(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) - cookieName := "myCookie" - cookieValue := "myValue" - signedValue := base64.URLEncoding.EncodeToString([]byte(cookieValue)) + "|invalid" +func TestManager_GetSigned_Base64DataDecodeError(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/", nil) - cookie := &http.Cookie{ - Name: cookieName, - Value: signedValue, - } + cookieName := "myCookie" + cookieValue := "invalidBase64|invalidBase64" - r.AddCookie(cookie) + r.AddCookie(&http.Cookie{Name: cookieName, Value: cookieValue}) - _, err := GetSigned(r, cookieName) + _, err := signedManager.GetSigned(r, cookieName) if err == nil { - t.Error("Expected error, got nil") + t.Error("Expected error, but got nil") + } + + expectedError := "illegal base64 data at input byte 12" + if err.Error() != expectedError { + t.Errorf("Expected error '%s', but got '%v'", expectedError, err) } } -func TestGetSignedInvalidSignature(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) - cookieName := "myCookie" - cookieValue := "myValue" - signature := generateHMAC("invalid") - signedValue := base64.URLEncoding.EncodeToString([]byte(cookieValue)) + "|" + signature +func TestManager_GetSigned_Base64SignatureDecodeError(t *testing.T) { + r, _ := http.NewRequest(http.MethodGet, "/", nil) - cookie := &http.Cookie{ - Name: cookieName, - Value: signedValue, - } + cookieName := "myCookie" + cookieValue := "ZXhhbXBsZQ==|invalidBase64" - r.AddCookie(cookie) + r.AddCookie(&http.Cookie{Name: cookieName, Value: cookieValue}) - _, err := GetSigned(r, cookieName) + _, err := signedManager.GetSigned(r, cookieName) if err == nil { - t.Error("Expected error, got nil") + t.Error("Expected error, but got nil") } - if err != ErrInvalidSignature { - t.Errorf("Expected error ErrInvalidSignature, got %v", err) + expectedError := "illegal base64 data at input byte 12" + if err.Error() != expectedError { + t.Errorf("Expected error '%s', but got '%v'", expectedError, err) } } -func TestRemove(t *testing.T) { +func TestManager_Set(t *testing.T) { w := httptest.NewRecorder() - name := "myCookie" - Remove(w, name) - // Get the response cookies - cookies := w.Result().Cookies() - // Check if the cookie was removed correctly - if len(cookies) != 1 { - t.Errorf("Expected 1 cookie, got %d", len(cookies)) - } - cookie := cookies[0] - if cookie.Name != name { - t.Errorf("Expected cookie name %s, got %s", name, cookie.Name) - } - if cookie.Value != "" { - t.Errorf("Expected cookie value to be empty, got %s", cookie.Value) - } - if cookie.Path != "/" { - t.Errorf("Expected cookie path /, got %s", cookie.Path) - } - if cookie.MaxAge != -1 { - t.Errorf("Expected cookie max age -1, got %d", cookie.MaxAge) - } -} -func TestPopulateFromCookies(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) - cookies := map[string]string{ - "myCookie": "myValue", - "myIntCookie": "123", - "myBoolCookie": "true", - "mySliceCookie": "val1,val2,val3", - "myIntSliceCookie": "1,2,3", - "myUUIDCookie": uuid.Must(uuid.NewV4()).String(), - "myTimeCookie": time.Now().Format(time.RFC3339), - "unsignedCookie": "unsignedValue", - } - for name, value := range cookies { - r.AddCookie(&http.Cookie{ - Name: name, - Value: value, - }) - } - - r.AddCookie(&http.Cookie{ - Name: "signedCookie", - Value: base64.URLEncoding.EncodeToString([]byte("signedValue")) + "|" + generateHMAC("signedValue"), - }) + cookieName := "myCookie" + cookieValue := "myValue" - type MyStruct struct { - StringField string `cookie:"myCookie"` - IntField int `cookie:"myIntCookie"` - BoolField bool `cookie:"myBoolCookie"` - StringSlice []string `cookie:"mySliceCookie"` - IntSlice []int `cookie:"myIntSliceCookie"` - UUIDField uuid.UUID `cookie:"myUUIDCookie"` - TimeField time.Time `cookie:"myTimeCookie"` - UnsignedCookie string `cookie:"unsignedCookie,unsigned"` - SignedCookie string `cookie:"signedCookie,signed"` - Unsupported complex64 `cookie:""` - } - - dest := &MyStruct{} - err := PopulateFromCookies(r, dest) + err := unsignedManager.Set(w, cookieName, cookieValue) if err != nil { - t.Fatalf("Unexpected error: %v", err) + t.Errorf("Unexpected error: %v", err) } - if dest.StringField != cookies["myCookie"] { - t.Errorf("Expected StringField %s, got %s", cookies["myCookie"], dest.StringField) - } + cookies := w.Result().Cookies() - expectedInt, _ := strconv.Atoi(cookies["myIntCookie"]) - if dest.IntField != expectedInt { - t.Errorf("Expected IntField %d, got %d", expectedInt, dest.IntField) + if len(cookies) != 1 { + t.Errorf("Expected 1 cookie, but got %d", len(cookies)) } - expectedBool, _ := strconv.ParseBool(cookies["myBoolCookie"]) - if dest.BoolField != expectedBool { - t.Errorf("Expected BoolField %t, got %t", expectedBool, dest.BoolField) + cookie := cookies[0] + if cookie.Name != cookieName { + t.Errorf("Expected cookie name '%s', but got '%s'", cookieName, cookie.Name) } - expectedStringSlice := strings.Split(cookies["mySliceCookie"], ",") - if !reflect.DeepEqual(dest.StringSlice, expectedStringSlice) { - t.Errorf("Expected StringSlice %v, got %v", expectedStringSlice, dest.StringSlice) + if cookie.Value != cookieValue { + t.Errorf("Expected cookie value '%s', but got '%s'", cookieValue, cookie.Value) } +} - intStrings := strings.Split(cookies["myIntSliceCookie"], ",") - expectedIntSlice := make([]int, len(intStrings)) - for i, s := range intStrings { - expectedIntSlice[i], _ = strconv.Atoi(s) - } - if !reflect.DeepEqual(dest.IntSlice, expectedIntSlice) { - t.Errorf("Expected IntSlice %v, got %v", expectedIntSlice, dest.IntSlice) - } +func TestManager_Set_Signed(t *testing.T) { + w := httptest.NewRecorder() - expectedUUID, _ := uuid.FromString(cookies["myUUIDCookie"]) - if dest.UUIDField != expectedUUID { - t.Errorf("Expected UUIDField %s, got %s", expectedUUID, dest.UUIDField) - } + cookieName := "myCookie" + expectedValue := "myValue" - expectedTime, _ := time.Parse(time.RFC3339, cookies["myTimeCookie"]) - if !dest.TimeField.Equal(expectedTime) { - t.Errorf("Expected TimeField %v, got %v", expectedTime, dest.TimeField) + err := signedManager.Set(w, cookieName, expectedValue, Options{Signed: true}) + if err != nil { + t.Errorf("Unexpected error: %v", err) } - expectedUnsignedValue := "unsignedValue" - if dest.UnsignedCookie != expectedUnsignedValue { - t.Errorf("Expected UnsignedCookie %s, got %s", expectedUnsignedValue, dest.UnsignedCookie) - } + cookies := w.Result().Cookies() - expectedSignedValue := "signedValue" - if dest.SignedCookie != expectedSignedValue { - t.Errorf("Expected SignedCookie %s, got %s", expectedSignedValue, dest.SignedCookie) + if len(cookies) != 1 { + t.Errorf("Expected 1 cookie, but got %d", len(cookies)) } -} -func TestPopulateFromCookies_InvalidIntValue(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.AddCookie(&http.Cookie{ - Name: "myIntCookie", - Value: "invalid", - }) - - type MyStruct struct { - IntField int `cookie:"myIntCookie"` + cookie := cookies[0] + if cookie.Name != cookieName { + t.Errorf("Expected cookie name '%s', but got '%s'", cookieName, cookie.Name) } - dest := &MyStruct{} - err := PopulateFromCookies(r, dest) - if err == nil { - t.Error("Expected error, got nil") + expectedCookieValue := signCookieValue(expectedValue, signedManager.signingKey) + + if cookie.Value != expectedCookieValue { + t.Errorf("Expected cookie value '%s', but got '%s'", expectedCookieValue, cookie.Value) } } -func TestPopulateFromCookies_InvalidBoolValue(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.AddCookie(&http.Cookie{ - Name: "myBoolCookie", - Value: "invalid", - }) +func TestManager_SetSigned(t *testing.T) { + w := httptest.NewRecorder() - type MyStruct struct { - BoolField bool `cookie:"myBoolCookie"` - } + cookieName := "myCookie" + expectedValue := "myValue" - dest := &MyStruct{} - err := PopulateFromCookies(r, dest) - if err == nil { - t.Error("Expected error, got nil") + err := signedManager.SetSigned(w, cookieName, expectedValue, Options{}) + if err != nil { + t.Errorf("Unexpected error: %v", err) } -} -func TestPopulateFromCookies_InvalidIntSliceValue(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.AddCookie(&http.Cookie{ - Name: "myIntSliceCookie", - Value: "1,2,invalid", - }) + cookies := w.Result().Cookies() - type MyStruct struct { - IntSlice []int `cookie:"myIntSliceCookie"` + if len(cookies) != 1 { + t.Errorf("Expected 1 cookie, but got %d", len(cookies)) } - dest := &MyStruct{} - err := PopulateFromCookies(r, dest) - if err == nil { - t.Error("Expected error, got nil") + cookie := cookies[0] + if cookie.Name != cookieName { + t.Errorf("Expected cookie name '%s', but got '%s'", cookieName, cookie.Name) } -} -func TestPopulateFromCookies_UnexpectedSliceType(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.AddCookie(&http.Cookie{ - Name: "mySliceCookie", - Value: "val1,val2,val3", - }) + expectedCookieValue := signCookieValue(expectedValue, signedManager.signingKey) - type MyStruct struct { - StringSlice []bool `cookie:"mySliceCookie"` + if cookie.Value != expectedCookieValue { + t.Errorf("Expected cookie value '%s', but got '%s'", expectedCookieValue, cookie.Value) } +} - dest := &MyStruct{} - err := PopulateFromCookies(r, dest) - if err == nil { - t.Error("Expected error, got nil") - } +func TestManager_Remove(t *testing.T) { + w := httptest.NewRecorder() + + cookieName := "myCookie" - if _, ok := err.(*ErrUnsupportedType); !ok { - t.Errorf("Expected error of type ErrUnsupportedType, got %T", err) + err := unsignedManager.Remove(w, cookieName) + if err != nil { + t.Errorf("Unexpected error: %v", err) } -} -func TestPopulateFromCookies_InvalidUUIDValue(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.AddCookie(&http.Cookie{ - Name: "myUUIDCookie", - Value: "invalid", - }) + cookies := w.Result().Cookies() - type MyStruct struct { - UUIDField uuid.UUID `cookie:"myUUIDCookie"` + if len(cookies) != 1 { + t.Errorf("Expected 1 cookie, but got %d", len(cookies)) } - dest := &MyStruct{} - err := PopulateFromCookies(r, dest) - if err == nil { - t.Error("Expected error, got nil") + cookie := cookies[0] + if cookie.Name != cookieName { + t.Errorf("Expected cookie name '%s', but got '%s'", cookieName, cookie.Name) } -} -func TestPopulateFromCookies_InvalidTimeValue(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.AddCookie(&http.Cookie{ - Name: "myTimeCookie", - Value: "invalid", - }) + if cookie.Value != "" { + t.Errorf("Expected empty cookie value, but got '%s'", cookie.Value) + } - type MyStruct struct { - TimeField time.Time `cookie:"myTimeCookie"` + if cookie.Expires.Unix() != 0 { + t.Errorf("Expected cookie to be expired, but it expires at %v", cookie.Expires) } - dest := &MyStruct{} - err := PopulateFromCookies(r, dest) - if err == nil { - t.Error("Expected error, got nil") + if cookie.MaxAge != -1 { + t.Errorf("Expected cookie to be expired, but it has MaxAge %d", cookie.MaxAge) } } -func TestPopulateFromCookies_NotFound(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) - type MyStruct struct { - StringField string `cookie:"myCookie"` - } +func TestManager_RemoveWithOptions(t *testing.T) { + w := httptest.NewRecorder() - dest := &MyStruct{} - err := PopulateFromCookies(r, dest) - if err != http.ErrNoCookie { - t.Errorf("Expected error ErrNoCookie, got %v", err) - } -} + cookieName := "myCookie" + path := "/path" -func TestPopulateFromCookies_UnsupportedType(t *testing.T) { - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.AddCookie(&http.Cookie{ - Name: "myCookie", - Value: "myValue", + err := unsignedManager.Remove(w, cookieName, Options{ + Path: path, }) - type MyStruct struct { - Unsupported complex64 `cookie:"myCookie"` + if err != nil { + t.Errorf("Unexpected error: %v", err) } - dest := &MyStruct{} - err := PopulateFromCookies(r, dest) - if err == nil { - t.Error("Expected error, got nil") + cookies := w.Result().Cookies() + + if len(cookies) != 1 { + t.Errorf("Expected 1 cookie, but got %d", len(cookies)) } - if _, ok := err.(*ErrUnsupportedType); !ok { - t.Errorf("Expected error of type ErrUnsupportedType, got %T", err) + cookie := cookies[0] + if cookie.Name != cookieName { + t.Errorf("Expected cookie name '%s', but got '%s'", cookieName, cookie.Name) } - expected := "cookie: unsupported type: complex64" - if err.Error() != expected { - t.Errorf("Expected error message %s, got %s", expected, err.Error()) + if cookie.Path != path { + t.Errorf("Expected cookie path '%s', but got '%s'", path, cookie.Path) } } diff --git a/custom.go b/custom.go new file mode 100644 index 0000000..6f3ee18 --- /dev/null +++ b/custom.go @@ -0,0 +1,4 @@ +package cookie + +// CustomTypeHandler defines a function type for custom type handling. +type CustomTypeHandler func(string) (interface{}, error) diff --git a/custom_test.go b/custom_test.go new file mode 100644 index 0000000..4808645 --- /dev/null +++ b/custom_test.go @@ -0,0 +1,67 @@ +package cookie + +import ( + "errors" + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +type CustomType struct { +} + +func CustomTypeFromString(value string) (CustomType, error) { + return CustomType{}, nil +} + +func CustomTypeErrorMaker(value string) (CustomType, error) { + return CustomType{}, errors.New("just a big ol fail") +} + +func TestWithCustomHandler(t *testing.T) { + manager := NewManager( + WithCustomHandler(reflect.TypeOf(CustomType{}), func(value string) (interface{}, error) { + return CustomTypeFromString(value) + }), + ) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "cookie", Value: "test"}) + + type MyStruct struct { + Field CustomType `cookie:"cookie"` + } + + dest := &MyStruct{} + err := manager.PopulateFromCookies(req, dest) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + expected := CustomType{} + if dest.Field != expected { + t.Errorf("Expected value '%s', but got '%s'", expected, dest.Field) + } +} + +func TestWithCustomHandler_HandlerErr(t *testing.T) { + manager := NewManager( + WithCustomHandler(reflect.TypeOf(CustomType{}), func(value string) (interface{}, error) { + return CustomTypeErrorMaker(value) + }), + ) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "cookie", Value: "test"}) + + type MyStruct struct { + Field CustomType `cookie:"cookie"` + } + + dest := &MyStruct{} + err := manager.PopulateFromCookies(req, dest) + if err == nil { + t.Error("Expected error, but got nil") + } +} diff --git a/default.go b/default.go new file mode 100644 index 0000000..1871212 --- /dev/null +++ b/default.go @@ -0,0 +1,36 @@ +package cookie + +import "net/http" + +// DefaultManager is the default cookie manager exposed by this package. +var DefaultManager = NewManager() + +// Get retrieves an unsigned cooke value. +func Get(r *http.Request, name string) (string, error) { + return DefaultManager.Get(r, name) +} + +// GetSigned retrieves a signed cookie value. +func GetSigned(r *http.Request, name string) (string, error) { + return DefaultManager.GetSigned(r, name) +} + +// Set sets the value of a cookie. +func Set(w http.ResponseWriter, name, value string, opts ...Options) error { + return DefaultManager.Set(w, name, value, opts...) +} + +// SetSigned sets a signed value of a cookie. +func SetSigned(w http.ResponseWriter, name, value string, opts ...Options) error { + return DefaultManager.SetSigned(w, name, value, opts...) +} + +// Remove removes a cookie from the response. +func Remove(w http.ResponseWriter, name string) error { + return DefaultManager.Remove(w, name) +} + +// PopulateFromCookies populates a struct with cookie values. +func PopulateFromCookies(r *http.Request, dest interface{}) error { + return DefaultManager.PopulateFromCookies(r, dest) +} diff --git a/default_test.go b/default_test.go new file mode 100644 index 0000000..8c436c4 --- /dev/null +++ b/default_test.go @@ -0,0 +1,182 @@ +package cookie + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestGet(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatal(err) + } + + req.AddCookie(&http.Cookie{Name: "cookieName", Value: "expectedValue"}) + + value, err := Get(req, "cookieName") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + expectedValue := "expectedValue" + if value != expectedValue { + t.Errorf("Expected value %s, but got %s", expectedValue, value) + } +} + +func TestGetSigned(t *testing.T) { + DefaultManager = signedManager + + r, _ := http.NewRequest(http.MethodGet, "/", nil) + + cookieName := "myCookie" + expectedValue := "myValue" + + cookieValue := signCookieValue(expectedValue, signedManager.signingKey) + + r.AddCookie(&http.Cookie{Name: cookieName, Value: cookieValue}) + + value, err := GetSigned(r, cookieName) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if value != expectedValue { + t.Errorf("Expected value '%s', but got '%s'", expectedValue, value) + } +} + +func TestSet(t *testing.T) { + _, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + + err = Set(rr, "cookieName", "cookieValue") + + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + cookie := rr.Result().Cookies()[0] + if cookie.Name != "cookieName" { + t.Errorf("Expected cookie name %s, but got %s", "cookieName", cookie.Name) + } + if cookie.Value != "cookieValue" { + t.Errorf("Expected cookie value %s, but got %s", "cookieValue", cookie.Value) + } +} + +func TestSet_Signed(t *testing.T) { + DefaultManager = signedManager + + w := httptest.NewRecorder() + + cookieName := "myCookie" + expectedValue := "myValue" + + err := Set(w, cookieName, expectedValue, Options{Signed: true}) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + cookies := w.Result().Cookies() + + if len(cookies) != 1 { + t.Errorf("Expected 1 cookie, but got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != cookieName { + t.Errorf("Expected cookie name '%s', but got '%s'", cookieName, cookie.Name) + } + + expectedCookieValue := signCookieValue(expectedValue, signedManager.signingKey) + if cookie.Value != expectedCookieValue { + t.Errorf("Expected cookie value '%s', but got '%s'", expectedCookieValue, cookie.Value) + } +} + +func TestSetSigned(t *testing.T) { + DefaultManager = signedManager + + w := httptest.NewRecorder() + + cookieName := "myCookie" + expectedValue := "myValue" + + err := SetSigned(w, cookieName, expectedValue, Options{}) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + cookies := w.Result().Cookies() + + if len(cookies) != 1 { + t.Errorf("Expected 1 cookie, but got %d", len(cookies)) + } + + cookie := cookies[0] + if cookie.Name != cookieName { + t.Errorf("Expected cookie name '%s', but got '%s'", cookieName, cookie.Name) + } + + expectedCookieValue := signCookieValue(expectedValue, signedManager.signingKey) + if cookie.Value != expectedCookieValue { + t.Errorf("Expected cookie value '%s', but got '%s'", expectedCookieValue, cookie.Value) + } +} + +func TestRemove(t *testing.T) { + // Create a mock request + _, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatal(err) + } + + // Create a mock response recorder + rr := httptest.NewRecorder() + + // Call the Remove function + err = Remove(rr, "cookieName") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Check if the cookie was set in the response + cookie := rr.Result().Cookies()[0] + if cookie.Name != "cookieName" { + t.Errorf("Expected cookie name %s, but got %s", "cookieName", cookie.Name) + } + if cookie.Value != "" { + t.Errorf("Expected cookie value %s, but got %s", "", cookie.Value) + } +} + +func TestPopulateFromCookies(t *testing.T) { + DefaultManager = signedManager + value := "test" + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "cookie1", Value: value}) + + type MyStruct struct { + Default string `cookie:"cookie1"` + } + + dest := &MyStruct{} + err := PopulateFromCookies(req, dest) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + expected := &MyStruct{ + Default: value, + } + if dest.Default != expected.Default { + t.Errorf("Expected value '%s', but got '%s'", expected.Default, dest.Default) + } +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..8d17557 --- /dev/null +++ b/errors.go @@ -0,0 +1,25 @@ +package cookie + +import ( + "errors" + "reflect" +) + +// ErrInvalidSignedCookieFormat is returned when the format of a signed cookie is invalid. +var ErrInvalidSignedCookieFormat = errors.New("invalid signed cookie format") + +// ErrInvalidCookieSignature is returned when the signature of a signed cookie is invalid. +var ErrInvalidCookieSignature = errors.New("invalid cookie signature") + +// ErrNonNilPointerRequired is returned when the destination parameter must be a non-nil pointer. +var ErrNonNilPointerRequired = errors.New("dest must be a non-nil pointer") + +// ErrUnsupportedType is returned when a field type is not supported. +type ErrUnsupportedType struct { + Type reflect.Type +} + +// Error returns the error message. +func (e *ErrUnsupportedType) Error() string { + return "cookie: unsupported type: " + e.Type.String() +} diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..b7d3983 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,15 @@ +package cookie + +import ( + "reflect" + "testing" +) + +func TestErrUnsupportedType_Error(t *testing.T) { + err := &ErrUnsupportedType{Type: reflect.TypeOf(0)} + expected := "cookie: unsupported type: int" + + if err.Error() != expected { + t.Errorf("Expected error message '%s', but got '%s'", expected, err.Error()) + } +} diff --git a/go.mod b/go.mod index aeb11d3..2826916 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,3 @@ module github.com/syntaqx/cookie -go 1.22.3 - -require github.com/gofrs/uuid/v5 v5.2.0 +go 1.22 diff --git a/go.sum b/go.sum index 831892f..e69de29 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +0,0 @@ -github.com/gofrs/uuid/v5 v5.2.0 h1:qw1GMx6/y8vhVsx626ImfKMuS5CvJmhIKKtuyvfajMM= -github.com/gofrs/uuid/v5 v5.2.0/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8= diff --git a/populate.go b/populate.go new file mode 100644 index 0000000..ca6e891 --- /dev/null +++ b/populate.go @@ -0,0 +1,138 @@ +package cookie + +import ( + "net/http" + "reflect" + "strconv" + "strings" + "time" +) + +// PopulateFromCookies populates a struct with cookie values. +func (m *Manager) PopulateFromCookies(r *http.Request, dest interface{}) error { + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Ptr || v.IsNil() { + return ErrNonNilPointerRequired + } + v = v.Elem() + + t := v.Type() + for i := 0; i < v.NumField(); i++ { + field := t.Field(i) + tag := field.Tag.Get("cookie") + if tag == "" { + continue + } + + parts := strings.Split(tag, ",") + name := parts[0] + signed := false + unsigned := false + omitempty := false + + for _, part := range parts[1:] { + if part == "signed" { + signed = true + } else if part == "unsigned" { + unsigned = true + } else if part == "omitempty" { + omitempty = true + } + } + + var value string + var err error + if signed && !unsigned { + value, err = m.GetSigned(r, name) + } else { + value, err = m.Get(r, name) + } + if err != nil { + if err == http.ErrNoCookie && omitempty { + continue + } + return err + } + + fieldVal := v.Field(i) + + // TODO: Is this necessary? How can I test it? + // if !fieldVal.CanSet() { + // continue + // } + + err = m.setFieldValue(fieldVal, value) + if err != nil { + return err + } + } + return nil +} + +// setFieldValue sets the value of a struct field based on its type. +func (m *Manager) setFieldValue(fieldVal reflect.Value, value string) error { + if handler, ok := m.customHandlers[fieldVal.Type()]; ok { + customValue, err := handler(value) + if err != nil { + return err + } + fieldVal.Set(reflect.ValueOf(customValue)) + return nil + } + + switch fieldVal.Kind() { + case reflect.Bool: + boolVal, err := strconv.ParseBool(value) + if err != nil { + return err + } + fieldVal.SetBool(boolVal) + case reflect.String: + fieldVal.SetString(value) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + intVal, err := strconv.ParseInt(value, 10, fieldVal.Type().Bits()) + if err != nil { + return err + } + fieldVal.SetInt(intVal) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + uintVal, err := strconv.ParseUint(value, 10, fieldVal.Type().Bits()) + if err != nil { + return err + } + fieldVal.SetUint(uintVal) + case reflect.Float32, reflect.Float64: + floatVal, err := strconv.ParseFloat(value, fieldVal.Type().Bits()) + if err != nil { + return err + } + fieldVal.SetFloat(floatVal) + case reflect.Slice: + switch fieldVal.Type().Elem().Kind() { + case reflect.String: + fieldVal.Set(reflect.ValueOf(strings.Split(value, ","))) + case reflect.Int: + strSlice := strings.Split(value, ",") + intSlice := make([]int, len(strSlice)) + for i, str := range strSlice { + intVal, err := strconv.Atoi(str) + if err != nil { + return err + } + intSlice[i] = intVal + } + fieldVal.Set(reflect.ValueOf(intSlice)) + } + case reflect.Struct: + if fieldVal.Type() == reflect.TypeOf(time.Time{}) { + timeVal, err := time.Parse(time.RFC3339, value) + if err != nil { + return err + } + fieldVal.Set(reflect.ValueOf(timeVal)) + } + default: + return &ErrUnsupportedType{Type: fieldVal.Type()} + } + return nil +} diff --git a/populate_test.go b/populate_test.go new file mode 100644 index 0000000..056386c --- /dev/null +++ b/populate_test.go @@ -0,0 +1,235 @@ +package cookie + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" + "time" +) + +func TestManager_PopulateFromCookies(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + + value := "test" + signedValue := signCookieValue(value, signedManager.signingKey) + + cookies := []*http.Cookie{ + {Name: "cookie1", Value: value}, + {Name: "cookie2", Value: value}, + {Name: "cookie3", Value: string(signedValue)}, + {Name: "cookie4", Value: "true"}, + {Name: "cookie5", Value: "123"}, + {Name: "cookie6", Value: "123.45"}, + {Name: "cookie7", Value: "a,b,c"}, + {Name: "cookie8", Value: "1,2,3"}, + {Name: "cookie9", Value: "2021-01-02T15:04:05Z"}, + } + + for _, cookie := range cookies { + req.AddCookie(cookie) + } + + type MyStruct struct { + UntaggedField string + Default string `cookie:"cookie1"` + Unsigned string `cookie:"cookie2,unsigned"` + Signed string `cookie:"cookie3,signed"` + Boolean bool `cookie:"cookie4"` + Integer int `cookie:"cookie5"` + UInteger uint `cookie:"cookie5"` + Float float64 `cookie:"cookie6"` + StringSlice []string `cookie:"cookie7"` + IntSlice []int `cookie:"cookie8"` + Timestamp time.Time `cookie:"cookie9"` + NotSet string `cookie:"cookie10,omitempty"` + } + + dest := &MyStruct{} + err := signedManager.PopulateFromCookies(req, dest) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + expected := &MyStruct{ + Default: value, + Unsigned: value, + Signed: value, + Boolean: true, + Integer: 123, + UInteger: 123, + Float: 123.45, + StringSlice: []string{"a", "b", "c"}, + IntSlice: []int{1, 2, 3}, + Timestamp: time.Date(2021, 1, 2, 15, 4, 5, 0, time.UTC), + NotSet: "", + } + if !reflect.DeepEqual(dest, expected) { + t.Errorf("Unexpected result. Got: %v, want: %v", dest, expected) + } +} + +func TestPopulateFromCookies_NonNilPointerRequired(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + + var dest *struct{} + err := unsignedManager.PopulateFromCookies(req, dest) + if err != ErrNonNilPointerRequired { + t.Errorf("Unexpected error: %v", err) + } +} + +func TestPopulateFromCookies_ErrNoCookie(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + + type MyStruct struct { + Field string `cookie:"cookie"` + } + + dest := &MyStruct{} + err := unsignedManager.PopulateFromCookies(req, dest) + if err != http.ErrNoCookie { + t.Errorf("Unexpected error: %v", err) + } +} + +func TestPopulateFromCookies_ErrUnsupportedType(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + + req.AddCookie(&http.Cookie{Name: "cookie", Value: "test"}) + + type MyStruct struct { + Field complex128 `cookie:"cookie"` + } + + dest := &MyStruct{} + err := unsignedManager.PopulateFromCookies(req, dest) + if err == nil { + t.Error("Expected error, but got nil") + } + + expectedError := "cookie: unsupported type: complex128" + if err.Error() != expectedError { + t.Errorf("Expected error '%s', but got '%v'", expectedError, err) + } +} + +func TestPopulateFromCookies_InvalidBoolean(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "cookie", Value: "invalid"}) + + type MyStruct struct { + Field bool `cookie:"cookie"` + } + + dest := &MyStruct{} + err := unsignedManager.PopulateFromCookies(req, dest) + if err == nil { + t.Error("Expected error, but got nil") + } + + expectedError := "strconv.ParseBool: parsing \"invalid\": invalid syntax" + if err.Error() != expectedError { + t.Errorf("Expected error '%s', but got '%v'", expectedError, err) + } +} + +func TestPopulateFromCookies_InvalidInteger(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "cookie", Value: "invalid"}) + + type MyStruct struct { + Field int `cookie:"cookie"` + } + + dest := &MyStruct{} + err := unsignedManager.PopulateFromCookies(req, dest) + if err == nil { + t.Error("Expected error, but got nil") + } + + expectedError := "strconv.ParseInt: parsing \"invalid\": invalid syntax" + if err.Error() != expectedError { + t.Errorf("Expected error '%s', but got '%v'", expectedError, err) + } +} + +func TestPopulateFromCookies_InvalidUnsignedInteger(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "cookie", Value: "-1"}) + + type MyStruct struct { + Field uint `cookie:"cookie"` + } + + dest := &MyStruct{} + err := unsignedManager.PopulateFromCookies(req, dest) + if err == nil { + t.Error("Expected error, but got nil") + } + + expectedError := "strconv.ParseUint: parsing \"-1\": invalid syntax" + if err.Error() != expectedError { + t.Errorf("Expected error '%s', but got '%v'", expectedError, err) + } +} + +func TestPopulateFromCookies_InvalidFloat(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "cookie", Value: "invalid"}) + + type MyStruct struct { + Field float64 `cookie:"cookie"` + } + + dest := &MyStruct{} + err := unsignedManager.PopulateFromCookies(req, dest) + if err == nil { + t.Error("Expected error, but got nil") + } + + expectedError := "strconv.ParseFloat: parsing \"invalid\": invalid syntax" + if err.Error() != expectedError { + t.Errorf("Expected error '%s', but got '%v'", expectedError, err) + } +} + +func TestPopulateFromCookies_InvalidIntSlice(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "cookie", Value: "invalid"}) + + type MyStruct struct { + Field []int `cookie:"cookie"` + } + + dest := &MyStruct{} + err := unsignedManager.PopulateFromCookies(req, dest) + if err == nil { + t.Error("Expected error, but got nil") + } + + expectedError := "strconv.Atoi: parsing \"invalid\": invalid syntax" + if err.Error() != expectedError { + t.Errorf("Expected error '%s', but got '%v'", expectedError, err) + } +} + +func TestPopulateFromCookies_InvalidTimestamp(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.AddCookie(&http.Cookie{Name: "cookie", Value: "invalid"}) + + type MyStruct struct { + Field time.Time `cookie:"cookie"` + } + + dest := &MyStruct{} + err := unsignedManager.PopulateFromCookies(req, dest) + if err == nil { + t.Error("Expected error, but got nil") + } + + expectedError := "parsing time \"invalid\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"invalid\" as \"2006\"" + if err.Error() != expectedError { + t.Errorf("Expected error '%s', but got '%v'", expectedError, err) + } +} diff --git a/signed.go b/signed.go new file mode 100644 index 0000000..bdd0183 --- /dev/null +++ b/signed.go @@ -0,0 +1,27 @@ +package cookie + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" +) + +// sign generates a HMAC signature for the given data using the provided key. +func sign(data, key []byte) []byte { + h := hmac.New(sha256.New, key) + h.Write(data) + return h.Sum(nil) +} + +// verify checks the HMAC signature of the given data using the provided key. +func verify(data, signature, key []byte) bool { + expectedSignature := sign(data, key) + return hmac.Equal(expectedSignature, signature) +} + +// signCookieValue signs a cookie value using the provided key. +func signCookieValue(value string, key []byte) string { + data := base64.URLEncoding.EncodeToString([]byte(value)) + signature := base64.URLEncoding.EncodeToString(sign([]byte(data), key)) + return data + "|" + signature +} diff --git a/signed_test.go b/signed_test.go new file mode 100644 index 0000000..5dce02d --- /dev/null +++ b/signed_test.go @@ -0,0 +1,26 @@ +package cookie + +import ( + "crypto/hmac" + "testing" +) + +func TestSignVerify(t *testing.T) { + data := []byte("example data") + key := []byte("example key") + + expectedSignature := []byte{ + 143, 44, 153, 63, 34, 126, 71, 71, 60, 146, 137, 245, 195, 249, 153, 4, + 171, 247, 130, 233, 162, 23, 163, 57, 160, 123, 76, 145, 124, 34, 222, 55, + } + + signature := sign(data, key) + + if !hmac.Equal(signature, expectedSignature) { + t.Error("sign failed to generate the expected signature") + } + + if !verify(data, signature, key) { + t.Error("verify failed to validate the signature") + } +}