github.com/go-chi/chi@v1.5.5/mux_test.go (about) 1 package chi 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "net" 10 "net/http" 11 "net/http/httptest" 12 "os" 13 "sync" 14 "testing" 15 "time" 16 ) 17 18 func TestMuxBasic(t *testing.T) { 19 var count uint64 20 countermw := func(next http.Handler) http.Handler { 21 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 22 count++ 23 next.ServeHTTP(w, r) 24 }) 25 } 26 27 usermw := func(next http.Handler) http.Handler { 28 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 29 ctx := r.Context() 30 ctx = context.WithValue(ctx, ctxKey{"user"}, "peter") 31 r = r.WithContext(ctx) 32 next.ServeHTTP(w, r) 33 }) 34 } 35 36 exmw := func(next http.Handler) http.Handler { 37 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 38 ctx := context.WithValue(r.Context(), ctxKey{"ex"}, "a") 39 r = r.WithContext(ctx) 40 next.ServeHTTP(w, r) 41 }) 42 } 43 44 logbuf := bytes.NewBufferString("") 45 logmsg := "logmw test" 46 logmw := func(next http.Handler) http.Handler { 47 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 48 logbuf.WriteString(logmsg) 49 next.ServeHTTP(w, r) 50 }) 51 } 52 53 cxindex := func(w http.ResponseWriter, r *http.Request) { 54 ctx := r.Context() 55 user := ctx.Value(ctxKey{"user"}).(string) 56 w.WriteHeader(200) 57 w.Write([]byte(fmt.Sprintf("hi %s", user))) 58 } 59 60 ping := func(w http.ResponseWriter, r *http.Request) { 61 w.WriteHeader(200) 62 w.Write([]byte(".")) 63 } 64 65 headPing := func(w http.ResponseWriter, r *http.Request) { 66 w.Header().Set("X-Ping", "1") 67 w.WriteHeader(200) 68 } 69 70 createPing := func(w http.ResponseWriter, r *http.Request) { 71 // create .... 72 w.WriteHeader(201) 73 } 74 75 pingAll := func(w http.ResponseWriter, r *http.Request) { 76 w.WriteHeader(200) 77 w.Write([]byte("ping all")) 78 } 79 80 pingAll2 := func(w http.ResponseWriter, r *http.Request) { 81 w.WriteHeader(200) 82 w.Write([]byte("ping all2")) 83 } 84 85 pingOne := func(w http.ResponseWriter, r *http.Request) { 86 idParam := URLParam(r, "id") 87 w.WriteHeader(200) 88 w.Write([]byte(fmt.Sprintf("ping one id: %s", idParam))) 89 } 90 91 pingWoop := func(w http.ResponseWriter, r *http.Request) { 92 w.WriteHeader(200) 93 w.Write([]byte("woop." + URLParam(r, "iidd"))) 94 } 95 96 catchAll := func(w http.ResponseWriter, r *http.Request) { 97 w.WriteHeader(200) 98 w.Write([]byte("catchall")) 99 } 100 101 m := NewRouter() 102 m.Use(countermw) 103 m.Use(usermw) 104 m.Use(exmw) 105 m.Use(logmw) 106 m.Get("/", cxindex) 107 m.Method("GET", "/ping", http.HandlerFunc(ping)) 108 m.MethodFunc("GET", "/pingall", pingAll) 109 m.MethodFunc("get", "/ping/all", pingAll) 110 m.Get("/ping/all2", pingAll2) 111 112 m.Head("/ping", headPing) 113 m.Post("/ping", createPing) 114 m.Get("/ping/{id}", pingWoop) 115 m.Get("/ping/{id}", pingOne) // expected to overwrite to pingOne handler 116 m.Get("/ping/{iidd}/woop", pingWoop) 117 m.HandleFunc("/admin/*", catchAll) 118 // m.Post("/admin/*", catchAll) 119 120 ts := httptest.NewServer(m) 121 defer ts.Close() 122 123 // GET / 124 if _, body := testRequest(t, ts, "GET", "/", nil); body != "hi peter" { 125 t.Fatalf(body) 126 } 127 tlogmsg, _ := logbuf.ReadString(0) 128 if tlogmsg != logmsg { 129 t.Error("expecting log message from middleware:", logmsg) 130 } 131 132 // GET /ping 133 if _, body := testRequest(t, ts, "GET", "/ping", nil); body != "." { 134 t.Fatalf(body) 135 } 136 137 // GET /pingall 138 if _, body := testRequest(t, ts, "GET", "/pingall", nil); body != "ping all" { 139 t.Fatalf(body) 140 } 141 142 // GET /ping/all 143 if _, body := testRequest(t, ts, "GET", "/ping/all", nil); body != "ping all" { 144 t.Fatalf(body) 145 } 146 147 // GET /ping/all2 148 if _, body := testRequest(t, ts, "GET", "/ping/all2", nil); body != "ping all2" { 149 t.Fatalf(body) 150 } 151 152 // GET /ping/123 153 if _, body := testRequest(t, ts, "GET", "/ping/123", nil); body != "ping one id: 123" { 154 t.Fatalf(body) 155 } 156 157 // GET /ping/allan 158 if _, body := testRequest(t, ts, "GET", "/ping/allan", nil); body != "ping one id: allan" { 159 t.Fatalf(body) 160 } 161 162 // GET /ping/1/woop 163 if _, body := testRequest(t, ts, "GET", "/ping/1/woop", nil); body != "woop.1" { 164 t.Fatalf(body) 165 } 166 167 // HEAD /ping 168 resp, err := http.Head(ts.URL + "/ping") 169 if err != nil { 170 t.Fatal(err) 171 } 172 if resp.StatusCode != 200 { 173 t.Error("head failed, should be 200") 174 } 175 if resp.Header.Get("X-Ping") == "" { 176 t.Error("expecting X-Ping header") 177 } 178 179 // GET /admin/catch-this 180 if _, body := testRequest(t, ts, "GET", "/admin/catch-thazzzzz", nil); body != "catchall" { 181 t.Fatalf(body) 182 } 183 184 // POST /admin/catch-this 185 resp, err = http.Post(ts.URL+"/admin/casdfsadfs", "text/plain", bytes.NewReader([]byte{})) 186 if err != nil { 187 t.Fatal(err) 188 } 189 190 body, err := ioutil.ReadAll(resp.Body) 191 if err != nil { 192 t.Fatal(err) 193 } 194 defer resp.Body.Close() 195 196 if resp.StatusCode != 200 { 197 t.Error("POST failed, should be 200") 198 } 199 200 if string(body) != "catchall" { 201 t.Error("expecting response body: 'catchall'") 202 } 203 204 // Custom http method DIE /ping/1/woop 205 if resp, body := testRequest(t, ts, "DIE", "/ping/1/woop", nil); body != "" || resp.StatusCode != 405 { 206 t.Fatalf(fmt.Sprintf("expecting 405 status and empty body, got %d '%s'", resp.StatusCode, body)) 207 } 208 } 209 210 func TestMuxMounts(t *testing.T) { 211 r := NewRouter() 212 213 r.Get("/{hash}", func(w http.ResponseWriter, r *http.Request) { 214 v := URLParam(r, "hash") 215 w.Write([]byte(fmt.Sprintf("/%s", v))) 216 }) 217 218 r.Route("/{hash}/share", func(r Router) { 219 r.Get("/", func(w http.ResponseWriter, r *http.Request) { 220 v := URLParam(r, "hash") 221 w.Write([]byte(fmt.Sprintf("/%s/share", v))) 222 }) 223 r.Get("/{network}", func(w http.ResponseWriter, r *http.Request) { 224 v := URLParam(r, "hash") 225 n := URLParam(r, "network") 226 w.Write([]byte(fmt.Sprintf("/%s/share/%s", v, n))) 227 }) 228 }) 229 230 m := NewRouter() 231 m.Mount("/sharing", r) 232 233 ts := httptest.NewServer(m) 234 defer ts.Close() 235 236 if _, body := testRequest(t, ts, "GET", "/sharing/aBc", nil); body != "/aBc" { 237 t.Fatalf(body) 238 } 239 if _, body := testRequest(t, ts, "GET", "/sharing/aBc/share", nil); body != "/aBc/share" { 240 t.Fatalf(body) 241 } 242 if _, body := testRequest(t, ts, "GET", "/sharing/aBc/share/twitter", nil); body != "/aBc/share/twitter" { 243 t.Fatalf(body) 244 } 245 } 246 247 func TestMuxPlain(t *testing.T) { 248 r := NewRouter() 249 r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { 250 w.Write([]byte("bye")) 251 }) 252 r.NotFound(func(w http.ResponseWriter, r *http.Request) { 253 w.WriteHeader(404) 254 w.Write([]byte("nothing here")) 255 }) 256 257 ts := httptest.NewServer(r) 258 defer ts.Close() 259 260 if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" { 261 t.Fatalf(body) 262 } 263 if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" { 264 t.Fatalf(body) 265 } 266 } 267 268 func TestMuxEmptyRoutes(t *testing.T) { 269 mux := NewRouter() 270 271 apiRouter := NewRouter() 272 // oops, we forgot to declare any route handlers 273 274 mux.Handle("/api*", apiRouter) 275 276 if _, body := testHandler(t, mux, "GET", "/", nil); body != "404 page not found\n" { 277 t.Fatalf(body) 278 } 279 280 if _, body := testHandler(t, apiRouter, "GET", "/", nil); body != "404 page not found\n" { 281 t.Fatalf(body) 282 } 283 } 284 285 // Test a mux that routes a trailing slash, see also middleware/strip_test.go 286 // for an example of using a middleware to handle trailing slashes. 287 func TestMuxTrailingSlash(t *testing.T) { 288 r := NewRouter() 289 r.NotFound(func(w http.ResponseWriter, r *http.Request) { 290 w.WriteHeader(404) 291 w.Write([]byte("nothing here")) 292 }) 293 294 subRoutes := NewRouter() 295 indexHandler := func(w http.ResponseWriter, r *http.Request) { 296 accountID := URLParam(r, "accountID") 297 w.Write([]byte(accountID)) 298 } 299 subRoutes.Get("/", indexHandler) 300 301 r.Mount("/accounts/{accountID}", subRoutes) 302 r.Get("/accounts/{accountID}/", indexHandler) 303 304 ts := httptest.NewServer(r) 305 defer ts.Close() 306 307 if _, body := testRequest(t, ts, "GET", "/accounts/admin", nil); body != "admin" { 308 t.Fatalf(body) 309 } 310 if _, body := testRequest(t, ts, "GET", "/accounts/admin/", nil); body != "admin" { 311 t.Fatalf(body) 312 } 313 if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" { 314 t.Fatalf(body) 315 } 316 } 317 318 func TestMuxNestedNotFound(t *testing.T) { 319 r := NewRouter() 320 321 r.Use(func(next http.Handler) http.Handler { 322 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 323 r = r.WithContext(context.WithValue(r.Context(), ctxKey{"mw"}, "mw")) 324 next.ServeHTTP(w, r) 325 }) 326 }) 327 328 r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { 329 w.Write([]byte("bye")) 330 }) 331 332 r.With(func(next http.Handler) http.Handler { 333 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 334 r = r.WithContext(context.WithValue(r.Context(), ctxKey{"with"}, "with")) 335 next.ServeHTTP(w, r) 336 }) 337 }).NotFound(func(w http.ResponseWriter, r *http.Request) { 338 chkMw := r.Context().Value(ctxKey{"mw"}).(string) 339 chkWith := r.Context().Value(ctxKey{"with"}).(string) 340 w.WriteHeader(404) 341 w.Write([]byte(fmt.Sprintf("root 404 %s %s", chkMw, chkWith))) 342 }) 343 344 sr1 := NewRouter() 345 346 sr1.Get("/sub", func(w http.ResponseWriter, r *http.Request) { 347 w.Write([]byte("sub")) 348 }) 349 sr1.Group(func(sr1 Router) { 350 sr1.Use(func(next http.Handler) http.Handler { 351 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 352 r = r.WithContext(context.WithValue(r.Context(), ctxKey{"mw2"}, "mw2")) 353 next.ServeHTTP(w, r) 354 }) 355 }) 356 sr1.NotFound(func(w http.ResponseWriter, r *http.Request) { 357 chkMw2 := r.Context().Value(ctxKey{"mw2"}).(string) 358 w.WriteHeader(404) 359 w.Write([]byte(fmt.Sprintf("sub 404 %s", chkMw2))) 360 }) 361 }) 362 363 sr2 := NewRouter() 364 sr2.Get("/sub", func(w http.ResponseWriter, r *http.Request) { 365 w.Write([]byte("sub2")) 366 }) 367 368 r.Mount("/admin1", sr1) 369 r.Mount("/admin2", sr2) 370 371 ts := httptest.NewServer(r) 372 defer ts.Close() 373 374 if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" { 375 t.Fatalf(body) 376 } 377 if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "root 404 mw with" { 378 t.Fatalf(body) 379 } 380 if _, body := testRequest(t, ts, "GET", "/admin1/sub", nil); body != "sub" { 381 t.Fatalf(body) 382 } 383 if _, body := testRequest(t, ts, "GET", "/admin1/nope", nil); body != "sub 404 mw2" { 384 t.Fatalf(body) 385 } 386 if _, body := testRequest(t, ts, "GET", "/admin2/sub", nil); body != "sub2" { 387 t.Fatalf(body) 388 } 389 390 // Not found pages should bubble up to the root. 391 if _, body := testRequest(t, ts, "GET", "/admin2/nope", nil); body != "root 404 mw with" { 392 t.Fatalf(body) 393 } 394 } 395 396 func TestMuxNestedMethodNotAllowed(t *testing.T) { 397 r := NewRouter() 398 r.Get("/root", func(w http.ResponseWriter, r *http.Request) { 399 w.Write([]byte("root")) 400 }) 401 r.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) { 402 w.WriteHeader(405) 403 w.Write([]byte("root 405")) 404 }) 405 406 sr1 := NewRouter() 407 sr1.Get("/sub1", func(w http.ResponseWriter, r *http.Request) { 408 w.Write([]byte("sub1")) 409 }) 410 sr1.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) { 411 w.WriteHeader(405) 412 w.Write([]byte("sub1 405")) 413 }) 414 415 sr2 := NewRouter() 416 sr2.Get("/sub2", func(w http.ResponseWriter, r *http.Request) { 417 w.Write([]byte("sub2")) 418 }) 419 420 pathVar := NewRouter() 421 pathVar.Get("/{var}", func(w http.ResponseWriter, r *http.Request) { 422 w.Write([]byte("pv")) 423 }) 424 pathVar.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) { 425 w.WriteHeader(405) 426 w.Write([]byte("pv 405")) 427 }) 428 429 r.Mount("/prefix1", sr1) 430 r.Mount("/prefix2", sr2) 431 r.Mount("/pathVar", pathVar) 432 433 ts := httptest.NewServer(r) 434 defer ts.Close() 435 436 if _, body := testRequest(t, ts, "GET", "/root", nil); body != "root" { 437 t.Fatalf(body) 438 } 439 if _, body := testRequest(t, ts, "PUT", "/root", nil); body != "root 405" { 440 t.Fatalf(body) 441 } 442 if _, body := testRequest(t, ts, "GET", "/prefix1/sub1", nil); body != "sub1" { 443 t.Fatalf(body) 444 } 445 if _, body := testRequest(t, ts, "PUT", "/prefix1/sub1", nil); body != "sub1 405" { 446 t.Fatalf(body) 447 } 448 if _, body := testRequest(t, ts, "GET", "/prefix2/sub2", nil); body != "sub2" { 449 t.Fatalf(body) 450 } 451 if _, body := testRequest(t, ts, "PUT", "/prefix2/sub2", nil); body != "root 405" { 452 t.Fatalf(body) 453 } 454 if _, body := testRequest(t, ts, "GET", "/pathVar/myvar", nil); body != "pv" { 455 t.Fatalf(body) 456 } 457 if _, body := testRequest(t, ts, "DELETE", "/pathVar/myvar", nil); body != "pv 405" { 458 t.Fatalf(body) 459 } 460 } 461 462 func TestMuxComplicatedNotFound(t *testing.T) { 463 decorateRouter := func(r *Mux) { 464 // Root router with groups 465 r.Get("/auth", func(w http.ResponseWriter, r *http.Request) { 466 w.Write([]byte("auth get")) 467 }) 468 r.Route("/public", func(r Router) { 469 r.Get("/", func(w http.ResponseWriter, r *http.Request) { 470 w.Write([]byte("public get")) 471 }) 472 }) 473 474 // sub router with groups 475 sub0 := NewRouter() 476 sub0.Route("/resource", func(r Router) { 477 r.Get("/", func(w http.ResponseWriter, r *http.Request) { 478 w.Write([]byte("private get")) 479 }) 480 }) 481 r.Mount("/private", sub0) 482 483 // sub router with groups 484 sub1 := NewRouter() 485 sub1.Route("/resource", func(r Router) { 486 r.Get("/", func(w http.ResponseWriter, r *http.Request) { 487 w.Write([]byte("private get")) 488 }) 489 }) 490 r.With(func(next http.Handler) http.Handler { return next }).Mount("/private_mw", sub1) 491 } 492 493 testNotFound := func(t *testing.T, r *Mux) { 494 ts := httptest.NewServer(r) 495 defer ts.Close() 496 497 // check that we didn't break correct routes 498 if _, body := testRequest(t, ts, "GET", "/auth", nil); body != "auth get" { 499 t.Fatalf(body) 500 } 501 if _, body := testRequest(t, ts, "GET", "/public", nil); body != "public get" { 502 t.Fatalf(body) 503 } 504 if _, body := testRequest(t, ts, "GET", "/public/", nil); body != "public get" { 505 t.Fatalf(body) 506 } 507 if _, body := testRequest(t, ts, "GET", "/private/resource", nil); body != "private get" { 508 t.Fatalf(body) 509 } 510 // check custom not-found on all levels 511 if _, body := testRequest(t, ts, "GET", "/nope", nil); body != "custom not-found" { 512 t.Fatalf(body) 513 } 514 if _, body := testRequest(t, ts, "GET", "/public/nope", nil); body != "custom not-found" { 515 t.Fatalf(body) 516 } 517 if _, body := testRequest(t, ts, "GET", "/private/nope", nil); body != "custom not-found" { 518 t.Fatalf(body) 519 } 520 if _, body := testRequest(t, ts, "GET", "/private/resource/nope", nil); body != "custom not-found" { 521 t.Fatalf(body) 522 } 523 if _, body := testRequest(t, ts, "GET", "/private_mw/nope", nil); body != "custom not-found" { 524 t.Fatalf(body) 525 } 526 if _, body := testRequest(t, ts, "GET", "/private_mw/resource/nope", nil); body != "custom not-found" { 527 t.Fatalf(body) 528 } 529 // check custom not-found on trailing slash routes 530 if _, body := testRequest(t, ts, "GET", "/auth/", nil); body != "custom not-found" { 531 t.Fatalf(body) 532 } 533 } 534 535 t.Run("pre", func(t *testing.T) { 536 r := NewRouter() 537 r.NotFound(func(w http.ResponseWriter, r *http.Request) { 538 w.Write([]byte("custom not-found")) 539 }) 540 decorateRouter(r) 541 testNotFound(t, r) 542 }) 543 544 t.Run("post", func(t *testing.T) { 545 r := NewRouter() 546 decorateRouter(r) 547 r.NotFound(func(w http.ResponseWriter, r *http.Request) { 548 w.Write([]byte("custom not-found")) 549 }) 550 testNotFound(t, r) 551 }) 552 } 553 554 func TestMuxWith(t *testing.T) { 555 var cmwInit1, cmwHandler1 uint64 556 var cmwInit2, cmwHandler2 uint64 557 mw1 := func(next http.Handler) http.Handler { 558 cmwInit1++ 559 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 560 cmwHandler1++ 561 r = r.WithContext(context.WithValue(r.Context(), ctxKey{"inline1"}, "yes")) 562 next.ServeHTTP(w, r) 563 }) 564 } 565 mw2 := func(next http.Handler) http.Handler { 566 cmwInit2++ 567 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 568 cmwHandler2++ 569 r = r.WithContext(context.WithValue(r.Context(), ctxKey{"inline2"}, "yes")) 570 next.ServeHTTP(w, r) 571 }) 572 } 573 574 r := NewRouter() 575 r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { 576 w.Write([]byte("bye")) 577 }) 578 r.With(mw1).With(mw2).Get("/inline", func(w http.ResponseWriter, r *http.Request) { 579 v1 := r.Context().Value(ctxKey{"inline1"}).(string) 580 v2 := r.Context().Value(ctxKey{"inline2"}).(string) 581 w.Write([]byte(fmt.Sprintf("inline %s %s", v1, v2))) 582 }) 583 584 ts := httptest.NewServer(r) 585 defer ts.Close() 586 587 if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" { 588 t.Fatalf(body) 589 } 590 if _, body := testRequest(t, ts, "GET", "/inline", nil); body != "inline yes yes" { 591 t.Fatalf(body) 592 } 593 if cmwInit1 != 1 { 594 t.Fatalf("expecting cmwInit1 to be 1, got %d", cmwInit1) 595 } 596 if cmwHandler1 != 1 { 597 t.Fatalf("expecting cmwHandler1 to be 1, got %d", cmwHandler1) 598 } 599 if cmwInit2 != 1 { 600 t.Fatalf("expecting cmwInit2 to be 1, got %d", cmwInit2) 601 } 602 if cmwHandler2 != 1 { 603 t.Fatalf("expecting cmwHandler2 to be 1, got %d", cmwHandler2) 604 } 605 } 606 607 func TestRouterFromMuxWith(t *testing.T) { 608 t.Parallel() 609 610 r := NewRouter() 611 612 with := r.With(func(next http.Handler) http.Handler { 613 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 614 next.ServeHTTP(w, r) 615 }) 616 }) 617 618 with.Get("/with_middleware", func(w http.ResponseWriter, r *http.Request) {}) 619 620 ts := httptest.NewServer(with) 621 defer ts.Close() 622 623 // Without the fix this test was committed with, this causes a panic. 624 testRequest(t, ts, http.MethodGet, "/with_middleware", nil) 625 } 626 627 func TestMuxMiddlewareStack(t *testing.T) { 628 var stdmwInit, stdmwHandler uint64 629 stdmw := func(next http.Handler) http.Handler { 630 stdmwInit++ 631 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 632 stdmwHandler++ 633 next.ServeHTTP(w, r) 634 }) 635 } 636 _ = stdmw 637 638 var ctxmwInit, ctxmwHandler uint64 639 ctxmw := func(next http.Handler) http.Handler { 640 ctxmwInit++ 641 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 642 ctxmwHandler++ 643 ctx := r.Context() 644 ctx = context.WithValue(ctx, ctxKey{"count.ctxmwHandler"}, ctxmwHandler) 645 r = r.WithContext(ctx) 646 next.ServeHTTP(w, r) 647 }) 648 } 649 650 var inCtxmwInit, inCtxmwHandler uint64 651 inCtxmw := func(next http.Handler) http.Handler { 652 inCtxmwInit++ 653 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 654 inCtxmwHandler++ 655 next.ServeHTTP(w, r) 656 }) 657 } 658 659 r := NewRouter() 660 r.Use(stdmw) 661 r.Use(ctxmw) 662 r.Use(func(next http.Handler) http.Handler { 663 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 664 if r.URL.Path == "/ping" { 665 w.Write([]byte("pong")) 666 return 667 } 668 next.ServeHTTP(w, r) 669 }) 670 }) 671 672 var handlerCount uint64 673 674 r.With(inCtxmw).Get("/", func(w http.ResponseWriter, r *http.Request) { 675 handlerCount++ 676 ctx := r.Context() 677 ctxmwHandlerCount := ctx.Value(ctxKey{"count.ctxmwHandler"}).(uint64) 678 w.Write([]byte(fmt.Sprintf("inits:%d reqs:%d ctxValue:%d", ctxmwInit, handlerCount, ctxmwHandlerCount))) 679 }) 680 681 r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { 682 w.Write([]byte("wooot")) 683 }) 684 685 ts := httptest.NewServer(r) 686 defer ts.Close() 687 688 testRequest(t, ts, "GET", "/", nil) 689 testRequest(t, ts, "GET", "/", nil) 690 var body string 691 _, body = testRequest(t, ts, "GET", "/", nil) 692 if body != "inits:1 reqs:3 ctxValue:3" { 693 t.Fatalf("got: '%s'", body) 694 } 695 696 _, body = testRequest(t, ts, "GET", "/ping", nil) 697 if body != "pong" { 698 t.Fatalf("got: '%s'", body) 699 } 700 } 701 702 func TestMuxRouteGroups(t *testing.T) { 703 var stdmwInit, stdmwHandler uint64 704 705 stdmw := func(next http.Handler) http.Handler { 706 stdmwInit++ 707 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 708 stdmwHandler++ 709 next.ServeHTTP(w, r) 710 }) 711 } 712 713 var stdmwInit2, stdmwHandler2 uint64 714 stdmw2 := func(next http.Handler) http.Handler { 715 stdmwInit2++ 716 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 717 stdmwHandler2++ 718 next.ServeHTTP(w, r) 719 }) 720 } 721 722 r := NewRouter() 723 r.Group(func(r Router) { 724 r.Use(stdmw) 725 r.Get("/group", func(w http.ResponseWriter, r *http.Request) { 726 w.Write([]byte("root group")) 727 }) 728 }) 729 r.Group(func(r Router) { 730 r.Use(stdmw2) 731 r.Get("/group2", func(w http.ResponseWriter, r *http.Request) { 732 w.Write([]byte("root group2")) 733 }) 734 }) 735 736 ts := httptest.NewServer(r) 737 defer ts.Close() 738 739 // GET /group 740 _, body := testRequest(t, ts, "GET", "/group", nil) 741 if body != "root group" { 742 t.Fatalf("got: '%s'", body) 743 } 744 if stdmwInit != 1 || stdmwHandler != 1 { 745 t.Logf("stdmw counters failed, should be 1:1, got %d:%d", stdmwInit, stdmwHandler) 746 } 747 748 // GET /group2 749 _, body = testRequest(t, ts, "GET", "/group2", nil) 750 if body != "root group2" { 751 t.Fatalf("got: '%s'", body) 752 } 753 if stdmwInit2 != 1 || stdmwHandler2 != 1 { 754 t.Fatalf("stdmw2 counters failed, should be 1:1, got %d:%d", stdmwInit2, stdmwHandler2) 755 } 756 } 757 758 func TestMuxBig(t *testing.T) { 759 r := bigMux() 760 761 ts := httptest.NewServer(r) 762 defer ts.Close() 763 764 var body, expected string 765 766 _, body = testRequest(t, ts, "GET", "/favicon.ico", nil) 767 if body != "fav" { 768 t.Fatalf("got '%s'", body) 769 } 770 _, body = testRequest(t, ts, "GET", "/hubs/4/view", nil) 771 if body != "/hubs/4/view reqid:1 session:anonymous" { 772 t.Fatalf("got '%v'", body) 773 } 774 _, body = testRequest(t, ts, "GET", "/hubs/4/view/index.html", nil) 775 if body != "/hubs/4/view/index.html reqid:1 session:anonymous" { 776 t.Fatalf("got '%s'", body) 777 } 778 _, body = testRequest(t, ts, "POST", "/hubs/ethereumhub/view/index.html", nil) 779 if body != "/hubs/ethereumhub/view/index.html reqid:1 session:anonymous" { 780 t.Fatalf("got '%s'", body) 781 } 782 _, body = testRequest(t, ts, "GET", "/", nil) 783 if body != "/ reqid:1 session:elvis" { 784 t.Fatalf("got '%s'", body) 785 } 786 _, body = testRequest(t, ts, "GET", "/suggestions", nil) 787 if body != "/suggestions reqid:1 session:elvis" { 788 t.Fatalf("got '%s'", body) 789 } 790 _, body = testRequest(t, ts, "GET", "/woot/444/hiiii", nil) 791 if body != "/woot/444/hiiii" { 792 t.Fatalf("got '%s'", body) 793 } 794 _, body = testRequest(t, ts, "GET", "/hubs/123", nil) 795 expected = "/hubs/123 reqid:1 session:elvis" 796 if body != expected { 797 t.Fatalf("expected:%s got:%s", expected, body) 798 } 799 _, body = testRequest(t, ts, "GET", "/hubs/123/touch", nil) 800 if body != "/hubs/123/touch reqid:1 session:elvis" { 801 t.Fatalf("got '%s'", body) 802 } 803 _, body = testRequest(t, ts, "GET", "/hubs/123/webhooks", nil) 804 if body != "/hubs/123/webhooks reqid:1 session:elvis" { 805 t.Fatalf("got '%s'", body) 806 } 807 _, body = testRequest(t, ts, "GET", "/hubs/123/posts", nil) 808 if body != "/hubs/123/posts reqid:1 session:elvis" { 809 t.Fatalf("got '%s'", body) 810 } 811 _, body = testRequest(t, ts, "GET", "/folders", nil) 812 if body != "404 page not found\n" { 813 t.Fatalf("got '%s'", body) 814 } 815 _, body = testRequest(t, ts, "GET", "/folders/", nil) 816 if body != "/folders/ reqid:1 session:elvis" { 817 t.Fatalf("got '%s'", body) 818 } 819 _, body = testRequest(t, ts, "GET", "/folders/public", nil) 820 if body != "/folders/public reqid:1 session:elvis" { 821 t.Fatalf("got '%s'", body) 822 } 823 _, body = testRequest(t, ts, "GET", "/folders/nothing", nil) 824 if body != "404 page not found\n" { 825 t.Fatalf("got '%s'", body) 826 } 827 } 828 829 func bigMux() Router { 830 var r *Mux 831 var sr3 *Mux 832 // var sr1, sr2, sr3, sr4, sr5, sr6 *Mux 833 r = NewRouter() 834 r.Use(func(next http.Handler) http.Handler { 835 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 836 ctx := context.WithValue(r.Context(), ctxKey{"requestID"}, "1") 837 next.ServeHTTP(w, r.WithContext(ctx)) 838 }) 839 }) 840 r.Use(func(next http.Handler) http.Handler { 841 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 842 next.ServeHTTP(w, r) 843 }) 844 }) 845 r.Group(func(r Router) { 846 r.Use(func(next http.Handler) http.Handler { 847 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 848 ctx := context.WithValue(r.Context(), ctxKey{"session.user"}, "anonymous") 849 next.ServeHTTP(w, r.WithContext(ctx)) 850 }) 851 }) 852 r.Get("/favicon.ico", func(w http.ResponseWriter, r *http.Request) { 853 w.Write([]byte("fav")) 854 }) 855 r.Get("/hubs/{hubID}/view", func(w http.ResponseWriter, r *http.Request) { 856 ctx := r.Context() 857 s := fmt.Sprintf("/hubs/%s/view reqid:%s session:%s", URLParam(r, "hubID"), 858 ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"})) 859 w.Write([]byte(s)) 860 }) 861 r.Get("/hubs/{hubID}/view/*", func(w http.ResponseWriter, r *http.Request) { 862 ctx := r.Context() 863 s := fmt.Sprintf("/hubs/%s/view/%s reqid:%s session:%s", URLParamFromCtx(ctx, "hubID"), 864 URLParam(r, "*"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"})) 865 w.Write([]byte(s)) 866 }) 867 r.Post("/hubs/{hubSlug}/view/*", func(w http.ResponseWriter, r *http.Request) { 868 ctx := r.Context() 869 s := fmt.Sprintf("/hubs/%s/view/%s reqid:%s session:%s", URLParamFromCtx(ctx, "hubSlug"), 870 URLParam(r, "*"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"})) 871 w.Write([]byte(s)) 872 }) 873 }) 874 r.Group(func(r Router) { 875 r.Use(func(next http.Handler) http.Handler { 876 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 877 ctx := context.WithValue(r.Context(), ctxKey{"session.user"}, "elvis") 878 next.ServeHTTP(w, r.WithContext(ctx)) 879 }) 880 }) 881 r.Get("/", func(w http.ResponseWriter, r *http.Request) { 882 ctx := r.Context() 883 s := fmt.Sprintf("/ reqid:%s session:%s", ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"})) 884 w.Write([]byte(s)) 885 }) 886 r.Get("/suggestions", func(w http.ResponseWriter, r *http.Request) { 887 ctx := r.Context() 888 s := fmt.Sprintf("/suggestions reqid:%s session:%s", ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"})) 889 w.Write([]byte(s)) 890 }) 891 892 r.Get("/woot/{wootID}/*", func(w http.ResponseWriter, r *http.Request) { 893 s := fmt.Sprintf("/woot/%s/%s", URLParam(r, "wootID"), URLParam(r, "*")) 894 w.Write([]byte(s)) 895 }) 896 897 r.Route("/hubs", func(r Router) { 898 _ = r.(*Mux) // sr1 899 r.Route("/{hubID}", func(r Router) { 900 _ = r.(*Mux) // sr2 901 r.Get("/", func(w http.ResponseWriter, r *http.Request) { 902 ctx := r.Context() 903 s := fmt.Sprintf("/hubs/%s reqid:%s session:%s", 904 URLParam(r, "hubID"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"})) 905 w.Write([]byte(s)) 906 }) 907 r.Get("/touch", func(w http.ResponseWriter, r *http.Request) { 908 ctx := r.Context() 909 s := fmt.Sprintf("/hubs/%s/touch reqid:%s session:%s", URLParam(r, "hubID"), 910 ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"})) 911 w.Write([]byte(s)) 912 }) 913 914 sr3 = NewRouter() 915 sr3.Get("/", func(w http.ResponseWriter, r *http.Request) { 916 ctx := r.Context() 917 s := fmt.Sprintf("/hubs/%s/webhooks reqid:%s session:%s", URLParam(r, "hubID"), 918 ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"})) 919 w.Write([]byte(s)) 920 }) 921 sr3.Route("/{webhookID}", func(r Router) { 922 _ = r.(*Mux) // sr4 923 r.Get("/", func(w http.ResponseWriter, r *http.Request) { 924 ctx := r.Context() 925 s := fmt.Sprintf("/hubs/%s/webhooks/%s reqid:%s session:%s", URLParam(r, "hubID"), 926 URLParam(r, "webhookID"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"})) 927 w.Write([]byte(s)) 928 }) 929 }) 930 931 r.Mount("/webhooks", Chain(func(next http.Handler) http.Handler { 932 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 933 next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), ctxKey{"hook"}, true))) 934 }) 935 }).Handler(sr3)) 936 937 r.Route("/posts", func(r Router) { 938 _ = r.(*Mux) // sr5 939 r.Get("/", func(w http.ResponseWriter, r *http.Request) { 940 ctx := r.Context() 941 s := fmt.Sprintf("/hubs/%s/posts reqid:%s session:%s", URLParam(r, "hubID"), 942 ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"})) 943 w.Write([]byte(s)) 944 }) 945 }) 946 }) 947 }) 948 949 r.Route("/folders/", func(r Router) { 950 _ = r.(*Mux) // sr6 951 r.Get("/", func(w http.ResponseWriter, r *http.Request) { 952 ctx := r.Context() 953 s := fmt.Sprintf("/folders/ reqid:%s session:%s", 954 ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"})) 955 w.Write([]byte(s)) 956 }) 957 r.Get("/public", func(w http.ResponseWriter, r *http.Request) { 958 ctx := r.Context() 959 s := fmt.Sprintf("/folders/public reqid:%s session:%s", 960 ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"})) 961 w.Write([]byte(s)) 962 }) 963 }) 964 }) 965 966 return r 967 } 968 969 func TestMuxSubroutesBasic(t *testing.T) { 970 hIndex := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 971 w.Write([]byte("index")) 972 }) 973 hArticlesList := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 974 w.Write([]byte("articles-list")) 975 }) 976 hSearchArticles := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 977 w.Write([]byte("search-articles")) 978 }) 979 hGetArticle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 980 w.Write([]byte(fmt.Sprintf("get-article:%s", URLParam(r, "id")))) 981 }) 982 hSyncArticle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 983 w.Write([]byte(fmt.Sprintf("sync-article:%s", URLParam(r, "id")))) 984 }) 985 986 r := NewRouter() 987 // var rr1, rr2 *Mux 988 r.Get("/", hIndex) 989 r.Route("/articles", func(r Router) { 990 // rr1 = r.(*Mux) 991 r.Get("/", hArticlesList) 992 r.Get("/search", hSearchArticles) 993 r.Route("/{id}", func(r Router) { 994 // rr2 = r.(*Mux) 995 r.Get("/", hGetArticle) 996 r.Get("/sync", hSyncArticle) 997 }) 998 }) 999 1000 // log.Println("~~~~~~~~~") 1001 // log.Println("~~~~~~~~~") 1002 // debugPrintTree(0, 0, r.tree, 0) 1003 // log.Println("~~~~~~~~~") 1004 // log.Println("~~~~~~~~~") 1005 1006 // log.Println("~~~~~~~~~") 1007 // log.Println("~~~~~~~~~") 1008 // debugPrintTree(0, 0, rr1.tree, 0) 1009 // log.Println("~~~~~~~~~") 1010 // log.Println("~~~~~~~~~") 1011 1012 // log.Println("~~~~~~~~~") 1013 // log.Println("~~~~~~~~~") 1014 // debugPrintTree(0, 0, rr2.tree, 0) 1015 // log.Println("~~~~~~~~~") 1016 // log.Println("~~~~~~~~~") 1017 1018 ts := httptest.NewServer(r) 1019 defer ts.Close() 1020 1021 var body, expected string 1022 1023 _, body = testRequest(t, ts, "GET", "/", nil) 1024 expected = "index" 1025 if body != expected { 1026 t.Fatalf("expected:%s got:%s", expected, body) 1027 } 1028 _, body = testRequest(t, ts, "GET", "/articles", nil) 1029 expected = "articles-list" 1030 if body != expected { 1031 t.Fatalf("expected:%s got:%s", expected, body) 1032 } 1033 _, body = testRequest(t, ts, "GET", "/articles/search", nil) 1034 expected = "search-articles" 1035 if body != expected { 1036 t.Fatalf("expected:%s got:%s", expected, body) 1037 } 1038 _, body = testRequest(t, ts, "GET", "/articles/123", nil) 1039 expected = "get-article:123" 1040 if body != expected { 1041 t.Fatalf("expected:%s got:%s", expected, body) 1042 } 1043 _, body = testRequest(t, ts, "GET", "/articles/123/sync", nil) 1044 expected = "sync-article:123" 1045 if body != expected { 1046 t.Fatalf("expected:%s got:%s", expected, body) 1047 } 1048 } 1049 1050 func TestMuxSubroutes(t *testing.T) { 1051 hHubView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1052 w.Write([]byte("hub1")) 1053 }) 1054 hHubView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1055 w.Write([]byte("hub2")) 1056 }) 1057 hHubView3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1058 w.Write([]byte("hub3")) 1059 }) 1060 hAccountView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1061 w.Write([]byte("account1")) 1062 }) 1063 hAccountView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1064 w.Write([]byte("account2")) 1065 }) 1066 1067 r := NewRouter() 1068 r.Get("/hubs/{hubID}/view", hHubView1) 1069 r.Get("/hubs/{hubID}/view/*", hHubView2) 1070 1071 sr := NewRouter() 1072 sr.Get("/", hHubView3) 1073 r.Mount("/hubs/{hubID}/users", sr) 1074 r.Get("/hubs/{hubID}/users/", func(w http.ResponseWriter, r *http.Request) { 1075 w.Write([]byte("hub3 override")) 1076 }) 1077 1078 sr3 := NewRouter() 1079 sr3.Get("/", hAccountView1) 1080 sr3.Get("/hi", hAccountView2) 1081 1082 // var sr2 *Mux 1083 r.Route("/accounts/{accountID}", func(r Router) { 1084 _ = r.(*Mux) // sr2 1085 // r.Get("/", hAccountView1) 1086 r.Mount("/", sr3) 1087 }) 1088 1089 // This is the same as the r.Route() call mounted on sr2 1090 // sr2 := NewRouter() 1091 // sr2.Mount("/", sr3) 1092 // r.Mount("/accounts/{accountID}", sr2) 1093 1094 ts := httptest.NewServer(r) 1095 defer ts.Close() 1096 1097 var body, expected string 1098 1099 _, body = testRequest(t, ts, "GET", "/hubs/123/view", nil) 1100 expected = "hub1" 1101 if body != expected { 1102 t.Fatalf("expected:%s got:%s", expected, body) 1103 } 1104 _, body = testRequest(t, ts, "GET", "/hubs/123/view/index.html", nil) 1105 expected = "hub2" 1106 if body != expected { 1107 t.Fatalf("expected:%s got:%s", expected, body) 1108 } 1109 _, body = testRequest(t, ts, "GET", "/hubs/123/users", nil) 1110 expected = "hub3" 1111 if body != expected { 1112 t.Fatalf("expected:%s got:%s", expected, body) 1113 } 1114 _, body = testRequest(t, ts, "GET", "/hubs/123/users/", nil) 1115 expected = "hub3 override" 1116 if body != expected { 1117 t.Fatalf("expected:%s got:%s", expected, body) 1118 } 1119 _, body = testRequest(t, ts, "GET", "/accounts/44", nil) 1120 expected = "account1" 1121 if body != expected { 1122 t.Fatalf("request:%s expected:%s got:%s", "GET /accounts/44", expected, body) 1123 } 1124 _, body = testRequest(t, ts, "GET", "/accounts/44/hi", nil) 1125 expected = "account2" 1126 if body != expected { 1127 t.Fatalf("expected:%s got:%s", expected, body) 1128 } 1129 1130 // Test that we're building the routingPatterns properly 1131 router := r 1132 req, _ := http.NewRequest("GET", "/accounts/44/hi", nil) 1133 1134 rctx := NewRouteContext() 1135 req = req.WithContext(context.WithValue(req.Context(), RouteCtxKey, rctx)) 1136 1137 w := httptest.NewRecorder() 1138 router.ServeHTTP(w, req) 1139 1140 body = w.Body.String() 1141 expected = "account2" 1142 if body != expected { 1143 t.Fatalf("expected:%s got:%s", expected, body) 1144 } 1145 1146 routePatterns := rctx.RoutePatterns 1147 if len(rctx.RoutePatterns) != 3 { 1148 t.Fatalf("expected 3 routing patterns, got:%d", len(rctx.RoutePatterns)) 1149 } 1150 expected = "/accounts/{accountID}/*" 1151 if routePatterns[0] != expected { 1152 t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[0]) 1153 } 1154 expected = "/*" 1155 if routePatterns[1] != expected { 1156 t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[1]) 1157 } 1158 expected = "/hi" 1159 if routePatterns[2] != expected { 1160 t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[2]) 1161 } 1162 1163 } 1164 1165 func TestSingleHandler(t *testing.T) { 1166 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1167 name := URLParam(r, "name") 1168 w.Write([]byte("hi " + name)) 1169 }) 1170 1171 r, _ := http.NewRequest("GET", "/", nil) 1172 rctx := NewRouteContext() 1173 r = r.WithContext(context.WithValue(r.Context(), RouteCtxKey, rctx)) 1174 rctx.URLParams.Add("name", "joe") 1175 1176 w := httptest.NewRecorder() 1177 h.ServeHTTP(w, r) 1178 1179 body := w.Body.String() 1180 expected := "hi joe" 1181 if body != expected { 1182 t.Fatalf("expected:%s got:%s", expected, body) 1183 } 1184 } 1185 1186 // TODO: a Router wrapper test.. 1187 // 1188 // type ACLMux struct { 1189 // *Mux 1190 // XX string 1191 // } 1192 // 1193 // func NewACLMux() *ACLMux { 1194 // return &ACLMux{Mux: NewRouter(), XX: "hihi"} 1195 // } 1196 // 1197 // // TODO: this should be supported... 1198 // func TestWoot(t *testing.T) { 1199 // var r Router = NewRouter() 1200 // 1201 // var r2 Router = NewACLMux() //NewRouter() 1202 // r2.Get("/hi", func(w http.ResponseWriter, r *http.Request) { 1203 // w.Write([]byte("hi")) 1204 // }) 1205 // 1206 // r.Mount("/", r2) 1207 // } 1208 1209 func TestServeHTTPExistingContext(t *testing.T) { 1210 r := NewRouter() 1211 r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { 1212 s, _ := r.Context().Value(ctxKey{"testCtx"}).(string) 1213 w.Write([]byte(s)) 1214 }) 1215 r.NotFound(func(w http.ResponseWriter, r *http.Request) { 1216 s, _ := r.Context().Value(ctxKey{"testCtx"}).(string) 1217 w.WriteHeader(404) 1218 w.Write([]byte(s)) 1219 }) 1220 1221 testcases := []struct { 1222 Method string 1223 Path string 1224 Ctx context.Context 1225 ExpectedStatus int 1226 ExpectedBody string 1227 }{ 1228 { 1229 Method: "GET", 1230 Path: "/hi", 1231 Ctx: context.WithValue(context.Background(), ctxKey{"testCtx"}, "hi ctx"), 1232 ExpectedStatus: 200, 1233 ExpectedBody: "hi ctx", 1234 }, 1235 { 1236 Method: "GET", 1237 Path: "/hello", 1238 Ctx: context.WithValue(context.Background(), ctxKey{"testCtx"}, "nothing here ctx"), 1239 ExpectedStatus: 404, 1240 ExpectedBody: "nothing here ctx", 1241 }, 1242 } 1243 1244 for _, tc := range testcases { 1245 resp := httptest.NewRecorder() 1246 req, err := http.NewRequest(tc.Method, tc.Path, nil) 1247 if err != nil { 1248 t.Fatalf("%v", err) 1249 } 1250 req = req.WithContext(tc.Ctx) 1251 r.ServeHTTP(resp, req) 1252 b, err := ioutil.ReadAll(resp.Body) 1253 if err != nil { 1254 t.Fatalf("%v", err) 1255 } 1256 if resp.Code != tc.ExpectedStatus { 1257 t.Fatalf("%v != %v", tc.ExpectedStatus, resp.Code) 1258 } 1259 if string(b) != tc.ExpectedBody { 1260 t.Fatalf("%s != %s", tc.ExpectedBody, b) 1261 } 1262 } 1263 } 1264 1265 func TestNestedGroups(t *testing.T) { 1266 handlerPrintCounter := func(w http.ResponseWriter, r *http.Request) { 1267 counter, _ := r.Context().Value(ctxKey{"counter"}).(int) 1268 w.Write([]byte(fmt.Sprintf("%v", counter))) 1269 } 1270 1271 mwIncreaseCounter := func(next http.Handler) http.Handler { 1272 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1273 ctx := r.Context() 1274 counter, _ := ctx.Value(ctxKey{"counter"}).(int) 1275 counter++ 1276 ctx = context.WithValue(ctx, ctxKey{"counter"}, counter) 1277 next.ServeHTTP(w, r.WithContext(ctx)) 1278 }) 1279 } 1280 1281 // Each route represents value of its counter (number of applied middlewares). 1282 r := NewRouter() // counter == 0 1283 r.Get("/0", handlerPrintCounter) 1284 r.Group(func(r Router) { 1285 r.Use(mwIncreaseCounter) // counter == 1 1286 r.Get("/1", handlerPrintCounter) 1287 1288 // r.Handle(GET, "/2", Chain(mwIncreaseCounter).HandlerFunc(handlerPrintCounter)) 1289 r.With(mwIncreaseCounter).Get("/2", handlerPrintCounter) 1290 1291 r.Group(func(r Router) { 1292 r.Use(mwIncreaseCounter, mwIncreaseCounter) // counter == 3 1293 r.Get("/3", handlerPrintCounter) 1294 }) 1295 r.Route("/", func(r Router) { 1296 r.Use(mwIncreaseCounter, mwIncreaseCounter) // counter == 3 1297 1298 // r.Handle(GET, "/4", Chain(mwIncreaseCounter).HandlerFunc(handlerPrintCounter)) 1299 r.With(mwIncreaseCounter).Get("/4", handlerPrintCounter) 1300 1301 r.Group(func(r Router) { 1302 r.Use(mwIncreaseCounter, mwIncreaseCounter) // counter == 5 1303 r.Get("/5", handlerPrintCounter) 1304 // r.Handle(GET, "/6", Chain(mwIncreaseCounter).HandlerFunc(handlerPrintCounter)) 1305 r.With(mwIncreaseCounter).Get("/6", handlerPrintCounter) 1306 1307 }) 1308 }) 1309 }) 1310 1311 ts := httptest.NewServer(r) 1312 defer ts.Close() 1313 1314 for _, route := range []string{"0", "1", "2", "3", "4", "5", "6"} { 1315 if _, body := testRequest(t, ts, "GET", "/"+route, nil); body != route { 1316 t.Errorf("expected %v, got %v", route, body) 1317 } 1318 } 1319 } 1320 1321 func TestMiddlewarePanicOnLateUse(t *testing.T) { 1322 handler := func(w http.ResponseWriter, r *http.Request) { 1323 w.Write([]byte("hello\n")) 1324 } 1325 1326 mw := func(next http.Handler) http.Handler { 1327 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1328 next.ServeHTTP(w, r) 1329 }) 1330 } 1331 1332 defer func() { 1333 if recover() == nil { 1334 t.Error("expected panic()") 1335 } 1336 }() 1337 1338 r := NewRouter() 1339 r.Get("/", handler) 1340 r.Use(mw) // Too late to apply middleware, we're expecting panic(). 1341 } 1342 1343 func TestMountingExistingPath(t *testing.T) { 1344 handler := func(w http.ResponseWriter, r *http.Request) {} 1345 1346 defer func() { 1347 if recover() == nil { 1348 t.Error("expected panic()") 1349 } 1350 }() 1351 1352 r := NewRouter() 1353 r.Get("/", handler) 1354 r.Mount("/hi", http.HandlerFunc(handler)) 1355 r.Mount("/hi", http.HandlerFunc(handler)) 1356 } 1357 1358 func TestMountingSimilarPattern(t *testing.T) { 1359 r := NewRouter() 1360 r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { 1361 w.Write([]byte("bye")) 1362 }) 1363 1364 r2 := NewRouter() 1365 r2.Get("/", func(w http.ResponseWriter, r *http.Request) { 1366 w.Write([]byte("foobar")) 1367 }) 1368 1369 r3 := NewRouter() 1370 r3.Get("/", func(w http.ResponseWriter, r *http.Request) { 1371 w.Write([]byte("foo")) 1372 }) 1373 1374 r.Mount("/foobar", r2) 1375 r.Mount("/foo", r3) 1376 1377 ts := httptest.NewServer(r) 1378 defer ts.Close() 1379 1380 if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" { 1381 t.Fatalf(body) 1382 } 1383 } 1384 1385 func TestMuxEmptyParams(t *testing.T) { 1386 r := NewRouter() 1387 r.Get(`/users/{x}/{y}/{z}`, func(w http.ResponseWriter, r *http.Request) { 1388 x := URLParam(r, "x") 1389 y := URLParam(r, "y") 1390 z := URLParam(r, "z") 1391 w.Write([]byte(fmt.Sprintf("%s-%s-%s", x, y, z))) 1392 }) 1393 1394 ts := httptest.NewServer(r) 1395 defer ts.Close() 1396 1397 if _, body := testRequest(t, ts, "GET", "/users/a/b/c", nil); body != "a-b-c" { 1398 t.Fatalf(body) 1399 } 1400 if _, body := testRequest(t, ts, "GET", "/users///c", nil); body != "--c" { 1401 t.Fatalf(body) 1402 } 1403 } 1404 1405 func TestMuxMissingParams(t *testing.T) { 1406 r := NewRouter() 1407 r.Get(`/user/{userId:\d+}`, func(w http.ResponseWriter, r *http.Request) { 1408 userID := URLParam(r, "userId") 1409 w.Write([]byte(fmt.Sprintf("userId = '%s'", userID))) 1410 }) 1411 r.NotFound(func(w http.ResponseWriter, r *http.Request) { 1412 w.WriteHeader(404) 1413 w.Write([]byte("nothing here")) 1414 }) 1415 1416 ts := httptest.NewServer(r) 1417 defer ts.Close() 1418 1419 if _, body := testRequest(t, ts, "GET", "/user/123", nil); body != "userId = '123'" { 1420 t.Fatalf(body) 1421 } 1422 if _, body := testRequest(t, ts, "GET", "/user/", nil); body != "nothing here" { 1423 t.Fatalf(body) 1424 } 1425 } 1426 1427 func TestMuxWildcardRoute(t *testing.T) { 1428 handler := func(w http.ResponseWriter, r *http.Request) {} 1429 1430 defer func() { 1431 if recover() == nil { 1432 t.Error("expected panic()") 1433 } 1434 }() 1435 1436 r := NewRouter() 1437 r.Get("/*/wildcard/must/be/at/end", handler) 1438 } 1439 1440 func TestMuxWildcardRouteCheckTwo(t *testing.T) { 1441 handler := func(w http.ResponseWriter, r *http.Request) {} 1442 1443 defer func() { 1444 if recover() == nil { 1445 t.Error("expected panic()") 1446 } 1447 }() 1448 1449 r := NewRouter() 1450 r.Get("/*/wildcard/{must}/be/at/end", handler) 1451 } 1452 1453 func TestMuxRegexp(t *testing.T) { 1454 r := NewRouter() 1455 r.Route("/{param:[0-9]*}/test", func(r Router) { 1456 r.Get("/", func(w http.ResponseWriter, r *http.Request) { 1457 w.Write([]byte(fmt.Sprintf("Hi: %s", URLParam(r, "param")))) 1458 }) 1459 }) 1460 1461 ts := httptest.NewServer(r) 1462 defer ts.Close() 1463 1464 if _, body := testRequest(t, ts, "GET", "//test", nil); body != "Hi: " { 1465 t.Fatalf(body) 1466 } 1467 } 1468 1469 func TestMuxRegexp2(t *testing.T) { 1470 r := NewRouter() 1471 r.Get("/foo-{suffix:[a-z]{2,3}}.json", func(w http.ResponseWriter, r *http.Request) { 1472 w.Write([]byte(URLParam(r, "suffix"))) 1473 }) 1474 ts := httptest.NewServer(r) 1475 defer ts.Close() 1476 1477 if _, body := testRequest(t, ts, "GET", "/foo-.json", nil); body != "" { 1478 t.Fatalf(body) 1479 } 1480 if _, body := testRequest(t, ts, "GET", "/foo-abc.json", nil); body != "abc" { 1481 t.Fatalf(body) 1482 } 1483 } 1484 1485 func TestMuxRegexp3(t *testing.T) { 1486 r := NewRouter() 1487 r.Get("/one/{firstId:[a-z0-9-]+}/{secondId:[a-z]+}/first", func(w http.ResponseWriter, r *http.Request) { 1488 w.Write([]byte("first")) 1489 }) 1490 r.Get("/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(w http.ResponseWriter, r *http.Request) { 1491 w.Write([]byte("second")) 1492 }) 1493 r.Delete("/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(w http.ResponseWriter, r *http.Request) { 1494 w.Write([]byte("third")) 1495 }) 1496 1497 r.Route("/one", func(r Router) { 1498 r.Get("/{dns:[a-z-0-9_]+}", func(writer http.ResponseWriter, request *http.Request) { 1499 writer.Write([]byte("_")) 1500 }) 1501 r.Get("/{dns:[a-z-0-9_]+}/info", func(writer http.ResponseWriter, request *http.Request) { 1502 writer.Write([]byte("_")) 1503 }) 1504 r.Delete("/{id:[0-9]+}", func(writer http.ResponseWriter, request *http.Request) { 1505 writer.Write([]byte("forth")) 1506 }) 1507 }) 1508 1509 ts := httptest.NewServer(r) 1510 defer ts.Close() 1511 1512 if _, body := testRequest(t, ts, "GET", "/one/hello/peter/first", nil); body != "first" { 1513 t.Fatalf(body) 1514 } 1515 if _, body := testRequest(t, ts, "GET", "/one/hithere/123/second", nil); body != "second" { 1516 t.Fatalf(body) 1517 } 1518 if _, body := testRequest(t, ts, "DELETE", "/one/hithere/123/second", nil); body != "third" { 1519 t.Fatalf(body) 1520 } 1521 if _, body := testRequest(t, ts, "DELETE", "/one/123", nil); body != "forth" { 1522 t.Fatalf(body) 1523 } 1524 } 1525 1526 func TestMuxSubrouterWildcardParam(t *testing.T) { 1527 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1528 fmt.Fprintf(w, "param:%v *:%v", URLParam(r, "param"), URLParam(r, "*")) 1529 }) 1530 1531 r := NewRouter() 1532 1533 r.Get("/bare/{param}", h) 1534 r.Get("/bare/{param}/*", h) 1535 1536 r.Route("/case0", func(r Router) { 1537 r.Get("/{param}", h) 1538 r.Get("/{param}/*", h) 1539 }) 1540 1541 ts := httptest.NewServer(r) 1542 defer ts.Close() 1543 1544 if _, body := testRequest(t, ts, "GET", "/bare/hi", nil); body != "param:hi *:" { 1545 t.Fatalf(body) 1546 } 1547 if _, body := testRequest(t, ts, "GET", "/bare/hi/yes", nil); body != "param:hi *:yes" { 1548 t.Fatalf(body) 1549 } 1550 if _, body := testRequest(t, ts, "GET", "/case0/hi", nil); body != "param:hi *:" { 1551 t.Fatalf(body) 1552 } 1553 if _, body := testRequest(t, ts, "GET", "/case0/hi/yes", nil); body != "param:hi *:yes" { 1554 t.Fatalf(body) 1555 } 1556 } 1557 1558 func TestMuxContextIsThreadSafe(t *testing.T) { 1559 router := NewRouter() 1560 router.Get("/{id}", func(w http.ResponseWriter, r *http.Request) { 1561 ctx, cancel := context.WithTimeout(r.Context(), 1*time.Millisecond) 1562 defer cancel() 1563 1564 <-ctx.Done() 1565 }) 1566 1567 wg := sync.WaitGroup{} 1568 1569 for i := 0; i < 100; i++ { 1570 wg.Add(1) 1571 go func() { 1572 defer wg.Done() 1573 for j := 0; j < 10000; j++ { 1574 w := httptest.NewRecorder() 1575 r, err := http.NewRequest("GET", "/ok", nil) 1576 if err != nil { 1577 t.Fatal(err) 1578 } 1579 1580 ctx, cancel := context.WithCancel(r.Context()) 1581 r = r.WithContext(ctx) 1582 1583 go func() { 1584 cancel() 1585 }() 1586 router.ServeHTTP(w, r) 1587 } 1588 }() 1589 } 1590 wg.Wait() 1591 } 1592 1593 func TestEscapedURLParams(t *testing.T) { 1594 m := NewRouter() 1595 m.Get("/api/{identifier}/{region}/{size}/{rotation}/*", func(w http.ResponseWriter, r *http.Request) { 1596 w.WriteHeader(200) 1597 rctx := RouteContext(r.Context()) 1598 if rctx == nil { 1599 t.Error("no context") 1600 return 1601 } 1602 identifier := URLParam(r, "identifier") 1603 if identifier != "http:%2f%2fexample.com%2fimage.png" { 1604 t.Errorf("identifier path parameter incorrect %s", identifier) 1605 return 1606 } 1607 region := URLParam(r, "region") 1608 if region != "full" { 1609 t.Errorf("region path parameter incorrect %s", region) 1610 return 1611 } 1612 size := URLParam(r, "size") 1613 if size != "max" { 1614 t.Errorf("size path parameter incorrect %s", size) 1615 return 1616 } 1617 rotation := URLParam(r, "rotation") 1618 if rotation != "0" { 1619 t.Errorf("rotation path parameter incorrect %s", rotation) 1620 return 1621 } 1622 w.Write([]byte("success")) 1623 }) 1624 1625 ts := httptest.NewServer(m) 1626 defer ts.Close() 1627 1628 if _, body := testRequest(t, ts, "GET", "/api/http:%2f%2fexample.com%2fimage.png/full/max/0/color.png", nil); body != "success" { 1629 t.Fatalf(body) 1630 } 1631 } 1632 1633 func TestMuxMatch(t *testing.T) { 1634 r := NewRouter() 1635 r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { 1636 w.Header().Set("X-Test", "yes") 1637 w.Write([]byte("bye")) 1638 }) 1639 r.Route("/articles", func(r Router) { 1640 r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) { 1641 id := URLParam(r, "id") 1642 w.Header().Set("X-Article", id) 1643 w.Write([]byte("article:" + id)) 1644 }) 1645 }) 1646 r.Route("/users", func(r Router) { 1647 r.Head("/{id}", func(w http.ResponseWriter, r *http.Request) { 1648 w.Header().Set("X-User", "-") 1649 w.Write([]byte("user")) 1650 }) 1651 r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) { 1652 id := URLParam(r, "id") 1653 w.Header().Set("X-User", id) 1654 w.Write([]byte("user:" + id)) 1655 }) 1656 }) 1657 1658 tctx := NewRouteContext() 1659 1660 tctx.Reset() 1661 if r.Match(tctx, "GET", "/users/1") == false { 1662 t.Fatal("expecting to find match for route:", "GET", "/users/1") 1663 } 1664 1665 tctx.Reset() 1666 if r.Match(tctx, "HEAD", "/articles/10") == true { 1667 t.Fatal("not expecting to find match for route:", "HEAD", "/articles/10") 1668 } 1669 } 1670 1671 func TestServerBaseContext(t *testing.T) { 1672 r := NewRouter() 1673 r.Get("/", func(w http.ResponseWriter, r *http.Request) { 1674 baseYes := r.Context().Value(ctxKey{"base"}).(string) 1675 if _, ok := r.Context().Value(http.ServerContextKey).(*http.Server); !ok { 1676 panic("missing server context") 1677 } 1678 if _, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr); !ok { 1679 panic("missing local addr context") 1680 } 1681 w.Write([]byte(baseYes)) 1682 }) 1683 1684 // Setup http Server with a base context 1685 ctx := context.WithValue(context.Background(), ctxKey{"base"}, "yes") 1686 ts := httptest.NewUnstartedServer(r) 1687 ts.Config.BaseContext = func(_ net.Listener) context.Context { 1688 return ctx 1689 } 1690 ts.Start() 1691 1692 defer ts.Close() 1693 1694 if _, body := testRequest(t, ts, "GET", "/", nil); body != "yes" { 1695 t.Fatalf(body) 1696 } 1697 } 1698 1699 func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) { 1700 req, err := http.NewRequest(method, ts.URL+path, body) 1701 if err != nil { 1702 t.Fatal(err) 1703 return nil, "" 1704 } 1705 1706 resp, err := http.DefaultClient.Do(req) 1707 if err != nil { 1708 t.Fatal(err) 1709 return nil, "" 1710 } 1711 1712 respBody, err := ioutil.ReadAll(resp.Body) 1713 if err != nil { 1714 t.Fatal(err) 1715 return nil, "" 1716 } 1717 defer resp.Body.Close() 1718 1719 return resp, string(respBody) 1720 } 1721 1722 func testHandler(t *testing.T, h http.Handler, method, path string, body io.Reader) (*http.Response, string) { 1723 r, _ := http.NewRequest(method, path, body) 1724 w := httptest.NewRecorder() 1725 h.ServeHTTP(w, r) 1726 return w.Result(), w.Body.String() 1727 } 1728 1729 type testFileSystem struct { 1730 open func(name string) (http.File, error) 1731 } 1732 1733 func (fs *testFileSystem) Open(name string) (http.File, error) { 1734 return fs.open(name) 1735 } 1736 1737 type testFile struct { 1738 name string 1739 contents []byte 1740 } 1741 1742 func (tf *testFile) Close() error { 1743 return nil 1744 } 1745 1746 func (tf *testFile) Read(p []byte) (n int, err error) { 1747 copy(p, tf.contents) 1748 return len(p), nil 1749 } 1750 1751 func (tf *testFile) Seek(offset int64, whence int) (int64, error) { 1752 return 0, nil 1753 } 1754 1755 func (tf *testFile) Readdir(count int) ([]os.FileInfo, error) { 1756 stat, _ := tf.Stat() 1757 return []os.FileInfo{stat}, nil 1758 } 1759 1760 func (tf *testFile) Stat() (os.FileInfo, error) { 1761 return &testFileInfo{tf.name, int64(len(tf.contents))}, nil 1762 } 1763 1764 type testFileInfo struct { 1765 name string 1766 size int64 1767 } 1768 1769 func (tfi *testFileInfo) Name() string { return tfi.name } 1770 func (tfi *testFileInfo) Size() int64 { return tfi.size } 1771 func (tfi *testFileInfo) Mode() os.FileMode { return 0755 } 1772 func (tfi *testFileInfo) ModTime() time.Time { return time.Now() } 1773 func (tfi *testFileInfo) IsDir() bool { return false } 1774 func (tfi *testFileInfo) Sys() interface{} { return nil } 1775 1776 type ctxKey struct { 1777 name string 1778 } 1779 1780 func (k ctxKey) String() string { 1781 return "context value " + k.name 1782 } 1783 1784 func BenchmarkMux(b *testing.B) { 1785 h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 1786 h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 1787 h3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 1788 h4 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 1789 h5 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 1790 h6 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 1791 1792 mx := NewRouter() 1793 mx.Get("/", h1) 1794 mx.Get("/hi", h2) 1795 mx.Get("/sup/{id}/and/{this}", h3) 1796 mx.Get("/sup/{id}/{bar:foo}/{this}", h3) 1797 1798 mx.Route("/sharing/{x}/{hash}", func(mx Router) { 1799 mx.Get("/", h4) // subrouter-1 1800 mx.Get("/{network}", h5) // subrouter-1 1801 mx.Get("/twitter", h5) 1802 mx.Route("/direct", func(mx Router) { 1803 mx.Get("/", h6) // subrouter-2 1804 mx.Get("/download", h6) 1805 }) 1806 }) 1807 1808 routes := []string{ 1809 "/", 1810 "/hi", 1811 "/sup/123/and/this", 1812 "/sup/123/foo/this", 1813 "/sharing/z/aBc", // subrouter-1 1814 "/sharing/z/aBc/twitter", // subrouter-1 1815 "/sharing/z/aBc/direct", // subrouter-2 1816 "/sharing/z/aBc/direct/download", // subrouter-2 1817 } 1818 1819 for _, path := range routes { 1820 b.Run("route:"+path, func(b *testing.B) { 1821 w := httptest.NewRecorder() 1822 r, _ := http.NewRequest("GET", path, nil) 1823 1824 b.ReportAllocs() 1825 b.ResetTimer() 1826 1827 for i := 0; i < b.N; i++ { 1828 mx.ServeHTTP(w, r) 1829 } 1830 }) 1831 } 1832 }