gonum.org/v1/gonum@v0.14.0/stat/samplemv/sample_test.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 package samplemv 5 6 import ( 7 "fmt" 8 "math" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/floats" 14 "gonum.org/v1/gonum/mat" 15 "gonum.org/v1/gonum/spatial/r1" 16 "gonum.org/v1/gonum/stat" 17 "gonum.org/v1/gonum/stat/distmv" 18 ) 19 20 type lhDist interface { 21 Quantile(x, p []float64) []float64 22 CDF(p, x []float64) []float64 23 Dim() int 24 } 25 26 func TestLatinHypercube(t *testing.T) { 27 src := rand.New(rand.NewSource(1)) 28 for _, nSamples := range []int{1, 2, 5, 10, 20} { 29 for _, dist := range []lhDist{ 30 distmv.NewUniform([]r1.Interval{{Min: 0, Max: 3}}, src), 31 distmv.NewUniform([]r1.Interval{{Min: 0, Max: 3}, {Min: -1, Max: 5}, {Min: -4, Max: -1}}, src), 32 } { 33 dim := dist.Dim() 34 batch := mat.NewDense(nSamples, dim, nil) 35 LatinHypercube{Src: src, Q: dist}.Sample(batch) 36 // Latin hypercube should have one entry per hyperrow. 37 present := make([][]bool, nSamples) 38 for i := range present { 39 present[i] = make([]bool, dim) 40 } 41 cdf := make([]float64, dim) 42 for i := 0; i < nSamples; i++ { 43 dist.CDF(cdf, batch.RawRowView(i)) 44 for j := 0; j < dim; j++ { 45 p := cdf[j] 46 quadrant := int(math.Floor(p * float64(nSamples))) 47 present[quadrant][j] = true 48 } 49 } 50 allPresent := true 51 for i := 0; i < nSamples; i++ { 52 for j := 0; j < dim; j++ { 53 if !present[i][j] { 54 allPresent = false 55 } 56 } 57 } 58 if !allPresent { 59 t.Errorf("All quadrants not present") 60 } 61 } 62 } 63 } 64 65 func TestImportance(t *testing.T) { 66 src := rand.New(rand.NewSource(1)) 67 // Test by finding the expected value of a multi-variate normal. 68 dim := 3 69 target, ok := randomNormal(dim, src) 70 if !ok { 71 t.Fatal("bad test, sigma not pos def") 72 } 73 74 muImp := make([]float64, dim) 75 sigmaImp := mat.NewSymDense(dim, nil) 76 for i := 0; i < dim; i++ { 77 sigmaImp.SetSym(i, i, 3) 78 } 79 proposal, ok := distmv.NewNormal(muImp, sigmaImp, src) 80 if !ok { 81 t.Fatal("bad test, sigma not pos def") 82 } 83 84 nSamples := 200000 85 batch := mat.NewDense(nSamples, dim, nil) 86 weights := make([]float64, nSamples) 87 Importance{Target: target, Proposal: proposal}.SampleWeighted(batch, weights) 88 89 compareNormal(t, target, batch, weights, 5e-2, 5e-2) 90 } 91 92 func TestRejection(t *testing.T) { 93 src := rand.New(rand.NewSource(1)) 94 // Test by finding the expected value of a uniform. 95 dim := 3 96 bounds := make([]r1.Interval, dim) 97 for i := 0; i < dim; i++ { 98 min := src.NormFloat64() 99 max := src.NormFloat64() 100 if min > max { 101 min, max = max, min 102 } 103 bounds[i].Min = min 104 bounds[i].Max = max 105 } 106 target := distmv.NewUniform(bounds, src) 107 mu := target.Mean(nil) 108 109 muImp := make([]float64, dim) 110 sigmaImp := mat.NewSymDense(dim, nil) 111 for i := 0; i < dim; i++ { 112 sigmaImp.SetSym(i, i, 6) 113 } 114 proposal, ok := distmv.NewNormal(muImp, sigmaImp, src) 115 if !ok { 116 t.Fatal("bad test, sigma not pos def") 117 } 118 119 nSamples := 1000 120 batch := mat.NewDense(nSamples, dim, nil) 121 weights := make([]float64, nSamples) 122 rej := Rejection{Target: target, Proposal: proposal, C: 1000, Src: src} 123 rej.Sample(batch) 124 err := rej.Err() 125 if err != nil { 126 t.Error("Bad test, nan samples") 127 } 128 129 for i := 0; i < dim; i++ { 130 col := mat.Col(nil, i, batch) 131 ev := stat.Mean(col, weights) 132 if math.Abs(ev-mu[i]) > 1e-2 { 133 t.Errorf("Mean mismatch: Want %v, got %v", mu[i], ev) 134 } 135 } 136 } 137 138 func TestMetropolisHastings(t *testing.T) { 139 src := rand.New(rand.NewSource(1)) 140 // Test by finding the expected value of a normal distribution. 141 dim := 3 142 target, ok := randomNormal(dim, src) 143 if !ok { 144 t.Fatal("bad test, sigma not pos def") 145 } 146 147 sigmaImp := mat.NewSymDense(dim, nil) 148 for i := 0; i < dim; i++ { 149 sigmaImp.SetSym(i, i, 0.25) 150 } 151 proposal, ok := NewProposalNormal(sigmaImp, src) 152 if !ok { 153 t.Fatal("bad test, sigma not pos def") 154 } 155 156 nSamples := 100000 157 burnin := 5000 158 batch := mat.NewDense(nSamples, dim, nil) 159 initial := make([]float64, dim) 160 metropolisHastings(batch, initial, target, proposal, src) 161 batch = batch.Slice(burnin, nSamples, 0, dim).(*mat.Dense) 162 163 compareNormal(t, target, batch, nil, 5e-1, 5e-1) 164 } 165 166 // randomNormal constructs a random Normal distribution using the provided 167 // random source. 168 func randomNormal(dim int, src *rand.Rand) (*distmv.Normal, bool) { 169 data := make([]float64, dim*dim) 170 for i := range data { 171 data[i] = src.Float64() 172 } 173 a := mat.NewDense(dim, dim, data) 174 var sigma mat.SymDense 175 sigma.SymOuterK(1, a) 176 mu := make([]float64, dim) 177 for i := range mu { 178 mu[i] = rand.NormFloat64() 179 } 180 return distmv.NewNormal(mu, &sigma, src) 181 } 182 183 func compareNormal(t *testing.T, want *distmv.Normal, batch *mat.Dense, weights []float64, meanTol, covTol float64) { 184 t.Helper() 185 186 dim := want.Dim() 187 mu := want.Mean(nil) 188 var sigma mat.SymDense 189 want.CovarianceMatrix(&sigma) 190 n, _ := batch.Dims() 191 if weights == nil { 192 weights = make([]float64, n) 193 for i := range weights { 194 weights[i] = 1 195 } 196 } 197 for i := 0; i < dim; i++ { 198 col := mat.Col(nil, i, batch) 199 ev := stat.Mean(col, weights) 200 if math.Abs(ev-mu[i]) > meanTol { 201 t.Errorf("Mean mismatch: Want %v, got %v", mu[i], ev) 202 } 203 } 204 205 var cov mat.SymDense 206 stat.CovarianceMatrix(&cov, batch, weights) 207 if !mat.EqualApprox(&cov, &sigma, covTol) { 208 t.Errorf("Covariance matrix mismatch") 209 } 210 } 211 212 func TestMetropolisHastingser(t *testing.T) { 213 for _, test := range []struct { 214 dim, burnin, rate, samples int 215 }{ 216 {3, 10, 1, 1}, 217 {3, 10, 2, 1}, 218 {3, 10, 1, 2}, 219 {3, 10, 3, 2}, 220 {3, 10, 7, 4}, 221 {3, 10, 7, 4}, 222 223 {3, 11, 51, 103}, 224 {3, 11, 103, 51}, 225 {3, 51, 11, 103}, 226 {3, 51, 103, 11}, 227 {3, 103, 11, 51}, 228 {3, 103, 51, 11}, 229 } { 230 src := rand.New(rand.NewSource(1)) 231 dim := test.dim 232 233 initial := make([]float64, dim) 234 target, ok := randomNormal(dim, src) 235 if !ok { 236 t.Fatal("bad test, sigma not pos def") 237 } 238 239 sigmaImp := mat.NewSymDense(dim, nil) 240 for i := 0; i < dim; i++ { 241 sigmaImp.SetSym(i, i, 0.25) 242 } 243 244 // Test the Metropolis Hastingser by generating all the samples, then generating 245 // the same samples with a burnin and rate. 246 src = rand.New(rand.NewSource(1)) 247 proposal, ok := NewProposalNormal(sigmaImp, src) 248 if !ok { 249 t.Fatal("bad test, sigma not pos def") 250 } 251 252 mh := MetropolisHastingser{ 253 Initial: initial, 254 Target: target, 255 Proposal: proposal, 256 Src: src, 257 BurnIn: 0, 258 Rate: 0, 259 } 260 samples := test.samples 261 burnin := test.burnin 262 rate := test.rate 263 fullBatch := mat.NewDense(1+burnin+rate*(samples-1), dim, nil) 264 mh.Sample(fullBatch) 265 266 src = rand.New(rand.NewSource(1)) 267 proposal, _ = NewProposalNormal(sigmaImp, src) 268 mh = MetropolisHastingser{ 269 Initial: initial, 270 Target: target, 271 Proposal: proposal, 272 Src: src, 273 BurnIn: burnin, 274 Rate: rate, 275 } 276 batch := mat.NewDense(samples, dim, nil) 277 mh.Sample(batch) 278 279 same := true 280 count := burnin 281 for i := 0; i < samples; i++ { 282 if !floats.Equal(batch.RawRowView(i), fullBatch.RawRowView(count)) { 283 fmt.Println("sample ", i, "is different") 284 same = false 285 break 286 } 287 count += rate 288 } 289 290 if !same { 291 fmt.Printf("%v\n", mat.Formatted(batch)) 292 fmt.Printf("%v\n", mat.Formatted(fullBatch)) 293 294 t.Errorf("sampling mismatch: dim = %v, burnin = %v, rate = %v, samples = %v", dim, burnin, rate, samples) 295 } 296 } 297 }