github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/diff/fd/gradient_test.go (about) 1 // Copyright ©2014 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package fd 6 7 import ( 8 "math" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "github.com/jingcheng-WU/gonum/floats" 14 ) 15 16 type Rosenbrock struct { 17 nDim int 18 } 19 20 func (r Rosenbrock) F(x []float64) (sum float64) { 21 deriv := make([]float64, len(x)) 22 return r.FDf(x, deriv) 23 } 24 25 func (r Rosenbrock) FDf(x []float64, deriv []float64) (sum float64) { 26 for i := range deriv { 27 deriv[i] = 0 28 } 29 30 for i := 0; i < len(x)-1; i++ { 31 sum += math.Pow(1-x[i], 2) + 100*math.Pow(x[i+1]-math.Pow(x[i], 2), 2) 32 } 33 for i := 0; i < len(x)-1; i++ { 34 deriv[i] += -1 * 2 * (1 - x[i]) 35 deriv[i] += 2 * 100 * (x[i+1] - math.Pow(x[i], 2)) * (-2 * x[i]) 36 } 37 for i := 1; i < len(x); i++ { 38 deriv[i] += 2 * 100 * (x[i] - math.Pow(x[i-1], 2)) 39 } 40 41 return sum 42 } 43 44 func TestGradient(t *testing.T) { 45 t.Parallel() 46 rnd := rand.New(rand.NewSource(1)) 47 for i, test := range []struct { 48 nDim int 49 tol float64 50 formula Formula 51 }{ 52 { 53 nDim: 2, 54 tol: 2e-4, 55 formula: Forward, 56 }, 57 { 58 nDim: 2, 59 tol: 1e-6, 60 formula: Central, 61 }, 62 { 63 nDim: 40, 64 tol: 2e-4, 65 formula: Forward, 66 }, 67 { 68 nDim: 40, 69 tol: 1e-5, 70 formula: Central, 71 }, 72 } { 73 x := make([]float64, test.nDim) 74 for i := range x { 75 x[i] = rnd.Float64() 76 } 77 xcopy := make([]float64, len(x)) 78 copy(xcopy, x) 79 80 r := Rosenbrock{len(x)} 81 trueGradient := make([]float64, len(x)) 82 r.FDf(x, trueGradient) 83 84 // Try with gradient nil. 85 gradient := Gradient(nil, r.F, x, &Settings{ 86 Formula: test.formula, 87 }) 88 if !floats.EqualApprox(gradient, trueGradient, test.tol) { 89 t.Errorf("Case %v: gradient mismatch in serial with nil. Want: %v, Got: %v.", i, trueGradient, gradient) 90 } 91 if !floats.Equal(x, xcopy) { 92 t.Errorf("Case %v: x modified during call to gradient in serial with nil.", i) 93 } 94 95 // Try with provided gradient. 96 for i := range gradient { 97 gradient[i] = rnd.Float64() 98 } 99 Gradient(gradient, r.F, x, &Settings{ 100 Formula: test.formula, 101 }) 102 if !floats.EqualApprox(gradient, trueGradient, test.tol) { 103 t.Errorf("Case %v: gradient mismatch in serial. Want: %v, Got: %v.", i, trueGradient, gradient) 104 } 105 if !floats.Equal(x, xcopy) { 106 t.Errorf("Case %v: x modified during call to gradient in serial with non-nil.", i) 107 } 108 109 // Try with known value. 110 for i := range gradient { 111 gradient[i] = rnd.Float64() 112 } 113 Gradient(gradient, r.F, x, &Settings{ 114 Formula: test.formula, 115 OriginKnown: true, 116 OriginValue: r.F(x), 117 }) 118 if !floats.EqualApprox(gradient, trueGradient, test.tol) { 119 t.Errorf("Case %v: gradient mismatch with known origin in serial. Want: %v, Got: %v.", i, trueGradient, gradient) 120 } 121 122 // Try with concurrent evaluation. 123 for i := range gradient { 124 gradient[i] = rnd.Float64() 125 } 126 Gradient(gradient, r.F, x, &Settings{ 127 Formula: test.formula, 128 Concurrent: true, 129 }) 130 if !floats.EqualApprox(gradient, trueGradient, test.tol) { 131 t.Errorf("Case %v: gradient mismatch with unknown origin in parallel. Want: %v, Got: %v.", i, trueGradient, gradient) 132 } 133 if !floats.Equal(x, xcopy) { 134 t.Errorf("Case %v: x modified during call to gradient in parallel", i) 135 } 136 137 // Try with concurrent evaluation with origin known. 138 for i := range gradient { 139 gradient[i] = rnd.Float64() 140 } 141 Gradient(gradient, r.F, x, &Settings{ 142 Formula: test.formula, 143 Concurrent: true, 144 OriginKnown: true, 145 OriginValue: r.F(x), 146 }) 147 if !floats.EqualApprox(gradient, trueGradient, test.tol) { 148 t.Errorf("Case %v: gradient mismatch with known origin in parallel. Want: %v, Got: %v.", i, trueGradient, gradient) 149 } 150 151 // Try with nil settings. 152 for i := range gradient { 153 gradient[i] = rnd.Float64() 154 } 155 Gradient(gradient, r.F, x, nil) 156 if !floats.EqualApprox(gradient, trueGradient, test.tol) { 157 t.Errorf("Case %v: gradient mismatch with default settings. Want: %v, Got: %v.", i, trueGradient, gradient) 158 } 159 160 // Try with zero-valued settings. 161 for i := range gradient { 162 gradient[i] = rnd.Float64() 163 } 164 Gradient(gradient, r.F, x, &Settings{}) 165 if !floats.EqualApprox(gradient, trueGradient, test.tol) { 166 t.Errorf("Case %v: gradient mismatch with zero settings. Want: %v, Got: %v.", i, trueGradient, gradient) 167 } 168 } 169 } 170 171 func Panics(fun func()) (b bool) { 172 defer func() { 173 err := recover() 174 if err != nil { 175 b = true 176 } 177 }() 178 fun() 179 return 180 } 181 182 func TestGradientPanics(t *testing.T) { 183 t.Parallel() 184 // Test that it panics 185 if !Panics(func() { 186 Gradient([]float64{0.0}, func(x []float64) float64 { return x[0] * x[0] }, []float64{0.0, 0.0}, nil) 187 }) { 188 t.Errorf("Gradient did not panic with length mismatch") 189 } 190 }