go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/model/model.go (about)

     1  package model
     2  
     3  import (
     4  	"go-ml.dev/pkg/base/fu"
     5  	"go-ml.dev/pkg/base/tables"
     6  	"go-ml.dev/pkg/iokit"
     7  	"go-ml.dev/pkg/zorros"
     8  	"io"
     9  	"path/filepath"
    10  	"reflect"
    11  )
    12  
    13  /*
    14  HungryModel is an ML algorithm grows from a data to predict something
    15  Needs to be fattened by Feed method to fit.
    16  */
    17  type HungryModel interface {
    18  	Feed(Dataset) FatModel
    19  }
    20  
    21  /*
    22  Report is an ML training report
    23  */
    24  type Report struct {
    25  	History     *tables.Table // all iterations history
    26  	TheBest     int           // the best iteration
    27  	Test, Train fu.Struct     // the best iteration metrics
    28  	Score       float64       // the best score
    29  }
    30  
    31  /*
    32  Workout is a training iteration abstraction
    33  */
    34  type Workout interface {
    35  	Iteration() int
    36  	TrainMetrics() MetricsUpdater
    37  	TestMetrics() MetricsUpdater
    38  	Complete(m MemorizeMap, train, test fu.Struct, metricsDone bool) (*Report, bool, error)
    39  	Next() Workout
    40  	Verbose(string)
    41  }
    42  
    43  /*
    44  UnifiedTraining is an interface allowing to write any logging/staging backend for ML training
    45  */
    46  type UnifiedTraining interface {
    47  	// Workout returns the first iteration workout
    48  	Workout() Workout
    49  }
    50  
    51  /*
    52  FatModel is fattened model (a training function of model instance bounded to a dataset)
    53  */
    54  type FatModel func(workout Workout) (*Report, error)
    55  
    56  /*
    57  Train a fattened (Fat) model
    58  */
    59  func (f FatModel) Train(training UnifiedTraining) (*Report, error) {
    60  	w := training.Workout()
    61  	if c, ok := w.(io.Closer); ok {
    62  		defer c.Close()
    63  	}
    64  	return f(w)
    65  }
    66  
    67  /*
    68  LuckyTrain trains fattened (Fat) model and trows any occurred errors as a panic
    69  */
    70  func (f FatModel) LuckyTrain(training UnifiedTraining) *Report {
    71  	m, err := f.Train(training)
    72  	if err != nil {
    73  		panic(zorros.Panic(err))
    74  	}
    75  	return m
    76  }
    77  
    78  /*
    79  PredictionModel is a predictor interface
    80  */
    81  type PredictionModel interface {
    82  	// Features model uses when maps features
    83  	// the same as Features in the training dataset
    84  	Features() []string
    85  	// Column name model adds to result table when maps features.
    86  	// By default it's 'Predicted'
    87  	Predicted() string
    88  	// Returns new table with all original columns except features
    89  	// adding one new column with prediction
    90  	FeaturesMapper(batchSize int) (tables.FeaturesMapper, error)
    91  }
    92  
    93  /*
    94  GpuPredictionModel is a prediction interface able to use GPU
    95  */
    96  type GpuPredictionModel interface {
    97  	PredictionModel
    98  	// Gpu changes context of prediction backend to gpu enabled
    99  	// it's a recommendation only, if GPU is not available or it's impossible to use it
   100  	// the cpu will be used instead
   101  	Gpu(...int) PredictionModel
   102  }
   103  
   104  func Path(s string) string {
   105  	if filepath.IsAbs(s) {
   106  		return s
   107  	}
   108  	return iokit.CacheFile(filepath.Join("go-ml", "Models", s))
   109  }
   110  
   111  /*
   112  Params is a set of hyper-parameters used by hyper-parameter optimization to generate new model
   113  */
   114  type Params map[string]float64
   115  
   116  /*
   117  Get value of the parameter by name if exists and dflt value otherwise
   118  */
   119  func (p Params) Get(name string, dflt float64) float64 {
   120  	if v, ok := p[name]; ok {
   121  		return v
   122  	}
   123  	return dflt
   124  }
   125  
   126  func (p Params) Apply(m map[string]reflect.Value) {
   127  	for k, v := range p {
   128  		ref, ok := m[k]
   129  		if !ok {
   130  			panic(zorros.Panic(zorros.Errorf("model does not have field `%v`", k)))
   131  		}
   132  		ref.Elem().Set(fu.Convert(reflect.ValueOf(v), false, ref.Type().Elem()))
   133  	}
   134  }