go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/model/training.go (about)

     1  package model
     2  
     3  import (
     4  	"fmt"
     5  	"go-ml.dev/pkg/base/fu"
     6  	"go-ml.dev/pkg/base/fu/lazy"
     7  	"go-ml.dev/pkg/base/tables"
     8  	"go-ml.dev/pkg/iokit"
     9  	"go-ml.dev/pkg/zorros"
    10  	"go-ml.dev/pkg/zorros/zlog"
    11  	"io"
    12  	"reflect"
    13  )
    14  
    15  /*
    16  Training is the default implementation of unified training interface
    17  */
    18  type Training struct {
    19  	Iterations   int          // maximum iterations
    20  	Metrics      Metrics      // evaluating metrics
    21  	Score        Score        // score function
    22  	ScoreHistory int          // possible count of forehead training with lower score
    23  	ModelFile    iokit.Output // file to store final model
    24  	Verbose      interface{}  // print function func(string)
    25  }
    26  
    27  type training struct {
    28  	Training
    29  	stash *ModelStash
    30  	done  bool
    31  }
    32  
    33  type workout struct {
    34  	iteration int
    35  	training  *training
    36  	perflog   [][2]fu.Struct
    37  	scorlog   []float64
    38  }
    39  
    40  const DefaultScoreHistory = 3
    41  
    42  func (t Training) Workout() Workout {
    43  	x := &training{
    44  		Training: t,
    45  		stash:    NewStash(fu.Fnzi(t.ScoreHistory, DefaultScoreHistory), "model-treaining-*.zip"),
    46  	}
    47  	return &workout{iteration: 0, training: x}
    48  }
    49  
    50  func (w *workout) Close() error {
    51  	return w.training.stash.Close()
    52  }
    53  
    54  func (w *workout) Iteration() int {
    55  	return w.iteration
    56  }
    57  
    58  func (w *workout) TrainMetrics() MetricsUpdater {
    59  	return w.training.Metrics.New(w.iteration, TrainSubset)
    60  }
    61  
    62  func (w *workout) TestMetrics() MetricsUpdater {
    63  	return w.training.Metrics.New(w.iteration, TestSubset)
    64  }
    65  
    66  func (w *workout) report(j int) (report *Report, err error) {
    67  	report = &Report{}
    68  	histlen := fu.Fnzi(w.training.ScoreHistory, DefaultScoreHistory)
    69  	if len(w.perflog) > 0 {
    70  		report.History = tables.Lazy(lazy.Flatn(w.perflog)).LuckyCollect()
    71  		if j == 0 {
    72  			l := fu.Mini(len(w.scorlog), histlen)
    73  			lj := len(w.scorlog) - l
    74  			j = fu.Indmaxd(w.scorlog[lj:]) + lj
    75  		}
    76  		report.TheBest = j
    77  		report.Train = w.perflog[j][0]
    78  		report.Test = w.perflog[j][1]
    79  		report.Score = w.scorlog[j]
    80  		if w.training.ModelFile != nil {
    81  			rd, e := w.training.stash.Reader(j)
    82  			if e != nil {
    83  				err = zorros.Trace(e)
    84  				return
    85  			}
    86  			wh, e := w.training.ModelFile.Create()
    87  			if e != nil {
    88  				err = zorros.Trace(e)
    89  				return
    90  			}
    91  			defer wh.End()
    92  			_, e = io.Copy(wh, rd)
    93  			if e != nil {
    94  				err = zorros.Trace(e)
    95  				return
    96  			}
    97  			if e = wh.Commit(); e != nil {
    98  				err = zorros.Trace(e)
    99  				return
   100  			}
   101  		}
   102  	} else {
   103  		report.History = tables.NewEmpty(w.training.Metrics.Names(), nil)
   104  	}
   105  	return
   106  }
   107  
   108  func (w *workout) Complete(m MemorizeMap, train, test fu.Struct, metricsDone bool) (report *Report, done bool, err error) {
   109  	histlen := fu.Fnzi(w.training.ScoreHistory, DefaultScoreHistory)
   110  	maxiter := fu.Maxi(w.training.Iterations, 1)
   111  	score := w.training.Score(train, test)
   112  	w.scorlog = append(w.scorlog, score)
   113  	w.perflog = append(w.perflog, [2]fu.Struct{train, test})
   114  	if w.training.ModelFile != nil {
   115  		o, e := w.training.stash.Output(w.iteration)
   116  		if e != nil {
   117  			err = zorros.Wrapf(e, "failed to create stash for model: %v", e.Error())
   118  			return
   119  		}
   120  		if err = Memorize(o, m); err != nil {
   121  			return
   122  		}
   123  	}
   124  	if metricsDone {
   125  		w.training.done = true
   126  		done = true
   127  		report, err = w.report(w.iteration)
   128  	} else if w.iteration == maxiter-1 || (w.iteration > histlen && fu.Indmaxd(w.scorlog[len(w.scorlog)-histlen:]) == 0) {
   129  		w.training.done = true
   130  		done = true
   131  		report, err = w.report(0)
   132  	}
   133  	if w.training.Verbose != nil {
   134  		w.Verbose(fmt.Sprintf(
   135  			"[%3d] loss: %.5f/%.5f, error: %.5f/%.5f, score: %.5f",
   136  			w.Iteration(), Loss(train), Loss(test), Error(train), Error(test), score))
   137  	}
   138  	return
   139  }
   140  
   141  func (w *workout) Verbose(s string) {
   142  	if w.training.Verbose != nil {
   143  		vf := reflect.ValueOf(w.training.Verbose)
   144  		vf.Call([]reflect.Value{reflect.ValueOf(s)})
   145  	}
   146  }
   147  
   148  func (w *workout) Next() Workout {
   149  	if w.training == nil {
   150  		//panic(zorros.Panic(zorros.Errorf("training is done")))
   151  		zlog.Warning("training is already done")
   152  		return nil
   153  	}
   154  	return &workout{
   155  		iteration: w.iteration + 1,
   156  		training:  w.training,
   157  		scorlog:   w.scorlog,
   158  		perflog:   w.perflog,
   159  	}
   160  }