gonum.org/v1/gonum@v0.14.0/stat/distmat/wishart_test.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package distmat 6 7 import ( 8 "math" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/floats/scalar" 14 "gonum.org/v1/gonum/mat" 15 ) 16 17 func TestWishart(t *testing.T) { 18 for c, test := range []struct { 19 v *mat.SymDense 20 nu float64 21 xs []*mat.SymDense 22 lps []float64 23 }{ 24 // Logprob data compared with scipy. 25 { 26 v: mat.NewSymDense(2, []float64{1, 0, 0, 1}), 27 nu: 4, 28 xs: []*mat.SymDense{ 29 mat.NewSymDense(2, []float64{0.9, 0.1, 0.1, 0.9}), 30 }, 31 lps: []float64{-4.2357432031863409}, 32 }, 33 { 34 v: mat.NewSymDense(2, []float64{0.8, -0.2, -0.2, 0.7}), 35 nu: 5, 36 xs: []*mat.SymDense{ 37 mat.NewSymDense(2, []float64{0.9, 0.1, 0.1, 0.9}), 38 mat.NewSymDense(2, []float64{0.3, -0.1, -0.1, 0.7}), 39 }, 40 lps: []float64{-4.2476495605333575, -4.9993285370378633}, 41 }, 42 { 43 v: mat.NewSymDense(3, []float64{0.8, 0.3, 0.1, 0.3, 0.7, -0.1, 0.1, -0.1, 7}), 44 nu: 5, 45 xs: []*mat.SymDense{ 46 mat.NewSymDense(3, []float64{1, 0.2, -0.3, 0.2, 0.6, -0.2, -0.3, -0.2, 6}), 47 }, 48 lps: []float64{-11.010982249229421}, 49 }, 50 } { 51 w, ok := NewWishart(test.v, test.nu, nil) 52 if !ok { 53 panic("bad test") 54 } 55 for i, x := range test.xs { 56 lp := w.LogProbSym(x) 57 58 var chol mat.Cholesky 59 ok := chol.Factorize(x) 60 if !ok { 61 panic("bad test") 62 } 63 lpc := w.LogProbSymChol(&chol) 64 65 if math.Abs(lp-lpc) > 1e-14 { 66 t.Errorf("Case %d, test %d: probability mismatch between chol and not", c, i) 67 } 68 if !scalar.EqualWithinAbsOrRel(lp, test.lps[i], 1e-14, 1e-14) { 69 t.Errorf("Case %d, test %d: got %v, want %v", c, i, lp, test.lps[i]) 70 } 71 } 72 73 var ch mat.Cholesky 74 w.RandCholTo(&ch) 75 w.RandCholTo(&ch) 76 77 var s mat.SymDense 78 w.RandSymTo(&s) 79 w.RandSymTo(&s) 80 } 81 } 82 83 func TestWishartRand(t *testing.T) { 84 for c, test := range []struct { 85 v *mat.SymDense 86 nu float64 87 samples int 88 tol float64 89 }{ 90 { 91 v: mat.NewSymDense(2, []float64{0.8, -0.2, -0.2, 0.7}), 92 nu: 5, 93 samples: 30000, 94 tol: 3e-2, 95 }, 96 { 97 v: mat.NewSymDense(3, []float64{0.8, 0.3, 0.1, 0.3, 0.7, -0.1, 0.1, -0.1, 7}), 98 nu: 5, 99 samples: 30000, 100 tol: 3e-1, 101 }, 102 { 103 v: mat.NewSymDense(4, []float64{ 104 0.8, 0.3, 0.1, -0.2, 105 0.3, 0.7, -0.1, 0.4, 106 0.1, -0.1, 7, 1, 107 -0.2, -0.1, 1, 6}), 108 nu: 6, 109 samples: 30000, 110 tol: 1e-1, 111 }, 112 } { 113 rnd := rand.New(rand.NewSource(1)) 114 dim := test.v.SymmetricDim() 115 w, ok := NewWishart(test.v, test.nu, rnd) 116 if !ok { 117 panic("bad test") 118 } 119 mean := mat.NewSymDense(dim, nil) 120 x := mat.NewSymDense(dim, nil) 121 for i := 0; i < test.samples; i++ { 122 w.RandSymTo(x) 123 x.ScaleSym(1/float64(test.samples), x) 124 mean.AddSym(mean, x) 125 } 126 var trueMean mat.SymDense 127 w.MeanSymTo(&trueMean) 128 if !mat.EqualApprox(&trueMean, mean, test.tol) { 129 t.Errorf("Case %d: Mismatch between estimated and true mean. Got\n%0.4v\nWant\n%0.4v\n", c, mat.Formatted(mean), mat.Formatted(&trueMean)) 130 } 131 } 132 }