github.com/avenga/couper@v1.12.2/handler/middleware/cors.go (about) 1 package middleware 2 3 import ( 4 "math" 5 "net/http" 6 "strconv" 7 "strings" 8 "time" 9 10 "github.com/avenga/couper/config" 11 "github.com/avenga/couper/errors" 12 "github.com/avenga/couper/internal/seetie" 13 ) 14 15 var _ http.Handler = &CORS{} 16 17 type CORS struct { 18 options *CORSOptions 19 nextHandler http.Handler 20 } 21 22 type CORSOptions struct { 23 AllowedOrigins []string 24 AllowCredentials bool 25 MaxAge string 26 methodAllowed methodAllowedFunc 27 } 28 29 func NewCORSOptions(cors *config.CORS, methodAllowed methodAllowedFunc) (*CORSOptions, error) { 30 if cors == nil { 31 return nil, nil 32 } 33 34 var corsMaxAge string 35 if cors.MaxAge != "" { 36 dur, err := time.ParseDuration(cors.MaxAge) 37 if err != nil { 38 return nil, errors.Configuration.With(err).Message("cors max_age") 39 } 40 corsMaxAge = strconv.Itoa(int(math.Floor(dur.Seconds()))) 41 } 42 43 allowedOrigins := seetie.ValueToStringSlice(cors.AllowedOrigins) 44 45 for i, a := range allowedOrigins { 46 allowedOrigins[i] = strings.ToLower(a) 47 } 48 49 return &CORSOptions{ 50 AllowedOrigins: allowedOrigins, 51 AllowCredentials: cors.AllowCredentials, 52 MaxAge: corsMaxAge, 53 methodAllowed: methodAllowed, 54 }, nil 55 } 56 57 func (c *CORSOptions) AllowsOrigin(origin string) bool { 58 if c == nil { 59 return false 60 } 61 62 for _, a := range c.AllowedOrigins { 63 if a == strings.ToLower(origin) || a == "*" { 64 return true 65 } 66 } 67 68 return false 69 } 70 71 func NewCORSHandler(opts *CORSOptions, nextHandler http.Handler) http.Handler { 72 if opts == nil { 73 return nextHandler 74 } 75 return &CORS{ 76 options: opts, 77 nextHandler: nextHandler, 78 } 79 } 80 81 func (c *CORS) ServeNextHTTP(rw http.ResponseWriter, nextHandler http.Handler, req *http.Request) { 82 c.setCorsRespHeaders(rw.Header(), req) 83 84 if c.isCorsPreflightRequest(req) { 85 rw.WriteHeader(http.StatusNoContent) 86 return 87 } 88 89 nextHandler.ServeHTTP(rw, req) 90 } 91 92 func (c *CORS) ServeHTTP(rw http.ResponseWriter, req *http.Request) { 93 c.ServeNextHTTP(rw, c.nextHandler, req) 94 } 95 96 func (c *CORS) isCorsPreflightRequest(req *http.Request) bool { 97 return req.Method == http.MethodOptions && 98 (req.Header.Get("Access-Control-Request-Method") != "" || 99 req.Header.Get("Access-Control-Request-Headers") != "") 100 } 101 102 func (c *CORS) setCorsRespHeaders(headers http.Header, req *http.Request) { 103 // see https://fetch.spec.whatwg.org/#http-responses 104 allowSpecificOrigin := false 105 if c.options.AllowsOrigin("*") && !c.options.AllowCredentials { 106 headers.Set("Access-Control-Allow-Origin", "*") 107 } else { 108 headers.Add("Vary", "Origin") 109 allowSpecificOrigin = true 110 } 111 112 if !c.isCorsRequest(req) { 113 return 114 } 115 116 requestOrigin := req.Header.Get("Origin") 117 if !c.options.AllowsOrigin(requestOrigin) { 118 return 119 } 120 121 if allowSpecificOrigin { 122 headers.Set("Access-Control-Allow-Origin", requestOrigin) 123 } 124 125 if c.options.AllowCredentials { 126 headers.Set("Access-Control-Allow-Credentials", "true") 127 } 128 129 if c.isCorsPreflightRequest(req) { 130 // Reflect request header value 131 acrm := req.Header.Get("Access-Control-Request-Method") 132 if acrm != "" { 133 if c.options.methodAllowed == nil || c.options.methodAllowed(acrm) { 134 headers.Set("Access-Control-Allow-Methods", acrm) 135 } 136 headers.Add("Vary", "Access-Control-Request-Method") 137 } 138 // Reflect request header value 139 acrh := req.Header.Get("Access-Control-Request-Headers") 140 if acrh != "" { 141 headers.Set("Access-Control-Allow-Headers", acrh) 142 headers.Add("Vary", "Access-Control-Request-Headers") 143 } 144 if c.options.MaxAge != "" { 145 headers.Set("Access-Control-Max-Age", c.options.MaxAge) 146 } 147 } 148 } 149 150 func (c *CORS) isCorsRequest(req *http.Request) bool { 151 return req.Header.Get("Origin") != "" 152 } 153 154 func (c *CORS) Child() http.Handler { 155 return c.nextHandler 156 }