github.com/cryptogateway/go-paymex@v0.0.0-20210204174735-96277fb1e602/les/utils/weighted_select.go (about) 1 // Copyright 2016 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package utils 18 19 import ( 20 "math" 21 "math/rand" 22 23 "github.com/cryptogateway/go-paymex/log" 24 ) 25 26 type ( 27 // WeightedRandomSelect is capable of weighted random selection from a set of items 28 WeightedRandomSelect struct { 29 root *wrsNode 30 idx map[WrsItem]int 31 wfn WeightFn 32 } 33 WrsItem interface{} 34 WeightFn func(interface{}) uint64 35 ) 36 37 // NewWeightedRandomSelect returns a new WeightedRandomSelect structure 38 func NewWeightedRandomSelect(wfn WeightFn) *WeightedRandomSelect { 39 return &WeightedRandomSelect{root: &wrsNode{maxItems: wrsBranches}, idx: make(map[WrsItem]int), wfn: wfn} 40 } 41 42 // Update updates an item's weight, adds it if it was non-existent or removes it if 43 // the new weight is zero. Note that explicitly updating decreasing weights is not necessary. 44 func (w *WeightedRandomSelect) Update(item WrsItem) { 45 w.setWeight(item, w.wfn(item)) 46 } 47 48 // Remove removes an item from the set 49 func (w *WeightedRandomSelect) Remove(item WrsItem) { 50 w.setWeight(item, 0) 51 } 52 53 // IsEmpty returns true if the set is empty 54 func (w *WeightedRandomSelect) IsEmpty() bool { 55 return w.root.sumCost == 0 56 } 57 58 // setWeight sets an item's weight to a specific value (removes it if zero) 59 func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) { 60 if weight > math.MaxInt64-w.root.sumCost { 61 // old weight is still included in sumCost, remove and check again 62 w.setWeight(item, 0) 63 if weight > math.MaxInt64-w.root.sumCost { 64 log.Error("WeightedRandomSelect overflow", "sumCost", w.root.sumCost, "new weight", weight) 65 weight = math.MaxInt64 - w.root.sumCost 66 } 67 } 68 idx, ok := w.idx[item] 69 if ok { 70 w.root.setWeight(idx, weight) 71 if weight == 0 { 72 delete(w.idx, item) 73 } 74 } else { 75 if weight != 0 { 76 if w.root.itemCnt == w.root.maxItems { 77 // add a new level 78 newRoot := &wrsNode{sumCost: w.root.sumCost, itemCnt: w.root.itemCnt, level: w.root.level + 1, maxItems: w.root.maxItems * wrsBranches} 79 newRoot.items[0] = w.root 80 newRoot.weights[0] = w.root.sumCost 81 w.root = newRoot 82 } 83 w.idx[item] = w.root.insert(item, weight) 84 } 85 } 86 } 87 88 // Choose randomly selects an item from the set, with a chance proportional to its 89 // current weight. If the weight of the chosen element has been decreased since the 90 // last stored value, returns it with a newWeight/oldWeight chance, otherwise just 91 // updates its weight and selects another one 92 func (w *WeightedRandomSelect) Choose() WrsItem { 93 for { 94 if w.root.sumCost == 0 { 95 return nil 96 } 97 val := uint64(rand.Int63n(int64(w.root.sumCost))) 98 choice, lastWeight := w.root.choose(val) 99 weight := w.wfn(choice) 100 if weight != lastWeight { 101 w.setWeight(choice, weight) 102 } 103 if weight >= lastWeight || uint64(rand.Int63n(int64(lastWeight))) < weight { 104 return choice 105 } 106 } 107 } 108 109 const wrsBranches = 8 // max number of branches in the wrsNode tree 110 111 // wrsNode is a node of a tree structure that can store WrsItems or further wrsNodes. 112 type wrsNode struct { 113 items [wrsBranches]interface{} 114 weights [wrsBranches]uint64 115 sumCost uint64 116 level, itemCnt, maxItems int 117 } 118 119 // insert recursively inserts a new item to the tree and returns the item index 120 func (n *wrsNode) insert(item WrsItem, weight uint64) int { 121 branch := 0 122 for n.items[branch] != nil && (n.level == 0 || n.items[branch].(*wrsNode).itemCnt == n.items[branch].(*wrsNode).maxItems) { 123 branch++ 124 if branch == wrsBranches { 125 panic(nil) 126 } 127 } 128 n.itemCnt++ 129 n.sumCost += weight 130 n.weights[branch] += weight 131 if n.level == 0 { 132 n.items[branch] = item 133 return branch 134 } 135 var subNode *wrsNode 136 if n.items[branch] == nil { 137 subNode = &wrsNode{maxItems: n.maxItems / wrsBranches, level: n.level - 1} 138 n.items[branch] = subNode 139 } else { 140 subNode = n.items[branch].(*wrsNode) 141 } 142 subIdx := subNode.insert(item, weight) 143 return subNode.maxItems*branch + subIdx 144 } 145 146 // setWeight updates the weight of a certain item (which should exist) and returns 147 // the change of the last weight value stored in the tree 148 func (n *wrsNode) setWeight(idx int, weight uint64) uint64 { 149 if n.level == 0 { 150 oldWeight := n.weights[idx] 151 n.weights[idx] = weight 152 diff := weight - oldWeight 153 n.sumCost += diff 154 if weight == 0 { 155 n.items[idx] = nil 156 n.itemCnt-- 157 } 158 return diff 159 } 160 branchItems := n.maxItems / wrsBranches 161 branch := idx / branchItems 162 diff := n.items[branch].(*wrsNode).setWeight(idx-branch*branchItems, weight) 163 n.weights[branch] += diff 164 n.sumCost += diff 165 if weight == 0 { 166 n.itemCnt-- 167 } 168 return diff 169 } 170 171 // choose recursively selects an item from the tree and returns it along with its weight 172 func (n *wrsNode) choose(val uint64) (WrsItem, uint64) { 173 for i, w := range n.weights { 174 if val < w { 175 if n.level == 0 { 176 return n.items[i].(WrsItem), n.weights[i] 177 } 178 return n.items[i].(*wrsNode).choose(val) 179 } 180 val -= w 181 } 182 panic(nil) 183 }