gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/handlers/cors.go (about) 1 package handlers 2 3 import ( 4 "strconv" 5 "strings" 6 7 http "gitee.com/ks-custle/core-gm/gmhttp" 8 ) 9 10 // CORSOption represents a functional option for configuring the CORS middleware. 11 type CORSOption func(*cors) error 12 13 type cors struct { 14 h http.Handler 15 allowedHeaders []string 16 allowedMethods []string 17 allowedOrigins []string 18 allowedOriginValidator OriginValidator 19 exposedHeaders []string 20 maxAge int 21 ignoreOptions bool 22 allowCredentials bool 23 optionStatusCode int 24 } 25 26 // OriginValidator takes an origin string and returns whether or not that origin is allowed. 27 type OriginValidator func(string) bool 28 29 var ( 30 defaultCorsOptionStatusCode = 200 31 defaultCorsMethods = []string{"GET", "HEAD", "POST"} 32 defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"} 33 // (WebKit/Safari v9 sends the Origin header by default in AJAX requests) 34 ) 35 36 const ( 37 corsOptionMethod string = "OPTIONS" 38 corsAllowOriginHeader string = "Access-Control-Allow-Origin" 39 corsExposeHeadersHeader string = "Access-Control-Expose-Headers" 40 corsMaxAgeHeader string = "Access-Control-Max-Age" 41 corsAllowMethodsHeader string = "Access-Control-Allow-Methods" 42 corsAllowHeadersHeader string = "Access-Control-Allow-Headers" 43 corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials" 44 corsRequestMethodHeader string = "Access-Control-Request-Method" 45 corsRequestHeadersHeader string = "Access-Control-Request-Headers" 46 corsOriginHeader string = "Origin" 47 corsVaryHeader string = "Vary" 48 corsOriginMatchAll string = "*" 49 ) 50 51 func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { 52 origin := r.Header.Get(corsOriginHeader) 53 if !ch.isOriginAllowed(origin) { 54 if r.Method != corsOptionMethod || ch.ignoreOptions { 55 ch.h.ServeHTTP(w, r) 56 } 57 58 return 59 } 60 61 if r.Method == corsOptionMethod { 62 if ch.ignoreOptions { 63 ch.h.ServeHTTP(w, r) 64 return 65 } 66 67 if _, ok := r.Header[corsRequestMethodHeader]; !ok { 68 w.WriteHeader(http.StatusBadRequest) 69 return 70 } 71 72 method := r.Header.Get(corsRequestMethodHeader) 73 if !ch.isMatch(method, ch.allowedMethods) { 74 w.WriteHeader(http.StatusMethodNotAllowed) 75 return 76 } 77 78 requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",") 79 allowedHeaders := []string{} 80 for _, v := range requestHeaders { 81 canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) 82 if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) { 83 continue 84 } 85 86 if !ch.isMatch(canonicalHeader, ch.allowedHeaders) { 87 w.WriteHeader(http.StatusForbidden) 88 return 89 } 90 91 allowedHeaders = append(allowedHeaders, canonicalHeader) 92 } 93 94 if len(allowedHeaders) > 0 { 95 w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ",")) 96 } 97 98 if ch.maxAge > 0 { 99 w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge)) 100 } 101 102 if !ch.isMatch(method, defaultCorsMethods) { 103 w.Header().Set(corsAllowMethodsHeader, method) 104 } 105 } else { 106 if len(ch.exposedHeaders) > 0 { 107 w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ",")) 108 } 109 } 110 111 if ch.allowCredentials { 112 w.Header().Set(corsAllowCredentialsHeader, "true") 113 } 114 115 if len(ch.allowedOrigins) > 1 { 116 w.Header().Set(corsVaryHeader, corsOriginHeader) 117 } 118 119 returnOrigin := origin 120 if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 { 121 returnOrigin = "*" 122 } else { 123 for _, o := range ch.allowedOrigins { 124 // A configuration of * is different than explicitly setting an allowed 125 // origin. Returning arbitrary origin headers in an access control allow 126 // origin header is unsafe and is not required by any use case. 127 if o == corsOriginMatchAll { 128 returnOrigin = "*" 129 break 130 } 131 } 132 } 133 w.Header().Set(corsAllowOriginHeader, returnOrigin) 134 135 if r.Method == corsOptionMethod { 136 w.WriteHeader(ch.optionStatusCode) 137 return 138 } 139 ch.h.ServeHTTP(w, r) 140 } 141 142 // CORS provides Cross-Origin Resource Sharing middleware. 143 // Example: 144 // 145 // import ( 146 // http "gitee.com/ks-custle/core-gm/gmhttp" 147 // 148 // "gitee.com/ks-custle/core-gm/handlers" 149 // "gitee.com/ks-custle/core-gm/mux" 150 // ) 151 // 152 // func main() { 153 // r := mux.NewRouter() 154 // r.HandleFunc("/users", UserEndpoint) 155 // r.HandleFunc("/projects", ProjectEndpoint) 156 // 157 // // Apply the CORS middleware to our top-level router, with the defaults. 158 // http.ListenAndServe(":8000", handlers.CORS()(r)) 159 // } 160 func CORS(opts ...CORSOption) func(http.Handler) http.Handler { 161 return func(h http.Handler) http.Handler { 162 ch := parseCORSOptions(opts...) 163 ch.h = h 164 return ch 165 } 166 } 167 168 func parseCORSOptions(opts ...CORSOption) *cors { 169 ch := &cors{ 170 allowedMethods: defaultCorsMethods, 171 allowedHeaders: defaultCorsHeaders, 172 allowedOrigins: []string{}, 173 optionStatusCode: defaultCorsOptionStatusCode, 174 } 175 176 for _, option := range opts { 177 option(ch) 178 } 179 180 return ch 181 } 182 183 // 184 // Functional options for configuring CORS. 185 // 186 187 // AllowedHeaders adds the provided headers to the list of allowed headers in a 188 // CORS request. 189 // This is an append operation so the headers Accept, Accept-Language, 190 // and Content-Language are always allowed. 191 // Content-Type must be explicitly declared if accepting Content-Types other than 192 // application/x-www-form-urlencoded, multipart/form-data, or text/plain. 193 func AllowedHeaders(headers []string) CORSOption { 194 return func(ch *cors) error { 195 for _, v := range headers { 196 normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) 197 if normalizedHeader == "" { 198 continue 199 } 200 201 if !ch.isMatch(normalizedHeader, ch.allowedHeaders) { 202 ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader) 203 } 204 } 205 206 return nil 207 } 208 } 209 210 // AllowedMethods can be used to explicitly allow methods in the 211 // Access-Control-Allow-Methods header. 212 // This is a replacement operation so you must also 213 // pass GET, HEAD, and POST if you wish to support those methods. 214 func AllowedMethods(methods []string) CORSOption { 215 return func(ch *cors) error { 216 ch.allowedMethods = []string{} 217 for _, v := range methods { 218 normalizedMethod := strings.ToUpper(strings.TrimSpace(v)) 219 if normalizedMethod == "" { 220 continue 221 } 222 223 if !ch.isMatch(normalizedMethod, ch.allowedMethods) { 224 ch.allowedMethods = append(ch.allowedMethods, normalizedMethod) 225 } 226 } 227 228 return nil 229 } 230 } 231 232 // AllowedOrigins sets the allowed origins for CORS requests, as used in the 233 // 'Allow-Access-Control-Origin' HTTP header. 234 // Note: Passing in a []string{"*"} will allow any domain. 235 func AllowedOrigins(origins []string) CORSOption { 236 return func(ch *cors) error { 237 for _, v := range origins { 238 if v == corsOriginMatchAll { 239 ch.allowedOrigins = []string{corsOriginMatchAll} 240 return nil 241 } 242 } 243 244 ch.allowedOrigins = origins 245 return nil 246 } 247 } 248 249 // AllowedOriginValidator sets a function for evaluating allowed origins in CORS requests, represented by the 250 // 'Allow-Access-Control-Origin' HTTP header. 251 func AllowedOriginValidator(fn OriginValidator) CORSOption { 252 return func(ch *cors) error { 253 ch.allowedOriginValidator = fn 254 return nil 255 } 256 } 257 258 // OptionStatusCode sets a custom status code on the OPTIONS requests. 259 // Default behaviour sets it to 200 to reflect best practices. This is option is not mandatory 260 // and can be used if you need a custom status code (i.e 204). 261 // 262 // More informations on the spec: 263 // https://fetch.spec.whatwg.org/#cors-preflight-fetch 264 func OptionStatusCode(code int) CORSOption { 265 return func(ch *cors) error { 266 ch.optionStatusCode = code 267 return nil 268 } 269 } 270 271 // ExposedHeaders can be used to specify headers that are available 272 // and will not be stripped out by the user-agent. 273 func ExposedHeaders(headers []string) CORSOption { 274 return func(ch *cors) error { 275 ch.exposedHeaders = []string{} 276 for _, v := range headers { 277 normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) 278 if normalizedHeader == "" { 279 continue 280 } 281 282 if !ch.isMatch(normalizedHeader, ch.exposedHeaders) { 283 ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader) 284 } 285 } 286 287 return nil 288 } 289 } 290 291 // MaxAge determines the maximum age (in seconds) between preflight requests. A 292 // maximum of 10 minutes is allowed. An age above this value will default to 10 293 // minutes. 294 func MaxAge(age int) CORSOption { 295 return func(ch *cors) error { 296 // Maximum of 10 minutes. 297 if age > 600 { 298 age = 600 299 } 300 301 ch.maxAge = age 302 return nil 303 } 304 } 305 306 // IgnoreOptions causes the CORS middleware to ignore OPTIONS requests, instead 307 // passing them through to the next handler. This is useful when your application 308 // or framework has a pre-existing mechanism for responding to OPTIONS requests. 309 func IgnoreOptions() CORSOption { 310 return func(ch *cors) error { 311 ch.ignoreOptions = true 312 return nil 313 } 314 } 315 316 // AllowCredentials can be used to specify that the user agent may pass 317 // authentication details along with the request. 318 func AllowCredentials() CORSOption { 319 return func(ch *cors) error { 320 ch.allowCredentials = true 321 return nil 322 } 323 } 324 325 func (ch *cors) isOriginAllowed(origin string) bool { 326 if origin == "" { 327 return false 328 } 329 330 if ch.allowedOriginValidator != nil { 331 return ch.allowedOriginValidator(origin) 332 } 333 334 if len(ch.allowedOrigins) == 0 { 335 return true 336 } 337 338 for _, allowedOrigin := range ch.allowedOrigins { 339 if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll { 340 return true 341 } 342 } 343 344 return false 345 } 346 347 func (ch *cors) isMatch(needle string, haystack []string) bool { 348 for _, v := range haystack { 349 if v == needle { 350 return true 351 } 352 } 353 354 return false 355 }