gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/stat/roc_test.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package stat
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"slices"
    11  	"testing"
    12  
    13  	"golang.org/x/exp/rand"
    14  
    15  	"gonum.org/v1/gonum/floats"
    16  )
    17  
    18  func TestROC(t *testing.T) {
    19  	const tol = 1e-14
    20  
    21  	cases := []struct {
    22  		y          []float64
    23  		c          []bool
    24  		w          []float64
    25  		cutoffs    []float64
    26  		wantTPR    []float64
    27  		wantFPR    []float64
    28  		wantThresh []float64
    29  	}{
    30  		// Test cases were informed by using sklearn metrics.roc_curve when
    31  		// cutoffs is nil, but all test cases (including when cutoffs is not
    32  		// nil) were calculated manually.
    33  		// Some differences exist between unweighted ROCs from our function
    34  		// and metrics.roc_curve which appears to use integer cutoffs in that
    35  		// case. sklearn also appears to do some magic that trims leading zeros
    36  		// sometimes.
    37  		{ // 0
    38  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    39  			c:          []bool{false, true, false, true, true, true},
    40  			wantTPR:    []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1},
    41  			wantFPR:    []float64{0, 0, 0, 0, 0.5, 0.5, 1},
    42  			wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
    43  		},
    44  		{ // 1
    45  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    46  			c:          []bool{false, true, false, true, true, true},
    47  			w:          []float64{4, 1, 6, 3, 2, 2},
    48  			wantTPR:    []float64{0, 0.25, 0.5, 0.875, 0.875, 1, 1},
    49  			wantFPR:    []float64{0, 0, 0, 0, 0.6, 0.6, 1},
    50  			wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
    51  		},
    52  		{ // 2
    53  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    54  			c:          []bool{false, true, false, true, true, true},
    55  			cutoffs:    []float64{-1, 2, 4, 6, 8},
    56  			wantTPR:    []float64{0.25, 0.75, 0.75, 1, 1},
    57  			wantFPR:    []float64{0, 0, 0.5, 0.5, 1},
    58  			wantThresh: []float64{8, 6, 4, 2, -1},
    59  		},
    60  		{ // 3
    61  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    62  			c:          []bool{false, true, false, true, true, true},
    63  			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
    64  			wantTPR:    []float64{0.25, 0.5, 0.75, 0.75, 0.75, 1, 1, 1, 1},
    65  			wantFPR:    []float64{0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 1},
    66  			wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1},
    67  		},
    68  		{ // 4
    69  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    70  			c:          []bool{false, true, false, true, true, true},
    71  			w:          []float64{4, 1, 6, 3, 2, 2},
    72  			cutoffs:    []float64{-1, 2, 4, 6, 8},
    73  			wantTPR:    []float64{0.25, 0.875, 0.875, 1, 1},
    74  			wantFPR:    []float64{0, 0, 0.6, 0.6, 1},
    75  			wantThresh: []float64{8, 6, 4, 2, -1},
    76  		},
    77  		{ // 5
    78  			y:          []float64{0, 3, 5, 6, 7.5, 8},
    79  			c:          []bool{false, true, false, true, true, true},
    80  			w:          []float64{4, 1, 6, 3, 2, 2},
    81  			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
    82  			wantTPR:    []float64{0.25, 0.5, 0.875, 0.875, 0.875, 1, 1, 1, 1},
    83  			wantFPR:    []float64{0, 0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 1},
    84  			wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1},
    85  		},
    86  		{ // 6
    87  			y:          []float64{0, 3, 6, 6, 6, 8},
    88  			c:          []bool{false, true, false, true, true, true},
    89  			wantTPR:    []float64{0, 0.25, 0.75, 1, 1},
    90  			wantFPR:    []float64{0, 0, 0.5, 0.5, 1},
    91  			wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
    92  		},
    93  		{ // 7
    94  			y:          []float64{0, 3, 6, 6, 6, 8},
    95  			c:          []bool{false, true, false, true, true, true},
    96  			w:          []float64{4, 1, 6, 3, 2, 2},
    97  			wantTPR:    []float64{0, 0.25, 0.875, 1, 1},
    98  			wantFPR:    []float64{0, 0, 0.6, 0.6, 1},
    99  			wantThresh: []float64{math.Inf(1), 8, 6, 3, 0},
   100  		},
   101  		{ // 8
   102  			y:          []float64{0, 3, 6, 6, 6, 8},
   103  			c:          []bool{false, true, false, true, true, true},
   104  			cutoffs:    []float64{-1, 2, 4, 6, 8},
   105  			wantTPR:    []float64{0.25, 0.75, 0.75, 1, 1},
   106  			wantFPR:    []float64{0, 0.5, 0.5, 0.5, 1},
   107  			wantThresh: []float64{8, 6, 4, 2, -1},
   108  		},
   109  		{ // 9
   110  			y:          []float64{0, 3, 6, 6, 6, 8},
   111  			c:          []bool{false, true, false, true, true, true},
   112  			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
   113  			wantTPR:    []float64{0.25, 0.25, 0.75, 0.75, 0.75, 1, 1, 1, 1},
   114  			wantFPR:    []float64{0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1},
   115  			wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1},
   116  		},
   117  		{ // 10
   118  			y:          []float64{0, 3, 6, 6, 6, 8},
   119  			c:          []bool{false, true, false, true, true, true},
   120  			w:          []float64{4, 1, 6, 3, 2, 2},
   121  			cutoffs:    []float64{-1, 2, 4, 6, 8},
   122  			wantTPR:    []float64{0.25, 0.875, 0.875, 1, 1},
   123  			wantFPR:    []float64{0, 0.6, 0.6, 0.6, 1},
   124  			wantThresh: []float64{8, 6, 4, 2, -1},
   125  		},
   126  		{ // 11
   127  			y:          []float64{0, 3, 6, 6, 6, 8},
   128  			c:          []bool{false, true, false, true, true, true},
   129  			w:          []float64{4, 1, 6, 3, 2, 2},
   130  			cutoffs:    []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8},
   131  			wantTPR:    []float64{0.25, 0.25, 0.875, 0.875, 0.875, 1, 1, 1, 1},
   132  			wantFPR:    []float64{0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 1},
   133  			wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1},
   134  		},
   135  		{ // 12
   136  			y:          []float64{0.1, 0.35, 0.4, 0.8},
   137  			c:          []bool{true, false, true, false},
   138  			wantTPR:    []float64{0, 0, 0.5, 0.5, 1},
   139  			wantFPR:    []float64{0, 0.5, 0.5, 1, 1},
   140  			wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
   141  		},
   142  		{ // 13
   143  			y:          []float64{0.1, 0.35, 0.4, 0.8},
   144  			c:          []bool{false, false, true, true},
   145  			wantTPR:    []float64{0, 0.5, 1, 1, 1},
   146  			wantFPR:    []float64{0, 0, 0, 0.5, 1},
   147  			wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1},
   148  		},
   149  		{ // 14
   150  			y:          []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10},
   151  			c:          []bool{false, true, false, false, true, true, false},
   152  			cutoffs:    []float64{-1, 2.5, 5, 7.5, 10},
   153  			wantTPR:    []float64{0, 0, 0, 0, 1},
   154  			wantFPR:    []float64{0.25, 0.25, 0.25, 0.25, 1},
   155  			wantThresh: []float64{10, 7.5, 5, 2.5, -1},
   156  		},
   157  		{ // 15
   158  			y:          []float64{1, 2},
   159  			c:          []bool{false, false},
   160  			wantTPR:    []float64{math.NaN(), math.NaN(), math.NaN()},
   161  			wantFPR:    []float64{0, 0.5, 1},
   162  			wantThresh: []float64{math.Inf(1), 2, 1},
   163  		},
   164  		{ // 16
   165  			y:          []float64{1, 2},
   166  			c:          []bool{false, false},
   167  			cutoffs:    []float64{-1, 2},
   168  			wantTPR:    []float64{math.NaN(), math.NaN()},
   169  			wantFPR:    []float64{0.5, 1},
   170  			wantThresh: []float64{2, -1},
   171  		},
   172  		{ // 17
   173  			y:          []float64{1, 2},
   174  			c:          []bool{false, false},
   175  			cutoffs:    []float64{0, 1.2, 1.4, 1.6, 1.8, 2},
   176  			wantTPR:    []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()},
   177  			wantFPR:    []float64{0.5, 0.5, 0.5, 0.5, 0.5, 1},
   178  			wantThresh: []float64{2, 1.8, 1.6, 1.4, 1.2, 0},
   179  		},
   180  		{ // 18
   181  			y:          []float64{1},
   182  			c:          []bool{false},
   183  			wantTPR:    []float64{math.NaN(), math.NaN()},
   184  			wantFPR:    []float64{0, 1},
   185  			wantThresh: []float64{math.Inf(1), 1},
   186  		},
   187  		{ // 19
   188  			y:          []float64{1},
   189  			c:          []bool{false},
   190  			cutoffs:    []float64{-1, 1},
   191  			wantTPR:    []float64{math.NaN(), math.NaN()},
   192  			wantFPR:    []float64{1, 1},
   193  			wantThresh: []float64{1, -1},
   194  		},
   195  		{ // 20
   196  			y:          []float64{1},
   197  			c:          []bool{true},
   198  			wantTPR:    []float64{0, 1},
   199  			wantFPR:    []float64{math.NaN(), math.NaN()},
   200  			wantThresh: []float64{math.Inf(1), 1},
   201  		},
   202  		{ // 21
   203  			y:          []float64{},
   204  			c:          []bool{},
   205  			wantTPR:    nil,
   206  			wantFPR:    nil,
   207  			wantThresh: nil,
   208  		},
   209  		{ // 22
   210  			y:          []float64{},
   211  			c:          []bool{},
   212  			cutoffs:    []float64{-1, 2.5, 5, 7.5, 10},
   213  			wantTPR:    nil,
   214  			wantFPR:    nil,
   215  			wantThresh: nil,
   216  		},
   217  		{ // 23
   218  			y:          []float64{0.1, 0.35, 0.4, 0.8},
   219  			c:          []bool{true, false, true, false},
   220  			cutoffs:    []float64{-1, 0.1, 0.35, 0.4, 0.8, 0.9, 1},
   221  			wantTPR:    []float64{0, 0, 0, 0.5, 0.5, 1, 1},
   222  			wantFPR:    []float64{0, 0, 0.5, 0.5, 1, 1, 1},
   223  			wantThresh: []float64{1, 0.9, 0.8, 0.4, 0.35, 0.1, -1},
   224  		},
   225  		{ // 24
   226  			y:          []float64{0.1, 0.35, 0.4, 0.8},
   227  			c:          []bool{true, false, true, false},
   228  			cutoffs:    []float64{math.Inf(-1), 0.1, 0.36, 0.8},
   229  			wantTPR:    []float64{0, 0.5, 1, 1},
   230  			wantFPR:    []float64{0.5, 0.5, 1, 1},
   231  			wantThresh: []float64{0.8, 0.36, 0.1, math.Inf(-1)},
   232  		},
   233  		{ // 25
   234  			y:          []float64{0, 3, 5, 6, 7.5, 8},
   235  			c:          []bool{false, true, false, true, true, true},
   236  			cutoffs:    make([]float64, 0, 10),
   237  			wantTPR:    []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1},
   238  			wantFPR:    []float64{0, 0, 0, 0, 0.5, 0.5, 1},
   239  			wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0},
   240  		},
   241  		{ // 26
   242  			y:          []float64{0.1, 0.35, 0.4, 0.8},
   243  			c:          []bool{true, false, true, false},
   244  			cutoffs:    []float64{-1, 0.1, 0.35, 0.4, 0.8, 0.9, 1, 1.1, 1.2},
   245  			wantTPR:    []float64{0, 0, 0, 0, 0, 0.5, 0.5, 1, 1},
   246  			wantFPR:    []float64{0, 0, 0, 0, 0.5, 0.5, 1, 1, 1},
   247  			wantThresh: []float64{1.2, 1.1, 1, 0.9, 0.8, 0.4, 0.35, 0.1, -1},
   248  		},
   249  	}
   250  	for i, test := range cases {
   251  		gotTPR, gotFPR, gotThresh := ROC(test.cutoffs, test.y, test.c, test.w)
   252  		if !floats.Same(gotTPR, test.wantTPR) && !floats.EqualApprox(gotTPR, test.wantTPR, tol) {
   253  			t.Errorf("%d: unexpected TPR got:%v want:%v", i, gotTPR, test.wantTPR)
   254  		}
   255  		if !floats.Same(gotFPR, test.wantFPR) && !floats.EqualApprox(gotFPR, test.wantFPR, tol) {
   256  			t.Errorf("%d: unexpected FPR got:%v want:%v", i, gotFPR, test.wantFPR)
   257  		}
   258  		if !floats.Same(gotThresh, test.wantThresh) {
   259  			t.Errorf("%d: unexpected thresholds got:%#v want:%v", i, gotThresh, test.wantThresh)
   260  		}
   261  	}
   262  }
   263  
   264  func TestTOC(t *testing.T) {
   265  	cases := []struct {
   266  		c       []bool
   267  		w       []float64
   268  		wantMin []float64
   269  		wantMax []float64
   270  		wantTOC []float64
   271  	}{
   272  		{ // 0
   273  			// This is the example given in the paper's supplement.
   274  			// http://www2.clarku.edu/~rpontius/TOCexample2.xlsx
   275  			// It is also shown in the WP article.
   276  			// https://en.wikipedia.org/wiki/Total_operating_characteristic#/media/File:TOC_labeled.png
   277  			c: []bool{
   278  				false, false, false, false, false, false,
   279  				false, false, false, false, false, false,
   280  				false, false, true, true, true, true,
   281  				true, true, true, false, false, true,
   282  				false, true, false, false, true, false,
   283  			},
   284  			wantMin: []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
   285  			wantMax: []float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10},
   286  			wantTOC: []float64{0, 0, 1, 1, 1, 2, 2, 3, 3, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10},
   287  		},
   288  		{ // 1
   289  			c:       []bool{},
   290  			wantMin: nil,
   291  			wantMax: nil,
   292  			wantTOC: nil,
   293  		},
   294  		{ // 2
   295  			c: []bool{
   296  				true, true, true, true, true,
   297  			},
   298  			wantMin: []float64{0, 1, 2, 3, 4, 5},
   299  			wantMax: []float64{0, 1, 2, 3, 4, 5},
   300  			wantTOC: []float64{0, 1, 2, 3, 4, 5},
   301  		},
   302  		{ // 3
   303  			c: []bool{
   304  				false, false, false, false, false,
   305  			},
   306  			wantMin: []float64{0, 0, 0, 0, 0, 0},
   307  			wantMax: []float64{0, 0, 0, 0, 0, 0},
   308  			wantTOC: []float64{0, 0, 0, 0, 0, 0},
   309  		},
   310  		{ // 4
   311  			c:       []bool{false, false, false, true, false, true},
   312  			w:       []float64{2, 2, 3, 6, 1, 4},
   313  			wantMin: []float64{0, 0, 0, 3, 6, 8, 10},
   314  			wantMax: []float64{0, 4, 5, 10, 10, 10, 10},
   315  			wantTOC: []float64{0, 4, 4, 10, 10, 10, 10},
   316  		},
   317  	}
   318  	for i, test := range cases {
   319  		gotMin, gotTOC, gotMax := TOC(test.c, test.w)
   320  		if !floats.Same(gotMin, test.wantMin) {
   321  			t.Errorf("%d: unexpected minimum bound got:%v want:%v", i, gotMin, test.wantMin)
   322  		}
   323  		if !floats.Same(gotMax, test.wantMax) {
   324  			t.Errorf("%d: unexpected maximum bound got:%v want:%v", i, gotMax, test.wantMax)
   325  		}
   326  		if !floats.Same(gotTOC, test.wantTOC) {
   327  			t.Errorf("%d: unexpected TOC got:%v want:%v", i, gotTOC, test.wantTOC)
   328  		}
   329  	}
   330  }
   331  
   332  func BenchmarkROC(b *testing.B) {
   333  	sizes := []int{empty, small, medium, large}
   334  	for _, cutoffsSize := range sizes {
   335  		for _, ySize := range sizes {
   336  			classesSize := ySize
   337  			for _, weightsSize := range slices.Compact([]int{empty, ySize}) {
   338  				benchmarkROC(b, cutoffsSize, ySize, classesSize, weightsSize)
   339  			}
   340  		}
   341  	}
   342  }
   343  
   344  func benchmarkROC(b *testing.B, cutoffsSize int, ySize int, classesSize int, weightsSize int) bool {
   345  	return b.Run(
   346  		fmt.Sprintf(
   347  			"cutoffs=%d,y=%d,classes=%d,weights=%d",
   348  			cutoffsSize, ySize, classesSize, weightsSize),
   349  		func(b *testing.B) {
   350  			src := rand.NewSource(1)
   351  
   352  			cutoffs := randomFloats(cutoffsSize, src)
   353  			slices.Sort(cutoffs)
   354  
   355  			y := randomFloats(ySize, src)
   356  			slices.Sort(y)
   357  
   358  			classes := randomBools(classesSize, src)
   359  
   360  			var weights []float64
   361  			if weightsSize != empty {
   362  				weights = randomFloats(weightsSize, src)
   363  			}
   364  
   365  			b.ResetTimer()
   366  			for i := 0; i < b.N; i++ {
   367  				ROC(cutoffs, y, classes, weights)
   368  			}
   369  		})
   370  }
   371  
   372  func randomFloats(l int, src rand.Source) []float64 {
   373  	rnd := rand.New(src)
   374  	s := make([]float64, l)
   375  	for i := range s {
   376  		s[i] = rnd.Float64()
   377  	}
   378  	return s
   379  }
   380  
   381  func randomBools(l int, src rand.Source) []bool {
   382  	rnd := rand.New(src)
   383  	s := make([]bool, l)
   384  	for i := range s {
   385  		s[i] = rnd.Int31n(2) == 1
   386  	}
   387  	return s
   388  }