gonum.org/v1/gonum@v0.14.0/stat/samplemv/samplemv.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package samplemv
     6  
     7  import (
     8  	"errors"
     9  	"math"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"gonum.org/v1/gonum/mat"
    14  	"gonum.org/v1/gonum/stat/distmv"
    15  )
    16  
    17  const errLengthMismatch = "samplemv: slice length mismatch"
    18  
    19  var (
    20  	_ Sampler = LatinHypercube{}
    21  	_ Sampler = (*Rejection)(nil)
    22  	_ Sampler = IID{}
    23  
    24  	_ WeightedSampler = SampleUniformWeighted{}
    25  	_ WeightedSampler = Importance{}
    26  )
    27  
    28  func min(a, b int) int {
    29  	if a < b {
    30  		return a
    31  	}
    32  	return b
    33  }
    34  
    35  // Sampler generates a batch of samples according to the rule specified by the
    36  // implementing type. The number of samples generated is equal to rows(batch),
    37  // and the samples are stored in-place into the input.
    38  type Sampler interface {
    39  	Sample(batch *mat.Dense)
    40  }
    41  
    42  // WeightedSampler generates a batch of samples and their relative weights
    43  // according to the rule specified by the implementing type. The number of samples
    44  // generated is equal to rows(batch), and the samples and weights
    45  // are stored in-place into the inputs. The length of weights must equal
    46  // rows(batch), otherwise SampleWeighted will panic.
    47  type WeightedSampler interface {
    48  	SampleWeighted(batch *mat.Dense, weights []float64)
    49  }
    50  
    51  // SampleUniformWeighted wraps a Sampler type to create a WeightedSampler where all
    52  // weights are equal.
    53  type SampleUniformWeighted struct {
    54  	Sampler
    55  }
    56  
    57  // SampleWeighted generates rows(batch) samples from the embedded Sampler type
    58  // and sets all of the weights equal to 1. If rows(batch) and len(weights)
    59  // of weights are not equal, SampleWeighted will panic.
    60  func (w SampleUniformWeighted) SampleWeighted(batch *mat.Dense, weights []float64) {
    61  	r, _ := batch.Dims()
    62  	if r != len(weights) {
    63  		panic(errLengthMismatch)
    64  	}
    65  	w.Sample(batch)
    66  	for i := range weights {
    67  		weights[i] = 1
    68  	}
    69  }
    70  
    71  // LatinHypercube is a type for sampling using Latin hypercube sampling
    72  // from the given distribution. If src is not nil, it will be used to generate
    73  // random numbers, otherwise rand.Float64 will be used.
    74  //
    75  // Latin hypercube sampling divides the cumulative distribution function into equally
    76  // spaced bins and guarantees that one sample is generated per bin. Within each bin,
    77  // the location is randomly sampled. The distmv.NewUnitUniform function can be used
    78  // for easy sampling from the unit hypercube.
    79  type LatinHypercube struct {
    80  	Q   distmv.Quantiler
    81  	Src rand.Source
    82  }
    83  
    84  // Sample generates rows(batch) samples using the LatinHypercube generation
    85  // procedure.
    86  func (l LatinHypercube) Sample(batch *mat.Dense) {
    87  	latinHypercube(batch, l.Q, l.Src)
    88  }
    89  
    90  func latinHypercube(batch *mat.Dense, q distmv.Quantiler, src rand.Source) {
    91  	r, c := batch.Dims()
    92  	var f64 func() float64
    93  	var perm func(int) []int
    94  	if src != nil {
    95  		r := rand.New(src)
    96  		f64 = r.Float64
    97  		perm = r.Perm
    98  	} else {
    99  		f64 = rand.Float64
   100  		perm = rand.Perm
   101  	}
   102  	r64 := float64(r)
   103  	for i := 0; i < c; i++ {
   104  		p := perm(r)
   105  		for j := 0; j < r; j++ {
   106  			v := f64()/r64 + float64(j)/r64
   107  			batch.Set(p[j], i, v)
   108  		}
   109  	}
   110  	p := make([]float64, c)
   111  	for i := 0; i < r; i++ {
   112  		copy(p, batch.RawRowView(i))
   113  		q.Quantile(batch.RawRowView(i), p)
   114  	}
   115  }
   116  
   117  // Importance is a type for performing importance sampling using the given
   118  // Target and Proposal distributions.
   119  //
   120  // Importance sampling is a variance reduction technique where samples are
   121  // generated from a proposal distribution, q(x), instead of the target distribution
   122  // p(x). This allows relatively unlikely samples in p(x) to be generated more frequently.
   123  //
   124  // The importance sampling weight at x is given by p(x)/q(x). To reduce variance,
   125  // a good proposal distribution will bound this sampling weight. This implies the
   126  // support of q(x) should be at least as broad as p(x), and q(x) should be "fatter tailed"
   127  // than p(x).
   128  type Importance struct {
   129  	Target   distmv.LogProber
   130  	Proposal distmv.RandLogProber
   131  }
   132  
   133  // SampleWeighted generates rows(batch) samples using the Importance sampling
   134  // generation procedure.
   135  //
   136  // The length of weights must equal the length of batch, otherwise Importance will panic.
   137  func (l Importance) SampleWeighted(batch *mat.Dense, weights []float64) {
   138  	importance(batch, weights, l.Target, l.Proposal)
   139  }
   140  
   141  func importance(batch *mat.Dense, weights []float64, target distmv.LogProber, proposal distmv.RandLogProber) {
   142  	r, _ := batch.Dims()
   143  	if r != len(weights) {
   144  		panic(errLengthMismatch)
   145  	}
   146  	for i := 0; i < r; i++ {
   147  		v := batch.RawRowView(i)
   148  		proposal.Rand(v)
   149  		weights[i] = math.Exp(target.LogProb(v) - proposal.LogProb(v))
   150  	}
   151  }
   152  
   153  // ErrRejection is returned when the constant in Rejection is not sufficiently high.
   154  var ErrRejection = errors.New("rejection: acceptance ratio above 1")
   155  
   156  // Rejection is a type for sampling using the rejection sampling algorithm.
   157  //
   158  // Rejection sampling generates points from the target distribution by using
   159  // the proposal distribution. At each step of the algorithm, the proposed point
   160  // is accepted with probability
   161  //
   162  //	p = target(x) / (proposal(x) * c)
   163  //
   164  // where target(x) is the probability of the point according to the target distribution
   165  // and proposal(x) is the probability according to the proposal distribution.
   166  // The constant c must be chosen such that target(x) < proposal(x) * c for all x.
   167  // The expected number of proposed samples is len(samples) * c.
   168  //
   169  // The number of proposed locations during sampling can be found with a call to
   170  // Proposed. If there was an error during sampling, all elements of samples are
   171  // set to NaN and the error can be accessed with the Err method. If src != nil,
   172  // it will be used to generate random numbers, otherwise rand.Float64 will be used.
   173  //
   174  // Target may return the true (log of) the probability of the location, or it may return
   175  // a value that is proportional to the probability (logprob + constant). This is
   176  // useful for cases where the probability distribution is only known up to a normalization
   177  // constant.
   178  type Rejection struct {
   179  	C        float64
   180  	Target   distmv.LogProber
   181  	Proposal distmv.RandLogProber
   182  	Src      rand.Source
   183  
   184  	err      error
   185  	proposed int
   186  }
   187  
   188  // Err returns nil if the most recent call to sample was successful, and returns
   189  // ErrRejection if it was not.
   190  func (r *Rejection) Err() error {
   191  	return r.err
   192  }
   193  
   194  // Proposed returns the number of samples proposed during the most recent call to
   195  // Sample.
   196  func (r *Rejection) Proposed() int {
   197  	return r.proposed
   198  }
   199  
   200  // Sample generates rows(batch) using the Rejection sampling generation procedure.
   201  // Rejection sampling may fail if the constant is insufficiently high, as described
   202  // in the type comment for Rejection. If the generation fails, the samples
   203  // are set to math.NaN(), and a call to Err will return a non-nil value.
   204  func (r *Rejection) Sample(batch *mat.Dense) {
   205  	r.err = nil
   206  	r.proposed = 0
   207  	proposed, ok := rejection(batch, r.Target, r.Proposal, r.C, r.Src)
   208  	if !ok {
   209  		r.err = ErrRejection
   210  	}
   211  	r.proposed = proposed
   212  }
   213  
   214  func rejection(batch *mat.Dense, target distmv.LogProber, proposal distmv.RandLogProber, c float64, src rand.Source) (nProposed int, ok bool) {
   215  	if c < 1 {
   216  		panic("rejection: acceptance constant must be greater than 1")
   217  	}
   218  	f64 := rand.Float64
   219  	if src != nil {
   220  		f64 = rand.New(src).Float64
   221  	}
   222  	r, dim := batch.Dims()
   223  	v := make([]float64, dim)
   224  	var idx int
   225  	for {
   226  		nProposed++
   227  		proposal.Rand(v)
   228  		qx := proposal.LogProb(v)
   229  		px := target.LogProb(v)
   230  		accept := math.Exp(px-qx) / c
   231  		if accept > 1 {
   232  			// Invalidate the whole result and return a failure.
   233  			for i := 0; i < r; i++ {
   234  				for j := 0; j < dim; j++ {
   235  					batch.Set(i, j, math.NaN())
   236  				}
   237  			}
   238  			return nProposed, false
   239  		}
   240  		if accept > f64() {
   241  			batch.SetRow(idx, v)
   242  			idx++
   243  			if idx == r {
   244  				break
   245  			}
   246  		}
   247  	}
   248  	return nProposed, true
   249  }
   250  
   251  // IID generates a set of independently and identically distributed samples from
   252  // the input distribution.
   253  type IID struct {
   254  	Dist distmv.Rander
   255  }
   256  
   257  // Sample generates a set of identically and independently distributed samples.
   258  func (iid IID) Sample(batch *mat.Dense) {
   259  	r, _ := batch.Dims()
   260  	for i := 0; i < r; i++ {
   261  		iid.Dist.Rand(batch.RawRowView(i))
   262  	}
   263  }