-
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcsrf.go
249 lines (209 loc) · 6.97 KB
/
csrf.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
package csrf
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"errors"
"net/http"
"net/url"
"strings"
"github.com/gobuffalo/buffalo"
"github.com/gobuffalo/envy"
)
const (
// CSRF token length in bytes.
tokenLength int = 32
tokenKey string = "authenticity_token"
)
var (
// The name value used in form fields.
fieldName = tokenKey
// The HTTP request header to inspect
headerName = "X-CSRF-Token"
// Idempotent (safe) methods as defined by RFC7231 section 4.2.2.
safeMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"}
)
var (
// ErrNoReferer is returned when a HTTPS request provides an empty Referer
// header.
ErrNoReferer = errors.New("referer not supplied")
// ErrBadReferer is returned when the scheme & host in the URL do not match
// the supplied Referer header.
ErrBadReferer = errors.New("referer invalid")
// ErrNoToken is returned if no CSRF token is supplied in the request.
ErrNoToken = errors.New("CSRF token not found in request")
// ErrBadToken is returned if the CSRF token in the request does not match
// the token in the session, or is otherwise malformed.
ErrBadToken = errors.New("CSRF token invalid")
)
// New enable CSRF protection on routes using this middleware.
// This middleware is adapted from gorilla/csrf
var New = func(next buffalo.Handler) buffalo.Handler {
// don't run the actual middleware in test mode
if envy.Get("GO_ENV", "development") == "test" {
return func(c buffalo.Context) error {
c.Logger().Warn("csrf middleware is running in test mode")
c.Set(tokenKey, "test")
return next(c)
}
}
return func(c buffalo.Context) error {
req := c.Request()
var realToken []byte
var err error
rawRealToken := c.Session().Get(tokenKey)
if rawRealToken == nil || len(rawRealToken.([]byte)) != tokenLength {
// If the token is missing, or the length if the token is wrong,
// generate a new token.
realToken, err = generateRandomBytes(tokenLength)
if err != nil {
return err
}
// Save the new real token in session
c.Session().Set(tokenKey, realToken)
} else {
realToken = rawRealToken.([]byte)
}
// Set masked token in context data, to be available in template
c.Set(fieldName, mask(realToken, req))
// HTTP methods not defined as idempotent ("safe") under RFC7231 require
// inspection.
if !contains(safeMethods, req.Method) {
// Enforce an origin check for HTTPS connections. As per the Django CSRF
// implementation (https://goo.gl/vKA7GE) the Referer header is almost
// always present for same-domain HTTP requests.
if req.URL.Scheme == "https" {
// Fetch the Referer value. Call the error handler if it's empty or
// otherwise fails to parse.
referer, err := url.Parse(req.Referer())
if err != nil || referer.String() == "" {
return c.Error(http.StatusForbidden, ErrNoReferer)
}
if !sameOrigin(req.URL, referer) {
return c.Error(http.StatusForbidden, ErrBadReferer)
}
}
// Retrieve the combined token (pad + masked) token and unmask it.
requestToken := unmask(requestCSRFToken(req))
// Missing token
if requestToken == nil {
return c.Error(http.StatusForbidden, ErrNoToken)
}
// Compare tokens
if !compareTokens(requestToken, realToken) {
return c.Error(http.StatusForbidden, ErrBadToken)
}
}
return next(c)
}
}
// generateRandomBytes returns securely generated random bytes.
// It will return an error if the system's secure random number generator
// fails to function correctly.
func generateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
// err == nil only if len(b) == n
if err != nil {
return nil, err
}
return b, nil
}
// sameOrigin returns true if URLs a and b share the same origin. The same
// origin is defined as host (which includes the port) and scheme.
func sameOrigin(a, b *url.URL) bool {
return (a.Scheme == b.Scheme && a.Host == b.Host)
}
// contains is a helper function to check if a string exists in a slice - e.g.
// whether a HTTP method exists in a list of safe methods.
func contains(vals []string, s string) bool {
s = strings.ToLower(s)
for _, v := range vals {
if strings.Contains(s, strings.ToLower(v)) {
return true
}
}
return false
}
// compare securely (constant-time) compares the unmasked token from the request
// against the real token from the session.
func compareTokens(a, b []byte) bool {
// This is required as subtle.ConstantTimeCompare does not check for equal
// lengths in Go versions prior to 1.3.
if len(a) != len(b) {
return false
}
return subtle.ConstantTimeCompare(a, b) == 1
}
// xorToken XORs tokens ([]byte) to provide unique-per-request CSRF tokens. It
// will return a masked token if the base token is XOR'ed with a one-time-pad.
// An unmasked token will be returned if a masked token is XOR'ed with the
// one-time-pad used to mask it.
func xorToken(a, b []byte) []byte {
n := len(a)
bn := len(b)
if bn < n {
n = bn
}
res := make([]byte, n)
for i := 0; i < n; i++ {
res[i] = a[i] ^ b[i]
}
return res
}
// mask returns a unique-per-request token to mitigate the BREACH attack
// as per http://breachattack.com/#mitigations
//
// The token is generated by XOR'ing a one-time-pad and the base (session) CSRF
// token and returning them together as a 64-byte slice. This effectively
// randomises the token on a per-request basis without breaking multiple browser
// tabs/windows.
func mask(realToken []byte, r *http.Request) string {
otp, err := generateRandomBytes(tokenLength)
if err != nil {
return ""
}
// XOR the OTP with the real token to generate a masked token. Append the
// OTP to the front of the masked token to allow unmasking in the subsequent
// request.
return base64.StdEncoding.EncodeToString(append(otp, xorToken(otp, realToken)...))
}
// unmask splits the issued token (one-time-pad + masked token) and returns the
// unmasked request token for comparison.
func unmask(issued []byte) []byte {
// Issued tokens are always masked and combined with the pad.
if len(issued) != tokenLength*2 {
return nil
}
// We now know the length of the byte slice.
otp := issued[:tokenLength]
masked := issued[tokenLength:]
// Unmask the token by XOR'ing it against the OTP used to mask it.
return xorToken(otp, masked)
}
// requestCSRFToken gets the CSRF token from either:
// - a HTTP header
// - a form value
// - a multipart form value
func requestCSRFToken(r *http.Request) []byte {
// 1. Check the HTTP header first.
issued := r.Header.Get(headerName)
// 2. Fall back to the POST (form) value.
if issued == "" {
issued = r.PostFormValue(fieldName)
}
// 3. Finally, fall back to the multipart form (if set).
if issued == "" && r.MultipartForm != nil {
vals := r.MultipartForm.Value[fieldName]
if len(vals) > 0 {
issued = vals[0]
}
}
// Decode the "issued" (pad + masked) token sent in the request. Return a
// nil byte slice on a decoding error (this will fail upstream).
decoded, err := base64.StdEncoding.DecodeString(issued)
if err != nil {
return nil
}
return decoded
}