github.com/gopherd/gonum@v0.0.4/stat/samplemv/samplemv.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package samplemv
     6  
     7  import (
     8  	"errors"
     9  	"math"
    10  
    11  	"math/rand"
    12  
    13  	"github.com/gopherd/gonum/mat"
    14  	"github.com/gopherd/gonum/stat/distmv"
    15  )
    16  
    17  const errLengthMismatch = "samplemv: slice length mismatch"
    18  
    19  var (
    20  	_ Sampler = LatinHypercube{}
    21  	_ Sampler = (*Rejection)(nil)
    22  	_ Sampler = IID{}
    23  
    24  	_ WeightedSampler = SampleUniformWeighted{}
    25  	_ WeightedSampler = Importance{}
    26  )
    27  
    28  func min(a, b int) int {
    29  	if a < b {
    30  		return a
    31  	}
    32  	return b
    33  }
    34  
    35  // Sampler generates a batch of samples according to the rule specified by the
    36  // implementing type. The number of samples generated is equal to rows(batch),
    37  // and the samples are stored in-place into the input.
    38  type Sampler interface {
    39  	Sample(batch *mat.Dense)
    40  }
    41  
    42  // WeightedSampler generates a batch of samples and their relative weights
    43  // according to the rule specified by the implementing type. The number of samples
    44  // generated is equal to rows(batch), and the samples and weights
    45  // are stored in-place into the inputs. The length of weights must equal
    46  // rows(batch), otherwise SampleWeighted will panic.
    47  type WeightedSampler interface {
    48  	SampleWeighted(batch *mat.Dense, weights []float64)
    49  }
    50  
    51  // SampleUniformWeighted wraps a Sampler type to create a WeightedSampler where all
    52  // weights are equal.
    53  type SampleUniformWeighted struct {
    54  	Sampler
    55  }
    56  
    57  // SampleWeighted generates rows(batch) samples from the embedded Sampler type
    58  // and sets all of the weights equal to 1. If rows(batch) and len(weights)
    59  // of weights are not equal, SampleWeighted will panic.
    60  func (w SampleUniformWeighted) SampleWeighted(batch *mat.Dense, weights []float64) {
    61  	r, _ := batch.Dims()
    62  	if r != len(weights) {
    63  		panic(errLengthMismatch)
    64  	}
    65  	w.Sample(batch)
    66  	for i := range weights {
    67  		weights[i] = 1
    68  	}
    69  }
    70  
    71  // LatinHypercube is a type for sampling using Latin hypercube sampling
    72  // from the given distribution. If src is not nil, it will be used to generate
    73  // random numbers, otherwise rand.Float64 will be used.
    74  //
    75  // Latin hypercube sampling divides the cumulative distribution function into equally
    76  // spaced bins and guarantees that one sample is generated per bin. Within each bin,
    77  // the location is randomly sampled. The distmv.NewUnitUniform function can be used
    78  // for easy sampling from the unit hypercube.
    79  type LatinHypercube struct {
    80  	Q   distmv.Quantiler
    81  	Src rand.Source
    82  }
    83  
    84  // Sample generates rows(batch) samples using the LatinHypercube generation
    85  // procedure.
    86  func (l LatinHypercube) Sample(batch *mat.Dense) {
    87  	latinHypercube(batch, l.Q, l.Src)
    88  }
    89  
    90  func latinHypercube(batch *mat.Dense, q distmv.Quantiler, src rand.Source) {
    91  	r, c := batch.Dims()
    92  	var f64 func() float64
    93  	var perm func(int) []int
    94  	if src != nil {
    95  		r := rand.New(src)
    96  		f64 = r.Float64
    97  		perm = r.Perm
    98  	} else {
    99  		f64 = rand.Float64
   100  		perm = rand.Perm
   101  	}
   102  	r64 := float64(r)
   103  	for i := 0; i < c; i++ {
   104  		p := perm(r)
   105  		for j := 0; j < r; j++ {
   106  			v := f64()/r64 + float64(j)/r64
   107  			batch.Set(p[j], i, v)
   108  		}
   109  	}
   110  	p := make([]float64, c)
   111  	for i := 0; i < r; i++ {
   112  		copy(p, batch.RawRowView(i))
   113  		q.Quantile(batch.RawRowView(i), p)
   114  	}
   115  }
   116  
   117  // Importance is a type for performing importance sampling using the given
   118  // Target and Proposal distributions.
   119  //
   120  // Importance sampling is a variance reduction technique where samples are
   121  // generated from a proposal distribution, q(x), instead of the target distribution
   122  // p(x). This allows relatively unlikely samples in p(x) to be generated more frequently.
   123  //
   124  // The importance sampling weight at x is given by p(x)/q(x). To reduce variance,
   125  // a good proposal distribution will bound this sampling weight. This implies the
   126  // support of q(x) should be at least as broad as p(x), and q(x) should be "fatter tailed"
   127  // than p(x).
   128  type Importance struct {
   129  	Target   distmv.LogProber
   130  	Proposal distmv.RandLogProber
   131  }
   132  
   133  // SampleWeighted generates rows(batch) samples using the Importance sampling
   134  // generation procedure.
   135  //
   136  // The length of weights must equal the length of batch, otherwise Importance will panic.
   137  func (l Importance) SampleWeighted(batch *mat.Dense, weights []float64) {
   138  	importance(batch, weights, l.Target, l.Proposal)
   139  }
   140  
   141  func importance(batch *mat.Dense, weights []float64, target distmv.LogProber, proposal distmv.RandLogProber) {
   142  	r, _ := batch.Dims()
   143  	if r != len(weights) {
   144  		panic(errLengthMismatch)
   145  	}
   146  	for i := 0; i < r; i++ {
   147  		v := batch.RawRowView(i)
   148  		proposal.Rand(v)
   149  		weights[i] = math.Exp(target.LogProb(v) - proposal.LogProb(v))
   150  	}
   151  }
   152  
   153  // ErrRejection is returned when the constant in Rejection is not sufficiently high.
   154  var ErrRejection = errors.New("rejection: acceptance ratio above 1")
   155  
   156  // Rejection is a type for sampling using the rejection sampling algorithm.
   157  //
   158  // Rejection sampling generates points from the target distribution by using
   159  // the proposal distribution. At each step of the algorithm, the proposed point
   160  // is accepted with probability
   161  //  p = target(x) / (proposal(x) * c)
   162  // where target(x) is the probability of the point according to the target distribution
   163  // and proposal(x) is the probability according to the proposal distribution.
   164  // The constant c must be chosen such that target(x) < proposal(x) * c for all x.
   165  // The expected number of proposed samples is len(samples) * c.
   166  //
   167  // The number of proposed locations during sampling can be found with a call to
   168  // Proposed. If there was an error during sampling, all elements of samples are
   169  // set to NaN and the error can be accesssed with the Err method. If src != nil,
   170  // it will be used to generate random numbers, otherwise rand.Float64 will be used.
   171  //
   172  // Target may return the true (log of) the probablity of the location, or it may return
   173  // a value that is proportional to the probability (logprob + constant). This is
   174  // useful for cases where the probability distribution is only known up to a normalization
   175  // constant.
   176  type Rejection struct {
   177  	C        float64
   178  	Target   distmv.LogProber
   179  	Proposal distmv.RandLogProber
   180  	Src      rand.Source
   181  
   182  	err      error
   183  	proposed int
   184  }
   185  
   186  // Err returns nil if the most recent call to sample was successful, and returns
   187  // ErrRejection if it was not.
   188  func (r *Rejection) Err() error {
   189  	return r.err
   190  }
   191  
   192  // Proposed returns the number of samples proposed during the most recent call to
   193  // Sample.
   194  func (r *Rejection) Proposed() int {
   195  	return r.proposed
   196  }
   197  
   198  // Sample generates rows(batch) using the Rejection sampling generation procedure.
   199  // Rejection sampling may fail if the constant is insufficiently high, as described
   200  // in the type comment for Rejection. If the generation fails, the samples
   201  // are set to math.NaN(), and a call to Err will return a non-nil value.
   202  func (r *Rejection) Sample(batch *mat.Dense) {
   203  	r.err = nil
   204  	r.proposed = 0
   205  	proposed, ok := rejection(batch, r.Target, r.Proposal, r.C, r.Src)
   206  	if !ok {
   207  		r.err = ErrRejection
   208  	}
   209  	r.proposed = proposed
   210  }
   211  
   212  func rejection(batch *mat.Dense, target distmv.LogProber, proposal distmv.RandLogProber, c float64, src rand.Source) (nProposed int, ok bool) {
   213  	if c < 1 {
   214  		panic("rejection: acceptance constant must be greater than 1")
   215  	}
   216  	f64 := rand.Float64
   217  	if src != nil {
   218  		f64 = rand.New(src).Float64
   219  	}
   220  	r, dim := batch.Dims()
   221  	v := make([]float64, dim)
   222  	var idx int
   223  	for {
   224  		nProposed++
   225  		proposal.Rand(v)
   226  		qx := proposal.LogProb(v)
   227  		px := target.LogProb(v)
   228  		accept := math.Exp(px-qx) / c
   229  		if accept > 1 {
   230  			// Invalidate the whole result and return a failure.
   231  			for i := 0; i < r; i++ {
   232  				for j := 0; j < dim; j++ {
   233  					batch.Set(i, j, math.NaN())
   234  				}
   235  			}
   236  			return nProposed, false
   237  		}
   238  		if accept > f64() {
   239  			batch.SetRow(idx, v)
   240  			idx++
   241  			if idx == r {
   242  				break
   243  			}
   244  		}
   245  	}
   246  	return nProposed, true
   247  }
   248  
   249  // IID generates a set of independently and identically distributed samples from
   250  // the input distribution.
   251  type IID struct {
   252  	Dist distmv.Rander
   253  }
   254  
   255  // Sample generates a set of identically and independently distributed samples.
   256  func (iid IID) Sample(batch *mat.Dense) {
   257  	r, _ := batch.Dims()
   258  	for i := 0; i < r; i++ {
   259  		iid.Dist.Rand(batch.RawRowView(i))
   260  	}
   261  }