github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/web/middlewares/cors.go (about)

     1  package middlewares
     2  
     3  import (
     4  	"net/http"
     5  	"strconv"
     6  	"strings"
     7  	"time"
     8  
     9  	"github.com/labstack/echo/v4"
    10  )
    11  
    12  // MaxAgeCORS is used to cache the CORS header for 12 hours
    13  const MaxAgeCORS = "43200"
    14  
    15  // CORSOptions contains different options to create a CORS middleware.
    16  type CORSOptions struct {
    17  	MaxAge         time.Duration
    18  	BlockList      []string
    19  	AllowedMethods []string
    20  }
    21  
    22  // CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
    23  // See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
    24  func CORS(opts CORSOptions) echo.MiddlewareFunc {
    25  	var maxAge string
    26  	if opts.MaxAge != 0 {
    27  		maxAge = strconv.Itoa(int(opts.MaxAge.Seconds()))
    28  	} else {
    29  		maxAge = MaxAgeCORS
    30  	}
    31  
    32  	var allowedMethods []string
    33  	if opts.AllowedMethods == nil {
    34  		allowedMethods = []string{
    35  			echo.GET,
    36  			echo.HEAD,
    37  			echo.PUT,
    38  			echo.PATCH,
    39  			echo.POST,
    40  			echo.DELETE,
    41  		}
    42  	}
    43  
    44  	allowMethods := strings.Join(allowedMethods, ",")
    45  
    46  	return func(next echo.HandlerFunc) echo.HandlerFunc {
    47  		return func(c echo.Context) error {
    48  			req := c.Request()
    49  			res := c.Response()
    50  
    51  			origin := req.Header.Get(echo.HeaderOrigin)
    52  			if origin == "" {
    53  				return next(c)
    54  			}
    55  
    56  			path := c.Path()
    57  			for _, route := range opts.BlockList {
    58  				if strings.HasPrefix(path, route) {
    59  					return next(c)
    60  				}
    61  			}
    62  
    63  			// Simple request
    64  			if req.Method != echo.OPTIONS {
    65  				res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
    66  				res.Header().Set(echo.HeaderAccessControlAllowOrigin, origin)
    67  				res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
    68  				return next(c)
    69  			}
    70  
    71  			// Preflight request
    72  			res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
    73  			res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
    74  			res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
    75  			res.Header().Set(echo.HeaderAccessControlAllowOrigin, origin)
    76  			res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods)
    77  			res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
    78  
    79  			h := req.Header.Get(echo.HeaderAccessControlRequestHeaders)
    80  			if h != "" {
    81  				res.Header().Set(echo.HeaderAccessControlAllowHeaders, h)
    82  			}
    83  
    84  			res.Header().Set(echo.HeaderAccessControlMaxAge, maxAge)
    85  
    86  			return c.NoContent(http.StatusNoContent)
    87  		}
    88  	}
    89  }