github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/optimize/cmaes_test.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package optimize
     6  
     7  import (
     8  	"errors"
     9  	"math"
    10  	"testing"
    11  
    12  	"golang.org/x/exp/rand"
    13  
    14  	"github.com/jingcheng-WU/gonum/floats"
    15  	"github.com/jingcheng-WU/gonum/mat"
    16  	"github.com/jingcheng-WU/gonum/optimize/functions"
    17  )
    18  
    19  type functionThresholdConverger struct {
    20  	Threshold float64
    21  }
    22  
    23  func (functionThresholdConverger) Init(dim int) {}
    24  
    25  func (f functionThresholdConverger) Converged(loc *Location) Status {
    26  	if loc.F < f.Threshold {
    27  		return FunctionThreshold
    28  	}
    29  	return NotTerminated
    30  }
    31  
    32  type cmaTestCase struct {
    33  	dim      int
    34  	problem  Problem
    35  	method   *CmaEsChol
    36  	initX    []float64
    37  	settings *Settings
    38  	good     func(result *Result, err error, concurrent int) error
    39  }
    40  
    41  func cmaTestCases() []cmaTestCase {
    42  	localMinMean := []float64{2.2, -2.2}
    43  	s := mat.NewSymDense(2, []float64{0.01, 0, 0, 0.01})
    44  	var localMinChol mat.Cholesky
    45  	localMinChol.Factorize(s)
    46  	return []cmaTestCase{
    47  		{
    48  			// Test that can find a small value.
    49  			dim: 10,
    50  			problem: Problem{
    51  				Func: functions.ExtendedRosenbrock{}.Func,
    52  			},
    53  			method: &CmaEsChol{
    54  				StopLogDet: math.NaN(),
    55  			},
    56  			settings: &Settings{
    57  				Converger: functionThresholdConverger{0.01},
    58  			},
    59  			good: func(result *Result, err error, concurrent int) error {
    60  				if result.Status != FunctionThreshold {
    61  					return errors.New("result not function threshold")
    62  				}
    63  				if result.F > 0.01 {
    64  					return errors.New("result not sufficiently small")
    65  				}
    66  				return nil
    67  			},
    68  		},
    69  		{
    70  			// Test that can stop when the covariance gets small.
    71  			// For this case, also test that it is really at a minimum.
    72  			dim: 2,
    73  			problem: Problem{
    74  				Func: functions.ExtendedRosenbrock{}.Func,
    75  			},
    76  			method: &CmaEsChol{},
    77  			settings: &Settings{
    78  				Converger: NeverTerminate{},
    79  			},
    80  			good: func(result *Result, err error, concurrent int) error {
    81  				if result.Status != MethodConverge {
    82  					return errors.New("result not method converge")
    83  				}
    84  				if result.F > 1e-12 {
    85  					return errors.New("minimum not found")
    86  				}
    87  				return nil
    88  			},
    89  		},
    90  		{
    91  			// Test that population works properly and it stops after a certain
    92  			// number of iterations.
    93  			dim: 3,
    94  			problem: Problem{
    95  				Func: functions.ExtendedRosenbrock{}.Func,
    96  			},
    97  			method: &CmaEsChol{
    98  				Population: 100,
    99  				ForgetBest: true, // Otherwise may get an update at the end.
   100  			},
   101  			settings: &Settings{
   102  				MajorIterations: 10,
   103  				Converger:       NeverTerminate{},
   104  			},
   105  			good: func(result *Result, err error, concurrent int) error {
   106  				if result.Status != IterationLimit {
   107  					return errors.New("result not iteration limit")
   108  				}
   109  				threshLower := 10
   110  				threshUpper := 10
   111  				if concurrent != 0 {
   112  					// Could have one more from final update.
   113  					threshUpper++
   114  				}
   115  				if result.MajorIterations < threshLower || result.MajorIterations > threshUpper {
   116  					return errors.New("wrong number of iterations")
   117  				}
   118  				return nil
   119  			},
   120  		},
   121  		{
   122  			// Test that work stops with some number of function evaluations.
   123  			dim: 5,
   124  			problem: Problem{
   125  				Func: functions.ExtendedRosenbrock{}.Func,
   126  			},
   127  			method: &CmaEsChol{
   128  				Population: 100,
   129  			},
   130  			settings: &Settings{
   131  				FuncEvaluations: 250, // Somewhere in the middle of an iteration.
   132  				Converger:       NeverTerminate{},
   133  			},
   134  			good: func(result *Result, err error, concurrent int) error {
   135  				if result.Status != FunctionEvaluationLimit {
   136  					return errors.New("result not function evaluations")
   137  				}
   138  				threshLower := 250
   139  				threshUpper := 251
   140  				if concurrent != 0 {
   141  					threshUpper = threshLower + concurrent
   142  				}
   143  				if result.FuncEvaluations < threshLower {
   144  					return errors.New("too few function evaluations")
   145  				}
   146  				if result.FuncEvaluations > threshUpper {
   147  					return errors.New("too many function evaluations")
   148  				}
   149  				return nil
   150  			},
   151  		},
   152  		{
   153  			// Test that the global minimum is found with the right initialization.
   154  			dim: 2,
   155  			problem: Problem{
   156  				Func: functions.Rastrigin{}.Func,
   157  			},
   158  			method: &CmaEsChol{
   159  				Population: 100, // Increase the population size to reduce noise.
   160  			},
   161  			settings: &Settings{
   162  				Converger: NeverTerminate{},
   163  			},
   164  			good: func(result *Result, err error, concurrent int) error {
   165  				if result.Status != MethodConverge {
   166  					return errors.New("result not method converge")
   167  				}
   168  				if !floats.EqualApprox(result.X, []float64{0, 0}, 1e-6) {
   169  					return errors.New("global minimum not found")
   170  				}
   171  				return nil
   172  			},
   173  		},
   174  		{
   175  			// Test that a local minimum is found (with a different initialization).
   176  			dim: 2,
   177  			problem: Problem{
   178  				Func: functions.Rastrigin{}.Func,
   179  			},
   180  			initX: localMinMean,
   181  			method: &CmaEsChol{
   182  				Population:   100, // Increase the population size to reduce noise.
   183  				InitCholesky: &localMinChol,
   184  				ForgetBest:   true, // So that if it accidentally finds a better place we still converge to the minimum.
   185  			},
   186  			settings: &Settings{
   187  				Converger: NeverTerminate{},
   188  			},
   189  			good: func(result *Result, err error, concurrent int) error {
   190  				if result.Status != MethodConverge {
   191  					return errors.New("result not method converge")
   192  				}
   193  				if !floats.EqualApprox(result.X, []float64{2, -2}, 3e-2) {
   194  					return errors.New("local minimum not found")
   195  				}
   196  				return nil
   197  			},
   198  		},
   199  	}
   200  }
   201  
   202  func TestCmaEsChol(t *testing.T) {
   203  	t.Parallel()
   204  	for i, test := range cmaTestCases() {
   205  		src := rand.New(rand.NewSource(1))
   206  		method := test.method
   207  		method.Src = src
   208  		initX := test.initX
   209  		if initX == nil {
   210  			initX = make([]float64, test.dim)
   211  		}
   212  		// Run and check that the expected termination occurs.
   213  		result, err := Minimize(test.problem, initX, test.settings, method)
   214  		if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil {
   215  			t.Errorf("cas %d: %v", i, testErr)
   216  		}
   217  
   218  		// Run a second time to make sure there are no residual effects
   219  		result, err = Minimize(test.problem, initX, test.settings, method)
   220  		if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil {
   221  			t.Errorf("cas %d second: %v", i, testErr)
   222  		}
   223  
   224  		// Test the problem in parallel.
   225  		test.settings.Concurrent = 5
   226  		result, err = Minimize(test.problem, initX, test.settings, method)
   227  		if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil {
   228  			t.Errorf("cas %d concurrent: %v", i, testErr)
   229  		}
   230  		test.settings.Concurrent = 0
   231  	}
   232  }