github.com/gopherd/gonum@v0.0.4/stat/samplemv/metropolishastings.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package samplemv 6 7 import ( 8 "math" 9 10 "math/rand" 11 12 "github.com/gopherd/gonum/mat" 13 "github.com/gopherd/gonum/stat/distmv" 14 ) 15 16 var _ Sampler = MetropolisHastingser{} 17 18 // MHProposal defines a proposal distribution for Metropolis Hastings. 19 type MHProposal interface { 20 // ConditionalLogProb returns the probability of the first argument 21 // conditioned on being at the second argument. 22 // p(x|y) 23 // ConditionalLogProb panics if the input slices are not the same length. 24 ConditionalLogProb(x, y []float64) (prob float64) 25 26 // ConditionalRand generates a new random location conditioned being at the 27 // location y. If the first argument is nil, a new slice is allocated and 28 // returned. Otherwise, the random location is stored in-place into the first 29 // argument, and ConditionalRand will panic if the input slice lengths differ. 30 ConditionalRand(x, y []float64) []float64 31 } 32 33 // MetropolisHastingser is a type for generating samples using the Metropolis Hastings 34 // algorithm (http://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm), 35 // with the given target and proposal distributions, starting at the location 36 // specified by Initial. If src != nil, it will be used to generate random 37 // numbers, otherwise rand.Float64 will be used. 38 // 39 // Metropolis-Hastings is a Markov-chain Monte Carlo algorithm that generates 40 // samples according to the distribution specified by target using the Markov 41 // chain implicitly defined by the proposal distribution. At each 42 // iteration, a proposal point is generated randomly from the current location. 43 // This proposal point is accepted with probability 44 // p = min(1, (target(new) * proposal(current|new)) / (target(current) * proposal(new|current))) 45 // If the new location is accepted, it becomes the new current location. 46 // If it is rejected, the current location remains. This is the sample stored in 47 // batch, ignoring BurnIn and Rate (discussed below). 48 // 49 // The samples in Metropolis Hastings are correlated with one another through the 50 // Markov chain. As a result, the initial value can have a significant influence 51 // on the early samples, and so, typically, the first samples generated by the chain 52 // are ignored. This is known as "burn-in", and the number of samples ignored 53 // at the beginning is specified by BurnIn. The proper BurnIn value will depend 54 // on the mixing time of the Markov chain defined by the target and proposal 55 // distributions. 56 // 57 // Many choose to have a sampling "rate" where a number of samples 58 // are ignored in between each kept sample. This helps decorrelate 59 // the samples from one another, but also reduces the number of available samples. 60 // This value is specified by Rate. If Rate is 0 it is defaulted to 1 (keep 61 // every sample). 62 // 63 // The initial value is NOT changed during calls to Sample. 64 type MetropolisHastingser struct { 65 Initial []float64 66 Target distmv.LogProber 67 Proposal MHProposal 68 Src rand.Source 69 70 BurnIn int 71 Rate int 72 } 73 74 // Sample generates rows(batch) samples using the Metropolis Hastings sample 75 // generation method. The initial location is NOT updated during the call to Sample. 76 // 77 // The number of columns in batch must equal len(m.Initial), otherwise Sample 78 // will panic. 79 func (m MetropolisHastingser) Sample(batch *mat.Dense) { 80 burnIn := m.BurnIn 81 rate := m.Rate 82 if rate == 0 { 83 rate = 1 84 } 85 r, c := batch.Dims() 86 if len(m.Initial) != c { 87 panic("metropolishastings: length mismatch") 88 } 89 90 // Use the optimal size for the temporary memory to allow the fewest calls 91 // to MetropolisHastings. The case where tmp shadows samples must be 92 // aligned with the logic after burn-in so that tmp does not shadow samples 93 // during the rate portion. 94 tmp := batch 95 if rate > r { 96 tmp = mat.NewDense(rate, c, nil) 97 } 98 rTmp, _ := tmp.Dims() 99 100 // Perform burn-in. 101 remaining := burnIn 102 initial := make([]float64, c) 103 copy(initial, m.Initial) 104 for remaining != 0 { 105 newSamp := min(rTmp, remaining) 106 metropolisHastings(tmp.Slice(0, newSamp, 0, c).(*mat.Dense), initial, m.Target, m.Proposal, m.Src) 107 copy(initial, tmp.RawRowView(newSamp-1)) 108 remaining -= newSamp 109 } 110 111 if rate == 1 { 112 metropolisHastings(batch, initial, m.Target, m.Proposal, m.Src) 113 return 114 } 115 116 if rTmp <= r { 117 tmp = mat.NewDense(rate, c, nil) 118 } 119 120 // Take a single sample from the chain. 121 metropolisHastings(batch.Slice(0, 1, 0, c).(*mat.Dense), initial, m.Target, m.Proposal, m.Src) 122 123 copy(initial, batch.RawRowView(0)) 124 // For all of the other samples, first generate Rate samples and then actually 125 // accept the last one. 126 for i := 1; i < r; i++ { 127 metropolisHastings(tmp, initial, m.Target, m.Proposal, m.Src) 128 v := tmp.RawRowView(rate - 1) 129 batch.SetRow(i, v) 130 copy(initial, v) 131 } 132 } 133 134 func metropolisHastings(batch *mat.Dense, initial []float64, target distmv.LogProber, proposal MHProposal, src rand.Source) { 135 f64 := rand.Float64 136 if src != nil { 137 f64 = rand.New(src).Float64 138 } 139 if len(initial) == 0 { 140 panic("metropolishastings: zero length initial") 141 } 142 r, _ := batch.Dims() 143 current := make([]float64, len(initial)) 144 copy(current, initial) 145 proposed := make([]float64, len(initial)) 146 currentLogProb := target.LogProb(initial) 147 for i := 0; i < r; i++ { 148 proposal.ConditionalRand(proposed, current) 149 proposedLogProb := target.LogProb(proposed) 150 probTo := proposal.ConditionalLogProb(proposed, current) 151 probBack := proposal.ConditionalLogProb(current, proposed) 152 153 accept := math.Exp(proposedLogProb + probBack - probTo - currentLogProb) 154 if accept > f64() { 155 copy(current, proposed) 156 currentLogProb = proposedLogProb 157 } 158 batch.SetRow(i, current) 159 } 160 } 161 162 // ProposalNormal is a sampling distribution for Metropolis-Hastings. It has a 163 // fixed covariance matrix and changes the mean based on the current sampling 164 // location. 165 type ProposalNormal struct { 166 normal *distmv.Normal 167 } 168 169 // NewProposalNormal constructs a new ProposalNormal for use as a proposal 170 // distribution for Metropolis-Hastings. ProposalNormal is a multivariate normal 171 // distribution (implemented by distmv.Normal) where the covariance matrix is fixed 172 // and the mean of the distribution changes. 173 // 174 // NewProposalNormal returns {nil, false} if the covariance matrix is not positive-definite. 175 func NewProposalNormal(sigma *mat.SymDense, src rand.Source) (*ProposalNormal, bool) { 176 mu := make([]float64, sigma.SymmetricDim()) 177 normal, ok := distmv.NewNormal(mu, sigma, src) 178 if !ok { 179 return nil, false 180 } 181 p := &ProposalNormal{ 182 normal: normal, 183 } 184 return p, true 185 } 186 187 // ConditionalLogProb returns the probability of the first argument conditioned on 188 // being at the second argument. 189 // p(x|y) 190 // ConditionalLogProb panics if the input slices are not the same length or 191 // are not equal to the dimension of the covariance matrix. 192 func (p *ProposalNormal) ConditionalLogProb(x, y []float64) (prob float64) { 193 // Either SetMean or LogProb will panic if the slice lengths are innaccurate. 194 p.normal.SetMean(y) 195 return p.normal.LogProb(x) 196 } 197 198 // ConditionalRand generates a new random location conditioned being at the 199 // location y. If the first argument is nil, a new slice is allocated and 200 // returned. Otherwise, the random location is stored in-place into the first 201 // argument, and ConditionalRand will panic if the input slice lengths differ or 202 // if they are not equal to the dimension of the covariance matrix. 203 func (p *ProposalNormal) ConditionalRand(x, y []float64) []float64 { 204 if x == nil { 205 x = make([]float64, p.normal.Dim()) 206 } 207 if len(x) != len(y) { 208 panic(errLengthMismatch) 209 } 210 p.normal.SetMean(y) 211 p.normal.Rand(x) 212 return x 213 }