github.com/jasonish/buffalo@v0.8.2-0.20170413145823-bacbdd415f1b/middleware/csrf.go (about)

     1  package middleware
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/subtle"
     6  	"encoding/base64"
     7  	"errors"
     8  	"net/http"
     9  	"net/url"
    10  
    11  	"github.com/gobuffalo/buffalo"
    12  )
    13  
    14  const (
    15  	// CSRF token length in bytes.
    16  	csrfTokenLength int    = 32
    17  	csrfTokenKey    string = "authenticity_token"
    18  )
    19  
    20  var (
    21  	// The name value used in form fields.
    22  	fieldName = csrfTokenKey
    23  
    24  	// The HTTP request header to inspect
    25  	headerName = "X-CSRF-Token"
    26  
    27  	// Idempotent (safe) methods as defined by RFC7231 section 4.2.2.
    28  	safeMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"}
    29  )
    30  
    31  var (
    32  	// ErrNoReferer is returned when a HTTPS request provides an empty Referer
    33  	// header.
    34  	ErrNoReferer = errors.New("referer not supplied")
    35  	// ErrBadReferer is returned when the scheme & host in the URL do not match
    36  	// the supplied Referer header.
    37  	ErrBadReferer = errors.New("referer invalid")
    38  	// ErrNoCSRFToken is returned if no CSRF token is supplied in the request.
    39  	ErrNoCSRFToken = errors.New("CSRF token not found in request")
    40  	// ErrBadCSRFToken is returned if the CSRF token in the request does not match
    41  	// the token in the session, or is otherwise malformed.
    42  	ErrBadCSRFToken = errors.New("CSRF token invalid")
    43  )
    44  
    45  // CSRF enable CSRF protection on routes using this middleware.
    46  // This middleware is adapted from gorilla/csrf
    47  var CSRF = func(next buffalo.Handler) buffalo.Handler {
    48  	return func(c buffalo.Context) error {
    49  		var realToken []byte
    50  		rawRealToken := c.Session().Get(csrfTokenKey)
    51  
    52  		if rawRealToken == nil || len(rawRealToken.([]byte)) != csrfTokenLength {
    53  			// If the token is missing, or the length if the token is wrong,
    54  			// generate a new token.
    55  			realToken, err := generateRandomBytes(csrfTokenLength)
    56  			if err != nil {
    57  				return err
    58  			}
    59  			// Save the new real token in session
    60  			c.Session().Set(csrfTokenKey, realToken)
    61  		} else {
    62  			realToken = rawRealToken.([]byte)
    63  		}
    64  
    65  		// Set masked token in context data, to be available in template
    66  		c.Set(fieldName, mask(realToken, c.Request()))
    67  
    68  		// HTTP methods not defined as idempotent ("safe") under RFC7231 require
    69  		// inspection.
    70  		if !contains(safeMethods, c.Request().Method) {
    71  			// Enforce an origin check for HTTPS connections. As per the Django CSRF
    72  			// implementation (https://goo.gl/vKA7GE) the Referer header is almost
    73  			// always present for same-domain HTTP requests.
    74  			if c.Request().URL.Scheme == "https" {
    75  				// Fetch the Referer value. Call the error handler if it's empty or
    76  				// otherwise fails to parse.
    77  				referer, err := url.Parse(c.Request().Referer())
    78  				if err != nil || referer.String() == "" {
    79  					return ErrNoReferer
    80  				}
    81  
    82  				if sameOrigin(c.Request().URL, referer) == false {
    83  					return ErrBadReferer
    84  				}
    85  			}
    86  
    87  			// Retrieve the combined token (pad + masked) token and unmask it.
    88  			requestToken := unmask(requestCSRFToken(c.Request()))
    89  
    90  			// Missing token
    91  			if requestToken == nil {
    92  				return ErrNoCSRFToken
    93  			}
    94  
    95  			// Compare tokens
    96  			if !compareTokens(requestToken, realToken) {
    97  				return ErrBadCSRFToken
    98  			}
    99  		}
   100  
   101  		return next(c)
   102  	}
   103  }
   104  
   105  // generateRandomBytes returns securely generated random bytes.
   106  // It will return an error if the system's secure random number generator
   107  // fails to function correctly.
   108  func generateRandomBytes(n int) ([]byte, error) {
   109  	b := make([]byte, n)
   110  	_, err := rand.Read(b)
   111  	// err == nil only if len(b) == n
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	return b, nil
   117  }
   118  
   119  // sameOrigin returns true if URLs a and b share the same origin. The same
   120  // origin is defined as host (which includes the port) and scheme.
   121  func sameOrigin(a, b *url.URL) bool {
   122  	return (a.Scheme == b.Scheme && a.Host == b.Host)
   123  }
   124  
   125  // contains is a helper function to check if a string exists in a slice - e.g.
   126  // whether a HTTP method exists in a list of safe methods.
   127  func contains(vals []string, s string) bool {
   128  	for _, v := range vals {
   129  		if v == s {
   130  			return true
   131  		}
   132  	}
   133  
   134  	return false
   135  }
   136  
   137  // compare securely (constant-time) compares the unmasked token from the request
   138  // against the real token from the session.
   139  func compareTokens(a, b []byte) bool {
   140  	// This is required as subtle.ConstantTimeCompare does not check for equal
   141  	// lengths in Go versions prior to 1.3.
   142  	if len(a) != len(b) {
   143  		return false
   144  	}
   145  
   146  	return subtle.ConstantTimeCompare(a, b) == 1
   147  }
   148  
   149  // xorToken XORs tokens ([]byte) to provide unique-per-request CSRF tokens. It
   150  // will return a masked token if the base token is XOR'ed with a one-time-pad.
   151  // An unmasked token will be returned if a masked token is XOR'ed with the
   152  // one-time-pad used to mask it.
   153  func xorToken(a, b []byte) []byte {
   154  	n := len(a)
   155  	if len(b) < n {
   156  		n = len(b)
   157  	}
   158  
   159  	res := make([]byte, n)
   160  
   161  	for i := 0; i < n; i++ {
   162  		res[i] = a[i] ^ b[i]
   163  	}
   164  
   165  	return res
   166  }
   167  
   168  // mask returns a unique-per-request token to mitigate the BREACH attack
   169  // as per http://breachattack.com/#mitigations
   170  //
   171  // The token is generated by XOR'ing a one-time-pad and the base (session) CSRF
   172  // token and returning them together as a 64-byte slice. This effectively
   173  // randomises the token on a per-request basis without breaking multiple browser
   174  // tabs/windows.
   175  func mask(realToken []byte, r *http.Request) string {
   176  	otp, err := generateRandomBytes(csrfTokenLength)
   177  	if err != nil {
   178  		return ""
   179  	}
   180  
   181  	// XOR the OTP with the real token to generate a masked token. Append the
   182  	// OTP to the front of the masked token to allow unmasking in the subsequent
   183  	// request.
   184  	return base64.StdEncoding.EncodeToString(append(otp, xorToken(otp, realToken)...))
   185  }
   186  
   187  // unmask splits the issued token (one-time-pad + masked token) and returns the
   188  // unmasked request token for comparison.
   189  func unmask(issued []byte) []byte {
   190  	// Issued tokens are always masked and combined with the pad.
   191  	if len(issued) != csrfTokenLength*2 {
   192  		return nil
   193  	}
   194  
   195  	// We now know the length of the byte slice.
   196  	otp := issued[csrfTokenLength:]
   197  	masked := issued[:csrfTokenLength]
   198  
   199  	// Unmask the token by XOR'ing it against the OTP used to mask it.
   200  	return xorToken(otp, masked)
   201  }
   202  
   203  // requestCSRFToken gets the CSRF token from either:
   204  // - a HTTP header
   205  // - a form value
   206  // - a multipart form value
   207  func requestCSRFToken(r *http.Request) []byte {
   208  	// 1. Check the HTTP header first.
   209  	issued := r.Header.Get(headerName)
   210  
   211  	// 2. Fall back to the POST (form) value.
   212  	if issued == "" {
   213  		issued = r.PostFormValue(fieldName)
   214  	}
   215  
   216  	// 3. Finally, fall back to the multipart form (if set).
   217  	if issued == "" && r.MultipartForm != nil {
   218  		vals := r.MultipartForm.Value[fieldName]
   219  
   220  		if len(vals) > 0 {
   221  			issued = vals[0]
   222  		}
   223  	}
   224  
   225  	// Decode the "issued" (pad + masked) token sent in the request. Return a
   226  	// nil byte slice on a decoding error (this will fail upstream).
   227  	decoded, err := base64.StdEncoding.DecodeString(issued)
   228  	if err != nil {
   229  		return nil
   230  	}
   231  
   232  	return decoded
   233  }