gonum.org/v1/gonum@v0.14.0/optimize/gradientdescent.go (about) 1 // Copyright ©2014 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package optimize 6 7 import "gonum.org/v1/gonum/floats" 8 9 var ( 10 _ Method = (*GradientDescent)(nil) 11 _ localMethod = (*GradientDescent)(nil) 12 _ NextDirectioner = (*GradientDescent)(nil) 13 ) 14 15 // GradientDescent implements the steepest descent optimization method that 16 // performs successive steps along the direction of the negative gradient. 17 type GradientDescent struct { 18 // Linesearcher selects suitable steps along the descent direction. 19 // If Linesearcher is nil, a reasonable default will be chosen. 20 Linesearcher Linesearcher 21 // StepSizer determines the initial step size along each direction. 22 // If StepSizer is nil, a reasonable default will be chosen. 23 StepSizer StepSizer 24 // GradStopThreshold sets the threshold for stopping if the gradient norm 25 // gets too small. If GradStopThreshold is 0 it is defaulted to 1e-12, and 26 // if it is NaN the setting is not used. 27 GradStopThreshold float64 28 29 ls *LinesearchMethod 30 31 status Status 32 err error 33 } 34 35 func (g *GradientDescent) Status() (Status, error) { 36 return g.status, g.err 37 } 38 39 func (*GradientDescent) Uses(has Available) (uses Available, err error) { 40 return has.gradient() 41 } 42 43 func (g *GradientDescent) Init(dim, tasks int) int { 44 g.status = NotTerminated 45 g.err = nil 46 return 1 47 } 48 49 func (g *GradientDescent) Run(operation chan<- Task, result <-chan Task, tasks []Task) { 50 g.status, g.err = localOptimizer{}.run(g, g.GradStopThreshold, operation, result, tasks) 51 close(operation) 52 } 53 54 func (g *GradientDescent) initLocal(loc *Location) (Operation, error) { 55 if g.Linesearcher == nil { 56 g.Linesearcher = &Backtracking{} 57 } 58 if g.StepSizer == nil { 59 g.StepSizer = &QuadraticStepSize{} 60 } 61 62 if g.ls == nil { 63 g.ls = &LinesearchMethod{} 64 } 65 g.ls.Linesearcher = g.Linesearcher 66 g.ls.NextDirectioner = g 67 68 return g.ls.Init(loc) 69 } 70 71 func (g *GradientDescent) iterateLocal(loc *Location) (Operation, error) { 72 return g.ls.Iterate(loc) 73 } 74 75 func (g *GradientDescent) InitDirection(loc *Location, dir []float64) (stepSize float64) { 76 copy(dir, loc.Gradient) 77 floats.Scale(-1, dir) 78 return g.StepSizer.Init(loc, dir) 79 } 80 81 func (g *GradientDescent) NextDirection(loc *Location, dir []float64) (stepSize float64) { 82 copy(dir, loc.Gradient) 83 floats.Scale(-1, dir) 84 return g.StepSizer.StepSize(loc, dir) 85 } 86 87 func (*GradientDescent) needs() struct { 88 Gradient bool 89 Hessian bool 90 } { 91 return struct { 92 Gradient bool 93 Hessian bool 94 }{true, false} 95 }