github.com/gopherd/gonum@v0.0.4/stat/roc.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package stat 6 7 import ( 8 "math" 9 "sort" 10 ) 11 12 // ROC returns paired false positive rate (FPR) and true positive rate 13 // (TPR) values corresponding to cutoff points on the receiver operator 14 // characteristic (ROC) curve obtained when y is treated as a binary 15 // classifier for classes with weights. The cutoff thresholds used to 16 // calculate the ROC are returned in thresh such that tpr[i] and fpr[i] 17 // are the true and false positive rates for y >= thresh[i]. 18 // 19 // The input y and cutoffs must be sorted, and values in y must correspond 20 // to values in classes and weights. SortWeightedLabeled can be used to 21 // sort y together with classes and weights. 22 // 23 // For a given cutoff value, observations corresponding to entries in y 24 // greater than the cutoff value are classified as true, while those 25 // less than or equal to the cutoff value are classified as false. These 26 // assigned class labels are compared with the true values in the classes 27 // slice and used to calculate the FPR and TPR. 28 // 29 // If weights is nil, all weights are treated as 1. If weights is not nil 30 // it must have the same length as y and classes, otherwise ROC will panic. 31 // 32 // If cutoffs is nil or empty, all possible cutoffs are calculated, 33 // resulting in fpr and tpr having length one greater than the number of 34 // unique values in y. Otherwise fpr and tpr will be returned with the 35 // same length as cutoffs. floats.Span can be used to generate equally 36 // spaced cutoffs. 37 // 38 // More details about ROC curves are available at 39 // https://en.wikipedia.org/wiki/Receiver_operating_characteristic 40 func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr, thresh []float64) { 41 if len(y) != len(classes) { 42 panic("stat: slice length mismatch") 43 } 44 if weights != nil && len(y) != len(weights) { 45 panic("stat: slice length mismatch") 46 } 47 if !sort.Float64sAreSorted(y) { 48 panic("stat: input must be sorted ascending") 49 } 50 if !sort.Float64sAreSorted(cutoffs) { 51 panic("stat: cutoff values must be sorted ascending") 52 } 53 if len(y) == 0 { 54 return nil, nil, nil 55 } 56 if len(cutoffs) == 0 { 57 if cutoffs == nil || cap(cutoffs) < len(y)+1 { 58 cutoffs = make([]float64, len(y)+1) 59 } else { 60 cutoffs = cutoffs[:len(y)+1] 61 } 62 // Choose all possible cutoffs for unique values in y. 63 bin := 0 64 cutoffs[bin] = y[0] 65 for i, u := range y[1:] { 66 if u == y[i] { 67 continue 68 } 69 bin++ 70 cutoffs[bin] = u 71 } 72 cutoffs[bin+1] = math.Inf(1) 73 cutoffs = cutoffs[:bin+2] 74 } else { 75 // Don't mutate the provided cutoffs. 76 tmp := cutoffs 77 cutoffs = make([]float64, len(cutoffs)) 78 copy(cutoffs, tmp) 79 } 80 81 tpr = make([]float64, len(cutoffs)) 82 fpr = make([]float64, len(cutoffs)) 83 var bin int 84 var nPos, nNeg float64 85 for i, u := range classes { 86 // Update the bin until it matches the next y value 87 // skipping empty bins. 88 for bin < len(cutoffs)-1 && y[i] >= cutoffs[bin] { 89 bin++ 90 tpr[bin] = tpr[bin-1] 91 fpr[bin] = fpr[bin-1] 92 } 93 posWeight, negWeight := 1.0, 0.0 94 if weights != nil { 95 posWeight = weights[i] 96 } 97 if !u { 98 posWeight, negWeight = negWeight, posWeight 99 } 100 nPos += posWeight 101 nNeg += negWeight 102 // Count false negatives (in tpr) and true negatives (in fpr). 103 if y[i] < cutoffs[bin] { 104 tpr[bin] += posWeight 105 fpr[bin] += negWeight 106 } 107 } 108 109 invNeg := 1 / nNeg 110 invPos := 1 / nPos 111 // Convert negative counts to TPR and FPR. 112 // Bins beyond the maximum value in y are skipped 113 // leaving these fpr and tpr elements as zero. 114 for i := range tpr[:bin+1] { 115 // Prevent fused float operations by 116 // making explicit float64 conversions. 117 tpr[i] = 1 - float64(tpr[i]*invPos) 118 fpr[i] = 1 - float64(fpr[i]*invNeg) 119 } 120 for i, j := 0, len(tpr)-1; i < j; i, j = i+1, j-1 { 121 tpr[i], tpr[j] = tpr[j], tpr[i] 122 fpr[i], fpr[j] = fpr[j], fpr[i] 123 } 124 for i, j := 0, len(cutoffs)-1; i < j; i, j = i+1, j-1 { 125 cutoffs[i], cutoffs[j] = cutoffs[j], cutoffs[i] 126 } 127 128 return tpr, fpr, cutoffs 129 } 130 131 // TOC returns the Total Operating Characteristic for the classes provided 132 // and the minimum and maximum bounds for the TOC. 133 // 134 // The input y values that correspond to classes and weights must be sorted 135 // in ascending order. classes[i] is the class of value y[i] and weights[i] 136 // is the weight of y[i]. SortWeightedLabeled can be used to sort classes 137 // together with weights by the rank variable, i+1. 138 // 139 // The returned ntp values can be interpreted as the number of true positives 140 // where values above the given rank are assigned class true for each given 141 // rank from 1 to len(classes). 142 // ntp_i = sum_{j ≥ len(ntp)-1 - i} [ classes_j ] * weights_j, where [x] = 1 if x else 0. 143 // The values of min and max provide the minimum and maximum possible number 144 // of false values for the set of classes. The first element of ntp, min and 145 // max are always zero as this corresponds to assigning all data class false 146 // and the last elements are always weighted sum of classes as this corresponds 147 // to assigning every data class true. For len(classes) != 0, the lengths of 148 // min, ntp and max are len(classes)+1. 149 // 150 // If weights is nil, all weights are treated as 1. When weights are not nil, 151 // the calculation of min and max allows for partial assignment of single data 152 // points. If weights is not nil it must have the same length as classes, 153 // otherwise TOC will panic. 154 // 155 // More details about TOC curves are available at 156 // https://en.wikipedia.org/wiki/Total_operating_characteristic 157 func TOC(classes []bool, weights []float64) (min, ntp, max []float64) { 158 if weights != nil && len(classes) != len(weights) { 159 panic("stat: slice length mismatch") 160 } 161 if len(classes) == 0 { 162 return nil, nil, nil 163 } 164 165 ntp = make([]float64, len(classes)+1) 166 min = make([]float64, len(ntp)) 167 max = make([]float64, len(ntp)) 168 if weights == nil { 169 for i := range ntp[1:] { 170 ntp[i+1] = ntp[i] 171 if classes[len(classes)-i-1] { 172 ntp[i+1]++ 173 } 174 } 175 totalPositive := ntp[len(ntp)-1] 176 for i := range ntp { 177 min[i] = math.Max(0, totalPositive-float64(len(classes)-i)) 178 max[i] = math.Min(totalPositive, float64(i)) 179 } 180 return min, ntp, max 181 } 182 183 cumw := max // Reuse max for cumulative weight. Update its elements last. 184 for i := range ntp[1:] { 185 ntp[i+1] = ntp[i] 186 w := weights[len(weights)-i-1] 187 cumw[i+1] = cumw[i] + w 188 if classes[len(classes)-i-1] { 189 ntp[i+1] += w 190 } 191 } 192 totw := cumw[len(cumw)-1] 193 totalPositive := ntp[len(ntp)-1] 194 for i := range ntp { 195 min[i] = math.Max(0, totalPositive-(totw-cumw[i])) 196 max[i] = math.Min(totalPositive, cumw[i]) 197 } 198 return min, ntp, max 199 }