gonum.org/v1/gonum@v0.14.0/stat/distmv/dirichlet.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package distmv
     6  
     7  import (
     8  	"math"
     9  
    10  	"golang.org/x/exp/rand"
    11  
    12  	"gonum.org/v1/gonum/floats"
    13  	"gonum.org/v1/gonum/mat"
    14  	"gonum.org/v1/gonum/stat/distuv"
    15  )
    16  
    17  // Dirichlet implements the Dirichlet probability distribution.
    18  //
    19  // The Dirichlet distribution is a continuous probability distribution that
    20  // generates elements over the probability simplex, i.e. ||x||_1 = 1. The Dirichlet
    21  // distribution is the conjugate prior to the categorical distribution and the
    22  // multivariate version of the beta distribution. The probability of a point x is
    23  //
    24  //	1/Beta(α) \prod_i x_i^(α_i - 1)
    25  //
    26  // where Beta(α) is the multivariate Beta function (see the mathext package).
    27  //
    28  // For more information see https://en.wikipedia.org/wiki/Dirichlet_distribution
    29  type Dirichlet struct {
    30  	alpha []float64
    31  	dim   int
    32  	src   rand.Source
    33  
    34  	lbeta    float64
    35  	sumAlpha float64
    36  }
    37  
    38  // NewDirichlet creates a new dirichlet distribution with the given parameters alpha.
    39  // NewDirichlet will panic if len(alpha) == 0, or if any alpha is <= 0.
    40  func NewDirichlet(alpha []float64, src rand.Source) *Dirichlet {
    41  	dim := len(alpha)
    42  	if dim == 0 {
    43  		panic(badZeroDimension)
    44  	}
    45  	for _, v := range alpha {
    46  		if v <= 0 {
    47  			panic("dirichlet: non-positive alpha")
    48  		}
    49  	}
    50  	a := make([]float64, len(alpha))
    51  	copy(a, alpha)
    52  	d := &Dirichlet{
    53  		alpha: a,
    54  		dim:   dim,
    55  		src:   src,
    56  	}
    57  	d.lbeta, d.sumAlpha = d.genLBeta(a)
    58  	return d
    59  }
    60  
    61  // CovarianceMatrix calculates the covariance matrix of the distribution,
    62  // storing the result in dst. Upon return, the value at element {i, j} of the
    63  // covariance matrix is equal to the covariance of the i^th and j^th variables.
    64  //
    65  //	covariance(i, j) = E[(x_i - E[x_i])(x_j - E[x_j])]
    66  //
    67  // If the dst matrix is empty it will be resized to the correct dimensions,
    68  // otherwise dst must match the dimension of the receiver or CovarianceMatrix
    69  // will panic.
    70  func (d *Dirichlet) CovarianceMatrix(dst *mat.SymDense) {
    71  	if dst.IsEmpty() {
    72  		*dst = *(dst.GrowSym(d.dim).(*mat.SymDense))
    73  	} else if dst.SymmetricDim() != d.dim {
    74  		panic("dirichelet: input matrix size mismatch")
    75  	}
    76  	scale := 1 / (d.sumAlpha * d.sumAlpha * (d.sumAlpha + 1))
    77  	for i := 0; i < d.dim; i++ {
    78  		ai := d.alpha[i]
    79  		v := ai * (d.sumAlpha - ai) * scale
    80  		dst.SetSym(i, i, v)
    81  		for j := i + 1; j < d.dim; j++ {
    82  			aj := d.alpha[j]
    83  			v := -ai * aj * scale
    84  			dst.SetSym(i, j, v)
    85  		}
    86  	}
    87  }
    88  
    89  // genLBeta computes the generalized LBeta function.
    90  func (d *Dirichlet) genLBeta(alpha []float64) (lbeta, sumAlpha float64) {
    91  	for _, alpha := range d.alpha {
    92  		lg, _ := math.Lgamma(alpha)
    93  		lbeta += lg
    94  		sumAlpha += alpha
    95  	}
    96  	lg, _ := math.Lgamma(sumAlpha)
    97  	return lbeta - lg, sumAlpha
    98  }
    99  
   100  // Dim returns the dimension of the distribution.
   101  func (d *Dirichlet) Dim() int {
   102  	return d.dim
   103  }
   104  
   105  // LogProb computes the log of the pdf of the point x.
   106  //
   107  // It does not check that ||x||_1 = 1.
   108  func (d *Dirichlet) LogProb(x []float64) float64 {
   109  	dim := d.dim
   110  	if len(x) != dim {
   111  		panic(badSizeMismatch)
   112  	}
   113  	var lprob float64
   114  	for i, x := range x {
   115  		lprob += (d.alpha[i] - 1) * math.Log(x)
   116  	}
   117  	lprob -= d.lbeta
   118  	return lprob
   119  }
   120  
   121  // Mean returns the mean of the probability distribution at x. If the
   122  // input argument is nil, a new slice will be allocated, otherwise the result
   123  // will be put in-place into the receiver.
   124  func (d *Dirichlet) Mean(x []float64) []float64 {
   125  	x = reuseAs(x, d.dim)
   126  	copy(x, d.alpha)
   127  	floats.Scale(1/d.sumAlpha, x)
   128  	return x
   129  }
   130  
   131  // Prob computes the value of the probability density function at x.
   132  func (d *Dirichlet) Prob(x []float64) float64 {
   133  	return math.Exp(d.LogProb(x))
   134  }
   135  
   136  // Rand generates a random number according to the distributon.
   137  // If the input slice is nil, new memory is allocated, otherwise the result is stored
   138  // in place.
   139  func (d *Dirichlet) Rand(x []float64) []float64 {
   140  	x = reuseAs(x, d.dim)
   141  	for i := range x {
   142  		x[i] = distuv.Gamma{Alpha: d.alpha[i], Beta: 1, Src: d.src}.Rand()
   143  	}
   144  	sum := floats.Sum(x)
   145  	floats.Scale(1/sum, x)
   146  	return x
   147  }