gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/spatial/vptree/vptree_test.go (about)

     1  // Copyright ©2019 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package vptree
     6  
     7  import (
     8  	"flag"
     9  	"fmt"
    10  	"math"
    11  	"os"
    12  	"reflect"
    13  	"slices"
    14  	"sort"
    15  	"strings"
    16  	"testing"
    17  	"unsafe"
    18  
    19  	"golang.org/x/exp/rand"
    20  
    21  	"gonum.org/v1/gonum/internal/order"
    22  )
    23  
    24  var (
    25  	genDot   = flag.Bool("dot", false, "generate dot code for failing trees")
    26  	dotLimit = flag.Int("dotmax", 100, "specify maximum size for tree output for dot format")
    27  )
    28  
    29  var (
    30  	// Using example from WP article: https://en.wikipedia.org/w/index.php?title=K-d_tree&oldid=887573572.
    31  	wpData = []Comparable{
    32  		Point{2, 3},
    33  		Point{5, 4},
    34  		Point{9, 6},
    35  		Point{4, 7},
    36  		Point{8, 1},
    37  		Point{7, 2},
    38  	}
    39  )
    40  
    41  var newTests = []struct {
    42  	data   []Comparable
    43  	effort int
    44  	seed   uint64
    45  }{
    46  	{data: wpData, effort: 0, seed: 1},
    47  	{data: wpData, effort: 1, seed: 1},
    48  	{data: wpData, effort: 2, seed: 1},
    49  	{data: wpData, effort: 4, seed: 1},
    50  	{data: wpData, effort: 8, seed: 1},
    51  	{data: []Comparable{Point{2, 3}, Point{5, 4}, Point{9, 6}, Point{5, 4}, Point{8, 1}, Point{7, 2}}, effort: 3, seed: 5555},
    52  }
    53  
    54  func TestNew(t *testing.T) {
    55  	for i, test := range newTests {
    56  		var tree *Tree
    57  		var err error
    58  		var panicked bool
    59  		func() {
    60  			defer func() {
    61  				if r := recover(); r != nil {
    62  					panicked = true
    63  				}
    64  			}()
    65  			tree, err = New(test.data, test.effort, rand.NewSource(test.seed))
    66  		}()
    67  		if panicked {
    68  			t.Errorf("unexpected panic for test %d", i)
    69  			continue
    70  		}
    71  		if err != nil {
    72  			t.Errorf("unexpected error for test %d: %v", i, err)
    73  			continue
    74  		}
    75  
    76  		if !tree.Root.isVPTree() {
    77  			t.Errorf("tree %d is not vp-tree", i)
    78  		}
    79  
    80  		if t.Failed() && *genDot && tree.Len() <= *dotLimit {
    81  			err := dotFile(tree, fmt.Sprintf("TestNew%d", i), "")
    82  			if err != nil {
    83  				t.Fatalf("failed to write DOT file: %v", err)
    84  			}
    85  		}
    86  	}
    87  }
    88  
    89  type compFn func(v, radius float64) bool
    90  
    91  func closer(v, radius float64) bool  { return v <= radius }
    92  func further(v, radius float64) bool { return v >= radius }
    93  
    94  func (n *Node) isVPTree() bool {
    95  	if n == nil {
    96  		return true
    97  	}
    98  	if !n.Closer.isPartitioned(n.Point, closer, n.Radius) {
    99  		return false
   100  	}
   101  	if !n.Further.isPartitioned(n.Point, further, n.Radius) {
   102  		return false
   103  	}
   104  	return n.Closer.isVPTree() && n.Further.isVPTree()
   105  }
   106  
   107  func (n *Node) isPartitioned(vp Comparable, fn compFn, radius float64) bool {
   108  	if n == nil {
   109  		return true
   110  	}
   111  	if n.Closer != nil && !fn(vp.Distance(n.Closer.Point), radius) {
   112  		return false
   113  	}
   114  	if n.Further != nil && !fn(vp.Distance(n.Further.Point), radius) {
   115  		return false
   116  	}
   117  	return n.Closer.isPartitioned(vp, fn, radius) && n.Further.isPartitioned(vp, fn, radius)
   118  }
   119  
   120  func nearest(q Comparable, p []Comparable) (Comparable, float64) {
   121  	min := q.Distance(p[0])
   122  	var r int
   123  	for i := 1; i < len(p); i++ {
   124  		d := q.Distance(p[i])
   125  		if d < min {
   126  			min = d
   127  			r = i
   128  		}
   129  	}
   130  	return p[r], min
   131  }
   132  
   133  func TestNearestRandom(t *testing.T) {
   134  	rnd := rand.New(rand.NewSource(1))
   135  
   136  	const (
   137  		min = 0.0
   138  		max = 1000.0
   139  
   140  		dims    = 4
   141  		setSize = 10000
   142  	)
   143  
   144  	var randData []Comparable
   145  	for i := 0; i < setSize; i++ {
   146  		p := make(Point, dims)
   147  		for j := 0; j < dims; j++ {
   148  			p[j] = (max-min)*rnd.Float64() + min
   149  		}
   150  		randData = append(randData, p)
   151  	}
   152  	tree, err := New(randData, 10, rand.NewSource(1))
   153  	if err != nil {
   154  		t.Fatalf("unexpected error: %v", err)
   155  	}
   156  
   157  	for i := 0; i < setSize; i++ {
   158  		q := make(Point, dims)
   159  		for j := 0; j < dims; j++ {
   160  			q[j] = (max-min)*rnd.Float64() + min
   161  		}
   162  
   163  		got, _ := tree.Nearest(q)
   164  		want, _ := nearest(q, randData)
   165  		if !reflect.DeepEqual(got, want) {
   166  			t.Fatalf("unexpected result from query %d %.3f: got:%.3f want:%.3f", i, q, got, want)
   167  		}
   168  	}
   169  }
   170  
   171  func TestNearest(t *testing.T) {
   172  	tree, err := New(wpData, 3, rand.NewSource(1))
   173  	if err != nil {
   174  		t.Fatalf("unexpected error: %v", err)
   175  	}
   176  	for _, q := range append([]Comparable{
   177  		Point{4, 6},
   178  		// Point{7, 5}, // Omitted because it is ambiguously finds [9 6] or [5 4].
   179  		Point{8, 7},
   180  		Point{6, -5},
   181  		Point{1e5, 1e5},
   182  		Point{1e5, -1e5},
   183  		Point{-1e5, 1e5},
   184  		Point{-1e5, -1e5},
   185  		Point{1e5, 0},
   186  		Point{0, -1e5},
   187  		Point{0, 1e5},
   188  		Point{-1e5, 0},
   189  	}, wpData...) {
   190  		gotP, gotD := tree.Nearest(q)
   191  		wantP, wantD := nearest(q, wpData)
   192  		if !reflect.DeepEqual(gotP, wantP) {
   193  			t.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP)
   194  		}
   195  		if gotD != wantD {
   196  			t.Errorf("unexpected distance for query %.3f : got:%v want:%v", q, gotD, wantD)
   197  		}
   198  	}
   199  }
   200  
   201  func nearestN(n int, q Comparable, p []Comparable) []ComparableDist {
   202  	nk := NewNKeeper(n)
   203  	for i := 0; i < len(p); i++ {
   204  		nk.Keep(ComparableDist{Comparable: p[i], Dist: q.Distance(p[i])})
   205  	}
   206  	if len(nk.Heap) == 1 {
   207  		return nk.Heap
   208  	}
   209  	sort.Sort(nk)
   210  	slices.Reverse(nk.Heap)
   211  	return nk.Heap
   212  }
   213  
   214  func TestNearestSetN(t *testing.T) {
   215  	data := append([]Comparable{
   216  		Point{4, 6},
   217  		Point{7, 5}, // OK here because we collect N.
   218  		Point{8, 7},
   219  		Point{6, -5},
   220  		Point{1e5, 1e5},
   221  		Point{1e5, -1e5},
   222  		Point{-1e5, 1e5},
   223  		Point{-1e5, -1e5},
   224  		Point{1e5, 0},
   225  		Point{0, -1e5},
   226  		Point{0, 1e5},
   227  		Point{-1e5, 0}},
   228  		wpData[:len(wpData)-1]...)
   229  
   230  	tree, err := New(wpData, 3, rand.NewSource(1))
   231  	if err != nil {
   232  		t.Fatalf("unexpected error: %v", err)
   233  	}
   234  	for k := 1; k <= len(wpData); k++ {
   235  		for _, q := range data {
   236  			wantP := nearestN(k, q, wpData)
   237  
   238  			nk := NewNKeeper(k)
   239  			tree.NearestSet(nk, q)
   240  
   241  			var max float64
   242  			wantD := make(map[float64]map[string]struct{})
   243  			for _, p := range wantP {
   244  				if p.Dist > max {
   245  					max = p.Dist
   246  				}
   247  				d, ok := wantD[p.Dist]
   248  				if !ok {
   249  					d = make(map[string]struct{})
   250  				}
   251  				d[fmt.Sprint(p.Comparable)] = struct{}{}
   252  				wantD[p.Dist] = d
   253  			}
   254  			gotD := make(map[float64]map[string]struct{})
   255  			for _, p := range nk.Heap {
   256  				if p.Dist > max {
   257  					t.Errorf("unexpected distance for point %.3f: got:%v want:<=%v", p.Comparable, p.Dist, max)
   258  				}
   259  				d, ok := gotD[p.Dist]
   260  				if !ok {
   261  					d = make(map[string]struct{})
   262  				}
   263  				d[fmt.Sprint(p.Comparable)] = struct{}{}
   264  				gotD[p.Dist] = d
   265  			}
   266  
   267  			// If the available number of slots does not fit all the coequal furthest points
   268  			// we will fail the check. So remove, but check them minimally here.
   269  			if !reflect.DeepEqual(wantD[max], gotD[max]) {
   270  				// The best we can do at this stage is confirm that there are an equal number of matches at this distance.
   271  				if len(gotD[max]) != len(wantD[max]) {
   272  					t.Errorf("unexpected number of maximal distance points: got:%d want:%d", len(gotD[max]), len(wantD[max]))
   273  				}
   274  				delete(wantD, max)
   275  				delete(gotD, max)
   276  			}
   277  
   278  			if !reflect.DeepEqual(gotD, wantD) {
   279  				t.Errorf("unexpected result for k=%d query %.3f: got:%v want:%v", k, q, gotD, wantD)
   280  			}
   281  		}
   282  	}
   283  }
   284  
   285  var nearestSetDistTests = []Point{
   286  	{4, 6},
   287  	{7, 5},
   288  	{8, 7},
   289  	{6, -5},
   290  }
   291  
   292  func TestNearestSetDist(t *testing.T) {
   293  	tree, err := New(wpData, 3, rand.NewSource(1))
   294  	if err != nil {
   295  		t.Fatalf("unexpected error: %v", err)
   296  	}
   297  	for i, q := range nearestSetDistTests {
   298  		for d := 1.0; d < 100; d += 0.1 {
   299  			dk := NewDistKeeper(d)
   300  			tree.NearestSet(dk, q)
   301  
   302  			hits := make(map[string]float64)
   303  			for _, p := range wpData {
   304  				hits[fmt.Sprint(p)] = p.Distance(q)
   305  			}
   306  
   307  			for _, p := range dk.Heap {
   308  				var done bool
   309  				if p.Comparable == nil {
   310  					done = true
   311  					continue
   312  				}
   313  				delete(hits, fmt.Sprint(p.Comparable))
   314  				if done {
   315  					t.Error("expectedly finished heap iteration")
   316  					break
   317  				}
   318  				dist := p.Comparable.Distance(q)
   319  				if dist > d {
   320  					t.Errorf("Test %d: query %v found %v expect %.3f <= %.3f", i, q, p, dist, d)
   321  				}
   322  			}
   323  
   324  			for p, dist := range hits {
   325  				if dist <= d {
   326  					t.Errorf("Test %d: query %v missed %v expect %.3f > %.3f", i, q, p, dist, d)
   327  				}
   328  			}
   329  		}
   330  	}
   331  }
   332  
   333  func TestDo(t *testing.T) {
   334  	tree, err := New(wpData, 3, rand.NewSource(1))
   335  	if err != nil {
   336  		t.Fatalf("unexpected error: %v", err)
   337  	}
   338  	var got []Point
   339  	fn := func(c Comparable, _ int) (done bool) {
   340  		got = append(got, c.(Point))
   341  		return
   342  	}
   343  	killed := tree.Do(fn)
   344  
   345  	want := make([]Point, len(wpData))
   346  	for i, p := range wpData {
   347  		want[i] = p.(Point)
   348  	}
   349  	order.BySliceValues(got)
   350  	order.BySliceValues(want)
   351  
   352  	if !reflect.DeepEqual(got, want) {
   353  		t.Errorf("unexpected result from tree iteration: got:%v want:%v", got, want)
   354  	}
   355  	if killed {
   356  		t.Error("tree iteration unexpectedly killed")
   357  	}
   358  }
   359  
   360  func BenchmarkNew(b *testing.B) {
   361  	for _, effort := range []int{0, 10, 100} {
   362  		b.Run(fmt.Sprintf("New:%d", effort), func(b *testing.B) {
   363  			rnd := rand.New(rand.NewSource(1))
   364  			p := make([]Comparable, 1e5)
   365  			for i := range p {
   366  				p[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}
   367  			}
   368  			b.ResetTimer()
   369  			for i := 0; i < b.N; i++ {
   370  				_, err := New(p, effort, rand.NewSource(1))
   371  				if err != nil {
   372  					b.Fatalf("unexpected error: %v", err)
   373  				}
   374  			}
   375  		})
   376  	}
   377  }
   378  
   379  func Benchmark(b *testing.B) {
   380  	var r Comparable
   381  	var d float64
   382  	queryBenchmarks := []struct {
   383  		name string
   384  		fn   func(data []Comparable, tree *Tree, rnd *rand.Rand) func(*testing.B)
   385  	}{
   386  		{
   387  			name: "NearestBrute", fn: func(data []Comparable, _ *Tree, rnd *rand.Rand) func(b *testing.B) {
   388  				return func(b *testing.B) {
   389  					for i := 0; i < b.N; i++ {
   390  						r, d = nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data)
   391  					}
   392  					if r == nil {
   393  						b.Error("unexpected nil result")
   394  					}
   395  					if math.IsNaN(d) {
   396  						b.Error("unexpected NaN result")
   397  					}
   398  				}
   399  			},
   400  		},
   401  		{
   402  			name: "NearestBruteN10", fn: func(data []Comparable, _ *Tree, rnd *rand.Rand) func(b *testing.B) {
   403  				return func(b *testing.B) {
   404  					var r []ComparableDist
   405  					for i := 0; i < b.N; i++ {
   406  						r = nearestN(10, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data)
   407  					}
   408  					if len(r) != 10 {
   409  						b.Error("unexpected result length", len(r))
   410  					}
   411  				}
   412  			},
   413  		},
   414  		{
   415  			name: "Nearest", fn: func(_ []Comparable, tree *Tree, rnd *rand.Rand) func(b *testing.B) {
   416  				return func(b *testing.B) {
   417  					for i := 0; i < b.N; i++ {
   418  						r, d = tree.Nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()})
   419  					}
   420  					if r == nil {
   421  						b.Error("unexpected nil result")
   422  					}
   423  					if math.IsNaN(d) {
   424  						b.Error("unexpected NaN result")
   425  					}
   426  				}
   427  			},
   428  		},
   429  		{
   430  			name: "NearestSetN10", fn: func(_ []Comparable, tree *Tree, rnd *rand.Rand) func(b *testing.B) {
   431  				return func(b *testing.B) {
   432  					nk := NewNKeeper(10)
   433  					for i := 0; i < b.N; i++ {
   434  						tree.NearestSet(nk, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()})
   435  						if nk.Len() != 10 {
   436  							b.Error("unexpected result length")
   437  						}
   438  						nk.Heap = nk.Heap[:1]
   439  						nk.Heap[0] = ComparableDist{Dist: inf}
   440  					}
   441  				}
   442  			},
   443  		},
   444  	}
   445  
   446  	for _, effort := range []int{0, 3, 10, 30, 100, 300} {
   447  		rnd := rand.New(rand.NewSource(1))
   448  		data := make([]Comparable, 1e5)
   449  		for i := range data {
   450  			data[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}
   451  		}
   452  		tree, err := New(data, effort, rand.NewSource(1))
   453  		if err != nil {
   454  			b.Errorf("unexpected error for effort=%d: %v", effort, err)
   455  			continue
   456  		}
   457  
   458  		if !tree.Root.isVPTree() {
   459  			b.Fatal("tree is not vantage point tree")
   460  		}
   461  
   462  		for i := 0; i < 1e3; i++ {
   463  			q := Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}
   464  			gotP, gotD := tree.Nearest(q)
   465  			wantP, wantD := nearest(q, data)
   466  			if !reflect.DeepEqual(gotP, wantP) {
   467  				b.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP)
   468  			}
   469  			if gotD != wantD {
   470  				b.Errorf("unexpected distance for query %.3f: got:%v want:%v", q, gotD, wantD)
   471  			}
   472  		}
   473  
   474  		if b.Failed() && *genDot && tree.Len() <= *dotLimit {
   475  			err := dotFile(tree, "TestBenches", "")
   476  			if err != nil {
   477  				b.Fatalf("failed to write DOT file: %v", err)
   478  			}
   479  			return
   480  		}
   481  
   482  		for _, bench := range queryBenchmarks {
   483  			if strings.Contains(bench.name, "Brute") && effort != 0 {
   484  				continue
   485  			}
   486  			b.Run(fmt.Sprintf("%s:%d", bench.name, effort), bench.fn(data, tree, rnd))
   487  		}
   488  	}
   489  }
   490  
   491  func dot(t *Tree, label string) string {
   492  	if t == nil {
   493  		return ""
   494  	}
   495  	var (
   496  		s      []string
   497  		follow func(*Node)
   498  	)
   499  	follow = func(n *Node) {
   500  		id := uintptr(unsafe.Pointer(n))
   501  		c := fmt.Sprintf("%d[label = \"<Closer> |<Elem> %.3f/%.3f|<Further>\"];",
   502  			id, n.Point, n.Radius)
   503  		if n.Closer != nil {
   504  			c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Closer -> \"%d\":Elem [label=%.3f];",
   505  				id, uintptr(unsafe.Pointer(n.Closer)), n.Point.Distance(n.Closer.Point))
   506  			follow(n.Closer)
   507  		}
   508  		if n.Further != nil {
   509  			c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Further -> \"%d\":Elem [label=%.3f];",
   510  				id, uintptr(unsafe.Pointer(n.Further)), n.Point.Distance(n.Further.Point))
   511  			follow(n.Further)
   512  		}
   513  		s = append(s, c)
   514  	}
   515  	if t.Root != nil {
   516  		follow(t.Root)
   517  	}
   518  	return fmt.Sprintf("digraph %s {\n\tnode [shape=record,height=0.1];\n\t%s\n}\n",
   519  		label,
   520  		strings.Join(s, "\n\t"),
   521  	)
   522  }
   523  
   524  func dotFile(t *Tree, label, dotString string) (err error) {
   525  	if t == nil && dotString == "" {
   526  		return
   527  	}
   528  	f, err := os.Create(label + ".dot")
   529  	if err != nil {
   530  		return
   531  	}
   532  	defer f.Close()
   533  	if dotString == "" {
   534  		fmt.Fprint(f, dot(t, label))
   535  	} else {
   536  		fmt.Fprint(f, dotString)
   537  	}
   538  	return
   539  }