github.com/adharshmk96/stk@v1.2.3/pkg/middleware/cors.go (about) 1 package middleware 2 3 import ( 4 "net/http" 5 "strings" 6 7 "github.com/adharshmk96/stk/gsk" 8 ) 9 10 var ( 11 defaultAllowMethods = []string{"POST", "GET", "OPTIONS", "PUT", "DELETE", "PATCH"} 12 // "POST, GET, OPTIONS, PUT, DELETE, PATCH" 13 defaultAllowHeaders = []string{"Accept", "Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization"} 14 // "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization" 15 defaultAllowCredentials = "true" 16 ) 17 18 const ( 19 defaultCORSOrigin = "same-origin" 20 21 AccessControlAllowOrigin = "Access-Control-Allow-Origin" 22 AccessControlAllowMethods = "Access-Control-Allow-Methods" 23 AccessControlAllowHeaders = "Access-Control-Allow-Headers" 24 AccessControlAllowCredentials = "Access-Control-Allow-Credentials" 25 ) 26 27 type CORSConfig struct { 28 AllowedOrigins []string 29 AllowedMethods []string 30 AllowedHeaders []string 31 AllowAll bool 32 } 33 34 func CORS(config ...CORSConfig) gsk.Middleware { 35 var corsConfig CORSConfig 36 if len(config) > 0 { 37 corsConfig = config[0] 38 } else { 39 corsConfig = CORSConfig{ 40 AllowedOrigins: []string{defaultCORSOrigin}, 41 AllowedMethods: defaultAllowMethods, 42 AllowedHeaders: defaultAllowHeaders, 43 AllowAll: false, 44 } 45 } 46 47 allowedMethods := strings.Join(defaultAllowMethods, ", ") 48 if len(corsConfig.AllowedMethods) != 0 { 49 allowedMethods = strings.Join(corsConfig.AllowedMethods, ", ") 50 } 51 allowedHeaders := strings.Join(defaultAllowHeaders, ", ") 52 if len(corsConfig.AllowedHeaders) != 0 { 53 allowedHeaders = strings.Join(corsConfig.AllowedHeaders, ", ") 54 } 55 56 return func(next gsk.HandlerFunc) gsk.HandlerFunc { 57 return func(c *gsk.Context) { 58 allowedOrigins := getAllowedOrigins(corsConfig.AllowedOrigins) 59 60 origin := c.Origin() 61 // Check if the origin is in the allowedOrigins list 62 isAllowed := false 63 for _, allowedOrigin := range allowedOrigins { 64 if allowedOrigin == "same-origin" || allowedOrigin == "*" || origin == allowedOrigin { 65 isAllowed = true 66 break 67 } 68 } 69 70 if !corsConfig.AllowAll && !isAllowed { 71 c.Status(http.StatusForbidden) 72 c.SetHeader("Content-Type", "text/plain") 73 c.RawResponse([]byte("Forbidden")) 74 return 75 } 76 77 // Set CORS headers 78 headers := c.Writer.Header() 79 80 headers.Set(AccessControlAllowOrigin, origin) 81 headers.Set(AccessControlAllowMethods, allowedMethods) 82 headers.Set(AccessControlAllowHeaders, allowedHeaders) 83 headers.Set(AccessControlAllowCredentials, defaultAllowCredentials) 84 85 next(c) 86 87 } 88 } 89 } 90 91 func getAllowedOrigins(origins []string) []string { 92 var allowedOrigins []string 93 94 if len(origins) == 0 { 95 allowedOrigins = []string{defaultCORSOrigin} 96 } else { 97 allowedOrigins = origins 98 } 99 return allowedOrigins 100 }