github.com/gopherd/gonum@v0.0.4/stat/samplemv/sample_test.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  package samplemv
     5  
     6  import (
     7  	"fmt"
     8  	"math"
     9  	"testing"
    10  
    11  	"math/rand"
    12  
    13  	"github.com/gopherd/gonum/floats"
    14  	"github.com/gopherd/gonum/mat"
    15  	"github.com/gopherd/gonum/spatial/r1"
    16  	"github.com/gopherd/gonum/stat"
    17  	"github.com/gopherd/gonum/stat/distmv"
    18  )
    19  
    20  type lhDist interface {
    21  	Quantile(x, p []float64) []float64
    22  	CDF(p, x []float64) []float64
    23  	Dim() int
    24  }
    25  
    26  func TestLatinHypercube(t *testing.T) {
    27  	src := rand.New(rand.NewSource(1))
    28  	for _, nSamples := range []int{1, 2, 5, 10, 20} {
    29  		for _, dist := range []lhDist{
    30  			distmv.NewUniform([]r1.Interval{{Min: 0, Max: 3}}, src),
    31  			distmv.NewUniform([]r1.Interval{{Min: 0, Max: 3}, {Min: -1, Max: 5}, {Min: -4, Max: -1}}, src),
    32  		} {
    33  			dim := dist.Dim()
    34  			batch := mat.NewDense(nSamples, dim, nil)
    35  			LatinHypercube{Src: src, Q: dist}.Sample(batch)
    36  			// Latin hypercube should have one entry per hyperrow.
    37  			present := make([][]bool, nSamples)
    38  			for i := range present {
    39  				present[i] = make([]bool, dim)
    40  			}
    41  			cdf := make([]float64, dim)
    42  			for i := 0; i < nSamples; i++ {
    43  				dist.CDF(cdf, batch.RawRowView(i))
    44  				for j := 0; j < dim; j++ {
    45  					p := cdf[j]
    46  					quadrant := int(math.Floor(p * float64(nSamples)))
    47  					present[quadrant][j] = true
    48  				}
    49  			}
    50  			allPresent := true
    51  			for i := 0; i < nSamples; i++ {
    52  				for j := 0; j < dim; j++ {
    53  					if !present[i][j] {
    54  						allPresent = false
    55  					}
    56  				}
    57  			}
    58  			if !allPresent {
    59  				t.Errorf("All quadrants not present")
    60  			}
    61  		}
    62  	}
    63  }
    64  
    65  func TestImportance(t *testing.T) {
    66  	src := rand.New(rand.NewSource(1))
    67  	// Test by finding the expected value of a multi-variate normal.
    68  	dim := 3
    69  	target, ok := randomNormal(dim, src)
    70  	if !ok {
    71  		t.Fatal("bad test, sigma not pos def")
    72  	}
    73  
    74  	muImp := make([]float64, dim)
    75  	sigmaImp := mat.NewSymDense(dim, nil)
    76  	for i := 0; i < dim; i++ {
    77  		sigmaImp.SetSym(i, i, 3)
    78  	}
    79  	proposal, ok := distmv.NewNormal(muImp, sigmaImp, src)
    80  	if !ok {
    81  		t.Fatal("bad test, sigma not pos def")
    82  	}
    83  
    84  	nSamples := 200000
    85  	batch := mat.NewDense(nSamples, dim, nil)
    86  	weights := make([]float64, nSamples)
    87  	Importance{Target: target, Proposal: proposal}.SampleWeighted(batch, weights)
    88  
    89  	compareNormal(t, target, batch, weights, 5e-2, 5e-2)
    90  }
    91  
    92  func TestRejection(t *testing.T) {
    93  	src := rand.New(rand.NewSource(1))
    94  	// Test by finding the expected value of a uniform.
    95  	dim := 3
    96  	bounds := make([]r1.Interval, dim)
    97  	for i := 0; i < dim; i++ {
    98  		min := src.NormFloat64()
    99  		max := src.NormFloat64()
   100  		if min > max {
   101  			min, max = max, min
   102  		}
   103  		bounds[i].Min = min
   104  		bounds[i].Max = max
   105  	}
   106  	target := distmv.NewUniform(bounds, src)
   107  	mu := target.Mean(nil)
   108  
   109  	muImp := make([]float64, dim)
   110  	sigmaImp := mat.NewSymDense(dim, nil)
   111  	for i := 0; i < dim; i++ {
   112  		sigmaImp.SetSym(i, i, 6)
   113  	}
   114  	proposal, ok := distmv.NewNormal(muImp, sigmaImp, src)
   115  	if !ok {
   116  		t.Fatal("bad test, sigma not pos def")
   117  	}
   118  
   119  	nSamples := 1000
   120  	batch := mat.NewDense(nSamples, dim, nil)
   121  	weights := make([]float64, nSamples)
   122  	rej := Rejection{Target: target, Proposal: proposal, C: 1000, Src: src}
   123  	rej.Sample(batch)
   124  	err := rej.Err()
   125  	if err != nil {
   126  		t.Error("Bad test, nan samples")
   127  	}
   128  
   129  	for i := 0; i < dim; i++ {
   130  		col := mat.Col(nil, i, batch)
   131  		ev := stat.Mean(col, weights)
   132  		if math.Abs(ev-mu[i]) > 1e-2 {
   133  			t.Errorf("Mean mismatch: Want %v, got %v", mu[i], ev)
   134  		}
   135  	}
   136  }
   137  
   138  func TestMetropolisHastings(t *testing.T) {
   139  	src := rand.New(rand.NewSource(1))
   140  	// Test by finding the expected value of a normal distribution.
   141  	dim := 3
   142  	target, ok := randomNormal(dim, src)
   143  	if !ok {
   144  		t.Fatal("bad test, sigma not pos def")
   145  	}
   146  
   147  	sigmaImp := mat.NewSymDense(dim, nil)
   148  	for i := 0; i < dim; i++ {
   149  		sigmaImp.SetSym(i, i, 0.25)
   150  	}
   151  	proposal, ok := NewProposalNormal(sigmaImp, src)
   152  	if !ok {
   153  		t.Fatal("bad test, sigma not pos def")
   154  	}
   155  
   156  	nSamples := 100000
   157  	burnin := 5000
   158  	batch := mat.NewDense(nSamples, dim, nil)
   159  	initial := make([]float64, dim)
   160  	metropolisHastings(batch, initial, target, proposal, src)
   161  	batch = batch.Slice(burnin, nSamples, 0, dim).(*mat.Dense)
   162  
   163  	compareNormal(t, target, batch, nil, 5e-1, 5e-1)
   164  }
   165  
   166  // randomNormal constructs a random Normal distribution.
   167  func randomNormal(dim int, src *rand.Rand) (*distmv.Normal, bool) {
   168  	data := make([]float64, dim*dim)
   169  	for i := range data {
   170  		data[i] = rand.Float64()
   171  	}
   172  	a := mat.NewDense(dim, dim, data)
   173  	var sigma mat.SymDense
   174  	sigma.SymOuterK(1, a)
   175  	mu := make([]float64, dim)
   176  	for i := range mu {
   177  		mu[i] = rand.NormFloat64()
   178  	}
   179  	return distmv.NewNormal(mu, &sigma, src)
   180  }
   181  
   182  func compareNormal(t *testing.T, want *distmv.Normal, batch *mat.Dense, weights []float64, meanTol, covTol float64) {
   183  	dim := want.Dim()
   184  	mu := want.Mean(nil)
   185  	var sigma mat.SymDense
   186  	want.CovarianceMatrix(&sigma)
   187  	n, _ := batch.Dims()
   188  	if weights == nil {
   189  		weights = make([]float64, n)
   190  		for i := range weights {
   191  			weights[i] = 1
   192  		}
   193  	}
   194  	for i := 0; i < dim; i++ {
   195  		col := mat.Col(nil, i, batch)
   196  		ev := stat.Mean(col, weights)
   197  		if math.Abs(ev-mu[i]) > meanTol {
   198  			t.Errorf("Mean mismatch: Want %v, got %v", mu[i], ev)
   199  		}
   200  	}
   201  
   202  	var cov mat.SymDense
   203  	stat.CovarianceMatrix(&cov, batch, weights)
   204  	if !mat.EqualApprox(&cov, &sigma, covTol) {
   205  		t.Errorf("Covariance matrix mismatch")
   206  	}
   207  }
   208  
   209  func TestMetropolisHastingser(t *testing.T) {
   210  	for _, test := range []struct {
   211  		dim, burnin, rate, samples int
   212  	}{
   213  		{3, 10, 1, 1},
   214  		{3, 10, 2, 1},
   215  		{3, 10, 1, 2},
   216  		{3, 10, 3, 2},
   217  		{3, 10, 7, 4},
   218  		{3, 10, 7, 4},
   219  
   220  		{3, 11, 51, 103},
   221  		{3, 11, 103, 51},
   222  		{3, 51, 11, 103},
   223  		{3, 51, 103, 11},
   224  		{3, 103, 11, 51},
   225  		{3, 103, 51, 11},
   226  	} {
   227  		dim := test.dim
   228  
   229  		initial := make([]float64, dim)
   230  		target, ok := randomNormal(dim, nil)
   231  		if !ok {
   232  			t.Fatal("bad test, sigma not pos def")
   233  		}
   234  
   235  		sigmaImp := mat.NewSymDense(dim, nil)
   236  		for i := 0; i < dim; i++ {
   237  			sigmaImp.SetSym(i, i, 0.25)
   238  		}
   239  
   240  		// Test the Metropolis Hastingser by generating all the samples, then generating
   241  		// the same samples with a burnin and rate.
   242  		src := rand.New(rand.NewSource(1))
   243  		proposal, ok := NewProposalNormal(sigmaImp, src)
   244  		if !ok {
   245  			t.Fatal("bad test, sigma not pos def")
   246  		}
   247  
   248  		mh := MetropolisHastingser{
   249  			Initial:  initial,
   250  			Target:   target,
   251  			Proposal: proposal,
   252  			Src:      src,
   253  			BurnIn:   0,
   254  			Rate:     0,
   255  		}
   256  		samples := test.samples
   257  		burnin := test.burnin
   258  		rate := test.rate
   259  		fullBatch := mat.NewDense(1+burnin+rate*(samples-1), dim, nil)
   260  		mh.Sample(fullBatch)
   261  
   262  		src = rand.New(rand.NewSource(1))
   263  		proposal, _ = NewProposalNormal(sigmaImp, src)
   264  		mh = MetropolisHastingser{
   265  			Initial:  initial,
   266  			Target:   target,
   267  			Proposal: proposal,
   268  			Src:      src,
   269  			BurnIn:   burnin,
   270  			Rate:     rate,
   271  		}
   272  		batch := mat.NewDense(samples, dim, nil)
   273  		mh.Sample(batch)
   274  
   275  		same := true
   276  		count := burnin
   277  		for i := 0; i < samples; i++ {
   278  			if !floats.Equal(batch.RawRowView(i), fullBatch.RawRowView(count)) {
   279  				fmt.Println("sample ", i, "is different")
   280  				same = false
   281  				break
   282  			}
   283  			count += rate
   284  		}
   285  
   286  		if !same {
   287  			fmt.Printf("%v\n", mat.Formatted(batch))
   288  			fmt.Printf("%v\n", mat.Formatted(fullBatch))
   289  
   290  			t.Errorf("sampling mismatch: dim = %v, burnin = %v, rate = %v, samples = %v", dim, burnin, rate, samples)
   291  		}
   292  	}
   293  }