go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/model/model.go (about) 1 package model 2 3 import ( 4 "go-ml.dev/pkg/base/fu" 5 "go-ml.dev/pkg/base/tables" 6 "go-ml.dev/pkg/iokit" 7 "go-ml.dev/pkg/zorros" 8 "io" 9 "path/filepath" 10 "reflect" 11 ) 12 13 /* 14 HungryModel is an ML algorithm grows from a data to predict something 15 Needs to be fattened by Feed method to fit. 16 */ 17 type HungryModel interface { 18 Feed(Dataset) FatModel 19 } 20 21 /* 22 Report is an ML training report 23 */ 24 type Report struct { 25 History *tables.Table // all iterations history 26 TheBest int // the best iteration 27 Test, Train fu.Struct // the best iteration metrics 28 Score float64 // the best score 29 } 30 31 /* 32 Workout is a training iteration abstraction 33 */ 34 type Workout interface { 35 Iteration() int 36 TrainMetrics() MetricsUpdater 37 TestMetrics() MetricsUpdater 38 Complete(m MemorizeMap, train, test fu.Struct, metricsDone bool) (*Report, bool, error) 39 Next() Workout 40 Verbose(string) 41 } 42 43 /* 44 UnifiedTraining is an interface allowing to write any logging/staging backend for ML training 45 */ 46 type UnifiedTraining interface { 47 // Workout returns the first iteration workout 48 Workout() Workout 49 } 50 51 /* 52 FatModel is fattened model (a training function of model instance bounded to a dataset) 53 */ 54 type FatModel func(workout Workout) (*Report, error) 55 56 /* 57 Train a fattened (Fat) model 58 */ 59 func (f FatModel) Train(training UnifiedTraining) (*Report, error) { 60 w := training.Workout() 61 if c, ok := w.(io.Closer); ok { 62 defer c.Close() 63 } 64 return f(w) 65 } 66 67 /* 68 LuckyTrain trains fattened (Fat) model and trows any occurred errors as a panic 69 */ 70 func (f FatModel) LuckyTrain(training UnifiedTraining) *Report { 71 m, err := f.Train(training) 72 if err != nil { 73 panic(zorros.Panic(err)) 74 } 75 return m 76 } 77 78 /* 79 PredictionModel is a predictor interface 80 */ 81 type PredictionModel interface { 82 // Features model uses when maps features 83 // the same as Features in the training dataset 84 Features() []string 85 // Column name model adds to result table when maps features. 86 // By default it's 'Predicted' 87 Predicted() string 88 // Returns new table with all original columns except features 89 // adding one new column with prediction 90 FeaturesMapper(batchSize int) (tables.FeaturesMapper, error) 91 } 92 93 /* 94 GpuPredictionModel is a prediction interface able to use GPU 95 */ 96 type GpuPredictionModel interface { 97 PredictionModel 98 // Gpu changes context of prediction backend to gpu enabled 99 // it's a recommendation only, if GPU is not available or it's impossible to use it 100 // the cpu will be used instead 101 Gpu(...int) PredictionModel 102 } 103 104 func Path(s string) string { 105 if filepath.IsAbs(s) { 106 return s 107 } 108 return iokit.CacheFile(filepath.Join("go-ml", "Models", s)) 109 } 110 111 /* 112 Params is a set of hyper-parameters used by hyper-parameter optimization to generate new model 113 */ 114 type Params map[string]float64 115 116 /* 117 Get value of the parameter by name if exists and dflt value otherwise 118 */ 119 func (p Params) Get(name string, dflt float64) float64 { 120 if v, ok := p[name]; ok { 121 return v 122 } 123 return dflt 124 } 125 126 func (p Params) Apply(m map[string]reflect.Value) { 127 for k, v := range p { 128 ref, ok := m[k] 129 if !ok { 130 panic(zorros.Panic(zorros.Errorf("model does not have field `%v`", k))) 131 } 132 ref.Elem().Set(fu.Convert(reflect.ValueOf(v), false, ref.Type().Elem())) 133 } 134 }