github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/ratelimiter/ratelimiter.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package ratelimiter
     7  
     8  import (
     9  	"net"
    10  	"sync"
    11  	"time"
    12  )
    13  
    14  const (
    15  	packetsPerSecond   = 20
    16  	packetsBurstable   = 5
    17  	garbageCollectTime = time.Second
    18  	packetCost         = 1000000000 / packetsPerSecond
    19  	maxTokens          = packetCost * packetsBurstable
    20  )
    21  
    22  type RatelimiterEntry struct {
    23  	mu       sync.Mutex
    24  	lastTime time.Time
    25  	tokens   int64
    26  }
    27  
    28  type Ratelimiter struct {
    29  	mu      sync.RWMutex
    30  	timeNow func() time.Time
    31  
    32  	stopReset chan struct{} // send to reset, close to stop
    33  	tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
    34  	tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
    35  }
    36  
    37  func (rate *Ratelimiter) Close() {
    38  	rate.mu.Lock()
    39  	defer rate.mu.Unlock()
    40  
    41  	if rate.stopReset != nil {
    42  		close(rate.stopReset)
    43  	}
    44  }
    45  
    46  func (rate *Ratelimiter) Init() {
    47  	rate.mu.Lock()
    48  	defer rate.mu.Unlock()
    49  
    50  	if rate.timeNow == nil {
    51  		rate.timeNow = time.Now
    52  	}
    53  
    54  	// stop any ongoing garbage collection routine
    55  	if rate.stopReset != nil {
    56  		close(rate.stopReset)
    57  	}
    58  
    59  	rate.stopReset = make(chan struct{})
    60  	rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
    61  	rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
    62  
    63  	stopReset := rate.stopReset // store in case Init is called again.
    64  
    65  	// Start garbage collection routine.
    66  	go func() {
    67  		ticker := time.NewTicker(time.Second)
    68  		ticker.Stop()
    69  		for {
    70  			select {
    71  			case _, ok := <-stopReset:
    72  				ticker.Stop()
    73  				if !ok {
    74  					return
    75  				}
    76  				ticker = time.NewTicker(time.Second)
    77  			case <-ticker.C:
    78  				if rate.cleanup() {
    79  					ticker.Stop()
    80  				}
    81  			}
    82  		}
    83  	}()
    84  }
    85  
    86  func (rate *Ratelimiter) cleanup() (empty bool) {
    87  	rate.mu.Lock()
    88  	defer rate.mu.Unlock()
    89  
    90  	for key, entry := range rate.tableIPv4 {
    91  		entry.mu.Lock()
    92  		if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
    93  			delete(rate.tableIPv4, key)
    94  		}
    95  		entry.mu.Unlock()
    96  	}
    97  
    98  	for key, entry := range rate.tableIPv6 {
    99  		entry.mu.Lock()
   100  		if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
   101  			delete(rate.tableIPv6, key)
   102  		}
   103  		entry.mu.Unlock()
   104  	}
   105  
   106  	return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
   107  }
   108  
   109  func (rate *Ratelimiter) Allow(ip net.IP) bool {
   110  	var entry *RatelimiterEntry
   111  	var keyIPv4 [net.IPv4len]byte
   112  	var keyIPv6 [net.IPv6len]byte
   113  
   114  	// lookup entry
   115  
   116  	IPv4 := ip.To4()
   117  	IPv6 := ip.To16()
   118  
   119  	rate.mu.RLock()
   120  
   121  	if IPv4 != nil {
   122  		copy(keyIPv4[:], IPv4)
   123  		entry = rate.tableIPv4[keyIPv4]
   124  	} else {
   125  		copy(keyIPv6[:], IPv6)
   126  		entry = rate.tableIPv6[keyIPv6]
   127  	}
   128  
   129  	rate.mu.RUnlock()
   130  
   131  	// make new entry if not found
   132  
   133  	if entry == nil {
   134  		entry = new(RatelimiterEntry)
   135  		entry.tokens = maxTokens - packetCost
   136  		entry.lastTime = rate.timeNow()
   137  		rate.mu.Lock()
   138  		if IPv4 != nil {
   139  			rate.tableIPv4[keyIPv4] = entry
   140  			if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
   141  				rate.stopReset <- struct{}{}
   142  			}
   143  		} else {
   144  			rate.tableIPv6[keyIPv6] = entry
   145  			if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
   146  				rate.stopReset <- struct{}{}
   147  			}
   148  		}
   149  		rate.mu.Unlock()
   150  		return true
   151  	}
   152  
   153  	// add tokens to entry
   154  
   155  	entry.mu.Lock()
   156  	now := rate.timeNow()
   157  	entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
   158  	entry.lastTime = now
   159  	if entry.tokens > maxTokens {
   160  		entry.tokens = maxTokens
   161  	}
   162  
   163  	// subtract cost of packet
   164  
   165  	if entry.tokens > packetCost {
   166  		entry.tokens -= packetCost
   167  		entry.mu.Unlock()
   168  		return true
   169  	}
   170  	entry.mu.Unlock()
   171  	return false
   172  }