github.com/gopherd/gonum@v0.0.4/stat/distmv/general_test.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package distmv
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  
    11  	"github.com/gopherd/gonum/floats"
    12  	"github.com/gopherd/gonum/mat"
    13  	"github.com/gopherd/gonum/stat"
    14  )
    15  
    16  type prober interface {
    17  	Prob(x []float64) float64
    18  	LogProb(x []float64) float64
    19  }
    20  
    21  type probCase struct {
    22  	dist    prober
    23  	loc     []float64
    24  	logProb float64
    25  }
    26  
    27  func testProbability(t *testing.T, cases []probCase) {
    28  	for _, test := range cases {
    29  		logProb := test.dist.LogProb(test.loc)
    30  		if math.Abs(logProb-test.logProb) > 1e-14 {
    31  			t.Errorf("LogProb mismatch: want: %v, got: %v", test.logProb, logProb)
    32  		}
    33  		prob := test.dist.Prob(test.loc)
    34  		if math.Abs(prob-math.Exp(test.logProb)) > 1e-14 {
    35  			t.Errorf("Prob mismatch: want: %v, got: %v", math.Exp(test.logProb), prob)
    36  		}
    37  	}
    38  }
    39  
    40  func generateSamples(x *mat.Dense, r Rander) {
    41  	n, _ := x.Dims()
    42  	for i := 0; i < n; i++ {
    43  		r.Rand(x.RawRowView(i))
    44  	}
    45  }
    46  
    47  type Meaner interface {
    48  	Mean([]float64) []float64
    49  }
    50  
    51  func checkMean(t *testing.T, cas int, x *mat.Dense, m Meaner, tol float64) {
    52  	mean := m.Mean(nil)
    53  
    54  	// Check that the answer is identical when using nil or non-nil.
    55  	mean2 := make([]float64, len(mean))
    56  	m.Mean(mean2)
    57  	if !floats.Equal(mean, mean2) {
    58  		t.Errorf("Mean mismatch when providing nil and slice. Case %v", cas)
    59  	}
    60  
    61  	// Check that the mean matches the samples.
    62  	r, _ := x.Dims()
    63  	col := make([]float64, r)
    64  	meanEst := make([]float64, len(mean))
    65  	for i := range meanEst {
    66  		meanEst[i] = stat.Mean(mat.Col(col, i, x), nil)
    67  	}
    68  	if !floats.EqualApprox(mean, meanEst, tol) {
    69  		t.Errorf("Returned mean and sample mean mismatch. Case %v. Empirical %v, returned %v", cas, meanEst, mean)
    70  	}
    71  }
    72  
    73  type Cover interface {
    74  	CovarianceMatrix(*mat.SymDense)
    75  }
    76  
    77  func checkCov(t *testing.T, cas int, x *mat.Dense, c Cover, tol float64) {
    78  	var cov mat.SymDense
    79  	c.CovarianceMatrix(&cov)
    80  	n := cov.SymmetricDim()
    81  	cov2 := mat.NewSymDense(n, nil)
    82  	c.CovarianceMatrix(cov2)
    83  	if !mat.Equal(&cov, cov2) {
    84  		t.Errorf("Cov mismatch when providing nil and matrix. Case %v", cas)
    85  	}
    86  	var cov3 mat.SymDense
    87  	c.CovarianceMatrix(&cov3)
    88  	if !mat.Equal(&cov, &cov3) {
    89  		t.Errorf("Cov mismatch when providing zero matrix. Case %v", cas)
    90  	}
    91  
    92  	// Check that the covariance matrix matches the samples
    93  	var covEst mat.SymDense
    94  	stat.CovarianceMatrix(&covEst, x, nil)
    95  	if !mat.EqualApprox(&covEst, &cov, tol) {
    96  		t.Errorf("Return cov and sample cov mismatch. Cas %v.\nGot:\n%0.4v\nWant:\n%0.4v", cas, mat.Formatted(&cov), mat.Formatted(&covEst))
    97  	}
    98  }