gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/stat/samplemv/samplemv.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package samplemv 6 7 import ( 8 "errors" 9 "math" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/mat" 14 "gonum.org/v1/gonum/stat/distmv" 15 ) 16 17 const errLengthMismatch = "samplemv: slice length mismatch" 18 19 var ( 20 _ Sampler = LatinHypercube{} 21 _ Sampler = (*Rejection)(nil) 22 _ Sampler = IID{} 23 24 _ WeightedSampler = SampleUniformWeighted{} 25 _ WeightedSampler = Importance{} 26 ) 27 28 // Sampler generates a batch of samples according to the rule specified by the 29 // implementing type. The number of samples generated is equal to rows(batch), 30 // and the samples are stored in-place into the input. 31 type Sampler interface { 32 Sample(batch *mat.Dense) 33 } 34 35 // WeightedSampler generates a batch of samples and their relative weights 36 // according to the rule specified by the implementing type. The number of samples 37 // generated is equal to rows(batch), and the samples and weights 38 // are stored in-place into the inputs. The length of weights must equal 39 // rows(batch), otherwise SampleWeighted will panic. 40 type WeightedSampler interface { 41 SampleWeighted(batch *mat.Dense, weights []float64) 42 } 43 44 // SampleUniformWeighted wraps a Sampler type to create a WeightedSampler where all 45 // weights are equal. 46 type SampleUniformWeighted struct { 47 Sampler 48 } 49 50 // SampleWeighted generates rows(batch) samples from the embedded Sampler type 51 // and sets all of the weights equal to 1. If rows(batch) and len(weights) 52 // of weights are not equal, SampleWeighted will panic. 53 func (w SampleUniformWeighted) SampleWeighted(batch *mat.Dense, weights []float64) { 54 r, _ := batch.Dims() 55 if r != len(weights) { 56 panic(errLengthMismatch) 57 } 58 w.Sample(batch) 59 for i := range weights { 60 weights[i] = 1 61 } 62 } 63 64 // LatinHypercube is a type for sampling using Latin hypercube sampling 65 // from the given distribution. If src is not nil, it will be used to generate 66 // random numbers, otherwise rand.Float64 will be used. 67 // 68 // Latin hypercube sampling divides the cumulative distribution function into equally 69 // spaced bins and guarantees that one sample is generated per bin. Within each bin, 70 // the location is randomly sampled. The distmv.NewUnitUniform function can be used 71 // for easy sampling from the unit hypercube. 72 type LatinHypercube struct { 73 Q distmv.Quantiler 74 Src rand.Source 75 } 76 77 // Sample generates rows(batch) samples using the LatinHypercube generation 78 // procedure. 79 func (l LatinHypercube) Sample(batch *mat.Dense) { 80 latinHypercube(batch, l.Q, l.Src) 81 } 82 83 func latinHypercube(batch *mat.Dense, q distmv.Quantiler, src rand.Source) { 84 r, c := batch.Dims() 85 var f64 func() float64 86 var perm func(int) []int 87 if src != nil { 88 r := rand.New(src) 89 f64 = r.Float64 90 perm = r.Perm 91 } else { 92 f64 = rand.Float64 93 perm = rand.Perm 94 } 95 r64 := float64(r) 96 for i := 0; i < c; i++ { 97 p := perm(r) 98 for j := 0; j < r; j++ { 99 v := f64()/r64 + float64(j)/r64 100 batch.Set(p[j], i, v) 101 } 102 } 103 p := make([]float64, c) 104 for i := 0; i < r; i++ { 105 copy(p, batch.RawRowView(i)) 106 q.Quantile(batch.RawRowView(i), p) 107 } 108 } 109 110 // Importance is a type for performing importance sampling using the given 111 // Target and Proposal distributions. 112 // 113 // Importance sampling is a variance reduction technique where samples are 114 // generated from a proposal distribution, q(x), instead of the target distribution 115 // p(x). This allows relatively unlikely samples in p(x) to be generated more frequently. 116 // 117 // The importance sampling weight at x is given by p(x)/q(x). To reduce variance, 118 // a good proposal distribution will bound this sampling weight. This implies the 119 // support of q(x) should be at least as broad as p(x), and q(x) should be "fatter tailed" 120 // than p(x). 121 type Importance struct { 122 Target distmv.LogProber 123 Proposal distmv.RandLogProber 124 } 125 126 // SampleWeighted generates rows(batch) samples using the Importance sampling 127 // generation procedure. 128 // 129 // The length of weights must equal the length of batch, otherwise Importance will panic. 130 func (l Importance) SampleWeighted(batch *mat.Dense, weights []float64) { 131 importance(batch, weights, l.Target, l.Proposal) 132 } 133 134 func importance(batch *mat.Dense, weights []float64, target distmv.LogProber, proposal distmv.RandLogProber) { 135 r, _ := batch.Dims() 136 if r != len(weights) { 137 panic(errLengthMismatch) 138 } 139 for i := 0; i < r; i++ { 140 v := batch.RawRowView(i) 141 proposal.Rand(v) 142 weights[i] = math.Exp(target.LogProb(v) - proposal.LogProb(v)) 143 } 144 } 145 146 // ErrRejection is returned when the constant in Rejection is not sufficiently high. 147 var ErrRejection = errors.New("rejection: acceptance ratio above 1") 148 149 // Rejection is a type for sampling using the rejection sampling algorithm. 150 // 151 // Rejection sampling generates points from the target distribution by using 152 // the proposal distribution. At each step of the algorithm, the proposed point 153 // is accepted with probability 154 // 155 // p = target(x) / (proposal(x) * c) 156 // 157 // where target(x) is the probability of the point according to the target distribution 158 // and proposal(x) is the probability according to the proposal distribution. 159 // The constant c must be chosen such that target(x) < proposal(x) * c for all x. 160 // The expected number of proposed samples is len(samples) * c. 161 // 162 // The number of proposed locations during sampling can be found with a call to 163 // Proposed. If there was an error during sampling, all elements of samples are 164 // set to NaN and the error can be accessed with the Err method. If src != nil, 165 // it will be used to generate random numbers, otherwise rand.Float64 will be used. 166 // 167 // Target may return the true (log of) the probability of the location, or it may return 168 // a value that is proportional to the probability (logprob + constant). This is 169 // useful for cases where the probability distribution is only known up to a normalization 170 // constant. 171 type Rejection struct { 172 C float64 173 Target distmv.LogProber 174 Proposal distmv.RandLogProber 175 Src rand.Source 176 177 err error 178 proposed int 179 } 180 181 // Err returns nil if the most recent call to sample was successful, and returns 182 // ErrRejection if it was not. 183 func (r *Rejection) Err() error { 184 return r.err 185 } 186 187 // Proposed returns the number of samples proposed during the most recent call to 188 // Sample. 189 func (r *Rejection) Proposed() int { 190 return r.proposed 191 } 192 193 // Sample generates rows(batch) using the Rejection sampling generation procedure. 194 // Rejection sampling may fail if the constant is insufficiently high, as described 195 // in the type comment for Rejection. If the generation fails, the samples 196 // are set to math.NaN(), and a call to Err will return a non-nil value. 197 func (r *Rejection) Sample(batch *mat.Dense) { 198 r.err = nil 199 r.proposed = 0 200 proposed, ok := rejection(batch, r.Target, r.Proposal, r.C, r.Src) 201 if !ok { 202 r.err = ErrRejection 203 } 204 r.proposed = proposed 205 } 206 207 func rejection(batch *mat.Dense, target distmv.LogProber, proposal distmv.RandLogProber, c float64, src rand.Source) (nProposed int, ok bool) { 208 if c < 1 { 209 panic("rejection: acceptance constant must be greater than 1") 210 } 211 f64 := rand.Float64 212 if src != nil { 213 f64 = rand.New(src).Float64 214 } 215 r, dim := batch.Dims() 216 v := make([]float64, dim) 217 var idx int 218 for { 219 nProposed++ 220 proposal.Rand(v) 221 qx := proposal.LogProb(v) 222 px := target.LogProb(v) 223 accept := math.Exp(px-qx) / c 224 if accept > 1 { 225 // Invalidate the whole result and return a failure. 226 for i := 0; i < r; i++ { 227 for j := 0; j < dim; j++ { 228 batch.Set(i, j, math.NaN()) 229 } 230 } 231 return nProposed, false 232 } 233 if accept > f64() { 234 batch.SetRow(idx, v) 235 idx++ 236 if idx == r { 237 break 238 } 239 } 240 } 241 return nProposed, true 242 } 243 244 // IID generates a set of independently and identically distributed samples from 245 // the input distribution. 246 type IID struct { 247 Dist distmv.Rander 248 } 249 250 // Sample generates a set of identically and independently distributed samples. 251 func (iid IID) Sample(batch *mat.Dense) { 252 r, _ := batch.Dims() 253 for i := 0; i < r; i++ { 254 iid.Dist.Rand(batch.RawRowView(i)) 255 } 256 }