github.com/zak-blake/goa@v1.4.1/middleware/request_id_test.go (about) 1 package middleware_test 2 3 import ( 4 "net/http" 5 "net/url" 6 7 "context" 8 9 "github.com/goadesign/goa" 10 "github.com/goadesign/goa/middleware" 11 . "github.com/onsi/ginkgo" 12 . "github.com/onsi/gomega" 13 ) 14 15 var _ = Describe("RequestID", func() { 16 const reqID = "request id" 17 var ctx context.Context 18 var rw http.ResponseWriter 19 var req *http.Request 20 var params url.Values 21 var service *goa.Service 22 23 BeforeEach(func() { 24 service = newService(nil) 25 26 var err error 27 req, err = http.NewRequest("GET", "/goo", nil) 28 Ω(err).ShouldNot(HaveOccurred()) 29 req.Header.Set(middleware.RequestIDHeader, reqID) 30 rw = new(testResponseWriter) 31 params = url.Values{"query": []string{"value"}} 32 service.Encoder.Register(goa.NewJSONEncoder, "*/*") 33 ctx = newContext(service, rw, req, params) 34 }) 35 36 It("sets the request ID in the context", func() { 37 var newCtx context.Context 38 h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error { 39 newCtx = ctx 40 return service.Send(ctx, 200, "ok") 41 } 42 rg := middleware.RequestID()(h) 43 Ω(rg(ctx, rw, req)).ShouldNot(HaveOccurred()) 44 Ω(middleware.ContextRequestID(newCtx)).Should(Equal(reqID)) 45 }) 46 47 It("truncates request ID when it exceeds a default limit", func() { 48 var newCtx context.Context 49 h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error { 50 newCtx = ctx 51 return service.Send(ctx, 200, "ok") 52 } 53 tooLong := makeRequestID(2 * middleware.DefaultRequestIDLengthLimit) 54 expected := makeRequestID(middleware.DefaultRequestIDLengthLimit) 55 req.Header.Set(middleware.RequestIDHeader, tooLong) 56 rg := middleware.RequestID()(h) 57 Ω(rg(ctx, rw, req)).ShouldNot(HaveOccurred()) 58 Ω(middleware.ContextRequestID(newCtx)).Should(Equal(expected)) 59 }) 60 61 It("sets the request ID in the context for a custom header and limit", func() { 62 var newCtx context.Context 63 h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error { 64 newCtx = ctx 65 return service.Send(ctx, 200, "ok") 66 } 67 req.Header.Set("Foo", "abcdefghij") 68 rg := middleware.RequestIDWithHeaderAndLengthLimit("Foo", 7)(h) 69 Ω(rg(ctx, rw, req)).ShouldNot(HaveOccurred()) 70 Ω(middleware.ContextRequestID(newCtx)).Should(Equal("abcdefg")) 71 }) 72 73 It("allows any request ID when length limit is negative", func() { 74 var newCtx context.Context 75 h := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error { 76 newCtx = ctx 77 return service.Send(ctx, 200, "ok") 78 } 79 original := makeRequestID(2 * middleware.DefaultRequestIDLengthLimit) 80 req.Header.Set(middleware.RequestIDHeader, string(original)) 81 rg := middleware.RequestIDWithHeaderAndLengthLimit(middleware.RequestIDHeader, -1)(h) 82 Ω(rg(ctx, rw, req)).ShouldNot(HaveOccurred()) 83 Ω(middleware.ContextRequestID(newCtx)).Should(Equal(string(original))) 84 }) 85 86 }) 87 88 func makeRequestID(length int) string { 89 buffer := make([]byte, length) 90 for i := range buffer { 91 buffer[i] = 'x' 92 } 93 return string(buffer) 94 }