github.com/machinefi/w3bstream@v1.6.5-rc9.0.20240426031326-b8c7c4876e72/pkg/depends/conf/http/mws/cors.go (about)

     1  package mws
     2  
     3  // https://github.com/gorilla/handlers/blob/master/cors.go
     4  
     5  import (
     6  	"net/http"
     7  	"strconv"
     8  	"strings"
     9  )
    10  
    11  func DefaultCORS() func(http.Handler) http.Handler {
    12  	return CORS(
    13  		AllowedOrigins([]string{"*"}),
    14  		AllowedMethods(append(defaultCorsMethods, []string{
    15  			http.MethodConnect,
    16  			http.MethodPut,
    17  			http.MethodPatch,
    18  			http.MethodDelete,
    19  			http.MethodOptions,
    20  		}...)),
    21  		AllowCredentials(),
    22  		AllowedHeaders(append(defaultCorsHeaders, []string{
    23  			corsRequestMethodHeader,
    24  			corsRequestHeadersHeader,
    25  			"Content-Type",
    26  			"Authorization",
    27  			"X-Request-Id",
    28  			"User-Agent",
    29  		}...)),
    30  		ExposedHeaders([]string{
    31  			"Content-Type",
    32  			"Origin",
    33  			"b3",
    34  			"User-Agent",
    35  			"X-Requested-With",
    36  			"X-Request-Id",
    37  			"X-Meta",
    38  			"X-RateLimit-Limit", // follow https://developer.github.com/v3/rate_limit/
    39  			"X-RateLimit-Remaining",
    40  			"X-RateLimit-Reset",
    41  		}),
    42  		OptionStatusCode(http.StatusNoContent),
    43  	)
    44  }
    45  
    46  // CORSOption represents a functional option for configuring the CORS middleware.
    47  type CORSOption func(*cors) error
    48  
    49  type cors struct {
    50  	h                      http.Handler
    51  	allowedHeaders         []string
    52  	allowedMethods         []string
    53  	allowedOrigins         []string
    54  	allowedOriginValidator OriginValidator
    55  	exposedHeaders         []string
    56  	maxAge                 int
    57  	ignoreOptions          bool
    58  	allowCredentials       bool
    59  	optionStatusCode       int
    60  }
    61  
    62  // OriginValidator takes an origin string and returns whether or not that origin is allowed.
    63  type OriginValidator func(string) bool
    64  
    65  var (
    66  	defaultCorsOptionStatusCode = 200
    67  	defaultCorsMethods          = []string{"GET", "HEAD", "POST"}
    68  	defaultCorsHeaders          = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
    69  	// (WebKit/Safari v9 sends the Origin header by default in AJAX requests)
    70  )
    71  
    72  const (
    73  	corsOptionMethod           string = "OPTIONS"
    74  	corsAllowOriginHeader      string = "Access-Control-Allow-Origin"
    75  	corsExposeHeadersHeader    string = "Access-Control-Expose-Headers"
    76  	corsMaxAgeHeader           string = "Access-Control-Max-Age"
    77  	corsAllowMethodsHeader     string = "Access-Control-Allow-Methods"
    78  	corsAllowHeadersHeader     string = "Access-Control-Allow-Headers"
    79  	corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials"
    80  	corsRequestMethodHeader    string = "Access-Control-Request-Method"
    81  	corsRequestHeadersHeader   string = "Access-Control-Request-Headers"
    82  	corsOriginHeader           string = "Origin"
    83  	corsVaryHeader             string = "Vary"
    84  	corsOriginMatchAll         string = "*"
    85  )
    86  
    87  func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    88  	origin := r.Header.Get(corsOriginHeader)
    89  	if !ch.isOriginAllowed(origin) {
    90  		if r.Method != corsOptionMethod || ch.ignoreOptions {
    91  			ch.h.ServeHTTP(w, r)
    92  		}
    93  		return
    94  	}
    95  
    96  	if r.Method == corsOptionMethod {
    97  		if ch.ignoreOptions {
    98  			ch.h.ServeHTTP(w, r)
    99  			return
   100  		}
   101  
   102  		if _, ok := r.Header[corsRequestMethodHeader]; !ok {
   103  			w.WriteHeader(http.StatusBadRequest)
   104  			return
   105  		}
   106  
   107  		method := r.Header.Get(corsRequestMethodHeader)
   108  		if !ch.isMatch(method, ch.allowedMethods) {
   109  			w.WriteHeader(http.StatusMethodNotAllowed)
   110  			return
   111  		}
   112  
   113  		requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",")
   114  		allowedHeaders := make([]string, 0)
   115  		for _, v := range requestHeaders {
   116  			canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
   117  			if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) {
   118  				continue
   119  			}
   120  
   121  			if !ch.isMatch(canonicalHeader, ch.allowedHeaders) {
   122  				w.WriteHeader(http.StatusForbidden)
   123  				return
   124  			}
   125  
   126  			allowedHeaders = append(allowedHeaders, canonicalHeader)
   127  		}
   128  
   129  		if len(allowedHeaders) > 0 {
   130  			w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ","))
   131  		}
   132  
   133  		if ch.maxAge > 0 {
   134  			w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge))
   135  		}
   136  
   137  		if !ch.isMatch(method, defaultCorsMethods) {
   138  			w.Header().Set(corsAllowMethodsHeader, method)
   139  		}
   140  	} else {
   141  		if len(ch.exposedHeaders) > 0 {
   142  			w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ","))
   143  		}
   144  	}
   145  
   146  	if ch.allowCredentials {
   147  		w.Header().Set(corsAllowCredentialsHeader, "true")
   148  	}
   149  
   150  	if len(ch.allowedOrigins) > 1 {
   151  		w.Header().Set(corsVaryHeader, corsOriginHeader)
   152  	}
   153  
   154  	returnOrigin := origin
   155  	if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 {
   156  		returnOrigin = "*"
   157  	} else {
   158  		for _, o := range ch.allowedOrigins {
   159  			// A configuration of * is different than explicitly setting an allowed
   160  			// origin. Returning arbitrary origin headers in an access control allow
   161  			// origin header is unsafe and is not required by any use case.
   162  			if o == corsOriginMatchAll {
   163  				returnOrigin = "*"
   164  				break
   165  			}
   166  		}
   167  	}
   168  	w.Header().Set(corsAllowOriginHeader, returnOrigin)
   169  
   170  	if r.Method == corsOptionMethod {
   171  		w.WriteHeader(ch.optionStatusCode)
   172  		return
   173  	}
   174  	ch.h.ServeHTTP(w, r)
   175  }
   176  
   177  // CORS provides Cross-Origin Resource Sharing middleware.
   178  // Example:
   179  //
   180  //	import (
   181  //	    "net/http"
   182  //
   183  //	    "github.com/gorilla/handlers"
   184  //	    "github.com/gorilla/mux"
   185  //	)
   186  //
   187  //	func main() {
   188  //	    r := mux.NewRouter()
   189  //	    r.HandleFunc("/users", UserEndpoint)
   190  //	    r.HandleFunc("/projects", ProjectEndpoint)
   191  //
   192  //	    // Apply the CORS middleware to our top-level router, with the defaults.
   193  //	    http.ListenAndServe(":8000", handlers.CORS()(r))
   194  //	}
   195  func CORS(opts ...CORSOption) func(http.Handler) http.Handler {
   196  	return func(h http.Handler) http.Handler {
   197  		ch := parseCORSOptions(opts...)
   198  		ch.h = h
   199  		return ch
   200  	}
   201  }
   202  
   203  func parseCORSOptions(opts ...CORSOption) *cors {
   204  	ch := &cors{
   205  		allowedMethods:   defaultCorsMethods,
   206  		allowedHeaders:   defaultCorsHeaders,
   207  		allowedOrigins:   []string{},
   208  		optionStatusCode: defaultCorsOptionStatusCode,
   209  	}
   210  
   211  	for _, option := range opts {
   212  		_ = option(ch)
   213  	}
   214  
   215  	return ch
   216  }
   217  
   218  //
   219  // Functional options for configuring CORS.
   220  //
   221  
   222  // AllowedHeaders adds the provided headers to the list of allowed headers in a
   223  // CORS request.
   224  // This is an append operation so the headers Accept, Accept-Language,
   225  // and Content-Language are always allowed.
   226  // Content-Type must be explicitly declared if accepting Content-Types other than
   227  // application/x-www-form-urlencoded, multipart/form-data, or text/plain.
   228  func AllowedHeaders(headers []string) CORSOption {
   229  	return func(ch *cors) error {
   230  		for _, v := range headers {
   231  			normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
   232  			if normalizedHeader == "" {
   233  				continue
   234  			}
   235  
   236  			if !ch.isMatch(normalizedHeader, ch.allowedHeaders) {
   237  				ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader)
   238  			}
   239  		}
   240  
   241  		return nil
   242  	}
   243  }
   244  
   245  // AllowedMethods can be used to explicitly allow methods in the
   246  // Access-Control-Allow-Methods header.
   247  // This is a replacement operation so you must also
   248  // pass GET, HEAD, and POST if you wish to support those methods.
   249  func AllowedMethods(methods []string) CORSOption {
   250  	return func(ch *cors) error {
   251  		ch.allowedMethods = []string{}
   252  		for _, v := range methods {
   253  			normalizedMethod := strings.ToUpper(strings.TrimSpace(v))
   254  			if normalizedMethod == "" {
   255  				continue
   256  			}
   257  
   258  			if !ch.isMatch(normalizedMethod, ch.allowedMethods) {
   259  				ch.allowedMethods = append(ch.allowedMethods, normalizedMethod)
   260  			}
   261  		}
   262  
   263  		return nil
   264  	}
   265  }
   266  
   267  // AllowedOrigins sets the allowed origins for CORS requests, as used in the
   268  // 'Allow-Access-Control-Origin' HTTP header.
   269  // Note: Passing in a []string{"*"} will allow any domain.
   270  func AllowedOrigins(origins []string) CORSOption {
   271  	return func(ch *cors) error {
   272  		for _, v := range origins {
   273  			if v == corsOriginMatchAll {
   274  				ch.allowedOrigins = []string{corsOriginMatchAll}
   275  				return nil
   276  			}
   277  		}
   278  
   279  		ch.allowedOrigins = origins
   280  		return nil
   281  	}
   282  }
   283  
   284  // AllowedOriginValidator sets a function for evaluating allowed origins in CORS requests, represented by the
   285  // 'Allow-Access-Control-Origin' HTTP header.
   286  func AllowedOriginValidator(fn OriginValidator) CORSOption {
   287  	return func(ch *cors) error {
   288  		ch.allowedOriginValidator = fn
   289  		return nil
   290  	}
   291  }
   292  
   293  // OptionStatusCode sets a custom status code on the OPTIONS requests.
   294  // Default behaviour sets it to 200 to reflect best practices. This is option is not mandatory
   295  // and can be used if you need a custom status code (i.e 204).
   296  //
   297  // More informations on the spec:
   298  // https://fetch.spec.whatwg.org/#cors-preflight-fetch
   299  func OptionStatusCode(code int) CORSOption {
   300  	return func(ch *cors) error {
   301  		ch.optionStatusCode = code
   302  		return nil
   303  	}
   304  }
   305  
   306  // ExposedHeaders can be used to specify headers that are available
   307  // and will not be stripped out by the user-agent.
   308  func ExposedHeaders(headers []string) CORSOption {
   309  	return func(ch *cors) error {
   310  		ch.exposedHeaders = []string{}
   311  		for _, v := range headers {
   312  			normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
   313  			if normalizedHeader == "" {
   314  				continue
   315  			}
   316  
   317  			if !ch.isMatch(normalizedHeader, ch.exposedHeaders) {
   318  				ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader)
   319  			}
   320  		}
   321  
   322  		return nil
   323  	}
   324  }
   325  
   326  // MaxAge determines the maximum age (in seconds) between preflight requests. A
   327  // maximum of 10 minutes is allowed. An age above this value will default to 10
   328  // minutes.
   329  func MaxAge(age int) CORSOption {
   330  	return func(ch *cors) error {
   331  		// Maximum of 10 minutes.
   332  		if age > 600 {
   333  			age = 600
   334  		}
   335  
   336  		ch.maxAge = age
   337  		return nil
   338  	}
   339  }
   340  
   341  // IgnoreOptions causes the CORS middleware to ignore OPTIONS requests, instead
   342  // passing them through to the next handler. This is useful when your application
   343  // or framework has a pre-existing mechanism for responding to OPTIONS requests.
   344  func IgnoreOptions() CORSOption {
   345  	return func(ch *cors) error {
   346  		ch.ignoreOptions = true
   347  		return nil
   348  	}
   349  }
   350  
   351  // AllowCredentials can be used to specify that the user agent may pass
   352  // authentication details along with the request.
   353  func AllowCredentials() CORSOption {
   354  	return func(ch *cors) error {
   355  		ch.allowCredentials = true
   356  		return nil
   357  	}
   358  }
   359  
   360  func (ch *cors) isOriginAllowed(origin string) bool {
   361  	if origin == "" {
   362  		return false
   363  	}
   364  
   365  	if ch.allowedOriginValidator != nil {
   366  		return ch.allowedOriginValidator(origin)
   367  	}
   368  
   369  	if len(ch.allowedOrigins) == 0 {
   370  		return true
   371  	}
   372  
   373  	for _, allowedOrigin := range ch.allowedOrigins {
   374  		if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll {
   375  			return true
   376  		}
   377  	}
   378  
   379  	return false
   380  }
   381  
   382  func (ch *cors) isMatch(needle string, haystack []string) bool {
   383  	for _, v := range haystack {
   384  		if v == needle {
   385  			return true
   386  		}
   387  	}
   388  
   389  	return false
   390  }