github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/stat/distuv/statdist_test.go (about)

     1  // Copyright ©2018 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package distuv
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"github.com/jingcheng-WU/gonum/floats"
    14  	"github.com/jingcheng-WU/gonum/floats/scalar"
    15  )
    16  
    17  func TestBhattacharyyaBeta(t *testing.T) {
    18  	t.Parallel()
    19  	rnd := rand.New(rand.NewSource(1))
    20  	for cas, test := range []struct {
    21  		a, b    Beta
    22  		samples int
    23  		tol     float64
    24  	}{
    25  		{
    26  			a:       Beta{Alpha: 1, Beta: 2, Src: rnd},
    27  			b:       Beta{Alpha: 1, Beta: 4, Src: rnd},
    28  			samples: 100000,
    29  			tol:     1e-2,
    30  		},
    31  		{
    32  			a:       Beta{Alpha: 0.5, Beta: 0.4, Src: rnd},
    33  			b:       Beta{Alpha: 0.7, Beta: 0.2, Src: rnd},
    34  			samples: 100000,
    35  			tol:     1e-2,
    36  		},
    37  		{
    38  			a:       Beta{Alpha: 3, Beta: 5, Src: rnd},
    39  			b:       Beta{Alpha: 5, Beta: 3, Src: rnd},
    40  			samples: 100000,
    41  			tol:     1e-2,
    42  		},
    43  	} {
    44  		want := bhattacharyyaSample(test.samples, test.a, test.b)
    45  		got := Bhattacharyya{}.DistBeta(test.a, test.b)
    46  		if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) {
    47  			t.Errorf("Bhattacharyya mismatch, case %d: got %v, want %v", cas, got, want)
    48  		}
    49  
    50  		// Bhattacharyya should be symmetric
    51  		got2 := Bhattacharyya{}.DistBeta(test.b, test.a)
    52  		if math.Abs(got-got2) > 1e-14 {
    53  			t.Errorf("Bhattacharyya distance not symmetric")
    54  		}
    55  	}
    56  }
    57  
    58  func TestBhattacharyyaNormal(t *testing.T) {
    59  	t.Parallel()
    60  	rnd := rand.New(rand.NewSource(1))
    61  	for cas, test := range []struct {
    62  		a, b    Normal
    63  		samples int
    64  		tol     float64
    65  	}{
    66  		{
    67  			a:       Normal{Mu: 1, Sigma: 2, Src: rnd},
    68  			b:       Normal{Mu: 1, Sigma: 4, Src: rnd},
    69  			samples: 100000,
    70  			tol:     1e-2,
    71  		},
    72  		{
    73  			a:       Normal{Mu: 0, Sigma: 2, Src: rnd},
    74  			b:       Normal{Mu: 2, Sigma: 2, Src: rnd},
    75  			samples: 100000,
    76  			tol:     1e-2,
    77  		},
    78  		{
    79  			a:       Normal{Mu: 0, Sigma: 5, Src: rnd},
    80  			b:       Normal{Mu: 2, Sigma: 0.1, Src: rnd},
    81  			samples: 200000,
    82  			tol:     1e-2,
    83  		},
    84  	} {
    85  		want := bhattacharyyaSample(test.samples, test.a, test.b)
    86  		got := Bhattacharyya{}.DistNormal(test.a, test.b)
    87  		if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) {
    88  			t.Errorf("Bhattacharyya mismatch, case %d: got %v, want %v", cas, got, want)
    89  		}
    90  
    91  		// Bhattacharyya should be symmetric
    92  		got2 := Bhattacharyya{}.DistNormal(test.b, test.a)
    93  		if math.Abs(got-got2) > 1e-14 {
    94  			t.Errorf("Bhattacharyya distance not symmetric")
    95  		}
    96  	}
    97  }
    98  
    99  // bhattacharyyaSample finds an estimate of the Bhattacharyya coefficient through
   100  // sampling.
   101  func bhattacharyyaSample(samples int, l RandLogProber, r LogProber) float64 {
   102  	lBhatt := make([]float64, samples)
   103  	for i := 0; i < samples; i++ {
   104  		// Do importance sampling over a: \int sqrt(a*b)/a * a dx
   105  		x := l.Rand()
   106  		pa := l.LogProb(x)
   107  		pb := r.LogProb(x)
   108  		lBhatt[i] = 0.5*pb - 0.5*pa
   109  	}
   110  	logBc := floats.LogSumExp(lBhatt) - math.Log(float64(samples))
   111  	return -logBc
   112  }
   113  
   114  func TestKullbackLeiblerBeta(t *testing.T) {
   115  	t.Parallel()
   116  	rnd := rand.New(rand.NewSource(1))
   117  	for cas, test := range []struct {
   118  		a, b    Beta
   119  		samples int
   120  		tol     float64
   121  	}{
   122  		{
   123  			a:       Beta{Alpha: 1, Beta: 2, Src: rnd},
   124  			b:       Beta{Alpha: 1, Beta: 4, Src: rnd},
   125  			samples: 100000,
   126  			tol:     1e-2,
   127  		},
   128  		{
   129  			a:       Beta{Alpha: 0.5, Beta: 0.4, Src: rnd},
   130  			b:       Beta{Alpha: 0.7, Beta: 0.2, Src: rnd},
   131  			samples: 100000,
   132  			tol:     1e-2,
   133  		},
   134  		{
   135  			a:       Beta{Alpha: 3, Beta: 5, Src: rnd},
   136  			b:       Beta{Alpha: 5, Beta: 3, Src: rnd},
   137  			samples: 100000,
   138  			tol:     1e-2,
   139  		},
   140  	} {
   141  		a, b := test.a, test.b
   142  		want := klSample(test.samples, a, b)
   143  		got := KullbackLeibler{}.DistBeta(a, b)
   144  		if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) {
   145  			t.Errorf("Kullback-Leibler mismatch, case %d: got %v, want %v", cas, got, want)
   146  		}
   147  	}
   148  	good := Beta{0.5, 0.5, nil}
   149  	bad := Beta{0, 1, nil}
   150  	if !panics(func() { KullbackLeibler{}.DistBeta(bad, good) }) {
   151  		t.Errorf("Expected Kullback-Leibler to panic when called with invalid left Beta distribution")
   152  	}
   153  	if !panics(func() { KullbackLeibler{}.DistBeta(good, bad) }) {
   154  		t.Errorf("Expected Kullback-Leibler to panic when called with invalid right Beta distribution")
   155  	}
   156  	bad = Beta{1, 0, nil}
   157  	if !panics(func() { KullbackLeibler{}.DistBeta(bad, good) }) {
   158  		t.Errorf("Expected Kullback-Leibler to panic when called with invalid left Beta distribution")
   159  	}
   160  	if !panics(func() { KullbackLeibler{}.DistBeta(good, bad) }) {
   161  		t.Errorf("Expected Kullback-Leibler to panic when called with invalid right Beta distribution")
   162  	}
   163  }
   164  
   165  func TestKullbackLeiblerNormal(t *testing.T) {
   166  	t.Parallel()
   167  	rnd := rand.New(rand.NewSource(1))
   168  	for cas, test := range []struct {
   169  		a, b    Normal
   170  		samples int
   171  		tol     float64
   172  	}{
   173  		{
   174  			a:       Normal{Mu: 1, Sigma: 2, Src: rnd},
   175  			b:       Normal{Mu: 1, Sigma: 4, Src: rnd},
   176  			samples: 100000,
   177  			tol:     1e-2,
   178  		},
   179  		{
   180  			a:       Normal{Mu: 0, Sigma: 2, Src: rnd},
   181  			b:       Normal{Mu: 2, Sigma: 2, Src: rnd},
   182  			samples: 100000,
   183  			tol:     1e-2,
   184  		},
   185  		{
   186  			a:       Normal{Mu: 0, Sigma: 5, Src: rnd},
   187  			b:       Normal{Mu: 2, Sigma: 0.1, Src: rnd},
   188  			samples: 100000,
   189  			tol:     1e-2,
   190  		},
   191  	} {
   192  		a, b := test.a, test.b
   193  		want := klSample(test.samples, a, b)
   194  		got := KullbackLeibler{}.DistNormal(a, b)
   195  		if !scalar.EqualWithinAbsOrRel(want, got, test.tol, test.tol) {
   196  			t.Errorf("Kullback-Leibler mismatch, case %d: got %v, want %v", cas, got, want)
   197  		}
   198  	}
   199  }
   200  
   201  // klSample finds an estimate of the Kullback-Leibler divergence through sampling.
   202  func klSample(samples int, l RandLogProber, r LogProber) float64 {
   203  	var klmc float64
   204  	for i := 0; i < samples; i++ {
   205  		x := l.Rand()
   206  		pa := l.LogProb(x)
   207  		pb := r.LogProb(x)
   208  		klmc += pa - pb
   209  	}
   210  	return klmc / float64(samples)
   211  }
   212  
   213  func TestHellingerBeta(t *testing.T) {
   214  	t.Parallel()
   215  	rnd := rand.New(rand.NewSource(1))
   216  	const tol = 1e-15
   217  	for cas, test := range []struct {
   218  		a, b Beta
   219  	}{
   220  		{
   221  			a: Beta{Alpha: 1, Beta: 2, Src: rnd},
   222  			b: Beta{Alpha: 1, Beta: 4, Src: rnd},
   223  		},
   224  		{
   225  			a: Beta{Alpha: 0.5, Beta: 0.4, Src: rnd},
   226  			b: Beta{Alpha: 0.7, Beta: 0.2, Src: rnd},
   227  		},
   228  		{
   229  			a: Beta{Alpha: 3, Beta: 5, Src: rnd},
   230  			b: Beta{Alpha: 5, Beta: 3, Src: rnd},
   231  		},
   232  	} {
   233  		got := Hellinger{}.DistBeta(test.a, test.b)
   234  		want := math.Sqrt(1 - math.Exp(-Bhattacharyya{}.DistBeta(test.a, test.b)))
   235  		if !scalar.EqualWithinAbsOrRel(got, want, tol, tol) {
   236  			t.Errorf("Hellinger mismatch, case %d: got %v, want %v", cas, got, want)
   237  		}
   238  	}
   239  }
   240  
   241  func TestHellingerNormal(t *testing.T) {
   242  	t.Parallel()
   243  	rnd := rand.New(rand.NewSource(1))
   244  	const tol = 1e-15
   245  	for cas, test := range []struct {
   246  		a, b Normal
   247  	}{
   248  		{
   249  			a: Normal{Mu: 1, Sigma: 2, Src: rnd},
   250  			b: Normal{Mu: 1, Sigma: 4, Src: rnd},
   251  		},
   252  		{
   253  			a: Normal{Mu: 0, Sigma: 2, Src: rnd},
   254  			b: Normal{Mu: 2, Sigma: 2, Src: rnd},
   255  		},
   256  		{
   257  			a: Normal{Mu: 0, Sigma: 5, Src: rnd},
   258  			b: Normal{Mu: 2, Sigma: 0.1, Src: rnd},
   259  		},
   260  	} {
   261  		got := Hellinger{}.DistNormal(test.a, test.b)
   262  		want := math.Sqrt(1 - math.Exp(-Bhattacharyya{}.DistNormal(test.a, test.b)))
   263  		if !scalar.EqualWithinAbsOrRel(got, want, tol, tol) {
   264  			t.Errorf("Hellinger mismatch, case %d: got %v, want %v", cas, got, want)
   265  		}
   266  	}
   267  }