gonum.org/v1/gonum@v0.14.0/stat/distmv/studentst_test.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package distmv 6 7 import ( 8 "math" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/floats" 14 "gonum.org/v1/gonum/floats/scalar" 15 "gonum.org/v1/gonum/mat" 16 "gonum.org/v1/gonum/stat" 17 ) 18 19 func TestStudentTProbs(t *testing.T) { 20 src := rand.New(rand.NewSource(1)) 21 for _, test := range []struct { 22 nu float64 23 mu []float64 24 sigma *mat.SymDense 25 26 x [][]float64 27 probs []float64 28 }{ 29 { 30 nu: 3, 31 mu: []float64{0, 0}, 32 sigma: mat.NewSymDense(2, []float64{1, 0, 0, 1}), 33 34 x: [][]float64{ 35 {0, 0}, 36 {1, -1}, 37 {3, 4}, 38 {-1, -2}, 39 }, 40 // Outputs compared with WolframAlpha. 41 probs: []float64{ 42 0.159154943091895335768883, 43 0.0443811199724279860006777747927, 44 0.0005980371870904696541052658, 45 0.01370560783418571283428283, 46 }, 47 }, 48 { 49 nu: 4, 50 mu: []float64{2, -3}, 51 sigma: mat.NewSymDense(2, []float64{8, -1, -1, 5}), 52 53 x: [][]float64{ 54 {0, 0}, 55 {1, -1}, 56 {3, 4}, 57 {-1, -2}, 58 {2, -3}, 59 }, 60 // Outputs compared with WolframAlpha. 61 probs: []float64{ 62 0.007360810111491788657953608191001, 63 0.0143309905845607117740440592999, 64 0.0005307774290578041397794096037035009801668903, 65 0.0115657422475668739943625904793879, 66 0.0254851872062589062995305736215, 67 }, 68 }, 69 } { 70 s, ok := NewStudentsT(test.mu, test.sigma, test.nu, src) 71 if !ok { 72 t.Fatal("bad test") 73 } 74 for i, x := range test.x { 75 xcpy := make([]float64, len(x)) 76 copy(xcpy, x) 77 p := s.Prob(x) 78 if !floats.Same(x, xcpy) { 79 t.Errorf("X modified during call to prob, %v, %v", x, xcpy) 80 } 81 if !scalar.EqualWithinAbsOrRel(p, test.probs[i], 1e-10, 1e-10) { 82 t.Errorf("Probability mismatch. X = %v. Got %v, want %v.", x, p, test.probs[i]) 83 } 84 } 85 } 86 } 87 88 func TestStudentsTRand(t *testing.T) { 89 src := rand.New(rand.NewSource(1)) 90 for cas, test := range []struct { 91 mean []float64 92 cov *mat.SymDense 93 nu float64 94 tolcov float64 95 }{ 96 { 97 mean: []float64{0, 0}, 98 cov: mat.NewSymDense(2, []float64{1, 0, 0, 1}), 99 nu: 4, 100 tolcov: 1e-2, 101 }, 102 { 103 mean: []float64{3, 4}, 104 cov: mat.NewSymDense(2, []float64{5, 1.2, 1.2, 6}), 105 nu: 8, 106 tolcov: 1e-2, 107 }, 108 { 109 mean: []float64{3, 4, -2}, 110 cov: mat.NewSymDense(3, []float64{5, 1.2, -0.8, 1.2, 6, 0.4, -0.8, 0.4, 2}), 111 nu: 8, 112 tolcov: 1e-2, 113 }, 114 } { 115 s, ok := NewStudentsT(test.mean, test.cov, test.nu, src) 116 if !ok { 117 t.Fatal("bad test") 118 } 119 const nSamples = 1e6 120 dim := len(test.mean) 121 samps := mat.NewDense(nSamples, dim, nil) 122 for i := 0; i < nSamples; i++ { 123 s.Rand(samps.RawRowView(i)) 124 } 125 estMean := make([]float64, dim) 126 for i := range estMean { 127 estMean[i] = stat.Mean(mat.Col(nil, i, samps), nil) 128 } 129 mean := s.Mean(nil) 130 if !floats.EqualApprox(estMean, mean, 1e-2) { 131 t.Errorf("Mean mismatch: want: %v, got %v", test.mean, estMean) 132 } 133 var cov, estCov mat.SymDense 134 s.CovarianceMatrix(&cov) 135 stat.CovarianceMatrix(&estCov, samps, nil) 136 if !mat.EqualApprox(&estCov, &cov, test.tolcov) { 137 t.Errorf("Case %d: Cov mismatch: want: %v, got %v", cas, &cov, &estCov) 138 } 139 } 140 } 141 142 func TestStudentsTConditional(t *testing.T) { 143 src := rand.New(rand.NewSource(1)) 144 for _, test := range []struct { 145 mean []float64 146 cov *mat.SymDense 147 nu float64 148 149 idx []int 150 value []float64 151 tolcov float64 152 }{ 153 { 154 mean: []float64{3, 4, -2}, 155 cov: mat.NewSymDense(3, []float64{5, 1.2, -0.8, 1.2, 6, 0.4, -0.8, 0.4, 2}), 156 nu: 8, 157 idx: []int{0}, 158 value: []float64{6}, 159 160 tolcov: 1e-2, 161 }, 162 } { 163 s, ok := NewStudentsT(test.mean, test.cov, test.nu, src) 164 if !ok { 165 t.Fatal("bad test") 166 } 167 168 sUp, ok := s.ConditionStudentsT(test.idx, test.value, src) 169 if !ok { 170 t.Error("unexpected failure of ConditionStudentsT") 171 } 172 173 // Compute the other values by hand the inefficient way to compare 174 newNu := test.nu + float64(len(test.idx)) 175 if newNu != sUp.nu { 176 t.Errorf("Updated nu mismatch. Got %v, want %v", s.nu, newNu) 177 } 178 dim := len(test.mean) 179 unob := findUnob(test.idx, dim) 180 ob := test.idx 181 182 muUnob := make([]float64, len(unob)) 183 for i, v := range unob { 184 muUnob[i] = test.mean[v] 185 } 186 muOb := make([]float64, len(ob)) 187 for i, v := range ob { 188 muOb[i] = test.mean[v] 189 } 190 191 var sig11, sig22 mat.SymDense 192 sig11.SubsetSym(&s.sigma, unob) 193 sig22.SubsetSym(&s.sigma, ob) 194 195 sig12 := mat.NewDense(len(unob), len(ob), nil) 196 for i := range unob { 197 for j := range ob { 198 sig12.Set(i, j, s.sigma.At(unob[i], ob[j])) 199 } 200 } 201 202 shift := make([]float64, len(ob)) 203 copy(shift, test.value) 204 floats.Sub(shift, muOb) 205 206 newMu := make([]float64, len(muUnob)) 207 newMuVec := mat.NewVecDense(len(muUnob), newMu) 208 shiftVec := mat.NewVecDense(len(shift), shift) 209 var tmp mat.VecDense 210 err := tmp.SolveVec(&sig22, shiftVec) 211 if err != nil { 212 t.Errorf("unexpected error from vector solve: %v", err) 213 } 214 newMuVec.MulVec(sig12, &tmp) 215 floats.Add(newMu, muUnob) 216 217 if !floats.EqualApprox(newMu, sUp.mu, 1e-10) { 218 t.Errorf("Mu mismatch. Got %v, want %v", sUp.mu, newMu) 219 } 220 221 var tmp2 mat.Dense 222 err = tmp2.Solve(&sig22, sig12.T()) 223 if err != nil { 224 t.Errorf("unexpected error from dense solve: %v", err) 225 } 226 227 var tmp3 mat.Dense 228 tmp3.Mul(sig12, &tmp2) 229 tmp3.Sub(&sig11, &tmp3) 230 231 dot := mat.Dot(shiftVec, &tmp) 232 tmp3.Scale((test.nu+dot)/(test.nu+float64(len(ob))), &tmp3) 233 if !mat.EqualApprox(&tmp3, &sUp.sigma, 1e-10) { 234 t.Errorf("Sigma mismatch") 235 } 236 } 237 } 238 239 func TestStudentsTMarginalSingle(t *testing.T) { 240 for _, test := range []struct { 241 mu []float64 242 sigma *mat.SymDense 243 nu float64 244 }{ 245 { 246 mu: []float64{2, 3, 4}, 247 sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}), 248 nu: 5, 249 }, 250 { 251 mu: []float64{2, 3, 4, 5}, 252 sigma: mat.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}), 253 nu: 6, 254 }, 255 } { 256 studentst, ok := NewStudentsT(test.mu, test.sigma, test.nu, nil) 257 if !ok { 258 t.Fatalf("Bad test, covariance matrix not positive definite") 259 } 260 for i, mean := range test.mu { 261 st := studentst.MarginalStudentsTSingle(i, nil) 262 if st.Mean() != mean { 263 t.Errorf("Mean mismatch nil Sigma, idx %v: want %v, got %v.", i, mean, st.Mean()) 264 } 265 std := math.Sqrt(test.sigma.At(i, i)) 266 if math.Abs(st.Sigma-std) > 1e-14 { 267 t.Errorf("StdDev mismatch nil Sigma, idx %v: want %v, got %v.", i, std, st.StdDev()) 268 } 269 if st.Nu != test.nu { 270 t.Errorf("Nu mismatch nil Sigma, idx %v: want %v, got %v ", i, test.nu, st.Nu) 271 } 272 } 273 } 274 }