gonum.org/v1/gonum@v0.14.0/stat/distuv/categorical.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package distuv 6 7 import ( 8 "math" 9 10 "golang.org/x/exp/rand" 11 ) 12 13 // Categorical is an extension of the Bernoulli distribution where x takes 14 // values {0, 1, ..., len(w)-1} where w is the weight vector. Categorical must 15 // be initialized with NewCategorical. 16 type Categorical struct { 17 weights []float64 18 19 // heap is a weight heap. 20 // 21 // It keeps a heap-organised sum of remaining 22 // index weights that are available to be taken 23 // from. 24 // 25 // Each element holds the sum of weights for 26 // the corresponding index, plus the sum of 27 // its children's weights; the children of 28 // an element i can be found at positions 29 // 2*(i+1)-1 and 2*(i+1). The root of the 30 // weight heap is at element 0. 31 // 32 // See comments in container/heap for an 33 // explanation of the layout of a heap. 34 heap []float64 35 36 src rand.Source 37 } 38 39 // NewCategorical constructs a new categorical distribution where the probability 40 // that x equals i is proportional to w[i]. All of the weights must be 41 // nonnegative, and at least one of the weights must be positive. 42 func NewCategorical(w []float64, src rand.Source) Categorical { 43 c := Categorical{ 44 weights: make([]float64, len(w)), 45 heap: make([]float64, len(w)), 46 src: src, 47 } 48 c.ReweightAll(w) 49 return c 50 } 51 52 // CDF computes the value of the cumulative density function at x. 53 func (c Categorical) CDF(x float64) float64 { 54 var cdf float64 55 for i, w := range c.weights { 56 if x < float64(i) { 57 break 58 } 59 cdf += w 60 } 61 return cdf / c.heap[0] 62 } 63 64 // Entropy returns the entropy of the distribution. 65 func (c Categorical) Entropy() float64 { 66 var ent float64 67 for _, w := range c.weights { 68 if w == 0 { 69 continue 70 } 71 p := w / c.heap[0] 72 ent += p * math.Log(p) 73 } 74 return -ent 75 } 76 77 // Len returns the number of values x could possibly take (the length of the 78 // initial supplied weight vector). 79 func (c Categorical) Len() int { 80 return len(c.weights) 81 } 82 83 // Mean returns the mean of the probability distribution. 84 func (c Categorical) Mean() float64 { 85 var mean float64 86 for i, v := range c.weights { 87 mean += float64(i) * v 88 } 89 return mean / c.heap[0] 90 } 91 92 // Prob computes the value of the probability density function at x. 93 func (c Categorical) Prob(x float64) float64 { 94 xi := int(x) 95 if float64(xi) != x { 96 return 0 97 } 98 if xi < 0 || xi > len(c.weights)-1 { 99 return 0 100 } 101 return c.weights[xi] / c.heap[0] 102 } 103 104 // LogProb computes the natural logarithm of the value of the probability density function at x. 105 func (c Categorical) LogProb(x float64) float64 { 106 return math.Log(c.Prob(x)) 107 } 108 109 // Rand returns a random draw from the categorical distribution. 110 func (c Categorical) Rand() float64 { 111 var r float64 112 if c.src == nil { 113 r = c.heap[0] * rand.Float64() 114 } else { 115 r = c.heap[0] * rand.New(c.src).Float64() 116 } 117 i := 1 118 last := -1 119 left := len(c.weights) 120 for { 121 if r -= c.weights[i-1]; r <= 0 { 122 break // Fall within item i-1. 123 } 124 i <<= 1 // Move to left child. 125 if d := c.heap[i-1]; r > d { 126 r -= d 127 // If enough r to pass left child, 128 // move to right child state will 129 // be caught at break above. 130 i++ 131 } 132 if i == last || left < 0 { 133 panic("categorical: bad sample") 134 } 135 last = i 136 left-- 137 } 138 return float64(i - 1) 139 } 140 141 // Reweight sets the weight of item idx to w. The input weight must be 142 // non-negative, and after reweighting at least one of the weights must be 143 // positive. 144 func (c Categorical) Reweight(idx int, w float64) { 145 if w < 0 { 146 panic("categorical: negative weight") 147 } 148 w, c.weights[idx] = c.weights[idx]-w, w 149 idx++ 150 for idx > 0 { 151 c.heap[idx-1] -= w 152 idx >>= 1 153 } 154 if c.heap[0] <= 0 { 155 panic("categorical: sum of the weights non-positive") 156 } 157 } 158 159 // ReweightAll resets the weights of the distribution. ReweightAll panics if 160 // len(w) != c.Len. All of the weights must be nonnegative, and at least one of 161 // the weights must be positive. 162 func (c Categorical) ReweightAll(w []float64) { 163 if len(w) != c.Len() { 164 panic("categorical: length of the slices do not match") 165 } 166 for _, v := range w { 167 if v < 0 { 168 panic("categorical: negative weight") 169 } 170 } 171 copy(c.weights, w) 172 c.reset() 173 } 174 175 func (c Categorical) reset() { 176 copy(c.heap, c.weights) 177 for i := len(c.heap) - 1; i > 0; i-- { 178 // Sometimes 1-based counting makes sense. 179 c.heap[((i+1)>>1)-1] += c.heap[i] 180 } 181 // TODO(btracey): Renormalization for weird weights? 182 if c.heap[0] <= 0 { 183 panic("categorical: sum of the weights non-positive") 184 } 185 }