github.com/alwitt/goutils@v0.6.4/rest_test.go (about) 1 package goutils_test 2 3 import ( 4 "bufio" 5 "context" 6 "encoding/json" 7 "fmt" 8 "math/rand" 9 "net/http" 10 "net/http/httptest" 11 "sync" 12 "testing" 13 "time" 14 15 "github.com/alwitt/goutils" 16 "github.com/apex/log" 17 "github.com/google/uuid" 18 "github.com/gorilla/mux" 19 "github.com/stretchr/testify/assert" 20 ) 21 22 func TestRestAPIHandlerRequestIDInjection(t *testing.T) { 23 assert := assert.New(t) 24 log.SetLevel(log.DebugLevel) 25 26 // Case 0: no user request ID header defined 27 uutNoUserRequestIDHeader := goutils.RestAPIHandler{ 28 Component: goutils.Component{ 29 LogTags: log.Fields{"entity": "unit-tester"}, 30 LogTagModifiers: []goutils.LogMetadataModifier{ 31 goutils.ModifyLogMetadataByRestRequestParam, 32 }, 33 }, 34 CallRequestIDHeaderField: nil, 35 LogLevel: goutils.HTTPLogLevelDEBUG, 36 } 37 { 38 rid := uuid.New().String() 39 req, err := http.NewRequest("GET", "/testing", nil) 40 assert.Nil(err) 41 req.Header.Add("Request-ID", rid) 42 43 dummyHandler := func(w http.ResponseWriter, r *http.Request) { 44 callContext := r.Context() 45 assert.NotNil(callContext.Value(goutils.RestRequestParamKey{})) 46 v, ok := callContext.Value(goutils.RestRequestParamKey{}).(goutils.RestRequestParam) 47 assert.True(ok) 48 assert.NotEqual(rid, v.ID) 49 assert.Equal("GET", v.Method) 50 assert.Equal("/testing", v.URI) 51 } 52 53 router := mux.NewRouter() 54 respRecorder := httptest.NewRecorder() 55 router.HandleFunc("/testing", uutNoUserRequestIDHeader.LoggingMiddleware(dummyHandler)) 56 router.ServeHTTP(respRecorder, req) 57 58 assert.Equal(http.StatusOK, respRecorder.Code) 59 assert.Equal("", (respRecorder.Header().Get("Request-ID"))) 60 } 61 62 // Case 1: user request ID header defined 63 testReqIDHeader := uuid.New().String() 64 uutWithUserRequestIDHeader := goutils.RestAPIHandler{ 65 Component: goutils.Component{ 66 LogTags: log.Fields{"entity": "unit-tester"}, 67 LogTagModifiers: []goutils.LogMetadataModifier{ 68 goutils.ModifyLogMetadataByRestRequestParam, 69 }, 70 }, 71 CallRequestIDHeaderField: &testReqIDHeader, 72 LogLevel: goutils.HTTPLogLevelINFO, 73 } 74 { 75 rid := uuid.New().String() 76 req, err := http.NewRequest("DELETE", "/testing2", nil) 77 assert.Nil(err) 78 req.Header.Add(testReqIDHeader, rid) 79 80 dummyHandler := func(w http.ResponseWriter, r *http.Request) { 81 callContext := r.Context() 82 assert.NotNil(callContext.Value(goutils.RestRequestParamKey{})) 83 v, ok := callContext.Value(goutils.RestRequestParamKey{}).(goutils.RestRequestParam) 84 assert.True(ok) 85 assert.Equal(rid, v.ID) 86 assert.Equal("DELETE", v.Method) 87 assert.Equal("/testing2", v.URI) 88 } 89 90 router := mux.NewRouter() 91 respRecorder := httptest.NewRecorder() 92 router.HandleFunc("/testing2", uutWithUserRequestIDHeader.LoggingMiddleware(dummyHandler)) 93 router.ServeHTTP(respRecorder, req) 94 95 assert.Equal(http.StatusOK, respRecorder.Code) 96 assert.Equal(rid, (respRecorder.Header().Get(testReqIDHeader))) 97 } 98 } 99 100 func TestRestAPIHandlerRequestLogging(t *testing.T) { 101 assert := assert.New(t) 102 log.SetLevel(log.DebugLevel) 103 104 uut := goutils.RestAPIHandler{ 105 Component: goutils.Component{ 106 LogTags: log.Fields{"entity": "unit-tester"}, 107 LogTagModifiers: []goutils.LogMetadataModifier{ 108 goutils.ModifyLogMetadataByRestRequestParam, 109 }, 110 }, 111 DoNotLogHeaders: map[string]bool{"Not-Allowed": true}, 112 LogLevel: goutils.HTTPLogLevelDEBUG, 113 } 114 { 115 value1 := uuid.New().String() 116 value2 := uuid.New().String() 117 req, err := http.NewRequest("GET", "/testing", nil) 118 assert.Nil(err) 119 req.Header.Add("Allowed", value1) 120 req.Header.Add("Not-Allowed", value2) 121 122 dummyHandler := func(w http.ResponseWriter, r *http.Request) { 123 callContext := r.Context() 124 assert.NotNil(callContext.Value(goutils.RestRequestParamKey{})) 125 v, ok := callContext.Value(goutils.RestRequestParamKey{}).(goutils.RestRequestParam) 126 assert.True(ok) 127 assert.Equal("GET", v.Method) 128 assert.Equal("/testing", v.URI) 129 assert.Equal(value1, v.RequestHeaders.Get("Allowed")) 130 assert.Equal("", v.RequestHeaders.Get("Not-Allowed")) 131 } 132 133 router := mux.NewRouter() 134 respRecorder := httptest.NewRecorder() 135 router.HandleFunc("/testing", uut.LoggingMiddleware(dummyHandler)) 136 router.ServeHTTP(respRecorder, req) 137 138 assert.Equal(http.StatusOK, respRecorder.Code) 139 } 140 } 141 142 func TestRestAPIHandlerProcessStreamingEndpoints(t *testing.T) { 143 assert := assert.New(t) 144 log.SetLevel(log.DebugLevel) 145 146 testReqIDHeader := uuid.New().String() 147 uut := goutils.RestAPIHandler{ 148 Component: goutils.Component{ 149 LogTags: log.Fields{"entity": "unit-tester"}, 150 LogTagModifiers: []goutils.LogMetadataModifier{ 151 goutils.ModifyLogMetadataByRestRequestParam, 152 }, 153 }, 154 CallRequestIDHeaderField: &testReqIDHeader, 155 LogLevel: goutils.HTTPLogLevelDEBUG, 156 } 157 158 type testMessage struct { 159 Timestamp time.Time 160 Msg string 161 } 162 163 testMsgTX := make(chan testMessage, 1) 164 testMsgRX := make(chan testMessage, 1) 165 166 wg := sync.WaitGroup{} 167 defer wg.Wait() 168 utCtxt, ctxtCancel := context.WithCancel(context.Background()) 169 defer ctxtCancel() 170 171 // Define streaming data handler 172 testHandler := func(w http.ResponseWriter, r *http.Request) { 173 flusher, ok := w.(http.Flusher) 174 assert.True(ok) 175 w.Header().Set("Content-Type", "text/event-stream") 176 w.Header().Set("Cache-Control", "no-cache") 177 w.Header().Set("Connection", "keep-alive") 178 w.Header().Set("Access-Control-Allow-Origin", "*") 179 180 log.Debug("Starting stream response handler") 181 complete := false 182 for !complete { 183 select { 184 case <-utCtxt.Done(): 185 complete = true 186 case msg, ok := <-testMsgTX: 187 assert.True(ok) 188 t, err := json.Marshal(&msg) 189 assert.Nil(err) 190 fmt.Fprintf(w, "%s\n", t) 191 flusher.Flush() 192 log.Debugf("Sent %s\n", t) 193 } 194 } 195 log.Debug("Stoping stream response handler") 196 } 197 198 router := mux.NewRouter() 199 router.HandleFunc("/testing", uut.LoggingMiddleware(testHandler)) 200 201 // Define HTTP server 202 testServerPort := rand.Intn(30000) + 32769 203 testServerListen := fmt.Sprintf("127.0.0.1:%d", testServerPort) 204 testServer := &http.Server{ 205 Addr: testServerListen, 206 Handler: router, 207 } 208 // Start the HTTP server 209 log.Debugf("Starting test server on %s", testServerListen) 210 wg.Add(1) 211 go func() { 212 defer wg.Done() 213 if err := testServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { 214 assert.Nil(err) 215 } 216 log.Debugf("Stopped test server on %s", testServerListen) 217 }() 218 defer func() { 219 // Helper function to shutdown the server 220 ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) 221 defer cancel() 222 if err := testServer.Shutdown(ctx); err != nil { 223 assert.Nil(err) 224 } 225 }() 226 227 // Define test HTTP client 228 testClient := http.Client{} 229 req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/testing", testServerListen), nil) 230 assert.Nil(err) 231 testRID := uuid.New().String() 232 req.Header.Add(testReqIDHeader, testRID) 233 234 // Make the request in another thread 235 wg.Add(1) 236 go func() { 237 defer wg.Done() 238 var resp *http.Response 239 var err error 240 for i := 0; i < 3; i++ { 241 log.Debug("Connecting to test server") 242 resp, err = testClient.Do(req) 243 if err == nil { 244 break 245 } 246 time.Sleep(time.Millisecond * 25) 247 } 248 log.Debugf("Connected to test server http://%s/testing", testServerListen) 249 assert.Nil(err) 250 assert.Equal(http.StatusOK, resp.StatusCode) 251 assert.Equal(testRID, resp.Header.Get(testReqIDHeader)) 252 // Process the resp stream 253 scanner := bufio.NewScanner(resp.Body) 254 scanner.Split(bufio.ScanLines) 255 log.Debug("Scanning SSE stream") 256 for scanner.Scan() { 257 received := scanner.Text() 258 log.Debugf("Received: %s", received) 259 var parsed testMessage 260 assert.Nil(json.Unmarshal([]byte(received), &parsed)) 261 testMsgRX <- parsed 262 } 263 log.Debug("Stopped scanner") 264 }() 265 266 // Send message multiple times 267 for i := 0; i < 4; i++ { 268 newMsg := testMessage{Timestamp: time.Now(), Msg: uuid.New().String()} 269 testMsgTX <- newMsg 270 ctxt, lclCancel := context.WithTimeout(utCtxt, time.Millisecond*100) 271 defer lclCancel() 272 select { 273 case <-ctxt.Done(): 274 assert.Nil(ctxt.Err()) 275 case rx, ok := <-testMsgRX: 276 assert.True(ok) 277 assert.Equal(newMsg.Timestamp.UnixMicro(), rx.Timestamp.UnixMicro()) 278 assert.Equal(newMsg.Msg, rx.Msg) 279 } 280 } 281 282 // Allow for clean shutdown 283 ctxtCancel() 284 }