github.com/MetalBlockchain/metalgo@v1.11.9/utils/sampler/weighted_linear.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package sampler 5 6 import ( 7 "cmp" 8 9 "github.com/MetalBlockchain/metalgo/utils" 10 "github.com/MetalBlockchain/metalgo/utils/math" 11 ) 12 13 var ( 14 _ Weighted = (*weightedLinear)(nil) 15 _ utils.Sortable[weightedLinearElement] = weightedLinearElement{} 16 ) 17 18 type weightedLinearElement struct { 19 cumulativeWeight uint64 20 index int 21 } 22 23 // Note that this sorts in order of decreasing cumulative weight. 24 func (e weightedLinearElement) Compare(other weightedLinearElement) int { 25 return cmp.Compare(other.cumulativeWeight, e.cumulativeWeight) 26 } 27 28 // Sampling is performed by executing a linear search over the provided elements 29 // in the order of their probabilistic occurrence. 30 // 31 // Initialization takes O(n * log(n)) time, where n is the number of elements 32 // that can be sampled. 33 // Sampling can take up to O(n) time. As the distribution becomes more biased, 34 // sampling will become faster in expectation. 35 type weightedLinear struct { 36 arr []weightedLinearElement 37 } 38 39 func (s *weightedLinear) Initialize(weights []uint64) error { 40 numWeights := len(weights) 41 if numWeights <= cap(s.arr) { 42 s.arr = s.arr[:numWeights] 43 } else { 44 s.arr = make([]weightedLinearElement, numWeights) 45 } 46 47 for i, weight := range weights { 48 s.arr[i] = weightedLinearElement{ 49 cumulativeWeight: weight, 50 index: i, 51 } 52 } 53 54 // Optimize so that the most probable values are at the front of the array 55 utils.Sort(s.arr) 56 57 for i := 1; i < len(s.arr); i++ { 58 newWeight, err := math.Add64( 59 s.arr[i-1].cumulativeWeight, 60 s.arr[i].cumulativeWeight, 61 ) 62 if err != nil { 63 return err 64 } 65 s.arr[i].cumulativeWeight = newWeight 66 } 67 68 return nil 69 } 70 71 func (s *weightedLinear) Sample(value uint64) (int, bool) { 72 if len(s.arr) == 0 || s.arr[len(s.arr)-1].cumulativeWeight <= value { 73 return 0, false 74 } 75 76 index := 0 77 for { 78 if elem := s.arr[index]; value < elem.cumulativeWeight { 79 return elem.index, true 80 } 81 index++ 82 } 83 }