github.com/gofunct/common@v0.0.0-20190131174352-fd058c7fbf22/pkg/transport/handlers/cors.go (about) 1 package handlers 2 3 import ( 4 "net/http" 5 "strconv" 6 "strings" 7 ) 8 9 // CORSOption represents a functional option for configuring the CORS middleware. 10 type CORSOption func(*cors) error 11 12 type cors struct { 13 h http.Handler 14 allowedHeaders []string 15 allowedMethods []string 16 allowedOrigins []string 17 allowedOriginValidator OriginValidator 18 exposedHeaders []string 19 maxAge int 20 ignoreOptions bool 21 allowCredentials bool 22 optionStatusCode int 23 } 24 25 // OriginValidator takes an origin string and returns whether or not that origin is allowed. 26 type OriginValidator func(string) bool 27 28 var ( 29 defaultCorsOptionStatusCode = 200 30 defaultCorsMethods = []string{"GET", "HEAD", "POST"} 31 defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"} 32 // (WebKit/Safari v9 sends the Origin header by default in AJAX requests) 33 ) 34 35 const ( 36 corsOptionMethod string = "OPTIONS" 37 corsAllowOriginHeader string = "Access-Control-Allow-Origin" 38 corsExposeHeadersHeader string = "Access-Control-Expose-Headers" 39 corsMaxAgeHeader string = "Access-Control-Max-Age" 40 corsAllowMethodsHeader string = "Access-Control-Allow-Methods" 41 corsAllowHeadersHeader string = "Access-Control-Allow-Headers" 42 corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials" 43 corsRequestMethodHeader string = "Access-Control-Request-Method" 44 corsRequestHeadersHeader string = "Access-Control-Request-Headers" 45 corsOriginHeader string = "Origin" 46 corsVaryHeader string = "Vary" 47 corsOriginMatchAll string = "*" 48 ) 49 50 func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { 51 origin := r.Header.Get(corsOriginHeader) 52 if !ch.isOriginAllowed(origin) { 53 if r.Method != corsOptionMethod || ch.ignoreOptions { 54 ch.h.ServeHTTP(w, r) 55 } 56 57 return 58 } 59 60 if r.Method == corsOptionMethod { 61 if ch.ignoreOptions { 62 ch.h.ServeHTTP(w, r) 63 return 64 } 65 66 if _, ok := r.Header[corsRequestMethodHeader]; !ok { 67 w.WriteHeader(http.StatusBadRequest) 68 return 69 } 70 71 method := r.Header.Get(corsRequestMethodHeader) 72 if !ch.isMatch(method, ch.allowedMethods) { 73 w.WriteHeader(http.StatusMethodNotAllowed) 74 return 75 } 76 77 requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",") 78 allowedHeaders := []string{} 79 for _, v := range requestHeaders { 80 canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) 81 if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) { 82 continue 83 } 84 85 if !ch.isMatch(canonicalHeader, ch.allowedHeaders) { 86 w.WriteHeader(http.StatusForbidden) 87 return 88 } 89 90 allowedHeaders = append(allowedHeaders, canonicalHeader) 91 } 92 93 if len(allowedHeaders) > 0 { 94 w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ",")) 95 } 96 97 if ch.maxAge > 0 { 98 w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge)) 99 } 100 101 if !ch.isMatch(method, defaultCorsMethods) { 102 w.Header().Set(corsAllowMethodsHeader, method) 103 } 104 } else { 105 if len(ch.exposedHeaders) > 0 { 106 w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ",")) 107 } 108 } 109 110 if ch.allowCredentials { 111 w.Header().Set(corsAllowCredentialsHeader, "true") 112 } 113 114 if len(ch.allowedOrigins) > 1 { 115 w.Header().Set(corsVaryHeader, corsOriginHeader) 116 } 117 118 returnOrigin := origin 119 if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 { 120 returnOrigin = "*" 121 } else { 122 for _, o := range ch.allowedOrigins { 123 // A configuration of * is different than explicitly setting an allowed 124 // origin. Returning arbitrary origin headers in an access control allow 125 // origin header is unsafe and is not required by any use case. 126 if o == corsOriginMatchAll { 127 returnOrigin = "*" 128 break 129 } 130 } 131 } 132 w.Header().Set(corsAllowOriginHeader, returnOrigin) 133 134 if r.Method == corsOptionMethod { 135 w.WriteHeader(ch.optionStatusCode) 136 return 137 } 138 ch.h.ServeHTTP(w, r) 139 } 140 141 // CORS provides Cross-Origin Resource Sharing middleware. 142 // Example: 143 // 144 // import ( 145 // "net/http" 146 // 147 // "github.com/gorilla/handlers" 148 // "github.com/gorilla/mux" 149 // ) 150 // 151 // func main() { 152 // r := mux.NewRouter() 153 // r.HandleFunc("/users", UserEndpoint) 154 // r.HandleFunc("/projects", ProjectEndpoint) 155 // 156 // // Apply the CORS middleware to our top-level router, with the defaults. 157 // http.ListenAndServe(":8000", handlers.CORS()(r)) 158 // } 159 // 160 func CORS(opts ...CORSOption) func(http.Handler) http.Handler { 161 return func(h http.Handler) http.Handler { 162 ch := parseCORSOptions(opts...) 163 ch.h = h 164 return ch 165 } 166 } 167 168 func parseCORSOptions(opts ...CORSOption) *cors { 169 ch := &cors{ 170 allowedMethods: defaultCorsMethods, 171 allowedHeaders: defaultCorsHeaders, 172 allowedOrigins: []string{}, 173 optionStatusCode: defaultCorsOptionStatusCode, 174 } 175 176 for _, option := range opts { 177 option(ch) 178 } 179 180 return ch 181 } 182 183 // 184 // Functional options for configuring CORS. 185 // 186 187 // AllowedHeaders adds the provided headers to the list of allowed headers in a 188 // CORS request. 189 // This is an append operation so the headers Accept, Accept-Language, 190 // and Content-Language are always allowed. 191 // Content-Type must be explicitly declared if accepting Content-Types other than 192 // application/x-www-form-urlencoded, multipart/form-data, or text/plain. 193 func AllowedHeaders(headers []string) CORSOption { 194 return func(ch *cors) error { 195 for _, v := range headers { 196 normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) 197 if normalizedHeader == "" { 198 continue 199 } 200 201 if !ch.isMatch(normalizedHeader, ch.allowedHeaders) { 202 ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader) 203 } 204 } 205 206 return nil 207 } 208 } 209 210 // AllowedMethods can be used to explicitly allow methods in the 211 // Access-Control-Allow-Methods header. 212 // This is a replacement operation so you must also 213 // pass GET, HEAD, and POST if you wish to support those methods. 214 func AllowedMethods(methods []string) CORSOption { 215 return func(ch *cors) error { 216 ch.allowedMethods = []string{} 217 for _, v := range methods { 218 normalizedMethod := strings.ToUpper(strings.TrimSpace(v)) 219 if normalizedMethod == "" { 220 continue 221 } 222 223 if !ch.isMatch(normalizedMethod, ch.allowedMethods) { 224 ch.allowedMethods = append(ch.allowedMethods, normalizedMethod) 225 } 226 } 227 228 return nil 229 } 230 } 231 232 // AllowedOrigins sets the allowed origins for CORS requests, as used in the 233 // 'Allow-Access-Control-Origin' HTTP header. 234 // Note: Passing in a []string{"*"} will allow any domain. 235 func AllowedOrigins(origins []string) CORSOption { 236 return func(ch *cors) error { 237 for _, v := range origins { 238 if v == corsOriginMatchAll { 239 ch.allowedOrigins = []string{corsOriginMatchAll} 240 return nil 241 } 242 } 243 244 ch.allowedOrigins = origins 245 return nil 246 } 247 } 248 249 // AllowedOriginValidator sets a function for evaluating allowed origins in CORS requests, represented by the 250 // 'Allow-Access-Control-Origin' HTTP header. 251 func AllowedOriginValidator(fn OriginValidator) CORSOption { 252 return func(ch *cors) error { 253 ch.allowedOriginValidator = fn 254 return nil 255 } 256 } 257 258 // OptionStatusCode sets a custom status code on the OPTIONS requests. 259 // Default behaviour sets it to 200 to reflect best practices. This is option is not mandatory 260 // and can be used if you need a custom status code (i.e 204). 261 // 262 // More informations on the spec: 263 // https://fetch.spec.whatwg.org/#cors-preflight-fetch 264 func OptionStatusCode(code int) CORSOption { 265 return func(ch *cors) error { 266 ch.optionStatusCode = code 267 return nil 268 } 269 } 270 271 // ExposedHeaders can be used to specify headers that are available 272 // and will not be stripped out by the user-agent. 273 func ExposedHeaders(headers []string) CORSOption { 274 return func(ch *cors) error { 275 ch.exposedHeaders = []string{} 276 for _, v := range headers { 277 normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) 278 if normalizedHeader == "" { 279 continue 280 } 281 282 if !ch.isMatch(normalizedHeader, ch.exposedHeaders) { 283 ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader) 284 } 285 } 286 287 return nil 288 } 289 } 290 291 // MaxAge determines the maximum age (in seconds) between preflight requests. A 292 // maximum of 10 minutes is allowed. An age above this value will default to 10 293 // minutes. 294 func MaxAge(age int) CORSOption { 295 return func(ch *cors) error { 296 // Maximum of 10 minutes. 297 if age > 600 { 298 age = 600 299 } 300 301 ch.maxAge = age 302 return nil 303 } 304 } 305 306 // IgnoreOptions causes the CORS middleware to ignore OPTIONS requests, instead 307 // passing them through to the next handler. This is useful when your application 308 // or framework has a pre-existing mechanism for responding to OPTIONS requests. 309 func IgnoreOptions() CORSOption { 310 return func(ch *cors) error { 311 ch.ignoreOptions = true 312 return nil 313 } 314 } 315 316 // AllowCredentials can be used to specify that the user agent may pass 317 // authentication details along with the request. 318 func AllowCredentials() CORSOption { 319 return func(ch *cors) error { 320 ch.allowCredentials = true 321 return nil 322 } 323 } 324 325 func (ch *cors) isOriginAllowed(origin string) bool { 326 if origin == "" { 327 return false 328 } 329 330 if ch.allowedOriginValidator != nil { 331 return ch.allowedOriginValidator(origin) 332 } 333 334 if len(ch.allowedOrigins) == 0 { 335 return true 336 } 337 338 for _, allowedOrigin := range ch.allowedOrigins { 339 if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll { 340 return true 341 } 342 } 343 344 return false 345 } 346 347 func (ch *cors) isMatch(needle string, haystack []string) bool { 348 for _, v := range haystack { 349 if v == needle { 350 return true 351 } 352 } 353 354 return false 355 }