github.com/gopherd/gonum@v0.0.4/optimize/printer.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package optimize
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"math"
    11  	"os"
    12  	"time"
    13  
    14  	"github.com/gopherd/gonum/floats"
    15  )
    16  
    17  var printerHeadings = [...]string{
    18  	"Iter",
    19  	"Runtime",
    20  	"FuncEvals",
    21  	"Func",
    22  	"GradEvals",
    23  	"|Gradient|∞",
    24  	"HessEvals",
    25  }
    26  
    27  const (
    28  	printerBaseTmpl = "%9v  %16v  %9v  %22v" // Base template for headings and values that are always printed.
    29  	printerGradTmpl = "  %9v  %22v"          // Appended to base template when loc.Gradient != nil.
    30  	printerHessTmpl = "  %9v"                // Appended to base template when loc.Hessian != nil.
    31  )
    32  
    33  var _ Recorder = (*Printer)(nil)
    34  
    35  // Printer writes column-format output to the specified writer as the optimization
    36  // progresses. By default, it writes to os.Stdout.
    37  type Printer struct {
    38  	Writer          io.Writer
    39  	HeadingInterval int
    40  	ValueInterval   time.Duration
    41  
    42  	lastHeading int
    43  	lastValue   time.Time
    44  }
    45  
    46  func NewPrinter() *Printer {
    47  	return &Printer{
    48  		Writer:          os.Stdout,
    49  		HeadingInterval: 30,
    50  		ValueInterval:   500 * time.Millisecond,
    51  	}
    52  }
    53  
    54  func (p *Printer) Init() error {
    55  	p.lastHeading = p.HeadingInterval              // So the headings are printed the first time.
    56  	p.lastValue = time.Now().Add(-p.ValueInterval) // So the values are printed the first time.
    57  	return nil
    58  }
    59  
    60  func (p *Printer) Record(loc *Location, op Operation, stats *Stats) error {
    61  	if op != MajorIteration && op != InitIteration && op != PostIteration {
    62  		return nil
    63  	}
    64  
    65  	// Print values always on PostIteration or when ValueInterval has elapsed.
    66  	printValues := time.Since(p.lastValue) > p.ValueInterval || op == PostIteration
    67  	if !printValues {
    68  		// Return early if not printing anything.
    69  		return nil
    70  	}
    71  
    72  	// Print heading when HeadingInterval lines have been printed, but never on PostIteration.
    73  	printHeading := p.lastHeading >= p.HeadingInterval && op != PostIteration
    74  	if printHeading {
    75  		p.lastHeading = 1
    76  	} else {
    77  		p.lastHeading++
    78  	}
    79  
    80  	if printHeading {
    81  		headings := "\n" + fmt.Sprintf(printerBaseTmpl, printerHeadings[0], printerHeadings[1], printerHeadings[2], printerHeadings[3])
    82  		if loc.Gradient != nil {
    83  			headings += fmt.Sprintf(printerGradTmpl, printerHeadings[4], printerHeadings[5])
    84  		}
    85  		if loc.Hessian != nil {
    86  			headings += fmt.Sprintf(printerHessTmpl, printerHeadings[6])
    87  		}
    88  		_, err := fmt.Fprintln(p.Writer, headings)
    89  		if err != nil {
    90  			return err
    91  		}
    92  	}
    93  
    94  	values := fmt.Sprintf(printerBaseTmpl, stats.MajorIterations, stats.Runtime, stats.FuncEvaluations, loc.F)
    95  	if loc.Gradient != nil {
    96  		values += fmt.Sprintf(printerGradTmpl, stats.GradEvaluations, floats.Norm(loc.Gradient, math.Inf(1)))
    97  	}
    98  	if loc.Hessian != nil {
    99  		values += fmt.Sprintf(printerHessTmpl, stats.HessEvaluations)
   100  	}
   101  	_, err := fmt.Fprintln(p.Writer, values)
   102  	if err != nil {
   103  		return err
   104  	}
   105  
   106  	p.lastValue = time.Now()
   107  	return nil
   108  }