github.com/cheikhshift/buffalo@v0.9.5/middleware/csrf/csrf.go (about) 1 package csrf 2 3 import ( 4 "crypto/rand" 5 "crypto/subtle" 6 "encoding/base64" 7 "errors" 8 "fmt" 9 "net/http" 10 "net/url" 11 "runtime" 12 "strings" 13 14 "github.com/gobuffalo/buffalo" 15 "github.com/gobuffalo/envy" 16 "github.com/markbates/going/defaults" 17 ) 18 19 const ( 20 // CSRF token length in bytes. 21 tokenLength int = 32 22 tokenKey string = "authenticity_token" 23 ) 24 25 var ( 26 // The name value used in form fields. 27 fieldName = tokenKey 28 29 // The HTTP request header to inspect 30 headerName = "X-CSRF-Token" 31 32 // Idempotent (safe) methods as defined by RFC7231 section 4.2.2. 33 safeMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"} 34 htmlTypes = []string{"html", "form", "plain"} 35 ) 36 37 var ( 38 // ErrNoReferer is returned when a HTTPS request provides an empty Referer 39 // header. 40 ErrNoReferer = errors.New("referer not supplied") 41 // ErrBadReferer is returned when the scheme & host in the URL do not match 42 // the supplied Referer header. 43 ErrBadReferer = errors.New("referer invalid") 44 // ErrNoToken is returned if no CSRF token is supplied in the request. 45 ErrNoToken = errors.New("CSRF token not found in request") 46 // ErrBadToken is returned if the CSRF token in the request does not match 47 // the token in the session, or is otherwise malformed. 48 ErrBadToken = errors.New("CSRF token invalid") 49 ) 50 51 // Middleware is deprecated, and will be removed in v0.10.0. Use csrf.New instead. 52 var Middleware = func(next buffalo.Handler) buffalo.Handler { 53 warningMsg := "csrf.Middleware is deprecated, and will be removed in v0.10.0. Use csrf.New instead." 54 _, file, no, ok := runtime.Caller(1) 55 if ok { 56 warningMsg = fmt.Sprintf("%s Called from %s:%d", warningMsg, file, no) 57 } 58 return New(next) 59 } 60 61 // New enable CSRF protection on routes using this middleware. 62 // This middleware is adapted from gorilla/csrf 63 var New = func(next buffalo.Handler) buffalo.Handler { 64 return func(c buffalo.Context) error { 65 // don't run in test mode 66 if envy.Get("GO_ENV", "development") == "test" { 67 return next(c) 68 } 69 70 req := c.Request() 71 72 ct := defaults.String(req.Header.Get("Content-Type"), req.Header.Get("Accept")) 73 // ignore non-html requests 74 if ct != "" && !contains(htmlTypes, ct) { 75 return next(c) 76 } 77 78 var realToken []byte 79 rawRealToken := c.Session().Get(tokenKey) 80 81 if rawRealToken == nil || len(rawRealToken.([]byte)) != tokenLength { 82 // If the token is missing, or the length if the token is wrong, 83 // generate a new token. 84 realToken, err := generateRandomBytes(tokenLength) 85 if err != nil { 86 return err 87 } 88 // Save the new real token in session 89 c.Session().Set(tokenKey, realToken) 90 } else { 91 realToken = rawRealToken.([]byte) 92 } 93 94 // Set masked token in context data, to be available in template 95 c.Set(fieldName, mask(realToken, req)) 96 97 // HTTP methods not defined as idempotent ("safe") under RFC7231 require 98 // inspection. 99 if !contains(safeMethods, req.Method) { 100 // Enforce an origin check for HTTPS connections. As per the Django CSRF 101 // implementation (https://goo.gl/vKA7GE) the Referer header is almost 102 // always present for same-domain HTTP requests. 103 if req.URL.Scheme == "https" { 104 // Fetch the Referer value. Call the error handler if it's empty or 105 // otherwise fails to parse. 106 referer, err := url.Parse(req.Referer()) 107 if err != nil || referer.String() == "" { 108 return ErrNoReferer 109 } 110 111 if !sameOrigin(req.URL, referer) { 112 return ErrBadReferer 113 } 114 } 115 116 // Retrieve the combined token (pad + masked) token and unmask it. 117 requestToken := unmask(requestCSRFToken(req)) 118 119 // Missing token 120 if requestToken == nil { 121 return ErrNoToken 122 } 123 124 // Compare tokens 125 if !compareTokens(requestToken, realToken) { 126 return ErrBadToken 127 } 128 } 129 130 return next(c) 131 } 132 } 133 134 // generateRandomBytes returns securely generated random bytes. 135 // It will return an error if the system's secure random number generator 136 // fails to function correctly. 137 func generateRandomBytes(n int) ([]byte, error) { 138 b := make([]byte, n) 139 _, err := rand.Read(b) 140 // err == nil only if len(b) == n 141 if err != nil { 142 return nil, err 143 } 144 145 return b, nil 146 } 147 148 // sameOrigin returns true if URLs a and b share the same origin. The same 149 // origin is defined as host (which includes the port) and scheme. 150 func sameOrigin(a, b *url.URL) bool { 151 return (a.Scheme == b.Scheme && a.Host == b.Host) 152 } 153 154 // contains is a helper function to check if a string exists in a slice - e.g. 155 // whether a HTTP method exists in a list of safe methods. 156 func contains(vals []string, s string) bool { 157 s = strings.ToLower(s) 158 for _, v := range vals { 159 if strings.Contains(s, strings.ToLower(v)) { 160 return true 161 } 162 } 163 164 return false 165 } 166 167 // compare securely (constant-time) compares the unmasked token from the request 168 // against the real token from the session. 169 func compareTokens(a, b []byte) bool { 170 // This is required as subtle.ConstantTimeCompare does not check for equal 171 // lengths in Go versions prior to 1.3. 172 if len(a) != len(b) { 173 return false 174 } 175 176 return subtle.ConstantTimeCompare(a, b) == 1 177 } 178 179 // xorToken XORs tokens ([]byte) to provide unique-per-request CSRF tokens. It 180 // will return a masked token if the base token is XOR'ed with a one-time-pad. 181 // An unmasked token will be returned if a masked token is XOR'ed with the 182 // one-time-pad used to mask it. 183 func xorToken(a, b []byte) []byte { 184 n := len(a) 185 if len(b) < n { 186 n = len(b) 187 } 188 189 res := make([]byte, n) 190 191 for i := 0; i < n; i++ { 192 res[i] = a[i] ^ b[i] 193 } 194 195 return res 196 } 197 198 // mask returns a unique-per-request token to mitigate the BREACH attack 199 // as per http://breachattack.com/#mitigations 200 // 201 // The token is generated by XOR'ing a one-time-pad and the base (session) CSRF 202 // token and returning them together as a 64-byte slice. This effectively 203 // randomises the token on a per-request basis without breaking multiple browser 204 // tabs/windows. 205 func mask(realToken []byte, r *http.Request) string { 206 otp, err := generateRandomBytes(tokenLength) 207 if err != nil { 208 return "" 209 } 210 211 // XOR the OTP with the real token to generate a masked token. Append the 212 // OTP to the front of the masked token to allow unmasking in the subsequent 213 // request. 214 return base64.StdEncoding.EncodeToString(append(otp, xorToken(otp, realToken)...)) 215 } 216 217 // unmask splits the issued token (one-time-pad + masked token) and returns the 218 // unmasked request token for comparison. 219 func unmask(issued []byte) []byte { 220 // Issued tokens are always masked and combined with the pad. 221 if len(issued) != tokenLength*2 { 222 return nil 223 } 224 225 // We now know the length of the byte slice. 226 otp := issued[tokenLength:] 227 masked := issued[:tokenLength] 228 229 // Unmask the token by XOR'ing it against the OTP used to mask it. 230 return xorToken(otp, masked) 231 } 232 233 // requestCSRFToken gets the CSRF token from either: 234 // - a HTTP header 235 // - a form value 236 // - a multipart form value 237 func requestCSRFToken(r *http.Request) []byte { 238 // 1. Check the HTTP header first. 239 issued := r.Header.Get(headerName) 240 241 // 2. Fall back to the POST (form) value. 242 if issued == "" { 243 issued = r.PostFormValue(fieldName) 244 } 245 246 // 3. Finally, fall back to the multipart form (if set). 247 if issued == "" && r.MultipartForm != nil { 248 vals := r.MultipartForm.Value[fieldName] 249 250 if len(vals) > 0 { 251 issued = vals[0] 252 } 253 } 254 255 // Decode the "issued" (pad + masked) token sent in the request. Return a 256 // nil byte slice on a decoding error (this will fail upstream). 257 decoded, err := base64.StdEncoding.DecodeString(issued) 258 if err != nil { 259 return nil 260 } 261 262 return decoded 263 }