gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/spatial/kdtree/kdtree_test.go (about) 1 // Copyright ©2019 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package kdtree 6 7 import ( 8 "flag" 9 "fmt" 10 "math" 11 "os" 12 "reflect" 13 "slices" 14 "sort" 15 "strings" 16 "testing" 17 "unsafe" 18 19 "golang.org/x/exp/rand" 20 ) 21 22 var ( 23 genDot = flag.Bool("dot", false, "generate dot code for failing trees") 24 dotLimit = flag.Int("dotmax", 100, "specify maximum size for tree output for dot format") 25 ) 26 27 var ( 28 // Using example from WP article: https://en.wikipedia.org/w/index.php?title=K-d_tree&oldid=887573572. 29 wpData = Points{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}} 30 nbWpData = nbPoints{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}} 31 wpBound = &Bounding{Point{2, 1}, Point{9, 7}} 32 ) 33 34 var newTests = []struct { 35 data Interface 36 bounding bool 37 wantBounds *Bounding 38 }{ 39 {data: wpData, bounding: false, wantBounds: nil}, 40 {data: nbWpData, bounding: false, wantBounds: nil}, 41 {data: wpData, bounding: true, wantBounds: wpBound}, 42 {data: nbWpData, bounding: true, wantBounds: nil}, 43 } 44 45 func TestNew(t *testing.T) { 46 for i, test := range newTests { 47 var tree *Tree 48 var panicked bool 49 func() { 50 defer func() { 51 if r := recover(); r != nil { 52 panicked = true 53 } 54 }() 55 tree = New(test.data, test.bounding) 56 }() 57 if panicked { 58 t.Errorf("unexpected panic for test %d", i) 59 continue 60 } 61 62 if !tree.Root.isKDTree() { 63 t.Errorf("tree %d is not k-d tree", i) 64 } 65 66 switch data := test.data.(type) { 67 case Points: 68 for _, p := range data { 69 if !tree.Contains(p) { 70 t.Errorf("failed to find point %.3f in test %d", p, i) 71 } 72 } 73 case nbPoints: 74 for _, p := range data { 75 if !tree.Contains(p) { 76 t.Errorf("failed to find point %.3f in test %d", p, i) 77 } 78 } 79 default: 80 t.Fatalf("bad test: unknown data type: %T", test.data) 81 } 82 83 if !reflect.DeepEqual(tree.Root.Bounding, test.wantBounds) { 84 t.Errorf("unexpected bounding box for test %d with data type %T: got:%v want:%v", 85 i, test.data, tree.Root.Bounding, test.wantBounds) 86 } 87 88 if t.Failed() && *genDot && tree.Len() <= *dotLimit { 89 err := dotFile(tree, fmt.Sprintf("TestNew%T", test.data), "") 90 if err != nil { 91 t.Fatalf("failed to write DOT file: %v", err) 92 } 93 } 94 } 95 } 96 97 var insertTests = []struct { 98 data Interface 99 insert []Comparable 100 wantBounds *Bounding 101 }{ 102 { 103 data: wpData, 104 insert: []Comparable{Point{0, 0}, Point{10, 10}}, 105 wantBounds: &Bounding{Point{0, 0}, Point{10, 10}}, 106 }, 107 { 108 data: nbWpData, 109 insert: []Comparable{nbPoint{0, 0}, nbPoint{10, 10}}, 110 wantBounds: nil, 111 }, 112 } 113 114 func TestInsert(t *testing.T) { 115 for i, test := range insertTests { 116 tree := New(test.data, true) 117 for _, v := range test.insert { 118 tree.Insert(v, true) 119 } 120 121 if !tree.Root.isKDTree() { 122 t.Errorf("tree %d is not k-d tree", i) 123 } 124 125 if !reflect.DeepEqual(tree.Root.Bounding, test.wantBounds) { 126 t.Errorf("unexpected bounding box for test %d with data type %T: got:%v want:%v", 127 i, test.data, tree.Root.Bounding, test.wantBounds) 128 } 129 130 if t.Failed() && *genDot && tree.Len() <= *dotLimit { 131 err := dotFile(tree, fmt.Sprintf("TestInsert%T", test.data), "") 132 if err != nil { 133 t.Fatalf("failed to write DOT file: %v", err) 134 } 135 } 136 } 137 } 138 139 type compFn func(float64) bool 140 141 func left(v float64) bool { return v <= 0 } 142 func right(v float64) bool { return !left(v) } 143 144 func (n *Node) isKDTree() bool { 145 if n == nil { 146 return true 147 } 148 d := n.Point.Dims() 149 // Together these define the property of minimal orthogonal bounding. 150 if !(n.isContainedBy(n.Bounding) && n.Bounding.planesHaveCoincidentPointsIn(n, [2][]bool{make([]bool, d), make([]bool, d)})) { 151 return false 152 } 153 if !n.Left.isPartitioned(n.Point, left, n.Plane) { 154 return false 155 } 156 if !n.Right.isPartitioned(n.Point, right, n.Plane) { 157 return false 158 } 159 return n.Left.isKDTree() && n.Right.isKDTree() 160 } 161 162 func (n *Node) isPartitioned(pivot Comparable, fn compFn, plane Dim) bool { 163 if n == nil { 164 return true 165 } 166 if n.Left != nil && fn(pivot.Compare(n.Left.Point, plane)) { 167 return false 168 } 169 if n.Right != nil && fn(pivot.Compare(n.Right.Point, plane)) { 170 return false 171 } 172 return n.Left.isPartitioned(pivot, fn, plane) && n.Right.isPartitioned(pivot, fn, plane) 173 } 174 175 func (n *Node) isContainedBy(b *Bounding) bool { 176 if n == nil { 177 return true 178 } 179 if !b.Contains(n.Point) { 180 return false 181 } 182 return n.Left.isContainedBy(b) && n.Right.isContainedBy(b) 183 } 184 185 func (b *Bounding) planesHaveCoincidentPointsIn(n *Node, tight [2][]bool) bool { 186 if b == nil { 187 return true 188 } 189 if n == nil { 190 return true 191 } 192 193 b.planesHaveCoincidentPointsIn(n.Left, tight) 194 b.planesHaveCoincidentPointsIn(n.Right, tight) 195 196 var ok = true 197 for i := range tight { 198 for d := 0; d < n.Point.Dims(); d++ { 199 if c := n.Point.Compare(b.Min, Dim(d)); c == 0 { 200 tight[i][d] = true 201 } 202 ok = ok && tight[i][d] 203 } 204 } 205 return ok 206 } 207 208 func nearest(q Point, p Points) (Point, float64) { 209 min := q.Distance(p[0]) 210 var r int 211 for i := 1; i < p.Len(); i++ { 212 d := q.Distance(p[i]) 213 if d < min { 214 min = d 215 r = i 216 } 217 } 218 return p[r], min 219 } 220 221 func TestNearestRandom(t *testing.T) { 222 rnd := rand.New(rand.NewSource(1)) 223 224 const ( 225 min = 0.0 226 max = 1000.0 227 228 dims = 4 229 setSize = 10000 230 ) 231 232 var randData Points 233 for i := 0; i < setSize; i++ { 234 p := make(Point, dims) 235 for j := 0; j < dims; j++ { 236 p[j] = (max-min)*rnd.Float64() + min 237 } 238 randData = append(randData, p) 239 } 240 tree := New(randData, false) 241 242 for i := 0; i < setSize; i++ { 243 q := make(Point, dims) 244 for j := 0; j < dims; j++ { 245 q[j] = (max-min)*rnd.Float64() + min 246 } 247 248 got, _ := tree.Nearest(q) 249 want, _ := nearest(q, randData) 250 if !reflect.DeepEqual(got, want) { 251 t.Fatalf("unexpected result from query %d %.3f: got:%.3f want:%.3f", i, q, got, want) 252 } 253 } 254 } 255 256 func TestNearest(t *testing.T) { 257 tree := New(wpData, false) 258 for _, q := range append([]Point{ 259 {4, 6}, 260 {7, 5}, 261 {8, 7}, 262 {6, -5}, 263 {1e5, 1e5}, 264 {1e5, -1e5}, 265 {-1e5, 1e5}, 266 {-1e5, -1e5}, 267 {1e5, 0}, 268 {0, -1e5}, 269 {0, 1e5}, 270 {-1e5, 0}, 271 }, wpData...) { 272 gotP, gotD := tree.Nearest(q) 273 wantP, wantD := nearest(q, wpData) 274 if !reflect.DeepEqual(gotP, wantP) { 275 t.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP) 276 } 277 if gotD != wantD { 278 t.Errorf("unexpected distance for query %.3f : got:%v want:%v", q, gotD, wantD) 279 } 280 } 281 } 282 283 func nearestN(n int, q Point, p Points) []ComparableDist { 284 nk := NewNKeeper(n) 285 for i := 0; i < p.Len(); i++ { 286 nk.Keep(ComparableDist{Comparable: p[i], Dist: q.Distance(p[i])}) 287 } 288 if len(nk.Heap) == 1 { 289 return nk.Heap 290 } 291 sort.Sort(nk) 292 slices.Reverse(nk.Heap) 293 return nk.Heap 294 } 295 296 func TestNearestSetN(t *testing.T) { 297 data := append([]Point{ 298 {4, 6}, 299 {7, 5}, 300 {8, 7}, 301 {6, -5}, 302 {1e5, 1e5}, 303 {1e5, -1e5}, 304 {-1e5, 1e5}, 305 {-1e5, -1e5}, 306 {1e5, 0}, 307 {0, -1e5}, 308 {0, 1e5}, 309 {-1e5, 0}}, 310 wpData[:len(wpData)-1]...) 311 312 tree := New(wpData, false) 313 for k := 1; k <= len(wpData); k++ { 314 for _, q := range data { 315 wantP := nearestN(k, q, wpData) 316 317 nk := NewNKeeper(k) 318 tree.NearestSet(nk, q) 319 320 var max float64 321 wantD := make(map[float64]map[string]struct{}) 322 for _, p := range wantP { 323 if p.Dist > max { 324 max = p.Dist 325 } 326 d, ok := wantD[p.Dist] 327 if !ok { 328 d = make(map[string]struct{}) 329 } 330 d[fmt.Sprint(p.Comparable)] = struct{}{} 331 wantD[p.Dist] = d 332 } 333 gotD := make(map[float64]map[string]struct{}) 334 for _, p := range nk.Heap { 335 if p.Dist > max { 336 t.Errorf("unexpected distance for point %.3f: got:%v want:<=%v", p.Comparable, p.Dist, max) 337 } 338 d, ok := gotD[p.Dist] 339 if !ok { 340 d = make(map[string]struct{}) 341 } 342 d[fmt.Sprint(p.Comparable)] = struct{}{} 343 gotD[p.Dist] = d 344 } 345 346 // If the available number of slots does not fit all the coequal furthest points 347 // we will fail the check. So remove, but check them minimally here. 348 if !reflect.DeepEqual(wantD[max], gotD[max]) { 349 // The best we can do at this stage is confirm that there are an equal number of matches at this distance. 350 if len(gotD[max]) != len(wantD[max]) { 351 t.Errorf("unexpected number of maximal distance points: got:%d want:%d", len(gotD[max]), len(wantD[max])) 352 } 353 delete(wantD, max) 354 delete(gotD, max) 355 } 356 357 if !reflect.DeepEqual(gotD, wantD) { 358 t.Errorf("unexpected result for k=%d query %.3f: got:%v want:%v", k, q, gotD, wantD) 359 } 360 } 361 } 362 } 363 364 var nearestSetDistTests = []Point{ 365 {4, 6}, 366 {7, 5}, 367 {8, 7}, 368 {6, -5}, 369 } 370 371 func TestNearestSetDist(t *testing.T) { 372 tree := New(wpData, false) 373 for i, q := range nearestSetDistTests { 374 for d := 1.0; d < 100; d += 0.1 { 375 dk := NewDistKeeper(d) 376 tree.NearestSet(dk, q) 377 378 hits := make(map[string]float64) 379 for _, p := range wpData { 380 hits[fmt.Sprint(p)] = p.Distance(q) 381 } 382 383 for _, p := range dk.Heap { 384 var done bool 385 if p.Comparable == nil { 386 done = true 387 continue 388 } 389 delete(hits, fmt.Sprint(p.Comparable)) 390 if done { 391 t.Error("expectedly finished heap iteration") 392 break 393 } 394 dist := p.Comparable.Distance(q) 395 if dist > d { 396 t.Errorf("Test %d: query %v found %v expect %.3f <= %.3f", i, q, p, dist, d) 397 } 398 } 399 400 for p, dist := range hits { 401 if dist <= d { 402 t.Errorf("Test %d: query %v missed %v expect %.3f > %.3f", i, q, p, dist, d) 403 } 404 } 405 } 406 } 407 } 408 409 func TestDo(t *testing.T) { 410 tree := New(wpData, false) 411 var got Points 412 fn := func(c Comparable, _ *Bounding, _ int) (done bool) { 413 got = append(got, c.(Point)) 414 return 415 } 416 killed := tree.Do(fn) 417 if !reflect.DeepEqual(got, wpData) { 418 t.Errorf("unexpected result from tree iteration: got:%v want:%v", got, wpData) 419 } 420 if killed { 421 t.Error("tree iteration unexpectedly killed") 422 } 423 } 424 425 var doBoundedTests = []struct { 426 bounds *Bounding 427 want Points 428 }{ 429 { 430 bounds: nil, 431 want: wpData, 432 }, 433 { 434 bounds: &Bounding{Point{0, 0}, Point{10, 10}}, 435 want: wpData, 436 }, 437 { 438 bounds: &Bounding{Point{3, 4}, Point{10, 10}}, 439 want: Points{Point{5, 4}, Point{4, 7}, Point{9, 6}}, 440 }, 441 { 442 bounds: &Bounding{Point{3, 3}, Point{10, 10}}, 443 want: Points{Point{5, 4}, Point{4, 7}, Point{9, 6}}, 444 }, 445 { 446 bounds: &Bounding{Point{0, 0}, Point{6, 5}}, 447 want: Points{Point{2, 3}, Point{5, 4}}, 448 }, 449 { 450 bounds: &Bounding{Point{5, 2}, Point{7, 4}}, 451 want: Points{Point{5, 4}, Point{7, 2}}, 452 }, 453 { 454 bounds: &Bounding{Point{2, 2}, Point{7, 4}}, 455 want: Points{Point{2, 3}, Point{5, 4}, Point{7, 2}}, 456 }, 457 { 458 bounds: &Bounding{Point{2, 3}, Point{9, 6}}, 459 want: Points{Point{2, 3}, Point{5, 4}, Point{9, 6}}, 460 }, 461 { 462 bounds: &Bounding{Point{7, 2}, Point{7, 2}}, 463 want: Points{Point{7, 2}}, 464 }, 465 } 466 467 func TestDoBounded(t *testing.T) { 468 for _, test := range doBoundedTests { 469 tree := New(wpData, false) 470 var got Points 471 fn := func(c Comparable, _ *Bounding, _ int) (done bool) { 472 got = append(got, c.(Point)) 473 return 474 } 475 killed := tree.DoBounded(test.bounds, fn) 476 if !reflect.DeepEqual(got, test.want) { 477 t.Errorf("unexpected result from bounded tree iteration: got:%v want:%v", got, test.want) 478 } 479 if killed { 480 t.Error("tree iteration unexpectedly killed") 481 } 482 } 483 } 484 485 func BenchmarkNew(b *testing.B) { 486 rnd := rand.New(rand.NewSource(1)) 487 p := make(Points, 1e5) 488 for i := range p { 489 p[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} 490 } 491 b.ResetTimer() 492 for i := 0; i < b.N; i++ { 493 _ = New(p, false) 494 } 495 } 496 497 func BenchmarkNewBounds(b *testing.B) { 498 rnd := rand.New(rand.NewSource(1)) 499 p := make(Points, 1e5) 500 for i := range p { 501 p[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} 502 } 503 b.ResetTimer() 504 for i := 0; i < b.N; i++ { 505 _ = New(p, true) 506 } 507 } 508 509 func BenchmarkInsert(b *testing.B) { 510 rnd := rand.New(rand.NewSource(1)) 511 t := &Tree{} 512 for i := 0; i < b.N; i++ { 513 t.Insert(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, false) 514 } 515 } 516 517 func BenchmarkInsertBounds(b *testing.B) { 518 rnd := rand.New(rand.NewSource(1)) 519 t := &Tree{} 520 for i := 0; i < b.N; i++ { 521 t.Insert(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, true) 522 } 523 } 524 525 func Benchmark(b *testing.B) { 526 rnd := rand.New(rand.NewSource(1)) 527 data := make(Points, 1e2) 528 for i := range data { 529 data[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} 530 } 531 tree := New(data, true) 532 533 if !tree.Root.isKDTree() { 534 b.Fatal("tree is not k-d tree") 535 } 536 537 for i := 0; i < 1e3; i++ { 538 q := Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} 539 gotP, gotD := tree.Nearest(q) 540 wantP, wantD := nearest(q, data) 541 if !reflect.DeepEqual(gotP, wantP) { 542 b.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP) 543 } 544 if gotD != wantD { 545 b.Errorf("unexpected distance for query %.3f : got:%v want:%v", q, gotD, wantD) 546 } 547 } 548 549 if b.Failed() && *genDot && tree.Len() <= *dotLimit { 550 err := dotFile(tree, "TestBenches", "") 551 if err != nil { 552 b.Fatalf("failed to write DOT file: %v", err) 553 } 554 return 555 } 556 557 var r Comparable 558 var d float64 559 queryBenchmarks := []struct { 560 name string 561 fn func(*testing.B) 562 }{ 563 { 564 name: "Nearest", fn: func(b *testing.B) { 565 for i := 0; i < b.N; i++ { 566 r, d = tree.Nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}) 567 } 568 if r == nil { 569 b.Error("unexpected nil result") 570 } 571 if math.IsNaN(d) { 572 b.Error("unexpected NaN result") 573 } 574 }, 575 }, 576 { 577 name: "NearestBrute", fn: func(b *testing.B) { 578 for i := 0; i < b.N; i++ { 579 r, d = nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data) 580 } 581 if r == nil { 582 b.Error("unexpected nil result") 583 } 584 if math.IsNaN(d) { 585 b.Error("unexpected NaN result") 586 } 587 }, 588 }, 589 { 590 name: "NearestSetN10", fn: func(b *testing.B) { 591 nk := NewNKeeper(10) 592 for i := 0; i < b.N; i++ { 593 tree.NearestSet(nk, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}) 594 if nk.Len() != 10 { 595 b.Error("unexpected result length") 596 } 597 nk.Heap = nk.Heap[:1] 598 nk.Heap[0] = ComparableDist{Dist: inf} 599 } 600 }, 601 }, 602 { 603 name: "NearestBruteN10", fn: func(b *testing.B) { 604 var r []ComparableDist 605 for i := 0; i < b.N; i++ { 606 r = nearestN(10, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data) 607 } 608 if len(r) != 10 { 609 b.Error("unexpected result length", len(r)) 610 } 611 }, 612 }, 613 } 614 for _, bench := range queryBenchmarks { 615 b.Run(bench.name, bench.fn) 616 } 617 } 618 619 func dot(t *Tree, label string) string { 620 if t == nil { 621 return "" 622 } 623 var ( 624 s []string 625 follow func(*Node) 626 ) 627 follow = func(n *Node) { 628 id := uintptr(unsafe.Pointer(n)) 629 c := fmt.Sprintf("%d[label = \"<Left> |<Elem> %s/%.3f\\n%.3f|<Right>\"];", 630 id, n, n.Point.(Point)[n.Plane], *n.Bounding) 631 if n.Left != nil { 632 c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Left -> \"%d\":Elem;", 633 id, uintptr(unsafe.Pointer(n.Left))) 634 follow(n.Left) 635 } 636 if n.Right != nil { 637 c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Right -> \"%d\":Elem;", 638 id, uintptr(unsafe.Pointer(n.Right))) 639 follow(n.Right) 640 } 641 s = append(s, c) 642 } 643 if t.Root != nil { 644 follow(t.Root) 645 } 646 return fmt.Sprintf("digraph %s {\n\tnode [shape=record,height=0.1];\n\t%s\n}\n", 647 label, 648 strings.Join(s, "\n\t"), 649 ) 650 } 651 652 func dotFile(t *Tree, label, dotString string) (err error) { 653 if t == nil && dotString == "" { 654 return 655 } 656 f, err := os.Create(label + ".dot") 657 if err != nil { 658 return 659 } 660 defer f.Close() 661 if dotString == "" { 662 fmt.Fprint(f, dot(t, label)) 663 } else { 664 fmt.Fprint(f, dotString) 665 } 666 return 667 }