github.com/prysmaticlabs/prysm@v1.4.4/beacon-chain/sync/rate_limiter.go (about)

     1  package sync
     2  
     3  import (
     4  	"reflect"
     5  	"sync"
     6  
     7  	"github.com/kevinms/leakybucket-go"
     8  	"github.com/libp2p/go-libp2p-core/network"
     9  	"github.com/pkg/errors"
    10  	"github.com/prysmaticlabs/prysm/beacon-chain/p2p"
    11  	p2ptypes "github.com/prysmaticlabs/prysm/beacon-chain/p2p/types"
    12  	"github.com/prysmaticlabs/prysm/cmd/beacon-chain/flags"
    13  	"github.com/sirupsen/logrus"
    14  	"github.com/trailofbits/go-mutexasserts"
    15  )
    16  
    17  const defaultBurstLimit = 5
    18  
    19  // Dummy topic to validate all incoming rpc requests.
    20  const rpcLimiterTopic = "rpc-limiter-topic"
    21  
    22  type limiter struct {
    23  	limiterMap map[string]*leakybucket.Collector
    24  	p2p        p2p.P2P
    25  	sync.RWMutex
    26  }
    27  
    28  // Instantiates a multi-rpc protocol rate limiter, providing
    29  // separate collectors for each topic.
    30  func newRateLimiter(p2pProvider p2p.P2P) *limiter {
    31  	// add encoding suffix
    32  	addEncoding := func(topic string) string {
    33  		return topic + p2pProvider.Encoding().ProtocolSuffix()
    34  	}
    35  	// Initialize block limits.
    36  	allowedBlocksPerSecond := float64(flags.Get().BlockBatchLimit)
    37  	allowedBlocksBurst := int64(flags.Get().BlockBatchLimitBurstFactor * flags.Get().BlockBatchLimit)
    38  
    39  	// Set topic map for all rpc topics.
    40  	topicMap := make(map[string]*leakybucket.Collector, len(p2p.RPCTopicMappings))
    41  	// Goodbye Message
    42  	topicMap[addEncoding(p2p.RPCGoodByeTopicV1)] = leakybucket.NewCollector(1, 1, false /* deleteEmptyBuckets */)
    43  	// MetadataV0 Message
    44  	topicMap[addEncoding(p2p.RPCMetaDataTopicV1)] = leakybucket.NewCollector(1, defaultBurstLimit, false /* deleteEmptyBuckets */)
    45  	// Ping Message
    46  	topicMap[addEncoding(p2p.RPCPingTopicV1)] = leakybucket.NewCollector(1, defaultBurstLimit, false /* deleteEmptyBuckets */)
    47  	// Status Message
    48  	topicMap[addEncoding(p2p.RPCStatusTopicV1)] = leakybucket.NewCollector(1, defaultBurstLimit, false /* deleteEmptyBuckets */)
    49  
    50  	// Use a single collector for block requests
    51  	blockCollector := leakybucket.NewCollector(allowedBlocksPerSecond, allowedBlocksBurst, false /* deleteEmptyBuckets */)
    52  
    53  	// BlocksByRoots requests
    54  	topicMap[addEncoding(p2p.RPCBlocksByRootTopicV1)] = blockCollector
    55  
    56  	// BlockByRange requests
    57  	topicMap[addEncoding(p2p.RPCBlocksByRangeTopicV1)] = blockCollector
    58  
    59  	// General topic for all rpc requests.
    60  	topicMap[rpcLimiterTopic] = leakybucket.NewCollector(5, defaultBurstLimit*2, false /* deleteEmptyBuckets */)
    61  
    62  	return &limiter{limiterMap: topicMap, p2p: p2pProvider}
    63  }
    64  
    65  // Returns the current topic collector for the provided topic.
    66  func (l *limiter) topicCollector(topic string) (*leakybucket.Collector, error) {
    67  	l.RLock()
    68  	defer l.RUnlock()
    69  	return l.retrieveCollector(topic)
    70  }
    71  
    72  // validates a request with the accompanying cost.
    73  func (l *limiter) validateRequest(stream network.Stream, amt uint64) error {
    74  	l.RLock()
    75  	defer l.RUnlock()
    76  
    77  	topic := string(stream.Protocol())
    78  
    79  	collector, err := l.retrieveCollector(topic)
    80  	if err != nil {
    81  		return err
    82  	}
    83  	key := stream.Conn().RemotePeer().String()
    84  	remaining := collector.Remaining(key)
    85  	// Treat each request as a minimum of 1.
    86  	if amt == 0 {
    87  		amt = 1
    88  	}
    89  	if amt > uint64(remaining) {
    90  		l.p2p.Peers().Scorers().BadResponsesScorer().Increment(stream.Conn().RemotePeer())
    91  		writeErrorResponseToStream(responseCodeInvalidRequest, p2ptypes.ErrRateLimited.Error(), stream, l.p2p)
    92  		return p2ptypes.ErrRateLimited
    93  	}
    94  	return nil
    95  }
    96  
    97  // This is used to validate all incoming rpc streams from external peers.
    98  func (l *limiter) validateRawRpcRequest(stream network.Stream) error {
    99  	l.RLock()
   100  	defer l.RUnlock()
   101  
   102  	topic := rpcLimiterTopic
   103  
   104  	collector, err := l.retrieveCollector(topic)
   105  	if err != nil {
   106  		return err
   107  	}
   108  	key := stream.Conn().RemotePeer().String()
   109  	remaining := collector.Remaining(key)
   110  	// Treat each request as a minimum of 1.
   111  	amt := int64(1)
   112  	if amt > remaining {
   113  		l.p2p.Peers().Scorers().BadResponsesScorer().Increment(stream.Conn().RemotePeer())
   114  		writeErrorResponseToStream(responseCodeInvalidRequest, p2ptypes.ErrRateLimited.Error(), stream, l.p2p)
   115  		return p2ptypes.ErrRateLimited
   116  	}
   117  	return nil
   118  }
   119  
   120  // adds the cost to our leaky bucket for the topic.
   121  func (l *limiter) add(stream network.Stream, amt int64) {
   122  	l.Lock()
   123  	defer l.Unlock()
   124  
   125  	topic := string(stream.Protocol())
   126  	log := l.topicLogger(topic)
   127  
   128  	collector, err := l.retrieveCollector(topic)
   129  	if err != nil {
   130  		log.Errorf("collector with topic '%s' does not exist", topic)
   131  		return
   132  	}
   133  	key := stream.Conn().RemotePeer().String()
   134  	collector.Add(key, amt)
   135  }
   136  
   137  // adds the cost to our leaky bucket for the peer.
   138  func (l *limiter) addRawStream(stream network.Stream) {
   139  	l.Lock()
   140  	defer l.Unlock()
   141  
   142  	topic := rpcLimiterTopic
   143  	log := l.topicLogger(topic)
   144  
   145  	collector, err := l.retrieveCollector(topic)
   146  	if err != nil {
   147  		log.Errorf("collector with topic '%s' does not exist", topic)
   148  		return
   149  	}
   150  	key := stream.Conn().RemotePeer().String()
   151  	collector.Add(key, 1)
   152  }
   153  
   154  // frees all the collectors and removes them.
   155  func (l *limiter) free() {
   156  	l.Lock()
   157  	defer l.Unlock()
   158  
   159  	tempMap := map[uintptr]bool{}
   160  	for t, collector := range l.limiterMap {
   161  		// Check if collector has already been cleared off
   162  		// as all collectors are not distinct from each other.
   163  		ptr := reflect.ValueOf(collector).Pointer()
   164  		if tempMap[ptr] {
   165  			// Remove from map
   166  			delete(l.limiterMap, t)
   167  			continue
   168  		}
   169  		collector.Free()
   170  		// Remove from map
   171  		delete(l.limiterMap, t)
   172  		tempMap[ptr] = true
   173  	}
   174  }
   175  
   176  // not to be used outside the rate limiter file as it is unsafe for concurrent usage
   177  // and is protected by a lock on all of its usages here.
   178  func (l *limiter) retrieveCollector(topic string) (*leakybucket.Collector, error) {
   179  	if !mutexasserts.RWMutexLocked(&l.RWMutex) && !mutexasserts.RWMutexRLocked(&l.RWMutex) {
   180  		return nil, errors.New("limiter.retrieveCollector: caller must hold read/write lock")
   181  	}
   182  	collector, ok := l.limiterMap[topic]
   183  	if !ok {
   184  		return nil, errors.Errorf("collector does not exist for topic %s", topic)
   185  	}
   186  	return collector, nil
   187  }
   188  
   189  func (l *limiter) topicLogger(topic string) *logrus.Entry {
   190  	return log.WithField("rate limiter", topic)
   191  }