code.vegaprotocol.io/vega@v0.79.0/libs/http/rate_limit.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package http
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"net"
    22  	"sync"
    23  	"time"
    24  
    25  	"code.vegaprotocol.io/vega/libs/config/encoding"
    26  )
    27  
    28  type RateLimitConfig struct {
    29  	CoolDown encoding.Duration `description:"rate-limit duration, e.g. 10s, 1m30s, 24h0m0s" long:"coolDown"`
    30  
    31  	AllowList []string `description:"a list of ip/subnets, e.g. 10.0.0.0/8, 192.168.0.0/16" long:"allowList"`
    32  
    33  	allowList []net.IPNet
    34  }
    35  
    36  type RateLimit struct {
    37  	cfg RateLimitConfig
    38  	// map of any_identifier -> time until request can be allowed
    39  	requests map[string]time.Time
    40  
    41  	mu sync.Mutex
    42  }
    43  
    44  func NewRateLimit(ctx context.Context, cfg RateLimitConfig) (*RateLimit, error) {
    45  	cfg.allowList = make([]net.IPNet, len(cfg.AllowList))
    46  	for i, allowItem := range cfg.AllowList {
    47  		_, ipnet, err := net.ParseCIDR(allowItem)
    48  		if err != nil {
    49  			return nil, fmt.Errorf("failed to parse AllowList entry: %s", allowItem)
    50  		}
    51  		cfg.allowList[i] = *ipnet
    52  	}
    53  	r := &RateLimit{
    54  		cfg:      cfg,
    55  		requests: map[string]time.Time{},
    56  	}
    57  	go r.startCleanup(ctx)
    58  	return r, nil
    59  }
    60  
    61  // NewRequest returns nil if the rate has not been exceeded.
    62  func (r *RateLimit) NewRequest(identifier, ip string) error {
    63  	r.mu.Lock()
    64  	defer r.mu.Unlock()
    65  
    66  	if r.isAllowListed(ip) {
    67  		return nil
    68  	}
    69  
    70  	until, found := r.requests[identifier]
    71  	if !found {
    72  		until = time.Time{}
    73  		r.requests[identifier] = until
    74  	}
    75  	// just check in case the time is expired already, and
    76  	// we had not the cleanup run
    77  	if time.Now().Before(until) {
    78  		// we are already greylisted,
    79  		// another request came in while still greylisted
    80  		return fmt.Errorf("rate-limited (%s) until %v", identifier, r.requests[identifier])
    81  	}
    82  
    83  	// greylist for the minimal duration
    84  	r.requests[identifier] = time.Now().Add(r.cfg.CoolDown.Duration)
    85  
    86  	return nil
    87  }
    88  
    89  func (r *RateLimit) isAllowListed(ip string) bool {
    90  	netIP := net.ParseIP(ip)
    91  	for _, allowItem := range r.cfg.allowList {
    92  		if allowItem.Contains(netIP) {
    93  			return true
    94  		}
    95  	}
    96  	return false
    97  }
    98  
    99  func (r *RateLimit) startCleanup(ctx context.Context) {
   100  	ticker := time.NewTicker(1 * time.Minute)
   101  	defer ticker.Stop()
   102  
   103  	for {
   104  		select {
   105  		case <-ctx.Done():
   106  			return
   107  		case <-ticker.C:
   108  			now := time.Now()
   109  			r.mu.Lock()
   110  			for identifier, until := range r.requests {
   111  				// if time is elapsed, remove from the map
   112  				if until.Before(now) {
   113  					delete(r.requests, identifier)
   114  				}
   115  			}
   116  			r.mu.Unlock()
   117  		}
   118  	}
   119  }