gonum.org/v1/gonum@v0.14.0/stat/distmv/dirichlet_test.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package distmv 6 7 import ( 8 "math" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/mat" 14 ) 15 16 func TestDirichlet(t *testing.T) { 17 // Data from Scipy. 18 for cas, test := range []struct { 19 Dir *Dirichlet 20 x []float64 21 prob float64 22 }{ 23 { 24 NewDirichlet([]float64{1, 1, 1}, nil), 25 []float64{0.2, 0.3, 0.5}, 26 2.0, 27 }, 28 { 29 NewDirichlet([]float64{0.6, 10, 8.7}, nil), 30 []float64{0.2, 0.3, 0.5}, 31 0.24079612737071665, 32 }, 33 } { 34 p := test.Dir.Prob(test.x) 35 if math.Abs(p-test.prob) > 1e-14 { 36 t.Errorf("Probablility mismatch. Case %v. Got %v, want %v", cas, p, test.prob) 37 } 38 } 39 40 rnd := rand.New(rand.NewSource(1)) 41 for cas, test := range []struct { 42 Dir *Dirichlet 43 }{ 44 { 45 NewDirichlet([]float64{1, 1, 1}, rnd), 46 }, 47 { 48 NewDirichlet([]float64{2, 3}, rnd), 49 }, 50 { 51 NewDirichlet([]float64{0.2, 0.3}, rnd), 52 }, 53 { 54 NewDirichlet([]float64{0.2, 4}, rnd), 55 }, 56 { 57 NewDirichlet([]float64{0.1, 4, 20}, rnd), 58 }, 59 } { 60 const n = 1e5 61 d := test.Dir 62 dim := d.Dim() 63 x := mat.NewDense(n, dim, nil) 64 generateSamples(x, d) 65 checkMean(t, cas, x, d, 1e-2) 66 checkCov(t, cas, x, d, 1e-2) 67 } 68 }