github.com/astaxie/beego@v1.12.3/plugins/cors/cors.go (about) 1 // Copyright 2014 beego Author. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package cors provides handlers to enable CORS support. 16 // Usage 17 // import ( 18 // "github.com/astaxie/beego" 19 // "github.com/astaxie/beego/plugins/cors" 20 // ) 21 // 22 // func main() { 23 // // CORS for https://foo.* origins, allowing: 24 // // - PUT and PATCH methods 25 // // - Origin header 26 // // - Credentials share 27 // beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{ 28 // AllowOrigins: []string{"https://*.foo.com"}, 29 // AllowMethods: []string{"PUT", "PATCH"}, 30 // AllowHeaders: []string{"Origin"}, 31 // ExposeHeaders: []string{"Content-Length"}, 32 // AllowCredentials: true, 33 // })) 34 // beego.Run() 35 // } 36 package cors 37 38 import ( 39 "net/http" 40 "regexp" 41 "strconv" 42 "strings" 43 "time" 44 45 "github.com/astaxie/beego" 46 "github.com/astaxie/beego/context" 47 ) 48 49 const ( 50 headerAllowOrigin = "Access-Control-Allow-Origin" 51 headerAllowCredentials = "Access-Control-Allow-Credentials" 52 headerAllowHeaders = "Access-Control-Allow-Headers" 53 headerAllowMethods = "Access-Control-Allow-Methods" 54 headerExposeHeaders = "Access-Control-Expose-Headers" 55 headerMaxAge = "Access-Control-Max-Age" 56 57 headerOrigin = "Origin" 58 headerRequestMethod = "Access-Control-Request-Method" 59 headerRequestHeaders = "Access-Control-Request-Headers" 60 ) 61 62 var ( 63 defaultAllowHeaders = []string{"Origin", "Accept", "Content-Type", "Authorization"} 64 // Regex patterns are generated from AllowOrigins. These are used and generated internally. 65 allowOriginPatterns = []string{} 66 ) 67 68 // Options represents Access Control options. 69 type Options struct { 70 // If set, all origins are allowed. 71 AllowAllOrigins bool 72 // A list of allowed origins. Wild cards and FQDNs are supported. 73 AllowOrigins []string 74 // If set, allows to share auth credentials such as cookies. 75 AllowCredentials bool 76 // A list of allowed HTTP methods. 77 AllowMethods []string 78 // A list of allowed HTTP headers. 79 AllowHeaders []string 80 // A list of exposed HTTP headers. 81 ExposeHeaders []string 82 // Max age of the CORS headers. 83 MaxAge time.Duration 84 } 85 86 // Header converts options into CORS headers. 87 func (o *Options) Header(origin string) (headers map[string]string) { 88 headers = make(map[string]string) 89 // if origin is not allowed, don't extend the headers 90 // with CORS headers. 91 if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) { 92 return 93 } 94 95 // add allow origin 96 if o.AllowAllOrigins { 97 headers[headerAllowOrigin] = "*" 98 } else { 99 headers[headerAllowOrigin] = origin 100 } 101 102 // add allow credentials 103 headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials) 104 105 // add allow methods 106 if len(o.AllowMethods) > 0 { 107 headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",") 108 } 109 110 // add allow headers 111 if len(o.AllowHeaders) > 0 { 112 headers[headerAllowHeaders] = strings.Join(o.AllowHeaders, ",") 113 } 114 115 // add exposed header 116 if len(o.ExposeHeaders) > 0 { 117 headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",") 118 } 119 // add a max age header 120 if o.MaxAge > time.Duration(0) { 121 headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10) 122 } 123 return 124 } 125 126 // PreflightHeader converts options into CORS headers for a preflight response. 127 func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) { 128 headers = make(map[string]string) 129 if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) { 130 return 131 } 132 // verify if requested method is allowed 133 for _, method := range o.AllowMethods { 134 if method == rMethod { 135 headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",") 136 break 137 } 138 } 139 140 // verify if requested headers are allowed 141 var allowed []string 142 for _, rHeader := range strings.Split(rHeaders, ",") { 143 rHeader = strings.TrimSpace(rHeader) 144 lookupLoop: 145 for _, allowedHeader := range o.AllowHeaders { 146 if strings.ToLower(rHeader) == strings.ToLower(allowedHeader) { 147 allowed = append(allowed, rHeader) 148 break lookupLoop 149 } 150 } 151 } 152 153 headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials) 154 // add allow origin 155 if o.AllowAllOrigins { 156 headers[headerAllowOrigin] = "*" 157 } else { 158 headers[headerAllowOrigin] = origin 159 } 160 161 // add allowed headers 162 if len(allowed) > 0 { 163 headers[headerAllowHeaders] = strings.Join(allowed, ",") 164 } 165 166 // add exposed headers 167 if len(o.ExposeHeaders) > 0 { 168 headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",") 169 } 170 // add a max age header 171 if o.MaxAge > time.Duration(0) { 172 headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10) 173 } 174 return 175 } 176 177 // IsOriginAllowed looks up if the origin matches one of the patterns 178 // generated from Options.AllowOrigins patterns. 179 func (o *Options) IsOriginAllowed(origin string) (allowed bool) { 180 for _, pattern := range allowOriginPatterns { 181 allowed, _ = regexp.MatchString(pattern, origin) 182 if allowed { 183 return 184 } 185 } 186 return 187 } 188 189 // Allow enables CORS for requests those match the provided options. 190 func Allow(opts *Options) beego.FilterFunc { 191 // Allow default headers if nothing is specified. 192 if len(opts.AllowHeaders) == 0 { 193 opts.AllowHeaders = defaultAllowHeaders 194 } 195 196 for _, origin := range opts.AllowOrigins { 197 pattern := regexp.QuoteMeta(origin) 198 pattern = strings.Replace(pattern, "\\*", ".*", -1) 199 pattern = strings.Replace(pattern, "\\?", ".", -1) 200 allowOriginPatterns = append(allowOriginPatterns, "^"+pattern+"$") 201 } 202 203 return func(ctx *context.Context) { 204 var ( 205 origin = ctx.Input.Header(headerOrigin) 206 requestedMethod = ctx.Input.Header(headerRequestMethod) 207 requestedHeaders = ctx.Input.Header(headerRequestHeaders) 208 // additional headers to be added 209 // to the response. 210 headers map[string]string 211 ) 212 213 if ctx.Input.Method() == "OPTIONS" && 214 (requestedMethod != "" || requestedHeaders != "") { 215 headers = opts.PreflightHeader(origin, requestedMethod, requestedHeaders) 216 for key, value := range headers { 217 ctx.Output.Header(key, value) 218 } 219 ctx.ResponseWriter.WriteHeader(http.StatusOK) 220 return 221 } 222 headers = opts.Header(origin) 223 224 for key, value := range headers { 225 ctx.Output.Header(key, value) 226 } 227 } 228 }