github.com/gopherd/gonum@v0.0.4/stat/samplemv/metropolishastings.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package samplemv
     6  
     7  import (
     8  	"math"
     9  
    10  	"math/rand"
    11  
    12  	"github.com/gopherd/gonum/mat"
    13  	"github.com/gopherd/gonum/stat/distmv"
    14  )
    15  
    16  var _ Sampler = MetropolisHastingser{}
    17  
    18  // MHProposal defines a proposal distribution for Metropolis Hastings.
    19  type MHProposal interface {
    20  	// ConditionalLogProb returns the probability of the first argument
    21  	// conditioned on being at the second argument.
    22  	//  p(x|y)
    23  	// ConditionalLogProb panics if the input slices are not the same length.
    24  	ConditionalLogProb(x, y []float64) (prob float64)
    25  
    26  	// ConditionalRand generates a new random location conditioned being at the
    27  	// location y. If the first argument is nil, a new slice is allocated and
    28  	// returned. Otherwise, the random location is stored in-place into the first
    29  	// argument, and ConditionalRand will panic if the input slice lengths differ.
    30  	ConditionalRand(x, y []float64) []float64
    31  }
    32  
    33  // MetropolisHastingser is a type for generating samples using the Metropolis Hastings
    34  // algorithm (http://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm),
    35  // with the given target and proposal distributions, starting at the location
    36  // specified by Initial. If src != nil, it will be used to generate random
    37  // numbers, otherwise rand.Float64 will be used.
    38  //
    39  // Metropolis-Hastings is a Markov-chain Monte Carlo algorithm that generates
    40  // samples according to the distribution specified by target using the Markov
    41  // chain implicitly defined by the proposal distribution. At each
    42  // iteration, a proposal point is generated randomly from the current location.
    43  // This proposal point is accepted with probability
    44  //  p = min(1, (target(new) * proposal(current|new)) / (target(current) * proposal(new|current)))
    45  // If the new location is accepted, it becomes the new current location.
    46  // If it is rejected, the current location remains. This is the sample stored in
    47  // batch, ignoring BurnIn and Rate (discussed below).
    48  //
    49  // The samples in Metropolis Hastings are correlated with one another through the
    50  // Markov chain. As a result, the initial value can have a significant influence
    51  // on the early samples, and so, typically, the first samples generated by the chain
    52  // are ignored. This is known as "burn-in", and the number of samples ignored
    53  // at the beginning is specified by BurnIn. The proper BurnIn value will depend
    54  // on the mixing time of the Markov chain defined by the target and proposal
    55  // distributions.
    56  //
    57  // Many choose to have a sampling "rate" where a number of samples
    58  // are ignored in between each kept sample. This helps decorrelate
    59  // the samples from one another, but also reduces the number of available samples.
    60  // This value is specified by Rate. If Rate is 0 it is defaulted to 1 (keep
    61  // every sample).
    62  //
    63  // The initial value is NOT changed during calls to Sample.
    64  type MetropolisHastingser struct {
    65  	Initial  []float64
    66  	Target   distmv.LogProber
    67  	Proposal MHProposal
    68  	Src      rand.Source
    69  
    70  	BurnIn int
    71  	Rate   int
    72  }
    73  
    74  // Sample generates rows(batch) samples using the Metropolis Hastings sample
    75  // generation method. The initial location is NOT updated during the call to Sample.
    76  //
    77  // The number of columns in batch must equal len(m.Initial), otherwise Sample
    78  // will panic.
    79  func (m MetropolisHastingser) Sample(batch *mat.Dense) {
    80  	burnIn := m.BurnIn
    81  	rate := m.Rate
    82  	if rate == 0 {
    83  		rate = 1
    84  	}
    85  	r, c := batch.Dims()
    86  	if len(m.Initial) != c {
    87  		panic("metropolishastings: length mismatch")
    88  	}
    89  
    90  	// Use the optimal size for the temporary memory to allow the fewest calls
    91  	// to MetropolisHastings. The case where tmp shadows samples must be
    92  	// aligned with the logic after burn-in so that tmp does not shadow samples
    93  	// during the rate portion.
    94  	tmp := batch
    95  	if rate > r {
    96  		tmp = mat.NewDense(rate, c, nil)
    97  	}
    98  	rTmp, _ := tmp.Dims()
    99  
   100  	// Perform burn-in.
   101  	remaining := burnIn
   102  	initial := make([]float64, c)
   103  	copy(initial, m.Initial)
   104  	for remaining != 0 {
   105  		newSamp := min(rTmp, remaining)
   106  		metropolisHastings(tmp.Slice(0, newSamp, 0, c).(*mat.Dense), initial, m.Target, m.Proposal, m.Src)
   107  		copy(initial, tmp.RawRowView(newSamp-1))
   108  		remaining -= newSamp
   109  	}
   110  
   111  	if rate == 1 {
   112  		metropolisHastings(batch, initial, m.Target, m.Proposal, m.Src)
   113  		return
   114  	}
   115  
   116  	if rTmp <= r {
   117  		tmp = mat.NewDense(rate, c, nil)
   118  	}
   119  
   120  	// Take a single sample from the chain.
   121  	metropolisHastings(batch.Slice(0, 1, 0, c).(*mat.Dense), initial, m.Target, m.Proposal, m.Src)
   122  
   123  	copy(initial, batch.RawRowView(0))
   124  	// For all of the other samples, first generate Rate samples and then actually
   125  	// accept the last one.
   126  	for i := 1; i < r; i++ {
   127  		metropolisHastings(tmp, initial, m.Target, m.Proposal, m.Src)
   128  		v := tmp.RawRowView(rate - 1)
   129  		batch.SetRow(i, v)
   130  		copy(initial, v)
   131  	}
   132  }
   133  
   134  func metropolisHastings(batch *mat.Dense, initial []float64, target distmv.LogProber, proposal MHProposal, src rand.Source) {
   135  	f64 := rand.Float64
   136  	if src != nil {
   137  		f64 = rand.New(src).Float64
   138  	}
   139  	if len(initial) == 0 {
   140  		panic("metropolishastings: zero length initial")
   141  	}
   142  	r, _ := batch.Dims()
   143  	current := make([]float64, len(initial))
   144  	copy(current, initial)
   145  	proposed := make([]float64, len(initial))
   146  	currentLogProb := target.LogProb(initial)
   147  	for i := 0; i < r; i++ {
   148  		proposal.ConditionalRand(proposed, current)
   149  		proposedLogProb := target.LogProb(proposed)
   150  		probTo := proposal.ConditionalLogProb(proposed, current)
   151  		probBack := proposal.ConditionalLogProb(current, proposed)
   152  
   153  		accept := math.Exp(proposedLogProb + probBack - probTo - currentLogProb)
   154  		if accept > f64() {
   155  			copy(current, proposed)
   156  			currentLogProb = proposedLogProb
   157  		}
   158  		batch.SetRow(i, current)
   159  	}
   160  }
   161  
   162  // ProposalNormal is a sampling distribution for Metropolis-Hastings. It has a
   163  // fixed covariance matrix and changes the mean based on the current sampling
   164  // location.
   165  type ProposalNormal struct {
   166  	normal *distmv.Normal
   167  }
   168  
   169  // NewProposalNormal constructs a new ProposalNormal for use as a proposal
   170  // distribution for Metropolis-Hastings. ProposalNormal is a multivariate normal
   171  // distribution (implemented by distmv.Normal) where the covariance matrix is fixed
   172  // and the mean of the distribution changes.
   173  //
   174  // NewProposalNormal returns {nil, false} if the covariance matrix is not positive-definite.
   175  func NewProposalNormal(sigma *mat.SymDense, src rand.Source) (*ProposalNormal, bool) {
   176  	mu := make([]float64, sigma.SymmetricDim())
   177  	normal, ok := distmv.NewNormal(mu, sigma, src)
   178  	if !ok {
   179  		return nil, false
   180  	}
   181  	p := &ProposalNormal{
   182  		normal: normal,
   183  	}
   184  	return p, true
   185  }
   186  
   187  // ConditionalLogProb returns the probability of the first argument conditioned on
   188  // being at the second argument.
   189  //  p(x|y)
   190  // ConditionalLogProb panics if the input slices are not the same length or
   191  // are not equal to the dimension of the covariance matrix.
   192  func (p *ProposalNormal) ConditionalLogProb(x, y []float64) (prob float64) {
   193  	// Either SetMean or LogProb will panic if the slice lengths are innaccurate.
   194  	p.normal.SetMean(y)
   195  	return p.normal.LogProb(x)
   196  }
   197  
   198  // ConditionalRand generates a new random location conditioned being at the
   199  // location y. If the first argument is nil, a new slice is allocated and
   200  // returned. Otherwise, the random location is stored in-place into the first
   201  // argument, and ConditionalRand will panic if the input slice lengths differ or
   202  // if they are not equal to the dimension of the covariance matrix.
   203  func (p *ProposalNormal) ConditionalRand(x, y []float64) []float64 {
   204  	if x == nil {
   205  		x = make([]float64, p.normal.Dim())
   206  	}
   207  	if len(x) != len(y) {
   208  		panic(errLengthMismatch)
   209  	}
   210  	p.normal.SetMean(y)
   211  	p.normal.Rand(x)
   212  	return x
   213  }