github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/httputil/tracing.go (about)

     1  package httputil
     2  
     3  import (
     4  	"io"
     5  	"net/http"
     6  	"time"
     7  
     8  	"github.com/treeverse/lakefs/pkg/logging"
     9  )
    10  
    11  const (
    12  	MaxBodyBytes                      = 750               // Log lines will be < 2KiB
    13  	RequestTracingMaxRequestBodySize  = 1024 * 1024 * 50  // 50KB
    14  	RequestTracingMaxResponseBodySize = 1024 * 1024 * 150 // 150KB
    15  )
    16  
    17  type CappedBuffer struct {
    18  	SizeBytes int
    19  	cursor    int
    20  	Buffer    []byte
    21  }
    22  
    23  func (c *CappedBuffer) Write(p []byte) (n int, err error) {
    24  	// pretend to write the whole thing, but only write SizeBytes
    25  	if c.cursor >= c.SizeBytes {
    26  		return len(p), nil
    27  	}
    28  	if c.Buffer == nil {
    29  		c.Buffer = make([]byte, 0)
    30  	}
    31  	var written int
    32  	if len(p) > (c.SizeBytes - c.cursor) {
    33  		c.Buffer = append(c.Buffer, p[0:(c.SizeBytes-c.cursor)]...)
    34  		written = c.SizeBytes - c.cursor
    35  	} else {
    36  		c.Buffer = append(c.Buffer, p...)
    37  		written = len(p)
    38  	}
    39  	c.cursor += written
    40  	return len(p), nil
    41  }
    42  
    43  type responseTracingWriter struct {
    44  	StatusCode   int
    45  	ResponseSize int64
    46  	BodyRecorder *CappedBuffer
    47  
    48  	Writer      http.ResponseWriter
    49  	multiWriter io.Writer
    50  }
    51  
    52  func newResponseTracingWriter(w http.ResponseWriter, sizeInBytes int) *responseTracingWriter {
    53  	buf := &CappedBuffer{
    54  		SizeBytes: sizeInBytes,
    55  	}
    56  	mw := io.MultiWriter(w, buf)
    57  	return &responseTracingWriter{
    58  		StatusCode:   http.StatusOK,
    59  		BodyRecorder: buf,
    60  		Writer:       w,
    61  		multiWriter:  mw,
    62  	}
    63  }
    64  
    65  func (w *responseTracingWriter) Header() http.Header {
    66  	return w.Writer.Header()
    67  }
    68  
    69  func (w *responseTracingWriter) Write(data []byte) (int, error) {
    70  	return w.multiWriter.Write(data)
    71  }
    72  
    73  func (w *responseTracingWriter) WriteHeader(statusCode int) {
    74  	w.StatusCode = statusCode
    75  	w.Writer.WriteHeader(statusCode)
    76  }
    77  
    78  type requestBodyTracer struct {
    79  	body         io.ReadCloser
    80  	bodyRecorder *CappedBuffer
    81  	tee          io.Reader
    82  }
    83  
    84  func newRequestBodyTracer(body io.ReadCloser, sizeInBytes int) *requestBodyTracer {
    85  	w := &CappedBuffer{
    86  		SizeBytes: sizeInBytes,
    87  	}
    88  	return &requestBodyTracer{
    89  		body:         body,
    90  		bodyRecorder: w,
    91  		tee:          io.TeeReader(body, w),
    92  	}
    93  }
    94  
    95  func (r *requestBodyTracer) Read(p []byte) (n int, err error) {
    96  	return r.tee.Read(p)
    97  }
    98  
    99  func (r *requestBodyTracer) Close() error {
   100  	return r.body.Close()
   101  }
   102  
   103  func presentBody(body []byte) string {
   104  	if len(body) > MaxBodyBytes {
   105  		body = body[:MaxBodyBytes]
   106  	}
   107  	return string(body)
   108  }
   109  
   110  func TracingMiddleware(requestIDHeaderName string, fields logging.Fields, traceRequestHeaders bool) func(http.Handler) http.Handler {
   111  	return func(next http.Handler) http.Handler {
   112  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   113  			startTime := time.Now()
   114  			responseWriter := newResponseTracingWriter(w, RequestTracingMaxResponseBodySize)
   115  			r, reqID := RequestID(r)
   116  
   117  			// add default fields to context
   118  			requestFields := logging.Fields{
   119  				logging.PathFieldKey:      r.RequestURI,
   120  				logging.MethodFieldKey:    r.Method,
   121  				logging.HostFieldKey:      r.Host,
   122  				logging.RequestIDFieldKey: reqID,
   123  			}
   124  			for k, v := range fields {
   125  				requestFields[k] = v
   126  			}
   127  			r = r.WithContext(logging.AddFields(r.Context(), requestFields))
   128  			responseWriter.Header().Set(requestIDHeaderName, reqID)
   129  
   130  			// record request body as well
   131  			requestBodyTracer := newRequestBodyTracer(r.Body, RequestTracingMaxRequestBodySize)
   132  			r.Body = requestBodyTracer
   133  
   134  			next.ServeHTTP(responseWriter, r) // handle the request
   135  
   136  			traceFields := logging.Fields{
   137  				"took":             time.Since(startTime),
   138  				"status_code":      responseWriter.StatusCode,
   139  				"sent_bytes":       responseWriter.ResponseSize,
   140  				"request_body":     presentBody(requestBodyTracer.bodyRecorder.Buffer),
   141  				"response_body":    presentBody(responseWriter.BodyRecorder.Buffer),
   142  				"response_headers": responseWriter.Header(),
   143  			}
   144  			if traceRequestHeaders {
   145  				traceFields["request_headers"] = r.Header
   146  			}
   147  			logging.FromContext(r.Context()).
   148  				WithFields(traceFields).
   149  				Trace("HTTP call ended")
   150  		})
   151  	}
   152  }