github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/optimize/local.go (about) 1 // Copyright ©2014 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package optimize 6 7 import ( 8 "math" 9 10 "github.com/jingcheng-WU/gonum/floats" 11 ) 12 13 // localOptimizer is a helper type for running an optimization using a LocalMethod. 14 type localOptimizer struct{} 15 16 // run controls the optimization run for a localMethod. The calling method 17 // must close the operation channel at the conclusion of the optimization. This 18 // provides a happens before relationship between the return of status and the 19 // closure of operation, and thus a call to method.Status (if necessary). 20 func (l localOptimizer) run(method localMethod, gradThresh float64, operation chan<- Task, result <-chan Task, tasks []Task) (Status, error) { 21 // Local methods start with a fully-specified initial location. 22 task := tasks[0] 23 task = l.initialLocation(operation, result, task, method) 24 if task.Op == PostIteration { 25 l.finish(operation, result) 26 return NotTerminated, nil 27 } 28 status, err := l.checkStartingLocation(task, gradThresh) 29 if err != nil { 30 l.finishMethodDone(operation, result, task) 31 return status, err 32 } 33 34 // Send a major iteration with the starting location. 35 task.Op = MajorIteration 36 operation <- task 37 task = <-result 38 if task.Op == PostIteration { 39 l.finish(operation, result) 40 return NotTerminated, nil 41 } 42 op, err := method.initLocal(task.Location) 43 if err != nil { 44 l.finishMethodDone(operation, result, task) 45 return Failure, err 46 } 47 task.Op = op 48 operation <- task 49 Loop: 50 for { 51 r := <-result 52 switch r.Op { 53 case PostIteration: 54 break Loop 55 case MajorIteration: 56 // The last operation was a MajorIteration. Check if the gradient 57 // is below the threshold. 58 if status := l.checkGradientConvergence(r.Gradient, gradThresh); status != NotTerminated { 59 l.finishMethodDone(operation, result, task) 60 return GradientThreshold, nil 61 } 62 fallthrough 63 default: 64 op, err := method.iterateLocal(r.Location) 65 if err != nil { 66 l.finishMethodDone(operation, result, r) 67 return Failure, err 68 } 69 r.Op = op 70 operation <- r 71 } 72 } 73 l.finish(operation, result) 74 return NotTerminated, nil 75 } 76 77 // initialOperation returns the Operation needed to fill the initial location 78 // based on the needs of the method and the values already supplied. 79 func (localOptimizer) initialOperation(task Task, n needser) Operation { 80 var newOp Operation 81 op := task.Op 82 if op&FuncEvaluation == 0 { 83 newOp |= FuncEvaluation 84 } 85 needs := n.needs() 86 if needs.Gradient && op&GradEvaluation == 0 { 87 newOp |= GradEvaluation 88 } 89 if needs.Hessian && op&HessEvaluation == 0 { 90 newOp |= HessEvaluation 91 } 92 return newOp 93 } 94 95 // initialLocation fills the initial location based on the needs of the method. 96 // The task passed to initialLocation should be the first task sent in RunGlobal. 97 func (l localOptimizer) initialLocation(operation chan<- Task, result <-chan Task, task Task, needs needser) Task { 98 task.Op = l.initialOperation(task, needs) 99 operation <- task 100 return <-result 101 } 102 103 func (l localOptimizer) checkStartingLocation(task Task, gradThresh float64) (Status, error) { 104 if math.IsInf(task.F, 1) || math.IsNaN(task.F) { 105 return Failure, ErrFunc(task.F) 106 } 107 for i, v := range task.Gradient { 108 if math.IsInf(v, 0) || math.IsNaN(v) { 109 return Failure, ErrGrad{Grad: v, Index: i} 110 } 111 } 112 status := l.checkGradientConvergence(task.Gradient, gradThresh) 113 return status, nil 114 } 115 116 func (localOptimizer) checkGradientConvergence(gradient []float64, gradThresh float64) Status { 117 if gradient == nil || math.IsNaN(gradThresh) { 118 return NotTerminated 119 } 120 if gradThresh == 0 { 121 gradThresh = defaultGradientAbsTol 122 } 123 if norm := floats.Norm(gradient, math.Inf(1)); norm < gradThresh { 124 return GradientThreshold 125 } 126 return NotTerminated 127 } 128 129 // finish completes the channel operations to finish an optimization. 130 func (localOptimizer) finish(operation chan<- Task, result <-chan Task) { 131 // Guarantee that result is closed before operation is closed. 132 for range result { 133 } 134 } 135 136 // finishMethodDone sends a MethodDone signal on operation, reads the result, 137 // and completes the channel operations to finish an optimization. 138 func (l localOptimizer) finishMethodDone(operation chan<- Task, result <-chan Task, task Task) { 139 task.Op = MethodDone 140 operation <- task 141 task = <-result 142 if task.Op != PostIteration { 143 panic("optimize: task should have returned post iteration") 144 } 145 l.finish(operation, result) 146 }