github.com/castai/kvisor@v1.7.1-0.20240516114728-b3572a2607b5/pkg/logging/ratelimiter_handler.go (about)

     1  package logging
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log/slog"
     7  	"sync/atomic"
     8  	"time"
     9  
    10  	"golang.org/x/time/rate"
    11  )
    12  
    13  func NewRateLimiterHandler(ctx context.Context, next slog.Handler, cfg RateLimiterConfig) slog.Handler {
    14  	droppedLogsCounters := map[slog.Level]*atomic.Uint64{
    15  		slog.LevelDebug: {},
    16  		slog.LevelInfo:  {},
    17  		slog.LevelWarn:  {},
    18  		slog.LevelError: {},
    19  	}
    20  	logsRate := cfg.Limit
    21  	burst := cfg.Burst
    22  	if cfg.Inform {
    23  		go printDroppedLogsCounter(ctx, droppedLogsCounters)
    24  	}
    25  	return &RateLimiterHandler{
    26  		next: next,
    27  		rt: map[slog.Level]*rate.Limiter{
    28  			slog.LevelDebug: rate.NewLimiter(logsRate, burst),
    29  			slog.LevelInfo:  rate.NewLimiter(logsRate, burst),
    30  			slog.LevelWarn:  rate.NewLimiter(logsRate, burst),
    31  			slog.LevelError: rate.NewLimiter(logsRate, burst),
    32  		},
    33  		droppedLogsCounters: droppedLogsCounters,
    34  	}
    35  }
    36  
    37  type RateLimiterHandler struct {
    38  	next                slog.Handler
    39  	rt                  map[slog.Level]*rate.Limiter
    40  	droppedLogsCounters map[slog.Level]*atomic.Uint64
    41  }
    42  
    43  func (s *RateLimiterHandler) Enabled(ctx context.Context, level slog.Level) bool {
    44  	if !s.next.Enabled(ctx, level) {
    45  		return false
    46  	}
    47  	if !s.rt[level].Allow() {
    48  		s.droppedLogsCounters[level].Add(1)
    49  		return false
    50  	}
    51  	return true
    52  }
    53  
    54  func (s *RateLimiterHandler) Handle(ctx context.Context, record slog.Record) error {
    55  	return s.next.Handle(ctx, record)
    56  }
    57  
    58  func (s *RateLimiterHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
    59  	return &RateLimiterHandler{
    60  		next:                s.next.WithAttrs(attrs),
    61  		rt:                  s.rt,
    62  		droppedLogsCounters: s.droppedLogsCounters,
    63  	}
    64  }
    65  
    66  func (s *RateLimiterHandler) WithGroup(name string) slog.Handler {
    67  	return &RateLimiterHandler{
    68  		next:                s.next.WithGroup(name),
    69  		rt:                  s.rt,
    70  		droppedLogsCounters: s.droppedLogsCounters,
    71  	}
    72  }
    73  
    74  func printDroppedLogsCounter(ctx context.Context, droppedLogsCounters map[slog.Level]*atomic.Uint64) {
    75  	ticker := time.NewTicker(5 * time.Second)
    76  	defer ticker.Stop()
    77  	for {
    78  		select {
    79  		case <-ctx.Done():
    80  			return
    81  		case <-ticker.C:
    82  			for level, val := range droppedLogsCounters {
    83  				count := val.Load()
    84  				if count > 0 {
    85  					slog.Warn(fmt.Sprintf("logs rate limit, dropped %d lines for level %s", count, level.String()))
    86  					val.Store(0)
    87  				}
    88  			}
    89  		}
    90  	}
    91  }