github.com/bscott/buffalo@v0.11.1/middleware/csrf/csrf.go (about) 1 package csrf 2 3 import ( 4 "crypto/rand" 5 "crypto/subtle" 6 "encoding/base64" 7 "errors" 8 "net/http" 9 "net/url" 10 "strings" 11 12 "github.com/gobuffalo/buffalo" 13 "github.com/gobuffalo/envy" 14 "github.com/markbates/going/defaults" 15 ) 16 17 const ( 18 // CSRF token length in bytes. 19 tokenLength int = 32 20 tokenKey string = "authenticity_token" 21 ) 22 23 var ( 24 // The name value used in form fields. 25 fieldName = tokenKey 26 27 // The HTTP request header to inspect 28 headerName = "X-CSRF-Token" 29 30 // Idempotent (safe) methods as defined by RFC7231 section 4.2.2. 31 safeMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"} 32 htmlTypes = []string{"html", "form", "plain", "*/*"} 33 ) 34 35 var ( 36 // ErrNoReferer is returned when a HTTPS request provides an empty Referer 37 // header. 38 ErrNoReferer = errors.New("referer not supplied") 39 // ErrBadReferer is returned when the scheme & host in the URL do not match 40 // the supplied Referer header. 41 ErrBadReferer = errors.New("referer invalid") 42 // ErrNoToken is returned if no CSRF token is supplied in the request. 43 ErrNoToken = errors.New("CSRF token not found in request") 44 // ErrBadToken is returned if the CSRF token in the request does not match 45 // the token in the session, or is otherwise malformed. 46 ErrBadToken = errors.New("CSRF token invalid") 47 ) 48 49 // New enable CSRF protection on routes using this middleware. 50 // This middleware is adapted from gorilla/csrf 51 var New = func(next buffalo.Handler) buffalo.Handler { 52 return func(c buffalo.Context) error { 53 // don't run in test mode 54 if envy.Get("GO_ENV", "development") == "test" { 55 c.Set(tokenKey, "test") 56 return next(c) 57 } 58 59 req := c.Request() 60 61 ct := defaults.String(req.Header.Get("Content-Type"), req.Header.Get("Accept")) 62 // ignore non-html requests 63 if ct != "" && !contains(htmlTypes, ct) { 64 return next(c) 65 } 66 67 var realToken []byte 68 rawRealToken := c.Session().Get(tokenKey) 69 70 if rawRealToken == nil || len(rawRealToken.([]byte)) != tokenLength { 71 // If the token is missing, or the length if the token is wrong, 72 // generate a new token. 73 realToken, err := generateRandomBytes(tokenLength) 74 if err != nil { 75 return err 76 } 77 // Save the new real token in session 78 c.Session().Set(tokenKey, realToken) 79 } else { 80 realToken = rawRealToken.([]byte) 81 } 82 83 // Set masked token in context data, to be available in template 84 c.Set(fieldName, mask(realToken, req)) 85 86 // HTTP methods not defined as idempotent ("safe") under RFC7231 require 87 // inspection. 88 if !contains(safeMethods, req.Method) { 89 // Enforce an origin check for HTTPS connections. As per the Django CSRF 90 // implementation (https://goo.gl/vKA7GE) the Referer header is almost 91 // always present for same-domain HTTP requests. 92 if req.URL.Scheme == "https" { 93 // Fetch the Referer value. Call the error handler if it's empty or 94 // otherwise fails to parse. 95 referer, err := url.Parse(req.Referer()) 96 if err != nil || referer.String() == "" { 97 return ErrNoReferer 98 } 99 100 if !sameOrigin(req.URL, referer) { 101 return ErrBadReferer 102 } 103 } 104 105 // Retrieve the combined token (pad + masked) token and unmask it. 106 requestToken := unmask(requestCSRFToken(req)) 107 108 // Missing token 109 if requestToken == nil { 110 return ErrNoToken 111 } 112 113 // Compare tokens 114 if !compareTokens(requestToken, realToken) { 115 return ErrBadToken 116 } 117 } 118 119 return next(c) 120 } 121 } 122 123 // generateRandomBytes returns securely generated random bytes. 124 // It will return an error if the system's secure random number generator 125 // fails to function correctly. 126 func generateRandomBytes(n int) ([]byte, error) { 127 b := make([]byte, n) 128 _, err := rand.Read(b) 129 // err == nil only if len(b) == n 130 if err != nil { 131 return nil, err 132 } 133 134 return b, nil 135 } 136 137 // sameOrigin returns true if URLs a and b share the same origin. The same 138 // origin is defined as host (which includes the port) and scheme. 139 func sameOrigin(a, b *url.URL) bool { 140 return (a.Scheme == b.Scheme && a.Host == b.Host) 141 } 142 143 // contains is a helper function to check if a string exists in a slice - e.g. 144 // whether a HTTP method exists in a list of safe methods. 145 func contains(vals []string, s string) bool { 146 s = strings.ToLower(s) 147 for _, v := range vals { 148 if strings.Contains(s, strings.ToLower(v)) { 149 return true 150 } 151 } 152 153 return false 154 } 155 156 // compare securely (constant-time) compares the unmasked token from the request 157 // against the real token from the session. 158 func compareTokens(a, b []byte) bool { 159 // This is required as subtle.ConstantTimeCompare does not check for equal 160 // lengths in Go versions prior to 1.3. 161 if len(a) != len(b) { 162 return false 163 } 164 165 return subtle.ConstantTimeCompare(a, b) == 1 166 } 167 168 // xorToken XORs tokens ([]byte) to provide unique-per-request CSRF tokens. It 169 // will return a masked token if the base token is XOR'ed with a one-time-pad. 170 // An unmasked token will be returned if a masked token is XOR'ed with the 171 // one-time-pad used to mask it. 172 func xorToken(a, b []byte) []byte { 173 n := len(a) 174 if len(b) < n { 175 n = len(b) 176 } 177 178 res := make([]byte, n) 179 180 for i := 0; i < n; i++ { 181 res[i] = a[i] ^ b[i] 182 } 183 184 return res 185 } 186 187 // mask returns a unique-per-request token to mitigate the BREACH attack 188 // as per http://breachattack.com/#mitigations 189 // 190 // The token is generated by XOR'ing a one-time-pad and the base (session) CSRF 191 // token and returning them together as a 64-byte slice. This effectively 192 // randomises the token on a per-request basis without breaking multiple browser 193 // tabs/windows. 194 func mask(realToken []byte, r *http.Request) string { 195 otp, err := generateRandomBytes(tokenLength) 196 if err != nil { 197 return "" 198 } 199 200 // XOR the OTP with the real token to generate a masked token. Append the 201 // OTP to the front of the masked token to allow unmasking in the subsequent 202 // request. 203 return base64.StdEncoding.EncodeToString(append(otp, xorToken(otp, realToken)...)) 204 } 205 206 // unmask splits the issued token (one-time-pad + masked token) and returns the 207 // unmasked request token for comparison. 208 func unmask(issued []byte) []byte { 209 // Issued tokens are always masked and combined with the pad. 210 if len(issued) != tokenLength*2 { 211 return nil 212 } 213 214 // We now know the length of the byte slice. 215 otp := issued[tokenLength:] 216 masked := issued[:tokenLength] 217 218 // Unmask the token by XOR'ing it against the OTP used to mask it. 219 return xorToken(otp, masked) 220 } 221 222 // requestCSRFToken gets the CSRF token from either: 223 // - a HTTP header 224 // - a form value 225 // - a multipart form value 226 func requestCSRFToken(r *http.Request) []byte { 227 // 1. Check the HTTP header first. 228 issued := r.Header.Get(headerName) 229 230 // 2. Fall back to the POST (form) value. 231 if issued == "" { 232 issued = r.PostFormValue(fieldName) 233 } 234 235 // 3. Finally, fall back to the multipart form (if set). 236 if issued == "" && r.MultipartForm != nil { 237 vals := r.MultipartForm.Value[fieldName] 238 239 if len(vals) > 0 { 240 issued = vals[0] 241 } 242 } 243 244 // Decode the "issued" (pad + masked) token sent in the request. Return a 245 // nil byte slice on a decoding error (this will fail upstream). 246 decoded, err := base64.StdEncoding.DecodeString(issued) 247 if err != nil { 248 return nil 249 } 250 251 return decoded 252 }