github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/stat/distmv/statdist.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package distmv 6 7 import ( 8 "math" 9 10 "github.com/jingcheng-WU/gonum/floats" 11 "github.com/jingcheng-WU/gonum/mat" 12 "github.com/jingcheng-WU/gonum/mathext" 13 "github.com/jingcheng-WU/gonum/spatial/r1" 14 "github.com/jingcheng-WU/gonum/stat" 15 ) 16 17 // Bhattacharyya is a type for computing the Bhattacharyya distance between 18 // probability distributions. 19 // 20 // The Bhattacharyya distance is defined as 21 // D_B = -ln(BC(l,r)) 22 // BC = \int_-∞^∞ (p(x)q(x))^(1/2) dx 23 // Where BC is known as the Bhattacharyya coefficient. 24 // The Bhattacharyya distance is related to the Hellinger distance by 25 // H(l,r) = sqrt(1-BC(l,r)) 26 // For more information, see 27 // https://en.wikipedia.org/wiki/Bhattacharyya_distance 28 type Bhattacharyya struct{} 29 30 // DistNormal computes the Bhattacharyya distance between normal distributions l and r. 31 // The dimensions of the input distributions must match or DistNormal will panic. 32 // 33 // For Normal distributions, the Bhattacharyya distance is 34 // Σ = (Σ_l + Σ_r)/2 35 // D_B = (1/8)*(μ_l - μ_r)ᵀ*Σ^-1*(μ_l - μ_r) + (1/2)*ln(det(Σ)/(det(Σ_l)*det(Σ_r))^(1/2)) 36 func (Bhattacharyya) DistNormal(l, r *Normal) float64 { 37 dim := l.Dim() 38 if dim != r.Dim() { 39 panic(badSizeMismatch) 40 } 41 42 var sigma mat.SymDense 43 sigma.AddSym(&l.sigma, &r.sigma) 44 sigma.ScaleSym(0.5, &sigma) 45 46 var chol mat.Cholesky 47 chol.Factorize(&sigma) 48 49 mahalanobis := stat.Mahalanobis(mat.NewVecDense(dim, l.mu), mat.NewVecDense(dim, r.mu), &chol) 50 mahalanobisSq := mahalanobis * mahalanobis 51 52 dl := l.chol.LogDet() 53 dr := r.chol.LogDet() 54 ds := chol.LogDet() 55 56 return 0.125*mahalanobisSq + 0.5*ds - 0.25*dl - 0.25*dr 57 } 58 59 // DistUniform computes the Bhattacharyya distance between uniform distributions l and r. 60 // The dimensions of the input distributions must match or DistUniform will panic. 61 func (Bhattacharyya) DistUniform(l, r *Uniform) float64 { 62 if len(l.bounds) != len(r.bounds) { 63 panic(badSizeMismatch) 64 } 65 // BC = \int \sqrt(p(x)q(x)), which for uniform distributions is a constant 66 // over the volume where both distributions have positive probability. 67 // Compute the overlap and the value of sqrt(p(x)q(x)). The entropy is the 68 // negative log probability of the distribution (use instead of LogProb so 69 // it is not necessary to construct an x value). 70 // 71 // BC = volume * sqrt(p(x)q(x)) 72 // logBC = log(volume) + 0.5*(logP + logQ) 73 // D_B = -logBC 74 return -unifLogVolOverlap(l.bounds, r.bounds) + 0.5*(l.Entropy()+r.Entropy()) 75 } 76 77 // unifLogVolOverlap computes the log of the volume of the hyper-rectangle where 78 // both uniform distributions have positive probability. 79 func unifLogVolOverlap(b1, b2 []r1.Interval) float64 { 80 var logVolOverlap float64 81 for dim, v1 := range b1 { 82 v2 := b2[dim] 83 // If the surfaces don't overlap, then the volume is 0 84 if v1.Max <= v2.Min || v2.Max <= v1.Min { 85 return math.Inf(-1) 86 } 87 vol := math.Min(v1.Max, v2.Max) - math.Max(v1.Min, v2.Min) 88 logVolOverlap += math.Log(vol) 89 } 90 return logVolOverlap 91 } 92 93 // CrossEntropy is a type for computing the cross-entropy between probability 94 // distributions. 95 // 96 // The cross-entropy is defined as 97 // - \int_x l(x) log(r(x)) dx = KL(l || r) + H(l) 98 // where KL is the Kullback-Leibler divergence and H is the entropy. 99 // For more information, see 100 // https://en.wikipedia.org/wiki/Cross_entropy 101 type CrossEntropy struct{} 102 103 // DistNormal returns the cross-entropy between normal distributions l and r. 104 // The dimensions of the input distributions must match or DistNormal will panic. 105 func (CrossEntropy) DistNormal(l, r *Normal) float64 { 106 if l.Dim() != r.Dim() { 107 panic(badSizeMismatch) 108 } 109 kl := KullbackLeibler{}.DistNormal(l, r) 110 return kl + l.Entropy() 111 } 112 113 // Hellinger is a type for computing the Hellinger distance between probability 114 // distributions. 115 // 116 // The Hellinger distance is defined as 117 // H^2(l,r) = 1/2 * int_x (\sqrt(l(x)) - \sqrt(r(x)))^2 dx 118 // and is bounded between 0 and 1. Note the above formula defines the squared 119 // Hellinger distance, while this returns the Hellinger distance itself. 120 // The Hellinger distance is related to the Bhattacharyya distance by 121 // H^2 = 1 - exp(-D_B) 122 // For more information, see 123 // https://en.wikipedia.org/wiki/Hellinger_distance 124 type Hellinger struct{} 125 126 // DistNormal returns the Hellinger distance between normal distributions l and r. 127 // The dimensions of the input distributions must match or DistNormal will panic. 128 // 129 // See the documentation of Bhattacharyya.DistNormal for the formula for Normal 130 // distributions. 131 func (Hellinger) DistNormal(l, r *Normal) float64 { 132 if l.Dim() != r.Dim() { 133 panic(badSizeMismatch) 134 } 135 db := Bhattacharyya{}.DistNormal(l, r) 136 bc := math.Exp(-db) 137 return math.Sqrt(1 - bc) 138 } 139 140 // KullbackLeibler is a type for computing the Kullback-Leibler divergence from l to r. 141 // 142 // The Kullback-Leibler divergence is defined as 143 // D_KL(l || r ) = \int_x p(x) log(p(x)/q(x)) dx 144 // Note that the Kullback-Leibler divergence is not symmetric with respect to 145 // the order of the input arguments. 146 type KullbackLeibler struct{} 147 148 // DistDirichlet returns the Kullback-Leibler divergence between Dirichlet 149 // distributions l and r. The dimensions of the input distributions must match 150 // or DistDirichlet will panic. 151 // 152 // For two Dirichlet distributions, the KL divergence is computed as 153 // D_KL(l || r) = log Γ(α_0_l) - \sum_i log Γ(α_i_l) - log Γ(α_0_r) + \sum_i log Γ(α_i_r) 154 // + \sum_i (α_i_l - α_i_r)(ψ(α_i_l)- ψ(α_0_l)) 155 // Where Γ is the gamma function, ψ is the digamma function, and α_0 is the 156 // sum of the Dirichlet parameters. 157 func (KullbackLeibler) DistDirichlet(l, r *Dirichlet) float64 { 158 // http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/ 159 if l.Dim() != r.Dim() { 160 panic(badSizeMismatch) 161 } 162 l0, _ := math.Lgamma(l.sumAlpha) 163 r0, _ := math.Lgamma(r.sumAlpha) 164 dl := mathext.Digamma(l.sumAlpha) 165 166 var l1, r1, c float64 167 for i, al := range l.alpha { 168 ar := r.alpha[i] 169 vl, _ := math.Lgamma(al) 170 l1 += vl 171 vr, _ := math.Lgamma(ar) 172 r1 += vr 173 c += (al - ar) * (mathext.Digamma(al) - dl) 174 } 175 return l0 - l1 - r0 + r1 + c 176 } 177 178 // DistNormal returns the KullbackLeibler divergence between normal distributions l and r. 179 // The dimensions of the input distributions must match or DistNormal will panic. 180 // 181 // For two normal distributions, the KL divergence is computed as 182 // D_KL(l || r) = 0.5*[ln(|Σ_r|) - ln(|Σ_l|) + (μ_l - μ_r)ᵀ*Σ_r^-1*(μ_l - μ_r) + tr(Σ_r^-1*Σ_l)-d] 183 func (KullbackLeibler) DistNormal(l, r *Normal) float64 { 184 dim := l.Dim() 185 if dim != r.Dim() { 186 panic(badSizeMismatch) 187 } 188 189 mahalanobis := stat.Mahalanobis(mat.NewVecDense(dim, l.mu), mat.NewVecDense(dim, r.mu), &r.chol) 190 mahalanobisSq := mahalanobis * mahalanobis 191 192 // TODO(btracey): Optimize where there is a SolveCholeskySym 193 // TODO(btracey): There may be a more efficient way to just compute the trace 194 // Compute tr(Σ_r^-1*Σ_l) using the fact that Σ_l = Uᵀ * U 195 var u mat.TriDense 196 l.chol.UTo(&u) 197 var m mat.Dense 198 err := r.chol.SolveTo(&m, u.T()) 199 if err != nil { 200 return math.NaN() 201 } 202 m.Mul(&m, &u) 203 tr := mat.Trace(&m) 204 205 return r.logSqrtDet - l.logSqrtDet + 0.5*(mahalanobisSq+tr-float64(l.dim)) 206 } 207 208 // DistUniform returns the KullbackLeibler divergence between uniform distributions 209 // l and r. The dimensions of the input distributions must match or DistUniform 210 // will panic. 211 func (KullbackLeibler) DistUniform(l, r *Uniform) float64 { 212 bl := l.Bounds(nil) 213 br := r.Bounds(nil) 214 if len(bl) != len(br) { 215 panic(badSizeMismatch) 216 } 217 218 // The KL is ∞ if l is not completely contained within r, because then 219 // r(x) is zero when l(x) is non-zero for some x. 220 contained := true 221 for i, v := range bl { 222 if v.Min < br[i].Min || br[i].Max < v.Max { 223 contained = false 224 break 225 } 226 } 227 if !contained { 228 return math.Inf(1) 229 } 230 231 // The KL divergence is finite. 232 // 233 // KL defines 0*ln(0) = 0, so there is no contribution to KL where l(x) = 0. 234 // Inside the region, l(x) and r(x) are constant (uniform distribution), and 235 // this constant is integrated over l(x), which integrates out to one. 236 // The entropy is -log(p(x)). 237 logPx := -l.Entropy() 238 logQx := -r.Entropy() 239 return logPx - logQx 240 } 241 242 // Renyi is a type for computing the Rényi divergence of order α from l to r. 243 // 244 // The Rényi divergence with α > 0, α ≠ 1 is defined as 245 // D_α(l || r) = 1/(α-1) log(\int_-∞^∞ l(x)^α r(x)^(1-α)dx) 246 // The Rényi divergence has special forms for α = 0 and α = 1. This type does 247 // not implement α = ∞. For α = 0, 248 // D_0(l || r) = -log \int_-∞^∞ r(x)1{p(x)>0} dx 249 // that is, the negative log probability under r(x) that l(x) > 0. 250 // When α = 1, the Rényi divergence is equal to the Kullback-Leibler divergence. 251 // The Rényi divergence is also equal to half the Bhattacharyya distance when α = 0.5. 252 // 253 // The parameter α must be in 0 ≤ α < ∞ or the distance functions will panic. 254 type Renyi struct { 255 Alpha float64 256 } 257 258 // DistNormal returns the Rényi divergence between normal distributions l and r. 259 // The dimensions of the input distributions must match or DistNormal will panic. 260 // 261 // For two normal distributions, the Rényi divergence is computed as 262 // Σ_α = (1-α) Σ_l + αΣ_r 263 // D_α(l||r) = α/2 * (μ_l - μ_r)'*Σ_α^-1*(μ_l - μ_r) + 1/(2(α-1))*ln(|Σ_λ|/(|Σ_l|^(1-α)*|Σ_r|^α)) 264 // 265 // For a more nicely formatted version of the formula, see Eq. 15 of 266 // Kolchinsky, Artemy, and Brendan D. Tracey. "Estimating Mixture Entropy 267 // with Pairwise Distances." arXiv preprint arXiv:1706.02419 (2017). 268 // Note that the this formula is for Chernoff divergence, which differs from 269 // Rényi divergence by a factor of 1-α. Also be aware that most sources in 270 // the literature report this formula incorrectly. 271 func (renyi Renyi) DistNormal(l, r *Normal) float64 { 272 if renyi.Alpha < 0 { 273 panic("renyi: alpha < 0") 274 } 275 dim := l.Dim() 276 if dim != r.Dim() { 277 panic(badSizeMismatch) 278 } 279 if renyi.Alpha == 0 { 280 return 0 281 } 282 if renyi.Alpha == 1 { 283 return KullbackLeibler{}.DistNormal(l, r) 284 } 285 286 logDetL := l.chol.LogDet() 287 logDetR := r.chol.LogDet() 288 289 // Σ_α = (1-α)Σ_l + αΣ_r. 290 sigA := mat.NewSymDense(dim, nil) 291 for i := 0; i < dim; i++ { 292 for j := i; j < dim; j++ { 293 v := (1-renyi.Alpha)*l.sigma.At(i, j) + renyi.Alpha*r.sigma.At(i, j) 294 sigA.SetSym(i, j, v) 295 } 296 } 297 298 var chol mat.Cholesky 299 ok := chol.Factorize(sigA) 300 if !ok { 301 return math.NaN() 302 } 303 logDetA := chol.LogDet() 304 305 mahalanobis := stat.Mahalanobis(mat.NewVecDense(dim, l.mu), mat.NewVecDense(dim, r.mu), &chol) 306 mahalanobisSq := mahalanobis * mahalanobis 307 308 return (renyi.Alpha/2)*mahalanobisSq + 1/(2*(1-renyi.Alpha))*(logDetA-(1-renyi.Alpha)*logDetL-renyi.Alpha*logDetR) 309 } 310 311 // Wasserstein is a type for computing the Wasserstein distance between two 312 // probability distributions. 313 // 314 // The Wasserstein distance is defined as 315 // W(l,r) := inf 𝔼(||X-Y||_2^2)^1/2 316 // For more information, see 317 // https://en.wikipedia.org/wiki/Wasserstein_metric 318 type Wasserstein struct{} 319 320 // DistNormal returns the Wasserstein distance between normal distributions l and r. 321 // The dimensions of the input distributions must match or DistNormal will panic. 322 // 323 // The Wasserstein distance for Normal distributions is 324 // d^2 = ||m_l - m_r||_2^2 + Tr(Σ_l + Σ_r - 2(Σ_l^(1/2)*Σ_r*Σ_l^(1/2))^(1/2)) 325 // For more information, see 326 // http://djalil.chafai.net/blog/2010/04/30/wasserstein-distance-between-two-gaussians/ 327 func (Wasserstein) DistNormal(l, r *Normal) float64 { 328 dim := l.Dim() 329 if dim != r.Dim() { 330 panic(badSizeMismatch) 331 } 332 333 d := floats.Distance(l.mu, r.mu, 2) 334 d = d * d 335 336 // Compute Σ_l^(1/2) 337 var ssl mat.SymDense 338 err := ssl.PowPSD(&l.sigma, 0.5) 339 if err != nil { 340 panic(err) 341 } 342 // Compute Σ_l^(1/2)*Σ_r*Σ_l^(1/2) 343 var mean mat.Dense 344 mean.Mul(&ssl, &r.sigma) 345 mean.Mul(&mean, &ssl) 346 347 // Reinterpret as symdense, and take Σ^(1/2) 348 meanSym := mat.NewSymDense(dim, mean.RawMatrix().Data) 349 err = ssl.PowPSD(meanSym, 0.5) 350 if err != nil { 351 panic(err) 352 } 353 354 tr := mat.Trace(&r.sigma) 355 tl := mat.Trace(&l.sigma) 356 tm := mat.Trace(&ssl) 357 358 return d + tl + tr - 2*tm 359 }