github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/core/middleware/request_id_test.go (about)

     1  /*
     2  Copyright hechain. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package middleware_test
     8  
     9  import (
    10  	"net/http"
    11  	"net/http/httptest"
    12  
    13  	"github.com/hechain20/hechain/core/middleware"
    14  	"github.com/hechain20/hechain/core/middleware/fakes"
    15  	. "github.com/onsi/ginkgo"
    16  	. "github.com/onsi/gomega"
    17  )
    18  
    19  var _ = Describe("WithRequestID", func() {
    20  	var (
    21  		requestID middleware.Middleware
    22  		handler   *fakes.HTTPHandler
    23  		chain     http.Handler
    24  
    25  		req  *http.Request
    26  		resp *httptest.ResponseRecorder
    27  	)
    28  
    29  	BeforeEach(func() {
    30  		handler = &fakes.HTTPHandler{}
    31  		requestID = middleware.WithRequestID(
    32  			middleware.GenerateIDFunc(func() string { return "generated-id" }),
    33  		)
    34  		chain = requestID(handler)
    35  
    36  		req = httptest.NewRequest("GET", "/", nil)
    37  		resp = httptest.NewRecorder()
    38  	})
    39  
    40  	It("propagates the generated request ID in the request context", func() {
    41  		chain.ServeHTTP(resp, req)
    42  		_, r := handler.ServeHTTPArgsForCall(0)
    43  		requestID := middleware.RequestID(r.Context())
    44  		Expect(requestID).To(Equal("generated-id"))
    45  	})
    46  
    47  	It("returns the generated request ID in a header", func() {
    48  		chain.ServeHTTP(resp, req)
    49  		Expect(resp.Result().Header.Get("X-Request-Id")).To(Equal("generated-id"))
    50  	})
    51  
    52  	Context("when a request ID is already present", func() {
    53  		BeforeEach(func() {
    54  			req.Header.Set("X-Request-Id", "received-id")
    55  		})
    56  
    57  		It("sets the received ID into the context", func() {
    58  			chain.ServeHTTP(resp, req)
    59  			_, r := handler.ServeHTTPArgsForCall(0)
    60  			requestID := middleware.RequestID(r.Context())
    61  			Expect(requestID).To(Equal("received-id"))
    62  		})
    63  
    64  		It("sets the received ID into the request", func() {
    65  			chain.ServeHTTP(resp, req)
    66  			_, r := handler.ServeHTTPArgsForCall(0)
    67  			Expect(r.Header.Get("X-Request-Id")).To(Equal("received-id"))
    68  		})
    69  
    70  		It("propagates the request ID to the response", func() {
    71  			chain.ServeHTTP(resp, req)
    72  			Expect(resp.Result().Header.Get("X-Request-Id")).To(Equal("received-id"))
    73  		})
    74  	})
    75  })