github.com/weaveworks/common@v0.0.0-20230728070032-dd9e68f319d5/middleware/logging_test.go (about)

     1  package middleware
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"io"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"testing"
    11  
    12  	"github.com/sirupsen/logrus"
    13  	"github.com/stretchr/testify/require"
    14  
    15  	"github.com/weaveworks/common/logging"
    16  )
    17  
    18  func TestBadWriteLogging(t *testing.T) {
    19  	for _, tc := range []struct {
    20  		err         error
    21  		logContains []string
    22  	}{{
    23  		err:         context.Canceled,
    24  		logContains: []string{"debug", "request cancelled: context canceled"},
    25  	}, {
    26  		err:         errors.New("yolo"),
    27  		logContains: []string{"warning", "error: yolo"},
    28  	}, {
    29  		err:         nil,
    30  		logContains: []string{"debug", "GET http://example.com/foo (200)"},
    31  	}} {
    32  		buf := bytes.NewBuffer(nil)
    33  		logrusLogger := logrus.New()
    34  		logrusLogger.Out = buf
    35  		logrusLogger.Level = logrus.DebugLevel
    36  
    37  		loggingMiddleware := Log{
    38  			Log: logging.Logrus(logrusLogger),
    39  		}
    40  		handler := func(w http.ResponseWriter, r *http.Request) {
    41  			io.WriteString(w, "<html><body>Hello World!</body></html>")
    42  		}
    43  		loggingHandler := loggingMiddleware.Wrap(http.HandlerFunc(handler))
    44  
    45  		req := httptest.NewRequest("GET", "http://example.com/foo", nil)
    46  		recorder := httptest.NewRecorder()
    47  
    48  		w := errorWriter{
    49  			err: tc.err,
    50  			w:   recorder,
    51  		}
    52  		loggingHandler.ServeHTTP(w, req)
    53  
    54  		for _, content := range tc.logContains {
    55  			require.True(t, bytes.Contains(buf.Bytes(), []byte(content)))
    56  		}
    57  	}
    58  }
    59  
    60  func TestDisabledSuccessfulRequestsLogging(t *testing.T) {
    61  	for _, tc := range []struct {
    62  		err         error
    63  		disableLog  bool
    64  		logContains string
    65  	}{
    66  		{
    67  			err:        nil,
    68  			disableLog: false,
    69  		}, {
    70  			err:         nil,
    71  			disableLog:  true,
    72  			logContains: "",
    73  		},
    74  	} {
    75  		buf := bytes.NewBuffer(nil)
    76  		logrusLogger := logrus.New()
    77  		logrusLogger.Out = buf
    78  		logrusLogger.Level = logrus.DebugLevel
    79  
    80  		loggingMiddleware := Log{
    81  			Log:                      logging.Logrus(logrusLogger),
    82  			DisableRequestSuccessLog: tc.disableLog,
    83  		}
    84  
    85  		handler := func(w http.ResponseWriter, r *http.Request) {
    86  			io.WriteString(w, "<html><body>Hello World!</body></html>") //nolint:errcheck
    87  		}
    88  		loggingHandler := loggingMiddleware.Wrap(http.HandlerFunc(handler))
    89  
    90  		req := httptest.NewRequest("GET", "http://example.com/foo", nil)
    91  		recorder := httptest.NewRecorder()
    92  
    93  		w := errorWriter{
    94  			err: tc.err,
    95  			w:   recorder,
    96  		}
    97  		loggingHandler.ServeHTTP(w, req)
    98  		content := buf.String()
    99  
   100  		if !tc.disableLog {
   101  			require.Contains(t, content, "GET http://example.com/foo (200)")
   102  		} else {
   103  			require.NotContains(t, content, "(200)")
   104  			require.Empty(t, content)
   105  		}
   106  	}
   107  }
   108  
   109  func TestLoggingRequestsAtInfoLevel(t *testing.T) {
   110  	for _, tc := range []struct {
   111  		err         error
   112  		logContains []string
   113  	}{{
   114  		err:         context.Canceled,
   115  		logContains: []string{"info", "request cancelled: context canceled"},
   116  	}, {
   117  		err:         nil,
   118  		logContains: []string{"info", "GET http://example.com/foo (200)"},
   119  	}} {
   120  		buf := bytes.NewBuffer(nil)
   121  		logrusLogger := logrus.New()
   122  		logrusLogger.Out = buf
   123  		logrusLogger.Level = logrus.DebugLevel
   124  
   125  		loggingMiddleware := Log{
   126  			Log:                   logging.Logrus(logrusLogger),
   127  			LogRequestAtInfoLevel: true,
   128  		}
   129  		handler := func(w http.ResponseWriter, r *http.Request) {
   130  			io.WriteString(w, "<html><body>Hello World!</body></html>")
   131  		}
   132  		loggingHandler := loggingMiddleware.Wrap(http.HandlerFunc(handler))
   133  
   134  		req := httptest.NewRequest("GET", "http://example.com/foo", nil)
   135  		recorder := httptest.NewRecorder()
   136  
   137  		w := errorWriter{
   138  			err: tc.err,
   139  			w:   recorder,
   140  		}
   141  		loggingHandler.ServeHTTP(w, req)
   142  
   143  		for _, content := range tc.logContains {
   144  			require.True(t, bytes.Contains(buf.Bytes(), []byte(content)))
   145  		}
   146  	}
   147  }
   148  
   149  func TestLoggingRequestWithExcludedHeaders(t *testing.T) {
   150  	defaultHeaders := []string{"Authorization", "Cookie", "X-Csrf-Token"}
   151  	for _, tc := range []struct {
   152  		name              string
   153  		setHeaderList     []string
   154  		excludeHeaderList []string
   155  		mustNotContain    []string
   156  	}{
   157  		{
   158  			name:           "Default excluded headers are excluded",
   159  			setHeaderList:  defaultHeaders,
   160  			mustNotContain: defaultHeaders,
   161  		},
   162  		{
   163  			name:              "Extra configured header is also excluded",
   164  			setHeaderList:     append(defaultHeaders, "X-Secret-Header"),
   165  			excludeHeaderList: []string{"X-Secret-Header"},
   166  			mustNotContain:    append(defaultHeaders, "X-Secret-Header"),
   167  		},
   168  		{
   169  			name:              "Multiple extra configured headers are also excluded",
   170  			setHeaderList:     append(defaultHeaders, "X-Secret-Header", "X-Secret-Header-2"),
   171  			excludeHeaderList: []string{"X-Secret-Header", "X-Secret-Header-2"},
   172  			mustNotContain:    append(defaultHeaders, "X-Secret-Header", "X-Secret-Header-2"),
   173  		},
   174  	} {
   175  		t.Run(tc.name, func(t *testing.T) {
   176  			buf := bytes.NewBuffer(nil)
   177  			logrusLogger := logrus.New()
   178  			logrusLogger.Out = buf
   179  			logrusLogger.Level = logrus.DebugLevel
   180  
   181  			loggingMiddleware := NewLogMiddleware(logging.Logrus(logrusLogger), true, false, nil, tc.excludeHeaderList)
   182  
   183  			handler := func(w http.ResponseWriter, r *http.Request) {
   184  				_, _ = io.WriteString(w, "<html><body>Hello world!</body></html>")
   185  			}
   186  			loggingHandler := loggingMiddleware.Wrap(http.HandlerFunc(handler))
   187  
   188  			req := httptest.NewRequest("GET", "http://example.com/foo", nil)
   189  			for _, header := range tc.setHeaderList {
   190  				req.Header.Set(header, header)
   191  			}
   192  
   193  			recorder := httptest.NewRecorder()
   194  			loggingHandler.ServeHTTP(recorder, req)
   195  
   196  			output := buf.String()
   197  			for _, header := range tc.mustNotContain {
   198  				require.NotContains(t, output, header)
   199  			}
   200  		})
   201  	}
   202  }
   203  
   204  type errorWriter struct {
   205  	err error
   206  
   207  	w http.ResponseWriter
   208  }
   209  
   210  func (e errorWriter) Header() http.Header {
   211  	return e.w.Header()
   212  }
   213  
   214  func (e errorWriter) WriteHeader(statusCode int) {
   215  	e.w.WriteHeader(statusCode)
   216  }
   217  
   218  func (e errorWriter) Write(b []byte) (int, error) {
   219  	if e.err != nil {
   220  		return 0, e.err
   221  	}
   222  
   223  	return e.w.Write(b)
   224  }