gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/mux/middleware_test.go (about) 1 package mux 2 3 import ( 4 "bytes" 5 "testing" 6 7 http "gitee.com/ks-custle/core-gm/gmhttp" 8 ) 9 10 type testMiddleware struct { 11 timesCalled uint 12 } 13 14 func (tm *testMiddleware) Middleware(h http.Handler) http.Handler { 15 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 16 tm.timesCalled++ 17 h.ServeHTTP(w, r) 18 }) 19 } 20 21 func dummyHandler(w http.ResponseWriter, r *http.Request) {} 22 23 func TestMiddlewareAdd(t *testing.T) { 24 router := NewRouter() 25 router.HandleFunc("/", dummyHandler).Methods("GET") 26 27 mw := &testMiddleware{} 28 29 router.useInterface(mw) 30 if len(router.middlewares) != 1 || router.middlewares[0] != mw { 31 t.Fatal("Middleware interface was not added correctly") 32 } 33 34 router.Use(mw.Middleware) 35 if len(router.middlewares) != 2 { 36 t.Fatal("Middleware method was not added correctly") 37 } 38 39 banalMw := func(handler http.Handler) http.Handler { 40 return handler 41 } 42 router.Use(banalMw) 43 if len(router.middlewares) != 3 { 44 t.Fatal("Middleware function was not added correctly") 45 } 46 } 47 48 func TestMiddleware(t *testing.T) { 49 router := NewRouter() 50 router.HandleFunc("/", dummyHandler).Methods("GET") 51 52 mw := &testMiddleware{} 53 router.useInterface(mw) 54 55 rw := NewRecorder() 56 req := newRequest("GET", "/") 57 58 t.Run("regular middleware call", func(t *testing.T) { 59 router.ServeHTTP(rw, req) 60 if mw.timesCalled != 1 { 61 t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) 62 } 63 }) 64 65 t.Run("not called for 404", func(t *testing.T) { 66 req = newRequest("GET", "/not/found") 67 router.ServeHTTP(rw, req) 68 if mw.timesCalled != 1 { 69 t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) 70 } 71 }) 72 73 t.Run("not called for method mismatch", func(t *testing.T) { 74 req = newRequest("POST", "/") 75 router.ServeHTTP(rw, req) 76 if mw.timesCalled != 1 { 77 t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) 78 } 79 }) 80 81 t.Run("regular call using function middleware", func(t *testing.T) { 82 router.Use(mw.Middleware) 83 req = newRequest("GET", "/") 84 router.ServeHTTP(rw, req) 85 if mw.timesCalled != 3 { 86 t.Fatalf("Expected %d calls, but got only %d", 3, mw.timesCalled) 87 } 88 }) 89 } 90 91 func TestMiddlewareSubrouter(t *testing.T) { 92 router := NewRouter() 93 router.HandleFunc("/", dummyHandler).Methods("GET") 94 95 subrouter := router.PathPrefix("/sub").Subrouter() 96 subrouter.HandleFunc("/x", dummyHandler).Methods("GET") 97 98 mw := &testMiddleware{} 99 subrouter.useInterface(mw) 100 101 rw := NewRecorder() 102 req := newRequest("GET", "/") 103 104 t.Run("not called for route outside subrouter", func(t *testing.T) { 105 router.ServeHTTP(rw, req) 106 if mw.timesCalled != 0 { 107 t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled) 108 } 109 }) 110 111 t.Run("not called for subrouter root 404", func(t *testing.T) { 112 req = newRequest("GET", "/sub/") 113 router.ServeHTTP(rw, req) 114 if mw.timesCalled != 0 { 115 t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled) 116 } 117 }) 118 119 t.Run("called once for route inside subrouter", func(t *testing.T) { 120 req = newRequest("GET", "/sub/x") 121 router.ServeHTTP(rw, req) 122 if mw.timesCalled != 1 { 123 t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) 124 } 125 }) 126 127 t.Run("not called for 404 inside subrouter", func(t *testing.T) { 128 req = newRequest("GET", "/sub/not/found") 129 router.ServeHTTP(rw, req) 130 if mw.timesCalled != 1 { 131 t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled) 132 } 133 }) 134 135 t.Run("middleware added to router", func(t *testing.T) { 136 router.useInterface(mw) 137 138 t.Run("called once for route outside subrouter", func(t *testing.T) { 139 req = newRequest("GET", "/") 140 router.ServeHTTP(rw, req) 141 if mw.timesCalled != 2 { 142 t.Fatalf("Expected %d calls, but got only %d", 2, mw.timesCalled) 143 } 144 }) 145 146 t.Run("called twice for route inside subrouter", func(t *testing.T) { 147 req = newRequest("GET", "/sub/x") 148 router.ServeHTTP(rw, req) 149 if mw.timesCalled != 4 { 150 t.Fatalf("Expected %d calls, but got only %d", 4, mw.timesCalled) 151 } 152 }) 153 }) 154 } 155 156 func TestMiddlewareExecution(t *testing.T) { 157 mwStr := []byte("Middleware\n") 158 handlerStr := []byte("Logic\n") 159 160 router := NewRouter() 161 router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { 162 w.Write(handlerStr) 163 }) 164 165 t.Run("responds normally without middleware", func(t *testing.T) { 166 rw := NewRecorder() 167 req := newRequest("GET", "/") 168 169 router.ServeHTTP(rw, req) 170 171 if !bytes.Equal(rw.Body.Bytes(), handlerStr) { 172 t.Fatal("Handler response is not what it should be") 173 } 174 }) 175 176 t.Run("responds with handler and middleware response", func(t *testing.T) { 177 rw := NewRecorder() 178 req := newRequest("GET", "/") 179 180 router.Use(func(h http.Handler) http.Handler { 181 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 182 w.Write(mwStr) 183 h.ServeHTTP(w, r) 184 }) 185 }) 186 187 router.ServeHTTP(rw, req) 188 if !bytes.Equal(rw.Body.Bytes(), append(mwStr, handlerStr...)) { 189 t.Fatal("Middleware + handler response is not what it should be") 190 } 191 }) 192 } 193 194 func TestMiddlewareNotFound(t *testing.T) { 195 mwStr := []byte("Middleware\n") 196 handlerStr := []byte("Logic\n") 197 198 router := NewRouter() 199 router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { 200 w.Write(handlerStr) 201 }) 202 router.Use(func(h http.Handler) http.Handler { 203 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 204 w.Write(mwStr) 205 h.ServeHTTP(w, r) 206 }) 207 }) 208 209 // Test not found call with default handler 210 t.Run("not called", func(t *testing.T) { 211 rw := NewRecorder() 212 req := newRequest("GET", "/notfound") 213 214 router.ServeHTTP(rw, req) 215 if bytes.Contains(rw.Body.Bytes(), mwStr) { 216 t.Fatal("Middleware was called for a 404") 217 } 218 }) 219 220 t.Run("not called with custom not found handler", func(t *testing.T) { 221 rw := NewRecorder() 222 req := newRequest("GET", "/notfound") 223 224 router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 225 rw.Write([]byte("Custom 404 handler")) 226 }) 227 router.ServeHTTP(rw, req) 228 229 if bytes.Contains(rw.Body.Bytes(), mwStr) { 230 t.Fatal("Middleware was called for a custom 404") 231 } 232 }) 233 } 234 235 func TestMiddlewareMethodMismatch(t *testing.T) { 236 mwStr := []byte("Middleware\n") 237 handlerStr := []byte("Logic\n") 238 239 router := NewRouter() 240 router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { 241 w.Write(handlerStr) 242 }).Methods("GET") 243 244 router.Use(func(h http.Handler) http.Handler { 245 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 246 w.Write(mwStr) 247 h.ServeHTTP(w, r) 248 }) 249 }) 250 251 t.Run("not called", func(t *testing.T) { 252 rw := NewRecorder() 253 req := newRequest("POST", "/") 254 255 router.ServeHTTP(rw, req) 256 if bytes.Contains(rw.Body.Bytes(), mwStr) { 257 t.Fatal("Middleware was called for a method mismatch") 258 } 259 }) 260 261 t.Run("not called with custom method not allowed handler", func(t *testing.T) { 262 rw := NewRecorder() 263 req := newRequest("POST", "/") 264 265 router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 266 rw.Write([]byte("Method not allowed")) 267 }) 268 router.ServeHTTP(rw, req) 269 270 if bytes.Contains(rw.Body.Bytes(), mwStr) { 271 t.Fatal("Middleware was called for a method mismatch") 272 } 273 }) 274 } 275 276 func TestMiddlewareNotFoundSubrouter(t *testing.T) { 277 mwStr := []byte("Middleware\n") 278 handlerStr := []byte("Logic\n") 279 280 router := NewRouter() 281 router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { 282 w.Write(handlerStr) 283 }) 284 285 subrouter := router.PathPrefix("/sub/").Subrouter() 286 subrouter.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { 287 w.Write(handlerStr) 288 }) 289 290 router.Use(func(h http.Handler) http.Handler { 291 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 292 w.Write(mwStr) 293 h.ServeHTTP(w, r) 294 }) 295 }) 296 297 t.Run("not called", func(t *testing.T) { 298 rw := NewRecorder() 299 req := newRequest("GET", "/sub/notfound") 300 301 router.ServeHTTP(rw, req) 302 if bytes.Contains(rw.Body.Bytes(), mwStr) { 303 t.Fatal("Middleware was called for a 404") 304 } 305 }) 306 307 t.Run("not called with custom not found handler", func(t *testing.T) { 308 rw := NewRecorder() 309 req := newRequest("GET", "/sub/notfound") 310 311 subrouter.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 312 rw.Write([]byte("Custom 404 handler")) 313 }) 314 router.ServeHTTP(rw, req) 315 316 if bytes.Contains(rw.Body.Bytes(), mwStr) { 317 t.Fatal("Middleware was called for a custom 404") 318 } 319 }) 320 } 321 322 func TestMiddlewareMethodMismatchSubrouter(t *testing.T) { 323 mwStr := []byte("Middleware\n") 324 handlerStr := []byte("Logic\n") 325 326 router := NewRouter() 327 router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { 328 w.Write(handlerStr) 329 }) 330 331 subrouter := router.PathPrefix("/sub/").Subrouter() 332 subrouter.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) { 333 w.Write(handlerStr) 334 }).Methods("GET") 335 336 router.Use(func(h http.Handler) http.Handler { 337 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 338 w.Write(mwStr) 339 h.ServeHTTP(w, r) 340 }) 341 }) 342 343 t.Run("not called", func(t *testing.T) { 344 rw := NewRecorder() 345 req := newRequest("POST", "/sub/") 346 347 router.ServeHTTP(rw, req) 348 if bytes.Contains(rw.Body.Bytes(), mwStr) { 349 t.Fatal("Middleware was called for a method mismatch") 350 } 351 }) 352 353 t.Run("not called with custom method not allowed handler", func(t *testing.T) { 354 rw := NewRecorder() 355 req := newRequest("POST", "/sub/") 356 357 router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 358 rw.Write([]byte("Method not allowed")) 359 }) 360 router.ServeHTTP(rw, req) 361 362 if bytes.Contains(rw.Body.Bytes(), mwStr) { 363 t.Fatal("Middleware was called for a method mismatch") 364 } 365 }) 366 } 367 368 func TestCORSMethodMiddleware(t *testing.T) { 369 testCases := []struct { 370 name string 371 registerRoutes func(r *Router) 372 requestHeader http.Header 373 requestMethod string 374 requestPath string 375 expectedAccessControlAllowMethodsHeader string 376 expectedResponse string 377 }{ 378 { 379 name: "does not set without OPTIONS matcher", 380 registerRoutes: func(r *Router) { 381 r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch) 382 }, 383 requestMethod: "GET", 384 requestPath: "/foo", 385 expectedAccessControlAllowMethodsHeader: "", 386 expectedResponse: "a", 387 }, 388 { 389 name: "sets on non OPTIONS", 390 registerRoutes: func(r *Router) { 391 r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch) 392 r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions) 393 }, 394 requestMethod: "GET", 395 requestPath: "/foo", 396 expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS", 397 expectedResponse: "a", 398 }, 399 { 400 name: "sets without preflight headers", 401 registerRoutes: func(r *Router) { 402 r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch) 403 r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions) 404 }, 405 requestMethod: "OPTIONS", 406 requestPath: "/foo", 407 expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS", 408 expectedResponse: "b", 409 }, 410 { 411 name: "does not set on error", 412 registerRoutes: func(r *Router) { 413 r.HandleFunc("/foo", stringHandler("a")) 414 }, 415 requestMethod: "OPTIONS", 416 requestPath: "/foo", 417 expectedAccessControlAllowMethodsHeader: "", 418 expectedResponse: "a", 419 }, 420 { 421 name: "sets header on valid preflight", 422 registerRoutes: func(r *Router) { 423 r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch) 424 r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions) 425 }, 426 requestMethod: "OPTIONS", 427 requestPath: "/foo", 428 requestHeader: http.Header{ 429 "Access-Control-Request-Method": []string{"GET"}, 430 "Access-Control-Request-Headers": []string{"Authorization"}, 431 "Origin": []string{"http://example.com"}, 432 }, 433 expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS", 434 expectedResponse: "b", 435 }, 436 { 437 name: "does not set methods from unmatching routes", 438 registerRoutes: func(r *Router) { 439 r.HandleFunc("/foo", stringHandler("c")).Methods(http.MethodDelete) 440 r.HandleFunc("/foo/bar", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch) 441 r.HandleFunc("/foo/bar", stringHandler("b")).Methods(http.MethodOptions) 442 }, 443 requestMethod: "OPTIONS", 444 requestPath: "/foo/bar", 445 requestHeader: http.Header{ 446 "Access-Control-Request-Method": []string{"GET"}, 447 "Access-Control-Request-Headers": []string{"Authorization"}, 448 "Origin": []string{"http://example.com"}, 449 }, 450 expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS", 451 expectedResponse: "b", 452 }, 453 } 454 455 for _, tt := range testCases { 456 t.Run(tt.name, func(t *testing.T) { 457 router := NewRouter() 458 459 tt.registerRoutes(router) 460 461 router.Use(CORSMethodMiddleware(router)) 462 463 rw := NewRecorder() 464 req := newRequest(tt.requestMethod, tt.requestPath) 465 req.Header = tt.requestHeader 466 467 router.ServeHTTP(rw, req) 468 469 actualMethodsHeader := rw.Header().Get("Access-Control-Allow-Methods") 470 if actualMethodsHeader != tt.expectedAccessControlAllowMethodsHeader { 471 t.Fatalf("Expected Access-Control-Allow-Methods to equal %s but got %s", tt.expectedAccessControlAllowMethodsHeader, actualMethodsHeader) 472 } 473 474 actualResponse := rw.Body.String() 475 if actualResponse != tt.expectedResponse { 476 t.Fatalf("Expected response to equal %s but got %s", tt.expectedResponse, actualResponse) 477 } 478 }) 479 } 480 } 481 482 func TestCORSMethodMiddlewareSubrouter(t *testing.T) { 483 router := NewRouter().StrictSlash(true) 484 485 subrouter := router.PathPrefix("/test").Subrouter() 486 subrouter.HandleFunc("/hello", stringHandler("a")).Methods(http.MethodGet, http.MethodOptions, http.MethodPost) 487 subrouter.HandleFunc("/hello/{name}", stringHandler("b")).Methods(http.MethodGet, http.MethodOptions) 488 489 subrouter.Use(CORSMethodMiddleware(subrouter)) 490 491 rw := NewRecorder() 492 req := newRequest("GET", "/test/hello/asdf") 493 router.ServeHTTP(rw, req) 494 495 actualMethods := rw.Header().Get("Access-Control-Allow-Methods") 496 expectedMethods := "GET,OPTIONS" 497 if actualMethods != expectedMethods { 498 t.Fatalf("expected methods %q but got: %q", expectedMethods, actualMethods) 499 } 500 } 501 502 func TestMiddlewareOnMultiSubrouter(t *testing.T) { 503 first := "first" 504 second := "second" 505 notFound := "404 not found" 506 507 router := NewRouter() 508 firstSubRouter := router.PathPrefix("/").Subrouter() 509 secondSubRouter := router.PathPrefix("/").Subrouter() 510 511 router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 512 rw.Write([]byte(notFound)) 513 }) 514 515 firstSubRouter.HandleFunc("/first", func(w http.ResponseWriter, r *http.Request) { 516 517 }) 518 519 secondSubRouter.HandleFunc("/second", func(w http.ResponseWriter, r *http.Request) { 520 521 }) 522 523 firstSubRouter.Use(func(h http.Handler) http.Handler { 524 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 525 w.Write([]byte(first)) 526 h.ServeHTTP(w, r) 527 }) 528 }) 529 530 secondSubRouter.Use(func(h http.Handler) http.Handler { 531 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 532 w.Write([]byte(second)) 533 h.ServeHTTP(w, r) 534 }) 535 }) 536 537 t.Run("/first uses first middleware", func(t *testing.T) { 538 rw := NewRecorder() 539 req := newRequest("GET", "/first") 540 541 router.ServeHTTP(rw, req) 542 if rw.Body.String() != first { 543 t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", first, rw.Body.String()) 544 } 545 }) 546 547 t.Run("/second uses second middleware", func(t *testing.T) { 548 rw := NewRecorder() 549 req := newRequest("GET", "/second") 550 551 router.ServeHTTP(rw, req) 552 if rw.Body.String() != second { 553 t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", second, rw.Body.String()) 554 } 555 }) 556 557 t.Run("uses not found handler", func(t *testing.T) { 558 rw := NewRecorder() 559 req := newRequest("GET", "/second/not-exist") 560 561 router.ServeHTTP(rw, req) 562 if rw.Body.String() != notFound { 563 t.Fatalf("Notfound handler did not run: expected %s for not-exist, (got %s)", notFound, rw.Body.String()) 564 } 565 }) 566 }