github.com/tailscale/wireguard-go@v0.0.20201119-0.20210522003738-46b531feb08a/ratelimiter/ratelimiter.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package ratelimiter 7 8 import ( 9 "net" 10 "sync" 11 "time" 12 ) 13 14 const ( 15 packetsPerSecond = 20 16 packetsBurstable = 5 17 garbageCollectTime = time.Second 18 packetCost = 1000000000 / packetsPerSecond 19 maxTokens = packetCost * packetsBurstable 20 ) 21 22 type RatelimiterEntry struct { 23 mu sync.Mutex 24 lastTime time.Time 25 tokens int64 26 } 27 28 type Ratelimiter struct { 29 mu sync.RWMutex 30 timeNow func() time.Time 31 32 stopReset chan struct{} // send to reset, close to stop 33 tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry 34 tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry 35 } 36 37 func (rate *Ratelimiter) Close() { 38 rate.mu.Lock() 39 defer rate.mu.Unlock() 40 41 if rate.stopReset != nil { 42 close(rate.stopReset) 43 } 44 } 45 46 func (rate *Ratelimiter) Init() { 47 rate.mu.Lock() 48 defer rate.mu.Unlock() 49 50 if rate.timeNow == nil { 51 rate.timeNow = time.Now 52 } 53 54 // stop any ongoing garbage collection routine 55 if rate.stopReset != nil { 56 close(rate.stopReset) 57 } 58 59 rate.stopReset = make(chan struct{}) 60 rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) 61 rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) 62 63 stopReset := rate.stopReset // store in case Init is called again. 64 65 // Start garbage collection routine. 66 go func() { 67 ticker := time.NewTicker(time.Second) 68 ticker.Stop() 69 for { 70 select { 71 case _, ok := <-stopReset: 72 ticker.Stop() 73 if !ok { 74 return 75 } 76 ticker = time.NewTicker(time.Second) 77 case <-ticker.C: 78 if rate.cleanup() { 79 ticker.Stop() 80 } 81 } 82 } 83 }() 84 } 85 86 func (rate *Ratelimiter) cleanup() (empty bool) { 87 rate.mu.Lock() 88 defer rate.mu.Unlock() 89 90 for key, entry := range rate.tableIPv4 { 91 entry.mu.Lock() 92 if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { 93 delete(rate.tableIPv4, key) 94 } 95 entry.mu.Unlock() 96 } 97 98 for key, entry := range rate.tableIPv6 { 99 entry.mu.Lock() 100 if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { 101 delete(rate.tableIPv6, key) 102 } 103 entry.mu.Unlock() 104 } 105 106 return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 107 } 108 109 func (rate *Ratelimiter) Allow(ip net.IP) bool { 110 var entry *RatelimiterEntry 111 var keyIPv4 [net.IPv4len]byte 112 var keyIPv6 [net.IPv6len]byte 113 114 // lookup entry 115 116 IPv4 := ip.To4() 117 IPv6 := ip.To16() 118 119 rate.mu.RLock() 120 121 if IPv4 != nil { 122 copy(keyIPv4[:], IPv4) 123 entry = rate.tableIPv4[keyIPv4] 124 } else { 125 copy(keyIPv6[:], IPv6) 126 entry = rate.tableIPv6[keyIPv6] 127 } 128 129 rate.mu.RUnlock() 130 131 // make new entry if not found 132 133 if entry == nil { 134 entry = new(RatelimiterEntry) 135 entry.tokens = maxTokens - packetCost 136 entry.lastTime = rate.timeNow() 137 rate.mu.Lock() 138 if IPv4 != nil { 139 rate.tableIPv4[keyIPv4] = entry 140 if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 { 141 rate.stopReset <- struct{}{} 142 } 143 } else { 144 rate.tableIPv6[keyIPv6] = entry 145 if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 { 146 rate.stopReset <- struct{}{} 147 } 148 } 149 rate.mu.Unlock() 150 return true 151 } 152 153 // add tokens to entry 154 155 entry.mu.Lock() 156 now := rate.timeNow() 157 entry.tokens += now.Sub(entry.lastTime).Nanoseconds() 158 entry.lastTime = now 159 if entry.tokens > maxTokens { 160 entry.tokens = maxTokens 161 } 162 163 // subtract cost of packet 164 165 if entry.tokens > packetCost { 166 entry.tokens -= packetCost 167 entry.mu.Unlock() 168 return true 169 } 170 entry.mu.Unlock() 171 return false 172 }