go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/model/classificaction.go (about) 1 package model 2 3 import ( 4 "go-ml.dev/pkg/base/fu" 5 "reflect" 6 ) 7 8 /* 9 Classification metrics factory 10 */ 11 type Classification struct { 12 Accuracy float64 // accuracy goal 13 Error float64 // error goal 14 Confidence float32 // threshold for binary classification 15 } 16 17 /* 18 Names is the list of calculating metrics 19 */ 20 func (m Classification) Names() []string { 21 return []string{ 22 IterationCol, 23 SubsetCol, 24 ErrorCol, 25 LossCol, 26 AccuracyCol, 27 SensitivityCol, 28 PrecisionCol, 29 F1ScoreCol, 30 CorrectCol, 31 TotalCol, 32 } 33 } 34 35 /* 36 New metrics updater for the given iteration and subset 37 */ 38 func (m Classification) New(iteration int, subset string) MetricsUpdater { 39 return &cfupdater{ 40 Classification: m, 41 iteration: iteration, 42 subset: subset, 43 lIncorrect: map[int]float64{}, 44 rIncorrect: map[int]float64{}, 45 cCorrect: map[int]float64{}, 46 } 47 } 48 49 type cfupdater struct { 50 Classification 51 iteration int 52 subset string 53 correct float64 54 loss float64 55 lIncorrect map[int]float64 56 rIncorrect map[int]float64 57 cCorrect map[int]float64 58 count float64 59 } 60 61 func (m *cfupdater) Update(result, label reflect.Value, loss float64) { 62 l := fu.Cell{label}.Int() 63 y := 0 64 if result.Type() == fu.TensorType { 65 v := result.Interface().(fu.Tensor) 66 y = v.HotOne() 67 } else { 68 if m.Confidence > 0 { 69 x := fu.Cell{result}.Real() 70 if x > m.Confidence { 71 y = 1 72 } 73 } else { 74 y = fu.Cell{result}.Int() 75 } 76 } 77 if l == y { 78 m.correct++ 79 m.cCorrect[y] = m.cCorrect[y] + 1 80 } else { 81 m.lIncorrect[l] = m.lIncorrect[l] + 1 82 m.rIncorrect[y] = m.rIncorrect[y] + 1 83 } 84 m.loss += loss 85 m.count++ 86 } 87 88 func (m *cfupdater) Complete() (fu.Struct, bool) { 89 if m.count > 0 { 90 acc := m.correct / m.count 91 cno := float64(len(m.cCorrect)) 92 var sensitivity, precision, cerr float64 93 for i, v := range m.cCorrect { 94 sensitivity += v / (v + m.lIncorrect[i]) // false negative 95 precision += v / (v + m.rIncorrect[i]) // false positive 96 cerr += (m.rIncorrect[i] + m.lIncorrect[i]) / m.count 97 } 98 sensitivity /= cno 99 precision /= cno 100 cerr /= cno 101 f1 := 2 * precision * sensitivity / (precision + sensitivity) 102 columns := []reflect.Value{ 103 reflect.ValueOf(m.iteration), 104 reflect.ValueOf(m.subset), 105 reflect.ValueOf(cerr), 106 reflect.ValueOf(m.loss / m.count), 107 reflect.ValueOf(acc), 108 reflect.ValueOf(sensitivity), 109 reflect.ValueOf(precision), 110 reflect.ValueOf(f1), 111 reflect.ValueOf(int(m.correct)), 112 reflect.ValueOf(int(m.count)), 113 } 114 goal := false 115 if m.Accuracy > 0 { 116 goal = goal || acc > m.Accuracy 117 } 118 if m.Error > 0 { 119 goal = goal || cerr < m.Error 120 } 121 return fu.Struct{Names: m.Names(), Columns: columns}, goal 122 } 123 return fu. 124 NaStruct(m.Names(), fu.Float64). 125 Set(IterationCol, fu.IntZero). 126 Set(SubsetCol, fu.EmptyString), 127 false 128 }