github.com/gopherd/gonum@v0.0.4/optimize/minimize.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package optimize
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"time"
    11  
    12  	"github.com/gopherd/gonum/floats"
    13  	"github.com/gopherd/gonum/mat"
    14  )
    15  
    16  const (
    17  	nonpositiveDimension string = "optimize: non-positive input dimension"
    18  	negativeTasks        string = "optimize: negative input number of tasks"
    19  )
    20  
    21  func min(a, b int) int {
    22  	if a < b {
    23  		return a
    24  	}
    25  	return b
    26  }
    27  
    28  // Task is a type to communicate between the Method and the outer
    29  // calling script.
    30  type Task struct {
    31  	ID int
    32  	Op Operation
    33  	*Location
    34  }
    35  
    36  // Location represents a location in the optimization procedure.
    37  type Location struct {
    38  	// X is the function input for the location.
    39  	X []float64
    40  	// F is the result of evaluating the function at X.
    41  	F float64
    42  	// Gradient holds the first-order partial derivatives
    43  	// of the function at X.
    44  	// The length of Gradient must match the length of X
    45  	// or be zero. If the capacity of Gradient is less
    46  	// than the length of X, a new slice will be allocated.
    47  	Gradient []float64
    48  	// Hessian holds the second-order partial derivatives
    49  	// of the function at X.
    50  	// The dimensions of Hessian must match the length of X
    51  	// or Hessian must be nil or empty. If Hessian is nil
    52  	// a new mat.SymDense will be allocated, if it is empty
    53  	// it will be resized to match the length of X.
    54  	Hessian *mat.SymDense
    55  }
    56  
    57  // Method is a type which can search for an optimum of an objective function.
    58  type Method interface {
    59  	// Init initializes the method for optimization. The inputs are
    60  	// the problem dimension and number of available concurrent tasks.
    61  	//
    62  	// Init returns the number of concurrent processes to use, which must be
    63  	// less than or equal to tasks.
    64  	Init(dim, tasks int) (concurrent int)
    65  	// Run runs an optimization. The method sends Tasks on
    66  	// the operation channel (for performing function evaluations, major
    67  	// iterations, etc.). The result of the tasks will be returned on Result.
    68  	// See the documentation for Operation types for the possible operations.
    69  	//
    70  	// The caller of Run will signal the termination of the optimization
    71  	// (i.e. convergence from user settings) by sending a task with a PostIteration
    72  	// Op field on result. More tasks may still be sent on operation after this
    73  	// occurs, but only MajorIteration operations will still be conducted
    74  	// appropriately. Thus, it can not be guaranteed that all Evaluations sent
    75  	// on operation will be evaluated, however if an Evaluation is started,
    76  	// the results of that evaluation will be sent on results.
    77  	//
    78  	// The Method must read from the result channel until it is closed.
    79  	// During this, the Method may want to send new MajorIteration(s) on
    80  	// operation. Method then must close operation, and return from Run.
    81  	// These steps must establish a "happens-before" relationship between result
    82  	// being closed (externally) and Run closing operation, for example
    83  	// by using a range loop to read from result even if no results are expected.
    84  	//
    85  	// The last parameter to Run is a slice of tasks with length equal to
    86  	// the return from Init. Task has an ID field which may be
    87  	// set and modified by Method, and must not be modified by the caller.
    88  	// The first element of tasks contains information about the initial location.
    89  	// The Location.X field is always valid. The Operation field specifies which
    90  	// other values of Location are known. If Operation == NoOperation, none of
    91  	// the values should be used, otherwise the Evaluation operations will be
    92  	// composed to specify the valid fields. Methods are free to use or
    93  	// ignore these values.
    94  	//
    95  	// Successful execution of an Operation may require the Method to modify
    96  	// fields a Location. MajorIteration calls will not modify the values in
    97  	// the Location, but Evaluation operations will. Methods are encouraged to
    98  	// leave Location fields untouched to allow memory re-use. If data needs to
    99  	// be stored, the respective field should be set to nil -- Methods should
   100  	// not allocate Location memory themselves.
   101  	//
   102  	// Method may have its own specific convergence criteria, which can
   103  	// be communicated using a MethodDone operation. This will trigger a
   104  	// PostIteration to be sent on result, and the MethodDone task will not be
   105  	// returned on result. The Method must implement Statuser, and the
   106  	// call to Status must return a Status other than NotTerminated.
   107  	//
   108  	// The operation and result tasks are guaranteed to have a buffer length
   109  	// equal to the return from Init.
   110  	Run(operation chan<- Task, result <-chan Task, tasks []Task)
   111  	// Uses checks if the Method is suited to the optimization problem. The
   112  	// input is the available functions in Problem to call, and the returns are
   113  	// the functions which may be used and an error if there is a mismatch
   114  	// between the Problem and the Method's capabilities.
   115  	Uses(has Available) (uses Available, err error)
   116  }
   117  
   118  // Minimize uses an optimizer to search for a minimum of a function. A
   119  // maximization problem can be transformed into a minimization problem by
   120  // multiplying the function by -1.
   121  //
   122  // The first argument represents the problem to be minimized. Its fields are
   123  // routines that evaluate the objective function, gradient, and other
   124  // quantities related to the problem. The objective function, p.Func, must not
   125  // be nil. The optimization method used may require other fields to be non-nil
   126  // as specified by method.Needs. Minimize will panic if these are not met. The
   127  // method can be determined automatically from the supplied problem which is
   128  // described below.
   129  //
   130  // If p.Status is not nil, it is called before every evaluation. If the
   131  // returned Status is other than NotTerminated or if the error is not nil, the
   132  // optimization run is terminated.
   133  //
   134  // The second argument specifies the initial location for the optimization.
   135  // Some Methods do not require an initial location, but initX must still be
   136  // specified for the dimension of the optimization problem.
   137  //
   138  // The third argument contains the settings for the minimization. If settings
   139  // is nil, the zero value will be used, see the documentation of the Settings
   140  // type for more information, and see the warning below. All settings will be
   141  // honored for all Methods, even if that setting is counter-productive to the
   142  // method. Minimize cannot guarantee strict adherence to the evaluation bounds
   143  // specified when performing concurrent evaluations and updates.
   144  //
   145  // The final argument is the optimization method to use. If method == nil, then
   146  // an appropriate default is chosen based on the properties of the other arguments
   147  // (dimension, gradient-free or gradient-based, etc.). If method is not nil,
   148  // Minimize panics if the Problem is not consistent with the Method (Uses
   149  // returns an error).
   150  //
   151  // Minimize returns a Result struct and any error that occurred. See the
   152  // documentation of Result for more information.
   153  //
   154  // See the documentation for Method for the details on implementing a method.
   155  //
   156  // Be aware that the default settings of Minimize are to accurately find the
   157  // minimum. For certain functions and optimization methods, this can take many
   158  // function evaluations. The Settings input struct can be used to limit this,
   159  // for example by modifying the maximum function evaluations or gradient tolerance.
   160  func Minimize(p Problem, initX []float64, settings *Settings, method Method) (*Result, error) {
   161  	startTime := time.Now()
   162  	if method == nil {
   163  		method = getDefaultMethod(&p)
   164  	}
   165  	if settings == nil {
   166  		settings = &Settings{}
   167  	}
   168  	stats := &Stats{}
   169  	dim := len(initX)
   170  	err := checkOptimization(p, dim, settings.Recorder)
   171  	if err != nil {
   172  		return nil, err
   173  	}
   174  
   175  	optLoc := newLocation(dim) // This must have an allocated X field.
   176  	optLoc.F = math.Inf(1)
   177  
   178  	initOp, initLoc := getInitLocation(dim, initX, settings.InitValues)
   179  
   180  	converger := settings.Converger
   181  	if converger == nil {
   182  		converger = defaultFunctionConverge()
   183  	}
   184  	converger.Init(dim)
   185  
   186  	stats.Runtime = time.Since(startTime)
   187  
   188  	// Send initial location to Recorder
   189  	if settings.Recorder != nil {
   190  		err = settings.Recorder.Record(optLoc, InitIteration, stats)
   191  		if err != nil {
   192  			return nil, err
   193  		}
   194  	}
   195  
   196  	// Run optimization
   197  	var status Status
   198  	status, err = minimize(&p, method, settings, converger, stats, initOp, initLoc, optLoc, startTime)
   199  
   200  	// Cleanup and collect results
   201  	if settings.Recorder != nil && err == nil {
   202  		err = settings.Recorder.Record(optLoc, PostIteration, stats)
   203  	}
   204  	stats.Runtime = time.Since(startTime)
   205  	return &Result{
   206  		Location: *optLoc,
   207  		Stats:    *stats,
   208  		Status:   status,
   209  	}, err
   210  }
   211  
   212  func getDefaultMethod(p *Problem) Method {
   213  	if p.Grad != nil {
   214  		return &LBFGS{}
   215  	}
   216  	return &NelderMead{}
   217  }
   218  
   219  // minimize performs an optimization. minimize updates the settings and optLoc,
   220  // and returns the final Status and error.
   221  func minimize(prob *Problem, method Method, settings *Settings, converger Converger, stats *Stats, initOp Operation, initLoc, optLoc *Location, startTime time.Time) (Status, error) {
   222  	dim := len(optLoc.X)
   223  	nTasks := settings.Concurrent
   224  	if nTasks == 0 {
   225  		nTasks = 1
   226  	}
   227  	has := availFromProblem(*prob)
   228  	_, initErr := method.Uses(has)
   229  	if initErr != nil {
   230  		panic(fmt.Sprintf("optimize: specified method inconsistent with Problem: %v", initErr))
   231  	}
   232  	newNTasks := method.Init(dim, nTasks)
   233  	if newNTasks > nTasks {
   234  		panic("optimize: too many tasks returned by Method")
   235  	}
   236  	nTasks = newNTasks
   237  
   238  	// Launch the method. The method communicates tasks using the operations
   239  	// channel, and results is used to return the evaluated results.
   240  	operations := make(chan Task, nTasks)
   241  	results := make(chan Task, nTasks)
   242  	go func() {
   243  		tasks := make([]Task, nTasks)
   244  		tasks[0].Location = initLoc
   245  		tasks[0].Op = initOp
   246  		for i := 1; i < len(tasks); i++ {
   247  			tasks[i].Location = newLocation(dim)
   248  		}
   249  		method.Run(operations, results, tasks)
   250  	}()
   251  
   252  	// Algorithmic Overview:
   253  	// There are three pieces to performing a concurrent optimization,
   254  	// the distributor, the workers, and the stats combiner. At a high level,
   255  	// the distributor reads in tasks sent by method, sending evaluations to the
   256  	// workers, and forwarding other operations to the statsCombiner. The workers
   257  	// read these forwarded evaluation tasks, evaluate the relevant parts of Problem
   258  	// and forward the results on to the stats combiner. The stats combiner reads
   259  	// in results from the workers, as well as tasks from the distributor, and
   260  	// uses them to update optimization statistics (function evaluations, etc.)
   261  	// and to check optimization convergence.
   262  	//
   263  	// The complicated part is correctly shutting down the optimization. The
   264  	// procedure is as follows. First, the stats combiner closes done and sends
   265  	// a PostIteration to the method. The distributor then reads that done has
   266  	// been closed, and closes the channel with the workers. At this point, no
   267  	// more evaluation operations will be executed. As the workers finish their
   268  	// evaluations, they forward the results onto the stats combiner, and then
   269  	// signal their shutdown to the stats combiner. When all workers have successfully
   270  	// finished, the stats combiner closes the results channel, signaling to the
   271  	// method that all results have been collected. At this point, the method
   272  	// may send MajorIteration(s) to update an optimum location based on these
   273  	// last returned results, and then the method will close the operations channel.
   274  	// The Method must ensure that the closing of results happens before the
   275  	// closing of operations in order to ensure proper shutdown order.
   276  	// Now that no more tasks will be commanded by the method, the distributor
   277  	// closes statsChan, and with no more statistics to update the optimization
   278  	// concludes.
   279  
   280  	workerChan := make(chan Task) // Delegate tasks to the workers.
   281  	statsChan := make(chan Task)  // Send evaluation updates.
   282  	done := make(chan struct{})   // Communicate the optimization is done.
   283  
   284  	// Read tasks from the method and distribute as appropriate.
   285  	distributor := func() {
   286  		for {
   287  			select {
   288  			case task := <-operations:
   289  				switch task.Op {
   290  				case InitIteration:
   291  					panic("optimize: Method returned InitIteration")
   292  				case PostIteration:
   293  					panic("optimize: Method returned PostIteration")
   294  				case NoOperation, MajorIteration, MethodDone:
   295  					statsChan <- task
   296  				default:
   297  					if !task.Op.isEvaluation() {
   298  						panic("optimize: expecting evaluation operation")
   299  					}
   300  					workerChan <- task
   301  				}
   302  			case <-done:
   303  				// No more evaluations will be sent, shut down the workers, and
   304  				// read the final tasks.
   305  				close(workerChan)
   306  				for task := range operations {
   307  					if task.Op == MajorIteration {
   308  						statsChan <- task
   309  					}
   310  				}
   311  				close(statsChan)
   312  				return
   313  			}
   314  		}
   315  	}
   316  	go distributor()
   317  
   318  	// Evaluate the Problem concurrently.
   319  	worker := func() {
   320  		x := make([]float64, dim)
   321  		for task := range workerChan {
   322  			evaluate(prob, task.Location, task.Op, x)
   323  			statsChan <- task
   324  		}
   325  		// Signal successful worker completion.
   326  		statsChan <- Task{Op: signalDone}
   327  	}
   328  	for i := 0; i < nTasks; i++ {
   329  		go worker()
   330  	}
   331  
   332  	var (
   333  		workersDone int // effective wg for the workers
   334  		status      Status
   335  		err         error
   336  		finalStatus Status
   337  		finalError  error
   338  	)
   339  
   340  	// Update optimization statistics and check convergence.
   341  	var methodDone bool
   342  	for task := range statsChan {
   343  		switch task.Op {
   344  		default:
   345  			if !task.Op.isEvaluation() {
   346  				panic("minimize: evaluation task expected")
   347  			}
   348  			updateEvaluationStats(stats, task.Op)
   349  			status, err = checkEvaluationLimits(prob, stats, settings)
   350  		case signalDone:
   351  			workersDone++
   352  			if workersDone == nTasks {
   353  				close(results)
   354  			}
   355  			continue
   356  		case NoOperation:
   357  			// Just send the task back.
   358  		case MajorIteration:
   359  			status = performMajorIteration(optLoc, task.Location, stats, converger, startTime, settings)
   360  		case MethodDone:
   361  			methodDone = true
   362  			status = MethodConverge
   363  		}
   364  		if settings.Recorder != nil && status == NotTerminated && err == nil {
   365  			stats.Runtime = time.Since(startTime)
   366  			// Allow err to be overloaded if the Recorder fails.
   367  			err = settings.Recorder.Record(task.Location, task.Op, stats)
   368  			if err != nil {
   369  				status = Failure
   370  			}
   371  		}
   372  		// If this is the first termination status, trigger the conclusion of
   373  		// the optimization.
   374  		if status != NotTerminated || err != nil {
   375  			select {
   376  			case <-done:
   377  			default:
   378  				finalStatus = status
   379  				finalError = err
   380  				results <- Task{
   381  					Op: PostIteration,
   382  				}
   383  				close(done)
   384  			}
   385  		}
   386  
   387  		// Send the result back to the Problem if there are still active workers.
   388  		if workersDone != nTasks && task.Op != MethodDone {
   389  			results <- task
   390  		}
   391  	}
   392  	// This code block is here rather than above to ensure Status() is not called
   393  	// before Method.Run closes operations.
   394  	if methodDone {
   395  		statuser, ok := method.(Statuser)
   396  		if !ok {
   397  			panic("optimize: method returned MethodDone but is not a Statuser")
   398  		}
   399  		finalStatus, finalError = statuser.Status()
   400  		if finalStatus == NotTerminated {
   401  			panic("optimize: method returned MethodDone but a NotTerminated status")
   402  		}
   403  	}
   404  	return finalStatus, finalError
   405  }
   406  
   407  func defaultFunctionConverge() *FunctionConverge {
   408  	return &FunctionConverge{
   409  		Absolute:   1e-10,
   410  		Iterations: 100,
   411  	}
   412  }
   413  
   414  // newLocation allocates a new location structure with an X field of the
   415  // appropriate size.
   416  func newLocation(dim int) *Location {
   417  	return &Location{
   418  		X: make([]float64, dim),
   419  	}
   420  }
   421  
   422  // getInitLocation checks the validity of initLocation and initOperation and
   423  // returns the initial values as a *Location.
   424  func getInitLocation(dim int, initX []float64, initValues *Location) (Operation, *Location) {
   425  	loc := newLocation(dim)
   426  	if initX == nil {
   427  		if initValues != nil {
   428  			panic("optimize: initValues is non-nil but no initial location specified")
   429  		}
   430  		return NoOperation, loc
   431  	}
   432  	copy(loc.X, initX)
   433  	if initValues == nil {
   434  		return NoOperation, loc
   435  	} else {
   436  		if initValues.X != nil {
   437  			panic("optimize: location specified in InitValues (only use InitX)")
   438  		}
   439  	}
   440  	loc.F = initValues.F
   441  	op := FuncEvaluation
   442  	if initValues.Gradient != nil {
   443  		if len(initValues.Gradient) != dim {
   444  			panic("optimize: initial gradient does not match problem dimension")
   445  		}
   446  		loc.Gradient = initValues.Gradient
   447  		op |= GradEvaluation
   448  	}
   449  	if initValues.Hessian != nil {
   450  		if initValues.Hessian.SymmetricDim() != dim {
   451  			panic("optimize: initial Hessian does not match problem dimension")
   452  		}
   453  		loc.Hessian = initValues.Hessian
   454  		op |= HessEvaluation
   455  	}
   456  	return op, loc
   457  }
   458  
   459  func checkOptimization(p Problem, dim int, recorder Recorder) error {
   460  	if p.Func == nil {
   461  		panic(badProblem)
   462  	}
   463  	if dim <= 0 {
   464  		panic("optimize: impossible problem dimension")
   465  	}
   466  	if p.Status != nil {
   467  		_, err := p.Status()
   468  		if err != nil {
   469  			return err
   470  		}
   471  	}
   472  	if recorder != nil {
   473  		err := recorder.Init()
   474  		if err != nil {
   475  			return err
   476  		}
   477  	}
   478  	return nil
   479  }
   480  
   481  // evaluate evaluates the routines specified by the Operation at loc.X, and stores
   482  // the answer into loc. loc.X is copied into x before evaluating in order to
   483  // prevent the routines from modifying it.
   484  func evaluate(p *Problem, loc *Location, op Operation, x []float64) {
   485  	if !op.isEvaluation() {
   486  		panic(fmt.Sprintf("optimize: invalid evaluation %v", op))
   487  	}
   488  	copy(x, loc.X)
   489  	if op&FuncEvaluation != 0 {
   490  		loc.F = p.Func(x)
   491  	}
   492  	if op&GradEvaluation != 0 {
   493  		// Make sure we have a destination in which to place the gradient.
   494  		if len(loc.Gradient) == 0 {
   495  			if cap(loc.Gradient) < len(x) {
   496  				loc.Gradient = make([]float64, len(x))
   497  			} else {
   498  				loc.Gradient = loc.Gradient[:len(x)]
   499  			}
   500  		}
   501  		p.Grad(loc.Gradient, x)
   502  	}
   503  	if op&HessEvaluation != 0 {
   504  		// Make sure we have a destination in which to place the Hessian.
   505  		switch {
   506  		case loc.Hessian == nil:
   507  			loc.Hessian = mat.NewSymDense(len(x), nil)
   508  		case loc.Hessian.IsEmpty():
   509  			loc.Hessian.ReuseAsSym(len(x))
   510  		}
   511  		p.Hess(loc.Hessian, x)
   512  	}
   513  }
   514  
   515  // updateEvaluationStats updates the statistics based on the operation.
   516  func updateEvaluationStats(stats *Stats, op Operation) {
   517  	if op&FuncEvaluation != 0 {
   518  		stats.FuncEvaluations++
   519  	}
   520  	if op&GradEvaluation != 0 {
   521  		stats.GradEvaluations++
   522  	}
   523  	if op&HessEvaluation != 0 {
   524  		stats.HessEvaluations++
   525  	}
   526  }
   527  
   528  // checkLocationConvergence checks if the current optimal location satisfies
   529  // any of the convergence criteria based on the function location.
   530  //
   531  // checkLocationConvergence returns NotTerminated if the Location does not satisfy
   532  // the convergence criteria given by settings. Otherwise a corresponding status is
   533  // returned.
   534  // Unlike checkLimits, checkConvergence is called only at MajorIterations.
   535  func checkLocationConvergence(loc *Location, settings *Settings, converger Converger) Status {
   536  	if math.IsInf(loc.F, -1) {
   537  		return FunctionNegativeInfinity
   538  	}
   539  	if loc.Gradient != nil && settings.GradientThreshold > 0 {
   540  		norm := floats.Norm(loc.Gradient, math.Inf(1))
   541  		if norm < settings.GradientThreshold {
   542  			return GradientThreshold
   543  		}
   544  	}
   545  	return converger.Converged(loc)
   546  }
   547  
   548  // checkEvaluationLimits checks the optimization limits after an evaluation
   549  // Operation. It checks the number of evaluations (of various kinds) and checks
   550  // the status of the Problem, if applicable.
   551  func checkEvaluationLimits(p *Problem, stats *Stats, settings *Settings) (Status, error) {
   552  	if p.Status != nil {
   553  		status, err := p.Status()
   554  		if err != nil || status != NotTerminated {
   555  			return status, err
   556  		}
   557  	}
   558  	if settings.FuncEvaluations > 0 && stats.FuncEvaluations >= settings.FuncEvaluations {
   559  		return FunctionEvaluationLimit, nil
   560  	}
   561  	if settings.GradEvaluations > 0 && stats.GradEvaluations >= settings.GradEvaluations {
   562  		return GradientEvaluationLimit, nil
   563  	}
   564  	if settings.HessEvaluations > 0 && stats.HessEvaluations >= settings.HessEvaluations {
   565  		return HessianEvaluationLimit, nil
   566  	}
   567  	return NotTerminated, nil
   568  }
   569  
   570  // checkIterationLimits checks the limits on iterations affected by MajorIteration.
   571  func checkIterationLimits(loc *Location, stats *Stats, settings *Settings) Status {
   572  	if settings.MajorIterations > 0 && stats.MajorIterations >= settings.MajorIterations {
   573  		return IterationLimit
   574  	}
   575  	if settings.Runtime > 0 && stats.Runtime >= settings.Runtime {
   576  		return RuntimeLimit
   577  	}
   578  	return NotTerminated
   579  }
   580  
   581  // performMajorIteration does all of the steps needed to perform a MajorIteration.
   582  // It increments the iteration count, updates the optimal location, and checks
   583  // the necessary convergence criteria.
   584  func performMajorIteration(optLoc, loc *Location, stats *Stats, converger Converger, startTime time.Time, settings *Settings) Status {
   585  	optLoc.F = loc.F
   586  	copy(optLoc.X, loc.X)
   587  	if loc.Gradient == nil {
   588  		optLoc.Gradient = nil
   589  	} else {
   590  		if optLoc.Gradient == nil {
   591  			optLoc.Gradient = make([]float64, len(loc.Gradient))
   592  		}
   593  		copy(optLoc.Gradient, loc.Gradient)
   594  	}
   595  	stats.MajorIterations++
   596  	stats.Runtime = time.Since(startTime)
   597  	status := checkLocationConvergence(optLoc, settings, converger)
   598  	if status != NotTerminated {
   599  		return status
   600  	}
   601  	return checkIterationLimits(optLoc, stats, settings)
   602  }