github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/handlers/cors_test.go (about) 1 package handlers 2 3 import ( 4 "strings" 5 "testing" 6 7 http "github.com/hxx258456/ccgo/gmhttp" 8 "github.com/hxx258456/ccgo/gmhttp/httptest" 9 ) 10 11 func TestDefaultCORSHandlerReturnsOk(t *testing.T) { 12 r := newRequest("GET", "http://www.example.com/") 13 rr := httptest.NewRecorder() 14 15 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 16 17 CORS()(testHandler).ServeHTTP(rr, r) 18 19 if got, want := rr.Code, http.StatusOK; got != want { 20 t.Fatalf("bad status: got %v want %v", got, want) 21 } 22 } 23 24 func TestDefaultCORSHandlerReturnsOkWithOrigin(t *testing.T) { 25 r := newRequest("GET", "http://www.example.com/") 26 r.Header.Set("Origin", r.URL.String()) 27 28 rr := httptest.NewRecorder() 29 30 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 31 32 CORS()(testHandler).ServeHTTP(rr, r) 33 34 if got, want := rr.Code, http.StatusOK; got != want { 35 t.Fatalf("bad status: got %v want %v", got, want) 36 } 37 } 38 39 func TestCORSHandlerIgnoreOptionsFallsThrough(t *testing.T) { 40 r := newRequest("OPTIONS", "http://www.example.com/") 41 r.Header.Set("Origin", r.URL.String()) 42 43 rr := httptest.NewRecorder() 44 45 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 46 w.WriteHeader(http.StatusTeapot) 47 }) 48 49 CORS(IgnoreOptions())(testHandler).ServeHTTP(rr, r) 50 51 if got, want := rr.Code, http.StatusTeapot; got != want { 52 t.Fatalf("bad status: got %v want %v", got, want) 53 } 54 } 55 56 func TestCORSHandlerSetsExposedHeaders(t *testing.T) { 57 // Test default configuration. 58 r := newRequest("GET", "http://www.example.com/") 59 r.Header.Set("Origin", r.URL.String()) 60 61 rr := httptest.NewRecorder() 62 63 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 64 65 CORS(ExposedHeaders([]string{"X-CORS-TEST"}))(testHandler).ServeHTTP(rr, r) 66 67 if got, want := rr.Code, http.StatusOK; got != want { 68 t.Fatalf("bad status: got %v want %v", got, want) 69 } 70 71 header := rr.HeaderMap.Get(corsExposeHeadersHeader) 72 if got, want := header, "X-Cors-Test"; got != want { 73 t.Fatalf("bad header: expected %q header, got empty header for method.", want) 74 } 75 } 76 77 func TestCORSHandlerUnsetRequestMethodForPreflightBadRequest(t *testing.T) { 78 r := newRequest("OPTIONS", "http://www.example.com/") 79 r.Header.Set("Origin", r.URL.String()) 80 81 rr := httptest.NewRecorder() 82 83 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 84 85 CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r) 86 87 if got, want := rr.Code, http.StatusBadRequest; got != want { 88 t.Fatalf("bad status: got %v want %v", got, want) 89 } 90 } 91 92 func TestCORSHandlerInvalidRequestMethodForPreflightMethodNotAllowed(t *testing.T) { 93 r := newRequest("OPTIONS", "http://www.example.com/") 94 r.Header.Set("Origin", r.URL.String()) 95 r.Header.Set(corsRequestMethodHeader, "DELETE") 96 97 rr := httptest.NewRecorder() 98 99 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 100 101 CORS()(testHandler).ServeHTTP(rr, r) 102 103 if got, want := rr.Code, http.StatusMethodNotAllowed; got != want { 104 t.Fatalf("bad status: got %v want %v", got, want) 105 } 106 } 107 108 func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) { 109 r := newRequest("OPTIONS", "http://www.example.com/") 110 r.Header.Set("Origin", r.URL.String()) 111 r.Header.Set(corsRequestMethodHeader, "GET") 112 113 rr := httptest.NewRecorder() 114 115 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 116 t.Fatal("Options request must not be passed to next handler") 117 }) 118 119 CORS()(testHandler).ServeHTTP(rr, r) 120 121 if got, want := rr.Code, http.StatusOK; got != want { 122 t.Fatalf("bad status: got %v want %v", got, want) 123 } 124 } 125 126 func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWithCustomStatusCode(t *testing.T) { 127 statusCode := http.StatusNoContent 128 r := newRequest("OPTIONS", "http://www.example.com/") 129 r.Header.Set("Origin", r.URL.String()) 130 r.Header.Set(corsRequestMethodHeader, "GET") 131 132 rr := httptest.NewRecorder() 133 134 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 135 t.Fatal("Options request must not be passed to next handler") 136 }) 137 138 CORS(OptionStatusCode(statusCode))(testHandler).ServeHTTP(rr, r) 139 140 if got, want := rr.Code, statusCode; got != want { 141 t.Fatalf("bad status: got %v want %v", got, want) 142 } 143 } 144 145 func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWhenOriginNotAllowed(t *testing.T) { 146 r := newRequest("OPTIONS", "http://www.example.com/") 147 r.Header.Set("Origin", r.URL.String()) 148 r.Header.Set(corsRequestMethodHeader, "GET") 149 150 rr := httptest.NewRecorder() 151 152 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 153 t.Fatal("Options request must not be passed to next handler") 154 }) 155 156 CORS(AllowedOrigins([]string{}))(testHandler).ServeHTTP(rr, r) 157 158 if got, want := rr.Code, http.StatusOK; got != want { 159 t.Fatalf("bad status: got %v want %v", got, want) 160 } 161 } 162 163 func TestCORSHandlerAllowedMethodForPreflight(t *testing.T) { 164 r := newRequest("OPTIONS", "http://www.example.com/") 165 r.Header.Set("Origin", r.URL.String()) 166 r.Header.Set(corsRequestMethodHeader, "DELETE") 167 168 rr := httptest.NewRecorder() 169 170 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 171 172 CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r) 173 174 if got, want := rr.Code, http.StatusOK; got != want { 175 t.Fatalf("bad status: got %v want %v", got, want) 176 } 177 178 header := rr.HeaderMap.Get(corsAllowMethodsHeader) 179 if got, want := header, "DELETE"; got != want { 180 t.Fatalf("bad header: expected %q method header, got %q header.", want, got) 181 } 182 } 183 184 func TestCORSHandlerAllowMethodsNotSetForSimpleRequestPreflight(t *testing.T) { 185 for _, method := range defaultCorsMethods { 186 r := newRequest("OPTIONS", "http://www.example.com/") 187 r.Header.Set("Origin", r.URL.String()) 188 r.Header.Set(corsRequestMethodHeader, method) 189 190 rr := httptest.NewRecorder() 191 192 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 193 194 CORS()(testHandler).ServeHTTP(rr, r) 195 196 if got, want := rr.Code, http.StatusOK; got != want { 197 t.Fatalf("bad status: got %v want %v", got, want) 198 } 199 200 header := rr.HeaderMap.Get(corsAllowMethodsHeader) 201 if got, want := header, ""; got != want { 202 t.Fatalf("bad header: expected %q method header, got %q.", want, got) 203 } 204 } 205 } 206 207 func TestCORSHandlerAllowedHeaderNotSetForSimpleRequestPreflight(t *testing.T) { 208 for _, simpleHeader := range defaultCorsHeaders { 209 r := newRequest("OPTIONS", "http://www.example.com/") 210 r.Header.Set("Origin", r.URL.String()) 211 r.Header.Set(corsRequestMethodHeader, "GET") 212 r.Header.Set(corsRequestHeadersHeader, simpleHeader) 213 214 rr := httptest.NewRecorder() 215 216 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 217 218 CORS()(testHandler).ServeHTTP(rr, r) 219 220 if got, want := rr.Code, http.StatusOK; got != want { 221 t.Fatalf("bad status: got %v want %v", got, want) 222 } 223 224 header := rr.HeaderMap.Get(corsAllowHeadersHeader) 225 if got, want := header, ""; got != want { 226 t.Fatalf("bad header: expected %q header, got %q.", want, got) 227 } 228 } 229 } 230 231 func TestCORSHandlerAllowedHeaderForPreflight(t *testing.T) { 232 r := newRequest("OPTIONS", "http://www.example.com/") 233 r.Header.Set("Origin", r.URL.String()) 234 r.Header.Set(corsRequestMethodHeader, "POST") 235 r.Header.Set(corsRequestHeadersHeader, "Content-Type") 236 237 rr := httptest.NewRecorder() 238 239 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 240 241 CORS(AllowedHeaders([]string{"Content-Type"}))(testHandler).ServeHTTP(rr, r) 242 243 if got, want := rr.Code, http.StatusOK; got != want { 244 t.Fatalf("bad status: got %v want %v", got, want) 245 } 246 247 header := rr.HeaderMap.Get(corsAllowHeadersHeader) 248 if got, want := header, "Content-Type"; got != want { 249 t.Fatalf("bad header: expected %q header, got %q header.", want, got) 250 } 251 } 252 253 func TestCORSHandlerInvalidHeaderForPreflightForbidden(t *testing.T) { 254 r := newRequest("OPTIONS", "http://www.example.com/") 255 r.Header.Set("Origin", r.URL.String()) 256 r.Header.Set(corsRequestMethodHeader, "POST") 257 r.Header.Set(corsRequestHeadersHeader, "Content-Type") 258 259 rr := httptest.NewRecorder() 260 261 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 262 263 CORS()(testHandler).ServeHTTP(rr, r) 264 265 if got, want := rr.Code, http.StatusForbidden; got != want { 266 t.Fatalf("bad status: got %v want %v", got, want) 267 } 268 } 269 270 func TestCORSHandlerMaxAgeForPreflight(t *testing.T) { 271 r := newRequest("OPTIONS", "http://www.example.com/") 272 r.Header.Set("Origin", r.URL.String()) 273 r.Header.Set(corsRequestMethodHeader, "POST") 274 275 rr := httptest.NewRecorder() 276 277 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 278 279 CORS(MaxAge(3500))(testHandler).ServeHTTP(rr, r) 280 281 if got, want := rr.Code, http.StatusOK; got != want { 282 t.Fatalf("bad status: got %v want %v", got, want) 283 } 284 285 header := rr.HeaderMap.Get(corsMaxAgeHeader) 286 if got, want := header, "600"; got != want { 287 t.Fatalf("bad header: expected %q to be %q, got %q.", corsMaxAgeHeader, want, got) 288 } 289 } 290 291 func TestCORSHandlerAllowedCredentials(t *testing.T) { 292 r := newRequest("GET", "http://www.example.com/") 293 r.Header.Set("Origin", r.URL.String()) 294 295 rr := httptest.NewRecorder() 296 297 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 298 299 CORS(AllowCredentials())(testHandler).ServeHTTP(rr, r) 300 301 if status := rr.Code; status != http.StatusOK { 302 t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 303 } 304 305 header := rr.HeaderMap.Get(corsAllowCredentialsHeader) 306 if got, want := header, "true"; got != want { 307 t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowCredentialsHeader, want, got) 308 } 309 } 310 311 func TestCORSHandlerMultipleAllowOriginsSetsVaryHeader(t *testing.T) { 312 r := newRequest("GET", "http://www.example.com/") 313 r.Header.Set("Origin", r.URL.String()) 314 315 rr := httptest.NewRecorder() 316 317 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 318 319 CORS(AllowedOrigins([]string{r.URL.String(), "http://google.com"}))(testHandler).ServeHTTP(rr, r) 320 321 if status := rr.Code; status != http.StatusOK { 322 t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 323 } 324 325 header := rr.HeaderMap.Get(corsVaryHeader) 326 if got, want := header, corsOriginHeader; got != want { 327 t.Fatalf("bad header: expected %s to be %q, got %q.", corsVaryHeader, want, got) 328 } 329 } 330 331 func TestCORSWithMultipleHandlers(t *testing.T) { 332 var lastHandledBy string 333 corsMiddleware := CORS() 334 335 testHandler1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 336 lastHandledBy = "testHandler1" 337 }) 338 testHandler2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 339 lastHandledBy = "testHandler2" 340 }) 341 342 r1 := newRequest("GET", "http://www.example.com/") 343 rr1 := httptest.NewRecorder() 344 handler1 := corsMiddleware(testHandler1) 345 346 corsMiddleware(testHandler2) 347 348 handler1.ServeHTTP(rr1, r1) 349 if lastHandledBy != "testHandler1" { 350 t.Fatalf("bad CORS() registration: Handler served should be Handler registered") 351 } 352 } 353 354 func TestCORSOriginValidatorWithImplicitStar(t *testing.T) { 355 r := newRequest("GET", "http://a.example.com") 356 r.Header.Set("Origin", r.URL.String()) 357 rr := httptest.NewRecorder() 358 359 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 360 361 originValidator := func(origin string) bool { 362 if strings.HasSuffix(origin, ".example.com") { 363 return true 364 } 365 return false 366 } 367 368 CORS(AllowedOriginValidator(originValidator))(testHandler).ServeHTTP(rr, r) 369 header := rr.HeaderMap.Get(corsAllowOriginHeader) 370 if got, want := header, r.URL.String(); got != want { 371 t.Fatalf("bad header: expected %s to be %q, got %q.", corsAllowOriginHeader, want, got) 372 } 373 } 374 375 func TestCORSOriginValidatorWithExplicitStar(t *testing.T) { 376 r := newRequest("GET", "http://a.example.com") 377 r.Header.Set("Origin", r.URL.String()) 378 rr := httptest.NewRecorder() 379 380 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 381 382 originValidator := func(origin string) bool { 383 if strings.HasSuffix(origin, ".example.com") { 384 return true 385 } 386 return false 387 } 388 389 CORS( 390 AllowedOriginValidator(originValidator), 391 AllowedOrigins([]string{"*"}), 392 )(testHandler).ServeHTTP(rr, r) 393 header := rr.HeaderMap.Get(corsAllowOriginHeader) 394 if got, want := header, "*"; got != want { 395 t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowOriginHeader, want, got) 396 } 397 } 398 399 func TestCORSAllowStar(t *testing.T) { 400 r := newRequest("GET", "http://a.example.com") 401 r.Header.Set("Origin", r.URL.String()) 402 rr := httptest.NewRecorder() 403 404 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 405 406 CORS()(testHandler).ServeHTTP(rr, r) 407 header := rr.HeaderMap.Get(corsAllowOriginHeader) 408 if got, want := header, "*"; got != want { 409 t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowOriginHeader, want, got) 410 } 411 }