gonum.org/v1/gonum@v0.14.0/optimize/cmaes_test.go (about) 1 // Copyright ©2017 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package optimize 6 7 import ( 8 "errors" 9 "math" 10 "testing" 11 12 "golang.org/x/exp/rand" 13 14 "gonum.org/v1/gonum/floats" 15 "gonum.org/v1/gonum/mat" 16 "gonum.org/v1/gonum/optimize/functions" 17 ) 18 19 type functionThresholdConverger struct { 20 Threshold float64 21 } 22 23 func (functionThresholdConverger) Init(dim int) {} 24 25 func (f functionThresholdConverger) Converged(loc *Location) Status { 26 if loc.F < f.Threshold { 27 return FunctionThreshold 28 } 29 return NotTerminated 30 } 31 32 type cmaTestCase struct { 33 dim int 34 problem Problem 35 method *CmaEsChol 36 initX []float64 37 settings *Settings 38 good func(result *Result, err error, concurrent int) error 39 } 40 41 func cmaTestCases() []cmaTestCase { 42 localMinMean := []float64{2.2, -2.2} 43 s := mat.NewSymDense(2, []float64{0.01, 0, 0, 0.01}) 44 var localMinChol mat.Cholesky 45 localMinChol.Factorize(s) 46 return []cmaTestCase{ 47 { 48 // Test that can find a small value. 49 dim: 10, 50 problem: Problem{ 51 Func: functions.ExtendedRosenbrock{}.Func, 52 }, 53 method: &CmaEsChol{ 54 StopLogDet: math.NaN(), 55 }, 56 settings: &Settings{ 57 Converger: functionThresholdConverger{0.01}, 58 }, 59 good: func(result *Result, err error, concurrent int) error { 60 if result.Status != FunctionThreshold { 61 return errors.New("result not function threshold") 62 } 63 if result.F > 0.01 { 64 return errors.New("result not sufficiently small") 65 } 66 return nil 67 }, 68 }, 69 { 70 // Test that can stop when the covariance gets small. 71 // For this case, also test that it is really at a minimum. 72 dim: 2, 73 problem: Problem{ 74 Func: functions.ExtendedRosenbrock{}.Func, 75 }, 76 method: &CmaEsChol{}, 77 settings: &Settings{ 78 Converger: NeverTerminate{}, 79 }, 80 good: func(result *Result, err error, concurrent int) error { 81 if result.Status != MethodConverge { 82 return errors.New("result not method converge") 83 } 84 if result.F > 1e-12 { 85 return errors.New("minimum not found") 86 } 87 return nil 88 }, 89 }, 90 { 91 // Test that population works properly and it stops after a certain 92 // number of iterations. 93 dim: 3, 94 problem: Problem{ 95 Func: functions.ExtendedRosenbrock{}.Func, 96 }, 97 method: &CmaEsChol{ 98 Population: 100, 99 ForgetBest: true, // Otherwise may get an update at the end. 100 }, 101 settings: &Settings{ 102 MajorIterations: 10, 103 Converger: NeverTerminate{}, 104 }, 105 good: func(result *Result, err error, concurrent int) error { 106 if result.Status != IterationLimit { 107 return errors.New("result not iteration limit") 108 } 109 threshLower := 10 110 threshUpper := 10 111 if concurrent != 0 { 112 // Could have one more from final update. 113 threshUpper++ 114 } 115 if result.MajorIterations < threshLower || result.MajorIterations > threshUpper { 116 return errors.New("wrong number of iterations") 117 } 118 return nil 119 }, 120 }, 121 { 122 // Test that work stops with some number of function evaluations. 123 dim: 5, 124 problem: Problem{ 125 Func: functions.ExtendedRosenbrock{}.Func, 126 }, 127 method: &CmaEsChol{ 128 Population: 100, 129 }, 130 settings: &Settings{ 131 FuncEvaluations: 250, // Somewhere in the middle of an iteration. 132 Converger: NeverTerminate{}, 133 }, 134 good: func(result *Result, err error, concurrent int) error { 135 if result.Status != FunctionEvaluationLimit { 136 return errors.New("result not function evaluations") 137 } 138 threshLower := 250 139 threshUpper := 251 140 if concurrent != 0 { 141 threshUpper = threshLower + concurrent 142 } 143 if result.FuncEvaluations < threshLower { 144 return errors.New("too few function evaluations") 145 } 146 if result.FuncEvaluations > threshUpper { 147 return errors.New("too many function evaluations") 148 } 149 return nil 150 }, 151 }, 152 { 153 // Test that the global minimum is found with the right initialization. 154 dim: 2, 155 problem: Problem{ 156 Func: functions.Rastrigin{}.Func, 157 }, 158 method: &CmaEsChol{ 159 Population: 100, // Increase the population size to reduce noise. 160 }, 161 settings: &Settings{ 162 Converger: NeverTerminate{}, 163 }, 164 good: func(result *Result, err error, concurrent int) error { 165 if result.Status != MethodConverge { 166 return errors.New("result not method converge") 167 } 168 if !floats.EqualApprox(result.X, []float64{0, 0}, 1e-6) { 169 return errors.New("global minimum not found") 170 } 171 return nil 172 }, 173 }, 174 { 175 // Test that a local minimum is found (with a different initialization). 176 dim: 2, 177 problem: Problem{ 178 Func: functions.Rastrigin{}.Func, 179 }, 180 initX: localMinMean, 181 method: &CmaEsChol{ 182 Population: 100, // Increase the population size to reduce noise. 183 InitCholesky: &localMinChol, 184 ForgetBest: true, // So that if it accidentally finds a better place we still converge to the minimum. 185 }, 186 settings: &Settings{ 187 Converger: NeverTerminate{}, 188 }, 189 good: func(result *Result, err error, concurrent int) error { 190 if result.Status != MethodConverge { 191 return errors.New("result not method converge") 192 } 193 if !floats.EqualApprox(result.X, []float64{2, -2}, 3e-2) { 194 return errors.New("local minimum not found") 195 } 196 return nil 197 }, 198 }, 199 } 200 } 201 202 func TestCmaEsChol(t *testing.T) { 203 t.Parallel() 204 for i, test := range cmaTestCases() { 205 src := rand.New(rand.NewSource(1)) 206 method := test.method 207 method.Src = src 208 initX := test.initX 209 if initX == nil { 210 initX = make([]float64, test.dim) 211 } 212 // Run and check that the expected termination occurs. 213 result, err := Minimize(test.problem, initX, test.settings, method) 214 if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil { 215 t.Errorf("cas %d: %v", i, testErr) 216 } 217 218 // Run a second time to make sure there are no residual effects 219 result, err = Minimize(test.problem, initX, test.settings, method) 220 if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil { 221 t.Errorf("cas %d second: %v", i, testErr) 222 } 223 224 // Test the problem in parallel. 225 test.settings.Concurrent = 5 226 result, err = Minimize(test.problem, initX, test.settings, method) 227 if testErr := test.good(result, err, test.settings.Concurrent); testErr != nil { 228 t.Errorf("cas %d concurrent: %v", i, testErr) 229 } 230 test.settings.Concurrent = 0 231 } 232 }