gonum.org/v1/gonum@v0.14.0/stat/distmv/statdist_test.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package distmv
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"gonum.org/v1/gonum/floats"
    14  	"gonum.org/v1/gonum/floats/scalar"
    15  	"gonum.org/v1/gonum/mat"
    16  	"gonum.org/v1/gonum/spatial/r1"
    17  )
    18  
    19  func TestBhattacharyyaNormal(t *testing.T) {
    20  	for cas, test := range []struct {
    21  		am, bm  []float64
    22  		ac, bc  *mat.SymDense
    23  		samples int
    24  		tol     float64
    25  	}{
    26  		{
    27  			am:      []float64{2, 3},
    28  			ac:      mat.NewSymDense(2, []float64{3, -1, -1, 2}),
    29  			bm:      []float64{-1, 1},
    30  			bc:      mat.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}),
    31  			samples: 100000,
    32  			tol:     3e-1,
    33  		},
    34  	} {
    35  		rnd := rand.New(rand.NewSource(1))
    36  		a, ok := NewNormal(test.am, test.ac, rnd)
    37  		if !ok {
    38  			panic("bad test")
    39  		}
    40  		b, ok := NewNormal(test.bm, test.bc, rnd)
    41  		if !ok {
    42  			panic("bad test")
    43  		}
    44  		want := bhattacharyyaSample(a.Dim(), test.samples, a, b)
    45  		got := Bhattacharyya{}.DistNormal(a, b)
    46  		if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) {
    47  			t.Errorf("Bhattacharyya mismatch, case %d: got %v, want %v", cas, got, want)
    48  		}
    49  
    50  		// Bhattacharyya should by symmetric
    51  		got2 := Bhattacharyya{}.DistNormal(b, a)
    52  		if math.Abs(got-got2) > 1e-14 {
    53  			t.Errorf("Bhattacharyya distance not symmetric")
    54  		}
    55  	}
    56  }
    57  
    58  func TestBhattacharyyaUniform(t *testing.T) {
    59  	rnd := rand.New(rand.NewSource(1))
    60  	for cas, test := range []struct {
    61  		a, b    *Uniform
    62  		samples int
    63  		tol     float64
    64  	}{
    65  		{
    66  			a:       NewUniform([]r1.Interval{{Min: -3, Max: 2}, {Min: -5, Max: 8}}, rnd),
    67  			b:       NewUniform([]r1.Interval{{Min: -4, Max: 1}, {Min: -7, Max: 10}}, rnd),
    68  			samples: 100000,
    69  			tol:     1e-2,
    70  		},
    71  		{
    72  			a:       NewUniform([]r1.Interval{{Min: -3, Max: 2}, {Min: -5, Max: 8}}, rnd),
    73  			b:       NewUniform([]r1.Interval{{Min: -5, Max: -4}, {Min: -7, Max: 10}}, rnd),
    74  			samples: 100000,
    75  			tol:     1e-2,
    76  		},
    77  	} {
    78  		a, b := test.a, test.b
    79  		want := bhattacharyyaSample(a.Dim(), test.samples, a, b)
    80  		got := Bhattacharyya{}.DistUniform(a, b)
    81  		if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) {
    82  			t.Errorf("Bhattacharyya mismatch, case %d: got %v, want %v", cas, got, want)
    83  		}
    84  		// Bhattacharyya should by symmetric
    85  		got2 := Bhattacharyya{}.DistUniform(b, a)
    86  		if math.Abs(got-got2) > 1e-14 {
    87  			t.Errorf("Bhattacharyya distance not symmetric")
    88  		}
    89  	}
    90  }
    91  
    92  // bhattacharyyaSample finds an estimate of the Bhattacharyya coefficient through
    93  // sampling.
    94  func bhattacharyyaSample(dim, samples int, l RandLogProber, r LogProber) float64 {
    95  	lBhatt := make([]float64, samples)
    96  	x := make([]float64, dim)
    97  	for i := 0; i < samples; i++ {
    98  		// Do importance sampling over a: \int sqrt(a*b)/a * a dx
    99  		l.Rand(x)
   100  		pa := l.LogProb(x)
   101  		pb := r.LogProb(x)
   102  		lBhatt[i] = 0.5*pb - 0.5*pa
   103  	}
   104  	logBc := floats.LogSumExp(lBhatt) - math.Log(float64(samples))
   105  	return -logBc
   106  }
   107  
   108  func TestCrossEntropyNormal(t *testing.T) {
   109  	for cas, test := range []struct {
   110  		am, bm  []float64
   111  		ac, bc  *mat.SymDense
   112  		samples int
   113  		tol     float64
   114  	}{
   115  		{
   116  			am:      []float64{2, 3},
   117  			ac:      mat.NewSymDense(2, []float64{3, -1, -1, 2}),
   118  			bm:      []float64{-1, 1},
   119  			bc:      mat.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}),
   120  			samples: 100000,
   121  			tol:     1e-2,
   122  		},
   123  	} {
   124  		rnd := rand.New(rand.NewSource(1))
   125  		a, ok := NewNormal(test.am, test.ac, rnd)
   126  		if !ok {
   127  			panic("bad test")
   128  		}
   129  		b, ok := NewNormal(test.bm, test.bc, rnd)
   130  		if !ok {
   131  			panic("bad test")
   132  		}
   133  		var ce float64
   134  		x := make([]float64, a.Dim())
   135  		for i := 0; i < test.samples; i++ {
   136  			a.Rand(x)
   137  			ce -= b.LogProb(x)
   138  		}
   139  		ce /= float64(test.samples)
   140  		got := CrossEntropy{}.DistNormal(a, b)
   141  		if !scalar.EqualWithinAbsOrRel(ce, got, test.tol, test.tol) {
   142  			t.Errorf("CrossEntropy mismatch, case %d: got %v, want %v", cas, got, ce)
   143  		}
   144  	}
   145  }
   146  
   147  func TestHellingerNormal(t *testing.T) {
   148  	for cas, test := range []struct {
   149  		am, bm  []float64
   150  		ac, bc  *mat.SymDense
   151  		samples int
   152  		tol     float64
   153  	}{
   154  		{
   155  			am:      []float64{2, 3},
   156  			ac:      mat.NewSymDense(2, []float64{3, -1, -1, 2}),
   157  			bm:      []float64{-1, 1},
   158  			bc:      mat.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}),
   159  			samples: 100000,
   160  			tol:     5e-1,
   161  		},
   162  	} {
   163  		rnd := rand.New(rand.NewSource(1))
   164  		a, ok := NewNormal(test.am, test.ac, rnd)
   165  		if !ok {
   166  			panic("bad test")
   167  		}
   168  		b, ok := NewNormal(test.bm, test.bc, rnd)
   169  		if !ok {
   170  			panic("bad test")
   171  		}
   172  		lAitchEDoubleHockeySticks := make([]float64, test.samples)
   173  		x := make([]float64, a.Dim())
   174  		for i := 0; i < test.samples; i++ {
   175  			// Do importance sampling over a: \int (\sqrt(a)-\sqrt(b))^2/a * a dx
   176  			a.Rand(x)
   177  			pa := a.LogProb(x)
   178  			pb := b.LogProb(x)
   179  			d := math.Exp(0.5*pa) - math.Exp(0.5*pb)
   180  			d = d * d
   181  			lAitchEDoubleHockeySticks[i] = math.Log(d) - pa
   182  		}
   183  		want := math.Sqrt(0.5 * math.Exp(floats.LogSumExp(lAitchEDoubleHockeySticks)-math.Log(float64(test.samples))))
   184  		got := Hellinger{}.DistNormal(a, b)
   185  		if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) {
   186  			t.Errorf("Hellinger mismatch, case %d: got %v, want %v", cas, got, want)
   187  		}
   188  	}
   189  }
   190  
   191  func TestKullbackLeiblerDirichlet(t *testing.T) {
   192  	rnd := rand.New(rand.NewSource(1))
   193  	for cas, test := range []struct {
   194  		a, b    *Dirichlet
   195  		samples int
   196  		tol     float64
   197  	}{
   198  		{
   199  			a:       NewDirichlet([]float64{2, 3, 4}, rnd),
   200  			b:       NewDirichlet([]float64{4, 2, 1.1}, rnd),
   201  			samples: 100000,
   202  			tol:     1e-2,
   203  		},
   204  		{
   205  			a:       NewDirichlet([]float64{2, 3, 4, 0.1, 8}, rnd),
   206  			b:       NewDirichlet([]float64{2, 2, 6, 0.5, 9}, rnd),
   207  			samples: 100000,
   208  			tol:     1e-2,
   209  		},
   210  	} {
   211  		a, b := test.a, test.b
   212  		want := klSample(a.Dim(), test.samples, a, b)
   213  		got := KullbackLeibler{}.DistDirichlet(a, b)
   214  		if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) {
   215  			t.Errorf("Kullback-Leibler mismatch, case %d: got %v, want %v", cas, got, want)
   216  		}
   217  	}
   218  }
   219  
   220  func TestKullbackLeiblerNormal(t *testing.T) {
   221  	for cas, test := range []struct {
   222  		am, bm  []float64
   223  		ac, bc  *mat.SymDense
   224  		samples int
   225  		tol     float64
   226  	}{
   227  		{
   228  			am:      []float64{2, 3},
   229  			ac:      mat.NewSymDense(2, []float64{3, -1, -1, 2}),
   230  			bm:      []float64{-1, 1},
   231  			bc:      mat.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}),
   232  			samples: 10000,
   233  			tol:     1e-2,
   234  		},
   235  	} {
   236  		rnd := rand.New(rand.NewSource(1))
   237  		a, ok := NewNormal(test.am, test.ac, rnd)
   238  		if !ok {
   239  			panic("bad test")
   240  		}
   241  		b, ok := NewNormal(test.bm, test.bc, rnd)
   242  		if !ok {
   243  			panic("bad test")
   244  		}
   245  		want := klSample(a.Dim(), test.samples, a, b)
   246  		got := KullbackLeibler{}.DistNormal(a, b)
   247  		if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) {
   248  			t.Errorf("Case %d, KL mismatch: got %v, want %v", cas, got, want)
   249  		}
   250  	}
   251  }
   252  
   253  func TestKullbackLeiblerUniform(t *testing.T) {
   254  	rnd := rand.New(rand.NewSource(1))
   255  	for cas, test := range []struct {
   256  		a, b    *Uniform
   257  		samples int
   258  		tol     float64
   259  	}{
   260  		{
   261  			a:       NewUniform([]r1.Interval{{Min: -5, Max: 2}, {Min: -7, Max: 12}}, rnd),
   262  			b:       NewUniform([]r1.Interval{{Min: -4, Max: 1}, {Min: -7, Max: 10}}, rnd),
   263  			samples: 100000,
   264  			tol:     1e-2,
   265  		},
   266  		{
   267  			a:       NewUniform([]r1.Interval{{Min: -5, Max: 2}, {Min: -7, Max: 12}}, rnd),
   268  			b:       NewUniform([]r1.Interval{{Min: -9, Max: -6}, {Min: -7, Max: 10}}, rnd),
   269  			samples: 100000,
   270  			tol:     1e-2,
   271  		},
   272  	} {
   273  		a, b := test.a, test.b
   274  		want := klSample(a.Dim(), test.samples, a, b)
   275  		got := KullbackLeibler{}.DistUniform(a, b)
   276  		if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) {
   277  			t.Errorf("Kullback-Leibler mismatch, case %d: got %v, want %v", cas, got, want)
   278  		}
   279  	}
   280  }
   281  
   282  // klSample finds an estimate of the Kullback-Leibler divergence through sampling.
   283  func klSample(dim, samples int, l RandLogProber, r LogProber) float64 {
   284  	var klmc float64
   285  	x := make([]float64, dim)
   286  	for i := 0; i < samples; i++ {
   287  		l.Rand(x)
   288  		pa := l.LogProb(x)
   289  		pb := r.LogProb(x)
   290  		klmc += pa - pb
   291  	}
   292  	return klmc / float64(samples)
   293  }
   294  
   295  func TestRenyiNormal(t *testing.T) {
   296  	for cas, test := range []struct {
   297  		am, bm  []float64
   298  		ac, bc  *mat.SymDense
   299  		alpha   float64
   300  		samples int
   301  		tol     float64
   302  	}{
   303  		{
   304  			am:      []float64{2, 3},
   305  			ac:      mat.NewSymDense(2, []float64{3, -1, -1, 2}),
   306  			bm:      []float64{-1, 1},
   307  			bc:      mat.NewSymDense(2, []float64{1.5, 0.2, 0.2, 0.9}),
   308  			alpha:   0.3,
   309  			samples: 10000,
   310  			tol:     3e-1,
   311  		},
   312  	} {
   313  		rnd := rand.New(rand.NewSource(1))
   314  		a, ok := NewNormal(test.am, test.ac, rnd)
   315  		if !ok {
   316  			panic("bad test")
   317  		}
   318  		b, ok := NewNormal(test.bm, test.bc, rnd)
   319  		if !ok {
   320  			panic("bad test")
   321  		}
   322  		want := renyiSample(a.Dim(), test.samples, test.alpha, a, b)
   323  		got := Renyi{Alpha: test.alpha}.DistNormal(a, b)
   324  		if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) {
   325  			t.Errorf("Case %d: Renyi sampling mismatch: got %v, want %v", cas, got, want)
   326  		}
   327  
   328  		// Compare with Bhattacharyya.
   329  		want = 2 * Bhattacharyya{}.DistNormal(a, b)
   330  		got = Renyi{Alpha: 0.5}.DistNormal(a, b)
   331  		if !scalar.EqualWithinAbsOrRel(want, got, 1e-10, 1e-10) {
   332  			t.Errorf("Case %d: Renyi mismatch with Bhattacharyya: got %v, want %v", cas, got, want)
   333  		}
   334  
   335  		// Compare with KL in both directions.
   336  		want = KullbackLeibler{}.DistNormal(a, b)
   337  		got = Renyi{Alpha: 0.9999999}.DistNormal(a, b) // very close to 1 but not equal to 1.
   338  		if !scalar.EqualWithinAbsOrRel(want, got, 1e-6, 1e-6) {
   339  			t.Errorf("Case %d: Renyi mismatch with KL(a||b): got %v, want %v", cas, got, want)
   340  		}
   341  		want = KullbackLeibler{}.DistNormal(b, a)
   342  		got = Renyi{Alpha: 0.9999999}.DistNormal(b, a) // very close to 1 but not equal to 1.
   343  		if !scalar.EqualWithinAbsOrRel(want, got, 1e-6, 1e-6) {
   344  			t.Errorf("Case %d: Renyi mismatch with KL(b||a): got %v, want %v", cas, got, want)
   345  		}
   346  	}
   347  }
   348  
   349  // renyiSample finds an estimate of the Rényi divergence through sampling.
   350  // Note that this sampling procedure only works if l has broader support than r.
   351  func renyiSample(dim, samples int, alpha float64, l RandLogProber, r LogProber) float64 {
   352  	rmcs := make([]float64, samples)
   353  	x := make([]float64, dim)
   354  	for i := 0; i < samples; i++ {
   355  		l.Rand(x)
   356  		pa := l.LogProb(x)
   357  		pb := r.LogProb(x)
   358  		rmcs[i] = (alpha-1)*pa + (1-alpha)*pb
   359  	}
   360  	return 1 / (alpha - 1) * (floats.LogSumExp(rmcs) - math.Log(float64(samples)))
   361  }