gonum.org/v1/gonum@v0.14.0/stat/roc_test.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package stat 6 7 import ( 8 "math" 9 "testing" 10 11 "gonum.org/v1/gonum/floats" 12 ) 13 14 func TestROC(t *testing.T) { 15 const tol = 1e-14 16 17 cases := []struct { 18 y []float64 19 c []bool 20 w []float64 21 cutoffs []float64 22 wantTPR []float64 23 wantFPR []float64 24 wantThresh []float64 25 }{ 26 // Test cases were informed by using sklearn metrics.roc_curve when 27 // cutoffs is nil, but all test cases (including when cutoffs is not 28 // nil) were calculated manually. 29 // Some differences exist between unweighted ROCs from our function 30 // and metrics.roc_curve which appears to use integer cutoffs in that 31 // case. sklearn also appears to do some magic that trims leading zeros 32 // sometimes. 33 { // 0 34 y: []float64{0, 3, 5, 6, 7.5, 8}, 35 c: []bool{false, true, false, true, true, true}, 36 wantTPR: []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1}, 37 wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 1}, 38 wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0}, 39 }, 40 { // 1 41 y: []float64{0, 3, 5, 6, 7.5, 8}, 42 c: []bool{false, true, false, true, true, true}, 43 w: []float64{4, 1, 6, 3, 2, 2}, 44 wantTPR: []float64{0, 0.25, 0.5, 0.875, 0.875, 1, 1}, 45 wantFPR: []float64{0, 0, 0, 0, 0.6, 0.6, 1}, 46 wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0}, 47 }, 48 { // 2 49 y: []float64{0, 3, 5, 6, 7.5, 8}, 50 c: []bool{false, true, false, true, true, true}, 51 cutoffs: []float64{-1, 2, 4, 6, 8}, 52 wantTPR: []float64{0.25, 0.75, 0.75, 1, 1}, 53 wantFPR: []float64{0, 0, 0.5, 0.5, 1}, 54 wantThresh: []float64{8, 6, 4, 2, -1}, 55 }, 56 { // 3 57 y: []float64{0, 3, 5, 6, 7.5, 8}, 58 c: []bool{false, true, false, true, true, true}, 59 cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, 60 wantTPR: []float64{0.25, 0.5, 0.75, 0.75, 0.75, 1, 1, 1, 1}, 61 wantFPR: []float64{0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 1}, 62 wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1}, 63 }, 64 { // 4 65 y: []float64{0, 3, 5, 6, 7.5, 8}, 66 c: []bool{false, true, false, true, true, true}, 67 w: []float64{4, 1, 6, 3, 2, 2}, 68 cutoffs: []float64{-1, 2, 4, 6, 8}, 69 wantTPR: []float64{0.25, 0.875, 0.875, 1, 1}, 70 wantFPR: []float64{0, 0, 0.6, 0.6, 1}, 71 wantThresh: []float64{8, 6, 4, 2, -1}, 72 }, 73 { // 5 74 y: []float64{0, 3, 5, 6, 7.5, 8}, 75 c: []bool{false, true, false, true, true, true}, 76 w: []float64{4, 1, 6, 3, 2, 2}, 77 cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, 78 wantTPR: []float64{0.25, 0.5, 0.875, 0.875, 0.875, 1, 1, 1, 1}, 79 wantFPR: []float64{0, 0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 1}, 80 wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1}, 81 }, 82 { // 6 83 y: []float64{0, 3, 6, 6, 6, 8}, 84 c: []bool{false, true, false, true, true, true}, 85 wantTPR: []float64{0, 0.25, 0.75, 1, 1}, 86 wantFPR: []float64{0, 0, 0.5, 0.5, 1}, 87 wantThresh: []float64{math.Inf(1), 8, 6, 3, 0}, 88 }, 89 { // 7 90 y: []float64{0, 3, 6, 6, 6, 8}, 91 c: []bool{false, true, false, true, true, true}, 92 w: []float64{4, 1, 6, 3, 2, 2}, 93 wantTPR: []float64{0, 0.25, 0.875, 1, 1}, 94 wantFPR: []float64{0, 0, 0.6, 0.6, 1}, 95 wantThresh: []float64{math.Inf(1), 8, 6, 3, 0}, 96 }, 97 { // 8 98 y: []float64{0, 3, 6, 6, 6, 8}, 99 c: []bool{false, true, false, true, true, true}, 100 cutoffs: []float64{-1, 2, 4, 6, 8}, 101 wantTPR: []float64{0.25, 0.75, 0.75, 1, 1}, 102 wantFPR: []float64{0, 0.5, 0.5, 0.5, 1}, 103 wantThresh: []float64{8, 6, 4, 2, -1}, 104 }, 105 { // 9 106 y: []float64{0, 3, 6, 6, 6, 8}, 107 c: []bool{false, true, false, true, true, true}, 108 cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, 109 wantTPR: []float64{0.25, 0.25, 0.75, 0.75, 0.75, 1, 1, 1, 1}, 110 wantFPR: []float64{0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1}, 111 wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1}, 112 }, 113 { // 10 114 y: []float64{0, 3, 6, 6, 6, 8}, 115 c: []bool{false, true, false, true, true, true}, 116 w: []float64{4, 1, 6, 3, 2, 2}, 117 cutoffs: []float64{-1, 2, 4, 6, 8}, 118 wantTPR: []float64{0.25, 0.875, 0.875, 1, 1}, 119 wantFPR: []float64{0, 0.6, 0.6, 0.6, 1}, 120 wantThresh: []float64{8, 6, 4, 2, -1}, 121 }, 122 { // 11 123 y: []float64{0, 3, 6, 6, 6, 8}, 124 c: []bool{false, true, false, true, true, true}, 125 w: []float64{4, 1, 6, 3, 2, 2}, 126 cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, 127 wantTPR: []float64{0.25, 0.25, 0.875, 0.875, 0.875, 1, 1, 1, 1}, 128 wantFPR: []float64{0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 1}, 129 wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1}, 130 }, 131 { // 12 132 y: []float64{0.1, 0.35, 0.4, 0.8}, 133 c: []bool{true, false, true, false}, 134 wantTPR: []float64{0, 0, 0.5, 0.5, 1}, 135 wantFPR: []float64{0, 0.5, 0.5, 1, 1}, 136 wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1}, 137 }, 138 { // 13 139 y: []float64{0.1, 0.35, 0.4, 0.8}, 140 c: []bool{false, false, true, true}, 141 wantTPR: []float64{0, 0.5, 1, 1, 1}, 142 wantFPR: []float64{0, 0, 0, 0.5, 1}, 143 wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1}, 144 }, 145 { // 14 146 y: []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10}, 147 c: []bool{false, true, false, false, true, true, false}, 148 cutoffs: []float64{-1, 2.5, 5, 7.5, 10}, 149 wantTPR: []float64{0, 0, 0, 0, 1}, 150 wantFPR: []float64{0.25, 0.25, 0.25, 0.25, 1}, 151 wantThresh: []float64{10, 7.5, 5, 2.5, -1}, 152 }, 153 { // 15 154 y: []float64{1, 2}, 155 c: []bool{false, false}, 156 wantTPR: []float64{math.NaN(), math.NaN(), math.NaN()}, 157 wantFPR: []float64{0, 0.5, 1}, 158 wantThresh: []float64{math.Inf(1), 2, 1}, 159 }, 160 { // 16 161 y: []float64{1, 2}, 162 c: []bool{false, false}, 163 cutoffs: []float64{-1, 2}, 164 wantTPR: []float64{math.NaN(), math.NaN()}, 165 wantFPR: []float64{0.5, 1}, 166 wantThresh: []float64{2, -1}, 167 }, 168 { // 17 169 y: []float64{1, 2}, 170 c: []bool{false, false}, 171 cutoffs: []float64{0, 1.2, 1.4, 1.6, 1.8, 2}, 172 wantTPR: []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()}, 173 wantFPR: []float64{0.5, 0.5, 0.5, 0.5, 0.5, 1}, 174 wantThresh: []float64{2, 1.8, 1.6, 1.4, 1.2, 0}, 175 }, 176 { // 18 177 y: []float64{1}, 178 c: []bool{false}, 179 wantTPR: []float64{math.NaN(), math.NaN()}, 180 wantFPR: []float64{0, 1}, 181 wantThresh: []float64{math.Inf(1), 1}, 182 }, 183 { // 19 184 y: []float64{1}, 185 c: []bool{false}, 186 cutoffs: []float64{-1, 1}, 187 wantTPR: []float64{math.NaN(), math.NaN()}, 188 wantFPR: []float64{1, 1}, 189 wantThresh: []float64{1, -1}, 190 }, 191 { // 20 192 y: []float64{1}, 193 c: []bool{true}, 194 wantTPR: []float64{0, 1}, 195 wantFPR: []float64{math.NaN(), math.NaN()}, 196 wantThresh: []float64{math.Inf(1), 1}, 197 }, 198 { // 21 199 y: []float64{}, 200 c: []bool{}, 201 wantTPR: nil, 202 wantFPR: nil, 203 wantThresh: nil, 204 }, 205 { // 22 206 y: []float64{}, 207 c: []bool{}, 208 cutoffs: []float64{-1, 2.5, 5, 7.5, 10}, 209 wantTPR: nil, 210 wantFPR: nil, 211 wantThresh: nil, 212 }, 213 { // 23 214 y: []float64{0.1, 0.35, 0.4, 0.8}, 215 c: []bool{true, false, true, false}, 216 cutoffs: []float64{-1, 0.1, 0.35, 0.4, 0.8, 0.9, 1}, 217 wantTPR: []float64{0, 0, 0, 0.5, 0.5, 1, 1}, 218 wantFPR: []float64{0, 0, 0.5, 0.5, 1, 1, 1}, 219 wantThresh: []float64{1, 0.9, 0.8, 0.4, 0.35, 0.1, -1}, 220 }, 221 { // 24 222 y: []float64{0.1, 0.35, 0.4, 0.8}, 223 c: []bool{true, false, true, false}, 224 cutoffs: []float64{math.Inf(-1), 0.1, 0.36, 0.8}, 225 wantTPR: []float64{0, 0.5, 1, 1}, 226 wantFPR: []float64{0.5, 0.5, 1, 1}, 227 wantThresh: []float64{0.8, 0.36, 0.1, math.Inf(-1)}, 228 }, 229 { // 25 230 y: []float64{0, 3, 5, 6, 7.5, 8}, 231 c: []bool{false, true, false, true, true, true}, 232 cutoffs: make([]float64, 0, 10), 233 wantTPR: []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1}, 234 wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 1}, 235 wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0}, 236 }, 237 { // 26 238 y: []float64{0.1, 0.35, 0.4, 0.8}, 239 c: []bool{true, false, true, false}, 240 cutoffs: []float64{-1, 0.1, 0.35, 0.4, 0.8, 0.9, 1, 1.1, 1.2}, 241 wantTPR: []float64{0, 0, 0, 0, 0, 0.5, 0.5, 1, 1}, 242 wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 1, 1, 1}, 243 wantThresh: []float64{1.2, 1.1, 1, 0.9, 0.8, 0.4, 0.35, 0.1, -1}, 244 }, 245 } 246 for i, test := range cases { 247 gotTPR, gotFPR, gotThresh := ROC(test.cutoffs, test.y, test.c, test.w) 248 if !floats.Same(gotTPR, test.wantTPR) && !floats.EqualApprox(gotTPR, test.wantTPR, tol) { 249 t.Errorf("%d: unexpected TPR got:%v want:%v", i, gotTPR, test.wantTPR) 250 } 251 if !floats.Same(gotFPR, test.wantFPR) && !floats.EqualApprox(gotFPR, test.wantFPR, tol) { 252 t.Errorf("%d: unexpected FPR got:%v want:%v", i, gotFPR, test.wantFPR) 253 } 254 if !floats.Same(gotThresh, test.wantThresh) { 255 t.Errorf("%d: unexpected thresholds got:%#v want:%v", i, gotThresh, test.wantThresh) 256 } 257 } 258 } 259 260 func TestTOC(t *testing.T) { 261 cases := []struct { 262 c []bool 263 w []float64 264 wantMin []float64 265 wantMax []float64 266 wantTOC []float64 267 }{ 268 { // 0 269 // This is the example given in the paper's supplement. 270 // http://www2.clarku.edu/~rpontius/TOCexample2.xlsx 271 // It is also shown in the WP article. 272 // https://en.wikipedia.org/wiki/Total_operating_characteristic#/media/File:TOC_labeled.png 273 c: []bool{ 274 false, false, false, false, false, false, 275 false, false, false, false, false, false, 276 false, false, true, true, true, true, 277 true, true, true, false, false, true, 278 false, true, false, false, true, false, 279 }, 280 wantMin: []float64{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 281 wantMax: []float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10}, 282 wantTOC: []float64{0, 0, 1, 1, 1, 2, 2, 3, 3, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10}, 283 }, 284 { // 1 285 c: []bool{}, 286 wantMin: nil, 287 wantMax: nil, 288 wantTOC: nil, 289 }, 290 { // 2 291 c: []bool{ 292 true, true, true, true, true, 293 }, 294 wantMin: []float64{0, 1, 2, 3, 4, 5}, 295 wantMax: []float64{0, 1, 2, 3, 4, 5}, 296 wantTOC: []float64{0, 1, 2, 3, 4, 5}, 297 }, 298 { // 3 299 c: []bool{ 300 false, false, false, false, false, 301 }, 302 wantMin: []float64{0, 0, 0, 0, 0, 0}, 303 wantMax: []float64{0, 0, 0, 0, 0, 0}, 304 wantTOC: []float64{0, 0, 0, 0, 0, 0}, 305 }, 306 { // 4 307 c: []bool{false, false, false, true, false, true}, 308 w: []float64{2, 2, 3, 6, 1, 4}, 309 wantMin: []float64{0, 0, 0, 3, 6, 8, 10}, 310 wantMax: []float64{0, 4, 5, 10, 10, 10, 10}, 311 wantTOC: []float64{0, 4, 4, 10, 10, 10, 10}, 312 }, 313 } 314 for i, test := range cases { 315 gotMin, gotTOC, gotMax := TOC(test.c, test.w) 316 if !floats.Same(gotMin, test.wantMin) { 317 t.Errorf("%d: unexpected minimum bound got:%v want:%v", i, gotMin, test.wantMin) 318 } 319 if !floats.Same(gotMax, test.wantMax) { 320 t.Errorf("%d: unexpected maximum bound got:%v want:%v", i, gotMax, test.wantMax) 321 } 322 if !floats.Same(gotTOC, test.wantTOC) { 323 t.Errorf("%d: unexpected TOC got:%v want:%v", i, gotTOC, test.wantTOC) 324 } 325 } 326 }