github.com/go-board/x-go@v0.1.2-0.20220610024734-db1323f6cb15/xnet/xhttp/middleware_test.go (about) 1 package xhttp 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "io/ioutil" 8 "net/http" 9 "net/http/httptest" 10 "net/url" 11 "testing" 12 "time" 13 14 "github.com/stretchr/testify/require" 15 ) 16 17 func logHandler(w io.StringWriter) Middleware { 18 return MiddlewareFn(func(h http.Handler) http.Handler { 19 return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { 20 now := time.Now() 21 h.ServeHTTP(writer, request) 22 w.WriteString(fmt.Sprintf("request method: %s, path: %s, latency: %s\n", request.Method, request.URL.Path, time.Since(now))) 23 }) 24 }) 25 } 26 27 func responseHandler(body string) Middleware { 28 return MiddlewareFn(func(h http.Handler) http.Handler { 29 return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { 30 io.WriteString(writer, body) 31 }) 32 }) 33 } 34 35 func TestMiddleware(t *testing.T) { 36 httpHandler := func(writer http.ResponseWriter, request *http.Request) { 37 time.Sleep(time.Millisecond * 100) 38 } 39 40 b := &bytes.Buffer{} 41 logHandler(b).Next(http.HandlerFunc(httpHandler)).ServeHTTP(nil, &http.Request{Method: http.MethodGet, URL: &url.URL{Path: "/v2/api/user/info"}}) 42 require.Contains(t, b.String(), http.MethodGet) 43 require.Contains(t, b.String(), "/v2/api/user/info") 44 } 45 46 func TestChainedMiddleware(t *testing.T) { 47 httpHandler := func(writer http.ResponseWriter, request *http.Request) { 48 time.Sleep(time.Millisecond * 100) 49 } 50 b := &bytes.Buffer{} 51 resp := httptest.NewRecorder() 52 logIt := logHandler(b).Next(http.HandlerFunc(httpHandler)) 53 respIt := responseHandler("Hello,world").Next(logIt) 54 respIt.ServeHTTP(resp, &http.Request{Method: http.MethodGet, URL: &url.URL{Path: "/v2/api/user/info"}}) 55 body, err := ioutil.ReadAll(resp.Body) 56 require.Nil(t, err, "err must be nil") 57 require.Equal(t, "Hello,world", string(body)) 58 }