github.com/jacobsoderblom/buffalo@v0.11.0/middleware/csrf/csrf.go (about)

     1  package csrf
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/subtle"
     6  	"encoding/base64"
     7  	"errors"
     8  	"net/http"
     9  	"net/url"
    10  	"strings"
    11  
    12  	"github.com/gobuffalo/buffalo"
    13  	"github.com/gobuffalo/envy"
    14  	"github.com/markbates/going/defaults"
    15  )
    16  
    17  const (
    18  	// CSRF token length in bytes.
    19  	tokenLength int    = 32
    20  	tokenKey    string = "authenticity_token"
    21  )
    22  
    23  var (
    24  	// The name value used in form fields.
    25  	fieldName = tokenKey
    26  
    27  	// The HTTP request header to inspect
    28  	headerName = "X-CSRF-Token"
    29  
    30  	// Idempotent (safe) methods as defined by RFC7231 section 4.2.2.
    31  	safeMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"}
    32  	htmlTypes   = []string{"html", "form", "plain", "*/*"}
    33  )
    34  
    35  var (
    36  	// ErrNoReferer is returned when a HTTPS request provides an empty Referer
    37  	// header.
    38  	ErrNoReferer = errors.New("referer not supplied")
    39  	// ErrBadReferer is returned when the scheme & host in the URL do not match
    40  	// the supplied Referer header.
    41  	ErrBadReferer = errors.New("referer invalid")
    42  	// ErrNoToken is returned if no CSRF token is supplied in the request.
    43  	ErrNoToken = errors.New("CSRF token not found in request")
    44  	// ErrBadToken is returned if the CSRF token in the request does not match
    45  	// the token in the session, or is otherwise malformed.
    46  	ErrBadToken = errors.New("CSRF token invalid")
    47  )
    48  
    49  // New enable CSRF protection on routes using this middleware.
    50  // This middleware is adapted from gorilla/csrf
    51  var New = func(next buffalo.Handler) buffalo.Handler {
    52  	return func(c buffalo.Context) error {
    53  		// don't run in test mode
    54  		if envy.Get("GO_ENV", "development") == "test" {
    55  			c.Set(tokenKey, "test")
    56  			return next(c)
    57  		}
    58  
    59  		req := c.Request()
    60  
    61  		ct := defaults.String(req.Header.Get("Content-Type"), req.Header.Get("Accept"))
    62  		// ignore non-html requests
    63  		if ct != "" && !contains(htmlTypes, ct) {
    64  			return next(c)
    65  		}
    66  
    67  		var realToken []byte
    68  		rawRealToken := c.Session().Get(tokenKey)
    69  
    70  		if rawRealToken == nil || len(rawRealToken.([]byte)) != tokenLength {
    71  			// If the token is missing, or the length if the token is wrong,
    72  			// generate a new token.
    73  			realToken, err := generateRandomBytes(tokenLength)
    74  			if err != nil {
    75  				return err
    76  			}
    77  			// Save the new real token in session
    78  			c.Session().Set(tokenKey, realToken)
    79  		} else {
    80  			realToken = rawRealToken.([]byte)
    81  		}
    82  
    83  		// Set masked token in context data, to be available in template
    84  		c.Set(fieldName, mask(realToken, req))
    85  
    86  		// HTTP methods not defined as idempotent ("safe") under RFC7231 require
    87  		// inspection.
    88  		if !contains(safeMethods, req.Method) {
    89  			// Enforce an origin check for HTTPS connections. As per the Django CSRF
    90  			// implementation (https://goo.gl/vKA7GE) the Referer header is almost
    91  			// always present for same-domain HTTP requests.
    92  			if req.URL.Scheme == "https" {
    93  				// Fetch the Referer value. Call the error handler if it's empty or
    94  				// otherwise fails to parse.
    95  				referer, err := url.Parse(req.Referer())
    96  				if err != nil || referer.String() == "" {
    97  					return ErrNoReferer
    98  				}
    99  
   100  				if !sameOrigin(req.URL, referer) {
   101  					return ErrBadReferer
   102  				}
   103  			}
   104  
   105  			// Retrieve the combined token (pad + masked) token and unmask it.
   106  			requestToken := unmask(requestCSRFToken(req))
   107  
   108  			// Missing token
   109  			if requestToken == nil {
   110  				return ErrNoToken
   111  			}
   112  
   113  			// Compare tokens
   114  			if !compareTokens(requestToken, realToken) {
   115  				return ErrBadToken
   116  			}
   117  		}
   118  
   119  		return next(c)
   120  	}
   121  }
   122  
   123  // generateRandomBytes returns securely generated random bytes.
   124  // It will return an error if the system's secure random number generator
   125  // fails to function correctly.
   126  func generateRandomBytes(n int) ([]byte, error) {
   127  	b := make([]byte, n)
   128  	_, err := rand.Read(b)
   129  	// err == nil only if len(b) == n
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  
   134  	return b, nil
   135  }
   136  
   137  // sameOrigin returns true if URLs a and b share the same origin. The same
   138  // origin is defined as host (which includes the port) and scheme.
   139  func sameOrigin(a, b *url.URL) bool {
   140  	return (a.Scheme == b.Scheme && a.Host == b.Host)
   141  }
   142  
   143  // contains is a helper function to check if a string exists in a slice - e.g.
   144  // whether a HTTP method exists in a list of safe methods.
   145  func contains(vals []string, s string) bool {
   146  	s = strings.ToLower(s)
   147  	for _, v := range vals {
   148  		if strings.Contains(s, strings.ToLower(v)) {
   149  			return true
   150  		}
   151  	}
   152  
   153  	return false
   154  }
   155  
   156  // compare securely (constant-time) compares the unmasked token from the request
   157  // against the real token from the session.
   158  func compareTokens(a, b []byte) bool {
   159  	// This is required as subtle.ConstantTimeCompare does not check for equal
   160  	// lengths in Go versions prior to 1.3.
   161  	if len(a) != len(b) {
   162  		return false
   163  	}
   164  
   165  	return subtle.ConstantTimeCompare(a, b) == 1
   166  }
   167  
   168  // xorToken XORs tokens ([]byte) to provide unique-per-request CSRF tokens. It
   169  // will return a masked token if the base token is XOR'ed with a one-time-pad.
   170  // An unmasked token will be returned if a masked token is XOR'ed with the
   171  // one-time-pad used to mask it.
   172  func xorToken(a, b []byte) []byte {
   173  	n := len(a)
   174  	if len(b) < n {
   175  		n = len(b)
   176  	}
   177  
   178  	res := make([]byte, n)
   179  
   180  	for i := 0; i < n; i++ {
   181  		res[i] = a[i] ^ b[i]
   182  	}
   183  
   184  	return res
   185  }
   186  
   187  // mask returns a unique-per-request token to mitigate the BREACH attack
   188  // as per http://breachattack.com/#mitigations
   189  //
   190  // The token is generated by XOR'ing a one-time-pad and the base (session) CSRF
   191  // token and returning them together as a 64-byte slice. This effectively
   192  // randomises the token on a per-request basis without breaking multiple browser
   193  // tabs/windows.
   194  func mask(realToken []byte, r *http.Request) string {
   195  	otp, err := generateRandomBytes(tokenLength)
   196  	if err != nil {
   197  		return ""
   198  	}
   199  
   200  	// XOR the OTP with the real token to generate a masked token. Append the
   201  	// OTP to the front of the masked token to allow unmasking in the subsequent
   202  	// request.
   203  	return base64.StdEncoding.EncodeToString(append(otp, xorToken(otp, realToken)...))
   204  }
   205  
   206  // unmask splits the issued token (one-time-pad + masked token) and returns the
   207  // unmasked request token for comparison.
   208  func unmask(issued []byte) []byte {
   209  	// Issued tokens are always masked and combined with the pad.
   210  	if len(issued) != tokenLength*2 {
   211  		return nil
   212  	}
   213  
   214  	// We now know the length of the byte slice.
   215  	otp := issued[tokenLength:]
   216  	masked := issued[:tokenLength]
   217  
   218  	// Unmask the token by XOR'ing it against the OTP used to mask it.
   219  	return xorToken(otp, masked)
   220  }
   221  
   222  // requestCSRFToken gets the CSRF token from either:
   223  // - a HTTP header
   224  // - a form value
   225  // - a multipart form value
   226  func requestCSRFToken(r *http.Request) []byte {
   227  	// 1. Check the HTTP header first.
   228  	issued := r.Header.Get(headerName)
   229  
   230  	// 2. Fall back to the POST (form) value.
   231  	if issued == "" {
   232  		issued = r.PostFormValue(fieldName)
   233  	}
   234  
   235  	// 3. Finally, fall back to the multipart form (if set).
   236  	if issued == "" && r.MultipartForm != nil {
   237  		vals := r.MultipartForm.Value[fieldName]
   238  
   239  		if len(vals) > 0 {
   240  			issued = vals[0]
   241  		}
   242  	}
   243  
   244  	// Decode the "issued" (pad + masked) token sent in the request. Return a
   245  	// nil byte slice on a decoding error (this will fail upstream).
   246  	decoded, err := base64.StdEncoding.DecodeString(issued)
   247  	if err != nil {
   248  		return nil
   249  	}
   250  
   251  	return decoded
   252  }