gonum.org/v1/gonum@v0.14.0/stat/distuv/general_test.go (about) 1 // Copyright ©2014 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package distuv 6 7 import ( 8 "fmt" 9 "math" 10 "testing" 11 12 "gonum.org/v1/gonum/diff/fd" 13 "gonum.org/v1/gonum/floats" 14 ) 15 16 type univariateProbPoint struct { 17 loc float64 18 logProb float64 19 cumProb float64 20 prob float64 21 } 22 23 type UniProbDist interface { 24 Prob(float64) float64 25 CDF(float64) float64 26 LogProb(float64) float64 27 Quantile(float64) float64 28 Survival(float64) float64 29 } 30 31 func absEq(a, b float64) bool { 32 return absEqTol(a, b, 1e-14) 33 } 34 35 func absEqTol(a, b, tol float64) bool { 36 if math.IsNaN(a) || math.IsNaN(b) { 37 // NaN is not equal to anything. 38 return false 39 } 40 // This is expressed as the inverse to catch the 41 // case a = Inf and b = Inf of the same sign. 42 return !(math.Abs(a-b) > tol) 43 } 44 45 // TODO: Implement a better test for Quantile 46 func testDistributionProbs(t *testing.T, dist UniProbDist, name string, pts []univariateProbPoint) { 47 for _, pt := range pts { 48 logProb := dist.LogProb(pt.loc) 49 if !absEq(logProb, pt.logProb) { 50 t.Errorf("Log probability doesnt match for "+name+" at %v. Expected %v. Found %v", pt.loc, pt.logProb, logProb) 51 } 52 prob := dist.Prob(pt.loc) 53 if !absEq(prob, pt.prob) { 54 t.Errorf("Probability doesn't match for "+name+" at %v. Expected %v. Found %v", pt.loc, pt.prob, prob) 55 } 56 cumProb := dist.CDF(pt.loc) 57 if !absEq(cumProb, pt.cumProb) { 58 t.Errorf("Cumulative Probability doesn't match for "+name+". Expected %v. Found %v", pt.cumProb, cumProb) 59 } 60 if !absEq(dist.Survival(pt.loc), 1-pt.cumProb) { 61 t.Errorf("Survival doesn't match for %v. Expected %v, Found %v", name, 1-pt.cumProb, dist.Survival(pt.loc)) 62 } 63 if pt.prob != 0 { 64 if math.Abs(dist.Quantile(pt.cumProb)-pt.loc) > 1e-4 { 65 fmt.Println("true =", pt.loc) 66 fmt.Println("calculated=", dist.Quantile(pt.cumProb)) 67 t.Errorf("Quantile doesn't match for "+name+", loc = %v", pt.loc) 68 } 69 } 70 } 71 } 72 73 type ConjugateUpdater interface { 74 NumParameters() int 75 parameters([]Parameter) []Parameter 76 77 NumSuffStat() int 78 SuffStat([]float64, []float64, []float64) float64 79 ConjugateUpdate([]float64, float64, []float64) 80 81 Rand() float64 82 } 83 84 func testConjugateUpdate(t *testing.T, newFittable func() ConjugateUpdater) { 85 for i, test := range []struct { 86 samps []float64 87 weights []float64 88 }{ 89 { 90 samps: randn(newFittable(), 10), 91 weights: nil, 92 }, 93 { 94 samps: randn(newFittable(), 10), 95 weights: ones(10), 96 }, 97 { 98 samps: randn(newFittable(), 10), 99 weights: randn(&Exponential{Rate: 1}, 10), 100 }, 101 } { 102 // ensure that conjugate produces the same result both incrementally and all at once 103 incDist := newFittable() 104 stats := make([]float64, incDist.NumSuffStat()) 105 prior := make([]float64, incDist.NumParameters()) 106 for j := range test.samps { 107 var incWeights, allWeights []float64 108 if test.weights != nil { 109 incWeights = test.weights[j : j+1] 110 allWeights = test.weights[0 : j+1] 111 } 112 nsInc := incDist.SuffStat(stats, test.samps[j:j+1], incWeights) 113 incDist.ConjugateUpdate(stats, nsInc, prior) 114 115 allDist := newFittable() 116 nsAll := allDist.SuffStat(stats, test.samps[0:j+1], allWeights) 117 allDist.ConjugateUpdate(stats, nsAll, make([]float64, allDist.NumParameters())) 118 if !parametersEqual(incDist.parameters(nil), allDist.parameters(nil), 1e-12) { 119 t.Errorf("prior doesn't match after incremental update for (%d, %d). Incremental is %v, all at once is %v", i, j, incDist, allDist) 120 } 121 122 if test.weights == nil { 123 onesDist := newFittable() 124 nsOnes := onesDist.SuffStat(stats, test.samps[0:j+1], ones(j+1)) 125 onesDist.ConjugateUpdate(stats, nsOnes, make([]float64, onesDist.NumParameters())) 126 if !parametersEqual(onesDist.parameters(nil), incDist.parameters(nil), 1e-14) { 127 t.Errorf("nil and uniform weighted prior doesn't match for incremental update for (%d, %d). Uniform weighted is %v, nil is %v", i, j, onesDist, incDist) 128 } 129 if !parametersEqual(onesDist.parameters(nil), allDist.parameters(nil), 1e-14) { 130 t.Errorf("nil and uniform weighted prior doesn't match for all at once update for (%d, %d). Uniform weighted is %v, nil is %v", i, j, onesDist, incDist) 131 } 132 } 133 } 134 } 135 testSuffStatPanics(t, newFittable) 136 testConjugateUpdatePanics(t, newFittable) 137 } 138 139 func testSuffStatPanics(t *testing.T, newFittable func() ConjugateUpdater) { 140 dist := newFittable() 141 sample := randn(dist, 10) 142 if !panics(func() { dist.SuffStat(make([]float64, dist.NumSuffStat()), sample, make([]float64, len(sample)+1)) }) { 143 t.Errorf("Expected panic for mismatch between samples and weights lengths") 144 } 145 if !panics(func() { dist.SuffStat(make([]float64, dist.NumSuffStat()+1), sample, nil) }) { 146 t.Errorf("Expected panic for wrong sufficient statistic length") 147 } 148 } 149 150 func testConjugateUpdatePanics(t *testing.T, newFittable func() ConjugateUpdater) { 151 dist := newFittable() 152 if !panics(func() { 153 dist.ConjugateUpdate(make([]float64, dist.NumSuffStat()+1), 100, make([]float64, dist.NumParameters())) 154 }) { 155 t.Errorf("Expected panic for wrong sufficient statistic length") 156 } 157 if !panics(func() { 158 dist.ConjugateUpdate(make([]float64, dist.NumSuffStat()), 100, make([]float64, dist.NumParameters()+1)) 159 }) { 160 t.Errorf("Expected panic for wrong prior strength length") 161 } 162 } 163 164 // randn generates a specified number of random samples 165 func randn(dist Rander, n int) []float64 { 166 x := make([]float64, n) 167 for i := range x { 168 x[i] = dist.Rand() 169 } 170 return x 171 } 172 173 func ones(n int) []float64 { 174 x := make([]float64, n) 175 for i := range x { 176 x[i] = 1 177 } 178 return x 179 } 180 181 func parametersEqual(p1, p2 []Parameter, tol float64) bool { 182 for i, p := range p1 { 183 if p.Name != p2[i].Name { 184 return false 185 } 186 if math.Abs(p.Value-p2[i].Value) > tol { 187 return false 188 } 189 } 190 return true 191 } 192 193 type derivParamTester interface { 194 LogProb(x float64) float64 195 Score(deriv []float64, x float64) []float64 196 ScoreInput(x float64) float64 197 Quantile(p float64) float64 198 NumParameters() int 199 parameters([]Parameter) []Parameter 200 setParameters([]Parameter) 201 } 202 203 func testDerivParam(t *testing.T, d derivParamTester) { 204 // Tests that the derivative matches for a number of different quantiles 205 // along the distribution. 206 nTest := 10 207 quantiles := make([]float64, nTest) 208 floats.Span(quantiles, 0.1, 0.9) 209 210 scoreInPlace := make([]float64, d.NumParameters()) 211 fdDerivParam := make([]float64, d.NumParameters()) 212 213 if !panics(func() { d.Score(make([]float64, d.NumParameters()+1), 0) }) { 214 t.Errorf("Expected panic for wrong derivative slice length") 215 } 216 if !panics(func() { d.parameters(make([]Parameter, d.NumParameters()+1)) }) { 217 t.Errorf("Expected panic for wrong parameter slice length") 218 } 219 220 initParams := d.parameters(nil) 221 tooLongParams := make([]Parameter, len(initParams)+1) 222 copy(tooLongParams, initParams) 223 if !panics(func() { d.setParameters(tooLongParams) }) { 224 t.Errorf("Expected panic for wrong parameter slice length") 225 } 226 badNameParams := make([]Parameter, len(initParams)) 227 copy(badNameParams, initParams) 228 const badName = "__badName__" 229 for i := 0; i < len(initParams); i++ { 230 badNameParams[i].Name = badName 231 if !panics(func() { d.setParameters(badNameParams) }) { 232 t.Errorf("Expected panic for wrong %d-th parameter name", i) 233 } 234 badNameParams[i].Name = initParams[i].Name 235 } 236 237 init := make([]float64, d.NumParameters()) 238 for i, v := range initParams { 239 init[i] = v.Value 240 } 241 for _, v := range quantiles { 242 d.setParameters(initParams) 243 x := d.Quantile(v) 244 score := d.Score(scoreInPlace, x) 245 if &score[0] != &scoreInPlace[0] { 246 t.Errorf("Returned a different derivative slice than passed in. Got %v, want %v", score, scoreInPlace) 247 } 248 logProbParams := func(p []float64) float64 { 249 params := d.parameters(nil) 250 for i, v := range p { 251 params[i].Value = v 252 } 253 d.setParameters(params) 254 return d.LogProb(x) 255 } 256 fd.Gradient(fdDerivParam, logProbParams, init, nil) 257 if !floats.EqualApprox(scoreInPlace, fdDerivParam, 1e-6) { 258 t.Errorf("Score mismatch at x = %g. Want %v, got %v", x, fdDerivParam, scoreInPlace) 259 } 260 d.setParameters(initParams) 261 score2 := d.Score(nil, x) 262 if !floats.EqualApprox(score2, scoreInPlace, 1e-14) { 263 t.Errorf("Score mismatch when input nil Want %v, got %v", score2, scoreInPlace) 264 } 265 logProbInput := func(x2 float64) float64 { 266 return d.LogProb(x2) 267 } 268 scoreInput := d.ScoreInput(x) 269 fdDerivInput := fd.Derivative(logProbInput, x, nil) 270 if !absEqTol(scoreInput, fdDerivInput, 1e-6) { 271 t.Errorf("ScoreInput mismatch at x = %g. Want %v, got %v", x, fdDerivInput, scoreInput) 272 } 273 } 274 }