github.com/go-board/x-go@v0.1.2-0.20220610024734-db1323f6cb15/xnet/xhttp/middleware_test.go (about)

     1  package xhttp
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"net/url"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/stretchr/testify/require"
    15  )
    16  
    17  func logHandler(w io.StringWriter) Middleware {
    18  	return MiddlewareFn(func(h http.Handler) http.Handler {
    19  		return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
    20  			now := time.Now()
    21  			h.ServeHTTP(writer, request)
    22  			w.WriteString(fmt.Sprintf("request method: %s, path: %s, latency: %s\n", request.Method, request.URL.Path, time.Since(now)))
    23  		})
    24  	})
    25  }
    26  
    27  func responseHandler(body string) Middleware {
    28  	return MiddlewareFn(func(h http.Handler) http.Handler {
    29  		return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
    30  			io.WriteString(writer, body)
    31  		})
    32  	})
    33  }
    34  
    35  func TestMiddleware(t *testing.T) {
    36  	httpHandler := func(writer http.ResponseWriter, request *http.Request) {
    37  		time.Sleep(time.Millisecond * 100)
    38  	}
    39  
    40  	b := &bytes.Buffer{}
    41  	logHandler(b).Next(http.HandlerFunc(httpHandler)).ServeHTTP(nil, &http.Request{Method: http.MethodGet, URL: &url.URL{Path: "/v2/api/user/info"}})
    42  	require.Contains(t, b.String(), http.MethodGet)
    43  	require.Contains(t, b.String(), "/v2/api/user/info")
    44  }
    45  
    46  func TestChainedMiddleware(t *testing.T) {
    47  	httpHandler := func(writer http.ResponseWriter, request *http.Request) {
    48  		time.Sleep(time.Millisecond * 100)
    49  	}
    50  	b := &bytes.Buffer{}
    51  	resp := httptest.NewRecorder()
    52  	logIt := logHandler(b).Next(http.HandlerFunc(httpHandler))
    53  	respIt := responseHandler("Hello,world").Next(logIt)
    54  	respIt.ServeHTTP(resp, &http.Request{Method: http.MethodGet, URL: &url.URL{Path: "/v2/api/user/info"}})
    55  	body, err := ioutil.ReadAll(resp.Body)
    56  	require.Nil(t, err, "err must be nil")
    57  	require.Equal(t, "Hello,world", string(body))
    58  }