github.com/mayra-cabrera/buffalo@v0.9.4-0.20170814145312-66d2e7772f11/middleware/csrf/csrf.go (about) 1 package csrf 2 3 import ( 4 "crypto/rand" 5 "crypto/subtle" 6 "encoding/base64" 7 "errors" 8 "net/http" 9 "net/url" 10 "strings" 11 12 "github.com/gobuffalo/buffalo" 13 ) 14 15 const ( 16 // CSRF token length in bytes. 17 tokenLength int = 32 18 tokenKey string = "authenticity_token" 19 ) 20 21 var ( 22 // The name value used in form fields. 23 fieldName = tokenKey 24 25 // The HTTP request header to inspect 26 headerName = "X-CSRF-Token" 27 28 // Idempotent (safe) methods as defined by RFC7231 section 4.2.2. 29 safeMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"} 30 htmlTypes = []string{"html", "form", "plain"} 31 ) 32 33 var ( 34 // ErrNoReferer is returned when a HTTPS request provides an empty Referer 35 // header. 36 ErrNoReferer = errors.New("referer not supplied") 37 // ErrBadReferer is returned when the scheme & host in the URL do not match 38 // the supplied Referer header. 39 ErrBadReferer = errors.New("referer invalid") 40 // ErrNoToken is returned if no CSRF token is supplied in the request. 41 ErrNoToken = errors.New("CSRF token not found in request") 42 // ErrBadToken is returned if the CSRF token in the request does not match 43 // the token in the session, or is otherwise malformed. 44 ErrBadToken = errors.New("CSRF token invalid") 45 ) 46 47 // Middleware enable CSRF protection on routes using this middleware. 48 // This middleware is adapted from gorilla/csrf 49 var Middleware = func(next buffalo.Handler) buffalo.Handler { 50 return func(c buffalo.Context) error { 51 req := c.Request() 52 53 ct := req.Header.Get("Content-Type") 54 // ignore non-html requests 55 if ct != "" && !contains(htmlTypes, ct) { 56 return next(c) 57 } 58 59 var realToken []byte 60 rawRealToken := c.Session().Get(tokenKey) 61 62 if rawRealToken == nil || len(rawRealToken.([]byte)) != tokenLength { 63 // If the token is missing, or the length if the token is wrong, 64 // generate a new token. 65 realToken, err := generateRandomBytes(tokenLength) 66 if err != nil { 67 return err 68 } 69 // Save the new real token in session 70 c.Session().Set(tokenKey, realToken) 71 } else { 72 realToken = rawRealToken.([]byte) 73 } 74 75 // Set masked token in context data, to be available in template 76 c.Set(fieldName, mask(realToken, req)) 77 78 // HTTP methods not defined as idempotent ("safe") under RFC7231 require 79 // inspection. 80 if !contains(safeMethods, req.Method) { 81 // Enforce an origin check for HTTPS connections. As per the Django CSRF 82 // implementation (https://goo.gl/vKA7GE) the Referer header is almost 83 // always present for same-domain HTTP requests. 84 if req.URL.Scheme == "https" { 85 // Fetch the Referer value. Call the error handler if it's empty or 86 // otherwise fails to parse. 87 referer, err := url.Parse(req.Referer()) 88 if err != nil || referer.String() == "" { 89 return ErrNoReferer 90 } 91 92 if !sameOrigin(req.URL, referer) { 93 return ErrBadReferer 94 } 95 } 96 97 // Retrieve the combined token (pad + masked) token and unmask it. 98 requestToken := unmask(requestCSRFToken(req)) 99 100 // Missing token 101 if requestToken == nil { 102 return ErrNoToken 103 } 104 105 // Compare tokens 106 if !compareTokens(requestToken, realToken) { 107 return ErrBadToken 108 } 109 } 110 111 return next(c) 112 } 113 } 114 115 // generateRandomBytes returns securely generated random bytes. 116 // It will return an error if the system's secure random number generator 117 // fails to function correctly. 118 func generateRandomBytes(n int) ([]byte, error) { 119 b := make([]byte, n) 120 _, err := rand.Read(b) 121 // err == nil only if len(b) == n 122 if err != nil { 123 return nil, err 124 } 125 126 return b, nil 127 } 128 129 // sameOrigin returns true if URLs a and b share the same origin. The same 130 // origin is defined as host (which includes the port) and scheme. 131 func sameOrigin(a, b *url.URL) bool { 132 return (a.Scheme == b.Scheme && a.Host == b.Host) 133 } 134 135 // contains is a helper function to check if a string exists in a slice - e.g. 136 // whether a HTTP method exists in a list of safe methods. 137 func contains(vals []string, s string) bool { 138 s = strings.ToLower(s) 139 for _, v := range vals { 140 if strings.Contains(s, strings.ToLower(v)) { 141 return true 142 } 143 } 144 145 return false 146 } 147 148 // compare securely (constant-time) compares the unmasked token from the request 149 // against the real token from the session. 150 func compareTokens(a, b []byte) bool { 151 // This is required as subtle.ConstantTimeCompare does not check for equal 152 // lengths in Go versions prior to 1.3. 153 if len(a) != len(b) { 154 return false 155 } 156 157 return subtle.ConstantTimeCompare(a, b) == 1 158 } 159 160 // xorToken XORs tokens ([]byte) to provide unique-per-request CSRF tokens. It 161 // will return a masked token if the base token is XOR'ed with a one-time-pad. 162 // An unmasked token will be returned if a masked token is XOR'ed with the 163 // one-time-pad used to mask it. 164 func xorToken(a, b []byte) []byte { 165 n := len(a) 166 if len(b) < n { 167 n = len(b) 168 } 169 170 res := make([]byte, n) 171 172 for i := 0; i < n; i++ { 173 res[i] = a[i] ^ b[i] 174 } 175 176 return res 177 } 178 179 // mask returns a unique-per-request token to mitigate the BREACH attack 180 // as per http://breachattack.com/#mitigations 181 // 182 // The token is generated by XOR'ing a one-time-pad and the base (session) CSRF 183 // token and returning them together as a 64-byte slice. This effectively 184 // randomises the token on a per-request basis without breaking multiple browser 185 // tabs/windows. 186 func mask(realToken []byte, r *http.Request) string { 187 otp, err := generateRandomBytes(tokenLength) 188 if err != nil { 189 return "" 190 } 191 192 // XOR the OTP with the real token to generate a masked token. Append the 193 // OTP to the front of the masked token to allow unmasking in the subsequent 194 // request. 195 return base64.StdEncoding.EncodeToString(append(otp, xorToken(otp, realToken)...)) 196 } 197 198 // unmask splits the issued token (one-time-pad + masked token) and returns the 199 // unmasked request token for comparison. 200 func unmask(issued []byte) []byte { 201 // Issued tokens are always masked and combined with the pad. 202 if len(issued) != tokenLength*2 { 203 return nil 204 } 205 206 // We now know the length of the byte slice. 207 otp := issued[tokenLength:] 208 masked := issued[:tokenLength] 209 210 // Unmask the token by XOR'ing it against the OTP used to mask it. 211 return xorToken(otp, masked) 212 } 213 214 // requestCSRFToken gets the CSRF token from either: 215 // - a HTTP header 216 // - a form value 217 // - a multipart form value 218 func requestCSRFToken(r *http.Request) []byte { 219 // 1. Check the HTTP header first. 220 issued := r.Header.Get(headerName) 221 222 // 2. Fall back to the POST (form) value. 223 if issued == "" { 224 issued = r.PostFormValue(fieldName) 225 } 226 227 // 3. Finally, fall back to the multipart form (if set). 228 if issued == "" && r.MultipartForm != nil { 229 vals := r.MultipartForm.Value[fieldName] 230 231 if len(vals) > 0 { 232 issued = vals[0] 233 } 234 } 235 236 // Decode the "issued" (pad + masked) token sent in the request. Return a 237 // nil byte slice on a decoding error (this will fail upstream). 238 decoded, err := base64.StdEncoding.DecodeString(issued) 239 if err != nil { 240 return nil 241 } 242 243 return decoded 244 }