gobot.io/x/gobot@v1.16.0/api/cors.go (about) 1 package api 2 3 import ( 4 "net/http" 5 "regexp" 6 "strings" 7 ) 8 9 // CORS represents CORS configuration 10 type CORS struct { 11 AllowOrigins []string 12 AllowHeaders []string 13 AllowMethods []string 14 ContentType string 15 allowOriginPatterns []string 16 } 17 18 // AllowRequestsFrom returns handler to verify that requests come from allowedOrigins 19 func AllowRequestsFrom(allowedOrigins ...string) http.HandlerFunc { 20 c := &CORS{ 21 AllowOrigins: allowedOrigins, 22 AllowMethods: []string{"GET", "POST"}, 23 AllowHeaders: []string{"Origin", "Content-Type"}, 24 ContentType: "application/json; charset=utf-8", 25 } 26 27 c.generatePatterns() 28 29 return func(w http.ResponseWriter, req *http.Request) { 30 origin := req.Header.Get("Origin") 31 if c.isOriginAllowed(origin) { 32 w.Header().Set("Access-Control-Allow-Origin", origin) 33 w.Header().Set("Access-Control-Allow-Headers", c.AllowedHeaders()) 34 w.Header().Set("Access-Control-Allow-Methods", c.AllowedMethods()) 35 w.Header().Set("Content-Type", c.ContentType) 36 } 37 } 38 } 39 40 // isOriginAllowed returns true if origin matches an allowed origin pattern. 41 func (c *CORS) isOriginAllowed(origin string) (allowed bool) { 42 for _, allowedOriginPattern := range c.allowOriginPatterns { 43 allowed, _ = regexp.MatchString(allowedOriginPattern, origin) 44 if allowed { 45 return 46 } 47 } 48 return 49 } 50 51 // generatePatterns generates regex expression for AllowOrigins 52 func (c *CORS) generatePatterns() { 53 if c.AllowOrigins != nil { 54 for _, origin := range c.AllowOrigins { 55 pattern := regexp.QuoteMeta(origin) 56 pattern = strings.Replace(pattern, "\\*", ".*", -1) 57 pattern = strings.Replace(pattern, "\\?", ".", -1) 58 c.allowOriginPatterns = append(c.allowOriginPatterns, "^"+pattern+"$") 59 } 60 } 61 } 62 63 // AllowedHeaders returns allowed headers in a string 64 func (c *CORS) AllowedHeaders() string { 65 return strings.Join(c.AllowHeaders, ",") 66 } 67 68 // AllowedMethods returns allowed http methods in a string 69 func (c *CORS) AllowedMethods() string { 70 return strings.Join(c.AllowMethods, ",") 71 }