github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/stat/roc_test.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package stat
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  
    11  	"github.com/jingcheng-WU/gonum/floats"
    12  )
    13  
    14  func TestROC(t *testing.T) {
    15  	const tol = 1e-14
    16  
    17  	cases := []struct {
    18  		y          []float64
    19  		c          []bool
    20  		w          []float64
    21  		cutoffs    []float64
    22  		wantTPR    []float64
    23  		wantFPR    []float64
    24  		wantThresh []float64
    25  	}{
    26  		// Test cases were informed by using sklearn metrics.roc_curve when
    27  		// cutoffs is nil, but all test cases (including when cutoffs is not
    28  		// nil) were calculated manually.
    29  		// Some differences exist between unweighted ROCs from our function
    30  		// and metrics.roc_curve which appears to use integer cutoffs in that
    31  		// case. sklearn also appears to do some magic that trims leading zeros
    32  		// sometimes.
    33  		{ // 0
    34  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    35  			c:          []bool{false, true, false, true, true, true},
    36  			wantTPR:    []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1},
    37  			wantFPR:    []float64{0, 0, 0, 0, 0.5, 0.5, 1},
    38  			wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
    39  		},
    40  		{ // 1
    41  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    42  			c:          []bool{false, true, false, true, true, true},
    43  			w:          []float64{4, 1, 6, 3, 2, 2},
    44  			wantTPR:    []float64{0, 0.25, 0.5, 0.875, 0.875, 1, 1},
    45  			wantFPR:    []float64{0, 0, 0, 0, 0.6, 0.6, 1},
    46  			wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
    47  		},
    48  		{ // 2
    49  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    50  			c:          []bool{false, true, false, true, true, true},
    51  			cutoffs:    []float64{-1, 2, 4, 6, 8},
    52  			wantTPR:    []float64{0.25, 0.75, 0.75, 1, 1},
    53  			wantFPR:    []float64{0, 0, 0.5, 0.5, 1},
    54  			wantThresh: []float64{8, 6, 4, 2, -1},
    55  		},
    56  		{ // 3
    57  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    58  			c:          []bool{false, true, false, true, true, true},
    59  			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
    60  			wantTPR:    []float64{0.25, 0.5, 0.75, 0.75, 0.75, 1, 1, 1, 1},
    61  			wantFPR:    []float64{0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 1},
    62  			wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1},
    63  		},
    64  		{ // 4
    65  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    66  			c:          []bool{false, true, false, true, true, true},
    67  			w:          []float64{4, 1, 6, 3, 2, 2},
    68  			cutoffs:    []float64{-1, 2, 4, 6, 8},
    69  			wantTPR:    []float64{0.25, 0.875, 0.875, 1, 1},
    70  			wantFPR:    []float64{0, 0, 0.6, 0.6, 1},
    71  			wantThresh: []float64{8, 6, 4, 2, -1},
    72  		},
    73  		{ // 5
    74  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    75  			c:          []bool{false, true, false, true, true, true},
    76  			w:          []float64{4, 1, 6, 3, 2, 2},
    77  			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
    78  			wantTPR:    []float64{0.25, 0.5, 0.875, 0.875, 0.875, 1, 1, 1, 1},
    79  			wantFPR:    []float64{0, 0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 1},
    80  			wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1},
    81  		},
    82  		{ // 6
    83  			y:          []float64{0, 3, 6, 6, 6, 8},
    84  			c:          []bool{false, true, false, true, true, true},
    85  			wantTPR:    []float64{0, 0.25, 0.75, 1, 1},
    86  			wantFPR:    []float64{0, 0, 0.5, 0.5, 1},
    87  			wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
    88  		},
    89  		{ // 7
    90  			y:          []float64{0, 3, 6, 6, 6, 8},
    91  			c:          []bool{false, true, false, true, true, true},
    92  			w:          []float64{4, 1, 6, 3, 2, 2},
    93  			wantTPR:    []float64{0, 0.25, 0.875, 1, 1},
    94  			wantFPR:    []float64{0, 0, 0.6, 0.6, 1},
    95  			wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
    96  		},
    97  		{ // 8
    98  			y:          []float64{0, 3, 6, 6, 6, 8},
    99  			c:          []bool{false, true, false, true, true, true},
   100  			cutoffs:    []float64{-1, 2, 4, 6, 8},
   101  			wantTPR:    []float64{0.25, 0.75, 0.75, 1, 1},
   102  			wantFPR:    []float64{0, 0.5, 0.5, 0.5, 1},
   103  			wantThresh: []float64{8, 6, 4, 2, -1},
   104  		},
   105  		{ // 9
   106  			y:          []float64{0, 3, 6, 6, 6, 8},
   107  			c:          []bool{false, true, false, true, true, true},
   108  			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
   109  			wantTPR:    []float64{0.25, 0.25, 0.75, 0.75, 0.75, 1, 1, 1, 1},
   110  			wantFPR:    []float64{0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1},
   111  			wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1},
   112  		},
   113  		{ // 10
   114  			y:          []float64{0, 3, 6, 6, 6, 8},
   115  			c:          []bool{false, true, false, true, true, true},
   116  			w:          []float64{4, 1, 6, 3, 2, 2},
   117  			cutoffs:    []float64{-1, 2, 4, 6, 8},
   118  			wantTPR:    []float64{0.25, 0.875, 0.875, 1, 1},
   119  			wantFPR:    []float64{0, 0.6, 0.6, 0.6, 1},
   120  			wantThresh: []float64{8, 6, 4, 2, -1},
   121  		},
   122  		{ // 11
   123  			y:          []float64{0, 3, 6, 6, 6, 8},
   124  			c:          []bool{false, true, false, true, true, true},
   125  			w:          []float64{4, 1, 6, 3, 2, 2},
   126  			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
   127  			wantTPR:    []float64{0.25, 0.25, 0.875, 0.875, 0.875, 1, 1, 1, 1},
   128  			wantFPR:    []float64{0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 1},
   129  			wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1},
   130  		},
   131  		{ // 12
   132  			y:          []float64{0.1, 0.35, 0.4, 0.8},
   133  			c:          []bool{true, false, true, false},
   134  			wantTPR:    []float64{0, 0, 0.5, 0.5, 1},
   135  			wantFPR:    []float64{0, 0.5, 0.5, 1, 1},
   136  			wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
   137  		},
   138  		{ // 13
   139  			y:          []float64{0.1, 0.35, 0.4, 0.8},
   140  			c:          []bool{false, false, true, true},
   141  			wantTPR:    []float64{0, 0.5, 1, 1, 1},
   142  			wantFPR:    []float64{0, 0, 0, 0.5, 1},
   143  			wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
   144  		},
   145  		{ // 14
   146  			y:          []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10},
   147  			c:          []bool{false, true, false, false, true, true, false},
   148  			cutoffs:    []float64{-1, 2.5, 5, 7.5, 10},
   149  			wantTPR:    []float64{0, 0, 0, 0, 1},
   150  			wantFPR:    []float64{0.25, 0.25, 0.25, 0.25, 1},
   151  			wantThresh: []float64{10, 7.5, 5, 2.5, -1},
   152  		},
   153  		{ // 15
   154  			y:          []float64{1, 2},
   155  			c:          []bool{false, false},
   156  			wantTPR:    []float64{math.NaN(), math.NaN(), math.NaN()},
   157  			wantFPR:    []float64{0, 0.5, 1},
   158  			wantThresh: []float64{math.Inf(1), 2, 1},
   159  		},
   160  		{ // 16
   161  			y:          []float64{1, 2},
   162  			c:          []bool{false, false},
   163  			cutoffs:    []float64{-1, 2},
   164  			wantTPR:    []float64{math.NaN(), math.NaN()},
   165  			wantFPR:    []float64{0.5, 1},
   166  			wantThresh: []float64{2, -1},
   167  		},
   168  		{ // 17
   169  			y:          []float64{1, 2},
   170  			c:          []bool{false, false},
   171  			cutoffs:    []float64{0, 1.2, 1.4, 1.6, 1.8, 2},
   172  			wantTPR:    []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()},
   173  			wantFPR:    []float64{0.5, 0.5, 0.5, 0.5, 0.5, 1},
   174  			wantThresh: []float64{2, 1.8, 1.6, 1.4, 1.2, 0},
   175  		},
   176  		{ // 18
   177  			y:          []float64{1},
   178  			c:          []bool{false},
   179  			wantTPR:    []float64{math.NaN(), math.NaN()},
   180  			wantFPR:    []float64{0, 1},
   181  			wantThresh: []float64{math.Inf(1), 1},
   182  		},
   183  		{ // 19
   184  			y:          []float64{1},
   185  			c:          []bool{false},
   186  			cutoffs:    []float64{-1, 1},
   187  			wantTPR:    []float64{math.NaN(), math.NaN()},
   188  			wantFPR:    []float64{1, 1},
   189  			wantThresh: []float64{1, -1},
   190  		},
   191  		{ // 20
   192  			y:          []float64{1},
   193  			c:          []bool{true},
   194  			wantTPR:    []float64{0, 1},
   195  			wantFPR:    []float64{math.NaN(), math.NaN()},
   196  			wantThresh: []float64{math.Inf(1), 1},
   197  		},
   198  		{ // 21
   199  			y:          []float64{},
   200  			c:          []bool{},
   201  			wantTPR:    nil,
   202  			wantFPR:    nil,
   203  			wantThresh: nil,
   204  		},
   205  		{ // 22
   206  			y:          []float64{},
   207  			c:          []bool{},
   208  			cutoffs:    []float64{-1, 2.5, 5, 7.5, 10},
   209  			wantTPR:    nil,
   210  			wantFPR:    nil,
   211  			wantThresh: nil,
   212  		},
   213  		{ // 23
   214  			y:          []float64{0.1, 0.35, 0.4, 0.8},
   215  			c:          []bool{true, false, true, false},
   216  			cutoffs:    []float64{-1, 0.1, 0.35, 0.4, 0.8, 0.9, 1},
   217  			wantTPR:    []float64{0, 0, 0, 0.5, 0.5, 1, 1},
   218  			wantFPR:    []float64{0, 0, 0.5, 0.5, 1, 1, 1},
   219  			wantThresh: []float64{1, 0.9, 0.8, 0.4, 0.35, 0.1, -1},
   220  		},
   221  		{ // 24
   222  			y:          []float64{0.1, 0.35, 0.4, 0.8},
   223  			c:          []bool{true, false, true, false},
   224  			cutoffs:    []float64{math.Inf(-1), 0.1, 0.36, 0.8},
   225  			wantTPR:    []float64{0, 0.5, 1, 1},
   226  			wantFPR:    []float64{0.5, 0.5, 1, 1},
   227  			wantThresh: []float64{0.8, 0.36, 0.1, math.Inf(-1)},
   228  		},
   229  		{ // 25
   230  			y:          []float64{0, 3, 5, 6, 7.5, 8},
   231  			c:          []bool{false, true, false, true, true, true},
   232  			cutoffs:    make([]float64, 0, 10),
   233  			wantTPR:    []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1},
   234  			wantFPR:    []float64{0, 0, 0, 0, 0.5, 0.5, 1},
   235  			wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
   236  		},
   237  		{ // 26
   238  			y:          []float64{0.1, 0.35, 0.4, 0.8},
   239  			c:          []bool{true, false, true, false},
   240  			cutoffs:    []float64{-1, 0.1, 0.35, 0.4, 0.8, 0.9, 1, 1.1, 1.2},
   241  			wantTPR:    []float64{0, 0, 0, 0, 0, 0.5, 0.5, 1, 1},
   242  			wantFPR:    []float64{0, 0, 0, 0, 0.5, 0.5, 1, 1, 1},
   243  			wantThresh: []float64{1.2, 1.1, 1, 0.9, 0.8, 0.4, 0.35, 0.1, -1},
   244  		},
   245  	}
   246  	for i, test := range cases {
   247  		gotTPR, gotFPR, gotThresh := ROC(test.cutoffs, test.y, test.c, test.w)
   248  		if !floats.Same(gotTPR, test.wantTPR) && !floats.EqualApprox(gotTPR, test.wantTPR, tol) {
   249  			t.Errorf("%d: unexpected TPR got:%v want:%v", i, gotTPR, test.wantTPR)
   250  		}
   251  		if !floats.Same(gotFPR, test.wantFPR) && !floats.EqualApprox(gotFPR, test.wantFPR, tol) {
   252  			t.Errorf("%d: unexpected FPR got:%v want:%v", i, gotFPR, test.wantFPR)
   253  		}
   254  		if !floats.Same(gotThresh, test.wantThresh) {
   255  			t.Errorf("%d: unexpected thresholds got:%#v want:%v", i, gotThresh, test.wantThresh)
   256  		}
   257  	}
   258  }