github.com/adharshmk96/stk@v1.2.3/pkg/middleware/cors.go (about)

     1  package middleware
     2  
     3  import (
     4  	"net/http"
     5  	"strings"
     6  
     7  	"github.com/adharshmk96/stk/gsk"
     8  )
     9  
    10  var (
    11  	defaultAllowMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE", "PATCH"}
    12  	// "POST, GET, OPTIONS, PUT, DELETE, PATCH"
    13  	defaultAllowHeaders = []string{"Accept", "Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization"}
    14  	// "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization"
    15  	defaultAllowCredentials = "true"
    16  )
    17  
    18  const (
    19  	defaultCORSOrigin = "same-origin"
    20  
    21  	AccessControlAllowOrigin      = "Access-Control-Allow-Origin"
    22  	AccessControlAllowMethods     = "Access-Control-Allow-Methods"
    23  	AccessControlAllowHeaders     = "Access-Control-Allow-Headers"
    24  	AccessControlAllowCredentials = "Access-Control-Allow-Credentials"
    25  )
    26  
    27  type CORSConfig struct {
    28  	AllowedOrigins []string
    29  	AllowedMethods []string
    30  	AllowedHeaders []string
    31  	AllowAll       bool
    32  }
    33  
    34  func CORS(config ...CORSConfig) gsk.Middleware {
    35  	var corsConfig CORSConfig
    36  	if len(config) > 0 {
    37  		corsConfig = config[0]
    38  	} else {
    39  		corsConfig = CORSConfig{
    40  			AllowedOrigins: []string{defaultCORSOrigin},
    41  			AllowedMethods: defaultAllowMethods,
    42  			AllowedHeaders: defaultAllowHeaders,
    43  			AllowAll:       false,
    44  		}
    45  	}
    46  
    47  	allowedMethods := strings.Join(defaultAllowMethods, ", ")
    48  	if len(corsConfig.AllowedMethods) != 0 {
    49  		allowedMethods = strings.Join(corsConfig.AllowedMethods, ", ")
    50  	}
    51  	allowedHeaders := strings.Join(defaultAllowHeaders, ", ")
    52  	if len(corsConfig.AllowedHeaders) != 0 {
    53  		allowedHeaders = strings.Join(corsConfig.AllowedHeaders, ", ")
    54  	}
    55  
    56  	return func(next gsk.HandlerFunc) gsk.HandlerFunc {
    57  		return func(c *gsk.Context) {
    58  			allowedOrigins := getAllowedOrigins(corsConfig.AllowedOrigins)
    59  
    60  			origin := c.Origin()
    61  			// Check if the origin is in the allowedOrigins list
    62  			isAllowed := false
    63  			for _, allowedOrigin := range allowedOrigins {
    64  				if allowedOrigin == "same-origin" || allowedOrigin == "*" || origin == allowedOrigin {
    65  					isAllowed = true
    66  					break
    67  				}
    68  			}
    69  
    70  			if !corsConfig.AllowAll && !isAllowed {
    71  				c.Status(http.StatusForbidden)
    72  				c.SetHeader("Content-Type", "text/plain")
    73  				c.RawResponse([]byte("Forbidden"))
    74  				return
    75  			}
    76  
    77  			// Set CORS headers
    78  			headers := c.Writer.Header()
    79  
    80  			headers.Set(AccessControlAllowOrigin, origin)
    81  			headers.Set(AccessControlAllowMethods, allowedMethods)
    82  			headers.Set(AccessControlAllowHeaders, allowedHeaders)
    83  			headers.Set(AccessControlAllowCredentials, defaultAllowCredentials)
    84  
    85  			next(c)
    86  
    87  		}
    88  	}
    89  }
    90  
    91  func getAllowedOrigins(origins []string) []string {
    92  	var allowedOrigins []string
    93  
    94  	if len(origins) == 0 {
    95  		allowedOrigins = []string{defaultCORSOrigin}
    96  	} else {
    97  		allowedOrigins = origins
    98  	}
    99  	return allowedOrigins
   100  }