gonum.org/v1/gonum@v0.14.0/stat/distuv/general_test.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package distuv
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"testing"
    11  
    12  	"gonum.org/v1/gonum/diff/fd"
    13  	"gonum.org/v1/gonum/floats"
    14  )
    15  
    16  type univariateProbPoint struct {
    17  	loc     float64
    18  	logProb float64
    19  	cumProb float64
    20  	prob    float64
    21  }
    22  
    23  type UniProbDist interface {
    24  	Prob(float64) float64
    25  	CDF(float64) float64
    26  	LogProb(float64) float64
    27  	Quantile(float64) float64
    28  	Survival(float64) float64
    29  }
    30  
    31  func absEq(a, b float64) bool {
    32  	return absEqTol(a, b, 1e-14)
    33  }
    34  
    35  func absEqTol(a, b, tol float64) bool {
    36  	if math.IsNaN(a) || math.IsNaN(b) {
    37  		// NaN is not equal to anything.
    38  		return false
    39  	}
    40  	// This is expressed as the inverse to catch the
    41  	// case a = Inf and b = Inf of the same sign.
    42  	return !(math.Abs(a-b) > tol)
    43  }
    44  
    45  // TODO: Implement a better test for Quantile
    46  func testDistributionProbs(t *testing.T, dist UniProbDist, name string, pts []univariateProbPoint) {
    47  	for _, pt := range pts {
    48  		logProb := dist.LogProb(pt.loc)
    49  		if !absEq(logProb, pt.logProb) {
    50  			t.Errorf("Log probability doesnt match for "+name+" at %v. Expected %v. Found %v", pt.loc, pt.logProb, logProb)
    51  		}
    52  		prob := dist.Prob(pt.loc)
    53  		if !absEq(prob, pt.prob) {
    54  			t.Errorf("Probability doesn't match for "+name+" at %v. Expected %v. Found %v", pt.loc, pt.prob, prob)
    55  		}
    56  		cumProb := dist.CDF(pt.loc)
    57  		if !absEq(cumProb, pt.cumProb) {
    58  			t.Errorf("Cumulative Probability doesn't match for "+name+". Expected %v. Found %v", pt.cumProb, cumProb)
    59  		}
    60  		if !absEq(dist.Survival(pt.loc), 1-pt.cumProb) {
    61  			t.Errorf("Survival doesn't match for %v. Expected %v, Found %v", name, 1-pt.cumProb, dist.Survival(pt.loc))
    62  		}
    63  		if pt.prob != 0 {
    64  			if math.Abs(dist.Quantile(pt.cumProb)-pt.loc) > 1e-4 {
    65  				fmt.Println("true =", pt.loc)
    66  				fmt.Println("calculated=", dist.Quantile(pt.cumProb))
    67  				t.Errorf("Quantile doesn't match for "+name+", loc =  %v", pt.loc)
    68  			}
    69  		}
    70  	}
    71  }
    72  
    73  type ConjugateUpdater interface {
    74  	NumParameters() int
    75  	parameters([]Parameter) []Parameter
    76  
    77  	NumSuffStat() int
    78  	SuffStat([]float64, []float64, []float64) float64
    79  	ConjugateUpdate([]float64, float64, []float64)
    80  
    81  	Rand() float64
    82  }
    83  
    84  func testConjugateUpdate(t *testing.T, newFittable func() ConjugateUpdater) {
    85  	for i, test := range []struct {
    86  		samps   []float64
    87  		weights []float64
    88  	}{
    89  		{
    90  			samps:   randn(newFittable(), 10),
    91  			weights: nil,
    92  		},
    93  		{
    94  			samps:   randn(newFittable(), 10),
    95  			weights: ones(10),
    96  		},
    97  		{
    98  			samps:   randn(newFittable(), 10),
    99  			weights: randn(&Exponential{Rate: 1}, 10),
   100  		},
   101  	} {
   102  		// ensure that conjugate produces the same result both incrementally and all at once
   103  		incDist := newFittable()
   104  		stats := make([]float64, incDist.NumSuffStat())
   105  		prior := make([]float64, incDist.NumParameters())
   106  		for j := range test.samps {
   107  			var incWeights, allWeights []float64
   108  			if test.weights != nil {
   109  				incWeights = test.weights[j : j+1]
   110  				allWeights = test.weights[0 : j+1]
   111  			}
   112  			nsInc := incDist.SuffStat(stats, test.samps[j:j+1], incWeights)
   113  			incDist.ConjugateUpdate(stats, nsInc, prior)
   114  
   115  			allDist := newFittable()
   116  			nsAll := allDist.SuffStat(stats, test.samps[0:j+1], allWeights)
   117  			allDist.ConjugateUpdate(stats, nsAll, make([]float64, allDist.NumParameters()))
   118  			if !parametersEqual(incDist.parameters(nil), allDist.parameters(nil), 1e-12) {
   119  				t.Errorf("prior doesn't match after incremental update for (%d, %d). Incremental is %v, all at once is %v", i, j, incDist, allDist)
   120  			}
   121  
   122  			if test.weights == nil {
   123  				onesDist := newFittable()
   124  				nsOnes := onesDist.SuffStat(stats, test.samps[0:j+1], ones(j+1))
   125  				onesDist.ConjugateUpdate(stats, nsOnes, make([]float64, onesDist.NumParameters()))
   126  				if !parametersEqual(onesDist.parameters(nil), incDist.parameters(nil), 1e-14) {
   127  					t.Errorf("nil and uniform weighted prior doesn't match for incremental update for (%d, %d). Uniform weighted is %v, nil is %v", i, j, onesDist, incDist)
   128  				}
   129  				if !parametersEqual(onesDist.parameters(nil), allDist.parameters(nil), 1e-14) {
   130  					t.Errorf("nil and uniform weighted prior doesn't match for all at once update for (%d, %d). Uniform weighted is %v, nil is %v", i, j, onesDist, incDist)
   131  				}
   132  			}
   133  		}
   134  	}
   135  	testSuffStatPanics(t, newFittable)
   136  	testConjugateUpdatePanics(t, newFittable)
   137  }
   138  
   139  func testSuffStatPanics(t *testing.T, newFittable func() ConjugateUpdater) {
   140  	dist := newFittable()
   141  	sample := randn(dist, 10)
   142  	if !panics(func() { dist.SuffStat(make([]float64, dist.NumSuffStat()), sample, make([]float64, len(sample)+1)) }) {
   143  		t.Errorf("Expected panic for mismatch between samples and weights lengths")
   144  	}
   145  	if !panics(func() { dist.SuffStat(make([]float64, dist.NumSuffStat()+1), sample, nil) }) {
   146  		t.Errorf("Expected panic for wrong sufficient statistic length")
   147  	}
   148  }
   149  
   150  func testConjugateUpdatePanics(t *testing.T, newFittable func() ConjugateUpdater) {
   151  	dist := newFittable()
   152  	if !panics(func() {
   153  		dist.ConjugateUpdate(make([]float64, dist.NumSuffStat()+1), 100, make([]float64, dist.NumParameters()))
   154  	}) {
   155  		t.Errorf("Expected panic for wrong sufficient statistic length")
   156  	}
   157  	if !panics(func() {
   158  		dist.ConjugateUpdate(make([]float64, dist.NumSuffStat()), 100, make([]float64, dist.NumParameters()+1))
   159  	}) {
   160  		t.Errorf("Expected panic for wrong prior strength length")
   161  	}
   162  }
   163  
   164  // randn generates a specified number of random samples
   165  func randn(dist Rander, n int) []float64 {
   166  	x := make([]float64, n)
   167  	for i := range x {
   168  		x[i] = dist.Rand()
   169  	}
   170  	return x
   171  }
   172  
   173  func ones(n int) []float64 {
   174  	x := make([]float64, n)
   175  	for i := range x {
   176  		x[i] = 1
   177  	}
   178  	return x
   179  }
   180  
   181  func parametersEqual(p1, p2 []Parameter, tol float64) bool {
   182  	for i, p := range p1 {
   183  		if p.Name != p2[i].Name {
   184  			return false
   185  		}
   186  		if math.Abs(p.Value-p2[i].Value) > tol {
   187  			return false
   188  		}
   189  	}
   190  	return true
   191  }
   192  
   193  type derivParamTester interface {
   194  	LogProb(x float64) float64
   195  	Score(deriv []float64, x float64) []float64
   196  	ScoreInput(x float64) float64
   197  	Quantile(p float64) float64
   198  	NumParameters() int
   199  	parameters([]Parameter) []Parameter
   200  	setParameters([]Parameter)
   201  }
   202  
   203  func testDerivParam(t *testing.T, d derivParamTester) {
   204  	// Tests that the derivative matches for a number of different quantiles
   205  	// along the distribution.
   206  	nTest := 10
   207  	quantiles := make([]float64, nTest)
   208  	floats.Span(quantiles, 0.1, 0.9)
   209  
   210  	scoreInPlace := make([]float64, d.NumParameters())
   211  	fdDerivParam := make([]float64, d.NumParameters())
   212  
   213  	if !panics(func() { d.Score(make([]float64, d.NumParameters()+1), 0) }) {
   214  		t.Errorf("Expected panic for wrong derivative slice length")
   215  	}
   216  	if !panics(func() { d.parameters(make([]Parameter, d.NumParameters()+1)) }) {
   217  		t.Errorf("Expected panic for wrong parameter slice length")
   218  	}
   219  
   220  	initParams := d.parameters(nil)
   221  	tooLongParams := make([]Parameter, len(initParams)+1)
   222  	copy(tooLongParams, initParams)
   223  	if !panics(func() { d.setParameters(tooLongParams) }) {
   224  		t.Errorf("Expected panic for wrong parameter slice length")
   225  	}
   226  	badNameParams := make([]Parameter, len(initParams))
   227  	copy(badNameParams, initParams)
   228  	const badName = "__badName__"
   229  	for i := 0; i < len(initParams); i++ {
   230  		badNameParams[i].Name = badName
   231  		if !panics(func() { d.setParameters(badNameParams) }) {
   232  			t.Errorf("Expected panic for wrong %d-th parameter name", i)
   233  		}
   234  		badNameParams[i].Name = initParams[i].Name
   235  	}
   236  
   237  	init := make([]float64, d.NumParameters())
   238  	for i, v := range initParams {
   239  		init[i] = v.Value
   240  	}
   241  	for _, v := range quantiles {
   242  		d.setParameters(initParams)
   243  		x := d.Quantile(v)
   244  		score := d.Score(scoreInPlace, x)
   245  		if &score[0] != &scoreInPlace[0] {
   246  			t.Errorf("Returned a different derivative slice than passed in. Got %v, want %v", score, scoreInPlace)
   247  		}
   248  		logProbParams := func(p []float64) float64 {
   249  			params := d.parameters(nil)
   250  			for i, v := range p {
   251  				params[i].Value = v
   252  			}
   253  			d.setParameters(params)
   254  			return d.LogProb(x)
   255  		}
   256  		fd.Gradient(fdDerivParam, logProbParams, init, nil)
   257  		if !floats.EqualApprox(scoreInPlace, fdDerivParam, 1e-6) {
   258  			t.Errorf("Score mismatch at x = %g. Want %v, got %v", x, fdDerivParam, scoreInPlace)
   259  		}
   260  		d.setParameters(initParams)
   261  		score2 := d.Score(nil, x)
   262  		if !floats.EqualApprox(score2, scoreInPlace, 1e-14) {
   263  			t.Errorf("Score mismatch when input nil Want %v, got %v", score2, scoreInPlace)
   264  		}
   265  		logProbInput := func(x2 float64) float64 {
   266  			return d.LogProb(x2)
   267  		}
   268  		scoreInput := d.ScoreInput(x)
   269  		fdDerivInput := fd.Derivative(logProbInput, x, nil)
   270  		if !absEqTol(scoreInput, fdDerivInput, 1e-6) {
   271  			t.Errorf("ScoreInput mismatch at x = %g. Want %v, got %v", x, fdDerivInput, scoreInput)
   272  		}
   273  	}
   274  }