github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/ratelimiter/ratelimiter.go (about)

     1  // SPDX-License-Identifier: MPL-2.0
     2  /*
     3   * Copyright (C) 2024 The Noisy Sockets Authors.
     4   *
     5   * This Source Code Form is subject to the terms of the Mozilla Public
     6   * License, v. 2.0. If a copy of the MPL was not distributed with this
     7   * file, You can obtain one at http://mozilla.org/MPL/2.0/.
     8   *
     9   * Portions of this file are based on code originally from wireguard-go,
    10   *
    11   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
    12   *
    13   * Permission is hereby granted, free of charge, to any person obtaining a copy of
    14   * this software and associated documentation files (the "Software"), to deal in
    15   * the Software without restriction, including without limitation the rights to
    16   * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
    17   * of the Software, and to permit persons to whom the Software is furnished to do
    18   * so, subject to the following conditions:
    19   *
    20   * The above copyright notice and this permission notice shall be included in all
    21   * copies or substantial portions of the Software.
    22   *
    23   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    24   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    25   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    26   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    27   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    28   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    29   * SOFTWARE.
    30   */
    31  
    32  package ratelimiter
    33  
    34  import (
    35  	"net/netip"
    36  	"sync"
    37  	"time"
    38  )
    39  
    40  const (
    41  	packetsPerSecond   = 20
    42  	packetsBurstable   = 5
    43  	garbageCollectTime = time.Second
    44  	packetCost         = 1000000000 / packetsPerSecond
    45  	maxTokens          = packetCost * packetsBurstable
    46  )
    47  
    48  type RatelimiterEntry struct {
    49  	mu       sync.Mutex
    50  	lastTime time.Time
    51  	tokens   int64
    52  }
    53  
    54  type Ratelimiter struct {
    55  	mu      sync.RWMutex
    56  	timeNow func() time.Time
    57  
    58  	stopReset chan struct{} // send to reset, close to stop
    59  	table     map[netip.Addr]*RatelimiterEntry
    60  }
    61  
    62  func (rate *Ratelimiter) Close() error {
    63  	rate.mu.Lock()
    64  	defer rate.mu.Unlock()
    65  
    66  	if rate.stopReset != nil {
    67  		close(rate.stopReset)
    68  	}
    69  
    70  	return nil
    71  }
    72  
    73  func (rate *Ratelimiter) Init() {
    74  	rate.mu.Lock()
    75  	defer rate.mu.Unlock()
    76  
    77  	if rate.timeNow == nil {
    78  		rate.timeNow = time.Now
    79  	}
    80  
    81  	// stop any ongoing garbage collection routine
    82  	if rate.stopReset != nil {
    83  		close(rate.stopReset)
    84  	}
    85  
    86  	rate.stopReset = make(chan struct{})
    87  	rate.table = make(map[netip.Addr]*RatelimiterEntry)
    88  
    89  	stopReset := rate.stopReset // store in case Init is called again.
    90  
    91  	// Start garbage collection routine.
    92  	go func() {
    93  		ticker := time.NewTicker(time.Second)
    94  		ticker.Stop()
    95  		for {
    96  			select {
    97  			case _, ok := <-stopReset:
    98  				ticker.Stop()
    99  				if !ok {
   100  					return
   101  				}
   102  				ticker = time.NewTicker(time.Second)
   103  			case <-ticker.C:
   104  				if rate.cleanup() {
   105  					ticker.Stop()
   106  				}
   107  			}
   108  		}
   109  	}()
   110  }
   111  
   112  func (rate *Ratelimiter) cleanup() (empty bool) {
   113  	rate.mu.Lock()
   114  	defer rate.mu.Unlock()
   115  
   116  	for key, entry := range rate.table {
   117  		entry.mu.Lock()
   118  		if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
   119  			delete(rate.table, key)
   120  		}
   121  		entry.mu.Unlock()
   122  	}
   123  
   124  	return len(rate.table) == 0
   125  }
   126  
   127  func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
   128  	var entry *RatelimiterEntry
   129  	// lookup entry
   130  	rate.mu.RLock()
   131  	entry = rate.table[ip]
   132  	rate.mu.RUnlock()
   133  
   134  	// make new entry if not found
   135  	if entry == nil {
   136  		entry = new(RatelimiterEntry)
   137  		entry.tokens = maxTokens - packetCost
   138  		entry.lastTime = rate.timeNow()
   139  		rate.mu.Lock()
   140  		rate.table[ip] = entry
   141  		if len(rate.table) == 1 {
   142  			rate.stopReset <- struct{}{}
   143  		}
   144  		rate.mu.Unlock()
   145  		return true
   146  	}
   147  
   148  	// add tokens to entry
   149  	entry.mu.Lock()
   150  	now := rate.timeNow()
   151  	entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
   152  	entry.lastTime = now
   153  	if entry.tokens > maxTokens {
   154  		entry.tokens = maxTokens
   155  	}
   156  
   157  	// subtract cost of packet
   158  	if entry.tokens > packetCost {
   159  		entry.tokens -= packetCost
   160  		entry.mu.Unlock()
   161  		return true
   162  	}
   163  	entry.mu.Unlock()
   164  	return false
   165  }