code.vegaprotocol.io/vega@v0.79.0/datanode/ratelimit/naughtystep.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package ratelimit
    17  
    18  import (
    19  	"sync"
    20  	"time"
    21  
    22  	"code.vegaprotocol.io/vega/logging"
    23  
    24  	"github.com/didip/tollbooth/v7"
    25  	"github.com/didip/tollbooth/v7/limiter"
    26  	"go.uber.org/zap"
    27  )
    28  
    29  // naughtyStep is a struct for keeping track of bad behavior and bans.
    30  //
    31  // You get put on the naughty step if you make requests despite having run out of tokens.
    32  // The naughty step has it's own rate limiter, and its tokens are spent every time a failed
    33  // (due to rate limiting) API call is made. If you run out of naughty tokens, then you get
    34  // banned for a period of time.
    35  
    36  type naughtyStep struct {
    37  	log    *logging.Logger
    38  	lmt    *limiter.Limiter
    39  	bans   map[string]time.Time
    40  	mu     sync.RWMutex
    41  	banFor time.Duration
    42  }
    43  
    44  func newNaughtyStep(log *logging.Logger, rate float64, burst int, banFor, pruneEvery time.Duration) *naughtyStep {
    45  	limitOpts := limiter.ExpirableOptions{DefaultExpirationTTL: pruneEvery}
    46  	lmt := tollbooth.NewLimiter(rate, &limitOpts)
    47  	lmt.SetBurst(burst)
    48  
    49  	ns := naughtyStep{
    50  		log:    log,
    51  		lmt:    lmt,
    52  		bans:   make(map[string]time.Time),
    53  		banFor: banFor,
    54  	}
    55  
    56  	go func() {
    57  		for range time.Tick(pruneEvery) {
    58  			ns.prune()
    59  		}
    60  	}()
    61  
    62  	return &ns
    63  }
    64  
    65  func (n *naughtyStep) enabled() bool {
    66  	return n.banFor > 0
    67  }
    68  
    69  func (n *naughtyStep) smackBottom(ip string) {
    70  	if !n.enabled() {
    71  		return
    72  	}
    73  
    74  	if n.lmt.LimitReached(ip) {
    75  		n.ban(ip)
    76  		n.log.Info("banned for requesting past rate limit", zap.String("ip", ip))
    77  	}
    78  }
    79  
    80  func (n *naughtyStep) ban(ip string) {
    81  	n.mu.Lock()
    82  	defer n.mu.Unlock()
    83  
    84  	n.bans[ip] = time.Now().Add(n.banFor)
    85  }
    86  
    87  func (n *naughtyStep) isBanned(ip string) bool {
    88  	n.mu.RLock()
    89  	defer n.mu.RUnlock()
    90  
    91  	if bannedUntil, ok := n.bans[ip]; ok {
    92  		if time.Now().Before(bannedUntil) {
    93  			return true
    94  		}
    95  	}
    96  	return false
    97  }
    98  
    99  func (n *naughtyStep) prune() {
   100  	n.mu.Lock()
   101  	defer n.mu.Unlock()
   102  
   103  	for ip, until := range n.bans {
   104  		if time.Now().After(until) {
   105  			delete(n.bans, ip)
   106  		}
   107  	}
   108  }