github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/ratelimiter/ratelimiter.go (about) 1 // SPDX-License-Identifier: MPL-2.0 2 /* 3 * Copyright (C) 2024 The Noisy Sockets Authors. 4 * 5 * This Source Code Form is subject to the terms of the Mozilla Public 6 * License, v. 2.0. If a copy of the MPL was not distributed with this 7 * file, You can obtain one at http://mozilla.org/MPL/2.0/. 8 * 9 * Portions of this file are based on code originally from wireguard-go, 10 * 11 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 12 * 13 * Permission is hereby granted, free of charge, to any person obtaining a copy of 14 * this software and associated documentation files (the "Software"), to deal in 15 * the Software without restriction, including without limitation the rights to 16 * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 17 * of the Software, and to permit persons to whom the Software is furnished to do 18 * so, subject to the following conditions: 19 * 20 * The above copyright notice and this permission notice shall be included in all 21 * copies or substantial portions of the Software. 22 * 23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 * SOFTWARE. 30 */ 31 32 package ratelimiter 33 34 import ( 35 "net/netip" 36 "sync" 37 "time" 38 ) 39 40 const ( 41 packetsPerSecond = 20 42 packetsBurstable = 5 43 garbageCollectTime = time.Second 44 packetCost = 1000000000 / packetsPerSecond 45 maxTokens = packetCost * packetsBurstable 46 ) 47 48 type RatelimiterEntry struct { 49 mu sync.Mutex 50 lastTime time.Time 51 tokens int64 52 } 53 54 type Ratelimiter struct { 55 mu sync.RWMutex 56 timeNow func() time.Time 57 58 stopReset chan struct{} // send to reset, close to stop 59 table map[netip.Addr]*RatelimiterEntry 60 } 61 62 func (rate *Ratelimiter) Close() error { 63 rate.mu.Lock() 64 defer rate.mu.Unlock() 65 66 if rate.stopReset != nil { 67 close(rate.stopReset) 68 } 69 70 return nil 71 } 72 73 func (rate *Ratelimiter) Init() { 74 rate.mu.Lock() 75 defer rate.mu.Unlock() 76 77 if rate.timeNow == nil { 78 rate.timeNow = time.Now 79 } 80 81 // stop any ongoing garbage collection routine 82 if rate.stopReset != nil { 83 close(rate.stopReset) 84 } 85 86 rate.stopReset = make(chan struct{}) 87 rate.table = make(map[netip.Addr]*RatelimiterEntry) 88 89 stopReset := rate.stopReset // store in case Init is called again. 90 91 // Start garbage collection routine. 92 go func() { 93 ticker := time.NewTicker(time.Second) 94 ticker.Stop() 95 for { 96 select { 97 case _, ok := <-stopReset: 98 ticker.Stop() 99 if !ok { 100 return 101 } 102 ticker = time.NewTicker(time.Second) 103 case <-ticker.C: 104 if rate.cleanup() { 105 ticker.Stop() 106 } 107 } 108 } 109 }() 110 } 111 112 func (rate *Ratelimiter) cleanup() (empty bool) { 113 rate.mu.Lock() 114 defer rate.mu.Unlock() 115 116 for key, entry := range rate.table { 117 entry.mu.Lock() 118 if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { 119 delete(rate.table, key) 120 } 121 entry.mu.Unlock() 122 } 123 124 return len(rate.table) == 0 125 } 126 127 func (rate *Ratelimiter) Allow(ip netip.Addr) bool { 128 var entry *RatelimiterEntry 129 // lookup entry 130 rate.mu.RLock() 131 entry = rate.table[ip] 132 rate.mu.RUnlock() 133 134 // make new entry if not found 135 if entry == nil { 136 entry = new(RatelimiterEntry) 137 entry.tokens = maxTokens - packetCost 138 entry.lastTime = rate.timeNow() 139 rate.mu.Lock() 140 rate.table[ip] = entry 141 if len(rate.table) == 1 { 142 rate.stopReset <- struct{}{} 143 } 144 rate.mu.Unlock() 145 return true 146 } 147 148 // add tokens to entry 149 entry.mu.Lock() 150 now := rate.timeNow() 151 entry.tokens += now.Sub(entry.lastTime).Nanoseconds() 152 entry.lastTime = now 153 if entry.tokens > maxTokens { 154 entry.tokens = maxTokens 155 } 156 157 // subtract cost of packet 158 if entry.tokens > packetCost { 159 entry.tokens -= packetCost 160 entry.mu.Unlock() 161 return true 162 } 163 entry.mu.Unlock() 164 return false 165 }