github.com/gopherd/gonum@v0.0.4/diff/fd/hessian_test.go (about) 1 // Copyright ©2017 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package fd 6 7 import ( 8 "testing" 9 10 "github.com/gopherd/gonum/mat" 11 ) 12 13 type HessianTester interface { 14 Func(x []float64) float64 15 Grad(grad, x []float64) 16 Hess(dst mat.MutableSymmetric, x []float64) 17 } 18 19 type hessianTestCase struct { 20 h HessianTester 21 x []float64 22 settings *Settings 23 tol float64 24 } 25 26 var _hessianTestCases = []hessianTestCase{ 27 { 28 h: Watson{}, 29 x: []float64{0.2, 0.3, 0.1, 0.4}, 30 tol: 1e-3, 31 }, 32 { 33 h: Watson{}, 34 x: []float64{2, 3, 1, 4}, 35 tol: 1e-3, 36 settings: &Settings{ 37 Step: 1e-5, 38 Formula: Central, 39 }, 40 }, 41 { 42 h: Watson{}, 43 x: []float64{2, 3, 1}, 44 tol: 1e-3, 45 settings: &Settings{ 46 OriginKnown: true, 47 OriginValue: 7606.529501201192, 48 }, 49 }, 50 { 51 h: ConstFunc(5), 52 x: []float64{1, 9}, 53 tol: 1e-16, 54 }, 55 { 56 h: LinearFunc{w: []float64{10, 6, -1}, c: 5}, 57 x: []float64{3, 1, 8}, 58 tol: 1e-6, 59 }, 60 { 61 h: QuadFunc{ 62 a: mat.NewSymDense(3, []float64{ 63 10, 2, 1, 64 2, 5, -3, 65 1, -3, 6, 66 }), 67 b: mat.NewVecDense(3, []float64{3, -2, -1}), 68 c: 5, 69 }, 70 x: []float64{-1.6, -3, 2}, 71 tol: 1e-6, 72 }, 73 } 74 75 func hessianTestCases() []hessianTestCase { 76 xs := []hessianTestCase{} 77 for _, test := range _hessianTestCases { 78 n := test 79 if test.settings != nil { 80 clone := *test.settings 81 n.settings = &clone 82 } 83 xs = append(xs, n) 84 } 85 return xs 86 } 87 88 func TestHessian(t *testing.T) { 89 t.Parallel() 90 for cas, test := range hessianTestCases() { 91 n := len(test.x) 92 var got mat.SymDense 93 Hessian(&got, test.h.Func, test.x, test.settings) 94 want := mat.NewSymDense(n, nil) 95 test.h.Hess(want, test.x) 96 if !mat.EqualApprox(&got, want, test.tol) { 97 t.Errorf("Cas %d: Hessian mismatch\ngot=\n%0.4v\nwant=\n%0.4v\n", cas, mat.Formatted(&got), mat.Formatted(want)) 98 } 99 100 // Test that concurrency works. 101 settings := test.settings 102 if settings == nil { 103 settings = &Settings{} 104 } 105 settings.Concurrent = true 106 var got2 mat.SymDense 107 Hessian(&got2, test.h.Func, test.x, settings) 108 if !mat.EqualApprox(&got, &got2, 1e-5) { 109 t.Errorf("Cas %d: Hessian mismatch concurrent\ngot=\n%0.6v\nwant=\n%0.6v\n", cas, mat.Formatted(&got2), mat.Formatted(&got)) 110 } 111 } 112 }