github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/diff/fd/gradient_test.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fd
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"github.com/jingcheng-WU/gonum/floats"
    14  )
    15  
    16  type Rosenbrock struct {
    17  	nDim int
    18  }
    19  
    20  func (r Rosenbrock) F(x []float64) (sum float64) {
    21  	deriv := make([]float64, len(x))
    22  	return r.FDf(x, deriv)
    23  }
    24  
    25  func (r Rosenbrock) FDf(x []float64, deriv []float64) (sum float64) {
    26  	for i := range deriv {
    27  		deriv[i] = 0
    28  	}
    29  
    30  	for i := 0; i < len(x)-1; i++ {
    31  		sum += math.Pow(1-x[i], 2) + 100*math.Pow(x[i+1]-math.Pow(x[i], 2), 2)
    32  	}
    33  	for i := 0; i < len(x)-1; i++ {
    34  		deriv[i] += -1 * 2 * (1 - x[i])
    35  		deriv[i] += 2 * 100 * (x[i+1] - math.Pow(x[i], 2)) * (-2 * x[i])
    36  	}
    37  	for i := 1; i < len(x); i++ {
    38  		deriv[i] += 2 * 100 * (x[i] - math.Pow(x[i-1], 2))
    39  	}
    40  
    41  	return sum
    42  }
    43  
    44  func TestGradient(t *testing.T) {
    45  	t.Parallel()
    46  	rnd := rand.New(rand.NewSource(1))
    47  	for i, test := range []struct {
    48  		nDim    int
    49  		tol     float64
    50  		formula Formula
    51  	}{
    52  		{
    53  			nDim:    2,
    54  			tol:     2e-4,
    55  			formula: Forward,
    56  		},
    57  		{
    58  			nDim:    2,
    59  			tol:     1e-6,
    60  			formula: Central,
    61  		},
    62  		{
    63  			nDim:    40,
    64  			tol:     2e-4,
    65  			formula: Forward,
    66  		},
    67  		{
    68  			nDim:    40,
    69  			tol:     1e-5,
    70  			formula: Central,
    71  		},
    72  	} {
    73  		x := make([]float64, test.nDim)
    74  		for i := range x {
    75  			x[i] = rnd.Float64()
    76  		}
    77  		xcopy := make([]float64, len(x))
    78  		copy(xcopy, x)
    79  
    80  		r := Rosenbrock{len(x)}
    81  		trueGradient := make([]float64, len(x))
    82  		r.FDf(x, trueGradient)
    83  
    84  		// Try with gradient nil.
    85  		gradient := Gradient(nil, r.F, x, &Settings{
    86  			Formula: test.formula,
    87  		})
    88  		if !floats.EqualApprox(gradient, trueGradient, test.tol) {
    89  			t.Errorf("Case %v: gradient mismatch in serial with nil. Want: %v, Got: %v.", i, trueGradient, gradient)
    90  		}
    91  		if !floats.Equal(x, xcopy) {
    92  			t.Errorf("Case %v: x modified during call to gradient in serial with nil.", i)
    93  		}
    94  
    95  		// Try with provided gradient.
    96  		for i := range gradient {
    97  			gradient[i] = rnd.Float64()
    98  		}
    99  		Gradient(gradient, r.F, x, &Settings{
   100  			Formula: test.formula,
   101  		})
   102  		if !floats.EqualApprox(gradient, trueGradient, test.tol) {
   103  			t.Errorf("Case %v: gradient mismatch in serial. Want: %v, Got: %v.", i, trueGradient, gradient)
   104  		}
   105  		if !floats.Equal(x, xcopy) {
   106  			t.Errorf("Case %v: x modified during call to gradient in serial with non-nil.", i)
   107  		}
   108  
   109  		// Try with known value.
   110  		for i := range gradient {
   111  			gradient[i] = rnd.Float64()
   112  		}
   113  		Gradient(gradient, r.F, x, &Settings{
   114  			Formula:     test.formula,
   115  			OriginKnown: true,
   116  			OriginValue: r.F(x),
   117  		})
   118  		if !floats.EqualApprox(gradient, trueGradient, test.tol) {
   119  			t.Errorf("Case %v: gradient mismatch with known origin in serial. Want: %v, Got: %v.", i, trueGradient, gradient)
   120  		}
   121  
   122  		// Try with concurrent evaluation.
   123  		for i := range gradient {
   124  			gradient[i] = rnd.Float64()
   125  		}
   126  		Gradient(gradient, r.F, x, &Settings{
   127  			Formula:    test.formula,
   128  			Concurrent: true,
   129  		})
   130  		if !floats.EqualApprox(gradient, trueGradient, test.tol) {
   131  			t.Errorf("Case %v: gradient mismatch with unknown origin in parallel. Want: %v, Got: %v.", i, trueGradient, gradient)
   132  		}
   133  		if !floats.Equal(x, xcopy) {
   134  			t.Errorf("Case %v: x modified during call to gradient in parallel", i)
   135  		}
   136  
   137  		// Try with concurrent evaluation with origin known.
   138  		for i := range gradient {
   139  			gradient[i] = rnd.Float64()
   140  		}
   141  		Gradient(gradient, r.F, x, &Settings{
   142  			Formula:     test.formula,
   143  			Concurrent:  true,
   144  			OriginKnown: true,
   145  			OriginValue: r.F(x),
   146  		})
   147  		if !floats.EqualApprox(gradient, trueGradient, test.tol) {
   148  			t.Errorf("Case %v: gradient mismatch with known origin in parallel. Want: %v, Got: %v.", i, trueGradient, gradient)
   149  		}
   150  
   151  		// Try with nil settings.
   152  		for i := range gradient {
   153  			gradient[i] = rnd.Float64()
   154  		}
   155  		Gradient(gradient, r.F, x, nil)
   156  		if !floats.EqualApprox(gradient, trueGradient, test.tol) {
   157  			t.Errorf("Case %v: gradient mismatch with default settings. Want: %v, Got: %v.", i, trueGradient, gradient)
   158  		}
   159  
   160  		// Try with zero-valued settings.
   161  		for i := range gradient {
   162  			gradient[i] = rnd.Float64()
   163  		}
   164  		Gradient(gradient, r.F, x, &Settings{})
   165  		if !floats.EqualApprox(gradient, trueGradient, test.tol) {
   166  			t.Errorf("Case %v: gradient mismatch with zero settings. Want: %v, Got: %v.", i, trueGradient, gradient)
   167  		}
   168  	}
   169  }
   170  
   171  func Panics(fun func()) (b bool) {
   172  	defer func() {
   173  		err := recover()
   174  		if err != nil {
   175  			b = true
   176  		}
   177  	}()
   178  	fun()
   179  	return
   180  }
   181  
   182  func TestGradientPanics(t *testing.T) {
   183  	t.Parallel()
   184  	// Test that it panics
   185  	if !Panics(func() {
   186  		Gradient([]float64{0.0}, func(x []float64) float64 { return x[0] * x[0] }, []float64{0.0, 0.0}, nil)
   187  	}) {
   188  		t.Errorf("Gradient did not panic with length mismatch")
   189  	}
   190  }