github.com/grafana/pyroscope@v1.18.0/pkg/util/http_test.go (about) 1 package util 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "io" 8 "net/http" 9 "net/http/httptest" 10 "strings" 11 "testing" 12 13 "github.com/go-kit/log" 14 "github.com/grafana/dskit/user" 15 "github.com/stretchr/testify/assert" 16 "github.com/stretchr/testify/require" 17 18 "github.com/grafana/pyroscope/pkg/tenant" 19 ) 20 21 func TestWriteTextResponse(t *testing.T) { 22 w := httptest.NewRecorder() 23 24 WriteTextResponse(w, "hello world") 25 26 assert.Equal(t, 200, w.Code) 27 assert.Equal(t, "hello world", w.Body.String()) 28 assert.Equal(t, "text/plain", w.Header().Get("Content-Type")) 29 } 30 31 func TestMultitenantMiddleware(t *testing.T) { 32 w := httptest.NewRecorder() 33 r := httptest.NewRequest("GET", "http://localhost:8080", nil) 34 35 // No org ID header. 36 m := AuthenticateUser(true).Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 37 id, err := tenant.ExtractTenantIDFromContext(r.Context()) 38 require.NoError(t, err) 39 assert.Equal(t, "1", id) 40 })) 41 m.ServeHTTP(w, r) 42 assert.Equal(t, http.StatusUnauthorized, w.Code) 43 44 w = httptest.NewRecorder() 45 r.Header.Set("X-Scope-OrgID", "1") 46 m.ServeHTTP(w, r) 47 assert.Equal(t, http.StatusOK, w.Code) 48 49 // No org ID header without auth. 50 r = httptest.NewRequest("GET", "http://localhost:8080", nil) 51 m = AuthenticateUser(false).Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 52 id, err := tenant.ExtractTenantIDFromContext(r.Context()) 53 require.NoError(t, err) 54 assert.Equal(t, tenant.DefaultTenantID, id) 55 })) 56 m.ServeHTTP(w, r) 57 assert.Equal(t, http.StatusOK, w.Code) 58 } 59 60 func removeLogFields(line string, fields ...string) string { 61 for _, field := range fields { 62 // find field 63 needle := field + "=" 64 pos := strings.Index(line, needle) 65 if pos < 0 { 66 continue 67 } 68 69 // find space after field 70 offset := pos + len(needle) 71 posSpace := strings.Index(line[offset:], " ") 72 if posSpace < 0 { 73 // remove all after needle 74 line = line[0:offset] 75 continue 76 } 77 78 // remove value 79 line = line[:offset] + line[offset+posSpace:] 80 } 81 82 return line 83 84 } 85 86 type errorRecorder struct { 87 writeErr error 88 } 89 90 func (r *errorRecorder) Write([]byte) (int, error) { return 0, r.writeErr } 91 92 func (*errorRecorder) Header() http.Header { return make(http.Header) } 93 94 func (*errorRecorder) WriteHeader(statusCode int) {} 95 96 func TestHTTPLog(t *testing.T) { 97 ctxTenant := user.InjectOrgID(context.Background(), "my-tenant") 98 for _, tc := range []struct { 99 name string 100 log *Log 101 ctx context.Context 102 reqBody io.Reader 103 writeErr error 104 setHeaderList []string 105 statusCode int 106 message string 107 }{ 108 { 109 name: "Header logging disabled", 110 log: &Log{ 111 LogRequestHeaders: false, 112 }, 113 setHeaderList: []string{"good-header", "authorization"}, 114 message: `level=debug method=GET uri=http://example.com/foo status=200 duration= msg="http request processed"`, 115 }, 116 { 117 name: "Header logging enable", 118 log: &Log{ 119 LogRequestHeaders: true, 120 }, 121 setHeaderList: []string{"good-header", "authorization"}, 122 message: `level=debug method=GET uri=http://example.com/foo status=200 duration= request_header_Good-Header=good-headerValue msg="http request processed"`, 123 }, 124 { 125 name: "Extra Header excluded", 126 log: &Log{ 127 LogRequestHeaders: true, 128 LogRequestExcludeHeaders: []string{"bad-header"}, 129 }, 130 setHeaderList: []string{"good-header", "bad-header", "authorization"}, 131 message: `level=debug method=GET uri=http://example.com/foo status=200 duration= request_header_Good-Header=good-headerValue msg="http request processed"`, 132 }, 133 { 134 name: "Extra Header with different casing", 135 log: &Log{ 136 LogRequestHeaders: true, 137 LogRequestExcludeHeaders: []string{"Bad-Header"}, 138 }, 139 setHeaderList: []string{"good-header", "bad-header", "authorization"}, 140 message: `level=debug method=GET uri=http://example.com/foo status=200 duration= request_header_Good-Header=good-headerValue msg="http request processed"`, 141 }, 142 { 143 name: "Two Extra Headers excluded", 144 log: &Log{ 145 LogRequestHeaders: true, 146 LogRequestExcludeHeaders: []string{"bad-header", "bad-header2"}, 147 }, 148 setHeaderList: []string{"good-header", "bad-header", "bad-header2", "authorization"}, 149 message: `level=debug method=GET uri=http://example.com/foo status=200 duration= request_header_Good-Header=good-headerValue msg="http request processed"`, 150 }, 151 { 152 name: "Status code 500 should still log headers", 153 log: &Log{ 154 LogRequestHeaders: false, 155 LogRequestExcludeHeaders: []string{"bad-header"}, 156 }, 157 setHeaderList: []string{"good-header", "bad-header", "authorization"}, 158 message: `level=warn method=GET uri=http://example.com/foo status=500 duration= request_header_Good-Header=good-headerValue msg="http request failed" response_body="<html><body>Hello world!</body></html>"`, 159 160 statusCode: http.StatusInternalServerError, 161 }, 162 { 163 name: "Log request body size latency", 164 log: &Log{}, 165 reqBody: strings.NewReader("Hello World! I am a request body."), 166 message: `level=debug method=GET uri=http://example.com/foo status=200 duration= request_body_size=33B request_body_read_duration= msg="http request processed"`, 167 }, 168 { 169 name: "Write errors should be shown at warning level", 170 log: &Log{}, 171 writeErr: errors.New("some error"), 172 message: `level=warn method=GET uri=http://example.com/foo status=200 duration= msg="http request failed" err="some error"`, 173 }, 174 { 175 name: "Context cancelled requests should not be at warning level", 176 log: &Log{}, 177 writeErr: context.Canceled, 178 message: `level=debug method=GET uri=http://example.com/foo status=200 duration= msg="request cancelled"`, 179 }, 180 { 181 name: "Tenant id should be logged", 182 ctx: ctxTenant, 183 log: &Log{}, 184 message: `level=debug tenant=my-tenant method=GET uri=http://example.com/foo status=200 duration= msg="http request processed"`, 185 }, 186 } { 187 t.Run(tc.name, func(t *testing.T) { 188 buf := bytes.NewBuffer(nil) 189 190 if tc.statusCode == 0 { 191 tc.statusCode = http.StatusOK 192 } 193 194 ctx := tc.ctx 195 if ctx == nil { 196 ctx = context.Background() 197 } 198 199 tc.log.Log = log.NewLogfmtLogger(buf) 200 201 handler := func(w http.ResponseWriter, r *http.Request) { 202 if r.Body != nil { 203 _, _ = io.Copy(io.Discard, r.Body) 204 } 205 w.WriteHeader(tc.statusCode) 206 _, _ = io.WriteString(w, "<html><body>Hello world!</body></html>") 207 } 208 loggingHandler := tc.log.Wrap(http.HandlerFunc(handler)) 209 210 req := httptest.NewRequestWithContext(ctx, "GET", "http://example.com/foo", tc.reqBody) 211 for _, header := range tc.setHeaderList { 212 req.Header.Set(header, header+"Value") 213 } 214 215 var recorder http.ResponseWriter = httptest.NewRecorder() 216 if tc.writeErr != nil { 217 recorder = &errorRecorder{writeErr: tc.writeErr} 218 } 219 loggingHandler.ServeHTTP(recorder, req) 220 221 output := buf.String() 222 assert.Equal(t, tc.message, removeLogFields(strings.TrimSpace(output), "duration", "request_body_read_duration")) 223 }) 224 } 225 }