github.com/lingyao2333/mo-zero@v1.4.1/rest/internal/cors/handlers.go (about) 1 package cors 2 3 import ( 4 "net/http" 5 "strings" 6 7 "github.com/lingyao2333/mo-zero/rest/internal/response" 8 ) 9 10 const ( 11 allowOrigin = "Access-Control-Allow-Origin" 12 allOrigins = "*" 13 allowMethods = "Access-Control-Allow-Methods" 14 allowHeaders = "Access-Control-Allow-Headers" 15 allowCredentials = "Access-Control-Allow-Credentials" 16 exposeHeaders = "Access-Control-Expose-Headers" 17 requestMethod = "Access-Control-Request-Method" 18 requestHeaders = "Access-Control-Request-Headers" 19 allowHeadersVal = "Content-Type, Origin, X-CSRF-Token, Authorization, AccessToken, Token, Range" 20 exposeHeadersVal = "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers" 21 methods = "GET, HEAD, POST, PATCH, PUT, DELETE" 22 allowTrue = "true" 23 maxAgeHeader = "Access-Control-Max-Age" 24 maxAgeHeaderVal = "86400" 25 varyHeader = "Vary" 26 originHeader = "Origin" 27 ) 28 29 // NotAllowedHandler handles cross domain not allowed requests. 30 // At most one origin can be specified, other origins are ignored if given, default to be *. 31 func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.Handler { 32 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 33 gw := response.NewHeaderOnceResponseWriter(w) 34 checkAndSetHeaders(gw, r, origins) 35 if fn != nil { 36 fn(gw) 37 } 38 39 if r.Method == http.MethodOptions { 40 gw.WriteHeader(http.StatusNoContent) 41 } else { 42 gw.WriteHeader(http.StatusNotFound) 43 } 44 }) 45 } 46 47 // Middleware returns a middleware that adds CORS headers to the response. 48 func Middleware(fn func(w http.Header), origins ...string) func(http.HandlerFunc) http.HandlerFunc { 49 return func(next http.HandlerFunc) http.HandlerFunc { 50 return func(w http.ResponseWriter, r *http.Request) { 51 checkAndSetHeaders(w, r, origins) 52 if fn != nil { 53 fn(w.Header()) 54 } 55 56 if r.Method == http.MethodOptions { 57 w.WriteHeader(http.StatusNoContent) 58 } else { 59 next(w, r) 60 } 61 } 62 } 63 } 64 65 func checkAndSetHeaders(w http.ResponseWriter, r *http.Request, origins []string) { 66 setVaryHeaders(w, r) 67 68 if len(origins) == 0 { 69 setHeader(w, allOrigins) 70 return 71 } 72 73 origin := r.Header.Get(originHeader) 74 if isOriginAllowed(origins, origin) { 75 setHeader(w, origin) 76 } 77 } 78 79 func isOriginAllowed(allows []string, origin string) bool { 80 for _, o := range allows { 81 if o == allOrigins { 82 return true 83 } 84 85 if strings.HasSuffix(origin, o) { 86 return true 87 } 88 } 89 90 return false 91 } 92 93 func setHeader(w http.ResponseWriter, origin string) { 94 header := w.Header() 95 header.Set(allowOrigin, origin) 96 header.Set(allowMethods, methods) 97 header.Set(allowHeaders, allowHeadersVal) 98 header.Set(exposeHeaders, exposeHeadersVal) 99 if origin != allOrigins { 100 header.Set(allowCredentials, allowTrue) 101 } 102 header.Set(maxAgeHeader, maxAgeHeaderVal) 103 } 104 105 func setVaryHeaders(w http.ResponseWriter, r *http.Request) { 106 header := w.Header() 107 header.Add(varyHeader, originHeader) 108 if r.Method == http.MethodOptions { 109 header.Add(varyHeader, requestMethod) 110 header.Add(varyHeader, requestHeaders) 111 } 112 }