gonum.org/v1/gonum@v0.14.0/stat/distmat/wishart_test.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package distmat
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"gonum.org/v1/gonum/floats/scalar"
    14  	"gonum.org/v1/gonum/mat"
    15  )
    16  
    17  func TestWishart(t *testing.T) {
    18  	for c, test := range []struct {
    19  		v   *mat.SymDense
    20  		nu  float64
    21  		xs  []*mat.SymDense
    22  		lps []float64
    23  	}{
    24  		// Logprob data compared with scipy.
    25  		{
    26  			v:  mat.NewSymDense(2, []float64{1, 0, 0, 1}),
    27  			nu: 4,
    28  			xs: []*mat.SymDense{
    29  				mat.NewSymDense(2, []float64{0.9, 0.1, 0.1, 0.9}),
    30  			},
    31  			lps: []float64{-4.2357432031863409},
    32  		},
    33  		{
    34  			v:  mat.NewSymDense(2, []float64{0.8, -0.2, -0.2, 0.7}),
    35  			nu: 5,
    36  			xs: []*mat.SymDense{
    37  				mat.NewSymDense(2, []float64{0.9, 0.1, 0.1, 0.9}),
    38  				mat.NewSymDense(2, []float64{0.3, -0.1, -0.1, 0.7}),
    39  			},
    40  			lps: []float64{-4.2476495605333575, -4.9993285370378633},
    41  		},
    42  		{
    43  			v:  mat.NewSymDense(3, []float64{0.8, 0.3, 0.1, 0.3, 0.7, -0.1, 0.1, -0.1, 7}),
    44  			nu: 5,
    45  			xs: []*mat.SymDense{
    46  				mat.NewSymDense(3, []float64{1, 0.2, -0.3, 0.2, 0.6, -0.2, -0.3, -0.2, 6}),
    47  			},
    48  			lps: []float64{-11.010982249229421},
    49  		},
    50  	} {
    51  		w, ok := NewWishart(test.v, test.nu, nil)
    52  		if !ok {
    53  			panic("bad test")
    54  		}
    55  		for i, x := range test.xs {
    56  			lp := w.LogProbSym(x)
    57  
    58  			var chol mat.Cholesky
    59  			ok := chol.Factorize(x)
    60  			if !ok {
    61  				panic("bad test")
    62  			}
    63  			lpc := w.LogProbSymChol(&chol)
    64  
    65  			if math.Abs(lp-lpc) > 1e-14 {
    66  				t.Errorf("Case %d, test %d: probability mismatch between chol and not", c, i)
    67  			}
    68  			if !scalar.EqualWithinAbsOrRel(lp, test.lps[i], 1e-14, 1e-14) {
    69  				t.Errorf("Case %d, test %d: got %v, want %v", c, i, lp, test.lps[i])
    70  			}
    71  		}
    72  
    73  		var ch mat.Cholesky
    74  		w.RandCholTo(&ch)
    75  		w.RandCholTo(&ch)
    76  
    77  		var s mat.SymDense
    78  		w.RandSymTo(&s)
    79  		w.RandSymTo(&s)
    80  	}
    81  }
    82  
    83  func TestWishartRand(t *testing.T) {
    84  	for c, test := range []struct {
    85  		v       *mat.SymDense
    86  		nu      float64
    87  		samples int
    88  		tol     float64
    89  	}{
    90  		{
    91  			v:       mat.NewSymDense(2, []float64{0.8, -0.2, -0.2, 0.7}),
    92  			nu:      5,
    93  			samples: 30000,
    94  			tol:     3e-2,
    95  		},
    96  		{
    97  			v:       mat.NewSymDense(3, []float64{0.8, 0.3, 0.1, 0.3, 0.7, -0.1, 0.1, -0.1, 7}),
    98  			nu:      5,
    99  			samples: 30000,
   100  			tol:     3e-1,
   101  		},
   102  		{
   103  			v: mat.NewSymDense(4, []float64{
   104  				0.8, 0.3, 0.1, -0.2,
   105  				0.3, 0.7, -0.1, 0.4,
   106  				0.1, -0.1, 7, 1,
   107  				-0.2, -0.1, 1, 6}),
   108  			nu:      6,
   109  			samples: 30000,
   110  			tol:     1e-1,
   111  		},
   112  	} {
   113  		rnd := rand.New(rand.NewSource(1))
   114  		dim := test.v.SymmetricDim()
   115  		w, ok := NewWishart(test.v, test.nu, rnd)
   116  		if !ok {
   117  			panic("bad test")
   118  		}
   119  		mean := mat.NewSymDense(dim, nil)
   120  		x := mat.NewSymDense(dim, nil)
   121  		for i := 0; i < test.samples; i++ {
   122  			w.RandSymTo(x)
   123  			x.ScaleSym(1/float64(test.samples), x)
   124  			mean.AddSym(mean, x)
   125  		}
   126  		var trueMean mat.SymDense
   127  		w.MeanSymTo(&trueMean)
   128  		if !mat.EqualApprox(&trueMean, mean, test.tol) {
   129  			t.Errorf("Case %d: Mismatch between estimated and true mean. Got\n%0.4v\nWant\n%0.4v\n", c, mat.Formatted(mean), mat.Formatted(&trueMean))
   130  		}
   131  	}
   132  }