gonum.org/v1/gonum@v0.14.0/stat/distuv/categorical_test.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package distuv 6 7 import ( 8 "math" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/floats" 14 "gonum.org/v1/gonum/floats/scalar" 15 ) 16 17 const ( 18 Tiny = 2 19 Small = 5 20 Medium = 10 21 Large = 100 22 Huge = 1000 23 ) 24 25 func TestCategoricalProb(t *testing.T) { 26 t.Parallel() 27 for _, test := range [][]float64{ 28 {1, 2, 3, 0}, 29 } { 30 dist := NewCategorical(test, nil) 31 norm := make([]float64, len(test)) 32 floats.Scale(1/floats.Sum(norm), norm) 33 for i, v := range norm { 34 p := dist.Prob(float64(i)) 35 if math.Abs(p-v) > 1e-14 { 36 t.Errorf("Probability mismatch element %d", i) 37 } 38 logP := dist.LogProb(float64(i)) 39 if math.Abs(logP-math.Log(v)) > 1e-14 { 40 t.Errorf("Log-probability mismatch element %d", i) 41 } 42 p = dist.Prob(float64(i) + 0.5) 43 if p != 0 { 44 t.Errorf("Non-zero probability for non-integer x") 45 } 46 logP = dist.LogProb(float64(i) + 0.5) 47 if !math.IsInf(logP, -1) { 48 t.Errorf("Log-probability for non-integer x is not -Inf") 49 } 50 } 51 p := dist.Prob(-1) 52 if p != 0 { 53 t.Errorf("Non-zero probability for -1") 54 } 55 logP := dist.LogProb(-1) 56 if !math.IsInf(logP, -1) { 57 t.Errorf("Log-probability for -1 is not -Inf") 58 } 59 p = dist.Prob(float64(len(test))) 60 if p != 0 { 61 t.Errorf("Non-zero probability for len(test)") 62 } 63 logP = dist.LogProb(float64(len(test))) 64 if !math.IsInf(logP, -1) { 65 t.Errorf("Log-probability for len(test) is not -Inf") 66 } 67 } 68 } 69 70 func TestCategoricalRand(t *testing.T) { 71 t.Parallel() 72 for _, test := range [][]float64{ 73 {1, 2, 3, 0}, 74 } { 75 dist := NewCategorical(test, nil) 76 nSamples := 2000000 77 counts := sampleCategorical(t, dist, nSamples) 78 79 probs := make([]float64, len(test)) 80 for i := range probs { 81 probs[i] = dist.Prob(float64(i)) 82 } 83 same := samedDistCategorical(dist, counts, probs, 1e-2) 84 if !same { 85 t.Errorf("Probability mismatch. Want %v, got %v", probs, counts) 86 } 87 88 dist.Reweight(len(test)-1, 10) 89 counts = sampleCategorical(t, dist, nSamples) 90 probs = make([]float64, len(test)) 91 for i := range probs { 92 probs[i] = dist.Prob(float64(i)) 93 } 94 same = samedDistCategorical(dist, counts, probs, 1e-2) 95 if !same { 96 t.Errorf("Probability mismatch after Reweight. Want %v, got %v", probs, counts) 97 } 98 99 w := make([]float64, len(test)) 100 for i := range w { 101 w[i] = rand.Float64() 102 } 103 104 dist.ReweightAll(w) 105 counts = sampleCategorical(t, dist, nSamples) 106 probs = make([]float64, len(test)) 107 for i := range probs { 108 probs[i] = dist.Prob(float64(i)) 109 } 110 same = samedDistCategorical(dist, counts, probs, 1e-2) 111 if !same { 112 t.Errorf("Probability mismatch after ReweightAll. Want %v, got %v", probs, counts) 113 } 114 } 115 } 116 117 func TestCategoricalReweight(t *testing.T) { 118 t.Parallel() 119 dist := NewCategorical([]float64{1, 1}, nil) 120 if !panics(func() { dist.Reweight(0, -1) }) { 121 t.Errorf("Reweight did not panic for negative weight") 122 } 123 dist.Reweight(0, 0) 124 if !panics(func() { dist.Reweight(1, 0) }) { 125 t.Errorf("Reweight did not panic when trying to set the last positive weight to zero") 126 } 127 } 128 129 func TestCategoricalReweightAll(t *testing.T) { 130 t.Parallel() 131 w := []float64{0, 1, 2, 1} 132 dist := NewCategorical(w, nil) 133 if !panics(func() { dist.ReweightAll([]float64{1, 1}) }) { 134 t.Errorf("ReweightAll did not panic for different number of weights") 135 } 136 w[0] = -1 137 if !panics(func() { dist.ReweightAll(w) }) { 138 t.Errorf("ReweightAll did not panic for a negative weight") 139 } 140 w = []float64{0, 0, 0, 0} 141 if !panics(func() { dist.ReweightAll(w) }) { 142 t.Errorf("ReweightAll did not panic for weights which are all zero") 143 } 144 } 145 146 func sampleCategorical(t *testing.T, dist Categorical, nSamples int) []float64 { 147 counts := make([]float64, dist.Len()) 148 for i := 0; i < nSamples; i++ { 149 v := dist.Rand() 150 if float64(int(v)) != v { 151 t.Fatalf("Random number is not an integer") 152 } 153 counts[int(v)]++ 154 } 155 sum := floats.Sum(counts) 156 floats.Scale(1/sum, counts) 157 return counts 158 } 159 160 func samedDistCategorical(dist Categorical, counts, probs []float64, tol float64) bool { 161 same := true 162 for i, prob := range probs { 163 if prob == 0 && counts[i] != 0 { 164 same = false 165 break 166 } 167 if !scalar.EqualWithinAbsOrRel(prob, counts[i], tol, tol) { 168 same = false 169 break 170 } 171 } 172 return same 173 } 174 175 func TestCategoricalCDF(t *testing.T) { 176 t.Parallel() 177 for _, test := range [][]float64{ 178 {1, 2, 3, 0, 4}, 179 } { 180 c := make([]float64, len(test)) 181 copy(c, test) 182 floats.Scale(1/floats.Sum(c), c) 183 sum := make([]float64, len(test)) 184 floats.CumSum(sum, c) 185 186 dist := NewCategorical(test, nil) 187 cdf := dist.CDF(-0.5) 188 if cdf != 0 { 189 t.Errorf("CDF of negative number not zero") 190 } 191 for i := range c { 192 cdf := dist.CDF(float64(i)) 193 if math.Abs(cdf-sum[i]) > 1e-14 { 194 t.Errorf("CDF mismatch %v. Want %v, got %v.", float64(i), sum[i], cdf) 195 } 196 cdfp := dist.CDF(float64(i) + 0.5) 197 if cdfp != cdf { 198 t.Errorf("CDF mismatch for non-integer input") 199 } 200 } 201 } 202 } 203 204 func TestCategoricalEntropy(t *testing.T) { 205 t.Parallel() 206 for _, test := range []struct { 207 weights []float64 208 entropy float64 209 }{ 210 { 211 weights: []float64{1, 1}, 212 entropy: math.Ln2, 213 }, 214 { 215 weights: []float64{1, 1, 1, 1}, 216 entropy: math.Log(4), 217 }, 218 { 219 weights: []float64{0, 0, 1, 1, 0, 0}, 220 entropy: math.Ln2, 221 }, 222 } { 223 dist := NewCategorical(test.weights, nil) 224 entropy := dist.Entropy() 225 if math.IsNaN(entropy) || math.Abs(entropy-test.entropy) > 1e-14 { 226 t.Errorf("Entropy mismatch. Want %v, got %v.", test.entropy, entropy) 227 } 228 } 229 } 230 231 func TestCategoricalMean(t *testing.T) { 232 t.Parallel() 233 for _, test := range []struct { 234 weights []float64 235 mean float64 236 }{ 237 { 238 weights: []float64{10, 0, 0, 0}, 239 mean: 0, 240 }, 241 { 242 weights: []float64{0, 10, 0, 0}, 243 mean: 1, 244 }, 245 { 246 weights: []float64{1, 2, 3, 4}, 247 mean: 2, 248 }, 249 } { 250 dist := NewCategorical(test.weights, nil) 251 mean := dist.Mean() 252 if math.IsNaN(mean) || math.Abs(mean-test.mean) > 1e-14 { 253 t.Errorf("Entropy mismatch. Want %v, got %v.", test.mean, mean) 254 } 255 } 256 } 257 258 func BenchmarkCategoricalRandTiny(b *testing.B) { benchmarkCategoricalRand(b, Tiny) } 259 func BenchmarkCategoricalRandSmall(b *testing.B) { benchmarkCategoricalRand(b, Small) } 260 func BenchmarkCategoricalRandMedium(b *testing.B) { benchmarkCategoricalRand(b, Medium) } 261 func BenchmarkCategoricalRandLarge(b *testing.B) { benchmarkCategoricalRand(b, Large) } 262 func BenchmarkCategoricalRandHuge(b *testing.B) { benchmarkCategoricalRand(b, Huge) } 263 264 func benchmarkCategoricalRand(b *testing.B, size int) { 265 src := rand.NewSource(1) 266 rng := rand.New(src) 267 weights := make([]float64, size) 268 for i := 0; i < size; i++ { 269 weights[i] = rng.Float64() + 0.001 270 } 271 dist := NewCategorical(weights, src) 272 for i := 0; i < b.N; i++ { 273 dist.Rand() 274 } 275 }