go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/model/training.go (about) 1 package model 2 3 import ( 4 "fmt" 5 "go-ml.dev/pkg/base/fu" 6 "go-ml.dev/pkg/base/fu/lazy" 7 "go-ml.dev/pkg/base/tables" 8 "go-ml.dev/pkg/iokit" 9 "go-ml.dev/pkg/zorros" 10 "go-ml.dev/pkg/zorros/zlog" 11 "io" 12 "reflect" 13 ) 14 15 /* 16 Training is the default implementation of unified training interface 17 */ 18 type Training struct { 19 Iterations int // maximum iterations 20 Metrics Metrics // evaluating metrics 21 Score Score // score function 22 ScoreHistory int // possible count of forehead training with lower score 23 ModelFile iokit.Output // file to store final model 24 Verbose interface{} // print function func(string) 25 } 26 27 type training struct { 28 Training 29 stash *ModelStash 30 done bool 31 } 32 33 type workout struct { 34 iteration int 35 training *training 36 perflog [][2]fu.Struct 37 scorlog []float64 38 } 39 40 const DefaultScoreHistory = 3 41 42 func (t Training) Workout() Workout { 43 x := &training{ 44 Training: t, 45 stash: NewStash(fu.Fnzi(t.ScoreHistory, DefaultScoreHistory), "model-treaining-*.zip"), 46 } 47 return &workout{iteration: 0, training: x} 48 } 49 50 func (w *workout) Close() error { 51 return w.training.stash.Close() 52 } 53 54 func (w *workout) Iteration() int { 55 return w.iteration 56 } 57 58 func (w *workout) TrainMetrics() MetricsUpdater { 59 return w.training.Metrics.New(w.iteration, TrainSubset) 60 } 61 62 func (w *workout) TestMetrics() MetricsUpdater { 63 return w.training.Metrics.New(w.iteration, TestSubset) 64 } 65 66 func (w *workout) report(j int) (report *Report, err error) { 67 report = &Report{} 68 histlen := fu.Fnzi(w.training.ScoreHistory, DefaultScoreHistory) 69 if len(w.perflog) > 0 { 70 report.History = tables.Lazy(lazy.Flatn(w.perflog)).LuckyCollect() 71 if j == 0 { 72 l := fu.Mini(len(w.scorlog), histlen) 73 lj := len(w.scorlog) - l 74 j = fu.Indmaxd(w.scorlog[lj:]) + lj 75 } 76 report.TheBest = j 77 report.Train = w.perflog[j][0] 78 report.Test = w.perflog[j][1] 79 report.Score = w.scorlog[j] 80 if w.training.ModelFile != nil { 81 rd, e := w.training.stash.Reader(j) 82 if e != nil { 83 err = zorros.Trace(e) 84 return 85 } 86 wh, e := w.training.ModelFile.Create() 87 if e != nil { 88 err = zorros.Trace(e) 89 return 90 } 91 defer wh.End() 92 _, e = io.Copy(wh, rd) 93 if e != nil { 94 err = zorros.Trace(e) 95 return 96 } 97 if e = wh.Commit(); e != nil { 98 err = zorros.Trace(e) 99 return 100 } 101 } 102 } else { 103 report.History = tables.NewEmpty(w.training.Metrics.Names(), nil) 104 } 105 return 106 } 107 108 func (w *workout) Complete(m MemorizeMap, train, test fu.Struct, metricsDone bool) (report *Report, done bool, err error) { 109 histlen := fu.Fnzi(w.training.ScoreHistory, DefaultScoreHistory) 110 maxiter := fu.Maxi(w.training.Iterations, 1) 111 score := w.training.Score(train, test) 112 w.scorlog = append(w.scorlog, score) 113 w.perflog = append(w.perflog, [2]fu.Struct{train, test}) 114 if w.training.ModelFile != nil { 115 o, e := w.training.stash.Output(w.iteration) 116 if e != nil { 117 err = zorros.Wrapf(e, "failed to create stash for model: %v", e.Error()) 118 return 119 } 120 if err = Memorize(o, m); err != nil { 121 return 122 } 123 } 124 if metricsDone { 125 w.training.done = true 126 done = true 127 report, err = w.report(w.iteration) 128 } else if w.iteration == maxiter-1 || (w.iteration > histlen && fu.Indmaxd(w.scorlog[len(w.scorlog)-histlen:]) == 0) { 129 w.training.done = true 130 done = true 131 report, err = w.report(0) 132 } 133 if w.training.Verbose != nil { 134 w.Verbose(fmt.Sprintf( 135 "[%3d] loss: %.5f/%.5f, error: %.5f/%.5f, score: %.5f", 136 w.Iteration(), Loss(train), Loss(test), Error(train), Error(test), score)) 137 } 138 return 139 } 140 141 func (w *workout) Verbose(s string) { 142 if w.training.Verbose != nil { 143 vf := reflect.ValueOf(w.training.Verbose) 144 vf.Call([]reflect.Value{reflect.ValueOf(s)}) 145 } 146 } 147 148 func (w *workout) Next() Workout { 149 if w.training == nil { 150 //panic(zorros.Panic(zorros.Errorf("training is done"))) 151 zlog.Warning("training is already done") 152 return nil 153 } 154 return &workout{ 155 iteration: w.iteration + 1, 156 training: w.training, 157 scorlog: w.scorlog, 158 perflog: w.perflog, 159 } 160 }