github.com/kyma-incubator/compass/components/director@v0.0.0-20230623144113-d764f56ff805/pkg/correlation/middleware_test.go (about)

     1  package correlation_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"testing"
     9  
    10  	"github.com/kyma-incubator/compass/components/director/pkg/correlation"
    11  	"github.com/stretchr/testify/assert"
    12  )
    13  
    14  const expectedRequestID = "123"
    15  
    16  func TestContextEnrichMiddleware_AttachCorrelationIDToContext(t *testing.T) {
    17  	// given
    18  	handler := correlation.AttachCorrelationIDToContext()
    19  
    20  	t.Run("when x-request-id header is present it's added as correlation header to the request context and headers", func(t *testing.T) {
    21  		nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    22  			headersFromContext, ok := r.Context().Value(correlation.HeadersContextKey).(correlation.Headers)
    23  			assert.True(t, ok)
    24  
    25  			actual, ok := headersFromContext[correlation.RequestIDHeaderKey]
    26  			assert.True(t, ok)
    27  			assert.Equal(t, actual, expectedRequestID)
    28  
    29  			headerFromRequest := r.Header.Get(correlation.RequestIDHeaderKey)
    30  			assert.Equal(t, headerFromRequest, expectedRequestID)
    31  
    32  			correlationID := correlation.CorrelationIDFromContext(r.Context())
    33  			assert.Equal(t, correlationID, expectedRequestID)
    34  		})
    35  
    36  		req := httptest.NewRequest("GET", "/", nil)
    37  		req.Header.Set(correlation.RequestIDHeaderKey, expectedRequestID)
    38  
    39  		handler(nextHandler).ServeHTTP(httptest.NewRecorder(), req)
    40  	})
    41  
    42  	t.Run("when no identifying headers are present a new correlation header is added to the request context and headers", func(t *testing.T) {
    43  		nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    44  			headersFromContext, ok := r.Context().Value(correlation.HeadersContextKey).(correlation.Headers)
    45  			assert.True(t, ok)
    46  
    47  			requestIDHeader, ok := headersFromContext[correlation.RequestIDHeaderKey]
    48  			assert.True(t, ok)
    49  			assert.NotEmpty(t, requestIDHeader)
    50  
    51  			headerFromRequest := r.Header.Get(correlation.RequestIDHeaderKey)
    52  			assert.NotEmpty(t, headerFromRequest)
    53  
    54  			correlationID := correlation.CorrelationIDFromContext(r.Context())
    55  			assert.NotEmpty(t, correlationID)
    56  		})
    57  
    58  		req := httptest.NewRequest("GET", "/", nil)
    59  		handler(nextHandler).ServeHTTP(httptest.NewRecorder(), req)
    60  	})
    61  }
    62  
    63  func TestContextEnrichMiddleware_HeadersForRequest(t *testing.T) {
    64  	// given
    65  	headerKeys := []string{"x-request-id", "x-b3-traceid", "x-b3-spanid", "x-b3-parentspanid", "x-b3-sampled", "x-b3-flags", "b3"}
    66  
    67  	for _, header := range headerKeys {
    68  		t.Run(fmt.Sprintf("returns %s when %s header is present", header, header), func(t *testing.T) {
    69  			req := httptest.NewRequest("GET", "/", nil)
    70  			req.Header.Set(correlation.RequestIDHeaderKey, expectedRequestID)
    71  
    72  			headers := correlation.HeadersForRequest(req)
    73  			actualRequestID, ok := headers[correlation.RequestIDHeaderKey]
    74  			assert.True(t, ok)
    75  			assert.Equal(t, expectedRequestID, actualRequestID)
    76  		})
    77  	}
    78  }
    79  
    80  func TestContextEnrichMiddleware_HeadersForRequest_WellKnownCorrelationIDsAreAddedToRequestIfPresentInContextHeaders(t *testing.T) {
    81  	wellKnownHeaderKey := "x-b3-traceid"
    82  	wellKnownHeaderValue := "35b74672-9f48-4361-8f47-408832bd5a25"
    83  
    84  	ctx := correlation.SaveCorrelationIDHeaderToContext(context.Background(), &wellKnownHeaderKey, &wellKnownHeaderValue)
    85  
    86  	req := httptest.NewRequest("GET", "/", nil)
    87  	req = req.WithContext(ctx)
    88  
    89  	headers := correlation.HeadersForRequest(req)
    90  	actualRequestID, ok := headers[correlation.RequestIDHeaderKey]
    91  	assert.True(t, ok)
    92  	assert.NotEmpty(t, actualRequestID)
    93  
    94  	actualWellKnownRequestIDFromMap, ok := headers[wellKnownHeaderKey]
    95  	assert.True(t, ok)
    96  	assert.Equal(t, wellKnownHeaderValue, actualWellKnownRequestIDFromMap)
    97  
    98  	wellKnownHeaderKeyTitleCase := "X-B3-Traceid"
    99  	actualWellKnownRequestIDFromRequest := req.Header[wellKnownHeaderKeyTitleCase][0]
   100  	assert.True(t, ok)
   101  	assert.Equal(t, wellKnownHeaderValue, actualWellKnownRequestIDFromRequest)
   102  }
   103  
   104  func TestContextEnrichMiddleware_HeadersForRequest_AdditionalContextHeadersAreAddedToRequest(t *testing.T) {
   105  	headerKey := "X-Additional-Request-Id"
   106  	headerValue := "35b74672-9f48-4361-8f47-408832bd5a25"
   107  
   108  	ctx := correlation.SaveCorrelationIDHeaderToContext(context.Background(), &headerKey, &headerValue)
   109  
   110  	req := httptest.NewRequest("GET", "/", nil)
   111  	req = req.WithContext(ctx)
   112  
   113  	headers := correlation.HeadersForRequest(req)
   114  	actualRequestID, ok := headers[correlation.RequestIDHeaderKey]
   115  	assert.True(t, ok)
   116  	assert.NotEmpty(t, actualRequestID)
   117  
   118  	actualAdditionalRequestID, ok := headers[headerKey]
   119  	assert.True(t, ok)
   120  	assert.Equal(t, headerValue, actualAdditionalRequestID)
   121  
   122  	actualAdditionalRequestIDFromRequest := req.Header[headerKey][0]
   123  	assert.True(t, ok)
   124  	assert.Equal(t, headerValue, actualAdditionalRequestIDFromRequest)
   125  }