github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/stat/distmv/dirichlet.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package distmv 6 7 import ( 8 "math" 9 10 "golang.org/x/exp/rand" 11 12 "github.com/jingcheng-WU/gonum/floats" 13 "github.com/jingcheng-WU/gonum/mat" 14 "github.com/jingcheng-WU/gonum/stat/distuv" 15 ) 16 17 // Dirichlet implements the Dirichlet probability distribution. 18 // 19 // The Dirichlet distribution is a continuous probability distribution that 20 // generates elements over the probability simplex, i.e. ||x||_1 = 1. The Dirichlet 21 // distribution is the conjugate prior to the categorical distribution and the 22 // multivariate version of the beta distribution. The probability of a point x is 23 // 1/Beta(α) \prod_i x_i^(α_i - 1) 24 // where Beta(α) is the multivariate Beta function (see the mathext package). 25 // 26 // For more information see https://en.wikipedia.org/wiki/Dirichlet_distribution 27 type Dirichlet struct { 28 alpha []float64 29 dim int 30 src rand.Source 31 32 lbeta float64 33 sumAlpha float64 34 } 35 36 // NewDirichlet creates a new dirichlet distribution with the given parameters alpha. 37 // NewDirichlet will panic if len(alpha) == 0, or if any alpha is <= 0. 38 func NewDirichlet(alpha []float64, src rand.Source) *Dirichlet { 39 dim := len(alpha) 40 if dim == 0 { 41 panic(badZeroDimension) 42 } 43 for _, v := range alpha { 44 if v <= 0 { 45 panic("dirichlet: non-positive alpha") 46 } 47 } 48 a := make([]float64, len(alpha)) 49 copy(a, alpha) 50 d := &Dirichlet{ 51 alpha: a, 52 dim: dim, 53 src: src, 54 } 55 d.lbeta, d.sumAlpha = d.genLBeta(a) 56 return d 57 } 58 59 // CovarianceMatrix calculates the covariance matrix of the distribution, 60 // storing the result in dst. Upon return, the value at element {i, j} of the 61 // covariance matrix is equal to the covariance of the i^th and j^th variables. 62 // covariance(i, j) = E[(x_i - E[x_i])(x_j - E[x_j])] 63 // If the dst matrix is empty it will be resized to the correct dimensions, 64 // otherwise dst must match the dimension of the receiver or CovarianceMatrix 65 // will panic. 66 func (d *Dirichlet) CovarianceMatrix(dst *mat.SymDense) { 67 if dst.IsEmpty() { 68 *dst = *(dst.GrowSym(d.dim).(*mat.SymDense)) 69 } else if dst.Symmetric() != d.dim { 70 panic("dirichelet: input matrix size mismatch") 71 } 72 scale := 1 / (d.sumAlpha * d.sumAlpha * (d.sumAlpha + 1)) 73 for i := 0; i < d.dim; i++ { 74 ai := d.alpha[i] 75 v := ai * (d.sumAlpha - ai) * scale 76 dst.SetSym(i, i, v) 77 for j := i + 1; j < d.dim; j++ { 78 aj := d.alpha[j] 79 v := -ai * aj * scale 80 dst.SetSym(i, j, v) 81 } 82 } 83 } 84 85 // genLBeta computes the generalized LBeta function. 86 func (d *Dirichlet) genLBeta(alpha []float64) (lbeta, sumAlpha float64) { 87 for _, alpha := range d.alpha { 88 lg, _ := math.Lgamma(alpha) 89 lbeta += lg 90 sumAlpha += alpha 91 } 92 lg, _ := math.Lgamma(sumAlpha) 93 return lbeta - lg, sumAlpha 94 } 95 96 // Dim returns the dimension of the distribution. 97 func (d *Dirichlet) Dim() int { 98 return d.dim 99 } 100 101 // LogProb computes the log of the pdf of the point x. 102 // 103 // It does not check that ||x||_1 = 1. 104 func (d *Dirichlet) LogProb(x []float64) float64 { 105 dim := d.dim 106 if len(x) != dim { 107 panic(badSizeMismatch) 108 } 109 var lprob float64 110 for i, x := range x { 111 lprob += (d.alpha[i] - 1) * math.Log(x) 112 } 113 lprob -= d.lbeta 114 return lprob 115 } 116 117 // Mean returns the mean of the probability distribution at x. If the 118 // input argument is nil, a new slice will be allocated, otherwise the result 119 // will be put in-place into the receiver. 120 func (d *Dirichlet) Mean(x []float64) []float64 { 121 x = reuseAs(x, d.dim) 122 copy(x, d.alpha) 123 floats.Scale(1/d.sumAlpha, x) 124 return x 125 } 126 127 // Prob computes the value of the probability density function at x. 128 func (d *Dirichlet) Prob(x []float64) float64 { 129 return math.Exp(d.LogProb(x)) 130 } 131 132 // Rand generates a random number according to the distributon. 133 // If the input slice is nil, new memory is allocated, otherwise the result is stored 134 // in place. 135 func (d *Dirichlet) Rand(x []float64) []float64 { 136 x = reuseAs(x, d.dim) 137 for i := range x { 138 x[i] = distuv.Gamma{Alpha: d.alpha[i], Beta: 1, Src: d.src}.Rand() 139 } 140 sum := floats.Sum(x) 141 floats.Scale(1/sum, x) 142 return x 143 }