github.com/angenalZZZ/gofunc@v0.0.0-20210507121333-48ff1be3917b/http/fast/middleware/cors/cors.go (about)

     1  package cors
     2  
     3  import (
     4  	"github.com/angenalZZZ/gofunc/http/fast"
     5  	"strconv"
     6  	"strings"
     7  )
     8  
     9  // Config defines the config for cors middleware
    10  type Config struct {
    11  	// Filter defines a function to skip middleware.
    12  	// Optional. Default: nil
    13  	Filter func(*fast.Ctx) bool
    14  	// Optional. Default value []string{"*"}.
    15  	AllowOrigins []string
    16  	// Optional. Default value "GET,POST,HEAD,PUT,DELETE,PATCH"
    17  	AllowMethods string
    18  	// Optional. Default value "".
    19  	AllowHeaders string
    20  	// Optional. Default value false.
    21  	AllowCredentials bool
    22  	// Optional. Default value "".
    23  	ExposeHeaders string
    24  	// Optional. Default value 0.
    25  	MaxAge int64
    26  	// X-XSS-Protection...
    27  	X bool
    28  }
    29  
    30  // New middleware.
    31  //  cfg := cors.Config{
    32  //    AllowHeaders: "authorization, origin, content-type, accept",
    33  //    MaxAge: 86400,
    34  //    X: true,
    35  //  }
    36  // app.Use(cors.New(cfg))
    37  func New(config ...Config) func(*fast.Ctx) {
    38  	// Init config
    39  	var cfg Config
    40  	if len(config) > 0 {
    41  		cfg = config[0]
    42  	}
    43  	if len(cfg.AllowOrigins) == 0 {
    44  		cfg.AllowOrigins = []string{"*"}
    45  	}
    46  	if cfg.AllowMethods == "" {
    47  		cfg.AllowMethods = "GET,POST,HEAD,PUT,DELETE,PATCH"
    48  	}
    49  	// Middleware function
    50  	return func(c *fast.Ctx) {
    51  		// Filter request to skip middleware
    52  		if cfg.Filter != nil && cfg.Filter(c) {
    53  			c.Next()
    54  			return
    55  		}
    56  		origin := c.GetHeader("Origin")
    57  		allowOrigin := ""
    58  		// Check allowed origins
    59  		for _, o := range cfg.AllowOrigins {
    60  			if o == "*" && cfg.AllowCredentials {
    61  				allowOrigin = origin
    62  				break
    63  			}
    64  			if o == "*" || o == origin {
    65  				allowOrigin = o
    66  				break
    67  			}
    68  			if matchSubDomain(origin, o) {
    69  				allowOrigin = origin
    70  				break
    71  			}
    72  		}
    73  		// Simple request
    74  		if c.Method() != "OPTIONS" {
    75  			c.Vary("Origin")
    76  			c.SetHeader("Access-Control-Allow-Origin", allowOrigin)
    77  
    78  			if cfg.AllowCredentials {
    79  				c.SetHeader("Access-Control-Allow-Credentials", "true")
    80  			}
    81  			if cfg.ExposeHeaders != "" {
    82  				c.SetHeader("Access-Control-Expose-Headers", cfg.ExposeHeaders)
    83  			}
    84  			if cfg.X {
    85  				c.XSSProtection()
    86  			}
    87  			c.Next()
    88  			return
    89  		}
    90  		// Preflight request
    91  		c.Vary("Origin")
    92  		c.Vary("Access-Control-Request-Method")
    93  		c.Vary("Access-Control-Request-Headers")
    94  		c.SetHeader("Access-Control-Allow-Origin", allowOrigin)
    95  		c.SetHeader("Access-Control-Allow-Methods", cfg.AllowMethods)
    96  
    97  		if cfg.AllowCredentials {
    98  			c.SetHeader("Access-Control-Allow-Credentials", "true")
    99  		}
   100  		if cfg.AllowHeaders != "" {
   101  			c.SetHeader("Access-Control-Allow-Headers", cfg.AllowHeaders)
   102  		} else {
   103  			h := c.GetHeader("Access-Control-Request-Headers")
   104  			if h != "" {
   105  				c.SetHeader("Access-Control-Allow-Headers", h)
   106  			}
   107  		}
   108  		if cfg.MaxAge > 0 {
   109  			c.SetHeader("Access-Control-Max-Age", strconv.FormatInt(cfg.MaxAge, 10))
   110  		}
   111  		if cfg.X {
   112  			c.XSSProtection()
   113  		}
   114  		c.SendStatus(204) // No Content
   115  	}
   116  }
   117  
   118  // find domain
   119  func matchScheme(domain, pattern string) bool {
   120  	i := strings.Index(domain, ":")
   121  	p := strings.Index(pattern, ":")
   122  	return i != -1 && p != -1 && domain[:i] == pattern[:p]
   123  }
   124  
   125  // compares authority with wildcard
   126  func matchSubDomain(domain, pattern string) bool {
   127  	if !matchScheme(domain, pattern) {
   128  		return false
   129  	}
   130  	i := strings.Index(domain, "://")
   131  	p := strings.Index(pattern, "://")
   132  	if i == -1 || p == -1 {
   133  		return false
   134  	}
   135  	domAuth := domain[i+3:]
   136  	// to avoid long loop by invalid long domain
   137  	if len(domAuth) > 253 {
   138  		return false
   139  	}
   140  	patAuth := pattern[p+3:]
   141  	domComp := strings.Split(domAuth, ".")
   142  	patComp := strings.Split(patAuth, ".")
   143  	for i := len(domComp)/2 - 1; i >= 0; i-- {
   144  		opp := len(domComp) - 1 - i
   145  		domComp[i], domComp[opp] = domComp[opp], domComp[i]
   146  	}
   147  	for i := len(patComp)/2 - 1; i >= 0; i-- {
   148  		opp := len(patComp) - 1 - i
   149  		patComp[i], patComp[opp] = patComp[opp], patComp[i]
   150  	}
   151  	for i, v := range domComp {
   152  		if len(patComp) <= i {
   153  			return false
   154  		}
   155  		p := patComp[i]
   156  		if p == "*" {
   157  			return true
   158  		}
   159  		if p != v {
   160  			return false
   161  		}
   162  	}
   163  	return false
   164  }