github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/httputil/logging.go (about)

     1  package httputil
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/http"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/google/uuid"
    11  	"github.com/sirupsen/logrus"
    12  	"github.com/treeverse/lakefs/pkg/logging"
    13  )
    14  
    15  type contextKey string
    16  
    17  const (
    18  	RequestIDContextKey contextKey = "request_id"
    19  	AuditLogEndMessage  string     = "HTTP call ended"
    20  )
    21  
    22  type ResponseRecordingWriter struct {
    23  	StatusCode   int
    24  	ResponseSize int64
    25  	Writer       http.ResponseWriter
    26  }
    27  
    28  func (w *ResponseRecordingWriter) Header() http.Header {
    29  	return w.Writer.Header()
    30  }
    31  
    32  func (w *ResponseRecordingWriter) Write(data []byte) (int, error) {
    33  	written, err := w.Writer.Write(data)
    34  	w.ResponseSize += int64(written)
    35  	return written, err
    36  }
    37  
    38  func (w *ResponseRecordingWriter) WriteHeader(statusCode int) {
    39  	w.StatusCode = statusCode
    40  	w.Writer.WriteHeader(statusCode)
    41  }
    42  
    43  func RequestID(r *http.Request) (*http.Request, string) {
    44  	ctx := r.Context()
    45  	resp := ctx.Value(RequestIDContextKey)
    46  	var reqID string
    47  	if resp == nil {
    48  		// assign a request ID for this request
    49  		reqID = uuid.New().String()
    50  		r = r.WithContext(context.WithValue(ctx, RequestIDContextKey, reqID))
    51  	} else {
    52  		reqID = resp.(string)
    53  	}
    54  	return r, reqID
    55  }
    56  
    57  func SourceIP(r *http.Request) string {
    58  	sourceIP, sourcePort, err := net.SplitHostPort(r.RemoteAddr)
    59  
    60  	if err != nil {
    61  		return err.Error()
    62  	}
    63  	return sourceIP + ":" + sourcePort
    64  }
    65  
    66  func DefaultLoggingMiddleware(requestIDHeaderName string, fields logging.Fields, middlewareLogLevel string) func(next http.Handler) http.Handler {
    67  	return func(next http.Handler) http.Handler {
    68  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    69  			startTime := time.Now()
    70  			writer := &ResponseRecordingWriter{Writer: w, StatusCode: http.StatusOK}
    71  			r, reqID := RequestID(r)
    72  			client := GetRequestLakeFSClient(r)
    73  			sourceIP := SourceIP(r)
    74  
    75  			// add default fields to context
    76  			requestFields := logging.Fields{
    77  				logging.PathFieldKey:      r.RequestURI,
    78  				logging.MethodFieldKey:    r.Method,
    79  				logging.HostFieldKey:      r.Host,
    80  				logging.RequestIDFieldKey: reqID,
    81  			}
    82  			for k, v := range fields {
    83  				requestFields[k] = v
    84  			}
    85  			r = r.WithContext(logging.AddFields(r.Context(), requestFields))
    86  			writer.Header().Set(requestIDHeaderName, reqID)
    87  			next.ServeHTTP(writer, r) // handle the request
    88  
    89  			loggingFields := logging.Fields{
    90  				"took":           time.Since(startTime),
    91  				"status_code":    writer.StatusCode,
    92  				"sent_bytes":     writer.ResponseSize,
    93  				"client":         client,
    94  				logging.LogAudit: true,
    95  				"source_ip":      sourceIP,
    96  			}
    97  
    98  			logLevel := strings.ToLower(middlewareLogLevel)
    99  			if logLevel == "null" || logLevel == "none" {
   100  				logging.FromContext(r.Context()).WithFields(loggingFields).Debug(AuditLogEndMessage)
   101  			} else {
   102  				level, _ := logrus.ParseLevel(logLevel)
   103  				logging.FromContext(r.Context()).WithFields(loggingFields).Log(level, AuditLogEndMessage)
   104  			}
   105  		})
   106  	}
   107  }
   108  
   109  func LoggingMiddleware(requestIDHeaderName string, fields logging.Fields, loggingMiddlewareLevel string, traceRequestHeaders bool) func(next http.Handler) http.Handler {
   110  	if strings.ToLower(loggingMiddlewareLevel) == "trace" {
   111  		return TracingMiddleware(requestIDHeaderName, fields, traceRequestHeaders)
   112  	}
   113  	return DefaultLoggingMiddleware(requestIDHeaderName, fields, loggingMiddlewareLevel)
   114  }