gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/mux/middleware_test.go (about)

     1  package mux
     2  
     3  import (
     4  	"bytes"
     5  	"testing"
     6  
     7  	http "gitee.com/ks-custle/core-gm/gmhttp"
     8  )
     9  
    10  type testMiddleware struct {
    11  	timesCalled uint
    12  }
    13  
    14  func (tm *testMiddleware) Middleware(h http.Handler) http.Handler {
    15  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    16  		tm.timesCalled++
    17  		h.ServeHTTP(w, r)
    18  	})
    19  }
    20  
    21  func dummyHandler(w http.ResponseWriter, r *http.Request) {}
    22  
    23  func TestMiddlewareAdd(t *testing.T) {
    24  	router := NewRouter()
    25  	router.HandleFunc("/", dummyHandler).Methods("GET")
    26  
    27  	mw := &testMiddleware{}
    28  
    29  	router.useInterface(mw)
    30  	if len(router.middlewares) != 1 || router.middlewares[0] != mw {
    31  		t.Fatal("Middleware interface was not added correctly")
    32  	}
    33  
    34  	router.Use(mw.Middleware)
    35  	if len(router.middlewares) != 2 {
    36  		t.Fatal("Middleware method was not added correctly")
    37  	}
    38  
    39  	banalMw := func(handler http.Handler) http.Handler {
    40  		return handler
    41  	}
    42  	router.Use(banalMw)
    43  	if len(router.middlewares) != 3 {
    44  		t.Fatal("Middleware function was not added correctly")
    45  	}
    46  }
    47  
    48  func TestMiddleware(t *testing.T) {
    49  	router := NewRouter()
    50  	router.HandleFunc("/", dummyHandler).Methods("GET")
    51  
    52  	mw := &testMiddleware{}
    53  	router.useInterface(mw)
    54  
    55  	rw := NewRecorder()
    56  	req := newRequest("GET", "/")
    57  
    58  	t.Run("regular middleware call", func(t *testing.T) {
    59  		router.ServeHTTP(rw, req)
    60  		if mw.timesCalled != 1 {
    61  			t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
    62  		}
    63  	})
    64  
    65  	t.Run("not called for 404", func(t *testing.T) {
    66  		req = newRequest("GET", "/not/found")
    67  		router.ServeHTTP(rw, req)
    68  		if mw.timesCalled != 1 {
    69  			t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
    70  		}
    71  	})
    72  
    73  	t.Run("not called for method mismatch", func(t *testing.T) {
    74  		req = newRequest("POST", "/")
    75  		router.ServeHTTP(rw, req)
    76  		if mw.timesCalled != 1 {
    77  			t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
    78  		}
    79  	})
    80  
    81  	t.Run("regular call using function middleware", func(t *testing.T) {
    82  		router.Use(mw.Middleware)
    83  		req = newRequest("GET", "/")
    84  		router.ServeHTTP(rw, req)
    85  		if mw.timesCalled != 3 {
    86  			t.Fatalf("Expected %d calls, but got only %d", 3, mw.timesCalled)
    87  		}
    88  	})
    89  }
    90  
    91  func TestMiddlewareSubrouter(t *testing.T) {
    92  	router := NewRouter()
    93  	router.HandleFunc("/", dummyHandler).Methods("GET")
    94  
    95  	subrouter := router.PathPrefix("/sub").Subrouter()
    96  	subrouter.HandleFunc("/x", dummyHandler).Methods("GET")
    97  
    98  	mw := &testMiddleware{}
    99  	subrouter.useInterface(mw)
   100  
   101  	rw := NewRecorder()
   102  	req := newRequest("GET", "/")
   103  
   104  	t.Run("not called for route outside subrouter", func(t *testing.T) {
   105  		router.ServeHTTP(rw, req)
   106  		if mw.timesCalled != 0 {
   107  			t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
   108  		}
   109  	})
   110  
   111  	t.Run("not called for subrouter root 404", func(t *testing.T) {
   112  		req = newRequest("GET", "/sub/")
   113  		router.ServeHTTP(rw, req)
   114  		if mw.timesCalled != 0 {
   115  			t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
   116  		}
   117  	})
   118  
   119  	t.Run("called once for route inside subrouter", func(t *testing.T) {
   120  		req = newRequest("GET", "/sub/x")
   121  		router.ServeHTTP(rw, req)
   122  		if mw.timesCalled != 1 {
   123  			t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
   124  		}
   125  	})
   126  
   127  	t.Run("not called for 404 inside subrouter", func(t *testing.T) {
   128  		req = newRequest("GET", "/sub/not/found")
   129  		router.ServeHTTP(rw, req)
   130  		if mw.timesCalled != 1 {
   131  			t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
   132  		}
   133  	})
   134  
   135  	t.Run("middleware added to router", func(t *testing.T) {
   136  		router.useInterface(mw)
   137  
   138  		t.Run("called once for route outside subrouter", func(t *testing.T) {
   139  			req = newRequest("GET", "/")
   140  			router.ServeHTTP(rw, req)
   141  			if mw.timesCalled != 2 {
   142  				t.Fatalf("Expected %d calls, but got only %d", 2, mw.timesCalled)
   143  			}
   144  		})
   145  
   146  		t.Run("called twice for route inside subrouter", func(t *testing.T) {
   147  			req = newRequest("GET", "/sub/x")
   148  			router.ServeHTTP(rw, req)
   149  			if mw.timesCalled != 4 {
   150  				t.Fatalf("Expected %d calls, but got only %d", 4, mw.timesCalled)
   151  			}
   152  		})
   153  	})
   154  }
   155  
   156  func TestMiddlewareExecution(t *testing.T) {
   157  	mwStr := []byte("Middleware\n")
   158  	handlerStr := []byte("Logic\n")
   159  
   160  	router := NewRouter()
   161  	router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
   162  		w.Write(handlerStr)
   163  	})
   164  
   165  	t.Run("responds normally without middleware", func(t *testing.T) {
   166  		rw := NewRecorder()
   167  		req := newRequest("GET", "/")
   168  
   169  		router.ServeHTTP(rw, req)
   170  
   171  		if !bytes.Equal(rw.Body.Bytes(), handlerStr) {
   172  			t.Fatal("Handler response is not what it should be")
   173  		}
   174  	})
   175  
   176  	t.Run("responds with handler and middleware response", func(t *testing.T) {
   177  		rw := NewRecorder()
   178  		req := newRequest("GET", "/")
   179  
   180  		router.Use(func(h http.Handler) http.Handler {
   181  			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   182  				w.Write(mwStr)
   183  				h.ServeHTTP(w, r)
   184  			})
   185  		})
   186  
   187  		router.ServeHTTP(rw, req)
   188  		if !bytes.Equal(rw.Body.Bytes(), append(mwStr, handlerStr...)) {
   189  			t.Fatal("Middleware + handler response is not what it should be")
   190  		}
   191  	})
   192  }
   193  
   194  func TestMiddlewareNotFound(t *testing.T) {
   195  	mwStr := []byte("Middleware\n")
   196  	handlerStr := []byte("Logic\n")
   197  
   198  	router := NewRouter()
   199  	router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
   200  		w.Write(handlerStr)
   201  	})
   202  	router.Use(func(h http.Handler) http.Handler {
   203  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   204  			w.Write(mwStr)
   205  			h.ServeHTTP(w, r)
   206  		})
   207  	})
   208  
   209  	// Test not found call with default handler
   210  	t.Run("not called", func(t *testing.T) {
   211  		rw := NewRecorder()
   212  		req := newRequest("GET", "/notfound")
   213  
   214  		router.ServeHTTP(rw, req)
   215  		if bytes.Contains(rw.Body.Bytes(), mwStr) {
   216  			t.Fatal("Middleware was called for a 404")
   217  		}
   218  	})
   219  
   220  	t.Run("not called with custom not found handler", func(t *testing.T) {
   221  		rw := NewRecorder()
   222  		req := newRequest("GET", "/notfound")
   223  
   224  		router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   225  			rw.Write([]byte("Custom 404 handler"))
   226  		})
   227  		router.ServeHTTP(rw, req)
   228  
   229  		if bytes.Contains(rw.Body.Bytes(), mwStr) {
   230  			t.Fatal("Middleware was called for a custom 404")
   231  		}
   232  	})
   233  }
   234  
   235  func TestMiddlewareMethodMismatch(t *testing.T) {
   236  	mwStr := []byte("Middleware\n")
   237  	handlerStr := []byte("Logic\n")
   238  
   239  	router := NewRouter()
   240  	router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
   241  		w.Write(handlerStr)
   242  	}).Methods("GET")
   243  
   244  	router.Use(func(h http.Handler) http.Handler {
   245  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   246  			w.Write(mwStr)
   247  			h.ServeHTTP(w, r)
   248  		})
   249  	})
   250  
   251  	t.Run("not called", func(t *testing.T) {
   252  		rw := NewRecorder()
   253  		req := newRequest("POST", "/")
   254  
   255  		router.ServeHTTP(rw, req)
   256  		if bytes.Contains(rw.Body.Bytes(), mwStr) {
   257  			t.Fatal("Middleware was called for a method mismatch")
   258  		}
   259  	})
   260  
   261  	t.Run("not called with custom method not allowed handler", func(t *testing.T) {
   262  		rw := NewRecorder()
   263  		req := newRequest("POST", "/")
   264  
   265  		router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   266  			rw.Write([]byte("Method not allowed"))
   267  		})
   268  		router.ServeHTTP(rw, req)
   269  
   270  		if bytes.Contains(rw.Body.Bytes(), mwStr) {
   271  			t.Fatal("Middleware was called for a method mismatch")
   272  		}
   273  	})
   274  }
   275  
   276  func TestMiddlewareNotFoundSubrouter(t *testing.T) {
   277  	mwStr := []byte("Middleware\n")
   278  	handlerStr := []byte("Logic\n")
   279  
   280  	router := NewRouter()
   281  	router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
   282  		w.Write(handlerStr)
   283  	})
   284  
   285  	subrouter := router.PathPrefix("/sub/").Subrouter()
   286  	subrouter.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
   287  		w.Write(handlerStr)
   288  	})
   289  
   290  	router.Use(func(h http.Handler) http.Handler {
   291  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   292  			w.Write(mwStr)
   293  			h.ServeHTTP(w, r)
   294  		})
   295  	})
   296  
   297  	t.Run("not called", func(t *testing.T) {
   298  		rw := NewRecorder()
   299  		req := newRequest("GET", "/sub/notfound")
   300  
   301  		router.ServeHTTP(rw, req)
   302  		if bytes.Contains(rw.Body.Bytes(), mwStr) {
   303  			t.Fatal("Middleware was called for a 404")
   304  		}
   305  	})
   306  
   307  	t.Run("not called with custom not found handler", func(t *testing.T) {
   308  		rw := NewRecorder()
   309  		req := newRequest("GET", "/sub/notfound")
   310  
   311  		subrouter.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   312  			rw.Write([]byte("Custom 404 handler"))
   313  		})
   314  		router.ServeHTTP(rw, req)
   315  
   316  		if bytes.Contains(rw.Body.Bytes(), mwStr) {
   317  			t.Fatal("Middleware was called for a custom 404")
   318  		}
   319  	})
   320  }
   321  
   322  func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
   323  	mwStr := []byte("Middleware\n")
   324  	handlerStr := []byte("Logic\n")
   325  
   326  	router := NewRouter()
   327  	router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
   328  		w.Write(handlerStr)
   329  	})
   330  
   331  	subrouter := router.PathPrefix("/sub/").Subrouter()
   332  	subrouter.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
   333  		w.Write(handlerStr)
   334  	}).Methods("GET")
   335  
   336  	router.Use(func(h http.Handler) http.Handler {
   337  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   338  			w.Write(mwStr)
   339  			h.ServeHTTP(w, r)
   340  		})
   341  	})
   342  
   343  	t.Run("not called", func(t *testing.T) {
   344  		rw := NewRecorder()
   345  		req := newRequest("POST", "/sub/")
   346  
   347  		router.ServeHTTP(rw, req)
   348  		if bytes.Contains(rw.Body.Bytes(), mwStr) {
   349  			t.Fatal("Middleware was called for a method mismatch")
   350  		}
   351  	})
   352  
   353  	t.Run("not called with custom method not allowed handler", func(t *testing.T) {
   354  		rw := NewRecorder()
   355  		req := newRequest("POST", "/sub/")
   356  
   357  		router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   358  			rw.Write([]byte("Method not allowed"))
   359  		})
   360  		router.ServeHTTP(rw, req)
   361  
   362  		if bytes.Contains(rw.Body.Bytes(), mwStr) {
   363  			t.Fatal("Middleware was called for a method mismatch")
   364  		}
   365  	})
   366  }
   367  
   368  func TestCORSMethodMiddleware(t *testing.T) {
   369  	testCases := []struct {
   370  		name                                    string
   371  		registerRoutes                          func(r *Router)
   372  		requestHeader                           http.Header
   373  		requestMethod                           string
   374  		requestPath                             string
   375  		expectedAccessControlAllowMethodsHeader string
   376  		expectedResponse                        string
   377  	}{
   378  		{
   379  			name: "does not set without OPTIONS matcher",
   380  			registerRoutes: func(r *Router) {
   381  				r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
   382  			},
   383  			requestMethod:                           "GET",
   384  			requestPath:                             "/foo",
   385  			expectedAccessControlAllowMethodsHeader: "",
   386  			expectedResponse:                        "a",
   387  		},
   388  		{
   389  			name: "sets on non OPTIONS",
   390  			registerRoutes: func(r *Router) {
   391  				r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
   392  				r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
   393  			},
   394  			requestMethod:                           "GET",
   395  			requestPath:                             "/foo",
   396  			expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
   397  			expectedResponse:                        "a",
   398  		},
   399  		{
   400  			name: "sets without preflight headers",
   401  			registerRoutes: func(r *Router) {
   402  				r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
   403  				r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
   404  			},
   405  			requestMethod:                           "OPTIONS",
   406  			requestPath:                             "/foo",
   407  			expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
   408  			expectedResponse:                        "b",
   409  		},
   410  		{
   411  			name: "does not set on error",
   412  			registerRoutes: func(r *Router) {
   413  				r.HandleFunc("/foo", stringHandler("a"))
   414  			},
   415  			requestMethod:                           "OPTIONS",
   416  			requestPath:                             "/foo",
   417  			expectedAccessControlAllowMethodsHeader: "",
   418  			expectedResponse:                        "a",
   419  		},
   420  		{
   421  			name: "sets header on valid preflight",
   422  			registerRoutes: func(r *Router) {
   423  				r.HandleFunc("/foo", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
   424  				r.HandleFunc("/foo", stringHandler("b")).Methods(http.MethodOptions)
   425  			},
   426  			requestMethod: "OPTIONS",
   427  			requestPath:   "/foo",
   428  			requestHeader: http.Header{
   429  				"Access-Control-Request-Method":  []string{"GET"},
   430  				"Access-Control-Request-Headers": []string{"Authorization"},
   431  				"Origin":                         []string{"http://example.com"},
   432  			},
   433  			expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
   434  			expectedResponse:                        "b",
   435  		},
   436  		{
   437  			name: "does not set methods from unmatching routes",
   438  			registerRoutes: func(r *Router) {
   439  				r.HandleFunc("/foo", stringHandler("c")).Methods(http.MethodDelete)
   440  				r.HandleFunc("/foo/bar", stringHandler("a")).Methods(http.MethodGet, http.MethodPut, http.MethodPatch)
   441  				r.HandleFunc("/foo/bar", stringHandler("b")).Methods(http.MethodOptions)
   442  			},
   443  			requestMethod: "OPTIONS",
   444  			requestPath:   "/foo/bar",
   445  			requestHeader: http.Header{
   446  				"Access-Control-Request-Method":  []string{"GET"},
   447  				"Access-Control-Request-Headers": []string{"Authorization"},
   448  				"Origin":                         []string{"http://example.com"},
   449  			},
   450  			expectedAccessControlAllowMethodsHeader: "GET,PUT,PATCH,OPTIONS",
   451  			expectedResponse:                        "b",
   452  		},
   453  	}
   454  
   455  	for _, tt := range testCases {
   456  		t.Run(tt.name, func(t *testing.T) {
   457  			router := NewRouter()
   458  
   459  			tt.registerRoutes(router)
   460  
   461  			router.Use(CORSMethodMiddleware(router))
   462  
   463  			rw := NewRecorder()
   464  			req := newRequest(tt.requestMethod, tt.requestPath)
   465  			req.Header = tt.requestHeader
   466  
   467  			router.ServeHTTP(rw, req)
   468  
   469  			actualMethodsHeader := rw.Header().Get("Access-Control-Allow-Methods")
   470  			if actualMethodsHeader != tt.expectedAccessControlAllowMethodsHeader {
   471  				t.Fatalf("Expected Access-Control-Allow-Methods to equal %s but got %s", tt.expectedAccessControlAllowMethodsHeader, actualMethodsHeader)
   472  			}
   473  
   474  			actualResponse := rw.Body.String()
   475  			if actualResponse != tt.expectedResponse {
   476  				t.Fatalf("Expected response to equal %s but got %s", tt.expectedResponse, actualResponse)
   477  			}
   478  		})
   479  	}
   480  }
   481  
   482  func TestCORSMethodMiddlewareSubrouter(t *testing.T) {
   483  	router := NewRouter().StrictSlash(true)
   484  
   485  	subrouter := router.PathPrefix("/test").Subrouter()
   486  	subrouter.HandleFunc("/hello", stringHandler("a")).Methods(http.MethodGet, http.MethodOptions, http.MethodPost)
   487  	subrouter.HandleFunc("/hello/{name}", stringHandler("b")).Methods(http.MethodGet, http.MethodOptions)
   488  
   489  	subrouter.Use(CORSMethodMiddleware(subrouter))
   490  
   491  	rw := NewRecorder()
   492  	req := newRequest("GET", "/test/hello/asdf")
   493  	router.ServeHTTP(rw, req)
   494  
   495  	actualMethods := rw.Header().Get("Access-Control-Allow-Methods")
   496  	expectedMethods := "GET,OPTIONS"
   497  	if actualMethods != expectedMethods {
   498  		t.Fatalf("expected methods %q but got: %q", expectedMethods, actualMethods)
   499  	}
   500  }
   501  
   502  func TestMiddlewareOnMultiSubrouter(t *testing.T) {
   503  	first := "first"
   504  	second := "second"
   505  	notFound := "404 not found"
   506  
   507  	router := NewRouter()
   508  	firstSubRouter := router.PathPrefix("/").Subrouter()
   509  	secondSubRouter := router.PathPrefix("/").Subrouter()
   510  
   511  	router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   512  		rw.Write([]byte(notFound))
   513  	})
   514  
   515  	firstSubRouter.HandleFunc("/first", func(w http.ResponseWriter, r *http.Request) {
   516  
   517  	})
   518  
   519  	secondSubRouter.HandleFunc("/second", func(w http.ResponseWriter, r *http.Request) {
   520  
   521  	})
   522  
   523  	firstSubRouter.Use(func(h http.Handler) http.Handler {
   524  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   525  			w.Write([]byte(first))
   526  			h.ServeHTTP(w, r)
   527  		})
   528  	})
   529  
   530  	secondSubRouter.Use(func(h http.Handler) http.Handler {
   531  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   532  			w.Write([]byte(second))
   533  			h.ServeHTTP(w, r)
   534  		})
   535  	})
   536  
   537  	t.Run("/first uses first middleware", func(t *testing.T) {
   538  		rw := NewRecorder()
   539  		req := newRequest("GET", "/first")
   540  
   541  		router.ServeHTTP(rw, req)
   542  		if rw.Body.String() != first {
   543  			t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", first, rw.Body.String())
   544  		}
   545  	})
   546  
   547  	t.Run("/second uses second middleware", func(t *testing.T) {
   548  		rw := NewRecorder()
   549  		req := newRequest("GET", "/second")
   550  
   551  		router.ServeHTTP(rw, req)
   552  		if rw.Body.String() != second {
   553  			t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", second, rw.Body.String())
   554  		}
   555  	})
   556  
   557  	t.Run("uses not found handler", func(t *testing.T) {
   558  		rw := NewRecorder()
   559  		req := newRequest("GET", "/second/not-exist")
   560  
   561  		router.ServeHTTP(rw, req)
   562  		if rw.Body.String() != notFound {
   563  			t.Fatalf("Notfound handler did not run: expected %s for not-exist, (got %s)", notFound, rw.Body.String())
   564  		}
   565  	})
   566  }