github.com/gopherd/gonum@v0.0.4/diff/fd/derivative_test.go (about) 1 // Copyright ©2014 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package fd 6 7 import ( 8 "math" 9 "testing" 10 ) 11 12 var xSquared = func(x float64) float64 { return x * x } 13 14 type testPoint struct { 15 f func(float64) float64 16 loc float64 17 fofx float64 18 ans float64 19 } 20 21 var testsFirst = []testPoint{ 22 { 23 f: xSquared, 24 loc: 0, 25 fofx: 0, 26 ans: 0, 27 }, 28 { 29 f: xSquared, 30 loc: 5, 31 fofx: 25, 32 ans: 10, 33 }, 34 { 35 f: xSquared, 36 loc: 2, 37 fofx: 4, 38 ans: 4, 39 }, 40 { 41 f: xSquared, 42 loc: -5, 43 fofx: 25, 44 ans: -10, 45 }, 46 } 47 48 var testsSecond = []testPoint{ 49 { 50 f: xSquared, 51 loc: 0, 52 fofx: 0, 53 ans: 2, 54 }, 55 { 56 f: xSquared, 57 loc: 5, 58 fofx: 25, 59 ans: 2, 60 }, 61 { 62 f: xSquared, 63 loc: 2, 64 fofx: 4, 65 ans: 2, 66 }, 67 { 68 f: xSquared, 69 loc: -5, 70 fofx: 25, 71 ans: 2, 72 }, 73 } 74 75 func testDerivative(t *testing.T, formula Formula, tol float64, tests []testPoint) { 76 for i, test := range tests { 77 78 ans := Derivative(test.f, test.loc, &Settings{ 79 Formula: formula, 80 }) 81 if math.Abs(test.ans-ans) > tol { 82 t.Errorf("Case %v: ans mismatch serial: expected %v, found %v", i, test.ans, ans) 83 } 84 85 ans = Derivative(test.f, test.loc, &Settings{ 86 Formula: formula, 87 OriginKnown: true, 88 OriginValue: test.fofx, 89 }) 90 if math.Abs(test.ans-ans) > tol { 91 t.Errorf("Case %v: ans mismatch serial origin known: expected %v, found %v", i, test.ans, ans) 92 } 93 94 ans = Derivative(test.f, test.loc, &Settings{ 95 Formula: formula, 96 Concurrent: true, 97 }) 98 if math.Abs(test.ans-ans) > tol { 99 t.Errorf("Case %v: ans mismatch concurrent: expected %v, found %v", i, test.ans, ans) 100 } 101 102 ans = Derivative(test.f, test.loc, &Settings{ 103 Formula: formula, 104 OriginKnown: true, 105 OriginValue: test.fofx, 106 Concurrent: true, 107 }) 108 if math.Abs(test.ans-ans) > tol { 109 t.Errorf("Case %v: ans mismatch concurrent: expected %v, found %v", i, test.ans, ans) 110 } 111 } 112 } 113 114 func TestForward(t *testing.T) { 115 t.Parallel() 116 testDerivative(t, Forward, 2e-4, testsFirst) 117 } 118 119 func TestBackward(t *testing.T) { 120 t.Parallel() 121 testDerivative(t, Backward, 2e-4, testsFirst) 122 } 123 124 func TestCentral(t *testing.T) { 125 t.Parallel() 126 testDerivative(t, Central, 1e-6, testsFirst) 127 } 128 129 func TestCentralSecond(t *testing.T) { 130 t.Parallel() 131 testDerivative(t, Central2nd, 1e-3, testsSecond) 132 } 133 134 // TestDerivativeDefault checks that the derivative works when settings is nil 135 // or zero value. 136 func TestDerivativeDefault(t *testing.T) { 137 t.Parallel() 138 tol := 1e-6 139 for i, test := range testsFirst { 140 ans := Derivative(test.f, test.loc, nil) 141 if math.Abs(test.ans-ans) > tol { 142 t.Errorf("Case %v: ans mismatch default: expected %v, found %v", i, test.ans, ans) 143 } 144 145 ans = Derivative(test.f, test.loc, &Settings{}) 146 if math.Abs(test.ans-ans) > tol { 147 t.Errorf("Case %v: ans mismatch zero value: expected %v, found %v", i, test.ans, ans) 148 } 149 } 150 }