git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/httpx/cors/cors_test.go (about) 1 package cors 2 3 import ( 4 "net/http" 5 "net/http/httptest" 6 "regexp" 7 "strings" 8 "testing" 9 ) 10 11 var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 12 w.Write([]byte("bar")) 13 }) 14 15 var allHeaders = []string{ 16 "Vary", 17 "Access-Control-Allow-Origin", 18 "Access-Control-Allow-Methods", 19 "Access-Control-Allow-Headers", 20 "Access-Control-Allow-Credentials", 21 "Access-Control-Allow-Private-Network", 22 "Access-Control-Max-Age", 23 "Access-Control-Expose-Headers", 24 } 25 26 func assertHeaders(t *testing.T, resHeaders http.Header, expHeaders map[string]string) { 27 for _, name := range allHeaders { 28 got := strings.Join(resHeaders[name], ", ") 29 want := expHeaders[name] 30 if got != want { 31 t.Errorf("Response header %q = %q, want %q", name, got, want) 32 } 33 } 34 } 35 36 func assertResponse(t *testing.T, res *httptest.ResponseRecorder, responseCode int) { 37 if responseCode != res.Code { 38 t.Errorf("assertResponse: expected response code to be %d but got %d. ", responseCode, res.Code) 39 } 40 } 41 42 func TestSpec(t *testing.T) { 43 cases := []struct { 44 name string 45 options Options 46 method string 47 reqHeaders map[string]string 48 resHeaders map[string]string 49 originAllowed bool 50 }{ 51 { 52 "NoConfig", 53 Options{ 54 // Intentionally left blank. 55 }, 56 "GET", 57 map[string]string{}, 58 map[string]string{ 59 "Vary": "Origin", 60 }, 61 true, 62 }, 63 { 64 "MatchAllOrigin", 65 Options{ 66 AllowedOrigins: []string{"*"}, 67 }, 68 "GET", 69 map[string]string{ 70 "Origin": "http://foobar.com", 71 }, 72 map[string]string{ 73 "Vary": "Origin", 74 "Access-Control-Allow-Origin": "*", 75 }, 76 true, 77 }, 78 { 79 "MatchAllOriginWithCredentials", 80 Options{ 81 AllowedOrigins: []string{"*"}, 82 AllowCredentials: true, 83 }, 84 "GET", 85 map[string]string{ 86 "Origin": "http://foobar.com", 87 }, 88 map[string]string{ 89 "Vary": "Origin", 90 "Access-Control-Allow-Origin": "*", 91 "Access-Control-Allow-Credentials": "true", 92 }, 93 true, 94 }, 95 { 96 "AllowedOrigin", 97 Options{ 98 AllowedOrigins: []string{"http://foobar.com"}, 99 }, 100 "GET", 101 map[string]string{ 102 "Origin": "http://foobar.com", 103 }, 104 map[string]string{ 105 "Vary": "Origin", 106 "Access-Control-Allow-Origin": "http://foobar.com", 107 }, 108 true, 109 }, 110 { 111 "WildcardOrigin", 112 Options{ 113 AllowedOrigins: []string{"http://*.bar.com"}, 114 }, 115 "GET", 116 map[string]string{ 117 "Origin": "http://foo.bar.com", 118 }, 119 map[string]string{ 120 "Vary": "Origin", 121 "Access-Control-Allow-Origin": "http://foo.bar.com", 122 }, 123 true, 124 }, 125 { 126 "DisallowedOrigin", 127 Options{ 128 AllowedOrigins: []string{"http://foobar.com"}, 129 }, 130 "GET", 131 map[string]string{ 132 "Origin": "http://barbaz.com", 133 }, 134 map[string]string{ 135 "Vary": "Origin", 136 }, 137 false, 138 }, 139 { 140 "DisallowedWildcardOrigin", 141 Options{ 142 AllowedOrigins: []string{"http://*.bar.com"}, 143 }, 144 "GET", 145 map[string]string{ 146 "Origin": "http://foo.baz.com", 147 }, 148 map[string]string{ 149 "Vary": "Origin", 150 }, 151 false, 152 }, 153 { 154 "AllowedOriginFuncMatch", 155 Options{ 156 AllowOriginFunc: func(o string) bool { 157 return regexp.MustCompile("^http://foo").MatchString(o) 158 }, 159 }, 160 "GET", 161 map[string]string{ 162 "Origin": "http://foobar.com", 163 }, 164 map[string]string{ 165 "Vary": "Origin", 166 "Access-Control-Allow-Origin": "http://foobar.com", 167 }, 168 true, 169 }, 170 { 171 "AllowOriginRequestFuncMatch", 172 Options{ 173 AllowOriginRequestFunc: func(r *http.Request, o string) bool { 174 return regexp.MustCompile("^http://foo").MatchString(o) && r.Header.Get("Authorization") == "secret" 175 }, 176 }, 177 "GET", 178 map[string]string{ 179 "Origin": "http://foobar.com", 180 "Authorization": "secret", 181 }, 182 map[string]string{ 183 "Vary": "Origin", 184 "Access-Control-Allow-Origin": "http://foobar.com", 185 }, 186 true, 187 }, 188 { 189 "AllowOriginRequestFuncNotMatch", 190 Options{ 191 AllowOriginRequestFunc: func(r *http.Request, o string) bool { 192 return regexp.MustCompile("^http://foo").MatchString(o) && r.Header.Get("Authorization") == "secret" 193 }, 194 }, 195 "GET", 196 map[string]string{ 197 "Origin": "http://foobar.com", 198 "Authorization": "not-secret", 199 }, 200 map[string]string{ 201 "Vary": "Origin", 202 }, 203 false, 204 }, 205 { 206 "MaxAge", 207 Options{ 208 AllowedOrigins: []string{"http://example.com/"}, 209 AllowedMethods: []string{"GET"}, 210 MaxAge: 10, 211 }, 212 "OPTIONS", 213 map[string]string{ 214 "Origin": "http://example.com/", 215 "Access-Control-Request-Method": "GET", 216 }, 217 map[string]string{ 218 "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 219 "Access-Control-Allow-Origin": "http://example.com/", 220 "Access-Control-Allow-Methods": "GET", 221 "Access-Control-Max-Age": "10", 222 }, 223 true, 224 }, 225 { 226 "AllowedMethod", 227 Options{ 228 AllowedOrigins: []string{"http://foobar.com"}, 229 AllowedMethods: []string{"PUT", "DELETE"}, 230 }, 231 "OPTIONS", 232 map[string]string{ 233 "Origin": "http://foobar.com", 234 "Access-Control-Request-Method": "PUT", 235 }, 236 map[string]string{ 237 "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 238 "Access-Control-Allow-Origin": "http://foobar.com", 239 "Access-Control-Allow-Methods": "PUT", 240 }, 241 true, 242 }, 243 { 244 "DisallowedMethod", 245 Options{ 246 AllowedOrigins: []string{"http://foobar.com"}, 247 AllowedMethods: []string{"PUT", "DELETE"}, 248 }, 249 "OPTIONS", 250 map[string]string{ 251 "Origin": "http://foobar.com", 252 "Access-Control-Request-Method": "PATCH", 253 }, 254 map[string]string{ 255 "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 256 }, 257 true, 258 }, 259 { 260 "AllowedHeaders", 261 Options{ 262 AllowedOrigins: []string{"http://foobar.com"}, 263 AllowedHeaders: []string{"X-Header-1", "x-header-2"}, 264 }, 265 "OPTIONS", 266 map[string]string{ 267 "Origin": "http://foobar.com", 268 "Access-Control-Request-Method": "GET", 269 "Access-Control-Request-Headers": "X-Header-2, X-HEADER-1", 270 }, 271 map[string]string{ 272 "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 273 "Access-Control-Allow-Origin": "http://foobar.com", 274 "Access-Control-Allow-Methods": "GET", 275 "Access-Control-Allow-Headers": "X-Header-2, X-Header-1", 276 }, 277 true, 278 }, 279 { 280 "DefaultAllowedHeaders", 281 Options{ 282 AllowedOrigins: []string{"http://foobar.com"}, 283 AllowedHeaders: []string{}, 284 }, 285 "OPTIONS", 286 map[string]string{ 287 "Origin": "http://foobar.com", 288 "Access-Control-Request-Method": "GET", 289 "Access-Control-Request-Headers": "X-Requested-With", 290 }, 291 map[string]string{ 292 "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 293 "Access-Control-Allow-Origin": "http://foobar.com", 294 "Access-Control-Allow-Methods": "GET", 295 "Access-Control-Allow-Headers": "X-Requested-With", 296 }, 297 true, 298 }, 299 { 300 "AllowedWildcardHeader", 301 Options{ 302 AllowedOrigins: []string{"http://foobar.com"}, 303 AllowedHeaders: []string{"*"}, 304 }, 305 "OPTIONS", 306 map[string]string{ 307 "Origin": "http://foobar.com", 308 "Access-Control-Request-Method": "GET", 309 "Access-Control-Request-Headers": "X-Header-2, X-HEADER-1", 310 }, 311 map[string]string{ 312 "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 313 "Access-Control-Allow-Origin": "http://foobar.com", 314 "Access-Control-Allow-Methods": "GET", 315 "Access-Control-Allow-Headers": "X-Header-2, X-Header-1", 316 }, 317 true, 318 }, 319 { 320 "DisallowedHeader", 321 Options{ 322 AllowedOrigins: []string{"http://foobar.com"}, 323 AllowedHeaders: []string{"X-Header-1", "x-header-2"}, 324 }, 325 "OPTIONS", 326 map[string]string{ 327 "Origin": "http://foobar.com", 328 "Access-Control-Request-Method": "GET", 329 "Access-Control-Request-Headers": "X-Header-3, X-Header-1", 330 }, 331 map[string]string{ 332 "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 333 }, 334 true, 335 }, 336 { 337 "OriginHeader", 338 Options{ 339 AllowedOrigins: []string{"http://foobar.com"}, 340 }, 341 "OPTIONS", 342 map[string]string{ 343 "Origin": "http://foobar.com", 344 "Access-Control-Request-Method": "GET", 345 "Access-Control-Request-Headers": "origin", 346 }, 347 map[string]string{ 348 "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 349 "Access-Control-Allow-Origin": "http://foobar.com", 350 "Access-Control-Allow-Methods": "GET", 351 "Access-Control-Allow-Headers": "Origin", 352 }, 353 true, 354 }, 355 { 356 "ExposedHeader", 357 Options{ 358 AllowedOrigins: []string{"http://foobar.com"}, 359 ExposedHeaders: []string{"X-Header-1", "x-header-2"}, 360 }, 361 "GET", 362 map[string]string{ 363 "Origin": "http://foobar.com", 364 }, 365 map[string]string{ 366 "Vary": "Origin", 367 "Access-Control-Allow-Origin": "http://foobar.com", 368 "Access-Control-Expose-Headers": "X-Header-1, X-Header-2", 369 }, 370 true, 371 }, 372 { 373 "AllowedCredentials", 374 Options{ 375 AllowedOrigins: []string{"http://foobar.com"}, 376 AllowCredentials: true, 377 }, 378 "OPTIONS", 379 map[string]string{ 380 "Origin": "http://foobar.com", 381 "Access-Control-Request-Method": "GET", 382 }, 383 map[string]string{ 384 "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 385 "Access-Control-Allow-Origin": "http://foobar.com", 386 "Access-Control-Allow-Methods": "GET", 387 "Access-Control-Allow-Credentials": "true", 388 }, 389 true, 390 }, 391 { 392 "AllowedPrivateNetwork", 393 Options{ 394 AllowedOrigins: []string{"http://foobar.com"}, 395 AllowPrivateNetwork: true, 396 }, 397 "OPTIONS", 398 map[string]string{ 399 "Origin": "http://foobar.com", 400 "Access-Control-Request-Method": "GET", 401 "Access-Control-Request-Private-Network": "true", 402 }, 403 map[string]string{ 404 "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network", 405 "Access-Control-Allow-Origin": "http://foobar.com", 406 "Access-Control-Allow-Methods": "GET", 407 "Access-Control-Allow-Private-Network": "true", 408 }, 409 true, 410 }, 411 { 412 "DisallowedPrivateNetwork", 413 Options{ 414 AllowedOrigins: []string{"http://foobar.com"}, 415 }, 416 "OPTIONS", 417 map[string]string{ 418 "Origin": "http://foobar.com", 419 "Access-Control-Request-Method": "GET", 420 "Access-Control-Request-PrivateNetwork": "true", 421 }, 422 map[string]string{ 423 "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 424 "Access-Control-Allow-Origin": "http://foobar.com", 425 "Access-Control-Allow-Methods": "GET", 426 }, 427 true, 428 }, 429 { 430 "OptionPassthrough", 431 Options{ 432 OptionsPassthrough: true, 433 }, 434 "OPTIONS", 435 map[string]string{ 436 "Origin": "http://foobar.com", 437 "Access-Control-Request-Method": "GET", 438 }, 439 map[string]string{ 440 "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 441 "Access-Control-Allow-Origin": "*", 442 "Access-Control-Allow-Methods": "GET", 443 }, 444 true, 445 }, 446 { 447 "NonPreflightOptions", 448 Options{ 449 AllowedOrigins: []string{"http://foobar.com"}, 450 }, 451 "OPTIONS", 452 map[string]string{ 453 "Origin": "http://foobar.com", 454 }, 455 map[string]string{ 456 "Vary": "Origin", 457 "Access-Control-Allow-Origin": "http://foobar.com", 458 }, 459 true, 460 }, 461 } 462 for i := range cases { 463 tc := cases[i] 464 t.Run(tc.name, func(t *testing.T) { 465 s := New(tc.options) 466 467 req, _ := http.NewRequest(tc.method, "http://example.com/foo", nil) 468 for name, value := range tc.reqHeaders { 469 req.Header.Add(name, value) 470 } 471 472 t.Run("OriginAllowed", func(t *testing.T) { 473 if have, want := s.OriginAllowed(req), tc.originAllowed; have != want { 474 t.Errorf("OriginAllowed have: %t want: %t", have, want) 475 } 476 }) 477 478 t.Run("Handler", func(t *testing.T) { 479 res := httptest.NewRecorder() 480 s.Handler(testHandler).ServeHTTP(res, req) 481 assertHeaders(t, res.Header(), tc.resHeaders) 482 }) 483 t.Run("HandlerFunc", func(t *testing.T) { 484 res := httptest.NewRecorder() 485 s.HandlerFunc(res, req) 486 assertHeaders(t, res.Header(), tc.resHeaders) 487 }) 488 t.Run("Negroni", func(t *testing.T) { 489 res := httptest.NewRecorder() 490 s.ServeHTTP(res, req, testHandler) 491 assertHeaders(t, res.Header(), tc.resHeaders) 492 }) 493 494 }) 495 } 496 } 497 498 func TestDebug(t *testing.T) { 499 s := New(Options{ 500 Debug: true, 501 }) 502 503 if s.Log == nil { 504 t.Error("Logger not created when debug=true") 505 } 506 } 507 508 func TestDefault(t *testing.T) { 509 s := Default() 510 if s.Log != nil { 511 t.Error("c.log should be nil when Default") 512 } 513 if !s.allowedOriginsAll { 514 t.Error("c.allowedOriginsAll should be true when Default") 515 } 516 if s.allowedHeaders == nil { 517 t.Error("c.allowedHeaders should be nil when Default") 518 } 519 if s.allowedMethods == nil { 520 t.Error("c.allowedMethods should be nil when Default") 521 } 522 } 523 524 func TestHandlePreflightInvalidOriginAbortion(t *testing.T) { 525 s := New(Options{ 526 AllowedOrigins: []string{"http://foo.com"}, 527 }) 528 res := httptest.NewRecorder() 529 req, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil) 530 req.Header.Add("Origin", "http://example.com/") 531 532 s.handlePreflight(res, req) 533 534 assertHeaders(t, res.Header(), map[string]string{ 535 "Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", 536 }) 537 } 538 539 func TestHandlePreflightNoOptionsAbortion(t *testing.T) { 540 s := New(Options{ 541 // Intentionally left blank. 542 }) 543 res := httptest.NewRecorder() 544 req, _ := http.NewRequest("GET", "http://example.com/foo", nil) 545 546 s.handlePreflight(res, req) 547 548 assertHeaders(t, res.Header(), map[string]string{}) 549 } 550 551 func TestHandleActualRequestInvalidOriginAbortion(t *testing.T) { 552 s := New(Options{ 553 AllowedOrigins: []string{"http://foo.com"}, 554 }) 555 res := httptest.NewRecorder() 556 req, _ := http.NewRequest("GET", "http://example.com/foo", nil) 557 req.Header.Add("Origin", "http://example.com/") 558 559 s.handleActualRequest(res, req) 560 561 assertHeaders(t, res.Header(), map[string]string{ 562 "Vary": "Origin", 563 }) 564 } 565 566 func TestHandleActualRequestInvalidMethodAbortion(t *testing.T) { 567 s := New(Options{ 568 AllowedMethods: []string{"POST"}, 569 AllowCredentials: true, 570 }) 571 res := httptest.NewRecorder() 572 req, _ := http.NewRequest("GET", "http://example.com/foo", nil) 573 req.Header.Add("Origin", "http://example.com/") 574 575 s.handleActualRequest(res, req) 576 577 assertHeaders(t, res.Header(), map[string]string{ 578 "Vary": "Origin", 579 }) 580 } 581 582 func TestIsMethodAllowedReturnsFalseWithNoMethods(t *testing.T) { 583 s := New(Options{ 584 // Intentionally left blank. 585 }) 586 s.allowedMethods = []string{} 587 if s.isMethodAllowed("") { 588 t.Error("IsMethodAllowed should return false when c.allowedMethods is nil.") 589 } 590 } 591 592 func TestIsMethodAllowedReturnsTrueWithOptions(t *testing.T) { 593 s := New(Options{ 594 // Intentionally left blank. 595 }) 596 if !s.isMethodAllowed("OPTIONS") { 597 t.Error("IsMethodAllowed should return true when c.allowedMethods is nil.") 598 } 599 } 600 601 func TestOptionsSuccessStatusCodeDefault(t *testing.T) { 602 s := New(Options{ 603 // Intentionally left blank. 604 }) 605 606 req, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil) 607 req.Header.Add("Access-Control-Request-Method", "GET") 608 609 t.Run("Handler", func(t *testing.T) { 610 res := httptest.NewRecorder() 611 s.Handler(testHandler).ServeHTTP(res, req) 612 assertResponse(t, res, http.StatusNoContent) 613 }) 614 t.Run("HandlerFunc", func(t *testing.T) { 615 res := httptest.NewRecorder() 616 s.HandlerFunc(res, req) 617 assertResponse(t, res, http.StatusNoContent) 618 }) 619 t.Run("Negroni", func(t *testing.T) { 620 res := httptest.NewRecorder() 621 s.ServeHTTP(res, req, testHandler) 622 assertResponse(t, res, http.StatusNoContent) 623 }) 624 } 625 626 func TestOptionsSuccessStatusCodeOverride(t *testing.T) { 627 s := New(Options{ 628 OptionsSuccessStatus: http.StatusOK, 629 }) 630 631 req, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil) 632 req.Header.Add("Access-Control-Request-Method", "GET") 633 634 t.Run("Handler", func(t *testing.T) { 635 res := httptest.NewRecorder() 636 s.Handler(testHandler).ServeHTTP(res, req) 637 assertResponse(t, res, http.StatusOK) 638 }) 639 t.Run("HandlerFunc", func(t *testing.T) { 640 res := httptest.NewRecorder() 641 s.HandlerFunc(res, req) 642 assertResponse(t, res, http.StatusOK) 643 }) 644 t.Run("Negroni", func(t *testing.T) { 645 res := httptest.NewRecorder() 646 s.ServeHTTP(res, req, testHandler) 647 assertResponse(t, res, http.StatusOK) 648 }) 649 } 650 651 func TestCorsAreHeadersAllowed(t *testing.T) { 652 cases := []struct { 653 name string 654 allowedHeaders []string 655 requestedHeaders []string 656 want bool 657 }{ 658 { 659 name: "nil allowedHeaders", 660 allowedHeaders: nil, 661 requestedHeaders: parseHeaderList("X-PINGOTHER, Content-Type"), 662 want: false, 663 }, 664 { 665 name: "star allowedHeaders", 666 allowedHeaders: []string{"*"}, 667 requestedHeaders: parseHeaderList("X-PINGOTHER, Content-Type"), 668 want: true, 669 }, 670 { 671 name: "empty reqHeader", 672 allowedHeaders: nil, 673 requestedHeaders: parseHeaderList(""), 674 want: true, 675 }, 676 { 677 name: "match allowedHeaders", 678 allowedHeaders: []string{"Content-Type", "X-PINGOTHER", "X-APP-KEY"}, 679 requestedHeaders: parseHeaderList("X-PINGOTHER, Content-Type"), 680 want: true, 681 }, 682 { 683 name: "not matched allowedHeaders", 684 allowedHeaders: []string{"X-PINGOTHER"}, 685 requestedHeaders: parseHeaderList("X-API-KEY, Content-Type"), 686 want: false, 687 }, 688 { 689 name: "allowedHeaders should be a superset of requestedHeaders", 690 allowedHeaders: []string{"X-PINGOTHER"}, 691 requestedHeaders: parseHeaderList("X-PINGOTHER, Content-Type"), 692 want: false, 693 }, 694 } 695 696 for _, tt := range cases { 697 tt := tt 698 699 t.Run(tt.name, func(t *testing.T) { 700 c := New(Options{AllowedHeaders: tt.allowedHeaders}) 701 have := c.areHeadersAllowed(tt.requestedHeaders) 702 if have != tt.want { 703 t.Errorf("Cors.areHeadersAllowed() have: %t want: %t", have, tt.want) 704 } 705 }) 706 } 707 }