github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/handlers/cors.go (about)

     1  package handlers
     2  
     3  import (
     4  	"strconv"
     5  	"strings"
     6  
     7  	http "github.com/hxx258456/ccgo/gmhttp"
     8  )
     9  
    10  // CORSOption represents a functional option for configuring the CORS middleware.
    11  type CORSOption func(*cors) error
    12  
    13  type cors struct {
    14  	h                      http.Handler
    15  	allowedHeaders         []string
    16  	allowedMethods         []string
    17  	allowedOrigins         []string
    18  	allowedOriginValidator OriginValidator
    19  	exposedHeaders         []string
    20  	maxAge                 int
    21  	ignoreOptions          bool
    22  	allowCredentials       bool
    23  	optionStatusCode       int
    24  }
    25  
    26  // OriginValidator takes an origin string and returns whether or not that origin is allowed.
    27  type OriginValidator func(string) bool
    28  
    29  var (
    30  	defaultCorsOptionStatusCode = 200
    31  	defaultCorsMethods          = []string{"GET", "HEAD", "POST"}
    32  	defaultCorsHeaders          = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
    33  	// (WebKit/Safari v9 sends the Origin header by default in AJAX requests)
    34  )
    35  
    36  const (
    37  	corsOptionMethod           string = "OPTIONS"
    38  	corsAllowOriginHeader      string = "Access-Control-Allow-Origin"
    39  	corsExposeHeadersHeader    string = "Access-Control-Expose-Headers"
    40  	corsMaxAgeHeader           string = "Access-Control-Max-Age"
    41  	corsAllowMethodsHeader     string = "Access-Control-Allow-Methods"
    42  	corsAllowHeadersHeader     string = "Access-Control-Allow-Headers"
    43  	corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials"
    44  	corsRequestMethodHeader    string = "Access-Control-Request-Method"
    45  	corsRequestHeadersHeader   string = "Access-Control-Request-Headers"
    46  	corsOriginHeader           string = "Origin"
    47  	corsVaryHeader             string = "Vary"
    48  	corsOriginMatchAll         string = "*"
    49  )
    50  
    51  func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    52  	origin := r.Header.Get(corsOriginHeader)
    53  	if !ch.isOriginAllowed(origin) {
    54  		if r.Method != corsOptionMethod || ch.ignoreOptions {
    55  			ch.h.ServeHTTP(w, r)
    56  		}
    57  
    58  		return
    59  	}
    60  
    61  	if r.Method == corsOptionMethod {
    62  		if ch.ignoreOptions {
    63  			ch.h.ServeHTTP(w, r)
    64  			return
    65  		}
    66  
    67  		if _, ok := r.Header[corsRequestMethodHeader]; !ok {
    68  			w.WriteHeader(http.StatusBadRequest)
    69  			return
    70  		}
    71  
    72  		method := r.Header.Get(corsRequestMethodHeader)
    73  		if !ch.isMatch(method, ch.allowedMethods) {
    74  			w.WriteHeader(http.StatusMethodNotAllowed)
    75  			return
    76  		}
    77  
    78  		requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",")
    79  		allowedHeaders := []string{}
    80  		for _, v := range requestHeaders {
    81  			canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
    82  			if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) {
    83  				continue
    84  			}
    85  
    86  			if !ch.isMatch(canonicalHeader, ch.allowedHeaders) {
    87  				w.WriteHeader(http.StatusForbidden)
    88  				return
    89  			}
    90  
    91  			allowedHeaders = append(allowedHeaders, canonicalHeader)
    92  		}
    93  
    94  		if len(allowedHeaders) > 0 {
    95  			w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ","))
    96  		}
    97  
    98  		if ch.maxAge > 0 {
    99  			w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge))
   100  		}
   101  
   102  		if !ch.isMatch(method, defaultCorsMethods) {
   103  			w.Header().Set(corsAllowMethodsHeader, method)
   104  		}
   105  	} else {
   106  		if len(ch.exposedHeaders) > 0 {
   107  			w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ","))
   108  		}
   109  	}
   110  
   111  	if ch.allowCredentials {
   112  		w.Header().Set(corsAllowCredentialsHeader, "true")
   113  	}
   114  
   115  	if len(ch.allowedOrigins) > 1 {
   116  		w.Header().Set(corsVaryHeader, corsOriginHeader)
   117  	}
   118  
   119  	returnOrigin := origin
   120  	if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 {
   121  		returnOrigin = "*"
   122  	} else {
   123  		for _, o := range ch.allowedOrigins {
   124  			// A configuration of * is different than explicitly setting an allowed
   125  			// origin. Returning arbitrary origin headers in an access control allow
   126  			// origin header is unsafe and is not required by any use case.
   127  			if o == corsOriginMatchAll {
   128  				returnOrigin = "*"
   129  				break
   130  			}
   131  		}
   132  	}
   133  	w.Header().Set(corsAllowOriginHeader, returnOrigin)
   134  
   135  	if r.Method == corsOptionMethod {
   136  		w.WriteHeader(ch.optionStatusCode)
   137  		return
   138  	}
   139  	ch.h.ServeHTTP(w, r)
   140  }
   141  
   142  // CORS provides Cross-Origin Resource Sharing middleware.
   143  // Example:
   144  //
   145  //  import (
   146  //      http "github.com/hxx258456/ccgo/gmhttp"
   147  //
   148  //      "github.com/hxx258456/ccgo/handlers"
   149  //      "github.com/hxx258456/ccgo/mux"
   150  //  )
   151  //
   152  //  func main() {
   153  //      r := mux.NewRouter()
   154  //      r.HandleFunc("/users", UserEndpoint)
   155  //      r.HandleFunc("/projects", ProjectEndpoint)
   156  //
   157  //      // Apply the CORS middleware to our top-level router, with the defaults.
   158  //      http.ListenAndServe(":8000", handlers.CORS()(r))
   159  //  }
   160  //
   161  func CORS(opts ...CORSOption) func(http.Handler) http.Handler {
   162  	return func(h http.Handler) http.Handler {
   163  		ch := parseCORSOptions(opts...)
   164  		ch.h = h
   165  		return ch
   166  	}
   167  }
   168  
   169  func parseCORSOptions(opts ...CORSOption) *cors {
   170  	ch := &cors{
   171  		allowedMethods:   defaultCorsMethods,
   172  		allowedHeaders:   defaultCorsHeaders,
   173  		allowedOrigins:   []string{},
   174  		optionStatusCode: defaultCorsOptionStatusCode,
   175  	}
   176  
   177  	for _, option := range opts {
   178  		option(ch)
   179  	}
   180  
   181  	return ch
   182  }
   183  
   184  //
   185  // Functional options for configuring CORS.
   186  //
   187  
   188  // AllowedHeaders adds the provided headers to the list of allowed headers in a
   189  // CORS request.
   190  // This is an append operation so the headers Accept, Accept-Language,
   191  // and Content-Language are always allowed.
   192  // Content-Type must be explicitly declared if accepting Content-Types other than
   193  // application/x-www-form-urlencoded, multipart/form-data, or text/plain.
   194  func AllowedHeaders(headers []string) CORSOption {
   195  	return func(ch *cors) error {
   196  		for _, v := range headers {
   197  			normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
   198  			if normalizedHeader == "" {
   199  				continue
   200  			}
   201  
   202  			if !ch.isMatch(normalizedHeader, ch.allowedHeaders) {
   203  				ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader)
   204  			}
   205  		}
   206  
   207  		return nil
   208  	}
   209  }
   210  
   211  // AllowedMethods can be used to explicitly allow methods in the
   212  // Access-Control-Allow-Methods header.
   213  // This is a replacement operation so you must also
   214  // pass GET, HEAD, and POST if you wish to support those methods.
   215  func AllowedMethods(methods []string) CORSOption {
   216  	return func(ch *cors) error {
   217  		ch.allowedMethods = []string{}
   218  		for _, v := range methods {
   219  			normalizedMethod := strings.ToUpper(strings.TrimSpace(v))
   220  			if normalizedMethod == "" {
   221  				continue
   222  			}
   223  
   224  			if !ch.isMatch(normalizedMethod, ch.allowedMethods) {
   225  				ch.allowedMethods = append(ch.allowedMethods, normalizedMethod)
   226  			}
   227  		}
   228  
   229  		return nil
   230  	}
   231  }
   232  
   233  // AllowedOrigins sets the allowed origins for CORS requests, as used in the
   234  // 'Allow-Access-Control-Origin' HTTP header.
   235  // Note: Passing in a []string{"*"} will allow any domain.
   236  func AllowedOrigins(origins []string) CORSOption {
   237  	return func(ch *cors) error {
   238  		for _, v := range origins {
   239  			if v == corsOriginMatchAll {
   240  				ch.allowedOrigins = []string{corsOriginMatchAll}
   241  				return nil
   242  			}
   243  		}
   244  
   245  		ch.allowedOrigins = origins
   246  		return nil
   247  	}
   248  }
   249  
   250  // AllowedOriginValidator sets a function for evaluating allowed origins in CORS requests, represented by the
   251  // 'Allow-Access-Control-Origin' HTTP header.
   252  func AllowedOriginValidator(fn OriginValidator) CORSOption {
   253  	return func(ch *cors) error {
   254  		ch.allowedOriginValidator = fn
   255  		return nil
   256  	}
   257  }
   258  
   259  // OptionStatusCode sets a custom status code on the OPTIONS requests.
   260  // Default behaviour sets it to 200 to reflect best practices. This is option is not mandatory
   261  // and can be used if you need a custom status code (i.e 204).
   262  //
   263  // More informations on the spec:
   264  // https://fetch.spec.whatwg.org/#cors-preflight-fetch
   265  func OptionStatusCode(code int) CORSOption {
   266  	return func(ch *cors) error {
   267  		ch.optionStatusCode = code
   268  		return nil
   269  	}
   270  }
   271  
   272  // ExposedHeaders can be used to specify headers that are available
   273  // and will not be stripped out by the user-agent.
   274  func ExposedHeaders(headers []string) CORSOption {
   275  	return func(ch *cors) error {
   276  		ch.exposedHeaders = []string{}
   277  		for _, v := range headers {
   278  			normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
   279  			if normalizedHeader == "" {
   280  				continue
   281  			}
   282  
   283  			if !ch.isMatch(normalizedHeader, ch.exposedHeaders) {
   284  				ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader)
   285  			}
   286  		}
   287  
   288  		return nil
   289  	}
   290  }
   291  
   292  // MaxAge determines the maximum age (in seconds) between preflight requests. A
   293  // maximum of 10 minutes is allowed. An age above this value will default to 10
   294  // minutes.
   295  func MaxAge(age int) CORSOption {
   296  	return func(ch *cors) error {
   297  		// Maximum of 10 minutes.
   298  		if age > 600 {
   299  			age = 600
   300  		}
   301  
   302  		ch.maxAge = age
   303  		return nil
   304  	}
   305  }
   306  
   307  // IgnoreOptions causes the CORS middleware to ignore OPTIONS requests, instead
   308  // passing them through to the next handler. This is useful when your application
   309  // or framework has a pre-existing mechanism for responding to OPTIONS requests.
   310  func IgnoreOptions() CORSOption {
   311  	return func(ch *cors) error {
   312  		ch.ignoreOptions = true
   313  		return nil
   314  	}
   315  }
   316  
   317  // AllowCredentials can be used to specify that the user agent may pass
   318  // authentication details along with the request.
   319  func AllowCredentials() CORSOption {
   320  	return func(ch *cors) error {
   321  		ch.allowCredentials = true
   322  		return nil
   323  	}
   324  }
   325  
   326  func (ch *cors) isOriginAllowed(origin string) bool {
   327  	if origin == "" {
   328  		return false
   329  	}
   330  
   331  	if ch.allowedOriginValidator != nil {
   332  		return ch.allowedOriginValidator(origin)
   333  	}
   334  
   335  	if len(ch.allowedOrigins) == 0 {
   336  		return true
   337  	}
   338  
   339  	for _, allowedOrigin := range ch.allowedOrigins {
   340  		if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll {
   341  			return true
   342  		}
   343  	}
   344  
   345  	return false
   346  }
   347  
   348  func (ch *cors) isMatch(needle string, haystack []string) bool {
   349  	for _, v := range haystack {
   350  		if v == needle {
   351  			return true
   352  		}
   353  	}
   354  
   355  	return false
   356  }