github.com/cheikhshift/buffalo@v0.9.5/middleware/csrf/csrf.go (about)

     1  package csrf
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/subtle"
     6  	"encoding/base64"
     7  	"errors"
     8  	"fmt"
     9  	"net/http"
    10  	"net/url"
    11  	"runtime"
    12  	"strings"
    13  
    14  	"github.com/gobuffalo/buffalo"
    15  	"github.com/gobuffalo/envy"
    16  	"github.com/markbates/going/defaults"
    17  )
    18  
    19  const (
    20  	// CSRF token length in bytes.
    21  	tokenLength int    = 32
    22  	tokenKey    string = "authenticity_token"
    23  )
    24  
    25  var (
    26  	// The name value used in form fields.
    27  	fieldName = tokenKey
    28  
    29  	// The HTTP request header to inspect
    30  	headerName = "X-CSRF-Token"
    31  
    32  	// Idempotent (safe) methods as defined by RFC7231 section 4.2.2.
    33  	safeMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"}
    34  	htmlTypes   = []string{"html", "form", "plain"}
    35  )
    36  
    37  var (
    38  	// ErrNoReferer is returned when a HTTPS request provides an empty Referer
    39  	// header.
    40  	ErrNoReferer = errors.New("referer not supplied")
    41  	// ErrBadReferer is returned when the scheme & host in the URL do not match
    42  	// the supplied Referer header.
    43  	ErrBadReferer = errors.New("referer invalid")
    44  	// ErrNoToken is returned if no CSRF token is supplied in the request.
    45  	ErrNoToken = errors.New("CSRF token not found in request")
    46  	// ErrBadToken is returned if the CSRF token in the request does not match
    47  	// the token in the session, or is otherwise malformed.
    48  	ErrBadToken = errors.New("CSRF token invalid")
    49  )
    50  
    51  // Middleware is deprecated, and will be removed in v0.10.0. Use csrf.New instead.
    52  var Middleware = func(next buffalo.Handler) buffalo.Handler {
    53  	warningMsg := "csrf.Middleware is deprecated, and will be removed in v0.10.0. Use csrf.New instead."
    54  	_, file, no, ok := runtime.Caller(1)
    55  	if ok {
    56  		warningMsg = fmt.Sprintf("%s Called from %s:%d", warningMsg, file, no)
    57  	}
    58  	return New(next)
    59  }
    60  
    61  // New enable CSRF protection on routes using this middleware.
    62  // This middleware is adapted from gorilla/csrf
    63  var New = func(next buffalo.Handler) buffalo.Handler {
    64  	return func(c buffalo.Context) error {
    65  		// don't run in test mode
    66  		if envy.Get("GO_ENV", "development") == "test" {
    67  			return next(c)
    68  		}
    69  
    70  		req := c.Request()
    71  
    72  		ct := defaults.String(req.Header.Get("Content-Type"), req.Header.Get("Accept"))
    73  		// ignore non-html requests
    74  		if ct != "" && !contains(htmlTypes, ct) {
    75  			return next(c)
    76  		}
    77  
    78  		var realToken []byte
    79  		rawRealToken := c.Session().Get(tokenKey)
    80  
    81  		if rawRealToken == nil || len(rawRealToken.([]byte)) != tokenLength {
    82  			// If the token is missing, or the length if the token is wrong,
    83  			// generate a new token.
    84  			realToken, err := generateRandomBytes(tokenLength)
    85  			if err != nil {
    86  				return err
    87  			}
    88  			// Save the new real token in session
    89  			c.Session().Set(tokenKey, realToken)
    90  		} else {
    91  			realToken = rawRealToken.([]byte)
    92  		}
    93  
    94  		// Set masked token in context data, to be available in template
    95  		c.Set(fieldName, mask(realToken, req))
    96  
    97  		// HTTP methods not defined as idempotent ("safe") under RFC7231 require
    98  		// inspection.
    99  		if !contains(safeMethods, req.Method) {
   100  			// Enforce an origin check for HTTPS connections. As per the Django CSRF
   101  			// implementation (https://goo.gl/vKA7GE) the Referer header is almost
   102  			// always present for same-domain HTTP requests.
   103  			if req.URL.Scheme == "https" {
   104  				// Fetch the Referer value. Call the error handler if it's empty or
   105  				// otherwise fails to parse.
   106  				referer, err := url.Parse(req.Referer())
   107  				if err != nil || referer.String() == "" {
   108  					return ErrNoReferer
   109  				}
   110  
   111  				if !sameOrigin(req.URL, referer) {
   112  					return ErrBadReferer
   113  				}
   114  			}
   115  
   116  			// Retrieve the combined token (pad + masked) token and unmask it.
   117  			requestToken := unmask(requestCSRFToken(req))
   118  
   119  			// Missing token
   120  			if requestToken == nil {
   121  				return ErrNoToken
   122  			}
   123  
   124  			// Compare tokens
   125  			if !compareTokens(requestToken, realToken) {
   126  				return ErrBadToken
   127  			}
   128  		}
   129  
   130  		return next(c)
   131  	}
   132  }
   133  
   134  // generateRandomBytes returns securely generated random bytes.
   135  // It will return an error if the system's secure random number generator
   136  // fails to function correctly.
   137  func generateRandomBytes(n int) ([]byte, error) {
   138  	b := make([]byte, n)
   139  	_, err := rand.Read(b)
   140  	// err == nil only if len(b) == n
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  
   145  	return b, nil
   146  }
   147  
   148  // sameOrigin returns true if URLs a and b share the same origin. The same
   149  // origin is defined as host (which includes the port) and scheme.
   150  func sameOrigin(a, b *url.URL) bool {
   151  	return (a.Scheme == b.Scheme && a.Host == b.Host)
   152  }
   153  
   154  // contains is a helper function to check if a string exists in a slice - e.g.
   155  // whether a HTTP method exists in a list of safe methods.
   156  func contains(vals []string, s string) bool {
   157  	s = strings.ToLower(s)
   158  	for _, v := range vals {
   159  		if strings.Contains(s, strings.ToLower(v)) {
   160  			return true
   161  		}
   162  	}
   163  
   164  	return false
   165  }
   166  
   167  // compare securely (constant-time) compares the unmasked token from the request
   168  // against the real token from the session.
   169  func compareTokens(a, b []byte) bool {
   170  	// This is required as subtle.ConstantTimeCompare does not check for equal
   171  	// lengths in Go versions prior to 1.3.
   172  	if len(a) != len(b) {
   173  		return false
   174  	}
   175  
   176  	return subtle.ConstantTimeCompare(a, b) == 1
   177  }
   178  
   179  // xorToken XORs tokens ([]byte) to provide unique-per-request CSRF tokens. It
   180  // will return a masked token if the base token is XOR'ed with a one-time-pad.
   181  // An unmasked token will be returned if a masked token is XOR'ed with the
   182  // one-time-pad used to mask it.
   183  func xorToken(a, b []byte) []byte {
   184  	n := len(a)
   185  	if len(b) < n {
   186  		n = len(b)
   187  	}
   188  
   189  	res := make([]byte, n)
   190  
   191  	for i := 0; i < n; i++ {
   192  		res[i] = a[i] ^ b[i]
   193  	}
   194  
   195  	return res
   196  }
   197  
   198  // mask returns a unique-per-request token to mitigate the BREACH attack
   199  // as per http://breachattack.com/#mitigations
   200  //
   201  // The token is generated by XOR'ing a one-time-pad and the base (session) CSRF
   202  // token and returning them together as a 64-byte slice. This effectively
   203  // randomises the token on a per-request basis without breaking multiple browser
   204  // tabs/windows.
   205  func mask(realToken []byte, r *http.Request) string {
   206  	otp, err := generateRandomBytes(tokenLength)
   207  	if err != nil {
   208  		return ""
   209  	}
   210  
   211  	// XOR the OTP with the real token to generate a masked token. Append the
   212  	// OTP to the front of the masked token to allow unmasking in the subsequent
   213  	// request.
   214  	return base64.StdEncoding.EncodeToString(append(otp, xorToken(otp, realToken)...))
   215  }
   216  
   217  // unmask splits the issued token (one-time-pad + masked token) and returns the
   218  // unmasked request token for comparison.
   219  func unmask(issued []byte) []byte {
   220  	// Issued tokens are always masked and combined with the pad.
   221  	if len(issued) != tokenLength*2 {
   222  		return nil
   223  	}
   224  
   225  	// We now know the length of the byte slice.
   226  	otp := issued[tokenLength:]
   227  	masked := issued[:tokenLength]
   228  
   229  	// Unmask the token by XOR'ing it against the OTP used to mask it.
   230  	return xorToken(otp, masked)
   231  }
   232  
   233  // requestCSRFToken gets the CSRF token from either:
   234  // - a HTTP header
   235  // - a form value
   236  // - a multipart form value
   237  func requestCSRFToken(r *http.Request) []byte {
   238  	// 1. Check the HTTP header first.
   239  	issued := r.Header.Get(headerName)
   240  
   241  	// 2. Fall back to the POST (form) value.
   242  	if issued == "" {
   243  		issued = r.PostFormValue(fieldName)
   244  	}
   245  
   246  	// 3. Finally, fall back to the multipart form (if set).
   247  	if issued == "" && r.MultipartForm != nil {
   248  		vals := r.MultipartForm.Value[fieldName]
   249  
   250  		if len(vals) > 0 {
   251  			issued = vals[0]
   252  		}
   253  	}
   254  
   255  	// Decode the "issued" (pad + masked) token sent in the request. Return a
   256  	// nil byte slice on a decoding error (this will fail upstream).
   257  	decoded, err := base64.StdEncoding.DecodeString(issued)
   258  	if err != nil {
   259  		return nil
   260  	}
   261  
   262  	return decoded
   263  }