github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/optimize/neldermead.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package optimize 6 7 import ( 8 "math" 9 "sort" 10 11 "github.com/jingcheng-WU/gonum/floats" 12 ) 13 14 // nmIterType is a Nelder-Mead evaluation kind 15 type nmIterType int 16 17 const ( 18 nmReflected = iota 19 nmExpanded 20 nmContractedInside 21 nmContractedOutside 22 nmInitialize 23 nmShrink 24 nmMajor 25 ) 26 27 type nmVertexSorter struct { 28 vertices [][]float64 29 values []float64 30 } 31 32 func (n nmVertexSorter) Len() int { 33 return len(n.values) 34 } 35 36 func (n nmVertexSorter) Less(i, j int) bool { 37 return n.values[i] < n.values[j] 38 } 39 40 func (n nmVertexSorter) Swap(i, j int) { 41 n.values[i], n.values[j] = n.values[j], n.values[i] 42 n.vertices[i], n.vertices[j] = n.vertices[j], n.vertices[i] 43 } 44 45 var _ Method = (*NelderMead)(nil) 46 47 // NelderMead is an implementation of the Nelder-Mead simplex algorithm for 48 // gradient-free nonlinear optimization (not to be confused with Danzig's 49 // simplex algorithm for linear programming). The implementation follows the 50 // algorithm described in 51 // 52 // http://epubs.siam.org/doi/pdf/10.1137/S1052623496303470 53 // 54 // If an initial simplex is provided, it is used and initLoc is ignored. If 55 // InitialVertices and InitialValues are both nil, an initial simplex will be 56 // generated automatically using the initial location as one vertex, and each 57 // additional vertex as SimplexSize away in one dimension. 58 // 59 // If the simplex update parameters (Reflection, etc.) 60 // are zero, they will be set automatically based on the dimension according to 61 // the recommendations in 62 // 63 // http://www.webpages.uidaho.edu/~fuchang/res/ANMS.pdf 64 type NelderMead struct { 65 InitialVertices [][]float64 66 InitialValues []float64 67 Reflection float64 // Reflection parameter (>0) 68 Expansion float64 // Expansion parameter (>1) 69 Contraction float64 // Contraction parameter (>0, <1) 70 Shrink float64 // Shrink parameter (>0, <1) 71 SimplexSize float64 // size of auto-constructed initial simplex 72 73 status Status 74 err error 75 76 reflection float64 77 expansion float64 78 contraction float64 79 shrink float64 80 81 vertices [][]float64 // location of the vertices sorted in ascending f 82 values []float64 // function values at the vertices sorted in ascending f 83 centroid []float64 // centroid of all but the worst vertex 84 85 fillIdx int // index for filling the simplex during initialization and shrinking 86 lastIter nmIterType // Last iteration 87 reflectedPoint []float64 // Storage of the reflected point location 88 reflectedValue float64 // Value at the last reflection point 89 } 90 91 func (n *NelderMead) Status() (Status, error) { 92 return n.status, n.err 93 } 94 95 func (*NelderMead) Uses(has Available) (uses Available, err error) { 96 return has.function() 97 } 98 99 func (n *NelderMead) Init(dim, tasks int) int { 100 n.status = NotTerminated 101 n.err = nil 102 return 1 103 } 104 105 func (n *NelderMead) Run(operation chan<- Task, result <-chan Task, tasks []Task) { 106 n.status, n.err = localOptimizer{}.run(n, math.NaN(), operation, result, tasks) 107 close(operation) 108 } 109 110 func (n *NelderMead) initLocal(loc *Location) (Operation, error) { 111 dim := len(loc.X) 112 if cap(n.vertices) < dim+1 { 113 n.vertices = make([][]float64, dim+1) 114 } 115 n.vertices = n.vertices[:dim+1] 116 for i := range n.vertices { 117 n.vertices[i] = resize(n.vertices[i], dim) 118 } 119 n.values = resize(n.values, dim+1) 120 n.centroid = resize(n.centroid, dim) 121 n.reflectedPoint = resize(n.reflectedPoint, dim) 122 123 if n.SimplexSize == 0 { 124 n.SimplexSize = 0.05 125 } 126 127 // Default parameter choices are chosen in a dimension-dependent way 128 // from http://www.webpages.uidaho.edu/~fuchang/res/ANMS.pdf 129 n.reflection = n.Reflection 130 if n.reflection == 0 { 131 n.reflection = 1 132 } 133 n.expansion = n.Expansion 134 if n.expansion == 0 { 135 n.expansion = 1 + 2/float64(dim) 136 if dim == 1 { 137 n.expansion = 2 138 } 139 } 140 n.contraction = n.Contraction 141 if n.contraction == 0 { 142 n.contraction = 0.75 - 1/(2*float64(dim)) 143 if dim == 1 { 144 n.contraction = 0.5 145 } 146 } 147 n.shrink = n.Shrink 148 if n.shrink == 0 { 149 n.shrink = 1 - 1/float64(dim) 150 if dim == 1 { 151 n.shrink = 0.5 152 } 153 } 154 155 if n.InitialVertices != nil { 156 // Initial simplex provided. Copy the locations and values, and sort them. 157 if len(n.InitialVertices) != dim+1 { 158 panic("neldermead: incorrect number of vertices in initial simplex") 159 } 160 if len(n.InitialValues) != dim+1 { 161 panic("neldermead: incorrect number of values in initial simplex") 162 } 163 for i := range n.InitialVertices { 164 if len(n.InitialVertices[i]) != dim { 165 panic("neldermead: vertex size mismatch") 166 } 167 copy(n.vertices[i], n.InitialVertices[i]) 168 } 169 copy(n.values, n.InitialValues) 170 sort.Sort(nmVertexSorter{n.vertices, n.values}) 171 computeCentroid(n.vertices, n.centroid) 172 return n.returnNext(nmMajor, loc) 173 } 174 175 // No simplex provided. Begin initializing initial simplex. First simplex 176 // entry is the initial location, then step 1 in every direction. 177 copy(n.vertices[dim], loc.X) 178 n.values[dim] = loc.F 179 n.fillIdx = 0 180 loc.X[n.fillIdx] += n.SimplexSize 181 n.lastIter = nmInitialize 182 return FuncEvaluation, nil 183 } 184 185 // computeCentroid computes the centroid of all the simplex vertices except the 186 // final one 187 func computeCentroid(vertices [][]float64, centroid []float64) { 188 dim := len(centroid) 189 for i := range centroid { 190 centroid[i] = 0 191 } 192 for i := 0; i < dim; i++ { 193 vertex := vertices[i] 194 for j, v := range vertex { 195 centroid[j] += v 196 } 197 } 198 for i := range centroid { 199 centroid[i] /= float64(dim) 200 } 201 } 202 203 func (n *NelderMead) iterateLocal(loc *Location) (Operation, error) { 204 dim := len(loc.X) 205 switch n.lastIter { 206 case nmInitialize: 207 n.values[n.fillIdx] = loc.F 208 copy(n.vertices[n.fillIdx], loc.X) 209 n.fillIdx++ 210 if n.fillIdx == dim { 211 // Successfully finished building initial simplex. 212 sort.Sort(nmVertexSorter{n.vertices, n.values}) 213 computeCentroid(n.vertices, n.centroid) 214 return n.returnNext(nmMajor, loc) 215 } 216 copy(loc.X, n.vertices[dim]) 217 loc.X[n.fillIdx] += n.SimplexSize 218 return FuncEvaluation, nil 219 case nmMajor: 220 // Nelder Mead iterations start with Reflection step 221 return n.returnNext(nmReflected, loc) 222 case nmReflected: 223 n.reflectedValue = loc.F 224 switch { 225 case loc.F >= n.values[0] && loc.F < n.values[dim-1]: 226 n.replaceWorst(loc.X, loc.F) 227 return n.returnNext(nmMajor, loc) 228 case loc.F < n.values[0]: 229 return n.returnNext(nmExpanded, loc) 230 default: 231 if loc.F < n.values[dim] { 232 return n.returnNext(nmContractedOutside, loc) 233 } 234 return n.returnNext(nmContractedInside, loc) 235 } 236 case nmExpanded: 237 if loc.F < n.reflectedValue { 238 n.replaceWorst(loc.X, loc.F) 239 } else { 240 n.replaceWorst(n.reflectedPoint, n.reflectedValue) 241 } 242 return n.returnNext(nmMajor, loc) 243 case nmContractedOutside: 244 if loc.F <= n.reflectedValue { 245 n.replaceWorst(loc.X, loc.F) 246 return n.returnNext(nmMajor, loc) 247 } 248 n.fillIdx = 1 249 return n.returnNext(nmShrink, loc) 250 case nmContractedInside: 251 if loc.F < n.values[dim] { 252 n.replaceWorst(loc.X, loc.F) 253 return n.returnNext(nmMajor, loc) 254 } 255 n.fillIdx = 1 256 return n.returnNext(nmShrink, loc) 257 case nmShrink: 258 copy(n.vertices[n.fillIdx], loc.X) 259 n.values[n.fillIdx] = loc.F 260 n.fillIdx++ 261 if n.fillIdx != dim+1 { 262 return n.returnNext(nmShrink, loc) 263 } 264 sort.Sort(nmVertexSorter{n.vertices, n.values}) 265 computeCentroid(n.vertices, n.centroid) 266 return n.returnNext(nmMajor, loc) 267 default: 268 panic("unreachable") 269 } 270 } 271 272 // returnNext updates the location based on the iteration type and the current 273 // simplex, and returns the next operation. 274 func (n *NelderMead) returnNext(iter nmIterType, loc *Location) (Operation, error) { 275 n.lastIter = iter 276 switch iter { 277 case nmMajor: 278 // Fill loc with the current best point and value, 279 // and command a convergence check. 280 copy(loc.X, n.vertices[0]) 281 loc.F = n.values[0] 282 return MajorIteration, nil 283 case nmReflected, nmExpanded, nmContractedOutside, nmContractedInside: 284 // x_new = x_centroid + scale * (x_centroid - x_worst) 285 var scale float64 286 switch iter { 287 case nmReflected: 288 scale = n.reflection 289 case nmExpanded: 290 scale = n.reflection * n.expansion 291 case nmContractedOutside: 292 scale = n.reflection * n.contraction 293 case nmContractedInside: 294 scale = -n.contraction 295 } 296 dim := len(loc.X) 297 floats.SubTo(loc.X, n.centroid, n.vertices[dim]) 298 floats.Scale(scale, loc.X) 299 floats.Add(loc.X, n.centroid) 300 if iter == nmReflected { 301 copy(n.reflectedPoint, loc.X) 302 } 303 return FuncEvaluation, nil 304 case nmShrink: 305 // x_shrink = x_best + delta * (x_i + x_best) 306 floats.SubTo(loc.X, n.vertices[n.fillIdx], n.vertices[0]) 307 floats.Scale(n.shrink, loc.X) 308 floats.Add(loc.X, n.vertices[0]) 309 return FuncEvaluation, nil 310 default: 311 panic("unreachable") 312 } 313 } 314 315 // replaceWorst removes the worst location in the simplex and adds the new 316 // {x, f} pair maintaining sorting. 317 func (n *NelderMead) replaceWorst(x []float64, f float64) { 318 dim := len(x) 319 if f >= n.values[dim] { 320 panic("increase in simplex value") 321 } 322 copy(n.vertices[dim], x) 323 n.values[dim] = f 324 325 // Sort the newly-added value. 326 for i := dim - 1; i >= 0; i-- { 327 if n.values[i] < f { 328 break 329 } 330 n.vertices[i], n.vertices[i+1] = n.vertices[i+1], n.vertices[i] 331 n.values[i], n.values[i+1] = n.values[i+1], n.values[i] 332 } 333 334 // Update the location of the centroid. Only one point has been replaced, so 335 // subtract the worst point and add the new one. 336 floats.AddScaled(n.centroid, -1/float64(dim), n.vertices[dim]) 337 floats.AddScaled(n.centroid, 1/float64(dim), x) 338 } 339 340 func (*NelderMead) needs() struct { 341 Gradient bool 342 Hessian bool 343 } { 344 return struct { 345 Gradient bool 346 Hessian bool 347 }{false, false} 348 }