github.com/gopherd/gonum@v0.0.4/stat/samplemv/sample_test.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 package samplemv 5 6 import ( 7 "fmt" 8 "math" 9 "testing" 10 11 "math/rand" 12 13 "github.com/gopherd/gonum/floats" 14 "github.com/gopherd/gonum/mat" 15 "github.com/gopherd/gonum/spatial/r1" 16 "github.com/gopherd/gonum/stat" 17 "github.com/gopherd/gonum/stat/distmv" 18 ) 19 20 type lhDist interface { 21 Quantile(x, p []float64) []float64 22 CDF(p, x []float64) []float64 23 Dim() int 24 } 25 26 func TestLatinHypercube(t *testing.T) { 27 src := rand.New(rand.NewSource(1)) 28 for _, nSamples := range []int{1, 2, 5, 10, 20} { 29 for _, dist := range []lhDist{ 30 distmv.NewUniform([]r1.Interval{{Min: 0, Max: 3}}, src), 31 distmv.NewUniform([]r1.Interval{{Min: 0, Max: 3}, {Min: -1, Max: 5}, {Min: -4, Max: -1}}, src), 32 } { 33 dim := dist.Dim() 34 batch := mat.NewDense(nSamples, dim, nil) 35 LatinHypercube{Src: src, Q: dist}.Sample(batch) 36 // Latin hypercube should have one entry per hyperrow. 37 present := make([][]bool, nSamples) 38 for i := range present { 39 present[i] = make([]bool, dim) 40 } 41 cdf := make([]float64, dim) 42 for i := 0; i < nSamples; i++ { 43 dist.CDF(cdf, batch.RawRowView(i)) 44 for j := 0; j < dim; j++ { 45 p := cdf[j] 46 quadrant := int(math.Floor(p * float64(nSamples))) 47 present[quadrant][j] = true 48 } 49 } 50 allPresent := true 51 for i := 0; i < nSamples; i++ { 52 for j := 0; j < dim; j++ { 53 if !present[i][j] { 54 allPresent = false 55 } 56 } 57 } 58 if !allPresent { 59 t.Errorf("All quadrants not present") 60 } 61 } 62 } 63 } 64 65 func TestImportance(t *testing.T) { 66 src := rand.New(rand.NewSource(1)) 67 // Test by finding the expected value of a multi-variate normal. 68 dim := 3 69 target, ok := randomNormal(dim, src) 70 if !ok { 71 t.Fatal("bad test, sigma not pos def") 72 } 73 74 muImp := make([]float64, dim) 75 sigmaImp := mat.NewSymDense(dim, nil) 76 for i := 0; i < dim; i++ { 77 sigmaImp.SetSym(i, i, 3) 78 } 79 proposal, ok := distmv.NewNormal(muImp, sigmaImp, src) 80 if !ok { 81 t.Fatal("bad test, sigma not pos def") 82 } 83 84 nSamples := 200000 85 batch := mat.NewDense(nSamples, dim, nil) 86 weights := make([]float64, nSamples) 87 Importance{Target: target, Proposal: proposal}.SampleWeighted(batch, weights) 88 89 compareNormal(t, target, batch, weights, 5e-2, 5e-2) 90 } 91 92 func TestRejection(t *testing.T) { 93 src := rand.New(rand.NewSource(1)) 94 // Test by finding the expected value of a uniform. 95 dim := 3 96 bounds := make([]r1.Interval, dim) 97 for i := 0; i < dim; i++ { 98 min := src.NormFloat64() 99 max := src.NormFloat64() 100 if min > max { 101 min, max = max, min 102 } 103 bounds[i].Min = min 104 bounds[i].Max = max 105 } 106 target := distmv.NewUniform(bounds, src) 107 mu := target.Mean(nil) 108 109 muImp := make([]float64, dim) 110 sigmaImp := mat.NewSymDense(dim, nil) 111 for i := 0; i < dim; i++ { 112 sigmaImp.SetSym(i, i, 6) 113 } 114 proposal, ok := distmv.NewNormal(muImp, sigmaImp, src) 115 if !ok { 116 t.Fatal("bad test, sigma not pos def") 117 } 118 119 nSamples := 1000 120 batch := mat.NewDense(nSamples, dim, nil) 121 weights := make([]float64, nSamples) 122 rej := Rejection{Target: target, Proposal: proposal, C: 1000, Src: src} 123 rej.Sample(batch) 124 err := rej.Err() 125 if err != nil { 126 t.Error("Bad test, nan samples") 127 } 128 129 for i := 0; i < dim; i++ { 130 col := mat.Col(nil, i, batch) 131 ev := stat.Mean(col, weights) 132 if math.Abs(ev-mu[i]) > 1e-2 { 133 t.Errorf("Mean mismatch: Want %v, got %v", mu[i], ev) 134 } 135 } 136 } 137 138 func TestMetropolisHastings(t *testing.T) { 139 src := rand.New(rand.NewSource(1)) 140 // Test by finding the expected value of a normal distribution. 141 dim := 3 142 target, ok := randomNormal(dim, src) 143 if !ok { 144 t.Fatal("bad test, sigma not pos def") 145 } 146 147 sigmaImp := mat.NewSymDense(dim, nil) 148 for i := 0; i < dim; i++ { 149 sigmaImp.SetSym(i, i, 0.25) 150 } 151 proposal, ok := NewProposalNormal(sigmaImp, src) 152 if !ok { 153 t.Fatal("bad test, sigma not pos def") 154 } 155 156 nSamples := 100000 157 burnin := 5000 158 batch := mat.NewDense(nSamples, dim, nil) 159 initial := make([]float64, dim) 160 metropolisHastings(batch, initial, target, proposal, src) 161 batch = batch.Slice(burnin, nSamples, 0, dim).(*mat.Dense) 162 163 compareNormal(t, target, batch, nil, 5e-1, 5e-1) 164 } 165 166 // randomNormal constructs a random Normal distribution. 167 func randomNormal(dim int, src *rand.Rand) (*distmv.Normal, bool) { 168 data := make([]float64, dim*dim) 169 for i := range data { 170 data[i] = rand.Float64() 171 } 172 a := mat.NewDense(dim, dim, data) 173 var sigma mat.SymDense 174 sigma.SymOuterK(1, a) 175 mu := make([]float64, dim) 176 for i := range mu { 177 mu[i] = rand.NormFloat64() 178 } 179 return distmv.NewNormal(mu, &sigma, src) 180 } 181 182 func compareNormal(t *testing.T, want *distmv.Normal, batch *mat.Dense, weights []float64, meanTol, covTol float64) { 183 dim := want.Dim() 184 mu := want.Mean(nil) 185 var sigma mat.SymDense 186 want.CovarianceMatrix(&sigma) 187 n, _ := batch.Dims() 188 if weights == nil { 189 weights = make([]float64, n) 190 for i := range weights { 191 weights[i] = 1 192 } 193 } 194 for i := 0; i < dim; i++ { 195 col := mat.Col(nil, i, batch) 196 ev := stat.Mean(col, weights) 197 if math.Abs(ev-mu[i]) > meanTol { 198 t.Errorf("Mean mismatch: Want %v, got %v", mu[i], ev) 199 } 200 } 201 202 var cov mat.SymDense 203 stat.CovarianceMatrix(&cov, batch, weights) 204 if !mat.EqualApprox(&cov, &sigma, covTol) { 205 t.Errorf("Covariance matrix mismatch") 206 } 207 } 208 209 func TestMetropolisHastingser(t *testing.T) { 210 for _, test := range []struct { 211 dim, burnin, rate, samples int 212 }{ 213 {3, 10, 1, 1}, 214 {3, 10, 2, 1}, 215 {3, 10, 1, 2}, 216 {3, 10, 3, 2}, 217 {3, 10, 7, 4}, 218 {3, 10, 7, 4}, 219 220 {3, 11, 51, 103}, 221 {3, 11, 103, 51}, 222 {3, 51, 11, 103}, 223 {3, 51, 103, 11}, 224 {3, 103, 11, 51}, 225 {3, 103, 51, 11}, 226 } { 227 dim := test.dim 228 229 initial := make([]float64, dim) 230 target, ok := randomNormal(dim, nil) 231 if !ok { 232 t.Fatal("bad test, sigma not pos def") 233 } 234 235 sigmaImp := mat.NewSymDense(dim, nil) 236 for i := 0; i < dim; i++ { 237 sigmaImp.SetSym(i, i, 0.25) 238 } 239 240 // Test the Metropolis Hastingser by generating all the samples, then generating 241 // the same samples with a burnin and rate. 242 src := rand.New(rand.NewSource(1)) 243 proposal, ok := NewProposalNormal(sigmaImp, src) 244 if !ok { 245 t.Fatal("bad test, sigma not pos def") 246 } 247 248 mh := MetropolisHastingser{ 249 Initial: initial, 250 Target: target, 251 Proposal: proposal, 252 Src: src, 253 BurnIn: 0, 254 Rate: 0, 255 } 256 samples := test.samples 257 burnin := test.burnin 258 rate := test.rate 259 fullBatch := mat.NewDense(1+burnin+rate*(samples-1), dim, nil) 260 mh.Sample(fullBatch) 261 262 src = rand.New(rand.NewSource(1)) 263 proposal, _ = NewProposalNormal(sigmaImp, src) 264 mh = MetropolisHastingser{ 265 Initial: initial, 266 Target: target, 267 Proposal: proposal, 268 Src: src, 269 BurnIn: burnin, 270 Rate: rate, 271 } 272 batch := mat.NewDense(samples, dim, nil) 273 mh.Sample(batch) 274 275 same := true 276 count := burnin 277 for i := 0; i < samples; i++ { 278 if !floats.Equal(batch.RawRowView(i), fullBatch.RawRowView(count)) { 279 fmt.Println("sample ", i, "is different") 280 same = false 281 break 282 } 283 count += rate 284 } 285 286 if !same { 287 fmt.Printf("%v\n", mat.Formatted(batch)) 288 fmt.Printf("%v\n", mat.Formatted(fullBatch)) 289 290 t.Errorf("sampling mismatch: dim = %v, burnin = %v, rate = %v, samples = %v", dim, burnin, rate, samples) 291 } 292 } 293 }