github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/stat/distmv/dirichlet.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package distmv
     6  
     7  import (
     8  	"math"
     9  
    10  	"golang.org/x/exp/rand"
    11  
    12  	"github.com/jingcheng-WU/gonum/floats"
    13  	"github.com/jingcheng-WU/gonum/mat"
    14  	"github.com/jingcheng-WU/gonum/stat/distuv"
    15  )
    16  
    17  // Dirichlet implements the Dirichlet probability distribution.
    18  //
    19  // The Dirichlet distribution is a continuous probability distribution that
    20  // generates elements over the probability simplex, i.e. ||x||_1 = 1. The Dirichlet
    21  // distribution is the conjugate prior to the categorical distribution and the
    22  // multivariate version of the beta distribution. The probability of a point x is
    23  //  1/Beta(α) \prod_i x_i^(α_i - 1)
    24  // where Beta(α) is the multivariate Beta function (see the mathext package).
    25  //
    26  // For more information see https://en.wikipedia.org/wiki/Dirichlet_distribution
    27  type Dirichlet struct {
    28  	alpha []float64
    29  	dim   int
    30  	src   rand.Source
    31  
    32  	lbeta    float64
    33  	sumAlpha float64
    34  }
    35  
    36  // NewDirichlet creates a new dirichlet distribution with the given parameters alpha.
    37  // NewDirichlet will panic if len(alpha) == 0, or if any alpha is <= 0.
    38  func NewDirichlet(alpha []float64, src rand.Source) *Dirichlet {
    39  	dim := len(alpha)
    40  	if dim == 0 {
    41  		panic(badZeroDimension)
    42  	}
    43  	for _, v := range alpha {
    44  		if v <= 0 {
    45  			panic("dirichlet: non-positive alpha")
    46  		}
    47  	}
    48  	a := make([]float64, len(alpha))
    49  	copy(a, alpha)
    50  	d := &Dirichlet{
    51  		alpha: a,
    52  		dim:   dim,
    53  		src:   src,
    54  	}
    55  	d.lbeta, d.sumAlpha = d.genLBeta(a)
    56  	return d
    57  }
    58  
    59  // CovarianceMatrix calculates the covariance matrix of the distribution,
    60  // storing the result in dst. Upon return, the value at element {i, j} of the
    61  // covariance matrix is equal to the covariance of the i^th and j^th variables.
    62  //  covariance(i, j) = E[(x_i - E[x_i])(x_j - E[x_j])]
    63  // If the dst matrix is empty it will be resized to the correct dimensions,
    64  // otherwise dst must match the dimension of the receiver or CovarianceMatrix
    65  // will panic.
    66  func (d *Dirichlet) CovarianceMatrix(dst *mat.SymDense) {
    67  	if dst.IsEmpty() {
    68  		*dst = *(dst.GrowSym(d.dim).(*mat.SymDense))
    69  	} else if dst.Symmetric() != d.dim {
    70  		panic("dirichelet: input matrix size mismatch")
    71  	}
    72  	scale := 1 / (d.sumAlpha * d.sumAlpha * (d.sumAlpha + 1))
    73  	for i := 0; i < d.dim; i++ {
    74  		ai := d.alpha[i]
    75  		v := ai * (d.sumAlpha - ai) * scale
    76  		dst.SetSym(i, i, v)
    77  		for j := i + 1; j < d.dim; j++ {
    78  			aj := d.alpha[j]
    79  			v := -ai * aj * scale
    80  			dst.SetSym(i, j, v)
    81  		}
    82  	}
    83  }
    84  
    85  // genLBeta computes the generalized LBeta function.
    86  func (d *Dirichlet) genLBeta(alpha []float64) (lbeta, sumAlpha float64) {
    87  	for _, alpha := range d.alpha {
    88  		lg, _ := math.Lgamma(alpha)
    89  		lbeta += lg
    90  		sumAlpha += alpha
    91  	}
    92  	lg, _ := math.Lgamma(sumAlpha)
    93  	return lbeta - lg, sumAlpha
    94  }
    95  
    96  // Dim returns the dimension of the distribution.
    97  func (d *Dirichlet) Dim() int {
    98  	return d.dim
    99  }
   100  
   101  // LogProb computes the log of the pdf of the point x.
   102  //
   103  // It does not check that ||x||_1 = 1.
   104  func (d *Dirichlet) LogProb(x []float64) float64 {
   105  	dim := d.dim
   106  	if len(x) != dim {
   107  		panic(badSizeMismatch)
   108  	}
   109  	var lprob float64
   110  	for i, x := range x {
   111  		lprob += (d.alpha[i] - 1) * math.Log(x)
   112  	}
   113  	lprob -= d.lbeta
   114  	return lprob
   115  }
   116  
   117  // Mean returns the mean of the probability distribution at x. If the
   118  // input argument is nil, a new slice will be allocated, otherwise the result
   119  // will be put in-place into the receiver.
   120  func (d *Dirichlet) Mean(x []float64) []float64 {
   121  	x = reuseAs(x, d.dim)
   122  	copy(x, d.alpha)
   123  	floats.Scale(1/d.sumAlpha, x)
   124  	return x
   125  }
   126  
   127  // Prob computes the value of the probability density function at x.
   128  func (d *Dirichlet) Prob(x []float64) float64 {
   129  	return math.Exp(d.LogProb(x))
   130  }
   131  
   132  // Rand generates a random number according to the distributon.
   133  // If the input slice is nil, new memory is allocated, otherwise the result is stored
   134  // in place.
   135  func (d *Dirichlet) Rand(x []float64) []float64 {
   136  	x = reuseAs(x, d.dim)
   137  	for i := range x {
   138  		x[i] = distuv.Gamma{Alpha: d.alpha[i], Beta: 1, Src: d.src}.Rand()
   139  	}
   140  	sum := floats.Sum(x)
   141  	floats.Scale(1/sum, x)
   142  	return x
   143  }