github.com/go-chi/chi@v1.5.5/mux_test.go (about)

     1  package chi
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"net"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"os"
    13  	"sync"
    14  	"testing"
    15  	"time"
    16  )
    17  
    18  func TestMuxBasic(t *testing.T) {
    19  	var count uint64
    20  	countermw := func(next http.Handler) http.Handler {
    21  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    22  			count++
    23  			next.ServeHTTP(w, r)
    24  		})
    25  	}
    26  
    27  	usermw := func(next http.Handler) http.Handler {
    28  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    29  			ctx := r.Context()
    30  			ctx = context.WithValue(ctx, ctxKey{"user"}, "peter")
    31  			r = r.WithContext(ctx)
    32  			next.ServeHTTP(w, r)
    33  		})
    34  	}
    35  
    36  	exmw := func(next http.Handler) http.Handler {
    37  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    38  			ctx := context.WithValue(r.Context(), ctxKey{"ex"}, "a")
    39  			r = r.WithContext(ctx)
    40  			next.ServeHTTP(w, r)
    41  		})
    42  	}
    43  
    44  	logbuf := bytes.NewBufferString("")
    45  	logmsg := "logmw test"
    46  	logmw := func(next http.Handler) http.Handler {
    47  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    48  			logbuf.WriteString(logmsg)
    49  			next.ServeHTTP(w, r)
    50  		})
    51  	}
    52  
    53  	cxindex := func(w http.ResponseWriter, r *http.Request) {
    54  		ctx := r.Context()
    55  		user := ctx.Value(ctxKey{"user"}).(string)
    56  		w.WriteHeader(200)
    57  		w.Write([]byte(fmt.Sprintf("hi %s", user)))
    58  	}
    59  
    60  	ping := func(w http.ResponseWriter, r *http.Request) {
    61  		w.WriteHeader(200)
    62  		w.Write([]byte("."))
    63  	}
    64  
    65  	headPing := func(w http.ResponseWriter, r *http.Request) {
    66  		w.Header().Set("X-Ping", "1")
    67  		w.WriteHeader(200)
    68  	}
    69  
    70  	createPing := func(w http.ResponseWriter, r *http.Request) {
    71  		// create ....
    72  		w.WriteHeader(201)
    73  	}
    74  
    75  	pingAll := func(w http.ResponseWriter, r *http.Request) {
    76  		w.WriteHeader(200)
    77  		w.Write([]byte("ping all"))
    78  	}
    79  
    80  	pingAll2 := func(w http.ResponseWriter, r *http.Request) {
    81  		w.WriteHeader(200)
    82  		w.Write([]byte("ping all2"))
    83  	}
    84  
    85  	pingOne := func(w http.ResponseWriter, r *http.Request) {
    86  		idParam := URLParam(r, "id")
    87  		w.WriteHeader(200)
    88  		w.Write([]byte(fmt.Sprintf("ping one id: %s", idParam)))
    89  	}
    90  
    91  	pingWoop := func(w http.ResponseWriter, r *http.Request) {
    92  		w.WriteHeader(200)
    93  		w.Write([]byte("woop." + URLParam(r, "iidd")))
    94  	}
    95  
    96  	catchAll := func(w http.ResponseWriter, r *http.Request) {
    97  		w.WriteHeader(200)
    98  		w.Write([]byte("catchall"))
    99  	}
   100  
   101  	m := NewRouter()
   102  	m.Use(countermw)
   103  	m.Use(usermw)
   104  	m.Use(exmw)
   105  	m.Use(logmw)
   106  	m.Get("/", cxindex)
   107  	m.Method("GET", "/ping", http.HandlerFunc(ping))
   108  	m.MethodFunc("GET", "/pingall", pingAll)
   109  	m.MethodFunc("get", "/ping/all", pingAll)
   110  	m.Get("/ping/all2", pingAll2)
   111  
   112  	m.Head("/ping", headPing)
   113  	m.Post("/ping", createPing)
   114  	m.Get("/ping/{id}", pingWoop)
   115  	m.Get("/ping/{id}", pingOne) // expected to overwrite to pingOne handler
   116  	m.Get("/ping/{iidd}/woop", pingWoop)
   117  	m.HandleFunc("/admin/*", catchAll)
   118  	// m.Post("/admin/*", catchAll)
   119  
   120  	ts := httptest.NewServer(m)
   121  	defer ts.Close()
   122  
   123  	// GET /
   124  	if _, body := testRequest(t, ts, "GET", "/", nil); body != "hi peter" {
   125  		t.Fatalf(body)
   126  	}
   127  	tlogmsg, _ := logbuf.ReadString(0)
   128  	if tlogmsg != logmsg {
   129  		t.Error("expecting log message from middleware:", logmsg)
   130  	}
   131  
   132  	// GET /ping
   133  	if _, body := testRequest(t, ts, "GET", "/ping", nil); body != "." {
   134  		t.Fatalf(body)
   135  	}
   136  
   137  	// GET /pingall
   138  	if _, body := testRequest(t, ts, "GET", "/pingall", nil); body != "ping all" {
   139  		t.Fatalf(body)
   140  	}
   141  
   142  	// GET /ping/all
   143  	if _, body := testRequest(t, ts, "GET", "/ping/all", nil); body != "ping all" {
   144  		t.Fatalf(body)
   145  	}
   146  
   147  	// GET /ping/all2
   148  	if _, body := testRequest(t, ts, "GET", "/ping/all2", nil); body != "ping all2" {
   149  		t.Fatalf(body)
   150  	}
   151  
   152  	// GET /ping/123
   153  	if _, body := testRequest(t, ts, "GET", "/ping/123", nil); body != "ping one id: 123" {
   154  		t.Fatalf(body)
   155  	}
   156  
   157  	// GET /ping/allan
   158  	if _, body := testRequest(t, ts, "GET", "/ping/allan", nil); body != "ping one id: allan" {
   159  		t.Fatalf(body)
   160  	}
   161  
   162  	// GET /ping/1/woop
   163  	if _, body := testRequest(t, ts, "GET", "/ping/1/woop", nil); body != "woop.1" {
   164  		t.Fatalf(body)
   165  	}
   166  
   167  	// HEAD /ping
   168  	resp, err := http.Head(ts.URL + "/ping")
   169  	if err != nil {
   170  		t.Fatal(err)
   171  	}
   172  	if resp.StatusCode != 200 {
   173  		t.Error("head failed, should be 200")
   174  	}
   175  	if resp.Header.Get("X-Ping") == "" {
   176  		t.Error("expecting X-Ping header")
   177  	}
   178  
   179  	// GET /admin/catch-this
   180  	if _, body := testRequest(t, ts, "GET", "/admin/catch-thazzzzz", nil); body != "catchall" {
   181  		t.Fatalf(body)
   182  	}
   183  
   184  	// POST /admin/catch-this
   185  	resp, err = http.Post(ts.URL+"/admin/casdfsadfs", "text/plain", bytes.NewReader([]byte{}))
   186  	if err != nil {
   187  		t.Fatal(err)
   188  	}
   189  
   190  	body, err := ioutil.ReadAll(resp.Body)
   191  	if err != nil {
   192  		t.Fatal(err)
   193  	}
   194  	defer resp.Body.Close()
   195  
   196  	if resp.StatusCode != 200 {
   197  		t.Error("POST failed, should be 200")
   198  	}
   199  
   200  	if string(body) != "catchall" {
   201  		t.Error("expecting response body: 'catchall'")
   202  	}
   203  
   204  	// Custom http method DIE /ping/1/woop
   205  	if resp, body := testRequest(t, ts, "DIE", "/ping/1/woop", nil); body != "" || resp.StatusCode != 405 {
   206  		t.Fatalf(fmt.Sprintf("expecting 405 status and empty body, got %d '%s'", resp.StatusCode, body))
   207  	}
   208  }
   209  
   210  func TestMuxMounts(t *testing.T) {
   211  	r := NewRouter()
   212  
   213  	r.Get("/{hash}", func(w http.ResponseWriter, r *http.Request) {
   214  		v := URLParam(r, "hash")
   215  		w.Write([]byte(fmt.Sprintf("/%s", v)))
   216  	})
   217  
   218  	r.Route("/{hash}/share", func(r Router) {
   219  		r.Get("/", func(w http.ResponseWriter, r *http.Request) {
   220  			v := URLParam(r, "hash")
   221  			w.Write([]byte(fmt.Sprintf("/%s/share", v)))
   222  		})
   223  		r.Get("/{network}", func(w http.ResponseWriter, r *http.Request) {
   224  			v := URLParam(r, "hash")
   225  			n := URLParam(r, "network")
   226  			w.Write([]byte(fmt.Sprintf("/%s/share/%s", v, n)))
   227  		})
   228  	})
   229  
   230  	m := NewRouter()
   231  	m.Mount("/sharing", r)
   232  
   233  	ts := httptest.NewServer(m)
   234  	defer ts.Close()
   235  
   236  	if _, body := testRequest(t, ts, "GET", "/sharing/aBc", nil); body != "/aBc" {
   237  		t.Fatalf(body)
   238  	}
   239  	if _, body := testRequest(t, ts, "GET", "/sharing/aBc/share", nil); body != "/aBc/share" {
   240  		t.Fatalf(body)
   241  	}
   242  	if _, body := testRequest(t, ts, "GET", "/sharing/aBc/share/twitter", nil); body != "/aBc/share/twitter" {
   243  		t.Fatalf(body)
   244  	}
   245  }
   246  
   247  func TestMuxPlain(t *testing.T) {
   248  	r := NewRouter()
   249  	r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
   250  		w.Write([]byte("bye"))
   251  	})
   252  	r.NotFound(func(w http.ResponseWriter, r *http.Request) {
   253  		w.WriteHeader(404)
   254  		w.Write([]byte("nothing here"))
   255  	})
   256  
   257  	ts := httptest.NewServer(r)
   258  	defer ts.Close()
   259  
   260  	if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
   261  		t.Fatalf(body)
   262  	}
   263  	if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" {
   264  		t.Fatalf(body)
   265  	}
   266  }
   267  
   268  func TestMuxEmptyRoutes(t *testing.T) {
   269  	mux := NewRouter()
   270  
   271  	apiRouter := NewRouter()
   272  	// oops, we forgot to declare any route handlers
   273  
   274  	mux.Handle("/api*", apiRouter)
   275  
   276  	if _, body := testHandler(t, mux, "GET", "/", nil); body != "404 page not found\n" {
   277  		t.Fatalf(body)
   278  	}
   279  
   280  	if _, body := testHandler(t, apiRouter, "GET", "/", nil); body != "404 page not found\n" {
   281  		t.Fatalf(body)
   282  	}
   283  }
   284  
   285  // Test a mux that routes a trailing slash, see also middleware/strip_test.go
   286  // for an example of using a middleware to handle trailing slashes.
   287  func TestMuxTrailingSlash(t *testing.T) {
   288  	r := NewRouter()
   289  	r.NotFound(func(w http.ResponseWriter, r *http.Request) {
   290  		w.WriteHeader(404)
   291  		w.Write([]byte("nothing here"))
   292  	})
   293  
   294  	subRoutes := NewRouter()
   295  	indexHandler := func(w http.ResponseWriter, r *http.Request) {
   296  		accountID := URLParam(r, "accountID")
   297  		w.Write([]byte(accountID))
   298  	}
   299  	subRoutes.Get("/", indexHandler)
   300  
   301  	r.Mount("/accounts/{accountID}", subRoutes)
   302  	r.Get("/accounts/{accountID}/", indexHandler)
   303  
   304  	ts := httptest.NewServer(r)
   305  	defer ts.Close()
   306  
   307  	if _, body := testRequest(t, ts, "GET", "/accounts/admin", nil); body != "admin" {
   308  		t.Fatalf(body)
   309  	}
   310  	if _, body := testRequest(t, ts, "GET", "/accounts/admin/", nil); body != "admin" {
   311  		t.Fatalf(body)
   312  	}
   313  	if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" {
   314  		t.Fatalf(body)
   315  	}
   316  }
   317  
   318  func TestMuxNestedNotFound(t *testing.T) {
   319  	r := NewRouter()
   320  
   321  	r.Use(func(next http.Handler) http.Handler {
   322  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   323  			r = r.WithContext(context.WithValue(r.Context(), ctxKey{"mw"}, "mw"))
   324  			next.ServeHTTP(w, r)
   325  		})
   326  	})
   327  
   328  	r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
   329  		w.Write([]byte("bye"))
   330  	})
   331  
   332  	r.With(func(next http.Handler) http.Handler {
   333  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   334  			r = r.WithContext(context.WithValue(r.Context(), ctxKey{"with"}, "with"))
   335  			next.ServeHTTP(w, r)
   336  		})
   337  	}).NotFound(func(w http.ResponseWriter, r *http.Request) {
   338  		chkMw := r.Context().Value(ctxKey{"mw"}).(string)
   339  		chkWith := r.Context().Value(ctxKey{"with"}).(string)
   340  		w.WriteHeader(404)
   341  		w.Write([]byte(fmt.Sprintf("root 404 %s %s", chkMw, chkWith)))
   342  	})
   343  
   344  	sr1 := NewRouter()
   345  
   346  	sr1.Get("/sub", func(w http.ResponseWriter, r *http.Request) {
   347  		w.Write([]byte("sub"))
   348  	})
   349  	sr1.Group(func(sr1 Router) {
   350  		sr1.Use(func(next http.Handler) http.Handler {
   351  			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   352  				r = r.WithContext(context.WithValue(r.Context(), ctxKey{"mw2"}, "mw2"))
   353  				next.ServeHTTP(w, r)
   354  			})
   355  		})
   356  		sr1.NotFound(func(w http.ResponseWriter, r *http.Request) {
   357  			chkMw2 := r.Context().Value(ctxKey{"mw2"}).(string)
   358  			w.WriteHeader(404)
   359  			w.Write([]byte(fmt.Sprintf("sub 404 %s", chkMw2)))
   360  		})
   361  	})
   362  
   363  	sr2 := NewRouter()
   364  	sr2.Get("/sub", func(w http.ResponseWriter, r *http.Request) {
   365  		w.Write([]byte("sub2"))
   366  	})
   367  
   368  	r.Mount("/admin1", sr1)
   369  	r.Mount("/admin2", sr2)
   370  
   371  	ts := httptest.NewServer(r)
   372  	defer ts.Close()
   373  
   374  	if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
   375  		t.Fatalf(body)
   376  	}
   377  	if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "root 404 mw with" {
   378  		t.Fatalf(body)
   379  	}
   380  	if _, body := testRequest(t, ts, "GET", "/admin1/sub", nil); body != "sub" {
   381  		t.Fatalf(body)
   382  	}
   383  	if _, body := testRequest(t, ts, "GET", "/admin1/nope", nil); body != "sub 404 mw2" {
   384  		t.Fatalf(body)
   385  	}
   386  	if _, body := testRequest(t, ts, "GET", "/admin2/sub", nil); body != "sub2" {
   387  		t.Fatalf(body)
   388  	}
   389  
   390  	// Not found pages should bubble up to the root.
   391  	if _, body := testRequest(t, ts, "GET", "/admin2/nope", nil); body != "root 404 mw with" {
   392  		t.Fatalf(body)
   393  	}
   394  }
   395  
   396  func TestMuxNestedMethodNotAllowed(t *testing.T) {
   397  	r := NewRouter()
   398  	r.Get("/root", func(w http.ResponseWriter, r *http.Request) {
   399  		w.Write([]byte("root"))
   400  	})
   401  	r.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
   402  		w.WriteHeader(405)
   403  		w.Write([]byte("root 405"))
   404  	})
   405  
   406  	sr1 := NewRouter()
   407  	sr1.Get("/sub1", func(w http.ResponseWriter, r *http.Request) {
   408  		w.Write([]byte("sub1"))
   409  	})
   410  	sr1.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
   411  		w.WriteHeader(405)
   412  		w.Write([]byte("sub1 405"))
   413  	})
   414  
   415  	sr2 := NewRouter()
   416  	sr2.Get("/sub2", func(w http.ResponseWriter, r *http.Request) {
   417  		w.Write([]byte("sub2"))
   418  	})
   419  
   420  	pathVar := NewRouter()
   421  	pathVar.Get("/{var}", func(w http.ResponseWriter, r *http.Request) {
   422  		w.Write([]byte("pv"))
   423  	})
   424  	pathVar.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
   425  		w.WriteHeader(405)
   426  		w.Write([]byte("pv 405"))
   427  	})
   428  
   429  	r.Mount("/prefix1", sr1)
   430  	r.Mount("/prefix2", sr2)
   431  	r.Mount("/pathVar", pathVar)
   432  
   433  	ts := httptest.NewServer(r)
   434  	defer ts.Close()
   435  
   436  	if _, body := testRequest(t, ts, "GET", "/root", nil); body != "root" {
   437  		t.Fatalf(body)
   438  	}
   439  	if _, body := testRequest(t, ts, "PUT", "/root", nil); body != "root 405" {
   440  		t.Fatalf(body)
   441  	}
   442  	if _, body := testRequest(t, ts, "GET", "/prefix1/sub1", nil); body != "sub1" {
   443  		t.Fatalf(body)
   444  	}
   445  	if _, body := testRequest(t, ts, "PUT", "/prefix1/sub1", nil); body != "sub1 405" {
   446  		t.Fatalf(body)
   447  	}
   448  	if _, body := testRequest(t, ts, "GET", "/prefix2/sub2", nil); body != "sub2" {
   449  		t.Fatalf(body)
   450  	}
   451  	if _, body := testRequest(t, ts, "PUT", "/prefix2/sub2", nil); body != "root 405" {
   452  		t.Fatalf(body)
   453  	}
   454  	if _, body := testRequest(t, ts, "GET", "/pathVar/myvar", nil); body != "pv" {
   455  		t.Fatalf(body)
   456  	}
   457  	if _, body := testRequest(t, ts, "DELETE", "/pathVar/myvar", nil); body != "pv 405" {
   458  		t.Fatalf(body)
   459  	}
   460  }
   461  
   462  func TestMuxComplicatedNotFound(t *testing.T) {
   463  	decorateRouter := func(r *Mux) {
   464  		// Root router with groups
   465  		r.Get("/auth", func(w http.ResponseWriter, r *http.Request) {
   466  			w.Write([]byte("auth get"))
   467  		})
   468  		r.Route("/public", func(r Router) {
   469  			r.Get("/", func(w http.ResponseWriter, r *http.Request) {
   470  				w.Write([]byte("public get"))
   471  			})
   472  		})
   473  
   474  		// sub router with groups
   475  		sub0 := NewRouter()
   476  		sub0.Route("/resource", func(r Router) {
   477  			r.Get("/", func(w http.ResponseWriter, r *http.Request) {
   478  				w.Write([]byte("private get"))
   479  			})
   480  		})
   481  		r.Mount("/private", sub0)
   482  
   483  		// sub router with groups
   484  		sub1 := NewRouter()
   485  		sub1.Route("/resource", func(r Router) {
   486  			r.Get("/", func(w http.ResponseWriter, r *http.Request) {
   487  				w.Write([]byte("private get"))
   488  			})
   489  		})
   490  		r.With(func(next http.Handler) http.Handler { return next }).Mount("/private_mw", sub1)
   491  	}
   492  
   493  	testNotFound := func(t *testing.T, r *Mux) {
   494  		ts := httptest.NewServer(r)
   495  		defer ts.Close()
   496  
   497  		// check that we didn't break correct routes
   498  		if _, body := testRequest(t, ts, "GET", "/auth", nil); body != "auth get" {
   499  			t.Fatalf(body)
   500  		}
   501  		if _, body := testRequest(t, ts, "GET", "/public", nil); body != "public get" {
   502  			t.Fatalf(body)
   503  		}
   504  		if _, body := testRequest(t, ts, "GET", "/public/", nil); body != "public get" {
   505  			t.Fatalf(body)
   506  		}
   507  		if _, body := testRequest(t, ts, "GET", "/private/resource", nil); body != "private get" {
   508  			t.Fatalf(body)
   509  		}
   510  		// check custom not-found on all levels
   511  		if _, body := testRequest(t, ts, "GET", "/nope", nil); body != "custom not-found" {
   512  			t.Fatalf(body)
   513  		}
   514  		if _, body := testRequest(t, ts, "GET", "/public/nope", nil); body != "custom not-found" {
   515  			t.Fatalf(body)
   516  		}
   517  		if _, body := testRequest(t, ts, "GET", "/private/nope", nil); body != "custom not-found" {
   518  			t.Fatalf(body)
   519  		}
   520  		if _, body := testRequest(t, ts, "GET", "/private/resource/nope", nil); body != "custom not-found" {
   521  			t.Fatalf(body)
   522  		}
   523  		if _, body := testRequest(t, ts, "GET", "/private_mw/nope", nil); body != "custom not-found" {
   524  			t.Fatalf(body)
   525  		}
   526  		if _, body := testRequest(t, ts, "GET", "/private_mw/resource/nope", nil); body != "custom not-found" {
   527  			t.Fatalf(body)
   528  		}
   529  		// check custom not-found on trailing slash routes
   530  		if _, body := testRequest(t, ts, "GET", "/auth/", nil); body != "custom not-found" {
   531  			t.Fatalf(body)
   532  		}
   533  	}
   534  
   535  	t.Run("pre", func(t *testing.T) {
   536  		r := NewRouter()
   537  		r.NotFound(func(w http.ResponseWriter, r *http.Request) {
   538  			w.Write([]byte("custom not-found"))
   539  		})
   540  		decorateRouter(r)
   541  		testNotFound(t, r)
   542  	})
   543  
   544  	t.Run("post", func(t *testing.T) {
   545  		r := NewRouter()
   546  		decorateRouter(r)
   547  		r.NotFound(func(w http.ResponseWriter, r *http.Request) {
   548  			w.Write([]byte("custom not-found"))
   549  		})
   550  		testNotFound(t, r)
   551  	})
   552  }
   553  
   554  func TestMuxWith(t *testing.T) {
   555  	var cmwInit1, cmwHandler1 uint64
   556  	var cmwInit2, cmwHandler2 uint64
   557  	mw1 := func(next http.Handler) http.Handler {
   558  		cmwInit1++
   559  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   560  			cmwHandler1++
   561  			r = r.WithContext(context.WithValue(r.Context(), ctxKey{"inline1"}, "yes"))
   562  			next.ServeHTTP(w, r)
   563  		})
   564  	}
   565  	mw2 := func(next http.Handler) http.Handler {
   566  		cmwInit2++
   567  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   568  			cmwHandler2++
   569  			r = r.WithContext(context.WithValue(r.Context(), ctxKey{"inline2"}, "yes"))
   570  			next.ServeHTTP(w, r)
   571  		})
   572  	}
   573  
   574  	r := NewRouter()
   575  	r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
   576  		w.Write([]byte("bye"))
   577  	})
   578  	r.With(mw1).With(mw2).Get("/inline", func(w http.ResponseWriter, r *http.Request) {
   579  		v1 := r.Context().Value(ctxKey{"inline1"}).(string)
   580  		v2 := r.Context().Value(ctxKey{"inline2"}).(string)
   581  		w.Write([]byte(fmt.Sprintf("inline %s %s", v1, v2)))
   582  	})
   583  
   584  	ts := httptest.NewServer(r)
   585  	defer ts.Close()
   586  
   587  	if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
   588  		t.Fatalf(body)
   589  	}
   590  	if _, body := testRequest(t, ts, "GET", "/inline", nil); body != "inline yes yes" {
   591  		t.Fatalf(body)
   592  	}
   593  	if cmwInit1 != 1 {
   594  		t.Fatalf("expecting cmwInit1 to be 1, got %d", cmwInit1)
   595  	}
   596  	if cmwHandler1 != 1 {
   597  		t.Fatalf("expecting cmwHandler1 to be 1, got %d", cmwHandler1)
   598  	}
   599  	if cmwInit2 != 1 {
   600  		t.Fatalf("expecting cmwInit2 to be 1, got %d", cmwInit2)
   601  	}
   602  	if cmwHandler2 != 1 {
   603  		t.Fatalf("expecting cmwHandler2 to be 1, got %d", cmwHandler2)
   604  	}
   605  }
   606  
   607  func TestRouterFromMuxWith(t *testing.T) {
   608  	t.Parallel()
   609  
   610  	r := NewRouter()
   611  
   612  	with := r.With(func(next http.Handler) http.Handler {
   613  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   614  			next.ServeHTTP(w, r)
   615  		})
   616  	})
   617  
   618  	with.Get("/with_middleware", func(w http.ResponseWriter, r *http.Request) {})
   619  
   620  	ts := httptest.NewServer(with)
   621  	defer ts.Close()
   622  
   623  	// Without the fix this test was committed with, this causes a panic.
   624  	testRequest(t, ts, http.MethodGet, "/with_middleware", nil)
   625  }
   626  
   627  func TestMuxMiddlewareStack(t *testing.T) {
   628  	var stdmwInit, stdmwHandler uint64
   629  	stdmw := func(next http.Handler) http.Handler {
   630  		stdmwInit++
   631  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   632  			stdmwHandler++
   633  			next.ServeHTTP(w, r)
   634  		})
   635  	}
   636  	_ = stdmw
   637  
   638  	var ctxmwInit, ctxmwHandler uint64
   639  	ctxmw := func(next http.Handler) http.Handler {
   640  		ctxmwInit++
   641  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   642  			ctxmwHandler++
   643  			ctx := r.Context()
   644  			ctx = context.WithValue(ctx, ctxKey{"count.ctxmwHandler"}, ctxmwHandler)
   645  			r = r.WithContext(ctx)
   646  			next.ServeHTTP(w, r)
   647  		})
   648  	}
   649  
   650  	var inCtxmwInit, inCtxmwHandler uint64
   651  	inCtxmw := func(next http.Handler) http.Handler {
   652  		inCtxmwInit++
   653  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   654  			inCtxmwHandler++
   655  			next.ServeHTTP(w, r)
   656  		})
   657  	}
   658  
   659  	r := NewRouter()
   660  	r.Use(stdmw)
   661  	r.Use(ctxmw)
   662  	r.Use(func(next http.Handler) http.Handler {
   663  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   664  			if r.URL.Path == "/ping" {
   665  				w.Write([]byte("pong"))
   666  				return
   667  			}
   668  			next.ServeHTTP(w, r)
   669  		})
   670  	})
   671  
   672  	var handlerCount uint64
   673  
   674  	r.With(inCtxmw).Get("/", func(w http.ResponseWriter, r *http.Request) {
   675  		handlerCount++
   676  		ctx := r.Context()
   677  		ctxmwHandlerCount := ctx.Value(ctxKey{"count.ctxmwHandler"}).(uint64)
   678  		w.Write([]byte(fmt.Sprintf("inits:%d reqs:%d ctxValue:%d", ctxmwInit, handlerCount, ctxmwHandlerCount)))
   679  	})
   680  
   681  	r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
   682  		w.Write([]byte("wooot"))
   683  	})
   684  
   685  	ts := httptest.NewServer(r)
   686  	defer ts.Close()
   687  
   688  	testRequest(t, ts, "GET", "/", nil)
   689  	testRequest(t, ts, "GET", "/", nil)
   690  	var body string
   691  	_, body = testRequest(t, ts, "GET", "/", nil)
   692  	if body != "inits:1 reqs:3 ctxValue:3" {
   693  		t.Fatalf("got: '%s'", body)
   694  	}
   695  
   696  	_, body = testRequest(t, ts, "GET", "/ping", nil)
   697  	if body != "pong" {
   698  		t.Fatalf("got: '%s'", body)
   699  	}
   700  }
   701  
   702  func TestMuxRouteGroups(t *testing.T) {
   703  	var stdmwInit, stdmwHandler uint64
   704  
   705  	stdmw := func(next http.Handler) http.Handler {
   706  		stdmwInit++
   707  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   708  			stdmwHandler++
   709  			next.ServeHTTP(w, r)
   710  		})
   711  	}
   712  
   713  	var stdmwInit2, stdmwHandler2 uint64
   714  	stdmw2 := func(next http.Handler) http.Handler {
   715  		stdmwInit2++
   716  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   717  			stdmwHandler2++
   718  			next.ServeHTTP(w, r)
   719  		})
   720  	}
   721  
   722  	r := NewRouter()
   723  	r.Group(func(r Router) {
   724  		r.Use(stdmw)
   725  		r.Get("/group", func(w http.ResponseWriter, r *http.Request) {
   726  			w.Write([]byte("root group"))
   727  		})
   728  	})
   729  	r.Group(func(r Router) {
   730  		r.Use(stdmw2)
   731  		r.Get("/group2", func(w http.ResponseWriter, r *http.Request) {
   732  			w.Write([]byte("root group2"))
   733  		})
   734  	})
   735  
   736  	ts := httptest.NewServer(r)
   737  	defer ts.Close()
   738  
   739  	// GET /group
   740  	_, body := testRequest(t, ts, "GET", "/group", nil)
   741  	if body != "root group" {
   742  		t.Fatalf("got: '%s'", body)
   743  	}
   744  	if stdmwInit != 1 || stdmwHandler != 1 {
   745  		t.Logf("stdmw counters failed, should be 1:1, got %d:%d", stdmwInit, stdmwHandler)
   746  	}
   747  
   748  	// GET /group2
   749  	_, body = testRequest(t, ts, "GET", "/group2", nil)
   750  	if body != "root group2" {
   751  		t.Fatalf("got: '%s'", body)
   752  	}
   753  	if stdmwInit2 != 1 || stdmwHandler2 != 1 {
   754  		t.Fatalf("stdmw2 counters failed, should be 1:1, got %d:%d", stdmwInit2, stdmwHandler2)
   755  	}
   756  }
   757  
   758  func TestMuxBig(t *testing.T) {
   759  	r := bigMux()
   760  
   761  	ts := httptest.NewServer(r)
   762  	defer ts.Close()
   763  
   764  	var body, expected string
   765  
   766  	_, body = testRequest(t, ts, "GET", "/favicon.ico", nil)
   767  	if body != "fav" {
   768  		t.Fatalf("got '%s'", body)
   769  	}
   770  	_, body = testRequest(t, ts, "GET", "/hubs/4/view", nil)
   771  	if body != "/hubs/4/view reqid:1 session:anonymous" {
   772  		t.Fatalf("got '%v'", body)
   773  	}
   774  	_, body = testRequest(t, ts, "GET", "/hubs/4/view/index.html", nil)
   775  	if body != "/hubs/4/view/index.html reqid:1 session:anonymous" {
   776  		t.Fatalf("got '%s'", body)
   777  	}
   778  	_, body = testRequest(t, ts, "POST", "/hubs/ethereumhub/view/index.html", nil)
   779  	if body != "/hubs/ethereumhub/view/index.html reqid:1 session:anonymous" {
   780  		t.Fatalf("got '%s'", body)
   781  	}
   782  	_, body = testRequest(t, ts, "GET", "/", nil)
   783  	if body != "/ reqid:1 session:elvis" {
   784  		t.Fatalf("got '%s'", body)
   785  	}
   786  	_, body = testRequest(t, ts, "GET", "/suggestions", nil)
   787  	if body != "/suggestions reqid:1 session:elvis" {
   788  		t.Fatalf("got '%s'", body)
   789  	}
   790  	_, body = testRequest(t, ts, "GET", "/woot/444/hiiii", nil)
   791  	if body != "/woot/444/hiiii" {
   792  		t.Fatalf("got '%s'", body)
   793  	}
   794  	_, body = testRequest(t, ts, "GET", "/hubs/123", nil)
   795  	expected = "/hubs/123 reqid:1 session:elvis"
   796  	if body != expected {
   797  		t.Fatalf("expected:%s got:%s", expected, body)
   798  	}
   799  	_, body = testRequest(t, ts, "GET", "/hubs/123/touch", nil)
   800  	if body != "/hubs/123/touch reqid:1 session:elvis" {
   801  		t.Fatalf("got '%s'", body)
   802  	}
   803  	_, body = testRequest(t, ts, "GET", "/hubs/123/webhooks", nil)
   804  	if body != "/hubs/123/webhooks reqid:1 session:elvis" {
   805  		t.Fatalf("got '%s'", body)
   806  	}
   807  	_, body = testRequest(t, ts, "GET", "/hubs/123/posts", nil)
   808  	if body != "/hubs/123/posts reqid:1 session:elvis" {
   809  		t.Fatalf("got '%s'", body)
   810  	}
   811  	_, body = testRequest(t, ts, "GET", "/folders", nil)
   812  	if body != "404 page not found\n" {
   813  		t.Fatalf("got '%s'", body)
   814  	}
   815  	_, body = testRequest(t, ts, "GET", "/folders/", nil)
   816  	if body != "/folders/ reqid:1 session:elvis" {
   817  		t.Fatalf("got '%s'", body)
   818  	}
   819  	_, body = testRequest(t, ts, "GET", "/folders/public", nil)
   820  	if body != "/folders/public reqid:1 session:elvis" {
   821  		t.Fatalf("got '%s'", body)
   822  	}
   823  	_, body = testRequest(t, ts, "GET", "/folders/nothing", nil)
   824  	if body != "404 page not found\n" {
   825  		t.Fatalf("got '%s'", body)
   826  	}
   827  }
   828  
   829  func bigMux() Router {
   830  	var r *Mux
   831  	var sr3 *Mux
   832  	// var sr1, sr2, sr3, sr4, sr5, sr6 *Mux
   833  	r = NewRouter()
   834  	r.Use(func(next http.Handler) http.Handler {
   835  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   836  			ctx := context.WithValue(r.Context(), ctxKey{"requestID"}, "1")
   837  			next.ServeHTTP(w, r.WithContext(ctx))
   838  		})
   839  	})
   840  	r.Use(func(next http.Handler) http.Handler {
   841  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   842  			next.ServeHTTP(w, r)
   843  		})
   844  	})
   845  	r.Group(func(r Router) {
   846  		r.Use(func(next http.Handler) http.Handler {
   847  			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   848  				ctx := context.WithValue(r.Context(), ctxKey{"session.user"}, "anonymous")
   849  				next.ServeHTTP(w, r.WithContext(ctx))
   850  			})
   851  		})
   852  		r.Get("/favicon.ico", func(w http.ResponseWriter, r *http.Request) {
   853  			w.Write([]byte("fav"))
   854  		})
   855  		r.Get("/hubs/{hubID}/view", func(w http.ResponseWriter, r *http.Request) {
   856  			ctx := r.Context()
   857  			s := fmt.Sprintf("/hubs/%s/view reqid:%s session:%s", URLParam(r, "hubID"),
   858  				ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
   859  			w.Write([]byte(s))
   860  		})
   861  		r.Get("/hubs/{hubID}/view/*", func(w http.ResponseWriter, r *http.Request) {
   862  			ctx := r.Context()
   863  			s := fmt.Sprintf("/hubs/%s/view/%s reqid:%s session:%s", URLParamFromCtx(ctx, "hubID"),
   864  				URLParam(r, "*"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
   865  			w.Write([]byte(s))
   866  		})
   867  		r.Post("/hubs/{hubSlug}/view/*", func(w http.ResponseWriter, r *http.Request) {
   868  			ctx := r.Context()
   869  			s := fmt.Sprintf("/hubs/%s/view/%s reqid:%s session:%s", URLParamFromCtx(ctx, "hubSlug"),
   870  				URLParam(r, "*"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
   871  			w.Write([]byte(s))
   872  		})
   873  	})
   874  	r.Group(func(r Router) {
   875  		r.Use(func(next http.Handler) http.Handler {
   876  			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   877  				ctx := context.WithValue(r.Context(), ctxKey{"session.user"}, "elvis")
   878  				next.ServeHTTP(w, r.WithContext(ctx))
   879  			})
   880  		})
   881  		r.Get("/", func(w http.ResponseWriter, r *http.Request) {
   882  			ctx := r.Context()
   883  			s := fmt.Sprintf("/ reqid:%s session:%s", ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
   884  			w.Write([]byte(s))
   885  		})
   886  		r.Get("/suggestions", func(w http.ResponseWriter, r *http.Request) {
   887  			ctx := r.Context()
   888  			s := fmt.Sprintf("/suggestions reqid:%s session:%s", ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
   889  			w.Write([]byte(s))
   890  		})
   891  
   892  		r.Get("/woot/{wootID}/*", func(w http.ResponseWriter, r *http.Request) {
   893  			s := fmt.Sprintf("/woot/%s/%s", URLParam(r, "wootID"), URLParam(r, "*"))
   894  			w.Write([]byte(s))
   895  		})
   896  
   897  		r.Route("/hubs", func(r Router) {
   898  			_ = r.(*Mux) // sr1
   899  			r.Route("/{hubID}", func(r Router) {
   900  				_ = r.(*Mux) // sr2
   901  				r.Get("/", func(w http.ResponseWriter, r *http.Request) {
   902  					ctx := r.Context()
   903  					s := fmt.Sprintf("/hubs/%s reqid:%s session:%s",
   904  						URLParam(r, "hubID"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
   905  					w.Write([]byte(s))
   906  				})
   907  				r.Get("/touch", func(w http.ResponseWriter, r *http.Request) {
   908  					ctx := r.Context()
   909  					s := fmt.Sprintf("/hubs/%s/touch reqid:%s session:%s", URLParam(r, "hubID"),
   910  						ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
   911  					w.Write([]byte(s))
   912  				})
   913  
   914  				sr3 = NewRouter()
   915  				sr3.Get("/", func(w http.ResponseWriter, r *http.Request) {
   916  					ctx := r.Context()
   917  					s := fmt.Sprintf("/hubs/%s/webhooks reqid:%s session:%s", URLParam(r, "hubID"),
   918  						ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
   919  					w.Write([]byte(s))
   920  				})
   921  				sr3.Route("/{webhookID}", func(r Router) {
   922  					_ = r.(*Mux) // sr4
   923  					r.Get("/", func(w http.ResponseWriter, r *http.Request) {
   924  						ctx := r.Context()
   925  						s := fmt.Sprintf("/hubs/%s/webhooks/%s reqid:%s session:%s", URLParam(r, "hubID"),
   926  							URLParam(r, "webhookID"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
   927  						w.Write([]byte(s))
   928  					})
   929  				})
   930  
   931  				r.Mount("/webhooks", Chain(func(next http.Handler) http.Handler {
   932  					return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   933  						next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), ctxKey{"hook"}, true)))
   934  					})
   935  				}).Handler(sr3))
   936  
   937  				r.Route("/posts", func(r Router) {
   938  					_ = r.(*Mux) // sr5
   939  					r.Get("/", func(w http.ResponseWriter, r *http.Request) {
   940  						ctx := r.Context()
   941  						s := fmt.Sprintf("/hubs/%s/posts reqid:%s session:%s", URLParam(r, "hubID"),
   942  							ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
   943  						w.Write([]byte(s))
   944  					})
   945  				})
   946  			})
   947  		})
   948  
   949  		r.Route("/folders/", func(r Router) {
   950  			_ = r.(*Mux) // sr6
   951  			r.Get("/", func(w http.ResponseWriter, r *http.Request) {
   952  				ctx := r.Context()
   953  				s := fmt.Sprintf("/folders/ reqid:%s session:%s",
   954  					ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
   955  				w.Write([]byte(s))
   956  			})
   957  			r.Get("/public", func(w http.ResponseWriter, r *http.Request) {
   958  				ctx := r.Context()
   959  				s := fmt.Sprintf("/folders/public reqid:%s session:%s",
   960  					ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
   961  				w.Write([]byte(s))
   962  			})
   963  		})
   964  	})
   965  
   966  	return r
   967  }
   968  
   969  func TestMuxSubroutesBasic(t *testing.T) {
   970  	hIndex := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   971  		w.Write([]byte("index"))
   972  	})
   973  	hArticlesList := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   974  		w.Write([]byte("articles-list"))
   975  	})
   976  	hSearchArticles := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   977  		w.Write([]byte("search-articles"))
   978  	})
   979  	hGetArticle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   980  		w.Write([]byte(fmt.Sprintf("get-article:%s", URLParam(r, "id"))))
   981  	})
   982  	hSyncArticle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   983  		w.Write([]byte(fmt.Sprintf("sync-article:%s", URLParam(r, "id"))))
   984  	})
   985  
   986  	r := NewRouter()
   987  	// var rr1, rr2 *Mux
   988  	r.Get("/", hIndex)
   989  	r.Route("/articles", func(r Router) {
   990  		// rr1 = r.(*Mux)
   991  		r.Get("/", hArticlesList)
   992  		r.Get("/search", hSearchArticles)
   993  		r.Route("/{id}", func(r Router) {
   994  			// rr2 = r.(*Mux)
   995  			r.Get("/", hGetArticle)
   996  			r.Get("/sync", hSyncArticle)
   997  		})
   998  	})
   999  
  1000  	// log.Println("~~~~~~~~~")
  1001  	// log.Println("~~~~~~~~~")
  1002  	// debugPrintTree(0, 0, r.tree, 0)
  1003  	// log.Println("~~~~~~~~~")
  1004  	// log.Println("~~~~~~~~~")
  1005  
  1006  	// log.Println("~~~~~~~~~")
  1007  	// log.Println("~~~~~~~~~")
  1008  	// debugPrintTree(0, 0, rr1.tree, 0)
  1009  	// log.Println("~~~~~~~~~")
  1010  	// log.Println("~~~~~~~~~")
  1011  
  1012  	// log.Println("~~~~~~~~~")
  1013  	// log.Println("~~~~~~~~~")
  1014  	// debugPrintTree(0, 0, rr2.tree, 0)
  1015  	// log.Println("~~~~~~~~~")
  1016  	// log.Println("~~~~~~~~~")
  1017  
  1018  	ts := httptest.NewServer(r)
  1019  	defer ts.Close()
  1020  
  1021  	var body, expected string
  1022  
  1023  	_, body = testRequest(t, ts, "GET", "/", nil)
  1024  	expected = "index"
  1025  	if body != expected {
  1026  		t.Fatalf("expected:%s got:%s", expected, body)
  1027  	}
  1028  	_, body = testRequest(t, ts, "GET", "/articles", nil)
  1029  	expected = "articles-list"
  1030  	if body != expected {
  1031  		t.Fatalf("expected:%s got:%s", expected, body)
  1032  	}
  1033  	_, body = testRequest(t, ts, "GET", "/articles/search", nil)
  1034  	expected = "search-articles"
  1035  	if body != expected {
  1036  		t.Fatalf("expected:%s got:%s", expected, body)
  1037  	}
  1038  	_, body = testRequest(t, ts, "GET", "/articles/123", nil)
  1039  	expected = "get-article:123"
  1040  	if body != expected {
  1041  		t.Fatalf("expected:%s got:%s", expected, body)
  1042  	}
  1043  	_, body = testRequest(t, ts, "GET", "/articles/123/sync", nil)
  1044  	expected = "sync-article:123"
  1045  	if body != expected {
  1046  		t.Fatalf("expected:%s got:%s", expected, body)
  1047  	}
  1048  }
  1049  
  1050  func TestMuxSubroutes(t *testing.T) {
  1051  	hHubView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1052  		w.Write([]byte("hub1"))
  1053  	})
  1054  	hHubView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1055  		w.Write([]byte("hub2"))
  1056  	})
  1057  	hHubView3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1058  		w.Write([]byte("hub3"))
  1059  	})
  1060  	hAccountView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1061  		w.Write([]byte("account1"))
  1062  	})
  1063  	hAccountView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1064  		w.Write([]byte("account2"))
  1065  	})
  1066  
  1067  	r := NewRouter()
  1068  	r.Get("/hubs/{hubID}/view", hHubView1)
  1069  	r.Get("/hubs/{hubID}/view/*", hHubView2)
  1070  
  1071  	sr := NewRouter()
  1072  	sr.Get("/", hHubView3)
  1073  	r.Mount("/hubs/{hubID}/users", sr)
  1074  	r.Get("/hubs/{hubID}/users/", func(w http.ResponseWriter, r *http.Request) {
  1075  		w.Write([]byte("hub3 override"))
  1076  	})
  1077  
  1078  	sr3 := NewRouter()
  1079  	sr3.Get("/", hAccountView1)
  1080  	sr3.Get("/hi", hAccountView2)
  1081  
  1082  	// var sr2 *Mux
  1083  	r.Route("/accounts/{accountID}", func(r Router) {
  1084  		_ = r.(*Mux) // sr2
  1085  		// r.Get("/", hAccountView1)
  1086  		r.Mount("/", sr3)
  1087  	})
  1088  
  1089  	// This is the same as the r.Route() call mounted on sr2
  1090  	// sr2 := NewRouter()
  1091  	// sr2.Mount("/", sr3)
  1092  	// r.Mount("/accounts/{accountID}", sr2)
  1093  
  1094  	ts := httptest.NewServer(r)
  1095  	defer ts.Close()
  1096  
  1097  	var body, expected string
  1098  
  1099  	_, body = testRequest(t, ts, "GET", "/hubs/123/view", nil)
  1100  	expected = "hub1"
  1101  	if body != expected {
  1102  		t.Fatalf("expected:%s got:%s", expected, body)
  1103  	}
  1104  	_, body = testRequest(t, ts, "GET", "/hubs/123/view/index.html", nil)
  1105  	expected = "hub2"
  1106  	if body != expected {
  1107  		t.Fatalf("expected:%s got:%s", expected, body)
  1108  	}
  1109  	_, body = testRequest(t, ts, "GET", "/hubs/123/users", nil)
  1110  	expected = "hub3"
  1111  	if body != expected {
  1112  		t.Fatalf("expected:%s got:%s", expected, body)
  1113  	}
  1114  	_, body = testRequest(t, ts, "GET", "/hubs/123/users/", nil)
  1115  	expected = "hub3 override"
  1116  	if body != expected {
  1117  		t.Fatalf("expected:%s got:%s", expected, body)
  1118  	}
  1119  	_, body = testRequest(t, ts, "GET", "/accounts/44", nil)
  1120  	expected = "account1"
  1121  	if body != expected {
  1122  		t.Fatalf("request:%s expected:%s got:%s", "GET /accounts/44", expected, body)
  1123  	}
  1124  	_, body = testRequest(t, ts, "GET", "/accounts/44/hi", nil)
  1125  	expected = "account2"
  1126  	if body != expected {
  1127  		t.Fatalf("expected:%s got:%s", expected, body)
  1128  	}
  1129  
  1130  	// Test that we're building the routingPatterns properly
  1131  	router := r
  1132  	req, _ := http.NewRequest("GET", "/accounts/44/hi", nil)
  1133  
  1134  	rctx := NewRouteContext()
  1135  	req = req.WithContext(context.WithValue(req.Context(), RouteCtxKey, rctx))
  1136  
  1137  	w := httptest.NewRecorder()
  1138  	router.ServeHTTP(w, req)
  1139  
  1140  	body = w.Body.String()
  1141  	expected = "account2"
  1142  	if body != expected {
  1143  		t.Fatalf("expected:%s got:%s", expected, body)
  1144  	}
  1145  
  1146  	routePatterns := rctx.RoutePatterns
  1147  	if len(rctx.RoutePatterns) != 3 {
  1148  		t.Fatalf("expected 3 routing patterns, got:%d", len(rctx.RoutePatterns))
  1149  	}
  1150  	expected = "/accounts/{accountID}/*"
  1151  	if routePatterns[0] != expected {
  1152  		t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[0])
  1153  	}
  1154  	expected = "/*"
  1155  	if routePatterns[1] != expected {
  1156  		t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[1])
  1157  	}
  1158  	expected = "/hi"
  1159  	if routePatterns[2] != expected {
  1160  		t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[2])
  1161  	}
  1162  
  1163  }
  1164  
  1165  func TestSingleHandler(t *testing.T) {
  1166  	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1167  		name := URLParam(r, "name")
  1168  		w.Write([]byte("hi " + name))
  1169  	})
  1170  
  1171  	r, _ := http.NewRequest("GET", "/", nil)
  1172  	rctx := NewRouteContext()
  1173  	r = r.WithContext(context.WithValue(r.Context(), RouteCtxKey, rctx))
  1174  	rctx.URLParams.Add("name", "joe")
  1175  
  1176  	w := httptest.NewRecorder()
  1177  	h.ServeHTTP(w, r)
  1178  
  1179  	body := w.Body.String()
  1180  	expected := "hi joe"
  1181  	if body != expected {
  1182  		t.Fatalf("expected:%s got:%s", expected, body)
  1183  	}
  1184  }
  1185  
  1186  // TODO: a Router wrapper test..
  1187  //
  1188  // type ACLMux struct {
  1189  // 	*Mux
  1190  // 	XX string
  1191  // }
  1192  //
  1193  // func NewACLMux() *ACLMux {
  1194  // 	return &ACLMux{Mux: NewRouter(), XX: "hihi"}
  1195  // }
  1196  //
  1197  // // TODO: this should be supported...
  1198  // func TestWoot(t *testing.T) {
  1199  // 	var r Router = NewRouter()
  1200  //
  1201  // 	var r2 Router = NewACLMux() //NewRouter()
  1202  // 	r2.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
  1203  // 		w.Write([]byte("hi"))
  1204  // 	})
  1205  //
  1206  // 	r.Mount("/", r2)
  1207  // }
  1208  
  1209  func TestServeHTTPExistingContext(t *testing.T) {
  1210  	r := NewRouter()
  1211  	r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
  1212  		s, _ := r.Context().Value(ctxKey{"testCtx"}).(string)
  1213  		w.Write([]byte(s))
  1214  	})
  1215  	r.NotFound(func(w http.ResponseWriter, r *http.Request) {
  1216  		s, _ := r.Context().Value(ctxKey{"testCtx"}).(string)
  1217  		w.WriteHeader(404)
  1218  		w.Write([]byte(s))
  1219  	})
  1220  
  1221  	testcases := []struct {
  1222  		Method         string
  1223  		Path           string
  1224  		Ctx            context.Context
  1225  		ExpectedStatus int
  1226  		ExpectedBody   string
  1227  	}{
  1228  		{
  1229  			Method:         "GET",
  1230  			Path:           "/hi",
  1231  			Ctx:            context.WithValue(context.Background(), ctxKey{"testCtx"}, "hi ctx"),
  1232  			ExpectedStatus: 200,
  1233  			ExpectedBody:   "hi ctx",
  1234  		},
  1235  		{
  1236  			Method:         "GET",
  1237  			Path:           "/hello",
  1238  			Ctx:            context.WithValue(context.Background(), ctxKey{"testCtx"}, "nothing here ctx"),
  1239  			ExpectedStatus: 404,
  1240  			ExpectedBody:   "nothing here ctx",
  1241  		},
  1242  	}
  1243  
  1244  	for _, tc := range testcases {
  1245  		resp := httptest.NewRecorder()
  1246  		req, err := http.NewRequest(tc.Method, tc.Path, nil)
  1247  		if err != nil {
  1248  			t.Fatalf("%v", err)
  1249  		}
  1250  		req = req.WithContext(tc.Ctx)
  1251  		r.ServeHTTP(resp, req)
  1252  		b, err := ioutil.ReadAll(resp.Body)
  1253  		if err != nil {
  1254  			t.Fatalf("%v", err)
  1255  		}
  1256  		if resp.Code != tc.ExpectedStatus {
  1257  			t.Fatalf("%v != %v", tc.ExpectedStatus, resp.Code)
  1258  		}
  1259  		if string(b) != tc.ExpectedBody {
  1260  			t.Fatalf("%s != %s", tc.ExpectedBody, b)
  1261  		}
  1262  	}
  1263  }
  1264  
  1265  func TestNestedGroups(t *testing.T) {
  1266  	handlerPrintCounter := func(w http.ResponseWriter, r *http.Request) {
  1267  		counter, _ := r.Context().Value(ctxKey{"counter"}).(int)
  1268  		w.Write([]byte(fmt.Sprintf("%v", counter)))
  1269  	}
  1270  
  1271  	mwIncreaseCounter := func(next http.Handler) http.Handler {
  1272  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1273  			ctx := r.Context()
  1274  			counter, _ := ctx.Value(ctxKey{"counter"}).(int)
  1275  			counter++
  1276  			ctx = context.WithValue(ctx, ctxKey{"counter"}, counter)
  1277  			next.ServeHTTP(w, r.WithContext(ctx))
  1278  		})
  1279  	}
  1280  
  1281  	// Each route represents value of its counter (number of applied middlewares).
  1282  	r := NewRouter() // counter == 0
  1283  	r.Get("/0", handlerPrintCounter)
  1284  	r.Group(func(r Router) {
  1285  		r.Use(mwIncreaseCounter) // counter == 1
  1286  		r.Get("/1", handlerPrintCounter)
  1287  
  1288  		// r.Handle(GET, "/2", Chain(mwIncreaseCounter).HandlerFunc(handlerPrintCounter))
  1289  		r.With(mwIncreaseCounter).Get("/2", handlerPrintCounter)
  1290  
  1291  		r.Group(func(r Router) {
  1292  			r.Use(mwIncreaseCounter, mwIncreaseCounter) // counter == 3
  1293  			r.Get("/3", handlerPrintCounter)
  1294  		})
  1295  		r.Route("/", func(r Router) {
  1296  			r.Use(mwIncreaseCounter, mwIncreaseCounter) // counter == 3
  1297  
  1298  			// r.Handle(GET, "/4", Chain(mwIncreaseCounter).HandlerFunc(handlerPrintCounter))
  1299  			r.With(mwIncreaseCounter).Get("/4", handlerPrintCounter)
  1300  
  1301  			r.Group(func(r Router) {
  1302  				r.Use(mwIncreaseCounter, mwIncreaseCounter) // counter == 5
  1303  				r.Get("/5", handlerPrintCounter)
  1304  				// r.Handle(GET, "/6", Chain(mwIncreaseCounter).HandlerFunc(handlerPrintCounter))
  1305  				r.With(mwIncreaseCounter).Get("/6", handlerPrintCounter)
  1306  
  1307  			})
  1308  		})
  1309  	})
  1310  
  1311  	ts := httptest.NewServer(r)
  1312  	defer ts.Close()
  1313  
  1314  	for _, route := range []string{"0", "1", "2", "3", "4", "5", "6"} {
  1315  		if _, body := testRequest(t, ts, "GET", "/"+route, nil); body != route {
  1316  			t.Errorf("expected %v, got %v", route, body)
  1317  		}
  1318  	}
  1319  }
  1320  
  1321  func TestMiddlewarePanicOnLateUse(t *testing.T) {
  1322  	handler := func(w http.ResponseWriter, r *http.Request) {
  1323  		w.Write([]byte("hello\n"))
  1324  	}
  1325  
  1326  	mw := func(next http.Handler) http.Handler {
  1327  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1328  			next.ServeHTTP(w, r)
  1329  		})
  1330  	}
  1331  
  1332  	defer func() {
  1333  		if recover() == nil {
  1334  			t.Error("expected panic()")
  1335  		}
  1336  	}()
  1337  
  1338  	r := NewRouter()
  1339  	r.Get("/", handler)
  1340  	r.Use(mw) // Too late to apply middleware, we're expecting panic().
  1341  }
  1342  
  1343  func TestMountingExistingPath(t *testing.T) {
  1344  	handler := func(w http.ResponseWriter, r *http.Request) {}
  1345  
  1346  	defer func() {
  1347  		if recover() == nil {
  1348  			t.Error("expected panic()")
  1349  		}
  1350  	}()
  1351  
  1352  	r := NewRouter()
  1353  	r.Get("/", handler)
  1354  	r.Mount("/hi", http.HandlerFunc(handler))
  1355  	r.Mount("/hi", http.HandlerFunc(handler))
  1356  }
  1357  
  1358  func TestMountingSimilarPattern(t *testing.T) {
  1359  	r := NewRouter()
  1360  	r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
  1361  		w.Write([]byte("bye"))
  1362  	})
  1363  
  1364  	r2 := NewRouter()
  1365  	r2.Get("/", func(w http.ResponseWriter, r *http.Request) {
  1366  		w.Write([]byte("foobar"))
  1367  	})
  1368  
  1369  	r3 := NewRouter()
  1370  	r3.Get("/", func(w http.ResponseWriter, r *http.Request) {
  1371  		w.Write([]byte("foo"))
  1372  	})
  1373  
  1374  	r.Mount("/foobar", r2)
  1375  	r.Mount("/foo", r3)
  1376  
  1377  	ts := httptest.NewServer(r)
  1378  	defer ts.Close()
  1379  
  1380  	if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
  1381  		t.Fatalf(body)
  1382  	}
  1383  }
  1384  
  1385  func TestMuxEmptyParams(t *testing.T) {
  1386  	r := NewRouter()
  1387  	r.Get(`/users/{x}/{y}/{z}`, func(w http.ResponseWriter, r *http.Request) {
  1388  		x := URLParam(r, "x")
  1389  		y := URLParam(r, "y")
  1390  		z := URLParam(r, "z")
  1391  		w.Write([]byte(fmt.Sprintf("%s-%s-%s", x, y, z)))
  1392  	})
  1393  
  1394  	ts := httptest.NewServer(r)
  1395  	defer ts.Close()
  1396  
  1397  	if _, body := testRequest(t, ts, "GET", "/users/a/b/c", nil); body != "a-b-c" {
  1398  		t.Fatalf(body)
  1399  	}
  1400  	if _, body := testRequest(t, ts, "GET", "/users///c", nil); body != "--c" {
  1401  		t.Fatalf(body)
  1402  	}
  1403  }
  1404  
  1405  func TestMuxMissingParams(t *testing.T) {
  1406  	r := NewRouter()
  1407  	r.Get(`/user/{userId:\d+}`, func(w http.ResponseWriter, r *http.Request) {
  1408  		userID := URLParam(r, "userId")
  1409  		w.Write([]byte(fmt.Sprintf("userId = '%s'", userID)))
  1410  	})
  1411  	r.NotFound(func(w http.ResponseWriter, r *http.Request) {
  1412  		w.WriteHeader(404)
  1413  		w.Write([]byte("nothing here"))
  1414  	})
  1415  
  1416  	ts := httptest.NewServer(r)
  1417  	defer ts.Close()
  1418  
  1419  	if _, body := testRequest(t, ts, "GET", "/user/123", nil); body != "userId = '123'" {
  1420  		t.Fatalf(body)
  1421  	}
  1422  	if _, body := testRequest(t, ts, "GET", "/user/", nil); body != "nothing here" {
  1423  		t.Fatalf(body)
  1424  	}
  1425  }
  1426  
  1427  func TestMuxWildcardRoute(t *testing.T) {
  1428  	handler := func(w http.ResponseWriter, r *http.Request) {}
  1429  
  1430  	defer func() {
  1431  		if recover() == nil {
  1432  			t.Error("expected panic()")
  1433  		}
  1434  	}()
  1435  
  1436  	r := NewRouter()
  1437  	r.Get("/*/wildcard/must/be/at/end", handler)
  1438  }
  1439  
  1440  func TestMuxWildcardRouteCheckTwo(t *testing.T) {
  1441  	handler := func(w http.ResponseWriter, r *http.Request) {}
  1442  
  1443  	defer func() {
  1444  		if recover() == nil {
  1445  			t.Error("expected panic()")
  1446  		}
  1447  	}()
  1448  
  1449  	r := NewRouter()
  1450  	r.Get("/*/wildcard/{must}/be/at/end", handler)
  1451  }
  1452  
  1453  func TestMuxRegexp(t *testing.T) {
  1454  	r := NewRouter()
  1455  	r.Route("/{param:[0-9]*}/test", func(r Router) {
  1456  		r.Get("/", func(w http.ResponseWriter, r *http.Request) {
  1457  			w.Write([]byte(fmt.Sprintf("Hi: %s", URLParam(r, "param"))))
  1458  		})
  1459  	})
  1460  
  1461  	ts := httptest.NewServer(r)
  1462  	defer ts.Close()
  1463  
  1464  	if _, body := testRequest(t, ts, "GET", "//test", nil); body != "Hi: " {
  1465  		t.Fatalf(body)
  1466  	}
  1467  }
  1468  
  1469  func TestMuxRegexp2(t *testing.T) {
  1470  	r := NewRouter()
  1471  	r.Get("/foo-{suffix:[a-z]{2,3}}.json", func(w http.ResponseWriter, r *http.Request) {
  1472  		w.Write([]byte(URLParam(r, "suffix")))
  1473  	})
  1474  	ts := httptest.NewServer(r)
  1475  	defer ts.Close()
  1476  
  1477  	if _, body := testRequest(t, ts, "GET", "/foo-.json", nil); body != "" {
  1478  		t.Fatalf(body)
  1479  	}
  1480  	if _, body := testRequest(t, ts, "GET", "/foo-abc.json", nil); body != "abc" {
  1481  		t.Fatalf(body)
  1482  	}
  1483  }
  1484  
  1485  func TestMuxRegexp3(t *testing.T) {
  1486  	r := NewRouter()
  1487  	r.Get("/one/{firstId:[a-z0-9-]+}/{secondId:[a-z]+}/first", func(w http.ResponseWriter, r *http.Request) {
  1488  		w.Write([]byte("first"))
  1489  	})
  1490  	r.Get("/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(w http.ResponseWriter, r *http.Request) {
  1491  		w.Write([]byte("second"))
  1492  	})
  1493  	r.Delete("/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(w http.ResponseWriter, r *http.Request) {
  1494  		w.Write([]byte("third"))
  1495  	})
  1496  
  1497  	r.Route("/one", func(r Router) {
  1498  		r.Get("/{dns:[a-z-0-9_]+}", func(writer http.ResponseWriter, request *http.Request) {
  1499  			writer.Write([]byte("_"))
  1500  		})
  1501  		r.Get("/{dns:[a-z-0-9_]+}/info", func(writer http.ResponseWriter, request *http.Request) {
  1502  			writer.Write([]byte("_"))
  1503  		})
  1504  		r.Delete("/{id:[0-9]+}", func(writer http.ResponseWriter, request *http.Request) {
  1505  			writer.Write([]byte("forth"))
  1506  		})
  1507  	})
  1508  
  1509  	ts := httptest.NewServer(r)
  1510  	defer ts.Close()
  1511  
  1512  	if _, body := testRequest(t, ts, "GET", "/one/hello/peter/first", nil); body != "first" {
  1513  		t.Fatalf(body)
  1514  	}
  1515  	if _, body := testRequest(t, ts, "GET", "/one/hithere/123/second", nil); body != "second" {
  1516  		t.Fatalf(body)
  1517  	}
  1518  	if _, body := testRequest(t, ts, "DELETE", "/one/hithere/123/second", nil); body != "third" {
  1519  		t.Fatalf(body)
  1520  	}
  1521  	if _, body := testRequest(t, ts, "DELETE", "/one/123", nil); body != "forth" {
  1522  		t.Fatalf(body)
  1523  	}
  1524  }
  1525  
  1526  func TestMuxSubrouterWildcardParam(t *testing.T) {
  1527  	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1528  		fmt.Fprintf(w, "param:%v *:%v", URLParam(r, "param"), URLParam(r, "*"))
  1529  	})
  1530  
  1531  	r := NewRouter()
  1532  
  1533  	r.Get("/bare/{param}", h)
  1534  	r.Get("/bare/{param}/*", h)
  1535  
  1536  	r.Route("/case0", func(r Router) {
  1537  		r.Get("/{param}", h)
  1538  		r.Get("/{param}/*", h)
  1539  	})
  1540  
  1541  	ts := httptest.NewServer(r)
  1542  	defer ts.Close()
  1543  
  1544  	if _, body := testRequest(t, ts, "GET", "/bare/hi", nil); body != "param:hi *:" {
  1545  		t.Fatalf(body)
  1546  	}
  1547  	if _, body := testRequest(t, ts, "GET", "/bare/hi/yes", nil); body != "param:hi *:yes" {
  1548  		t.Fatalf(body)
  1549  	}
  1550  	if _, body := testRequest(t, ts, "GET", "/case0/hi", nil); body != "param:hi *:" {
  1551  		t.Fatalf(body)
  1552  	}
  1553  	if _, body := testRequest(t, ts, "GET", "/case0/hi/yes", nil); body != "param:hi *:yes" {
  1554  		t.Fatalf(body)
  1555  	}
  1556  }
  1557  
  1558  func TestMuxContextIsThreadSafe(t *testing.T) {
  1559  	router := NewRouter()
  1560  	router.Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
  1561  		ctx, cancel := context.WithTimeout(r.Context(), 1*time.Millisecond)
  1562  		defer cancel()
  1563  
  1564  		<-ctx.Done()
  1565  	})
  1566  
  1567  	wg := sync.WaitGroup{}
  1568  
  1569  	for i := 0; i < 100; i++ {
  1570  		wg.Add(1)
  1571  		go func() {
  1572  			defer wg.Done()
  1573  			for j := 0; j < 10000; j++ {
  1574  				w := httptest.NewRecorder()
  1575  				r, err := http.NewRequest("GET", "/ok", nil)
  1576  				if err != nil {
  1577  					t.Fatal(err)
  1578  				}
  1579  
  1580  				ctx, cancel := context.WithCancel(r.Context())
  1581  				r = r.WithContext(ctx)
  1582  
  1583  				go func() {
  1584  					cancel()
  1585  				}()
  1586  				router.ServeHTTP(w, r)
  1587  			}
  1588  		}()
  1589  	}
  1590  	wg.Wait()
  1591  }
  1592  
  1593  func TestEscapedURLParams(t *testing.T) {
  1594  	m := NewRouter()
  1595  	m.Get("/api/{identifier}/{region}/{size}/{rotation}/*", func(w http.ResponseWriter, r *http.Request) {
  1596  		w.WriteHeader(200)
  1597  		rctx := RouteContext(r.Context())
  1598  		if rctx == nil {
  1599  			t.Error("no context")
  1600  			return
  1601  		}
  1602  		identifier := URLParam(r, "identifier")
  1603  		if identifier != "http:%2f%2fexample.com%2fimage.png" {
  1604  			t.Errorf("identifier path parameter incorrect %s", identifier)
  1605  			return
  1606  		}
  1607  		region := URLParam(r, "region")
  1608  		if region != "full" {
  1609  			t.Errorf("region path parameter incorrect %s", region)
  1610  			return
  1611  		}
  1612  		size := URLParam(r, "size")
  1613  		if size != "max" {
  1614  			t.Errorf("size path parameter incorrect %s", size)
  1615  			return
  1616  		}
  1617  		rotation := URLParam(r, "rotation")
  1618  		if rotation != "0" {
  1619  			t.Errorf("rotation path parameter incorrect %s", rotation)
  1620  			return
  1621  		}
  1622  		w.Write([]byte("success"))
  1623  	})
  1624  
  1625  	ts := httptest.NewServer(m)
  1626  	defer ts.Close()
  1627  
  1628  	if _, body := testRequest(t, ts, "GET", "/api/http:%2f%2fexample.com%2fimage.png/full/max/0/color.png", nil); body != "success" {
  1629  		t.Fatalf(body)
  1630  	}
  1631  }
  1632  
  1633  func TestMuxMatch(t *testing.T) {
  1634  	r := NewRouter()
  1635  	r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
  1636  		w.Header().Set("X-Test", "yes")
  1637  		w.Write([]byte("bye"))
  1638  	})
  1639  	r.Route("/articles", func(r Router) {
  1640  		r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
  1641  			id := URLParam(r, "id")
  1642  			w.Header().Set("X-Article", id)
  1643  			w.Write([]byte("article:" + id))
  1644  		})
  1645  	})
  1646  	r.Route("/users", func(r Router) {
  1647  		r.Head("/{id}", func(w http.ResponseWriter, r *http.Request) {
  1648  			w.Header().Set("X-User", "-")
  1649  			w.Write([]byte("user"))
  1650  		})
  1651  		r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
  1652  			id := URLParam(r, "id")
  1653  			w.Header().Set("X-User", id)
  1654  			w.Write([]byte("user:" + id))
  1655  		})
  1656  	})
  1657  
  1658  	tctx := NewRouteContext()
  1659  
  1660  	tctx.Reset()
  1661  	if r.Match(tctx, "GET", "/users/1") == false {
  1662  		t.Fatal("expecting to find match for route:", "GET", "/users/1")
  1663  	}
  1664  
  1665  	tctx.Reset()
  1666  	if r.Match(tctx, "HEAD", "/articles/10") == true {
  1667  		t.Fatal("not expecting to find match for route:", "HEAD", "/articles/10")
  1668  	}
  1669  }
  1670  
  1671  func TestServerBaseContext(t *testing.T) {
  1672  	r := NewRouter()
  1673  	r.Get("/", func(w http.ResponseWriter, r *http.Request) {
  1674  		baseYes := r.Context().Value(ctxKey{"base"}).(string)
  1675  		if _, ok := r.Context().Value(http.ServerContextKey).(*http.Server); !ok {
  1676  			panic("missing server context")
  1677  		}
  1678  		if _, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr); !ok {
  1679  			panic("missing local addr context")
  1680  		}
  1681  		w.Write([]byte(baseYes))
  1682  	})
  1683  
  1684  	// Setup http Server with a base context
  1685  	ctx := context.WithValue(context.Background(), ctxKey{"base"}, "yes")
  1686  	ts := httptest.NewUnstartedServer(r)
  1687  	ts.Config.BaseContext = func(_ net.Listener) context.Context {
  1688  		return ctx
  1689  	}
  1690  	ts.Start()
  1691  
  1692  	defer ts.Close()
  1693  
  1694  	if _, body := testRequest(t, ts, "GET", "/", nil); body != "yes" {
  1695  		t.Fatalf(body)
  1696  	}
  1697  }
  1698  
  1699  func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) {
  1700  	req, err := http.NewRequest(method, ts.URL+path, body)
  1701  	if err != nil {
  1702  		t.Fatal(err)
  1703  		return nil, ""
  1704  	}
  1705  
  1706  	resp, err := http.DefaultClient.Do(req)
  1707  	if err != nil {
  1708  		t.Fatal(err)
  1709  		return nil, ""
  1710  	}
  1711  
  1712  	respBody, err := ioutil.ReadAll(resp.Body)
  1713  	if err != nil {
  1714  		t.Fatal(err)
  1715  		return nil, ""
  1716  	}
  1717  	defer resp.Body.Close()
  1718  
  1719  	return resp, string(respBody)
  1720  }
  1721  
  1722  func testHandler(t *testing.T, h http.Handler, method, path string, body io.Reader) (*http.Response, string) {
  1723  	r, _ := http.NewRequest(method, path, body)
  1724  	w := httptest.NewRecorder()
  1725  	h.ServeHTTP(w, r)
  1726  	return w.Result(), w.Body.String()
  1727  }
  1728  
  1729  type testFileSystem struct {
  1730  	open func(name string) (http.File, error)
  1731  }
  1732  
  1733  func (fs *testFileSystem) Open(name string) (http.File, error) {
  1734  	return fs.open(name)
  1735  }
  1736  
  1737  type testFile struct {
  1738  	name     string
  1739  	contents []byte
  1740  }
  1741  
  1742  func (tf *testFile) Close() error {
  1743  	return nil
  1744  }
  1745  
  1746  func (tf *testFile) Read(p []byte) (n int, err error) {
  1747  	copy(p, tf.contents)
  1748  	return len(p), nil
  1749  }
  1750  
  1751  func (tf *testFile) Seek(offset int64, whence int) (int64, error) {
  1752  	return 0, nil
  1753  }
  1754  
  1755  func (tf *testFile) Readdir(count int) ([]os.FileInfo, error) {
  1756  	stat, _ := tf.Stat()
  1757  	return []os.FileInfo{stat}, nil
  1758  }
  1759  
  1760  func (tf *testFile) Stat() (os.FileInfo, error) {
  1761  	return &testFileInfo{tf.name, int64(len(tf.contents))}, nil
  1762  }
  1763  
  1764  type testFileInfo struct {
  1765  	name string
  1766  	size int64
  1767  }
  1768  
  1769  func (tfi *testFileInfo) Name() string       { return tfi.name }
  1770  func (tfi *testFileInfo) Size() int64        { return tfi.size }
  1771  func (tfi *testFileInfo) Mode() os.FileMode  { return 0755 }
  1772  func (tfi *testFileInfo) ModTime() time.Time { return time.Now() }
  1773  func (tfi *testFileInfo) IsDir() bool        { return false }
  1774  func (tfi *testFileInfo) Sys() interface{}   { return nil }
  1775  
  1776  type ctxKey struct {
  1777  	name string
  1778  }
  1779  
  1780  func (k ctxKey) String() string {
  1781  	return "context value " + k.name
  1782  }
  1783  
  1784  func BenchmarkMux(b *testing.B) {
  1785  	h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
  1786  	h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
  1787  	h3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
  1788  	h4 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
  1789  	h5 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
  1790  	h6 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
  1791  
  1792  	mx := NewRouter()
  1793  	mx.Get("/", h1)
  1794  	mx.Get("/hi", h2)
  1795  	mx.Get("/sup/{id}/and/{this}", h3)
  1796  	mx.Get("/sup/{id}/{bar:foo}/{this}", h3)
  1797  
  1798  	mx.Route("/sharing/{x}/{hash}", func(mx Router) {
  1799  		mx.Get("/", h4)          // subrouter-1
  1800  		mx.Get("/{network}", h5) // subrouter-1
  1801  		mx.Get("/twitter", h5)
  1802  		mx.Route("/direct", func(mx Router) {
  1803  			mx.Get("/", h6) // subrouter-2
  1804  			mx.Get("/download", h6)
  1805  		})
  1806  	})
  1807  
  1808  	routes := []string{
  1809  		"/",
  1810  		"/hi",
  1811  		"/sup/123/and/this",
  1812  		"/sup/123/foo/this",
  1813  		"/sharing/z/aBc",                 // subrouter-1
  1814  		"/sharing/z/aBc/twitter",         // subrouter-1
  1815  		"/sharing/z/aBc/direct",          // subrouter-2
  1816  		"/sharing/z/aBc/direct/download", // subrouter-2
  1817  	}
  1818  
  1819  	for _, path := range routes {
  1820  		b.Run("route:"+path, func(b *testing.B) {
  1821  			w := httptest.NewRecorder()
  1822  			r, _ := http.NewRequest("GET", path, nil)
  1823  
  1824  			b.ReportAllocs()
  1825  			b.ResetTimer()
  1826  
  1827  			for i := 0; i < b.N; i++ {
  1828  				mx.ServeHTTP(w, r)
  1829  			}
  1830  		})
  1831  	}
  1832  }