github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/web/middlewares/csrf.go (about) 1 package middlewares 2 3 import ( 4 "crypto/subtle" 5 "errors" 6 "net/http" 7 "strings" 8 9 "github.com/cozy/cozy-stack/pkg/utils" 10 "github.com/labstack/echo/v4" 11 "github.com/labstack/echo/v4/middleware" 12 ) 13 14 type ( 15 // CSRFConfig defines the config for CSRF middleware. 16 CSRFConfig struct { 17 // Skipper defines a function to skip middleware. 18 Skipper middleware.Skipper 19 20 // TokenLength is the length of the generated token. 21 TokenLength int `yaml:"token_length"` 22 // Optional. Default value 32. 23 24 // TokenLookup is a string in the form of "<source>:<key>" that is used 25 // to extract token from the request. 26 // Optional. Default value "header:X-CSRF-Token". 27 // Possible values: 28 // - "header:<name>" 29 // - "form:<name>" 30 // - "query:<name>" 31 TokenLookup string `yaml:"token_lookup"` 32 33 // Context key to store generated CSRF token into context. 34 // Optional. Default value "csrf". 35 ContextKey string `yaml:"context_key"` 36 37 // Name of the CSRF cookie. This cookie will store CSRF token. 38 // Optional. Default value "csrf". 39 CookieName string `yaml:"cookie_name"` 40 41 // Domain of the CSRF cookie. 42 // Optional. Default value none. 43 CookieDomain string `yaml:"cookie_domain"` 44 45 // Path of the CSRF cookie. 46 // Optional. Default value none. 47 CookiePath string `yaml:"cookie_path"` 48 49 // Max age (in seconds) of the CSRF cookie. 50 // Optional. Default value 86400 (24hr). 51 CookieMaxAge int `yaml:"cookie_max_age"` 52 53 // Indicates if CSRF cookie is secure. 54 // Optional. Default value false. 55 CookieSecure bool `yaml:"cookie_secure"` 56 57 // Indicates if CSRF cookie is HTTP only. 58 // Optional. Default value false. 59 CookieHTTPOnly bool `yaml:"cookie_http_only"` 60 61 // Indicates the sameSite policy for the CSRF cookie. 62 // Optional. Default value is lax. 63 CookieSameSite http.SameSite `yaml:"cookie_same_site"` 64 } 65 66 // csrfTokenExtractor defines a function that takes `echo.Context` and returns 67 // either a token or an error. 68 csrfTokenExtractor func(echo.Context) (string, error) 69 ) 70 71 var ( 72 // DefaultCSRFConfig is the default CSRF middleware config. 73 DefaultCSRFConfig = CSRFConfig{ 74 Skipper: middleware.DefaultSkipper, 75 TokenLength: 32, 76 TokenLookup: "header:" + echo.HeaderXCSRFToken, 77 ContextKey: "csrf", 78 CookieName: "_csrf", 79 CookieMaxAge: 86400, 80 CookieSameSite: http.SameSiteLaxMode, 81 } 82 ) 83 84 // CSRF returns a Cross-Site Request Forgery (CSRF) middleware. 85 // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery 86 func CSRF() echo.MiddlewareFunc { 87 c := DefaultCSRFConfig 88 return CSRFWithConfig(c) 89 } 90 91 // CSRFWithConfig returns a CSRF middleware with config. 92 // See `CSRF()`. 93 func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { 94 // Defaults 95 if config.Skipper == nil { 96 config.Skipper = DefaultCSRFConfig.Skipper 97 } 98 if config.TokenLength == 0 { 99 config.TokenLength = DefaultCSRFConfig.TokenLength 100 } 101 if config.TokenLookup == "" { 102 config.TokenLookup = DefaultCSRFConfig.TokenLookup 103 } 104 if config.ContextKey == "" { 105 config.ContextKey = DefaultCSRFConfig.ContextKey 106 } 107 if config.CookieName == "" { 108 config.CookieName = DefaultCSRFConfig.CookieName 109 } 110 if config.CookieMaxAge == 0 { 111 config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge 112 } 113 if config.CookieSameSite == 0 { 114 config.CookieSameSite = DefaultCSRFConfig.CookieSameSite 115 } 116 117 // Initialize 118 parts := strings.Split(config.TokenLookup, ":") 119 extractor := csrfTokenFromHeader(parts[1]) 120 switch parts[0] { 121 case "form": 122 extractor = csrfTokenFromForm(parts[1]) 123 case "query": 124 extractor = csrfTokenFromQuery(parts[1]) 125 } 126 127 return func(next echo.HandlerFunc) echo.HandlerFunc { 128 return func(c echo.Context) error { 129 if config.Skipper(c) { 130 return next(c) 131 } 132 133 req := c.Request() 134 k, err := c.Cookie(config.CookieName) 135 token := "" 136 137 // Generate token 138 if err != nil { 139 token = utils.RandomString(config.TokenLength) 140 } else { 141 // Reuse token 142 token = k.Value 143 } 144 145 switch req.Method { 146 case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: 147 default: 148 // Validate token only for requests which are not defined as 'safe' by RFC7231 149 clientToken, err := extractor(c) 150 if err != nil { 151 return echo.NewHTTPError(http.StatusBadRequest, err.Error()) 152 } 153 if !validateCSRFToken(token, clientToken) { 154 return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") 155 } 156 } 157 158 // Set CSRF cookie 159 cookie := new(http.Cookie) 160 cookie.Name = config.CookieName 161 cookie.Value = token 162 if config.CookiePath != "" { 163 cookie.Path = config.CookiePath 164 } 165 if config.CookieDomain != "" { 166 cookie.Domain = config.CookieDomain 167 } 168 cookie.MaxAge = config.CookieMaxAge 169 cookie.Secure = config.CookieSecure 170 cookie.HttpOnly = config.CookieHTTPOnly 171 cookie.SameSite = config.CookieSameSite 172 c.SetCookie(cookie) 173 174 // Store token in the context 175 c.Set(config.ContextKey, token) 176 177 // Protect clients from caching the response 178 c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie) 179 180 return next(c) 181 } 182 } 183 } 184 185 // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the 186 // provided request header. 187 func csrfTokenFromHeader(header string) csrfTokenExtractor { 188 return func(c echo.Context) (string, error) { 189 return c.Request().Header.Get(header), nil 190 } 191 } 192 193 // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the 194 // provided form parameter. 195 func csrfTokenFromForm(param string) csrfTokenExtractor { 196 return func(c echo.Context) (string, error) { 197 token := c.FormValue(param) 198 if token == "" { 199 return "", errors.New("missing csrf token in the form parameter") 200 } 201 return token, nil 202 } 203 } 204 205 // csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the 206 // provided query parameter. 207 func csrfTokenFromQuery(param string) csrfTokenExtractor { 208 return func(c echo.Context) (string, error) { 209 token := c.QueryParam(param) 210 if token == "" { 211 return "", errors.New("missing csrf token in the query string") 212 } 213 return token, nil 214 } 215 } 216 217 func validateCSRFToken(token, clientToken string) bool { 218 return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1 219 }