github.com/zak-blake/goa@v1.4.1/middleware/request_id_test.go (about)

     1  package middleware_test
     2  
     3  import (
     4  	"net/http"
     5  	"net/url"
     6  
     7  	"context"
     8  
     9  	"github.com/goadesign/goa"
    10  	"github.com/goadesign/goa/middleware"
    11  	. "github.com/onsi/ginkgo"
    12  	. "github.com/onsi/gomega"
    13  )
    14  
    15  var _ = Describe("RequestID", func() {
    16  	const reqID = "request id"
    17  	var ctx context.Context
    18  	var rw http.ResponseWriter
    19  	var req *http.Request
    20  	var params url.Values
    21  	var service *goa.Service
    22  
    23  	BeforeEach(func() {
    24  		service = newService(nil)
    25  
    26  		var err error
    27  		req, err = http.NewRequest("GET", "/goo", nil)
    28  		Ω(err).ShouldNot(HaveOccurred())
    29  		req.Header.Set(middleware.RequestIDHeader, reqID)
    30  		rw = new(testResponseWriter)
    31  		params = url.Values{"query": []string{"value"}}
    32  		service.Encoder.Register(goa.NewJSONEncoder, "*/*")
    33  		ctx = newContext(service, rw, req, params)
    34  	})
    35  
    36  	It("sets the request ID in the context", func() {
    37  		var newCtx context.Context
    38  		h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
    39  			newCtx = ctx
    40  			return service.Send(ctx, 200, "ok")
    41  		}
    42  		rg := middleware.RequestID()(h)
    43  		Ω(rg(ctx, rw, req)).ShouldNot(HaveOccurred())
    44  		Ω(middleware.ContextRequestID(newCtx)).Should(Equal(reqID))
    45  	})
    46  
    47  	It("truncates request ID when it exceeds a default limit", func() {
    48  		var newCtx context.Context
    49  		h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
    50  			newCtx = ctx
    51  			return service.Send(ctx, 200, "ok")
    52  		}
    53  		tooLong := makeRequestID(2 * middleware.DefaultRequestIDLengthLimit)
    54  		expected := makeRequestID(middleware.DefaultRequestIDLengthLimit)
    55  		req.Header.Set(middleware.RequestIDHeader, tooLong)
    56  		rg := middleware.RequestID()(h)
    57  		Ω(rg(ctx, rw, req)).ShouldNot(HaveOccurred())
    58  		Ω(middleware.ContextRequestID(newCtx)).Should(Equal(expected))
    59  	})
    60  
    61  	It("sets the request ID in the context for a custom header and limit", func() {
    62  		var newCtx context.Context
    63  		h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
    64  			newCtx = ctx
    65  			return service.Send(ctx, 200, "ok")
    66  		}
    67  		req.Header.Set("Foo", "abcdefghij")
    68  		rg := middleware.RequestIDWithHeaderAndLengthLimit("Foo", 7)(h)
    69  		Ω(rg(ctx, rw, req)).ShouldNot(HaveOccurred())
    70  		Ω(middleware.ContextRequestID(newCtx)).Should(Equal("abcdefg"))
    71  	})
    72  
    73  	It("allows any request ID when length limit is negative", func() {
    74  		var newCtx context.Context
    75  		h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
    76  			newCtx = ctx
    77  			return service.Send(ctx, 200, "ok")
    78  		}
    79  		original := makeRequestID(2 * middleware.DefaultRequestIDLengthLimit)
    80  		req.Header.Set(middleware.RequestIDHeader, string(original))
    81  		rg := middleware.RequestIDWithHeaderAndLengthLimit(middleware.RequestIDHeader, -1)(h)
    82  		Ω(rg(ctx, rw, req)).ShouldNot(HaveOccurred())
    83  		Ω(middleware.ContextRequestID(newCtx)).Should(Equal(string(original)))
    84  	})
    85  
    86  })
    87  
    88  func makeRequestID(length int) string {
    89  	buffer := make([]byte, length)
    90  	for i := range buffer {
    91  		buffer[i] = 'x'
    92  	}
    93  	return string(buffer)
    94  }