gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/handlers/cors_test.go (about) 1 package handlers 2 3 import ( 4 http "gitee.com/ks-custle/core-gm/gmhttp" 5 "gitee.com/ks-custle/core-gm/gmhttp/httptest" 6 "strings" 7 "testing" 8 ) 9 10 func TestDefaultCORSHandlerReturnsOk(t *testing.T) { 11 r := newRequest("GET", "http://www.example.com/") 12 rr := httptest.NewRecorder() 13 14 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 15 16 CORS()(testHandler).ServeHTTP(rr, r) 17 18 if got, want := rr.Code, http.StatusOK; got != want { 19 t.Fatalf("bad status: got %v want %v", got, want) 20 } 21 } 22 23 func TestDefaultCORSHandlerReturnsOkWithOrigin(t *testing.T) { 24 r := newRequest("GET", "http://www.example.com/") 25 r.Header.Set("Origin", r.URL.String()) 26 27 rr := httptest.NewRecorder() 28 29 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 30 31 CORS()(testHandler).ServeHTTP(rr, r) 32 33 if got, want := rr.Code, http.StatusOK; got != want { 34 t.Fatalf("bad status: got %v want %v", got, want) 35 } 36 } 37 38 func TestCORSHandlerIgnoreOptionsFallsThrough(t *testing.T) { 39 r := newRequest("OPTIONS", "http://www.example.com/") 40 r.Header.Set("Origin", r.URL.String()) 41 42 rr := httptest.NewRecorder() 43 44 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 45 w.WriteHeader(http.StatusTeapot) 46 }) 47 48 CORS(IgnoreOptions())(testHandler).ServeHTTP(rr, r) 49 50 if got, want := rr.Code, http.StatusTeapot; got != want { 51 t.Fatalf("bad status: got %v want %v", got, want) 52 } 53 } 54 55 func TestCORSHandlerSetsExposedHeaders(t *testing.T) { 56 // Test default configuration. 57 r := newRequest("GET", "http://www.example.com/") 58 r.Header.Set("Origin", r.URL.String()) 59 60 rr := httptest.NewRecorder() 61 62 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 63 64 CORS(ExposedHeaders([]string{"X-CORS-TEST"}))(testHandler).ServeHTTP(rr, r) 65 66 if got, want := rr.Code, http.StatusOK; got != want { 67 t.Fatalf("bad status: got %v want %v", got, want) 68 } 69 70 header := rr.HeaderMap.Get(corsExposeHeadersHeader) 71 if got, want := header, "X-Cors-Test"; got != want { 72 t.Fatalf("bad header: expected %q header, got empty header for method.", want) 73 } 74 } 75 76 func TestCORSHandlerUnsetRequestMethodForPreflightBadRequest(t *testing.T) { 77 r := newRequest("OPTIONS", "http://www.example.com/") 78 r.Header.Set("Origin", r.URL.String()) 79 80 rr := httptest.NewRecorder() 81 82 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 83 84 CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r) 85 86 if got, want := rr.Code, http.StatusBadRequest; got != want { 87 t.Fatalf("bad status: got %v want %v", got, want) 88 } 89 } 90 91 func TestCORSHandlerInvalidRequestMethodForPreflightMethodNotAllowed(t *testing.T) { 92 r := newRequest("OPTIONS", "http://www.example.com/") 93 r.Header.Set("Origin", r.URL.String()) 94 r.Header.Set(corsRequestMethodHeader, "DELETE") 95 96 rr := httptest.NewRecorder() 97 98 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 99 100 CORS()(testHandler).ServeHTTP(rr, r) 101 102 if got, want := rr.Code, http.StatusMethodNotAllowed; got != want { 103 t.Fatalf("bad status: got %v want %v", got, want) 104 } 105 } 106 107 func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) { 108 r := newRequest("OPTIONS", "http://www.example.com/") 109 r.Header.Set("Origin", r.URL.String()) 110 r.Header.Set(corsRequestMethodHeader, "GET") 111 112 rr := httptest.NewRecorder() 113 114 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 115 t.Fatal("Options request must not be passed to next handler") 116 }) 117 118 CORS()(testHandler).ServeHTTP(rr, r) 119 120 if got, want := rr.Code, http.StatusOK; got != want { 121 t.Fatalf("bad status: got %v want %v", got, want) 122 } 123 } 124 125 func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWithCustomStatusCode(t *testing.T) { 126 statusCode := http.StatusNoContent 127 r := newRequest("OPTIONS", "http://www.example.com/") 128 r.Header.Set("Origin", r.URL.String()) 129 r.Header.Set(corsRequestMethodHeader, "GET") 130 131 rr := httptest.NewRecorder() 132 133 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 134 t.Fatal("Options request must not be passed to next handler") 135 }) 136 137 CORS(OptionStatusCode(statusCode))(testHandler).ServeHTTP(rr, r) 138 139 if got, want := rr.Code, statusCode; got != want { 140 t.Fatalf("bad status: got %v want %v", got, want) 141 } 142 } 143 144 func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWhenOriginNotAllowed(t *testing.T) { 145 r := newRequest("OPTIONS", "http://www.example.com/") 146 r.Header.Set("Origin", r.URL.String()) 147 r.Header.Set(corsRequestMethodHeader, "GET") 148 149 rr := httptest.NewRecorder() 150 151 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 152 t.Fatal("Options request must not be passed to next handler") 153 }) 154 155 CORS(AllowedOrigins([]string{}))(testHandler).ServeHTTP(rr, r) 156 157 if got, want := rr.Code, http.StatusOK; got != want { 158 t.Fatalf("bad status: got %v want %v", got, want) 159 } 160 } 161 162 func TestCORSHandlerAllowedMethodForPreflight(t *testing.T) { 163 r := newRequest("OPTIONS", "http://www.example.com/") 164 r.Header.Set("Origin", r.URL.String()) 165 r.Header.Set(corsRequestMethodHeader, "DELETE") 166 167 rr := httptest.NewRecorder() 168 169 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 170 171 CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r) 172 173 if got, want := rr.Code, http.StatusOK; got != want { 174 t.Fatalf("bad status: got %v want %v", got, want) 175 } 176 177 header := rr.HeaderMap.Get(corsAllowMethodsHeader) 178 if got, want := header, "DELETE"; got != want { 179 t.Fatalf("bad header: expected %q method header, got %q header.", want, got) 180 } 181 } 182 183 func TestCORSHandlerAllowMethodsNotSetForSimpleRequestPreflight(t *testing.T) { 184 for _, method := range defaultCorsMethods { 185 r := newRequest("OPTIONS", "http://www.example.com/") 186 r.Header.Set("Origin", r.URL.String()) 187 r.Header.Set(corsRequestMethodHeader, method) 188 189 rr := httptest.NewRecorder() 190 191 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 192 193 CORS()(testHandler).ServeHTTP(rr, r) 194 195 if got, want := rr.Code, http.StatusOK; got != want { 196 t.Fatalf("bad status: got %v want %v", got, want) 197 } 198 199 header := rr.HeaderMap.Get(corsAllowMethodsHeader) 200 if got, want := header, ""; got != want { 201 t.Fatalf("bad header: expected %q method header, got %q.", want, got) 202 } 203 } 204 } 205 206 func TestCORSHandlerAllowedHeaderNotSetForSimpleRequestPreflight(t *testing.T) { 207 for _, simpleHeader := range defaultCorsHeaders { 208 r := newRequest("OPTIONS", "http://www.example.com/") 209 r.Header.Set("Origin", r.URL.String()) 210 r.Header.Set(corsRequestMethodHeader, "GET") 211 r.Header.Set(corsRequestHeadersHeader, simpleHeader) 212 213 rr := httptest.NewRecorder() 214 215 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 216 217 CORS()(testHandler).ServeHTTP(rr, r) 218 219 if got, want := rr.Code, http.StatusOK; got != want { 220 t.Fatalf("bad status: got %v want %v", got, want) 221 } 222 223 header := rr.HeaderMap.Get(corsAllowHeadersHeader) 224 if got, want := header, ""; got != want { 225 t.Fatalf("bad header: expected %q header, got %q.", want, got) 226 } 227 } 228 } 229 230 func TestCORSHandlerAllowedHeaderForPreflight(t *testing.T) { 231 r := newRequest("OPTIONS", "http://www.example.com/") 232 r.Header.Set("Origin", r.URL.String()) 233 r.Header.Set(corsRequestMethodHeader, "POST") 234 r.Header.Set(corsRequestHeadersHeader, "Content-Type") 235 236 rr := httptest.NewRecorder() 237 238 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 239 240 CORS(AllowedHeaders([]string{"Content-Type"}))(testHandler).ServeHTTP(rr, r) 241 242 if got, want := rr.Code, http.StatusOK; got != want { 243 t.Fatalf("bad status: got %v want %v", got, want) 244 } 245 246 header := rr.HeaderMap.Get(corsAllowHeadersHeader) 247 if got, want := header, "Content-Type"; got != want { 248 t.Fatalf("bad header: expected %q header, got %q header.", want, got) 249 } 250 } 251 252 func TestCORSHandlerInvalidHeaderForPreflightForbidden(t *testing.T) { 253 r := newRequest("OPTIONS", "http://www.example.com/") 254 r.Header.Set("Origin", r.URL.String()) 255 r.Header.Set(corsRequestMethodHeader, "POST") 256 r.Header.Set(corsRequestHeadersHeader, "Content-Type") 257 258 rr := httptest.NewRecorder() 259 260 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 261 262 CORS()(testHandler).ServeHTTP(rr, r) 263 264 if got, want := rr.Code, http.StatusForbidden; got != want { 265 t.Fatalf("bad status: got %v want %v", got, want) 266 } 267 } 268 269 func TestCORSHandlerMaxAgeForPreflight(t *testing.T) { 270 r := newRequest("OPTIONS", "http://www.example.com/") 271 r.Header.Set("Origin", r.URL.String()) 272 r.Header.Set(corsRequestMethodHeader, "POST") 273 274 rr := httptest.NewRecorder() 275 276 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 277 278 CORS(MaxAge(3500))(testHandler).ServeHTTP(rr, r) 279 280 if got, want := rr.Code, http.StatusOK; got != want { 281 t.Fatalf("bad status: got %v want %v", got, want) 282 } 283 284 header := rr.HeaderMap.Get(corsMaxAgeHeader) 285 if got, want := header, "600"; got != want { 286 t.Fatalf("bad header: expected %q to be %q, got %q.", corsMaxAgeHeader, want, got) 287 } 288 } 289 290 func TestCORSHandlerAllowedCredentials(t *testing.T) { 291 r := newRequest("GET", "http://www.example.com/") 292 r.Header.Set("Origin", r.URL.String()) 293 294 rr := httptest.NewRecorder() 295 296 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 297 298 CORS(AllowCredentials())(testHandler).ServeHTTP(rr, r) 299 300 if status := rr.Code; status != http.StatusOK { 301 t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 302 } 303 304 header := rr.HeaderMap.Get(corsAllowCredentialsHeader) 305 if got, want := header, "true"; got != want { 306 t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowCredentialsHeader, want, got) 307 } 308 } 309 310 func TestCORSHandlerMultipleAllowOriginsSetsVaryHeader(t *testing.T) { 311 r := newRequest("GET", "http://www.example.com/") 312 r.Header.Set("Origin", r.URL.String()) 313 314 rr := httptest.NewRecorder() 315 316 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 317 318 CORS(AllowedOrigins([]string{r.URL.String(), "http://google.com"}))(testHandler).ServeHTTP(rr, r) 319 320 if status := rr.Code; status != http.StatusOK { 321 t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 322 } 323 324 header := rr.HeaderMap.Get(corsVaryHeader) 325 if got, want := header, corsOriginHeader; got != want { 326 t.Fatalf("bad header: expected %s to be %q, got %q.", corsVaryHeader, want, got) 327 } 328 } 329 330 func TestCORSWithMultipleHandlers(t *testing.T) { 331 var lastHandledBy string 332 corsMiddleware := CORS() 333 334 testHandler1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 335 lastHandledBy = "testHandler1" 336 }) 337 testHandler2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 338 lastHandledBy = "testHandler2" 339 }) 340 341 r1 := newRequest("GET", "http://www.example.com/") 342 rr1 := httptest.NewRecorder() 343 handler1 := corsMiddleware(testHandler1) 344 345 corsMiddleware(testHandler2) 346 347 handler1.ServeHTTP(rr1, r1) 348 if lastHandledBy != "testHandler1" { 349 t.Fatalf("bad CORS() registration: Handler served should be Handler registered") 350 } 351 } 352 353 func TestCORSOriginValidatorWithImplicitStar(t *testing.T) { 354 r := newRequest("GET", "http://a.example.com") 355 r.Header.Set("Origin", r.URL.String()) 356 rr := httptest.NewRecorder() 357 358 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 359 360 originValidator := func(origin string) bool { 361 if strings.HasSuffix(origin, ".example.com") { 362 return true 363 } 364 return false 365 } 366 367 CORS(AllowedOriginValidator(originValidator))(testHandler).ServeHTTP(rr, r) 368 header := rr.HeaderMap.Get(corsAllowOriginHeader) 369 if got, want := header, r.URL.String(); got != want { 370 t.Fatalf("bad header: expected %s to be %q, got %q.", corsAllowOriginHeader, want, got) 371 } 372 } 373 374 func TestCORSOriginValidatorWithExplicitStar(t *testing.T) { 375 r := newRequest("GET", "http://a.example.com") 376 r.Header.Set("Origin", r.URL.String()) 377 rr := httptest.NewRecorder() 378 379 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 380 381 originValidator := func(origin string) bool { 382 if strings.HasSuffix(origin, ".example.com") { 383 return true 384 } 385 return false 386 } 387 388 CORS( 389 AllowedOriginValidator(originValidator), 390 AllowedOrigins([]string{"*"}), 391 )(testHandler).ServeHTTP(rr, r) 392 header := rr.HeaderMap.Get(corsAllowOriginHeader) 393 if got, want := header, "*"; got != want { 394 t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowOriginHeader, want, got) 395 } 396 } 397 398 func TestCORSAllowStar(t *testing.T) { 399 r := newRequest("GET", "http://a.example.com") 400 r.Header.Set("Origin", r.URL.String()) 401 rr := httptest.NewRecorder() 402 403 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 404 405 CORS()(testHandler).ServeHTTP(rr, r) 406 header := rr.HeaderMap.Get(corsAllowOriginHeader) 407 if got, want := header, "*"; got != want { 408 t.Fatalf("bad header: expected %q to be %q, got %q.", corsAllowOriginHeader, want, got) 409 } 410 }