gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/stat/roc.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package stat
     6  
     7  import (
     8  	"math"
     9  	"slices"
    10  	"sort"
    11  )
    12  
    13  // ROC returns paired false positive rate (FPR) and true positive rate
    14  // (TPR) values corresponding to cutoff points on the receiver operator
    15  // characteristic (ROC) curve obtained when y is treated as a binary
    16  // classifier for classes with weights. The cutoff thresholds used to
    17  // calculate the ROC are returned in thresh such that tpr[i] and fpr[i]
    18  // are the true and false positive rates for y >= thresh[i].
    19  //
    20  // The input y and cutoffs must be sorted, and values in y must correspond
    21  // to values in classes and weights. SortWeightedLabeled can be used to
    22  // sort y together with classes and weights.
    23  //
    24  // For a given cutoff value, observations corresponding to entries in y
    25  // greater than the cutoff value are classified as true, while those
    26  // less than or equal to the cutoff value are classified as false. These
    27  // assigned class labels are compared with the true values in the classes
    28  // slice and used to calculate the FPR and TPR.
    29  //
    30  // If weights is nil, all weights are treated as 1. If weights is not nil
    31  // it must have the same length as y and classes, otherwise ROC will panic.
    32  //
    33  // If cutoffs is nil or empty, all possible cutoffs are calculated,
    34  // resulting in fpr and tpr having length one greater than the number of
    35  // unique values in y. Otherwise fpr and tpr will be returned with the
    36  // same length as cutoffs. floats.Span can be used to generate equally
    37  // spaced cutoffs.
    38  //
    39  // More details about ROC curves are available at
    40  // https://en.wikipedia.org/wiki/Receiver_operating_characteristic
    41  func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr, thresh []float64) {
    42  	if len(y) != len(classes) {
    43  		panic("stat: slice length mismatch")
    44  	}
    45  	if weights != nil && len(y) != len(weights) {
    46  		panic("stat: slice length mismatch")
    47  	}
    48  	if !sort.Float64sAreSorted(y) {
    49  		panic("stat: input must be sorted ascending")
    50  	}
    51  	if !sort.Float64sAreSorted(cutoffs) {
    52  		panic("stat: cutoff values must be sorted ascending")
    53  	}
    54  	if len(y) == 0 {
    55  		return nil, nil, nil
    56  	}
    57  	if len(cutoffs) == 0 {
    58  		if cutoffs == nil || cap(cutoffs) < len(y)+1 {
    59  			cutoffs = make([]float64, len(y)+1)
    60  		} else {
    61  			cutoffs = cutoffs[:len(y)+1]
    62  		}
    63  		// Choose all possible cutoffs for unique values in y.
    64  		bin := 0
    65  		cutoffs[bin] = y[0]
    66  		for i, u := range y[1:] {
    67  			if u == y[i] {
    68  				continue
    69  			}
    70  			bin++
    71  			cutoffs[bin] = u
    72  		}
    73  		cutoffs[bin+1] = math.Inf(1)
    74  		cutoffs = cutoffs[:bin+2]
    75  	} else {
    76  		// Don't mutate the provided cutoffs.
    77  		tmp := cutoffs
    78  		cutoffs = make([]float64, len(cutoffs))
    79  		copy(cutoffs, tmp)
    80  	}
    81  
    82  	tpr = make([]float64, len(cutoffs))
    83  	fpr = make([]float64, len(cutoffs))
    84  	var bin int
    85  	var nPos, nNeg float64
    86  	for i, u := range classes {
    87  		// Update the bin until it matches the next y value
    88  		// skipping empty bins.
    89  		for bin < len(cutoffs)-1 && y[i] >= cutoffs[bin] {
    90  			bin++
    91  			tpr[bin] = tpr[bin-1]
    92  			fpr[bin] = fpr[bin-1]
    93  		}
    94  		posWeight, negWeight := 1.0, 0.0
    95  		if weights != nil {
    96  			posWeight = weights[i]
    97  		}
    98  		if !u {
    99  			posWeight, negWeight = negWeight, posWeight
   100  		}
   101  		nPos += posWeight
   102  		nNeg += negWeight
   103  		// Count false negatives (in tpr) and true negatives (in fpr).
   104  		if y[i] < cutoffs[bin] {
   105  			tpr[bin] += posWeight
   106  			fpr[bin] += negWeight
   107  		}
   108  	}
   109  
   110  	invNeg := 1 / nNeg
   111  	invPos := 1 / nPos
   112  	// Convert negative counts to TPR and FPR.
   113  	// Bins beyond the maximum value in y are skipped
   114  	// leaving these fpr and tpr elements as zero.
   115  	for i := range tpr[:bin+1] {
   116  		// Prevent fused float operations by
   117  		// making explicit float64 conversions.
   118  		tpr[i] = 1 - float64(tpr[i]*invPos)
   119  		fpr[i] = 1 - float64(fpr[i]*invNeg)
   120  	}
   121  	slices.Reverse(tpr)
   122  	slices.Reverse(fpr)
   123  	slices.Reverse(cutoffs)
   124  
   125  	return tpr, fpr, cutoffs
   126  }
   127  
   128  // TOC returns the Total Operating Characteristic for the classes provided
   129  // and the minimum and maximum bounds for the TOC.
   130  //
   131  // The input y values that correspond to classes and weights must be sorted
   132  // in ascending order. classes[i] is the class of value y[i] and weights[i]
   133  // is the weight of y[i]. SortWeightedLabeled can be used to sort classes
   134  // together with weights by the rank variable, i+1.
   135  //
   136  // The returned ntp values can be interpreted as the number of true positives
   137  // where values above the given rank are assigned class true for each given
   138  // rank from 1 to len(classes).
   139  //
   140  //	ntp_i = sum_{j ≥ len(ntp)-1 - i} [ classes_j ] * weights_j, where [x] = 1 if x else 0.
   141  //
   142  // The values of min and max provide the minimum and maximum possible number
   143  // of false values for the set of classes. The first element of ntp, min and
   144  // max are always zero as this corresponds to assigning all data class false
   145  // and the last elements are always weighted sum of classes as this corresponds
   146  // to assigning every data class true. For len(classes) != 0, the lengths of
   147  // min, ntp and max are len(classes)+1.
   148  //
   149  // If weights is nil, all weights are treated as 1. When weights are not nil,
   150  // the calculation of min and max allows for partial assignment of single data
   151  // points. If weights is not nil it must have the same length as classes,
   152  // otherwise TOC will panic.
   153  //
   154  // More details about TOC curves are available at
   155  // https://en.wikipedia.org/wiki/Total_operating_characteristic
   156  func TOC(classes []bool, weights []float64) (min, ntp, max []float64) {
   157  	if weights != nil && len(classes) != len(weights) {
   158  		panic("stat: slice length mismatch")
   159  	}
   160  	if len(classes) == 0 {
   161  		return nil, nil, nil
   162  	}
   163  
   164  	ntp = make([]float64, len(classes)+1)
   165  	min = make([]float64, len(ntp))
   166  	max = make([]float64, len(ntp))
   167  	if weights == nil {
   168  		for i := range ntp[1:] {
   169  			ntp[i+1] = ntp[i]
   170  			if classes[len(classes)-i-1] {
   171  				ntp[i+1]++
   172  			}
   173  		}
   174  		totalPositive := ntp[len(ntp)-1]
   175  		for i := range ntp {
   176  			min[i] = math.Max(0, totalPositive-float64(len(classes)-i))
   177  			max[i] = math.Min(totalPositive, float64(i))
   178  		}
   179  		return min, ntp, max
   180  	}
   181  
   182  	cumw := max // Reuse max for cumulative weight. Update its elements last.
   183  	for i := range ntp[1:] {
   184  		ntp[i+1] = ntp[i]
   185  		w := weights[len(weights)-i-1]
   186  		cumw[i+1] = cumw[i] + w
   187  		if classes[len(classes)-i-1] {
   188  			ntp[i+1] += w
   189  		}
   190  	}
   191  	totw := cumw[len(cumw)-1]
   192  	totalPositive := ntp[len(ntp)-1]
   193  	for i := range ntp {
   194  		min[i] = math.Max(0, totalPositive-(totw-cumw[i]))
   195  		max[i] = math.Min(totalPositive, cumw[i])
   196  	}
   197  	return min, ntp, max
   198  }