github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/spatial/vptree/vptree_test.go (about) 1 // Copyright ©2019 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package vptree 6 7 import ( 8 "flag" 9 "fmt" 10 "math" 11 "os" 12 "reflect" 13 "sort" 14 "strings" 15 "testing" 16 "unsafe" 17 18 "golang.org/x/exp/rand" 19 ) 20 21 var ( 22 genDot = flag.Bool("dot", false, "generate dot code for failing trees") 23 dotLimit = flag.Int("dotmax", 100, "specify maximum size for tree output for dot format") 24 ) 25 26 var ( 27 // Using example from WP article: https://en.wikipedia.org/w/index.php?title=K-d_tree&oldid=887573572. 28 wpData = []Comparable{ 29 Point{2, 3}, 30 Point{5, 4}, 31 Point{9, 6}, 32 Point{4, 7}, 33 Point{8, 1}, 34 Point{7, 2}, 35 } 36 ) 37 38 var newTests = []struct { 39 data []Comparable 40 effort int 41 }{ 42 {data: wpData, effort: 0}, 43 {data: wpData, effort: 1}, 44 {data: wpData, effort: 2}, 45 {data: wpData, effort: 4}, 46 {data: wpData, effort: 8}, 47 } 48 49 func TestNew(t *testing.T) { 50 for i, test := range newTests { 51 var tree *Tree 52 var err error 53 var panicked bool 54 func() { 55 defer func() { 56 if r := recover(); r != nil { 57 panicked = true 58 } 59 }() 60 tree, err = New(test.data, test.effort, rand.NewSource(1)) 61 }() 62 if panicked { 63 t.Errorf("unexpected panic for test %d", i) 64 continue 65 } 66 if err != nil { 67 t.Errorf("unexpected error for test %d: %v", i, err) 68 continue 69 } 70 71 if !tree.Root.isVPTree() { 72 t.Errorf("tree %d is not vp-tree", i) 73 } 74 75 if t.Failed() && *genDot && tree.Len() <= *dotLimit { 76 err := dotFile(tree, fmt.Sprintf("TestNew%d", i), "") 77 if err != nil { 78 t.Fatalf("failed to write DOT file: %v", err) 79 } 80 } 81 } 82 } 83 84 type compFn func(v, radius float64) bool 85 86 func closer(v, radius float64) bool { return v <= radius } 87 func further(v, radius float64) bool { return v >= radius } 88 89 func (n *Node) isVPTree() bool { 90 if n == nil { 91 return true 92 } 93 if !n.Closer.isPartitioned(n.Point, closer, n.Radius) { 94 return false 95 } 96 if !n.Further.isPartitioned(n.Point, further, n.Radius) { 97 return false 98 } 99 return n.Closer.isVPTree() && n.Further.isVPTree() 100 } 101 102 func (n *Node) isPartitioned(vp Comparable, fn compFn, radius float64) bool { 103 if n == nil { 104 return true 105 } 106 if n.Closer != nil && !fn(vp.Distance(n.Closer.Point), radius) { 107 return false 108 } 109 if n.Further != nil && !fn(vp.Distance(n.Further.Point), radius) { 110 return false 111 } 112 return n.Closer.isPartitioned(vp, fn, radius) && n.Further.isPartitioned(vp, fn, radius) 113 } 114 115 func nearest(q Comparable, p []Comparable) (Comparable, float64) { 116 min := q.Distance(p[0]) 117 var r int 118 for i := 1; i < len(p); i++ { 119 d := q.Distance(p[i]) 120 if d < min { 121 min = d 122 r = i 123 } 124 } 125 return p[r], min 126 } 127 128 func TestNearestRandom(t *testing.T) { 129 rnd := rand.New(rand.NewSource(1)) 130 131 const ( 132 min = 0.0 133 max = 1000.0 134 135 dims = 4 136 setSize = 10000 137 ) 138 139 var randData []Comparable 140 for i := 0; i < setSize; i++ { 141 p := make(Point, dims) 142 for j := 0; j < dims; j++ { 143 p[j] = (max-min)*rnd.Float64() + min 144 } 145 randData = append(randData, p) 146 } 147 tree, err := New(randData, 10, rand.NewSource(1)) 148 if err != nil { 149 t.Fatalf("unexpected error: %v", err) 150 } 151 152 for i := 0; i < setSize; i++ { 153 q := make(Point, dims) 154 for j := 0; j < dims; j++ { 155 q[j] = (max-min)*rnd.Float64() + min 156 } 157 158 got, _ := tree.Nearest(q) 159 want, _ := nearest(q, randData) 160 if !reflect.DeepEqual(got, want) { 161 t.Fatalf("unexpected result from query %d %.3f: got:%.3f want:%.3f", i, q, got, want) 162 } 163 } 164 } 165 166 func TestNearest(t *testing.T) { 167 tree, err := New(wpData, 3, rand.NewSource(1)) 168 if err != nil { 169 t.Fatalf("unexpected error: %v", err) 170 } 171 for _, q := range append([]Comparable{ 172 Point{4, 6}, 173 // Point{7, 5}, // Omitted because it is ambiguously finds [9 6] or [5 4]. 174 Point{8, 7}, 175 Point{6, -5}, 176 Point{1e5, 1e5}, 177 Point{1e5, -1e5}, 178 Point{-1e5, 1e5}, 179 Point{-1e5, -1e5}, 180 Point{1e5, 0}, 181 Point{0, -1e5}, 182 Point{0, 1e5}, 183 Point{-1e5, 0}, 184 }, wpData...) { 185 gotP, gotD := tree.Nearest(q) 186 wantP, wantD := nearest(q, wpData) 187 if !reflect.DeepEqual(gotP, wantP) { 188 t.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP) 189 } 190 if gotD != wantD { 191 t.Errorf("unexpected distance for query %.3f : got:%v want:%v", q, gotD, wantD) 192 } 193 } 194 } 195 196 func nearestN(n int, q Comparable, p []Comparable) []ComparableDist { 197 nk := NewNKeeper(n) 198 for i := 0; i < len(p); i++ { 199 nk.Keep(ComparableDist{Comparable: p[i], Dist: q.Distance(p[i])}) 200 } 201 if len(nk.Heap) == 1 { 202 return nk.Heap 203 } 204 sort.Sort(nk) 205 for i, j := 0, len(nk.Heap)-1; i < j; i, j = i+1, j-1 { 206 nk.Heap[i], nk.Heap[j] = nk.Heap[j], nk.Heap[i] 207 } 208 return nk.Heap 209 } 210 211 func TestNearestSetN(t *testing.T) { 212 data := append([]Comparable{ 213 Point{4, 6}, 214 Point{7, 5}, // OK here because we collect N. 215 Point{8, 7}, 216 Point{6, -5}, 217 Point{1e5, 1e5}, 218 Point{1e5, -1e5}, 219 Point{-1e5, 1e5}, 220 Point{-1e5, -1e5}, 221 Point{1e5, 0}, 222 Point{0, -1e5}, 223 Point{0, 1e5}, 224 Point{-1e5, 0}}, 225 wpData[:len(wpData)-1]...) 226 227 tree, err := New(wpData, 3, rand.NewSource(1)) 228 if err != nil { 229 t.Fatalf("unexpected error: %v", err) 230 } 231 for k := 1; k <= len(wpData); k++ { 232 for _, q := range data { 233 wantP := nearestN(k, q, wpData) 234 235 nk := NewNKeeper(k) 236 tree.NearestSet(nk, q) 237 238 var max float64 239 wantD := make(map[float64]map[string]struct{}) 240 for _, p := range wantP { 241 if p.Dist > max { 242 max = p.Dist 243 } 244 d, ok := wantD[p.Dist] 245 if !ok { 246 d = make(map[string]struct{}) 247 } 248 d[fmt.Sprint(p.Comparable)] = struct{}{} 249 wantD[p.Dist] = d 250 } 251 gotD := make(map[float64]map[string]struct{}) 252 for _, p := range nk.Heap { 253 if p.Dist > max { 254 t.Errorf("unexpected distance for point %.3f: got:%v want:<=%v", p.Comparable, p.Dist, max) 255 } 256 d, ok := gotD[p.Dist] 257 if !ok { 258 d = make(map[string]struct{}) 259 } 260 d[fmt.Sprint(p.Comparable)] = struct{}{} 261 gotD[p.Dist] = d 262 } 263 264 // If the available number of slots does not fit all the coequal furthest points 265 // we will fail the check. So remove, but check them minimally here. 266 if !reflect.DeepEqual(wantD[max], gotD[max]) { 267 // The best we can do at this stage is confirm that there are an equal number of matches at this distance. 268 if len(gotD[max]) != len(wantD[max]) { 269 t.Errorf("unexpected number of maximal distance points: got:%d want:%d", len(gotD[max]), len(wantD[max])) 270 } 271 delete(wantD, max) 272 delete(gotD, max) 273 } 274 275 if !reflect.DeepEqual(gotD, wantD) { 276 t.Errorf("unexpected result for k=%d query %.3f: got:%v want:%v", k, q, gotD, wantD) 277 } 278 } 279 } 280 } 281 282 var nearestSetDistTests = []Point{ 283 {4, 6}, 284 {7, 5}, 285 {8, 7}, 286 {6, -5}, 287 } 288 289 func TestNearestSetDist(t *testing.T) { 290 tree, err := New(wpData, 3, rand.NewSource(1)) 291 if err != nil { 292 t.Fatalf("unexpected error: %v", err) 293 } 294 for i, q := range nearestSetDistTests { 295 for d := 1.0; d < 100; d += 0.1 { 296 dk := NewDistKeeper(d) 297 tree.NearestSet(dk, q) 298 299 hits := make(map[string]float64) 300 for _, p := range wpData { 301 hits[fmt.Sprint(p)] = p.Distance(q) 302 } 303 304 for _, p := range dk.Heap { 305 var done bool 306 if p.Comparable == nil { 307 done = true 308 continue 309 } 310 delete(hits, fmt.Sprint(p.Comparable)) 311 if done { 312 t.Error("expectedly finished heap iteration") 313 break 314 } 315 dist := p.Comparable.Distance(q) 316 if dist > d { 317 t.Errorf("Test %d: query %v found %v expect %.3f <= %.3f", i, q, p, dist, d) 318 } 319 } 320 321 for p, dist := range hits { 322 if dist <= d { 323 t.Errorf("Test %d: query %v missed %v expect %.3f > %.3f", i, q, p, dist, d) 324 } 325 } 326 } 327 } 328 } 329 330 func TestDo(t *testing.T) { 331 tree, err := New(wpData, 3, rand.NewSource(1)) 332 if err != nil { 333 t.Fatalf("unexpected error: %v", err) 334 } 335 var got []Point 336 fn := func(c Comparable, _ int) (done bool) { 337 got = append(got, c.(Point)) 338 return 339 } 340 killed := tree.Do(fn) 341 342 want := make([]Point, len(wpData)) 343 for i, p := range wpData { 344 want[i] = p.(Point) 345 } 346 sort.Sort(lexical(got)) 347 sort.Sort(lexical(want)) 348 349 if !reflect.DeepEqual(got, want) { 350 t.Errorf("unexpected result from tree iteration: got:%v want:%v", got, want) 351 } 352 if killed { 353 t.Error("tree iteration unexpectedly killed") 354 } 355 } 356 357 type lexical []Point 358 359 func (c lexical) Len() int { return len(c) } 360 func (c lexical) Less(i, j int) bool { 361 a, b := c[i], c[j] 362 l := len(a) 363 if len(b) < l { 364 l = len(b) 365 } 366 for k, v := range a[:l] { 367 if v < b[k] { 368 return true 369 } 370 if v > b[k] { 371 return false 372 } 373 } 374 return len(a) < len(b) 375 } 376 func (c lexical) Swap(i, j int) { c[i], c[j] = c[j], c[i] } 377 378 func BenchmarkNew(b *testing.B) { 379 for _, effort := range []int{0, 10, 100} { 380 b.Run(fmt.Sprintf("New:%d", effort), func(b *testing.B) { 381 rnd := rand.New(rand.NewSource(1)) 382 p := make([]Comparable, 1e5) 383 for i := range p { 384 p[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} 385 } 386 b.ResetTimer() 387 for i := 0; i < b.N; i++ { 388 _, err := New(p, effort, rand.NewSource(1)) 389 if err != nil { 390 b.Fatalf("unexpected error: %v", err) 391 } 392 } 393 }) 394 } 395 } 396 397 func Benchmark(b *testing.B) { 398 var r Comparable 399 var d float64 400 queryBenchmarks := []struct { 401 name string 402 fn func(data []Comparable, tree *Tree, rnd *rand.Rand) func(*testing.B) 403 }{ 404 { 405 name: "NearestBrute", fn: func(data []Comparable, _ *Tree, rnd *rand.Rand) func(b *testing.B) { 406 return func(b *testing.B) { 407 for i := 0; i < b.N; i++ { 408 r, d = nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data) 409 } 410 if r == nil { 411 b.Error("unexpected nil result") 412 } 413 if math.IsNaN(d) { 414 b.Error("unexpected NaN result") 415 } 416 } 417 }, 418 }, 419 { 420 name: "NearestBruteN10", fn: func(data []Comparable, _ *Tree, rnd *rand.Rand) func(b *testing.B) { 421 return func(b *testing.B) { 422 var r []ComparableDist 423 for i := 0; i < b.N; i++ { 424 r = nearestN(10, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data) 425 } 426 if len(r) != 10 { 427 b.Error("unexpected result length", len(r)) 428 } 429 } 430 }, 431 }, 432 { 433 name: "Nearest", fn: func(_ []Comparable, tree *Tree, rnd *rand.Rand) func(b *testing.B) { 434 return func(b *testing.B) { 435 for i := 0; i < b.N; i++ { 436 r, d = tree.Nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}) 437 } 438 if r == nil { 439 b.Error("unexpected nil result") 440 } 441 if math.IsNaN(d) { 442 b.Error("unexpected NaN result") 443 } 444 } 445 }, 446 }, 447 { 448 name: "NearestSetN10", fn: func(_ []Comparable, tree *Tree, rnd *rand.Rand) func(b *testing.B) { 449 return func(b *testing.B) { 450 nk := NewNKeeper(10) 451 for i := 0; i < b.N; i++ { 452 tree.NearestSet(nk, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}) 453 if nk.Len() != 10 { 454 b.Error("unexpected result length") 455 } 456 nk.Heap = nk.Heap[:1] 457 nk.Heap[0] = ComparableDist{Dist: inf} 458 } 459 } 460 }, 461 }, 462 } 463 464 for _, effort := range []int{0, 3, 10, 30, 100, 300} { 465 rnd := rand.New(rand.NewSource(1)) 466 data := make([]Comparable, 1e5) 467 for i := range data { 468 data[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} 469 } 470 tree, err := New(data, effort, rand.NewSource(1)) 471 if err != nil { 472 b.Errorf("unexpected error for effort=%d: %v", effort, err) 473 continue 474 } 475 476 if !tree.Root.isVPTree() { 477 b.Fatal("tree is not vantage point tree") 478 } 479 480 for i := 0; i < 1e3; i++ { 481 q := Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} 482 gotP, gotD := tree.Nearest(q) 483 wantP, wantD := nearest(q, data) 484 if !reflect.DeepEqual(gotP, wantP) { 485 b.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP) 486 } 487 if gotD != wantD { 488 b.Errorf("unexpected distance for query %.3f: got:%v want:%v", q, gotD, wantD) 489 } 490 } 491 492 if b.Failed() && *genDot && tree.Len() <= *dotLimit { 493 err := dotFile(tree, "TestBenches", "") 494 if err != nil { 495 b.Fatalf("failed to write DOT file: %v", err) 496 } 497 return 498 } 499 500 for _, bench := range queryBenchmarks { 501 if strings.Contains(bench.name, "Brute") && effort != 0 { 502 continue 503 } 504 b.Run(fmt.Sprintf("%s:%d", bench.name, effort), bench.fn(data, tree, rnd)) 505 } 506 } 507 } 508 509 func dot(t *Tree, label string) string { 510 if t == nil { 511 return "" 512 } 513 var ( 514 s []string 515 follow func(*Node) 516 ) 517 follow = func(n *Node) { 518 id := uintptr(unsafe.Pointer(n)) 519 c := fmt.Sprintf("%d[label = \"<Closer> |<Elem> %.3f/%.3f|<Further>\"];", 520 id, n.Point, n.Radius) 521 if n.Closer != nil { 522 c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Closer -> \"%d\":Elem [label=%.3f];", 523 id, uintptr(unsafe.Pointer(n.Closer)), n.Point.Distance(n.Closer.Point)) 524 follow(n.Closer) 525 } 526 if n.Further != nil { 527 c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Further -> \"%d\":Elem [label=%.3f];", 528 id, uintptr(unsafe.Pointer(n.Further)), n.Point.Distance(n.Further.Point)) 529 follow(n.Further) 530 } 531 s = append(s, c) 532 } 533 if t.Root != nil { 534 follow(t.Root) 535 } 536 return fmt.Sprintf("digraph %s {\n\tnode [shape=record,height=0.1];\n\t%s\n}\n", 537 label, 538 strings.Join(s, "\n\t"), 539 ) 540 } 541 542 func dotFile(t *Tree, label, dotString string) (err error) { 543 if t == nil && dotString == "" { 544 return 545 } 546 f, err := os.Create(label + ".dot") 547 if err != nil { 548 return 549 } 550 defer f.Close() 551 if dotString == "" { 552 fmt.Fprint(f, dot(t, label)) 553 } else { 554 fmt.Fprint(f, dotString) 555 } 556 return 557 }