go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/model/classificaction.go (about)

     1  package model
     2  
     3  import (
     4  	"go-ml.dev/pkg/base/fu"
     5  	"reflect"
     6  )
     7  
     8  /*
     9  Classification metrics factory
    10  */
    11  type Classification struct {
    12  	Accuracy   float64 // accuracy goal
    13  	Error      float64 // error goal
    14  	Confidence float32 // threshold for binary classification
    15  }
    16  
    17  /*
    18  Names is the list of calculating metrics
    19  */
    20  func (m Classification) Names() []string {
    21  	return []string{
    22  		IterationCol,
    23  		SubsetCol,
    24  		ErrorCol,
    25  		LossCol,
    26  		AccuracyCol,
    27  		SensitivityCol,
    28  		PrecisionCol,
    29  		F1ScoreCol,
    30  		CorrectCol,
    31  		TotalCol,
    32  	}
    33  }
    34  
    35  /*
    36  New metrics updater for the given iteration and subset
    37  */
    38  func (m Classification) New(iteration int, subset string) MetricsUpdater {
    39  	return &cfupdater{
    40  		Classification: m,
    41  		iteration:      iteration,
    42  		subset:         subset,
    43  		lIncorrect:     map[int]float64{},
    44  		rIncorrect:     map[int]float64{},
    45  		cCorrect:       map[int]float64{},
    46  	}
    47  }
    48  
    49  type cfupdater struct {
    50  	Classification
    51  	iteration  int
    52  	subset     string
    53  	correct    float64
    54  	loss       float64
    55  	lIncorrect map[int]float64
    56  	rIncorrect map[int]float64
    57  	cCorrect   map[int]float64
    58  	count      float64
    59  }
    60  
    61  func (m *cfupdater) Update(result, label reflect.Value, loss float64) {
    62  	l := fu.Cell{label}.Int()
    63  	y := 0
    64  	if result.Type() == fu.TensorType {
    65  		v := result.Interface().(fu.Tensor)
    66  		y = v.HotOne()
    67  	} else {
    68  		if m.Confidence > 0 {
    69  			x := fu.Cell{result}.Real()
    70  			if x > m.Confidence {
    71  				y = 1
    72  			}
    73  		} else {
    74  			y = fu.Cell{result}.Int()
    75  		}
    76  	}
    77  	if l == y {
    78  		m.correct++
    79  		m.cCorrect[y] = m.cCorrect[y] + 1
    80  	} else {
    81  		m.lIncorrect[l] = m.lIncorrect[l] + 1
    82  		m.rIncorrect[y] = m.rIncorrect[y] + 1
    83  	}
    84  	m.loss += loss
    85  	m.count++
    86  }
    87  
    88  func (m *cfupdater) Complete() (fu.Struct, bool) {
    89  	if m.count > 0 {
    90  		acc := m.correct / m.count
    91  		cno := float64(len(m.cCorrect))
    92  		var sensitivity, precision, cerr float64
    93  		for i, v := range m.cCorrect {
    94  			sensitivity += v / (v + m.lIncorrect[i]) // false negative
    95  			precision += v / (v + m.rIncorrect[i])   // false positive
    96  			cerr += (m.rIncorrect[i] + m.lIncorrect[i]) / m.count
    97  		}
    98  		sensitivity /= cno
    99  		precision /= cno
   100  		cerr /= cno
   101  		f1 := 2 * precision * sensitivity / (precision + sensitivity)
   102  		columns := []reflect.Value{
   103  			reflect.ValueOf(m.iteration),
   104  			reflect.ValueOf(m.subset),
   105  			reflect.ValueOf(cerr),
   106  			reflect.ValueOf(m.loss / m.count),
   107  			reflect.ValueOf(acc),
   108  			reflect.ValueOf(sensitivity),
   109  			reflect.ValueOf(precision),
   110  			reflect.ValueOf(f1),
   111  			reflect.ValueOf(int(m.correct)),
   112  			reflect.ValueOf(int(m.count)),
   113  		}
   114  		goal := false
   115  		if m.Accuracy > 0 {
   116  			goal = goal || acc > m.Accuracy
   117  		}
   118  		if m.Error > 0 {
   119  			goal = goal || cerr < m.Error
   120  		}
   121  		return fu.Struct{Names: m.Names(), Columns: columns}, goal
   122  	}
   123  	return fu.
   124  			NaStruct(m.Names(), fu.Float64).
   125  			Set(IterationCol, fu.IntZero).
   126  			Set(SubsetCol, fu.EmptyString),
   127  		false
   128  }