github.com/gopherd/gonum@v0.0.4/stat/samplemv/samplemv.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package samplemv 6 7 import ( 8 "errors" 9 "math" 10 11 "math/rand" 12 13 "github.com/gopherd/gonum/mat" 14 "github.com/gopherd/gonum/stat/distmv" 15 ) 16 17 const errLengthMismatch = "samplemv: slice length mismatch" 18 19 var ( 20 _ Sampler = LatinHypercube{} 21 _ Sampler = (*Rejection)(nil) 22 _ Sampler = IID{} 23 24 _ WeightedSampler = SampleUniformWeighted{} 25 _ WeightedSampler = Importance{} 26 ) 27 28 func min(a, b int) int { 29 if a < b { 30 return a 31 } 32 return b 33 } 34 35 // Sampler generates a batch of samples according to the rule specified by the 36 // implementing type. The number of samples generated is equal to rows(batch), 37 // and the samples are stored in-place into the input. 38 type Sampler interface { 39 Sample(batch *mat.Dense) 40 } 41 42 // WeightedSampler generates a batch of samples and their relative weights 43 // according to the rule specified by the implementing type. The number of samples 44 // generated is equal to rows(batch), and the samples and weights 45 // are stored in-place into the inputs. The length of weights must equal 46 // rows(batch), otherwise SampleWeighted will panic. 47 type WeightedSampler interface { 48 SampleWeighted(batch *mat.Dense, weights []float64) 49 } 50 51 // SampleUniformWeighted wraps a Sampler type to create a WeightedSampler where all 52 // weights are equal. 53 type SampleUniformWeighted struct { 54 Sampler 55 } 56 57 // SampleWeighted generates rows(batch) samples from the embedded Sampler type 58 // and sets all of the weights equal to 1. If rows(batch) and len(weights) 59 // of weights are not equal, SampleWeighted will panic. 60 func (w SampleUniformWeighted) SampleWeighted(batch *mat.Dense, weights []float64) { 61 r, _ := batch.Dims() 62 if r != len(weights) { 63 panic(errLengthMismatch) 64 } 65 w.Sample(batch) 66 for i := range weights { 67 weights[i] = 1 68 } 69 } 70 71 // LatinHypercube is a type for sampling using Latin hypercube sampling 72 // from the given distribution. If src is not nil, it will be used to generate 73 // random numbers, otherwise rand.Float64 will be used. 74 // 75 // Latin hypercube sampling divides the cumulative distribution function into equally 76 // spaced bins and guarantees that one sample is generated per bin. Within each bin, 77 // the location is randomly sampled. The distmv.NewUnitUniform function can be used 78 // for easy sampling from the unit hypercube. 79 type LatinHypercube struct { 80 Q distmv.Quantiler 81 Src rand.Source 82 } 83 84 // Sample generates rows(batch) samples using the LatinHypercube generation 85 // procedure. 86 func (l LatinHypercube) Sample(batch *mat.Dense) { 87 latinHypercube(batch, l.Q, l.Src) 88 } 89 90 func latinHypercube(batch *mat.Dense, q distmv.Quantiler, src rand.Source) { 91 r, c := batch.Dims() 92 var f64 func() float64 93 var perm func(int) []int 94 if src != nil { 95 r := rand.New(src) 96 f64 = r.Float64 97 perm = r.Perm 98 } else { 99 f64 = rand.Float64 100 perm = rand.Perm 101 } 102 r64 := float64(r) 103 for i := 0; i < c; i++ { 104 p := perm(r) 105 for j := 0; j < r; j++ { 106 v := f64()/r64 + float64(j)/r64 107 batch.Set(p[j], i, v) 108 } 109 } 110 p := make([]float64, c) 111 for i := 0; i < r; i++ { 112 copy(p, batch.RawRowView(i)) 113 q.Quantile(batch.RawRowView(i), p) 114 } 115 } 116 117 // Importance is a type for performing importance sampling using the given 118 // Target and Proposal distributions. 119 // 120 // Importance sampling is a variance reduction technique where samples are 121 // generated from a proposal distribution, q(x), instead of the target distribution 122 // p(x). This allows relatively unlikely samples in p(x) to be generated more frequently. 123 // 124 // The importance sampling weight at x is given by p(x)/q(x). To reduce variance, 125 // a good proposal distribution will bound this sampling weight. This implies the 126 // support of q(x) should be at least as broad as p(x), and q(x) should be "fatter tailed" 127 // than p(x). 128 type Importance struct { 129 Target distmv.LogProber 130 Proposal distmv.RandLogProber 131 } 132 133 // SampleWeighted generates rows(batch) samples using the Importance sampling 134 // generation procedure. 135 // 136 // The length of weights must equal the length of batch, otherwise Importance will panic. 137 func (l Importance) SampleWeighted(batch *mat.Dense, weights []float64) { 138 importance(batch, weights, l.Target, l.Proposal) 139 } 140 141 func importance(batch *mat.Dense, weights []float64, target distmv.LogProber, proposal distmv.RandLogProber) { 142 r, _ := batch.Dims() 143 if r != len(weights) { 144 panic(errLengthMismatch) 145 } 146 for i := 0; i < r; i++ { 147 v := batch.RawRowView(i) 148 proposal.Rand(v) 149 weights[i] = math.Exp(target.LogProb(v) - proposal.LogProb(v)) 150 } 151 } 152 153 // ErrRejection is returned when the constant in Rejection is not sufficiently high. 154 var ErrRejection = errors.New("rejection: acceptance ratio above 1") 155 156 // Rejection is a type for sampling using the rejection sampling algorithm. 157 // 158 // Rejection sampling generates points from the target distribution by using 159 // the proposal distribution. At each step of the algorithm, the proposed point 160 // is accepted with probability 161 // p = target(x) / (proposal(x) * c) 162 // where target(x) is the probability of the point according to the target distribution 163 // and proposal(x) is the probability according to the proposal distribution. 164 // The constant c must be chosen such that target(x) < proposal(x) * c for all x. 165 // The expected number of proposed samples is len(samples) * c. 166 // 167 // The number of proposed locations during sampling can be found with a call to 168 // Proposed. If there was an error during sampling, all elements of samples are 169 // set to NaN and the error can be accesssed with the Err method. If src != nil, 170 // it will be used to generate random numbers, otherwise rand.Float64 will be used. 171 // 172 // Target may return the true (log of) the probablity of the location, or it may return 173 // a value that is proportional to the probability (logprob + constant). This is 174 // useful for cases where the probability distribution is only known up to a normalization 175 // constant. 176 type Rejection struct { 177 C float64 178 Target distmv.LogProber 179 Proposal distmv.RandLogProber 180 Src rand.Source 181 182 err error 183 proposed int 184 } 185 186 // Err returns nil if the most recent call to sample was successful, and returns 187 // ErrRejection if it was not. 188 func (r *Rejection) Err() error { 189 return r.err 190 } 191 192 // Proposed returns the number of samples proposed during the most recent call to 193 // Sample. 194 func (r *Rejection) Proposed() int { 195 return r.proposed 196 } 197 198 // Sample generates rows(batch) using the Rejection sampling generation procedure. 199 // Rejection sampling may fail if the constant is insufficiently high, as described 200 // in the type comment for Rejection. If the generation fails, the samples 201 // are set to math.NaN(), and a call to Err will return a non-nil value. 202 func (r *Rejection) Sample(batch *mat.Dense) { 203 r.err = nil 204 r.proposed = 0 205 proposed, ok := rejection(batch, r.Target, r.Proposal, r.C, r.Src) 206 if !ok { 207 r.err = ErrRejection 208 } 209 r.proposed = proposed 210 } 211 212 func rejection(batch *mat.Dense, target distmv.LogProber, proposal distmv.RandLogProber, c float64, src rand.Source) (nProposed int, ok bool) { 213 if c < 1 { 214 panic("rejection: acceptance constant must be greater than 1") 215 } 216 f64 := rand.Float64 217 if src != nil { 218 f64 = rand.New(src).Float64 219 } 220 r, dim := batch.Dims() 221 v := make([]float64, dim) 222 var idx int 223 for { 224 nProposed++ 225 proposal.Rand(v) 226 qx := proposal.LogProb(v) 227 px := target.LogProb(v) 228 accept := math.Exp(px-qx) / c 229 if accept > 1 { 230 // Invalidate the whole result and return a failure. 231 for i := 0; i < r; i++ { 232 for j := 0; j < dim; j++ { 233 batch.Set(i, j, math.NaN()) 234 } 235 } 236 return nProposed, false 237 } 238 if accept > f64() { 239 batch.SetRow(idx, v) 240 idx++ 241 if idx == r { 242 break 243 } 244 } 245 } 246 return nProposed, true 247 } 248 249 // IID generates a set of independently and identically distributed samples from 250 // the input distribution. 251 type IID struct { 252 Dist distmv.Rander 253 } 254 255 // Sample generates a set of identically and independently distributed samples. 256 func (iid IID) Sample(batch *mat.Dense) { 257 r, _ := batch.Dims() 258 for i := 0; i < r; i++ { 259 iid.Dist.Rand(batch.RawRowView(i)) 260 } 261 }