github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/optimize/functions/validate.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package functions
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  
    11  	"github.com/jingcheng-WU/gonum/diff/fd"
    12  	"github.com/jingcheng-WU/gonum/floats"
    13  )
    14  
    15  // function represents an objective function.
    16  type function interface {
    17  	Func(x []float64) float64
    18  }
    19  
    20  type gradient interface {
    21  	Grad(grad, x []float64) []float64
    22  }
    23  
    24  // minimumer is an objective function that can also provide information about
    25  // its minima.
    26  type minimumer interface {
    27  	function
    28  
    29  	// Minima returns _known_ minima of the function.
    30  	Minima() []Minimum
    31  }
    32  
    33  // Minimum represents information about an optimal location of a function.
    34  type Minimum struct {
    35  	// X is the location of the minimum. X may not be nil.
    36  	X []float64
    37  	// F is the value of the objective function at X.
    38  	F float64
    39  	// Global indicates if the location is a global minimum.
    40  	Global bool
    41  }
    42  
    43  type funcTest struct {
    44  	X []float64
    45  
    46  	// F is the expected function value at X.
    47  	F float64
    48  	// Gradient is the expected gradient at X. If nil, it is not evaluated.
    49  	Gradient []float64
    50  }
    51  
    52  // TODO(vladimir-ch): Decide and implement an exported testing function:
    53  // func Test(f Function, ??? ) ??? {
    54  // }
    55  
    56  const (
    57  	defaultTol       = 1e-12
    58  	defaultGradTol   = 1e-9
    59  	defaultFDGradTol = 1e-5
    60  )
    61  
    62  // testFunction checks that the function can evaluate itself (and its gradient)
    63  // correctly.
    64  func testFunction(f function, ftests []funcTest, t *testing.T) {
    65  	// Make a copy of tests because we may append to the slice.
    66  	tests := make([]funcTest, len(ftests))
    67  	copy(tests, ftests)
    68  
    69  	// Get information about the function.
    70  	fMinima, isMinimumer := f.(minimumer)
    71  	fGradient, isGradient := f.(gradient)
    72  
    73  	// If the function is a Minimumer, append its minima to the tests.
    74  	if isMinimumer {
    75  		for _, minimum := range fMinima.Minima() {
    76  			// Allocate gradient only if the function can evaluate it.
    77  			var grad []float64
    78  			if isGradient {
    79  				grad = make([]float64, len(minimum.X))
    80  			}
    81  			tests = append(tests, funcTest{
    82  				X:        minimum.X,
    83  				F:        minimum.F,
    84  				Gradient: grad,
    85  			})
    86  		}
    87  	}
    88  
    89  	for i, test := range tests {
    90  		F := f.Func(test.X)
    91  
    92  		// Check that the function value is as expected.
    93  		if math.Abs(F-test.F) > defaultTol {
    94  			t.Errorf("Test #%d: function value given by Func is incorrect. Want: %v, Got: %v",
    95  				i, test.F, F)
    96  		}
    97  
    98  		if test.Gradient == nil {
    99  			continue
   100  		}
   101  
   102  		// Evaluate the finite difference gradient.
   103  		fdGrad := fd.Gradient(nil, f.Func, test.X, &fd.Settings{
   104  			Formula: fd.Central,
   105  			Step:    1e-6,
   106  		})
   107  
   108  		// Check that the finite difference and expected gradients match.
   109  		if !floats.EqualApprox(fdGrad, test.Gradient, defaultFDGradTol) {
   110  			dist := floats.Distance(fdGrad, test.Gradient, math.Inf(1))
   111  			t.Errorf("Test #%d: numerical and expected gradients do not match. |fdGrad - WantGrad|_∞ = %v",
   112  				i, dist)
   113  		}
   114  
   115  		// If the function is a Gradient, check that it computes the gradient correctly.
   116  		if isGradient {
   117  			grad := make([]float64, len(test.Gradient))
   118  			fGradient.Grad(grad, test.X)
   119  
   120  			if !floats.EqualApprox(grad, test.Gradient, defaultGradTol) {
   121  				dist := floats.Distance(grad, test.Gradient, math.Inf(1))
   122  				t.Errorf("Test #%d: gradient given by Grad is incorrect. |grad - WantGrad|_∞ = %v",
   123  					i, dist)
   124  			}
   125  		}
   126  	}
   127  }