github.com/hellobchain/third_party@v0.0.0-20230331131523-deb0478a2e52/gorilla/handlers/cors.go (about) 1 package handlers 2 3 import ( 4 "github.com/hellobchain/newcryptosm/http" 5 "strconv" 6 "strings" 7 ) 8 9 // CORSOption represents a functional option for configuring the CORS middleware. 10 type CORSOption func(*cors) error 11 12 type cors struct { 13 h http.Handler 14 allowedHeaders []string 15 allowedMethods []string 16 allowedOrigins []string 17 allowedOriginValidator OriginValidator 18 exposedHeaders []string 19 maxAge int 20 ignoreOptions bool 21 allowCredentials bool 22 } 23 24 // OriginValidator takes an origin string and returns whether or not that origin is allowed. 25 type OriginValidator func(string) bool 26 27 var ( 28 defaultCorsMethods = []string{"GET", "HEAD", "POST"} 29 defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"} 30 // (WebKit/Safari v9 sends the Origin header by default in AJAX requests) 31 ) 32 33 const ( 34 corsOptionMethod string = "OPTIONS" 35 corsAllowOriginHeader string = "Access-Control-Allow-Origin" 36 corsExposeHeadersHeader string = "Access-Control-Expose-Headers" 37 corsMaxAgeHeader string = "Access-Control-Max-Age" 38 corsAllowMethodsHeader string = "Access-Control-Allow-Methods" 39 corsAllowHeadersHeader string = "Access-Control-Allow-Headers" 40 corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials" 41 corsRequestMethodHeader string = "Access-Control-Request-Method" 42 corsRequestHeadersHeader string = "Access-Control-Request-Headers" 43 corsOriginHeader string = "Origin" 44 corsVaryHeader string = "Vary" 45 corsOriginMatchAll string = "*" 46 ) 47 48 func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { 49 origin := r.Header.Get(corsOriginHeader) 50 if !ch.isOriginAllowed(origin) { 51 if r.Method != corsOptionMethod || ch.ignoreOptions { 52 ch.h.ServeHTTP(w, r) 53 } 54 55 return 56 } 57 58 if r.Method == corsOptionMethod { 59 if ch.ignoreOptions { 60 ch.h.ServeHTTP(w, r) 61 return 62 } 63 64 if _, ok := r.Header[corsRequestMethodHeader]; !ok { 65 w.WriteHeader(http.StatusBadRequest) 66 return 67 } 68 69 method := r.Header.Get(corsRequestMethodHeader) 70 if !ch.isMatch(method, ch.allowedMethods) { 71 w.WriteHeader(http.StatusMethodNotAllowed) 72 return 73 } 74 75 requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",") 76 allowedHeaders := []string{} 77 for _, v := range requestHeaders { 78 canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) 79 if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) { 80 continue 81 } 82 83 if !ch.isMatch(canonicalHeader, ch.allowedHeaders) { 84 w.WriteHeader(http.StatusForbidden) 85 return 86 } 87 88 allowedHeaders = append(allowedHeaders, canonicalHeader) 89 } 90 91 if len(allowedHeaders) > 0 { 92 w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ",")) 93 } 94 95 if ch.maxAge > 0 { 96 w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge)) 97 } 98 99 if !ch.isMatch(method, defaultCorsMethods) { 100 w.Header().Set(corsAllowMethodsHeader, method) 101 } 102 } else { 103 if len(ch.exposedHeaders) > 0 { 104 w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ",")) 105 } 106 } 107 108 if ch.allowCredentials { 109 w.Header().Set(corsAllowCredentialsHeader, "true") 110 } 111 112 if len(ch.allowedOrigins) > 1 { 113 w.Header().Set(corsVaryHeader, corsOriginHeader) 114 } 115 116 returnOrigin := origin 117 if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 { 118 returnOrigin = "*" 119 } else { 120 for _, o := range ch.allowedOrigins { 121 // A configuration of * is different than explicitly setting an allowed 122 // origin. Returning arbitrary origin headers in an access control allow 123 // origin header is unsafe and is not required by any use case. 124 if o == corsOriginMatchAll { 125 returnOrigin = "*" 126 break 127 } 128 } 129 } 130 w.Header().Set(corsAllowOriginHeader, returnOrigin) 131 132 if r.Method == corsOptionMethod { 133 return 134 } 135 ch.h.ServeHTTP(w, r) 136 } 137 138 // CORS provides Cross-Origin Resource Sharing middleware. 139 // Example: 140 // 141 // import ( 142 // "github.com/hellobchain/newcryptosm/http" 143 // 144 // "github.com/hellobchain/third_party/gorilla/handlers" 145 // "github.com/hellobchain/third_party/gorilla/mux" 146 // ) 147 // 148 // func main() { 149 // r := mux.NewRouter() 150 // r.HandleFunc("/users", UserEndpoint) 151 // r.HandleFunc("/projects", ProjectEndpoint) 152 // 153 // // Apply the CORS middleware to our top-level router, with the defaults. 154 // http.ListenAndServe(":8000", handlers.CORS()(r)) 155 // } 156 // 157 func CORS(opts ...CORSOption) func(http.Handler) http.Handler { 158 return func(h http.Handler) http.Handler { 159 ch := parseCORSOptions(opts...) 160 ch.h = h 161 return ch 162 } 163 } 164 165 func parseCORSOptions(opts ...CORSOption) *cors { 166 ch := &cors{ 167 allowedMethods: defaultCorsMethods, 168 allowedHeaders: defaultCorsHeaders, 169 allowedOrigins: []string{}, 170 } 171 172 for _, option := range opts { 173 option(ch) 174 } 175 176 return ch 177 } 178 179 // 180 // Functional options for configuring CORS. 181 // 182 183 // AllowedHeaders adds the provided headers to the list of allowed headers in a 184 // CORS request. 185 // This is an append operation so the headers Accept, Accept-Language, 186 // and Content-Language are always allowed. 187 // Content-Type must be explicitly declared if accepting Content-Types other than 188 // application/x-www-form-urlencoded, multipart/form-data, or text/plain. 189 func AllowedHeaders(headers []string) CORSOption { 190 return func(ch *cors) error { 191 for _, v := range headers { 192 normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) 193 if normalizedHeader == "" { 194 continue 195 } 196 197 if !ch.isMatch(normalizedHeader, ch.allowedHeaders) { 198 ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader) 199 } 200 } 201 202 return nil 203 } 204 } 205 206 // AllowedMethods can be used to explicitly allow methods in the 207 // Access-Control-Allow-Methods header. 208 // This is a replacement operation so you must also 209 // pass GET, HEAD, and POST if you wish to support those methods. 210 func AllowedMethods(methods []string) CORSOption { 211 return func(ch *cors) error { 212 ch.allowedMethods = []string{} 213 for _, v := range methods { 214 normalizedMethod := strings.ToUpper(strings.TrimSpace(v)) 215 if normalizedMethod == "" { 216 continue 217 } 218 219 if !ch.isMatch(normalizedMethod, ch.allowedMethods) { 220 ch.allowedMethods = append(ch.allowedMethods, normalizedMethod) 221 } 222 } 223 224 return nil 225 } 226 } 227 228 // AllowedOrigins sets the allowed origins for CORS requests, as used in the 229 // 'Allow-Access-Control-Origin' HTTP header. 230 // Note: Passing in a []string{"*"} will allow any domain. 231 func AllowedOrigins(origins []string) CORSOption { 232 return func(ch *cors) error { 233 for _, v := range origins { 234 if v == corsOriginMatchAll { 235 ch.allowedOrigins = []string{corsOriginMatchAll} 236 return nil 237 } 238 } 239 240 ch.allowedOrigins = origins 241 return nil 242 } 243 } 244 245 // AllowedOriginValidator sets a function for evaluating allowed origins in CORS requests, represented by the 246 // 'Allow-Access-Control-Origin' HTTP header. 247 func AllowedOriginValidator(fn OriginValidator) CORSOption { 248 return func(ch *cors) error { 249 ch.allowedOriginValidator = fn 250 return nil 251 } 252 } 253 254 // ExposeHeaders can be used to specify headers that are available 255 // and will not be stripped out by the user-agent. 256 func ExposedHeaders(headers []string) CORSOption { 257 return func(ch *cors) error { 258 ch.exposedHeaders = []string{} 259 for _, v := range headers { 260 normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) 261 if normalizedHeader == "" { 262 continue 263 } 264 265 if !ch.isMatch(normalizedHeader, ch.exposedHeaders) { 266 ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader) 267 } 268 } 269 270 return nil 271 } 272 } 273 274 // MaxAge determines the maximum age (in seconds) between preflight requests. A 275 // maximum of 10 minutes is allowed. An age above this value will default to 10 276 // minutes. 277 func MaxAge(age int) CORSOption { 278 return func(ch *cors) error { 279 // Maximum of 10 minutes. 280 if age > 600 { 281 age = 600 282 } 283 284 ch.maxAge = age 285 return nil 286 } 287 } 288 289 // IgnoreOptions causes the CORS middleware to ignore OPTIONS requests, instead 290 // passing them through to the next handler. This is useful when your application 291 // or framework has a pre-existing mechanism for responding to OPTIONS requests. 292 func IgnoreOptions() CORSOption { 293 return func(ch *cors) error { 294 ch.ignoreOptions = true 295 return nil 296 } 297 } 298 299 // AllowCredentials can be used to specify that the user agent may pass 300 // authentication details along with the request. 301 func AllowCredentials() CORSOption { 302 return func(ch *cors) error { 303 ch.allowCredentials = true 304 return nil 305 } 306 } 307 308 func (ch *cors) isOriginAllowed(origin string) bool { 309 if origin == "" { 310 return false 311 } 312 313 if ch.allowedOriginValidator != nil { 314 return ch.allowedOriginValidator(origin) 315 } 316 317 if len(ch.allowedOrigins) == 0 { 318 return true 319 } 320 321 for _, allowedOrigin := range ch.allowedOrigins { 322 if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll { 323 return true 324 } 325 } 326 327 return false 328 } 329 330 func (ch *cors) isMatch(needle string, haystack []string) bool { 331 for _, v := range haystack { 332 if v == needle { 333 return true 334 } 335 } 336 337 return false 338 }