github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/stat/distmat/wishart.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package distmat 6 7 import ( 8 "math" 9 "sync" 10 11 "golang.org/x/exp/rand" 12 13 "github.com/jingcheng-WU/gonum/mat" 14 "github.com/jingcheng-WU/gonum/mathext" 15 "github.com/jingcheng-WU/gonum/stat/distuv" 16 ) 17 18 // Wishart is a distribution over d×d positive symmetric definite matrices. It 19 // is parametrized by a scalar degrees of freedom parameter ν and a d×d positive 20 // definite matrix V. 21 // 22 // The Wishart PDF is given by 23 // p(X) = [|X|^((ν-d-1)/2) * exp(-tr(V^-1 * X)/2)] / [2^(ν*d/2) * |V|^(ν/2) * Γ_d(ν/2)] 24 // where X is a d×d PSD matrix, ν > d-1, |·| denotes the determinant, tr is the 25 // trace and Γ_d is the multivariate gamma function. 26 // 27 // See https://en.wikipedia.org/wiki/Wishart_distribution for more information. 28 type Wishart struct { 29 nu float64 30 src rand.Source 31 32 dim int 33 cholv mat.Cholesky 34 logdetv float64 35 upper mat.TriDense 36 37 once sync.Once 38 v *mat.SymDense // only stored if needed 39 } 40 41 // NewWishart returns a new Wishart distribution with the given shape matrix and 42 // degrees of freedom parameter. NewWishart returns whether the creation was 43 // successful. 44 // 45 // NewWishart panics if nu <= d - 1 where d is the order of v. 46 func NewWishart(v mat.Symmetric, nu float64, src rand.Source) (*Wishart, bool) { 47 dim := v.Symmetric() 48 if nu <= float64(dim-1) { 49 panic("wishart: nu must be greater than dim-1") 50 } 51 var chol mat.Cholesky 52 ok := chol.Factorize(v) 53 if !ok { 54 return nil, false 55 } 56 57 var u mat.TriDense 58 chol.UTo(&u) 59 60 w := &Wishart{ 61 nu: nu, 62 src: src, 63 64 dim: dim, 65 cholv: chol, 66 logdetv: chol.LogDet(), 67 upper: u, 68 } 69 return w, true 70 } 71 72 // MeanSymTo calculates the mean matrix of the distribution in and stores it in dst. 73 // If dst is empty, it is resized to be an d×d symmetric matrix where d is the order 74 // of the receiver. When dst is non-empty, MeanSymTo panics if dst is not d×d. 75 func (w *Wishart) MeanSymTo(dst *mat.SymDense) { 76 if dst.IsEmpty() { 77 dst.ReuseAsSym(w.dim) 78 } else if dst.Symmetric() != w.dim { 79 panic(badDim) 80 } 81 w.setV() 82 dst.CopySym(w.v) 83 dst.ScaleSym(w.nu, dst) 84 } 85 86 // ProbSym returns the probability of the symmetric matrix x. If x is not positive 87 // definite (the Cholesky decomposition fails), it has 0 probability. 88 func (w *Wishart) ProbSym(x mat.Symmetric) float64 { 89 return math.Exp(w.LogProbSym(x)) 90 } 91 92 // LogProbSym returns the log of the probability of the input symmetric matrix. 93 // 94 // LogProbSym returns -∞ if the input matrix is not positive definite (the Cholesky 95 // decomposition fails). 96 func (w *Wishart) LogProbSym(x mat.Symmetric) float64 { 97 dim := x.Symmetric() 98 if dim != w.dim { 99 panic(badDim) 100 } 101 var chol mat.Cholesky 102 ok := chol.Factorize(x) 103 if !ok { 104 return math.Inf(-1) 105 } 106 return w.logProbSymChol(&chol) 107 } 108 109 // LogProbSymChol returns the log of the probability of the input symmetric matrix 110 // given its Cholesky decomposition. 111 func (w *Wishart) LogProbSymChol(cholX *mat.Cholesky) float64 { 112 dim := cholX.Symmetric() 113 if dim != w.dim { 114 panic(badDim) 115 } 116 return w.logProbSymChol(cholX) 117 } 118 119 func (w *Wishart) logProbSymChol(cholX *mat.Cholesky) float64 { 120 // The PDF is 121 // p(X) = [|X|^((ν-d-1)/2) * exp(-tr(V^-1 * X)/2)] / [2^(ν*d/2) * |V|^(ν/2) * Γ_d(ν/2)] 122 // The LogPDF is thus 123 // (ν-d-1)/2 * log(|X|) - tr(V^-1 * X)/2 - (ν*d/2)*log(2) - ν/2 * log(|V|) - log(Γ_d(ν/2)) 124 logdetx := cholX.LogDet() 125 126 // Compute tr(V^-1 * X), using the fact that X = Uᵀ * U. 127 var u mat.TriDense 128 cholX.UTo(&u) 129 130 var vinvx mat.Dense 131 err := w.cholv.SolveTo(&vinvx, u.T()) 132 if err != nil { 133 return math.Inf(-1) 134 } 135 vinvx.Mul(&vinvx, &u) 136 tr := mat.Trace(&vinvx) 137 138 fnu := float64(w.nu) 139 fdim := float64(w.dim) 140 141 return 0.5*((fnu-fdim-1)*logdetx-tr-fnu*fdim*math.Ln2-fnu*w.logdetv) - mathext.MvLgamma(0.5*fnu, w.dim) 142 } 143 144 // RandSymTo generates a random symmetric matrix from the distribution. 145 // If dst is empty, it is resized to be an d×d symmetric matrix where d is the order 146 // of the receiver. When dst is non-empty, RandSymTo panics if dst is not d×d. 147 func (w *Wishart) RandSymTo(dst *mat.SymDense) { 148 var c mat.Cholesky 149 w.RandCholTo(&c) 150 c.ToSym(dst) 151 } 152 153 // RandCholTo generates the Cholesky decomposition of a random matrix from the distribution. 154 // If dst is empty, it is resized to be an d×d symmetric matrix where d is the order 155 // of the receiver. When dst is non-empty, RandCholTo panics if dst is not d×d. 156 func (w *Wishart) RandCholTo(dst *mat.Cholesky) { 157 // TODO(btracey): Modify the code if the underlying data from dst is exposed 158 // to avoid the dim^2 allocation here. 159 160 // Use the Bartlett Decomposition, which says that 161 // X ~ L A Aᵀ Lᵀ 162 // Where A is a lower triangular matrix in which the diagonal of A is 163 // generated from the square roots of χ^2 random variables, and the 164 // off-diagonals are generated from standard normal variables. 165 // The above gives the cholesky decomposition of X, where L_x = L A. 166 // 167 // mat works with the upper triagular decomposition, so we would like to do 168 // the same. We can instead say that 169 // U_x = L_xᵀ = (L * A)ᵀ = Aᵀ * Lᵀ = Aᵀ * U 170 // Instead, generate Aᵀ, by using the procedure above, except as an upper 171 // triangular matrix. 172 norm := distuv.Normal{ 173 Mu: 0, 174 Sigma: 1, 175 Src: w.src, 176 } 177 178 t := mat.NewTriDense(w.dim, mat.Upper, nil) 179 for i := 0; i < w.dim; i++ { 180 v := distuv.ChiSquared{ 181 K: w.nu - float64(i), 182 Src: w.src, 183 }.Rand() 184 t.SetTri(i, i, math.Sqrt(v)) 185 } 186 for i := 0; i < w.dim; i++ { 187 for j := i + 1; j < w.dim; j++ { 188 t.SetTri(i, j, norm.Rand()) 189 } 190 } 191 192 t.MulTri(t, &w.upper) 193 dst.SetFromU(t) 194 } 195 196 // setV computes and stores the covariance matrix of the distribution. 197 func (w *Wishart) setV() { 198 w.once.Do(func() { 199 w.v = mat.NewSymDense(w.dim, nil) 200 w.cholv.ToSym(w.v) 201 }) 202 }