code.vegaprotocol.io/vega@v0.79.0/libs/http/cors.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package http
    17  
    18  import (
    19  	"net/http"
    20  	"strings"
    21  
    22  	"github.com/rs/cors"
    23  )
    24  
    25  // CORSConfig represents the configuration for CORS.
    26  type CORSConfig struct {
    27  	AllowedOrigins []string `description:"Allowed origins for CORS"                 long:"allowed-origins"`
    28  	MaxAge         int      `description:"Max age (in seconds) for preflight cache" long:"max-age"`
    29  }
    30  
    31  func CORSOptions(config CORSConfig) cors.Options {
    32  	return cors.Options{
    33  		AllowOriginFunc: AllowedOrigin(config.AllowedOrigins),
    34  		AllowedMethods: []string{
    35  			http.MethodHead,
    36  			http.MethodGet,
    37  			http.MethodPost,
    38  			http.MethodPut,
    39  			http.MethodPatch,
    40  			http.MethodDelete,
    41  		},
    42  		AllowedHeaders:   []string{"*"},
    43  		ExposedHeaders:   []string{"*"},
    44  		MaxAge:           config.MaxAge,
    45  		AllowCredentials: false,
    46  	}
    47  }
    48  
    49  func AllowedOrigin(allowedOrigins []string) func(origin string) bool {
    50  	trimScheme := func(origin string) string {
    51  		return strings.TrimPrefix(strings.TrimPrefix(origin, "https://"), "http://")
    52  	}
    53  	return func(origin string) bool {
    54  		if len(allowedOrigins) == 0 || allowedOrigins[0] == "*" {
    55  			return true
    56  		}
    57  		for _, allowedOrigin := range allowedOrigins {
    58  			if allowedOrigin == origin || trimScheme(allowedOrigin) == trimScheme(origin) {
    59  				return true
    60  			}
    61  		}
    62  		return false
    63  	}
    64  }