github.com/yuanzimu/bsc@v1.1.4/les/utils/limiter.go (about)

     1  // Copyright 2020 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package utils
    18  
    19  import (
    20  	"sort"
    21  	"sync"
    22  
    23  	"github.com/ethereum/go-ethereum/p2p/enode"
    24  )
    25  
    26  const maxSelectionWeight = 1000000000 // maximum selection weight of each individual node/address group
    27  
    28  // Limiter protects a network request serving mechanism from denial-of-service attacks.
    29  // It limits the total amount of resources used for serving requests while ensuring that
    30  // the most valuable connections always have a reasonable chance of being served.
    31  type Limiter struct {
    32  	lock sync.Mutex
    33  	cond *sync.Cond
    34  	quit bool
    35  
    36  	nodes                          map[enode.ID]*nodeQueue
    37  	addresses                      map[string]*addressGroup
    38  	addressSelect, valueSelect     *WeightedRandomSelect
    39  	maxValue                       float64
    40  	maxCost, sumCost, sumCostLimit uint
    41  	selectAddressNext              bool
    42  }
    43  
    44  // nodeQueue represents queued requests coming from a single node ID
    45  type nodeQueue struct {
    46  	queue                   []request // always nil if penaltyCost != 0
    47  	id                      enode.ID
    48  	address                 string
    49  	value                   float64
    50  	flatWeight, valueWeight uint64 // current selection weights in the address/value selectors
    51  	sumCost                 uint   // summed cost of requests queued by the node
    52  	penaltyCost             uint   // cumulative cost of dropped requests since last processed request
    53  	groupIndex              int
    54  }
    55  
    56  // addressGroup is a group of node IDs that have sent their last requests from the same
    57  // network address
    58  type addressGroup struct {
    59  	nodes                      []*nodeQueue
    60  	nodeSelect                 *WeightedRandomSelect
    61  	sumFlatWeight, groupWeight uint64
    62  }
    63  
    64  // request represents an incoming request scheduled for processing
    65  type request struct {
    66  	process chan chan struct{}
    67  	cost    uint
    68  }
    69  
    70  // flatWeight distributes weights equally between each active network address
    71  func flatWeight(item interface{}) uint64 { return item.(*nodeQueue).flatWeight }
    72  
    73  // add adds the node queue to the address group. It is the caller's responsibility to
    74  // add the address group to the address map and the address selector if it wasn't
    75  // there before.
    76  func (ag *addressGroup) add(nq *nodeQueue) {
    77  	if nq.groupIndex != -1 {
    78  		panic("added node queue is already in an address group")
    79  	}
    80  	l := len(ag.nodes)
    81  	nq.groupIndex = l
    82  	ag.nodes = append(ag.nodes, nq)
    83  	ag.sumFlatWeight += nq.flatWeight
    84  	ag.groupWeight = ag.sumFlatWeight / uint64(l+1)
    85  	ag.nodeSelect.Update(ag.nodes[l])
    86  }
    87  
    88  // update updates the selection weight of the node queue inside the address group.
    89  // It is the caller's responsibility to update the group's selection weight in the
    90  // address selector.
    91  func (ag *addressGroup) update(nq *nodeQueue, weight uint64) {
    92  	if nq.groupIndex == -1 || nq.groupIndex >= len(ag.nodes) || ag.nodes[nq.groupIndex] != nq {
    93  		panic("updated node queue is not in this address group")
    94  	}
    95  	ag.sumFlatWeight += weight - nq.flatWeight
    96  	nq.flatWeight = weight
    97  	ag.groupWeight = ag.sumFlatWeight / uint64(len(ag.nodes))
    98  	ag.nodeSelect.Update(nq)
    99  }
   100  
   101  // remove removes the node queue from the address group. It is the caller's responsibility
   102  // to remove the address group from the address map if it is empty.
   103  func (ag *addressGroup) remove(nq *nodeQueue) {
   104  	if nq.groupIndex == -1 || nq.groupIndex >= len(ag.nodes) || ag.nodes[nq.groupIndex] != nq {
   105  		panic("removed node queue is not in this address group")
   106  	}
   107  
   108  	l := len(ag.nodes) - 1
   109  	if nq.groupIndex != l {
   110  		ag.nodes[nq.groupIndex] = ag.nodes[l]
   111  		ag.nodes[nq.groupIndex].groupIndex = nq.groupIndex
   112  	}
   113  	nq.groupIndex = -1
   114  	ag.nodes = ag.nodes[:l]
   115  	ag.sumFlatWeight -= nq.flatWeight
   116  	if l >= 1 {
   117  		ag.groupWeight = ag.sumFlatWeight / uint64(l)
   118  	} else {
   119  		ag.groupWeight = 0
   120  	}
   121  	ag.nodeSelect.Remove(nq)
   122  }
   123  
   124  // choose selects one of the node queues belonging to the address group
   125  func (ag *addressGroup) choose() *nodeQueue {
   126  	return ag.nodeSelect.Choose().(*nodeQueue)
   127  }
   128  
   129  // NewLimiter creates a new Limiter
   130  func NewLimiter(sumCostLimit uint) *Limiter {
   131  	l := &Limiter{
   132  		addressSelect: NewWeightedRandomSelect(func(item interface{}) uint64 { return item.(*addressGroup).groupWeight }),
   133  		valueSelect:   NewWeightedRandomSelect(func(item interface{}) uint64 { return item.(*nodeQueue).valueWeight }),
   134  		nodes:         make(map[enode.ID]*nodeQueue),
   135  		addresses:     make(map[string]*addressGroup),
   136  		sumCostLimit:  sumCostLimit,
   137  	}
   138  	l.cond = sync.NewCond(&l.lock)
   139  	go l.processLoop()
   140  	return l
   141  }
   142  
   143  // selectionWeights calculates the selection weights of a node for both the address and
   144  // the value selector. The selection weight depends on the next request cost or the
   145  // summed cost of recently dropped requests.
   146  func (l *Limiter) selectionWeights(reqCost uint, value float64) (flatWeight, valueWeight uint64) {
   147  	if value > l.maxValue {
   148  		l.maxValue = value
   149  	}
   150  	if value > 0 {
   151  		// normalize value to <= 1
   152  		value /= l.maxValue
   153  	}
   154  	if reqCost > l.maxCost {
   155  		l.maxCost = reqCost
   156  	}
   157  	relCost := float64(reqCost) / float64(l.maxCost)
   158  	var f float64
   159  	if relCost <= 0.001 {
   160  		f = 1
   161  	} else {
   162  		f = 0.001 / relCost
   163  	}
   164  	f *= maxSelectionWeight
   165  	flatWeight, valueWeight = uint64(f), uint64(f*value)
   166  	if flatWeight == 0 {
   167  		flatWeight = 1
   168  	}
   169  	return
   170  }
   171  
   172  // Add adds a new request to the node queue belonging to the given id. Value belongs
   173  // to the requesting node. A higher value gives the request a higher chance of being
   174  // served quickly in case of heavy load or a DDoS attack. Cost is a rough estimate
   175  // of the serving cost of the request. A lower cost also gives the request a
   176  // better chance.
   177  func (l *Limiter) Add(id enode.ID, address string, value float64, reqCost uint) chan chan struct{} {
   178  	l.lock.Lock()
   179  	defer l.lock.Unlock()
   180  
   181  	process := make(chan chan struct{}, 1)
   182  	if l.quit {
   183  		close(process)
   184  		return process
   185  	}
   186  	if reqCost == 0 {
   187  		reqCost = 1
   188  	}
   189  	if nq, ok := l.nodes[id]; ok {
   190  		if nq.queue != nil {
   191  			nq.queue = append(nq.queue, request{process, reqCost})
   192  			nq.sumCost += reqCost
   193  			nq.value = value
   194  			if address != nq.address {
   195  				// known id sending request from a new address, move to different address group
   196  				l.removeFromGroup(nq)
   197  				l.addToGroup(nq, address)
   198  			}
   199  		} else {
   200  			// already waiting on a penalty, just add to the penalty cost and drop the request
   201  			nq.penaltyCost += reqCost
   202  			l.update(nq)
   203  			close(process)
   204  			return process
   205  		}
   206  	} else {
   207  		nq := &nodeQueue{
   208  			queue:      []request{{process, reqCost}},
   209  			id:         id,
   210  			value:      value,
   211  			sumCost:    reqCost,
   212  			groupIndex: -1,
   213  		}
   214  		nq.flatWeight, nq.valueWeight = l.selectionWeights(reqCost, value)
   215  		if len(l.nodes) == 0 {
   216  			l.cond.Signal()
   217  		}
   218  		l.nodes[id] = nq
   219  		if nq.valueWeight != 0 {
   220  			l.valueSelect.Update(nq)
   221  		}
   222  		l.addToGroup(nq, address)
   223  	}
   224  	l.sumCost += reqCost
   225  	if l.sumCost > l.sumCostLimit {
   226  		l.dropRequests()
   227  	}
   228  	return process
   229  }
   230  
   231  // update updates the selection weights of the node queue
   232  func (l *Limiter) update(nq *nodeQueue) {
   233  	var cost uint
   234  	if nq.queue != nil {
   235  		cost = nq.queue[0].cost
   236  	} else {
   237  		cost = nq.penaltyCost
   238  	}
   239  	flatWeight, valueWeight := l.selectionWeights(cost, nq.value)
   240  	ag := l.addresses[nq.address]
   241  	ag.update(nq, flatWeight)
   242  	l.addressSelect.Update(ag)
   243  	nq.valueWeight = valueWeight
   244  	l.valueSelect.Update(nq)
   245  }
   246  
   247  // addToGroup adds the node queue to the given address group. The group is created if
   248  // it does not exist yet.
   249  func (l *Limiter) addToGroup(nq *nodeQueue, address string) {
   250  	nq.address = address
   251  	ag := l.addresses[address]
   252  	if ag == nil {
   253  		ag = &addressGroup{nodeSelect: NewWeightedRandomSelect(flatWeight)}
   254  		l.addresses[address] = ag
   255  	}
   256  	ag.add(nq)
   257  	l.addressSelect.Update(ag)
   258  }
   259  
   260  // removeFromGroup removes the node queue from its address group
   261  func (l *Limiter) removeFromGroup(nq *nodeQueue) {
   262  	ag := l.addresses[nq.address]
   263  	ag.remove(nq)
   264  	if len(ag.nodes) == 0 {
   265  		delete(l.addresses, nq.address)
   266  	}
   267  	l.addressSelect.Update(ag)
   268  }
   269  
   270  // remove removes the node queue from its address group, the nodes map and the value
   271  // selector
   272  func (l *Limiter) remove(nq *nodeQueue) {
   273  	l.removeFromGroup(nq)
   274  	if nq.valueWeight != 0 {
   275  		l.valueSelect.Remove(nq)
   276  	}
   277  	delete(l.nodes, nq.id)
   278  }
   279  
   280  // choose selects the next node queue to process.
   281  func (l *Limiter) choose() *nodeQueue {
   282  	if l.valueSelect.IsEmpty() || l.selectAddressNext {
   283  		if ag, ok := l.addressSelect.Choose().(*addressGroup); ok {
   284  			l.selectAddressNext = false
   285  			return ag.choose()
   286  		}
   287  	}
   288  	nq, _ := l.valueSelect.Choose().(*nodeQueue)
   289  	l.selectAddressNext = true
   290  	return nq
   291  }
   292  
   293  // processLoop processes requests sequentially
   294  func (l *Limiter) processLoop() {
   295  	l.lock.Lock()
   296  	defer l.lock.Unlock()
   297  
   298  	for {
   299  		if l.quit {
   300  			for _, nq := range l.nodes {
   301  				for _, request := range nq.queue {
   302  					close(request.process)
   303  				}
   304  			}
   305  			return
   306  		}
   307  		nq := l.choose()
   308  		if nq == nil {
   309  			l.cond.Wait()
   310  			continue
   311  		}
   312  		if nq.queue != nil {
   313  			request := nq.queue[0]
   314  			nq.queue = nq.queue[1:]
   315  			nq.sumCost -= request.cost
   316  			l.sumCost -= request.cost
   317  			l.lock.Unlock()
   318  			ch := make(chan struct{})
   319  			request.process <- ch
   320  			<-ch
   321  			l.lock.Lock()
   322  			if len(nq.queue) > 0 {
   323  				l.update(nq)
   324  			} else {
   325  				l.remove(nq)
   326  			}
   327  		} else {
   328  			// penalized queue removed, next request will be added to a clean queue
   329  			l.remove(nq)
   330  		}
   331  	}
   332  }
   333  
   334  // Stop stops the processing loop. All queued and future requests are rejected.
   335  func (l *Limiter) Stop() {
   336  	l.lock.Lock()
   337  	defer l.lock.Unlock()
   338  
   339  	l.quit = true
   340  	l.cond.Signal()
   341  }
   342  
   343  type (
   344  	dropList     []dropListItem
   345  	dropListItem struct {
   346  		nq       *nodeQueue
   347  		priority float64
   348  	}
   349  )
   350  
   351  func (l dropList) Len() int {
   352  	return len(l)
   353  }
   354  
   355  func (l dropList) Less(i, j int) bool {
   356  	return l[i].priority < l[j].priority
   357  }
   358  
   359  func (l dropList) Swap(i, j int) {
   360  	l[i], l[j] = l[j], l[i]
   361  }
   362  
   363  // dropRequests selects the nodes with the highest queued request cost to selection
   364  // weight ratio and drops their queued request. The empty node queues stay in the
   365  // selectors with a low selection weight in order to penalize these nodes.
   366  func (l *Limiter) dropRequests() {
   367  	var (
   368  		sumValue float64
   369  		list     dropList
   370  	)
   371  	for _, nq := range l.nodes {
   372  		sumValue += nq.value
   373  	}
   374  	for _, nq := range l.nodes {
   375  		if nq.sumCost == 0 {
   376  			continue
   377  		}
   378  		w := 1 / float64(len(l.addresses)*len(l.addresses[nq.address].nodes))
   379  		if sumValue > 0 {
   380  			w += nq.value / sumValue
   381  		}
   382  		list = append(list, dropListItem{
   383  			nq:       nq,
   384  			priority: w / float64(nq.sumCost),
   385  		})
   386  	}
   387  	sort.Sort(list)
   388  	for _, item := range list {
   389  		for _, request := range item.nq.queue {
   390  			close(request.process)
   391  		}
   392  		// make the queue penalized; no more requests are accepted until the node is
   393  		// selected based on the penalty cost which is the cumulative cost of all dropped
   394  		// requests. This ensures that sending excess requests is always penalized
   395  		// and incentivizes the sender to stop for a while if no replies are received.
   396  		item.nq.queue = nil
   397  		item.nq.penaltyCost = item.nq.sumCost
   398  		l.sumCost -= item.nq.sumCost // penalty costs are not counted in sumCost
   399  		item.nq.sumCost = 0
   400  		l.update(item.nq)
   401  		if l.sumCost <= l.sumCostLimit/2 {
   402  			return
   403  		}
   404  	}
   405  }