gonum.org/v1/gonum@v0.14.0/stat/distuv/statdist_test.go (about) 1 // Copyright ©2018 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package distuv 6 7 import ( 8 "math" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/floats" 14 "gonum.org/v1/gonum/floats/scalar" 15 ) 16 17 func TestBhattacharyyaBeta(t *testing.T) { 18 t.Parallel() 19 rnd := rand.New(rand.NewSource(1)) 20 for cas, test := range []struct { 21 a, b Beta 22 samples int 23 tol float64 24 }{ 25 { 26 a: Beta{Alpha: 1, Beta: 2, Src: rnd}, 27 b: Beta{Alpha: 1, Beta: 4, Src: rnd}, 28 samples: 100000, 29 tol: 1e-2, 30 }, 31 { 32 a: Beta{Alpha: 0.5, Beta: 0.4, Src: rnd}, 33 b: Beta{Alpha: 0.7, Beta: 0.2, Src: rnd}, 34 samples: 100000, 35 tol: 1e-2, 36 }, 37 { 38 a: Beta{Alpha: 3, Beta: 5, Src: rnd}, 39 b: Beta{Alpha: 5, Beta: 3, Src: rnd}, 40 samples: 100000, 41 tol: 1e-2, 42 }, 43 } { 44 want := bhattacharyyaSample(test.samples, test.a, test.b) 45 got := Bhattacharyya{}.DistBeta(test.a, test.b) 46 if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) { 47 t.Errorf("Bhattacharyya mismatch, case %d: got %v, want %v", cas, got, want) 48 } 49 50 // Bhattacharyya should be symmetric 51 got2 := Bhattacharyya{}.DistBeta(test.b, test.a) 52 if math.Abs(got-got2) > 1e-14 { 53 t.Errorf("Bhattacharyya distance not symmetric") 54 } 55 } 56 } 57 58 func TestBhattacharyyaNormal(t *testing.T) { 59 t.Parallel() 60 rnd := rand.New(rand.NewSource(1)) 61 for cas, test := range []struct { 62 a, b Normal 63 samples int 64 tol float64 65 }{ 66 { 67 a: Normal{Mu: 1, Sigma: 2, Src: rnd}, 68 b: Normal{Mu: 1, Sigma: 4, Src: rnd}, 69 samples: 100000, 70 tol: 1e-2, 71 }, 72 { 73 a: Normal{Mu: 0, Sigma: 2, Src: rnd}, 74 b: Normal{Mu: 2, Sigma: 2, Src: rnd}, 75 samples: 100000, 76 tol: 1e-2, 77 }, 78 { 79 a: Normal{Mu: 0, Sigma: 5, Src: rnd}, 80 b: Normal{Mu: 2, Sigma: 0.1, Src: rnd}, 81 samples: 200000, 82 tol: 1e-2, 83 }, 84 } { 85 want := bhattacharyyaSample(test.samples, test.a, test.b) 86 got := Bhattacharyya{}.DistNormal(test.a, test.b) 87 if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) { 88 t.Errorf("Bhattacharyya mismatch, case %d: got %v, want %v", cas, got, want) 89 } 90 91 // Bhattacharyya should be symmetric 92 got2 := Bhattacharyya{}.DistNormal(test.b, test.a) 93 if math.Abs(got-got2) > 1e-14 { 94 t.Errorf("Bhattacharyya distance not symmetric") 95 } 96 } 97 } 98 99 // bhattacharyyaSample finds an estimate of the Bhattacharyya coefficient through 100 // sampling. 101 func bhattacharyyaSample(samples int, l RandLogProber, r LogProber) float64 { 102 lBhatt := make([]float64, samples) 103 for i := 0; i < samples; i++ { 104 // Do importance sampling over a: \int sqrt(a*b)/a * a dx 105 x := l.Rand() 106 pa := l.LogProb(x) 107 pb := r.LogProb(x) 108 lBhatt[i] = 0.5*pb - 0.5*pa 109 } 110 logBc := floats.LogSumExp(lBhatt) - math.Log(float64(samples)) 111 return -logBc 112 } 113 114 func TestKullbackLeiblerBeta(t *testing.T) { 115 t.Parallel() 116 rnd := rand.New(rand.NewSource(1)) 117 for cas, test := range []struct { 118 a, b Beta 119 samples int 120 tol float64 121 }{ 122 { 123 a: Beta{Alpha: 1, Beta: 2, Src: rnd}, 124 b: Beta{Alpha: 1, Beta: 4, Src: rnd}, 125 samples: 100000, 126 tol: 1e-2, 127 }, 128 { 129 a: Beta{Alpha: 0.5, Beta: 0.4, Src: rnd}, 130 b: Beta{Alpha: 0.7, Beta: 0.2, Src: rnd}, 131 samples: 100000, 132 tol: 1e-2, 133 }, 134 { 135 a: Beta{Alpha: 3, Beta: 5, Src: rnd}, 136 b: Beta{Alpha: 5, Beta: 3, Src: rnd}, 137 samples: 100000, 138 tol: 1e-2, 139 }, 140 } { 141 a, b := test.a, test.b 142 want := klSample(test.samples, a, b) 143 got := KullbackLeibler{}.DistBeta(a, b) 144 if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) { 145 t.Errorf("Kullback-Leibler mismatch, case %d: got %v, want %v", cas, got, want) 146 } 147 } 148 good := Beta{0.5, 0.5, nil} 149 bad := Beta{0, 1, nil} 150 if !panics(func() { KullbackLeibler{}.DistBeta(bad, good) }) { 151 t.Errorf("Expected Kullback-Leibler to panic when called with invalid left Beta distribution") 152 } 153 if !panics(func() { KullbackLeibler{}.DistBeta(good, bad) }) { 154 t.Errorf("Expected Kullback-Leibler to panic when called with invalid right Beta distribution") 155 } 156 bad = Beta{1, 0, nil} 157 if !panics(func() { KullbackLeibler{}.DistBeta(bad, good) }) { 158 t.Errorf("Expected Kullback-Leibler to panic when called with invalid left Beta distribution") 159 } 160 if !panics(func() { KullbackLeibler{}.DistBeta(good, bad) }) { 161 t.Errorf("Expected Kullback-Leibler to panic when called with invalid right Beta distribution") 162 } 163 } 164 165 func TestKullbackLeiblerNormal(t *testing.T) { 166 t.Parallel() 167 rnd := rand.New(rand.NewSource(1)) 168 for cas, test := range []struct { 169 a, b Normal 170 samples int 171 tol float64 172 }{ 173 { 174 a: Normal{Mu: 1, Sigma: 2, Src: rnd}, 175 b: Normal{Mu: 1, Sigma: 4, Src: rnd}, 176 samples: 100000, 177 tol: 1e-2, 178 }, 179 { 180 a: Normal{Mu: 0, Sigma: 2, Src: rnd}, 181 b: Normal{Mu: 2, Sigma: 2, Src: rnd}, 182 samples: 100000, 183 tol: 1e-2, 184 }, 185 { 186 a: Normal{Mu: 0, Sigma: 5, Src: rnd}, 187 b: Normal{Mu: 2, Sigma: 0.1, Src: rnd}, 188 samples: 100000, 189 tol: 1e-2, 190 }, 191 } { 192 a, b := test.a, test.b 193 want := klSample(test.samples, a, b) 194 got := KullbackLeibler{}.DistNormal(a, b) 195 if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) { 196 t.Errorf("Kullback-Leibler mismatch, case %d: got %v, want %v", cas, got, want) 197 } 198 } 199 } 200 201 // klSample finds an estimate of the Kullback-Leibler divergence through sampling. 202 func klSample(samples int, l RandLogProber, r LogProber) float64 { 203 var klmc float64 204 for i := 0; i < samples; i++ { 205 x := l.Rand() 206 pa := l.LogProb(x) 207 pb := r.LogProb(x) 208 klmc += pa - pb 209 } 210 return klmc / float64(samples) 211 } 212 213 func TestHellingerBeta(t *testing.T) { 214 t.Parallel() 215 rnd := rand.New(rand.NewSource(1)) 216 const tol = 1e-15 217 for cas, test := range []struct { 218 a, b Beta 219 }{ 220 { 221 a: Beta{Alpha: 1, Beta: 2, Src: rnd}, 222 b: Beta{Alpha: 1, Beta: 4, Src: rnd}, 223 }, 224 { 225 a: Beta{Alpha: 0.5, Beta: 0.4, Src: rnd}, 226 b: Beta{Alpha: 0.7, Beta: 0.2, Src: rnd}, 227 }, 228 { 229 a: Beta{Alpha: 3, Beta: 5, Src: rnd}, 230 b: Beta{Alpha: 5, Beta: 3, Src: rnd}, 231 }, 232 } { 233 got := Hellinger{}.DistBeta(test.a, test.b) 234 want := math.Sqrt(1 - math.Exp(-Bhattacharyya{}.DistBeta(test.a, test.b))) 235 if !scalar.EqualWithinAbsOrRel(got, want, tol, tol) { 236 t.Errorf("Hellinger mismatch, case %d: got %v, want %v", cas, got, want) 237 } 238 } 239 } 240 241 func TestHellingerNormal(t *testing.T) { 242 t.Parallel() 243 rnd := rand.New(rand.NewSource(1)) 244 const tol = 1e-15 245 for cas, test := range []struct { 246 a, b Normal 247 }{ 248 { 249 a: Normal{Mu: 1, Sigma: 2, Src: rnd}, 250 b: Normal{Mu: 1, Sigma: 4, Src: rnd}, 251 }, 252 { 253 a: Normal{Mu: 0, Sigma: 2, Src: rnd}, 254 b: Normal{Mu: 2, Sigma: 2, Src: rnd}, 255 }, 256 { 257 a: Normal{Mu: 0, Sigma: 5, Src: rnd}, 258 b: Normal{Mu: 2, Sigma: 0.1, Src: rnd}, 259 }, 260 } { 261 got := Hellinger{}.DistNormal(test.a, test.b) 262 want := math.Sqrt(1 - math.Exp(-Bhattacharyya{}.DistNormal(test.a, test.b))) 263 if !scalar.EqualWithinAbsOrRel(got, want, tol, tol) { 264 t.Errorf("Hellinger mismatch, case %d: got %v, want %v", cas, got, want) 265 } 266 } 267 }