gonum.org/v1/gonum@v0.14.0/stat/distmv/dirichlet.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package distmv 6 7 import ( 8 "math" 9 10 "golang.org/x/exp/rand" 11 12 "gonum.org/v1/gonum/floats" 13 "gonum.org/v1/gonum/mat" 14 "gonum.org/v1/gonum/stat/distuv" 15 ) 16 17 // Dirichlet implements the Dirichlet probability distribution. 18 // 19 // The Dirichlet distribution is a continuous probability distribution that 20 // generates elements over the probability simplex, i.e. ||x||_1 = 1. The Dirichlet 21 // distribution is the conjugate prior to the categorical distribution and the 22 // multivariate version of the beta distribution. The probability of a point x is 23 // 24 // 1/Beta(α) \prod_i x_i^(α_i - 1) 25 // 26 // where Beta(α) is the multivariate Beta function (see the mathext package). 27 // 28 // For more information see https://en.wikipedia.org/wiki/Dirichlet_distribution 29 type Dirichlet struct { 30 alpha []float64 31 dim int 32 src rand.Source 33 34 lbeta float64 35 sumAlpha float64 36 } 37 38 // NewDirichlet creates a new dirichlet distribution with the given parameters alpha. 39 // NewDirichlet will panic if len(alpha) == 0, or if any alpha is <= 0. 40 func NewDirichlet(alpha []float64, src rand.Source) *Dirichlet { 41 dim := len(alpha) 42 if dim == 0 { 43 panic(badZeroDimension) 44 } 45 for _, v := range alpha { 46 if v <= 0 { 47 panic("dirichlet: non-positive alpha") 48 } 49 } 50 a := make([]float64, len(alpha)) 51 copy(a, alpha) 52 d := &Dirichlet{ 53 alpha: a, 54 dim: dim, 55 src: src, 56 } 57 d.lbeta, d.sumAlpha = d.genLBeta(a) 58 return d 59 } 60 61 // CovarianceMatrix calculates the covariance matrix of the distribution, 62 // storing the result in dst. Upon return, the value at element {i, j} of the 63 // covariance matrix is equal to the covariance of the i^th and j^th variables. 64 // 65 // covariance(i, j) = E[(x_i - E[x_i])(x_j - E[x_j])] 66 // 67 // If the dst matrix is empty it will be resized to the correct dimensions, 68 // otherwise dst must match the dimension of the receiver or CovarianceMatrix 69 // will panic. 70 func (d *Dirichlet) CovarianceMatrix(dst *mat.SymDense) { 71 if dst.IsEmpty() { 72 *dst = *(dst.GrowSym(d.dim).(*mat.SymDense)) 73 } else if dst.SymmetricDim() != d.dim { 74 panic("dirichelet: input matrix size mismatch") 75 } 76 scale := 1 / (d.sumAlpha * d.sumAlpha * (d.sumAlpha + 1)) 77 for i := 0; i < d.dim; i++ { 78 ai := d.alpha[i] 79 v := ai * (d.sumAlpha - ai) * scale 80 dst.SetSym(i, i, v) 81 for j := i + 1; j < d.dim; j++ { 82 aj := d.alpha[j] 83 v := -ai * aj * scale 84 dst.SetSym(i, j, v) 85 } 86 } 87 } 88 89 // genLBeta computes the generalized LBeta function. 90 func (d *Dirichlet) genLBeta(alpha []float64) (lbeta, sumAlpha float64) { 91 for _, alpha := range d.alpha { 92 lg, _ := math.Lgamma(alpha) 93 lbeta += lg 94 sumAlpha += alpha 95 } 96 lg, _ := math.Lgamma(sumAlpha) 97 return lbeta - lg, sumAlpha 98 } 99 100 // Dim returns the dimension of the distribution. 101 func (d *Dirichlet) Dim() int { 102 return d.dim 103 } 104 105 // LogProb computes the log of the pdf of the point x. 106 // 107 // It does not check that ||x||_1 = 1. 108 func (d *Dirichlet) LogProb(x []float64) float64 { 109 dim := d.dim 110 if len(x) != dim { 111 panic(badSizeMismatch) 112 } 113 var lprob float64 114 for i, x := range x { 115 lprob += (d.alpha[i] - 1) * math.Log(x) 116 } 117 lprob -= d.lbeta 118 return lprob 119 } 120 121 // Mean returns the mean of the probability distribution at x. If the 122 // input argument is nil, a new slice will be allocated, otherwise the result 123 // will be put in-place into the receiver. 124 func (d *Dirichlet) Mean(x []float64) []float64 { 125 x = reuseAs(x, d.dim) 126 copy(x, d.alpha) 127 floats.Scale(1/d.sumAlpha, x) 128 return x 129 } 130 131 // Prob computes the value of the probability density function at x. 132 func (d *Dirichlet) Prob(x []float64) float64 { 133 return math.Exp(d.LogProb(x)) 134 } 135 136 // Rand generates a random number according to the distributon. 137 // If the input slice is nil, new memory is allocated, otherwise the result is stored 138 // in place. 139 func (d *Dirichlet) Rand(x []float64) []float64 { 140 x = reuseAs(x, d.dim) 141 for i := range x { 142 x[i] = distuv.Gamma{Alpha: d.alpha[i], Beta: 1, Src: d.src}.Rand() 143 } 144 sum := floats.Sum(x) 145 floats.Scale(1/sum, x) 146 return x 147 }