gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/handlers/cors.go (about)

     1  package handlers
     2  
     3  import (
     4  	"strconv"
     5  	"strings"
     6  
     7  	http "gitee.com/ks-custle/core-gm/gmhttp"
     8  )
     9  
    10  // CORSOption represents a functional option for configuring the CORS middleware.
    11  type CORSOption func(*cors) error
    12  
    13  type cors struct {
    14  	h                      http.Handler
    15  	allowedHeaders         []string
    16  	allowedMethods         []string
    17  	allowedOrigins         []string
    18  	allowedOriginValidator OriginValidator
    19  	exposedHeaders         []string
    20  	maxAge                 int
    21  	ignoreOptions          bool
    22  	allowCredentials       bool
    23  	optionStatusCode       int
    24  }
    25  
    26  // OriginValidator takes an origin string and returns whether or not that origin is allowed.
    27  type OriginValidator func(string) bool
    28  
    29  var (
    30  	defaultCorsOptionStatusCode = 200
    31  	defaultCorsMethods          = []string{"GET", "HEAD", "POST"}
    32  	defaultCorsHeaders          = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
    33  	// (WebKit/Safari v9 sends the Origin header by default in AJAX requests)
    34  )
    35  
    36  const (
    37  	corsOptionMethod           string = "OPTIONS"
    38  	corsAllowOriginHeader      string = "Access-Control-Allow-Origin"
    39  	corsExposeHeadersHeader    string = "Access-Control-Expose-Headers"
    40  	corsMaxAgeHeader           string = "Access-Control-Max-Age"
    41  	corsAllowMethodsHeader     string = "Access-Control-Allow-Methods"
    42  	corsAllowHeadersHeader     string = "Access-Control-Allow-Headers"
    43  	corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials"
    44  	corsRequestMethodHeader    string = "Access-Control-Request-Method"
    45  	corsRequestHeadersHeader   string = "Access-Control-Request-Headers"
    46  	corsOriginHeader           string = "Origin"
    47  	corsVaryHeader             string = "Vary"
    48  	corsOriginMatchAll         string = "*"
    49  )
    50  
    51  func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    52  	origin := r.Header.Get(corsOriginHeader)
    53  	if !ch.isOriginAllowed(origin) {
    54  		if r.Method != corsOptionMethod || ch.ignoreOptions {
    55  			ch.h.ServeHTTP(w, r)
    56  		}
    57  
    58  		return
    59  	}
    60  
    61  	if r.Method == corsOptionMethod {
    62  		if ch.ignoreOptions {
    63  			ch.h.ServeHTTP(w, r)
    64  			return
    65  		}
    66  
    67  		if _, ok := r.Header[corsRequestMethodHeader]; !ok {
    68  			w.WriteHeader(http.StatusBadRequest)
    69  			return
    70  		}
    71  
    72  		method := r.Header.Get(corsRequestMethodHeader)
    73  		if !ch.isMatch(method, ch.allowedMethods) {
    74  			w.WriteHeader(http.StatusMethodNotAllowed)
    75  			return
    76  		}
    77  
    78  		requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",")
    79  		allowedHeaders := []string{}
    80  		for _, v := range requestHeaders {
    81  			canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
    82  			if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) {
    83  				continue
    84  			}
    85  
    86  			if !ch.isMatch(canonicalHeader, ch.allowedHeaders) {
    87  				w.WriteHeader(http.StatusForbidden)
    88  				return
    89  			}
    90  
    91  			allowedHeaders = append(allowedHeaders, canonicalHeader)
    92  		}
    93  
    94  		if len(allowedHeaders) > 0 {
    95  			w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ","))
    96  		}
    97  
    98  		if ch.maxAge > 0 {
    99  			w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge))
   100  		}
   101  
   102  		if !ch.isMatch(method, defaultCorsMethods) {
   103  			w.Header().Set(corsAllowMethodsHeader, method)
   104  		}
   105  	} else {
   106  		if len(ch.exposedHeaders) > 0 {
   107  			w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ","))
   108  		}
   109  	}
   110  
   111  	if ch.allowCredentials {
   112  		w.Header().Set(corsAllowCredentialsHeader, "true")
   113  	}
   114  
   115  	if len(ch.allowedOrigins) > 1 {
   116  		w.Header().Set(corsVaryHeader, corsOriginHeader)
   117  	}
   118  
   119  	returnOrigin := origin
   120  	if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 {
   121  		returnOrigin = "*"
   122  	} else {
   123  		for _, o := range ch.allowedOrigins {
   124  			// A configuration of * is different than explicitly setting an allowed
   125  			// origin. Returning arbitrary origin headers in an access control allow
   126  			// origin header is unsafe and is not required by any use case.
   127  			if o == corsOriginMatchAll {
   128  				returnOrigin = "*"
   129  				break
   130  			}
   131  		}
   132  	}
   133  	w.Header().Set(corsAllowOriginHeader, returnOrigin)
   134  
   135  	if r.Method == corsOptionMethod {
   136  		w.WriteHeader(ch.optionStatusCode)
   137  		return
   138  	}
   139  	ch.h.ServeHTTP(w, r)
   140  }
   141  
   142  // CORS provides Cross-Origin Resource Sharing middleware.
   143  // Example:
   144  //
   145  //	import (
   146  //	    http "gitee.com/ks-custle/core-gm/gmhttp"
   147  //
   148  //	    "gitee.com/ks-custle/core-gm/handlers"
   149  //	    "gitee.com/ks-custle/core-gm/mux"
   150  //	)
   151  //
   152  //	func main() {
   153  //	    r := mux.NewRouter()
   154  //	    r.HandleFunc("/users", UserEndpoint)
   155  //	    r.HandleFunc("/projects", ProjectEndpoint)
   156  //
   157  //	    // Apply the CORS middleware to our top-level router, with the defaults.
   158  //	    http.ListenAndServe(":8000", handlers.CORS()(r))
   159  //	}
   160  func CORS(opts ...CORSOption) func(http.Handler) http.Handler {
   161  	return func(h http.Handler) http.Handler {
   162  		ch := parseCORSOptions(opts...)
   163  		ch.h = h
   164  		return ch
   165  	}
   166  }
   167  
   168  func parseCORSOptions(opts ...CORSOption) *cors {
   169  	ch := &cors{
   170  		allowedMethods:   defaultCorsMethods,
   171  		allowedHeaders:   defaultCorsHeaders,
   172  		allowedOrigins:   []string{},
   173  		optionStatusCode: defaultCorsOptionStatusCode,
   174  	}
   175  
   176  	for _, option := range opts {
   177  		option(ch)
   178  	}
   179  
   180  	return ch
   181  }
   182  
   183  //
   184  // Functional options for configuring CORS.
   185  //
   186  
   187  // AllowedHeaders adds the provided headers to the list of allowed headers in a
   188  // CORS request.
   189  // This is an append operation so the headers Accept, Accept-Language,
   190  // and Content-Language are always allowed.
   191  // Content-Type must be explicitly declared if accepting Content-Types other than
   192  // application/x-www-form-urlencoded, multipart/form-data, or text/plain.
   193  func AllowedHeaders(headers []string) CORSOption {
   194  	return func(ch *cors) error {
   195  		for _, v := range headers {
   196  			normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
   197  			if normalizedHeader == "" {
   198  				continue
   199  			}
   200  
   201  			if !ch.isMatch(normalizedHeader, ch.allowedHeaders) {
   202  				ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader)
   203  			}
   204  		}
   205  
   206  		return nil
   207  	}
   208  }
   209  
   210  // AllowedMethods can be used to explicitly allow methods in the
   211  // Access-Control-Allow-Methods header.
   212  // This is a replacement operation so you must also
   213  // pass GET, HEAD, and POST if you wish to support those methods.
   214  func AllowedMethods(methods []string) CORSOption {
   215  	return func(ch *cors) error {
   216  		ch.allowedMethods = []string{}
   217  		for _, v := range methods {
   218  			normalizedMethod := strings.ToUpper(strings.TrimSpace(v))
   219  			if normalizedMethod == "" {
   220  				continue
   221  			}
   222  
   223  			if !ch.isMatch(normalizedMethod, ch.allowedMethods) {
   224  				ch.allowedMethods = append(ch.allowedMethods, normalizedMethod)
   225  			}
   226  		}
   227  
   228  		return nil
   229  	}
   230  }
   231  
   232  // AllowedOrigins sets the allowed origins for CORS requests, as used in the
   233  // 'Allow-Access-Control-Origin' HTTP header.
   234  // Note: Passing in a []string{"*"} will allow any domain.
   235  func AllowedOrigins(origins []string) CORSOption {
   236  	return func(ch *cors) error {
   237  		for _, v := range origins {
   238  			if v == corsOriginMatchAll {
   239  				ch.allowedOrigins = []string{corsOriginMatchAll}
   240  				return nil
   241  			}
   242  		}
   243  
   244  		ch.allowedOrigins = origins
   245  		return nil
   246  	}
   247  }
   248  
   249  // AllowedOriginValidator sets a function for evaluating allowed origins in CORS requests, represented by the
   250  // 'Allow-Access-Control-Origin' HTTP header.
   251  func AllowedOriginValidator(fn OriginValidator) CORSOption {
   252  	return func(ch *cors) error {
   253  		ch.allowedOriginValidator = fn
   254  		return nil
   255  	}
   256  }
   257  
   258  // OptionStatusCode sets a custom status code on the OPTIONS requests.
   259  // Default behaviour sets it to 200 to reflect best practices. This is option is not mandatory
   260  // and can be used if you need a custom status code (i.e 204).
   261  //
   262  // More informations on the spec:
   263  // https://fetch.spec.whatwg.org/#cors-preflight-fetch
   264  func OptionStatusCode(code int) CORSOption {
   265  	return func(ch *cors) error {
   266  		ch.optionStatusCode = code
   267  		return nil
   268  	}
   269  }
   270  
   271  // ExposedHeaders can be used to specify headers that are available
   272  // and will not be stripped out by the user-agent.
   273  func ExposedHeaders(headers []string) CORSOption {
   274  	return func(ch *cors) error {
   275  		ch.exposedHeaders = []string{}
   276  		for _, v := range headers {
   277  			normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
   278  			if normalizedHeader == "" {
   279  				continue
   280  			}
   281  
   282  			if !ch.isMatch(normalizedHeader, ch.exposedHeaders) {
   283  				ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader)
   284  			}
   285  		}
   286  
   287  		return nil
   288  	}
   289  }
   290  
   291  // MaxAge determines the maximum age (in seconds) between preflight requests. A
   292  // maximum of 10 minutes is allowed. An age above this value will default to 10
   293  // minutes.
   294  func MaxAge(age int) CORSOption {
   295  	return func(ch *cors) error {
   296  		// Maximum of 10 minutes.
   297  		if age > 600 {
   298  			age = 600
   299  		}
   300  
   301  		ch.maxAge = age
   302  		return nil
   303  	}
   304  }
   305  
   306  // IgnoreOptions causes the CORS middleware to ignore OPTIONS requests, instead
   307  // passing them through to the next handler. This is useful when your application
   308  // or framework has a pre-existing mechanism for responding to OPTIONS requests.
   309  func IgnoreOptions() CORSOption {
   310  	return func(ch *cors) error {
   311  		ch.ignoreOptions = true
   312  		return nil
   313  	}
   314  }
   315  
   316  // AllowCredentials can be used to specify that the user agent may pass
   317  // authentication details along with the request.
   318  func AllowCredentials() CORSOption {
   319  	return func(ch *cors) error {
   320  		ch.allowCredentials = true
   321  		return nil
   322  	}
   323  }
   324  
   325  func (ch *cors) isOriginAllowed(origin string) bool {
   326  	if origin == "" {
   327  		return false
   328  	}
   329  
   330  	if ch.allowedOriginValidator != nil {
   331  		return ch.allowedOriginValidator(origin)
   332  	}
   333  
   334  	if len(ch.allowedOrigins) == 0 {
   335  		return true
   336  	}
   337  
   338  	for _, allowedOrigin := range ch.allowedOrigins {
   339  		if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll {
   340  			return true
   341  		}
   342  	}
   343  
   344  	return false
   345  }
   346  
   347  func (ch *cors) isMatch(needle string, haystack []string) bool {
   348  	for _, v := range haystack {
   349  		if v == needle {
   350  			return true
   351  		}
   352  	}
   353  
   354  	return false
   355  }