github.com/gopherd/gonum@v0.0.4/stat/distuv/categorical_test.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package distuv
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  
    11  	"math/rand"
    12  
    13  	"github.com/gopherd/gonum/floats"
    14  	"github.com/gopherd/gonum/floats/scalar"
    15  )
    16  
    17  const (
    18  	Tiny   = 2
    19  	Small  = 5
    20  	Medium = 10
    21  	Large  = 100
    22  	Huge   = 1000
    23  )
    24  
    25  func TestCategoricalProb(t *testing.T) {
    26  	t.Parallel()
    27  	for _, test := range [][]float64{
    28  		{1, 2, 3, 0},
    29  	} {
    30  		dist := NewCategorical(test, nil)
    31  		norm := make([]float64, len(test))
    32  		floats.Scale(1/floats.Sum(norm), norm)
    33  		for i, v := range norm {
    34  			p := dist.Prob(float64(i))
    35  			if math.Abs(p-v) > 1e-14 {
    36  				t.Errorf("Probability mismatch element %d", i)
    37  			}
    38  			logP := dist.LogProb(float64(i))
    39  			if math.Abs(logP-math.Log(v)) > 1e-14 {
    40  				t.Errorf("Log-probability mismatch element %d", i)
    41  			}
    42  			p = dist.Prob(float64(i) + 0.5)
    43  			if p != 0 {
    44  				t.Errorf("Non-zero probability for non-integer x")
    45  			}
    46  			logP = dist.LogProb(float64(i) + 0.5)
    47  			if !math.IsInf(logP, -1) {
    48  				t.Errorf("Log-probability for non-integer x is not -Inf")
    49  			}
    50  		}
    51  		p := dist.Prob(-1)
    52  		if p != 0 {
    53  			t.Errorf("Non-zero probability for -1")
    54  		}
    55  		logP := dist.LogProb(-1)
    56  		if !math.IsInf(logP, -1) {
    57  			t.Errorf("Log-probability for -1 is not -Inf")
    58  		}
    59  		p = dist.Prob(float64(len(test)))
    60  		if p != 0 {
    61  			t.Errorf("Non-zero probability for len(test)")
    62  		}
    63  		logP = dist.LogProb(float64(len(test)))
    64  		if !math.IsInf(logP, -1) {
    65  			t.Errorf("Log-probability for len(test) is not -Inf")
    66  		}
    67  	}
    68  }
    69  
    70  func TestCategoricalRand(t *testing.T) {
    71  	t.Parallel()
    72  	for _, test := range [][]float64{
    73  		{1, 2, 3, 0},
    74  	} {
    75  		dist := NewCategorical(test, nil)
    76  		nSamples := 2000000
    77  		counts := sampleCategorical(t, dist, nSamples)
    78  
    79  		probs := make([]float64, len(test))
    80  		for i := range probs {
    81  			probs[i] = dist.Prob(float64(i))
    82  		}
    83  		same := samedDistCategorical(dist, counts, probs, 1e-2)
    84  		if !same {
    85  			t.Errorf("Probability mismatch. Want %v, got %v", probs, counts)
    86  		}
    87  
    88  		dist.Reweight(len(test)-1, 10)
    89  		counts = sampleCategorical(t, dist, nSamples)
    90  		probs = make([]float64, len(test))
    91  		for i := range probs {
    92  			probs[i] = dist.Prob(float64(i))
    93  		}
    94  		same = samedDistCategorical(dist, counts, probs, 1e-2)
    95  		if !same {
    96  			t.Errorf("Probability mismatch after Reweight. Want %v, got %v", probs, counts)
    97  		}
    98  
    99  		w := make([]float64, len(test))
   100  		for i := range w {
   101  			w[i] = rand.Float64()
   102  		}
   103  
   104  		dist.ReweightAll(w)
   105  		counts = sampleCategorical(t, dist, nSamples)
   106  		probs = make([]float64, len(test))
   107  		for i := range probs {
   108  			probs[i] = dist.Prob(float64(i))
   109  		}
   110  		same = samedDistCategorical(dist, counts, probs, 1e-2)
   111  		if !same {
   112  			t.Errorf("Probability mismatch after ReweightAll. Want %v, got %v", probs, counts)
   113  		}
   114  	}
   115  }
   116  
   117  func TestCategoricalReweight(t *testing.T) {
   118  	t.Parallel()
   119  	dist := NewCategorical([]float64{1, 1}, nil)
   120  	if !panics(func() { dist.Reweight(0, -1) }) {
   121  		t.Errorf("Reweight did not panic for negative weight")
   122  	}
   123  	dist.Reweight(0, 0)
   124  	if !panics(func() { dist.Reweight(1, 0) }) {
   125  		t.Errorf("Reweight did not panic when trying to set the last positive weight to zero")
   126  	}
   127  }
   128  
   129  func TestCategoricalReweightAll(t *testing.T) {
   130  	t.Parallel()
   131  	w := []float64{0, 1, 2, 1}
   132  	dist := NewCategorical(w, nil)
   133  	if !panics(func() { dist.ReweightAll([]float64{1, 1}) }) {
   134  		t.Errorf("ReweightAll did not panic for different number of weights")
   135  	}
   136  	w[0] = -1
   137  	if !panics(func() { dist.ReweightAll(w) }) {
   138  		t.Errorf("ReweightAll did not panic for a negative weight")
   139  	}
   140  	w = []float64{0, 0, 0, 0}
   141  	if !panics(func() { dist.ReweightAll(w) }) {
   142  		t.Errorf("ReweightAll did not panic for weights which are all zero")
   143  	}
   144  }
   145  
   146  func sampleCategorical(t *testing.T, dist Categorical, nSamples int) []float64 {
   147  	counts := make([]float64, dist.Len())
   148  	for i := 0; i < nSamples; i++ {
   149  		v := dist.Rand()
   150  		if float64(int(v)) != v {
   151  			t.Fatalf("Random number is not an integer")
   152  		}
   153  		counts[int(v)]++
   154  	}
   155  	sum := floats.Sum(counts)
   156  	floats.Scale(1/sum, counts)
   157  	return counts
   158  }
   159  
   160  func samedDistCategorical(dist Categorical, counts, probs []float64, tol float64) bool {
   161  	same := true
   162  	for i, prob := range probs {
   163  		if prob == 0 && counts[i] != 0 {
   164  			same = false
   165  			break
   166  		}
   167  		if !scalar.EqualWithinAbsOrRel(prob, counts[i], tol, tol) {
   168  			same = false
   169  			break
   170  		}
   171  	}
   172  	return same
   173  }
   174  
   175  func TestCategoricalCDF(t *testing.T) {
   176  	t.Parallel()
   177  	for _, test := range [][]float64{
   178  		{1, 2, 3, 0, 4},
   179  	} {
   180  		c := make([]float64, len(test))
   181  		copy(c, test)
   182  		floats.Scale(1/floats.Sum(c), c)
   183  		sum := make([]float64, len(test))
   184  		floats.CumSum(sum, c)
   185  
   186  		dist := NewCategorical(test, nil)
   187  		cdf := dist.CDF(-0.5)
   188  		if cdf != 0 {
   189  			t.Errorf("CDF of negative number not zero")
   190  		}
   191  		for i := range c {
   192  			cdf := dist.CDF(float64(i))
   193  			if math.Abs(cdf-sum[i]) > 1e-14 {
   194  				t.Errorf("CDF mismatch %v. Want %v, got %v.", float64(i), sum[i], cdf)
   195  			}
   196  			cdfp := dist.CDF(float64(i) + 0.5)
   197  			if cdfp != cdf {
   198  				t.Errorf("CDF mismatch for non-integer input")
   199  			}
   200  		}
   201  	}
   202  }
   203  
   204  func TestCategoricalEntropy(t *testing.T) {
   205  	t.Parallel()
   206  	for _, test := range []struct {
   207  		weights []float64
   208  		entropy float64
   209  	}{
   210  		{
   211  			weights: []float64{1, 1},
   212  			entropy: math.Ln2,
   213  		},
   214  		{
   215  			weights: []float64{1, 1, 1, 1},
   216  			entropy: math.Log(4),
   217  		},
   218  		{
   219  			weights: []float64{0, 0, 1, 1, 0, 0},
   220  			entropy: math.Ln2,
   221  		},
   222  	} {
   223  		dist := NewCategorical(test.weights, nil)
   224  		entropy := dist.Entropy()
   225  		if math.IsNaN(entropy) || math.Abs(entropy-test.entropy) > 1e-14 {
   226  			t.Errorf("Entropy mismatch. Want %v, got %v.", test.entropy, entropy)
   227  		}
   228  	}
   229  }
   230  
   231  func TestCategoricalMean(t *testing.T) {
   232  	t.Parallel()
   233  	for _, test := range []struct {
   234  		weights []float64
   235  		mean    float64
   236  	}{
   237  		{
   238  			weights: []float64{10, 0, 0, 0},
   239  			mean:    0,
   240  		},
   241  		{
   242  			weights: []float64{0, 10, 0, 0},
   243  			mean:    1,
   244  		},
   245  		{
   246  			weights: []float64{1, 2, 3, 4},
   247  			mean:    2,
   248  		},
   249  	} {
   250  		dist := NewCategorical(test.weights, nil)
   251  		mean := dist.Mean()
   252  		if math.IsNaN(mean) || math.Abs(mean-test.mean) > 1e-14 {
   253  			t.Errorf("Entropy mismatch. Want %v, got %v.", test.mean, mean)
   254  		}
   255  	}
   256  }
   257  
   258  func BenchmarkCategoricalRandTiny(b *testing.B)   { benchmarkCategoricalRand(b, Tiny) }
   259  func BenchmarkCategoricalRandSmall(b *testing.B)  { benchmarkCategoricalRand(b, Small) }
   260  func BenchmarkCategoricalRandMedium(b *testing.B) { benchmarkCategoricalRand(b, Medium) }
   261  func BenchmarkCategoricalRandLarge(b *testing.B)  { benchmarkCategoricalRand(b, Large) }
   262  func BenchmarkCategoricalRandHuge(b *testing.B)   { benchmarkCategoricalRand(b, Huge) }
   263  
   264  func benchmarkCategoricalRand(b *testing.B, size int) {
   265  	src := rand.NewSource(1)
   266  	rng := rand.New(src)
   267  	weights := make([]float64, size)
   268  	for i := 0; i < size; i++ {
   269  		weights[i] = rng.Float64() + 0.001
   270  	}
   271  	dist := NewCategorical(weights, src)
   272  	for i := 0; i < b.N; i++ {
   273  		dist.Rand()
   274  	}
   275  }