git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/httpx/cors/cors.go (about)

     1  /*
     2  Package cors is net/http handler to handle CORS related requests
     3  as defined by http://www.w3.org/TR/cors/
     4  
     5  You can configure it by passing an option struct to cors.New:
     6  
     7  	c := cors.New(cors.Options{
     8  	    AllowedOrigins:   []string{"foo.com"},
     9  	    AllowedMethods:   []string{http.MethodGet, http.MethodPost, http.MethodDelete},
    10  	    AllowCredentials: true,
    11  	})
    12  
    13  Then insert the handler in the chain:
    14  
    15  	handler = c.Handler(handler)
    16  
    17  See Options documentation for more options.
    18  
    19  The resulting handler is a standard net/http handler.
    20  */
    21  package cors
    22  
    23  import (
    24  	"log"
    25  	"net/http"
    26  	"os"
    27  	"strconv"
    28  	"strings"
    29  )
    30  
    31  // Options is a configuration container to setup the CORS middleware.
    32  type Options struct {
    33  	// AllowedOrigins is a list of origins a cross-domain request can be executed from.
    34  	// If the special "*" value is present in the list, all origins will be allowed.
    35  	// An origin may contain a wildcard (*) to replace 0 or more characters
    36  	// (i.e.: http://*.domain.com). Usage of wildcards implies a small performance penalty.
    37  	// Only one wildcard can be used per origin.
    38  	// Default value is ["*"]
    39  	AllowedOrigins []string
    40  	// AllowOriginFunc is a custom function to validate the origin. It take the origin
    41  	// as argument and returns true if allowed or false otherwise. If this option is
    42  	// set, the content of AllowedOrigins is ignored.
    43  	AllowOriginFunc func(origin string) bool
    44  	// AllowOriginRequestFunc is a custom function to validate the origin. It takes the HTTP Request object and the origin as
    45  	// argument and returns true if allowed or false otherwise. If this option is set, the content of `AllowedOrigins`
    46  	// and `AllowOriginFunc` is ignored.
    47  	AllowOriginRequestFunc func(r *http.Request, origin string) bool
    48  	// AllowedMethods is a list of methods the client is allowed to use with
    49  	// cross-domain requests. Default value is simple methods (HEAD, GET and POST).
    50  	AllowedMethods []string
    51  	// AllowedHeaders is list of non simple headers the client is allowed to use with
    52  	// cross-domain requests.
    53  	// If the special "*" value is present in the list, all headers will be allowed.
    54  	// Default value is [] but "Origin" is always appended to the list.
    55  	AllowedHeaders []string
    56  	// ExposedHeaders indicates which headers are safe to expose to the API of a CORS
    57  	// API specification
    58  	ExposedHeaders []string
    59  	// MaxAge indicates how long (in seconds) the results of a preflight request
    60  	// can be cached
    61  	MaxAge int
    62  	// AllowCredentials indicates whether the request can include user credentials like
    63  	// cookies, HTTP authentication or client side SSL certificates.
    64  	AllowCredentials bool
    65  	// AllowPrivateNetwork indicates whether to accept cross-origin requests over a
    66  	// private network.
    67  	AllowPrivateNetwork bool
    68  	// OptionsPassthrough instructs preflight to let other potential next handlers to
    69  	// process the OPTIONS method. Turn this on if your application handles OPTIONS.
    70  	OptionsPassthrough bool
    71  	// Provides a status code to use for successful OPTIONS requests.
    72  	// Default value is http.StatusNoContent (204).
    73  	OptionsSuccessStatus int
    74  	// Debugging flag adds additional output to debug server side CORS issues
    75  	Debug bool
    76  }
    77  
    78  // Logger generic interface for logger
    79  type Logger interface {
    80  	Printf(string, ...interface{})
    81  }
    82  
    83  // Cors http handler
    84  type Cors struct {
    85  	// Debug logger
    86  	Log Logger
    87  	// Normalized list of plain allowed origins
    88  	allowedOrigins []string
    89  	// List of allowed origins containing wildcards
    90  	allowedWOrigins []wildcard
    91  	// Optional origin validator function
    92  	allowOriginFunc func(origin string) bool
    93  	// Optional origin validator (with request) function
    94  	allowOriginRequestFunc func(r *http.Request, origin string) bool
    95  	// Normalized list of allowed headers
    96  	allowedHeaders []string
    97  	// Normalized list of allowed methods
    98  	allowedMethods []string
    99  	// Normalized list of exposed headers
   100  	exposedHeaders []string
   101  	maxAge         int
   102  	// Set to true when allowed origins contains a "*"
   103  	allowedOriginsAll bool
   104  	// Set to true when allowed headers contains a "*"
   105  	allowedHeadersAll bool
   106  	// Status code to use for successful OPTIONS requests
   107  	optionsSuccessStatus int
   108  	allowCredentials     bool
   109  	allowPrivateNetwork  bool
   110  	optionPassthrough    bool
   111  }
   112  
   113  // New creates a new Cors handler with the provided options.
   114  func New(options Options) *Cors {
   115  	c := &Cors{
   116  		exposedHeaders:         convert(options.ExposedHeaders, http.CanonicalHeaderKey),
   117  		allowOriginFunc:        options.AllowOriginFunc,
   118  		allowOriginRequestFunc: options.AllowOriginRequestFunc,
   119  		allowCredentials:       options.AllowCredentials,
   120  		allowPrivateNetwork:    options.AllowPrivateNetwork,
   121  		maxAge:                 options.MaxAge,
   122  		optionPassthrough:      options.OptionsPassthrough,
   123  	}
   124  	if options.Debug && c.Log == nil {
   125  		c.Log = log.New(os.Stdout, "[cors] ", log.LstdFlags)
   126  	}
   127  
   128  	// Normalize options
   129  	// Note: for origins and methods matching, the spec requires a case-sensitive matching.
   130  	// As it may error prone, we chose to ignore the spec here.
   131  
   132  	// Allowed Origins
   133  	if len(options.AllowedOrigins) == 0 {
   134  		if options.AllowOriginFunc == nil && options.AllowOriginRequestFunc == nil {
   135  			// Default is all origins
   136  			c.allowedOriginsAll = true
   137  		}
   138  	} else {
   139  		c.allowedOrigins = []string{}
   140  		c.allowedWOrigins = []wildcard{}
   141  		for _, origin := range options.AllowedOrigins {
   142  			// Normalize
   143  			origin = strings.ToLower(origin)
   144  			if origin == "*" {
   145  				// If "*" is present in the list, turn the whole list into a match all
   146  				c.allowedOriginsAll = true
   147  				c.allowedOrigins = nil
   148  				c.allowedWOrigins = nil
   149  				break
   150  			} else if i := strings.IndexByte(origin, '*'); i >= 0 {
   151  				// Split the origin in two: start and end string without the *
   152  				w := wildcard{origin[0:i], origin[i+1:]}
   153  				c.allowedWOrigins = append(c.allowedWOrigins, w)
   154  			} else {
   155  				c.allowedOrigins = append(c.allowedOrigins, origin)
   156  			}
   157  		}
   158  	}
   159  
   160  	// Allowed Headers
   161  	if len(options.AllowedHeaders) == 0 {
   162  		// Use sensible defaults
   163  		c.allowedHeaders = []string{"Origin", "Accept", "Content-Type", "X-Requested-With"}
   164  	} else {
   165  		// Origin is always appended as some browsers will always request for this header at preflight
   166  		c.allowedHeaders = convert(append(options.AllowedHeaders, "Origin"), http.CanonicalHeaderKey)
   167  		for _, h := range options.AllowedHeaders {
   168  			if h == "*" {
   169  				c.allowedHeadersAll = true
   170  				c.allowedHeaders = nil
   171  				break
   172  			}
   173  		}
   174  	}
   175  
   176  	// Allowed Methods
   177  	if len(options.AllowedMethods) == 0 {
   178  		// Default is spec's "simple" methods
   179  		c.allowedMethods = []string{http.MethodGet, http.MethodPost, http.MethodHead}
   180  	} else {
   181  		c.allowedMethods = convert(options.AllowedMethods, strings.ToUpper)
   182  	}
   183  
   184  	// Options Success Status Code
   185  	if options.OptionsSuccessStatus == 0 {
   186  		c.optionsSuccessStatus = http.StatusNoContent
   187  	} else {
   188  		c.optionsSuccessStatus = options.OptionsSuccessStatus
   189  	}
   190  
   191  	return c
   192  }
   193  
   194  // Default creates a new Cors handler with default options.
   195  func Default() *Cors {
   196  	return New(Options{})
   197  }
   198  
   199  // AllowAll create a new Cors handler with permissive configuration allowing all
   200  // origins with all standard methods with any header and credentials.
   201  func AllowAll() *Cors {
   202  	return New(Options{
   203  		AllowedOrigins: []string{"*"},
   204  		AllowedMethods: []string{
   205  			http.MethodHead,
   206  			http.MethodGet,
   207  			http.MethodPost,
   208  			http.MethodPut,
   209  			http.MethodPatch,
   210  			http.MethodDelete,
   211  		},
   212  		AllowedHeaders:   []string{"*"},
   213  		AllowCredentials: false,
   214  	})
   215  }
   216  
   217  // Handler apply the CORS specification on the request, and add relevant CORS headers
   218  // as necessary.
   219  func (c *Cors) Handler(h http.Handler) http.Handler {
   220  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   221  		if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
   222  			c.logf("Handler: Preflight request")
   223  			c.handlePreflight(w, r)
   224  			// Preflight requests are standalone and should stop the chain as some other
   225  			// middleware may not handle OPTIONS requests correctly. One typical example
   226  			// is authentication middleware ; OPTIONS requests won't carry authentication
   227  			// headers (see #1)
   228  			if c.optionPassthrough {
   229  				h.ServeHTTP(w, r)
   230  			} else {
   231  				w.WriteHeader(c.optionsSuccessStatus)
   232  			}
   233  		} else {
   234  			c.logf("Handler: Actual request")
   235  			c.handleActualRequest(w, r)
   236  			h.ServeHTTP(w, r)
   237  		}
   238  	})
   239  }
   240  
   241  // HandlerFunc provides Martini compatible handler
   242  func (c *Cors) HandlerFunc(w http.ResponseWriter, r *http.Request) {
   243  	if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
   244  		c.logf("HandlerFunc: Preflight request")
   245  		c.handlePreflight(w, r)
   246  
   247  		w.WriteHeader(c.optionsSuccessStatus)
   248  	} else {
   249  		c.logf("HandlerFunc: Actual request")
   250  		c.handleActualRequest(w, r)
   251  	}
   252  }
   253  
   254  // Negroni compatible interface
   255  func (c *Cors) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
   256  	if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
   257  		c.logf("ServeHTTP: Preflight request")
   258  		c.handlePreflight(w, r)
   259  		// Preflight requests are standalone and should stop the chain as some other
   260  		// middleware may not handle OPTIONS requests correctly. One typical example
   261  		// is authentication middleware ; OPTIONS requests won't carry authentication
   262  		// headers (see #1)
   263  		if c.optionPassthrough {
   264  			next(w, r)
   265  		} else {
   266  			w.WriteHeader(c.optionsSuccessStatus)
   267  		}
   268  	} else {
   269  		c.logf("ServeHTTP: Actual request")
   270  		c.handleActualRequest(w, r)
   271  		next(w, r)
   272  	}
   273  }
   274  
   275  // handlePreflight handles pre-flight CORS requests
   276  func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
   277  	headers := w.Header()
   278  	origin := r.Header.Get("Origin")
   279  
   280  	if r.Method != http.MethodOptions {
   281  		c.logf("  Preflight aborted: %s!=OPTIONS", r.Method)
   282  		return
   283  	}
   284  	// Always set Vary headers
   285  	// see https://git.sr.ht/~pingoo/stdx/cors/issues/10,
   286  	//     https://git.sr.ht/~pingoo/stdx/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
   287  	headers.Add("Vary", "Origin")
   288  	headers.Add("Vary", "Access-Control-Request-Method")
   289  	headers.Add("Vary", "Access-Control-Request-Headers")
   290  	if c.allowPrivateNetwork {
   291  		headers.Add("Vary", "Access-Control-Request-Private-Network")
   292  	}
   293  
   294  	if origin == "" {
   295  		c.logf("  Preflight aborted: empty origin")
   296  		return
   297  	}
   298  	if !c.isOriginAllowed(r, origin) {
   299  		c.logf("  Preflight aborted: origin '%s' not allowed", origin)
   300  		return
   301  	}
   302  
   303  	reqMethod := r.Header.Get("Access-Control-Request-Method")
   304  	if !c.isMethodAllowed(reqMethod) {
   305  		c.logf("  Preflight aborted: method '%s' not allowed", reqMethod)
   306  		return
   307  	}
   308  	reqHeaders := parseHeaderList(r.Header.Get("Access-Control-Request-Headers"))
   309  	if !c.areHeadersAllowed(reqHeaders) {
   310  		c.logf("  Preflight aborted: headers '%v' not allowed", reqHeaders)
   311  		return
   312  	}
   313  	if c.allowedOriginsAll {
   314  		headers.Set("Access-Control-Allow-Origin", "*")
   315  	} else {
   316  		headers.Set("Access-Control-Allow-Origin", origin)
   317  	}
   318  	// Spec says: Since the list of methods can be unbounded, simply returning the method indicated
   319  	// by Access-Control-Request-Method (if supported) can be enough
   320  	headers.Set("Access-Control-Allow-Methods", strings.ToUpper(reqMethod))
   321  	if len(reqHeaders) > 0 {
   322  
   323  		// Spec says: Since the list of headers can be unbounded, simply returning supported headers
   324  		// from Access-Control-Request-Headers can be enough
   325  		headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", "))
   326  	}
   327  	if c.allowCredentials {
   328  		headers.Set("Access-Control-Allow-Credentials", "true")
   329  	}
   330  	if c.allowPrivateNetwork && r.Header.Get("Access-Control-Request-Private-Network") == "true" {
   331  		headers.Set("Access-Control-Allow-Private-Network", "true")
   332  	}
   333  	if c.maxAge > 0 {
   334  		headers.Set("Access-Control-Max-Age", strconv.Itoa(c.maxAge))
   335  	}
   336  	c.logf("  Preflight response headers: %v", headers)
   337  }
   338  
   339  // handleActualRequest handles simple cross-origin requests, actual request or redirects
   340  func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) {
   341  	headers := w.Header()
   342  	origin := r.Header.Get("Origin")
   343  
   344  	// Always set Vary, see https://git.sr.ht/~pingoo/stdx/cors/issues/10
   345  	headers.Add("Vary", "Origin")
   346  	if origin == "" {
   347  		c.logf("  Actual request no headers added: missing origin")
   348  		return
   349  	}
   350  	if !c.isOriginAllowed(r, origin) {
   351  		c.logf("  Actual request no headers added: origin '%s' not allowed", origin)
   352  		return
   353  	}
   354  
   355  	// Note that spec does define a way to specifically disallow a simple method like GET or
   356  	// POST. Access-Control-Allow-Methods is only used for pre-flight requests and the
   357  	// spec doesn't instruct to check the allowed methods for simple cross-origin requests.
   358  	// We think it's a nice feature to be able to have control on those methods though.
   359  	if !c.isMethodAllowed(r.Method) {
   360  		c.logf("  Actual request no headers added: method '%s' not allowed", r.Method)
   361  
   362  		return
   363  	}
   364  	if c.allowedOriginsAll {
   365  		headers.Set("Access-Control-Allow-Origin", "*")
   366  	} else {
   367  		headers.Set("Access-Control-Allow-Origin", origin)
   368  	}
   369  	if len(c.exposedHeaders) > 0 {
   370  		headers.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders, ", "))
   371  	}
   372  	if c.allowCredentials {
   373  		headers.Set("Access-Control-Allow-Credentials", "true")
   374  	}
   375  	c.logf("  Actual response added headers: %v", headers)
   376  }
   377  
   378  // convenience method. checks if a logger is set.
   379  func (c *Cors) logf(format string, a ...interface{}) {
   380  	if c.Log != nil {
   381  		c.Log.Printf(format, a...)
   382  	}
   383  }
   384  
   385  // check the Origin of a request. No origin at all is also allowed.
   386  func (c *Cors) OriginAllowed(r *http.Request) bool {
   387  	origin := r.Header.Get("Origin")
   388  	return c.isOriginAllowed(r, origin)
   389  }
   390  
   391  // isOriginAllowed checks if a given origin is allowed to perform cross-domain requests
   392  // on the endpoint
   393  func (c *Cors) isOriginAllowed(r *http.Request, origin string) bool {
   394  	if c.allowOriginRequestFunc != nil {
   395  		return c.allowOriginRequestFunc(r, origin)
   396  	}
   397  	if c.allowOriginFunc != nil {
   398  		return c.allowOriginFunc(origin)
   399  	}
   400  	if c.allowedOriginsAll {
   401  		return true
   402  	}
   403  	origin = strings.ToLower(origin)
   404  	for _, o := range c.allowedOrigins {
   405  		if o == origin {
   406  			return true
   407  		}
   408  	}
   409  	for _, w := range c.allowedWOrigins {
   410  		if w.match(origin) {
   411  			return true
   412  		}
   413  	}
   414  	return false
   415  }
   416  
   417  // isMethodAllowed checks if a given method can be used as part of a cross-domain request
   418  // on the endpoint
   419  func (c *Cors) isMethodAllowed(method string) bool {
   420  	if len(c.allowedMethods) == 0 {
   421  		// If no method allowed, always return false, even for preflight request
   422  		return false
   423  	}
   424  	method = strings.ToUpper(method)
   425  	if method == http.MethodOptions {
   426  		// Always allow preflight requests
   427  		return true
   428  	}
   429  	for _, m := range c.allowedMethods {
   430  		if m == method {
   431  			return true
   432  		}
   433  	}
   434  	return false
   435  }
   436  
   437  // areHeadersAllowed checks if a given list of headers are allowed to used within
   438  // a cross-domain request.
   439  func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool {
   440  	if c.allowedHeadersAll || len(requestedHeaders) == 0 {
   441  		return true
   442  	}
   443  	for _, header := range requestedHeaders {
   444  		header = http.CanonicalHeaderKey(header)
   445  		found := false
   446  		for _, h := range c.allowedHeaders {
   447  			if h == header {
   448  				found = true
   449  				break
   450  			}
   451  		}
   452  		if !found {
   453  			return false
   454  		}
   455  	}
   456  	return true
   457  }