github.com/amazechain/amc@v0.1.3/internal/sync/rate_limiter.go (about)

     1  package sync
     2  
     3  import (
     4  	"fmt"
     5  	"github.com/amazechain/amc/internal/p2p"
     6  	leakybucket "github.com/amazechain/amc/internal/p2p/leaky-bucket"
     7  	p2ptypes "github.com/amazechain/amc/internal/p2p/types"
     8  	"github.com/amazechain/amc/log"
     9  	"github.com/trailofbits/go-mutexasserts"
    10  	"reflect"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/libp2p/go-libp2p/core/network"
    15  	"github.com/pkg/errors"
    16  )
    17  
    18  const defaultBurstLimit = 5
    19  
    20  const leakyBucketPeriod = 1 * time.Second
    21  
    22  // Dummy topic to validate all incoming rpc requests.
    23  const rpcLimiterTopic = "rpc-limiter-topic"
    24  
    25  type limiter struct {
    26  	limiterMap map[string]*leakybucket.Collector
    27  	p2p        p2p.P2P
    28  	sync.RWMutex
    29  }
    30  
    31  // Instantiates a multi-rpc protocol rate limiter, providing
    32  // separate collectors for each topic.
    33  func newRateLimiter(p2pProvider p2p.P2P) *limiter {
    34  	// add encoding suffix
    35  	addEncoding := func(topic string) string {
    36  		return topic + p2pProvider.Encoding().ProtocolSuffix()
    37  	}
    38  
    39  	// Initialize block limits.
    40  	allowedBlocksPerSecond := float64(p2pProvider.GetConfig().P2PLimit.BlockBatchLimit)
    41  	allowedBlocksBurst := int64(p2pProvider.GetConfig().P2PLimit.BlockBatchLimitBurstFactor * p2pProvider.GetConfig().P2PLimit.BlockBatchLimit)
    42  
    43  	blockLimiterPeriod := time.Duration(p2pProvider.GetConfig().P2PLimit.BlockBatchLimiterPeriod) * time.Second
    44  
    45  	// Set topic map for all rpc topics.
    46  	topicMap := make(map[string]*leakybucket.Collector, len(p2p.RPCTopicMappings))
    47  	// Goodbye Message
    48  	topicMap[addEncoding(p2p.RPCGoodByeTopicV1)] = leakybucket.NewCollector(1, 1, leakyBucketPeriod, false /* deleteEmptyBuckets */)
    49  	// Ping Message
    50  	topicMap[addEncoding(p2p.RPCPingTopicV1)] = leakybucket.NewCollector(1, defaultBurstLimit, leakyBucketPeriod, false /* deleteEmptyBuckets */)
    51  	// Status Message
    52  	topicMap[addEncoding(p2p.RPCStatusTopicV1)] = leakybucket.NewCollector(1, defaultBurstLimit, leakyBucketPeriod, false /* deleteEmptyBuckets */)
    53  
    54  	// Bodies Message
    55  	topicMap[addEncoding(p2p.RPCBodiesDataTopicV1)] = leakybucket.NewCollector(allowedBlocksPerSecond, allowedBlocksBurst, blockLimiterPeriod, false /* deleteEmptyBuckets */)
    56  
    57  	// Headers Message
    58  	topicMap[addEncoding(p2p.RPCHeadersDataTopicV1)] = leakybucket.NewCollector(allowedBlocksPerSecond, allowedBlocksBurst, blockLimiterPeriod, false /* deleteEmptyBuckets */)
    59  
    60  	// General topic for all rpc requests.
    61  	topicMap[rpcLimiterTopic] = leakybucket.NewCollector(5, defaultBurstLimit*2, leakyBucketPeriod, false /* deleteEmptyBuckets */)
    62  
    63  	return &limiter{limiterMap: topicMap, p2p: p2pProvider}
    64  }
    65  
    66  // Returns the current topic collector for the provided topic.
    67  func (l *limiter) topicCollector(topic string) (*leakybucket.Collector, error) {
    68  	l.RLock()
    69  	defer l.RUnlock()
    70  	return l.retrieveCollector(topic)
    71  }
    72  
    73  // validates a request with the accompanying cost.
    74  func (l *limiter) validateRequest(stream network.Stream, amt uint64) error {
    75  	l.RLock()
    76  	defer l.RUnlock()
    77  
    78  	topic := string(stream.Protocol())
    79  
    80  	collector, err := l.retrieveCollector(topic)
    81  	if err != nil {
    82  		return err
    83  	}
    84  	key := stream.Conn().RemotePeer().String()
    85  	remaining := collector.Remaining(key)
    86  	// Treat each request as a minimum of 1.
    87  	if amt == 0 {
    88  		amt = 1
    89  	}
    90  	if amt > uint64(remaining) {
    91  		log.Warn("validate Request failure",
    92  			"key", key,
    93  			"topic", topic,
    94  			"count", amt,
    95  			"remaining", remaining,
    96  		)
    97  
    98  		l.p2p.Peers().Scorers().BadResponsesScorer().Increment(stream.Conn().RemotePeer())
    99  		writeErrorResponseToStream(responseCodeInvalidRequest, p2ptypes.ErrRateLimited.Error(), stream, l.p2p)
   100  		return p2ptypes.ErrRateLimited
   101  	}
   102  	return nil
   103  }
   104  
   105  // This is used to validate all incoming rpc streams from external peers.
   106  func (l *limiter) validateRawRpcRequest(stream network.Stream) error {
   107  	l.RLock()
   108  	defer l.RUnlock()
   109  
   110  	topic := rpcLimiterTopic
   111  
   112  	collector, err := l.retrieveCollector(topic)
   113  	if err != nil {
   114  		return err
   115  	}
   116  	key := stream.Conn().RemotePeer().String()
   117  	remaining := collector.Remaining(key)
   118  	// Treat each request as a minimum of 1.
   119  	amt := int64(1)
   120  	if amt > remaining {
   121  		l.p2p.Peers().Scorers().BadResponsesScorer().Increment(stream.Conn().RemotePeer())
   122  		writeErrorResponseToStream(responseCodeInvalidRequest, p2ptypes.ErrRateLimited.Error(), stream, l.p2p)
   123  		return p2ptypes.ErrRateLimited
   124  	}
   125  	return nil
   126  }
   127  
   128  // adds the cost to our leaky bucket for the topic.
   129  func (l *limiter) add(stream network.Stream, amt int64) {
   130  	l.Lock()
   131  	defer l.Unlock()
   132  
   133  	topic := string(stream.Protocol())
   134  
   135  	collector, err := l.retrieveCollector(topic)
   136  	if err != nil {
   137  		log.Error(fmt.Sprintf("collector with topic '%s' does not exist", topic), "rate limiter", topic)
   138  		return
   139  	}
   140  	key := stream.Conn().RemotePeer().String()
   141  	collector.Add(key, amt)
   142  }
   143  
   144  // adds the cost to our leaky bucket for the peer.
   145  func (l *limiter) addRawStream(stream network.Stream) {
   146  	l.Lock()
   147  	defer l.Unlock()
   148  
   149  	topic := rpcLimiterTopic
   150  
   151  	collector, err := l.retrieveCollector(topic)
   152  	if err != nil {
   153  		log.Error(fmt.Sprintf("collector with topic '%s' does not exist", topic), "rate limiter", topic)
   154  		return
   155  	}
   156  	key := stream.Conn().RemotePeer().String()
   157  	collector.Add(key, 1)
   158  }
   159  
   160  // frees all the collectors and removes them.
   161  func (l *limiter) free() {
   162  	l.Lock()
   163  	defer l.Unlock()
   164  
   165  	tempMap := map[uintptr]bool{}
   166  	for t, collector := range l.limiterMap {
   167  		// Check if collector has already been cleared off
   168  		// as all collectors are not distinct from each other.
   169  		ptr := reflect.ValueOf(collector).Pointer()
   170  		if tempMap[ptr] {
   171  			// Remove from map
   172  			delete(l.limiterMap, t)
   173  			continue
   174  		}
   175  		collector.Free()
   176  		// Remove from map
   177  		delete(l.limiterMap, t)
   178  		tempMap[ptr] = true
   179  	}
   180  }
   181  
   182  // not to be used outside the rate limiter file as it is unsafe for concurrent usage
   183  // and is protected by a lock on all of its usages here.
   184  func (l *limiter) retrieveCollector(topic string) (*leakybucket.Collector, error) {
   185  	if !mutexasserts.RWMutexLocked(&l.RWMutex) && !mutexasserts.RWMutexRLocked(&l.RWMutex) {
   186  		return nil, errors.New("limiter.retrieveCollector: caller must hold read/write lock")
   187  	}
   188  	collector, ok := l.limiterMap[topic]
   189  	if !ok {
   190  		return nil, errors.Errorf("collector does not exist for topic %s", topic)
   191  	}
   192  	return collector, nil
   193  }