gonum.org/v1/gonum@v0.14.0/stat/samplemv/metropolishastings.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package samplemv 6 7 import ( 8 "math" 9 10 "golang.org/x/exp/rand" 11 12 "gonum.org/v1/gonum/mat" 13 "gonum.org/v1/gonum/stat/distmv" 14 ) 15 16 var _ Sampler = MetropolisHastingser{} 17 18 // MHProposal defines a proposal distribution for Metropolis Hastings. 19 type MHProposal interface { 20 // ConditionalLogProb returns the probability of the first argument 21 // conditioned on being at the second argument. 22 // p(x|y) 23 // ConditionalLogProb panics if the input slices are not the same length. 24 ConditionalLogProb(x, y []float64) (prob float64) 25 26 // ConditionalRand generates a new random location conditioned being at the 27 // location y. If the first argument is nil, a new slice is allocated and 28 // returned. Otherwise, the random location is stored in-place into the first 29 // argument, and ConditionalRand will panic if the input slice lengths differ. 30 ConditionalRand(x, y []float64) []float64 31 } 32 33 // MetropolisHastingser is a type for generating samples using the Metropolis Hastings 34 // algorithm (http://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm), 35 // with the given target and proposal distributions, starting at the location 36 // specified by Initial. If src != nil, it will be used to generate random 37 // numbers, otherwise rand.Float64 will be used. 38 // 39 // Metropolis-Hastings is a Markov-chain Monte Carlo algorithm that generates 40 // samples according to the distribution specified by target using the Markov 41 // chain implicitly defined by the proposal distribution. At each 42 // iteration, a proposal point is generated randomly from the current location. 43 // This proposal point is accepted with probability 44 // 45 // p = min(1, (target(new) * proposal(current|new)) / (target(current) * proposal(new|current))) 46 // 47 // If the new location is accepted, it becomes the new current location. 48 // If it is rejected, the current location remains. This is the sample stored in 49 // batch, ignoring BurnIn and Rate (discussed below). 50 // 51 // The samples in Metropolis Hastings are correlated with one another through the 52 // Markov chain. As a result, the initial value can have a significant influence 53 // on the early samples, and so, typically, the first samples generated by the chain 54 // are ignored. This is known as "burn-in", and the number of samples ignored 55 // at the beginning is specified by BurnIn. The proper BurnIn value will depend 56 // on the mixing time of the Markov chain defined by the target and proposal 57 // distributions. 58 // 59 // Many choose to have a sampling "rate" where a number of samples 60 // are ignored in between each kept sample. This helps decorrelate 61 // the samples from one another, but also reduces the number of available samples. 62 // This value is specified by Rate. If Rate is 0 it is defaulted to 1 (keep 63 // every sample). 64 // 65 // The initial value is NOT changed during calls to Sample. 66 type MetropolisHastingser struct { 67 Initial []float64 68 Target distmv.LogProber 69 Proposal MHProposal 70 Src rand.Source 71 72 BurnIn int 73 Rate int 74 } 75 76 // Sample generates rows(batch) samples using the Metropolis Hastings sample 77 // generation method. The initial location is NOT updated during the call to Sample. 78 // 79 // The number of columns in batch must equal len(m.Initial), otherwise Sample 80 // will panic. 81 func (m MetropolisHastingser) Sample(batch *mat.Dense) { 82 burnIn := m.BurnIn 83 rate := m.Rate 84 if rate == 0 { 85 rate = 1 86 } 87 r, c := batch.Dims() 88 if len(m.Initial) != c { 89 panic("metropolishastings: length mismatch") 90 } 91 92 // Use the optimal size for the temporary memory to allow the fewest calls 93 // to MetropolisHastings. The case where tmp shadows samples must be 94 // aligned with the logic after burn-in so that tmp does not shadow samples 95 // during the rate portion. 96 tmp := batch 97 if rate > r { 98 tmp = mat.NewDense(rate, c, nil) 99 } 100 rTmp, _ := tmp.Dims() 101 102 // Perform burn-in. 103 remaining := burnIn 104 initial := make([]float64, c) 105 copy(initial, m.Initial) 106 for remaining != 0 { 107 newSamp := min(rTmp, remaining) 108 metropolisHastings(tmp.Slice(0, newSamp, 0, c).(*mat.Dense), initial, m.Target, m.Proposal, m.Src) 109 copy(initial, tmp.RawRowView(newSamp-1)) 110 remaining -= newSamp 111 } 112 113 if rate == 1 { 114 metropolisHastings(batch, initial, m.Target, m.Proposal, m.Src) 115 return 116 } 117 118 if rTmp <= r { 119 tmp = mat.NewDense(rate, c, nil) 120 } 121 122 // Take a single sample from the chain. 123 metropolisHastings(batch.Slice(0, 1, 0, c).(*mat.Dense), initial, m.Target, m.Proposal, m.Src) 124 125 copy(initial, batch.RawRowView(0)) 126 // For all of the other samples, first generate Rate samples and then actually 127 // accept the last one. 128 for i := 1; i < r; i++ { 129 metropolisHastings(tmp, initial, m.Target, m.Proposal, m.Src) 130 v := tmp.RawRowView(rate - 1) 131 batch.SetRow(i, v) 132 copy(initial, v) 133 } 134 } 135 136 func metropolisHastings(batch *mat.Dense, initial []float64, target distmv.LogProber, proposal MHProposal, src rand.Source) { 137 f64 := rand.Float64 138 if src != nil { 139 f64 = rand.New(src).Float64 140 } 141 if len(initial) == 0 { 142 panic("metropolishastings: zero length initial") 143 } 144 r, _ := batch.Dims() 145 current := make([]float64, len(initial)) 146 copy(current, initial) 147 proposed := make([]float64, len(initial)) 148 currentLogProb := target.LogProb(initial) 149 for i := 0; i < r; i++ { 150 proposal.ConditionalRand(proposed, current) 151 proposedLogProb := target.LogProb(proposed) 152 probTo := proposal.ConditionalLogProb(proposed, current) 153 probBack := proposal.ConditionalLogProb(current, proposed) 154 155 accept := math.Exp(proposedLogProb + probBack - probTo - currentLogProb) 156 if accept > f64() { 157 copy(current, proposed) 158 currentLogProb = proposedLogProb 159 } 160 batch.SetRow(i, current) 161 } 162 } 163 164 // ProposalNormal is a sampling distribution for Metropolis-Hastings. It has a 165 // fixed covariance matrix and changes the mean based on the current sampling 166 // location. 167 type ProposalNormal struct { 168 normal *distmv.Normal 169 } 170 171 // NewProposalNormal constructs a new ProposalNormal for use as a proposal 172 // distribution for Metropolis-Hastings. ProposalNormal is a multivariate normal 173 // distribution (implemented by distmv.Normal) where the covariance matrix is fixed 174 // and the mean of the distribution changes. 175 // 176 // NewProposalNormal returns {nil, false} if the covariance matrix is not positive-definite. 177 func NewProposalNormal(sigma *mat.SymDense, src rand.Source) (*ProposalNormal, bool) { 178 mu := make([]float64, sigma.SymmetricDim()) 179 normal, ok := distmv.NewNormal(mu, sigma, src) 180 if !ok { 181 return nil, false 182 } 183 p := &ProposalNormal{ 184 normal: normal, 185 } 186 return p, true 187 } 188 189 // ConditionalLogProb returns the probability of the first argument conditioned on 190 // being at the second argument. 191 // 192 // p(x|y) 193 // 194 // ConditionalLogProb panics if the input slices are not the same length or 195 // are not equal to the dimension of the covariance matrix. 196 func (p *ProposalNormal) ConditionalLogProb(x, y []float64) (prob float64) { 197 // Either SetMean or LogProb will panic if the slice lengths are inaccurate. 198 p.normal.SetMean(y) 199 return p.normal.LogProb(x) 200 } 201 202 // ConditionalRand generates a new random location conditioned being at the 203 // location y. If the first argument is nil, a new slice is allocated and 204 // returned. Otherwise, the random location is stored in-place into the first 205 // argument, and ConditionalRand will panic if the input slice lengths differ or 206 // if they are not equal to the dimension of the covariance matrix. 207 func (p *ProposalNormal) ConditionalRand(x, y []float64) []float64 { 208 if x == nil { 209 x = make([]float64, p.normal.Dim()) 210 } 211 if len(x) != len(y) { 212 panic(errLengthMismatch) 213 } 214 p.normal.SetMean(y) 215 p.normal.Rand(x) 216 return x 217 }