github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/stat/sampleuv/weighted.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this code is governed by a BSD-style 3 // license that can be found in the LICENSE file 4 5 package sampleuv 6 7 import ( 8 "golang.org/x/exp/rand" 9 10 "github.com/jingcheng-WU/gonum/floats/scalar" 11 ) 12 13 // Weighted provides sampling without replacement from a collection of items with 14 // non-uniform probability. 15 type Weighted struct { 16 weights []float64 17 // heap is a weight heap. 18 // 19 // It keeps a heap-organised sum of remaining 20 // index weights that are available to be taken 21 // from. 22 // 23 // Each element holds the sum of weights for 24 // the corresponding index, plus the sum of 25 // its children's weights; the children of 26 // an element i can be found at positions 27 // 2*(i+1)-1 and 2*(i+1). The root of the 28 // weight heap is at element 0. 29 // 30 // See comments in container/heap for an 31 // explanation of the layout of a heap. 32 heap []float64 33 rnd *rand.Rand 34 } 35 36 // NewWeighted returns a Weighted for the weights w. If src is nil, rand.Rand is 37 // used as the random number generator. 38 // 39 // Note that sampling from weights with a high variance or overall low absolute 40 // value sum may result in problems with numerical stability. 41 func NewWeighted(w []float64, src rand.Source) Weighted { 42 s := Weighted{ 43 weights: make([]float64, len(w)), 44 heap: make([]float64, len(w)), 45 } 46 if src != nil { 47 s.rnd = rand.New(src) 48 } 49 s.ReweightAll(w) 50 return s 51 } 52 53 // Len returns the number of items held by the Weighted, including items 54 // already taken. 55 func (s Weighted) Len() int { return len(s.weights) } 56 57 // Take returns an index from the Weighted with probability proportional 58 // to the weight of the item. The weight of the item is then set to zero. 59 // Take returns false if there are no items remaining. 60 func (s Weighted) Take() (idx int, ok bool) { 61 const small = 1e-12 62 if scalar.EqualWithinAbsOrRel(s.heap[0], 0, small, small) { 63 return -1, false 64 } 65 66 var r float64 67 if s.rnd == nil { 68 r = s.heap[0] * rand.Float64() 69 } else { 70 r = s.heap[0] * s.rnd.Float64() 71 } 72 i := 1 73 last := -1 74 left := len(s.weights) 75 for { 76 if r -= s.weights[i-1]; r <= 0 { 77 break // Fall within item i-1. 78 } 79 i <<= 1 // Move to left child. 80 if d := s.heap[i-1]; r > d { 81 r -= d 82 // If enough r to pass left child 83 // move to right child state will 84 // be caught at break above. 85 i++ 86 } 87 if i == last || left < 0 { 88 // No progression. 89 return -1, false 90 } 91 last = i 92 left-- 93 } 94 95 w, idx := s.weights[i-1], i-1 96 97 s.weights[i-1] = 0 98 for i > 0 { 99 s.heap[i-1] -= w 100 // The following condition is necessary to 101 // handle floating point error. If we see 102 // a heap value below zero, we know we need 103 // to rebuild it. 104 if s.heap[i-1] < 0 { 105 s.reset() 106 return idx, true 107 } 108 i >>= 1 109 } 110 111 return idx, true 112 } 113 114 // Reweight sets the weight of item idx to w. 115 func (s Weighted) Reweight(idx int, w float64) { 116 w, s.weights[idx] = s.weights[idx]-w, w 117 idx++ 118 for idx > 0 { 119 s.heap[idx-1] -= w 120 idx >>= 1 121 } 122 } 123 124 // ReweightAll sets the weight of all items in the Weighted. ReweightAll 125 // panics if len(w) != s.Len. 126 func (s Weighted) ReweightAll(w []float64) { 127 if len(w) != s.Len() { 128 panic("floats: length of the slices do not match") 129 } 130 copy(s.weights, w) 131 s.reset() 132 } 133 134 func (s Weighted) reset() { 135 copy(s.heap, s.weights) 136 for i := len(s.heap) - 1; i > 0; i-- { 137 // Sometimes 1-based counting makes sense. 138 s.heap[((i+1)>>1)-1] += s.heap[i] 139 } 140 }