github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/stat/distmv/studentst_test.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package distmv
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"github.com/jingcheng-WU/gonum/floats"
    14  	"github.com/jingcheng-WU/gonum/floats/scalar"
    15  	"github.com/jingcheng-WU/gonum/mat"
    16  	"github.com/jingcheng-WU/gonum/stat"
    17  )
    18  
    19  func TestStudentTProbs(t *testing.T) {
    20  	src := rand.New(rand.NewSource(1))
    21  	for _, test := range []struct {
    22  		nu    float64
    23  		mu    []float64
    24  		sigma *mat.SymDense
    25  
    26  		x     [][]float64
    27  		probs []float64
    28  	}{
    29  		{
    30  			nu:    3,
    31  			mu:    []float64{0, 0},
    32  			sigma: mat.NewSymDense(2, []float64{1, 0, 0, 1}),
    33  
    34  			x: [][]float64{
    35  				{0, 0},
    36  				{1, -1},
    37  				{3, 4},
    38  				{-1, -2},
    39  			},
    40  			// Outputs compared with WolframAlpha.
    41  			probs: []float64{
    42  				0.159154943091895335768883,
    43  				0.0443811199724279860006777747927,
    44  				0.0005980371870904696541052658,
    45  				0.01370560783418571283428283,
    46  			},
    47  		},
    48  		{
    49  			nu:    4,
    50  			mu:    []float64{2, -3},
    51  			sigma: mat.NewSymDense(2, []float64{8, -1, -1, 5}),
    52  
    53  			x: [][]float64{
    54  				{0, 0},
    55  				{1, -1},
    56  				{3, 4},
    57  				{-1, -2},
    58  				{2, -3},
    59  			},
    60  			// Outputs compared with WolframAlpha.
    61  			probs: []float64{
    62  				0.007360810111491788657953608191001,
    63  				0.0143309905845607117740440592999,
    64  				0.0005307774290578041397794096037035009801668903,
    65  				0.0115657422475668739943625904793879,
    66  				0.0254851872062589062995305736215,
    67  			},
    68  		},
    69  	} {
    70  		s, ok := NewStudentsT(test.mu, test.sigma, test.nu, src)
    71  		if !ok {
    72  			t.Fatal("bad test")
    73  		}
    74  		for i, x := range test.x {
    75  			xcpy := make([]float64, len(x))
    76  			copy(xcpy, x)
    77  			p := s.Prob(x)
    78  			if !floats.Same(x, xcpy) {
    79  				t.Errorf("X modified during call to prob, %v, %v", x, xcpy)
    80  			}
    81  			if !scalar.EqualWithinAbsOrRel(p, test.probs[i], 1e-10, 1e-10) {
    82  				t.Errorf("Probability mismatch. X = %v. Got %v, want %v.", x, p, test.probs[i])
    83  			}
    84  		}
    85  	}
    86  }
    87  
    88  func TestStudentsTRand(t *testing.T) {
    89  	src := rand.New(rand.NewSource(1))
    90  	for cas, test := range []struct {
    91  		mean   []float64
    92  		cov    *mat.SymDense
    93  		nu     float64
    94  		tolcov float64
    95  	}{
    96  		{
    97  			mean:   []float64{0, 0},
    98  			cov:    mat.NewSymDense(2, []float64{1, 0, 0, 1}),
    99  			nu:     4,
   100  			tolcov: 1e-2,
   101  		},
   102  		{
   103  			mean:   []float64{3, 4},
   104  			cov:    mat.NewSymDense(2, []float64{5, 1.2, 1.2, 6}),
   105  			nu:     8,
   106  			tolcov: 1e-2,
   107  		},
   108  		{
   109  			mean:   []float64{3, 4, -2},
   110  			cov:    mat.NewSymDense(3, []float64{5, 1.2, -0.8, 1.2, 6, 0.4, -0.8, 0.4, 2}),
   111  			nu:     8,
   112  			tolcov: 1e-2,
   113  		},
   114  	} {
   115  		s, ok := NewStudentsT(test.mean, test.cov, test.nu, src)
   116  		if !ok {
   117  			t.Fatal("bad test")
   118  		}
   119  		const nSamples = 1e6
   120  		dim := len(test.mean)
   121  		samps := mat.NewDense(nSamples, dim, nil)
   122  		for i := 0; i < nSamples; i++ {
   123  			s.Rand(samps.RawRowView(i))
   124  		}
   125  		estMean := make([]float64, dim)
   126  		for i := range estMean {
   127  			estMean[i] = stat.Mean(mat.Col(nil, i, samps), nil)
   128  		}
   129  		mean := s.Mean(nil)
   130  		if !floats.EqualApprox(estMean, mean, 1e-2) {
   131  			t.Errorf("Mean mismatch: want: %v, got %v", test.mean, estMean)
   132  		}
   133  		var cov, estCov mat.SymDense
   134  		s.CovarianceMatrix(&cov)
   135  		stat.CovarianceMatrix(&estCov, samps, nil)
   136  		if !mat.EqualApprox(&estCov, &cov, test.tolcov) {
   137  			t.Errorf("Case %d: Cov mismatch: want: %v, got %v", cas, &cov, &estCov)
   138  		}
   139  	}
   140  }
   141  
   142  func TestStudentsTConditional(t *testing.T) {
   143  	src := rand.New(rand.NewSource(1))
   144  	for _, test := range []struct {
   145  		mean []float64
   146  		cov  *mat.SymDense
   147  		nu   float64
   148  
   149  		idx    []int
   150  		value  []float64
   151  		tolcov float64
   152  	}{
   153  		{
   154  			mean:  []float64{3, 4, -2},
   155  			cov:   mat.NewSymDense(3, []float64{5, 1.2, -0.8, 1.2, 6, 0.4, -0.8, 0.4, 2}),
   156  			nu:    8,
   157  			idx:   []int{0},
   158  			value: []float64{6},
   159  
   160  			tolcov: 1e-2,
   161  		},
   162  	} {
   163  		s, ok := NewStudentsT(test.mean, test.cov, test.nu, src)
   164  		if !ok {
   165  			t.Fatal("bad test")
   166  		}
   167  
   168  		sUp, ok := s.ConditionStudentsT(test.idx, test.value, src)
   169  		if !ok {
   170  			t.Error("unexpected failure of ConditionStudentsT")
   171  		}
   172  
   173  		// Compute the other values by hand the inefficient way to compare
   174  		newNu := test.nu + float64(len(test.idx))
   175  		if newNu != sUp.nu {
   176  			t.Errorf("Updated nu mismatch. Got %v, want %v", s.nu, newNu)
   177  		}
   178  		dim := len(test.mean)
   179  		unob := findUnob(test.idx, dim)
   180  		ob := test.idx
   181  
   182  		muUnob := make([]float64, len(unob))
   183  		for i, v := range unob {
   184  			muUnob[i] = test.mean[v]
   185  		}
   186  		muOb := make([]float64, len(ob))
   187  		for i, v := range ob {
   188  			muOb[i] = test.mean[v]
   189  		}
   190  
   191  		var sig11, sig22 mat.SymDense
   192  		sig11.SubsetSym(&s.sigma, unob)
   193  		sig22.SubsetSym(&s.sigma, ob)
   194  
   195  		sig12 := mat.NewDense(len(unob), len(ob), nil)
   196  		for i := range unob {
   197  			for j := range ob {
   198  				sig12.Set(i, j, s.sigma.At(unob[i], ob[j]))
   199  			}
   200  		}
   201  
   202  		shift := make([]float64, len(ob))
   203  		copy(shift, test.value)
   204  		floats.Sub(shift, muOb)
   205  
   206  		newMu := make([]float64, len(muUnob))
   207  		newMuVec := mat.NewVecDense(len(muUnob), newMu)
   208  		shiftVec := mat.NewVecDense(len(shift), shift)
   209  		var tmp mat.VecDense
   210  		err := tmp.SolveVec(&sig22, shiftVec)
   211  		if err != nil {
   212  			t.Errorf("unexpected error from vector solve: %v", err)
   213  		}
   214  		newMuVec.MulVec(sig12, &tmp)
   215  		floats.Add(newMu, muUnob)
   216  
   217  		if !floats.EqualApprox(newMu, sUp.mu, 1e-10) {
   218  			t.Errorf("Mu mismatch. Got %v, want %v", sUp.mu, newMu)
   219  		}
   220  
   221  		var tmp2 mat.Dense
   222  		err = tmp2.Solve(&sig22, sig12.T())
   223  		if err != nil {
   224  			t.Errorf("unexpected error from dense solve: %v", err)
   225  		}
   226  
   227  		var tmp3 mat.Dense
   228  		tmp3.Mul(sig12, &tmp2)
   229  		tmp3.Sub(&sig11, &tmp3)
   230  
   231  		dot := mat.Dot(shiftVec, &tmp)
   232  		tmp3.Scale((test.nu+dot)/(test.nu+float64(len(ob))), &tmp3)
   233  		if !mat.EqualApprox(&tmp3, &sUp.sigma, 1e-10) {
   234  			t.Errorf("Sigma mismatch")
   235  		}
   236  	}
   237  }
   238  
   239  func TestStudentsTMarginalSingle(t *testing.T) {
   240  	for _, test := range []struct {
   241  		mu    []float64
   242  		sigma *mat.SymDense
   243  		nu    float64
   244  	}{
   245  		{
   246  			mu:    []float64{2, 3, 4},
   247  			sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}),
   248  			nu:    5,
   249  		},
   250  		{
   251  			mu:    []float64{2, 3, 4, 5},
   252  			sigma: mat.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}),
   253  			nu:    6,
   254  		},
   255  	} {
   256  		studentst, ok := NewStudentsT(test.mu, test.sigma, test.nu, nil)
   257  		if !ok {
   258  			t.Fatalf("Bad test, covariance matrix not positive definite")
   259  		}
   260  		for i, mean := range test.mu {
   261  			st := studentst.MarginalStudentsTSingle(i, nil)
   262  			if st.Mean() != mean {
   263  				t.Errorf("Mean mismatch nil Sigma, idx %v: want %v, got %v.", i, mean, st.Mean())
   264  			}
   265  			std := math.Sqrt(test.sigma.At(i, i))
   266  			if math.Abs(st.Sigma-std) > 1e-14 {
   267  				t.Errorf("StdDev mismatch nil Sigma, idx %v: want %v, got %v.", i, std, st.StdDev())
   268  			}
   269  			if st.Nu != test.nu {
   270  				t.Errorf("Nu mismatch nil Sigma, idx %v: want %v, got %v ", i, test.nu, st.Nu)
   271  			}
   272  		}
   273  	}
   274  }