github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/cors/cors.go (about) 1 package cors 2 3 import ( 4 "net/http" 5 "strings" 6 7 "github.com/sohaha/zlsgo/znet" 8 "github.com/sohaha/zlsgo/zstring" 9 ) 10 11 type ( 12 // Config cors configuration 13 Config struct { 14 CustomHandler Handler 15 methods string 16 credentials string 17 headers string 18 exposeHeaders string 19 Domains []string 20 Methods []string 21 Credentials []string 22 Headers []string 23 ExposeHeaders []string 24 } 25 Handler func(conf *Config, c *znet.Context) 26 ) 27 28 const ( 29 DefaultHeaders = "Origin,No-Cache,X-Requested-With,If-Modified-Since,Pragma,Last-Modified,Cache-Control,Expires,Content-Type,Access-Control-Allow-Origin,Authorization" 30 ) 31 32 func Default() znet.HandlerFunc { 33 return New(&Config{}) 34 } 35 36 func NewAllowHeaders() (addAllowHeader func(header string), handler znet.HandlerFunc) { 37 conf := &Config{} 38 handler = New(conf) 39 40 return func(header string) { 41 conf.headers = conf.headers + ", " + header 42 }, handler 43 } 44 45 func New(conf *Config) znet.HandlerFunc { 46 if len(conf.Methods) == 0 { 47 conf.Methods = []string{ 48 http.MethodGet, 49 http.MethodHead, 50 http.MethodPost, 51 http.MethodPut, 52 http.MethodPatch, 53 http.MethodDelete, 54 http.MethodConnect, 55 http.MethodOptions, 56 http.MethodTrace, 57 } 58 } 59 conf.methods = strings.Join(conf.Methods, ", ") 60 if len(conf.Credentials) == 0 { 61 conf.Credentials = []string{"true"} 62 } 63 conf.credentials = strings.Join(conf.Credentials, ", ") 64 if len(conf.Headers) != 0 { 65 conf.headers = strings.Join(conf.Headers, ", ") 66 } else { 67 conf.headers = DefaultHeaders 68 } 69 70 if len(conf.ExposeHeaders) > 0 { 71 conf.exposeHeaders = strings.Join(conf.ExposeHeaders, ", ") 72 } 73 74 return func(c *znet.Context) { 75 if applyCors(c, conf) { 76 c.Next() 77 } 78 } 79 } 80 81 func applyCors(c *znet.Context, conf *Config) bool { 82 origin := c.GetHeader("Origin") 83 if len(origin) == 0 { 84 return true 85 } 86 87 domains := conf.Domains 88 if len(domains) > 0 { 89 adopt := false 90 for k := range domains { 91 if adopt = zstring.Match(origin, domains[k]); adopt { 92 break 93 } 94 } 95 if !adopt { 96 c.Abort(http.StatusForbidden) 97 return false 98 } 99 } 100 101 c.SetHeader("Access-Control-Allow-Methods", conf.methods) 102 c.SetHeader("Access-Control-Allow-Credentials", conf.credentials) 103 c.SetHeader("Access-Control-Allow-Headers", conf.headers) 104 if conf.exposeHeaders != "" { 105 c.SetHeader("Access-Control-Expose-Headers", conf.exposeHeaders) 106 } 107 c.SetHeader("Access-Control-Allow-Origin", origin) 108 if conf.CustomHandler != nil { 109 conf.CustomHandler(conf, c) 110 } 111 112 if c.Request.Method == "OPTIONS" { 113 c.Abort(http.StatusNoContent) 114 return false 115 } 116 117 return true 118 }