github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/optimize/morethuente.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package optimize
     6  
     7  import "math"
     8  
     9  var _ Linesearcher = (*MoreThuente)(nil)
    10  
    11  // MoreThuente is a Linesearcher that finds steps that satisfy both the
    12  // sufficient decrease and curvature conditions (the strong Wolfe conditions).
    13  //
    14  // References:
    15  //  - More, J.J. and D.J. Thuente: Line Search Algorithms with Guaranteed Sufficient
    16  //    Decrease. ACM Transactions on Mathematical Software 20(3) (1994), 286-307
    17  type MoreThuente struct {
    18  	// DecreaseFactor is the constant factor in the sufficient decrease
    19  	// (Armijo) condition.
    20  	// It must be in the interval [0, 1). The default value is 0.
    21  	DecreaseFactor float64
    22  	// CurvatureFactor is the constant factor in the Wolfe conditions. Smaller
    23  	// values result in a more exact line search.
    24  	// A set value must be in the interval (0, 1). If it is zero, it will be
    25  	// defaulted to 0.9.
    26  	CurvatureFactor float64
    27  	// StepTolerance sets the minimum acceptable width for the linesearch
    28  	// interval. If the relative interval length is less than this value,
    29  	// ErrLinesearcherFailure is returned.
    30  	// It must be non-negative. If it is zero, it will be defaulted to 1e-10.
    31  	StepTolerance float64
    32  
    33  	// MinimumStep is the minimum step that the linesearcher will take.
    34  	// It must be non-negative and less than MaximumStep. Defaults to no
    35  	// minimum (a value of 0).
    36  	MinimumStep float64
    37  	// MaximumStep is the maximum step that the linesearcher will take.
    38  	// It must be greater than MinimumStep. If it is zero, it will be defaulted
    39  	// to 1e20.
    40  	MaximumStep float64
    41  
    42  	bracketed bool    // Indicates if a minimum has been bracketed.
    43  	fInit     float64 // Function value at step = 0.
    44  	gInit     float64 // Derivative value at step = 0.
    45  
    46  	// When stage is 1, the algorithm updates the interval given by x and y
    47  	// so that it contains a minimizer of the modified function
    48  	//  psi(step) = f(step) - f(0) - DecreaseFactor * step * f'(0).
    49  	// When stage is 2, the interval is updated so that it contains a minimizer
    50  	// of f.
    51  	stage int
    52  
    53  	step         float64    // Current step.
    54  	lower, upper float64    // Lower and upper bounds on the next step.
    55  	x            float64    // Endpoint of the interval with a lower function value.
    56  	fx, gx       float64    // Data at x.
    57  	y            float64    // The other endpoint.
    58  	fy, gy       float64    // Data at y.
    59  	width        [2]float64 // Width of the interval at two previous iterations.
    60  }
    61  
    62  const (
    63  	mtMinGrowthFactor float64 = 1.1
    64  	mtMaxGrowthFactor float64 = 4
    65  )
    66  
    67  func (mt *MoreThuente) Init(f, g float64, step float64) Operation {
    68  	// Based on the original Fortran code that is available, for example, from
    69  	//  http://ftp.mcs.anl.gov/pub/MINPACK-2/csrch/
    70  	// as part of
    71  	//  MINPACK-2 Project. November 1993.
    72  	//  Argonne National Laboratory and University of Minnesota.
    73  	//  Brett M. Averick, Richard G. Carter, and Jorge J. Moré.
    74  
    75  	if g >= 0 {
    76  		panic("morethuente: initial derivative is non-negative")
    77  	}
    78  	if step <= 0 {
    79  		panic("morethuente: invalid initial step")
    80  	}
    81  
    82  	if mt.CurvatureFactor == 0 {
    83  		mt.CurvatureFactor = 0.9
    84  	}
    85  	if mt.StepTolerance == 0 {
    86  		mt.StepTolerance = 1e-10
    87  	}
    88  	if mt.MaximumStep == 0 {
    89  		mt.MaximumStep = 1e20
    90  	}
    91  
    92  	if mt.MinimumStep < 0 {
    93  		panic("morethuente: minimum step is negative")
    94  	}
    95  	if mt.MaximumStep <= mt.MinimumStep {
    96  		panic("morethuente: maximum step is not greater than minimum step")
    97  	}
    98  	if mt.DecreaseFactor < 0 || mt.DecreaseFactor >= 1 {
    99  		panic("morethuente: invalid decrease factor")
   100  	}
   101  	if mt.CurvatureFactor <= 0 || mt.CurvatureFactor >= 1 {
   102  		panic("morethuente: invalid curvature factor")
   103  	}
   104  	if mt.StepTolerance <= 0 {
   105  		panic("morethuente: step tolerance is not positive")
   106  	}
   107  
   108  	if step < mt.MinimumStep {
   109  		step = mt.MinimumStep
   110  	}
   111  	if step > mt.MaximumStep {
   112  		step = mt.MaximumStep
   113  	}
   114  
   115  	mt.bracketed = false
   116  	mt.stage = 1
   117  	mt.fInit = f
   118  	mt.gInit = g
   119  
   120  	mt.x, mt.fx, mt.gx = 0, f, g
   121  	mt.y, mt.fy, mt.gy = 0, f, g
   122  
   123  	mt.lower = 0
   124  	mt.upper = step + mtMaxGrowthFactor*step
   125  
   126  	mt.width[0] = mt.MaximumStep - mt.MinimumStep
   127  	mt.width[1] = 2 * mt.width[0]
   128  
   129  	mt.step = step
   130  	return FuncEvaluation | GradEvaluation
   131  }
   132  
   133  func (mt *MoreThuente) Iterate(f, g float64) (Operation, float64, error) {
   134  	if mt.stage == 0 {
   135  		panic("morethuente: Init has not been called")
   136  	}
   137  
   138  	gTest := mt.DecreaseFactor * mt.gInit
   139  	fTest := mt.fInit + mt.step*gTest
   140  
   141  	if mt.bracketed {
   142  		if mt.step <= mt.lower || mt.step >= mt.upper || mt.upper-mt.lower <= mt.StepTolerance*mt.upper {
   143  			// step contains the best step found (see below).
   144  			return NoOperation, mt.step, ErrLinesearcherFailure
   145  		}
   146  	}
   147  	if mt.step == mt.MaximumStep && f <= fTest && g <= gTest {
   148  		return NoOperation, mt.step, ErrLinesearcherBound
   149  	}
   150  	if mt.step == mt.MinimumStep && (f > fTest || g >= gTest) {
   151  		return NoOperation, mt.step, ErrLinesearcherFailure
   152  	}
   153  
   154  	// Test for convergence.
   155  	if f <= fTest && math.Abs(g) <= mt.CurvatureFactor*(-mt.gInit) {
   156  		mt.stage = 0
   157  		return MajorIteration, mt.step, nil
   158  	}
   159  
   160  	if mt.stage == 1 && f <= fTest && g >= 0 {
   161  		mt.stage = 2
   162  	}
   163  
   164  	if mt.stage == 1 && f <= mt.fx && f > fTest {
   165  		// Lower function value but the decrease is not sufficient .
   166  
   167  		// Compute values and derivatives of the modified function at step, x, y.
   168  		fm := f - mt.step*gTest
   169  		fxm := mt.fx - mt.x*gTest
   170  		fym := mt.fy - mt.y*gTest
   171  		gm := g - gTest
   172  		gxm := mt.gx - gTest
   173  		gym := mt.gy - gTest
   174  		// Update x, y and step.
   175  		mt.nextStep(fxm, gxm, fym, gym, fm, gm)
   176  		// Recover values and derivates of the non-modified function at x and y.
   177  		mt.fx = fxm + mt.x*gTest
   178  		mt.fy = fym + mt.y*gTest
   179  		mt.gx = gxm + gTest
   180  		mt.gy = gym + gTest
   181  	} else {
   182  		// Update x, y and step.
   183  		mt.nextStep(mt.fx, mt.gx, mt.fy, mt.gy, f, g)
   184  	}
   185  
   186  	if mt.bracketed {
   187  		// Monitor the length of the bracketing interval. If the interval has
   188  		// not been reduced sufficiently after two steps, use bisection to
   189  		// force its length to zero.
   190  		width := mt.y - mt.x
   191  		if math.Abs(width) >= 2.0/3*mt.width[1] {
   192  			mt.step = mt.x + 0.5*width
   193  		}
   194  		mt.width[0], mt.width[1] = math.Abs(width), mt.width[0]
   195  	}
   196  
   197  	if mt.bracketed {
   198  		mt.lower = math.Min(mt.x, mt.y)
   199  		mt.upper = math.Max(mt.x, mt.y)
   200  	} else {
   201  		mt.lower = mt.step + mtMinGrowthFactor*(mt.step-mt.x)
   202  		mt.upper = mt.step + mtMaxGrowthFactor*(mt.step-mt.x)
   203  	}
   204  
   205  	// Force the step to be in [MinimumStep, MaximumStep].
   206  	mt.step = math.Max(mt.MinimumStep, math.Min(mt.step, mt.MaximumStep))
   207  
   208  	if mt.bracketed {
   209  		if mt.step <= mt.lower || mt.step >= mt.upper || mt.upper-mt.lower <= mt.StepTolerance*mt.upper {
   210  			// If further progress is not possible, set step to the best step
   211  			// obtained during the search.
   212  			mt.step = mt.x
   213  		}
   214  	}
   215  
   216  	return FuncEvaluation | GradEvaluation, mt.step, nil
   217  }
   218  
   219  // nextStep computes the next safeguarded step and updates the interval that
   220  // contains a step that satisfies the sufficient decrease and curvature
   221  // conditions.
   222  func (mt *MoreThuente) nextStep(fx, gx, fy, gy, f, g float64) {
   223  	x := mt.x
   224  	y := mt.y
   225  	step := mt.step
   226  
   227  	gNeg := g < 0
   228  	if gx < 0 {
   229  		gNeg = !gNeg
   230  	}
   231  
   232  	var next float64
   233  	var bracketed bool
   234  	switch {
   235  	case f > fx:
   236  		// A higher function value. The minimum is bracketed between x and step.
   237  		// We want the next step to be closer to x because the function value
   238  		// there is lower.
   239  
   240  		theta := 3*(fx-f)/(step-x) + gx + g
   241  		s := math.Max(math.Abs(gx), math.Abs(g))
   242  		s = math.Max(s, math.Abs(theta))
   243  		gamma := s * math.Sqrt((theta/s)*(theta/s)-(gx/s)*(g/s))
   244  		if step < x {
   245  			gamma *= -1
   246  		}
   247  		p := gamma - gx + theta
   248  		q := gamma - gx + gamma + g
   249  		r := p / q
   250  		stpc := x + r*(step-x)
   251  		stpq := x + gx/((fx-f)/(step-x)+gx)/2*(step-x)
   252  
   253  		if math.Abs(stpc-x) < math.Abs(stpq-x) {
   254  			// The cubic step is closer to x than the quadratic step.
   255  			// Take the cubic step.
   256  			next = stpc
   257  		} else {
   258  			// If f is much larger than fx, then the quadratic step may be too
   259  			// close to x. Therefore heuristically take the average of the
   260  			// cubic and quadratic steps.
   261  			next = stpc + (stpq-stpc)/2
   262  		}
   263  		bracketed = true
   264  
   265  	case gNeg:
   266  		// A lower function value and derivatives of opposite sign. The minimum
   267  		// is bracketed between x and step. If we choose a step that is far
   268  		// from step, the next iteration will also likely fall in this case.
   269  
   270  		theta := 3*(fx-f)/(step-x) + gx + g
   271  		s := math.Max(math.Abs(gx), math.Abs(g))
   272  		s = math.Max(s, math.Abs(theta))
   273  		gamma := s * math.Sqrt((theta/s)*(theta/s)-(gx/s)*(g/s))
   274  		if step > x {
   275  			gamma *= -1
   276  		}
   277  		p := gamma - g + theta
   278  		q := gamma - g + gamma + gx
   279  		r := p / q
   280  		stpc := step + r*(x-step)
   281  		stpq := step + g/(g-gx)*(x-step)
   282  
   283  		if math.Abs(stpc-step) > math.Abs(stpq-step) {
   284  			// The cubic step is farther from x than the quadratic step.
   285  			// Take the cubic step.
   286  			next = stpc
   287  		} else {
   288  			// Take the quadratic step.
   289  			next = stpq
   290  		}
   291  		bracketed = true
   292  
   293  	case math.Abs(g) < math.Abs(gx):
   294  		// A lower function value, derivatives of the same sign, and the
   295  		// magnitude of the derivative decreases. Extrapolate function values
   296  		// at x and step so that the next step lies between step and y.
   297  
   298  		theta := 3*(fx-f)/(step-x) + gx + g
   299  		s := math.Max(math.Abs(gx), math.Abs(g))
   300  		s = math.Max(s, math.Abs(theta))
   301  		gamma := s * math.Sqrt(math.Max(0, (theta/s)*(theta/s)-(gx/s)*(g/s)))
   302  		if step > x {
   303  			gamma *= -1
   304  		}
   305  		p := gamma - g + theta
   306  		q := gamma + gx - g + gamma
   307  		r := p / q
   308  		var stpc float64
   309  		switch {
   310  		case r < 0 && gamma != 0:
   311  			stpc = step + r*(x-step)
   312  		case step > x:
   313  			stpc = mt.upper
   314  		default:
   315  			stpc = mt.lower
   316  		}
   317  		stpq := step + g/(g-gx)*(x-step)
   318  
   319  		if mt.bracketed {
   320  			// We are extrapolating so be cautious and take the step that
   321  			// is closer to step.
   322  			if math.Abs(stpc-step) < math.Abs(stpq-step) {
   323  				next = stpc
   324  			} else {
   325  				next = stpq
   326  			}
   327  			// Modify next if it is close to or beyond y.
   328  			if step > x {
   329  				next = math.Min(step+2.0/3*(y-step), next)
   330  			} else {
   331  				next = math.Max(step+2.0/3*(y-step), next)
   332  			}
   333  		} else {
   334  			// Minimum has not been bracketed so take the larger step...
   335  			if math.Abs(stpc-step) > math.Abs(stpq-step) {
   336  				next = stpc
   337  			} else {
   338  				next = stpq
   339  			}
   340  			// ...but within reason.
   341  			next = math.Max(mt.lower, math.Min(next, mt.upper))
   342  		}
   343  
   344  	default:
   345  		// A lower function value, derivatives of the same sign, and the
   346  		// magnitude of the derivative does not decrease. The function seems to
   347  		// decrease rapidly in the direction of the step.
   348  
   349  		switch {
   350  		case mt.bracketed:
   351  			theta := 3*(f-fy)/(y-step) + gy + g
   352  			s := math.Max(math.Abs(gy), math.Abs(g))
   353  			s = math.Max(s, math.Abs(theta))
   354  			gamma := s * math.Sqrt((theta/s)*(theta/s)-(gy/s)*(g/s))
   355  			if step > y {
   356  				gamma *= -1
   357  			}
   358  			p := gamma - g + theta
   359  			q := gamma - g + gamma + gy
   360  			r := p / q
   361  			next = step + r*(y-step)
   362  		case step > x:
   363  			next = mt.upper
   364  		default:
   365  			next = mt.lower
   366  		}
   367  	}
   368  
   369  	if f > fx {
   370  		// x is still the best step.
   371  		mt.y = step
   372  		mt.fy = f
   373  		mt.gy = g
   374  	} else {
   375  		// step is the new best step.
   376  		if gNeg {
   377  			mt.y = x
   378  			mt.fy = fx
   379  			mt.gy = gx
   380  		}
   381  		mt.x = step
   382  		mt.fx = f
   383  		mt.gx = g
   384  	}
   385  	mt.bracketed = bracketed
   386  	mt.step = next
   387  }