github.com/kyma-incubator/compass/components/director@v0.0.0-20230623144113-d764f56ff805/pkg/log/request_test.go (about)

     1  package log_test
     2  
     3  import (
     4  	"bytes"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"net/url"
     8  	"strings"
     9  	"testing"
    10  
    11  	"github.com/sirupsen/logrus"
    12  
    13  	"github.com/kyma-incubator/compass/components/director/pkg/log"
    14  	"github.com/stretchr/testify/require"
    15  )
    16  
    17  func TestRequestLoggerGeneratesCorrelationIDWhenNotFoundInHeaders(t *testing.T) {
    18  	response := httptest.NewRecorder()
    19  
    20  	testURL, err := url.Parse("http://localhost:8080")
    21  	require.NoError(t, err)
    22  	request := &http.Request{
    23  		Method: http.MethodPost,
    24  		URL:    testURL,
    25  		Header: map[string][]string{},
    26  	}
    27  
    28  	handler := log.RequestLogger()
    29  	handler(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
    30  		entry := log.C(request.Context())
    31  
    32  		correlationIDFromLogger, exists := entry.Data[log.FieldRequestID]
    33  		require.True(t, exists)
    34  		require.NotEmpty(t, correlationIDFromLogger)
    35  	})).ServeHTTP(response, request)
    36  }
    37  
    38  func TestRequestLoggerUseCorrelationIDFromHeaderIfProvided(t *testing.T) {
    39  	correlationID := "test-correlation-id"
    40  	response := httptest.NewRecorder()
    41  
    42  	testURL, err := url.Parse("http://localhost:8080")
    43  	require.NoError(t, err)
    44  	request := &http.Request{
    45  		Method: http.MethodPost,
    46  		URL:    testURL,
    47  		Header: map[string][]string{},
    48  	}
    49  	request.Header.Set("x-request-id", correlationID)
    50  	request.Header.Set("X-Real-IP", "127.0.0.1")
    51  
    52  	handler := log.RequestLogger()
    53  	handler(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
    54  		entry := log.C(request.Context())
    55  
    56  		correlationIDFromLogger, exists := entry.Data[log.FieldRequestID]
    57  		require.True(t, exists)
    58  		require.Equal(t, correlationID, correlationIDFromLogger)
    59  	})).ServeHTTP(response, request)
    60  }
    61  
    62  func TestRequestLoggerWithMDC(t *testing.T) {
    63  	response := httptest.NewRecorder()
    64  	testURL, err := url.Parse("http://localhost:8080")
    65  	require.NoError(t, err)
    66  
    67  	request := &http.Request{
    68  		Method: http.MethodPost,
    69  		URL:    testURL,
    70  		Header: map[string][]string{},
    71  	}
    72  
    73  	oldLogger := logrus.StandardLogger().Out
    74  	buf := bytes.Buffer{}
    75  	logrus.SetOutput(&buf)
    76  	defer logrus.SetOutput(oldLogger)
    77  
    78  	handler := log.RequestLogger()
    79  	handler(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
    80  		hasMdc := false
    81  		if mdc := log.MdcFromContext(request.Context()); nil != mdc {
    82  			hasMdc = true
    83  			mdc.Set("test", "test")
    84  		}
    85  		require.True(t, hasMdc, "There is no MDC in the request context")
    86  
    87  		//remove the "Started handling ..." line
    88  		buf.Reset()
    89  	})).ServeHTTP(response, request)
    90  
    91  	logLine := buf.String()
    92  	hasMdcMessage := strings.Contains(logLine, "test=test")
    93  	require.True(t, hasMdcMessage, "The log line does not contain the MDC content: %v", logLine)
    94  }
    95  
    96  func TestRequestLoggerDebugPaths(t *testing.T) {
    97  	response := httptest.NewRecorder()
    98  	testURL, err := url.Parse("http://localhost:8080")
    99  	require.NoError(t, err)
   100  
   101  	request := &http.Request{
   102  		Method: http.MethodPost,
   103  		URL:    testURL,
   104  		Header: map[string][]string{},
   105  	}
   106  
   107  	oldLogger := logrus.StandardLogger().Out
   108  	buf := bytes.Buffer{}
   109  	logrus.SetOutput(&buf)
   110  	logrus.SetLevel(logrus.DebugLevel)
   111  	defer logrus.SetOutput(oldLogger)
   112  
   113  	emptyHandlerFunc := func(writer http.ResponseWriter, request *http.Request) {}
   114  
   115  	const debugPath = "/healthz"
   116  	handler := log.RequestLogger(debugPath)
   117  	handler(http.HandlerFunc(emptyHandlerFunc)).ServeHTTP(response, request)
   118  
   119  	logs := buf.String()
   120  	require.Contains(t, logs, `level=info msg="Started handling request..."`)
   121  	require.Contains(t, logs, `level=info msg="Finished handling request..."`)
   122  
   123  	buf.Reset()
   124  	request.URL.Path = debugPath
   125  	handler(http.HandlerFunc(emptyHandlerFunc)).ServeHTTP(response, request)
   126  
   127  	logs = buf.String()
   128  	require.Contains(t, logs, `level=debug msg="Started handling request..."`)
   129  	require.Contains(t, logs, `level=debug msg="Finished handling request..."`)
   130  }