github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/optimize/backtracking.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package optimize
     6  
     7  const (
     8  	defaultBacktrackingContraction = 0.5
     9  	defaultBacktrackingDecrease    = 1e-4
    10  	minimumBacktrackingStepSize    = 1e-20
    11  )
    12  
    13  var _ Linesearcher = (*Backtracking)(nil)
    14  
    15  // Backtracking is a Linesearcher that uses backtracking to find a point that
    16  // satisfies the Armijo condition with the given decrease factor. If the Armijo
    17  // condition has not been met, the step size is decreased by ContractionFactor.
    18  //
    19  // The Armijo condition only requires the gradient at the beginning of each
    20  // major iteration (not at successive step locations), and so Backtracking may
    21  // be a good linesearch for functions with expensive gradients. Backtracking is
    22  // not appropriate for optimizers that require the Wolfe conditions to be met,
    23  // such as BFGS.
    24  //
    25  // Both DecreaseFactor and ContractionFactor must be between zero and one, and
    26  // Backtracking will panic otherwise. If either DecreaseFactor or
    27  // ContractionFactor are zero, it will be set to a reasonable default.
    28  type Backtracking struct {
    29  	DecreaseFactor    float64 // Constant factor in the sufficient decrease (Armijo) condition.
    30  	ContractionFactor float64 // Step size multiplier at each iteration (step *= ContractionFactor).
    31  
    32  	stepSize float64
    33  	initF    float64
    34  	initG    float64
    35  
    36  	lastOp Operation
    37  }
    38  
    39  func (b *Backtracking) Init(f, g float64, step float64) Operation {
    40  	if step <= 0 {
    41  		panic("backtracking: bad step size")
    42  	}
    43  	if g >= 0 {
    44  		panic("backtracking: initial derivative is non-negative")
    45  	}
    46  
    47  	if b.ContractionFactor == 0 {
    48  		b.ContractionFactor = defaultBacktrackingContraction
    49  	}
    50  	if b.DecreaseFactor == 0 {
    51  		b.DecreaseFactor = defaultBacktrackingDecrease
    52  	}
    53  	if b.ContractionFactor <= 0 || b.ContractionFactor >= 1 {
    54  		panic("backtracking: ContractionFactor must be between 0 and 1")
    55  	}
    56  	if b.DecreaseFactor <= 0 || b.DecreaseFactor >= 1 {
    57  		panic("backtracking: DecreaseFactor must be between 0 and 1")
    58  	}
    59  
    60  	b.stepSize = step
    61  	b.initF = f
    62  	b.initG = g
    63  
    64  	b.lastOp = FuncEvaluation
    65  	return b.lastOp
    66  }
    67  
    68  func (b *Backtracking) Iterate(f, _ float64) (Operation, float64, error) {
    69  	if b.lastOp != FuncEvaluation {
    70  		panic("backtracking: Init has not been called")
    71  	}
    72  
    73  	if ArmijoConditionMet(f, b.initF, b.initG, b.stepSize, b.DecreaseFactor) {
    74  		b.lastOp = MajorIteration
    75  		return b.lastOp, b.stepSize, nil
    76  	}
    77  	b.stepSize *= b.ContractionFactor
    78  	if b.stepSize < minimumBacktrackingStepSize {
    79  		b.lastOp = NoOperation
    80  		return b.lastOp, b.stepSize, ErrLinesearcherFailure
    81  	}
    82  	b.lastOp = FuncEvaluation
    83  	return b.lastOp, b.stepSize, nil
    84  }