github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/stat/sampleuv/sample.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sampleuv
     6  
     7  import (
     8  	"errors"
     9  	"math"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"github.com/jingcheng-WU/gonum/stat/distuv"
    14  )
    15  
    16  const badLengthMismatch = "sample: slice length mismatch"
    17  
    18  var (
    19  	_ Sampler = LatinHypercube{}
    20  	_ Sampler = MetropolisHastings{}
    21  	_ Sampler = (*Rejection)(nil)
    22  	_ Sampler = IIDer{}
    23  
    24  	_ WeightedSampler = SampleUniformWeighted{}
    25  	_ WeightedSampler = Importance{}
    26  )
    27  
    28  func min(a, b int) int {
    29  	if a < b {
    30  		return a
    31  	}
    32  	return b
    33  }
    34  
    35  // Sampler generates a batch of samples according to the rule specified by the
    36  // implementing type. The number of samples generated is equal to len(batch),
    37  // and the samples are stored in-place into the input.
    38  type Sampler interface {
    39  	Sample(batch []float64)
    40  }
    41  
    42  // WeightedSampler generates a batch of samples and their relative weights
    43  // according to the rule specified by the implementing type. The number of samples
    44  // generated is equal to len(batch), and the samples and weights
    45  // are stored in-place into the inputs. The length of weights must equal
    46  // len(batch), otherwise SampleWeighted will panic.
    47  type WeightedSampler interface {
    48  	SampleWeighted(batch, weights []float64)
    49  }
    50  
    51  // SampleUniformWeighted wraps a Sampler type to create a WeightedSampler where all
    52  // weights are equal.
    53  type SampleUniformWeighted struct {
    54  	Sampler
    55  }
    56  
    57  // SampleWeighted generates len(batch) samples from the embedded Sampler type
    58  // and sets all of the weights equal to 1. If len(batch) and len(weights)
    59  // are not equal, SampleWeighted will panic.
    60  func (w SampleUniformWeighted) SampleWeighted(batch, weights []float64) {
    61  	if len(batch) != len(weights) {
    62  		panic(badLengthMismatch)
    63  	}
    64  	w.Sample(batch)
    65  	for i := range weights {
    66  		weights[i] = 1
    67  	}
    68  }
    69  
    70  // LatinHypercube is a type for sampling using Latin hypercube sampling
    71  // from the given distribution. If src is not nil, it will be used to generate
    72  // random numbers, otherwise rand.Float64 will be used.
    73  //
    74  // Latin hypercube sampling divides the cumulative distribution function into equally
    75  // spaced bins and guarantees that one sample is generated per bin. Within each bin,
    76  // the location is randomly sampled. The distuv.UnitUniform variable can be used
    77  // for easy sampling from the unit hypercube.
    78  type LatinHypercube struct {
    79  	Q   distuv.Quantiler
    80  	Src rand.Source
    81  }
    82  
    83  // Sample generates len(batch) samples using the LatinHypercube generation
    84  // procedure.
    85  func (l LatinHypercube) Sample(batch []float64) {
    86  	latinHypercube(batch, l.Q, l.Src)
    87  }
    88  
    89  func latinHypercube(batch []float64, q distuv.Quantiler, src rand.Source) {
    90  	n := len(batch)
    91  	var perm []int
    92  	var f64 func() float64
    93  	if src != nil {
    94  		r := rand.New(src)
    95  		f64 = r.Float64
    96  		perm = r.Perm(n)
    97  	} else {
    98  		f64 = rand.Float64
    99  		perm = rand.Perm(n)
   100  	}
   101  	for i := range batch {
   102  		v := f64()/float64(n) + float64(i)/float64(n)
   103  		batch[perm[i]] = q.Quantile(v)
   104  	}
   105  }
   106  
   107  // Importance is a type for performing importance sampling using the given
   108  // Target and Proposal distributions.
   109  //
   110  // Importance sampling is a variance reduction technique where samples are
   111  // generated from a proposal distribution, q(x), instead of the target distribution
   112  // p(x). This allows relatively unlikely samples in p(x) to be generated more frequently.
   113  //
   114  // The importance sampling weight at x is given by p(x)/q(x). To reduce variance,
   115  // a good proposal distribution will bound this sampling weight. This implies the
   116  // support of q(x) should be at least as broad as p(x), and q(x) should be "fatter tailed"
   117  // than p(x).
   118  type Importance struct {
   119  	Target   distuv.LogProber
   120  	Proposal distuv.RandLogProber
   121  }
   122  
   123  // SampleWeighted generates len(batch) samples using the Importance sampling
   124  // generation procedure.
   125  //
   126  // The length of weights must equal the length of batch, otherwise Importance will panic.
   127  func (l Importance) SampleWeighted(batch, weights []float64) {
   128  	importance(batch, weights, l.Target, l.Proposal)
   129  }
   130  
   131  func importance(batch, weights []float64, target distuv.LogProber, proposal distuv.RandLogProber) {
   132  	if len(batch) != len(weights) {
   133  		panic(badLengthMismatch)
   134  	}
   135  	for i := range batch {
   136  		v := proposal.Rand()
   137  		batch[i] = v
   138  		weights[i] = math.Exp(target.LogProb(v) - proposal.LogProb(v))
   139  	}
   140  }
   141  
   142  // ErrRejection is returned when the constant in Rejection is not sufficiently high.
   143  var ErrRejection = errors.New("rejection: acceptance ratio above 1")
   144  
   145  // Rejection is a type for sampling using the rejection sampling algorithm.
   146  //
   147  // Rejection sampling generates points from the target distribution by using
   148  // the proposal distribution. At each step of the algorithm, the proposed point
   149  // is accepted with probability
   150  //  p = target(x) / (proposal(x) * c)
   151  // where target(x) is the probability of the point according to the target distribution
   152  // and proposal(x) is the probability according to the proposal distribution.
   153  // The constant c must be chosen such that target(x) < proposal(x) * c for all x.
   154  // The expected number of proposed samples is len(samples) * c.
   155  //
   156  // The number of proposed locations during sampling can be found with a call to
   157  // Proposed. If there was an error during sampling, all elements of samples are
   158  // set to NaN and the error can be accesssed with the Err method. If src != nil,
   159  // it will be used to generate random numbers, otherwise rand.Float64 will be used.
   160  //
   161  // Target may return the true (log of) the probablity of the location, or it may return
   162  // a value that is proportional to the probability (logprob + constant). This is
   163  // useful for cases where the probability distribution is only known up to a normalization
   164  // constant.
   165  type Rejection struct {
   166  	C        float64
   167  	Target   distuv.LogProber
   168  	Proposal distuv.RandLogProber
   169  	Src      rand.Source
   170  
   171  	err      error
   172  	proposed int
   173  }
   174  
   175  // Err returns nil if the most recent call to sample was successful, and returns
   176  // ErrRejection if it was not.
   177  func (r *Rejection) Err() error {
   178  	return r.err
   179  }
   180  
   181  // Proposed returns the number of samples proposed during the most recent call to
   182  // Sample.
   183  func (r *Rejection) Proposed() int {
   184  	return r.proposed
   185  }
   186  
   187  // Sample generates len(batch) using the Rejection sampling generation procedure.
   188  // Rejection sampling may fail if the constant is insufficiently high, as described
   189  // in the type comment for Rejection. If the generation fails, the samples
   190  // are set to math.NaN(), and a call to Err will return a non-nil value.
   191  func (r *Rejection) Sample(batch []float64) {
   192  	r.err = nil
   193  	r.proposed = 0
   194  	proposed, ok := rejection(batch, r.Target, r.Proposal, r.C, r.Src)
   195  	if !ok {
   196  		r.err = ErrRejection
   197  	}
   198  	r.proposed = proposed
   199  }
   200  
   201  func rejection(batch []float64, target distuv.LogProber, proposal distuv.RandLogProber, c float64, src rand.Source) (nProposed int, ok bool) {
   202  	if c < 1 {
   203  		panic("rejection: acceptance constant must be greater than 1")
   204  	}
   205  	f64 := rand.Float64
   206  	if src != nil {
   207  		f64 = rand.New(src).Float64
   208  	}
   209  	var idx int
   210  	for {
   211  		nProposed++
   212  		v := proposal.Rand()
   213  		qx := proposal.LogProb(v)
   214  		px := target.LogProb(v)
   215  		accept := math.Exp(px-qx) / c
   216  		if accept > 1 {
   217  			// Invalidate the whole result and return a failure.
   218  			for i := range batch {
   219  				batch[i] = math.NaN()
   220  			}
   221  			return nProposed, false
   222  		}
   223  		if accept > f64() {
   224  			batch[idx] = v
   225  			idx++
   226  			if idx == len(batch) {
   227  				break
   228  			}
   229  		}
   230  	}
   231  	return nProposed, true
   232  }
   233  
   234  // MHProposal defines a proposal distribution for Metropolis Hastings.
   235  type MHProposal interface {
   236  	// ConditionalDist returns the probability of the first argument conditioned on
   237  	// being at the second argument
   238  	//  p(x|y)
   239  	ConditionalLogProb(x, y float64) (prob float64)
   240  
   241  	// ConditionalRand generates a new random location conditioned being at the
   242  	// location y.
   243  	ConditionalRand(y float64) (x float64)
   244  }
   245  
   246  // MetropolisHastings is a type for generating samples using the Metropolis Hastings
   247  // algorithm (http://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm),
   248  // with the given target and proposal distributions, starting at the location
   249  // specified by Initial. If src != nil, it will be used to generate random
   250  // numbers, otherwise rand.Float64 will be used.
   251  //
   252  // Metropolis-Hastings is a Markov-chain Monte Carlo algorithm that generates
   253  // samples according to the distribution specified by target using the Markov
   254  // chain implicitly defined by the proposal distribution. At each
   255  // iteration, a proposal point is generated randomly from the current location.
   256  // This proposal point is accepted with probability
   257  //  p = min(1, (target(new) * proposal(current|new)) / (target(current) * proposal(new|current)))
   258  // If the new location is accepted, it becomes the new current location.
   259  // If it is rejected, the current location remains. This is the sample stored in
   260  // batch, ignoring BurnIn and Rate (discussed below).
   261  //
   262  // The samples in Metropolis Hastings are correlated with one another through the
   263  // Markov chain. As a result, the initial value can have a significant influence
   264  // on the early samples, and so, typically, the first samples generated by the chain
   265  // are ignored. This is known as "burn-in", and the number of samples ignored
   266  // at the beginning is specified by BurnIn. The proper BurnIn value will depend
   267  // on the mixing time of the Markov chain defined by the target and proposal
   268  // distributions.
   269  //
   270  // Many choose to have a sampling "rate" where a number of samples
   271  // are ignored in between each kept sample. This helps decorrelate
   272  // the samples from one another, but also reduces the number of available samples.
   273  // This value is specified by Rate. If Rate is 0 it is defaulted to 1 (keep
   274  // every sample).
   275  //
   276  // The initial value is NOT changed during calls to Sample.
   277  type MetropolisHastings struct {
   278  	Initial  float64
   279  	Target   distuv.LogProber
   280  	Proposal MHProposal
   281  	Src      rand.Source
   282  
   283  	BurnIn int
   284  	Rate   int
   285  }
   286  
   287  // Sample generates len(batch) samples using the Metropolis Hastings sample
   288  // generation method. The initial location is NOT updated during the call to Sample.
   289  func (m MetropolisHastings) Sample(batch []float64) {
   290  	burnIn := m.BurnIn
   291  	rate := m.Rate
   292  	if rate == 0 {
   293  		rate = 1
   294  	}
   295  
   296  	// Use the optimal size for the temporary memory to allow the fewest calls
   297  	// to MetropolisHastings. The case where tmp shadows samples must be
   298  	// aligned with the logic after burn-in so that tmp does not shadow samples
   299  	// during the rate portion.
   300  	tmp := batch
   301  	if rate > len(batch) {
   302  		tmp = make([]float64, rate)
   303  	}
   304  
   305  	// Perform burn-in.
   306  	remaining := burnIn
   307  	initial := m.Initial
   308  	for remaining != 0 {
   309  		newSamp := min(len(tmp), remaining)
   310  		metropolisHastings(tmp[newSamp:], initial, m.Target, m.Proposal, m.Src)
   311  		initial = tmp[newSamp-1]
   312  		remaining -= newSamp
   313  	}
   314  
   315  	if rate == 1 {
   316  		metropolisHastings(batch, initial, m.Target, m.Proposal, m.Src)
   317  		return
   318  	}
   319  
   320  	if len(tmp) <= len(batch) {
   321  		tmp = make([]float64, rate)
   322  	}
   323  
   324  	// Take a single sample from the chain
   325  	metropolisHastings(batch[0:1], initial, m.Target, m.Proposal, m.Src)
   326  	initial = batch[0]
   327  
   328  	// For all of the other samples, first generate Rate samples and then actually
   329  	// accept the last one.
   330  	for i := 1; i < len(batch); i++ {
   331  		metropolisHastings(tmp, initial, m.Target, m.Proposal, m.Src)
   332  		v := tmp[rate-1]
   333  		batch[i] = v
   334  		initial = v
   335  	}
   336  }
   337  
   338  func metropolisHastings(batch []float64, initial float64, target distuv.LogProber, proposal MHProposal, src rand.Source) {
   339  	f64 := rand.Float64
   340  	if src != nil {
   341  		f64 = rand.New(src).Float64
   342  	}
   343  	current := initial
   344  	currentLogProb := target.LogProb(initial)
   345  	for i := range batch {
   346  		proposed := proposal.ConditionalRand(current)
   347  		proposedLogProb := target.LogProb(proposed)
   348  		probTo := proposal.ConditionalLogProb(proposed, current)
   349  		probBack := proposal.ConditionalLogProb(current, proposed)
   350  
   351  		accept := math.Exp(proposedLogProb + probBack - probTo - currentLogProb)
   352  		if accept > f64() {
   353  			current = proposed
   354  			currentLogProb = proposedLogProb
   355  		}
   356  		batch[i] = current
   357  	}
   358  }
   359  
   360  // IIDer generates a set of independently and identically distributed samples from
   361  // the input distribution.
   362  type IIDer struct {
   363  	Dist distuv.Rander
   364  }
   365  
   366  // Sample generates a set of identically and independently distributed samples.
   367  func (iid IIDer) Sample(batch []float64) {
   368  	for i := range batch {
   369  		batch[i] = iid.Dist.Rand()
   370  	}
   371  }