github.com/sagernet/wireguard-go@v0.0.0-20231215174105-89dec3b2f3e8/ratelimiter/ratelimiter.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package ratelimiter
     7  
     8  import (
     9  	"net/netip"
    10  	"sync"
    11  	"time"
    12  )
    13  
    14  const (
    15  	packetsPerSecond   = 20
    16  	packetsBurstable   = 5
    17  	garbageCollectTime = time.Second
    18  	packetCost         = 1000000000 / packetsPerSecond
    19  	maxTokens          = packetCost * packetsBurstable
    20  )
    21  
    22  type RatelimiterEntry struct {
    23  	mu       sync.Mutex
    24  	lastTime time.Time
    25  	tokens   int64
    26  }
    27  
    28  type Ratelimiter struct {
    29  	mu      sync.RWMutex
    30  	timeNow func() time.Time
    31  
    32  	stopReset chan struct{} // send to reset, close to stop
    33  	table     map[netip.Addr]*RatelimiterEntry
    34  }
    35  
    36  func (rate *Ratelimiter) Close() {
    37  	rate.mu.Lock()
    38  	defer rate.mu.Unlock()
    39  
    40  	if rate.stopReset != nil {
    41  		close(rate.stopReset)
    42  	}
    43  }
    44  
    45  func (rate *Ratelimiter) Init() {
    46  	rate.mu.Lock()
    47  	defer rate.mu.Unlock()
    48  
    49  	if rate.timeNow == nil {
    50  		rate.timeNow = time.Now
    51  	}
    52  
    53  	// stop any ongoing garbage collection routine
    54  	if rate.stopReset != nil {
    55  		close(rate.stopReset)
    56  	}
    57  
    58  	rate.stopReset = make(chan struct{})
    59  	rate.table = make(map[netip.Addr]*RatelimiterEntry)
    60  
    61  	stopReset := rate.stopReset // store in case Init is called again.
    62  
    63  	// Start garbage collection routine.
    64  	go func() {
    65  		ticker := time.NewTicker(time.Second)
    66  		ticker.Stop()
    67  		for {
    68  			select {
    69  			case _, ok := <-stopReset:
    70  				ticker.Stop()
    71  				if !ok {
    72  					return
    73  				}
    74  				ticker = time.NewTicker(time.Second)
    75  			case <-ticker.C:
    76  				if rate.cleanup() {
    77  					ticker.Stop()
    78  				}
    79  			}
    80  		}
    81  	}()
    82  }
    83  
    84  func (rate *Ratelimiter) cleanup() (empty bool) {
    85  	rate.mu.Lock()
    86  	defer rate.mu.Unlock()
    87  
    88  	for key, entry := range rate.table {
    89  		entry.mu.Lock()
    90  		if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
    91  			delete(rate.table, key)
    92  		}
    93  		entry.mu.Unlock()
    94  	}
    95  
    96  	return len(rate.table) == 0
    97  }
    98  
    99  func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
   100  	var entry *RatelimiterEntry
   101  	// lookup entry
   102  	rate.mu.RLock()
   103  	entry = rate.table[ip]
   104  	rate.mu.RUnlock()
   105  
   106  	// make new entry if not found
   107  	if entry == nil {
   108  		entry = new(RatelimiterEntry)
   109  		entry.tokens = maxTokens - packetCost
   110  		entry.lastTime = rate.timeNow()
   111  		rate.mu.Lock()
   112  		rate.table[ip] = entry
   113  		if len(rate.table) == 1 {
   114  			rate.stopReset <- struct{}{}
   115  		}
   116  		rate.mu.Unlock()
   117  		return true
   118  	}
   119  
   120  	// add tokens to entry
   121  	entry.mu.Lock()
   122  	now := rate.timeNow()
   123  	entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
   124  	entry.lastTime = now
   125  	if entry.tokens > maxTokens {
   126  		entry.tokens = maxTokens
   127  	}
   128  
   129  	// subtract cost of packet
   130  	if entry.tokens > packetCost {
   131  		entry.tokens -= packetCost
   132  		entry.mu.Unlock()
   133  		return true
   134  	}
   135  	entry.mu.Unlock()
   136  	return false
   137  }