github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/stat/roc.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package stat
     6  
     7  import (
     8  	"math"
     9  	"sort"
    10  )
    11  
    12  // ROC returns paired false positive rate (FPR) and true positive rate
    13  // (TPR) values corresponding to cutoff points on the receiver operator
    14  // characteristic (ROC) curve obtained when y is treated as a binary
    15  // classifier for classes with weights. The cutoff thresholds used to
    16  // calculate the ROC are returned in thresh such that tpr[i] and fpr[i]
    17  // are the true and false positive rates for y >= thresh[i].
    18  //
    19  // The input y and cutoffs must be sorted, and values in y must correspond
    20  // to values in classes and weights. SortWeightedLabeled can be used to
    21  // sort y together with classes and weights.
    22  //
    23  // For a given cutoff value, observations corresponding to entries in y
    24  // greater than the cutoff value are classified as true, while those
    25  // less than or equal to the cutoff value are classified as false. These
    26  // assigned class labels are compared with the true values in the classes
    27  // slice and used to calculate the FPR and TPR.
    28  //
    29  // If weights is nil, all weights are treated as 1. If weights is not nil
    30  // it must have the same length as y and classes, otherwise ROC will panic.
    31  //
    32  // If cutoffs is nil or empty, all possible cutoffs are calculated,
    33  // resulting in fpr and tpr having length one greater than the number of
    34  // unique values in y. Otherwise fpr and tpr will be returned with the
    35  // same length as cutoffs. floats.Span can be used to generate equally
    36  // spaced cutoffs.
    37  //
    38  // More details about ROC curves are available at
    39  // https://en.wikipedia.org/wiki/Receiver_operating_characteristic
    40  func ROC(cutoffs, y []float64, classes []bool, weights []float64) (tpr, fpr, thresh []float64) {
    41  	if len(y) != len(classes) {
    42  		panic("stat: slice length mismatch")
    43  	}
    44  	if weights != nil && len(y) != len(weights) {
    45  		panic("stat: slice length mismatch")
    46  	}
    47  	if !sort.Float64sAreSorted(y) {
    48  		panic("stat: input must be sorted ascending")
    49  	}
    50  	if !sort.Float64sAreSorted(cutoffs) {
    51  		panic("stat: cutoff values must be sorted ascending")
    52  	}
    53  	if len(y) == 0 {
    54  		return nil, nil, nil
    55  	}
    56  	if len(cutoffs) == 0 {
    57  		if cutoffs == nil || cap(cutoffs) < len(y)+1 {
    58  			cutoffs = make([]float64, len(y)+1)
    59  		} else {
    60  			cutoffs = cutoffs[:len(y)+1]
    61  		}
    62  		// Choose all possible cutoffs for unique values in y.
    63  		bin := 0
    64  		cutoffs[bin] = y[0]
    65  		for i, u := range y[1:] {
    66  			if u == y[i] {
    67  				continue
    68  			}
    69  			bin++
    70  			cutoffs[bin] = u
    71  		}
    72  		cutoffs[bin+1] = math.Inf(1)
    73  		cutoffs = cutoffs[:bin+2]
    74  	} else {
    75  		// Don't mutate the provided cutoffs.
    76  		tmp := cutoffs
    77  		cutoffs = make([]float64, len(cutoffs))
    78  		copy(cutoffs, tmp)
    79  	}
    80  
    81  	tpr = make([]float64, len(cutoffs))
    82  	fpr = make([]float64, len(cutoffs))
    83  	var bin int
    84  	var nPos, nNeg float64
    85  	for i, u := range classes {
    86  		// Update the bin until it matches the next y value
    87  		// skipping empty bins.
    88  		for bin < len(cutoffs)-1 && y[i] >= cutoffs[bin] {
    89  			bin++
    90  			tpr[bin] = tpr[bin-1]
    91  			fpr[bin] = fpr[bin-1]
    92  		}
    93  		posWeight, negWeight := 1.0, 0.0
    94  		if weights != nil {
    95  			posWeight = weights[i]
    96  		}
    97  		if !u {
    98  			posWeight, negWeight = negWeight, posWeight
    99  		}
   100  		nPos += posWeight
   101  		nNeg += negWeight
   102  		// Count false negatives (in tpr) and true negatives (in fpr).
   103  		if y[i] < cutoffs[bin] {
   104  			tpr[bin] += posWeight
   105  			fpr[bin] += negWeight
   106  		}
   107  	}
   108  
   109  	invNeg := 1 / nNeg
   110  	invPos := 1 / nPos
   111  	// Convert negative counts to TPR and FPR.
   112  	// Bins beyond the maximum value in y are skipped
   113  	// leaving these fpr and tpr elements as zero.
   114  	for i := range tpr[:bin+1] {
   115  		// Prevent fused float operations by
   116  		// making explicit float64 conversions.
   117  		tpr[i] = 1 - float64(tpr[i]*invPos)
   118  		fpr[i] = 1 - float64(fpr[i]*invNeg)
   119  	}
   120  	for i, j := 0, len(tpr)-1; i < j; i, j = i+1, j-1 {
   121  		tpr[i], tpr[j] = tpr[j], tpr[i]
   122  		fpr[i], fpr[j] = fpr[j], fpr[i]
   123  	}
   124  	for i, j := 0, len(cutoffs)-1; i < j; i, j = i+1, j-1 {
   125  		cutoffs[i], cutoffs[j] = cutoffs[j], cutoffs[i]
   126  	}
   127  
   128  	return tpr, fpr, cutoffs
   129  }