github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/plugin/rate/rate_limit_logger.go (about)

     1  package rate
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/http"
     7  
     8  	"github.com/felixge/httpsnoop"
     9  	"github.com/hellofresh/stats-go/bucket"
    10  	"github.com/hellofresh/stats-go/client"
    11  	log "github.com/sirupsen/logrus"
    12  	"github.com/ulule/limiter/v3"
    13  )
    14  
    15  const (
    16  	limiterSection = "limiter"
    17  	limiterMetric  = "state"
    18  )
    19  
    20  // NewRateLimitLogger logs the IP of blocked users with rate limit
    21  func NewRateLimitLogger(lmt *limiter.Limiter, statsClient client.Client, trustForwardHeaders bool) func(handler http.Handler) http.Handler {
    22  	return func(handler http.Handler) http.Handler {
    23  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    24  			log.Debug("Starting RateLimitLogger.WriterWrapper middleware")
    25  
    26  			m := httpsnoop.CaptureMetrics(handler, w, r)
    27  
    28  			limiterIP := limiter.GetIP(r, limiter.Options{TrustForwardHeader: trustForwardHeaders})
    29  			if m.Code == http.StatusTooManyRequests {
    30  				log.WithFields(log.Fields{
    31  					"ip_address":  limiterIP.String(),
    32  					"request_uri": r.RequestURI,
    33  				}).Warning("Rate Limit exceeded for this IP")
    34  			}
    35  
    36  			trackLimitState(lmt, statsClient, limiterIP, r)
    37  		})
    38  	}
    39  }
    40  
    41  func trackLimitState(lmt *limiter.Limiter, statsClient client.Client, limiterIP net.IP, r *http.Request) {
    42  	ctx, err := lmt.Peek(context.Background(), limiterIP.String())
    43  	if err != nil {
    44  		log.WithError(err).WithFields(log.Fields{
    45  			"ip_address":  limiterIP.String(),
    46  			"request_uri": r.RequestURI,
    47  		}).Error("Failed to get limiter ctx from request")
    48  		return
    49  	}
    50  
    51  	requestsPerformed := ctx.Limit - ctx.Remaining
    52  	limitState := requestsPerformed * 100 / ctx.Limit
    53  
    54  	operation := bucket.BuildHTTPRequestMetricOperation(r, statsClient.GetHTTPMetricCallback())
    55  	// replace request method with fixed section name
    56  	operation[0] = limiterMetric
    57  
    58  	statsClient.TrackState(limiterSection, operation, int(limitState))
    59  }