github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/optimize/local.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package optimize
     6  
     7  import (
     8  	"math"
     9  
    10  	"github.com/jingcheng-WU/gonum/floats"
    11  )
    12  
    13  // localOptimizer is a helper type for running an optimization using a LocalMethod.
    14  type localOptimizer struct{}
    15  
    16  // run controls the optimization run for a localMethod. The calling method
    17  // must close the operation channel at the conclusion of the optimization. This
    18  // provides a happens before relationship between the return of status and the
    19  // closure of operation, and thus a call to method.Status (if necessary).
    20  func (l localOptimizer) run(method localMethod, gradThresh float64, operation chan<- Task, result <-chan Task, tasks []Task) (Status, error) {
    21  	// Local methods start with a fully-specified initial location.
    22  	task := tasks[0]
    23  	task = l.initialLocation(operation, result, task, method)
    24  	if task.Op == PostIteration {
    25  		l.finish(operation, result)
    26  		return NotTerminated, nil
    27  	}
    28  	status, err := l.checkStartingLocation(task, gradThresh)
    29  	if err != nil {
    30  		l.finishMethodDone(operation, result, task)
    31  		return status, err
    32  	}
    33  
    34  	// Send a major iteration with the starting location.
    35  	task.Op = MajorIteration
    36  	operation <- task
    37  	task = <-result
    38  	if task.Op == PostIteration {
    39  		l.finish(operation, result)
    40  		return NotTerminated, nil
    41  	}
    42  	op, err := method.initLocal(task.Location)
    43  	if err != nil {
    44  		l.finishMethodDone(operation, result, task)
    45  		return Failure, err
    46  	}
    47  	task.Op = op
    48  	operation <- task
    49  Loop:
    50  	for {
    51  		r := <-result
    52  		switch r.Op {
    53  		case PostIteration:
    54  			break Loop
    55  		case MajorIteration:
    56  			// The last operation was a MajorIteration. Check if the gradient
    57  			// is below the threshold.
    58  			if status := l.checkGradientConvergence(r.Gradient, gradThresh); status != NotTerminated {
    59  				l.finishMethodDone(operation, result, task)
    60  				return GradientThreshold, nil
    61  			}
    62  			fallthrough
    63  		default:
    64  			op, err := method.iterateLocal(r.Location)
    65  			if err != nil {
    66  				l.finishMethodDone(operation, result, r)
    67  				return Failure, err
    68  			}
    69  			r.Op = op
    70  			operation <- r
    71  		}
    72  	}
    73  	l.finish(operation, result)
    74  	return NotTerminated, nil
    75  }
    76  
    77  // initialOperation returns the Operation needed to fill the initial location
    78  // based on the needs of the method and the values already supplied.
    79  func (localOptimizer) initialOperation(task Task, n needser) Operation {
    80  	var newOp Operation
    81  	op := task.Op
    82  	if op&FuncEvaluation == 0 {
    83  		newOp |= FuncEvaluation
    84  	}
    85  	needs := n.needs()
    86  	if needs.Gradient && op&GradEvaluation == 0 {
    87  		newOp |= GradEvaluation
    88  	}
    89  	if needs.Hessian && op&HessEvaluation == 0 {
    90  		newOp |= HessEvaluation
    91  	}
    92  	return newOp
    93  }
    94  
    95  // initialLocation fills the initial location based on the needs of the method.
    96  // The task passed to initialLocation should be the first task sent in RunGlobal.
    97  func (l localOptimizer) initialLocation(operation chan<- Task, result <-chan Task, task Task, needs needser) Task {
    98  	task.Op = l.initialOperation(task, needs)
    99  	operation <- task
   100  	return <-result
   101  }
   102  
   103  func (l localOptimizer) checkStartingLocation(task Task, gradThresh float64) (Status, error) {
   104  	if math.IsInf(task.F, 1) || math.IsNaN(task.F) {
   105  		return Failure, ErrFunc(task.F)
   106  	}
   107  	for i, v := range task.Gradient {
   108  		if math.IsInf(v, 0) || math.IsNaN(v) {
   109  			return Failure, ErrGrad{Grad: v, Index: i}
   110  		}
   111  	}
   112  	status := l.checkGradientConvergence(task.Gradient, gradThresh)
   113  	return status, nil
   114  }
   115  
   116  func (localOptimizer) checkGradientConvergence(gradient []float64, gradThresh float64) Status {
   117  	if gradient == nil || math.IsNaN(gradThresh) {
   118  		return NotTerminated
   119  	}
   120  	if gradThresh == 0 {
   121  		gradThresh = defaultGradientAbsTol
   122  	}
   123  	if norm := floats.Norm(gradient, math.Inf(1)); norm < gradThresh {
   124  		return GradientThreshold
   125  	}
   126  	return NotTerminated
   127  }
   128  
   129  // finish completes the channel operations to finish an optimization.
   130  func (localOptimizer) finish(operation chan<- Task, result <-chan Task) {
   131  	// Guarantee that result is closed before operation is closed.
   132  	for range result {
   133  	}
   134  }
   135  
   136  // finishMethodDone sends a MethodDone signal on operation, reads the result,
   137  // and completes the channel operations to finish an optimization.
   138  func (l localOptimizer) finishMethodDone(operation chan<- Task, result <-chan Task, task Task) {
   139  	task.Op = MethodDone
   140  	operation <- task
   141  	task = <-result
   142  	if task.Op != PostIteration {
   143  		panic("optimize: task should have returned post iteration")
   144  	}
   145  	l.finish(operation, result)
   146  }