github.com/cryptogateway/go-paymex@v0.0.0-20210204174735-96277fb1e602/les/utils/weighted_select.go (about)

     1  // Copyright 2016 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package utils
    18  
    19  import (
    20  	"math"
    21  	"math/rand"
    22  
    23  	"github.com/cryptogateway/go-paymex/log"
    24  )
    25  
    26  type (
    27  	// WeightedRandomSelect is capable of weighted random selection from a set of items
    28  	WeightedRandomSelect struct {
    29  		root *wrsNode
    30  		idx  map[WrsItem]int
    31  		wfn  WeightFn
    32  	}
    33  	WrsItem  interface{}
    34  	WeightFn func(interface{}) uint64
    35  )
    36  
    37  // NewWeightedRandomSelect returns a new WeightedRandomSelect structure
    38  func NewWeightedRandomSelect(wfn WeightFn) *WeightedRandomSelect {
    39  	return &WeightedRandomSelect{root: &wrsNode{maxItems: wrsBranches}, idx: make(map[WrsItem]int), wfn: wfn}
    40  }
    41  
    42  // Update updates an item's weight, adds it if it was non-existent or removes it if
    43  // the new weight is zero. Note that explicitly updating decreasing weights is not necessary.
    44  func (w *WeightedRandomSelect) Update(item WrsItem) {
    45  	w.setWeight(item, w.wfn(item))
    46  }
    47  
    48  // Remove removes an item from the set
    49  func (w *WeightedRandomSelect) Remove(item WrsItem) {
    50  	w.setWeight(item, 0)
    51  }
    52  
    53  // IsEmpty returns true if the set is empty
    54  func (w *WeightedRandomSelect) IsEmpty() bool {
    55  	return w.root.sumCost == 0
    56  }
    57  
    58  // setWeight sets an item's weight to a specific value (removes it if zero)
    59  func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) {
    60  	if weight > math.MaxInt64-w.root.sumCost {
    61  		// old weight is still included in sumCost, remove and check again
    62  		w.setWeight(item, 0)
    63  		if weight > math.MaxInt64-w.root.sumCost {
    64  			log.Error("WeightedRandomSelect overflow", "sumCost", w.root.sumCost, "new weight", weight)
    65  			weight = math.MaxInt64 - w.root.sumCost
    66  		}
    67  	}
    68  	idx, ok := w.idx[item]
    69  	if ok {
    70  		w.root.setWeight(idx, weight)
    71  		if weight == 0 {
    72  			delete(w.idx, item)
    73  		}
    74  	} else {
    75  		if weight != 0 {
    76  			if w.root.itemCnt == w.root.maxItems {
    77  				// add a new level
    78  				newRoot := &wrsNode{sumCost: w.root.sumCost, itemCnt: w.root.itemCnt, level: w.root.level + 1, maxItems: w.root.maxItems * wrsBranches}
    79  				newRoot.items[0] = w.root
    80  				newRoot.weights[0] = w.root.sumCost
    81  				w.root = newRoot
    82  			}
    83  			w.idx[item] = w.root.insert(item, weight)
    84  		}
    85  	}
    86  }
    87  
    88  // Choose randomly selects an item from the set, with a chance proportional to its
    89  // current weight. If the weight of the chosen element has been decreased since the
    90  // last stored value, returns it with a newWeight/oldWeight chance, otherwise just
    91  // updates its weight and selects another one
    92  func (w *WeightedRandomSelect) Choose() WrsItem {
    93  	for {
    94  		if w.root.sumCost == 0 {
    95  			return nil
    96  		}
    97  		val := uint64(rand.Int63n(int64(w.root.sumCost)))
    98  		choice, lastWeight := w.root.choose(val)
    99  		weight := w.wfn(choice)
   100  		if weight != lastWeight {
   101  			w.setWeight(choice, weight)
   102  		}
   103  		if weight >= lastWeight || uint64(rand.Int63n(int64(lastWeight))) < weight {
   104  			return choice
   105  		}
   106  	}
   107  }
   108  
   109  const wrsBranches = 8 // max number of branches in the wrsNode tree
   110  
   111  // wrsNode is a node of a tree structure that can store WrsItems or further wrsNodes.
   112  type wrsNode struct {
   113  	items                    [wrsBranches]interface{}
   114  	weights                  [wrsBranches]uint64
   115  	sumCost                  uint64
   116  	level, itemCnt, maxItems int
   117  }
   118  
   119  // insert recursively inserts a new item to the tree and returns the item index
   120  func (n *wrsNode) insert(item WrsItem, weight uint64) int {
   121  	branch := 0
   122  	for n.items[branch] != nil && (n.level == 0 || n.items[branch].(*wrsNode).itemCnt == n.items[branch].(*wrsNode).maxItems) {
   123  		branch++
   124  		if branch == wrsBranches {
   125  			panic(nil)
   126  		}
   127  	}
   128  	n.itemCnt++
   129  	n.sumCost += weight
   130  	n.weights[branch] += weight
   131  	if n.level == 0 {
   132  		n.items[branch] = item
   133  		return branch
   134  	}
   135  	var subNode *wrsNode
   136  	if n.items[branch] == nil {
   137  		subNode = &wrsNode{maxItems: n.maxItems / wrsBranches, level: n.level - 1}
   138  		n.items[branch] = subNode
   139  	} else {
   140  		subNode = n.items[branch].(*wrsNode)
   141  	}
   142  	subIdx := subNode.insert(item, weight)
   143  	return subNode.maxItems*branch + subIdx
   144  }
   145  
   146  // setWeight updates the weight of a certain item (which should exist) and returns
   147  // the change of the last weight value stored in the tree
   148  func (n *wrsNode) setWeight(idx int, weight uint64) uint64 {
   149  	if n.level == 0 {
   150  		oldWeight := n.weights[idx]
   151  		n.weights[idx] = weight
   152  		diff := weight - oldWeight
   153  		n.sumCost += diff
   154  		if weight == 0 {
   155  			n.items[idx] = nil
   156  			n.itemCnt--
   157  		}
   158  		return diff
   159  	}
   160  	branchItems := n.maxItems / wrsBranches
   161  	branch := idx / branchItems
   162  	diff := n.items[branch].(*wrsNode).setWeight(idx-branch*branchItems, weight)
   163  	n.weights[branch] += diff
   164  	n.sumCost += diff
   165  	if weight == 0 {
   166  		n.itemCnt--
   167  	}
   168  	return diff
   169  }
   170  
   171  // choose recursively selects an item from the tree and returns it along with its weight
   172  func (n *wrsNode) choose(val uint64) (WrsItem, uint64) {
   173  	for i, w := range n.weights {
   174  		if val < w {
   175  			if n.level == 0 {
   176  				return n.items[i].(WrsItem), n.weights[i]
   177  			}
   178  			return n.items[i].(*wrsNode).choose(val)
   179  		}
   180  		val -= w
   181  	}
   182  	panic(nil)
   183  }