github.com/corylanou/buffalo@v0.8.0/middleware/csrf.go (about)

     1  package middleware
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/subtle"
     6  	"encoding/base64"
     7  	"errors"
     8  	"github.com/gobuffalo/buffalo"
     9  	"net/http"
    10  	"net/url"
    11  )
    12  
    13  const (
    14  	// CSRF token length in bytes.
    15  	csrfTokenLength int    = 32
    16  	csrfTokenKey    string = "authenticity_token"
    17  )
    18  
    19  var (
    20  	// The name value used in form fields.
    21  	fieldName = csrfTokenKey
    22  
    23  	// The HTTP request header to inspect
    24  	headerName = "X-CSRF-Token"
    25  
    26  	// Idempotent (safe) methods as defined by RFC7231 section 4.2.2.
    27  	safeMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"}
    28  )
    29  
    30  var (
    31  	// ErrNoReferer is returned when a HTTPS request provides an empty Referer
    32  	// header.
    33  	ErrNoReferer = errors.New("referer not supplied")
    34  	// ErrBadReferer is returned when the scheme & host in the URL do not match
    35  	// the supplied Referer header.
    36  	ErrBadReferer = errors.New("referer invalid")
    37  	// ErrNoCSRFToken is returned if no CSRF token is supplied in the request.
    38  	ErrNoCSRFToken = errors.New("CSRF token not found in request")
    39  	// ErrBadCSRFToken is returned if the CSRF token in the request does not match
    40  	// the token in the session, or is otherwise malformed.
    41  	ErrBadCSRFToken = errors.New("CSRF token invalid")
    42  )
    43  
    44  // CSRF enable CSRF protection on routes using this middleware.
    45  // This middleware is adapted from gorilla/csrf
    46  func CSRF(next buffalo.Handler) buffalo.Handler {
    47  	return func(c buffalo.Context) error {
    48  		var realToken []byte
    49  		rawRealToken := c.Session().Get(csrfTokenKey)
    50  
    51  		if rawRealToken == nil || len(rawRealToken.([]byte)) != csrfTokenLength {
    52  			// If the token is missing, or the length if the token is wrong,
    53  			// generate a new token.
    54  			realToken, err := generateRandomBytes(csrfTokenLength)
    55  			if err != nil {
    56  				return err
    57  			}
    58  			// Save the new real token in session
    59  			c.Session().Set(csrfTokenKey, realToken)
    60  		} else {
    61  			realToken = rawRealToken.([]byte)
    62  		}
    63  
    64  		// Set masked token in context data, to be available in template
    65  		c.Set(fieldName, mask(realToken, c.Request()))
    66  
    67  		// HTTP methods not defined as idempotent ("safe") under RFC7231 require
    68  		// inspection.
    69  		if !contains(safeMethods, c.Request().Method) {
    70  			// Enforce an origin check for HTTPS connections. As per the Django CSRF
    71  			// implementation (https://goo.gl/vKA7GE) the Referer header is almost
    72  			// always present for same-domain HTTP requests.
    73  			if c.Request().URL.Scheme == "https" {
    74  				// Fetch the Referer value. Call the error handler if it's empty or
    75  				// otherwise fails to parse.
    76  				referer, err := url.Parse(c.Request().Referer())
    77  				if err != nil || referer.String() == "" {
    78  					return ErrNoReferer
    79  				}
    80  
    81  				if sameOrigin(c.Request().URL, referer) == false {
    82  					return ErrBadReferer
    83  				}
    84  			}
    85  
    86  			// Retrieve the combined token (pad + masked) token and unmask it.
    87  			requestToken := unmask(requestCSRFToken(c.Request()))
    88  
    89  			// Missing token
    90  			if requestToken == nil {
    91  				return ErrNoCSRFToken
    92  			}
    93  
    94  			// Compare tokens
    95  			if !compareTokens(requestToken, realToken) {
    96  				return ErrBadCSRFToken
    97  			}
    98  		}
    99  
   100  		return next(c)
   101  	}
   102  }
   103  
   104  // generateRandomBytes returns securely generated random bytes.
   105  // It will return an error if the system's secure random number generator
   106  // fails to function correctly.
   107  func generateRandomBytes(n int) ([]byte, error) {
   108  	b := make([]byte, n)
   109  	_, err := rand.Read(b)
   110  	// err == nil only if len(b) == n
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	return b, nil
   116  }
   117  
   118  // sameOrigin returns true if URLs a and b share the same origin. The same
   119  // origin is defined as host (which includes the port) and scheme.
   120  func sameOrigin(a, b *url.URL) bool {
   121  	return (a.Scheme == b.Scheme && a.Host == b.Host)
   122  }
   123  
   124  // contains is a helper function to check if a string exists in a slice - e.g.
   125  // whether a HTTP method exists in a list of safe methods.
   126  func contains(vals []string, s string) bool {
   127  	for _, v := range vals {
   128  		if v == s {
   129  			return true
   130  		}
   131  	}
   132  
   133  	return false
   134  }
   135  
   136  // compare securely (constant-time) compares the unmasked token from the request
   137  // against the real token from the session.
   138  func compareTokens(a, b []byte) bool {
   139  	// This is required as subtle.ConstantTimeCompare does not check for equal
   140  	// lengths in Go versions prior to 1.3.
   141  	if len(a) != len(b) {
   142  		return false
   143  	}
   144  
   145  	return subtle.ConstantTimeCompare(a, b) == 1
   146  }
   147  
   148  // xorToken XORs tokens ([]byte) to provide unique-per-request CSRF tokens. It
   149  // will return a masked token if the base token is XOR'ed with a one-time-pad.
   150  // An unmasked token will be returned if a masked token is XOR'ed with the
   151  // one-time-pad used to mask it.
   152  func xorToken(a, b []byte) []byte {
   153  	n := len(a)
   154  	if len(b) < n {
   155  		n = len(b)
   156  	}
   157  
   158  	res := make([]byte, n)
   159  
   160  	for i := 0; i < n; i++ {
   161  		res[i] = a[i] ^ b[i]
   162  	}
   163  
   164  	return res
   165  }
   166  
   167  // mask returns a unique-per-request token to mitigate the BREACH attack
   168  // as per http://breachattack.com/#mitigations
   169  //
   170  // The token is generated by XOR'ing a one-time-pad and the base (session) CSRF
   171  // token and returning them together as a 64-byte slice. This effectively
   172  // randomises the token on a per-request basis without breaking multiple browser
   173  // tabs/windows.
   174  func mask(realToken []byte, r *http.Request) string {
   175  	otp, err := generateRandomBytes(csrfTokenLength)
   176  	if err != nil {
   177  		return ""
   178  	}
   179  
   180  	// XOR the OTP with the real token to generate a masked token. Append the
   181  	// OTP to the front of the masked token to allow unmasking in the subsequent
   182  	// request.
   183  	return base64.StdEncoding.EncodeToString(append(otp, xorToken(otp, realToken)...))
   184  }
   185  
   186  // unmask splits the issued token (one-time-pad + masked token) and returns the
   187  // unmasked request token for comparison.
   188  func unmask(issued []byte) []byte {
   189  	// Issued tokens are always masked and combined with the pad.
   190  	if len(issued) != csrfTokenLength*2 {
   191  		return nil
   192  	}
   193  
   194  	// We now know the length of the byte slice.
   195  	otp := issued[csrfTokenLength:]
   196  	masked := issued[:csrfTokenLength]
   197  
   198  	// Unmask the token by XOR'ing it against the OTP used to mask it.
   199  	return xorToken(otp, masked)
   200  }
   201  
   202  // requestCSRFToken gets the CSRF token from either:
   203  // - a HTTP header
   204  // - a form value
   205  // - a multipart form value
   206  func requestCSRFToken(r *http.Request) []byte {
   207  	// 1. Check the HTTP header first.
   208  	issued := r.Header.Get(headerName)
   209  
   210  	// 2. Fall back to the POST (form) value.
   211  	if issued == "" {
   212  		issued = r.PostFormValue(fieldName)
   213  	}
   214  
   215  	// 3. Finally, fall back to the multipart form (if set).
   216  	if issued == "" && r.MultipartForm != nil {
   217  		vals := r.MultipartForm.Value[fieldName]
   218  
   219  		if len(vals) > 0 {
   220  			issued = vals[0]
   221  		}
   222  	}
   223  
   224  	// Decode the "issued" (pad + masked) token sent in the request. Return a
   225  	// nil byte slice on a decoding error (this will fail upstream).
   226  	decoded, err := base64.StdEncoding.DecodeString(issued)
   227  	if err != nil {
   228  		return nil
   229  	}
   230  
   231  	return decoded
   232  }