github.com/avenga/couper@v1.12.2/handler/middleware/cors.go (about)

     1  package middleware
     2  
     3  import (
     4  	"math"
     5  	"net/http"
     6  	"strconv"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/avenga/couper/config"
    11  	"github.com/avenga/couper/errors"
    12  	"github.com/avenga/couper/internal/seetie"
    13  )
    14  
    15  var _ http.Handler = &CORS{}
    16  
    17  type CORS struct {
    18  	options     *CORSOptions
    19  	nextHandler http.Handler
    20  }
    21  
    22  type CORSOptions struct {
    23  	AllowedOrigins   []string
    24  	AllowCredentials bool
    25  	MaxAge           string
    26  	methodAllowed    methodAllowedFunc
    27  }
    28  
    29  func NewCORSOptions(cors *config.CORS, methodAllowed methodAllowedFunc) (*CORSOptions, error) {
    30  	if cors == nil {
    31  		return nil, nil
    32  	}
    33  
    34  	var corsMaxAge string
    35  	if cors.MaxAge != "" {
    36  		dur, err := time.ParseDuration(cors.MaxAge)
    37  		if err != nil {
    38  			return nil, errors.Configuration.With(err).Message("cors max_age")
    39  		}
    40  		corsMaxAge = strconv.Itoa(int(math.Floor(dur.Seconds())))
    41  	}
    42  
    43  	allowedOrigins := seetie.ValueToStringSlice(cors.AllowedOrigins)
    44  
    45  	for i, a := range allowedOrigins {
    46  		allowedOrigins[i] = strings.ToLower(a)
    47  	}
    48  
    49  	return &CORSOptions{
    50  		AllowedOrigins:   allowedOrigins,
    51  		AllowCredentials: cors.AllowCredentials,
    52  		MaxAge:           corsMaxAge,
    53  		methodAllowed:    methodAllowed,
    54  	}, nil
    55  }
    56  
    57  func (c *CORSOptions) AllowsOrigin(origin string) bool {
    58  	if c == nil {
    59  		return false
    60  	}
    61  
    62  	for _, a := range c.AllowedOrigins {
    63  		if a == strings.ToLower(origin) || a == "*" {
    64  			return true
    65  		}
    66  	}
    67  
    68  	return false
    69  }
    70  
    71  func NewCORSHandler(opts *CORSOptions, nextHandler http.Handler) http.Handler {
    72  	if opts == nil {
    73  		return nextHandler
    74  	}
    75  	return &CORS{
    76  		options:     opts,
    77  		nextHandler: nextHandler,
    78  	}
    79  }
    80  
    81  func (c *CORS) ServeNextHTTP(rw http.ResponseWriter, nextHandler http.Handler, req *http.Request) {
    82  	c.setCorsRespHeaders(rw.Header(), req)
    83  
    84  	if c.isCorsPreflightRequest(req) {
    85  		rw.WriteHeader(http.StatusNoContent)
    86  		return
    87  	}
    88  
    89  	nextHandler.ServeHTTP(rw, req)
    90  }
    91  
    92  func (c *CORS) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
    93  	c.ServeNextHTTP(rw, c.nextHandler, req)
    94  }
    95  
    96  func (c *CORS) isCorsPreflightRequest(req *http.Request) bool {
    97  	return req.Method == http.MethodOptions &&
    98  		(req.Header.Get("Access-Control-Request-Method") != "" ||
    99  			req.Header.Get("Access-Control-Request-Headers") != "")
   100  }
   101  
   102  func (c *CORS) setCorsRespHeaders(headers http.Header, req *http.Request) {
   103  	// see https://fetch.spec.whatwg.org/#http-responses
   104  	allowSpecificOrigin := false
   105  	if c.options.AllowsOrigin("*") && !c.options.AllowCredentials {
   106  		headers.Set("Access-Control-Allow-Origin", "*")
   107  	} else {
   108  		headers.Add("Vary", "Origin")
   109  		allowSpecificOrigin = true
   110  	}
   111  
   112  	if !c.isCorsRequest(req) {
   113  		return
   114  	}
   115  
   116  	requestOrigin := req.Header.Get("Origin")
   117  	if !c.options.AllowsOrigin(requestOrigin) {
   118  		return
   119  	}
   120  
   121  	if allowSpecificOrigin {
   122  		headers.Set("Access-Control-Allow-Origin", requestOrigin)
   123  	}
   124  
   125  	if c.options.AllowCredentials {
   126  		headers.Set("Access-Control-Allow-Credentials", "true")
   127  	}
   128  
   129  	if c.isCorsPreflightRequest(req) {
   130  		// Reflect request header value
   131  		acrm := req.Header.Get("Access-Control-Request-Method")
   132  		if acrm != "" {
   133  			if c.options.methodAllowed == nil || c.options.methodAllowed(acrm) {
   134  				headers.Set("Access-Control-Allow-Methods", acrm)
   135  			}
   136  			headers.Add("Vary", "Access-Control-Request-Method")
   137  		}
   138  		// Reflect request header value
   139  		acrh := req.Header.Get("Access-Control-Request-Headers")
   140  		if acrh != "" {
   141  			headers.Set("Access-Control-Allow-Headers", acrh)
   142  			headers.Add("Vary", "Access-Control-Request-Headers")
   143  		}
   144  		if c.options.MaxAge != "" {
   145  			headers.Set("Access-Control-Max-Age", c.options.MaxAge)
   146  		}
   147  	}
   148  }
   149  
   150  func (c *CORS) isCorsRequest(req *http.Request) bool {
   151  	return req.Header.Get("Origin") != ""
   152  }
   153  
   154  func (c *CORS) Child() http.Handler {
   155  	return c.nextHandler
   156  }