github.com/yuanzimu/bsc@v1.1.4/les/utils/limiter.go (about) 1 // Copyright 2020 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package utils 18 19 import ( 20 "sort" 21 "sync" 22 23 "github.com/ethereum/go-ethereum/p2p/enode" 24 ) 25 26 const maxSelectionWeight = 1000000000 // maximum selection weight of each individual node/address group 27 28 // Limiter protects a network request serving mechanism from denial-of-service attacks. 29 // It limits the total amount of resources used for serving requests while ensuring that 30 // the most valuable connections always have a reasonable chance of being served. 31 type Limiter struct { 32 lock sync.Mutex 33 cond *sync.Cond 34 quit bool 35 36 nodes map[enode.ID]*nodeQueue 37 addresses map[string]*addressGroup 38 addressSelect, valueSelect *WeightedRandomSelect 39 maxValue float64 40 maxCost, sumCost, sumCostLimit uint 41 selectAddressNext bool 42 } 43 44 // nodeQueue represents queued requests coming from a single node ID 45 type nodeQueue struct { 46 queue []request // always nil if penaltyCost != 0 47 id enode.ID 48 address string 49 value float64 50 flatWeight, valueWeight uint64 // current selection weights in the address/value selectors 51 sumCost uint // summed cost of requests queued by the node 52 penaltyCost uint // cumulative cost of dropped requests since last processed request 53 groupIndex int 54 } 55 56 // addressGroup is a group of node IDs that have sent their last requests from the same 57 // network address 58 type addressGroup struct { 59 nodes []*nodeQueue 60 nodeSelect *WeightedRandomSelect 61 sumFlatWeight, groupWeight uint64 62 } 63 64 // request represents an incoming request scheduled for processing 65 type request struct { 66 process chan chan struct{} 67 cost uint 68 } 69 70 // flatWeight distributes weights equally between each active network address 71 func flatWeight(item interface{}) uint64 { return item.(*nodeQueue).flatWeight } 72 73 // add adds the node queue to the address group. It is the caller's responsibility to 74 // add the address group to the address map and the address selector if it wasn't 75 // there before. 76 func (ag *addressGroup) add(nq *nodeQueue) { 77 if nq.groupIndex != -1 { 78 panic("added node queue is already in an address group") 79 } 80 l := len(ag.nodes) 81 nq.groupIndex = l 82 ag.nodes = append(ag.nodes, nq) 83 ag.sumFlatWeight += nq.flatWeight 84 ag.groupWeight = ag.sumFlatWeight / uint64(l+1) 85 ag.nodeSelect.Update(ag.nodes[l]) 86 } 87 88 // update updates the selection weight of the node queue inside the address group. 89 // It is the caller's responsibility to update the group's selection weight in the 90 // address selector. 91 func (ag *addressGroup) update(nq *nodeQueue, weight uint64) { 92 if nq.groupIndex == -1 || nq.groupIndex >= len(ag.nodes) || ag.nodes[nq.groupIndex] != nq { 93 panic("updated node queue is not in this address group") 94 } 95 ag.sumFlatWeight += weight - nq.flatWeight 96 nq.flatWeight = weight 97 ag.groupWeight = ag.sumFlatWeight / uint64(len(ag.nodes)) 98 ag.nodeSelect.Update(nq) 99 } 100 101 // remove removes the node queue from the address group. It is the caller's responsibility 102 // to remove the address group from the address map if it is empty. 103 func (ag *addressGroup) remove(nq *nodeQueue) { 104 if nq.groupIndex == -1 || nq.groupIndex >= len(ag.nodes) || ag.nodes[nq.groupIndex] != nq { 105 panic("removed node queue is not in this address group") 106 } 107 108 l := len(ag.nodes) - 1 109 if nq.groupIndex != l { 110 ag.nodes[nq.groupIndex] = ag.nodes[l] 111 ag.nodes[nq.groupIndex].groupIndex = nq.groupIndex 112 } 113 nq.groupIndex = -1 114 ag.nodes = ag.nodes[:l] 115 ag.sumFlatWeight -= nq.flatWeight 116 if l >= 1 { 117 ag.groupWeight = ag.sumFlatWeight / uint64(l) 118 } else { 119 ag.groupWeight = 0 120 } 121 ag.nodeSelect.Remove(nq) 122 } 123 124 // choose selects one of the node queues belonging to the address group 125 func (ag *addressGroup) choose() *nodeQueue { 126 return ag.nodeSelect.Choose().(*nodeQueue) 127 } 128 129 // NewLimiter creates a new Limiter 130 func NewLimiter(sumCostLimit uint) *Limiter { 131 l := &Limiter{ 132 addressSelect: NewWeightedRandomSelect(func(item interface{}) uint64 { return item.(*addressGroup).groupWeight }), 133 valueSelect: NewWeightedRandomSelect(func(item interface{}) uint64 { return item.(*nodeQueue).valueWeight }), 134 nodes: make(map[enode.ID]*nodeQueue), 135 addresses: make(map[string]*addressGroup), 136 sumCostLimit: sumCostLimit, 137 } 138 l.cond = sync.NewCond(&l.lock) 139 go l.processLoop() 140 return l 141 } 142 143 // selectionWeights calculates the selection weights of a node for both the address and 144 // the value selector. The selection weight depends on the next request cost or the 145 // summed cost of recently dropped requests. 146 func (l *Limiter) selectionWeights(reqCost uint, value float64) (flatWeight, valueWeight uint64) { 147 if value > l.maxValue { 148 l.maxValue = value 149 } 150 if value > 0 { 151 // normalize value to <= 1 152 value /= l.maxValue 153 } 154 if reqCost > l.maxCost { 155 l.maxCost = reqCost 156 } 157 relCost := float64(reqCost) / float64(l.maxCost) 158 var f float64 159 if relCost <= 0.001 { 160 f = 1 161 } else { 162 f = 0.001 / relCost 163 } 164 f *= maxSelectionWeight 165 flatWeight, valueWeight = uint64(f), uint64(f*value) 166 if flatWeight == 0 { 167 flatWeight = 1 168 } 169 return 170 } 171 172 // Add adds a new request to the node queue belonging to the given id. Value belongs 173 // to the requesting node. A higher value gives the request a higher chance of being 174 // served quickly in case of heavy load or a DDoS attack. Cost is a rough estimate 175 // of the serving cost of the request. A lower cost also gives the request a 176 // better chance. 177 func (l *Limiter) Add(id enode.ID, address string, value float64, reqCost uint) chan chan struct{} { 178 l.lock.Lock() 179 defer l.lock.Unlock() 180 181 process := make(chan chan struct{}, 1) 182 if l.quit { 183 close(process) 184 return process 185 } 186 if reqCost == 0 { 187 reqCost = 1 188 } 189 if nq, ok := l.nodes[id]; ok { 190 if nq.queue != nil { 191 nq.queue = append(nq.queue, request{process, reqCost}) 192 nq.sumCost += reqCost 193 nq.value = value 194 if address != nq.address { 195 // known id sending request from a new address, move to different address group 196 l.removeFromGroup(nq) 197 l.addToGroup(nq, address) 198 } 199 } else { 200 // already waiting on a penalty, just add to the penalty cost and drop the request 201 nq.penaltyCost += reqCost 202 l.update(nq) 203 close(process) 204 return process 205 } 206 } else { 207 nq := &nodeQueue{ 208 queue: []request{{process, reqCost}}, 209 id: id, 210 value: value, 211 sumCost: reqCost, 212 groupIndex: -1, 213 } 214 nq.flatWeight, nq.valueWeight = l.selectionWeights(reqCost, value) 215 if len(l.nodes) == 0 { 216 l.cond.Signal() 217 } 218 l.nodes[id] = nq 219 if nq.valueWeight != 0 { 220 l.valueSelect.Update(nq) 221 } 222 l.addToGroup(nq, address) 223 } 224 l.sumCost += reqCost 225 if l.sumCost > l.sumCostLimit { 226 l.dropRequests() 227 } 228 return process 229 } 230 231 // update updates the selection weights of the node queue 232 func (l *Limiter) update(nq *nodeQueue) { 233 var cost uint 234 if nq.queue != nil { 235 cost = nq.queue[0].cost 236 } else { 237 cost = nq.penaltyCost 238 } 239 flatWeight, valueWeight := l.selectionWeights(cost, nq.value) 240 ag := l.addresses[nq.address] 241 ag.update(nq, flatWeight) 242 l.addressSelect.Update(ag) 243 nq.valueWeight = valueWeight 244 l.valueSelect.Update(nq) 245 } 246 247 // addToGroup adds the node queue to the given address group. The group is created if 248 // it does not exist yet. 249 func (l *Limiter) addToGroup(nq *nodeQueue, address string) { 250 nq.address = address 251 ag := l.addresses[address] 252 if ag == nil { 253 ag = &addressGroup{nodeSelect: NewWeightedRandomSelect(flatWeight)} 254 l.addresses[address] = ag 255 } 256 ag.add(nq) 257 l.addressSelect.Update(ag) 258 } 259 260 // removeFromGroup removes the node queue from its address group 261 func (l *Limiter) removeFromGroup(nq *nodeQueue) { 262 ag := l.addresses[nq.address] 263 ag.remove(nq) 264 if len(ag.nodes) == 0 { 265 delete(l.addresses, nq.address) 266 } 267 l.addressSelect.Update(ag) 268 } 269 270 // remove removes the node queue from its address group, the nodes map and the value 271 // selector 272 func (l *Limiter) remove(nq *nodeQueue) { 273 l.removeFromGroup(nq) 274 if nq.valueWeight != 0 { 275 l.valueSelect.Remove(nq) 276 } 277 delete(l.nodes, nq.id) 278 } 279 280 // choose selects the next node queue to process. 281 func (l *Limiter) choose() *nodeQueue { 282 if l.valueSelect.IsEmpty() || l.selectAddressNext { 283 if ag, ok := l.addressSelect.Choose().(*addressGroup); ok { 284 l.selectAddressNext = false 285 return ag.choose() 286 } 287 } 288 nq, _ := l.valueSelect.Choose().(*nodeQueue) 289 l.selectAddressNext = true 290 return nq 291 } 292 293 // processLoop processes requests sequentially 294 func (l *Limiter) processLoop() { 295 l.lock.Lock() 296 defer l.lock.Unlock() 297 298 for { 299 if l.quit { 300 for _, nq := range l.nodes { 301 for _, request := range nq.queue { 302 close(request.process) 303 } 304 } 305 return 306 } 307 nq := l.choose() 308 if nq == nil { 309 l.cond.Wait() 310 continue 311 } 312 if nq.queue != nil { 313 request := nq.queue[0] 314 nq.queue = nq.queue[1:] 315 nq.sumCost -= request.cost 316 l.sumCost -= request.cost 317 l.lock.Unlock() 318 ch := make(chan struct{}) 319 request.process <- ch 320 <-ch 321 l.lock.Lock() 322 if len(nq.queue) > 0 { 323 l.update(nq) 324 } else { 325 l.remove(nq) 326 } 327 } else { 328 // penalized queue removed, next request will be added to a clean queue 329 l.remove(nq) 330 } 331 } 332 } 333 334 // Stop stops the processing loop. All queued and future requests are rejected. 335 func (l *Limiter) Stop() { 336 l.lock.Lock() 337 defer l.lock.Unlock() 338 339 l.quit = true 340 l.cond.Signal() 341 } 342 343 type ( 344 dropList []dropListItem 345 dropListItem struct { 346 nq *nodeQueue 347 priority float64 348 } 349 ) 350 351 func (l dropList) Len() int { 352 return len(l) 353 } 354 355 func (l dropList) Less(i, j int) bool { 356 return l[i].priority < l[j].priority 357 } 358 359 func (l dropList) Swap(i, j int) { 360 l[i], l[j] = l[j], l[i] 361 } 362 363 // dropRequests selects the nodes with the highest queued request cost to selection 364 // weight ratio and drops their queued request. The empty node queues stay in the 365 // selectors with a low selection weight in order to penalize these nodes. 366 func (l *Limiter) dropRequests() { 367 var ( 368 sumValue float64 369 list dropList 370 ) 371 for _, nq := range l.nodes { 372 sumValue += nq.value 373 } 374 for _, nq := range l.nodes { 375 if nq.sumCost == 0 { 376 continue 377 } 378 w := 1 / float64(len(l.addresses)*len(l.addresses[nq.address].nodes)) 379 if sumValue > 0 { 380 w += nq.value / sumValue 381 } 382 list = append(list, dropListItem{ 383 nq: nq, 384 priority: w / float64(nq.sumCost), 385 }) 386 } 387 sort.Sort(list) 388 for _, item := range list { 389 for _, request := range item.nq.queue { 390 close(request.process) 391 } 392 // make the queue penalized; no more requests are accepted until the node is 393 // selected based on the penalty cost which is the cumulative cost of all dropped 394 // requests. This ensures that sending excess requests is always penalized 395 // and incentivizes the sender to stop for a while if no replies are received. 396 item.nq.queue = nil 397 item.nq.penaltyCost = item.nq.sumCost 398 l.sumCost -= item.nq.sumCost // penalty costs are not counted in sumCost 399 item.nq.sumCost = 0 400 l.update(item.nq) 401 if l.sumCost <= l.sumCostLimit/2 { 402 return 403 } 404 } 405 }