gonum.org/v1/gonum@v0.14.0/stat/distmv/dirichlet_test.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package distmv
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"gonum.org/v1/gonum/mat"
    14  )
    15  
    16  func TestDirichlet(t *testing.T) {
    17  	// Data from Scipy.
    18  	for cas, test := range []struct {
    19  		Dir  *Dirichlet
    20  		x    []float64
    21  		prob float64
    22  	}{
    23  		{
    24  			NewDirichlet([]float64{1, 1, 1}, nil),
    25  			[]float64{0.2, 0.3, 0.5},
    26  			2.0,
    27  		},
    28  		{
    29  			NewDirichlet([]float64{0.6, 10, 8.7}, nil),
    30  			[]float64{0.2, 0.3, 0.5},
    31  			0.24079612737071665,
    32  		},
    33  	} {
    34  		p := test.Dir.Prob(test.x)
    35  		if math.Abs(p-test.prob) > 1e-14 {
    36  			t.Errorf("Probablility mismatch. Case %v. Got %v, want %v", cas, p, test.prob)
    37  		}
    38  	}
    39  
    40  	rnd := rand.New(rand.NewSource(1))
    41  	for cas, test := range []struct {
    42  		Dir *Dirichlet
    43  	}{
    44  		{
    45  			NewDirichlet([]float64{1, 1, 1}, rnd),
    46  		},
    47  		{
    48  			NewDirichlet([]float64{2, 3}, rnd),
    49  		},
    50  		{
    51  			NewDirichlet([]float64{0.2, 0.3}, rnd),
    52  		},
    53  		{
    54  			NewDirichlet([]float64{0.2, 4}, rnd),
    55  		},
    56  		{
    57  			NewDirichlet([]float64{0.1, 4, 20}, rnd),
    58  		},
    59  	} {
    60  		const n = 1e5
    61  		d := test.Dir
    62  		dim := d.Dim()
    63  		x := mat.NewDense(n, dim, nil)
    64  		generateSamples(x, d)
    65  		checkMean(t, cas, x, d, 1e-2)
    66  		checkCov(t, cas, x, d, 1e-2)
    67  	}
    68  }