github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/stat/sampleuv/sample.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package sampleuv 6 7 import ( 8 "errors" 9 "math" 10 11 "golang.org/x/exp/rand" 12 13 "github.com/jingcheng-WU/gonum/stat/distuv" 14 ) 15 16 const badLengthMismatch = "sample: slice length mismatch" 17 18 var ( 19 _ Sampler = LatinHypercube{} 20 _ Sampler = MetropolisHastings{} 21 _ Sampler = (*Rejection)(nil) 22 _ Sampler = IIDer{} 23 24 _ WeightedSampler = SampleUniformWeighted{} 25 _ WeightedSampler = Importance{} 26 ) 27 28 func min(a, b int) int { 29 if a < b { 30 return a 31 } 32 return b 33 } 34 35 // Sampler generates a batch of samples according to the rule specified by the 36 // implementing type. The number of samples generated is equal to len(batch), 37 // and the samples are stored in-place into the input. 38 type Sampler interface { 39 Sample(batch []float64) 40 } 41 42 // WeightedSampler generates a batch of samples and their relative weights 43 // according to the rule specified by the implementing type. The number of samples 44 // generated is equal to len(batch), and the samples and weights 45 // are stored in-place into the inputs. The length of weights must equal 46 // len(batch), otherwise SampleWeighted will panic. 47 type WeightedSampler interface { 48 SampleWeighted(batch, weights []float64) 49 } 50 51 // SampleUniformWeighted wraps a Sampler type to create a WeightedSampler where all 52 // weights are equal. 53 type SampleUniformWeighted struct { 54 Sampler 55 } 56 57 // SampleWeighted generates len(batch) samples from the embedded Sampler type 58 // and sets all of the weights equal to 1. If len(batch) and len(weights) 59 // are not equal, SampleWeighted will panic. 60 func (w SampleUniformWeighted) SampleWeighted(batch, weights []float64) { 61 if len(batch) != len(weights) { 62 panic(badLengthMismatch) 63 } 64 w.Sample(batch) 65 for i := range weights { 66 weights[i] = 1 67 } 68 } 69 70 // LatinHypercube is a type for sampling using Latin hypercube sampling 71 // from the given distribution. If src is not nil, it will be used to generate 72 // random numbers, otherwise rand.Float64 will be used. 73 // 74 // Latin hypercube sampling divides the cumulative distribution function into equally 75 // spaced bins and guarantees that one sample is generated per bin. Within each bin, 76 // the location is randomly sampled. The distuv.UnitUniform variable can be used 77 // for easy sampling from the unit hypercube. 78 type LatinHypercube struct { 79 Q distuv.Quantiler 80 Src rand.Source 81 } 82 83 // Sample generates len(batch) samples using the LatinHypercube generation 84 // procedure. 85 func (l LatinHypercube) Sample(batch []float64) { 86 latinHypercube(batch, l.Q, l.Src) 87 } 88 89 func latinHypercube(batch []float64, q distuv.Quantiler, src rand.Source) { 90 n := len(batch) 91 var perm []int 92 var f64 func() float64 93 if src != nil { 94 r := rand.New(src) 95 f64 = r.Float64 96 perm = r.Perm(n) 97 } else { 98 f64 = rand.Float64 99 perm = rand.Perm(n) 100 } 101 for i := range batch { 102 v := f64()/float64(n) + float64(i)/float64(n) 103 batch[perm[i]] = q.Quantile(v) 104 } 105 } 106 107 // Importance is a type for performing importance sampling using the given 108 // Target and Proposal distributions. 109 // 110 // Importance sampling is a variance reduction technique where samples are 111 // generated from a proposal distribution, q(x), instead of the target distribution 112 // p(x). This allows relatively unlikely samples in p(x) to be generated more frequently. 113 // 114 // The importance sampling weight at x is given by p(x)/q(x). To reduce variance, 115 // a good proposal distribution will bound this sampling weight. This implies the 116 // support of q(x) should be at least as broad as p(x), and q(x) should be "fatter tailed" 117 // than p(x). 118 type Importance struct { 119 Target distuv.LogProber 120 Proposal distuv.RandLogProber 121 } 122 123 // SampleWeighted generates len(batch) samples using the Importance sampling 124 // generation procedure. 125 // 126 // The length of weights must equal the length of batch, otherwise Importance will panic. 127 func (l Importance) SampleWeighted(batch, weights []float64) { 128 importance(batch, weights, l.Target, l.Proposal) 129 } 130 131 func importance(batch, weights []float64, target distuv.LogProber, proposal distuv.RandLogProber) { 132 if len(batch) != len(weights) { 133 panic(badLengthMismatch) 134 } 135 for i := range batch { 136 v := proposal.Rand() 137 batch[i] = v 138 weights[i] = math.Exp(target.LogProb(v) - proposal.LogProb(v)) 139 } 140 } 141 142 // ErrRejection is returned when the constant in Rejection is not sufficiently high. 143 var ErrRejection = errors.New("rejection: acceptance ratio above 1") 144 145 // Rejection is a type for sampling using the rejection sampling algorithm. 146 // 147 // Rejection sampling generates points from the target distribution by using 148 // the proposal distribution. At each step of the algorithm, the proposed point 149 // is accepted with probability 150 // p = target(x) / (proposal(x) * c) 151 // where target(x) is the probability of the point according to the target distribution 152 // and proposal(x) is the probability according to the proposal distribution. 153 // The constant c must be chosen such that target(x) < proposal(x) * c for all x. 154 // The expected number of proposed samples is len(samples) * c. 155 // 156 // The number of proposed locations during sampling can be found with a call to 157 // Proposed. If there was an error during sampling, all elements of samples are 158 // set to NaN and the error can be accesssed with the Err method. If src != nil, 159 // it will be used to generate random numbers, otherwise rand.Float64 will be used. 160 // 161 // Target may return the true (log of) the probablity of the location, or it may return 162 // a value that is proportional to the probability (logprob + constant). This is 163 // useful for cases where the probability distribution is only known up to a normalization 164 // constant. 165 type Rejection struct { 166 C float64 167 Target distuv.LogProber 168 Proposal distuv.RandLogProber 169 Src rand.Source 170 171 err error 172 proposed int 173 } 174 175 // Err returns nil if the most recent call to sample was successful, and returns 176 // ErrRejection if it was not. 177 func (r *Rejection) Err() error { 178 return r.err 179 } 180 181 // Proposed returns the number of samples proposed during the most recent call to 182 // Sample. 183 func (r *Rejection) Proposed() int { 184 return r.proposed 185 } 186 187 // Sample generates len(batch) using the Rejection sampling generation procedure. 188 // Rejection sampling may fail if the constant is insufficiently high, as described 189 // in the type comment for Rejection. If the generation fails, the samples 190 // are set to math.NaN(), and a call to Err will return a non-nil value. 191 func (r *Rejection) Sample(batch []float64) { 192 r.err = nil 193 r.proposed = 0 194 proposed, ok := rejection(batch, r.Target, r.Proposal, r.C, r.Src) 195 if !ok { 196 r.err = ErrRejection 197 } 198 r.proposed = proposed 199 } 200 201 func rejection(batch []float64, target distuv.LogProber, proposal distuv.RandLogProber, c float64, src rand.Source) (nProposed int, ok bool) { 202 if c < 1 { 203 panic("rejection: acceptance constant must be greater than 1") 204 } 205 f64 := rand.Float64 206 if src != nil { 207 f64 = rand.New(src).Float64 208 } 209 var idx int 210 for { 211 nProposed++ 212 v := proposal.Rand() 213 qx := proposal.LogProb(v) 214 px := target.LogProb(v) 215 accept := math.Exp(px-qx) / c 216 if accept > 1 { 217 // Invalidate the whole result and return a failure. 218 for i := range batch { 219 batch[i] = math.NaN() 220 } 221 return nProposed, false 222 } 223 if accept > f64() { 224 batch[idx] = v 225 idx++ 226 if idx == len(batch) { 227 break 228 } 229 } 230 } 231 return nProposed, true 232 } 233 234 // MHProposal defines a proposal distribution for Metropolis Hastings. 235 type MHProposal interface { 236 // ConditionalDist returns the probability of the first argument conditioned on 237 // being at the second argument 238 // p(x|y) 239 ConditionalLogProb(x, y float64) (prob float64) 240 241 // ConditionalRand generates a new random location conditioned being at the 242 // location y. 243 ConditionalRand(y float64) (x float64) 244 } 245 246 // MetropolisHastings is a type for generating samples using the Metropolis Hastings 247 // algorithm (http://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm), 248 // with the given target and proposal distributions, starting at the location 249 // specified by Initial. If src != nil, it will be used to generate random 250 // numbers, otherwise rand.Float64 will be used. 251 // 252 // Metropolis-Hastings is a Markov-chain Monte Carlo algorithm that generates 253 // samples according to the distribution specified by target using the Markov 254 // chain implicitly defined by the proposal distribution. At each 255 // iteration, a proposal point is generated randomly from the current location. 256 // This proposal point is accepted with probability 257 // p = min(1, (target(new) * proposal(current|new)) / (target(current) * proposal(new|current))) 258 // If the new location is accepted, it becomes the new current location. 259 // If it is rejected, the current location remains. This is the sample stored in 260 // batch, ignoring BurnIn and Rate (discussed below). 261 // 262 // The samples in Metropolis Hastings are correlated with one another through the 263 // Markov chain. As a result, the initial value can have a significant influence 264 // on the early samples, and so, typically, the first samples generated by the chain 265 // are ignored. This is known as "burn-in", and the number of samples ignored 266 // at the beginning is specified by BurnIn. The proper BurnIn value will depend 267 // on the mixing time of the Markov chain defined by the target and proposal 268 // distributions. 269 // 270 // Many choose to have a sampling "rate" where a number of samples 271 // are ignored in between each kept sample. This helps decorrelate 272 // the samples from one another, but also reduces the number of available samples. 273 // This value is specified by Rate. If Rate is 0 it is defaulted to 1 (keep 274 // every sample). 275 // 276 // The initial value is NOT changed during calls to Sample. 277 type MetropolisHastings struct { 278 Initial float64 279 Target distuv.LogProber 280 Proposal MHProposal 281 Src rand.Source 282 283 BurnIn int 284 Rate int 285 } 286 287 // Sample generates len(batch) samples using the Metropolis Hastings sample 288 // generation method. The initial location is NOT updated during the call to Sample. 289 func (m MetropolisHastings) Sample(batch []float64) { 290 burnIn := m.BurnIn 291 rate := m.Rate 292 if rate == 0 { 293 rate = 1 294 } 295 296 // Use the optimal size for the temporary memory to allow the fewest calls 297 // to MetropolisHastings. The case where tmp shadows samples must be 298 // aligned with the logic after burn-in so that tmp does not shadow samples 299 // during the rate portion. 300 tmp := batch 301 if rate > len(batch) { 302 tmp = make([]float64, rate) 303 } 304 305 // Perform burn-in. 306 remaining := burnIn 307 initial := m.Initial 308 for remaining != 0 { 309 newSamp := min(len(tmp), remaining) 310 metropolisHastings(tmp[newSamp:], initial, m.Target, m.Proposal, m.Src) 311 initial = tmp[newSamp-1] 312 remaining -= newSamp 313 } 314 315 if rate == 1 { 316 metropolisHastings(batch, initial, m.Target, m.Proposal, m.Src) 317 return 318 } 319 320 if len(tmp) <= len(batch) { 321 tmp = make([]float64, rate) 322 } 323 324 // Take a single sample from the chain 325 metropolisHastings(batch[0:1], initial, m.Target, m.Proposal, m.Src) 326 initial = batch[0] 327 328 // For all of the other samples, first generate Rate samples and then actually 329 // accept the last one. 330 for i := 1; i < len(batch); i++ { 331 metropolisHastings(tmp, initial, m.Target, m.Proposal, m.Src) 332 v := tmp[rate-1] 333 batch[i] = v 334 initial = v 335 } 336 } 337 338 func metropolisHastings(batch []float64, initial float64, target distuv.LogProber, proposal MHProposal, src rand.Source) { 339 f64 := rand.Float64 340 if src != nil { 341 f64 = rand.New(src).Float64 342 } 343 current := initial 344 currentLogProb := target.LogProb(initial) 345 for i := range batch { 346 proposed := proposal.ConditionalRand(current) 347 proposedLogProb := target.LogProb(proposed) 348 probTo := proposal.ConditionalLogProb(proposed, current) 349 probBack := proposal.ConditionalLogProb(current, proposed) 350 351 accept := math.Exp(proposedLogProb + probBack - probTo - currentLogProb) 352 if accept > f64() { 353 current = proposed 354 currentLogProb = proposedLogProb 355 } 356 batch[i] = current 357 } 358 } 359 360 // IIDer generates a set of independently and identically distributed samples from 361 // the input distribution. 362 type IIDer struct { 363 Dist distuv.Rander 364 } 365 366 // Sample generates a set of identically and independently distributed samples. 367 func (iid IIDer) Sample(batch []float64) { 368 for i := range batch { 369 batch[i] = iid.Dist.Rand() 370 } 371 }