github.com/epsagon/epsagon-go@v1.39.0/wrappers/net/http/client_test.go (about) 1 package epsagonhttp 2 3 import ( 4 "bytes" 5 "fmt" 6 "io/ioutil" 7 "net/http" 8 "net/http/httptest" 9 "strings" 10 "testing" 11 12 "github.com/epsagon/epsagon-go/epsagon" 13 "github.com/epsagon/epsagon-go/protocol" 14 "github.com/epsagon/epsagon-go/tracer" 15 . "github.com/onsi/ginkgo" 16 . "github.com/onsi/gomega" 17 ) 18 19 const TEST_RESPONSE_STRING = "response_test_string" 20 21 func TestEpsagonHTTPWrappers(t *testing.T) { 22 RegisterFailHandler(Fail) 23 RunSpecs(t, "epsagon http wrapper suite") 24 } 25 26 func verifyTraceIDExists(event *protocol.Event) { 27 traceID, ok := event.Resource.Metadata[tracer.EpsagonHTTPTraceIDKey] 28 Expect(ok).To(BeTrue()) 29 Expect(traceID).To(Not(BeZero())) 30 } 31 32 func verifyTraceIDNotExists(event *protocol.Event) { 33 Expect(event.Resource.Metadata).NotTo( 34 HaveKey(tracer.EpsagonHTTPTraceIDKey)) 35 } 36 37 func verifyResponseSuccess(response *http.Response, err error) { 38 Expect(err).To(BeNil()) 39 defer response.Body.Close() 40 responseData, err := ioutil.ReadAll(response.Body) 41 Expect(err).To(BeNil()) 42 responseString := string(responseData) 43 Expect(responseString).To(Equal(TEST_RESPONSE_STRING)) 44 } 45 46 type mockTransport struct { 47 called bool 48 } 49 50 func (m *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { 51 m.called = true 52 return http.DefaultTransport.RoundTrip(req) 53 } 54 55 var _ = Describe("ClientWrapper", func() { 56 var ( 57 events []*protocol.Event 58 exceptions []*protocol.Exception 59 requests []*http.Request 60 testServer *httptest.Server 61 response_data []byte 62 ) 63 BeforeEach(func() { 64 requests = make([]*http.Request, 0) 65 events = make([]*protocol.Event, 0) 66 exceptions = make([]*protocol.Exception, 0) 67 response_data = []byte(TEST_RESPONSE_STRING) 68 tracer.GlobalTracer = &tracer.MockedEpsagonTracer{ 69 Events: &events, 70 Exceptions: &exceptions, 71 Config: &tracer.Config{}, 72 } 73 testServer = httptest.NewServer(http.HandlerFunc( 74 func(res http.ResponseWriter, req *http.Request) { 75 requests = append(requests, req) 76 res.Write(response_data) 77 })) 78 }) 79 AfterEach(func() { 80 tracer.GlobalTracer = nil 81 testServer.Close() 82 }) 83 84 Describe(".Do", func() { 85 BeforeEach(func() { 86 events = make([]*protocol.Event, 0) 87 exceptions = make([]*protocol.Exception, 0) 88 requests = make([]*http.Request, 0) 89 }) 90 Context("sending a request to existing server", func() { 91 It("adds an event with no error", func() { 92 client := Wrap(http.Client{}) 93 req, err := http.NewRequest(http.MethodGet, testServer.URL, nil) 94 if err != nil { 95 Fail("couldn't create request") 96 } 97 client.Do(req) 98 Expect(requests).To(HaveLen(1)) 99 Expect(events).To(HaveLen(1)) 100 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK)) 101 verifyTraceIDExists(events[0]) 102 }) 103 }) 104 Context("sending a request to existing server, no tracer", func() { 105 It("adds an event with no error", func() { 106 tracer.GlobalTracer = nil 107 client := Wrap(http.Client{}) 108 req, err := http.NewRequest(http.MethodGet, testServer.URL, nil) 109 if err != nil { 110 Fail("couldn't create request") 111 } 112 response, err := client.Do(req) 113 verifyResponseSuccess(response, err) 114 }) 115 }) 116 Context("request to whitelisted url", func() { 117 It("Adds event with trace ID", func() { 118 client := Wrap(http.Client{}) 119 req, err := http.NewRequest( 120 http.MethodGet, 121 fmt.Sprintf("https://test.%s.com", APPSYNC_API_SUBDOMAIN), 122 nil, 123 ) 124 if err != nil { 125 Fail("couldn't create request") 126 } 127 client.Do(req) 128 Expect(events).To(HaveLen(1)) 129 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR)) 130 verifyTraceIDExists(events[0]) 131 }) 132 }) 133 Context("request to blacklisted url", func() { 134 It("Adds event with trace ID", func() { 135 client := Wrap(http.Client{}) 136 req, err := http.NewRequest( 137 http.MethodGet, 138 fmt.Sprintf("https://%s", EPSAGON_DOMAIN), 139 nil, 140 ) 141 if err != nil { 142 Fail("couldn't create request") 143 } 144 client.Do(req) 145 Expect(events).To(HaveLen(1)) 146 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK)) 147 verifyTraceIDNotExists(events[0]) 148 }) 149 }) 150 }) 151 Describe(".Get", func() { 152 BeforeEach(func() { 153 events = make([]*protocol.Event, 0) 154 exceptions = make([]*protocol.Exception, 0) 155 }) 156 Context("request created succesfully", func() { 157 It("Adds event", func() { 158 client := Wrap(http.Client{}) 159 client.Get(testServer.URL) 160 Expect(requests).To(HaveLen(1)) 161 Expect(events).To(HaveLen(1)) 162 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK)) 163 Expect(events[0].Resource.Metadata["response_body"]).To( 164 Equal(string(response_data))) 165 verifyTraceIDExists(events[0]) 166 }) 167 }) 168 Context("sending a request to existing server, no tracer", func() { 169 It("adds an event with no error", func() { 170 tracer.GlobalTracer = nil 171 client := Wrap(http.Client{}) 172 response, err := client.Get(testServer.URL) 173 verifyResponseSuccess(response, err) 174 }) 175 }) 176 Context("request to whitelisted url", func() { 177 It("Adds event with trace ID", func() { 178 client := Wrap(http.Client{}) 179 client.Get(fmt.Sprintf("https://test.%s.com", APPSYNC_API_SUBDOMAIN)) 180 Expect(events).To(HaveLen(1)) 181 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR)) 182 verifyTraceIDExists(events[0]) 183 }) 184 }) 185 Context("request to blacklisted url", func() { 186 It("Adds event with trace ID", func() { 187 client := Wrap(http.Client{}) 188 client.Get(fmt.Sprintf("https://%s", EPSAGON_DOMAIN)) 189 Expect(events).To(HaveLen(1)) 190 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK)) 191 verifyTraceIDNotExists(events[0]) 192 }) 193 }) 194 Context("bad input failing to create request", func() { 195 It("Adds event with error code error", func() { 196 client := Wrap(http.Client{}) 197 client.Get(testServer.URL + "balbla") 198 Expect(requests).To(HaveLen(0)) 199 Expect(events).To(HaveLen(1)) 200 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR)) 201 verifyTraceIDNotExists(events[0]) 202 }) 203 }) 204 }) 205 Describe(".Post", func() { 206 BeforeEach(func() { 207 events = make([]*protocol.Event, 0) 208 exceptions = make([]*protocol.Exception, 0) 209 }) 210 Context("request created succesfully", func() { 211 It("Adds event", func() { 212 client := Wrap(http.Client{}) 213 data := "{\"hello\":\"world\"}" 214 client.Post( 215 testServer.URL, 216 "application/json", 217 strings.NewReader(data)) 218 Expect(requests).To(HaveLen(1)) 219 Expect(events).To(HaveLen(1)) 220 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK)) 221 Expect(events[0].Resource.Metadata["response_body"]).To( 222 Equal(string(response_data))) 223 Expect(events[0].Resource.Metadata["request_body"]).To( 224 Equal(data)) 225 verifyTraceIDExists(events[0]) 226 }) 227 }) 228 Context("sending a request to existing server, no tracer", func() { 229 It("adds an event with no error", func() { 230 tracer.GlobalTracer = nil 231 client := Wrap(http.Client{}) 232 data := "{\"hello\":\"world\"}" 233 response, err := client.Post( 234 testServer.URL, 235 "application/json", 236 strings.NewReader(data)) 237 verifyResponseSuccess(response, err) 238 }) 239 }) 240 Context("client with metadataOnly", func() { 241 It("Adds event", func() { 242 client := Wrap(http.Client{}) 243 client.MetadataOnly = true 244 data := "{\"hello\":\"world\"}" 245 client.Post( 246 testServer.URL, 247 "application/json", 248 strings.NewReader(data)) 249 Expect(requests).To(HaveLen(1)) 250 Expect(events).To(HaveLen(1)) 251 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK)) 252 Expect(events[0].Resource.Metadata).NotTo( 253 HaveKey("response_body")) 254 Expect(events[0].Resource.Metadata).NotTo( 255 HaveKey("request_body")) 256 verifyTraceIDExists(events[0]) 257 }) 258 }) 259 Context("request to whitelisted url", func() { 260 It("Adds event with trace ID", func() { 261 client := Wrap(http.Client{}) 262 data := "{\"hello\":\"world\"}" 263 client.Post( 264 fmt.Sprintf("https://test.%s.com", APPSYNC_API_SUBDOMAIN), 265 "application/json", 266 strings.NewReader(data)) 267 Expect(events).To(HaveLen(1)) 268 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR)) 269 verifyTraceIDExists(events[0]) 270 }) 271 }) 272 Context("request to blacklisted url", func() { 273 It("Adds event with trace ID", func() { 274 client := Wrap(http.Client{}) 275 data := "{\"hello\":\"world\"}" 276 client.Post( 277 fmt.Sprintf("https://%s", EPSAGON_DOMAIN), 278 "application/json", 279 strings.NewReader(data)) 280 Expect(events).To(HaveLen(1)) 281 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK)) 282 verifyTraceIDNotExists(events[0]) 283 }) 284 }) 285 Context("bad input failing to create request", func() { 286 It("Adds event", func() { 287 client := Wrap(http.Client{}) 288 client.Post( 289 testServer.URL+"blabla", 290 "application/json", 291 strings.NewReader("{\"hello\":\"world\"}")) 292 Expect(requests).To(HaveLen(0)) 293 Expect(events).To(HaveLen(1)) 294 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR)) 295 verifyTraceIDNotExists(events[0]) 296 }) 297 }) 298 }) 299 Describe(".PostForm", func() { 300 BeforeEach(func() { 301 events = make([]*protocol.Event, 0) 302 exceptions = make([]*protocol.Exception, 0) 303 }) 304 Context("request created succesfully", func() { 305 It("Adds event", func() { 306 client := Wrap(http.Client{}) 307 client.PostForm( 308 testServer.URL, 309 map[string][]string{ 310 "hello": []string{"world", "of", "serverless"}, 311 }, 312 ) 313 Expect(requests).To(HaveLen(1)) 314 Expect(events).To(HaveLen(1)) 315 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK)) 316 verifyTraceIDExists(events[0]) 317 }) 318 }) 319 Context("sending a request to existing server, no tracer", func() { 320 It("adds an event with no error", func() { 321 tracer.GlobalTracer = nil 322 client := Wrap(http.Client{}) 323 response, err := client.PostForm( 324 testServer.URL, 325 map[string][]string{ 326 "hello": []string{"world", "of", "serverless"}, 327 }, 328 ) 329 verifyResponseSuccess(response, err) 330 }) 331 }) 332 Context("request to whitelisted url", func() { 333 It("Adds event with trace ID", func() { 334 client := Wrap(http.Client{}) 335 client.PostForm( 336 fmt.Sprintf("https://test.%s.com", APPSYNC_API_SUBDOMAIN), 337 map[string][]string{ 338 "hello": []string{"world", "of", "serverless"}, 339 }, 340 ) 341 Expect(events).To(HaveLen(1)) 342 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR)) 343 verifyTraceIDExists(events[0]) 344 }) 345 }) 346 Context("request to blacklisted url", func() { 347 It("Adds event with trace ID", func() { 348 client := Wrap(http.Client{}) 349 client.PostForm( 350 fmt.Sprintf("https://%s", EPSAGON_DOMAIN), 351 map[string][]string{ 352 "hello": []string{"world", "of", "serverless"}, 353 }, 354 ) 355 Expect(events).To(HaveLen(1)) 356 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK)) 357 verifyTraceIDNotExists(events[0]) 358 }) 359 }) 360 Context("bad input failing to create request", func() { 361 It("Adds event with error code error", func() { 362 client := Wrap(http.Client{}) 363 client.PostForm( 364 testServer.URL+"blabla", 365 map[string][]string{ 366 "hello": []string{"world", "of", "serverless"}, 367 }, 368 ) 369 Expect(requests).To(HaveLen(0)) 370 Expect(events).To(HaveLen(1)) 371 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR)) 372 verifyTraceIDNotExists(events[0]) 373 }) 374 }) 375 }) 376 Describe(".Head", func() { 377 BeforeEach(func() { 378 events = make([]*protocol.Event, 0) 379 exceptions = make([]*protocol.Exception, 0) 380 }) 381 Context("request created succesfully", func() { 382 It("Adds event", func() { 383 client := Wrap(http.Client{}) 384 client.Head(testServer.URL) 385 Expect(requests).To(HaveLen(1)) 386 Expect(events).To(HaveLen(1)) 387 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK)) 388 verifyTraceIDExists(events[0]) 389 }) 390 }) 391 Context("sending a request to existing server, no tracer", func() { 392 It("adds an event with no error", func() { 393 tracer.GlobalTracer = nil 394 client := Wrap(http.Client{}) 395 _, err := client.Head(testServer.URL) 396 Expect(err).To(BeNil()) 397 }) 398 }) 399 Context("request to whitelisted url", func() { 400 It("Adds event with trace ID", func() { 401 client := Wrap(http.Client{}) 402 client.Head(fmt.Sprintf("https://test.%s.com", APPSYNC_API_SUBDOMAIN)) 403 Expect(events).To(HaveLen(1)) 404 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR)) 405 verifyTraceIDExists(events[0]) 406 }) 407 }) 408 Context("request to blacklisted url", func() { 409 It("Adds event with trace ID", func() { 410 client := Wrap(http.Client{}) 411 client.Head(fmt.Sprintf("https://%s", EPSAGON_DOMAIN)) 412 Expect(events).To(HaveLen(1)) 413 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK)) 414 verifyTraceIDNotExists(events[0]) 415 }) 416 }) 417 Context("bad input failing to create request", func() { 418 It("Adds event with error code error", func() { 419 client := Wrap(http.Client{}) 420 client.Head(testServer.URL + "blabla") 421 Expect(requests).To(HaveLen(0)) 422 Expect(events).To(HaveLen(1)) 423 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR)) 424 verifyTraceIDNotExists(events[0]) 425 }) 426 }) 427 }) 428 Describe("http.RoundTripper", func() { 429 BeforeEach(func() { 430 events = make([]*protocol.Event, 0) 431 exceptions = make([]*protocol.Exception, 0) 432 requests = make([]*http.Request, 0) 433 }) 434 Context("sending a request to existing server", func() { 435 It("adds an event with no error, truncating the request body", func() { 436 client := &http.Client{Transport: NewTracingTransport()} 437 data := make([]byte, 128*1024) 438 for i := range data { 439 data[i] = byte(1) 440 } 441 req, err := http.NewRequest( 442 http.MethodPost, 443 testServer.URL, 444 bytes.NewReader(data)) 445 if err != nil { 446 Fail("couldn't create request") 447 } 448 client.Do(req) 449 Expect(requests).To(HaveLen(1)) 450 Expect(events).To(HaveLen(1)) 451 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK)) 452 Expect([]byte(events[0].Resource.Metadata["request_body"])).To(HaveCap(epsagon.MaxMetadataSize)) 453 verifyTraceIDExists(events[0]) 454 }) 455 }) 456 Context("sending a request to existing server, no tracer", func() { 457 It("adds an event with no error", func() { 458 tracer.GlobalTracer = nil 459 client := &http.Client{Transport: NewTracingTransport()} 460 req, err := http.NewRequest(http.MethodGet, testServer.URL, nil) 461 if err != nil { 462 Fail("couldn't create request") 463 } 464 response, err := client.Do(req) 465 verifyResponseSuccess(response, err) 466 }) 467 }) 468 Context("request to whitelisted url", func() { 469 It("Adds event with trace ID", func() { 470 client := &http.Client{Transport: NewTracingTransport()} 471 req, err := http.NewRequest( 472 http.MethodGet, 473 fmt.Sprintf("https://test.%s.com", APPSYNC_API_SUBDOMAIN), 474 nil, 475 ) 476 if err != nil { 477 Fail("couldn't create request") 478 } 479 client.Do(req) 480 Expect(events).To(HaveLen(1)) 481 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR)) 482 verifyTraceIDExists(events[0]) 483 }) 484 }) 485 Context("request to blacklisted url", func() { 486 It("Adds event with trace ID and the response truncated", func() { 487 client := &http.Client{Transport: NewTracingTransport()} 488 req, err := http.NewRequest( 489 http.MethodGet, 490 fmt.Sprintf("https://%s", EPSAGON_DOMAIN), 491 nil, 492 ) 493 if err != nil { 494 Fail("couldn't create request") 495 } 496 client.Do(req) 497 Expect(events).To(HaveLen(1)) 498 Expect([]byte(events[0].Resource.Metadata["response_body"])).To(HaveCap(10 * 1024)) 499 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK)) 500 verifyTraceIDNotExists(events[0]) 501 }) 502 }) 503 Context("wrapping a custom transport, request created succesfully", func() { 504 It("Adds event", func() { 505 mock := &mockTransport{} 506 client := &http.Client{Transport: NewWrappedTracingTransport(mock)} 507 client.Head(testServer.URL) 508 Expect(requests).To(HaveLen(1)) 509 Expect(events).To(HaveLen(1)) 510 Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK)) 511 Expect(mock.called).To(BeTrue()) 512 verifyTraceIDExists(events[0]) 513 }) 514 }) 515 }) 516 })