github.com/hellobchain/third_party@v0.0.0-20230331131523-deb0478a2e52/gorilla/handlers/cors.go (about)

     1  package handlers
     2  
     3  import (
     4  	"github.com/hellobchain/newcryptosm/http"
     5  	"strconv"
     6  	"strings"
     7  )
     8  
     9  // CORSOption represents a functional option for configuring the CORS middleware.
    10  type CORSOption func(*cors) error
    11  
    12  type cors struct {
    13  	h                      http.Handler
    14  	allowedHeaders         []string
    15  	allowedMethods         []string
    16  	allowedOrigins         []string
    17  	allowedOriginValidator OriginValidator
    18  	exposedHeaders         []string
    19  	maxAge                 int
    20  	ignoreOptions          bool
    21  	allowCredentials       bool
    22  }
    23  
    24  // OriginValidator takes an origin string and returns whether or not that origin is allowed.
    25  type OriginValidator func(string) bool
    26  
    27  var (
    28  	defaultCorsMethods = []string{"GET", "HEAD", "POST"}
    29  	defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
    30  	// (WebKit/Safari v9 sends the Origin header by default in AJAX requests)
    31  )
    32  
    33  const (
    34  	corsOptionMethod           string = "OPTIONS"
    35  	corsAllowOriginHeader      string = "Access-Control-Allow-Origin"
    36  	corsExposeHeadersHeader    string = "Access-Control-Expose-Headers"
    37  	corsMaxAgeHeader           string = "Access-Control-Max-Age"
    38  	corsAllowMethodsHeader     string = "Access-Control-Allow-Methods"
    39  	corsAllowHeadersHeader     string = "Access-Control-Allow-Headers"
    40  	corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials"
    41  	corsRequestMethodHeader    string = "Access-Control-Request-Method"
    42  	corsRequestHeadersHeader   string = "Access-Control-Request-Headers"
    43  	corsOriginHeader           string = "Origin"
    44  	corsVaryHeader             string = "Vary"
    45  	corsOriginMatchAll         string = "*"
    46  )
    47  
    48  func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    49  	origin := r.Header.Get(corsOriginHeader)
    50  	if !ch.isOriginAllowed(origin) {
    51  		if r.Method != corsOptionMethod || ch.ignoreOptions {
    52  			ch.h.ServeHTTP(w, r)
    53  		}
    54  
    55  		return
    56  	}
    57  
    58  	if r.Method == corsOptionMethod {
    59  		if ch.ignoreOptions {
    60  			ch.h.ServeHTTP(w, r)
    61  			return
    62  		}
    63  
    64  		if _, ok := r.Header[corsRequestMethodHeader]; !ok {
    65  			w.WriteHeader(http.StatusBadRequest)
    66  			return
    67  		}
    68  
    69  		method := r.Header.Get(corsRequestMethodHeader)
    70  		if !ch.isMatch(method, ch.allowedMethods) {
    71  			w.WriteHeader(http.StatusMethodNotAllowed)
    72  			return
    73  		}
    74  
    75  		requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",")
    76  		allowedHeaders := []string{}
    77  		for _, v := range requestHeaders {
    78  			canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
    79  			if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) {
    80  				continue
    81  			}
    82  
    83  			if !ch.isMatch(canonicalHeader, ch.allowedHeaders) {
    84  				w.WriteHeader(http.StatusForbidden)
    85  				return
    86  			}
    87  
    88  			allowedHeaders = append(allowedHeaders, canonicalHeader)
    89  		}
    90  
    91  		if len(allowedHeaders) > 0 {
    92  			w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ","))
    93  		}
    94  
    95  		if ch.maxAge > 0 {
    96  			w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge))
    97  		}
    98  
    99  		if !ch.isMatch(method, defaultCorsMethods) {
   100  			w.Header().Set(corsAllowMethodsHeader, method)
   101  		}
   102  	} else {
   103  		if len(ch.exposedHeaders) > 0 {
   104  			w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ","))
   105  		}
   106  	}
   107  
   108  	if ch.allowCredentials {
   109  		w.Header().Set(corsAllowCredentialsHeader, "true")
   110  	}
   111  
   112  	if len(ch.allowedOrigins) > 1 {
   113  		w.Header().Set(corsVaryHeader, corsOriginHeader)
   114  	}
   115  
   116  	returnOrigin := origin
   117  	if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 {
   118  		returnOrigin = "*"
   119  	} else {
   120  		for _, o := range ch.allowedOrigins {
   121  			// A configuration of * is different than explicitly setting an allowed
   122  			// origin. Returning arbitrary origin headers in an access control allow
   123  			// origin header is unsafe and is not required by any use case.
   124  			if o == corsOriginMatchAll {
   125  				returnOrigin = "*"
   126  				break
   127  			}
   128  		}
   129  	}
   130  	w.Header().Set(corsAllowOriginHeader, returnOrigin)
   131  
   132  	if r.Method == corsOptionMethod {
   133  		return
   134  	}
   135  	ch.h.ServeHTTP(w, r)
   136  }
   137  
   138  // CORS provides Cross-Origin Resource Sharing middleware.
   139  // Example:
   140  //
   141  //  import (
   142  //      "github.com/hellobchain/newcryptosm/http"
   143  //
   144  //      "github.com/hellobchain/third_party/gorilla/handlers"
   145  //      "github.com/hellobchain/third_party/gorilla/mux"
   146  //  )
   147  //
   148  //  func main() {
   149  //      r := mux.NewRouter()
   150  //      r.HandleFunc("/users", UserEndpoint)
   151  //      r.HandleFunc("/projects", ProjectEndpoint)
   152  //
   153  //      // Apply the CORS middleware to our top-level router, with the defaults.
   154  //      http.ListenAndServe(":8000", handlers.CORS()(r))
   155  //  }
   156  //
   157  func CORS(opts ...CORSOption) func(http.Handler) http.Handler {
   158  	return func(h http.Handler) http.Handler {
   159  		ch := parseCORSOptions(opts...)
   160  		ch.h = h
   161  		return ch
   162  	}
   163  }
   164  
   165  func parseCORSOptions(opts ...CORSOption) *cors {
   166  	ch := &cors{
   167  		allowedMethods: defaultCorsMethods,
   168  		allowedHeaders: defaultCorsHeaders,
   169  		allowedOrigins: []string{},
   170  	}
   171  
   172  	for _, option := range opts {
   173  		option(ch)
   174  	}
   175  
   176  	return ch
   177  }
   178  
   179  //
   180  // Functional options for configuring CORS.
   181  //
   182  
   183  // AllowedHeaders adds the provided headers to the list of allowed headers in a
   184  // CORS request.
   185  // This is an append operation so the headers Accept, Accept-Language,
   186  // and Content-Language are always allowed.
   187  // Content-Type must be explicitly declared if accepting Content-Types other than
   188  // application/x-www-form-urlencoded, multipart/form-data, or text/plain.
   189  func AllowedHeaders(headers []string) CORSOption {
   190  	return func(ch *cors) error {
   191  		for _, v := range headers {
   192  			normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
   193  			if normalizedHeader == "" {
   194  				continue
   195  			}
   196  
   197  			if !ch.isMatch(normalizedHeader, ch.allowedHeaders) {
   198  				ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader)
   199  			}
   200  		}
   201  
   202  		return nil
   203  	}
   204  }
   205  
   206  // AllowedMethods can be used to explicitly allow methods in the
   207  // Access-Control-Allow-Methods header.
   208  // This is a replacement operation so you must also
   209  // pass GET, HEAD, and POST if you wish to support those methods.
   210  func AllowedMethods(methods []string) CORSOption {
   211  	return func(ch *cors) error {
   212  		ch.allowedMethods = []string{}
   213  		for _, v := range methods {
   214  			normalizedMethod := strings.ToUpper(strings.TrimSpace(v))
   215  			if normalizedMethod == "" {
   216  				continue
   217  			}
   218  
   219  			if !ch.isMatch(normalizedMethod, ch.allowedMethods) {
   220  				ch.allowedMethods = append(ch.allowedMethods, normalizedMethod)
   221  			}
   222  		}
   223  
   224  		return nil
   225  	}
   226  }
   227  
   228  // AllowedOrigins sets the allowed origins for CORS requests, as used in the
   229  // 'Allow-Access-Control-Origin' HTTP header.
   230  // Note: Passing in a []string{"*"} will allow any domain.
   231  func AllowedOrigins(origins []string) CORSOption {
   232  	return func(ch *cors) error {
   233  		for _, v := range origins {
   234  			if v == corsOriginMatchAll {
   235  				ch.allowedOrigins = []string{corsOriginMatchAll}
   236  				return nil
   237  			}
   238  		}
   239  
   240  		ch.allowedOrigins = origins
   241  		return nil
   242  	}
   243  }
   244  
   245  // AllowedOriginValidator sets a function for evaluating allowed origins in CORS requests, represented by the
   246  // 'Allow-Access-Control-Origin' HTTP header.
   247  func AllowedOriginValidator(fn OriginValidator) CORSOption {
   248  	return func(ch *cors) error {
   249  		ch.allowedOriginValidator = fn
   250  		return nil
   251  	}
   252  }
   253  
   254  // ExposeHeaders can be used to specify headers that are available
   255  // and will not be stripped out by the user-agent.
   256  func ExposedHeaders(headers []string) CORSOption {
   257  	return func(ch *cors) error {
   258  		ch.exposedHeaders = []string{}
   259  		for _, v := range headers {
   260  			normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
   261  			if normalizedHeader == "" {
   262  				continue
   263  			}
   264  
   265  			if !ch.isMatch(normalizedHeader, ch.exposedHeaders) {
   266  				ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader)
   267  			}
   268  		}
   269  
   270  		return nil
   271  	}
   272  }
   273  
   274  // MaxAge determines the maximum age (in seconds) between preflight requests. A
   275  // maximum of 10 minutes is allowed. An age above this value will default to 10
   276  // minutes.
   277  func MaxAge(age int) CORSOption {
   278  	return func(ch *cors) error {
   279  		// Maximum of 10 minutes.
   280  		if age > 600 {
   281  			age = 600
   282  		}
   283  
   284  		ch.maxAge = age
   285  		return nil
   286  	}
   287  }
   288  
   289  // IgnoreOptions causes the CORS middleware to ignore OPTIONS requests, instead
   290  // passing them through to the next handler. This is useful when your application
   291  // or framework has a pre-existing mechanism for responding to OPTIONS requests.
   292  func IgnoreOptions() CORSOption {
   293  	return func(ch *cors) error {
   294  		ch.ignoreOptions = true
   295  		return nil
   296  	}
   297  }
   298  
   299  // AllowCredentials can be used to specify that the user agent may pass
   300  // authentication details along with the request.
   301  func AllowCredentials() CORSOption {
   302  	return func(ch *cors) error {
   303  		ch.allowCredentials = true
   304  		return nil
   305  	}
   306  }
   307  
   308  func (ch *cors) isOriginAllowed(origin string) bool {
   309  	if origin == "" {
   310  		return false
   311  	}
   312  
   313  	if ch.allowedOriginValidator != nil {
   314  		return ch.allowedOriginValidator(origin)
   315  	}
   316  
   317  	if len(ch.allowedOrigins) == 0 {
   318  		return true
   319  	}
   320  
   321  	for _, allowedOrigin := range ch.allowedOrigins {
   322  		if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll {
   323  			return true
   324  		}
   325  	}
   326  
   327  	return false
   328  }
   329  
   330  func (ch *cors) isMatch(needle string, haystack []string) bool {
   331  	for _, v := range haystack {
   332  		if v == needle {
   333  			return true
   334  		}
   335  	}
   336  
   337  	return false
   338  }