gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/stat/samplemv/samplemv.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package samplemv
     6  
     7  import (
     8  	"errors"
     9  	"math"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"gonum.org/v1/gonum/mat"
    14  	"gonum.org/v1/gonum/stat/distmv"
    15  )
    16  
    17  const errLengthMismatch = "samplemv: slice length mismatch"
    18  
    19  var (
    20  	_ Sampler = LatinHypercube{}
    21  	_ Sampler = (*Rejection)(nil)
    22  	_ Sampler = IID{}
    23  
    24  	_ WeightedSampler = SampleUniformWeighted{}
    25  	_ WeightedSampler = Importance{}
    26  )
    27  
    28  // Sampler generates a batch of samples according to the rule specified by the
    29  // implementing type. The number of samples generated is equal to rows(batch),
    30  // and the samples are stored in-place into the input.
    31  type Sampler interface {
    32  	Sample(batch *mat.Dense)
    33  }
    34  
    35  // WeightedSampler generates a batch of samples and their relative weights
    36  // according to the rule specified by the implementing type. The number of samples
    37  // generated is equal to rows(batch), and the samples and weights
    38  // are stored in-place into the inputs. The length of weights must equal
    39  // rows(batch), otherwise SampleWeighted will panic.
    40  type WeightedSampler interface {
    41  	SampleWeighted(batch *mat.Dense, weights []float64)
    42  }
    43  
    44  // SampleUniformWeighted wraps a Sampler type to create a WeightedSampler where all
    45  // weights are equal.
    46  type SampleUniformWeighted struct {
    47  	Sampler
    48  }
    49  
    50  // SampleWeighted generates rows(batch) samples from the embedded Sampler type
    51  // and sets all of the weights equal to 1. If rows(batch) and len(weights)
    52  // of weights are not equal, SampleWeighted will panic.
    53  func (w SampleUniformWeighted) SampleWeighted(batch *mat.Dense, weights []float64) {
    54  	r, _ := batch.Dims()
    55  	if r != len(weights) {
    56  		panic(errLengthMismatch)
    57  	}
    58  	w.Sample(batch)
    59  	for i := range weights {
    60  		weights[i] = 1
    61  	}
    62  }
    63  
    64  // LatinHypercube is a type for sampling using Latin hypercube sampling
    65  // from the given distribution. If src is not nil, it will be used to generate
    66  // random numbers, otherwise rand.Float64 will be used.
    67  //
    68  // Latin hypercube sampling divides the cumulative distribution function into equally
    69  // spaced bins and guarantees that one sample is generated per bin. Within each bin,
    70  // the location is randomly sampled. The distmv.NewUnitUniform function can be used
    71  // for easy sampling from the unit hypercube.
    72  type LatinHypercube struct {
    73  	Q   distmv.Quantiler
    74  	Src rand.Source
    75  }
    76  
    77  // Sample generates rows(batch) samples using the LatinHypercube generation
    78  // procedure.
    79  func (l LatinHypercube) Sample(batch *mat.Dense) {
    80  	latinHypercube(batch, l.Q, l.Src)
    81  }
    82  
    83  func latinHypercube(batch *mat.Dense, q distmv.Quantiler, src rand.Source) {
    84  	r, c := batch.Dims()
    85  	var f64 func() float64
    86  	var perm func(int) []int
    87  	if src != nil {
    88  		r := rand.New(src)
    89  		f64 = r.Float64
    90  		perm = r.Perm
    91  	} else {
    92  		f64 = rand.Float64
    93  		perm = rand.Perm
    94  	}
    95  	r64 := float64(r)
    96  	for i := 0; i < c; i++ {
    97  		p := perm(r)
    98  		for j := 0; j < r; j++ {
    99  			v := f64()/r64 + float64(j)/r64
   100  			batch.Set(p[j], i, v)
   101  		}
   102  	}
   103  	p := make([]float64, c)
   104  	for i := 0; i < r; i++ {
   105  		copy(p, batch.RawRowView(i))
   106  		q.Quantile(batch.RawRowView(i), p)
   107  	}
   108  }
   109  
   110  // Importance is a type for performing importance sampling using the given
   111  // Target and Proposal distributions.
   112  //
   113  // Importance sampling is a variance reduction technique where samples are
   114  // generated from a proposal distribution, q(x), instead of the target distribution
   115  // p(x). This allows relatively unlikely samples in p(x) to be generated more frequently.
   116  //
   117  // The importance sampling weight at x is given by p(x)/q(x). To reduce variance,
   118  // a good proposal distribution will bound this sampling weight. This implies the
   119  // support of q(x) should be at least as broad as p(x), and q(x) should be "fatter tailed"
   120  // than p(x).
   121  type Importance struct {
   122  	Target   distmv.LogProber
   123  	Proposal distmv.RandLogProber
   124  }
   125  
   126  // SampleWeighted generates rows(batch) samples using the Importance sampling
   127  // generation procedure.
   128  //
   129  // The length of weights must equal the length of batch, otherwise Importance will panic.
   130  func (l Importance) SampleWeighted(batch *mat.Dense, weights []float64) {
   131  	importance(batch, weights, l.Target, l.Proposal)
   132  }
   133  
   134  func importance(batch *mat.Dense, weights []float64, target distmv.LogProber, proposal distmv.RandLogProber) {
   135  	r, _ := batch.Dims()
   136  	if r != len(weights) {
   137  		panic(errLengthMismatch)
   138  	}
   139  	for i := 0; i < r; i++ {
   140  		v := batch.RawRowView(i)
   141  		proposal.Rand(v)
   142  		weights[i] = math.Exp(target.LogProb(v) - proposal.LogProb(v))
   143  	}
   144  }
   145  
   146  // ErrRejection is returned when the constant in Rejection is not sufficiently high.
   147  var ErrRejection = errors.New("rejection: acceptance ratio above 1")
   148  
   149  // Rejection is a type for sampling using the rejection sampling algorithm.
   150  //
   151  // Rejection sampling generates points from the target distribution by using
   152  // the proposal distribution. At each step of the algorithm, the proposed point
   153  // is accepted with probability
   154  //
   155  //	p = target(x) / (proposal(x) * c)
   156  //
   157  // where target(x) is the probability of the point according to the target distribution
   158  // and proposal(x) is the probability according to the proposal distribution.
   159  // The constant c must be chosen such that target(x) < proposal(x) * c for all x.
   160  // The expected number of proposed samples is len(samples) * c.
   161  //
   162  // The number of proposed locations during sampling can be found with a call to
   163  // Proposed. If there was an error during sampling, all elements of samples are
   164  // set to NaN and the error can be accessed with the Err method. If src != nil,
   165  // it will be used to generate random numbers, otherwise rand.Float64 will be used.
   166  //
   167  // Target may return the true (log of) the probability of the location, or it may return
   168  // a value that is proportional to the probability (logprob + constant). This is
   169  // useful for cases where the probability distribution is only known up to a normalization
   170  // constant.
   171  type Rejection struct {
   172  	C        float64
   173  	Target   distmv.LogProber
   174  	Proposal distmv.RandLogProber
   175  	Src      rand.Source
   176  
   177  	err      error
   178  	proposed int
   179  }
   180  
   181  // Err returns nil if the most recent call to sample was successful, and returns
   182  // ErrRejection if it was not.
   183  func (r *Rejection) Err() error {
   184  	return r.err
   185  }
   186  
   187  // Proposed returns the number of samples proposed during the most recent call to
   188  // Sample.
   189  func (r *Rejection) Proposed() int {
   190  	return r.proposed
   191  }
   192  
   193  // Sample generates rows(batch) using the Rejection sampling generation procedure.
   194  // Rejection sampling may fail if the constant is insufficiently high, as described
   195  // in the type comment for Rejection. If the generation fails, the samples
   196  // are set to math.NaN(), and a call to Err will return a non-nil value.
   197  func (r *Rejection) Sample(batch *mat.Dense) {
   198  	r.err = nil
   199  	r.proposed = 0
   200  	proposed, ok := rejection(batch, r.Target, r.Proposal, r.C, r.Src)
   201  	if !ok {
   202  		r.err = ErrRejection
   203  	}
   204  	r.proposed = proposed
   205  }
   206  
   207  func rejection(batch *mat.Dense, target distmv.LogProber, proposal distmv.RandLogProber, c float64, src rand.Source) (nProposed int, ok bool) {
   208  	if c < 1 {
   209  		panic("rejection: acceptance constant must be greater than 1")
   210  	}
   211  	f64 := rand.Float64
   212  	if src != nil {
   213  		f64 = rand.New(src).Float64
   214  	}
   215  	r, dim := batch.Dims()
   216  	v := make([]float64, dim)
   217  	var idx int
   218  	for {
   219  		nProposed++
   220  		proposal.Rand(v)
   221  		qx := proposal.LogProb(v)
   222  		px := target.LogProb(v)
   223  		accept := math.Exp(px-qx) / c
   224  		if accept > 1 {
   225  			// Invalidate the whole result and return a failure.
   226  			for i := 0; i < r; i++ {
   227  				for j := 0; j < dim; j++ {
   228  					batch.Set(i, j, math.NaN())
   229  				}
   230  			}
   231  			return nProposed, false
   232  		}
   233  		if accept > f64() {
   234  			batch.SetRow(idx, v)
   235  			idx++
   236  			if idx == r {
   237  				break
   238  			}
   239  		}
   240  	}
   241  	return nProposed, true
   242  }
   243  
   244  // IID generates a set of independently and identically distributed samples from
   245  // the input distribution.
   246  type IID struct {
   247  	Dist distmv.Rander
   248  }
   249  
   250  // Sample generates a set of identically and independently distributed samples.
   251  func (iid IID) Sample(batch *mat.Dense) {
   252  	r, _ := batch.Dims()
   253  	for i := 0; i < r; i++ {
   254  		iid.Dist.Rand(batch.RawRowView(i))
   255  	}
   256  }