git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/httpx/cors/cors.go (about) 1 /* 2 Package cors is net/http handler to handle CORS related requests 3 as defined by http://www.w3.org/TR/cors/ 4 5 You can configure it by passing an option struct to cors.New: 6 7 c := cors.New(cors.Options{ 8 AllowedOrigins: []string{"foo.com"}, 9 AllowedMethods: []string{http.MethodGet, http.MethodPost, http.MethodDelete}, 10 AllowCredentials: true, 11 }) 12 13 Then insert the handler in the chain: 14 15 handler = c.Handler(handler) 16 17 See Options documentation for more options. 18 19 The resulting handler is a standard net/http handler. 20 */ 21 package cors 22 23 import ( 24 "log" 25 "net/http" 26 "os" 27 "strconv" 28 "strings" 29 ) 30 31 // Options is a configuration container to setup the CORS middleware. 32 type Options struct { 33 // AllowedOrigins is a list of origins a cross-domain request can be executed from. 34 // If the special "*" value is present in the list, all origins will be allowed. 35 // An origin may contain a wildcard (*) to replace 0 or more characters 36 // (i.e.: http://*.domain.com). Usage of wildcards implies a small performance penalty. 37 // Only one wildcard can be used per origin. 38 // Default value is ["*"] 39 AllowedOrigins []string 40 // AllowOriginFunc is a custom function to validate the origin. It take the origin 41 // as argument and returns true if allowed or false otherwise. If this option is 42 // set, the content of AllowedOrigins is ignored. 43 AllowOriginFunc func(origin string) bool 44 // AllowOriginRequestFunc is a custom function to validate the origin. It takes the HTTP Request object and the origin as 45 // argument and returns true if allowed or false otherwise. If this option is set, the content of `AllowedOrigins` 46 // and `AllowOriginFunc` is ignored. 47 AllowOriginRequestFunc func(r *http.Request, origin string) bool 48 // AllowedMethods is a list of methods the client is allowed to use with 49 // cross-domain requests. Default value is simple methods (HEAD, GET and POST). 50 AllowedMethods []string 51 // AllowedHeaders is list of non simple headers the client is allowed to use with 52 // cross-domain requests. 53 // If the special "*" value is present in the list, all headers will be allowed. 54 // Default value is [] but "Origin" is always appended to the list. 55 AllowedHeaders []string 56 // ExposedHeaders indicates which headers are safe to expose to the API of a CORS 57 // API specification 58 ExposedHeaders []string 59 // MaxAge indicates how long (in seconds) the results of a preflight request 60 // can be cached 61 MaxAge int 62 // AllowCredentials indicates whether the request can include user credentials like 63 // cookies, HTTP authentication or client side SSL certificates. 64 AllowCredentials bool 65 // AllowPrivateNetwork indicates whether to accept cross-origin requests over a 66 // private network. 67 AllowPrivateNetwork bool 68 // OptionsPassthrough instructs preflight to let other potential next handlers to 69 // process the OPTIONS method. Turn this on if your application handles OPTIONS. 70 OptionsPassthrough bool 71 // Provides a status code to use for successful OPTIONS requests. 72 // Default value is http.StatusNoContent (204). 73 OptionsSuccessStatus int 74 // Debugging flag adds additional output to debug server side CORS issues 75 Debug bool 76 } 77 78 // Logger generic interface for logger 79 type Logger interface { 80 Printf(string, ...interface{}) 81 } 82 83 // Cors http handler 84 type Cors struct { 85 // Debug logger 86 Log Logger 87 // Normalized list of plain allowed origins 88 allowedOrigins []string 89 // List of allowed origins containing wildcards 90 allowedWOrigins []wildcard 91 // Optional origin validator function 92 allowOriginFunc func(origin string) bool 93 // Optional origin validator (with request) function 94 allowOriginRequestFunc func(r *http.Request, origin string) bool 95 // Normalized list of allowed headers 96 allowedHeaders []string 97 // Normalized list of allowed methods 98 allowedMethods []string 99 // Normalized list of exposed headers 100 exposedHeaders []string 101 maxAge int 102 // Set to true when allowed origins contains a "*" 103 allowedOriginsAll bool 104 // Set to true when allowed headers contains a "*" 105 allowedHeadersAll bool 106 // Status code to use for successful OPTIONS requests 107 optionsSuccessStatus int 108 allowCredentials bool 109 allowPrivateNetwork bool 110 optionPassthrough bool 111 } 112 113 // New creates a new Cors handler with the provided options. 114 func New(options Options) *Cors { 115 c := &Cors{ 116 exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey), 117 allowOriginFunc: options.AllowOriginFunc, 118 allowOriginRequestFunc: options.AllowOriginRequestFunc, 119 allowCredentials: options.AllowCredentials, 120 allowPrivateNetwork: options.AllowPrivateNetwork, 121 maxAge: options.MaxAge, 122 optionPassthrough: options.OptionsPassthrough, 123 } 124 if options.Debug && c.Log == nil { 125 c.Log = log.New(os.Stdout, "[cors] ", log.LstdFlags) 126 } 127 128 // Normalize options 129 // Note: for origins and methods matching, the spec requires a case-sensitive matching. 130 // As it may error prone, we chose to ignore the spec here. 131 132 // Allowed Origins 133 if len(options.AllowedOrigins) == 0 { 134 if options.AllowOriginFunc == nil && options.AllowOriginRequestFunc == nil { 135 // Default is all origins 136 c.allowedOriginsAll = true 137 } 138 } else { 139 c.allowedOrigins = []string{} 140 c.allowedWOrigins = []wildcard{} 141 for _, origin := range options.AllowedOrigins { 142 // Normalize 143 origin = strings.ToLower(origin) 144 if origin == "*" { 145 // If "*" is present in the list, turn the whole list into a match all 146 c.allowedOriginsAll = true 147 c.allowedOrigins = nil 148 c.allowedWOrigins = nil 149 break 150 } else if i := strings.IndexByte(origin, '*'); i >= 0 { 151 // Split the origin in two: start and end string without the * 152 w := wildcard{origin[0:i], origin[i+1:]} 153 c.allowedWOrigins = append(c.allowedWOrigins, w) 154 } else { 155 c.allowedOrigins = append(c.allowedOrigins, origin) 156 } 157 } 158 } 159 160 // Allowed Headers 161 if len(options.AllowedHeaders) == 0 { 162 // Use sensible defaults 163 c.allowedHeaders = []string{"Origin", "Accept", "Content-Type", "X-Requested-With"} 164 } else { 165 // Origin is always appended as some browsers will always request for this header at preflight 166 c.allowedHeaders = convert(append(options.AllowedHeaders, "Origin"), http.CanonicalHeaderKey) 167 for _, h := range options.AllowedHeaders { 168 if h == "*" { 169 c.allowedHeadersAll = true 170 c.allowedHeaders = nil 171 break 172 } 173 } 174 } 175 176 // Allowed Methods 177 if len(options.AllowedMethods) == 0 { 178 // Default is spec's "simple" methods 179 c.allowedMethods = []string{http.MethodGet, http.MethodPost, http.MethodHead} 180 } else { 181 c.allowedMethods = convert(options.AllowedMethods, strings.ToUpper) 182 } 183 184 // Options Success Status Code 185 if options.OptionsSuccessStatus == 0 { 186 c.optionsSuccessStatus = http.StatusNoContent 187 } else { 188 c.optionsSuccessStatus = options.OptionsSuccessStatus 189 } 190 191 return c 192 } 193 194 // Default creates a new Cors handler with default options. 195 func Default() *Cors { 196 return New(Options{}) 197 } 198 199 // AllowAll create a new Cors handler with permissive configuration allowing all 200 // origins with all standard methods with any header and credentials. 201 func AllowAll() *Cors { 202 return New(Options{ 203 AllowedOrigins: []string{"*"}, 204 AllowedMethods: []string{ 205 http.MethodHead, 206 http.MethodGet, 207 http.MethodPost, 208 http.MethodPut, 209 http.MethodPatch, 210 http.MethodDelete, 211 }, 212 AllowedHeaders: []string{"*"}, 213 AllowCredentials: false, 214 }) 215 } 216 217 // Handler apply the CORS specification on the request, and add relevant CORS headers 218 // as necessary. 219 func (c *Cors) Handler(h http.Handler) http.Handler { 220 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 221 if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" { 222 c.logf("Handler: Preflight request") 223 c.handlePreflight(w, r) 224 // Preflight requests are standalone and should stop the chain as some other 225 // middleware may not handle OPTIONS requests correctly. One typical example 226 // is authentication middleware ; OPTIONS requests won't carry authentication 227 // headers (see #1) 228 if c.optionPassthrough { 229 h.ServeHTTP(w, r) 230 } else { 231 w.WriteHeader(c.optionsSuccessStatus) 232 } 233 } else { 234 c.logf("Handler: Actual request") 235 c.handleActualRequest(w, r) 236 h.ServeHTTP(w, r) 237 } 238 }) 239 } 240 241 // HandlerFunc provides Martini compatible handler 242 func (c *Cors) HandlerFunc(w http.ResponseWriter, r *http.Request) { 243 if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" { 244 c.logf("HandlerFunc: Preflight request") 245 c.handlePreflight(w, r) 246 247 w.WriteHeader(c.optionsSuccessStatus) 248 } else { 249 c.logf("HandlerFunc: Actual request") 250 c.handleActualRequest(w, r) 251 } 252 } 253 254 // Negroni compatible interface 255 func (c *Cors) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { 256 if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" { 257 c.logf("ServeHTTP: Preflight request") 258 c.handlePreflight(w, r) 259 // Preflight requests are standalone and should stop the chain as some other 260 // middleware may not handle OPTIONS requests correctly. One typical example 261 // is authentication middleware ; OPTIONS requests won't carry authentication 262 // headers (see #1) 263 if c.optionPassthrough { 264 next(w, r) 265 } else { 266 w.WriteHeader(c.optionsSuccessStatus) 267 } 268 } else { 269 c.logf("ServeHTTP: Actual request") 270 c.handleActualRequest(w, r) 271 next(w, r) 272 } 273 } 274 275 // handlePreflight handles pre-flight CORS requests 276 func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) { 277 headers := w.Header() 278 origin := r.Header.Get("Origin") 279 280 if r.Method != http.MethodOptions { 281 c.logf(" Preflight aborted: %s!=OPTIONS", r.Method) 282 return 283 } 284 // Always set Vary headers 285 // see https://git.sr.ht/~pingoo/stdx/cors/issues/10, 286 // https://git.sr.ht/~pingoo/stdx/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001 287 headers.Add("Vary", "Origin") 288 headers.Add("Vary", "Access-Control-Request-Method") 289 headers.Add("Vary", "Access-Control-Request-Headers") 290 if c.allowPrivateNetwork { 291 headers.Add("Vary", "Access-Control-Request-Private-Network") 292 } 293 294 if origin == "" { 295 c.logf(" Preflight aborted: empty origin") 296 return 297 } 298 if !c.isOriginAllowed(r, origin) { 299 c.logf(" Preflight aborted: origin '%s' not allowed", origin) 300 return 301 } 302 303 reqMethod := r.Header.Get("Access-Control-Request-Method") 304 if !c.isMethodAllowed(reqMethod) { 305 c.logf(" Preflight aborted: method '%s' not allowed", reqMethod) 306 return 307 } 308 reqHeaders := parseHeaderList(r.Header.Get("Access-Control-Request-Headers")) 309 if !c.areHeadersAllowed(reqHeaders) { 310 c.logf(" Preflight aborted: headers '%v' not allowed", reqHeaders) 311 return 312 } 313 if c.allowedOriginsAll { 314 headers.Set("Access-Control-Allow-Origin", "*") 315 } else { 316 headers.Set("Access-Control-Allow-Origin", origin) 317 } 318 // Spec says: Since the list of methods can be unbounded, simply returning the method indicated 319 // by Access-Control-Request-Method (if supported) can be enough 320 headers.Set("Access-Control-Allow-Methods", strings.ToUpper(reqMethod)) 321 if len(reqHeaders) > 0 { 322 323 // Spec says: Since the list of headers can be unbounded, simply returning supported headers 324 // from Access-Control-Request-Headers can be enough 325 headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", ")) 326 } 327 if c.allowCredentials { 328 headers.Set("Access-Control-Allow-Credentials", "true") 329 } 330 if c.allowPrivateNetwork && r.Header.Get("Access-Control-Request-Private-Network") == "true" { 331 headers.Set("Access-Control-Allow-Private-Network", "true") 332 } 333 if c.maxAge > 0 { 334 headers.Set("Access-Control-Max-Age", strconv.Itoa(c.maxAge)) 335 } 336 c.logf(" Preflight response headers: %v", headers) 337 } 338 339 // handleActualRequest handles simple cross-origin requests, actual request or redirects 340 func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) { 341 headers := w.Header() 342 origin := r.Header.Get("Origin") 343 344 // Always set Vary, see https://git.sr.ht/~pingoo/stdx/cors/issues/10 345 headers.Add("Vary", "Origin") 346 if origin == "" { 347 c.logf(" Actual request no headers added: missing origin") 348 return 349 } 350 if !c.isOriginAllowed(r, origin) { 351 c.logf(" Actual request no headers added: origin '%s' not allowed", origin) 352 return 353 } 354 355 // Note that spec does define a way to specifically disallow a simple method like GET or 356 // POST. Access-Control-Allow-Methods is only used for pre-flight requests and the 357 // spec doesn't instruct to check the allowed methods for simple cross-origin requests. 358 // We think it's a nice feature to be able to have control on those methods though. 359 if !c.isMethodAllowed(r.Method) { 360 c.logf(" Actual request no headers added: method '%s' not allowed", r.Method) 361 362 return 363 } 364 if c.allowedOriginsAll { 365 headers.Set("Access-Control-Allow-Origin", "*") 366 } else { 367 headers.Set("Access-Control-Allow-Origin", origin) 368 } 369 if len(c.exposedHeaders) > 0 { 370 headers.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders, ", ")) 371 } 372 if c.allowCredentials { 373 headers.Set("Access-Control-Allow-Credentials", "true") 374 } 375 c.logf(" Actual response added headers: %v", headers) 376 } 377 378 // convenience method. checks if a logger is set. 379 func (c *Cors) logf(format string, a ...interface{}) { 380 if c.Log != nil { 381 c.Log.Printf(format, a...) 382 } 383 } 384 385 // check the Origin of a request. No origin at all is also allowed. 386 func (c *Cors) OriginAllowed(r *http.Request) bool { 387 origin := r.Header.Get("Origin") 388 return c.isOriginAllowed(r, origin) 389 } 390 391 // isOriginAllowed checks if a given origin is allowed to perform cross-domain requests 392 // on the endpoint 393 func (c *Cors) isOriginAllowed(r *http.Request, origin string) bool { 394 if c.allowOriginRequestFunc != nil { 395 return c.allowOriginRequestFunc(r, origin) 396 } 397 if c.allowOriginFunc != nil { 398 return c.allowOriginFunc(origin) 399 } 400 if c.allowedOriginsAll { 401 return true 402 } 403 origin = strings.ToLower(origin) 404 for _, o := range c.allowedOrigins { 405 if o == origin { 406 return true 407 } 408 } 409 for _, w := range c.allowedWOrigins { 410 if w.match(origin) { 411 return true 412 } 413 } 414 return false 415 } 416 417 // isMethodAllowed checks if a given method can be used as part of a cross-domain request 418 // on the endpoint 419 func (c *Cors) isMethodAllowed(method string) bool { 420 if len(c.allowedMethods) == 0 { 421 // If no method allowed, always return false, even for preflight request 422 return false 423 } 424 method = strings.ToUpper(method) 425 if method == http.MethodOptions { 426 // Always allow preflight requests 427 return true 428 } 429 for _, m := range c.allowedMethods { 430 if m == method { 431 return true 432 } 433 } 434 return false 435 } 436 437 // areHeadersAllowed checks if a given list of headers are allowed to used within 438 // a cross-domain request. 439 func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool { 440 if c.allowedHeadersAll || len(requestedHeaders) == 0 { 441 return true 442 } 443 for _, header := range requestedHeaders { 444 header = http.CanonicalHeaderKey(header) 445 found := false 446 for _, h := range c.allowedHeaders { 447 if h == header { 448 found = true 449 break 450 } 451 } 452 if !found { 453 return false 454 } 455 } 456 return true 457 }