gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/stat/sampleuv/weighted.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this code is governed by a BSD-style 3 // license that can be found in the LICENSE file 4 5 package sampleuv 6 7 import "golang.org/x/exp/rand" 8 9 // Weighted provides sampling without replacement from a collection of items with 10 // non-uniform probability. 11 type Weighted struct { 12 weights []float64 13 // heap is a weight heap. 14 // 15 // It keeps a heap-organised sum of remaining 16 // index weights that are available to be taken 17 // from. 18 // 19 // Each element holds the sum of weights for 20 // the corresponding index, plus the sum of 21 // its children's weights; the children of 22 // an element i can be found at positions 23 // 2*(i+1)-1 and 2*(i+1). The root of the 24 // weight heap is at element 0. 25 // 26 // See comments in container/heap for an 27 // explanation of the layout of a heap. 28 heap []float64 29 rnd *rand.Rand 30 } 31 32 // NewWeighted returns a Weighted for the weights w. If src is nil, rand.Rand is 33 // used as the random number generator. 34 // 35 // Note that sampling from weights with a high variance or overall low absolute 36 // value sum may result in problems with numerical stability. 37 func NewWeighted(w []float64, src rand.Source) Weighted { 38 s := Weighted{ 39 weights: make([]float64, len(w)), 40 heap: make([]float64, len(w)), 41 } 42 if src != nil { 43 s.rnd = rand.New(src) 44 } 45 s.ReweightAll(w) 46 return s 47 } 48 49 // Len returns the number of items held by the Weighted, including items 50 // already taken. 51 func (s Weighted) Len() int { return len(s.weights) } 52 53 // Take returns an index from the Weighted with probability proportional 54 // to the weight of the item. The weight of the item is then set to zero. 55 // Take returns false if there are no items remaining. 56 func (s Weighted) Take() (idx int, ok bool) { 57 if s.heap[0] == 0 { 58 return -1, false 59 } 60 61 var r float64 62 if s.rnd == nil { 63 r = rand.Float64() 64 } else { 65 r = s.rnd.Float64() 66 } 67 68 r *= s.heap[0] 69 i := 0 70 for { 71 r -= s.weights[i] 72 if r < 0 { 73 break // Fall within item i. 74 } 75 76 li := i*2 + 1 // Move to left child. 77 // Left node should exist, because r is non-negative, 78 // but there could be floating point errors, so we 79 // check index explicitly. 80 if li >= len(s.heap) { 81 break 82 } 83 84 i = li 85 86 d := s.heap[i] 87 if r >= d { 88 // If there is enough r to pass left child try to 89 // move to the right child. 90 r -= d 91 ri := i + 1 92 93 if ri >= len(s.heap) { 94 break 95 } 96 97 i = ri 98 } 99 } 100 101 s.Reweight(i, 0) 102 103 return i, true 104 } 105 106 // Reweight sets the weight of item idx to w. 107 func (s Weighted) Reweight(idx int, w float64) { 108 s.weights[idx] = w 109 110 // We want to keep the heap state here consistent 111 // with the result of a reset call. So we sum 112 // weights in the same order, since floating point 113 // addition is not associative. 114 for { 115 w = s.weights[idx] 116 117 ri := idx*2 + 2 118 if ri < len(s.heap) { 119 w += s.heap[ri] 120 } 121 122 li := ri - 1 123 if li < len(s.heap) { 124 w += s.heap[li] 125 } 126 127 s.heap[idx] = w 128 129 if idx == 0 { 130 break 131 } 132 133 idx = (idx - 1) / 2 134 } 135 } 136 137 // ReweightAll sets the weight of all items in the Weighted. ReweightAll 138 // panics if len(w) != s.Len. 139 func (s Weighted) ReweightAll(w []float64) { 140 if len(w) != s.Len() { 141 panic("floats: length of the slices do not match") 142 } 143 copy(s.weights, w) 144 s.reset() 145 } 146 147 func (s Weighted) reset() { 148 copy(s.heap, s.weights) 149 for i := len(s.heap) - 1; i > 0; i-- { 150 // Sometimes 1-based counting makes sense. 151 s.heap[((i+1)>>1)-1] += s.heap[i] 152 } 153 }