github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/ratelimiter/ratelimiter.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package ratelimiter 7 8 import ( 9 "sync" 10 "time" 11 12 "golang.zx2c4.com/go118/netip" 13 ) 14 15 const ( 16 packetsPerSecond = 20 17 packetsBurstable = 5 18 garbageCollectTime = time.Second 19 packetCost = 1000000000 / packetsPerSecond 20 maxTokens = packetCost * packetsBurstable 21 ) 22 23 type RatelimiterEntry struct { 24 mu sync.Mutex 25 lastTime time.Time 26 tokens int64 27 } 28 29 type Ratelimiter struct { 30 mu sync.RWMutex 31 timeNow func() time.Time 32 33 stopReset chan struct{} // send to reset, close to stop 34 table map[netip.Addr]*RatelimiterEntry 35 } 36 37 func (rate *Ratelimiter) Close() { 38 rate.mu.Lock() 39 defer rate.mu.Unlock() 40 41 if rate.stopReset != nil { 42 close(rate.stopReset) 43 } 44 } 45 46 func (rate *Ratelimiter) Init() { 47 rate.mu.Lock() 48 defer rate.mu.Unlock() 49 50 if rate.timeNow == nil { 51 rate.timeNow = time.Now 52 } 53 54 // stop any ongoing garbage collection routine 55 if rate.stopReset != nil { 56 close(rate.stopReset) 57 } 58 59 rate.stopReset = make(chan struct{}) 60 rate.table = make(map[netip.Addr]*RatelimiterEntry) 61 62 stopReset := rate.stopReset // store in case Init is called again. 63 64 // Start garbage collection routine. 65 go func() { 66 ticker := time.NewTicker(time.Second) 67 ticker.Stop() 68 for { 69 select { 70 case _, ok := <-stopReset: 71 ticker.Stop() 72 if !ok { 73 return 74 } 75 ticker = time.NewTicker(time.Second) 76 case <-ticker.C: 77 if rate.cleanup() { 78 ticker.Stop() 79 } 80 } 81 } 82 }() 83 } 84 85 func (rate *Ratelimiter) cleanup() (empty bool) { 86 rate.mu.Lock() 87 defer rate.mu.Unlock() 88 89 for key, entry := range rate.table { 90 entry.mu.Lock() 91 if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { 92 delete(rate.table, key) 93 } 94 entry.mu.Unlock() 95 } 96 97 return len(rate.table) == 0 98 } 99 100 func (rate *Ratelimiter) Allow(ip netip.Addr) bool { 101 var entry *RatelimiterEntry 102 // lookup entry 103 rate.mu.RLock() 104 entry = rate.table[ip] 105 rate.mu.RUnlock() 106 107 // make new entry if not found 108 if entry == nil { 109 entry = new(RatelimiterEntry) 110 entry.tokens = maxTokens - packetCost 111 entry.lastTime = rate.timeNow() 112 rate.mu.Lock() 113 rate.table[ip] = entry 114 if len(rate.table) == 1 { 115 rate.stopReset <- struct{}{} 116 } 117 rate.mu.Unlock() 118 return true 119 } 120 121 // add tokens to entry 122 entry.mu.Lock() 123 now := rate.timeNow() 124 entry.tokens += now.Sub(entry.lastTime).Nanoseconds() 125 entry.lastTime = now 126 if entry.tokens > maxTokens { 127 entry.tokens = maxTokens 128 } 129 130 // subtract cost of packet 131 if entry.tokens > packetCost { 132 entry.tokens -= packetCost 133 entry.mu.Unlock() 134 return true 135 } 136 entry.mu.Unlock() 137 return false 138 }