github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/ratelimiter/ratelimiter.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package ratelimiter
     7  
     8  import (
     9  	"sync"
    10  	"time"
    11  
    12  	"golang.zx2c4.com/go118/netip"
    13  )
    14  
    15  const (
    16  	packetsPerSecond   = 20
    17  	packetsBurstable   = 5
    18  	garbageCollectTime = time.Second
    19  	packetCost         = 1000000000 / packetsPerSecond
    20  	maxTokens          = packetCost * packetsBurstable
    21  )
    22  
    23  type RatelimiterEntry struct {
    24  	mu       sync.Mutex
    25  	lastTime time.Time
    26  	tokens   int64
    27  }
    28  
    29  type Ratelimiter struct {
    30  	mu      sync.RWMutex
    31  	timeNow func() time.Time
    32  
    33  	stopReset chan struct{} // send to reset, close to stop
    34  	table     map[netip.Addr]*RatelimiterEntry
    35  }
    36  
    37  func (rate *Ratelimiter) Close() {
    38  	rate.mu.Lock()
    39  	defer rate.mu.Unlock()
    40  
    41  	if rate.stopReset != nil {
    42  		close(rate.stopReset)
    43  	}
    44  }
    45  
    46  func (rate *Ratelimiter) Init() {
    47  	rate.mu.Lock()
    48  	defer rate.mu.Unlock()
    49  
    50  	if rate.timeNow == nil {
    51  		rate.timeNow = time.Now
    52  	}
    53  
    54  	// stop any ongoing garbage collection routine
    55  	if rate.stopReset != nil {
    56  		close(rate.stopReset)
    57  	}
    58  
    59  	rate.stopReset = make(chan struct{})
    60  	rate.table = make(map[netip.Addr]*RatelimiterEntry)
    61  
    62  	stopReset := rate.stopReset // store in case Init is called again.
    63  
    64  	// Start garbage collection routine.
    65  	go func() {
    66  		ticker := time.NewTicker(time.Second)
    67  		ticker.Stop()
    68  		for {
    69  			select {
    70  			case _, ok := <-stopReset:
    71  				ticker.Stop()
    72  				if !ok {
    73  					return
    74  				}
    75  				ticker = time.NewTicker(time.Second)
    76  			case <-ticker.C:
    77  				if rate.cleanup() {
    78  					ticker.Stop()
    79  				}
    80  			}
    81  		}
    82  	}()
    83  }
    84  
    85  func (rate *Ratelimiter) cleanup() (empty bool) {
    86  	rate.mu.Lock()
    87  	defer rate.mu.Unlock()
    88  
    89  	for key, entry := range rate.table {
    90  		entry.mu.Lock()
    91  		if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
    92  			delete(rate.table, key)
    93  		}
    94  		entry.mu.Unlock()
    95  	}
    96  
    97  	return len(rate.table) == 0
    98  }
    99  
   100  func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
   101  	var entry *RatelimiterEntry
   102  	// lookup entry
   103  	rate.mu.RLock()
   104  	entry = rate.table[ip]
   105  	rate.mu.RUnlock()
   106  
   107  	// make new entry if not found
   108  	if entry == nil {
   109  		entry = new(RatelimiterEntry)
   110  		entry.tokens = maxTokens - packetCost
   111  		entry.lastTime = rate.timeNow()
   112  		rate.mu.Lock()
   113  		rate.table[ip] = entry
   114  		if len(rate.table) == 1 {
   115  			rate.stopReset <- struct{}{}
   116  		}
   117  		rate.mu.Unlock()
   118  		return true
   119  	}
   120  
   121  	// add tokens to entry
   122  	entry.mu.Lock()
   123  	now := rate.timeNow()
   124  	entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
   125  	entry.lastTime = now
   126  	if entry.tokens > maxTokens {
   127  		entry.tokens = maxTokens
   128  	}
   129  
   130  	// subtract cost of packet
   131  	if entry.tokens > packetCost {
   132  		entry.tokens -= packetCost
   133  		entry.mu.Unlock()
   134  		return true
   135  	}
   136  	entry.mu.Unlock()
   137  	return false
   138  }