github.com/gopherd/gonum@v0.0.4/optimize/gradientdescent.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package optimize
     6  
     7  import "github.com/gopherd/gonum/floats"
     8  
     9  var (
    10  	_ Method          = (*GradientDescent)(nil)
    11  	_ localMethod     = (*GradientDescent)(nil)
    12  	_ NextDirectioner = (*GradientDescent)(nil)
    13  )
    14  
    15  // GradientDescent implements the steepest descent optimization method that
    16  // performs successive steps along the direction of the negative gradient.
    17  type GradientDescent struct {
    18  	// Linesearcher selects suitable steps along the descent direction.
    19  	// If Linesearcher is nil, a reasonable default will be chosen.
    20  	Linesearcher Linesearcher
    21  	// StepSizer determines the initial step size along each direction.
    22  	// If StepSizer is nil, a reasonable default will be chosen.
    23  	StepSizer StepSizer
    24  	// GradStopThreshold sets the threshold for stopping if the gradient norm
    25  	// gets too small. If GradStopThreshold is 0 it is defaulted to 1e-12, and
    26  	// if it is NaN the setting is not used.
    27  	GradStopThreshold float64
    28  
    29  	ls *LinesearchMethod
    30  
    31  	status Status
    32  	err    error
    33  }
    34  
    35  func (g *GradientDescent) Status() (Status, error) {
    36  	return g.status, g.err
    37  }
    38  
    39  func (*GradientDescent) Uses(has Available) (uses Available, err error) {
    40  	return has.gradient()
    41  }
    42  
    43  func (g *GradientDescent) Init(dim, tasks int) int {
    44  	g.status = NotTerminated
    45  	g.err = nil
    46  	return 1
    47  }
    48  
    49  func (g *GradientDescent) Run(operation chan<- Task, result <-chan Task, tasks []Task) {
    50  	g.status, g.err = localOptimizer{}.run(g, g.GradStopThreshold, operation, result, tasks)
    51  	close(operation)
    52  }
    53  
    54  func (g *GradientDescent) initLocal(loc *Location) (Operation, error) {
    55  	if g.Linesearcher == nil {
    56  		g.Linesearcher = &Backtracking{}
    57  	}
    58  	if g.StepSizer == nil {
    59  		g.StepSizer = &QuadraticStepSize{}
    60  	}
    61  
    62  	if g.ls == nil {
    63  		g.ls = &LinesearchMethod{}
    64  	}
    65  	g.ls.Linesearcher = g.Linesearcher
    66  	g.ls.NextDirectioner = g
    67  
    68  	return g.ls.Init(loc)
    69  }
    70  
    71  func (g *GradientDescent) iterateLocal(loc *Location) (Operation, error) {
    72  	return g.ls.Iterate(loc)
    73  }
    74  
    75  func (g *GradientDescent) InitDirection(loc *Location, dir []float64) (stepSize float64) {
    76  	copy(dir, loc.Gradient)
    77  	floats.Scale(-1, dir)
    78  	return g.StepSizer.Init(loc, dir)
    79  }
    80  
    81  func (g *GradientDescent) NextDirection(loc *Location, dir []float64) (stepSize float64) {
    82  	copy(dir, loc.Gradient)
    83  	floats.Scale(-1, dir)
    84  	return g.StepSizer.StepSize(loc, dir)
    85  }
    86  
    87  func (*GradientDescent) needs() struct {
    88  	Gradient bool
    89  	Hessian  bool
    90  } {
    91  	return struct {
    92  		Gradient bool
    93  		Hessian  bool
    94  	}{true, false}
    95  }