gonum.org/v1/gonum@v0.14.0/stat/samplemv/samplemv.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package samplemv 6 7 import ( 8 "errors" 9 "math" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/mat" 14 "gonum.org/v1/gonum/stat/distmv" 15 ) 16 17 const errLengthMismatch = "samplemv: slice length mismatch" 18 19 var ( 20 _ Sampler = LatinHypercube{} 21 _ Sampler = (*Rejection)(nil) 22 _ Sampler = IID{} 23 24 _ WeightedSampler = SampleUniformWeighted{} 25 _ WeightedSampler = Importance{} 26 ) 27 28 func min(a, b int) int { 29 if a < b { 30 return a 31 } 32 return b 33 } 34 35 // Sampler generates a batch of samples according to the rule specified by the 36 // implementing type. The number of samples generated is equal to rows(batch), 37 // and the samples are stored in-place into the input. 38 type Sampler interface { 39 Sample(batch *mat.Dense) 40 } 41 42 // WeightedSampler generates a batch of samples and their relative weights 43 // according to the rule specified by the implementing type. The number of samples 44 // generated is equal to rows(batch), and the samples and weights 45 // are stored in-place into the inputs. The length of weights must equal 46 // rows(batch), otherwise SampleWeighted will panic. 47 type WeightedSampler interface { 48 SampleWeighted(batch *mat.Dense, weights []float64) 49 } 50 51 // SampleUniformWeighted wraps a Sampler type to create a WeightedSampler where all 52 // weights are equal. 53 type SampleUniformWeighted struct { 54 Sampler 55 } 56 57 // SampleWeighted generates rows(batch) samples from the embedded Sampler type 58 // and sets all of the weights equal to 1. If rows(batch) and len(weights) 59 // of weights are not equal, SampleWeighted will panic. 60 func (w SampleUniformWeighted) SampleWeighted(batch *mat.Dense, weights []float64) { 61 r, _ := batch.Dims() 62 if r != len(weights) { 63 panic(errLengthMismatch) 64 } 65 w.Sample(batch) 66 for i := range weights { 67 weights[i] = 1 68 } 69 } 70 71 // LatinHypercube is a type for sampling using Latin hypercube sampling 72 // from the given distribution. If src is not nil, it will be used to generate 73 // random numbers, otherwise rand.Float64 will be used. 74 // 75 // Latin hypercube sampling divides the cumulative distribution function into equally 76 // spaced bins and guarantees that one sample is generated per bin. Within each bin, 77 // the location is randomly sampled. The distmv.NewUnitUniform function can be used 78 // for easy sampling from the unit hypercube. 79 type LatinHypercube struct { 80 Q distmv.Quantiler 81 Src rand.Source 82 } 83 84 // Sample generates rows(batch) samples using the LatinHypercube generation 85 // procedure. 86 func (l LatinHypercube) Sample(batch *mat.Dense) { 87 latinHypercube(batch, l.Q, l.Src) 88 } 89 90 func latinHypercube(batch *mat.Dense, q distmv.Quantiler, src rand.Source) { 91 r, c := batch.Dims() 92 var f64 func() float64 93 var perm func(int) []int 94 if src != nil { 95 r := rand.New(src) 96 f64 = r.Float64 97 perm = r.Perm 98 } else { 99 f64 = rand.Float64 100 perm = rand.Perm 101 } 102 r64 := float64(r) 103 for i := 0; i < c; i++ { 104 p := perm(r) 105 for j := 0; j < r; j++ { 106 v := f64()/r64 + float64(j)/r64 107 batch.Set(p[j], i, v) 108 } 109 } 110 p := make([]float64, c) 111 for i := 0; i < r; i++ { 112 copy(p, batch.RawRowView(i)) 113 q.Quantile(batch.RawRowView(i), p) 114 } 115 } 116 117 // Importance is a type for performing importance sampling using the given 118 // Target and Proposal distributions. 119 // 120 // Importance sampling is a variance reduction technique where samples are 121 // generated from a proposal distribution, q(x), instead of the target distribution 122 // p(x). This allows relatively unlikely samples in p(x) to be generated more frequently. 123 // 124 // The importance sampling weight at x is given by p(x)/q(x). To reduce variance, 125 // a good proposal distribution will bound this sampling weight. This implies the 126 // support of q(x) should be at least as broad as p(x), and q(x) should be "fatter tailed" 127 // than p(x). 128 type Importance struct { 129 Target distmv.LogProber 130 Proposal distmv.RandLogProber 131 } 132 133 // SampleWeighted generates rows(batch) samples using the Importance sampling 134 // generation procedure. 135 // 136 // The length of weights must equal the length of batch, otherwise Importance will panic. 137 func (l Importance) SampleWeighted(batch *mat.Dense, weights []float64) { 138 importance(batch, weights, l.Target, l.Proposal) 139 } 140 141 func importance(batch *mat.Dense, weights []float64, target distmv.LogProber, proposal distmv.RandLogProber) { 142 r, _ := batch.Dims() 143 if r != len(weights) { 144 panic(errLengthMismatch) 145 } 146 for i := 0; i < r; i++ { 147 v := batch.RawRowView(i) 148 proposal.Rand(v) 149 weights[i] = math.Exp(target.LogProb(v) - proposal.LogProb(v)) 150 } 151 } 152 153 // ErrRejection is returned when the constant in Rejection is not sufficiently high. 154 var ErrRejection = errors.New("rejection: acceptance ratio above 1") 155 156 // Rejection is a type for sampling using the rejection sampling algorithm. 157 // 158 // Rejection sampling generates points from the target distribution by using 159 // the proposal distribution. At each step of the algorithm, the proposed point 160 // is accepted with probability 161 // 162 // p = target(x) / (proposal(x) * c) 163 // 164 // where target(x) is the probability of the point according to the target distribution 165 // and proposal(x) is the probability according to the proposal distribution. 166 // The constant c must be chosen such that target(x) < proposal(x) * c for all x. 167 // The expected number of proposed samples is len(samples) * c. 168 // 169 // The number of proposed locations during sampling can be found with a call to 170 // Proposed. If there was an error during sampling, all elements of samples are 171 // set to NaN and the error can be accessed with the Err method. If src != nil, 172 // it will be used to generate random numbers, otherwise rand.Float64 will be used. 173 // 174 // Target may return the true (log of) the probability of the location, or it may return 175 // a value that is proportional to the probability (logprob + constant). This is 176 // useful for cases where the probability distribution is only known up to a normalization 177 // constant. 178 type Rejection struct { 179 C float64 180 Target distmv.LogProber 181 Proposal distmv.RandLogProber 182 Src rand.Source 183 184 err error 185 proposed int 186 } 187 188 // Err returns nil if the most recent call to sample was successful, and returns 189 // ErrRejection if it was not. 190 func (r *Rejection) Err() error { 191 return r.err 192 } 193 194 // Proposed returns the number of samples proposed during the most recent call to 195 // Sample. 196 func (r *Rejection) Proposed() int { 197 return r.proposed 198 } 199 200 // Sample generates rows(batch) using the Rejection sampling generation procedure. 201 // Rejection sampling may fail if the constant is insufficiently high, as described 202 // in the type comment for Rejection. If the generation fails, the samples 203 // are set to math.NaN(), and a call to Err will return a non-nil value. 204 func (r *Rejection) Sample(batch *mat.Dense) { 205 r.err = nil 206 r.proposed = 0 207 proposed, ok := rejection(batch, r.Target, r.Proposal, r.C, r.Src) 208 if !ok { 209 r.err = ErrRejection 210 } 211 r.proposed = proposed 212 } 213 214 func rejection(batch *mat.Dense, target distmv.LogProber, proposal distmv.RandLogProber, c float64, src rand.Source) (nProposed int, ok bool) { 215 if c < 1 { 216 panic("rejection: acceptance constant must be greater than 1") 217 } 218 f64 := rand.Float64 219 if src != nil { 220 f64 = rand.New(src).Float64 221 } 222 r, dim := batch.Dims() 223 v := make([]float64, dim) 224 var idx int 225 for { 226 nProposed++ 227 proposal.Rand(v) 228 qx := proposal.LogProb(v) 229 px := target.LogProb(v) 230 accept := math.Exp(px-qx) / c 231 if accept > 1 { 232 // Invalidate the whole result and return a failure. 233 for i := 0; i < r; i++ { 234 for j := 0; j < dim; j++ { 235 batch.Set(i, j, math.NaN()) 236 } 237 } 238 return nProposed, false 239 } 240 if accept > f64() { 241 batch.SetRow(idx, v) 242 idx++ 243 if idx == r { 244 break 245 } 246 } 247 } 248 return nProposed, true 249 } 250 251 // IID generates a set of independently and identically distributed samples from 252 // the input distribution. 253 type IID struct { 254 Dist distmv.Rander 255 } 256 257 // Sample generates a set of identically and independently distributed samples. 258 func (iid IID) Sample(batch *mat.Dense) { 259 r, _ := batch.Dims() 260 for i := 0; i < r; i++ { 261 iid.Dist.Rand(batch.RawRowView(i)) 262 } 263 }