github.com/biogo/store@v0.0.0-20201120204734-aad293a2328f/kdtree/kdtree_test.go (about) 1 // Copyright ©2012 The bíogo Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package kdtree 6 7 import ( 8 "flag" 9 "fmt" 10 "math/rand" 11 "os" 12 "reflect" 13 "sort" 14 "strings" 15 "testing" 16 "unsafe" 17 18 "gopkg.in/check.v1" 19 ) 20 21 var ( 22 genDot = flag.Bool("dot", false, "Generate dot code for failing trees.") 23 dotLimit = flag.Int("dotmax", 100, "Maximum size for tree output for dot format.") 24 ) 25 26 func Test(t *testing.T) { check.TestingT(t) } 27 28 type S struct{} 29 30 var _ = check.Suite(&S{}) 31 32 var ( 33 // Using example from WP article. 34 wpData = Points{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}} 35 nbWpData = nbPoints{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}} 36 wpBound = &Bounding{Point{2, 1}, Point{9, 7}} 37 bData = func(i int) Points { 38 p := make(Points, i) 39 for i := range p { 40 p[i] = Point{rand.Float64(), rand.Float64(), rand.Float64()} 41 } 42 return p 43 }(1e2) 44 bTree = New(bData, true) 45 ) 46 47 func (s *S) TestNew(c *check.C) { 48 for i, test := range []struct { 49 data Interface 50 bounding bool 51 bounds *Bounding 52 }{ 53 {wpData, false, nil}, 54 {nbWpData, false, nil}, 55 {wpData, true, wpBound}, 56 {nbWpData, true, nil}, 57 } { 58 var t *Tree 59 NewTreePanics := func() (panicked bool) { 60 defer func() { 61 if r := recover(); r != nil { 62 panicked = true 63 } 64 }() 65 t = New(test.data, test.bounding) 66 return 67 } 68 c.Check(NewTreePanics(), check.Equals, false) 69 c.Check(t.Root.isKDTree(), check.Equals, true) 70 switch data := test.data.(type) { 71 case Points: 72 for _, p := range data { 73 c.Check(t.Contains(p), check.Equals, true) 74 } 75 case nbPoints: 76 for _, p := range data { 77 c.Check(t.Contains(p), check.Equals, true) 78 } 79 } 80 c.Check(t.Root.Bounding, check.DeepEquals, test.bounds, 81 check.Commentf("Test %d. %T %v", i, test.data, test.bounding)) 82 if c.Failed() && *genDot && t.Len() <= *dotLimit { 83 err := dotFile(t, fmt.Sprintf("TestNew%T", test.data), "") 84 if err != nil { 85 c.Errorf("Dot file write failed: %v", err) 86 } 87 } 88 } 89 } 90 91 func (s *S) TestInsert(c *check.C) { 92 for i, test := range []struct { 93 data Interface 94 insert []Comparable 95 bounds *Bounding 96 }{ 97 { 98 wpData, 99 []Comparable{Point{0, 0}, Point{10, 10}}, 100 &Bounding{Point{0, 0}, Point{10, 10}}, 101 }, 102 { 103 nbWpData, 104 []Comparable{nbPoint{0, 0}, nbPoint{10, 10}}, 105 nil, 106 }, 107 } { 108 t := New(test.data, true) 109 for _, v := range test.insert { 110 t.Insert(v, true) 111 } 112 c.Check(t.Root.isKDTree(), check.Equals, true) 113 c.Check(t.Root.Bounding, check.DeepEquals, test.bounds, 114 check.Commentf("Test %d. %T", i, test.data)) 115 if c.Failed() && *genDot && t.Len() <= *dotLimit { 116 err := dotFile(t, fmt.Sprintf("TestInsert%T", test.data), "") 117 if err != nil { 118 c.Errorf("Dot file write failed: %v", err) 119 } 120 } 121 } 122 } 123 124 type compFn func(float64) bool 125 126 func left(v float64) bool { return v <= 0 } 127 func right(v float64) bool { return !left(v) } 128 129 func (n *Node) isKDTree() bool { 130 if n == nil { 131 return true 132 } 133 d := n.Point.Dims() 134 // Together these define the property of minimal orthogonal bounding. 135 if !(n.isContainedBy(n.Bounding) && n.Bounding.planesHaveCoincidentPointsIn(n, [2][]bool{make([]bool, d), make([]bool, d)})) { 136 return false 137 } 138 if !n.Left.isPartitioned(n.Point, left, n.Plane) { 139 return false 140 } 141 if !n.Right.isPartitioned(n.Point, right, n.Plane) { 142 return false 143 } 144 return n.Left.isKDTree() && n.Right.isKDTree() 145 } 146 147 func (n *Node) isPartitioned(pivot Comparable, fn compFn, plane Dim) bool { 148 if n == nil { 149 return true 150 } 151 if n.Left != nil && fn(pivot.Compare(n.Left.Point, plane)) { 152 return false 153 } 154 if n.Right != nil && fn(pivot.Compare(n.Right.Point, plane)) { 155 return false 156 } 157 return n.Left.isPartitioned(pivot, fn, plane) && n.Right.isPartitioned(pivot, fn, plane) 158 } 159 160 func (n *Node) isContainedBy(b *Bounding) bool { 161 if n == nil { 162 return true 163 } 164 if !b.Contains(n.Point) { 165 return false 166 } 167 return n.Left.isContainedBy(b) && n.Right.isContainedBy(b) 168 } 169 170 func (b *Bounding) planesHaveCoincidentPointsIn(n *Node, tight [2][]bool) bool { 171 if b == nil { 172 return true 173 } 174 if n == nil { 175 return true 176 } 177 178 b.planesHaveCoincidentPointsIn(n.Left, tight) 179 b.planesHaveCoincidentPointsIn(n.Right, tight) 180 181 var ok = true 182 for i := range tight { 183 for d := 0; d < n.Point.Dims(); d++ { 184 if c := n.Point.Compare(b[0], Dim(d)); c == 0 { 185 tight[i][d] = true 186 } 187 ok = ok && tight[i][d] 188 } 189 } 190 return ok 191 } 192 193 func nearest(q Point, p Points) (Point, float64) { 194 min := q.Distance(p[0]) 195 var r int 196 for i := 1; i < p.Len(); i++ { 197 d := q.Distance(p[i]) 198 if d < min { 199 min = d 200 r = i 201 } 202 } 203 return p[r], min 204 } 205 206 func (s *S) TestNearestRandom(c *check.C) { 207 const ( 208 min = 0. 209 max = 1000. 210 211 dims = 4 212 setSize = 10000 213 ) 214 215 var randData Points 216 for i := 0; i < setSize; i++ { 217 p := make(Point, dims) 218 for j := 0; j < dims; j++ { 219 p[j] = (max-min)*rand.Float64() + min 220 } 221 randData = append(randData, p) 222 } 223 t := New(randData, false) 224 225 for i := 0; i < setSize; i++ { 226 q := make(Point, dims) 227 for j := 0; j < dims; j++ { 228 q[j] = (max-min)*rand.Float64() + min 229 } 230 231 p, _ := t.Nearest(q) 232 ep, _ := nearest(q, randData) 233 c.Assert(p, check.DeepEquals, ep, check.Commentf("Test %d: query %.3f expects %.3f", i, q, ep)) 234 } 235 } 236 237 func (s *S) TestNearest(c *check.C) { 238 t := New(wpData, false) 239 for i, q := range append([]Point{ 240 {4, 6}, 241 {7, 5}, 242 {8, 7}, 243 {6, -5}, 244 {1e5, 1e5}, 245 {1e5, -1e5}, 246 {-1e5, 1e5}, 247 {-1e5, -1e5}, 248 {1e5, 0}, 249 {0, -1e5}, 250 {0, 1e5}, 251 {-1e5, 0}, 252 }, wpData...) { 253 p, d := t.Nearest(q) 254 ep, ed := nearest(q, wpData) 255 c.Check(p, check.DeepEquals, ep, check.Commentf("Test %d: query %.3f expects %.3f", i, q, ep)) 256 c.Check(d, check.Equals, ed) 257 } 258 } 259 260 func nearestN(n int, q Point, p Points) []ComparableDist { 261 nk := NewNKeeper(n) 262 for i := 0; i < p.Len(); i++ { 263 nk.Keep(ComparableDist{Comparable: p[i], Dist: q.Distance(p[i])}) 264 } 265 if len(nk.Heap) == 1 { 266 return nk.Heap 267 } 268 sort.Sort(nk) 269 for i, j := 0, len(nk.Heap)-1; i < j; i, j = i+1, j-1 { 270 nk.Heap[i], nk.Heap[j] = nk.Heap[j], nk.Heap[i] 271 } 272 return nk.Heap 273 } 274 275 func (s *S) TestNearestSetN(c *check.C) { 276 t := New(wpData, false) 277 in := append([]Point{ 278 {4, 6}, 279 {7, 5}, 280 {8, 7}, 281 {6, -5}, 282 {1e5, 1e5}, 283 {1e5, -1e5}, 284 {-1e5, 1e5}, 285 {-1e5, -1e5}, 286 {1e5, 0}, 287 {0, -1e5}, 288 {0, 1e5}, 289 {-1e5, 0}}, wpData[:len(wpData)-1]...) 290 for k := 1; k <= len(wpData); k++ { 291 for i, q := range in { 292 ep := nearestN(k, q, wpData) 293 nk := NewNKeeper(k) 294 t.NearestSet(nk, q) 295 296 var max float64 297 ed := make(map[float64]map[string]struct{}) 298 for _, p := range ep { 299 if p.Dist > max { 300 max = p.Dist 301 } 302 d, ok := ed[p.Dist] 303 if !ok { 304 d = make(map[string]struct{}) 305 } 306 d[fmt.Sprint(p.Comparable)] = struct{}{} 307 ed[p.Dist] = d 308 } 309 kd := make(map[float64]map[string]struct{}) 310 for _, p := range nk.Heap { 311 c.Check(max >= p.Dist, check.Equals, true) 312 d, ok := kd[p.Dist] 313 if !ok { 314 d = make(map[string]struct{}) 315 } 316 d[fmt.Sprint(p.Comparable)] = struct{}{} 317 kd[p.Dist] = d 318 } 319 320 // If the available number of slots does not fit all the coequal furthest points 321 // we will fail the check. So remove, but check them minimally here. 322 if !reflect.DeepEqual(ed[max], kd[max]) { 323 // The best we can do at this stage is confirm that there are an equal number of matches at this distance. 324 c.Check(len(ed[max]), check.Equals, len(kd[max])) 325 delete(ed, max) 326 delete(kd, max) 327 } 328 329 c.Check(kd, check.DeepEquals, ed, check.Commentf("Test k=%d %d: query %.3f expects %.3f", k, i, q, ep)) 330 } 331 } 332 } 333 334 func (s *S) TestNearestSetDist(c *check.C) { 335 t := New(wpData, false) 336 for i, q := range []Point{ 337 {4, 6}, 338 {7, 5}, 339 {8, 7}, 340 {6, -5}, 341 } { 342 for d := 1.; d < 100; d += 0.1 { 343 dk := NewDistKeeper(d) 344 t.NearestSet(dk, q) 345 346 hits := make(map[string]float64) 347 for _, p := range wpData { 348 hits[fmt.Sprint(p)] = p.Distance(q) 349 } 350 351 for _, p := range dk.Heap { 352 var finished bool 353 if p.Comparable != nil { 354 delete(hits, fmt.Sprint(p.Comparable)) 355 c.Check(finished, check.Equals, false) 356 dist := p.Comparable.Distance(q) 357 c.Check(dist <= d, check.Equals, true, check.Commentf("Test %d: query %v found %v expect %.3f <= %.3f", i, q, p, dist, d)) 358 } else { 359 finished = true 360 } 361 } 362 363 for p, dist := range hits { 364 c.Check(dist > d, check.Equals, true, check.Commentf("Test %d: query %v missed %v expect %.3f > %.3f", i, q, p, dist, d)) 365 } 366 } 367 } 368 } 369 370 func (s *S) TestDo(c *check.C) { 371 var result Points 372 t := New(wpData, false) 373 f := func(c Comparable, _ *Bounding, _ int) (done bool) { 374 result = append(result, c.(Point)) 375 return 376 } 377 killed := t.Do(f) 378 c.Check(result, check.DeepEquals, wpData) 379 c.Check(killed, check.Equals, false) 380 } 381 382 func (s *S) TestDoBounded(c *check.C) { 383 for _, test := range []struct { 384 bounds *Bounding 385 result Points 386 }{ 387 { 388 nil, 389 wpData, 390 }, 391 { 392 &Bounding{Point{0, 0}, Point{10, 10}}, 393 wpData, 394 }, 395 { 396 &Bounding{Point{3, 4}, Point{10, 10}}, 397 Points{Point{5, 4}, Point{4, 7}, Point{9, 6}}, 398 }, 399 { 400 &Bounding{Point{3, 3}, Point{10, 10}}, 401 Points{Point{5, 4}, Point{4, 7}, Point{9, 6}}, 402 }, 403 { 404 &Bounding{Point{0, 0}, Point{6, 5}}, 405 Points{Point{2, 3}, Point{5, 4}}, 406 }, 407 { 408 &Bounding{Point{5, 2}, Point{7, 4}}, 409 Points{Point{5, 4}, Point{7, 2}}, 410 }, 411 { 412 &Bounding{Point{2, 2}, Point{7, 4}}, 413 Points{Point{2, 3}, Point{5, 4}, Point{7, 2}}, 414 }, 415 { 416 &Bounding{Point{2, 3}, Point{9, 6}}, 417 Points{Point{2, 3}, Point{5, 4}, Point{9, 6}}, 418 }, 419 { 420 &Bounding{Point{7, 2}, Point{7, 2}}, 421 Points{Point{7, 2}}, 422 }, 423 } { 424 var result Points 425 t := New(wpData, false) 426 f := func(c Comparable, _ *Bounding, _ int) (done bool) { 427 result = append(result, c.(Point)) 428 return 429 } 430 killed := t.DoBounded(f, test.bounds) 431 c.Check(result, check.DeepEquals, test.result) 432 c.Check(killed, check.Equals, false) 433 } 434 } 435 436 func BenchmarkNew(b *testing.B) { 437 b.StopTimer() 438 p := make(Points, 1e5) 439 for i := range p { 440 p[i] = Point{rand.Float64(), rand.Float64(), rand.Float64()} 441 } 442 b.StartTimer() 443 for i := 0; i < b.N; i++ { 444 _ = New(p, false) 445 } 446 } 447 448 func BenchmarkNewBounds(b *testing.B) { 449 b.StopTimer() 450 p := make(Points, 1e5) 451 for i := range p { 452 p[i] = Point{rand.Float64(), rand.Float64(), rand.Float64()} 453 } 454 b.StartTimer() 455 for i := 0; i < b.N; i++ { 456 _ = New(p, true) 457 } 458 } 459 460 func BenchmarkInsert(b *testing.B) { 461 rand.Seed(1) 462 t := &Tree{} 463 for i := 0; i < b.N; i++ { 464 t.Insert(Point{rand.Float64(), rand.Float64(), rand.Float64()}, false) 465 } 466 } 467 468 func BenchmarkInsertBounds(b *testing.B) { 469 rand.Seed(1) 470 t := &Tree{} 471 for i := 0; i < b.N; i++ { 472 t.Insert(Point{rand.Float64(), rand.Float64(), rand.Float64()}, true) 473 } 474 } 475 476 func (s *S) TestBenches(c *check.C) { 477 c.Check(bTree.Root.isKDTree(), check.Equals, true) 478 for i := 0; i < 1e3; i++ { 479 q := Point{rand.Float64(), rand.Float64(), rand.Float64()} 480 p, d := bTree.Nearest(q) 481 ep, ed := nearest(q, bData) 482 c.Check(p, check.DeepEquals, ep, check.Commentf("Test %d: query %.3f expects %.3f", i, q, ep)) 483 c.Check(d, check.Equals, ed) 484 } 485 if c.Failed() && *genDot && bTree.Len() <= *dotLimit { 486 err := dotFile(bTree, "TestBenches", "") 487 if err != nil { 488 c.Errorf("Dot file write failed: %v", err) 489 } 490 } 491 } 492 493 func BenchmarkNearest(b *testing.B) { 494 var ( 495 r Comparable 496 d float64 497 ) 498 for i := 0; i < b.N; i++ { 499 r, d = bTree.Nearest(Point{rand.Float64(), rand.Float64(), rand.Float64()}) 500 } 501 _, _ = r, d 502 } 503 504 func BenchmarkNearBrute(b *testing.B) { 505 var ( 506 r Comparable 507 d float64 508 ) 509 for i := 0; i < b.N; i++ { 510 r, d = nearest(Point{rand.Float64(), rand.Float64(), rand.Float64()}, bData) 511 } 512 _, _ = r, d 513 } 514 515 func BenchmarkNearestSetN10(b *testing.B) { 516 var nk = NewNKeeper(10) 517 for i := 0; i < b.N; i++ { 518 bTree.NearestSet(nk, Point{rand.Float64(), rand.Float64(), rand.Float64()}) 519 nk.Heap = nk.Heap[:1] 520 nk.Heap[0] = ComparableDist{Comparable: nil, Dist: inf} 521 } 522 } 523 524 func BenchmarkNearBruteN10(b *testing.B) { 525 var r []ComparableDist 526 for i := 0; i < b.N; i++ { 527 r = nearestN(10, Point{rand.Float64(), rand.Float64(), rand.Float64()}, bData) 528 } 529 _ = r 530 } 531 532 func dot(t *Tree, label string) string { 533 if t == nil { 534 return "" 535 } 536 var ( 537 s []string 538 follow func(*Node) 539 ) 540 follow = func(n *Node) { 541 id := uintptr(unsafe.Pointer(n)) 542 c := fmt.Sprintf("%d[label = \"<Left> |<Elem> %s/%.3f\\n%.3f|<Right>\"];", 543 id, n, n.Point.(Point)[n.Plane], *n.Bounding) 544 if n.Left != nil { 545 c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Left -> \"%d\":Elem;", 546 id, uintptr(unsafe.Pointer(n.Left))) 547 follow(n.Left) 548 } 549 if n.Right != nil { 550 c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Right -> \"%d\":Elem;", 551 id, uintptr(unsafe.Pointer(n.Right))) 552 follow(n.Right) 553 } 554 s = append(s, c) 555 } 556 if t.Root != nil { 557 follow(t.Root) 558 } 559 return fmt.Sprintf("digraph %s {\n\tnode [shape=record,height=0.1];\n\t%s\n}\n", 560 label, 561 strings.Join(s, "\n\t"), 562 ) 563 } 564 565 func dotFile(t *Tree, label, dotString string) (err error) { 566 if t == nil && dotString == "" { 567 return 568 } 569 f, err := os.Create(label + ".dot") 570 if err != nil { 571 return 572 } 573 defer f.Close() 574 if dotString == "" { 575 fmt.Fprintf(f, dot(t, label)) 576 } else { 577 fmt.Fprintf(f, dotString) 578 } 579 return 580 }