github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/optimize/linesearch.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package optimize
     6  
     7  import (
     8  	"math"
     9  
    10  	"github.com/jingcheng-WU/gonum/floats"
    11  )
    12  
    13  // LinesearchMethod represents an abstract optimization method in which a
    14  // function is optimized through successive line search optimizations.
    15  type LinesearchMethod struct {
    16  	// NextDirectioner specifies the search direction of each linesearch.
    17  	NextDirectioner NextDirectioner
    18  	// Linesearcher performs a linesearch along the search direction.
    19  	Linesearcher Linesearcher
    20  
    21  	x   []float64 // Starting point for the current iteration.
    22  	dir []float64 // Search direction for the current iteration.
    23  
    24  	first     bool      // Indicator of the first iteration.
    25  	nextMajor bool      // Indicates that MajorIteration must be commanded at the next call to Iterate.
    26  	eval      Operation // Indicator of valid fields in Location.
    27  
    28  	lastStep float64   // Step taken from x in the previous call to Iterate.
    29  	lastOp   Operation // Operation returned from the previous call to Iterate.
    30  }
    31  
    32  func (ls *LinesearchMethod) Init(loc *Location) (Operation, error) {
    33  	if loc.Gradient == nil {
    34  		panic("linesearch: gradient is nil")
    35  	}
    36  
    37  	dim := len(loc.X)
    38  	ls.x = resize(ls.x, dim)
    39  	ls.dir = resize(ls.dir, dim)
    40  
    41  	ls.first = true
    42  	ls.nextMajor = false
    43  
    44  	// Indicate that all fields of loc are valid.
    45  	ls.eval = FuncEvaluation | GradEvaluation
    46  	if loc.Hessian != nil {
    47  		ls.eval |= HessEvaluation
    48  	}
    49  
    50  	ls.lastStep = math.NaN()
    51  	ls.lastOp = NoOperation
    52  
    53  	return ls.initNextLinesearch(loc)
    54  }
    55  
    56  func (ls *LinesearchMethod) Iterate(loc *Location) (Operation, error) {
    57  	switch ls.lastOp {
    58  	case NoOperation:
    59  		// TODO(vladimir-ch): Either Init has not been called, or the caller is
    60  		// trying to resume the optimization run after Iterate previously
    61  		// returned with an error. Decide what is the proper thing to do. See also #125.
    62  
    63  	case MajorIteration:
    64  		// The previous updated location did not converge the full
    65  		// optimization. Initialize a new Linesearch.
    66  		return ls.initNextLinesearch(loc)
    67  
    68  	default:
    69  		// Update the indicator of valid fields of loc.
    70  		ls.eval |= ls.lastOp
    71  
    72  		if ls.nextMajor {
    73  			ls.nextMajor = false
    74  
    75  			// Linesearcher previously finished, and the invalid fields of loc
    76  			// have now been validated. Announce MajorIteration.
    77  			ls.lastOp = MajorIteration
    78  			return ls.lastOp, nil
    79  		}
    80  	}
    81  
    82  	// Continue the linesearch.
    83  
    84  	f := math.NaN()
    85  	if ls.eval&FuncEvaluation != 0 {
    86  		f = loc.F
    87  	}
    88  	projGrad := math.NaN()
    89  	if ls.eval&GradEvaluation != 0 {
    90  		projGrad = floats.Dot(loc.Gradient, ls.dir)
    91  	}
    92  	op, step, err := ls.Linesearcher.Iterate(f, projGrad)
    93  	if err != nil {
    94  		return ls.error(err)
    95  	}
    96  
    97  	switch op {
    98  	case MajorIteration:
    99  		// Linesearch has been finished.
   100  
   101  		ls.lastOp = complementEval(loc, ls.eval)
   102  		if ls.lastOp == NoOperation {
   103  			// loc is complete, MajorIteration can be declared directly.
   104  			ls.lastOp = MajorIteration
   105  		} else {
   106  			// Declare MajorIteration on the next call to Iterate.
   107  			ls.nextMajor = true
   108  		}
   109  
   110  	case FuncEvaluation, GradEvaluation, FuncEvaluation | GradEvaluation:
   111  		if step != ls.lastStep {
   112  			// We are moving to a new location, and not, say, evaluating extra
   113  			// information at the current location.
   114  
   115  			// Compute the next evaluation point and store it in loc.X.
   116  			floats.AddScaledTo(loc.X, ls.x, step, ls.dir)
   117  			if floats.Equal(ls.x, loc.X) {
   118  				// Step size has become so small that the next evaluation point is
   119  				// indistinguishable from the starting point for the current
   120  				// iteration due to rounding errors.
   121  				return ls.error(ErrNoProgress)
   122  			}
   123  			ls.lastStep = step
   124  			ls.eval = NoOperation // Indicate all invalid fields of loc.
   125  		}
   126  		ls.lastOp = op
   127  
   128  	default:
   129  		panic("linesearch: Linesearcher returned invalid operation")
   130  	}
   131  
   132  	return ls.lastOp, nil
   133  }
   134  
   135  func (ls *LinesearchMethod) error(err error) (Operation, error) {
   136  	ls.lastOp = NoOperation
   137  	return ls.lastOp, err
   138  }
   139  
   140  // initNextLinesearch initializes the next linesearch using the previous
   141  // complete location stored in loc. It fills loc.X and returns an evaluation
   142  // to be performed at loc.X.
   143  func (ls *LinesearchMethod) initNextLinesearch(loc *Location) (Operation, error) {
   144  	copy(ls.x, loc.X)
   145  
   146  	var step float64
   147  	if ls.first {
   148  		ls.first = false
   149  		step = ls.NextDirectioner.InitDirection(loc, ls.dir)
   150  	} else {
   151  		step = ls.NextDirectioner.NextDirection(loc, ls.dir)
   152  	}
   153  
   154  	projGrad := floats.Dot(loc.Gradient, ls.dir)
   155  	if projGrad >= 0 {
   156  		return ls.error(ErrNonDescentDirection)
   157  	}
   158  
   159  	op := ls.Linesearcher.Init(loc.F, projGrad, step)
   160  	switch op {
   161  	case FuncEvaluation, GradEvaluation, FuncEvaluation | GradEvaluation:
   162  	default:
   163  		panic("linesearch: Linesearcher returned invalid operation")
   164  	}
   165  
   166  	floats.AddScaledTo(loc.X, ls.x, step, ls.dir)
   167  	if floats.Equal(ls.x, loc.X) {
   168  		// Step size is so small that the next evaluation point is
   169  		// indistinguishable from the starting point for the current iteration
   170  		// due to rounding errors.
   171  		return ls.error(ErrNoProgress)
   172  	}
   173  
   174  	ls.lastStep = step
   175  	ls.eval = NoOperation // Invalidate all fields of loc.
   176  
   177  	ls.lastOp = op
   178  	return ls.lastOp, nil
   179  }
   180  
   181  // ArmijoConditionMet returns true if the Armijo condition (aka sufficient
   182  // decrease) has been met. Under normal conditions, the following should be
   183  // true, though this is not enforced:
   184  //  - initGrad < 0
   185  //  - step > 0
   186  //  - 0 < decrease < 1
   187  func ArmijoConditionMet(currObj, initObj, initGrad, step, decrease float64) bool {
   188  	return currObj <= initObj+decrease*step*initGrad
   189  }
   190  
   191  // StrongWolfeConditionsMet returns true if the strong Wolfe conditions have been met.
   192  // The strong Wolfe conditions ensure sufficient decrease in the function
   193  // value, and sufficient decrease in the magnitude of the projected gradient.
   194  // Under normal conditions, the following should be true, though this is not
   195  // enforced:
   196  //  - initGrad < 0
   197  //  - step > 0
   198  //  - 0 <= decrease < curvature < 1
   199  func StrongWolfeConditionsMet(currObj, currGrad, initObj, initGrad, step, decrease, curvature float64) bool {
   200  	if currObj > initObj+decrease*step*initGrad {
   201  		return false
   202  	}
   203  	return math.Abs(currGrad) < curvature*math.Abs(initGrad)
   204  }
   205  
   206  // WeakWolfeConditionsMet returns true if the weak Wolfe conditions have been met.
   207  // The weak Wolfe conditions ensure sufficient decrease in the function value,
   208  // and sufficient decrease in the value of the projected gradient. Under normal
   209  // conditions, the following should be true, though this is not enforced:
   210  //  - initGrad < 0
   211  //  - step > 0
   212  //  - 0 <= decrease < curvature< 1
   213  func WeakWolfeConditionsMet(currObj, currGrad, initObj, initGrad, step, decrease, curvature float64) bool {
   214  	if currObj > initObj+decrease*step*initGrad {
   215  		return false
   216  	}
   217  	return currGrad >= curvature*initGrad
   218  }