github.com/gopherd/gonum@v0.0.4/diff/fd/derivative_test.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fd
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  )
    11  
    12  var xSquared = func(x float64) float64 { return x * x }
    13  
    14  type testPoint struct {
    15  	f    func(float64) float64
    16  	loc  float64
    17  	fofx float64
    18  	ans  float64
    19  }
    20  
    21  var testsFirst = []testPoint{
    22  	{
    23  		f:    xSquared,
    24  		loc:  0,
    25  		fofx: 0,
    26  		ans:  0,
    27  	},
    28  	{
    29  		f:    xSquared,
    30  		loc:  5,
    31  		fofx: 25,
    32  		ans:  10,
    33  	},
    34  	{
    35  		f:    xSquared,
    36  		loc:  2,
    37  		fofx: 4,
    38  		ans:  4,
    39  	},
    40  	{
    41  		f:    xSquared,
    42  		loc:  -5,
    43  		fofx: 25,
    44  		ans:  -10,
    45  	},
    46  }
    47  
    48  var testsSecond = []testPoint{
    49  	{
    50  		f:    xSquared,
    51  		loc:  0,
    52  		fofx: 0,
    53  		ans:  2,
    54  	},
    55  	{
    56  		f:    xSquared,
    57  		loc:  5,
    58  		fofx: 25,
    59  		ans:  2,
    60  	},
    61  	{
    62  		f:    xSquared,
    63  		loc:  2,
    64  		fofx: 4,
    65  		ans:  2,
    66  	},
    67  	{
    68  		f:    xSquared,
    69  		loc:  -5,
    70  		fofx: 25,
    71  		ans:  2,
    72  	},
    73  }
    74  
    75  func testDerivative(t *testing.T, formula Formula, tol float64, tests []testPoint) {
    76  	for i, test := range tests {
    77  
    78  		ans := Derivative(test.f, test.loc, &Settings{
    79  			Formula: formula,
    80  		})
    81  		if math.Abs(test.ans-ans) > tol {
    82  			t.Errorf("Case %v: ans mismatch serial: expected %v, found %v", i, test.ans, ans)
    83  		}
    84  
    85  		ans = Derivative(test.f, test.loc, &Settings{
    86  			Formula:     formula,
    87  			OriginKnown: true,
    88  			OriginValue: test.fofx,
    89  		})
    90  		if math.Abs(test.ans-ans) > tol {
    91  			t.Errorf("Case %v: ans mismatch serial origin known: expected %v, found %v", i, test.ans, ans)
    92  		}
    93  
    94  		ans = Derivative(test.f, test.loc, &Settings{
    95  			Formula:    formula,
    96  			Concurrent: true,
    97  		})
    98  		if math.Abs(test.ans-ans) > tol {
    99  			t.Errorf("Case %v: ans mismatch concurrent: expected %v, found %v", i, test.ans, ans)
   100  		}
   101  
   102  		ans = Derivative(test.f, test.loc, &Settings{
   103  			Formula:     formula,
   104  			OriginKnown: true,
   105  			OriginValue: test.fofx,
   106  			Concurrent:  true,
   107  		})
   108  		if math.Abs(test.ans-ans) > tol {
   109  			t.Errorf("Case %v: ans mismatch concurrent: expected %v, found %v", i, test.ans, ans)
   110  		}
   111  	}
   112  }
   113  
   114  func TestForward(t *testing.T) {
   115  	t.Parallel()
   116  	testDerivative(t, Forward, 2e-4, testsFirst)
   117  }
   118  
   119  func TestBackward(t *testing.T) {
   120  	t.Parallel()
   121  	testDerivative(t, Backward, 2e-4, testsFirst)
   122  }
   123  
   124  func TestCentral(t *testing.T) {
   125  	t.Parallel()
   126  	testDerivative(t, Central, 1e-6, testsFirst)
   127  }
   128  
   129  func TestCentralSecond(t *testing.T) {
   130  	t.Parallel()
   131  	testDerivative(t, Central2nd, 1e-3, testsSecond)
   132  }
   133  
   134  // TestDerivativeDefault checks that the derivative works when settings is nil
   135  // or zero value.
   136  func TestDerivativeDefault(t *testing.T) {
   137  	t.Parallel()
   138  	tol := 1e-6
   139  	for i, test := range testsFirst {
   140  		ans := Derivative(test.f, test.loc, nil)
   141  		if math.Abs(test.ans-ans) > tol {
   142  			t.Errorf("Case %v: ans mismatch default: expected %v, found %v", i, test.ans, ans)
   143  		}
   144  
   145  		ans = Derivative(test.f, test.loc, &Settings{})
   146  		if math.Abs(test.ans-ans) > tol {
   147  			t.Errorf("Case %v: ans mismatch zero value: expected %v, found %v", i, test.ans, ans)
   148  		}
   149  	}
   150  }