gonum.org/v1/gonum@v0.14.0/stat/distmv/statdist.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package distmv 6 7 import ( 8 "math" 9 10 "gonum.org/v1/gonum/floats" 11 "gonum.org/v1/gonum/mat" 12 "gonum.org/v1/gonum/mathext" 13 "gonum.org/v1/gonum/spatial/r1" 14 "gonum.org/v1/gonum/stat" 15 ) 16 17 // Bhattacharyya is a type for computing the Bhattacharyya distance between 18 // probability distributions. 19 // 20 // The Bhattacharyya distance is defined as 21 // 22 // D_B = -ln(BC(l,r)) 23 // BC = \int_-∞^∞ (p(x)q(x))^(1/2) dx 24 // 25 // Where BC is known as the Bhattacharyya coefficient. 26 // The Bhattacharyya distance is related to the Hellinger distance by 27 // 28 // H(l,r) = sqrt(1-BC(l,r)) 29 // 30 // For more information, see 31 // 32 // https://en.wikipedia.org/wiki/Bhattacharyya_distance 33 type Bhattacharyya struct{} 34 35 // DistNormal computes the Bhattacharyya distance between normal distributions l and r. 36 // The dimensions of the input distributions must match or DistNormal will panic. 37 // 38 // For Normal distributions, the Bhattacharyya distance is 39 // 40 // Σ = (Σ_l + Σ_r)/2 41 // D_B = (1/8)*(μ_l - μ_r)ᵀ*Σ^-1*(μ_l - μ_r) + (1/2)*ln(det(Σ)/(det(Σ_l)*det(Σ_r))^(1/2)) 42 func (Bhattacharyya) DistNormal(l, r *Normal) float64 { 43 dim := l.Dim() 44 if dim != r.Dim() { 45 panic(badSizeMismatch) 46 } 47 48 var sigma mat.SymDense 49 sigma.AddSym(&l.sigma, &r.sigma) 50 sigma.ScaleSym(0.5, &sigma) 51 52 var chol mat.Cholesky 53 chol.Factorize(&sigma) 54 55 mahalanobis := stat.Mahalanobis(mat.NewVecDense(dim, l.mu), mat.NewVecDense(dim, r.mu), &chol) 56 mahalanobisSq := mahalanobis * mahalanobis 57 58 dl := l.chol.LogDet() 59 dr := r.chol.LogDet() 60 ds := chol.LogDet() 61 62 return 0.125*mahalanobisSq + 0.5*ds - 0.25*dl - 0.25*dr 63 } 64 65 // DistUniform computes the Bhattacharyya distance between uniform distributions l and r. 66 // The dimensions of the input distributions must match or DistUniform will panic. 67 func (Bhattacharyya) DistUniform(l, r *Uniform) float64 { 68 if len(l.bounds) != len(r.bounds) { 69 panic(badSizeMismatch) 70 } 71 // BC = \int \sqrt(p(x)q(x)), which for uniform distributions is a constant 72 // over the volume where both distributions have positive probability. 73 // Compute the overlap and the value of sqrt(p(x)q(x)). The entropy is the 74 // negative log probability of the distribution (use instead of LogProb so 75 // it is not necessary to construct an x value). 76 // 77 // BC = volume * sqrt(p(x)q(x)) 78 // logBC = log(volume) + 0.5*(logP + logQ) 79 // D_B = -logBC 80 return -unifLogVolOverlap(l.bounds, r.bounds) + 0.5*(l.Entropy()+r.Entropy()) 81 } 82 83 // unifLogVolOverlap computes the log of the volume of the hyper-rectangle where 84 // both uniform distributions have positive probability. 85 func unifLogVolOverlap(b1, b2 []r1.Interval) float64 { 86 var logVolOverlap float64 87 for dim, v1 := range b1 { 88 v2 := b2[dim] 89 // If the surfaces don't overlap, then the volume is 0 90 if v1.Max <= v2.Min || v2.Max <= v1.Min { 91 return math.Inf(-1) 92 } 93 vol := math.Min(v1.Max, v2.Max) - math.Max(v1.Min, v2.Min) 94 logVolOverlap += math.Log(vol) 95 } 96 return logVolOverlap 97 } 98 99 // CrossEntropy is a type for computing the cross-entropy between probability 100 // distributions. 101 // 102 // The cross-entropy is defined as 103 // - \int_x l(x) log(r(x)) dx = KL(l || r) + H(l) 104 // 105 // where KL is the Kullback-Leibler divergence and H is the entropy. 106 // For more information, see 107 // 108 // https://en.wikipedia.org/wiki/Cross_entropy 109 type CrossEntropy struct{} 110 111 // DistNormal returns the cross-entropy between normal distributions l and r. 112 // The dimensions of the input distributions must match or DistNormal will panic. 113 func (CrossEntropy) DistNormal(l, r *Normal) float64 { 114 if l.Dim() != r.Dim() { 115 panic(badSizeMismatch) 116 } 117 kl := KullbackLeibler{}.DistNormal(l, r) 118 return kl + l.Entropy() 119 } 120 121 // Hellinger is a type for computing the Hellinger distance between probability 122 // distributions. 123 // 124 // The Hellinger distance is defined as 125 // 126 // H^2(l,r) = 1/2 * int_x (\sqrt(l(x)) - \sqrt(r(x)))^2 dx 127 // 128 // and is bounded between 0 and 1. Note the above formula defines the squared 129 // Hellinger distance, while this returns the Hellinger distance itself. 130 // The Hellinger distance is related to the Bhattacharyya distance by 131 // 132 // H^2 = 1 - exp(-D_B) 133 // 134 // For more information, see 135 // 136 // https://en.wikipedia.org/wiki/Hellinger_distance 137 type Hellinger struct{} 138 139 // DistNormal returns the Hellinger distance between normal distributions l and r. 140 // The dimensions of the input distributions must match or DistNormal will panic. 141 // 142 // See the documentation of Bhattacharyya.DistNormal for the formula for Normal 143 // distributions. 144 func (Hellinger) DistNormal(l, r *Normal) float64 { 145 if l.Dim() != r.Dim() { 146 panic(badSizeMismatch) 147 } 148 db := Bhattacharyya{}.DistNormal(l, r) 149 bc := math.Exp(-db) 150 return math.Sqrt(1 - bc) 151 } 152 153 // KullbackLeibler is a type for computing the Kullback-Leibler divergence from l to r. 154 // 155 // The Kullback-Leibler divergence is defined as 156 // 157 // D_KL(l || r ) = \int_x p(x) log(p(x)/q(x)) dx 158 // 159 // Note that the Kullback-Leibler divergence is not symmetric with respect to 160 // the order of the input arguments. 161 type KullbackLeibler struct{} 162 163 // DistDirichlet returns the Kullback-Leibler divergence between Dirichlet 164 // distributions l and r. The dimensions of the input distributions must match 165 // or DistDirichlet will panic. 166 // 167 // For two Dirichlet distributions, the KL divergence is computed as 168 // 169 // D_KL(l || r) = log Γ(α_0_l) - \sum_i log Γ(α_i_l) - log Γ(α_0_r) + \sum_i log Γ(α_i_r) 170 // + \sum_i (α_i_l - α_i_r)(ψ(α_i_l)- ψ(α_0_l)) 171 // 172 // Where Γ is the gamma function, ψ is the digamma function, and α_0 is the 173 // sum of the Dirichlet parameters. 174 func (KullbackLeibler) DistDirichlet(l, r *Dirichlet) float64 { 175 // http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/ 176 if l.Dim() != r.Dim() { 177 panic(badSizeMismatch) 178 } 179 l0, _ := math.Lgamma(l.sumAlpha) 180 r0, _ := math.Lgamma(r.sumAlpha) 181 dl := mathext.Digamma(l.sumAlpha) 182 183 var l1, r1, c float64 184 for i, al := range l.alpha { 185 ar := r.alpha[i] 186 vl, _ := math.Lgamma(al) 187 l1 += vl 188 vr, _ := math.Lgamma(ar) 189 r1 += vr 190 c += (al - ar) * (mathext.Digamma(al) - dl) 191 } 192 return l0 - l1 - r0 + r1 + c 193 } 194 195 // DistNormal returns the KullbackLeibler divergence between normal distributions l and r. 196 // The dimensions of the input distributions must match or DistNormal will panic. 197 // 198 // For two normal distributions, the KL divergence is computed as 199 // 200 // D_KL(l || r) = 0.5*[ln(|Σ_r|) - ln(|Σ_l|) + (μ_l - μ_r)ᵀ*Σ_r^-1*(μ_l - μ_r) + tr(Σ_r^-1*Σ_l)-d] 201 func (KullbackLeibler) DistNormal(l, r *Normal) float64 { 202 dim := l.Dim() 203 if dim != r.Dim() { 204 panic(badSizeMismatch) 205 } 206 207 mahalanobis := stat.Mahalanobis(mat.NewVecDense(dim, l.mu), mat.NewVecDense(dim, r.mu), &r.chol) 208 mahalanobisSq := mahalanobis * mahalanobis 209 210 // TODO(btracey): Optimize where there is a SolveCholeskySym 211 // TODO(btracey): There may be a more efficient way to just compute the trace 212 // Compute tr(Σ_r^-1*Σ_l) using the fact that Σ_l = Uᵀ * U 213 var u mat.TriDense 214 l.chol.UTo(&u) 215 var m mat.Dense 216 err := r.chol.SolveTo(&m, u.T()) 217 if err != nil { 218 return math.NaN() 219 } 220 m.Mul(&m, &u) 221 tr := mat.Trace(&m) 222 223 return r.logSqrtDet - l.logSqrtDet + 0.5*(mahalanobisSq+tr-float64(l.dim)) 224 } 225 226 // DistUniform returns the KullbackLeibler divergence between uniform distributions 227 // l and r. The dimensions of the input distributions must match or DistUniform 228 // will panic. 229 func (KullbackLeibler) DistUniform(l, r *Uniform) float64 { 230 bl := l.Bounds(nil) 231 br := r.Bounds(nil) 232 if len(bl) != len(br) { 233 panic(badSizeMismatch) 234 } 235 236 // The KL is ∞ if l is not completely contained within r, because then 237 // r(x) is zero when l(x) is non-zero for some x. 238 contained := true 239 for i, v := range bl { 240 if v.Min < br[i].Min || br[i].Max < v.Max { 241 contained = false 242 break 243 } 244 } 245 if !contained { 246 return math.Inf(1) 247 } 248 249 // The KL divergence is finite. 250 // 251 // KL defines 0*ln(0) = 0, so there is no contribution to KL where l(x) = 0. 252 // Inside the region, l(x) and r(x) are constant (uniform distribution), and 253 // this constant is integrated over l(x), which integrates out to one. 254 // The entropy is -log(p(x)). 255 logPx := -l.Entropy() 256 logQx := -r.Entropy() 257 return logPx - logQx 258 } 259 260 // Renyi is a type for computing the Rényi divergence of order α from l to r. 261 // 262 // The Rényi divergence with α > 0, α ≠ 1 is defined as 263 // 264 // D_α(l || r) = 1/(α-1) log(\int_-∞^∞ l(x)^α r(x)^(1-α)dx) 265 // 266 // The Rényi divergence has special forms for α = 0 and α = 1. This type does 267 // not implement α = ∞. For α = 0, 268 // 269 // D_0(l || r) = -log \int_-∞^∞ r(x)1{p(x)>0} dx 270 // 271 // that is, the negative log probability under r(x) that l(x) > 0. 272 // When α = 1, the Rényi divergence is equal to the Kullback-Leibler divergence. 273 // The Rényi divergence is also equal to half the Bhattacharyya distance when α = 0.5. 274 // 275 // The parameter α must be in 0 ≤ α < ∞ or the distance functions will panic. 276 type Renyi struct { 277 Alpha float64 278 } 279 280 // DistNormal returns the Rényi divergence between normal distributions l and r. 281 // The dimensions of the input distributions must match or DistNormal will panic. 282 // 283 // For two normal distributions, the Rényi divergence is computed as 284 // 285 // Σ_α = (1-α) Σ_l + αΣ_r 286 // D_α(l||r) = α/2 * (μ_l - μ_r)'*Σ_α^-1*(μ_l - μ_r) + 1/(2(α-1))*ln(|Σ_λ|/(|Σ_l|^(1-α)*|Σ_r|^α)) 287 // 288 // For a more nicely formatted version of the formula, see Eq. 15 of 289 // 290 // Kolchinsky, Artemy, and Brendan D. Tracey. "Estimating Mixture Entropy 291 // with Pairwise Distances." arXiv preprint arXiv:1706.02419 (2017). 292 // 293 // Note that the this formula is for Chernoff divergence, which differs from 294 // Rényi divergence by a factor of 1-α. Also be aware that most sources in 295 // the literature report this formula incorrectly. 296 func (renyi Renyi) DistNormal(l, r *Normal) float64 { 297 if renyi.Alpha < 0 { 298 panic("renyi: alpha < 0") 299 } 300 dim := l.Dim() 301 if dim != r.Dim() { 302 panic(badSizeMismatch) 303 } 304 if renyi.Alpha == 0 { 305 return 0 306 } 307 if renyi.Alpha == 1 { 308 return KullbackLeibler{}.DistNormal(l, r) 309 } 310 311 logDetL := l.chol.LogDet() 312 logDetR := r.chol.LogDet() 313 314 // Σ_α = (1-α)Σ_l + αΣ_r. 315 sigA := mat.NewSymDense(dim, nil) 316 for i := 0; i < dim; i++ { 317 for j := i; j < dim; j++ { 318 v := (1-renyi.Alpha)*l.sigma.At(i, j) + renyi.Alpha*r.sigma.At(i, j) 319 sigA.SetSym(i, j, v) 320 } 321 } 322 323 var chol mat.Cholesky 324 ok := chol.Factorize(sigA) 325 if !ok { 326 return math.NaN() 327 } 328 logDetA := chol.LogDet() 329 330 mahalanobis := stat.Mahalanobis(mat.NewVecDense(dim, l.mu), mat.NewVecDense(dim, r.mu), &chol) 331 mahalanobisSq := mahalanobis * mahalanobis 332 333 return (renyi.Alpha/2)*mahalanobisSq + 1/(2*(1-renyi.Alpha))*(logDetA-(1-renyi.Alpha)*logDetL-renyi.Alpha*logDetR) 334 } 335 336 // Wasserstein is a type for computing the Wasserstein distance between two 337 // probability distributions. 338 // 339 // The Wasserstein distance is defined as 340 // 341 // W(l,r) := inf 𝔼(||X-Y||_2^2)^1/2 342 // 343 // For more information, see 344 // 345 // https://en.wikipedia.org/wiki/Wasserstein_metric 346 type Wasserstein struct{} 347 348 // DistNormal returns the Wasserstein distance between normal distributions l and r. 349 // The dimensions of the input distributions must match or DistNormal will panic. 350 // 351 // The Wasserstein distance for Normal distributions is 352 // 353 // d^2 = ||m_l - m_r||_2^2 + Tr(Σ_l + Σ_r - 2(Σ_l^(1/2)*Σ_r*Σ_l^(1/2))^(1/2)) 354 // 355 // For more information, see 356 // 357 // http://djalil.chafai.net/blog/2010/04/30/wasserstein-distance-between-two-gaussians/ 358 func (Wasserstein) DistNormal(l, r *Normal) float64 { 359 dim := l.Dim() 360 if dim != r.Dim() { 361 panic(badSizeMismatch) 362 } 363 364 d := floats.Distance(l.mu, r.mu, 2) 365 d = d * d 366 367 // Compute Σ_l^(1/2) 368 var ssl mat.SymDense 369 err := ssl.PowPSD(&l.sigma, 0.5) 370 if err != nil { 371 panic(err) 372 } 373 // Compute Σ_l^(1/2)*Σ_r*Σ_l^(1/2) 374 var mean mat.Dense 375 mean.Mul(&ssl, &r.sigma) 376 mean.Mul(&mean, &ssl) 377 378 // Reinterpret as symdense, and take Σ^(1/2) 379 meanSym := mat.NewSymDense(dim, mean.RawMatrix().Data) 380 err = ssl.PowPSD(meanSym, 0.5) 381 if err != nil { 382 panic(err) 383 } 384 385 tr := mat.Trace(&r.sigma) 386 tl := mat.Trace(&l.sigma) 387 tm := mat.Trace(&ssl) 388 389 return d + tl + tr - 2*tm 390 }