gonum.org/v1/gonum@v0.14.0/stat/samplemv/sample_test.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  package samplemv
     5  
     6  import (
     7  	"fmt"
     8  	"math"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"gonum.org/v1/gonum/floats"
    14  	"gonum.org/v1/gonum/mat"
    15  	"gonum.org/v1/gonum/spatial/r1"
    16  	"gonum.org/v1/gonum/stat"
    17  	"gonum.org/v1/gonum/stat/distmv"
    18  )
    19  
    20  type lhDist interface {
    21  	Quantile(x, p []float64) []float64
    22  	CDF(p, x []float64) []float64
    23  	Dim() int
    24  }
    25  
    26  func TestLatinHypercube(t *testing.T) {
    27  	src := rand.New(rand.NewSource(1))
    28  	for _, nSamples := range []int{1, 2, 5, 10, 20} {
    29  		for _, dist := range []lhDist{
    30  			distmv.NewUniform([]r1.Interval{{Min: 0, Max: 3}}, src),
    31  			distmv.NewUniform([]r1.Interval{{Min: 0, Max: 3}, {Min: -1, Max: 5}, {Min: -4, Max: -1}}, src),
    32  		} {
    33  			dim := dist.Dim()
    34  			batch := mat.NewDense(nSamples, dim, nil)
    35  			LatinHypercube{Src: src, Q: dist}.Sample(batch)
    36  			// Latin hypercube should have one entry per hyperrow.
    37  			present := make([][]bool, nSamples)
    38  			for i := range present {
    39  				present[i] = make([]bool, dim)
    40  			}
    41  			cdf := make([]float64, dim)
    42  			for i := 0; i < nSamples; i++ {
    43  				dist.CDF(cdf, batch.RawRowView(i))
    44  				for j := 0; j < dim; j++ {
    45  					p := cdf[j]
    46  					quadrant := int(math.Floor(p * float64(nSamples)))
    47  					present[quadrant][j] = true
    48  				}
    49  			}
    50  			allPresent := true
    51  			for i := 0; i < nSamples; i++ {
    52  				for j := 0; j < dim; j++ {
    53  					if !present[i][j] {
    54  						allPresent = false
    55  					}
    56  				}
    57  			}
    58  			if !allPresent {
    59  				t.Errorf("All quadrants not present")
    60  			}
    61  		}
    62  	}
    63  }
    64  
    65  func TestImportance(t *testing.T) {
    66  	src := rand.New(rand.NewSource(1))
    67  	// Test by finding the expected value of a multi-variate normal.
    68  	dim := 3
    69  	target, ok := randomNormal(dim, src)
    70  	if !ok {
    71  		t.Fatal("bad test, sigma not pos def")
    72  	}
    73  
    74  	muImp := make([]float64, dim)
    75  	sigmaImp := mat.NewSymDense(dim, nil)
    76  	for i := 0; i < dim; i++ {
    77  		sigmaImp.SetSym(i, i, 3)
    78  	}
    79  	proposal, ok := distmv.NewNormal(muImp, sigmaImp, src)
    80  	if !ok {
    81  		t.Fatal("bad test, sigma not pos def")
    82  	}
    83  
    84  	nSamples := 200000
    85  	batch := mat.NewDense(nSamples, dim, nil)
    86  	weights := make([]float64, nSamples)
    87  	Importance{Target: target, Proposal: proposal}.SampleWeighted(batch, weights)
    88  
    89  	compareNormal(t, target, batch, weights, 5e-2, 5e-2)
    90  }
    91  
    92  func TestRejection(t *testing.T) {
    93  	src := rand.New(rand.NewSource(1))
    94  	// Test by finding the expected value of a uniform.
    95  	dim := 3
    96  	bounds := make([]r1.Interval, dim)
    97  	for i := 0; i < dim; i++ {
    98  		min := src.NormFloat64()
    99  		max := src.NormFloat64()
   100  		if min > max {
   101  			min, max = max, min
   102  		}
   103  		bounds[i].Min = min
   104  		bounds[i].Max = max
   105  	}
   106  	target := distmv.NewUniform(bounds, src)
   107  	mu := target.Mean(nil)
   108  
   109  	muImp := make([]float64, dim)
   110  	sigmaImp := mat.NewSymDense(dim, nil)
   111  	for i := 0; i < dim; i++ {
   112  		sigmaImp.SetSym(i, i, 6)
   113  	}
   114  	proposal, ok := distmv.NewNormal(muImp, sigmaImp, src)
   115  	if !ok {
   116  		t.Fatal("bad test, sigma not pos def")
   117  	}
   118  
   119  	nSamples := 1000
   120  	batch := mat.NewDense(nSamples, dim, nil)
   121  	weights := make([]float64, nSamples)
   122  	rej := Rejection{Target: target, Proposal: proposal, C: 1000, Src: src}
   123  	rej.Sample(batch)
   124  	err := rej.Err()
   125  	if err != nil {
   126  		t.Error("Bad test, nan samples")
   127  	}
   128  
   129  	for i := 0; i < dim; i++ {
   130  		col := mat.Col(nil, i, batch)
   131  		ev := stat.Mean(col, weights)
   132  		if math.Abs(ev-mu[i]) > 1e-2 {
   133  			t.Errorf("Mean mismatch: Want %v, got %v", mu[i], ev)
   134  		}
   135  	}
   136  }
   137  
   138  func TestMetropolisHastings(t *testing.T) {
   139  	src := rand.New(rand.NewSource(1))
   140  	// Test by finding the expected value of a normal distribution.
   141  	dim := 3
   142  	target, ok := randomNormal(dim, src)
   143  	if !ok {
   144  		t.Fatal("bad test, sigma not pos def")
   145  	}
   146  
   147  	sigmaImp := mat.NewSymDense(dim, nil)
   148  	for i := 0; i < dim; i++ {
   149  		sigmaImp.SetSym(i, i, 0.25)
   150  	}
   151  	proposal, ok := NewProposalNormal(sigmaImp, src)
   152  	if !ok {
   153  		t.Fatal("bad test, sigma not pos def")
   154  	}
   155  
   156  	nSamples := 100000
   157  	burnin := 5000
   158  	batch := mat.NewDense(nSamples, dim, nil)
   159  	initial := make([]float64, dim)
   160  	metropolisHastings(batch, initial, target, proposal, src)
   161  	batch = batch.Slice(burnin, nSamples, 0, dim).(*mat.Dense)
   162  
   163  	compareNormal(t, target, batch, nil, 5e-1, 5e-1)
   164  }
   165  
   166  // randomNormal constructs a random Normal distribution using the provided
   167  // random source.
   168  func randomNormal(dim int, src *rand.Rand) (*distmv.Normal, bool) {
   169  	data := make([]float64, dim*dim)
   170  	for i := range data {
   171  		data[i] = src.Float64()
   172  	}
   173  	a := mat.NewDense(dim, dim, data)
   174  	var sigma mat.SymDense
   175  	sigma.SymOuterK(1, a)
   176  	mu := make([]float64, dim)
   177  	for i := range mu {
   178  		mu[i] = rand.NormFloat64()
   179  	}
   180  	return distmv.NewNormal(mu, &sigma, src)
   181  }
   182  
   183  func compareNormal(t *testing.T, want *distmv.Normal, batch *mat.Dense, weights []float64, meanTol, covTol float64) {
   184  	t.Helper()
   185  
   186  	dim := want.Dim()
   187  	mu := want.Mean(nil)
   188  	var sigma mat.SymDense
   189  	want.CovarianceMatrix(&sigma)
   190  	n, _ := batch.Dims()
   191  	if weights == nil {
   192  		weights = make([]float64, n)
   193  		for i := range weights {
   194  			weights[i] = 1
   195  		}
   196  	}
   197  	for i := 0; i < dim; i++ {
   198  		col := mat.Col(nil, i, batch)
   199  		ev := stat.Mean(col, weights)
   200  		if math.Abs(ev-mu[i]) > meanTol {
   201  			t.Errorf("Mean mismatch: Want %v, got %v", mu[i], ev)
   202  		}
   203  	}
   204  
   205  	var cov mat.SymDense
   206  	stat.CovarianceMatrix(&cov, batch, weights)
   207  	if !mat.EqualApprox(&cov, &sigma, covTol) {
   208  		t.Errorf("Covariance matrix mismatch")
   209  	}
   210  }
   211  
   212  func TestMetropolisHastingser(t *testing.T) {
   213  	for _, test := range []struct {
   214  		dim, burnin, rate, samples int
   215  	}{
   216  		{3, 10, 1, 1},
   217  		{3, 10, 2, 1},
   218  		{3, 10, 1, 2},
   219  		{3, 10, 3, 2},
   220  		{3, 10, 7, 4},
   221  		{3, 10, 7, 4},
   222  
   223  		{3, 11, 51, 103},
   224  		{3, 11, 103, 51},
   225  		{3, 51, 11, 103},
   226  		{3, 51, 103, 11},
   227  		{3, 103, 11, 51},
   228  		{3, 103, 51, 11},
   229  	} {
   230  		src := rand.New(rand.NewSource(1))
   231  		dim := test.dim
   232  
   233  		initial := make([]float64, dim)
   234  		target, ok := randomNormal(dim, src)
   235  		if !ok {
   236  			t.Fatal("bad test, sigma not pos def")
   237  		}
   238  
   239  		sigmaImp := mat.NewSymDense(dim, nil)
   240  		for i := 0; i < dim; i++ {
   241  			sigmaImp.SetSym(i, i, 0.25)
   242  		}
   243  
   244  		// Test the Metropolis Hastingser by generating all the samples, then generating
   245  		// the same samples with a burnin and rate.
   246  		src = rand.New(rand.NewSource(1))
   247  		proposal, ok := NewProposalNormal(sigmaImp, src)
   248  		if !ok {
   249  			t.Fatal("bad test, sigma not pos def")
   250  		}
   251  
   252  		mh := MetropolisHastingser{
   253  			Initial:  initial,
   254  			Target:   target,
   255  			Proposal: proposal,
   256  			Src:      src,
   257  			BurnIn:   0,
   258  			Rate:     0,
   259  		}
   260  		samples := test.samples
   261  		burnin := test.burnin
   262  		rate := test.rate
   263  		fullBatch := mat.NewDense(1+burnin+rate*(samples-1), dim, nil)
   264  		mh.Sample(fullBatch)
   265  
   266  		src = rand.New(rand.NewSource(1))
   267  		proposal, _ = NewProposalNormal(sigmaImp, src)
   268  		mh = MetropolisHastingser{
   269  			Initial:  initial,
   270  			Target:   target,
   271  			Proposal: proposal,
   272  			Src:      src,
   273  			BurnIn:   burnin,
   274  			Rate:     rate,
   275  		}
   276  		batch := mat.NewDense(samples, dim, nil)
   277  		mh.Sample(batch)
   278  
   279  		same := true
   280  		count := burnin
   281  		for i := 0; i < samples; i++ {
   282  			if !floats.Equal(batch.RawRowView(i), fullBatch.RawRowView(count)) {
   283  				fmt.Println("sample ", i, "is different")
   284  				same = false
   285  				break
   286  			}
   287  			count += rate
   288  		}
   289  
   290  		if !same {
   291  			fmt.Printf("%v\n", mat.Formatted(batch))
   292  			fmt.Printf("%v\n", mat.Formatted(fullBatch))
   293  
   294  			t.Errorf("sampling mismatch: dim = %v, burnin = %v, rate = %v, samples = %v", dim, burnin, rate, samples)
   295  		}
   296  	}
   297  }