go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/model/regression.go (about)

     1  package model
     2  
     3  import (
     4  	"go-ml.dev/pkg/base/fu"
     5  	"math"
     6  	"reflect"
     7  )
     8  
     9  /*
    10  Regression - the regression metrics factory
    11  */
    12  type Regression struct {
    13  	Error float64 // error goal
    14  }
    15  
    16  /*
    17  New iteration metrics
    18  */
    19  func (m Regression) New(iteration int, subset string) MetricsUpdater {
    20  	return &rgupdater{
    21  		Regression: m,
    22  		iteration:  iteration,
    23  		subset:     subset,
    24  	}
    25  }
    26  
    27  /*
    28  Names is the list of calculating metrics
    29  */
    30  func (m Regression) Names() []string {
    31  	return []string{
    32  		IterationCol,
    33  		SubsetCol,
    34  		ErrorCol,
    35  		LossCol,
    36  		RmseCol,
    37  		MaeCol,
    38  		MeCol,
    39  		TotalCol,
    40  	}
    41  }
    42  
    43  type rgupdater struct {
    44  	Regression
    45  	iteration int
    46  	subset    string
    47  	loss      float64
    48  	error     float64 // sum{|result-label|}
    49  	error1    float64 // sum{result-label}
    50  	error2    float64 // sum{(result-label)^2}
    51  	count     float64
    52  }
    53  
    54  func (m *rgupdater) Complete() (fu.Struct, bool) {
    55  	if m.count > 0 {
    56  		squrederr := m.error2 / m.count
    57  		errsqrt := math.Sqrt(squrederr)
    58  		abserr := m.error / m.count
    59  		meanerr := m.error1 / m.count
    60  		columns := []reflect.Value{
    61  			reflect.ValueOf(m.iteration),
    62  			reflect.ValueOf(m.subset),
    63  			reflect.ValueOf(squrederr),
    64  			reflect.ValueOf(m.loss / m.count),
    65  			reflect.ValueOf(errsqrt),
    66  			reflect.ValueOf(abserr),
    67  			reflect.ValueOf(meanerr),
    68  			reflect.ValueOf(int(m.count)),
    69  		}
    70  		goal := false
    71  		if m.Error > 0 {
    72  			goal = goal || squrederr < m.Error
    73  		}
    74  		return fu.Struct{Names: m.Names(), Columns: columns}, goal
    75  	}
    76  	return fu.
    77  			NaStruct(m.Names(), fu.Float64).
    78  			Set(IterationCol, fu.IntZero).
    79  			Set(SubsetCol, fu.EmptyString),
    80  		false
    81  }
    82  
    83  func error1(a, b []float32) (float64, float64) {
    84  	c := 0.
    85  	m := 0.
    86  	for i, v := range a {
    87  		x := float64(v - b[i])
    88  		c += math.Abs(x)
    89  		m += x
    90  	}
    91  	return c / float64(len(a)), m / float64(len(a))
    92  }
    93  
    94  func error2(a, b []float32) float64 {
    95  	c := 0.
    96  	for i, v := range a {
    97  		q := float64(v - b[i])
    98  		c += q * q
    99  	}
   100  	return c / float64(len(a))
   101  }
   102  
   103  func (m *rgupdater) Update(result, label reflect.Value, loss float64) {
   104  	var e, e1, e2 float64
   105  	if result.Type() == fu.TensorType {
   106  		vr := result.Interface().(fu.Tensor).Floats32()
   107  		if t, ok := label.Interface().(fu.Tensor); ok {
   108  			vl := t.Floats32()
   109  			e, e1 = error1(vr, vl)
   110  			e2 = error2(vr, vl)
   111  		}
   112  	} else {
   113  		r := fu.Cell{result}.Float()
   114  		l := fu.Cell{label}.Float()
   115  		e = math.Abs(r - l)
   116  		e1 = r - l
   117  		e2 = e * e
   118  	}
   119  	m.error += e
   120  	m.error1 += e1
   121  	m.error2 += e2
   122  	m.loss += loss
   123  	m.count++
   124  }