github.com/gopherd/gonum@v0.0.4/diff/fd/crosslaplacian_test.go (about) 1 // Copyright ©2017 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package fd 6 7 import ( 8 "testing" 9 10 "github.com/gopherd/gonum/floats/scalar" 11 "github.com/gopherd/gonum/mat" 12 ) 13 14 type CrossLaplacianTester interface { 15 Func(x, y []float64) float64 16 CrossLaplacian(x, y []float64) float64 17 } 18 19 type WrapperCL struct { 20 Tester HessianTester 21 } 22 23 func (WrapperCL) constructZ(x, y []float64) []float64 { 24 z := make([]float64, len(x)+len(y)) 25 copy(z, x) 26 copy(z[len(x):], y) 27 return z 28 } 29 30 func (w WrapperCL) Func(x, y []float64) float64 { 31 z := w.constructZ(x, y) 32 return w.Tester.Func(z) 33 } 34 35 func (w WrapperCL) CrossLaplacian(x, y []float64) float64 { 36 z := w.constructZ(x, y) 37 hess := mat.NewSymDense(len(z), nil) 38 w.Tester.Hess(hess, z) 39 // The CrossLaplacian is the trace of the off-diagonal block of the Hessian. 40 var l float64 41 for i := 0; i < len(x); i++ { 42 l += hess.At(i, i+len(x)) 43 } 44 return l 45 } 46 47 func TestCrossLaplacian(t *testing.T) { 48 t.Parallel() 49 for cas, test := range []struct { 50 l CrossLaplacianTester 51 x, y []float64 52 settings *Settings 53 tol float64 54 }{ 55 { 56 l: WrapperCL{Watson{}}, 57 x: []float64{0.2, 0.3}, 58 y: []float64{0.1, 0.4}, 59 tol: 1e-3, 60 }, 61 { 62 l: WrapperCL{Watson{}}, 63 x: []float64{2, 3, 1}, 64 y: []float64{1, 4, 1}, 65 tol: 1e-3, 66 }, 67 { 68 l: WrapperCL{ConstFunc(6)}, 69 x: []float64{2, -3, 1}, 70 y: []float64{1, 4, -5}, 71 tol: 1e-6, 72 }, 73 { 74 l: WrapperCL{LinearFunc{w: []float64{10, 6, -1, 5}, c: 5}}, 75 x: []float64{3, 1}, 76 y: []float64{8, 6}, 77 tol: 1e-6, 78 }, 79 { 80 l: WrapperCL{QuadFunc{ 81 a: mat.NewSymDense(4, []float64{ 82 10, 2, 1, 9, 83 2, 5, -3, 4, 84 1, -3, 6, 2, 85 9, 4, 2, -14, 86 }), 87 b: mat.NewVecDense(4, []float64{3, -2, -1, 4}), 88 c: 5, 89 }}, 90 x: []float64{-1.6, -3}, 91 y: []float64{1.8, 3.4}, 92 tol: 1e-6, 93 }, 94 } { 95 got := CrossLaplacian(test.l.Func, test.x, test.y, test.settings) 96 want := test.l.CrossLaplacian(test.x, test.y) 97 if !scalar.EqualWithinAbsOrRel(got, want, test.tol, test.tol) { 98 t.Errorf("Cas %d: CrossLaplacian mismatch serial. got %v, want %v", cas, got, want) 99 } 100 101 // Test that concurrency works. 102 settings := test.settings 103 if settings == nil { 104 settings = &Settings{} 105 } 106 settings.Concurrent = true 107 got2 := CrossLaplacian(test.l.Func, test.x, test.y, settings) 108 if !scalar.EqualWithinAbsOrRel(got, got2, 1e-6, 1e-6) { 109 t.Errorf("Cas %d: Laplacian mismatch. got %v, want %v", cas, got2, got) 110 } 111 } 112 }