gonum.org/v1/gonum@v0.14.0/spatial/vptree/vptree_test.go (about) 1 // Copyright ©2019 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package vptree 6 7 import ( 8 "flag" 9 "fmt" 10 "math" 11 "os" 12 "reflect" 13 "sort" 14 "strings" 15 "testing" 16 "unsafe" 17 18 "golang.org/x/exp/rand" 19 ) 20 21 var ( 22 genDot = flag.Bool("dot", false, "generate dot code for failing trees") 23 dotLimit = flag.Int("dotmax", 100, "specify maximum size for tree output for dot format") 24 ) 25 26 var ( 27 // Using example from WP article: https://en.wikipedia.org/w/index.php?title=K-d_tree&oldid=887573572. 28 wpData = []Comparable{ 29 Point{2, 3}, 30 Point{5, 4}, 31 Point{9, 6}, 32 Point{4, 7}, 33 Point{8, 1}, 34 Point{7, 2}, 35 } 36 ) 37 38 var newTests = []struct { 39 data []Comparable 40 effort int 41 seed uint64 42 }{ 43 {data: wpData, effort: 0, seed: 1}, 44 {data: wpData, effort: 1, seed: 1}, 45 {data: wpData, effort: 2, seed: 1}, 46 {data: wpData, effort: 4, seed: 1}, 47 {data: wpData, effort: 8, seed: 1}, 48 {data: []Comparable{Point{2, 3}, Point{5, 4}, Point{9, 6}, Point{5, 4}, Point{8, 1}, Point{7, 2}}, effort: 3, seed: 5555}, 49 } 50 51 func TestNew(t *testing.T) { 52 for i, test := range newTests { 53 var tree *Tree 54 var err error 55 var panicked bool 56 func() { 57 defer func() { 58 if r := recover(); r != nil { 59 panicked = true 60 } 61 }() 62 tree, err = New(test.data, test.effort, rand.NewSource(test.seed)) 63 }() 64 if panicked { 65 t.Errorf("unexpected panic for test %d", i) 66 continue 67 } 68 if err != nil { 69 t.Errorf("unexpected error for test %d: %v", i, err) 70 continue 71 } 72 73 if !tree.Root.isVPTree() { 74 t.Errorf("tree %d is not vp-tree", i) 75 } 76 77 if t.Failed() && *genDot && tree.Len() <= *dotLimit { 78 err := dotFile(tree, fmt.Sprintf("TestNew%d", i), "") 79 if err != nil { 80 t.Fatalf("failed to write DOT file: %v", err) 81 } 82 } 83 } 84 } 85 86 type compFn func(v, radius float64) bool 87 88 func closer(v, radius float64) bool { return v <= radius } 89 func further(v, radius float64) bool { return v >= radius } 90 91 func (n *Node) isVPTree() bool { 92 if n == nil { 93 return true 94 } 95 if !n.Closer.isPartitioned(n.Point, closer, n.Radius) { 96 return false 97 } 98 if !n.Further.isPartitioned(n.Point, further, n.Radius) { 99 return false 100 } 101 return n.Closer.isVPTree() && n.Further.isVPTree() 102 } 103 104 func (n *Node) isPartitioned(vp Comparable, fn compFn, radius float64) bool { 105 if n == nil { 106 return true 107 } 108 if n.Closer != nil && !fn(vp.Distance(n.Closer.Point), radius) { 109 return false 110 } 111 if n.Further != nil && !fn(vp.Distance(n.Further.Point), radius) { 112 return false 113 } 114 return n.Closer.isPartitioned(vp, fn, radius) && n.Further.isPartitioned(vp, fn, radius) 115 } 116 117 func nearest(q Comparable, p []Comparable) (Comparable, float64) { 118 min := q.Distance(p[0]) 119 var r int 120 for i := 1; i < len(p); i++ { 121 d := q.Distance(p[i]) 122 if d < min { 123 min = d 124 r = i 125 } 126 } 127 return p[r], min 128 } 129 130 func TestNearestRandom(t *testing.T) { 131 rnd := rand.New(rand.NewSource(1)) 132 133 const ( 134 min = 0.0 135 max = 1000.0 136 137 dims = 4 138 setSize = 10000 139 ) 140 141 var randData []Comparable 142 for i := 0; i < setSize; i++ { 143 p := make(Point, dims) 144 for j := 0; j < dims; j++ { 145 p[j] = (max-min)*rnd.Float64() + min 146 } 147 randData = append(randData, p) 148 } 149 tree, err := New(randData, 10, rand.NewSource(1)) 150 if err != nil { 151 t.Fatalf("unexpected error: %v", err) 152 } 153 154 for i := 0; i < setSize; i++ { 155 q := make(Point, dims) 156 for j := 0; j < dims; j++ { 157 q[j] = (max-min)*rnd.Float64() + min 158 } 159 160 got, _ := tree.Nearest(q) 161 want, _ := nearest(q, randData) 162 if !reflect.DeepEqual(got, want) { 163 t.Fatalf("unexpected result from query %d %.3f: got:%.3f want:%.3f", i, q, got, want) 164 } 165 } 166 } 167 168 func TestNearest(t *testing.T) { 169 tree, err := New(wpData, 3, rand.NewSource(1)) 170 if err != nil { 171 t.Fatalf("unexpected error: %v", err) 172 } 173 for _, q := range append([]Comparable{ 174 Point{4, 6}, 175 // Point{7, 5}, // Omitted because it is ambiguously finds [9 6] or [5 4]. 176 Point{8, 7}, 177 Point{6, -5}, 178 Point{1e5, 1e5}, 179 Point{1e5, -1e5}, 180 Point{-1e5, 1e5}, 181 Point{-1e5, -1e5}, 182 Point{1e5, 0}, 183 Point{0, -1e5}, 184 Point{0, 1e5}, 185 Point{-1e5, 0}, 186 }, wpData...) { 187 gotP, gotD := tree.Nearest(q) 188 wantP, wantD := nearest(q, wpData) 189 if !reflect.DeepEqual(gotP, wantP) { 190 t.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP) 191 } 192 if gotD != wantD { 193 t.Errorf("unexpected distance for query %.3f : got:%v want:%v", q, gotD, wantD) 194 } 195 } 196 } 197 198 func nearestN(n int, q Comparable, p []Comparable) []ComparableDist { 199 nk := NewNKeeper(n) 200 for i := 0; i < len(p); i++ { 201 nk.Keep(ComparableDist{Comparable: p[i], Dist: q.Distance(p[i])}) 202 } 203 if len(nk.Heap) == 1 { 204 return nk.Heap 205 } 206 sort.Sort(nk) 207 for i, j := 0, len(nk.Heap)-1; i < j; i, j = i+1, j-1 { 208 nk.Heap[i], nk.Heap[j] = nk.Heap[j], nk.Heap[i] 209 } 210 return nk.Heap 211 } 212 213 func TestNearestSetN(t *testing.T) { 214 data := append([]Comparable{ 215 Point{4, 6}, 216 Point{7, 5}, // OK here because we collect N. 217 Point{8, 7}, 218 Point{6, -5}, 219 Point{1e5, 1e5}, 220 Point{1e5, -1e5}, 221 Point{-1e5, 1e5}, 222 Point{-1e5, -1e5}, 223 Point{1e5, 0}, 224 Point{0, -1e5}, 225 Point{0, 1e5}, 226 Point{-1e5, 0}}, 227 wpData[:len(wpData)-1]...) 228 229 tree, err := New(wpData, 3, rand.NewSource(1)) 230 if err != nil { 231 t.Fatalf("unexpected error: %v", err) 232 } 233 for k := 1; k <= len(wpData); k++ { 234 for _, q := range data { 235 wantP := nearestN(k, q, wpData) 236 237 nk := NewNKeeper(k) 238 tree.NearestSet(nk, q) 239 240 var max float64 241 wantD := make(map[float64]map[string]struct{}) 242 for _, p := range wantP { 243 if p.Dist > max { 244 max = p.Dist 245 } 246 d, ok := wantD[p.Dist] 247 if !ok { 248 d = make(map[string]struct{}) 249 } 250 d[fmt.Sprint(p.Comparable)] = struct{}{} 251 wantD[p.Dist] = d 252 } 253 gotD := make(map[float64]map[string]struct{}) 254 for _, p := range nk.Heap { 255 if p.Dist > max { 256 t.Errorf("unexpected distance for point %.3f: got:%v want:<=%v", p.Comparable, p.Dist, max) 257 } 258 d, ok := gotD[p.Dist] 259 if !ok { 260 d = make(map[string]struct{}) 261 } 262 d[fmt.Sprint(p.Comparable)] = struct{}{} 263 gotD[p.Dist] = d 264 } 265 266 // If the available number of slots does not fit all the coequal furthest points 267 // we will fail the check. So remove, but check them minimally here. 268 if !reflect.DeepEqual(wantD[max], gotD[max]) { 269 // The best we can do at this stage is confirm that there are an equal number of matches at this distance. 270 if len(gotD[max]) != len(wantD[max]) { 271 t.Errorf("unexpected number of maximal distance points: got:%d want:%d", len(gotD[max]), len(wantD[max])) 272 } 273 delete(wantD, max) 274 delete(gotD, max) 275 } 276 277 if !reflect.DeepEqual(gotD, wantD) { 278 t.Errorf("unexpected result for k=%d query %.3f: got:%v want:%v", k, q, gotD, wantD) 279 } 280 } 281 } 282 } 283 284 var nearestSetDistTests = []Point{ 285 {4, 6}, 286 {7, 5}, 287 {8, 7}, 288 {6, -5}, 289 } 290 291 func TestNearestSetDist(t *testing.T) { 292 tree, err := New(wpData, 3, rand.NewSource(1)) 293 if err != nil { 294 t.Fatalf("unexpected error: %v", err) 295 } 296 for i, q := range nearestSetDistTests { 297 for d := 1.0; d < 100; d += 0.1 { 298 dk := NewDistKeeper(d) 299 tree.NearestSet(dk, q) 300 301 hits := make(map[string]float64) 302 for _, p := range wpData { 303 hits[fmt.Sprint(p)] = p.Distance(q) 304 } 305 306 for _, p := range dk.Heap { 307 var done bool 308 if p.Comparable == nil { 309 done = true 310 continue 311 } 312 delete(hits, fmt.Sprint(p.Comparable)) 313 if done { 314 t.Error("expectedly finished heap iteration") 315 break 316 } 317 dist := p.Comparable.Distance(q) 318 if dist > d { 319 t.Errorf("Test %d: query %v found %v expect %.3f <= %.3f", i, q, p, dist, d) 320 } 321 } 322 323 for p, dist := range hits { 324 if dist <= d { 325 t.Errorf("Test %d: query %v missed %v expect %.3f > %.3f", i, q, p, dist, d) 326 } 327 } 328 } 329 } 330 } 331 332 func TestDo(t *testing.T) { 333 tree, err := New(wpData, 3, rand.NewSource(1)) 334 if err != nil { 335 t.Fatalf("unexpected error: %v", err) 336 } 337 var got []Point 338 fn := func(c Comparable, _ int) (done bool) { 339 got = append(got, c.(Point)) 340 return 341 } 342 killed := tree.Do(fn) 343 344 want := make([]Point, len(wpData)) 345 for i, p := range wpData { 346 want[i] = p.(Point) 347 } 348 sort.Sort(lexical(got)) 349 sort.Sort(lexical(want)) 350 351 if !reflect.DeepEqual(got, want) { 352 t.Errorf("unexpected result from tree iteration: got:%v want:%v", got, want) 353 } 354 if killed { 355 t.Error("tree iteration unexpectedly killed") 356 } 357 } 358 359 type lexical []Point 360 361 func (c lexical) Len() int { return len(c) } 362 func (c lexical) Less(i, j int) bool { 363 a, b := c[i], c[j] 364 l := len(a) 365 if len(b) < l { 366 l = len(b) 367 } 368 for k, v := range a[:l] { 369 if v < b[k] { 370 return true 371 } 372 if v > b[k] { 373 return false 374 } 375 } 376 return len(a) < len(b) 377 } 378 func (c lexical) Swap(i, j int) { c[i], c[j] = c[j], c[i] } 379 380 func BenchmarkNew(b *testing.B) { 381 for _, effort := range []int{0, 10, 100} { 382 b.Run(fmt.Sprintf("New:%d", effort), func(b *testing.B) { 383 rnd := rand.New(rand.NewSource(1)) 384 p := make([]Comparable, 1e5) 385 for i := range p { 386 p[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} 387 } 388 b.ResetTimer() 389 for i := 0; i < b.N; i++ { 390 _, err := New(p, effort, rand.NewSource(1)) 391 if err != nil { 392 b.Fatalf("unexpected error: %v", err) 393 } 394 } 395 }) 396 } 397 } 398 399 func Benchmark(b *testing.B) { 400 var r Comparable 401 var d float64 402 queryBenchmarks := []struct { 403 name string 404 fn func(data []Comparable, tree *Tree, rnd *rand.Rand) func(*testing.B) 405 }{ 406 { 407 name: "NearestBrute", fn: func(data []Comparable, _ *Tree, rnd *rand.Rand) func(b *testing.B) { 408 return func(b *testing.B) { 409 for i := 0; i < b.N; i++ { 410 r, d = nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data) 411 } 412 if r == nil { 413 b.Error("unexpected nil result") 414 } 415 if math.IsNaN(d) { 416 b.Error("unexpected NaN result") 417 } 418 } 419 }, 420 }, 421 { 422 name: "NearestBruteN10", fn: func(data []Comparable, _ *Tree, rnd *rand.Rand) func(b *testing.B) { 423 return func(b *testing.B) { 424 var r []ComparableDist 425 for i := 0; i < b.N; i++ { 426 r = nearestN(10, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data) 427 } 428 if len(r) != 10 { 429 b.Error("unexpected result length", len(r)) 430 } 431 } 432 }, 433 }, 434 { 435 name: "Nearest", fn: func(_ []Comparable, tree *Tree, rnd *rand.Rand) func(b *testing.B) { 436 return func(b *testing.B) { 437 for i := 0; i < b.N; i++ { 438 r, d = tree.Nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}) 439 } 440 if r == nil { 441 b.Error("unexpected nil result") 442 } 443 if math.IsNaN(d) { 444 b.Error("unexpected NaN result") 445 } 446 } 447 }, 448 }, 449 { 450 name: "NearestSetN10", fn: func(_ []Comparable, tree *Tree, rnd *rand.Rand) func(b *testing.B) { 451 return func(b *testing.B) { 452 nk := NewNKeeper(10) 453 for i := 0; i < b.N; i++ { 454 tree.NearestSet(nk, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}) 455 if nk.Len() != 10 { 456 b.Error("unexpected result length") 457 } 458 nk.Heap = nk.Heap[:1] 459 nk.Heap[0] = ComparableDist{Dist: inf} 460 } 461 } 462 }, 463 }, 464 } 465 466 for _, effort := range []int{0, 3, 10, 30, 100, 300} { 467 rnd := rand.New(rand.NewSource(1)) 468 data := make([]Comparable, 1e5) 469 for i := range data { 470 data[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} 471 } 472 tree, err := New(data, effort, rand.NewSource(1)) 473 if err != nil { 474 b.Errorf("unexpected error for effort=%d: %v", effort, err) 475 continue 476 } 477 478 if !tree.Root.isVPTree() { 479 b.Fatal("tree is not vantage point tree") 480 } 481 482 for i := 0; i < 1e3; i++ { 483 q := Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} 484 gotP, gotD := tree.Nearest(q) 485 wantP, wantD := nearest(q, data) 486 if !reflect.DeepEqual(gotP, wantP) { 487 b.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP) 488 } 489 if gotD != wantD { 490 b.Errorf("unexpected distance for query %.3f: got:%v want:%v", q, gotD, wantD) 491 } 492 } 493 494 if b.Failed() && *genDot && tree.Len() <= *dotLimit { 495 err := dotFile(tree, "TestBenches", "") 496 if err != nil { 497 b.Fatalf("failed to write DOT file: %v", err) 498 } 499 return 500 } 501 502 for _, bench := range queryBenchmarks { 503 if strings.Contains(bench.name, "Brute") && effort != 0 { 504 continue 505 } 506 b.Run(fmt.Sprintf("%s:%d", bench.name, effort), bench.fn(data, tree, rnd)) 507 } 508 } 509 } 510 511 func dot(t *Tree, label string) string { 512 if t == nil { 513 return "" 514 } 515 var ( 516 s []string 517 follow func(*Node) 518 ) 519 follow = func(n *Node) { 520 id := uintptr(unsafe.Pointer(n)) 521 c := fmt.Sprintf("%d[label = \"<Closer> |<Elem> %.3f/%.3f|<Further>\"];", 522 id, n.Point, n.Radius) 523 if n.Closer != nil { 524 c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Closer -> \"%d\":Elem [label=%.3f];", 525 id, uintptr(unsafe.Pointer(n.Closer)), n.Point.Distance(n.Closer.Point)) 526 follow(n.Closer) 527 } 528 if n.Further != nil { 529 c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Further -> \"%d\":Elem [label=%.3f];", 530 id, uintptr(unsafe.Pointer(n.Further)), n.Point.Distance(n.Further.Point)) 531 follow(n.Further) 532 } 533 s = append(s, c) 534 } 535 if t.Root != nil { 536 follow(t.Root) 537 } 538 return fmt.Sprintf("digraph %s {\n\tnode [shape=record,height=0.1];\n\t%s\n}\n", 539 label, 540 strings.Join(s, "\n\t"), 541 ) 542 } 543 544 func dotFile(t *Tree, label, dotString string) (err error) { 545 if t == nil && dotString == "" { 546 return 547 } 548 f, err := os.Create(label + ".dot") 549 if err != nil { 550 return 551 } 552 defer f.Close() 553 if dotString == "" { 554 fmt.Fprint(f, dot(t, label)) 555 } else { 556 fmt.Fprint(f, dotString) 557 } 558 return 559 }