github.com/mayra-cabrera/buffalo@v0.9.4-0.20170814145312-66d2e7772f11/middleware/csrf/csrf.go (about)

     1  package csrf
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/subtle"
     6  	"encoding/base64"
     7  	"errors"
     8  	"net/http"
     9  	"net/url"
    10  	"strings"
    11  
    12  	"github.com/gobuffalo/buffalo"
    13  )
    14  
    15  const (
    16  	// CSRF token length in bytes.
    17  	tokenLength int    = 32
    18  	tokenKey    string = "authenticity_token"
    19  )
    20  
    21  var (
    22  	// The name value used in form fields.
    23  	fieldName = tokenKey
    24  
    25  	// The HTTP request header to inspect
    26  	headerName = "X-CSRF-Token"
    27  
    28  	// Idempotent (safe) methods as defined by RFC7231 section 4.2.2.
    29  	safeMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"}
    30  	htmlTypes   = []string{"html", "form", "plain"}
    31  )
    32  
    33  var (
    34  	// ErrNoReferer is returned when a HTTPS request provides an empty Referer
    35  	// header.
    36  	ErrNoReferer = errors.New("referer not supplied")
    37  	// ErrBadReferer is returned when the scheme & host in the URL do not match
    38  	// the supplied Referer header.
    39  	ErrBadReferer = errors.New("referer invalid")
    40  	// ErrNoToken is returned if no CSRF token is supplied in the request.
    41  	ErrNoToken = errors.New("CSRF token not found in request")
    42  	// ErrBadToken is returned if the CSRF token in the request does not match
    43  	// the token in the session, or is otherwise malformed.
    44  	ErrBadToken = errors.New("CSRF token invalid")
    45  )
    46  
    47  // Middleware enable CSRF protection on routes using this middleware.
    48  // This middleware is adapted from gorilla/csrf
    49  var Middleware = func(next buffalo.Handler) buffalo.Handler {
    50  	return func(c buffalo.Context) error {
    51  		req := c.Request()
    52  
    53  		ct := req.Header.Get("Content-Type")
    54  		// ignore non-html requests
    55  		if ct != "" && !contains(htmlTypes, ct) {
    56  			return next(c)
    57  		}
    58  
    59  		var realToken []byte
    60  		rawRealToken := c.Session().Get(tokenKey)
    61  
    62  		if rawRealToken == nil || len(rawRealToken.([]byte)) != tokenLength {
    63  			// If the token is missing, or the length if the token is wrong,
    64  			// generate a new token.
    65  			realToken, err := generateRandomBytes(tokenLength)
    66  			if err != nil {
    67  				return err
    68  			}
    69  			// Save the new real token in session
    70  			c.Session().Set(tokenKey, realToken)
    71  		} else {
    72  			realToken = rawRealToken.([]byte)
    73  		}
    74  
    75  		// Set masked token in context data, to be available in template
    76  		c.Set(fieldName, mask(realToken, req))
    77  
    78  		// HTTP methods not defined as idempotent ("safe") under RFC7231 require
    79  		// inspection.
    80  		if !contains(safeMethods, req.Method) {
    81  			// Enforce an origin check for HTTPS connections. As per the Django CSRF
    82  			// implementation (https://goo.gl/vKA7GE) the Referer header is almost
    83  			// always present for same-domain HTTP requests.
    84  			if req.URL.Scheme == "https" {
    85  				// Fetch the Referer value. Call the error handler if it's empty or
    86  				// otherwise fails to parse.
    87  				referer, err := url.Parse(req.Referer())
    88  				if err != nil || referer.String() == "" {
    89  					return ErrNoReferer
    90  				}
    91  
    92  				if !sameOrigin(req.URL, referer) {
    93  					return ErrBadReferer
    94  				}
    95  			}
    96  
    97  			// Retrieve the combined token (pad + masked) token and unmask it.
    98  			requestToken := unmask(requestCSRFToken(req))
    99  
   100  			// Missing token
   101  			if requestToken == nil {
   102  				return ErrNoToken
   103  			}
   104  
   105  			// Compare tokens
   106  			if !compareTokens(requestToken, realToken) {
   107  				return ErrBadToken
   108  			}
   109  		}
   110  
   111  		return next(c)
   112  	}
   113  }
   114  
   115  // generateRandomBytes returns securely generated random bytes.
   116  // It will return an error if the system's secure random number generator
   117  // fails to function correctly.
   118  func generateRandomBytes(n int) ([]byte, error) {
   119  	b := make([]byte, n)
   120  	_, err := rand.Read(b)
   121  	// err == nil only if len(b) == n
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  
   126  	return b, nil
   127  }
   128  
   129  // sameOrigin returns true if URLs a and b share the same origin. The same
   130  // origin is defined as host (which includes the port) and scheme.
   131  func sameOrigin(a, b *url.URL) bool {
   132  	return (a.Scheme == b.Scheme && a.Host == b.Host)
   133  }
   134  
   135  // contains is a helper function to check if a string exists in a slice - e.g.
   136  // whether a HTTP method exists in a list of safe methods.
   137  func contains(vals []string, s string) bool {
   138  	s = strings.ToLower(s)
   139  	for _, v := range vals {
   140  		if strings.Contains(s, strings.ToLower(v)) {
   141  			return true
   142  		}
   143  	}
   144  
   145  	return false
   146  }
   147  
   148  // compare securely (constant-time) compares the unmasked token from the request
   149  // against the real token from the session.
   150  func compareTokens(a, b []byte) bool {
   151  	// This is required as subtle.ConstantTimeCompare does not check for equal
   152  	// lengths in Go versions prior to 1.3.
   153  	if len(a) != len(b) {
   154  		return false
   155  	}
   156  
   157  	return subtle.ConstantTimeCompare(a, b) == 1
   158  }
   159  
   160  // xorToken XORs tokens ([]byte) to provide unique-per-request CSRF tokens. It
   161  // will return a masked token if the base token is XOR'ed with a one-time-pad.
   162  // An unmasked token will be returned if a masked token is XOR'ed with the
   163  // one-time-pad used to mask it.
   164  func xorToken(a, b []byte) []byte {
   165  	n := len(a)
   166  	if len(b) < n {
   167  		n = len(b)
   168  	}
   169  
   170  	res := make([]byte, n)
   171  
   172  	for i := 0; i < n; i++ {
   173  		res[i] = a[i] ^ b[i]
   174  	}
   175  
   176  	return res
   177  }
   178  
   179  // mask returns a unique-per-request token to mitigate the BREACH attack
   180  // as per http://breachattack.com/#mitigations
   181  //
   182  // The token is generated by XOR'ing a one-time-pad and the base (session) CSRF
   183  // token and returning them together as a 64-byte slice. This effectively
   184  // randomises the token on a per-request basis without breaking multiple browser
   185  // tabs/windows.
   186  func mask(realToken []byte, r *http.Request) string {
   187  	otp, err := generateRandomBytes(tokenLength)
   188  	if err != nil {
   189  		return ""
   190  	}
   191  
   192  	// XOR the OTP with the real token to generate a masked token. Append the
   193  	// OTP to the front of the masked token to allow unmasking in the subsequent
   194  	// request.
   195  	return base64.StdEncoding.EncodeToString(append(otp, xorToken(otp, realToken)...))
   196  }
   197  
   198  // unmask splits the issued token (one-time-pad + masked token) and returns the
   199  // unmasked request token for comparison.
   200  func unmask(issued []byte) []byte {
   201  	// Issued tokens are always masked and combined with the pad.
   202  	if len(issued) != tokenLength*2 {
   203  		return nil
   204  	}
   205  
   206  	// We now know the length of the byte slice.
   207  	otp := issued[tokenLength:]
   208  	masked := issued[:tokenLength]
   209  
   210  	// Unmask the token by XOR'ing it against the OTP used to mask it.
   211  	return xorToken(otp, masked)
   212  }
   213  
   214  // requestCSRFToken gets the CSRF token from either:
   215  // - a HTTP header
   216  // - a form value
   217  // - a multipart form value
   218  func requestCSRFToken(r *http.Request) []byte {
   219  	// 1. Check the HTTP header first.
   220  	issued := r.Header.Get(headerName)
   221  
   222  	// 2. Fall back to the POST (form) value.
   223  	if issued == "" {
   224  		issued = r.PostFormValue(fieldName)
   225  	}
   226  
   227  	// 3. Finally, fall back to the multipart form (if set).
   228  	if issued == "" && r.MultipartForm != nil {
   229  		vals := r.MultipartForm.Value[fieldName]
   230  
   231  		if len(vals) > 0 {
   232  			issued = vals[0]
   233  		}
   234  	}
   235  
   236  	// Decode the "issued" (pad + masked) token sent in the request. Return a
   237  	// nil byte slice on a decoding error (this will fail upstream).
   238  	decoded, err := base64.StdEncoding.DecodeString(issued)
   239  	if err != nil {
   240  		return nil
   241  	}
   242  
   243  	return decoded
   244  }