gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/spatial/vptree/vptree_test.go (about) 1 // Copyright ©2019 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package vptree 6 7 import ( 8 "flag" 9 "fmt" 10 "math" 11 "os" 12 "reflect" 13 "slices" 14 "sort" 15 "strings" 16 "testing" 17 "unsafe" 18 19 "golang.org/x/exp/rand" 20 21 "gonum.org/v1/gonum/internal/order" 22 ) 23 24 var ( 25 genDot = flag.Bool("dot", false, "generate dot code for failing trees") 26 dotLimit = flag.Int("dotmax", 100, "specify maximum size for tree output for dot format") 27 ) 28 29 var ( 30 // Using example from WP article: https://en.wikipedia.org/w/index.php?title=K-d_tree&oldid=887573572. 31 wpData = []Comparable{ 32 Point{2, 3}, 33 Point{5, 4}, 34 Point{9, 6}, 35 Point{4, 7}, 36 Point{8, 1}, 37 Point{7, 2}, 38 } 39 ) 40 41 var newTests = []struct { 42 data []Comparable 43 effort int 44 seed uint64 45 }{ 46 {data: wpData, effort: 0, seed: 1}, 47 {data: wpData, effort: 1, seed: 1}, 48 {data: wpData, effort: 2, seed: 1}, 49 {data: wpData, effort: 4, seed: 1}, 50 {data: wpData, effort: 8, seed: 1}, 51 {data: []Comparable{Point{2, 3}, Point{5, 4}, Point{9, 6}, Point{5, 4}, Point{8, 1}, Point{7, 2}}, effort: 3, seed: 5555}, 52 } 53 54 func TestNew(t *testing.T) { 55 for i, test := range newTests { 56 var tree *Tree 57 var err error 58 var panicked bool 59 func() { 60 defer func() { 61 if r := recover(); r != nil { 62 panicked = true 63 } 64 }() 65 tree, err = New(test.data, test.effort, rand.NewSource(test.seed)) 66 }() 67 if panicked { 68 t.Errorf("unexpected panic for test %d", i) 69 continue 70 } 71 if err != nil { 72 t.Errorf("unexpected error for test %d: %v", i, err) 73 continue 74 } 75 76 if !tree.Root.isVPTree() { 77 t.Errorf("tree %d is not vp-tree", i) 78 } 79 80 if t.Failed() && *genDot && tree.Len() <= *dotLimit { 81 err := dotFile(tree, fmt.Sprintf("TestNew%d", i), "") 82 if err != nil { 83 t.Fatalf("failed to write DOT file: %v", err) 84 } 85 } 86 } 87 } 88 89 type compFn func(v, radius float64) bool 90 91 func closer(v, radius float64) bool { return v <= radius } 92 func further(v, radius float64) bool { return v >= radius } 93 94 func (n *Node) isVPTree() bool { 95 if n == nil { 96 return true 97 } 98 if !n.Closer.isPartitioned(n.Point, closer, n.Radius) { 99 return false 100 } 101 if !n.Further.isPartitioned(n.Point, further, n.Radius) { 102 return false 103 } 104 return n.Closer.isVPTree() && n.Further.isVPTree() 105 } 106 107 func (n *Node) isPartitioned(vp Comparable, fn compFn, radius float64) bool { 108 if n == nil { 109 return true 110 } 111 if n.Closer != nil && !fn(vp.Distance(n.Closer.Point), radius) { 112 return false 113 } 114 if n.Further != nil && !fn(vp.Distance(n.Further.Point), radius) { 115 return false 116 } 117 return n.Closer.isPartitioned(vp, fn, radius) && n.Further.isPartitioned(vp, fn, radius) 118 } 119 120 func nearest(q Comparable, p []Comparable) (Comparable, float64) { 121 min := q.Distance(p[0]) 122 var r int 123 for i := 1; i < len(p); i++ { 124 d := q.Distance(p[i]) 125 if d < min { 126 min = d 127 r = i 128 } 129 } 130 return p[r], min 131 } 132 133 func TestNearestRandom(t *testing.T) { 134 rnd := rand.New(rand.NewSource(1)) 135 136 const ( 137 min = 0.0 138 max = 1000.0 139 140 dims = 4 141 setSize = 10000 142 ) 143 144 var randData []Comparable 145 for i := 0; i < setSize; i++ { 146 p := make(Point, dims) 147 for j := 0; j < dims; j++ { 148 p[j] = (max-min)*rnd.Float64() + min 149 } 150 randData = append(randData, p) 151 } 152 tree, err := New(randData, 10, rand.NewSource(1)) 153 if err != nil { 154 t.Fatalf("unexpected error: %v", err) 155 } 156 157 for i := 0; i < setSize; i++ { 158 q := make(Point, dims) 159 for j := 0; j < dims; j++ { 160 q[j] = (max-min)*rnd.Float64() + min 161 } 162 163 got, _ := tree.Nearest(q) 164 want, _ := nearest(q, randData) 165 if !reflect.DeepEqual(got, want) { 166 t.Fatalf("unexpected result from query %d %.3f: got:%.3f want:%.3f", i, q, got, want) 167 } 168 } 169 } 170 171 func TestNearest(t *testing.T) { 172 tree, err := New(wpData, 3, rand.NewSource(1)) 173 if err != nil { 174 t.Fatalf("unexpected error: %v", err) 175 } 176 for _, q := range append([]Comparable{ 177 Point{4, 6}, 178 // Point{7, 5}, // Omitted because it is ambiguously finds [9 6] or [5 4]. 179 Point{8, 7}, 180 Point{6, -5}, 181 Point{1e5, 1e5}, 182 Point{1e5, -1e5}, 183 Point{-1e5, 1e5}, 184 Point{-1e5, -1e5}, 185 Point{1e5, 0}, 186 Point{0, -1e5}, 187 Point{0, 1e5}, 188 Point{-1e5, 0}, 189 }, wpData...) { 190 gotP, gotD := tree.Nearest(q) 191 wantP, wantD := nearest(q, wpData) 192 if !reflect.DeepEqual(gotP, wantP) { 193 t.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP) 194 } 195 if gotD != wantD { 196 t.Errorf("unexpected distance for query %.3f : got:%v want:%v", q, gotD, wantD) 197 } 198 } 199 } 200 201 func nearestN(n int, q Comparable, p []Comparable) []ComparableDist { 202 nk := NewNKeeper(n) 203 for i := 0; i < len(p); i++ { 204 nk.Keep(ComparableDist{Comparable: p[i], Dist: q.Distance(p[i])}) 205 } 206 if len(nk.Heap) == 1 { 207 return nk.Heap 208 } 209 sort.Sort(nk) 210 slices.Reverse(nk.Heap) 211 return nk.Heap 212 } 213 214 func TestNearestSetN(t *testing.T) { 215 data := append([]Comparable{ 216 Point{4, 6}, 217 Point{7, 5}, // OK here because we collect N. 218 Point{8, 7}, 219 Point{6, -5}, 220 Point{1e5, 1e5}, 221 Point{1e5, -1e5}, 222 Point{-1e5, 1e5}, 223 Point{-1e5, -1e5}, 224 Point{1e5, 0}, 225 Point{0, -1e5}, 226 Point{0, 1e5}, 227 Point{-1e5, 0}}, 228 wpData[:len(wpData)-1]...) 229 230 tree, err := New(wpData, 3, rand.NewSource(1)) 231 if err != nil { 232 t.Fatalf("unexpected error: %v", err) 233 } 234 for k := 1; k <= len(wpData); k++ { 235 for _, q := range data { 236 wantP := nearestN(k, q, wpData) 237 238 nk := NewNKeeper(k) 239 tree.NearestSet(nk, q) 240 241 var max float64 242 wantD := make(map[float64]map[string]struct{}) 243 for _, p := range wantP { 244 if p.Dist > max { 245 max = p.Dist 246 } 247 d, ok := wantD[p.Dist] 248 if !ok { 249 d = make(map[string]struct{}) 250 } 251 d[fmt.Sprint(p.Comparable)] = struct{}{} 252 wantD[p.Dist] = d 253 } 254 gotD := make(map[float64]map[string]struct{}) 255 for _, p := range nk.Heap { 256 if p.Dist > max { 257 t.Errorf("unexpected distance for point %.3f: got:%v want:<=%v", p.Comparable, p.Dist, max) 258 } 259 d, ok := gotD[p.Dist] 260 if !ok { 261 d = make(map[string]struct{}) 262 } 263 d[fmt.Sprint(p.Comparable)] = struct{}{} 264 gotD[p.Dist] = d 265 } 266 267 // If the available number of slots does not fit all the coequal furthest points 268 // we will fail the check. So remove, but check them minimally here. 269 if !reflect.DeepEqual(wantD[max], gotD[max]) { 270 // The best we can do at this stage is confirm that there are an equal number of matches at this distance. 271 if len(gotD[max]) != len(wantD[max]) { 272 t.Errorf("unexpected number of maximal distance points: got:%d want:%d", len(gotD[max]), len(wantD[max])) 273 } 274 delete(wantD, max) 275 delete(gotD, max) 276 } 277 278 if !reflect.DeepEqual(gotD, wantD) { 279 t.Errorf("unexpected result for k=%d query %.3f: got:%v want:%v", k, q, gotD, wantD) 280 } 281 } 282 } 283 } 284 285 var nearestSetDistTests = []Point{ 286 {4, 6}, 287 {7, 5}, 288 {8, 7}, 289 {6, -5}, 290 } 291 292 func TestNearestSetDist(t *testing.T) { 293 tree, err := New(wpData, 3, rand.NewSource(1)) 294 if err != nil { 295 t.Fatalf("unexpected error: %v", err) 296 } 297 for i, q := range nearestSetDistTests { 298 for d := 1.0; d < 100; d += 0.1 { 299 dk := NewDistKeeper(d) 300 tree.NearestSet(dk, q) 301 302 hits := make(map[string]float64) 303 for _, p := range wpData { 304 hits[fmt.Sprint(p)] = p.Distance(q) 305 } 306 307 for _, p := range dk.Heap { 308 var done bool 309 if p.Comparable == nil { 310 done = true 311 continue 312 } 313 delete(hits, fmt.Sprint(p.Comparable)) 314 if done { 315 t.Error("expectedly finished heap iteration") 316 break 317 } 318 dist := p.Comparable.Distance(q) 319 if dist > d { 320 t.Errorf("Test %d: query %v found %v expect %.3f <= %.3f", i, q, p, dist, d) 321 } 322 } 323 324 for p, dist := range hits { 325 if dist <= d { 326 t.Errorf("Test %d: query %v missed %v expect %.3f > %.3f", i, q, p, dist, d) 327 } 328 } 329 } 330 } 331 } 332 333 func TestDo(t *testing.T) { 334 tree, err := New(wpData, 3, rand.NewSource(1)) 335 if err != nil { 336 t.Fatalf("unexpected error: %v", err) 337 } 338 var got []Point 339 fn := func(c Comparable, _ int) (done bool) { 340 got = append(got, c.(Point)) 341 return 342 } 343 killed := tree.Do(fn) 344 345 want := make([]Point, len(wpData)) 346 for i, p := range wpData { 347 want[i] = p.(Point) 348 } 349 order.BySliceValues(got) 350 order.BySliceValues(want) 351 352 if !reflect.DeepEqual(got, want) { 353 t.Errorf("unexpected result from tree iteration: got:%v want:%v", got, want) 354 } 355 if killed { 356 t.Error("tree iteration unexpectedly killed") 357 } 358 } 359 360 func BenchmarkNew(b *testing.B) { 361 for _, effort := range []int{0, 10, 100} { 362 b.Run(fmt.Sprintf("New:%d", effort), func(b *testing.B) { 363 rnd := rand.New(rand.NewSource(1)) 364 p := make([]Comparable, 1e5) 365 for i := range p { 366 p[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} 367 } 368 b.ResetTimer() 369 for i := 0; i < b.N; i++ { 370 _, err := New(p, effort, rand.NewSource(1)) 371 if err != nil { 372 b.Fatalf("unexpected error: %v", err) 373 } 374 } 375 }) 376 } 377 } 378 379 func Benchmark(b *testing.B) { 380 var r Comparable 381 var d float64 382 queryBenchmarks := []struct { 383 name string 384 fn func(data []Comparable, tree *Tree, rnd *rand.Rand) func(*testing.B) 385 }{ 386 { 387 name: "NearestBrute", fn: func(data []Comparable, _ *Tree, rnd *rand.Rand) func(b *testing.B) { 388 return func(b *testing.B) { 389 for i := 0; i < b.N; i++ { 390 r, d = nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data) 391 } 392 if r == nil { 393 b.Error("unexpected nil result") 394 } 395 if math.IsNaN(d) { 396 b.Error("unexpected NaN result") 397 } 398 } 399 }, 400 }, 401 { 402 name: "NearestBruteN10", fn: func(data []Comparable, _ *Tree, rnd *rand.Rand) func(b *testing.B) { 403 return func(b *testing.B) { 404 var r []ComparableDist 405 for i := 0; i < b.N; i++ { 406 r = nearestN(10, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data) 407 } 408 if len(r) != 10 { 409 b.Error("unexpected result length", len(r)) 410 } 411 } 412 }, 413 }, 414 { 415 name: "Nearest", fn: func(_ []Comparable, tree *Tree, rnd *rand.Rand) func(b *testing.B) { 416 return func(b *testing.B) { 417 for i := 0; i < b.N; i++ { 418 r, d = tree.Nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}) 419 } 420 if r == nil { 421 b.Error("unexpected nil result") 422 } 423 if math.IsNaN(d) { 424 b.Error("unexpected NaN result") 425 } 426 } 427 }, 428 }, 429 { 430 name: "NearestSetN10", fn: func(_ []Comparable, tree *Tree, rnd *rand.Rand) func(b *testing.B) { 431 return func(b *testing.B) { 432 nk := NewNKeeper(10) 433 for i := 0; i < b.N; i++ { 434 tree.NearestSet(nk, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}) 435 if nk.Len() != 10 { 436 b.Error("unexpected result length") 437 } 438 nk.Heap = nk.Heap[:1] 439 nk.Heap[0] = ComparableDist{Dist: inf} 440 } 441 } 442 }, 443 }, 444 } 445 446 for _, effort := range []int{0, 3, 10, 30, 100, 300} { 447 rnd := rand.New(rand.NewSource(1)) 448 data := make([]Comparable, 1e5) 449 for i := range data { 450 data[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} 451 } 452 tree, err := New(data, effort, rand.NewSource(1)) 453 if err != nil { 454 b.Errorf("unexpected error for effort=%d: %v", effort, err) 455 continue 456 } 457 458 if !tree.Root.isVPTree() { 459 b.Fatal("tree is not vantage point tree") 460 } 461 462 for i := 0; i < 1e3; i++ { 463 q := Point{rnd.Float64(), rnd.Float64(), rnd.Float64()} 464 gotP, gotD := tree.Nearest(q) 465 wantP, wantD := nearest(q, data) 466 if !reflect.DeepEqual(gotP, wantP) { 467 b.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP) 468 } 469 if gotD != wantD { 470 b.Errorf("unexpected distance for query %.3f: got:%v want:%v", q, gotD, wantD) 471 } 472 } 473 474 if b.Failed() && *genDot && tree.Len() <= *dotLimit { 475 err := dotFile(tree, "TestBenches", "") 476 if err != nil { 477 b.Fatalf("failed to write DOT file: %v", err) 478 } 479 return 480 } 481 482 for _, bench := range queryBenchmarks { 483 if strings.Contains(bench.name, "Brute") && effort != 0 { 484 continue 485 } 486 b.Run(fmt.Sprintf("%s:%d", bench.name, effort), bench.fn(data, tree, rnd)) 487 } 488 } 489 } 490 491 func dot(t *Tree, label string) string { 492 if t == nil { 493 return "" 494 } 495 var ( 496 s []string 497 follow func(*Node) 498 ) 499 follow = func(n *Node) { 500 id := uintptr(unsafe.Pointer(n)) 501 c := fmt.Sprintf("%d[label = \"<Closer> |<Elem> %.3f/%.3f|<Further>\"];", 502 id, n.Point, n.Radius) 503 if n.Closer != nil { 504 c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Closer -> \"%d\":Elem [label=%.3f];", 505 id, uintptr(unsafe.Pointer(n.Closer)), n.Point.Distance(n.Closer.Point)) 506 follow(n.Closer) 507 } 508 if n.Further != nil { 509 c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Further -> \"%d\":Elem [label=%.3f];", 510 id, uintptr(unsafe.Pointer(n.Further)), n.Point.Distance(n.Further.Point)) 511 follow(n.Further) 512 } 513 s = append(s, c) 514 } 515 if t.Root != nil { 516 follow(t.Root) 517 } 518 return fmt.Sprintf("digraph %s {\n\tnode [shape=record,height=0.1];\n\t%s\n}\n", 519 label, 520 strings.Join(s, "\n\t"), 521 ) 522 } 523 524 func dotFile(t *Tree, label, dotString string) (err error) { 525 if t == nil && dotString == "" { 526 return 527 } 528 f, err := os.Create(label + ".dot") 529 if err != nil { 530 return 531 } 532 defer f.Close() 533 if dotString == "" { 534 fmt.Fprint(f, dot(t, label)) 535 } else { 536 fmt.Fprint(f, dotString) 537 } 538 return 539 }