github.com/MetalBlockchain/metalgo@v1.11.9/utils/sampler/weighted_array.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package sampler 5 6 import ( 7 "cmp" 8 9 "github.com/MetalBlockchain/metalgo/utils" 10 "github.com/MetalBlockchain/metalgo/utils/math" 11 ) 12 13 var ( 14 _ Weighted = (*weightedArray)(nil) 15 _ utils.Sortable[weightedArrayElement] = weightedArrayElement{} 16 ) 17 18 type weightedArrayElement struct { 19 cumulativeWeight uint64 20 index int 21 } 22 23 // Note that this sorts in order of decreasing weight. 24 func (e weightedArrayElement) Compare(other weightedArrayElement) int { 25 return cmp.Compare(other.cumulativeWeight, e.cumulativeWeight) 26 } 27 28 // Sampling is performed by executing a modified binary search over the provided 29 // elements. Rather than cutting the remaining dataset in half, the algorithm 30 // attempt to just in to where it think the value will be assuming a linear 31 // distribution of the element weights. 32 // 33 // Initialization takes O(n * log(n)) time, where n is the number of elements 34 // that can be sampled. 35 // Sampling can take up to O(n) time. If the distribution is linearly 36 // distributed, then the runtime is constant. 37 type weightedArray struct { 38 arr []weightedArrayElement 39 } 40 41 func (s *weightedArray) Initialize(weights []uint64) error { 42 numWeights := len(weights) 43 if numWeights <= cap(s.arr) { 44 s.arr = s.arr[:numWeights] 45 } else { 46 s.arr = make([]weightedArrayElement, numWeights) 47 } 48 49 for i, weight := range weights { 50 s.arr[i] = weightedArrayElement{ 51 cumulativeWeight: weight, 52 index: i, 53 } 54 } 55 56 // Optimize so that the array is closer to the uniform distribution 57 utils.Sort(s.arr) 58 59 maxIndex := len(s.arr) - 1 60 oneIfOdd := 1 & maxIndex 61 oneIfEven := 1 - oneIfOdd 62 end := maxIndex - oneIfEven 63 for i := 1; i < end; i += 2 { 64 s.arr[i], s.arr[end] = s.arr[end], s.arr[i] 65 end -= 2 66 } 67 68 cumulativeWeight := uint64(0) 69 for i := 0; i < len(s.arr); i++ { 70 newWeight, err := math.Add64( 71 cumulativeWeight, 72 s.arr[i].cumulativeWeight, 73 ) 74 if err != nil { 75 return err 76 } 77 cumulativeWeight = newWeight 78 s.arr[i].cumulativeWeight = cumulativeWeight 79 } 80 81 return nil 82 } 83 84 func (s *weightedArray) Sample(value uint64) (int, bool) { 85 if len(s.arr) == 0 || s.arr[len(s.arr)-1].cumulativeWeight <= value { 86 return 0, false 87 } 88 minIndex := 0 89 maxIndex := len(s.arr) - 1 90 maxCumulativeWeight := float64(s.arr[len(s.arr)-1].cumulativeWeight) 91 index := int((float64(value) * float64(maxIndex+1)) / maxCumulativeWeight) 92 93 for { 94 previousWeight := uint64(0) 95 if index > 0 { 96 previousWeight = s.arr[index-1].cumulativeWeight 97 } 98 currentElem := s.arr[index] 99 currentWeight := currentElem.cumulativeWeight 100 if previousWeight <= value && value < currentWeight { 101 return currentElem.index, true 102 } 103 104 if value < previousWeight { 105 // go to the left 106 maxIndex = index - 1 107 } else { 108 // go to the right 109 minIndex = index + 1 110 } 111 112 minWeight := uint64(0) 113 if minIndex > 0 { 114 minWeight = s.arr[minIndex-1].cumulativeWeight 115 } 116 maxWeight := s.arr[maxIndex].cumulativeWeight 117 118 valueRange := maxWeight - minWeight 119 adjustedLookupValue := value - minWeight 120 indexRange := maxIndex - minIndex + 1 121 lookupMass := float64(adjustedLookupValue) * float64(indexRange) 122 123 index = int(lookupMass/float64(valueRange)) + minIndex 124 } 125 }