gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/stat/roc.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package stat 6 7 import ( 8 "math" 9 "slices" 10 "sort" 11 ) 12 13 // ROC returns paired false positive rate (FPR) and true positive rate 14 // (TPR) values corresponding to cutoff points on the receiver operator 15 // characteristic (ROC) curve obtained when y is treated as a binary 16 // classifier for classes with weights. The cutoff thresholds used to 17 // calculate the ROC are returned in thresh such that tpr[i] and fpr[i] 18 // are the true and false positive rates for y >= thresh[i]. 19 // 20 // The input y and cutoffs must be sorted, and values in y must correspond 21 // to values in classes and weights. SortWeightedLabeled can be used to 22 // sort y together with classes and weights. 23 // 24 // For a given cutoff value, observations corresponding to entries in y 25 // greater than the cutoff value are classified as true, while those 26 // less than or equal to the cutoff value are classified as false. These 27 // assigned class labels are compared with the true values in the classes 28 // slice and used to calculate the FPR and TPR. 29 // 30 // If weights is nil, all weights are treated as 1. If weights is not nil 31 // it must have the same length as y and classes, otherwise ROC will panic. 32 // 33 // If cutoffs is nil or empty, all possible cutoffs are calculated, 34 // resulting in fpr and tpr having length one greater than the number of 35 // unique values in y. Otherwise fpr and tpr will be returned with the 36 // same length as cutoffs. floats.Span can be used to generate equally 37 // spaced cutoffs. 38 // 39 // More details about ROC curves are available at 40 // https://en.wikipedia.org/wiki/Receiver_operating_characteristic 41 func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr, thresh []float64) { 42 if len(y) != len(classes) { 43 panic("stat: slice length mismatch") 44 } 45 if weights != nil && len(y) != len(weights) { 46 panic("stat: slice length mismatch") 47 } 48 if !sort.Float64sAreSorted(y) { 49 panic("stat: input must be sorted ascending") 50 } 51 if !sort.Float64sAreSorted(cutoffs) { 52 panic("stat: cutoff values must be sorted ascending") 53 } 54 if len(y) == 0 { 55 return nil, nil, nil 56 } 57 if len(cutoffs) == 0 { 58 if cutoffs == nil || cap(cutoffs) < len(y)+1 { 59 cutoffs = make([]float64, len(y)+1) 60 } else { 61 cutoffs = cutoffs[:len(y)+1] 62 } 63 // Choose all possible cutoffs for unique values in y. 64 bin := 0 65 cutoffs[bin] = y[0] 66 for i, u := range y[1:] { 67 if u == y[i] { 68 continue 69 } 70 bin++ 71 cutoffs[bin] = u 72 } 73 cutoffs[bin+1] = math.Inf(1) 74 cutoffs = cutoffs[:bin+2] 75 } else { 76 // Don't mutate the provided cutoffs. 77 tmp := cutoffs 78 cutoffs = make([]float64, len(cutoffs)) 79 copy(cutoffs, tmp) 80 } 81 82 tpr = make([]float64, len(cutoffs)) 83 fpr = make([]float64, len(cutoffs)) 84 var bin int 85 var nPos, nNeg float64 86 for i, u := range classes { 87 // Update the bin until it matches the next y value 88 // skipping empty bins. 89 for bin < len(cutoffs)-1 && y[i] >= cutoffs[bin] { 90 bin++ 91 tpr[bin] = tpr[bin-1] 92 fpr[bin] = fpr[bin-1] 93 } 94 posWeight, negWeight := 1.0, 0.0 95 if weights != nil { 96 posWeight = weights[i] 97 } 98 if !u { 99 posWeight, negWeight = negWeight, posWeight 100 } 101 nPos += posWeight 102 nNeg += negWeight 103 // Count false negatives (in tpr) and true negatives (in fpr). 104 if y[i] < cutoffs[bin] { 105 tpr[bin] += posWeight 106 fpr[bin] += negWeight 107 } 108 } 109 110 invNeg := 1 / nNeg 111 invPos := 1 / nPos 112 // Convert negative counts to TPR and FPR. 113 // Bins beyond the maximum value in y are skipped 114 // leaving these fpr and tpr elements as zero. 115 for i := range tpr[:bin+1] { 116 // Prevent fused float operations by 117 // making explicit float64 conversions. 118 tpr[i] = 1 - float64(tpr[i]*invPos) 119 fpr[i] = 1 - float64(fpr[i]*invNeg) 120 } 121 slices.Reverse(tpr) 122 slices.Reverse(fpr) 123 slices.Reverse(cutoffs) 124 125 return tpr, fpr, cutoffs 126 } 127 128 // TOC returns the Total Operating Characteristic for the classes provided 129 // and the minimum and maximum bounds for the TOC. 130 // 131 // The input y values that correspond to classes and weights must be sorted 132 // in ascending order. classes[i] is the class of value y[i] and weights[i] 133 // is the weight of y[i]. SortWeightedLabeled can be used to sort classes 134 // together with weights by the rank variable, i+1. 135 // 136 // The returned ntp values can be interpreted as the number of true positives 137 // where values above the given rank are assigned class true for each given 138 // rank from 1 to len(classes). 139 // 140 // ntp_i = sum_{j ≥ len(ntp)-1 - i} [ classes_j ] * weights_j, where [x] = 1 if x else 0. 141 // 142 // The values of min and max provide the minimum and maximum possible number 143 // of false values for the set of classes. The first element of ntp, min and 144 // max are always zero as this corresponds to assigning all data class false 145 // and the last elements are always weighted sum of classes as this corresponds 146 // to assigning every data class true. For len(classes) != 0, the lengths of 147 // min, ntp and max are len(classes)+1. 148 // 149 // If weights is nil, all weights are treated as 1. When weights are not nil, 150 // the calculation of min and max allows for partial assignment of single data 151 // points. If weights is not nil it must have the same length as classes, 152 // otherwise TOC will panic. 153 // 154 // More details about TOC curves are available at 155 // https://en.wikipedia.org/wiki/Total_operating_characteristic 156 func TOC(classes []bool, weights []float64) (min, ntp, max []float64) { 157 if weights != nil && len(classes) != len(weights) { 158 panic("stat: slice length mismatch") 159 } 160 if len(classes) == 0 { 161 return nil, nil, nil 162 } 163 164 ntp = make([]float64, len(classes)+1) 165 min = make([]float64, len(ntp)) 166 max = make([]float64, len(ntp)) 167 if weights == nil { 168 for i := range ntp[1:] { 169 ntp[i+1] = ntp[i] 170 if classes[len(classes)-i-1] { 171 ntp[i+1]++ 172 } 173 } 174 totalPositive := ntp[len(ntp)-1] 175 for i := range ntp { 176 min[i] = math.Max(0, totalPositive-float64(len(classes)-i)) 177 max[i] = math.Min(totalPositive, float64(i)) 178 } 179 return min, ntp, max 180 } 181 182 cumw := max // Reuse max for cumulative weight. Update its elements last. 183 for i := range ntp[1:] { 184 ntp[i+1] = ntp[i] 185 w := weights[len(weights)-i-1] 186 cumw[i+1] = cumw[i] + w 187 if classes[len(classes)-i-1] { 188 ntp[i+1] += w 189 } 190 } 191 totw := cumw[len(cumw)-1] 192 totalPositive := ntp[len(ntp)-1] 193 for i := range ntp { 194 min[i] = math.Max(0, totalPositive-(totw-cumw[i])) 195 max[i] = math.Min(totalPositive, cumw[i]) 196 } 197 return min, ntp, max 198 }