gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/optimize/minimize.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package optimize
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"time"
    11  
    12  	"gonum.org/v1/gonum/floats"
    13  	"gonum.org/v1/gonum/mat"
    14  )
    15  
    16  const (
    17  	nonpositiveDimension string = "optimize: non-positive input dimension"
    18  	negativeTasks        string = "optimize: negative input number of tasks"
    19  )
    20  
    21  // Task is a type to communicate between the Method and the outer
    22  // calling script.
    23  type Task struct {
    24  	ID int
    25  	Op Operation
    26  	*Location
    27  }
    28  
    29  // Location represents a location in the optimization procedure.
    30  type Location struct {
    31  	// X is the function input for the location.
    32  	X []float64
    33  	// F is the result of evaluating the function at X.
    34  	F float64
    35  	// Gradient holds the first-order partial derivatives
    36  	// of the function at X.
    37  	// The length of Gradient must match the length of X
    38  	// or be zero. If the capacity of Gradient is less
    39  	// than the length of X, a new slice will be allocated.
    40  	Gradient []float64
    41  	// Hessian holds the second-order partial derivatives
    42  	// of the function at X.
    43  	// The dimensions of Hessian must match the length of X
    44  	// or Hessian must be nil or empty. If Hessian is nil
    45  	// a new mat.SymDense will be allocated, if it is empty
    46  	// it will be resized to match the length of X.
    47  	Hessian *mat.SymDense
    48  }
    49  
    50  // Method is a type which can search for an optimum of an objective function.
    51  type Method interface {
    52  	// Init initializes the method for optimization. The inputs are
    53  	// the problem dimension and number of available concurrent tasks.
    54  	//
    55  	// Init returns the number of concurrent processes to use, which must be
    56  	// less than or equal to tasks.
    57  	Init(dim, tasks int) (concurrent int)
    58  	// Run runs an optimization. The method sends Tasks on
    59  	// the operation channel (for performing function evaluations, major
    60  	// iterations, etc.). The result of the tasks will be returned on Result.
    61  	// See the documentation for Operation types for the possible operations.
    62  	//
    63  	// The caller of Run will signal the termination of the optimization
    64  	// (i.e. convergence from user settings) by sending a task with a PostIteration
    65  	// Op field on result. More tasks may still be sent on operation after this
    66  	// occurs, but only MajorIteration operations will still be conducted
    67  	// appropriately. Thus, it can not be guaranteed that all Evaluations sent
    68  	// on operation will be evaluated, however if an Evaluation is started,
    69  	// the results of that evaluation will be sent on results.
    70  	//
    71  	// The Method must read from the result channel until it is closed.
    72  	// During this, the Method may want to send new MajorIteration(s) on
    73  	// operation. Method then must close operation, and return from Run.
    74  	// These steps must establish a "happens-before" relationship between result
    75  	// being closed (externally) and Run closing operation, for example
    76  	// by using a range loop to read from result even if no results are expected.
    77  	//
    78  	// The last parameter to Run is a slice of tasks with length equal to
    79  	// the return from Init. Task has an ID field which may be
    80  	// set and modified by Method, and must not be modified by the caller.
    81  	// The first element of tasks contains information about the initial location.
    82  	// The Location.X field is always valid. The Operation field specifies which
    83  	// other values of Location are known. If Operation == NoOperation, none of
    84  	// the values should be used, otherwise the Evaluation operations will be
    85  	// composed to specify the valid fields. Methods are free to use or
    86  	// ignore these values.
    87  	//
    88  	// Successful execution of an Operation may require the Method to modify
    89  	// fields a Location. MajorIteration calls will not modify the values in
    90  	// the Location, but Evaluation operations will. Methods are encouraged to
    91  	// leave Location fields untouched to allow memory re-use. If data needs to
    92  	// be stored, the respective field should be set to nil -- Methods should
    93  	// not allocate Location memory themselves.
    94  	//
    95  	// Method may have its own specific convergence criteria, which can
    96  	// be communicated using a MethodDone operation. This will trigger a
    97  	// PostIteration to be sent on result, and the MethodDone task will not be
    98  	// returned on result. The Method must implement Statuser, and the
    99  	// call to Status must return a Status other than NotTerminated.
   100  	//
   101  	// The operation and result tasks are guaranteed to have a buffer length
   102  	// equal to the return from Init.
   103  	Run(operation chan<- Task, result <-chan Task, tasks []Task)
   104  	// Uses checks if the Method is suited to the optimization problem. The
   105  	// input is the available functions in Problem to call, and the returns are
   106  	// the functions which may be used and an error if there is a mismatch
   107  	// between the Problem and the Method's capabilities.
   108  	Uses(has Available) (uses Available, err error)
   109  }
   110  
   111  // Minimize uses an optimizer to search for a minimum of a function. A
   112  // maximization problem can be transformed into a minimization problem by
   113  // multiplying the function by -1.
   114  //
   115  // The first argument represents the problem to be minimized. Its fields are
   116  // routines that evaluate the objective function, gradient, and other
   117  // quantities related to the problem. The objective function, p.Func, must not
   118  // be nil. The optimization method used may require other fields to be non-nil
   119  // as specified by method.Needs. Minimize will panic if these are not met. The
   120  // method can be determined automatically from the supplied problem which is
   121  // described below.
   122  //
   123  // If p.Status is not nil, it is called before every evaluation. If the
   124  // returned Status is other than NotTerminated or if the error is not nil, the
   125  // optimization run is terminated.
   126  //
   127  // The second argument specifies the initial location for the optimization.
   128  // Some Methods do not require an initial location, but initX must still be
   129  // specified for the dimension of the optimization problem.
   130  //
   131  // The third argument contains the settings for the minimization. If settings
   132  // is nil, the zero value will be used, see the documentation of the Settings
   133  // type for more information, and see the warning below. All settings will be
   134  // honored for all Methods, even if that setting is counter-productive to the
   135  // method. Minimize cannot guarantee strict adherence to the evaluation bounds
   136  // specified when performing concurrent evaluations and updates.
   137  //
   138  // The final argument is the optimization method to use. If method == nil, then
   139  // an appropriate default is chosen based on the properties of the other arguments
   140  // (dimension, gradient-free or gradient-based, etc.). If method is not nil,
   141  // Minimize panics if the Problem is not consistent with the Method (Uses
   142  // returns an error).
   143  //
   144  // Minimize returns a Result struct and any error that occurred. See the
   145  // documentation of Result for more information.
   146  //
   147  // See the documentation for Method for the details on implementing a method.
   148  //
   149  // Be aware that the default settings of Minimize are to accurately find the
   150  // minimum. For certain functions and optimization methods, this can take many
   151  // function evaluations. The Settings input struct can be used to limit this,
   152  // for example by modifying the maximum function evaluations or gradient tolerance.
   153  func Minimize(p Problem, initX []float64, settings *Settings, method Method) (*Result, error) {
   154  	startTime := time.Now()
   155  	if method == nil {
   156  		method = getDefaultMethod(&p)
   157  	}
   158  	if settings == nil {
   159  		settings = &Settings{}
   160  	}
   161  	stats := &Stats{}
   162  	dim := len(initX)
   163  	err := checkOptimization(p, dim, settings.Recorder)
   164  	if err != nil {
   165  		return nil, err
   166  	}
   167  
   168  	optLoc := newLocation(dim) // This must have an allocated X field.
   169  	optLoc.F = math.Inf(1)
   170  
   171  	initOp, initLoc := getInitLocation(dim, initX, settings.InitValues)
   172  
   173  	converger := settings.Converger
   174  	if converger == nil {
   175  		converger = defaultFunctionConverge()
   176  	}
   177  	converger.Init(dim)
   178  
   179  	stats.Runtime = time.Since(startTime)
   180  
   181  	// Send initial location to Recorder
   182  	if settings.Recorder != nil {
   183  		err = settings.Recorder.Record(optLoc, InitIteration, stats)
   184  		if err != nil {
   185  			return nil, err
   186  		}
   187  	}
   188  
   189  	// Run optimization
   190  	var status Status
   191  	status, err = minimize(&p, method, settings, converger, stats, initOp, initLoc, optLoc, startTime)
   192  
   193  	// Cleanup and collect results
   194  	if settings.Recorder != nil && err == nil {
   195  		err = settings.Recorder.Record(optLoc, PostIteration, stats)
   196  	}
   197  	stats.Runtime = time.Since(startTime)
   198  	return &Result{
   199  		Location: *optLoc,
   200  		Stats:    *stats,
   201  		Status:   status,
   202  	}, err
   203  }
   204  
   205  func getDefaultMethod(p *Problem) Method {
   206  	if p.Grad != nil {
   207  		return &LBFGS{}
   208  	}
   209  	return &NelderMead{}
   210  }
   211  
   212  // minimize performs an optimization. minimize updates the settings and optLoc,
   213  // and returns the final Status and error.
   214  func minimize(prob *Problem, method Method, settings *Settings, converger Converger, stats *Stats, initOp Operation, initLoc, optLoc *Location, startTime time.Time) (Status, error) {
   215  	dim := len(optLoc.X)
   216  	nTasks := settings.Concurrent
   217  	if nTasks == 0 {
   218  		nTasks = 1
   219  	}
   220  	has := availFromProblem(*prob)
   221  	_, initErr := method.Uses(has)
   222  	if initErr != nil {
   223  		panic(fmt.Sprintf("optimize: specified method inconsistent with Problem: %v", initErr))
   224  	}
   225  	newNTasks := method.Init(dim, nTasks)
   226  	if newNTasks > nTasks {
   227  		panic("optimize: too many tasks returned by Method")
   228  	}
   229  	nTasks = newNTasks
   230  
   231  	// Launch the method. The method communicates tasks using the operations
   232  	// channel, and results is used to return the evaluated results.
   233  	operations := make(chan Task, nTasks)
   234  	results := make(chan Task, nTasks)
   235  	go func() {
   236  		tasks := make([]Task, nTasks)
   237  		tasks[0].Location = initLoc
   238  		tasks[0].Op = initOp
   239  		for i := 1; i < len(tasks); i++ {
   240  			tasks[i].Location = newLocation(dim)
   241  		}
   242  		method.Run(operations, results, tasks)
   243  	}()
   244  
   245  	// Algorithmic Overview:
   246  	// There are three pieces to performing a concurrent optimization,
   247  	// the distributor, the workers, and the stats combiner. At a high level,
   248  	// the distributor reads in tasks sent by method, sending evaluations to the
   249  	// workers, and forwarding other operations to the statsCombiner. The workers
   250  	// read these forwarded evaluation tasks, evaluate the relevant parts of Problem
   251  	// and forward the results on to the stats combiner. The stats combiner reads
   252  	// in results from the workers, as well as tasks from the distributor, and
   253  	// uses them to update optimization statistics (function evaluations, etc.)
   254  	// and to check optimization convergence.
   255  	//
   256  	// The complicated part is correctly shutting down the optimization. The
   257  	// procedure is as follows. First, the stats combiner closes done and sends
   258  	// a PostIteration to the method. The distributor then reads that done has
   259  	// been closed, and closes the channel with the workers. At this point, no
   260  	// more evaluation operations will be executed. As the workers finish their
   261  	// evaluations, they forward the results onto the stats combiner, and then
   262  	// signal their shutdown to the stats combiner. When all workers have successfully
   263  	// finished, the stats combiner closes the results channel, signaling to the
   264  	// method that all results have been collected. At this point, the method
   265  	// may send MajorIteration(s) to update an optimum location based on these
   266  	// last returned results, and then the method will close the operations channel.
   267  	// The Method must ensure that the closing of results happens before the
   268  	// closing of operations in order to ensure proper shutdown order.
   269  	// Now that no more tasks will be commanded by the method, the distributor
   270  	// closes statsChan, and with no more statistics to update the optimization
   271  	// concludes.
   272  
   273  	workerChan := make(chan Task) // Delegate tasks to the workers.
   274  	statsChan := make(chan Task)  // Send evaluation updates.
   275  	done := make(chan struct{})   // Communicate the optimization is done.
   276  
   277  	// Read tasks from the method and distribute as appropriate.
   278  	distributor := func() {
   279  		for {
   280  			select {
   281  			case task := <-operations:
   282  				switch task.Op {
   283  				case InitIteration:
   284  					panic("optimize: Method returned InitIteration")
   285  				case PostIteration:
   286  					panic("optimize: Method returned PostIteration")
   287  				case NoOperation, MajorIteration, MethodDone:
   288  					statsChan <- task
   289  				default:
   290  					if !task.Op.isEvaluation() {
   291  						panic("optimize: expecting evaluation operation")
   292  					}
   293  					workerChan <- task
   294  				}
   295  			case <-done:
   296  				// No more evaluations will be sent, shut down the workers, and
   297  				// read the final tasks.
   298  				close(workerChan)
   299  				for task := range operations {
   300  					if task.Op == MajorIteration {
   301  						statsChan <- task
   302  					}
   303  				}
   304  				close(statsChan)
   305  				return
   306  			}
   307  		}
   308  	}
   309  	go distributor()
   310  
   311  	// Evaluate the Problem concurrently.
   312  	worker := func() {
   313  		x := make([]float64, dim)
   314  		for task := range workerChan {
   315  			evaluate(prob, task.Location, task.Op, x)
   316  			statsChan <- task
   317  		}
   318  		// Signal successful worker completion.
   319  		statsChan <- Task{Op: signalDone}
   320  	}
   321  	for i := 0; i < nTasks; i++ {
   322  		go worker()
   323  	}
   324  
   325  	var (
   326  		workersDone int // effective wg for the workers
   327  		status      Status
   328  		err         error
   329  		finalStatus Status
   330  		finalError  error
   331  	)
   332  
   333  	// Update optimization statistics and check convergence.
   334  	var methodDone bool
   335  	for task := range statsChan {
   336  		switch task.Op {
   337  		default:
   338  			if !task.Op.isEvaluation() {
   339  				panic("minimize: evaluation task expected")
   340  			}
   341  			updateEvaluationStats(stats, task.Op)
   342  			status, err = checkEvaluationLimits(prob, stats, settings)
   343  		case signalDone:
   344  			workersDone++
   345  			if workersDone == nTasks {
   346  				close(results)
   347  			}
   348  			continue
   349  		case NoOperation:
   350  			// Just send the task back.
   351  		case MajorIteration:
   352  			status = performMajorIteration(optLoc, task.Location, stats, converger, startTime, settings)
   353  		case MethodDone:
   354  			methodDone = true
   355  			status = MethodConverge
   356  		}
   357  		if settings.Recorder != nil && status == NotTerminated && err == nil {
   358  			stats.Runtime = time.Since(startTime)
   359  			// Allow err to be overloaded if the Recorder fails.
   360  			err = settings.Recorder.Record(task.Location, task.Op, stats)
   361  			if err != nil {
   362  				status = Failure
   363  			}
   364  		}
   365  		// If this is the first termination status, trigger the conclusion of
   366  		// the optimization.
   367  		if status != NotTerminated || err != nil {
   368  			select {
   369  			case <-done:
   370  			default:
   371  				finalStatus = status
   372  				finalError = err
   373  				results <- Task{
   374  					Op: PostIteration,
   375  				}
   376  				close(done)
   377  			}
   378  		}
   379  
   380  		// Send the result back to the Problem if there are still active workers.
   381  		if workersDone != nTasks && task.Op != MethodDone {
   382  			results <- task
   383  		}
   384  	}
   385  	// This code block is here rather than above to ensure Status() is not called
   386  	// before Method.Run closes operations.
   387  	if methodDone {
   388  		statuser, ok := method.(Statuser)
   389  		if !ok {
   390  			panic("optimize: method returned MethodDone but is not a Statuser")
   391  		}
   392  		finalStatus, finalError = statuser.Status()
   393  		if finalStatus == NotTerminated {
   394  			panic("optimize: method returned MethodDone but a NotTerminated status")
   395  		}
   396  	}
   397  	return finalStatus, finalError
   398  }
   399  
   400  func defaultFunctionConverge() *FunctionConverge {
   401  	return &FunctionConverge{
   402  		Absolute:   1e-10,
   403  		Iterations: 100,
   404  	}
   405  }
   406  
   407  // newLocation allocates a new location structure with an X field of the
   408  // appropriate size.
   409  func newLocation(dim int) *Location {
   410  	return &Location{
   411  		X: make([]float64, dim),
   412  	}
   413  }
   414  
   415  // getInitLocation checks the validity of initLocation and initOperation and
   416  // returns the initial values as a *Location.
   417  func getInitLocation(dim int, initX []float64, initValues *Location) (Operation, *Location) {
   418  	loc := newLocation(dim)
   419  	if initX == nil {
   420  		if initValues != nil {
   421  			panic("optimize: initValues is non-nil but no initial location specified")
   422  		}
   423  		return NoOperation, loc
   424  	}
   425  	copy(loc.X, initX)
   426  	if initValues == nil {
   427  		return NoOperation, loc
   428  	} else {
   429  		if initValues.X != nil {
   430  			panic("optimize: location specified in InitValues (only use InitX)")
   431  		}
   432  	}
   433  	loc.F = initValues.F
   434  	op := FuncEvaluation
   435  	if initValues.Gradient != nil {
   436  		if len(initValues.Gradient) != dim {
   437  			panic("optimize: initial gradient does not match problem dimension")
   438  		}
   439  		loc.Gradient = initValues.Gradient
   440  		op |= GradEvaluation
   441  	}
   442  	if initValues.Hessian != nil {
   443  		if initValues.Hessian.SymmetricDim() != dim {
   444  			panic("optimize: initial Hessian does not match problem dimension")
   445  		}
   446  		loc.Hessian = initValues.Hessian
   447  		op |= HessEvaluation
   448  	}
   449  	return op, loc
   450  }
   451  
   452  func checkOptimization(p Problem, dim int, recorder Recorder) error {
   453  	if p.Func == nil {
   454  		panic(badProblem)
   455  	}
   456  	if dim <= 0 {
   457  		panic("optimize: impossible problem dimension")
   458  	}
   459  	if p.Status != nil {
   460  		_, err := p.Status()
   461  		if err != nil {
   462  			return err
   463  		}
   464  	}
   465  	if recorder != nil {
   466  		err := recorder.Init()
   467  		if err != nil {
   468  			return err
   469  		}
   470  	}
   471  	return nil
   472  }
   473  
   474  // evaluate evaluates the routines specified by the Operation at loc.X, and stores
   475  // the answer into loc. loc.X is copied into x before evaluating in order to
   476  // prevent the routines from modifying it.
   477  func evaluate(p *Problem, loc *Location, op Operation, x []float64) {
   478  	if !op.isEvaluation() {
   479  		panic(fmt.Sprintf("optimize: invalid evaluation %v", op))
   480  	}
   481  	copy(x, loc.X)
   482  	if op&FuncEvaluation != 0 {
   483  		loc.F = p.Func(x)
   484  	}
   485  	if op&GradEvaluation != 0 {
   486  		// Make sure we have a destination in which to place the gradient.
   487  		if len(loc.Gradient) == 0 {
   488  			if cap(loc.Gradient) < len(x) {
   489  				loc.Gradient = make([]float64, len(x))
   490  			} else {
   491  				loc.Gradient = loc.Gradient[:len(x)]
   492  			}
   493  		}
   494  		p.Grad(loc.Gradient, x)
   495  	}
   496  	if op&HessEvaluation != 0 {
   497  		// Make sure we have a destination in which to place the Hessian.
   498  		switch {
   499  		case loc.Hessian == nil:
   500  			loc.Hessian = mat.NewSymDense(len(x), nil)
   501  		case loc.Hessian.IsEmpty():
   502  			loc.Hessian.ReuseAsSym(len(x))
   503  		}
   504  		p.Hess(loc.Hessian, x)
   505  	}
   506  }
   507  
   508  // updateEvaluationStats updates the statistics based on the operation.
   509  func updateEvaluationStats(stats *Stats, op Operation) {
   510  	if op&FuncEvaluation != 0 {
   511  		stats.FuncEvaluations++
   512  	}
   513  	if op&GradEvaluation != 0 {
   514  		stats.GradEvaluations++
   515  	}
   516  	if op&HessEvaluation != 0 {
   517  		stats.HessEvaluations++
   518  	}
   519  }
   520  
   521  // checkLocationConvergence checks if the current optimal location satisfies
   522  // any of the convergence criteria based on the function location.
   523  //
   524  // checkLocationConvergence returns NotTerminated if the Location does not satisfy
   525  // the convergence criteria given by settings. Otherwise a corresponding status is
   526  // returned.
   527  // Unlike checkLimits, checkConvergence is called only at MajorIterations.
   528  func checkLocationConvergence(loc *Location, settings *Settings, converger Converger) Status {
   529  	if math.IsInf(loc.F, -1) {
   530  		return FunctionNegativeInfinity
   531  	}
   532  	if loc.Gradient != nil && settings.GradientThreshold > 0 {
   533  		norm := floats.Norm(loc.Gradient, math.Inf(1))
   534  		if norm < settings.GradientThreshold {
   535  			return GradientThreshold
   536  		}
   537  	}
   538  	return converger.Converged(loc)
   539  }
   540  
   541  // checkEvaluationLimits checks the optimization limits after an evaluation
   542  // Operation. It checks the number of evaluations (of various kinds) and checks
   543  // the status of the Problem, if applicable.
   544  func checkEvaluationLimits(p *Problem, stats *Stats, settings *Settings) (Status, error) {
   545  	if p.Status != nil {
   546  		status, err := p.Status()
   547  		if err != nil || status != NotTerminated {
   548  			return status, err
   549  		}
   550  	}
   551  	if settings.FuncEvaluations > 0 && stats.FuncEvaluations >= settings.FuncEvaluations {
   552  		return FunctionEvaluationLimit, nil
   553  	}
   554  	if settings.GradEvaluations > 0 && stats.GradEvaluations >= settings.GradEvaluations {
   555  		return GradientEvaluationLimit, nil
   556  	}
   557  	if settings.HessEvaluations > 0 && stats.HessEvaluations >= settings.HessEvaluations {
   558  		return HessianEvaluationLimit, nil
   559  	}
   560  	return NotTerminated, nil
   561  }
   562  
   563  // checkIterationLimits checks the limits on iterations affected by MajorIteration.
   564  func checkIterationLimits(loc *Location, stats *Stats, settings *Settings) Status {
   565  	if settings.MajorIterations > 0 && stats.MajorIterations >= settings.MajorIterations {
   566  		return IterationLimit
   567  	}
   568  	if settings.Runtime > 0 && stats.Runtime >= settings.Runtime {
   569  		return RuntimeLimit
   570  	}
   571  	return NotTerminated
   572  }
   573  
   574  // performMajorIteration does all of the steps needed to perform a MajorIteration.
   575  // It increments the iteration count, updates the optimal location, and checks
   576  // the necessary convergence criteria.
   577  func performMajorIteration(optLoc, loc *Location, stats *Stats, converger Converger, startTime time.Time, settings *Settings) Status {
   578  	optLoc.F = loc.F
   579  	copy(optLoc.X, loc.X)
   580  	if loc.Gradient == nil {
   581  		optLoc.Gradient = nil
   582  	} else {
   583  		if optLoc.Gradient == nil {
   584  			optLoc.Gradient = make([]float64, len(loc.Gradient))
   585  		}
   586  		copy(optLoc.Gradient, loc.Gradient)
   587  	}
   588  	stats.MajorIterations++
   589  	stats.Runtime = time.Since(startTime)
   590  	status := checkLocationConvergence(optLoc, settings, converger)
   591  	if status != NotTerminated {
   592  		return status
   593  	}
   594  	return checkIterationLimits(optLoc, stats, settings)
   595  }