gonum.org/v1/gonum@v0.14.0/stat/distmat/wishart.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package distmat 6 7 import ( 8 "math" 9 "sync" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/mat" 14 "gonum.org/v1/gonum/mathext" 15 "gonum.org/v1/gonum/stat/distuv" 16 ) 17 18 // Wishart is a distribution over d×d positive symmetric definite matrices. It 19 // is parametrized by a scalar degrees of freedom parameter ν and a d×d positive 20 // definite matrix V. 21 // 22 // The Wishart PDF is given by 23 // 24 // p(X) = [|X|^((ν-d-1)/2) * exp(-tr(V^-1 * X)/2)] / [2^(ν*d/2) * |V|^(ν/2) * Γ_d(ν/2)] 25 // 26 // where X is a d×d PSD matrix, ν > d-1, |·| denotes the determinant, tr is the 27 // trace and Γ_d is the multivariate gamma function. 28 // 29 // See https://en.wikipedia.org/wiki/Wishart_distribution for more information. 30 type Wishart struct { 31 nu float64 32 src rand.Source 33 34 dim int 35 cholv mat.Cholesky 36 logdetv float64 37 upper mat.TriDense 38 39 once sync.Once 40 v *mat.SymDense // only stored if needed 41 } 42 43 // NewWishart returns a new Wishart distribution with the given shape matrix and 44 // degrees of freedom parameter. NewWishart returns whether the creation was 45 // successful. 46 // 47 // NewWishart panics if nu <= d - 1 where d is the order of v. 48 func NewWishart(v mat.Symmetric, nu float64, src rand.Source) (*Wishart, bool) { 49 dim := v.SymmetricDim() 50 if nu <= float64(dim-1) { 51 panic("wishart: nu must be greater than dim-1") 52 } 53 var chol mat.Cholesky 54 ok := chol.Factorize(v) 55 if !ok { 56 return nil, false 57 } 58 59 var u mat.TriDense 60 chol.UTo(&u) 61 62 w := &Wishart{ 63 nu: nu, 64 src: src, 65 66 dim: dim, 67 cholv: chol, 68 logdetv: chol.LogDet(), 69 upper: u, 70 } 71 return w, true 72 } 73 74 // MeanSymTo calculates the mean matrix of the distribution in and stores it in dst. 75 // If dst is empty, it is resized to be an d×d symmetric matrix where d is the order 76 // of the receiver. When dst is non-empty, MeanSymTo panics if dst is not d×d. 77 func (w *Wishart) MeanSymTo(dst *mat.SymDense) { 78 if dst.IsEmpty() { 79 dst.ReuseAsSym(w.dim) 80 } else if dst.SymmetricDim() != w.dim { 81 panic(badDim) 82 } 83 w.setV() 84 dst.CopySym(w.v) 85 dst.ScaleSym(w.nu, dst) 86 } 87 88 // ProbSym returns the probability of the symmetric matrix x. If x is not positive 89 // definite (the Cholesky decomposition fails), it has 0 probability. 90 func (w *Wishart) ProbSym(x mat.Symmetric) float64 { 91 return math.Exp(w.LogProbSym(x)) 92 } 93 94 // LogProbSym returns the log of the probability of the input symmetric matrix. 95 // 96 // LogProbSym returns -∞ if the input matrix is not positive definite (the Cholesky 97 // decomposition fails). 98 func (w *Wishart) LogProbSym(x mat.Symmetric) float64 { 99 dim := x.SymmetricDim() 100 if dim != w.dim { 101 panic(badDim) 102 } 103 var chol mat.Cholesky 104 ok := chol.Factorize(x) 105 if !ok { 106 return math.Inf(-1) 107 } 108 return w.logProbSymChol(&chol) 109 } 110 111 // LogProbSymChol returns the log of the probability of the input symmetric matrix 112 // given its Cholesky decomposition. 113 func (w *Wishart) LogProbSymChol(cholX *mat.Cholesky) float64 { 114 dim := cholX.SymmetricDim() 115 if dim != w.dim { 116 panic(badDim) 117 } 118 return w.logProbSymChol(cholX) 119 } 120 121 func (w *Wishart) logProbSymChol(cholX *mat.Cholesky) float64 { 122 // The PDF is 123 // p(X) = [|X|^((ν-d-1)/2) * exp(-tr(V^-1 * X)/2)] / [2^(ν*d/2) * |V|^(ν/2) * Γ_d(ν/2)] 124 // The LogPDF is thus 125 // (ν-d-1)/2 * log(|X|) - tr(V^-1 * X)/2 - (ν*d/2)*log(2) - ν/2 * log(|V|) - log(Γ_d(ν/2)) 126 logdetx := cholX.LogDet() 127 128 // Compute tr(V^-1 * X), using the fact that X = Uᵀ * U. 129 var u mat.TriDense 130 cholX.UTo(&u) 131 132 var vinvx mat.Dense 133 err := w.cholv.SolveTo(&vinvx, u.T()) 134 if err != nil { 135 return math.Inf(-1) 136 } 137 vinvx.Mul(&vinvx, &u) 138 tr := mat.Trace(&vinvx) 139 140 fnu := float64(w.nu) 141 fdim := float64(w.dim) 142 143 return 0.5*((fnu-fdim-1)*logdetx-tr-fnu*fdim*math.Ln2-fnu*w.logdetv) - mathext.MvLgamma(0.5*fnu, w.dim) 144 } 145 146 // RandSymTo generates a random symmetric matrix from the distribution. 147 // If dst is empty, it is resized to be an d×d symmetric matrix where d is the order 148 // of the receiver. When dst is non-empty, RandSymTo panics if dst is not d×d. 149 func (w *Wishart) RandSymTo(dst *mat.SymDense) { 150 var c mat.Cholesky 151 w.RandCholTo(&c) 152 c.ToSym(dst) 153 } 154 155 // RandCholTo generates the Cholesky decomposition of a random matrix from the distribution. 156 // If dst is empty, it is resized to be an d×d symmetric matrix where d is the order 157 // of the receiver. When dst is non-empty, RandCholTo panics if dst is not d×d. 158 func (w *Wishart) RandCholTo(dst *mat.Cholesky) { 159 // TODO(btracey): Modify the code if the underlying data from dst is exposed 160 // to avoid the dim^2 allocation here. 161 162 // Use the Bartlett Decomposition, which says that 163 // X ~ L A Aᵀ Lᵀ 164 // Where A is a lower triangular matrix in which the diagonal of A is 165 // generated from the square roots of χ^2 random variables, and the 166 // off-diagonals are generated from standard normal variables. 167 // The above gives the cholesky decomposition of X, where L_x = L A. 168 // 169 // mat works with the upper triagular decomposition, so we would like to do 170 // the same. We can instead say that 171 // U_x = L_xᵀ = (L * A)ᵀ = Aᵀ * Lᵀ = Aᵀ * U 172 // Instead, generate Aᵀ, by using the procedure above, except as an upper 173 // triangular matrix. 174 norm := distuv.Normal{ 175 Mu: 0, 176 Sigma: 1, 177 Src: w.src, 178 } 179 180 t := mat.NewTriDense(w.dim, mat.Upper, nil) 181 for i := 0; i < w.dim; i++ { 182 v := distuv.ChiSquared{ 183 K: w.nu - float64(i), 184 Src: w.src, 185 }.Rand() 186 t.SetTri(i, i, math.Sqrt(v)) 187 } 188 for i := 0; i < w.dim; i++ { 189 for j := i + 1; j < w.dim; j++ { 190 t.SetTri(i, j, norm.Rand()) 191 } 192 } 193 194 t.MulTri(t, &w.upper) 195 dst.SetFromU(t) 196 } 197 198 // setV computes and stores the covariance matrix of the distribution. 199 func (w *Wishart) setV() { 200 w.once.Do(func() { 201 w.v = mat.NewSymDense(w.dim, nil) 202 w.cholv.ToSym(w.v) 203 }) 204 }