github.com/MetalBlockchain/metalgo@v1.11.9/utils/sampler/weighted_heap.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package sampler 5 6 import ( 7 "cmp" 8 9 "github.com/MetalBlockchain/metalgo/utils" 10 "github.com/MetalBlockchain/metalgo/utils/math" 11 ) 12 13 var ( 14 _ Weighted = (*weightedHeap)(nil) 15 _ utils.Sortable[weightedHeapElement] = weightedHeapElement{} 16 ) 17 18 type weightedHeapElement struct { 19 weight uint64 20 cumulativeWeight uint64 21 index int 22 } 23 24 // Compare the elements. Weight is in decreasing order. Index is in increasing 25 // order. 26 func (e weightedHeapElement) Compare(other weightedHeapElement) int { 27 // By accounting for the initial index of the weights, this results in a 28 // stable sort. We do this rather than using `sort.Stable` because of the 29 // reported change in performance of the sort used. 30 if weightCmp := cmp.Compare(other.weight, e.weight); weightCmp != 0 { 31 return weightCmp 32 } 33 return cmp.Compare(e.index, other.index) 34 } 35 36 // Sampling is performed by executing a search over a tree of elements in the 37 // order of their probabilistic occurrence. 38 // 39 // Initialization takes O(n * log(n)) time, where n is the number of elements 40 // that can be sampled. 41 // Sampling can take up to O(log(n)) time. As the distribution becomes more 42 // biased, sampling will become faster in expectation. 43 type weightedHeap struct { 44 heap []weightedHeapElement 45 } 46 47 func (s *weightedHeap) Initialize(weights []uint64) error { 48 numWeights := len(weights) 49 if numWeights <= cap(s.heap) { 50 s.heap = s.heap[:numWeights] 51 } else { 52 s.heap = make([]weightedHeapElement, numWeights) 53 } 54 for i, weight := range weights { 55 s.heap[i] = weightedHeapElement{ 56 weight: weight, 57 cumulativeWeight: weight, 58 index: i, 59 } 60 } 61 62 // Optimize so that the most probable values are at the top of the heap 63 utils.Sort(s.heap) 64 65 // Initialize the heap 66 for i := len(s.heap) - 1; i > 0; i-- { 67 // Explicitly performing a shift here allows the compiler to avoid 68 // checking for negative numbers, which saves a couple cycles 69 parentIndex := (i - 1) >> 1 70 newWeight, err := math.Add64( 71 s.heap[parentIndex].cumulativeWeight, 72 s.heap[i].cumulativeWeight, 73 ) 74 if err != nil { 75 return err 76 } 77 s.heap[parentIndex].cumulativeWeight = newWeight 78 } 79 80 return nil 81 } 82 83 func (s *weightedHeap) Sample(value uint64) (int, bool) { 84 if len(s.heap) == 0 || s.heap[0].cumulativeWeight <= value { 85 return 0, false 86 } 87 88 index := 0 89 for { 90 currentElement := s.heap[index] 91 currentWeight := currentElement.weight 92 if value < currentWeight { 93 return currentElement.index, true 94 } 95 value -= currentWeight 96 97 // We shouldn't return the root, so check the left child 98 index = index*2 + 1 99 100 if leftWeight := s.heap[index].cumulativeWeight; leftWeight <= value { 101 // If the weight is greater than the left weight, you should move to 102 // the right child 103 value -= leftWeight 104 index++ 105 } 106 } 107 }