github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/web/middlewares/csrf.go (about)

     1  package middlewares
     2  
     3  import (
     4  	"crypto/subtle"
     5  	"errors"
     6  	"net/http"
     7  	"strings"
     8  
     9  	"github.com/cozy/cozy-stack/pkg/utils"
    10  	"github.com/labstack/echo/v4"
    11  	"github.com/labstack/echo/v4/middleware"
    12  )
    13  
    14  type (
    15  	// CSRFConfig defines the config for CSRF middleware.
    16  	CSRFConfig struct {
    17  		// Skipper defines a function to skip middleware.
    18  		Skipper middleware.Skipper
    19  
    20  		// TokenLength is the length of the generated token.
    21  		TokenLength int `yaml:"token_length"`
    22  		// Optional. Default value 32.
    23  
    24  		// TokenLookup is a string in the form of "<source>:<key>" that is used
    25  		// to extract token from the request.
    26  		// Optional. Default value "header:X-CSRF-Token".
    27  		// Possible values:
    28  		// - "header:<name>"
    29  		// - "form:<name>"
    30  		// - "query:<name>"
    31  		TokenLookup string `yaml:"token_lookup"`
    32  
    33  		// Context key to store generated CSRF token into context.
    34  		// Optional. Default value "csrf".
    35  		ContextKey string `yaml:"context_key"`
    36  
    37  		// Name of the CSRF cookie. This cookie will store CSRF token.
    38  		// Optional. Default value "csrf".
    39  		CookieName string `yaml:"cookie_name"`
    40  
    41  		// Domain of the CSRF cookie.
    42  		// Optional. Default value none.
    43  		CookieDomain string `yaml:"cookie_domain"`
    44  
    45  		// Path of the CSRF cookie.
    46  		// Optional. Default value none.
    47  		CookiePath string `yaml:"cookie_path"`
    48  
    49  		// Max age (in seconds) of the CSRF cookie.
    50  		// Optional. Default value 86400 (24hr).
    51  		CookieMaxAge int `yaml:"cookie_max_age"`
    52  
    53  		// Indicates if CSRF cookie is secure.
    54  		// Optional. Default value false.
    55  		CookieSecure bool `yaml:"cookie_secure"`
    56  
    57  		// Indicates if CSRF cookie is HTTP only.
    58  		// Optional. Default value false.
    59  		CookieHTTPOnly bool `yaml:"cookie_http_only"`
    60  
    61  		// Indicates the sameSite policy for the CSRF cookie.
    62  		// Optional. Default value is lax.
    63  		CookieSameSite http.SameSite `yaml:"cookie_same_site"`
    64  	}
    65  
    66  	// csrfTokenExtractor defines a function that takes `echo.Context` and returns
    67  	// either a token or an error.
    68  	csrfTokenExtractor func(echo.Context) (string, error)
    69  )
    70  
    71  var (
    72  	// DefaultCSRFConfig is the default CSRF middleware config.
    73  	DefaultCSRFConfig = CSRFConfig{
    74  		Skipper:        middleware.DefaultSkipper,
    75  		TokenLength:    32,
    76  		TokenLookup:    "header:" + echo.HeaderXCSRFToken,
    77  		ContextKey:     "csrf",
    78  		CookieName:     "_csrf",
    79  		CookieMaxAge:   86400,
    80  		CookieSameSite: http.SameSiteLaxMode,
    81  	}
    82  )
    83  
    84  // CSRF returns a Cross-Site Request Forgery (CSRF) middleware.
    85  // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery
    86  func CSRF() echo.MiddlewareFunc {
    87  	c := DefaultCSRFConfig
    88  	return CSRFWithConfig(c)
    89  }
    90  
    91  // CSRFWithConfig returns a CSRF middleware with config.
    92  // See `CSRF()`.
    93  func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
    94  	// Defaults
    95  	if config.Skipper == nil {
    96  		config.Skipper = DefaultCSRFConfig.Skipper
    97  	}
    98  	if config.TokenLength == 0 {
    99  		config.TokenLength = DefaultCSRFConfig.TokenLength
   100  	}
   101  	if config.TokenLookup == "" {
   102  		config.TokenLookup = DefaultCSRFConfig.TokenLookup
   103  	}
   104  	if config.ContextKey == "" {
   105  		config.ContextKey = DefaultCSRFConfig.ContextKey
   106  	}
   107  	if config.CookieName == "" {
   108  		config.CookieName = DefaultCSRFConfig.CookieName
   109  	}
   110  	if config.CookieMaxAge == 0 {
   111  		config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge
   112  	}
   113  	if config.CookieSameSite == 0 {
   114  		config.CookieSameSite = DefaultCSRFConfig.CookieSameSite
   115  	}
   116  
   117  	// Initialize
   118  	parts := strings.Split(config.TokenLookup, ":")
   119  	extractor := csrfTokenFromHeader(parts[1])
   120  	switch parts[0] {
   121  	case "form":
   122  		extractor = csrfTokenFromForm(parts[1])
   123  	case "query":
   124  		extractor = csrfTokenFromQuery(parts[1])
   125  	}
   126  
   127  	return func(next echo.HandlerFunc) echo.HandlerFunc {
   128  		return func(c echo.Context) error {
   129  			if config.Skipper(c) {
   130  				return next(c)
   131  			}
   132  
   133  			req := c.Request()
   134  			k, err := c.Cookie(config.CookieName)
   135  			token := ""
   136  
   137  			// Generate token
   138  			if err != nil {
   139  				token = utils.RandomString(config.TokenLength)
   140  			} else {
   141  				// Reuse token
   142  				token = k.Value
   143  			}
   144  
   145  			switch req.Method {
   146  			case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
   147  			default:
   148  				// Validate token only for requests which are not defined as 'safe' by RFC7231
   149  				clientToken, err := extractor(c)
   150  				if err != nil {
   151  					return echo.NewHTTPError(http.StatusBadRequest, err.Error())
   152  				}
   153  				if !validateCSRFToken(token, clientToken) {
   154  					return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
   155  				}
   156  			}
   157  
   158  			// Set CSRF cookie
   159  			cookie := new(http.Cookie)
   160  			cookie.Name = config.CookieName
   161  			cookie.Value = token
   162  			if config.CookiePath != "" {
   163  				cookie.Path = config.CookiePath
   164  			}
   165  			if config.CookieDomain != "" {
   166  				cookie.Domain = config.CookieDomain
   167  			}
   168  			cookie.MaxAge = config.CookieMaxAge
   169  			cookie.Secure = config.CookieSecure
   170  			cookie.HttpOnly = config.CookieHTTPOnly
   171  			cookie.SameSite = config.CookieSameSite
   172  			c.SetCookie(cookie)
   173  
   174  			// Store token in the context
   175  			c.Set(config.ContextKey, token)
   176  
   177  			// Protect clients from caching the response
   178  			c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
   179  
   180  			return next(c)
   181  		}
   182  	}
   183  }
   184  
   185  // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
   186  // provided request header.
   187  func csrfTokenFromHeader(header string) csrfTokenExtractor {
   188  	return func(c echo.Context) (string, error) {
   189  		return c.Request().Header.Get(header), nil
   190  	}
   191  }
   192  
   193  // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
   194  // provided form parameter.
   195  func csrfTokenFromForm(param string) csrfTokenExtractor {
   196  	return func(c echo.Context) (string, error) {
   197  		token := c.FormValue(param)
   198  		if token == "" {
   199  			return "", errors.New("missing csrf token in the form parameter")
   200  		}
   201  		return token, nil
   202  	}
   203  }
   204  
   205  // csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
   206  // provided query parameter.
   207  func csrfTokenFromQuery(param string) csrfTokenExtractor {
   208  	return func(c echo.Context) (string, error) {
   209  		token := c.QueryParam(param)
   210  		if token == "" {
   211  			return "", errors.New("missing csrf token in the query string")
   212  		}
   213  		return token, nil
   214  	}
   215  }
   216  
   217  func validateCSRFToken(token, clientToken string) bool {
   218  	return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
   219  }