gobot.io/x/gobot@v1.16.0/api/cors.go (about)

     1  package api
     2  
     3  import (
     4  	"net/http"
     5  	"regexp"
     6  	"strings"
     7  )
     8  
     9  // CORS represents CORS configuration
    10  type CORS struct {
    11  	AllowOrigins        []string
    12  	AllowHeaders        []string
    13  	AllowMethods        []string
    14  	ContentType         string
    15  	allowOriginPatterns []string
    16  }
    17  
    18  // AllowRequestsFrom returns handler to verify that requests come from allowedOrigins
    19  func AllowRequestsFrom(allowedOrigins ...string) http.HandlerFunc {
    20  	c := &CORS{
    21  		AllowOrigins: allowedOrigins,
    22  		AllowMethods: []string{"GET", "POST"},
    23  		AllowHeaders: []string{"Origin", "Content-Type"},
    24  		ContentType:  "application/json; charset=utf-8",
    25  	}
    26  
    27  	c.generatePatterns()
    28  
    29  	return func(w http.ResponseWriter, req *http.Request) {
    30  		origin := req.Header.Get("Origin")
    31  		if c.isOriginAllowed(origin) {
    32  			w.Header().Set("Access-Control-Allow-Origin", origin)
    33  			w.Header().Set("Access-Control-Allow-Headers", c.AllowedHeaders())
    34  			w.Header().Set("Access-Control-Allow-Methods", c.AllowedMethods())
    35  			w.Header().Set("Content-Type", c.ContentType)
    36  		}
    37  	}
    38  }
    39  
    40  // isOriginAllowed returns true if origin matches an allowed origin pattern.
    41  func (c *CORS) isOriginAllowed(origin string) (allowed bool) {
    42  	for _, allowedOriginPattern := range c.allowOriginPatterns {
    43  		allowed, _ = regexp.MatchString(allowedOriginPattern, origin)
    44  		if allowed {
    45  			return
    46  		}
    47  	}
    48  	return
    49  }
    50  
    51  // generatePatterns generates regex expression for AllowOrigins
    52  func (c *CORS) generatePatterns() {
    53  	if c.AllowOrigins != nil {
    54  		for _, origin := range c.AllowOrigins {
    55  			pattern := regexp.QuoteMeta(origin)
    56  			pattern = strings.Replace(pattern, "\\*", ".*", -1)
    57  			pattern = strings.Replace(pattern, "\\?", ".", -1)
    58  			c.allowOriginPatterns = append(c.allowOriginPatterns, "^"+pattern+"$")
    59  		}
    60  	}
    61  }
    62  
    63  // AllowedHeaders returns allowed headers in a string
    64  func (c *CORS) AllowedHeaders() string {
    65  	return strings.Join(c.AllowHeaders, ",")
    66  }
    67  
    68  // AllowedMethods returns allowed http methods in a string
    69  func (c *CORS) AllowedMethods() string {
    70  	return strings.Join(c.AllowMethods, ",")
    71  }