github.com/koomox/wireguard-go@v0.0.0-20230722134753-17a50b2f22a3/ratelimiter/ratelimiter.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package ratelimiter 7 8 import ( 9 "net/netip" 10 "sync" 11 "time" 12 ) 13 14 const ( 15 packetsPerSecond = 20 16 packetsBurstable = 5 17 garbageCollectTime = time.Second 18 packetCost = 1000000000 / packetsPerSecond 19 maxTokens = packetCost * packetsBurstable 20 ) 21 22 type RatelimiterEntry struct { 23 mu sync.Mutex 24 lastTime time.Time 25 tokens int64 26 } 27 28 type Ratelimiter struct { 29 mu sync.RWMutex 30 timeNow func() time.Time 31 32 stopReset chan struct{} // send to reset, close to stop 33 table map[netip.Addr]*RatelimiterEntry 34 } 35 36 func (rate *Ratelimiter) Close() { 37 rate.mu.Lock() 38 defer rate.mu.Unlock() 39 40 if rate.stopReset != nil { 41 close(rate.stopReset) 42 } 43 } 44 45 func (rate *Ratelimiter) Init() { 46 rate.mu.Lock() 47 defer rate.mu.Unlock() 48 49 if rate.timeNow == nil { 50 rate.timeNow = time.Now 51 } 52 53 // stop any ongoing garbage collection routine 54 if rate.stopReset != nil { 55 close(rate.stopReset) 56 } 57 58 rate.stopReset = make(chan struct{}) 59 rate.table = make(map[netip.Addr]*RatelimiterEntry) 60 61 stopReset := rate.stopReset // store in case Init is called again. 62 63 // Start garbage collection routine. 64 go func() { 65 ticker := time.NewTicker(time.Second) 66 ticker.Stop() 67 for { 68 select { 69 case _, ok := <-stopReset: 70 ticker.Stop() 71 if !ok { 72 return 73 } 74 ticker = time.NewTicker(time.Second) 75 case <-ticker.C: 76 if rate.cleanup() { 77 ticker.Stop() 78 } 79 } 80 } 81 }() 82 } 83 84 func (rate *Ratelimiter) cleanup() (empty bool) { 85 rate.mu.Lock() 86 defer rate.mu.Unlock() 87 88 for key, entry := range rate.table { 89 entry.mu.Lock() 90 if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { 91 delete(rate.table, key) 92 } 93 entry.mu.Unlock() 94 } 95 96 return len(rate.table) == 0 97 } 98 99 func (rate *Ratelimiter) Allow(ip netip.Addr) bool { 100 var entry *RatelimiterEntry 101 // lookup entry 102 rate.mu.RLock() 103 entry = rate.table[ip] 104 rate.mu.RUnlock() 105 106 // make new entry if not found 107 if entry == nil { 108 entry = new(RatelimiterEntry) 109 entry.tokens = maxTokens - packetCost 110 entry.lastTime = rate.timeNow() 111 rate.mu.Lock() 112 rate.table[ip] = entry 113 if len(rate.table) == 1 { 114 rate.stopReset <- struct{}{} 115 } 116 rate.mu.Unlock() 117 return true 118 } 119 120 // add tokens to entry 121 entry.mu.Lock() 122 now := rate.timeNow() 123 entry.tokens += now.Sub(entry.lastTime).Nanoseconds() 124 entry.lastTime = now 125 if entry.tokens > maxTokens { 126 entry.tokens = maxTokens 127 } 128 129 // subtract cost of packet 130 if entry.tokens > packetCost { 131 entry.tokens -= packetCost 132 entry.mu.Unlock() 133 return true 134 } 135 entry.mu.Unlock() 136 return false 137 }