github.com/gopherd/gonum@v0.0.4/stat/roc_test.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package stat
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  
    11  	"github.com/gopherd/gonum/floats"
    12  )
    13  
    14  func TestROC(t *testing.T) {
    15  	const tol = 1e-14
    16  
    17  	cases := []struct {
    18  		y          []float64
    19  		c          []bool
    20  		w          []float64
    21  		cutoffs    []float64
    22  		wantTPR    []float64
    23  		wantFPR    []float64
    24  		wantThresh []float64
    25  	}{
    26  		// Test cases were informed by using sklearn metrics.roc_curve when
    27  		// cutoffs is nil, but all test cases (including when cutoffs is not
    28  		// nil) were calculated manually.
    29  		// Some differences exist between unweighted ROCs from our function
    30  		// and metrics.roc_curve which appears to use integer cutoffs in that
    31  		// case. sklearn also appears to do some magic that trims leading zeros
    32  		// sometimes.
    33  		{ // 0
    34  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    35  			c:          []bool{false, true, false, true, true, true},
    36  			wantTPR:    []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1},
    37  			wantFPR:    []float64{0, 0, 0, 0, 0.5, 0.5, 1},
    38  			wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
    39  		},
    40  		{ // 1
    41  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    42  			c:          []bool{false, true, false, true, true, true},
    43  			w:          []float64{4, 1, 6, 3, 2, 2},
    44  			wantTPR:    []float64{0, 0.25, 0.5, 0.875, 0.875, 1, 1},
    45  			wantFPR:    []float64{0, 0, 0, 0, 0.6, 0.6, 1},
    46  			wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
    47  		},
    48  		{ // 2
    49  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    50  			c:          []bool{false, true, false, true, true, true},
    51  			cutoffs:    []float64{-1, 2, 4, 6, 8},
    52  			wantTPR:    []float64{0.25, 0.75, 0.75, 1, 1},
    53  			wantFPR:    []float64{0, 0, 0.5, 0.5, 1},
    54  			wantThresh: []float64{8, 6, 4, 2, -1},
    55  		},
    56  		{ // 3
    57  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    58  			c:          []bool{false, true, false, true, true, true},
    59  			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
    60  			wantTPR:    []float64{0.25, 0.5, 0.75, 0.75, 0.75, 1, 1, 1, 1},
    61  			wantFPR:    []float64{0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 1},
    62  			wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1},
    63  		},
    64  		{ // 4
    65  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    66  			c:          []bool{false, true, false, true, true, true},
    67  			w:          []float64{4, 1, 6, 3, 2, 2},
    68  			cutoffs:    []float64{-1, 2, 4, 6, 8},
    69  			wantTPR:    []float64{0.25, 0.875, 0.875, 1, 1},
    70  			wantFPR:    []float64{0, 0, 0.6, 0.6, 1},
    71  			wantThresh: []float64{8, 6, 4, 2, -1},
    72  		},
    73  		{ // 5
    74  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    75  			c:          []bool{false, true, false, true, true, true},
    76  			w:          []float64{4, 1, 6, 3, 2, 2},
    77  			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
    78  			wantTPR:    []float64{0.25, 0.5, 0.875, 0.875, 0.875, 1, 1, 1, 1},
    79  			wantFPR:    []float64{0, 0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 1},
    80  			wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1},
    81  		},
    82  		{ // 6
    83  			y:          []float64{0, 3, 6, 6, 6, 8},
    84  			c:          []bool{false, true, false, true, true, true},
    85  			wantTPR:    []float64{0, 0.25, 0.75, 1, 1},
    86  			wantFPR:    []float64{0, 0, 0.5, 0.5, 1},
    87  			wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
    88  		},
    89  		{ // 7
    90  			y:          []float64{0, 3, 6, 6, 6, 8},
    91  			c:          []bool{false, true, false, true, true, true},
    92  			w:          []float64{4, 1, 6, 3, 2, 2},
    93  			wantTPR:    []float64{0, 0.25, 0.875, 1, 1},
    94  			wantFPR:    []float64{0, 0, 0.6, 0.6, 1},
    95  			wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
    96  		},
    97  		{ // 8
    98  			y:          []float64{0, 3, 6, 6, 6, 8},
    99  			c:          []bool{false, true, false, true, true, true},
   100  			cutoffs:    []float64{-1, 2, 4, 6, 8},
   101  			wantTPR:    []float64{0.25, 0.75, 0.75, 1, 1},
   102  			wantFPR:    []float64{0, 0.5, 0.5, 0.5, 1},
   103  			wantThresh: []float64{8, 6, 4, 2, -1},
   104  		},
   105  		{ // 9
   106  			y:          []float64{0, 3, 6, 6, 6, 8},
   107  			c:          []bool{false, true, false, true, true, true},
   108  			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
   109  			wantTPR:    []float64{0.25, 0.25, 0.75, 0.75, 0.75, 1, 1, 1, 1},
   110  			wantFPR:    []float64{0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1},
   111  			wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1},
   112  		},
   113  		{ // 10
   114  			y:          []float64{0, 3, 6, 6, 6, 8},
   115  			c:          []bool{false, true, false, true, true, true},
   116  			w:          []float64{4, 1, 6, 3, 2, 2},
   117  			cutoffs:    []float64{-1, 2, 4, 6, 8},
   118  			wantTPR:    []float64{0.25, 0.875, 0.875, 1, 1},
   119  			wantFPR:    []float64{0, 0.6, 0.6, 0.6, 1},
   120  			wantThresh: []float64{8, 6, 4, 2, -1},
   121  		},
   122  		{ // 11
   123  			y:          []float64{0, 3, 6, 6, 6, 8},
   124  			c:          []bool{false, true, false, true, true, true},
   125  			w:          []float64{4, 1, 6, 3, 2, 2},
   126  			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
   127  			wantTPR:    []float64{0.25, 0.25, 0.875, 0.875, 0.875, 1, 1, 1, 1},
   128  			wantFPR:    []float64{0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 1},
   129  			wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1},
   130  		},
   131  		{ // 12
   132  			y:          []float64{0.1, 0.35, 0.4, 0.8},
   133  			c:          []bool{true, false, true, false},
   134  			wantTPR:    []float64{0, 0, 0.5, 0.5, 1},
   135  			wantFPR:    []float64{0, 0.5, 0.5, 1, 1},
   136  			wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
   137  		},
   138  		{ // 13
   139  			y:          []float64{0.1, 0.35, 0.4, 0.8},
   140  			c:          []bool{false, false, true, true},
   141  			wantTPR:    []float64{0, 0.5, 1, 1, 1},
   142  			wantFPR:    []float64{0, 0, 0, 0.5, 1},
   143  			wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
   144  		},
   145  		{ // 14
   146  			y:          []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10},
   147  			c:          []bool{false, true, false, false, true, true, false},
   148  			cutoffs:    []float64{-1, 2.5, 5, 7.5, 10},
   149  			wantTPR:    []float64{0, 0, 0, 0, 1},
   150  			wantFPR:    []float64{0.25, 0.25, 0.25, 0.25, 1},
   151  			wantThresh: []float64{10, 7.5, 5, 2.5, -1},
   152  		},
   153  		{ // 15
   154  			y:          []float64{1, 2},
   155  			c:          []bool{false, false},
   156  			wantTPR:    []float64{math.NaN(), math.NaN(), math.NaN()},
   157  			wantFPR:    []float64{0, 0.5, 1},
   158  			wantThresh: []float64{math.Inf(1), 2, 1},
   159  		},
   160  		{ // 16
   161  			y:          []float64{1, 2},
   162  			c:          []bool{false, false},
   163  			cutoffs:    []float64{-1, 2},
   164  			wantTPR:    []float64{math.NaN(), math.NaN()},
   165  			wantFPR:    []float64{0.5, 1},
   166  			wantThresh: []float64{2, -1},
   167  		},
   168  		{ // 17
   169  			y:          []float64{1, 2},
   170  			c:          []bool{false, false},
   171  			cutoffs:    []float64{0, 1.2, 1.4, 1.6, 1.8, 2},
   172  			wantTPR:    []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()},
   173  			wantFPR:    []float64{0.5, 0.5, 0.5, 0.5, 0.5, 1},
   174  			wantThresh: []float64{2, 1.8, 1.6, 1.4, 1.2, 0},
   175  		},
   176  		{ // 18
   177  			y:          []float64{1},
   178  			c:          []bool{false},
   179  			wantTPR:    []float64{math.NaN(), math.NaN()},
   180  			wantFPR:    []float64{0, 1},
   181  			wantThresh: []float64{math.Inf(1), 1},
   182  		},
   183  		{ // 19
   184  			y:          []float64{1},
   185  			c:          []bool{false},
   186  			cutoffs:    []float64{-1, 1},
   187  			wantTPR:    []float64{math.NaN(), math.NaN()},
   188  			wantFPR:    []float64{1, 1},
   189  			wantThresh: []float64{1, -1},
   190  		},
   191  		{ // 20
   192  			y:          []float64{1},
   193  			c:          []bool{true},
   194  			wantTPR:    []float64{0, 1},
   195  			wantFPR:    []float64{math.NaN(), math.NaN()},
   196  			wantThresh: []float64{math.Inf(1), 1},
   197  		},
   198  		{ // 21
   199  			y:          []float64{},
   200  			c:          []bool{},
   201  			wantTPR:    nil,
   202  			wantFPR:    nil,
   203  			wantThresh: nil,
   204  		},
   205  		{ // 22
   206  			y:          []float64{},
   207  			c:          []bool{},
   208  			cutoffs:    []float64{-1, 2.5, 5, 7.5, 10},
   209  			wantTPR:    nil,
   210  			wantFPR:    nil,
   211  			wantThresh: nil,
   212  		},
   213  		{ // 23
   214  			y:          []float64{0.1, 0.35, 0.4, 0.8},
   215  			c:          []bool{true, false, true, false},
   216  			cutoffs:    []float64{-1, 0.1, 0.35, 0.4, 0.8, 0.9, 1},
   217  			wantTPR:    []float64{0, 0, 0, 0.5, 0.5, 1, 1},
   218  			wantFPR:    []float64{0, 0, 0.5, 0.5, 1, 1, 1},
   219  			wantThresh: []float64{1, 0.9, 0.8, 0.4, 0.35, 0.1, -1},
   220  		},
   221  		{ // 24
   222  			y:          []float64{0.1, 0.35, 0.4, 0.8},
   223  			c:          []bool{true, false, true, false},
   224  			cutoffs:    []float64{math.Inf(-1), 0.1, 0.36, 0.8},
   225  			wantTPR:    []float64{0, 0.5, 1, 1},
   226  			wantFPR:    []float64{0.5, 0.5, 1, 1},
   227  			wantThresh: []float64{0.8, 0.36, 0.1, math.Inf(-1)},
   228  		},
   229  		{ // 25
   230  			y:          []float64{0, 3, 5, 6, 7.5, 8},
   231  			c:          []bool{false, true, false, true, true, true},
   232  			cutoffs:    make([]float64, 0, 10),
   233  			wantTPR:    []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1},
   234  			wantFPR:    []float64{0, 0, 0, 0, 0.5, 0.5, 1},
   235  			wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
   236  		},
   237  		{ // 26
   238  			y:          []float64{0.1, 0.35, 0.4, 0.8},
   239  			c:          []bool{true, false, true, false},
   240  			cutoffs:    []float64{-1, 0.1, 0.35, 0.4, 0.8, 0.9, 1, 1.1, 1.2},
   241  			wantTPR:    []float64{0, 0, 0, 0, 0, 0.5, 0.5, 1, 1},
   242  			wantFPR:    []float64{0, 0, 0, 0, 0.5, 0.5, 1, 1, 1},
   243  			wantThresh: []float64{1.2, 1.1, 1, 0.9, 0.8, 0.4, 0.35, 0.1, -1},
   244  		},
   245  	}
   246  	for i, test := range cases {
   247  		gotTPR, gotFPR, gotThresh := ROC(test.cutoffs, test.y, test.c, test.w)
   248  		if !floats.Same(gotTPR, test.wantTPR) && !floats.EqualApprox(gotTPR, test.wantTPR, tol) {
   249  			t.Errorf("%d: unexpected TPR got:%v want:%v", i, gotTPR, test.wantTPR)
   250  		}
   251  		if !floats.Same(gotFPR, test.wantFPR) && !floats.EqualApprox(gotFPR, test.wantFPR, tol) {
   252  			t.Errorf("%d: unexpected FPR got:%v want:%v", i, gotFPR, test.wantFPR)
   253  		}
   254  		if !floats.Same(gotThresh, test.wantThresh) {
   255  			t.Errorf("%d: unexpected thresholds got:%#v want:%v", i, gotThresh, test.wantThresh)
   256  		}
   257  	}
   258  }
   259  
   260  func TestTOC(t *testing.T) {
   261  	cases := []struct {
   262  		c       []bool
   263  		w       []float64
   264  		wantMin []float64
   265  		wantMax []float64
   266  		wantTOC []float64
   267  	}{
   268  		{ // 0
   269  			// This is the example given in the paper's supplement.
   270  			// http://www2.clarku.edu/~rpontius/TOCexample2.xlsx
   271  			// It is also shown in the WP article.
   272  			// https://en.wikipedia.org/wiki/Total_operating_characteristic#/media/File:TOC_labeled.png
   273  			c: []bool{
   274  				false, false, false, false, false, false,
   275  				false, false, false, false, false, false,
   276  				false, false, true, true, true, true,
   277  				true, true, true, false, false, true,
   278  				false, true, false, false, true, false,
   279  			},
   280  			wantMin: []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
   281  			wantMax: []float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10},
   282  			wantTOC: []float64{0, 0, 1, 1, 1, 2, 2, 3, 3, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10},
   283  		},
   284  		{ // 1
   285  			c:       []bool{},
   286  			wantMin: nil,
   287  			wantMax: nil,
   288  			wantTOC: nil,
   289  		},
   290  		{ // 2
   291  			c: []bool{
   292  				true, true, true, true, true,
   293  			},
   294  			wantMin: []float64{0, 1, 2, 3, 4, 5},
   295  			wantMax: []float64{0, 1, 2, 3, 4, 5},
   296  			wantTOC: []float64{0, 1, 2, 3, 4, 5},
   297  		},
   298  		{ // 3
   299  			c: []bool{
   300  				false, false, false, false, false,
   301  			},
   302  			wantMin: []float64{0, 0, 0, 0, 0, 0},
   303  			wantMax: []float64{0, 0, 0, 0, 0, 0},
   304  			wantTOC: []float64{0, 0, 0, 0, 0, 0},
   305  		},
   306  		{ // 4
   307  			c:       []bool{false, false, false, true, false, true},
   308  			w:       []float64{2, 2, 3, 6, 1, 4},
   309  			wantMin: []float64{0, 0, 0, 3, 6, 8, 10},
   310  			wantMax: []float64{0, 4, 5, 10, 10, 10, 10},
   311  			wantTOC: []float64{0, 4, 4, 10, 10, 10, 10},
   312  		},
   313  	}
   314  	for i, test := range cases {
   315  		gotMin, gotTOC, gotMax := TOC(test.c, test.w)
   316  		if !floats.Same(gotMin, test.wantMin) {
   317  			t.Errorf("%d: unexpected minimum bound got:%v want:%v", i, gotMin, test.wantMin)
   318  		}
   319  		if !floats.Same(gotMax, test.wantMax) {
   320  			t.Errorf("%d: unexpected maximum bound got:%v want:%v", i, gotMax, test.wantMax)
   321  		}
   322  		if !floats.Same(gotTOC, test.wantTOC) {
   323  			t.Errorf("%d: unexpected TOC got:%v want:%v", i, gotTOC, test.wantTOC)
   324  		}
   325  	}
   326  }