github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/cors/cors.go (about)

     1  package cors
     2  
     3  import (
     4  	"net/http"
     5  	"strings"
     6  
     7  	"github.com/sohaha/zlsgo/znet"
     8  	"github.com/sohaha/zlsgo/zstring"
     9  )
    10  
    11  type (
    12  	// Config cors configuration
    13  	Config struct {
    14  		CustomHandler Handler
    15  		methods       string
    16  		credentials   string
    17  		headers       string
    18  		exposeHeaders string
    19  		Domains       []string
    20  		Methods       []string
    21  		Credentials   []string
    22  		Headers       []string
    23  		ExposeHeaders []string
    24  	}
    25  	Handler func(conf *Config, c *znet.Context)
    26  )
    27  
    28  const (
    29  	DefaultHeaders = "Origin,No-Cache,X-Requested-With,If-Modified-Since,Pragma,Last-Modified,Cache-Control,Expires,Content-Type,Access-Control-Allow-Origin,Authorization"
    30  )
    31  
    32  func Default() znet.HandlerFunc {
    33  	return New(&Config{})
    34  }
    35  
    36  func NewAllowHeaders() (addAllowHeader func(header string), handler znet.HandlerFunc) {
    37  	conf := &Config{}
    38  	handler = New(conf)
    39  
    40  	return func(header string) {
    41  		conf.headers = conf.headers + ", " + header
    42  	}, handler
    43  }
    44  
    45  func New(conf *Config) znet.HandlerFunc {
    46  	if len(conf.Methods) == 0 {
    47  		conf.Methods = []string{
    48  			http.MethodGet,
    49  			http.MethodHead,
    50  			http.MethodPost,
    51  			http.MethodPut,
    52  			http.MethodPatch,
    53  			http.MethodDelete,
    54  			http.MethodConnect,
    55  			http.MethodOptions,
    56  			http.MethodTrace,
    57  		}
    58  	}
    59  	conf.methods = strings.Join(conf.Methods, ", ")
    60  	if len(conf.Credentials) == 0 {
    61  		conf.Credentials = []string{"true"}
    62  	}
    63  	conf.credentials = strings.Join(conf.Credentials, ", ")
    64  	if len(conf.Headers) != 0 {
    65  		conf.headers = strings.Join(conf.Headers, ", ")
    66  	} else {
    67  		conf.headers = DefaultHeaders
    68  	}
    69  
    70  	if len(conf.ExposeHeaders) > 0 {
    71  		conf.exposeHeaders = strings.Join(conf.ExposeHeaders, ", ")
    72  	}
    73  
    74  	return func(c *znet.Context) {
    75  		if applyCors(c, conf) {
    76  			c.Next()
    77  		}
    78  	}
    79  }
    80  
    81  func applyCors(c *znet.Context, conf *Config) bool {
    82  	origin := c.GetHeader("Origin")
    83  	if len(origin) == 0 {
    84  		return true
    85  	}
    86  
    87  	domains := conf.Domains
    88  	if len(domains) > 0 {
    89  		adopt := false
    90  		for k := range domains {
    91  			if adopt = zstring.Match(origin, domains[k]); adopt {
    92  				break
    93  			}
    94  		}
    95  		if !adopt {
    96  			c.Abort(http.StatusForbidden)
    97  			return false
    98  		}
    99  	}
   100  
   101  	c.SetHeader("Access-Control-Allow-Methods", conf.methods)
   102  	c.SetHeader("Access-Control-Allow-Credentials", conf.credentials)
   103  	c.SetHeader("Access-Control-Allow-Headers", conf.headers)
   104  	if conf.exposeHeaders != "" {
   105  		c.SetHeader("Access-Control-Expose-Headers", conf.exposeHeaders)
   106  	}
   107  	c.SetHeader("Access-Control-Allow-Origin", origin)
   108  	if conf.CustomHandler != nil {
   109  		conf.CustomHandler(conf, c)
   110  	}
   111  
   112  	if c.Request.Method == "OPTIONS" {
   113  		c.Abort(http.StatusNoContent)
   114  		return false
   115  	}
   116  
   117  	return true
   118  }