github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/handlers/cors.go (about) 1 package handlers 2 3 import ( 4 "strconv" 5 "strings" 6 7 http "github.com/hxx258456/ccgo/gmhttp" 8 ) 9 10 // CORSOption represents a functional option for configuring the CORS middleware. 11 type CORSOption func(*cors) error 12 13 type cors struct { 14 h http.Handler 15 allowedHeaders []string 16 allowedMethods []string 17 allowedOrigins []string 18 allowedOriginValidator OriginValidator 19 exposedHeaders []string 20 maxAge int 21 ignoreOptions bool 22 allowCredentials bool 23 optionStatusCode int 24 } 25 26 // OriginValidator takes an origin string and returns whether or not that origin is allowed. 27 type OriginValidator func(string) bool 28 29 var ( 30 defaultCorsOptionStatusCode = 200 31 defaultCorsMethods = []string{"GET", "HEAD", "POST"} 32 defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"} 33 // (WebKit/Safari v9 sends the Origin header by default in AJAX requests) 34 ) 35 36 const ( 37 corsOptionMethod string = "OPTIONS" 38 corsAllowOriginHeader string = "Access-Control-Allow-Origin" 39 corsExposeHeadersHeader string = "Access-Control-Expose-Headers" 40 corsMaxAgeHeader string = "Access-Control-Max-Age" 41 corsAllowMethodsHeader string = "Access-Control-Allow-Methods" 42 corsAllowHeadersHeader string = "Access-Control-Allow-Headers" 43 corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials" 44 corsRequestMethodHeader string = "Access-Control-Request-Method" 45 corsRequestHeadersHeader string = "Access-Control-Request-Headers" 46 corsOriginHeader string = "Origin" 47 corsVaryHeader string = "Vary" 48 corsOriginMatchAll string = "*" 49 ) 50 51 func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { 52 origin := r.Header.Get(corsOriginHeader) 53 if !ch.isOriginAllowed(origin) { 54 if r.Method != corsOptionMethod || ch.ignoreOptions { 55 ch.h.ServeHTTP(w, r) 56 } 57 58 return 59 } 60 61 if r.Method == corsOptionMethod { 62 if ch.ignoreOptions { 63 ch.h.ServeHTTP(w, r) 64 return 65 } 66 67 if _, ok := r.Header[corsRequestMethodHeader]; !ok { 68 w.WriteHeader(http.StatusBadRequest) 69 return 70 } 71 72 method := r.Header.Get(corsRequestMethodHeader) 73 if !ch.isMatch(method, ch.allowedMethods) { 74 w.WriteHeader(http.StatusMethodNotAllowed) 75 return 76 } 77 78 requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",") 79 allowedHeaders := []string{} 80 for _, v := range requestHeaders { 81 canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) 82 if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) { 83 continue 84 } 85 86 if !ch.isMatch(canonicalHeader, ch.allowedHeaders) { 87 w.WriteHeader(http.StatusForbidden) 88 return 89 } 90 91 allowedHeaders = append(allowedHeaders, canonicalHeader) 92 } 93 94 if len(allowedHeaders) > 0 { 95 w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ",")) 96 } 97 98 if ch.maxAge > 0 { 99 w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge)) 100 } 101 102 if !ch.isMatch(method, defaultCorsMethods) { 103 w.Header().Set(corsAllowMethodsHeader, method) 104 } 105 } else { 106 if len(ch.exposedHeaders) > 0 { 107 w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ",")) 108 } 109 } 110 111 if ch.allowCredentials { 112 w.Header().Set(corsAllowCredentialsHeader, "true") 113 } 114 115 if len(ch.allowedOrigins) > 1 { 116 w.Header().Set(corsVaryHeader, corsOriginHeader) 117 } 118 119 returnOrigin := origin 120 if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 { 121 returnOrigin = "*" 122 } else { 123 for _, o := range ch.allowedOrigins { 124 // A configuration of * is different than explicitly setting an allowed 125 // origin. Returning arbitrary origin headers in an access control allow 126 // origin header is unsafe and is not required by any use case. 127 if o == corsOriginMatchAll { 128 returnOrigin = "*" 129 break 130 } 131 } 132 } 133 w.Header().Set(corsAllowOriginHeader, returnOrigin) 134 135 if r.Method == corsOptionMethod { 136 w.WriteHeader(ch.optionStatusCode) 137 return 138 } 139 ch.h.ServeHTTP(w, r) 140 } 141 142 // CORS provides Cross-Origin Resource Sharing middleware. 143 // Example: 144 // 145 // import ( 146 // http "github.com/hxx258456/ccgo/gmhttp" 147 // 148 // "github.com/hxx258456/ccgo/handlers" 149 // "github.com/hxx258456/ccgo/mux" 150 // ) 151 // 152 // func main() { 153 // r := mux.NewRouter() 154 // r.HandleFunc("/users", UserEndpoint) 155 // r.HandleFunc("/projects", ProjectEndpoint) 156 // 157 // // Apply the CORS middleware to our top-level router, with the defaults. 158 // http.ListenAndServe(":8000", handlers.CORS()(r)) 159 // } 160 // 161 func CORS(opts ...CORSOption) func(http.Handler) http.Handler { 162 return func(h http.Handler) http.Handler { 163 ch := parseCORSOptions(opts...) 164 ch.h = h 165 return ch 166 } 167 } 168 169 func parseCORSOptions(opts ...CORSOption) *cors { 170 ch := &cors{ 171 allowedMethods: defaultCorsMethods, 172 allowedHeaders: defaultCorsHeaders, 173 allowedOrigins: []string{}, 174 optionStatusCode: defaultCorsOptionStatusCode, 175 } 176 177 for _, option := range opts { 178 option(ch) 179 } 180 181 return ch 182 } 183 184 // 185 // Functional options for configuring CORS. 186 // 187 188 // AllowedHeaders adds the provided headers to the list of allowed headers in a 189 // CORS request. 190 // This is an append operation so the headers Accept, Accept-Language, 191 // and Content-Language are always allowed. 192 // Content-Type must be explicitly declared if accepting Content-Types other than 193 // application/x-www-form-urlencoded, multipart/form-data, or text/plain. 194 func AllowedHeaders(headers []string) CORSOption { 195 return func(ch *cors) error { 196 for _, v := range headers { 197 normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) 198 if normalizedHeader == "" { 199 continue 200 } 201 202 if !ch.isMatch(normalizedHeader, ch.allowedHeaders) { 203 ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader) 204 } 205 } 206 207 return nil 208 } 209 } 210 211 // AllowedMethods can be used to explicitly allow methods in the 212 // Access-Control-Allow-Methods header. 213 // This is a replacement operation so you must also 214 // pass GET, HEAD, and POST if you wish to support those methods. 215 func AllowedMethods(methods []string) CORSOption { 216 return func(ch *cors) error { 217 ch.allowedMethods = []string{} 218 for _, v := range methods { 219 normalizedMethod := strings.ToUpper(strings.TrimSpace(v)) 220 if normalizedMethod == "" { 221 continue 222 } 223 224 if !ch.isMatch(normalizedMethod, ch.allowedMethods) { 225 ch.allowedMethods = append(ch.allowedMethods, normalizedMethod) 226 } 227 } 228 229 return nil 230 } 231 } 232 233 // AllowedOrigins sets the allowed origins for CORS requests, as used in the 234 // 'Allow-Access-Control-Origin' HTTP header. 235 // Note: Passing in a []string{"*"} will allow any domain. 236 func AllowedOrigins(origins []string) CORSOption { 237 return func(ch *cors) error { 238 for _, v := range origins { 239 if v == corsOriginMatchAll { 240 ch.allowedOrigins = []string{corsOriginMatchAll} 241 return nil 242 } 243 } 244 245 ch.allowedOrigins = origins 246 return nil 247 } 248 } 249 250 // AllowedOriginValidator sets a function for evaluating allowed origins in CORS requests, represented by the 251 // 'Allow-Access-Control-Origin' HTTP header. 252 func AllowedOriginValidator(fn OriginValidator) CORSOption { 253 return func(ch *cors) error { 254 ch.allowedOriginValidator = fn 255 return nil 256 } 257 } 258 259 // OptionStatusCode sets a custom status code on the OPTIONS requests. 260 // Default behaviour sets it to 200 to reflect best practices. This is option is not mandatory 261 // and can be used if you need a custom status code (i.e 204). 262 // 263 // More informations on the spec: 264 // https://fetch.spec.whatwg.org/#cors-preflight-fetch 265 func OptionStatusCode(code int) CORSOption { 266 return func(ch *cors) error { 267 ch.optionStatusCode = code 268 return nil 269 } 270 } 271 272 // ExposedHeaders can be used to specify headers that are available 273 // and will not be stripped out by the user-agent. 274 func ExposedHeaders(headers []string) CORSOption { 275 return func(ch *cors) error { 276 ch.exposedHeaders = []string{} 277 for _, v := range headers { 278 normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) 279 if normalizedHeader == "" { 280 continue 281 } 282 283 if !ch.isMatch(normalizedHeader, ch.exposedHeaders) { 284 ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader) 285 } 286 } 287 288 return nil 289 } 290 } 291 292 // MaxAge determines the maximum age (in seconds) between preflight requests. A 293 // maximum of 10 minutes is allowed. An age above this value will default to 10 294 // minutes. 295 func MaxAge(age int) CORSOption { 296 return func(ch *cors) error { 297 // Maximum of 10 minutes. 298 if age > 600 { 299 age = 600 300 } 301 302 ch.maxAge = age 303 return nil 304 } 305 } 306 307 // IgnoreOptions causes the CORS middleware to ignore OPTIONS requests, instead 308 // passing them through to the next handler. This is useful when your application 309 // or framework has a pre-existing mechanism for responding to OPTIONS requests. 310 func IgnoreOptions() CORSOption { 311 return func(ch *cors) error { 312 ch.ignoreOptions = true 313 return nil 314 } 315 } 316 317 // AllowCredentials can be used to specify that the user agent may pass 318 // authentication details along with the request. 319 func AllowCredentials() CORSOption { 320 return func(ch *cors) error { 321 ch.allowCredentials = true 322 return nil 323 } 324 } 325 326 func (ch *cors) isOriginAllowed(origin string) bool { 327 if origin == "" { 328 return false 329 } 330 331 if ch.allowedOriginValidator != nil { 332 return ch.allowedOriginValidator(origin) 333 } 334 335 if len(ch.allowedOrigins) == 0 { 336 return true 337 } 338 339 for _, allowedOrigin := range ch.allowedOrigins { 340 if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll { 341 return true 342 } 343 } 344 345 return false 346 } 347 348 func (ch *cors) isMatch(needle string, haystack []string) bool { 349 for _, v := range haystack { 350 if v == needle { 351 return true 352 } 353 } 354 355 return false 356 }