github.com/mendersoftware/go-lib-micro@v0.0.0-20240304135804-e8e39c59b148/requestid/middleware_test.go (about) 1 // Copyright 2023 Northern.tech AS 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package requestid 16 17 import ( 18 "net/http" 19 "net/http/httptest" 20 "testing" 21 22 "github.com/ant0ine/go-json-rest/rest" 23 "github.com/ant0ine/go-json-rest/rest/test" 24 "github.com/gin-gonic/gin" 25 "github.com/google/uuid" 26 "github.com/stretchr/testify/assert" 27 ) 28 29 func init() { 30 gin.SetMode(gin.ReleaseMode) // please just shut up 31 } 32 33 func TestGinMiddleware(t *testing.T) { 34 t.Parallel() 35 testCases := []struct { 36 Name string 37 38 Options *MiddlewareOptions 39 40 Headers http.Header 41 }{{ 42 Name: "Request with ID", 43 44 Headers: func() http.Header { 45 hdr := http.Header{} 46 hdr.Set(RequestIdHeader, "test") 47 return hdr 48 }(), 49 }, { 50 Name: "Request generated ID", 51 52 Options: NewMiddlewareOptions(). 53 SetGenerateRequestID(true), 54 }} 55 56 for i := range testCases { 57 tc := testCases[i] 58 t.Run(tc.Name, func(t *testing.T) { 59 t.Parallel() 60 61 router := gin.New() 62 router.Use(Middleware(tc.Options)) 63 router.GET("/test", func(c *gin.Context) {}) 64 65 w := httptest.NewRecorder() 66 req, _ := http.NewRequest("GET", "http://mender.io/test", nil) 67 for k, v := range tc.Headers { 68 for _, vv := range v { 69 req.Header.Add(k, vv) 70 } 71 } 72 router.ServeHTTP(w, req) 73 74 rsp := w.Result() 75 76 if id := tc.Headers.Get(RequestIdHeader); id != "" { 77 rspID := rsp.Header.Get(RequestIdHeader) 78 assert.Equal(t, id, rspID) 79 } else { 80 if tc.Options.GenerateRequestID != nil && 81 *tc.Options.GenerateRequestID { 82 _, err := uuid.Parse(rsp.Header.Get(RequestIdHeader)) 83 assert.NoError(t, err, "Generated requestID is not a UUID") 84 } else { 85 assert.Empty(t, rsp.Header.Get(RequestIdHeader)) 86 } 87 } 88 }) 89 } 90 } 91 92 func TestRequestIdMiddlewareWithReqID(t *testing.T) { 93 api := rest.NewApi() 94 95 api.Use(&RequestIdMiddleware{}) 96 97 reqid := "4420a5b9-dbf2-4e5d-8b4f-3cf2013d04af" 98 api.SetApp(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) { 99 assert.Equal(t, reqid, FromContext(r.Context())) 100 w.WriteJson(map[string]string{"foo": "bar"}) 101 })) 102 103 handler := api.MakeHandler() 104 105 req := test.MakeSimpleRequest("GET", "http://localhost/", nil) 106 req.Header.Set(RequestIdHeader, reqid) 107 108 recorded := test.RunRequest(t, handler, req) 109 recorded.CodeIs(200) 110 recorded.ContentTypeIsJson() 111 recorded.HeaderIs(RequestIdHeader, reqid) 112 113 } 114 115 func TestRequestIdMiddlewareNoReqID(t *testing.T) { 116 api := rest.NewApi() 117 118 api.Use(&RequestIdMiddleware{}) 119 120 api.SetApp(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) { 121 reqid := FromContext(r.Context()) 122 _, err := uuid.Parse(reqid) 123 assert.NoError(t, err) 124 w.WriteJson(map[string]string{"foo": "bar"}) 125 })) 126 127 handler := api.MakeHandler() 128 129 req := test.MakeSimpleRequest("GET", "http://localhost/", nil) 130 recorded := test.RunRequest(t, handler, req) 131 recorded.CodeIs(200) 132 recorded.ContentTypeIsJson() 133 outReqIdStr := recorded.Recorder.HeaderMap.Get(RequestIdHeader) 134 _, err := uuid.Parse(outReqIdStr) 135 assert.NoError(t, err) 136 }