github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/stat/roc_test.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package stat 6 7 import ( 8 "math" 9 "testing" 10 11 "github.com/jingcheng-WU/gonum/floats" 12 ) 13 14 func TestROC(t *testing.T) { 15 const tol = 1e-14 16 17 cases := []struct { 18 y []float64 19 c []bool 20 w []float64 21 cutoffs []float64 22 wantTPR []float64 23 wantFPR []float64 24 wantThresh []float64 25 }{ 26 // Test cases were informed by using sklearn metrics.roc_curve when 27 // cutoffs is nil, but all test cases (including when cutoffs is not 28 // nil) were calculated manually. 29 // Some differences exist between unweighted ROCs from our function 30 // and metrics.roc_curve which appears to use integer cutoffs in that 31 // case. sklearn also appears to do some magic that trims leading zeros 32 // sometimes. 33 { // 0 34 y: []float64{0, 3, 5, 6, 7.5, 8}, 35 c: []bool{false, true, false, true, true, true}, 36 wantTPR: []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1}, 37 wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 1}, 38 wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0}, 39 }, 40 { // 1 41 y: []float64{0, 3, 5, 6, 7.5, 8}, 42 c: []bool{false, true, false, true, true, true}, 43 w: []float64{4, 1, 6, 3, 2, 2}, 44 wantTPR: []float64{0, 0.25, 0.5, 0.875, 0.875, 1, 1}, 45 wantFPR: []float64{0, 0, 0, 0, 0.6, 0.6, 1}, 46 wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0}, 47 }, 48 { // 2 49 y: []float64{0, 3, 5, 6, 7.5, 8}, 50 c: []bool{false, true, false, true, true, true}, 51 cutoffs: []float64{-1, 2, 4, 6, 8}, 52 wantTPR: []float64{0.25, 0.75, 0.75, 1, 1}, 53 wantFPR: []float64{0, 0, 0.5, 0.5, 1}, 54 wantThresh: []float64{8, 6, 4, 2, -1}, 55 }, 56 { // 3 57 y: []float64{0, 3, 5, 6, 7.5, 8}, 58 c: []bool{false, true, false, true, true, true}, 59 cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, 60 wantTPR: []float64{0.25, 0.5, 0.75, 0.75, 0.75, 1, 1, 1, 1}, 61 wantFPR: []float64{0, 0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 1}, 62 wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1}, 63 }, 64 { // 4 65 y: []float64{0, 3, 5, 6, 7.5, 8}, 66 c: []bool{false, true, false, true, true, true}, 67 w: []float64{4, 1, 6, 3, 2, 2}, 68 cutoffs: []float64{-1, 2, 4, 6, 8}, 69 wantTPR: []float64{0.25, 0.875, 0.875, 1, 1}, 70 wantFPR: []float64{0, 0, 0.6, 0.6, 1}, 71 wantThresh: []float64{8, 6, 4, 2, -1}, 72 }, 73 { // 5 74 y: []float64{0, 3, 5, 6, 7.5, 8}, 75 c: []bool{false, true, false, true, true, true}, 76 w: []float64{4, 1, 6, 3, 2, 2}, 77 cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, 78 wantTPR: []float64{0.25, 0.5, 0.875, 0.875, 0.875, 1, 1, 1, 1}, 79 wantFPR: []float64{0, 0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 1}, 80 wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1}, 81 }, 82 { // 6 83 y: []float64{0, 3, 6, 6, 6, 8}, 84 c: []bool{false, true, false, true, true, true}, 85 wantTPR: []float64{0, 0.25, 0.75, 1, 1}, 86 wantFPR: []float64{0, 0, 0.5, 0.5, 1}, 87 wantThresh: []float64{math.Inf(1), 8, 6, 3, 0}, 88 }, 89 { // 7 90 y: []float64{0, 3, 6, 6, 6, 8}, 91 c: []bool{false, true, false, true, true, true}, 92 w: []float64{4, 1, 6, 3, 2, 2}, 93 wantTPR: []float64{0, 0.25, 0.875, 1, 1}, 94 wantFPR: []float64{0, 0, 0.6, 0.6, 1}, 95 wantThresh: []float64{math.Inf(1), 8, 6, 3, 0}, 96 }, 97 { // 8 98 y: []float64{0, 3, 6, 6, 6, 8}, 99 c: []bool{false, true, false, true, true, true}, 100 cutoffs: []float64{-1, 2, 4, 6, 8}, 101 wantTPR: []float64{0.25, 0.75, 0.75, 1, 1}, 102 wantFPR: []float64{0, 0.5, 0.5, 0.5, 1}, 103 wantThresh: []float64{8, 6, 4, 2, -1}, 104 }, 105 { // 9 106 y: []float64{0, 3, 6, 6, 6, 8}, 107 c: []bool{false, true, false, true, true, true}, 108 cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, 109 wantTPR: []float64{0.25, 0.25, 0.75, 0.75, 0.75, 1, 1, 1, 1}, 110 wantFPR: []float64{0, 0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1}, 111 wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1}, 112 }, 113 { // 10 114 y: []float64{0, 3, 6, 6, 6, 8}, 115 c: []bool{false, true, false, true, true, true}, 116 w: []float64{4, 1, 6, 3, 2, 2}, 117 cutoffs: []float64{-1, 2, 4, 6, 8}, 118 wantTPR: []float64{0.25, 0.875, 0.875, 1, 1}, 119 wantFPR: []float64{0, 0.6, 0.6, 0.6, 1}, 120 wantThresh: []float64{8, 6, 4, 2, -1}, 121 }, 122 { // 11 123 y: []float64{0, 3, 6, 6, 6, 8}, 124 c: []bool{false, true, false, true, true, true}, 125 w: []float64{4, 1, 6, 3, 2, 2}, 126 cutoffs: []float64{-1, 1, 2, 3, 4, 5, 6, 7, 8}, 127 wantTPR: []float64{0.25, 0.25, 0.875, 0.875, 0.875, 1, 1, 1, 1}, 128 wantFPR: []float64{0, 0, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 1}, 129 wantThresh: []float64{8, 7, 6, 5, 4, 3, 2, 1, -1}, 130 }, 131 { // 12 132 y: []float64{0.1, 0.35, 0.4, 0.8}, 133 c: []bool{true, false, true, false}, 134 wantTPR: []float64{0, 0, 0.5, 0.5, 1}, 135 wantFPR: []float64{0, 0.5, 0.5, 1, 1}, 136 wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1}, 137 }, 138 { // 13 139 y: []float64{0.1, 0.35, 0.4, 0.8}, 140 c: []bool{false, false, true, true}, 141 wantTPR: []float64{0, 0.5, 1, 1, 1}, 142 wantFPR: []float64{0, 0, 0, 0.5, 1}, 143 wantThresh: []float64{math.Inf(1), 0.8, 0.4, 0.35, 0.1}, 144 }, 145 { // 14 146 y: []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 10}, 147 c: []bool{false, true, false, false, true, true, false}, 148 cutoffs: []float64{-1, 2.5, 5, 7.5, 10}, 149 wantTPR: []float64{0, 0, 0, 0, 1}, 150 wantFPR: []float64{0.25, 0.25, 0.25, 0.25, 1}, 151 wantThresh: []float64{10, 7.5, 5, 2.5, -1}, 152 }, 153 { // 15 154 y: []float64{1, 2}, 155 c: []bool{false, false}, 156 wantTPR: []float64{math.NaN(), math.NaN(), math.NaN()}, 157 wantFPR: []float64{0, 0.5, 1}, 158 wantThresh: []float64{math.Inf(1), 2, 1}, 159 }, 160 { // 16 161 y: []float64{1, 2}, 162 c: []bool{false, false}, 163 cutoffs: []float64{-1, 2}, 164 wantTPR: []float64{math.NaN(), math.NaN()}, 165 wantFPR: []float64{0.5, 1}, 166 wantThresh: []float64{2, -1}, 167 }, 168 { // 17 169 y: []float64{1, 2}, 170 c: []bool{false, false}, 171 cutoffs: []float64{0, 1.2, 1.4, 1.6, 1.8, 2}, 172 wantTPR: []float64{math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN(), math.NaN()}, 173 wantFPR: []float64{0.5, 0.5, 0.5, 0.5, 0.5, 1}, 174 wantThresh: []float64{2, 1.8, 1.6, 1.4, 1.2, 0}, 175 }, 176 { // 18 177 y: []float64{1}, 178 c: []bool{false}, 179 wantTPR: []float64{math.NaN(), math.NaN()}, 180 wantFPR: []float64{0, 1}, 181 wantThresh: []float64{math.Inf(1), 1}, 182 }, 183 { // 19 184 y: []float64{1}, 185 c: []bool{false}, 186 cutoffs: []float64{-1, 1}, 187 wantTPR: []float64{math.NaN(), math.NaN()}, 188 wantFPR: []float64{1, 1}, 189 wantThresh: []float64{1, -1}, 190 }, 191 { // 20 192 y: []float64{1}, 193 c: []bool{true}, 194 wantTPR: []float64{0, 1}, 195 wantFPR: []float64{math.NaN(), math.NaN()}, 196 wantThresh: []float64{math.Inf(1), 1}, 197 }, 198 { // 21 199 y: []float64{}, 200 c: []bool{}, 201 wantTPR: nil, 202 wantFPR: nil, 203 wantThresh: nil, 204 }, 205 { // 22 206 y: []float64{}, 207 c: []bool{}, 208 cutoffs: []float64{-1, 2.5, 5, 7.5, 10}, 209 wantTPR: nil, 210 wantFPR: nil, 211 wantThresh: nil, 212 }, 213 { // 23 214 y: []float64{0.1, 0.35, 0.4, 0.8}, 215 c: []bool{true, false, true, false}, 216 cutoffs: []float64{-1, 0.1, 0.35, 0.4, 0.8, 0.9, 1}, 217 wantTPR: []float64{0, 0, 0, 0.5, 0.5, 1, 1}, 218 wantFPR: []float64{0, 0, 0.5, 0.5, 1, 1, 1}, 219 wantThresh: []float64{1, 0.9, 0.8, 0.4, 0.35, 0.1, -1}, 220 }, 221 { // 24 222 y: []float64{0.1, 0.35, 0.4, 0.8}, 223 c: []bool{true, false, true, false}, 224 cutoffs: []float64{math.Inf(-1), 0.1, 0.36, 0.8}, 225 wantTPR: []float64{0, 0.5, 1, 1}, 226 wantFPR: []float64{0.5, 0.5, 1, 1}, 227 wantThresh: []float64{0.8, 0.36, 0.1, math.Inf(-1)}, 228 }, 229 { // 25 230 y: []float64{0, 3, 5, 6, 7.5, 8}, 231 c: []bool{false, true, false, true, true, true}, 232 cutoffs: make([]float64, 0, 10), 233 wantTPR: []float64{0, 0.25, 0.5, 0.75, 0.75, 1, 1}, 234 wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 1}, 235 wantThresh: []float64{math.Inf(1), 8, 7.5, 6, 5, 3, 0}, 236 }, 237 { // 26 238 y: []float64{0.1, 0.35, 0.4, 0.8}, 239 c: []bool{true, false, true, false}, 240 cutoffs: []float64{-1, 0.1, 0.35, 0.4, 0.8, 0.9, 1, 1.1, 1.2}, 241 wantTPR: []float64{0, 0, 0, 0, 0, 0.5, 0.5, 1, 1}, 242 wantFPR: []float64{0, 0, 0, 0, 0.5, 0.5, 1, 1, 1}, 243 wantThresh: []float64{1.2, 1.1, 1, 0.9, 0.8, 0.4, 0.35, 0.1, -1}, 244 }, 245 } 246 for i, test := range cases { 247 gotTPR, gotFPR, gotThresh := ROC(test.cutoffs, test.y, test.c, test.w) 248 if !floats.Same(gotTPR, test.wantTPR) && !floats.EqualApprox(gotTPR, test.wantTPR, tol) { 249 t.Errorf("%d: unexpected TPR got:%v want:%v", i, gotTPR, test.wantTPR) 250 } 251 if !floats.Same(gotFPR, test.wantFPR) && !floats.EqualApprox(gotFPR, test.wantFPR, tol) { 252 t.Errorf("%d: unexpected FPR got:%v want:%v", i, gotFPR, test.wantFPR) 253 } 254 if !floats.Same(gotThresh, test.wantThresh) { 255 t.Errorf("%d: unexpected thresholds got:%#v want:%v", i, gotThresh, test.wantThresh) 256 } 257 } 258 }