github.com/epsagon/epsagon-go@v1.39.0/wrappers/net/http/client_test.go (about)

     1  package epsagonhttp
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"strings"
    10  	"testing"
    11  
    12  	"github.com/epsagon/epsagon-go/epsagon"
    13  	"github.com/epsagon/epsagon-go/protocol"
    14  	"github.com/epsagon/epsagon-go/tracer"
    15  	. "github.com/onsi/ginkgo"
    16  	. "github.com/onsi/gomega"
    17  )
    18  
    19  const TEST_RESPONSE_STRING = "response_test_string"
    20  
    21  func TestEpsagonHTTPWrappers(t *testing.T) {
    22  	RegisterFailHandler(Fail)
    23  	RunSpecs(t, "epsagon http wrapper suite")
    24  }
    25  
    26  func verifyTraceIDExists(event *protocol.Event) {
    27  	traceID, ok := event.Resource.Metadata[tracer.EpsagonHTTPTraceIDKey]
    28  	Expect(ok).To(BeTrue())
    29  	Expect(traceID).To(Not(BeZero()))
    30  }
    31  
    32  func verifyTraceIDNotExists(event *protocol.Event) {
    33  	Expect(event.Resource.Metadata).NotTo(
    34  		HaveKey(tracer.EpsagonHTTPTraceIDKey))
    35  }
    36  
    37  func verifyResponseSuccess(response *http.Response, err error) {
    38  	Expect(err).To(BeNil())
    39  	defer response.Body.Close()
    40  	responseData, err := ioutil.ReadAll(response.Body)
    41  	Expect(err).To(BeNil())
    42  	responseString := string(responseData)
    43  	Expect(responseString).To(Equal(TEST_RESPONSE_STRING))
    44  }
    45  
    46  type mockTransport struct {
    47  	called bool
    48  }
    49  
    50  func (m *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
    51  	m.called = true
    52  	return http.DefaultTransport.RoundTrip(req)
    53  }
    54  
    55  var _ = Describe("ClientWrapper", func() {
    56  	var (
    57  		events        []*protocol.Event
    58  		exceptions    []*protocol.Exception
    59  		requests      []*http.Request
    60  		testServer    *httptest.Server
    61  		response_data []byte
    62  	)
    63  	BeforeEach(func() {
    64  		requests = make([]*http.Request, 0)
    65  		events = make([]*protocol.Event, 0)
    66  		exceptions = make([]*protocol.Exception, 0)
    67  		response_data = []byte(TEST_RESPONSE_STRING)
    68  		tracer.GlobalTracer = &tracer.MockedEpsagonTracer{
    69  			Events:     &events,
    70  			Exceptions: &exceptions,
    71  			Config:     &tracer.Config{},
    72  		}
    73  		testServer = httptest.NewServer(http.HandlerFunc(
    74  			func(res http.ResponseWriter, req *http.Request) {
    75  				requests = append(requests, req)
    76  				res.Write(response_data)
    77  			}))
    78  	})
    79  	AfterEach(func() {
    80  		tracer.GlobalTracer = nil
    81  		testServer.Close()
    82  	})
    83  
    84  	Describe(".Do", func() {
    85  		BeforeEach(func() {
    86  			events = make([]*protocol.Event, 0)
    87  			exceptions = make([]*protocol.Exception, 0)
    88  			requests = make([]*http.Request, 0)
    89  		})
    90  		Context("sending a request to existing server", func() {
    91  			It("adds an event with no error", func() {
    92  				client := Wrap(http.Client{})
    93  				req, err := http.NewRequest(http.MethodGet, testServer.URL, nil)
    94  				if err != nil {
    95  					Fail("couldn't create request")
    96  				}
    97  				client.Do(req)
    98  				Expect(requests).To(HaveLen(1))
    99  				Expect(events).To(HaveLen(1))
   100  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK))
   101  				verifyTraceIDExists(events[0])
   102  			})
   103  		})
   104  		Context("sending a request to existing server, no tracer", func() {
   105  			It("adds an event with no error", func() {
   106  				tracer.GlobalTracer = nil
   107  				client := Wrap(http.Client{})
   108  				req, err := http.NewRequest(http.MethodGet, testServer.URL, nil)
   109  				if err != nil {
   110  					Fail("couldn't create request")
   111  				}
   112  				response, err := client.Do(req)
   113  				verifyResponseSuccess(response, err)
   114  			})
   115  		})
   116  		Context("request to whitelisted url", func() {
   117  			It("Adds event with trace ID", func() {
   118  				client := Wrap(http.Client{})
   119  				req, err := http.NewRequest(
   120  					http.MethodGet,
   121  					fmt.Sprintf("https://test.%s.com", APPSYNC_API_SUBDOMAIN),
   122  					nil,
   123  				)
   124  				if err != nil {
   125  					Fail("couldn't create request")
   126  				}
   127  				client.Do(req)
   128  				Expect(events).To(HaveLen(1))
   129  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR))
   130  				verifyTraceIDExists(events[0])
   131  			})
   132  		})
   133  		Context("request to blacklisted url", func() {
   134  			It("Adds event with trace ID", func() {
   135  				client := Wrap(http.Client{})
   136  				req, err := http.NewRequest(
   137  					http.MethodGet,
   138  					fmt.Sprintf("https://%s", EPSAGON_DOMAIN),
   139  					nil,
   140  				)
   141  				if err != nil {
   142  					Fail("couldn't create request")
   143  				}
   144  				client.Do(req)
   145  				Expect(events).To(HaveLen(1))
   146  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK))
   147  				verifyTraceIDNotExists(events[0])
   148  			})
   149  		})
   150  	})
   151  	Describe(".Get", func() {
   152  		BeforeEach(func() {
   153  			events = make([]*protocol.Event, 0)
   154  			exceptions = make([]*protocol.Exception, 0)
   155  		})
   156  		Context("request created succesfully", func() {
   157  			It("Adds event", func() {
   158  				client := Wrap(http.Client{})
   159  				client.Get(testServer.URL)
   160  				Expect(requests).To(HaveLen(1))
   161  				Expect(events).To(HaveLen(1))
   162  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK))
   163  				Expect(events[0].Resource.Metadata["response_body"]).To(
   164  					Equal(string(response_data)))
   165  				verifyTraceIDExists(events[0])
   166  			})
   167  		})
   168  		Context("sending a request to existing server, no tracer", func() {
   169  			It("adds an event with no error", func() {
   170  				tracer.GlobalTracer = nil
   171  				client := Wrap(http.Client{})
   172  				response, err := client.Get(testServer.URL)
   173  				verifyResponseSuccess(response, err)
   174  			})
   175  		})
   176  		Context("request to whitelisted url", func() {
   177  			It("Adds event with trace ID", func() {
   178  				client := Wrap(http.Client{})
   179  				client.Get(fmt.Sprintf("https://test.%s.com", APPSYNC_API_SUBDOMAIN))
   180  				Expect(events).To(HaveLen(1))
   181  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR))
   182  				verifyTraceIDExists(events[0])
   183  			})
   184  		})
   185  		Context("request to blacklisted url", func() {
   186  			It("Adds event with trace ID", func() {
   187  				client := Wrap(http.Client{})
   188  				client.Get(fmt.Sprintf("https://%s", EPSAGON_DOMAIN))
   189  				Expect(events).To(HaveLen(1))
   190  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK))
   191  				verifyTraceIDNotExists(events[0])
   192  			})
   193  		})
   194  		Context("bad input failing to create request", func() {
   195  			It("Adds event with error code error", func() {
   196  				client := Wrap(http.Client{})
   197  				client.Get(testServer.URL + "balbla")
   198  				Expect(requests).To(HaveLen(0))
   199  				Expect(events).To(HaveLen(1))
   200  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR))
   201  				verifyTraceIDNotExists(events[0])
   202  			})
   203  		})
   204  	})
   205  	Describe(".Post", func() {
   206  		BeforeEach(func() {
   207  			events = make([]*protocol.Event, 0)
   208  			exceptions = make([]*protocol.Exception, 0)
   209  		})
   210  		Context("request created succesfully", func() {
   211  			It("Adds event", func() {
   212  				client := Wrap(http.Client{})
   213  				data := "{\"hello\":\"world\"}"
   214  				client.Post(
   215  					testServer.URL,
   216  					"application/json",
   217  					strings.NewReader(data))
   218  				Expect(requests).To(HaveLen(1))
   219  				Expect(events).To(HaveLen(1))
   220  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK))
   221  				Expect(events[0].Resource.Metadata["response_body"]).To(
   222  					Equal(string(response_data)))
   223  				Expect(events[0].Resource.Metadata["request_body"]).To(
   224  					Equal(data))
   225  				verifyTraceIDExists(events[0])
   226  			})
   227  		})
   228  		Context("sending a request to existing server, no tracer", func() {
   229  			It("adds an event with no error", func() {
   230  				tracer.GlobalTracer = nil
   231  				client := Wrap(http.Client{})
   232  				data := "{\"hello\":\"world\"}"
   233  				response, err := client.Post(
   234  					testServer.URL,
   235  					"application/json",
   236  					strings.NewReader(data))
   237  				verifyResponseSuccess(response, err)
   238  			})
   239  		})
   240  		Context("client with metadataOnly", func() {
   241  			It("Adds event", func() {
   242  				client := Wrap(http.Client{})
   243  				client.MetadataOnly = true
   244  				data := "{\"hello\":\"world\"}"
   245  				client.Post(
   246  					testServer.URL,
   247  					"application/json",
   248  					strings.NewReader(data))
   249  				Expect(requests).To(HaveLen(1))
   250  				Expect(events).To(HaveLen(1))
   251  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK))
   252  				Expect(events[0].Resource.Metadata).NotTo(
   253  					HaveKey("response_body"))
   254  				Expect(events[0].Resource.Metadata).NotTo(
   255  					HaveKey("request_body"))
   256  				verifyTraceIDExists(events[0])
   257  			})
   258  		})
   259  		Context("request to whitelisted url", func() {
   260  			It("Adds event with trace ID", func() {
   261  				client := Wrap(http.Client{})
   262  				data := "{\"hello\":\"world\"}"
   263  				client.Post(
   264  					fmt.Sprintf("https://test.%s.com", APPSYNC_API_SUBDOMAIN),
   265  					"application/json",
   266  					strings.NewReader(data))
   267  				Expect(events).To(HaveLen(1))
   268  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR))
   269  				verifyTraceIDExists(events[0])
   270  			})
   271  		})
   272  		Context("request to blacklisted url", func() {
   273  			It("Adds event with trace ID", func() {
   274  				client := Wrap(http.Client{})
   275  				data := "{\"hello\":\"world\"}"
   276  				client.Post(
   277  					fmt.Sprintf("https://%s", EPSAGON_DOMAIN),
   278  					"application/json",
   279  					strings.NewReader(data))
   280  				Expect(events).To(HaveLen(1))
   281  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK))
   282  				verifyTraceIDNotExists(events[0])
   283  			})
   284  		})
   285  		Context("bad input failing to create request", func() {
   286  			It("Adds event", func() {
   287  				client := Wrap(http.Client{})
   288  				client.Post(
   289  					testServer.URL+"blabla",
   290  					"application/json",
   291  					strings.NewReader("{\"hello\":\"world\"}"))
   292  				Expect(requests).To(HaveLen(0))
   293  				Expect(events).To(HaveLen(1))
   294  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR))
   295  				verifyTraceIDNotExists(events[0])
   296  			})
   297  		})
   298  	})
   299  	Describe(".PostForm", func() {
   300  		BeforeEach(func() {
   301  			events = make([]*protocol.Event, 0)
   302  			exceptions = make([]*protocol.Exception, 0)
   303  		})
   304  		Context("request created succesfully", func() {
   305  			It("Adds event", func() {
   306  				client := Wrap(http.Client{})
   307  				client.PostForm(
   308  					testServer.URL,
   309  					map[string][]string{
   310  						"hello": []string{"world", "of", "serverless"},
   311  					},
   312  				)
   313  				Expect(requests).To(HaveLen(1))
   314  				Expect(events).To(HaveLen(1))
   315  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK))
   316  				verifyTraceIDExists(events[0])
   317  			})
   318  		})
   319  		Context("sending a request to existing server, no tracer", func() {
   320  			It("adds an event with no error", func() {
   321  				tracer.GlobalTracer = nil
   322  				client := Wrap(http.Client{})
   323  				response, err := client.PostForm(
   324  					testServer.URL,
   325  					map[string][]string{
   326  						"hello": []string{"world", "of", "serverless"},
   327  					},
   328  				)
   329  				verifyResponseSuccess(response, err)
   330  			})
   331  		})
   332  		Context("request to whitelisted url", func() {
   333  			It("Adds event with trace ID", func() {
   334  				client := Wrap(http.Client{})
   335  				client.PostForm(
   336  					fmt.Sprintf("https://test.%s.com", APPSYNC_API_SUBDOMAIN),
   337  					map[string][]string{
   338  						"hello": []string{"world", "of", "serverless"},
   339  					},
   340  				)
   341  				Expect(events).To(HaveLen(1))
   342  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR))
   343  				verifyTraceIDExists(events[0])
   344  			})
   345  		})
   346  		Context("request to blacklisted url", func() {
   347  			It("Adds event with trace ID", func() {
   348  				client := Wrap(http.Client{})
   349  				client.PostForm(
   350  					fmt.Sprintf("https://%s", EPSAGON_DOMAIN),
   351  					map[string][]string{
   352  						"hello": []string{"world", "of", "serverless"},
   353  					},
   354  				)
   355  				Expect(events).To(HaveLen(1))
   356  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK))
   357  				verifyTraceIDNotExists(events[0])
   358  			})
   359  		})
   360  		Context("bad input failing to create request", func() {
   361  			It("Adds event with error code error", func() {
   362  				client := Wrap(http.Client{})
   363  				client.PostForm(
   364  					testServer.URL+"blabla",
   365  					map[string][]string{
   366  						"hello": []string{"world", "of", "serverless"},
   367  					},
   368  				)
   369  				Expect(requests).To(HaveLen(0))
   370  				Expect(events).To(HaveLen(1))
   371  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR))
   372  				verifyTraceIDNotExists(events[0])
   373  			})
   374  		})
   375  	})
   376  	Describe(".Head", func() {
   377  		BeforeEach(func() {
   378  			events = make([]*protocol.Event, 0)
   379  			exceptions = make([]*protocol.Exception, 0)
   380  		})
   381  		Context("request created succesfully", func() {
   382  			It("Adds event", func() {
   383  				client := Wrap(http.Client{})
   384  				client.Head(testServer.URL)
   385  				Expect(requests).To(HaveLen(1))
   386  				Expect(events).To(HaveLen(1))
   387  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK))
   388  				verifyTraceIDExists(events[0])
   389  			})
   390  		})
   391  		Context("sending a request to existing server, no tracer", func() {
   392  			It("adds an event with no error", func() {
   393  				tracer.GlobalTracer = nil
   394  				client := Wrap(http.Client{})
   395  				_, err := client.Head(testServer.URL)
   396  				Expect(err).To(BeNil())
   397  			})
   398  		})
   399  		Context("request to whitelisted url", func() {
   400  			It("Adds event with trace ID", func() {
   401  				client := Wrap(http.Client{})
   402  				client.Head(fmt.Sprintf("https://test.%s.com", APPSYNC_API_SUBDOMAIN))
   403  				Expect(events).To(HaveLen(1))
   404  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR))
   405  				verifyTraceIDExists(events[0])
   406  			})
   407  		})
   408  		Context("request to blacklisted url", func() {
   409  			It("Adds event with trace ID", func() {
   410  				client := Wrap(http.Client{})
   411  				client.Head(fmt.Sprintf("https://%s", EPSAGON_DOMAIN))
   412  				Expect(events).To(HaveLen(1))
   413  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK))
   414  				verifyTraceIDNotExists(events[0])
   415  			})
   416  		})
   417  		Context("bad input failing to create request", func() {
   418  			It("Adds event with error code error", func() {
   419  				client := Wrap(http.Client{})
   420  				client.Head(testServer.URL + "blabla")
   421  				Expect(requests).To(HaveLen(0))
   422  				Expect(events).To(HaveLen(1))
   423  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR))
   424  				verifyTraceIDNotExists(events[0])
   425  			})
   426  		})
   427  	})
   428  	Describe("http.RoundTripper", func() {
   429  		BeforeEach(func() {
   430  			events = make([]*protocol.Event, 0)
   431  			exceptions = make([]*protocol.Exception, 0)
   432  			requests = make([]*http.Request, 0)
   433  		})
   434  		Context("sending a request to existing server", func() {
   435  			It("adds an event with no error, truncating the request body", func() {
   436  				client := &http.Client{Transport: NewTracingTransport()}
   437  				data := make([]byte, 128*1024)
   438  				for i := range data {
   439  					data[i] = byte(1)
   440  				}
   441  				req, err := http.NewRequest(
   442  					http.MethodPost,
   443  					testServer.URL,
   444  					bytes.NewReader(data))
   445  				if err != nil {
   446  					Fail("couldn't create request")
   447  				}
   448  				client.Do(req)
   449  				Expect(requests).To(HaveLen(1))
   450  				Expect(events).To(HaveLen(1))
   451  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK))
   452  				Expect([]byte(events[0].Resource.Metadata["request_body"])).To(HaveCap(epsagon.MaxMetadataSize))
   453  				verifyTraceIDExists(events[0])
   454  			})
   455  		})
   456  		Context("sending a request to existing server, no tracer", func() {
   457  			It("adds an event with no error", func() {
   458  				tracer.GlobalTracer = nil
   459  				client := &http.Client{Transport: NewTracingTransport()}
   460  				req, err := http.NewRequest(http.MethodGet, testServer.URL, nil)
   461  				if err != nil {
   462  					Fail("couldn't create request")
   463  				}
   464  				response, err := client.Do(req)
   465  				verifyResponseSuccess(response, err)
   466  			})
   467  		})
   468  		Context("request to whitelisted url", func() {
   469  			It("Adds event with trace ID", func() {
   470  				client := &http.Client{Transport: NewTracingTransport()}
   471  				req, err := http.NewRequest(
   472  					http.MethodGet,
   473  					fmt.Sprintf("https://test.%s.com", APPSYNC_API_SUBDOMAIN),
   474  					nil,
   475  				)
   476  				if err != nil {
   477  					Fail("couldn't create request")
   478  				}
   479  				client.Do(req)
   480  				Expect(events).To(HaveLen(1))
   481  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_ERROR))
   482  				verifyTraceIDExists(events[0])
   483  			})
   484  		})
   485  		Context("request to blacklisted url", func() {
   486  			It("Adds event with trace ID and the response truncated", func() {
   487  				client := &http.Client{Transport: NewTracingTransport()}
   488  				req, err := http.NewRequest(
   489  					http.MethodGet,
   490  					fmt.Sprintf("https://%s", EPSAGON_DOMAIN),
   491  					nil,
   492  				)
   493  				if err != nil {
   494  					Fail("couldn't create request")
   495  				}
   496  				client.Do(req)
   497  				Expect(events).To(HaveLen(1))
   498  				Expect([]byte(events[0].Resource.Metadata["response_body"])).To(HaveCap(10 * 1024))
   499  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK))
   500  				verifyTraceIDNotExists(events[0])
   501  			})
   502  		})
   503  		Context("wrapping a custom transport, request created succesfully", func() {
   504  			It("Adds event", func() {
   505  				mock := &mockTransport{}
   506  				client := &http.Client{Transport: NewWrappedTracingTransport(mock)}
   507  				client.Head(testServer.URL)
   508  				Expect(requests).To(HaveLen(1))
   509  				Expect(events).To(HaveLen(1))
   510  				Expect(events[0].ErrorCode).To(Equal(protocol.ErrorCode_OK))
   511  				Expect(mock.called).To(BeTrue())
   512  				verifyTraceIDExists(events[0])
   513  			})
   514  		})
   515  	})
   516  })