github.com/gopherd/gonum@v0.0.4/stat/distmv/general_test.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package distmv 6 7 import ( 8 "math" 9 "testing" 10 11 "github.com/gopherd/gonum/floats" 12 "github.com/gopherd/gonum/mat" 13 "github.com/gopherd/gonum/stat" 14 ) 15 16 type prober interface { 17 Prob(x []float64) float64 18 LogProb(x []float64) float64 19 } 20 21 type probCase struct { 22 dist prober 23 loc []float64 24 logProb float64 25 } 26 27 func testProbability(t *testing.T, cases []probCase) { 28 for _, test := range cases { 29 logProb := test.dist.LogProb(test.loc) 30 if math.Abs(logProb-test.logProb) > 1e-14 { 31 t.Errorf("LogProb mismatch: want: %v, got: %v", test.logProb, logProb) 32 } 33 prob := test.dist.Prob(test.loc) 34 if math.Abs(prob-math.Exp(test.logProb)) > 1e-14 { 35 t.Errorf("Prob mismatch: want: %v, got: %v", math.Exp(test.logProb), prob) 36 } 37 } 38 } 39 40 func generateSamples(x *mat.Dense, r Rander) { 41 n, _ := x.Dims() 42 for i := 0; i < n; i++ { 43 r.Rand(x.RawRowView(i)) 44 } 45 } 46 47 type Meaner interface { 48 Mean([]float64) []float64 49 } 50 51 func checkMean(t *testing.T, cas int, x *mat.Dense, m Meaner, tol float64) { 52 mean := m.Mean(nil) 53 54 // Check that the answer is identical when using nil or non-nil. 55 mean2 := make([]float64, len(mean)) 56 m.Mean(mean2) 57 if !floats.Equal(mean, mean2) { 58 t.Errorf("Mean mismatch when providing nil and slice. Case %v", cas) 59 } 60 61 // Check that the mean matches the samples. 62 r, _ := x.Dims() 63 col := make([]float64, r) 64 meanEst := make([]float64, len(mean)) 65 for i := range meanEst { 66 meanEst[i] = stat.Mean(mat.Col(col, i, x), nil) 67 } 68 if !floats.EqualApprox(mean, meanEst, tol) { 69 t.Errorf("Returned mean and sample mean mismatch. Case %v. Empirical %v, returned %v", cas, meanEst, mean) 70 } 71 } 72 73 type Cover interface { 74 CovarianceMatrix(*mat.SymDense) 75 } 76 77 func checkCov(t *testing.T, cas int, x *mat.Dense, c Cover, tol float64) { 78 var cov mat.SymDense 79 c.CovarianceMatrix(&cov) 80 n := cov.SymmetricDim() 81 cov2 := mat.NewSymDense(n, nil) 82 c.CovarianceMatrix(cov2) 83 if !mat.Equal(&cov, cov2) { 84 t.Errorf("Cov mismatch when providing nil and matrix. Case %v", cas) 85 } 86 var cov3 mat.SymDense 87 c.CovarianceMatrix(&cov3) 88 if !mat.Equal(&cov, &cov3) { 89 t.Errorf("Cov mismatch when providing zero matrix. Case %v", cas) 90 } 91 92 // Check that the covariance matrix matches the samples 93 var covEst mat.SymDense 94 stat.CovarianceMatrix(&covEst, x, nil) 95 if !mat.EqualApprox(&covEst, &cov, tol) { 96 t.Errorf("Return cov and sample cov mismatch. Cas %v.\nGot:\n%0.4v\nWant:\n%0.4v", cas, mat.Formatted(&cov), mat.Formatted(&covEst)) 97 } 98 }