gonum.org/v1/gonum@v0.14.0/stat/distmv/statdist.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package distmv
     6  
     7  import (
     8  	"math"
     9  
    10  	"gonum.org/v1/gonum/floats"
    11  	"gonum.org/v1/gonum/mat"
    12  	"gonum.org/v1/gonum/mathext"
    13  	"gonum.org/v1/gonum/spatial/r1"
    14  	"gonum.org/v1/gonum/stat"
    15  )
    16  
    17  // Bhattacharyya is a type for computing the Bhattacharyya distance between
    18  // probability distributions.
    19  //
    20  // The Bhattacharyya distance is defined as
    21  //
    22  //	D_B = -ln(BC(l,r))
    23  //	BC = \int_-∞^∞ (p(x)q(x))^(1/2) dx
    24  //
    25  // Where BC is known as the Bhattacharyya coefficient.
    26  // The Bhattacharyya distance is related to the Hellinger distance by
    27  //
    28  //	H(l,r) = sqrt(1-BC(l,r))
    29  //
    30  // For more information, see
    31  //
    32  //	https://en.wikipedia.org/wiki/Bhattacharyya_distance
    33  type Bhattacharyya struct{}
    34  
    35  // DistNormal computes the Bhattacharyya distance between normal distributions l and r.
    36  // The dimensions of the input distributions must match or DistNormal will panic.
    37  //
    38  // For Normal distributions, the Bhattacharyya distance is
    39  //
    40  //	Σ = (Σ_l + Σ_r)/2
    41  //	D_B = (1/8)*(μ_l - μ_r)ᵀ*Σ^-1*(μ_l - μ_r) + (1/2)*ln(det(Σ)/(det(Σ_l)*det(Σ_r))^(1/2))
    42  func (Bhattacharyya) DistNormal(l, r *Normal) float64 {
    43  	dim := l.Dim()
    44  	if dim != r.Dim() {
    45  		panic(badSizeMismatch)
    46  	}
    47  
    48  	var sigma mat.SymDense
    49  	sigma.AddSym(&l.sigma, &r.sigma)
    50  	sigma.ScaleSym(0.5, &sigma)
    51  
    52  	var chol mat.Cholesky
    53  	chol.Factorize(&sigma)
    54  
    55  	mahalanobis := stat.Mahalanobis(mat.NewVecDense(dim, l.mu), mat.NewVecDense(dim, r.mu), &chol)
    56  	mahalanobisSq := mahalanobis * mahalanobis
    57  
    58  	dl := l.chol.LogDet()
    59  	dr := r.chol.LogDet()
    60  	ds := chol.LogDet()
    61  
    62  	return 0.125*mahalanobisSq + 0.5*ds - 0.25*dl - 0.25*dr
    63  }
    64  
    65  // DistUniform computes the Bhattacharyya distance between uniform distributions l and r.
    66  // The dimensions of the input distributions must match or DistUniform will panic.
    67  func (Bhattacharyya) DistUniform(l, r *Uniform) float64 {
    68  	if len(l.bounds) != len(r.bounds) {
    69  		panic(badSizeMismatch)
    70  	}
    71  	// BC = \int \sqrt(p(x)q(x)), which for uniform distributions is a constant
    72  	// over the volume where both distributions have positive probability.
    73  	// Compute the overlap and the value of sqrt(p(x)q(x)). The entropy is the
    74  	// negative log probability of the distribution (use instead of LogProb so
    75  	// it is not necessary to construct an x value).
    76  	//
    77  	// BC = volume * sqrt(p(x)q(x))
    78  	// logBC = log(volume) + 0.5*(logP + logQ)
    79  	// D_B = -logBC
    80  	return -unifLogVolOverlap(l.bounds, r.bounds) + 0.5*(l.Entropy()+r.Entropy())
    81  }
    82  
    83  // unifLogVolOverlap computes the log of the volume of the hyper-rectangle where
    84  // both uniform distributions have positive probability.
    85  func unifLogVolOverlap(b1, b2 []r1.Interval) float64 {
    86  	var logVolOverlap float64
    87  	for dim, v1 := range b1 {
    88  		v2 := b2[dim]
    89  		// If the surfaces don't overlap, then the volume is 0
    90  		if v1.Max <= v2.Min || v2.Max <= v1.Min {
    91  			return math.Inf(-1)
    92  		}
    93  		vol := math.Min(v1.Max, v2.Max) - math.Max(v1.Min, v2.Min)
    94  		logVolOverlap += math.Log(vol)
    95  	}
    96  	return logVolOverlap
    97  }
    98  
    99  // CrossEntropy is a type for computing the cross-entropy between probability
   100  // distributions.
   101  //
   102  // The cross-entropy is defined as
   103  //   - \int_x l(x) log(r(x)) dx = KL(l || r) + H(l)
   104  //
   105  // where KL is the Kullback-Leibler divergence and H is the entropy.
   106  // For more information, see
   107  //
   108  //	https://en.wikipedia.org/wiki/Cross_entropy
   109  type CrossEntropy struct{}
   110  
   111  // DistNormal returns the cross-entropy between normal distributions l and r.
   112  // The dimensions of the input distributions must match or DistNormal will panic.
   113  func (CrossEntropy) DistNormal(l, r *Normal) float64 {
   114  	if l.Dim() != r.Dim() {
   115  		panic(badSizeMismatch)
   116  	}
   117  	kl := KullbackLeibler{}.DistNormal(l, r)
   118  	return kl + l.Entropy()
   119  }
   120  
   121  // Hellinger is a type for computing the Hellinger distance between probability
   122  // distributions.
   123  //
   124  // The Hellinger distance is defined as
   125  //
   126  //	H^2(l,r) = 1/2 * int_x (\sqrt(l(x)) - \sqrt(r(x)))^2 dx
   127  //
   128  // and is bounded between 0 and 1. Note the above formula defines the squared
   129  // Hellinger distance, while this returns the Hellinger distance itself.
   130  // The Hellinger distance is related to the Bhattacharyya distance by
   131  //
   132  //	H^2 = 1 - exp(-D_B)
   133  //
   134  // For more information, see
   135  //
   136  //	https://en.wikipedia.org/wiki/Hellinger_distance
   137  type Hellinger struct{}
   138  
   139  // DistNormal returns the Hellinger distance between normal distributions l and r.
   140  // The dimensions of the input distributions must match or DistNormal will panic.
   141  //
   142  // See the documentation of Bhattacharyya.DistNormal for the formula for Normal
   143  // distributions.
   144  func (Hellinger) DistNormal(l, r *Normal) float64 {
   145  	if l.Dim() != r.Dim() {
   146  		panic(badSizeMismatch)
   147  	}
   148  	db := Bhattacharyya{}.DistNormal(l, r)
   149  	bc := math.Exp(-db)
   150  	return math.Sqrt(1 - bc)
   151  }
   152  
   153  // KullbackLeibler is a type for computing the Kullback-Leibler divergence from l to r.
   154  //
   155  // The Kullback-Leibler divergence is defined as
   156  //
   157  //	D_KL(l || r ) = \int_x p(x) log(p(x)/q(x)) dx
   158  //
   159  // Note that the Kullback-Leibler divergence is not symmetric with respect to
   160  // the order of the input arguments.
   161  type KullbackLeibler struct{}
   162  
   163  // DistDirichlet returns the Kullback-Leibler divergence between Dirichlet
   164  // distributions l and r. The dimensions of the input distributions must match
   165  // or DistDirichlet will panic.
   166  //
   167  // For two Dirichlet distributions, the KL divergence is computed as
   168  //
   169  //	D_KL(l || r) = log Γ(α_0_l) - \sum_i log Γ(α_i_l) - log Γ(α_0_r) + \sum_i log Γ(α_i_r)
   170  //	               + \sum_i (α_i_l - α_i_r)(ψ(α_i_l)- ψ(α_0_l))
   171  //
   172  // Where Γ is the gamma function, ψ is the digamma function, and α_0 is the
   173  // sum of the Dirichlet parameters.
   174  func (KullbackLeibler) DistDirichlet(l, r *Dirichlet) float64 {
   175  	// http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
   176  	if l.Dim() != r.Dim() {
   177  		panic(badSizeMismatch)
   178  	}
   179  	l0, _ := math.Lgamma(l.sumAlpha)
   180  	r0, _ := math.Lgamma(r.sumAlpha)
   181  	dl := mathext.Digamma(l.sumAlpha)
   182  
   183  	var l1, r1, c float64
   184  	for i, al := range l.alpha {
   185  		ar := r.alpha[i]
   186  		vl, _ := math.Lgamma(al)
   187  		l1 += vl
   188  		vr, _ := math.Lgamma(ar)
   189  		r1 += vr
   190  		c += (al - ar) * (mathext.Digamma(al) - dl)
   191  	}
   192  	return l0 - l1 - r0 + r1 + c
   193  }
   194  
   195  // DistNormal returns the KullbackLeibler divergence between normal distributions l and r.
   196  // The dimensions of the input distributions must match or DistNormal will panic.
   197  //
   198  // For two normal distributions, the KL divergence is computed as
   199  //
   200  //	D_KL(l || r) = 0.5*[ln(|Σ_r|) - ln(|Σ_l|) + (μ_l - μ_r)ᵀ*Σ_r^-1*(μ_l - μ_r) + tr(Σ_r^-1*Σ_l)-d]
   201  func (KullbackLeibler) DistNormal(l, r *Normal) float64 {
   202  	dim := l.Dim()
   203  	if dim != r.Dim() {
   204  		panic(badSizeMismatch)
   205  	}
   206  
   207  	mahalanobis := stat.Mahalanobis(mat.NewVecDense(dim, l.mu), mat.NewVecDense(dim, r.mu), &r.chol)
   208  	mahalanobisSq := mahalanobis * mahalanobis
   209  
   210  	// TODO(btracey): Optimize where there is a SolveCholeskySym
   211  	// TODO(btracey): There may be a more efficient way to just compute the trace
   212  	// Compute tr(Σ_r^-1*Σ_l) using the fact that Σ_l = Uᵀ * U
   213  	var u mat.TriDense
   214  	l.chol.UTo(&u)
   215  	var m mat.Dense
   216  	err := r.chol.SolveTo(&m, u.T())
   217  	if err != nil {
   218  		return math.NaN()
   219  	}
   220  	m.Mul(&m, &u)
   221  	tr := mat.Trace(&m)
   222  
   223  	return r.logSqrtDet - l.logSqrtDet + 0.5*(mahalanobisSq+tr-float64(l.dim))
   224  }
   225  
   226  // DistUniform returns the KullbackLeibler divergence between uniform distributions
   227  // l and r. The dimensions of the input distributions must match or DistUniform
   228  // will panic.
   229  func (KullbackLeibler) DistUniform(l, r *Uniform) float64 {
   230  	bl := l.Bounds(nil)
   231  	br := r.Bounds(nil)
   232  	if len(bl) != len(br) {
   233  		panic(badSizeMismatch)
   234  	}
   235  
   236  	// The KL is ∞ if l is not completely contained within r, because then
   237  	// r(x) is zero when l(x) is non-zero for some x.
   238  	contained := true
   239  	for i, v := range bl {
   240  		if v.Min < br[i].Min || br[i].Max < v.Max {
   241  			contained = false
   242  			break
   243  		}
   244  	}
   245  	if !contained {
   246  		return math.Inf(1)
   247  	}
   248  
   249  	// The KL divergence is finite.
   250  	//
   251  	// KL defines 0*ln(0) = 0, so there is no contribution to KL where l(x) = 0.
   252  	// Inside the region, l(x) and r(x) are constant (uniform distribution), and
   253  	// this constant is integrated over l(x), which integrates out to one.
   254  	// The entropy is -log(p(x)).
   255  	logPx := -l.Entropy()
   256  	logQx := -r.Entropy()
   257  	return logPx - logQx
   258  }
   259  
   260  // Renyi is a type for computing the Rényi divergence of order α from l to r.
   261  //
   262  // The Rényi divergence with α > 0, α ≠ 1 is defined as
   263  //
   264  //	D_α(l || r) = 1/(α-1) log(\int_-∞^∞ l(x)^α r(x)^(1-α)dx)
   265  //
   266  // The Rényi divergence has special forms for α = 0 and α = 1. This type does
   267  // not implement α = ∞. For α = 0,
   268  //
   269  //	D_0(l || r) = -log \int_-∞^∞ r(x)1{p(x)>0} dx
   270  //
   271  // that is, the negative log probability under r(x) that l(x) > 0.
   272  // When α = 1, the Rényi divergence is equal to the Kullback-Leibler divergence.
   273  // The Rényi divergence is also equal to half the Bhattacharyya distance when α = 0.5.
   274  //
   275  // The parameter α must be in 0 ≤ α < ∞ or the distance functions will panic.
   276  type Renyi struct {
   277  	Alpha float64
   278  }
   279  
   280  // DistNormal returns the Rényi divergence between normal distributions l and r.
   281  // The dimensions of the input distributions must match or DistNormal will panic.
   282  //
   283  // For two normal distributions, the Rényi divergence is computed as
   284  //
   285  //	Σ_α = (1-α) Σ_l + αΣ_r
   286  //	D_α(l||r) = α/2 * (μ_l - μ_r)'*Σ_α^-1*(μ_l - μ_r) + 1/(2(α-1))*ln(|Σ_λ|/(|Σ_l|^(1-α)*|Σ_r|^α))
   287  //
   288  // For a more nicely formatted version of the formula, see Eq. 15 of
   289  //
   290  //	Kolchinsky, Artemy, and Brendan D. Tracey. "Estimating Mixture Entropy
   291  //	with Pairwise Distances." arXiv preprint arXiv:1706.02419 (2017).
   292  //
   293  // Note that the this formula is for Chernoff divergence, which differs from
   294  // Rényi divergence by a factor of 1-α. Also be aware that most sources in
   295  // the literature report this formula incorrectly.
   296  func (renyi Renyi) DistNormal(l, r *Normal) float64 {
   297  	if renyi.Alpha < 0 {
   298  		panic("renyi: alpha < 0")
   299  	}
   300  	dim := l.Dim()
   301  	if dim != r.Dim() {
   302  		panic(badSizeMismatch)
   303  	}
   304  	if renyi.Alpha == 0 {
   305  		return 0
   306  	}
   307  	if renyi.Alpha == 1 {
   308  		return KullbackLeibler{}.DistNormal(l, r)
   309  	}
   310  
   311  	logDetL := l.chol.LogDet()
   312  	logDetR := r.chol.LogDet()
   313  
   314  	// Σ_α = (1-α)Σ_l + αΣ_r.
   315  	sigA := mat.NewSymDense(dim, nil)
   316  	for i := 0; i < dim; i++ {
   317  		for j := i; j < dim; j++ {
   318  			v := (1-renyi.Alpha)*l.sigma.At(i, j) + renyi.Alpha*r.sigma.At(i, j)
   319  			sigA.SetSym(i, j, v)
   320  		}
   321  	}
   322  
   323  	var chol mat.Cholesky
   324  	ok := chol.Factorize(sigA)
   325  	if !ok {
   326  		return math.NaN()
   327  	}
   328  	logDetA := chol.LogDet()
   329  
   330  	mahalanobis := stat.Mahalanobis(mat.NewVecDense(dim, l.mu), mat.NewVecDense(dim, r.mu), &chol)
   331  	mahalanobisSq := mahalanobis * mahalanobis
   332  
   333  	return (renyi.Alpha/2)*mahalanobisSq + 1/(2*(1-renyi.Alpha))*(logDetA-(1-renyi.Alpha)*logDetL-renyi.Alpha*logDetR)
   334  }
   335  
   336  // Wasserstein is a type for computing the Wasserstein distance between two
   337  // probability distributions.
   338  //
   339  // The Wasserstein distance is defined as
   340  //
   341  //	W(l,r) := inf 𝔼(||X-Y||_2^2)^1/2
   342  //
   343  // For more information, see
   344  //
   345  //	https://en.wikipedia.org/wiki/Wasserstein_metric
   346  type Wasserstein struct{}
   347  
   348  // DistNormal returns the Wasserstein distance between normal distributions l and r.
   349  // The dimensions of the input distributions must match or DistNormal will panic.
   350  //
   351  // The Wasserstein distance for Normal distributions is
   352  //
   353  //	d^2 = ||m_l - m_r||_2^2 + Tr(Σ_l + Σ_r - 2(Σ_l^(1/2)*Σ_r*Σ_l^(1/2))^(1/2))
   354  //
   355  // For more information, see
   356  //
   357  //	http://djalil.chafai.net/blog/2010/04/30/wasserstein-distance-between-two-gaussians/
   358  func (Wasserstein) DistNormal(l, r *Normal) float64 {
   359  	dim := l.Dim()
   360  	if dim != r.Dim() {
   361  		panic(badSizeMismatch)
   362  	}
   363  
   364  	d := floats.Distance(l.mu, r.mu, 2)
   365  	d = d * d
   366  
   367  	// Compute Σ_l^(1/2)
   368  	var ssl mat.SymDense
   369  	err := ssl.PowPSD(&l.sigma, 0.5)
   370  	if err != nil {
   371  		panic(err)
   372  	}
   373  	// Compute Σ_l^(1/2)*Σ_r*Σ_l^(1/2)
   374  	var mean mat.Dense
   375  	mean.Mul(&ssl, &r.sigma)
   376  	mean.Mul(&mean, &ssl)
   377  
   378  	// Reinterpret as symdense, and take Σ^(1/2)
   379  	meanSym := mat.NewSymDense(dim, mean.RawMatrix().Data)
   380  	err = ssl.PowPSD(meanSym, 0.5)
   381  	if err != nil {
   382  		panic(err)
   383  	}
   384  
   385  	tr := mat.Trace(&r.sigma)
   386  	tl := mat.Trace(&l.sigma)
   387  	tm := mat.Trace(&ssl)
   388  
   389  	return d + tl + tr - 2*tm
   390  }