github.com/grafana/pyroscope@v1.18.0/pkg/util/http.go (about)

     1  package util
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"encoding/json"
     8  	"errors"
     9  	"html/template"
    10  	"io"
    11  	"net"
    12  	"net/http"
    13  	"net/textproto"
    14  	"slices"
    15  	"strings"
    16  	"sync"
    17  	"time"
    18  
    19  	"github.com/grafana/dskit/instrument"
    20  	"go.opentelemetry.io/otel/baggage"
    21  
    22  	"github.com/dustin/go-humanize"
    23  	"github.com/felixge/httpsnoop"
    24  	"github.com/go-kit/log"
    25  	"github.com/go-kit/log/level"
    26  	"github.com/gorilla/mux"
    27  	"github.com/grafana/dskit/middleware"
    28  	"github.com/grafana/dskit/multierror"
    29  	"github.com/grafana/dskit/tracing"
    30  	"github.com/grafana/dskit/user"
    31  	"github.com/opentracing/opentracing-go"
    32  	otlog "github.com/opentracing/opentracing-go/log"
    33  	"github.com/prometheus/client_golang/prometheus"
    34  	"golang.org/x/net/http2"
    35  	"gopkg.in/yaml.v3"
    36  
    37  	"github.com/grafana/pyroscope/pkg/tenant"
    38  	httputil "github.com/grafana/pyroscope/pkg/util/http"
    39  	"github.com/grafana/pyroscope/pkg/util/nethttp"
    40  )
    41  
    42  var defaultTransport http.RoundTripper = &http2.Transport{
    43  	AllowHTTP:        true,
    44  	ReadIdleTimeout:  30 * time.Second,
    45  	WriteByteTimeout: 30 * time.Second,
    46  	PingTimeout:      90 * time.Second,
    47  	DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
    48  		return net.Dial(network, addr)
    49  	},
    50  }
    51  
    52  var timeNow = time.Now
    53  
    54  type RoundTripperFunc func(req *http.Request) (*http.Response, error)
    55  
    56  func (f RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
    57  	return f(req)
    58  }
    59  
    60  type RoundTripperInstrumentFunc func(next http.RoundTripper) http.RoundTripper
    61  
    62  // InstrumentedDefaultHTTPClient returns an http client configured with some
    63  // default settings which is wrapped with a variety of instrumented
    64  // RoundTrippers.
    65  func InstrumentedDefaultHTTPClient(instruments ...RoundTripperInstrumentFunc) *http.Client {
    66  	client := &http.Client{
    67  		Transport: defaultTransport,
    68  	}
    69  	return InstrumentedHTTPClient(client, instruments...)
    70  }
    71  
    72  // InstrumentedHTTPClient adds the associated instrumentation middlewares to the
    73  // provided http client.
    74  func InstrumentedHTTPClient(client *http.Client, instruments ...RoundTripperInstrumentFunc) *http.Client {
    75  	for i := len(instruments) - 1; i >= 0; i-- {
    76  		client.Transport = instruments[i](client.Transport)
    77  	}
    78  	return client
    79  }
    80  
    81  // WithTracingTransport wraps the given RoundTripper with a tracing instrumented
    82  // one.
    83  func WithTracingTransport() RoundTripperInstrumentFunc {
    84  	return func(next http.RoundTripper) http.RoundTripper {
    85  		next = &nethttp.Transport{RoundTripper: next}
    86  		return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
    87  			req = nethttp.TraceRequest(opentracing.GlobalTracer(), req)
    88  			return next.RoundTrip(req)
    89  		})
    90  	}
    91  }
    92  
    93  // WithBaggageTransport will set the Baggage header on the request if there is
    94  // any baggage in the context and it was not already set.
    95  func WithBaggageTransport() RoundTripperInstrumentFunc {
    96  	return func(next http.RoundTripper) http.RoundTripper {
    97  		return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
    98  			_, ok := req.Header["Baggage"]
    99  			if ok {
   100  				return next.RoundTrip(req)
   101  			}
   102  
   103  			b := baggage.FromContext(req.Context())
   104  			if b.Len() == 0 {
   105  				return next.RoundTrip(req)
   106  			}
   107  
   108  			req.Header.Set("Baggage", b.String())
   109  			return next.RoundTrip(req)
   110  		})
   111  	}
   112  }
   113  
   114  // WriteYAMLResponse writes some YAML as a HTTP response.
   115  func WriteYAMLResponse(w http.ResponseWriter, v interface{}) {
   116  	// There is not standardised content-type for YAML, text/plain ensures the
   117  	// YAML is displayed in the browser instead of offered as a download
   118  	w.Header().Set("Content-Type", "text/plain; charset=utf-8")
   119  
   120  	data, err := yaml.Marshal(v)
   121  	if err != nil {
   122  		httputil.Error(w, err)
   123  		return
   124  	}
   125  
   126  	// We ignore errors here, because we cannot do anything about them.
   127  	// Write will trigger sending Status code, so we cannot send a different status code afterwards.
   128  	// Also this isn't internal error, but error communicating with client.
   129  	_, _ = w.Write(data)
   130  }
   131  
   132  const (
   133  	maxResponseBodyInLogs = 4096 // At most 4k bytes from response bodies in our logs.
   134  )
   135  
   136  // Log middleware logs http requests
   137  type Log struct {
   138  	Log                      log.Logger
   139  	LogRequestHeaders        bool
   140  	LogRequestExcludeHeaders []string
   141  	LogRequestAtInfoLevel    bool // LogRequestAtInfoLevel true -> log requests at info log level
   142  	SourceIPs                *middleware.SourceIPExtractor
   143  
   144  	filterHeaderMap  map[string]struct{}
   145  	filterHeaderOnce sync.Once
   146  }
   147  
   148  func (l *Log) filterHeader(key string) bool {
   149  	// ensure map is populated once
   150  	l.filterHeaderOnce.Do(func() {
   151  		l.filterHeaderMap = make(map[string]struct{})
   152  		for _, k := range l.LogRequestExcludeHeaders {
   153  			l.filterHeaderMap[textproto.CanonicalMIMEHeaderKey(k)] = struct{}{}
   154  		}
   155  		for k := range middleware.AlwaysExcludedHeaders {
   156  			l.filterHeaderMap[textproto.CanonicalMIMEHeaderKey(k)] = struct{}{}
   157  		}
   158  	})
   159  	_, filter := l.filterHeaderMap[key]
   160  	return filter
   161  }
   162  
   163  func (l *Log) extractHeaders(req *http.Request) []any {
   164  	// Populate header list first and sort it
   165  	logKeys := make([]string, 0, len(req.Header))
   166  	for k := range req.Header {
   167  		if l.filterHeader(k) {
   168  			continue
   169  		}
   170  		logKeys = append(logKeys, k)
   171  	}
   172  	slices.SortFunc(logKeys, strings.Compare)
   173  
   174  	// build the log fields
   175  	logFields := make([]any, 0, len(logKeys)*2)
   176  	for _, k := range logKeys {
   177  		logFields = append(
   178  			logFields,
   179  			"request_header_"+k,
   180  			req.Header.Get(k),
   181  		)
   182  	}
   183  
   184  	return logFields
   185  }
   186  
   187  // logWithRequest information from the request and context as fields.
   188  func (l *Log) logWithRequest(r *http.Request) log.Logger {
   189  	localLog := l.Log
   190  	traceID, ok := tracing.ExtractTraceID(r.Context())
   191  	if ok {
   192  		localLog = log.With(localLog, "traceID", traceID)
   193  	}
   194  
   195  	if l.SourceIPs != nil {
   196  		ips := l.SourceIPs.Get(r)
   197  		if ips != "" {
   198  			localLog = log.With(localLog, "sourceIPs", ips)
   199  		}
   200  	}
   201  
   202  	tenantID := r.Header.Get(user.OrgIDHeaderName)
   203  	if tenantID == "" {
   204  		id, err := user.ExtractOrgID(r.Context())
   205  		if err == nil {
   206  			tenantID = id
   207  		}
   208  	}
   209  	if tenantID != "" {
   210  		localLog = log.With(localLog, "tenant", tenantID)
   211  	}
   212  
   213  	return localLog
   214  }
   215  
   216  // measure request body size
   217  type reqBody struct {
   218  	b    io.ReadCloser
   219  	read byteSize
   220  
   221  	start    time.Time
   222  	duration time.Duration
   223  
   224  	sp opentracing.Span
   225  }
   226  
   227  func (w *reqBody) Read(p []byte) (int, error) {
   228  	if w.start.IsZero() {
   229  		w.start = timeNow()
   230  		if w.sp != nil {
   231  			w.sp.LogFields(otlog.String("msg", "start reading body from request"))
   232  		}
   233  	}
   234  	n, err := w.b.Read(p)
   235  	if n > 0 {
   236  		w.read += byteSize(n)
   237  	}
   238  	if err == io.EOF {
   239  		w.duration = timeNow().Sub(w.start)
   240  		if w.sp != nil {
   241  			w.sp.LogFields(otlog.String("msg", "read body from request"))
   242  			if w.read > 0 {
   243  				w.sp.SetTag("request_body_size", w.read)
   244  			}
   245  		}
   246  	}
   247  	return n, err
   248  }
   249  
   250  func (w *reqBody) Close() error {
   251  	return w.b.Close()
   252  }
   253  
   254  type byteSize uint64
   255  
   256  func (bs byteSize) String() string {
   257  	return strings.Replace(humanize.IBytes(uint64(bs)), " ", "", 1)
   258  }
   259  
   260  // Wrap implements Middleware
   261  func (l *Log) Wrap(next http.Handler) http.Handler {
   262  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   263  		begin := timeNow()
   264  		uri := r.RequestURI // Capture the URI before running next, as it may get rewritten
   265  		requestLog := l.logWithRequest(r)
   266  		// Log headers before running 'next' in case other interceptors change the data.
   267  
   268  		var (
   269  			httpErr       multierror.MultiError
   270  			httpCode      = http.StatusOK
   271  			headerWritten bool
   272  			buf           bytes.Buffer
   273  			bodyLeft      = maxResponseBodyInLogs
   274  		)
   275  
   276  		headerFields := l.extractHeaders(r)
   277  
   278  		wrapped := httpsnoop.Wrap(w, httpsnoop.Hooks{
   279  			WriteHeader: func(next httpsnoop.WriteHeaderFunc) httpsnoop.WriteHeaderFunc {
   280  				return func(code int) {
   281  					next(code)
   282  					if !headerWritten {
   283  						httpCode = code
   284  						headerWritten = true
   285  					}
   286  				}
   287  			},
   288  
   289  			Write: func(next httpsnoop.WriteFunc) httpsnoop.WriteFunc {
   290  				return func(p []byte) (int, error) {
   291  					n, err := next(p)
   292  					headerWritten = true
   293  					httpErr.Add(err)
   294  					if httpCode >= 400 && httpCode < 600 {
   295  						bodyLeft = captureResponseBody(p, bodyLeft, &buf)
   296  					}
   297  					return n, err
   298  				}
   299  			},
   300  
   301  			ReadFrom: func(next httpsnoop.ReadFromFunc) httpsnoop.ReadFromFunc {
   302  				return func(src io.Reader) (int64, error) {
   303  					n, err := next(src)
   304  					headerWritten = true
   305  					httpErr.Add(err)
   306  					return n, err
   307  				}
   308  			},
   309  		})
   310  
   311  		origBody := r.Body
   312  		defer func() {
   313  			// No need to leak our Body wrapper beyond the scope of this handler.
   314  			r.Body = origBody
   315  		}()
   316  
   317  		rBody := &reqBody{
   318  			b:  origBody,
   319  			sp: opentracing.SpanFromContext(r.Context()),
   320  		}
   321  		r.Body = rBody
   322  
   323  		next.ServeHTTP(wrapped, r)
   324  
   325  		statusCode, writeErr := httpCode, httpErr.Err()
   326  
   327  		requestLog = log.With(requestLog, "method", r.Method, "uri", uri, "status", statusCode, "duration", time.Since(begin))
   328  
   329  		if l.LogRequestHeaders {
   330  			requestLog = log.With(requestLog, headerFields...)
   331  		}
   332  		if rBody.read > 0 {
   333  			requestLog = log.With(requestLog, "request_body_size", rBody.read)
   334  			if rBody.duration > 0 {
   335  				requestLog = log.With(requestLog, "request_body_read_duration", rBody.duration)
   336  			}
   337  		}
   338  
   339  		requestLvl := level.Debug
   340  		if l.LogRequestAtInfoLevel {
   341  			requestLvl = level.Info
   342  		}
   343  
   344  		// log successful requests
   345  		if writeErr == nil && (100 <= statusCode && statusCode < 400) {
   346  			requestLvl(requestLog).Log("msg", "http request processed")
   347  			return
   348  		}
   349  
   350  		// context cancelled is not considered a failure
   351  		if writeErr != nil && errors.Is(writeErr, context.Canceled) {
   352  			requestLvl(requestLog).Log("msg", "request cancelled")
   353  			return
   354  		}
   355  
   356  		// add request headers if not anyhow added
   357  		if !l.LogRequestHeaders {
   358  			requestLog = log.With(requestLog, headerFields...)
   359  		}
   360  
   361  		// writeError shouldn't log the body
   362  		if writeErr != nil {
   363  			level.Warn(requestLog).Log("msg", "http request failed", "err", writeErr)
   364  			return
   365  		}
   366  
   367  		level.Warn(requestLog).Log("msg", "http request failed", "response_body", buf.Bytes())
   368  	})
   369  }
   370  
   371  func captureResponseBody(data []byte, bodyBytesLeft int, buf *bytes.Buffer) int {
   372  	if bodyBytesLeft <= 0 {
   373  		return 0
   374  	}
   375  	if len(data) > bodyBytesLeft {
   376  		buf.Write(data[:bodyBytesLeft])
   377  		_, _ = io.WriteString(buf, "...")
   378  		return 0
   379  	} else {
   380  		buf.Write(data)
   381  		return bodyBytesLeft - len(data)
   382  	}
   383  }
   384  
   385  // NewHTTPMetricMiddleware creates a new middleware that automatically instruments HTTP requests from the given router.
   386  func NewHTTPMetricMiddleware(mux *mux.Router, namespace string, reg prometheus.Registerer) (middleware.Interface, error) {
   387  	// Prometheus histograms for requests.
   388  	requestDuration := prometheus.NewHistogramVec(prometheus.HistogramOpts{
   389  		Namespace: namespace,
   390  		Name:      "request_duration_seconds",
   391  		Help:      "Time (in seconds) spent serving HTTP requests.",
   392  		Buckets:   instrument.DefBuckets,
   393  	}, []string{"method", "route", "status_code", "ws"})
   394  	err := reg.Register(requestDuration)
   395  	if err != nil {
   396  		already, ok := err.(prometheus.AlreadyRegisteredError)
   397  		if ok {
   398  			requestDuration = already.ExistingCollector.(*prometheus.HistogramVec)
   399  		} else {
   400  			return nil, err
   401  		}
   402  	}
   403  
   404  	receivedMessageSize := prometheus.NewHistogramVec(prometheus.HistogramOpts{
   405  		Namespace: namespace,
   406  		Name:      "request_message_bytes",
   407  		Help:      "Size (in bytes) of messages received in the request.",
   408  		Buckets:   middleware.BodySizeBuckets,
   409  	}, []string{"method", "route"})
   410  	err = reg.Register(receivedMessageSize)
   411  	if err != nil {
   412  		already, ok := err.(prometheus.AlreadyRegisteredError)
   413  		if ok {
   414  			receivedMessageSize = already.ExistingCollector.(*prometheus.HistogramVec)
   415  		} else {
   416  			return nil, err
   417  		}
   418  	}
   419  
   420  	sentMessageSize := prometheus.NewHistogramVec(prometheus.HistogramOpts{
   421  		Namespace: namespace,
   422  		Name:      "response_message_bytes",
   423  		Help:      "Size (in bytes) of messages sent in response.",
   424  		Buckets:   middleware.BodySizeBuckets,
   425  	}, []string{"method", "route"})
   426  
   427  	err = reg.Register(sentMessageSize)
   428  	if err != nil {
   429  		already, ok := err.(prometheus.AlreadyRegisteredError)
   430  		if ok {
   431  			sentMessageSize = already.ExistingCollector.(*prometheus.HistogramVec)
   432  		} else {
   433  			return nil, err
   434  		}
   435  	}
   436  
   437  	inflightRequests := prometheus.NewGaugeVec(prometheus.GaugeOpts{
   438  		Namespace: namespace,
   439  		Name:      "inflight_requests",
   440  		Help:      "Current number of inflight requests.",
   441  	}, []string{"method", "route"})
   442  	err = reg.Register(inflightRequests)
   443  	if err != nil {
   444  		already, ok := err.(prometheus.AlreadyRegisteredError)
   445  		if ok {
   446  			inflightRequests = already.ExistingCollector.(*prometheus.GaugeVec)
   447  		} else {
   448  			return nil, err
   449  		}
   450  	}
   451  	return middleware.Instrument{
   452  		Duration:         requestDuration,
   453  		RequestBodySize:  receivedMessageSize,
   454  		ResponseBodySize: sentMessageSize,
   455  		InflightRequests: inflightRequests,
   456  	}, nil
   457  }
   458  
   459  // WriteHTMLResponse sends message as text/html response with 200 status code.
   460  func WriteHTMLResponse(w http.ResponseWriter, message string) {
   461  	w.Header().Set("Content-Type", "text/html")
   462  
   463  	// Ignore inactionable errors.
   464  	_, _ = w.Write([]byte(message))
   465  }
   466  
   467  // WriteTextResponse sends message as text/plain response with 200 status code.
   468  func WriteTextResponse(w http.ResponseWriter, message string) {
   469  	w.Header().Set("Content-Type", "text/plain")
   470  
   471  	// Ignore inactionable errors.
   472  	_, _ = w.Write([]byte(message))
   473  }
   474  
   475  // RenderHTTPResponse either responds with JSON or a rendered HTML page using the passed in template
   476  // by checking the Accepts header.
   477  func RenderHTTPResponse(w http.ResponseWriter, v interface{}, t *template.Template, r *http.Request) {
   478  	accept := r.Header.Get("Accept")
   479  	if strings.Contains(accept, "application/json") {
   480  		WriteJSONResponse(w, v)
   481  		return
   482  	}
   483  
   484  	w.Header().Set("Content-Type", "text/html; charset=utf-8")
   485  	err := t.Execute(w, v)
   486  	if err != nil {
   487  		httputil.Error(w, err)
   488  	}
   489  }
   490  
   491  // WriteJSONResponse writes some JSON as a HTTP response.
   492  func WriteJSONResponse(w http.ResponseWriter, v interface{}) {
   493  	w.Header().Set("Content-Type", "application/json")
   494  
   495  	data, err := json.Marshal(v)
   496  	if err != nil {
   497  		httputil.Error(w, err)
   498  		return
   499  	}
   500  
   501  	// We ignore errors here, because we cannot do anything about them.
   502  	// Write will trigger sending Status code, so we cannot send a different status code afterwards.
   503  	// Also this isn't internal error, but error communicating with client.
   504  	_, _ = w.Write(data)
   505  }
   506  
   507  // AuthenticateUser propagates the user ID from HTTP headers back to the request's context.
   508  // If on is false, it will inject the default tenant ID.
   509  func AuthenticateUser(on bool) middleware.Interface {
   510  	// TODO: @petethepig This logic is copied in otlp.*ingestHandler.Export. We should unify
   511  	return middleware.Func(func(next http.Handler) http.Handler {
   512  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   513  			if !on {
   514  				next.ServeHTTP(w, r.WithContext(user.InjectOrgID(r.Context(), tenant.DefaultTenantID)))
   515  				return
   516  			}
   517  			_, ctx, err := user.ExtractOrgIDFromHTTPRequest(r)
   518  			if err != nil {
   519  				httputil.ErrorWithStatus(w, err, http.StatusUnauthorized)
   520  				return
   521  			}
   522  			next.ServeHTTP(w, r.WithContext(ctx))
   523  		})
   524  	})
   525  }