github.com/jasonish/buffalo@v0.8.2-0.20170413145823-bacbdd415f1b/middleware/csrf.go (about) 1 package middleware 2 3 import ( 4 "crypto/rand" 5 "crypto/subtle" 6 "encoding/base64" 7 "errors" 8 "net/http" 9 "net/url" 10 11 "github.com/gobuffalo/buffalo" 12 ) 13 14 const ( 15 // CSRF token length in bytes. 16 csrfTokenLength int = 32 17 csrfTokenKey string = "authenticity_token" 18 ) 19 20 var ( 21 // The name value used in form fields. 22 fieldName = csrfTokenKey 23 24 // The HTTP request header to inspect 25 headerName = "X-CSRF-Token" 26 27 // Idempotent (safe) methods as defined by RFC7231 section 4.2.2. 28 safeMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"} 29 ) 30 31 var ( 32 // ErrNoReferer is returned when a HTTPS request provides an empty Referer 33 // header. 34 ErrNoReferer = errors.New("referer not supplied") 35 // ErrBadReferer is returned when the scheme & host in the URL do not match 36 // the supplied Referer header. 37 ErrBadReferer = errors.New("referer invalid") 38 // ErrNoCSRFToken is returned if no CSRF token is supplied in the request. 39 ErrNoCSRFToken = errors.New("CSRF token not found in request") 40 // ErrBadCSRFToken is returned if the CSRF token in the request does not match 41 // the token in the session, or is otherwise malformed. 42 ErrBadCSRFToken = errors.New("CSRF token invalid") 43 ) 44 45 // CSRF enable CSRF protection on routes using this middleware. 46 // This middleware is adapted from gorilla/csrf 47 var CSRF = func(next buffalo.Handler) buffalo.Handler { 48 return func(c buffalo.Context) error { 49 var realToken []byte 50 rawRealToken := c.Session().Get(csrfTokenKey) 51 52 if rawRealToken == nil || len(rawRealToken.([]byte)) != csrfTokenLength { 53 // If the token is missing, or the length if the token is wrong, 54 // generate a new token. 55 realToken, err := generateRandomBytes(csrfTokenLength) 56 if err != nil { 57 return err 58 } 59 // Save the new real token in session 60 c.Session().Set(csrfTokenKey, realToken) 61 } else { 62 realToken = rawRealToken.([]byte) 63 } 64 65 // Set masked token in context data, to be available in template 66 c.Set(fieldName, mask(realToken, c.Request())) 67 68 // HTTP methods not defined as idempotent ("safe") under RFC7231 require 69 // inspection. 70 if !contains(safeMethods, c.Request().Method) { 71 // Enforce an origin check for HTTPS connections. As per the Django CSRF 72 // implementation (https://goo.gl/vKA7GE) the Referer header is almost 73 // always present for same-domain HTTP requests. 74 if c.Request().URL.Scheme == "https" { 75 // Fetch the Referer value. Call the error handler if it's empty or 76 // otherwise fails to parse. 77 referer, err := url.Parse(c.Request().Referer()) 78 if err != nil || referer.String() == "" { 79 return ErrNoReferer 80 } 81 82 if sameOrigin(c.Request().URL, referer) == false { 83 return ErrBadReferer 84 } 85 } 86 87 // Retrieve the combined token (pad + masked) token and unmask it. 88 requestToken := unmask(requestCSRFToken(c.Request())) 89 90 // Missing token 91 if requestToken == nil { 92 return ErrNoCSRFToken 93 } 94 95 // Compare tokens 96 if !compareTokens(requestToken, realToken) { 97 return ErrBadCSRFToken 98 } 99 } 100 101 return next(c) 102 } 103 } 104 105 // generateRandomBytes returns securely generated random bytes. 106 // It will return an error if the system's secure random number generator 107 // fails to function correctly. 108 func generateRandomBytes(n int) ([]byte, error) { 109 b := make([]byte, n) 110 _, err := rand.Read(b) 111 // err == nil only if len(b) == n 112 if err != nil { 113 return nil, err 114 } 115 116 return b, nil 117 } 118 119 // sameOrigin returns true if URLs a and b share the same origin. The same 120 // origin is defined as host (which includes the port) and scheme. 121 func sameOrigin(a, b *url.URL) bool { 122 return (a.Scheme == b.Scheme && a.Host == b.Host) 123 } 124 125 // contains is a helper function to check if a string exists in a slice - e.g. 126 // whether a HTTP method exists in a list of safe methods. 127 func contains(vals []string, s string) bool { 128 for _, v := range vals { 129 if v == s { 130 return true 131 } 132 } 133 134 return false 135 } 136 137 // compare securely (constant-time) compares the unmasked token from the request 138 // against the real token from the session. 139 func compareTokens(a, b []byte) bool { 140 // This is required as subtle.ConstantTimeCompare does not check for equal 141 // lengths in Go versions prior to 1.3. 142 if len(a) != len(b) { 143 return false 144 } 145 146 return subtle.ConstantTimeCompare(a, b) == 1 147 } 148 149 // xorToken XORs tokens ([]byte) to provide unique-per-request CSRF tokens. It 150 // will return a masked token if the base token is XOR'ed with a one-time-pad. 151 // An unmasked token will be returned if a masked token is XOR'ed with the 152 // one-time-pad used to mask it. 153 func xorToken(a, b []byte) []byte { 154 n := len(a) 155 if len(b) < n { 156 n = len(b) 157 } 158 159 res := make([]byte, n) 160 161 for i := 0; i < n; i++ { 162 res[i] = a[i] ^ b[i] 163 } 164 165 return res 166 } 167 168 // mask returns a unique-per-request token to mitigate the BREACH attack 169 // as per http://breachattack.com/#mitigations 170 // 171 // The token is generated by XOR'ing a one-time-pad and the base (session) CSRF 172 // token and returning them together as a 64-byte slice. This effectively 173 // randomises the token on a per-request basis without breaking multiple browser 174 // tabs/windows. 175 func mask(realToken []byte, r *http.Request) string { 176 otp, err := generateRandomBytes(csrfTokenLength) 177 if err != nil { 178 return "" 179 } 180 181 // XOR the OTP with the real token to generate a masked token. Append the 182 // OTP to the front of the masked token to allow unmasking in the subsequent 183 // request. 184 return base64.StdEncoding.EncodeToString(append(otp, xorToken(otp, realToken)...)) 185 } 186 187 // unmask splits the issued token (one-time-pad + masked token) and returns the 188 // unmasked request token for comparison. 189 func unmask(issued []byte) []byte { 190 // Issued tokens are always masked and combined with the pad. 191 if len(issued) != csrfTokenLength*2 { 192 return nil 193 } 194 195 // We now know the length of the byte slice. 196 otp := issued[csrfTokenLength:] 197 masked := issued[:csrfTokenLength] 198 199 // Unmask the token by XOR'ing it against the OTP used to mask it. 200 return xorToken(otp, masked) 201 } 202 203 // requestCSRFToken gets the CSRF token from either: 204 // - a HTTP header 205 // - a form value 206 // - a multipart form value 207 func requestCSRFToken(r *http.Request) []byte { 208 // 1. Check the HTTP header first. 209 issued := r.Header.Get(headerName) 210 211 // 2. Fall back to the POST (form) value. 212 if issued == "" { 213 issued = r.PostFormValue(fieldName) 214 } 215 216 // 3. Finally, fall back to the multipart form (if set). 217 if issued == "" && r.MultipartForm != nil { 218 vals := r.MultipartForm.Value[fieldName] 219 220 if len(vals) > 0 { 221 issued = vals[0] 222 } 223 } 224 225 // Decode the "issued" (pad + masked) token sent in the request. Return a 226 // nil byte slice on a decoding error (this will fail upstream). 227 decoded, err := base64.StdEncoding.DecodeString(issued) 228 if err != nil { 229 return nil 230 } 231 232 return decoded 233 }