github.com/corylanou/buffalo@v0.8.0/middleware/csrf.go (about) 1 package middleware 2 3 import ( 4 "crypto/rand" 5 "crypto/subtle" 6 "encoding/base64" 7 "errors" 8 "github.com/gobuffalo/buffalo" 9 "net/http" 10 "net/url" 11 ) 12 13 const ( 14 // CSRF token length in bytes. 15 csrfTokenLength int = 32 16 csrfTokenKey string = "authenticity_token" 17 ) 18 19 var ( 20 // The name value used in form fields. 21 fieldName = csrfTokenKey 22 23 // The HTTP request header to inspect 24 headerName = "X-CSRF-Token" 25 26 // Idempotent (safe) methods as defined by RFC7231 section 4.2.2. 27 safeMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"} 28 ) 29 30 var ( 31 // ErrNoReferer is returned when a HTTPS request provides an empty Referer 32 // header. 33 ErrNoReferer = errors.New("referer not supplied") 34 // ErrBadReferer is returned when the scheme & host in the URL do not match 35 // the supplied Referer header. 36 ErrBadReferer = errors.New("referer invalid") 37 // ErrNoCSRFToken is returned if no CSRF token is supplied in the request. 38 ErrNoCSRFToken = errors.New("CSRF token not found in request") 39 // ErrBadCSRFToken is returned if the CSRF token in the request does not match 40 // the token in the session, or is otherwise malformed. 41 ErrBadCSRFToken = errors.New("CSRF token invalid") 42 ) 43 44 // CSRF enable CSRF protection on routes using this middleware. 45 // This middleware is adapted from gorilla/csrf 46 func CSRF(next buffalo.Handler) buffalo.Handler { 47 return func(c buffalo.Context) error { 48 var realToken []byte 49 rawRealToken := c.Session().Get(csrfTokenKey) 50 51 if rawRealToken == nil || len(rawRealToken.([]byte)) != csrfTokenLength { 52 // If the token is missing, or the length if the token is wrong, 53 // generate a new token. 54 realToken, err := generateRandomBytes(csrfTokenLength) 55 if err != nil { 56 return err 57 } 58 // Save the new real token in session 59 c.Session().Set(csrfTokenKey, realToken) 60 } else { 61 realToken = rawRealToken.([]byte) 62 } 63 64 // Set masked token in context data, to be available in template 65 c.Set(fieldName, mask(realToken, c.Request())) 66 67 // HTTP methods not defined as idempotent ("safe") under RFC7231 require 68 // inspection. 69 if !contains(safeMethods, c.Request().Method) { 70 // Enforce an origin check for HTTPS connections. As per the Django CSRF 71 // implementation (https://goo.gl/vKA7GE) the Referer header is almost 72 // always present for same-domain HTTP requests. 73 if c.Request().URL.Scheme == "https" { 74 // Fetch the Referer value. Call the error handler if it's empty or 75 // otherwise fails to parse. 76 referer, err := url.Parse(c.Request().Referer()) 77 if err != nil || referer.String() == "" { 78 return ErrNoReferer 79 } 80 81 if sameOrigin(c.Request().URL, referer) == false { 82 return ErrBadReferer 83 } 84 } 85 86 // Retrieve the combined token (pad + masked) token and unmask it. 87 requestToken := unmask(requestCSRFToken(c.Request())) 88 89 // Missing token 90 if requestToken == nil { 91 return ErrNoCSRFToken 92 } 93 94 // Compare tokens 95 if !compareTokens(requestToken, realToken) { 96 return ErrBadCSRFToken 97 } 98 } 99 100 return next(c) 101 } 102 } 103 104 // generateRandomBytes returns securely generated random bytes. 105 // It will return an error if the system's secure random number generator 106 // fails to function correctly. 107 func generateRandomBytes(n int) ([]byte, error) { 108 b := make([]byte, n) 109 _, err := rand.Read(b) 110 // err == nil only if len(b) == n 111 if err != nil { 112 return nil, err 113 } 114 115 return b, nil 116 } 117 118 // sameOrigin returns true if URLs a and b share the same origin. The same 119 // origin is defined as host (which includes the port) and scheme. 120 func sameOrigin(a, b *url.URL) bool { 121 return (a.Scheme == b.Scheme && a.Host == b.Host) 122 } 123 124 // contains is a helper function to check if a string exists in a slice - e.g. 125 // whether a HTTP method exists in a list of safe methods. 126 func contains(vals []string, s string) bool { 127 for _, v := range vals { 128 if v == s { 129 return true 130 } 131 } 132 133 return false 134 } 135 136 // compare securely (constant-time) compares the unmasked token from the request 137 // against the real token from the session. 138 func compareTokens(a, b []byte) bool { 139 // This is required as subtle.ConstantTimeCompare does not check for equal 140 // lengths in Go versions prior to 1.3. 141 if len(a) != len(b) { 142 return false 143 } 144 145 return subtle.ConstantTimeCompare(a, b) == 1 146 } 147 148 // xorToken XORs tokens ([]byte) to provide unique-per-request CSRF tokens. It 149 // will return a masked token if the base token is XOR'ed with a one-time-pad. 150 // An unmasked token will be returned if a masked token is XOR'ed with the 151 // one-time-pad used to mask it. 152 func xorToken(a, b []byte) []byte { 153 n := len(a) 154 if len(b) < n { 155 n = len(b) 156 } 157 158 res := make([]byte, n) 159 160 for i := 0; i < n; i++ { 161 res[i] = a[i] ^ b[i] 162 } 163 164 return res 165 } 166 167 // mask returns a unique-per-request token to mitigate the BREACH attack 168 // as per http://breachattack.com/#mitigations 169 // 170 // The token is generated by XOR'ing a one-time-pad and the base (session) CSRF 171 // token and returning them together as a 64-byte slice. This effectively 172 // randomises the token on a per-request basis without breaking multiple browser 173 // tabs/windows. 174 func mask(realToken []byte, r *http.Request) string { 175 otp, err := generateRandomBytes(csrfTokenLength) 176 if err != nil { 177 return "" 178 } 179 180 // XOR the OTP with the real token to generate a masked token. Append the 181 // OTP to the front of the masked token to allow unmasking in the subsequent 182 // request. 183 return base64.StdEncoding.EncodeToString(append(otp, xorToken(otp, realToken)...)) 184 } 185 186 // unmask splits the issued token (one-time-pad + masked token) and returns the 187 // unmasked request token for comparison. 188 func unmask(issued []byte) []byte { 189 // Issued tokens are always masked and combined with the pad. 190 if len(issued) != csrfTokenLength*2 { 191 return nil 192 } 193 194 // We now know the length of the byte slice. 195 otp := issued[csrfTokenLength:] 196 masked := issued[:csrfTokenLength] 197 198 // Unmask the token by XOR'ing it against the OTP used to mask it. 199 return xorToken(otp, masked) 200 } 201 202 // requestCSRFToken gets the CSRF token from either: 203 // - a HTTP header 204 // - a form value 205 // - a multipart form value 206 func requestCSRFToken(r *http.Request) []byte { 207 // 1. Check the HTTP header first. 208 issued := r.Header.Get(headerName) 209 210 // 2. Fall back to the POST (form) value. 211 if issued == "" { 212 issued = r.PostFormValue(fieldName) 213 } 214 215 // 3. Finally, fall back to the multipart form (if set). 216 if issued == "" && r.MultipartForm != nil { 217 vals := r.MultipartForm.Value[fieldName] 218 219 if len(vals) > 0 { 220 issued = vals[0] 221 } 222 } 223 224 // Decode the "issued" (pad + masked) token sent in the request. Return a 225 // nil byte slice on a decoding error (this will fail upstream). 226 decoded, err := base64.StdEncoding.DecodeString(issued) 227 if err != nil { 228 return nil 229 } 230 231 return decoded 232 }