github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/httputil/logging.go (about) 1 package httputil 2 3 import ( 4 "context" 5 "net" 6 "net/http" 7 "strings" 8 "time" 9 10 "github.com/google/uuid" 11 "github.com/sirupsen/logrus" 12 "github.com/treeverse/lakefs/pkg/logging" 13 ) 14 15 type contextKey string 16 17 const ( 18 RequestIDContextKey contextKey = "request_id" 19 AuditLogEndMessage string = "HTTP call ended" 20 ) 21 22 type ResponseRecordingWriter struct { 23 StatusCode int 24 ResponseSize int64 25 Writer http.ResponseWriter 26 } 27 28 func (w *ResponseRecordingWriter) Header() http.Header { 29 return w.Writer.Header() 30 } 31 32 func (w *ResponseRecordingWriter) Write(data []byte) (int, error) { 33 written, err := w.Writer.Write(data) 34 w.ResponseSize += int64(written) 35 return written, err 36 } 37 38 func (w *ResponseRecordingWriter) WriteHeader(statusCode int) { 39 w.StatusCode = statusCode 40 w.Writer.WriteHeader(statusCode) 41 } 42 43 func RequestID(r *http.Request) (*http.Request, string) { 44 ctx := r.Context() 45 resp := ctx.Value(RequestIDContextKey) 46 var reqID string 47 if resp == nil { 48 // assign a request ID for this request 49 reqID = uuid.New().String() 50 r = r.WithContext(context.WithValue(ctx, RequestIDContextKey, reqID)) 51 } else { 52 reqID = resp.(string) 53 } 54 return r, reqID 55 } 56 57 func SourceIP(r *http.Request) string { 58 sourceIP, sourcePort, err := net.SplitHostPort(r.RemoteAddr) 59 60 if err != nil { 61 return err.Error() 62 } 63 return sourceIP + ":" + sourcePort 64 } 65 66 func DefaultLoggingMiddleware(requestIDHeaderName string, fields logging.Fields, middlewareLogLevel string) func(next http.Handler) http.Handler { 67 return func(next http.Handler) http.Handler { 68 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 69 startTime := time.Now() 70 writer := &ResponseRecordingWriter{Writer: w, StatusCode: http.StatusOK} 71 r, reqID := RequestID(r) 72 client := GetRequestLakeFSClient(r) 73 sourceIP := SourceIP(r) 74 75 // add default fields to context 76 requestFields := logging.Fields{ 77 logging.PathFieldKey: r.RequestURI, 78 logging.MethodFieldKey: r.Method, 79 logging.HostFieldKey: r.Host, 80 logging.RequestIDFieldKey: reqID, 81 } 82 for k, v := range fields { 83 requestFields[k] = v 84 } 85 r = r.WithContext(logging.AddFields(r.Context(), requestFields)) 86 writer.Header().Set(requestIDHeaderName, reqID) 87 next.ServeHTTP(writer, r) // handle the request 88 89 loggingFields := logging.Fields{ 90 "took": time.Since(startTime), 91 "status_code": writer.StatusCode, 92 "sent_bytes": writer.ResponseSize, 93 "client": client, 94 logging.LogAudit: true, 95 "source_ip": sourceIP, 96 } 97 98 logLevel := strings.ToLower(middlewareLogLevel) 99 if logLevel == "null" || logLevel == "none" { 100 logging.FromContext(r.Context()).WithFields(loggingFields).Debug(AuditLogEndMessage) 101 } else { 102 level, _ := logrus.ParseLevel(logLevel) 103 logging.FromContext(r.Context()).WithFields(loggingFields).Log(level, AuditLogEndMessage) 104 } 105 }) 106 } 107 } 108 109 func LoggingMiddleware(requestIDHeaderName string, fields logging.Fields, loggingMiddlewareLevel string, traceRequestHeaders bool) func(next http.Handler) http.Handler { 110 if strings.ToLower(loggingMiddlewareLevel) == "trace" { 111 return TracingMiddleware(requestIDHeaderName, fields, traceRequestHeaders) 112 } 113 return DefaultLoggingMiddleware(requestIDHeaderName, fields, loggingMiddlewareLevel) 114 }