gitlab.com/gitlab-org/labkit@v1.21.0/correlation/outbound_http_test.go (about) 1 package correlation 2 3 import ( 4 "context" 5 "errors" 6 "net/http" 7 "testing" 8 9 "github.com/stretchr/testify/require" 10 ) 11 12 var httpCorrelationTests = []struct { 13 name string 14 ctx context.Context 15 correlationID string 16 clientName string 17 hasHeader bool 18 }{ 19 { 20 name: "context with value", 21 ctx: context.Background(), 22 correlationID: "CORRELATION_ID", 23 clientName: "test_client", 24 hasHeader: true, 25 }, 26 { 27 name: "context without value", 28 ctx: context.Background(), 29 correlationID: "", 30 clientName: "", 31 hasHeader: false, 32 }, 33 } 34 35 func Test_injectRequest(t *testing.T) { 36 for _, tt := range httpCorrelationTests { 37 t.Run(tt.name, func(t *testing.T) { 38 require := require.New(t) 39 40 ctx := context.WithValue(tt.ctx, keyCorrelationID, tt.correlationID) 41 req, err := http.NewRequest("GET", "http://example.com", nil) 42 require.NoError(err) 43 44 req = req.WithContext(ctx) 45 46 mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { 47 return &http.Response{}, nil 48 }) 49 50 roundTripper := NewInstrumentedRoundTripper(mockTransport, WithClientName(tt.clientName)) 51 roundTripper.(*instrumentedRoundTripper).injectRequest(req) 52 53 value := req.Header.Get(propagationHeader) 54 clientName := req.Header.Get(clientNameHeader) 55 require.True(tt.hasHeader == (value != ""), "Expected header existence %v. Instead got header %v", tt.hasHeader, value) 56 require.Equal(tt.clientName, clientName, "Expected client name value %v, got %v", tt.clientName, clientName) 57 if tt.hasHeader { 58 require.Equal(tt.correlationID, value, "Expected header value %v, got %v", tt.correlationID, value) 59 } 60 }) 61 } 62 } 63 64 type delegatedRoundTripper struct { 65 delegate func(req *http.Request) (*http.Response, error) 66 } 67 68 func (c delegatedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 69 return c.delegate(req) 70 } 71 72 func roundTripperFunc(delegate func(req *http.Request) (*http.Response, error)) http.RoundTripper { 73 return &delegatedRoundTripper{delegate} 74 } 75 76 func TestInstrumentedRoundTripper(t *testing.T) { 77 for _, tt := range httpCorrelationTests { 78 t.Run(tt.name, func(t *testing.T) { 79 require := require.New(t) 80 81 response := &http.Response{} 82 mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { 83 value := req.Header.Get(propagationHeader) 84 require.True(tt.hasHeader == (value != ""), "Expected header existence %v. Instead got header %v", tt.hasHeader, value) 85 86 if tt.hasHeader { 87 require.Equal(tt.correlationID, value, "Expected header value %v, got %v", tt.correlationID, value) 88 } 89 90 return response, nil 91 }) 92 93 client := &http.Client{ 94 Transport: NewInstrumentedRoundTripper(mockTransport), 95 } 96 97 ctx := context.WithValue(tt.ctx, keyCorrelationID, tt.correlationID) 98 req, err := http.NewRequest("GET", "http://example.com", nil) 99 require.NoError(err) 100 101 req = req.WithContext(ctx) 102 103 res, err := client.Do(req) 104 require.NoError(err) 105 require.Equal(response, res) 106 }) 107 } 108 } 109 110 func TestInstrumentedRoundTripperFailures(t *testing.T) { 111 for _, tt := range httpCorrelationTests { 112 t.Run(tt.name+" - with errors", func(t *testing.T) { 113 require := require.New(t) 114 115 testErr := errors.New("test") 116 117 mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { 118 value := req.Header.Get(propagationHeader) 119 require.True(tt.hasHeader == (value != ""), "Expected header existence %v. Instead got header %v", tt.hasHeader, value) 120 121 if tt.hasHeader { 122 require.Equal(tt.correlationID, value, "Expected header value %v, got %v", tt.correlationID, value) 123 } 124 125 return nil, testErr 126 }) 127 128 client := &http.Client{ 129 Transport: NewInstrumentedRoundTripper(mockTransport), 130 } 131 132 ctx := context.WithValue(tt.ctx, keyCorrelationID, tt.correlationID) 133 req, err := http.NewRequest("GET", "http://example.com", nil) 134 require.NoError(err) 135 136 req = req.WithContext(ctx) 137 138 res, err := client.Do(req) 139 require.Error(err) 140 require.Nil(res) 141 }) 142 } 143 } 144 145 func TestInstrumentedRoundTripperWithoutContext(t *testing.T) { 146 require := require.New(t) 147 148 response := &http.Response{} 149 mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) { 150 return response, nil 151 }) 152 153 client := &http.Client{ 154 Transport: NewInstrumentedRoundTripper(mockTransport), 155 } 156 157 req, err := http.NewRequest("GET", "http://example.com", nil) 158 require.NoError(err) 159 160 res, err := client.Do(req) 161 require.NoError(err) 162 require.Equal(response, res) 163 }