gonum.org/v1/gonum@v0.14.0/stat/samplemv/metropolishastings.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package samplemv
     6  
     7  import (
     8  	"math"
     9  
    10  	"golang.org/x/exp/rand"
    11  
    12  	"gonum.org/v1/gonum/mat"
    13  	"gonum.org/v1/gonum/stat/distmv"
    14  )
    15  
    16  var _ Sampler = MetropolisHastingser{}
    17  
    18  // MHProposal defines a proposal distribution for Metropolis Hastings.
    19  type MHProposal interface {
    20  	// ConditionalLogProb returns the probability of the first argument
    21  	// conditioned on being at the second argument.
    22  	//  p(x|y)
    23  	// ConditionalLogProb panics if the input slices are not the same length.
    24  	ConditionalLogProb(x, y []float64) (prob float64)
    25  
    26  	// ConditionalRand generates a new random location conditioned being at the
    27  	// location y. If the first argument is nil, a new slice is allocated and
    28  	// returned. Otherwise, the random location is stored in-place into the first
    29  	// argument, and ConditionalRand will panic if the input slice lengths differ.
    30  	ConditionalRand(x, y []float64) []float64
    31  }
    32  
    33  // MetropolisHastingser is a type for generating samples using the Metropolis Hastings
    34  // algorithm (http://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm),
    35  // with the given target and proposal distributions, starting at the location
    36  // specified by Initial. If src != nil, it will be used to generate random
    37  // numbers, otherwise rand.Float64 will be used.
    38  //
    39  // Metropolis-Hastings is a Markov-chain Monte Carlo algorithm that generates
    40  // samples according to the distribution specified by target using the Markov
    41  // chain implicitly defined by the proposal distribution. At each
    42  // iteration, a proposal point is generated randomly from the current location.
    43  // This proposal point is accepted with probability
    44  //
    45  //	p = min(1, (target(new) * proposal(current|new)) / (target(current) * proposal(new|current)))
    46  //
    47  // If the new location is accepted, it becomes the new current location.
    48  // If it is rejected, the current location remains. This is the sample stored in
    49  // batch, ignoring BurnIn and Rate (discussed below).
    50  //
    51  // The samples in Metropolis Hastings are correlated with one another through the
    52  // Markov chain. As a result, the initial value can have a significant influence
    53  // on the early samples, and so, typically, the first samples generated by the chain
    54  // are ignored. This is known as "burn-in", and the number of samples ignored
    55  // at the beginning is specified by BurnIn. The proper BurnIn value will depend
    56  // on the mixing time of the Markov chain defined by the target and proposal
    57  // distributions.
    58  //
    59  // Many choose to have a sampling "rate" where a number of samples
    60  // are ignored in between each kept sample. This helps decorrelate
    61  // the samples from one another, but also reduces the number of available samples.
    62  // This value is specified by Rate. If Rate is 0 it is defaulted to 1 (keep
    63  // every sample).
    64  //
    65  // The initial value is NOT changed during calls to Sample.
    66  type MetropolisHastingser struct {
    67  	Initial  []float64
    68  	Target   distmv.LogProber
    69  	Proposal MHProposal
    70  	Src      rand.Source
    71  
    72  	BurnIn int
    73  	Rate   int
    74  }
    75  
    76  // Sample generates rows(batch) samples using the Metropolis Hastings sample
    77  // generation method. The initial location is NOT updated during the call to Sample.
    78  //
    79  // The number of columns in batch must equal len(m.Initial), otherwise Sample
    80  // will panic.
    81  func (m MetropolisHastingser) Sample(batch *mat.Dense) {
    82  	burnIn := m.BurnIn
    83  	rate := m.Rate
    84  	if rate == 0 {
    85  		rate = 1
    86  	}
    87  	r, c := batch.Dims()
    88  	if len(m.Initial) != c {
    89  		panic("metropolishastings: length mismatch")
    90  	}
    91  
    92  	// Use the optimal size for the temporary memory to allow the fewest calls
    93  	// to MetropolisHastings. The case where tmp shadows samples must be
    94  	// aligned with the logic after burn-in so that tmp does not shadow samples
    95  	// during the rate portion.
    96  	tmp := batch
    97  	if rate > r {
    98  		tmp = mat.NewDense(rate, c, nil)
    99  	}
   100  	rTmp, _ := tmp.Dims()
   101  
   102  	// Perform burn-in.
   103  	remaining := burnIn
   104  	initial := make([]float64, c)
   105  	copy(initial, m.Initial)
   106  	for remaining != 0 {
   107  		newSamp := min(rTmp, remaining)
   108  		metropolisHastings(tmp.Slice(0, newSamp, 0, c).(*mat.Dense), initial, m.Target, m.Proposal, m.Src)
   109  		copy(initial, tmp.RawRowView(newSamp-1))
   110  		remaining -= newSamp
   111  	}
   112  
   113  	if rate == 1 {
   114  		metropolisHastings(batch, initial, m.Target, m.Proposal, m.Src)
   115  		return
   116  	}
   117  
   118  	if rTmp <= r {
   119  		tmp = mat.NewDense(rate, c, nil)
   120  	}
   121  
   122  	// Take a single sample from the chain.
   123  	metropolisHastings(batch.Slice(0, 1, 0, c).(*mat.Dense), initial, m.Target, m.Proposal, m.Src)
   124  
   125  	copy(initial, batch.RawRowView(0))
   126  	// For all of the other samples, first generate Rate samples and then actually
   127  	// accept the last one.
   128  	for i := 1; i < r; i++ {
   129  		metropolisHastings(tmp, initial, m.Target, m.Proposal, m.Src)
   130  		v := tmp.RawRowView(rate - 1)
   131  		batch.SetRow(i, v)
   132  		copy(initial, v)
   133  	}
   134  }
   135  
   136  func metropolisHastings(batch *mat.Dense, initial []float64, target distmv.LogProber, proposal MHProposal, src rand.Source) {
   137  	f64 := rand.Float64
   138  	if src != nil {
   139  		f64 = rand.New(src).Float64
   140  	}
   141  	if len(initial) == 0 {
   142  		panic("metropolishastings: zero length initial")
   143  	}
   144  	r, _ := batch.Dims()
   145  	current := make([]float64, len(initial))
   146  	copy(current, initial)
   147  	proposed := make([]float64, len(initial))
   148  	currentLogProb := target.LogProb(initial)
   149  	for i := 0; i < r; i++ {
   150  		proposal.ConditionalRand(proposed, current)
   151  		proposedLogProb := target.LogProb(proposed)
   152  		probTo := proposal.ConditionalLogProb(proposed, current)
   153  		probBack := proposal.ConditionalLogProb(current, proposed)
   154  
   155  		accept := math.Exp(proposedLogProb + probBack - probTo - currentLogProb)
   156  		if accept > f64() {
   157  			copy(current, proposed)
   158  			currentLogProb = proposedLogProb
   159  		}
   160  		batch.SetRow(i, current)
   161  	}
   162  }
   163  
   164  // ProposalNormal is a sampling distribution for Metropolis-Hastings. It has a
   165  // fixed covariance matrix and changes the mean based on the current sampling
   166  // location.
   167  type ProposalNormal struct {
   168  	normal *distmv.Normal
   169  }
   170  
   171  // NewProposalNormal constructs a new ProposalNormal for use as a proposal
   172  // distribution for Metropolis-Hastings. ProposalNormal is a multivariate normal
   173  // distribution (implemented by distmv.Normal) where the covariance matrix is fixed
   174  // and the mean of the distribution changes.
   175  //
   176  // NewProposalNormal returns {nil, false} if the covariance matrix is not positive-definite.
   177  func NewProposalNormal(sigma *mat.SymDense, src rand.Source) (*ProposalNormal, bool) {
   178  	mu := make([]float64, sigma.SymmetricDim())
   179  	normal, ok := distmv.NewNormal(mu, sigma, src)
   180  	if !ok {
   181  		return nil, false
   182  	}
   183  	p := &ProposalNormal{
   184  		normal: normal,
   185  	}
   186  	return p, true
   187  }
   188  
   189  // ConditionalLogProb returns the probability of the first argument conditioned on
   190  // being at the second argument.
   191  //
   192  //	p(x|y)
   193  //
   194  // ConditionalLogProb panics if the input slices are not the same length or
   195  // are not equal to the dimension of the covariance matrix.
   196  func (p *ProposalNormal) ConditionalLogProb(x, y []float64) (prob float64) {
   197  	// Either SetMean or LogProb will panic if the slice lengths are inaccurate.
   198  	p.normal.SetMean(y)
   199  	return p.normal.LogProb(x)
   200  }
   201  
   202  // ConditionalRand generates a new random location conditioned being at the
   203  // location y. If the first argument is nil, a new slice is allocated and
   204  // returned. Otherwise, the random location is stored in-place into the first
   205  // argument, and ConditionalRand will panic if the input slice lengths differ or
   206  // if they are not equal to the dimension of the covariance matrix.
   207  func (p *ProposalNormal) ConditionalRand(x, y []float64) []float64 {
   208  	if x == nil {
   209  		x = make([]float64, p.normal.Dim())
   210  	}
   211  	if len(x) != len(y) {
   212  		panic(errLengthMismatch)
   213  	}
   214  	p.normal.SetMean(y)
   215  	p.normal.Rand(x)
   216  	return x
   217  }