github.com/gopherd/gonum@v0.0.4/diff/fd/hessian_test.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fd
     6  
     7  import (
     8  	"testing"
     9  
    10  	"github.com/gopherd/gonum/mat"
    11  )
    12  
    13  type HessianTester interface {
    14  	Func(x []float64) float64
    15  	Grad(grad, x []float64)
    16  	Hess(dst mat.MutableSymmetric, x []float64)
    17  }
    18  
    19  type hessianTestCase struct {
    20  	h        HessianTester
    21  	x        []float64
    22  	settings *Settings
    23  	tol      float64
    24  }
    25  
    26  var _hessianTestCases = []hessianTestCase{
    27  	{
    28  		h:   Watson{},
    29  		x:   []float64{0.2, 0.3, 0.1, 0.4},
    30  		tol: 1e-3,
    31  	},
    32  	{
    33  		h:   Watson{},
    34  		x:   []float64{2, 3, 1, 4},
    35  		tol: 1e-3,
    36  		settings: &Settings{
    37  			Step:    1e-5,
    38  			Formula: Central,
    39  		},
    40  	},
    41  	{
    42  		h:   Watson{},
    43  		x:   []float64{2, 3, 1},
    44  		tol: 1e-3,
    45  		settings: &Settings{
    46  			OriginKnown: true,
    47  			OriginValue: 7606.529501201192,
    48  		},
    49  	},
    50  	{
    51  		h:   ConstFunc(5),
    52  		x:   []float64{1, 9},
    53  		tol: 1e-16,
    54  	},
    55  	{
    56  		h:   LinearFunc{w: []float64{10, 6, -1}, c: 5},
    57  		x:   []float64{3, 1, 8},
    58  		tol: 1e-6,
    59  	},
    60  	{
    61  		h: QuadFunc{
    62  			a: mat.NewSymDense(3, []float64{
    63  				10, 2, 1,
    64  				2, 5, -3,
    65  				1, -3, 6,
    66  			}),
    67  			b: mat.NewVecDense(3, []float64{3, -2, -1}),
    68  			c: 5,
    69  		},
    70  		x:   []float64{-1.6, -3, 2},
    71  		tol: 1e-6,
    72  	},
    73  }
    74  
    75  func hessianTestCases() []hessianTestCase {
    76  	xs := []hessianTestCase{}
    77  	for _, test := range _hessianTestCases {
    78  		n := test
    79  		if test.settings != nil {
    80  			clone := *test.settings
    81  			n.settings = &clone
    82  		}
    83  		xs = append(xs, n)
    84  	}
    85  	return xs
    86  }
    87  
    88  func TestHessian(t *testing.T) {
    89  	t.Parallel()
    90  	for cas, test := range hessianTestCases() {
    91  		n := len(test.x)
    92  		var got mat.SymDense
    93  		Hessian(&got, test.h.Func, test.x, test.settings)
    94  		want := mat.NewSymDense(n, nil)
    95  		test.h.Hess(want, test.x)
    96  		if !mat.EqualApprox(&got, want, test.tol) {
    97  			t.Errorf("Cas %d: Hessian mismatch\ngot=\n%0.4v\nwant=\n%0.4v\n", cas, mat.Formatted(&got), mat.Formatted(want))
    98  		}
    99  
   100  		// Test that concurrency works.
   101  		settings := test.settings
   102  		if settings == nil {
   103  			settings = &Settings{}
   104  		}
   105  		settings.Concurrent = true
   106  		var got2 mat.SymDense
   107  		Hessian(&got2, test.h.Func, test.x, settings)
   108  		if !mat.EqualApprox(&got, &got2, 1e-5) {
   109  			t.Errorf("Cas %d: Hessian mismatch concurrent\ngot=\n%0.6v\nwant=\n%0.6v\n", cas, mat.Formatted(&got2), mat.Formatted(&got))
   110  		}
   111  	}
   112  }