gitlab.com/gitlab-org/labkit@v1.21.0/correlation/outbound_http_test.go (about)

     1  package correlation
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net/http"
     7  	"testing"
     8  
     9  	"github.com/stretchr/testify/require"
    10  )
    11  
    12  var httpCorrelationTests = []struct {
    13  	name          string
    14  	ctx           context.Context
    15  	correlationID string
    16  	clientName    string
    17  	hasHeader     bool
    18  }{
    19  	{
    20  		name:          "context with value",
    21  		ctx:           context.Background(),
    22  		correlationID: "CORRELATION_ID",
    23  		clientName:    "test_client",
    24  		hasHeader:     true,
    25  	},
    26  	{
    27  		name:          "context without value",
    28  		ctx:           context.Background(),
    29  		correlationID: "",
    30  		clientName:    "",
    31  		hasHeader:     false,
    32  	},
    33  }
    34  
    35  func Test_injectRequest(t *testing.T) {
    36  	for _, tt := range httpCorrelationTests {
    37  		t.Run(tt.name, func(t *testing.T) {
    38  			require := require.New(t)
    39  
    40  			ctx := context.WithValue(tt.ctx, keyCorrelationID, tt.correlationID)
    41  			req, err := http.NewRequest("GET", "http://example.com", nil)
    42  			require.NoError(err)
    43  
    44  			req = req.WithContext(ctx)
    45  
    46  			mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
    47  				return &http.Response{}, nil
    48  			})
    49  
    50  			roundTripper := NewInstrumentedRoundTripper(mockTransport, WithClientName(tt.clientName))
    51  			roundTripper.(*instrumentedRoundTripper).injectRequest(req)
    52  
    53  			value := req.Header.Get(propagationHeader)
    54  			clientName := req.Header.Get(clientNameHeader)
    55  			require.True(tt.hasHeader == (value != ""), "Expected header existence %v. Instead got header %v", tt.hasHeader, value)
    56  			require.Equal(tt.clientName, clientName, "Expected client name value %v, got %v", tt.clientName, clientName)
    57  			if tt.hasHeader {
    58  				require.Equal(tt.correlationID, value, "Expected header value %v, got %v", tt.correlationID, value)
    59  			}
    60  		})
    61  	}
    62  }
    63  
    64  type delegatedRoundTripper struct {
    65  	delegate func(req *http.Request) (*http.Response, error)
    66  }
    67  
    68  func (c delegatedRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
    69  	return c.delegate(req)
    70  }
    71  
    72  func roundTripperFunc(delegate func(req *http.Request) (*http.Response, error)) http.RoundTripper {
    73  	return &delegatedRoundTripper{delegate}
    74  }
    75  
    76  func TestInstrumentedRoundTripper(t *testing.T) {
    77  	for _, tt := range httpCorrelationTests {
    78  		t.Run(tt.name, func(t *testing.T) {
    79  			require := require.New(t)
    80  
    81  			response := &http.Response{}
    82  			mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
    83  				value := req.Header.Get(propagationHeader)
    84  				require.True(tt.hasHeader == (value != ""), "Expected header existence %v. Instead got header %v", tt.hasHeader, value)
    85  
    86  				if tt.hasHeader {
    87  					require.Equal(tt.correlationID, value, "Expected header value %v, got %v", tt.correlationID, value)
    88  				}
    89  
    90  				return response, nil
    91  			})
    92  
    93  			client := &http.Client{
    94  				Transport: NewInstrumentedRoundTripper(mockTransport),
    95  			}
    96  
    97  			ctx := context.WithValue(tt.ctx, keyCorrelationID, tt.correlationID)
    98  			req, err := http.NewRequest("GET", "http://example.com", nil)
    99  			require.NoError(err)
   100  
   101  			req = req.WithContext(ctx)
   102  
   103  			res, err := client.Do(req)
   104  			require.NoError(err)
   105  			require.Equal(response, res)
   106  		})
   107  	}
   108  }
   109  
   110  func TestInstrumentedRoundTripperFailures(t *testing.T) {
   111  	for _, tt := range httpCorrelationTests {
   112  		t.Run(tt.name+" - with errors", func(t *testing.T) {
   113  			require := require.New(t)
   114  
   115  			testErr := errors.New("test")
   116  
   117  			mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
   118  				value := req.Header.Get(propagationHeader)
   119  				require.True(tt.hasHeader == (value != ""), "Expected header existence %v. Instead got header %v", tt.hasHeader, value)
   120  
   121  				if tt.hasHeader {
   122  					require.Equal(tt.correlationID, value, "Expected header value %v, got %v", tt.correlationID, value)
   123  				}
   124  
   125  				return nil, testErr
   126  			})
   127  
   128  			client := &http.Client{
   129  				Transport: NewInstrumentedRoundTripper(mockTransport),
   130  			}
   131  
   132  			ctx := context.WithValue(tt.ctx, keyCorrelationID, tt.correlationID)
   133  			req, err := http.NewRequest("GET", "http://example.com", nil)
   134  			require.NoError(err)
   135  
   136  			req = req.WithContext(ctx)
   137  
   138  			res, err := client.Do(req)
   139  			require.Error(err)
   140  			require.Nil(res)
   141  		})
   142  	}
   143  }
   144  
   145  func TestInstrumentedRoundTripperWithoutContext(t *testing.T) {
   146  	require := require.New(t)
   147  
   148  	response := &http.Response{}
   149  	mockTransport := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
   150  		return response, nil
   151  	})
   152  
   153  	client := &http.Client{
   154  		Transport: NewInstrumentedRoundTripper(mockTransport),
   155  	}
   156  
   157  	req, err := http.NewRequest("GET", "http://example.com", nil)
   158  	require.NoError(err)
   159  
   160  	res, err := client.Do(req)
   161  	require.NoError(err)
   162  	require.Equal(response, res)
   163  }