github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/core/middleware/chain_test.go (about)

     1  /*
     2  Copyright hechain. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package middleware_test
     8  
     9  import (
    10  	"net/http"
    11  	"net/http/httptest"
    12  
    13  	"github.com/hechain20/hechain/core/middleware"
    14  	. "github.com/onsi/ginkgo"
    15  	. "github.com/onsi/gomega"
    16  )
    17  
    18  var _ = Describe("Chain", func() {
    19  	var (
    20  		one, two, three middleware.Middleware
    21  		chain           middleware.Chain
    22  
    23  		hello http.Handler
    24  
    25  		req  *http.Request
    26  		resp *httptest.ResponseRecorder
    27  	)
    28  
    29  	BeforeEach(func() {
    30  		one = func(next http.Handler) http.Handler {
    31  			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    32  				w.Write([]byte("1:before,"))
    33  				next.ServeHTTP(w, r)
    34  				w.Write([]byte("1:after"))
    35  			})
    36  		}
    37  		two = func(next http.Handler) http.Handler {
    38  			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    39  				w.Write([]byte("2:before,"))
    40  				next.ServeHTTP(w, r)
    41  				w.Write([]byte("2:after,"))
    42  			})
    43  		}
    44  		three = func(next http.Handler) http.Handler {
    45  			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    46  				w.Write([]byte("3:before,"))
    47  				next.ServeHTTP(w, r)
    48  				w.Write([]byte("3:after,"))
    49  			})
    50  		}
    51  		chain = middleware.NewChain(one, two, three)
    52  
    53  		hello = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    54  			w.WriteHeader(http.StatusOK)
    55  			w.Write([]byte("Hello!,"))
    56  		})
    57  
    58  		req = httptest.NewRequest("GET", "/", nil)
    59  		resp = httptest.NewRecorder()
    60  	})
    61  
    62  	It("calls middleware in the specified order", func() {
    63  		chain.Handler(hello).ServeHTTP(resp, req)
    64  		Expect(resp.Body.String()).To(Equal("1:before,2:before,3:before,Hello!,3:after,2:after,1:after"))
    65  	})
    66  
    67  	Context("when the chain is empty", func() {
    68  		BeforeEach(func() {
    69  			chain = middleware.NewChain()
    70  		})
    71  
    72  		It("calls the handler", func() {
    73  			chain.Handler(hello).ServeHTTP(resp, req)
    74  			Expect(resp.Body.String()).To(Equal("Hello!,"))
    75  		})
    76  	})
    77  
    78  	Context("when the handler is nil", func() {
    79  		It("uses the DefaultServerMux", func() {
    80  			chain.Handler(nil).ServeHTTP(resp, req)
    81  			Expect(resp.Body.String()).To(ContainSubstring("404 page not found"))
    82  		})
    83  	})
    84  })