github.com/grafana/pyroscope@v1.18.0/pkg/util/http_test.go (about)

     1  package util
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"io"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"strings"
    11  	"testing"
    12  
    13  	"github.com/go-kit/log"
    14  	"github.com/grafana/dskit/user"
    15  	"github.com/stretchr/testify/assert"
    16  	"github.com/stretchr/testify/require"
    17  
    18  	"github.com/grafana/pyroscope/pkg/tenant"
    19  )
    20  
    21  func TestWriteTextResponse(t *testing.T) {
    22  	w := httptest.NewRecorder()
    23  
    24  	WriteTextResponse(w, "hello world")
    25  
    26  	assert.Equal(t, 200, w.Code)
    27  	assert.Equal(t, "hello world", w.Body.String())
    28  	assert.Equal(t, "text/plain", w.Header().Get("Content-Type"))
    29  }
    30  
    31  func TestMultitenantMiddleware(t *testing.T) {
    32  	w := httptest.NewRecorder()
    33  	r := httptest.NewRequest("GET", "http://localhost:8080", nil)
    34  
    35  	// No org ID header.
    36  	m := AuthenticateUser(true).Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    37  		id, err := tenant.ExtractTenantIDFromContext(r.Context())
    38  		require.NoError(t, err)
    39  		assert.Equal(t, "1", id)
    40  	}))
    41  	m.ServeHTTP(w, r)
    42  	assert.Equal(t, http.StatusUnauthorized, w.Code)
    43  
    44  	w = httptest.NewRecorder()
    45  	r.Header.Set("X-Scope-OrgID", "1")
    46  	m.ServeHTTP(w, r)
    47  	assert.Equal(t, http.StatusOK, w.Code)
    48  
    49  	// No org ID header without auth.
    50  	r = httptest.NewRequest("GET", "http://localhost:8080", nil)
    51  	m = AuthenticateUser(false).Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    52  		id, err := tenant.ExtractTenantIDFromContext(r.Context())
    53  		require.NoError(t, err)
    54  		assert.Equal(t, tenant.DefaultTenantID, id)
    55  	}))
    56  	m.ServeHTTP(w, r)
    57  	assert.Equal(t, http.StatusOK, w.Code)
    58  }
    59  
    60  func removeLogFields(line string, fields ...string) string {
    61  	for _, field := range fields {
    62  		// find field
    63  		needle := field + "="
    64  		pos := strings.Index(line, needle)
    65  		if pos < 0 {
    66  			continue
    67  		}
    68  
    69  		// find space after field
    70  		offset := pos + len(needle)
    71  		posSpace := strings.Index(line[offset:], " ")
    72  		if posSpace < 0 {
    73  			// remove all after needle
    74  			line = line[0:offset]
    75  			continue
    76  		}
    77  
    78  		// remove value
    79  		line = line[:offset] + line[offset+posSpace:]
    80  	}
    81  
    82  	return line
    83  
    84  }
    85  
    86  type errorRecorder struct {
    87  	writeErr error
    88  }
    89  
    90  func (r *errorRecorder) Write([]byte) (int, error) { return 0, r.writeErr }
    91  
    92  func (*errorRecorder) Header() http.Header { return make(http.Header) }
    93  
    94  func (*errorRecorder) WriteHeader(statusCode int) {}
    95  
    96  func TestHTTPLog(t *testing.T) {
    97  	ctxTenant := user.InjectOrgID(context.Background(), "my-tenant")
    98  	for _, tc := range []struct {
    99  		name          string
   100  		log           *Log
   101  		ctx           context.Context
   102  		reqBody       io.Reader
   103  		writeErr      error
   104  		setHeaderList []string
   105  		statusCode    int
   106  		message       string
   107  	}{
   108  		{
   109  			name: "Header logging disabled",
   110  			log: &Log{
   111  				LogRequestHeaders: false,
   112  			},
   113  			setHeaderList: []string{"good-header", "authorization"},
   114  			message:       `level=debug method=GET uri=http://example.com/foo status=200 duration= msg="http request processed"`,
   115  		},
   116  		{
   117  			name: "Header logging enable",
   118  			log: &Log{
   119  				LogRequestHeaders: true,
   120  			},
   121  			setHeaderList: []string{"good-header", "authorization"},
   122  			message:       `level=debug method=GET uri=http://example.com/foo status=200 duration= request_header_Good-Header=good-headerValue msg="http request processed"`,
   123  		},
   124  		{
   125  			name: "Extra Header excluded",
   126  			log: &Log{
   127  				LogRequestHeaders:        true,
   128  				LogRequestExcludeHeaders: []string{"bad-header"},
   129  			},
   130  			setHeaderList: []string{"good-header", "bad-header", "authorization"},
   131  			message:       `level=debug method=GET uri=http://example.com/foo status=200 duration= request_header_Good-Header=good-headerValue msg="http request processed"`,
   132  		},
   133  		{
   134  			name: "Extra Header with different casing",
   135  			log: &Log{
   136  				LogRequestHeaders:        true,
   137  				LogRequestExcludeHeaders: []string{"Bad-Header"},
   138  			},
   139  			setHeaderList: []string{"good-header", "bad-header", "authorization"},
   140  			message:       `level=debug method=GET uri=http://example.com/foo status=200 duration= request_header_Good-Header=good-headerValue msg="http request processed"`,
   141  		},
   142  		{
   143  			name: "Two Extra Headers excluded",
   144  			log: &Log{
   145  				LogRequestHeaders:        true,
   146  				LogRequestExcludeHeaders: []string{"bad-header", "bad-header2"},
   147  			},
   148  			setHeaderList: []string{"good-header", "bad-header", "bad-header2", "authorization"},
   149  			message:       `level=debug method=GET uri=http://example.com/foo status=200 duration= request_header_Good-Header=good-headerValue msg="http request processed"`,
   150  		},
   151  		{
   152  			name: "Status code 500 should still log headers",
   153  			log: &Log{
   154  				LogRequestHeaders:        false,
   155  				LogRequestExcludeHeaders: []string{"bad-header"},
   156  			},
   157  			setHeaderList: []string{"good-header", "bad-header", "authorization"},
   158  			message:       `level=warn method=GET uri=http://example.com/foo status=500 duration= request_header_Good-Header=good-headerValue msg="http request failed" response_body="<html><body>Hello world!</body></html>"`,
   159  
   160  			statusCode: http.StatusInternalServerError,
   161  		},
   162  		{
   163  			name:    "Log request body size latency",
   164  			log:     &Log{},
   165  			reqBody: strings.NewReader("Hello World! I am a request body."),
   166  			message: `level=debug method=GET uri=http://example.com/foo status=200 duration= request_body_size=33B request_body_read_duration= msg="http request processed"`,
   167  		},
   168  		{
   169  			name:     "Write errors should be shown at warning level",
   170  			log:      &Log{},
   171  			writeErr: errors.New("some error"),
   172  			message:  `level=warn method=GET uri=http://example.com/foo status=200 duration= msg="http request failed" err="some error"`,
   173  		},
   174  		{
   175  			name:     "Context cancelled requests should not be at warning level",
   176  			log:      &Log{},
   177  			writeErr: context.Canceled,
   178  			message:  `level=debug method=GET uri=http://example.com/foo status=200 duration= msg="request cancelled"`,
   179  		},
   180  		{
   181  			name:    "Tenant id should be logged",
   182  			ctx:     ctxTenant,
   183  			log:     &Log{},
   184  			message: `level=debug tenant=my-tenant method=GET uri=http://example.com/foo status=200 duration= msg="http request processed"`,
   185  		},
   186  	} {
   187  		t.Run(tc.name, func(t *testing.T) {
   188  			buf := bytes.NewBuffer(nil)
   189  
   190  			if tc.statusCode == 0 {
   191  				tc.statusCode = http.StatusOK
   192  			}
   193  
   194  			ctx := tc.ctx
   195  			if ctx == nil {
   196  				ctx = context.Background()
   197  			}
   198  
   199  			tc.log.Log = log.NewLogfmtLogger(buf)
   200  
   201  			handler := func(w http.ResponseWriter, r *http.Request) {
   202  				if r.Body != nil {
   203  					_, _ = io.Copy(io.Discard, r.Body)
   204  				}
   205  				w.WriteHeader(tc.statusCode)
   206  				_, _ = io.WriteString(w, "<html><body>Hello world!</body></html>")
   207  			}
   208  			loggingHandler := tc.log.Wrap(http.HandlerFunc(handler))
   209  
   210  			req := httptest.NewRequestWithContext(ctx, "GET", "http://example.com/foo", tc.reqBody)
   211  			for _, header := range tc.setHeaderList {
   212  				req.Header.Set(header, header+"Value")
   213  			}
   214  
   215  			var recorder http.ResponseWriter = httptest.NewRecorder()
   216  			if tc.writeErr != nil {
   217  				recorder = &errorRecorder{writeErr: tc.writeErr}
   218  			}
   219  			loggingHandler.ServeHTTP(recorder, req)
   220  
   221  			output := buf.String()
   222  			assert.Equal(t, tc.message, removeLogFields(strings.TrimSpace(output), "duration", "request_body_read_duration"))
   223  		})
   224  	}
   225  }