github.com/aacfactory/fns@v1.2.86-0.20240310083819-80d667fc0a17/transports/middlewares/cors/cors.go (about) 1 /* 2 * Copyright 2023 Wang Min Xiang 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 * 16 */ 17 18 package cors 19 20 import ( 21 "bytes" 22 "github.com/aacfactory/errors" 23 "github.com/aacfactory/fns/commons/bytex" 24 "github.com/aacfactory/fns/commons/wildcard" 25 "github.com/aacfactory/fns/transports" 26 "net/http" 27 "slices" 28 "strconv" 29 "strings" 30 ) 31 32 func New() transports.Middleware { 33 return &corsMiddleware{} 34 } 35 36 type corsMiddleware struct { 37 allowedOrigins [][]byte 38 allowedWOrigins []*wildcard.Wildcard 39 allowedOriginsAll bool 40 allowedHeaders [][]byte 41 allowedHeadersAll bool 42 allowedMethods [][]byte 43 exposedHeaders [][]byte 44 maxAge int 45 allowCredentials bool 46 allowPrivateNetwork bool 47 preflightVary [][]byte 48 handler transports.Handler 49 } 50 51 func (c *corsMiddleware) Name() string { 52 return "cors" 53 } 54 55 func (c *corsMiddleware) Construct(options transports.MiddlewareOptions) (err error) { 56 config := Config{} 57 err = options.Config.As(&config) 58 if err != nil { 59 err = errors.Warning("fns: build cors middleware failed").WithCause(err) 60 return 61 } 62 allowedOrigins := make([][]byte, 0, 1) 63 allowedWOrigins := make([]*wildcard.Wildcard, 0, 1) 64 allowedOriginsAll := false 65 if config.AllowedHeaders == nil { 66 config.AllowedHeaders = make([]string, 0, 1) 67 } 68 if len(config.AllowedHeaders) == 0 || config.AllowedHeaders[0] != "*" { 69 defaultAllowedHeaders := []string{ 70 string(transports.OriginHeaderName), string(transports.AcceptHeaderName), string(transports.ContentTypeHeaderName), 71 string(transports.AcceptEncodingHeaderName), 72 string(transports.XRequestedWithHeaderName), 73 string(transports.ConnectionHeaderName), string(transports.UpgradeHeaderName), 74 string(transports.XForwardedForHeaderName), string(transports.TrueClientIpHeaderName), string(transports.XRealIpHeaderName), 75 string(transports.DeviceIpHeaderName), string(transports.DeviceIdHeaderName), 76 string(transports.RequestIdHeaderName), 77 string(transports.RequestTimeoutHeaderName), string(transports.RequestVersionsHeaderName), 78 string(transports.CacheControlHeaderIfNonMatch), string(transports.CacheControlHeaderName), 79 string(transports.SignatureHeaderName), 80 } 81 for _, header := range defaultAllowedHeaders { 82 if !slices.Contains(config.AllowedHeaders, header) { 83 config.AllowedHeaders = append(config.AllowedHeaders, header) 84 } 85 } 86 } 87 if len(config.AllowedOrigins) == 0 { 88 config.AllowedOrigins = []string{"*"} 89 } 90 for _, origin := range config.AllowedOrigins { 91 origin = strings.ToLower(origin) 92 if origin == "*" { 93 allowedOriginsAll = true 94 allowedOrigins = nil 95 allowedWOrigins = nil 96 break 97 } else if i := strings.IndexByte(origin, '*'); i >= 0 { 98 w := wildcard.New(bytex.FromString(origin)) 99 allowedWOrigins = append(allowedWOrigins, w) 100 } else { 101 allowedOrigins = append(allowedOrigins, bytex.FromString(origin)) 102 } 103 } 104 allowedHeadersAll := false 105 allowedHeaders := make([][]byte, 0, 1) 106 for _, header := range config.AllowedHeaders { 107 allowedHeaders = append(allowedHeaders, bytex.FromString(header)) 108 } 109 allowedHeaders = convert(allowedHeaders, http.CanonicalHeaderKey) 110 for _, h := range config.AllowedHeaders { 111 if h == "*" { 112 allowedHeadersAll = true 113 allowedHeaders = nil 114 break 115 } 116 } 117 118 exposedHeaders := make([][]byte, 0, 1) 119 if config.ExposedHeaders == nil { 120 config.ExposedHeaders = make([]string, 0, 1) 121 } 122 defaultExposedHeaders := []string{ 123 string(transports.VaryHeaderName), 124 string(transports.DeviceIdHeaderName), 125 string(transports.EndpointIdHeaderName), string(transports.EndpointVersionHeaderName), 126 string(transports.ContentEncodingHeaderName), 127 string(transports.RequestIdHeaderName), string(transports.HandleLatencyHeaderName), 128 string(transports.CacheControlHeaderName), string(transports.ETagHeaderName), string(transports.ClearSiteDataHeaderName), string(transports.AgeHeaderName), 129 string(transports.ResponseRetryAfterHeaderName), string(transports.SignatureHeaderName), 130 string(transports.DeprecatedHeaderName), 131 } 132 for _, header := range defaultExposedHeaders { 133 if !slices.Contains(config.ExposedHeaders, header) { 134 config.ExposedHeaders = append(config.ExposedHeaders, header) 135 } 136 } 137 for _, header := range config.ExposedHeaders { 138 exposedHeaders = append(exposedHeaders, bytex.FromString(header)) 139 } 140 exposedHeaders = convert(exposedHeaders, http.CanonicalHeaderKey) 141 142 c.allowedOrigins = allowedOrigins 143 c.allowedWOrigins = allowedWOrigins 144 c.allowedOriginsAll = allowedOriginsAll 145 c.allowedHeaders = allowedHeaders 146 c.allowedHeadersAll = allowedHeadersAll 147 c.allowedMethods = [][]byte{methodGet, methodPost, methodHead} 148 c.exposedHeaders = exposedHeaders 149 c.maxAge = config.MaxAge 150 c.allowCredentials = config.AllowCredentials 151 c.allowPrivateNetwork = config.AllowPrivateNetwork 152 153 if c.allowPrivateNetwork { 154 c.preflightVary = [][]byte{[]byte("Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network")} 155 } else { 156 c.preflightVary = [][]byte{[]byte("Origin, Access-Control-Request-Method, Access-Control-Request-Headers")} 157 } 158 return 159 } 160 161 func (c *corsMiddleware) Handler(next transports.Handler) transports.Handler { 162 c.handler = next 163 return c 164 } 165 166 func (c *corsMiddleware) Close() (err error) { 167 return 168 } 169 170 func (c *corsMiddleware) Handle(w transports.ResponseWriter, r transports.Request) { 171 if bytes.Equal(r.Method(), methodOptions) && len(r.Header().Get(accessControlRequestMethodHeader)) > 0 { 172 c.handlePreflight(w, r) 173 w.SetStatus(http.StatusNoContent) 174 } else { 175 c.handleActualRequest(w, r) 176 c.handler.Handle(w, r) 177 } 178 } 179 180 func (c *corsMiddleware) handlePreflight(w transports.ResponseWriter, r transports.Request) { 181 headers := w.Header() 182 origin := r.Header().Get(originHeader) 183 184 if !bytes.Equal(r.Method(), methodOptions) { 185 return 186 } 187 188 if vary := headers.Get(varyHeader); len(vary) > 0 { 189 headers.Add(varyHeader, c.preflightVary[0]) 190 } else { 191 for _, preflightVary := range c.preflightVary { 192 headers.Add(varyHeader, preflightVary) 193 } 194 } 195 196 if len(origin) == 0 { 197 return 198 } 199 if !c.isOriginAllowed(origin) { 200 return 201 } 202 203 reqMethod := r.Header().Get(accessControlRequestMethodHeader) 204 if !c.isMethodAllowed(reqMethod) { 205 return 206 } 207 reqHeadersRaw := r.Header().Values(accessControlRequestHeadersHeader) 208 reqHeaders, reqHeadersEdited := parseHeaderList(reqHeadersRaw) 209 if !c.areHeadersAllowed(reqHeaders) { 210 return 211 } 212 if c.allowedOriginsAll { 213 headers.Set(accessControlAllowOriginHeader, all) 214 } else { 215 origins := w.Header().Values(originHeader) 216 for _, ori := range origins { 217 headers.Add(accessControlAllowOriginHeader, ori) 218 } 219 } 220 headers.Set(accessControlAllowMethodsHeader, bytes.ToUpper(reqMethod)) 221 if len(reqHeaders) > 0 { 222 if reqHeadersEdited || len(reqHeaders) != len(reqHeadersRaw) { 223 headers.Set(accessControlAllowHeadersHeader, bytes.Join(reqHeaders, joinBytes)) 224 } else { 225 for _, raw := range reqHeadersRaw { 226 headers.Add(accessControlAllowHeadersHeader, raw) 227 } 228 } 229 } 230 if c.allowCredentials { 231 headers.Set(accessControlAllowCredentialsHeader, trueBytes) 232 } 233 234 if c.allowPrivateNetwork && bytes.Equal(r.Header().Get(accessControlRequestPrivateNetworkHeader), trueBytes) { 235 headers.Set(accessControlAllowPrivateNetworkHeader, trueBytes) 236 } 237 238 if c.maxAge > 0 { 239 headers.Set(accessControlMaxAgeHeader, bytex.FromString(strconv.Itoa(c.maxAge))) 240 } 241 } 242 243 func (c *corsMiddleware) handleActualRequest(w transports.ResponseWriter, r transports.Request) { 244 headers := w.Header() 245 origin := r.Header().Get(originHeader) 246 247 if len(origin) == 0 { 248 return 249 } 250 if !c.isOriginAllowed(origin) { 251 return 252 } 253 254 if !c.isMethodAllowed(r.Method()) { 255 return 256 } 257 if c.allowedOriginsAll { 258 headers.Set(accessControlAllowOriginHeader, all) 259 } else { 260 origins := w.Header().Values(originHeader) 261 for _, ori := range origins { 262 headers.Add(accessControlAllowOriginHeader, ori) 263 } 264 } 265 if len(c.exposedHeaders) > 0 { 266 for _, exposedHeader := range c.exposedHeaders { 267 headers.Add(accessControlExposeHeadersHeader, exposedHeader) 268 } 269 } 270 if c.allowCredentials { 271 headers.Set(accessControlAllowCredentialsHeader, trueBytes) 272 } 273 } 274 275 func (c *corsMiddleware) isOriginAllowed(origin []byte) bool { 276 if c.allowedOriginsAll { 277 return true 278 } 279 origin = bytes.ToLower(origin) 280 for _, o := range c.allowedOrigins { 281 if bytes.Equal(o, origin) { 282 return true 283 } 284 } 285 for _, w := range c.allowedWOrigins { 286 if w.Match(origin) { 287 return true 288 } 289 } 290 return false 291 } 292 293 func (c *corsMiddleware) isMethodAllowed(method []byte) bool { 294 if len(c.allowedMethods) == 0 { 295 return false 296 } 297 ms := bytes.ToUpper(method) 298 if bytes.Equal(ms, methodOptions) { 299 return true 300 } 301 for _, m := range c.allowedMethods { 302 if bytes.Equal(ms, m) { 303 return true 304 } 305 } 306 return false 307 } 308 309 func (c *corsMiddleware) areHeadersAllowed(requestedHeaders [][]byte) bool { 310 if c.allowedHeadersAll || len(requestedHeaders) == 0 { 311 return true 312 } 313 for _, header := range requestedHeaders { 314 hs := bytex.FromString(http.CanonicalHeaderKey(bytex.ToString(header))) 315 found := false 316 for _, h := range c.allowedHeaders { 317 if bytes.Equal(hs, h) { 318 found = true 319 break 320 } 321 if bytes.Index(hs, transports.UserHeaderNamePrefix) == 0 { 322 found = true 323 break 324 } 325 } 326 if !found { 327 return false 328 } 329 } 330 return true 331 }