github.com/gofiber/fiber/v2@v2.47.0/middleware/cors/cors_test.go (about) 1 package cors 2 3 import ( 4 "net/http/httptest" 5 "strings" 6 "testing" 7 8 "github.com/gofiber/fiber/v2" 9 "github.com/gofiber/fiber/v2/utils" 10 11 "github.com/valyala/fasthttp" 12 ) 13 14 func Test_CORS_Defaults(t *testing.T) { 15 t.Parallel() 16 app := fiber.New() 17 app.Use(New()) 18 19 testDefaultOrEmptyConfig(t, app) 20 } 21 22 func Test_CORS_Empty_Config(t *testing.T) { 23 t.Parallel() 24 app := fiber.New() 25 app.Use(New(Config{})) 26 27 testDefaultOrEmptyConfig(t, app) 28 } 29 30 func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) { 31 t.Helper() 32 33 h := app.Handler() 34 35 // Test default GET response headers 36 ctx := &fasthttp.RequestCtx{} 37 ctx.Request.Header.SetMethod(fiber.MethodGet) 38 h(ctx) 39 40 utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) 41 utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) 42 utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders))) 43 44 // Test default OPTIONS (preflight) response headers 45 ctx = &fasthttp.RequestCtx{} 46 ctx.Request.Header.SetMethod(fiber.MethodOptions) 47 h(ctx) 48 49 utils.AssertEqual(t, "GET,POST,HEAD,PUT,DELETE,PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods))) 50 utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders))) 51 utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge))) 52 } 53 54 // go test -run -v Test_CORS_Wildcard 55 func Test_CORS_Wildcard(t *testing.T) { 56 t.Parallel() 57 // New fiber instance 58 app := fiber.New() 59 // OPTIONS (preflight) response headers when AllowOrigins is * 60 app.Use(New(Config{ 61 AllowOrigins: "*", 62 AllowCredentials: true, 63 MaxAge: 3600, 64 ExposeHeaders: "X-Request-ID", 65 AllowHeaders: "Authentication", 66 })) 67 // Get handler pointer 68 handler := app.Handler() 69 70 // Make request 71 ctx := &fasthttp.RequestCtx{} 72 ctx.Request.SetRequestURI("/") 73 ctx.Request.Header.Set(fiber.HeaderOrigin, "localhost") 74 ctx.Request.Header.SetMethod(fiber.MethodOptions) 75 76 // Perform request 77 handler(ctx) 78 79 // Check result 80 utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) 81 utils.AssertEqual(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) 82 utils.AssertEqual(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge))) 83 utils.AssertEqual(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders))) 84 85 // Test non OPTIONS (preflight) response headers 86 ctx = &fasthttp.RequestCtx{} 87 ctx.Request.Header.SetMethod(fiber.MethodGet) 88 handler(ctx) 89 90 utils.AssertEqual(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) 91 utils.AssertEqual(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders))) 92 } 93 94 // go test -run -v Test_CORS_Subdomain 95 func Test_CORS_Subdomain(t *testing.T) { 96 t.Parallel() 97 // New fiber instance 98 app := fiber.New() 99 // OPTIONS (preflight) response headers when AllowOrigins is set to a subdomain 100 app.Use("/", New(Config{AllowOrigins: "http://*.example.com"})) 101 102 // Get handler pointer 103 handler := app.Handler() 104 105 // Make request with disallowed origin 106 ctx := &fasthttp.RequestCtx{} 107 ctx.Request.SetRequestURI("/") 108 ctx.Request.Header.SetMethod(fiber.MethodOptions) 109 ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com") 110 111 // Perform request 112 handler(ctx) 113 114 // Allow-Origin header should be "" because http://google.com does not satisfy http://*.example.com 115 utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) 116 117 ctx.Request.Reset() 118 ctx.Response.Reset() 119 120 // Make request with allowed origin 121 ctx.Request.SetRequestURI("/") 122 ctx.Request.Header.SetMethod(fiber.MethodOptions) 123 ctx.Request.Header.Set(fiber.HeaderOrigin, "http://test.example.com") 124 125 handler(ctx) 126 127 utils.AssertEqual(t, "http://test.example.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) 128 } 129 130 func Test_CORS_AllowOriginScheme(t *testing.T) { 131 t.Parallel() 132 tests := []struct { 133 reqOrigin, pattern string 134 shouldAllowOrigin bool 135 }{ 136 { 137 pattern: "http://example.com", 138 reqOrigin: "http://example.com", 139 shouldAllowOrigin: true, 140 }, 141 { 142 pattern: "https://example.com", 143 reqOrigin: "https://example.com", 144 shouldAllowOrigin: true, 145 }, 146 { 147 pattern: "http://example.com", 148 reqOrigin: "https://example.com", 149 shouldAllowOrigin: false, 150 }, 151 { 152 pattern: "http://*.example.com", 153 reqOrigin: "http://aaa.example.com", 154 shouldAllowOrigin: true, 155 }, 156 { 157 pattern: "http://*.example.com", 158 reqOrigin: "http://bbb.aaa.example.com", 159 shouldAllowOrigin: true, 160 }, 161 { 162 pattern: "http://*.aaa.example.com", 163 reqOrigin: "http://bbb.aaa.example.com", 164 shouldAllowOrigin: true, 165 }, 166 { 167 pattern: "http://*.example.com:8080", 168 reqOrigin: "http://aaa.example.com:8080", 169 shouldAllowOrigin: true, 170 }, 171 { 172 pattern: "http://example.com", 173 reqOrigin: "http://gofiber.com", 174 shouldAllowOrigin: false, 175 }, 176 { 177 pattern: "http://*.aaa.example.com", 178 reqOrigin: "http://ccc.bbb.example.com", 179 shouldAllowOrigin: false, 180 }, 181 { 182 pattern: "http://*.example.com", 183 reqOrigin: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ 184 .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ 185 .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ 186 .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`, 187 shouldAllowOrigin: false, 188 }, 189 { 190 pattern: "http://example.com", 191 reqOrigin: "http://ccc.bbb.example.com", 192 shouldAllowOrigin: false, 193 }, 194 { 195 pattern: "https://*--aaa.bbb.com", 196 reqOrigin: "https://prod-preview--aaa.bbb.com", 197 shouldAllowOrigin: false, 198 }, 199 { 200 pattern: "http://*.example.com", 201 reqOrigin: "http://ccc.bbb.example.com", 202 shouldAllowOrigin: true, 203 }, 204 { 205 pattern: "http://foo.[a-z]*.example.com", 206 reqOrigin: "http://ccc.bbb.example.com", 207 shouldAllowOrigin: false, 208 }, 209 } 210 211 for _, tt := range tests { 212 app := fiber.New() 213 app.Use("/", New(Config{AllowOrigins: tt.pattern})) 214 215 handler := app.Handler() 216 217 ctx := &fasthttp.RequestCtx{} 218 ctx.Request.SetRequestURI("/") 219 ctx.Request.Header.SetMethod(fiber.MethodOptions) 220 ctx.Request.Header.Set(fiber.HeaderOrigin, tt.reqOrigin) 221 222 handler(ctx) 223 224 if tt.shouldAllowOrigin { 225 utils.AssertEqual(t, tt.reqOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) 226 } else { 227 utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) 228 } 229 } 230 } 231 232 // go test -run Test_CORS_Next 233 func Test_CORS_Next(t *testing.T) { 234 t.Parallel() 235 app := fiber.New() 236 app.Use(New(Config{ 237 Next: func(_ *fiber.Ctx) bool { 238 return true 239 }, 240 })) 241 242 resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) 243 utils.AssertEqual(t, nil, err) 244 utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) 245 } 246 247 func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) { 248 t.Parallel() 249 // New fiber instance 250 app := fiber.New() 251 app.Use("/", New(Config{ 252 AllowOrigins: "http://example-1.com", 253 AllowOriginsFunc: func(origin string) bool { 254 return strings.Contains(origin, "example-2") 255 }, 256 })) 257 258 // Get handler pointer 259 handler := app.Handler() 260 261 // Make request with disallowed origin 262 ctx := &fasthttp.RequestCtx{} 263 ctx.Request.SetRequestURI("/") 264 ctx.Request.Header.SetMethod(fiber.MethodOptions) 265 ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com") 266 267 // Perform request 268 handler(ctx) 269 270 // Allow-Origin header should be "" because http://google.com does not satisfy http://example-1.com or 'strings.Contains(origin, "example-2")' 271 utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) 272 273 ctx.Request.Reset() 274 ctx.Response.Reset() 275 276 // Make request with allowed origin 277 ctx.Request.SetRequestURI("/") 278 ctx.Request.Header.SetMethod(fiber.MethodOptions) 279 ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-1.com") 280 281 handler(ctx) 282 283 utils.AssertEqual(t, "http://example-1.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) 284 285 ctx.Request.Reset() 286 ctx.Response.Reset() 287 288 // Make request with allowed origin 289 ctx.Request.SetRequestURI("/") 290 ctx.Request.Header.SetMethod(fiber.MethodOptions) 291 ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com") 292 293 handler(ctx) 294 295 utils.AssertEqual(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) 296 } 297 298 func Test_CORS_AllowOriginsFunc(t *testing.T) { 299 t.Parallel() 300 // New fiber instance 301 app := fiber.New() 302 app.Use("/", New(Config{ 303 AllowOriginsFunc: func(origin string) bool { 304 return strings.Contains(origin, "example-2") 305 }, 306 })) 307 308 // Get handler pointer 309 handler := app.Handler() 310 311 // Make request with disallowed origin 312 ctx := &fasthttp.RequestCtx{} 313 ctx.Request.SetRequestURI("/") 314 ctx.Request.Header.SetMethod(fiber.MethodOptions) 315 ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com") 316 317 // Perform request 318 handler(ctx) 319 320 // Allow-Origin header should be "*" because http://google.com does not satisfy 'strings.Contains(origin, "example-2")' 321 // and AllowOrigins has not been set so the default "*" is used 322 utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) 323 324 ctx.Request.Reset() 325 ctx.Response.Reset() 326 327 // Make request with allowed origin 328 ctx.Request.SetRequestURI("/") 329 ctx.Request.Header.SetMethod(fiber.MethodOptions) 330 ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com") 331 332 handler(ctx) 333 334 // Allow-Origin header should be "http://example-2.com" 335 utils.AssertEqual(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) 336 }