github.com/angenalZZZ/gofunc@v0.0.0-20210507121333-48ff1be3917b/http/fast/middleware/cors/cors.go (about) 1 package cors 2 3 import ( 4 "github.com/angenalZZZ/gofunc/http/fast" 5 "strconv" 6 "strings" 7 ) 8 9 // Config defines the config for cors middleware 10 type Config struct { 11 // Filter defines a function to skip middleware. 12 // Optional. Default: nil 13 Filter func(*fast.Ctx) bool 14 // Optional. Default value []string{"*"}. 15 AllowOrigins []string 16 // Optional. Default value "GET,POST,HEAD,PUT,DELETE,PATCH" 17 AllowMethods string 18 // Optional. Default value "". 19 AllowHeaders string 20 // Optional. Default value false. 21 AllowCredentials bool 22 // Optional. Default value "". 23 ExposeHeaders string 24 // Optional. Default value 0. 25 MaxAge int64 26 // X-XSS-Protection... 27 X bool 28 } 29 30 // New middleware. 31 // cfg := cors.Config{ 32 // AllowHeaders: "authorization, origin, content-type, accept", 33 // MaxAge: 86400, 34 // X: true, 35 // } 36 // app.Use(cors.New(cfg)) 37 func New(config ...Config) func(*fast.Ctx) { 38 // Init config 39 var cfg Config 40 if len(config) > 0 { 41 cfg = config[0] 42 } 43 if len(cfg.AllowOrigins) == 0 { 44 cfg.AllowOrigins = []string{"*"} 45 } 46 if cfg.AllowMethods == "" { 47 cfg.AllowMethods = "GET,POST,HEAD,PUT,DELETE,PATCH" 48 } 49 // Middleware function 50 return func(c *fast.Ctx) { 51 // Filter request to skip middleware 52 if cfg.Filter != nil && cfg.Filter(c) { 53 c.Next() 54 return 55 } 56 origin := c.GetHeader("Origin") 57 allowOrigin := "" 58 // Check allowed origins 59 for _, o := range cfg.AllowOrigins { 60 if o == "*" && cfg.AllowCredentials { 61 allowOrigin = origin 62 break 63 } 64 if o == "*" || o == origin { 65 allowOrigin = o 66 break 67 } 68 if matchSubDomain(origin, o) { 69 allowOrigin = origin 70 break 71 } 72 } 73 // Simple request 74 if c.Method() != "OPTIONS" { 75 c.Vary("Origin") 76 c.SetHeader("Access-Control-Allow-Origin", allowOrigin) 77 78 if cfg.AllowCredentials { 79 c.SetHeader("Access-Control-Allow-Credentials", "true") 80 } 81 if cfg.ExposeHeaders != "" { 82 c.SetHeader("Access-Control-Expose-Headers", cfg.ExposeHeaders) 83 } 84 if cfg.X { 85 c.XSSProtection() 86 } 87 c.Next() 88 return 89 } 90 // Preflight request 91 c.Vary("Origin") 92 c.Vary("Access-Control-Request-Method") 93 c.Vary("Access-Control-Request-Headers") 94 c.SetHeader("Access-Control-Allow-Origin", allowOrigin) 95 c.SetHeader("Access-Control-Allow-Methods", cfg.AllowMethods) 96 97 if cfg.AllowCredentials { 98 c.SetHeader("Access-Control-Allow-Credentials", "true") 99 } 100 if cfg.AllowHeaders != "" { 101 c.SetHeader("Access-Control-Allow-Headers", cfg.AllowHeaders) 102 } else { 103 h := c.GetHeader("Access-Control-Request-Headers") 104 if h != "" { 105 c.SetHeader("Access-Control-Allow-Headers", h) 106 } 107 } 108 if cfg.MaxAge > 0 { 109 c.SetHeader("Access-Control-Max-Age", strconv.FormatInt(cfg.MaxAge, 10)) 110 } 111 if cfg.X { 112 c.XSSProtection() 113 } 114 c.SendStatus(204) // No Content 115 } 116 } 117 118 // find domain 119 func matchScheme(domain, pattern string) bool { 120 i := strings.Index(domain, ":") 121 p := strings.Index(pattern, ":") 122 return i != -1 && p != -1 && domain[:i] == pattern[:p] 123 } 124 125 // compares authority with wildcard 126 func matchSubDomain(domain, pattern string) bool { 127 if !matchScheme(domain, pattern) { 128 return false 129 } 130 i := strings.Index(domain, "://") 131 p := strings.Index(pattern, "://") 132 if i == -1 || p == -1 { 133 return false 134 } 135 domAuth := domain[i+3:] 136 // to avoid long loop by invalid long domain 137 if len(domAuth) > 253 { 138 return false 139 } 140 patAuth := pattern[p+3:] 141 domComp := strings.Split(domAuth, ".") 142 patComp := strings.Split(patAuth, ".") 143 for i := len(domComp)/2 - 1; i >= 0; i-- { 144 opp := len(domComp) - 1 - i 145 domComp[i], domComp[opp] = domComp[opp], domComp[i] 146 } 147 for i := len(patComp)/2 - 1; i >= 0; i-- { 148 opp := len(patComp) - 1 - i 149 patComp[i], patComp[opp] = patComp[opp], patComp[i] 150 } 151 for i, v := range domComp { 152 if len(patComp) <= i { 153 return false 154 } 155 p := patComp[i] 156 if p == "*" { 157 return true 158 } 159 if p != v { 160 return false 161 } 162 } 163 return false 164 }