github.com/weaveworks/common@v0.0.0-20230728070032-dd9e68f319d5/middleware/logging_test.go (about) 1 package middleware 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "io" 8 "net/http" 9 "net/http/httptest" 10 "testing" 11 12 "github.com/sirupsen/logrus" 13 "github.com/stretchr/testify/require" 14 15 "github.com/weaveworks/common/logging" 16 ) 17 18 func TestBadWriteLogging(t *testing.T) { 19 for _, tc := range []struct { 20 err error 21 logContains []string 22 }{{ 23 err: context.Canceled, 24 logContains: []string{"debug", "request cancelled: context canceled"}, 25 }, { 26 err: errors.New("yolo"), 27 logContains: []string{"warning", "error: yolo"}, 28 }, { 29 err: nil, 30 logContains: []string{"debug", "GET http://example.com/foo (200)"}, 31 }} { 32 buf := bytes.NewBuffer(nil) 33 logrusLogger := logrus.New() 34 logrusLogger.Out = buf 35 logrusLogger.Level = logrus.DebugLevel 36 37 loggingMiddleware := Log{ 38 Log: logging.Logrus(logrusLogger), 39 } 40 handler := func(w http.ResponseWriter, r *http.Request) { 41 io.WriteString(w, "<html><body>Hello World!</body></html>") 42 } 43 loggingHandler := loggingMiddleware.Wrap(http.HandlerFunc(handler)) 44 45 req := httptest.NewRequest("GET", "http://example.com/foo", nil) 46 recorder := httptest.NewRecorder() 47 48 w := errorWriter{ 49 err: tc.err, 50 w: recorder, 51 } 52 loggingHandler.ServeHTTP(w, req) 53 54 for _, content := range tc.logContains { 55 require.True(t, bytes.Contains(buf.Bytes(), []byte(content))) 56 } 57 } 58 } 59 60 func TestDisabledSuccessfulRequestsLogging(t *testing.T) { 61 for _, tc := range []struct { 62 err error 63 disableLog bool 64 logContains string 65 }{ 66 { 67 err: nil, 68 disableLog: false, 69 }, { 70 err: nil, 71 disableLog: true, 72 logContains: "", 73 }, 74 } { 75 buf := bytes.NewBuffer(nil) 76 logrusLogger := logrus.New() 77 logrusLogger.Out = buf 78 logrusLogger.Level = logrus.DebugLevel 79 80 loggingMiddleware := Log{ 81 Log: logging.Logrus(logrusLogger), 82 DisableRequestSuccessLog: tc.disableLog, 83 } 84 85 handler := func(w http.ResponseWriter, r *http.Request) { 86 io.WriteString(w, "<html><body>Hello World!</body></html>") //nolint:errcheck 87 } 88 loggingHandler := loggingMiddleware.Wrap(http.HandlerFunc(handler)) 89 90 req := httptest.NewRequest("GET", "http://example.com/foo", nil) 91 recorder := httptest.NewRecorder() 92 93 w := errorWriter{ 94 err: tc.err, 95 w: recorder, 96 } 97 loggingHandler.ServeHTTP(w, req) 98 content := buf.String() 99 100 if !tc.disableLog { 101 require.Contains(t, content, "GET http://example.com/foo (200)") 102 } else { 103 require.NotContains(t, content, "(200)") 104 require.Empty(t, content) 105 } 106 } 107 } 108 109 func TestLoggingRequestsAtInfoLevel(t *testing.T) { 110 for _, tc := range []struct { 111 err error 112 logContains []string 113 }{{ 114 err: context.Canceled, 115 logContains: []string{"info", "request cancelled: context canceled"}, 116 }, { 117 err: nil, 118 logContains: []string{"info", "GET http://example.com/foo (200)"}, 119 }} { 120 buf := bytes.NewBuffer(nil) 121 logrusLogger := logrus.New() 122 logrusLogger.Out = buf 123 logrusLogger.Level = logrus.DebugLevel 124 125 loggingMiddleware := Log{ 126 Log: logging.Logrus(logrusLogger), 127 LogRequestAtInfoLevel: true, 128 } 129 handler := func(w http.ResponseWriter, r *http.Request) { 130 io.WriteString(w, "<html><body>Hello World!</body></html>") 131 } 132 loggingHandler := loggingMiddleware.Wrap(http.HandlerFunc(handler)) 133 134 req := httptest.NewRequest("GET", "http://example.com/foo", nil) 135 recorder := httptest.NewRecorder() 136 137 w := errorWriter{ 138 err: tc.err, 139 w: recorder, 140 } 141 loggingHandler.ServeHTTP(w, req) 142 143 for _, content := range tc.logContains { 144 require.True(t, bytes.Contains(buf.Bytes(), []byte(content))) 145 } 146 } 147 } 148 149 func TestLoggingRequestWithExcludedHeaders(t *testing.T) { 150 defaultHeaders := []string{"Authorization", "Cookie", "X-Csrf-Token"} 151 for _, tc := range []struct { 152 name string 153 setHeaderList []string 154 excludeHeaderList []string 155 mustNotContain []string 156 }{ 157 { 158 name: "Default excluded headers are excluded", 159 setHeaderList: defaultHeaders, 160 mustNotContain: defaultHeaders, 161 }, 162 { 163 name: "Extra configured header is also excluded", 164 setHeaderList: append(defaultHeaders, "X-Secret-Header"), 165 excludeHeaderList: []string{"X-Secret-Header"}, 166 mustNotContain: append(defaultHeaders, "X-Secret-Header"), 167 }, 168 { 169 name: "Multiple extra configured headers are also excluded", 170 setHeaderList: append(defaultHeaders, "X-Secret-Header", "X-Secret-Header-2"), 171 excludeHeaderList: []string{"X-Secret-Header", "X-Secret-Header-2"}, 172 mustNotContain: append(defaultHeaders, "X-Secret-Header", "X-Secret-Header-2"), 173 }, 174 } { 175 t.Run(tc.name, func(t *testing.T) { 176 buf := bytes.NewBuffer(nil) 177 logrusLogger := logrus.New() 178 logrusLogger.Out = buf 179 logrusLogger.Level = logrus.DebugLevel 180 181 loggingMiddleware := NewLogMiddleware(logging.Logrus(logrusLogger), true, false, nil, tc.excludeHeaderList) 182 183 handler := func(w http.ResponseWriter, r *http.Request) { 184 _, _ = io.WriteString(w, "<html><body>Hello world!</body></html>") 185 } 186 loggingHandler := loggingMiddleware.Wrap(http.HandlerFunc(handler)) 187 188 req := httptest.NewRequest("GET", "http://example.com/foo", nil) 189 for _, header := range tc.setHeaderList { 190 req.Header.Set(header, header) 191 } 192 193 recorder := httptest.NewRecorder() 194 loggingHandler.ServeHTTP(recorder, req) 195 196 output := buf.String() 197 for _, header := range tc.mustNotContain { 198 require.NotContains(t, output, header) 199 } 200 }) 201 } 202 } 203 204 type errorWriter struct { 205 err error 206 207 w http.ResponseWriter 208 } 209 210 func (e errorWriter) Header() http.Header { 211 return e.w.Header() 212 } 213 214 func (e errorWriter) WriteHeader(statusCode int) { 215 e.w.WriteHeader(statusCode) 216 } 217 218 func (e errorWriter) Write(b []byte) (int, error) { 219 if e.err != nil { 220 return 0, e.err 221 } 222 223 return e.w.Write(b) 224 }