gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/stat/roc_test.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package stat 6 7 import ( 8 "fmt" 9 "math" 10 "slices" 11 "testing" 12 13 "golang.org/x/exp/rand" 14 15 "gonum.org/v1/gonum/floats" 16 ) 17 18 func TestROC(t *testing.T) { 19 const tol = 1e-14 20 21 cases := []struct { 22 y []float64 23 c []bool 24 w []float64 25 cutoffs []float64 26 wantTPR []float64 27 wantFPR []float64 28 wantThresh []float64 29 }{ 30 // Test cases were informed by using sklearn metrics.roc_curve when 31 // cutoffs is nil, but all test cases (including when cutoffs is not 32 // nil) were calculated manually. 33 // Some differences exist between unweighted ROCs from our function 34 // and metrics.roc_curve which appears to use integer cutoffs in that 35 // case. sklearn also appears to do some magic that trims leading zeros 36 // sometimes. 37 { // 0 38 y: []float64{0, 3, 5, 6, 7.5, 8}, 39 c: []bool{false, true, false, true, true, true}, 40 wantTPR: []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1}, 41 wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 1}, 42 wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0}, 43 }, 44 { // 1 45 y: []float64{0, 3, 5, 6, 7.5, 8}, 46 c: []bool{false, true, false, true, true, true}, 47 w: []float64{4, 1, 6, 3, 2, 2}, 48 wantTPR: []float64{0, 0.25, 0.5, 0.875, 0.875, 1, 1}, 49 wantFPR: []float64{0, 0, 0, 0, 0.6, 0.6, 1}, 50 wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0}, 51 }, 52 { // 2 53 y: []float64{0, 3, 5, 6, 7.5, 8}, 54 c: []bool{false, true, false, true, true, true}, 55 cutoffs: []float64{-1, 2, 4, 6, 8}, 56 wantTPR: []float64{0.25, 0.75, 0.75, 1, 1}, 57 wantFPR: []float64{0, 0, 0.5, 0.5, 1}, 58 wantThresh: []float64{8, 6, 4, 2, -1}, 59 }, 60 { // 3 61 y: []float64{0, 3, 5, 6, 7.5, 8}, 62 c: []bool{false, true, false, true, true, true}, 63 cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, 64 wantTPR: []float64{0.25, 0.5, 0.75, 0.75, 0.75, 1, 1, 1, 1}, 65 wantFPR: []float64{0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 1}, 66 wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1}, 67 }, 68 { // 4 69 y: []float64{0, 3, 5, 6, 7.5, 8}, 70 c: []bool{false, true, false, true, true, true}, 71 w: []float64{4, 1, 6, 3, 2, 2}, 72 cutoffs: []float64{-1, 2, 4, 6, 8}, 73 wantTPR: []float64{0.25, 0.875, 0.875, 1, 1}, 74 wantFPR: []float64{0, 0, 0.6, 0.6, 1}, 75 wantThresh: []float64{8, 6, 4, 2, -1}, 76 }, 77 { // 5 78 y: []float64{0, 3, 5, 6, 7.5, 8}, 79 c: []bool{false, true, false, true, true, true}, 80 w: []float64{4, 1, 6, 3, 2, 2}, 81 cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, 82 wantTPR: []float64{0.25, 0.5, 0.875, 0.875, 0.875, 1, 1, 1, 1}, 83 wantFPR: []float64{0, 0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 1}, 84 wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1}, 85 }, 86 { // 6 87 y: []float64{0, 3, 6, 6, 6, 8}, 88 c: []bool{false, true, false, true, true, true}, 89 wantTPR: []float64{0, 0.25, 0.75, 1, 1}, 90 wantFPR: []float64{0, 0, 0.5, 0.5, 1}, 91 wantThresh: []float64{math.Inf(1), 8, 6, 3, 0}, 92 }, 93 { // 7 94 y: []float64{0, 3, 6, 6, 6, 8}, 95 c: []bool{false, true, false, true, true, true}, 96 w: []float64{4, 1, 6, 3, 2, 2}, 97 wantTPR: []float64{0, 0.25, 0.875, 1, 1}, 98 wantFPR: []float64{0, 0, 0.6, 0.6, 1}, 99 wantThresh: []float64{math.Inf(1), 8, 6, 3, 0}, 100 }, 101 { // 8 102 y: []float64{0, 3, 6, 6, 6, 8}, 103 c: []bool{false, true, false, true, true, true}, 104 cutoffs: []float64{-1, 2, 4, 6, 8}, 105 wantTPR: []float64{0.25, 0.75, 0.75, 1, 1}, 106 wantFPR: []float64{0, 0.5, 0.5, 0.5, 1}, 107 wantThresh: []float64{8, 6, 4, 2, -1}, 108 }, 109 { // 9 110 y: []float64{0, 3, 6, 6, 6, 8}, 111 c: []bool{false, true, false, true, true, true}, 112 cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, 113 wantTPR: []float64{0.25, 0.25, 0.75, 0.75, 0.75, 1, 1, 1, 1}, 114 wantFPR: []float64{0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1}, 115 wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1}, 116 }, 117 { // 10 118 y: []float64{0, 3, 6, 6, 6, 8}, 119 c: []bool{false, true, false, true, true, true}, 120 w: []float64{4, 1, 6, 3, 2, 2}, 121 cutoffs: []float64{-1, 2, 4, 6, 8}, 122 wantTPR: []float64{0.25, 0.875, 0.875, 1, 1}, 123 wantFPR: []float64{0, 0.6, 0.6, 0.6, 1}, 124 wantThresh: []float64{8, 6, 4, 2, -1}, 125 }, 126 { // 11 127 y: []float64{0, 3, 6, 6, 6, 8}, 128 c: []bool{false, true, false, true, true, true}, 129 w: []float64{4, 1, 6, 3, 2, 2}, 130 cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, 131 wantTPR: []float64{0.25, 0.25, 0.875, 0.875, 0.875, 1, 1, 1, 1}, 132 wantFPR: []float64{0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 1}, 133 wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1}, 134 }, 135 { // 12 136 y: []float64{0.1, 0.35, 0.4, 0.8}, 137 c: []bool{true, false, true, false}, 138 wantTPR: []float64{0, 0, 0.5, 0.5, 1}, 139 wantFPR: []float64{0, 0.5, 0.5, 1, 1}, 140 wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1}, 141 }, 142 { // 13 143 y: []float64{0.1, 0.35, 0.4, 0.8}, 144 c: []bool{false, false, true, true}, 145 wantTPR: []float64{0, 0.5, 1, 1, 1}, 146 wantFPR: []float64{0, 0, 0, 0.5, 1}, 147 wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1}, 148 }, 149 { // 14 150 y: []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10}, 151 c: []bool{false, true, false, false, true, true, false}, 152 cutoffs: []float64{-1, 2.5, 5, 7.5, 10}, 153 wantTPR: []float64{0, 0, 0, 0, 1}, 154 wantFPR: []float64{0.25, 0.25, 0.25, 0.25, 1}, 155 wantThresh: []float64{10, 7.5, 5, 2.5, -1}, 156 }, 157 { // 15 158 y: []float64{1, 2}, 159 c: []bool{false, false}, 160 wantTPR: []float64{math.NaN(), math.NaN(), math.NaN()}, 161 wantFPR: []float64{0, 0.5, 1}, 162 wantThresh: []float64{math.Inf(1), 2, 1}, 163 }, 164 { // 16 165 y: []float64{1, 2}, 166 c: []bool{false, false}, 167 cutoffs: []float64{-1, 2}, 168 wantTPR: []float64{math.NaN(), math.NaN()}, 169 wantFPR: []float64{0.5, 1}, 170 wantThresh: []float64{2, -1}, 171 }, 172 { // 17 173 y: []float64{1, 2}, 174 c: []bool{false, false}, 175 cutoffs: []float64{0, 1.2, 1.4, 1.6, 1.8, 2}, 176 wantTPR: []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()}, 177 wantFPR: []float64{0.5, 0.5, 0.5, 0.5, 0.5, 1}, 178 wantThresh: []float64{2, 1.8, 1.6, 1.4, 1.2, 0}, 179 }, 180 { // 18 181 y: []float64{1}, 182 c: []bool{false}, 183 wantTPR: []float64{math.NaN(), math.NaN()}, 184 wantFPR: []float64{0, 1}, 185 wantThresh: []float64{math.Inf(1), 1}, 186 }, 187 { // 19 188 y: []float64{1}, 189 c: []bool{false}, 190 cutoffs: []float64{-1, 1}, 191 wantTPR: []float64{math.NaN(), math.NaN()}, 192 wantFPR: []float64{1, 1}, 193 wantThresh: []float64{1, -1}, 194 }, 195 { // 20 196 y: []float64{1}, 197 c: []bool{true}, 198 wantTPR: []float64{0, 1}, 199 wantFPR: []float64{math.NaN(), math.NaN()}, 200 wantThresh: []float64{math.Inf(1), 1}, 201 }, 202 { // 21 203 y: []float64{}, 204 c: []bool{}, 205 wantTPR: nil, 206 wantFPR: nil, 207 wantThresh: nil, 208 }, 209 { // 22 210 y: []float64{}, 211 c: []bool{}, 212 cutoffs: []float64{-1, 2.5, 5, 7.5, 10}, 213 wantTPR: nil, 214 wantFPR: nil, 215 wantThresh: nil, 216 }, 217 { // 23 218 y: []float64{0.1, 0.35, 0.4, 0.8}, 219 c: []bool{true, false, true, false}, 220 cutoffs: []float64{-1, 0.1, 0.35, 0.4, 0.8, 0.9, 1}, 221 wantTPR: []float64{0, 0, 0, 0.5, 0.5, 1, 1}, 222 wantFPR: []float64{0, 0, 0.5, 0.5, 1, 1, 1}, 223 wantThresh: []float64{1, 0.9, 0.8, 0.4, 0.35, 0.1, -1}, 224 }, 225 { // 24 226 y: []float64{0.1, 0.35, 0.4, 0.8}, 227 c: []bool{true, false, true, false}, 228 cutoffs: []float64{math.Inf(-1), 0.1, 0.36, 0.8}, 229 wantTPR: []float64{0, 0.5, 1, 1}, 230 wantFPR: []float64{0.5, 0.5, 1, 1}, 231 wantThresh: []float64{0.8, 0.36, 0.1, math.Inf(-1)}, 232 }, 233 { // 25 234 y: []float64{0, 3, 5, 6, 7.5, 8}, 235 c: []bool{false, true, false, true, true, true}, 236 cutoffs: make([]float64, 0, 10), 237 wantTPR: []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1}, 238 wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 1}, 239 wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0}, 240 }, 241 { // 26 242 y: []float64{0.1, 0.35, 0.4, 0.8}, 243 c: []bool{true, false, true, false}, 244 cutoffs: []float64{-1, 0.1, 0.35, 0.4, 0.8, 0.9, 1, 1.1, 1.2}, 245 wantTPR: []float64{0, 0, 0, 0, 0, 0.5, 0.5, 1, 1}, 246 wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 1, 1, 1}, 247 wantThresh: []float64{1.2, 1.1, 1, 0.9, 0.8, 0.4, 0.35, 0.1, -1}, 248 }, 249 } 250 for i, test := range cases { 251 gotTPR, gotFPR, gotThresh := ROC(test.cutoffs, test.y, test.c, test.w) 252 if !floats.Same(gotTPR, test.wantTPR) && !floats.EqualApprox(gotTPR, test.wantTPR, tol) { 253 t.Errorf("%d: unexpected TPR got:%v want:%v", i, gotTPR, test.wantTPR) 254 } 255 if !floats.Same(gotFPR, test.wantFPR) && !floats.EqualApprox(gotFPR, test.wantFPR, tol) { 256 t.Errorf("%d: unexpected FPR got:%v want:%v", i, gotFPR, test.wantFPR) 257 } 258 if !floats.Same(gotThresh, test.wantThresh) { 259 t.Errorf("%d: unexpected thresholds got:%#v want:%v", i, gotThresh, test.wantThresh) 260 } 261 } 262 } 263 264 func TestTOC(t *testing.T) { 265 cases := []struct { 266 c []bool 267 w []float64 268 wantMin []float64 269 wantMax []float64 270 wantTOC []float64 271 }{ 272 { // 0 273 // This is the example given in the paper's supplement. 274 // http://www2.clarku.edu/~rpontius/TOCexample2.xlsx 275 // It is also shown in the WP article. 276 // https://en.wikipedia.org/wiki/Total_operating_characteristic#/media/File:TOC_labeled.png 277 c: []bool{ 278 false, false, false, false, false, false, 279 false, false, false, false, false, false, 280 false, false, true, true, true, true, 281 true, true, true, false, false, true, 282 false, true, false, false, true, false, 283 }, 284 wantMin: []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 285 wantMax: []float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10}, 286 wantTOC: []float64{0, 0, 1, 1, 1, 2, 2, 3, 3, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10}, 287 }, 288 { // 1 289 c: []bool{}, 290 wantMin: nil, 291 wantMax: nil, 292 wantTOC: nil, 293 }, 294 { // 2 295 c: []bool{ 296 true, true, true, true, true, 297 }, 298 wantMin: []float64{0, 1, 2, 3, 4, 5}, 299 wantMax: []float64{0, 1, 2, 3, 4, 5}, 300 wantTOC: []float64{0, 1, 2, 3, 4, 5}, 301 }, 302 { // 3 303 c: []bool{ 304 false, false, false, false, false, 305 }, 306 wantMin: []float64{0, 0, 0, 0, 0, 0}, 307 wantMax: []float64{0, 0, 0, 0, 0, 0}, 308 wantTOC: []float64{0, 0, 0, 0, 0, 0}, 309 }, 310 { // 4 311 c: []bool{false, false, false, true, false, true}, 312 w: []float64{2, 2, 3, 6, 1, 4}, 313 wantMin: []float64{0, 0, 0, 3, 6, 8, 10}, 314 wantMax: []float64{0, 4, 5, 10, 10, 10, 10}, 315 wantTOC: []float64{0, 4, 4, 10, 10, 10, 10}, 316 }, 317 } 318 for i, test := range cases { 319 gotMin, gotTOC, gotMax := TOC(test.c, test.w) 320 if !floats.Same(gotMin, test.wantMin) { 321 t.Errorf("%d: unexpected minimum bound got:%v want:%v", i, gotMin, test.wantMin) 322 } 323 if !floats.Same(gotMax, test.wantMax) { 324 t.Errorf("%d: unexpected maximum bound got:%v want:%v", i, gotMax, test.wantMax) 325 } 326 if !floats.Same(gotTOC, test.wantTOC) { 327 t.Errorf("%d: unexpected TOC got:%v want:%v", i, gotTOC, test.wantTOC) 328 } 329 } 330 } 331 332 func BenchmarkROC(b *testing.B) { 333 sizes := []int{empty, small, medium, large} 334 for _, cutoffsSize := range sizes { 335 for _, ySize := range sizes { 336 classesSize := ySize 337 for _, weightsSize := range slices.Compact([]int{empty, ySize}) { 338 benchmarkROC(b, cutoffsSize, ySize, classesSize, weightsSize) 339 } 340 } 341 } 342 } 343 344 func benchmarkROC(b *testing.B, cutoffsSize int, ySize int, classesSize int, weightsSize int) bool { 345 return b.Run( 346 fmt.Sprintf( 347 "cutoffs=%d,y=%d,classes=%d,weights=%d", 348 cutoffsSize, ySize, classesSize, weightsSize), 349 func(b *testing.B) { 350 src := rand.NewSource(1) 351 352 cutoffs := randomFloats(cutoffsSize, src) 353 slices.Sort(cutoffs) 354 355 y := randomFloats(ySize, src) 356 slices.Sort(y) 357 358 classes := randomBools(classesSize, src) 359 360 var weights []float64 361 if weightsSize != empty { 362 weights = randomFloats(weightsSize, src) 363 } 364 365 b.ResetTimer() 366 for i := 0; i < b.N; i++ { 367 ROC(cutoffs, y, classes, weights) 368 } 369 }) 370 } 371 372 func randomFloats(l int, src rand.Source) []float64 { 373 rnd := rand.New(src) 374 s := make([]float64, l) 375 for i := range s { 376 s[i] = rnd.Float64() 377 } 378 return s 379 } 380 381 func randomBools(l int, src rand.Source) []bool { 382 rnd := rand.New(src) 383 s := make([]bool, l) 384 for i := range s { 385 s[i] = rnd.Int31n(2) == 1 386 } 387 return s 388 }