github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/optimize/neldermead.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package optimize
     6  
     7  import (
     8  	"math"
     9  	"sort"
    10  
    11  	"github.com/jingcheng-WU/gonum/floats"
    12  )
    13  
    14  // nmIterType is a Nelder-Mead evaluation kind
    15  type nmIterType int
    16  
    17  const (
    18  	nmReflected = iota
    19  	nmExpanded
    20  	nmContractedInside
    21  	nmContractedOutside
    22  	nmInitialize
    23  	nmShrink
    24  	nmMajor
    25  )
    26  
    27  type nmVertexSorter struct {
    28  	vertices [][]float64
    29  	values   []float64
    30  }
    31  
    32  func (n nmVertexSorter) Len() int {
    33  	return len(n.values)
    34  }
    35  
    36  func (n nmVertexSorter) Less(i, j int) bool {
    37  	return n.values[i] < n.values[j]
    38  }
    39  
    40  func (n nmVertexSorter) Swap(i, j int) {
    41  	n.values[i], n.values[j] = n.values[j], n.values[i]
    42  	n.vertices[i], n.vertices[j] = n.vertices[j], n.vertices[i]
    43  }
    44  
    45  var _ Method = (*NelderMead)(nil)
    46  
    47  // NelderMead is an implementation of the Nelder-Mead simplex algorithm for
    48  // gradient-free nonlinear optimization (not to be confused with Danzig's
    49  // simplex algorithm for linear programming). The implementation follows the
    50  // algorithm described in
    51  //
    52  //  http://epubs.siam.org/doi/pdf/10.1137/S1052623496303470
    53  //
    54  // If an initial simplex is provided, it is used and initLoc is ignored. If
    55  // InitialVertices and InitialValues are both nil, an initial simplex will be
    56  // generated automatically using the initial location as one vertex, and each
    57  // additional vertex as SimplexSize away in one dimension.
    58  //
    59  // If the simplex update parameters (Reflection, etc.)
    60  // are zero, they will be set automatically based on the dimension according to
    61  // the recommendations in
    62  //
    63  //  http://www.webpages.uidaho.edu/~fuchang/res/ANMS.pdf
    64  type NelderMead struct {
    65  	InitialVertices [][]float64
    66  	InitialValues   []float64
    67  	Reflection      float64 // Reflection parameter (>0)
    68  	Expansion       float64 // Expansion parameter (>1)
    69  	Contraction     float64 // Contraction parameter (>0, <1)
    70  	Shrink          float64 // Shrink parameter (>0, <1)
    71  	SimplexSize     float64 // size of auto-constructed initial simplex
    72  
    73  	status Status
    74  	err    error
    75  
    76  	reflection  float64
    77  	expansion   float64
    78  	contraction float64
    79  	shrink      float64
    80  
    81  	vertices [][]float64 // location of the vertices sorted in ascending f
    82  	values   []float64   // function values at the vertices sorted in ascending f
    83  	centroid []float64   // centroid of all but the worst vertex
    84  
    85  	fillIdx        int        // index for filling the simplex during initialization and shrinking
    86  	lastIter       nmIterType // Last iteration
    87  	reflectedPoint []float64  // Storage of the reflected point location
    88  	reflectedValue float64    // Value at the last reflection point
    89  }
    90  
    91  func (n *NelderMead) Status() (Status, error) {
    92  	return n.status, n.err
    93  }
    94  
    95  func (*NelderMead) Uses(has Available) (uses Available, err error) {
    96  	return has.function()
    97  }
    98  
    99  func (n *NelderMead) Init(dim, tasks int) int {
   100  	n.status = NotTerminated
   101  	n.err = nil
   102  	return 1
   103  }
   104  
   105  func (n *NelderMead) Run(operation chan<- Task, result <-chan Task, tasks []Task) {
   106  	n.status, n.err = localOptimizer{}.run(n, math.NaN(), operation, result, tasks)
   107  	close(operation)
   108  }
   109  
   110  func (n *NelderMead) initLocal(loc *Location) (Operation, error) {
   111  	dim := len(loc.X)
   112  	if cap(n.vertices) < dim+1 {
   113  		n.vertices = make([][]float64, dim+1)
   114  	}
   115  	n.vertices = n.vertices[:dim+1]
   116  	for i := range n.vertices {
   117  		n.vertices[i] = resize(n.vertices[i], dim)
   118  	}
   119  	n.values = resize(n.values, dim+1)
   120  	n.centroid = resize(n.centroid, dim)
   121  	n.reflectedPoint = resize(n.reflectedPoint, dim)
   122  
   123  	if n.SimplexSize == 0 {
   124  		n.SimplexSize = 0.05
   125  	}
   126  
   127  	// Default parameter choices are chosen in a dimension-dependent way
   128  	// from http://www.webpages.uidaho.edu/~fuchang/res/ANMS.pdf
   129  	n.reflection = n.Reflection
   130  	if n.reflection == 0 {
   131  		n.reflection = 1
   132  	}
   133  	n.expansion = n.Expansion
   134  	if n.expansion == 0 {
   135  		n.expansion = 1 + 2/float64(dim)
   136  		if dim == 1 {
   137  			n.expansion = 2
   138  		}
   139  	}
   140  	n.contraction = n.Contraction
   141  	if n.contraction == 0 {
   142  		n.contraction = 0.75 - 1/(2*float64(dim))
   143  		if dim == 1 {
   144  			n.contraction = 0.5
   145  		}
   146  	}
   147  	n.shrink = n.Shrink
   148  	if n.shrink == 0 {
   149  		n.shrink = 1 - 1/float64(dim)
   150  		if dim == 1 {
   151  			n.shrink = 0.5
   152  		}
   153  	}
   154  
   155  	if n.InitialVertices != nil {
   156  		// Initial simplex provided. Copy the locations and values, and sort them.
   157  		if len(n.InitialVertices) != dim+1 {
   158  			panic("neldermead: incorrect number of vertices in initial simplex")
   159  		}
   160  		if len(n.InitialValues) != dim+1 {
   161  			panic("neldermead: incorrect number of values in initial simplex")
   162  		}
   163  		for i := range n.InitialVertices {
   164  			if len(n.InitialVertices[i]) != dim {
   165  				panic("neldermead: vertex size mismatch")
   166  			}
   167  			copy(n.vertices[i], n.InitialVertices[i])
   168  		}
   169  		copy(n.values, n.InitialValues)
   170  		sort.Sort(nmVertexSorter{n.vertices, n.values})
   171  		computeCentroid(n.vertices, n.centroid)
   172  		return n.returnNext(nmMajor, loc)
   173  	}
   174  
   175  	// No simplex provided. Begin initializing initial simplex. First simplex
   176  	// entry is the initial location, then step 1 in every direction.
   177  	copy(n.vertices[dim], loc.X)
   178  	n.values[dim] = loc.F
   179  	n.fillIdx = 0
   180  	loc.X[n.fillIdx] += n.SimplexSize
   181  	n.lastIter = nmInitialize
   182  	return FuncEvaluation, nil
   183  }
   184  
   185  // computeCentroid computes the centroid of all the simplex vertices except the
   186  // final one
   187  func computeCentroid(vertices [][]float64, centroid []float64) {
   188  	dim := len(centroid)
   189  	for i := range centroid {
   190  		centroid[i] = 0
   191  	}
   192  	for i := 0; i < dim; i++ {
   193  		vertex := vertices[i]
   194  		for j, v := range vertex {
   195  			centroid[j] += v
   196  		}
   197  	}
   198  	for i := range centroid {
   199  		centroid[i] /= float64(dim)
   200  	}
   201  }
   202  
   203  func (n *NelderMead) iterateLocal(loc *Location) (Operation, error) {
   204  	dim := len(loc.X)
   205  	switch n.lastIter {
   206  	case nmInitialize:
   207  		n.values[n.fillIdx] = loc.F
   208  		copy(n.vertices[n.fillIdx], loc.X)
   209  		n.fillIdx++
   210  		if n.fillIdx == dim {
   211  			// Successfully finished building initial simplex.
   212  			sort.Sort(nmVertexSorter{n.vertices, n.values})
   213  			computeCentroid(n.vertices, n.centroid)
   214  			return n.returnNext(nmMajor, loc)
   215  		}
   216  		copy(loc.X, n.vertices[dim])
   217  		loc.X[n.fillIdx] += n.SimplexSize
   218  		return FuncEvaluation, nil
   219  	case nmMajor:
   220  		// Nelder Mead iterations start with Reflection step
   221  		return n.returnNext(nmReflected, loc)
   222  	case nmReflected:
   223  		n.reflectedValue = loc.F
   224  		switch {
   225  		case loc.F >= n.values[0] && loc.F < n.values[dim-1]:
   226  			n.replaceWorst(loc.X, loc.F)
   227  			return n.returnNext(nmMajor, loc)
   228  		case loc.F < n.values[0]:
   229  			return n.returnNext(nmExpanded, loc)
   230  		default:
   231  			if loc.F < n.values[dim] {
   232  				return n.returnNext(nmContractedOutside, loc)
   233  			}
   234  			return n.returnNext(nmContractedInside, loc)
   235  		}
   236  	case nmExpanded:
   237  		if loc.F < n.reflectedValue {
   238  			n.replaceWorst(loc.X, loc.F)
   239  		} else {
   240  			n.replaceWorst(n.reflectedPoint, n.reflectedValue)
   241  		}
   242  		return n.returnNext(nmMajor, loc)
   243  	case nmContractedOutside:
   244  		if loc.F <= n.reflectedValue {
   245  			n.replaceWorst(loc.X, loc.F)
   246  			return n.returnNext(nmMajor, loc)
   247  		}
   248  		n.fillIdx = 1
   249  		return n.returnNext(nmShrink, loc)
   250  	case nmContractedInside:
   251  		if loc.F < n.values[dim] {
   252  			n.replaceWorst(loc.X, loc.F)
   253  			return n.returnNext(nmMajor, loc)
   254  		}
   255  		n.fillIdx = 1
   256  		return n.returnNext(nmShrink, loc)
   257  	case nmShrink:
   258  		copy(n.vertices[n.fillIdx], loc.X)
   259  		n.values[n.fillIdx] = loc.F
   260  		n.fillIdx++
   261  		if n.fillIdx != dim+1 {
   262  			return n.returnNext(nmShrink, loc)
   263  		}
   264  		sort.Sort(nmVertexSorter{n.vertices, n.values})
   265  		computeCentroid(n.vertices, n.centroid)
   266  		return n.returnNext(nmMajor, loc)
   267  	default:
   268  		panic("unreachable")
   269  	}
   270  }
   271  
   272  // returnNext updates the location based on the iteration type and the current
   273  // simplex, and returns the next operation.
   274  func (n *NelderMead) returnNext(iter nmIterType, loc *Location) (Operation, error) {
   275  	n.lastIter = iter
   276  	switch iter {
   277  	case nmMajor:
   278  		// Fill loc with the current best point and value,
   279  		// and command a convergence check.
   280  		copy(loc.X, n.vertices[0])
   281  		loc.F = n.values[0]
   282  		return MajorIteration, nil
   283  	case nmReflected, nmExpanded, nmContractedOutside, nmContractedInside:
   284  		// x_new = x_centroid + scale * (x_centroid - x_worst)
   285  		var scale float64
   286  		switch iter {
   287  		case nmReflected:
   288  			scale = n.reflection
   289  		case nmExpanded:
   290  			scale = n.reflection * n.expansion
   291  		case nmContractedOutside:
   292  			scale = n.reflection * n.contraction
   293  		case nmContractedInside:
   294  			scale = -n.contraction
   295  		}
   296  		dim := len(loc.X)
   297  		floats.SubTo(loc.X, n.centroid, n.vertices[dim])
   298  		floats.Scale(scale, loc.X)
   299  		floats.Add(loc.X, n.centroid)
   300  		if iter == nmReflected {
   301  			copy(n.reflectedPoint, loc.X)
   302  		}
   303  		return FuncEvaluation, nil
   304  	case nmShrink:
   305  		// x_shrink = x_best + delta * (x_i + x_best)
   306  		floats.SubTo(loc.X, n.vertices[n.fillIdx], n.vertices[0])
   307  		floats.Scale(n.shrink, loc.X)
   308  		floats.Add(loc.X, n.vertices[0])
   309  		return FuncEvaluation, nil
   310  	default:
   311  		panic("unreachable")
   312  	}
   313  }
   314  
   315  // replaceWorst removes the worst location in the simplex and adds the new
   316  // {x, f} pair maintaining sorting.
   317  func (n *NelderMead) replaceWorst(x []float64, f float64) {
   318  	dim := len(x)
   319  	if f >= n.values[dim] {
   320  		panic("increase in simplex value")
   321  	}
   322  	copy(n.vertices[dim], x)
   323  	n.values[dim] = f
   324  
   325  	// Sort the newly-added value.
   326  	for i := dim - 1; i >= 0; i-- {
   327  		if n.values[i] < f {
   328  			break
   329  		}
   330  		n.vertices[i], n.vertices[i+1] = n.vertices[i+1], n.vertices[i]
   331  		n.values[i], n.values[i+1] = n.values[i+1], n.values[i]
   332  	}
   333  
   334  	// Update the location of the centroid. Only one point has been replaced, so
   335  	// subtract the worst point and add the new one.
   336  	floats.AddScaled(n.centroid, -1/float64(dim), n.vertices[dim])
   337  	floats.AddScaled(n.centroid, 1/float64(dim), x)
   338  }
   339  
   340  func (*NelderMead) needs() struct {
   341  	Gradient bool
   342  	Hessian  bool
   343  } {
   344  	return struct {
   345  		Gradient bool
   346  		Hessian  bool
   347  	}{false, false}
   348  }