github.com/gopherd/gonum@v0.0.4/optimize/printer.go (about) 1 // Copyright ©2014 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package optimize 6 7 import ( 8 "fmt" 9 "io" 10 "math" 11 "os" 12 "time" 13 14 "github.com/gopherd/gonum/floats" 15 ) 16 17 var printerHeadings = [...]string{ 18 "Iter", 19 "Runtime", 20 "FuncEvals", 21 "Func", 22 "GradEvals", 23 "|Gradient|∞", 24 "HessEvals", 25 } 26 27 const ( 28 printerBaseTmpl = "%9v %16v %9v %22v" // Base template for headings and values that are always printed. 29 printerGradTmpl = " %9v %22v" // Appended to base template when loc.Gradient != nil. 30 printerHessTmpl = " %9v" // Appended to base template when loc.Hessian != nil. 31 ) 32 33 var _ Recorder = (*Printer)(nil) 34 35 // Printer writes column-format output to the specified writer as the optimization 36 // progresses. By default, it writes to os.Stdout. 37 type Printer struct { 38 Writer io.Writer 39 HeadingInterval int 40 ValueInterval time.Duration 41 42 lastHeading int 43 lastValue time.Time 44 } 45 46 func NewPrinter() *Printer { 47 return &Printer{ 48 Writer: os.Stdout, 49 HeadingInterval: 30, 50 ValueInterval: 500 * time.Millisecond, 51 } 52 } 53 54 func (p *Printer) Init() error { 55 p.lastHeading = p.HeadingInterval // So the headings are printed the first time. 56 p.lastValue = time.Now().Add(-p.ValueInterval) // So the values are printed the first time. 57 return nil 58 } 59 60 func (p *Printer) Record(loc *Location, op Operation, stats *Stats) error { 61 if op != MajorIteration && op != InitIteration && op != PostIteration { 62 return nil 63 } 64 65 // Print values always on PostIteration or when ValueInterval has elapsed. 66 printValues := time.Since(p.lastValue) > p.ValueInterval || op == PostIteration 67 if !printValues { 68 // Return early if not printing anything. 69 return nil 70 } 71 72 // Print heading when HeadingInterval lines have been printed, but never on PostIteration. 73 printHeading := p.lastHeading >= p.HeadingInterval && op != PostIteration 74 if printHeading { 75 p.lastHeading = 1 76 } else { 77 p.lastHeading++ 78 } 79 80 if printHeading { 81 headings := "\n" + fmt.Sprintf(printerBaseTmpl, printerHeadings[0], printerHeadings[1], printerHeadings[2], printerHeadings[3]) 82 if loc.Gradient != nil { 83 headings += fmt.Sprintf(printerGradTmpl, printerHeadings[4], printerHeadings[5]) 84 } 85 if loc.Hessian != nil { 86 headings += fmt.Sprintf(printerHessTmpl, printerHeadings[6]) 87 } 88 _, err := fmt.Fprintln(p.Writer, headings) 89 if err != nil { 90 return err 91 } 92 } 93 94 values := fmt.Sprintf(printerBaseTmpl, stats.MajorIterations, stats.Runtime, stats.FuncEvaluations, loc.F) 95 if loc.Gradient != nil { 96 values += fmt.Sprintf(printerGradTmpl, stats.GradEvaluations, floats.Norm(loc.Gradient, math.Inf(1))) 97 } 98 if loc.Hessian != nil { 99 values += fmt.Sprintf(printerHessTmpl, stats.HessEvaluations) 100 } 101 _, err := fmt.Fprintln(p.Writer, values) 102 if err != nil { 103 return err 104 } 105 106 p.lastValue = time.Now() 107 return nil 108 }