github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/spatial/kdtree/kdtree_test.go (about) 1 // Copyright ©2019 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package kdtree 6 7 import ( 8 "flag" 9 "fmt" 10 "math" 11 "os" 12 "reflect" 13 "sort" 14 "strings" 15 "testing" 16 "unsafe" 17 18 "golang.org/x/exp/rand" 19 ) 20 21 var ( 22 genDot = flag.Bool("dot", false, "generate dot code for failing trees") 23 dotLimit = flag.Int("dotmax", 100, "specify maximum size for tree output for dot format") 24 ) 25 26 var ( 27 // Using example from WP article: https://en.wikipedia.org/w/index.php?title=K-d_tree&oldid=887573572. 28 wpData = Points{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}} 29 nbWpData = nbPoints{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}} 30 wpBound = &Bounding{Point{2, 1}, Point{9, 7}} 31 ) 32 33 var newTests = []struct { 34 data Interface 35 bounding bool 36 wantBounds *Bounding 37 }{ 38 {data: wpData, bounding: false, wantBounds: nil}, 39 {data: nbWpData, bounding: false, wantBounds: nil}, 40 {data: wpData, bounding: true, wantBounds: wpBound}, 41 {data: nbWpData, bounding: true, wantBounds: nil}, 42 } 43 44 func TestNew(t *testing.T) { 45 for i, test := range newTests { 46 var tree *Tree 47 var panicked bool 48 func() { 49 defer func() { 50 if r := recover(); r != nil { 51 panicked = true 52 } 53 }() 54 tree = New(test.data, test.bounding) 55 }() 56 if panicked { 57 t.Errorf("unexpected panic for test %d", i) 58 continue 59 } 60 61 if !tree.Root.isKDTree() { 62 t.Errorf("tree %d is not k-d tree", i) 63 } 64 65 switch data := test.data.(type) { 66 case Points: 67 for _, p := range data { 68 if !tree.Contains(p) { 69 t.Errorf("failed to find point %.3f in test %d", p, i) 70 } 71 } 72 case nbPoints: 73 for _, p := range data { 74 if !tree.Contains(p) { 75 t.Errorf("failed to find point %.3f in test %d", p, i) 76 } 77 } 78 default: 79 t.Fatalf("bad test: unknown data type: %T", test.data) 80 } 81 82 if !reflect.DeepEqual(tree.Root.Bounding, test.wantBounds) { 83 t.Errorf("unexpected bounding box for test %d with data type %T: got:%v want:%v", 84 i, test.data, tree.Root.Bounding, test.wantBounds) 85 } 86 87 if t.Failed() && *genDot && tree.Len() <= *dotLimit { 88 err := dotFile(tree, fmt.Sprintf("TestNew%T", test.data), "") 89 if err != nil { 90 t.Fatalf("failed to write DOT file: %v", err) 91 } 92 } 93 } 94 } 95 96 var insertTests = []struct { 97 data Interface 98 insert []Comparable 99 wantBounds *Bounding 100 }{ 101 { 102 data: wpData, 103 insert: []Comparable{Point{0, 0}, Point{10, 10}}, 104 wantBounds: &Bounding{Point{0, 0}, Point{10, 10}}, 105 }, 106 { 107 data: nbWpData, 108 insert: []Comparable{nbPoint{0, 0}, nbPoint{10, 10}}, 109 wantBounds: nil, 110 }, 111 } 112 113 func TestInsert(t *testing.T) { 114 for i, test := range insertTests { 115 tree := New(test.data, true) 116 for _, v := range test.insert { 117 tree.Insert(v, true) 118 } 119 120 if !tree.Root.isKDTree() { 121 t.Errorf("tree %d is not k-d tree", i) 122 } 123 124 if !reflect.DeepEqual(tree.Root.Bounding, test.wantBounds) { 125 t.Errorf("unexpected bounding box for test %d with data type %T: got:%v want:%v", 126 i, test.data, tree.Root.Bounding, test.wantBounds) 127 } 128 129 if t.Failed() && *genDot && tree.Len() <= *dotLimit { 130 err := dotFile(tree, fmt.Sprintf("TestInsert%T", test.data), "") 131 if err != nil { 132 t.Fatalf("failed to write DOT file: %v", err) 133 } 134 } 135 } 136 } 137 138 type compFn func(float64) bool 139 140 func left(v float64) bool { return v <= 0 } 141 func right(v float64) bool { return !left(v) } 142 143 func (n *Node) isKDTree() bool { 144 if n == nil { 145 return true 146 } 147 d := n.Point.Dims() 148 // Together these define the property of minimal orthogonal bounding. 149 if !(n.isContainedBy(n.Bounding) && n.Bounding.planesHaveCoincidentPointsIn(n, [2][]bool{make([]bool, d), make([]bool, d)})) { 150 return false 151 } 152 if !n.Left.isPartitioned(n.Point, left, n.Plane) { 153 return false 154 } 155 if !n.Right.isPartitioned(n.Point, right, n.Plane) { 156 return false 157 } 158 return n.Left.isKDTree() && n.Right.isKDTree() 159 } 160 161 func (n *Node) isPartitioned(pivot Comparable, fn compFn, plane Dim) bool { 162 if n == nil { 163 return true 164 } 165 if n.Left != nil && fn(pivot.Compare(n.Left.Point, plane)) { 166 return false 167 } 168 if n.Right != nil && fn(pivot.Compare(n.Right.Point, plane)) { 169 return false 170 } 171 return n.Left.isPartitioned(pivot, fn, plane) && n.Right.isPartitioned(pivot, fn, plane) 172 } 173 174 func (n *Node) isContainedBy(b *Bounding) bool { 175 if n == nil { 176 return true 177 } 178 if !b.Contains(n.Point) { 179 return false 180 } 181 return n.Left.isContainedBy(b) && n.Right.isContainedBy(b) 182 } 183 184 func (b *Bounding) planesHaveCoincidentPointsIn(n *Node, tight [2][]bool) bool { 185 if b == nil { 186 return true 187 } 188 if n == nil { 189 return true 190 } 191 192 b.planesHaveCoincidentPointsIn(n.Left, tight) 193 b.planesHaveCoincidentPointsIn(n.Right, tight) 194 195 var ok = true 196 for i := range tight { 197 for d := 0; d < n.Point.Dims(); d++ { 198 if c := n.Point.Compare(b.Min, Dim(d)); c == 0 { 199 tight[i][d] = true 200 } 201 ok = ok && tight[i][d] 202 } 203 } 204 return ok 205 } 206 207 func nearest(q Point, p Points) (Point, float64) { 208 min := q.Distance(p[0]) 209 var r int 210 for i := 1; i < p.Len(); i++ { 211 d := q.Distance(p[i]) 212 if d < min { 213 min = d 214 r = i 215 } 216 } 217 return p[r], min 218 } 219 220 func TestNearestRandom(t *testing.T) { 221 rnd := rand.New(rand.NewSource(1)) 222 223 const ( 224 min = 0.0 225 max = 1000.0 226 227 dims = 4 228 setSize = 10000 229 ) 230 231 var randData Points 232 for i := 0; i < setSize; i++ { 233 p := make(Point, dims) 234 for j := 0; j < dims; j++ { 235 p[j] = (max-min)*rnd.Float64() + min 236 } 237 randData = append(randData, p) 238 } 239 tree := New(randData, false) 240 241 for i := 0; i < setSize; i++ { 242 q := make(Point, dims) 243 for j := 0; j < dims; j++ { 244 q[j] = (max-min)*rnd.Float64() + min 245 } 246 247 got, _ := tree.Nearest(q) 248 want, _ := nearest(q, randData) 249 if !reflect.DeepEqual(got, want) { 250 t.Fatalf("unexpected result from query %d %.3f: got:%.3f want:%.3f", i, q, got, want) 251 } 252 } 253 } 254 255 func TestNearest(t *testing.T) { 256 tree := New(wpData, false) 257 for _, q := range append([]Point{ 258 {4, 6}, 259 {7, 5}, 260 {8, 7}, 261 {6, -5}, 262 {1e5, 1e5}, 263 {1e5, -1e5}, 264 {-1e5, 1e5}, 265 {-1e5, -1e5}, 266 {1e5, 0}, 267 {0, -1e5}, 268 {0, 1e5}, 269 {-1e5, 0}, 270 }, wpData...) { 271 gotP, gotD := tree.Nearest(q) 272 wantP, wantD := nearest(q, wpData) 273 if !reflect.DeepEqual(gotP, wantP) { 274 t.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP) 275 } 276 if gotD != wantD { 277 t.Errorf("unexpected distance for query %.3f : got:%v want:%v", q, gotD, wantD) 278 } 279 } 280 } 281 282 func nearestN(n int, q Point, p Points) []ComparableDist { 283 nk := NewNKeeper(n) 284 for i := 0; i < p.Len(); i++ { 285 nk.Keep(ComparableDist{Comparable: p[i], Dist: q.Distance(p[i])}) 286 } 287 if len(nk.Heap) == 1 { 288 return nk.Heap 289 } 290 sort.Sort(nk) 291 for i, j := 0, len(nk.Heap)-1; i < j; i, j = i+1, j-1 { 292 nk.Heap[i], nk.Heap[j] = nk.Heap[j], nk.Heap[i] 293 } 294 return nk.Heap 295 } 296 297 func TestNearestSetN(t *testing.T) { 298 data := append([]Point{ 299 {4, 6}, 300 {7, 5}, 301 {8, 7}, 302 {6, -5}, 303 {1e5, 1e5}, 304 {1e5, -1e5}, 305 {-1e5, 1e5}, 306 {-1e5, -1e5}, 307 {1e5, 0}, 308 {0, -1e5}, 309 {0, 1e5}, 310 {-1e5, 0}}, 311 wpData[:len(wpData)-1]...) 312 313 tree := New(wpData, false) 314 for k := 1; k <= len(wpData); k++ { 315 for _, q := range data { 316 wantP := nearestN(k, q, wpData) 317 318 nk := NewNKeeper(k) 319 tree.NearestSet(nk, q) 320 321 var max float64 322 wantD := make(map[float64]map[string]struct{}) 323 for _, p := range wantP { 324 if p.Dist > max { 325 max = p.Dist 326 } 327 d, ok := wantD[p.Dist] 328 if !ok { 329 d = make(map[string]struct{}) 330 } 331 d[fmt.Sprint(p.Comparable)] = struct{}{} 332 wantD[p.Dist] = d 333 } 334 gotD := make(map[float64]map[string]struct{}) 335 for _, p := range nk.Heap { 336 if p.Dist > max { 337 t.Errorf("unexpected distance for point %.3f: got:%v want:<=%v", p.Comparable, p.Dist, max) 338 } 339 d, ok := gotD[p.Dist] 340 if !ok { 341 d = make(map[string]struct{}) 342 } 343 d[fmt.Sprint(p.Comparable)] = struct{}{} 344 gotD[p.Dist] = d 345 } 346 347 // If the available number of slots does not fit all the coequal furthest points 348 // we will fail the check. So remove, but check them minimally here. 349 if !reflect.DeepEqual(wantD[max], gotD[max]) { 350 // The best we can do at this stage is confirm that there are an equal number of matches at this distance. 351 if len(gotD[max]) != len(wantD[max]) { 352 t.Errorf("unexpected number of maximal distance points: got:%d want:%d", len(gotD[max]), len(wantD[max])) 353 } 354 delete(wantD, max) 355 delete(gotD, max) 356 } 357 358 if !reflect.DeepEqual(gotD, wantD) { 359 t.Errorf("unexpected result for k=%d query %.3f: got:%v want:%v", k, q, gotD, wantD) 360 } 361 } 362 } 363 } 364 365 var nearestSetDistTests = []Point{ 366 {4, 6}, 367 {7, 5}, 368 {8, 7}, 369 {6, -5}, 370 } 371 372 func TestNearestSetDist(t *testing.T) { 373 tree := New(wpData, false) 374 for i, q := range nearestSetDistTests { 375 for d := 1.0; d < 100; d += 0.1 { 376 dk := NewDistKeeper(d) 377 tree.NearestSet(dk, q) 378 379 hits := make(map[string]float64) 380 for _, p := range wpData { 381 hits[fmt.Sprint(p)] = p.Distance(q) 382 } 383 384 for _, p := range dk.Heap { 385 var done bool 386 if p.Comparable == nil { 387 done = true 388 continue 389 } 390 delete(hits, fmt.Sprint(p.Comparable)) 391 if done { 392 t.Error("expectedly finished heap iteration") 393 break 394 } 395 dist := p.Comparable.Distance(q) 396 if dist > d { 397 t.Errorf("Test %d: query %v found %v expect %.3f <= %.3f", i, q, p, dist, d) 398 } 399 } 400 401 for p, dist := range hits { 402 if dist <= d { 403 t.Errorf("Test %d: query %v missed %v expect %.3f > %.3f", i, q, p, dist, d) 404 } 405 } 406 } 407 } 408 } 409 410 func TestDo(t *testing.T) { 411 tree := New(wpData, false) 412 var got Points 413 fn := func(c Comparable, _ *Bounding, _ int) (done bool) { 414 got = append(got, c.(Point)) 415 return 416 } 417 killed := tree.Do(fn) 418 if !reflect.DeepEqual(got, wpData) { 419 t.Errorf("unexpected result from tree iteration: got:%v want:%v", got, wpData) 420 } 421 if killed { 422 t.Error("tree iteration unexpectedly killed") 423 } 424 } 425 426 var doBoundedTests = []struct { 427 bounds *Bounding 428 want Points 429 }{ 430 { 431 bounds: nil, 432 want: wpData, 433 }, 434 { 435 bounds: &Bounding{Point{0, 0}, Point{10, 10}}, 436 want: wpData, 437 }, 438 { 439 bounds: &Bounding{Point{3, 4}, Point{10, 10}}, 440 want: Points{Point{5, 4}, Point{4, 7}, Point{9, 6}}, 441 }, 442 { 443 bounds: &Bounding{Point{3, 3}, Point{10, 10}}, 444 want: Points{Point{5, 4}, Point{4, 7}, Point{9, 6}}, 445 }, 446 { 447 bounds: &Bounding{Point{0, 0}, Point{6, 5}}, 448 want: Points{Point{2, 3}, Point{5, 4}}, 449 }, 450 { 451 bounds: &Bounding{Point{5, 2}, Point{7, 4}}, 452 want: Points{Point{5, 4}, Point{7, 2}}, 453 }, 454 { 455 bounds: &Bounding{Point{2, 2}, Point{7, 4}}, 456 want: Points{Point{2, 3}, Point{5, 4}, Point{7, 2}}, 457 }, 458 { 459 bounds: &Bounding{Point{2, 3}, Point{9, 6}}, 460 want: Points{Point{2, 3}, Point{5, 4}, Point{9, 6}}, 461 }, 462 { 463 bounds: &Bounding{Point{7, 2}, Point{7, 2}}, 464 want: Points{Point{7, 2}}, 465 }, 466 } 467 468 func TestDoBounded(t *testing.T) { 469 for _, test := range doBoundedTests { 470 tree := New(wpData, false) 471 var got Points 472 fn := func(c Comparable, _ *Bounding, _ int) (done bool) { 473 got = append(got, c.(Point)) 474 return 475 } 476 killed := tree.DoBounded(test.bounds, fn) 477 if !reflect.DeepEqual(got, test.want) { 478 t.Errorf("unexpected result from bounded tree iteration: got:%v want:%v", got, test.want) 479 } 480 if killed { 481 t.Error("tree iteration unexpectedly killed") 482 } 483 } 484 } 485 486 func BenchmarkNew(b *testing.B) { 487 rnd := rand.New(rand.NewSource(1)) 488 p := make(Points, 1e5) 489 for i := range p { 490 p[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} 491 } 492 b.ResetTimer() 493 for i := 0; i < b.N; i++ { 494 _ = New(p, false) 495 } 496 } 497 498 func BenchmarkNewBounds(b *testing.B) { 499 rnd := rand.New(rand.NewSource(1)) 500 p := make(Points, 1e5) 501 for i := range p { 502 p[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} 503 } 504 b.ResetTimer() 505 for i := 0; i < b.N; i++ { 506 _ = New(p, true) 507 } 508 } 509 510 func BenchmarkInsert(b *testing.B) { 511 rnd := rand.New(rand.NewSource(1)) 512 t := &Tree{} 513 for i := 0; i < b.N; i++ { 514 t.Insert(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, false) 515 } 516 } 517 518 func BenchmarkInsertBounds(b *testing.B) { 519 rnd := rand.New(rand.NewSource(1)) 520 t := &Tree{} 521 for i := 0; i < b.N; i++ { 522 t.Insert(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, true) 523 } 524 } 525 526 func Benchmark(b *testing.B) { 527 rnd := rand.New(rand.NewSource(1)) 528 data := make(Points, 1e2) 529 for i := range data { 530 data[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} 531 } 532 tree := New(data, true) 533 534 if !tree.Root.isKDTree() { 535 b.Fatal("tree is not k-d tree") 536 } 537 538 for i := 0; i < 1e3; i++ { 539 q := Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} 540 gotP, gotD := tree.Nearest(q) 541 wantP, wantD := nearest(q, data) 542 if !reflect.DeepEqual(gotP, wantP) { 543 b.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP) 544 } 545 if gotD != wantD { 546 b.Errorf("unexpected distance for query %.3f : got:%v want:%v", q, gotD, wantD) 547 } 548 } 549 550 if b.Failed() && *genDot && tree.Len() <= *dotLimit { 551 err := dotFile(tree, "TestBenches", "") 552 if err != nil { 553 b.Fatalf("failed to write DOT file: %v", err) 554 } 555 return 556 } 557 558 var r Comparable 559 var d float64 560 queryBenchmarks := []struct { 561 name string 562 fn func(*testing.B) 563 }{ 564 { 565 name: "Nearest", fn: func(b *testing.B) { 566 for i := 0; i < b.N; i++ { 567 r, d = tree.Nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}) 568 } 569 if r == nil { 570 b.Error("unexpected nil result") 571 } 572 if math.IsNaN(d) { 573 b.Error("unexpected NaN result") 574 } 575 }, 576 }, 577 { 578 name: "NearestBrute", fn: func(b *testing.B) { 579 for i := 0; i < b.N; i++ { 580 r, d = nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data) 581 } 582 if r == nil { 583 b.Error("unexpected nil result") 584 } 585 if math.IsNaN(d) { 586 b.Error("unexpected NaN result") 587 } 588 }, 589 }, 590 { 591 name: "NearestSetN10", fn: func(b *testing.B) { 592 nk := NewNKeeper(10) 593 for i := 0; i < b.N; i++ { 594 tree.NearestSet(nk, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}) 595 if nk.Len() != 10 { 596 b.Error("unexpected result length") 597 } 598 nk.Heap = nk.Heap[:1] 599 nk.Heap[0] = ComparableDist{Dist: inf} 600 } 601 }, 602 }, 603 { 604 name: "NearestBruteN10", fn: func(b *testing.B) { 605 var r []ComparableDist 606 for i := 0; i < b.N; i++ { 607 r = nearestN(10, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data) 608 } 609 if len(r) != 10 { 610 b.Error("unexpected result length", len(r)) 611 } 612 }, 613 }, 614 } 615 for _, bench := range queryBenchmarks { 616 b.Run(bench.name, bench.fn) 617 } 618 } 619 620 func dot(t *Tree, label string) string { 621 if t == nil { 622 return "" 623 } 624 var ( 625 s []string 626 follow func(*Node) 627 ) 628 follow = func(n *Node) { 629 id := uintptr(unsafe.Pointer(n)) 630 c := fmt.Sprintf("%d[label = \"<Left> |<Elem> %s/%.3f\\n%.3f|<Right>\"];", 631 id, n, n.Point.(Point)[n.Plane], *n.Bounding) 632 if n.Left != nil { 633 c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Left -> \"%d\":Elem;", 634 id, uintptr(unsafe.Pointer(n.Left))) 635 follow(n.Left) 636 } 637 if n.Right != nil { 638 c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Right -> \"%d\":Elem;", 639 id, uintptr(unsafe.Pointer(n.Right))) 640 follow(n.Right) 641 } 642 s = append(s, c) 643 } 644 if t.Root != nil { 645 follow(t.Root) 646 } 647 return fmt.Sprintf("digraph %s {\n\tnode [shape=record,height=0.1];\n\t%s\n}\n", 648 label, 649 strings.Join(s, "\n\t"), 650 ) 651 } 652 653 func dotFile(t *Tree, label, dotString string) (err error) { 654 if t == nil && dotString == "" { 655 return 656 } 657 f, err := os.Create(label + ".dot") 658 if err != nil { 659 return 660 } 661 defer f.Close() 662 if dotString == "" { 663 fmt.Fprint(f, dot(t, label)) 664 } else { 665 fmt.Fprint(f, dotString) 666 } 667 return 668 }