github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/diff/fd/crosslaplacian_test.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fd
     6  
     7  import (
     8  	"testing"
     9  
    10  	"github.com/jingcheng-WU/gonum/floats/scalar"
    11  	"github.com/jingcheng-WU/gonum/mat"
    12  )
    13  
    14  type CrossLaplacianTester interface {
    15  	Func(x, y []float64) float64
    16  	CrossLaplacian(x, y []float64) float64
    17  }
    18  
    19  type WrapperCL struct {
    20  	Tester HessianTester
    21  }
    22  
    23  func (WrapperCL) constructZ(x, y []float64) []float64 {
    24  	z := make([]float64, len(x)+len(y))
    25  	copy(z, x)
    26  	copy(z[len(x):], y)
    27  	return z
    28  }
    29  
    30  func (w WrapperCL) Func(x, y []float64) float64 {
    31  	z := w.constructZ(x, y)
    32  	return w.Tester.Func(z)
    33  }
    34  
    35  func (w WrapperCL) CrossLaplacian(x, y []float64) float64 {
    36  	z := w.constructZ(x, y)
    37  	hess := mat.NewSymDense(len(z), nil)
    38  	w.Tester.Hess(hess, z)
    39  	// The CrossLaplacian is the trace of the off-diagonal block of the Hessian.
    40  	var l float64
    41  	for i := 0; i < len(x); i++ {
    42  		l += hess.At(i, i+len(x))
    43  	}
    44  	return l
    45  }
    46  
    47  func TestCrossLaplacian(t *testing.T) {
    48  	t.Parallel()
    49  	for cas, test := range []struct {
    50  		l        CrossLaplacianTester
    51  		x, y     []float64
    52  		settings *Settings
    53  		tol      float64
    54  	}{
    55  		{
    56  			l:   WrapperCL{Watson{}},
    57  			x:   []float64{0.2, 0.3},
    58  			y:   []float64{0.1, 0.4},
    59  			tol: 1e-3,
    60  		},
    61  		{
    62  			l:   WrapperCL{Watson{}},
    63  			x:   []float64{2, 3, 1},
    64  			y:   []float64{1, 4, 1},
    65  			tol: 1e-3,
    66  		},
    67  		{
    68  			l:   WrapperCL{ConstFunc(6)},
    69  			x:   []float64{2, -3, 1},
    70  			y:   []float64{1, 4, -5},
    71  			tol: 1e-6,
    72  		},
    73  		{
    74  			l:   WrapperCL{LinearFunc{w: []float64{10, 6, -1, 5}, c: 5}},
    75  			x:   []float64{3, 1},
    76  			y:   []float64{8, 6},
    77  			tol: 1e-6,
    78  		},
    79  		{
    80  			l: WrapperCL{QuadFunc{
    81  				a: mat.NewSymDense(4, []float64{
    82  					10, 2, 1, 9,
    83  					2, 5, -3, 4,
    84  					1, -3, 6, 2,
    85  					9, 4, 2, -14,
    86  				}),
    87  				b: mat.NewVecDense(4, []float64{3, -2, -1, 4}),
    88  				c: 5,
    89  			}},
    90  			x:   []float64{-1.6, -3},
    91  			y:   []float64{1.8, 3.4},
    92  			tol: 1e-6,
    93  		},
    94  	} {
    95  		got := CrossLaplacian(test.l.Func, test.x, test.y, test.settings)
    96  		want := test.l.CrossLaplacian(test.x, test.y)
    97  		if !scalar.EqualWithinAbsOrRel(got, want, test.tol, test.tol) {
    98  			t.Errorf("Cas %d: CrossLaplacian mismatch serial. got %v, want %v", cas, got, want)
    99  		}
   100  
   101  		// Test that concurrency works.
   102  		settings := test.settings
   103  		if settings == nil {
   104  			settings = &Settings{}
   105  		}
   106  		settings.Concurrent = true
   107  		got2 := CrossLaplacian(test.l.Func, test.x, test.y, settings)
   108  		if !scalar.EqualWithinAbsOrRel(got, got2, 1e-6, 1e-6) {
   109  			t.Errorf("Cas %d: Laplacian mismatch. got %v, want %v", cas, got2, got)
   110  		}
   111  	}
   112  }