github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/stat/distmv/statdist_test.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package distmv 6 7 import ( 8 "math" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "github.com/jingcheng-WU/gonum/floats" 14 "github.com/jingcheng-WU/gonum/floats/scalar" 15 "github.com/jingcheng-WU/gonum/mat" 16 "github.com/jingcheng-WU/gonum/spatial/r1" 17 ) 18 19 func TestBhattacharyyaNormal(t *testing.T) { 20 for cas, test := range []struct { 21 am, bm []float64 22 ac, bc *mat.SymDense 23 samples int 24 tol float64 25 }{ 26 { 27 am: []float64{2, 3}, 28 ac: mat.NewSymDense(2, []float64{3, -1, -1, 2}), 29 bm: []float64{-1, 1}, 30 bc: mat.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}), 31 samples: 100000, 32 tol: 3e-1, 33 }, 34 } { 35 rnd := rand.New(rand.NewSource(1)) 36 a, ok := NewNormal(test.am, test.ac, rnd) 37 if !ok { 38 panic("bad test") 39 } 40 b, ok := NewNormal(test.bm, test.bc, rnd) 41 if !ok { 42 panic("bad test") 43 } 44 want := bhattacharyyaSample(a.Dim(), test.samples, a, b) 45 got := Bhattacharyya{}.DistNormal(a, b) 46 if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) { 47 t.Errorf("Bhattacharyya mismatch, case %d: got %v, want %v", cas, got, want) 48 } 49 50 // Bhattacharyya should by symmetric 51 got2 := Bhattacharyya{}.DistNormal(b, a) 52 if math.Abs(got-got2) > 1e-14 { 53 t.Errorf("Bhattacharyya distance not symmetric") 54 } 55 } 56 } 57 58 func TestBhattacharyyaUniform(t *testing.T) { 59 rnd := rand.New(rand.NewSource(1)) 60 for cas, test := range []struct { 61 a, b *Uniform 62 samples int 63 tol float64 64 }{ 65 { 66 a: NewUniform([]r1.Interval{{Min: -3, Max: 2}, {Min: -5, Max: 8}}, rnd), 67 b: NewUniform([]r1.Interval{{Min: -4, Max: 1}, {Min: -7, Max: 10}}, rnd), 68 samples: 100000, 69 tol: 1e-2, 70 }, 71 { 72 a: NewUniform([]r1.Interval{{Min: -3, Max: 2}, {Min: -5, Max: 8}}, rnd), 73 b: NewUniform([]r1.Interval{{Min: -5, Max: -4}, {Min: -7, Max: 10}}, rnd), 74 samples: 100000, 75 tol: 1e-2, 76 }, 77 } { 78 a, b := test.a, test.b 79 want := bhattacharyyaSample(a.Dim(), test.samples, a, b) 80 got := Bhattacharyya{}.DistUniform(a, b) 81 if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) { 82 t.Errorf("Bhattacharyya mismatch, case %d: got %v, want %v", cas, got, want) 83 } 84 // Bhattacharyya should by symmetric 85 got2 := Bhattacharyya{}.DistUniform(b, a) 86 if math.Abs(got-got2) > 1e-14 { 87 t.Errorf("Bhattacharyya distance not symmetric") 88 } 89 } 90 } 91 92 // bhattacharyyaSample finds an estimate of the Bhattacharyya coefficient through 93 // sampling. 94 func bhattacharyyaSample(dim, samples int, l RandLogProber, r LogProber) float64 { 95 lBhatt := make([]float64, samples) 96 x := make([]float64, dim) 97 for i := 0; i < samples; i++ { 98 // Do importance sampling over a: \int sqrt(a*b)/a * a dx 99 l.Rand(x) 100 pa := l.LogProb(x) 101 pb := r.LogProb(x) 102 lBhatt[i] = 0.5*pb - 0.5*pa 103 } 104 logBc := floats.LogSumExp(lBhatt) - math.Log(float64(samples)) 105 return -logBc 106 } 107 108 func TestCrossEntropyNormal(t *testing.T) { 109 for cas, test := range []struct { 110 am, bm []float64 111 ac, bc *mat.SymDense 112 samples int 113 tol float64 114 }{ 115 { 116 am: []float64{2, 3}, 117 ac: mat.NewSymDense(2, []float64{3, -1, -1, 2}), 118 bm: []float64{-1, 1}, 119 bc: mat.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}), 120 samples: 100000, 121 tol: 1e-2, 122 }, 123 } { 124 rnd := rand.New(rand.NewSource(1)) 125 a, ok := NewNormal(test.am, test.ac, rnd) 126 if !ok { 127 panic("bad test") 128 } 129 b, ok := NewNormal(test.bm, test.bc, rnd) 130 if !ok { 131 panic("bad test") 132 } 133 var ce float64 134 x := make([]float64, a.Dim()) 135 for i := 0; i < test.samples; i++ { 136 a.Rand(x) 137 ce -= b.LogProb(x) 138 } 139 ce /= float64(test.samples) 140 got := CrossEntropy{}.DistNormal(a, b) 141 if !scalar.EqualWithinAbsOrRel(ce, got, test.tol, test.tol) { 142 t.Errorf("CrossEntropy mismatch, case %d: got %v, want %v", cas, got, ce) 143 } 144 } 145 } 146 147 func TestHellingerNormal(t *testing.T) { 148 for cas, test := range []struct { 149 am, bm []float64 150 ac, bc *mat.SymDense 151 samples int 152 tol float64 153 }{ 154 { 155 am: []float64{2, 3}, 156 ac: mat.NewSymDense(2, []float64{3, -1, -1, 2}), 157 bm: []float64{-1, 1}, 158 bc: mat.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}), 159 samples: 100000, 160 tol: 5e-1, 161 }, 162 } { 163 rnd := rand.New(rand.NewSource(1)) 164 a, ok := NewNormal(test.am, test.ac, rnd) 165 if !ok { 166 panic("bad test") 167 } 168 b, ok := NewNormal(test.bm, test.bc, rnd) 169 if !ok { 170 panic("bad test") 171 } 172 lAitchEDoubleHockeySticks := make([]float64, test.samples) 173 x := make([]float64, a.Dim()) 174 for i := 0; i < test.samples; i++ { 175 // Do importance sampling over a: \int (\sqrt(a)-\sqrt(b))^2/a * a dx 176 a.Rand(x) 177 pa := a.LogProb(x) 178 pb := b.LogProb(x) 179 d := math.Exp(0.5*pa) - math.Exp(0.5*pb) 180 d = d * d 181 lAitchEDoubleHockeySticks[i] = math.Log(d) - pa 182 } 183 want := math.Sqrt(0.5 * math.Exp(floats.LogSumExp(lAitchEDoubleHockeySticks)-math.Log(float64(test.samples)))) 184 got := Hellinger{}.DistNormal(a, b) 185 if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) { 186 t.Errorf("Hellinger mismatch, case %d: got %v, want %v", cas, got, want) 187 } 188 } 189 } 190 191 func TestKullbackLeiblerDirichlet(t *testing.T) { 192 rnd := rand.New(rand.NewSource(1)) 193 for cas, test := range []struct { 194 a, b *Dirichlet 195 samples int 196 tol float64 197 }{ 198 { 199 a: NewDirichlet([]float64{2, 3, 4}, rnd), 200 b: NewDirichlet([]float64{4, 2, 1.1}, rnd), 201 samples: 100000, 202 tol: 1e-2, 203 }, 204 { 205 a: NewDirichlet([]float64{2, 3, 4, 0.1, 8}, rnd), 206 b: NewDirichlet([]float64{2, 2, 6, 0.5, 9}, rnd), 207 samples: 100000, 208 tol: 1e-2, 209 }, 210 } { 211 a, b := test.a, test.b 212 want := klSample(a.Dim(), test.samples, a, b) 213 got := KullbackLeibler{}.DistDirichlet(a, b) 214 if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) { 215 t.Errorf("Kullback-Leibler mismatch, case %d: got %v, want %v", cas, got, want) 216 } 217 } 218 } 219 220 func TestKullbackLeiblerNormal(t *testing.T) { 221 for cas, test := range []struct { 222 am, bm []float64 223 ac, bc *mat.SymDense 224 samples int 225 tol float64 226 }{ 227 { 228 am: []float64{2, 3}, 229 ac: mat.NewSymDense(2, []float64{3, -1, -1, 2}), 230 bm: []float64{-1, 1}, 231 bc: mat.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}), 232 samples: 10000, 233 tol: 1e-2, 234 }, 235 } { 236 rnd := rand.New(rand.NewSource(1)) 237 a, ok := NewNormal(test.am, test.ac, rnd) 238 if !ok { 239 panic("bad test") 240 } 241 b, ok := NewNormal(test.bm, test.bc, rnd) 242 if !ok { 243 panic("bad test") 244 } 245 want := klSample(a.Dim(), test.samples, a, b) 246 got := KullbackLeibler{}.DistNormal(a, b) 247 if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) { 248 t.Errorf("Case %d, KL mismatch: got %v, want %v", cas, got, want) 249 } 250 } 251 } 252 253 func TestKullbackLeiblerUniform(t *testing.T) { 254 rnd := rand.New(rand.NewSource(1)) 255 for cas, test := range []struct { 256 a, b *Uniform 257 samples int 258 tol float64 259 }{ 260 { 261 a: NewUniform([]r1.Interval{{Min: -5, Max: 2}, {Min: -7, Max: 12}}, rnd), 262 b: NewUniform([]r1.Interval{{Min: -4, Max: 1}, {Min: -7, Max: 10}}, rnd), 263 samples: 100000, 264 tol: 1e-2, 265 }, 266 { 267 a: NewUniform([]r1.Interval{{Min: -5, Max: 2}, {Min: -7, Max: 12}}, rnd), 268 b: NewUniform([]r1.Interval{{Min: -9, Max: -6}, {Min: -7, Max: 10}}, rnd), 269 samples: 100000, 270 tol: 1e-2, 271 }, 272 } { 273 a, b := test.a, test.b 274 want := klSample(a.Dim(), test.samples, a, b) 275 got := KullbackLeibler{}.DistUniform(a, b) 276 if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) { 277 t.Errorf("Kullback-Leibler mismatch, case %d: got %v, want %v", cas, got, want) 278 } 279 } 280 } 281 282 // klSample finds an estimate of the Kullback-Leibler divergence through sampling. 283 func klSample(dim, samples int, l RandLogProber, r LogProber) float64 { 284 var klmc float64 285 x := make([]float64, dim) 286 for i := 0; i < samples; i++ { 287 l.Rand(x) 288 pa := l.LogProb(x) 289 pb := r.LogProb(x) 290 klmc += pa - pb 291 } 292 return klmc / float64(samples) 293 } 294 295 func TestRenyiNormal(t *testing.T) { 296 for cas, test := range []struct { 297 am, bm []float64 298 ac, bc *mat.SymDense 299 alpha float64 300 samples int 301 tol float64 302 }{ 303 { 304 am: []float64{2, 3}, 305 ac: mat.NewSymDense(2, []float64{3, -1, -1, 2}), 306 bm: []float64{-1, 1}, 307 bc: mat.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}), 308 alpha: 0.3, 309 samples: 10000, 310 tol: 3e-1, 311 }, 312 } { 313 rnd := rand.New(rand.NewSource(1)) 314 a, ok := NewNormal(test.am, test.ac, rnd) 315 if !ok { 316 panic("bad test") 317 } 318 b, ok := NewNormal(test.bm, test.bc, rnd) 319 if !ok { 320 panic("bad test") 321 } 322 want := renyiSample(a.Dim(), test.samples, test.alpha, a, b) 323 got := Renyi{Alpha: test.alpha}.DistNormal(a, b) 324 if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) { 325 t.Errorf("Case %d: Renyi sampling mismatch: got %v, want %v", cas, got, want) 326 } 327 328 // Compare with Bhattacharyya. 329 want = 2 * Bhattacharyya{}.DistNormal(a, b) 330 got = Renyi{Alpha: 0.5}.DistNormal(a, b) 331 if !scalar.EqualWithinAbsOrRel(want, got, 1e-10, 1e-10) { 332 t.Errorf("Case %d: Renyi mismatch with Bhattacharyya: got %v, want %v", cas, got, want) 333 } 334 335 // Compare with KL in both directions. 336 want = KullbackLeibler{}.DistNormal(a, b) 337 got = Renyi{Alpha: 0.9999999}.DistNormal(a, b) // very close to 1 but not equal to 1. 338 if !scalar.EqualWithinAbsOrRel(want, got, 1e-6, 1e-6) { 339 t.Errorf("Case %d: Renyi mismatch with KL(a||b): got %v, want %v", cas, got, want) 340 } 341 want = KullbackLeibler{}.DistNormal(b, a) 342 got = Renyi{Alpha: 0.9999999}.DistNormal(b, a) // very close to 1 but not equal to 1. 343 if !scalar.EqualWithinAbsOrRel(want, got, 1e-6, 1e-6) { 344 t.Errorf("Case %d: Renyi mismatch with KL(b||a): got %v, want %v", cas, got, want) 345 } 346 } 347 } 348 349 // renyiSample finds an estimate of the Rényi divergence through sampling. 350 // Note that this sampling procedure only works if l has broader support than r. 351 func renyiSample(dim, samples int, alpha float64, l RandLogProber, r LogProber) float64 { 352 rmcs := make([]float64, samples) 353 x := make([]float64, dim) 354 for i := 0; i < samples; i++ { 355 l.Rand(x) 356 pa := l.LogProb(x) 357 pb := r.LogProb(x) 358 rmcs[i] = (alpha-1)*pa + (1-alpha)*pb 359 } 360 return 1 / (alpha - 1) * (floats.LogSumExp(rmcs) - math.Log(float64(samples))) 361 }