github.com/gopherd/gonum@v0.0.4/spatial/vptree/vptree_test.go (about)

     1  // Copyright ©2019 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package vptree
     6  
     7  import (
     8  	"flag"
     9  	"fmt"
    10  	"math"
    11  	"os"
    12  	"reflect"
    13  	"sort"
    14  	"strings"
    15  	"testing"
    16  	"unsafe"
    17  
    18  	"math/rand"
    19  )
    20  
    21  var (
    22  	genDot   = flag.Bool("dot", false, "generate dot code for failing trees")
    23  	dotLimit = flag.Int("dotmax", 100, "specify maximum size for tree output for dot format")
    24  )
    25  
    26  var (
    27  	// Using example from WP article: https://en.wikipedia.org/w/index.php?title=K-d_tree&oldid=887573572.
    28  	wpData = []Comparable{
    29  		Point{2, 3},
    30  		Point{5, 4},
    31  		Point{9, 6},
    32  		Point{4, 7},
    33  		Point{8, 1},
    34  		Point{7, 2},
    35  	}
    36  )
    37  
    38  var newTests = []struct {
    39  	data   []Comparable
    40  	effort int
    41  	seed   uint64
    42  }{
    43  	{data: wpData, effort: 0, seed: 1},
    44  	{data: wpData, effort: 1, seed: 1},
    45  	{data: wpData, effort: 2, seed: 1},
    46  	{data: wpData, effort: 4, seed: 1},
    47  	{data: wpData, effort: 8, seed: 1},
    48  	{data: []Comparable{Point{2, 3}, Point{5, 4}, Point{9, 6}, Point{5, 4}, Point{8, 1}, Point{7, 2}}, effort: 3, seed: 5555},
    49  }
    50  
    51  func TestNew(t *testing.T) {
    52  	for i, test := range newTests {
    53  		var tree *Tree
    54  		var err error
    55  		var panicked bool
    56  		func() {
    57  			defer func() {
    58  				if r := recover(); r != nil {
    59  					panicked = true
    60  				}
    61  			}()
    62  			tree, err = New(test.data, test.effort, rand.NewSource(test.seed))
    63  		}()
    64  		if panicked {
    65  			t.Errorf("unexpected panic for test %d", i)
    66  			continue
    67  		}
    68  		if err != nil {
    69  			t.Errorf("unexpected error for test %d: %v", i, err)
    70  			continue
    71  		}
    72  
    73  		if !tree.Root.isVPTree() {
    74  			t.Errorf("tree %d is not vp-tree", i)
    75  		}
    76  
    77  		if t.Failed() && *genDot && tree.Len() <= *dotLimit {
    78  			err := dotFile(tree, fmt.Sprintf("TestNew%d", i), "")
    79  			if err != nil {
    80  				t.Fatalf("failed to write DOT file: %v", err)
    81  			}
    82  		}
    83  	}
    84  }
    85  
    86  type compFn func(v, radius float64) bool
    87  
    88  func closer(v, radius float64) bool  { return v <= radius }
    89  func further(v, radius float64) bool { return v >= radius }
    90  
    91  func (n *Node) isVPTree() bool {
    92  	if n == nil {
    93  		return true
    94  	}
    95  	if !n.Closer.isPartitioned(n.Point, closer, n.Radius) {
    96  		return false
    97  	}
    98  	if !n.Further.isPartitioned(n.Point, further, n.Radius) {
    99  		return false
   100  	}
   101  	return n.Closer.isVPTree() && n.Further.isVPTree()
   102  }
   103  
   104  func (n *Node) isPartitioned(vp Comparable, fn compFn, radius float64) bool {
   105  	if n == nil {
   106  		return true
   107  	}
   108  	if n.Closer != nil && !fn(vp.Distance(n.Closer.Point), radius) {
   109  		return false
   110  	}
   111  	if n.Further != nil && !fn(vp.Distance(n.Further.Point), radius) {
   112  		return false
   113  	}
   114  	return n.Closer.isPartitioned(vp, fn, radius) && n.Further.isPartitioned(vp, fn, radius)
   115  }
   116  
   117  func nearest(q Comparable, p []Comparable) (Comparable, float64) {
   118  	min := q.Distance(p[0])
   119  	var r int
   120  	for i := 1; i < len(p); i++ {
   121  		d := q.Distance(p[i])
   122  		if d < min {
   123  			min = d
   124  			r = i
   125  		}
   126  	}
   127  	return p[r], min
   128  }
   129  
   130  func TestNearestRandom(t *testing.T) {
   131  	rnd := rand.New(rand.NewSource(1))
   132  
   133  	const (
   134  		min = 0.0
   135  		max = 1000.0
   136  
   137  		dims    = 4
   138  		setSize = 10000
   139  	)
   140  
   141  	var randData []Comparable
   142  	for i := 0; i < setSize; i++ {
   143  		p := make(Point, dims)
   144  		for j := 0; j < dims; j++ {
   145  			p[j] = (max-min)*rnd.Float64() + min
   146  		}
   147  		randData = append(randData, p)
   148  	}
   149  	tree, err := New(randData, 10, rand.NewSource(1))
   150  	if err != nil {
   151  		t.Fatalf("unexpected error: %v", err)
   152  	}
   153  
   154  	for i := 0; i < setSize; i++ {
   155  		q := make(Point, dims)
   156  		for j := 0; j < dims; j++ {
   157  			q[j] = (max-min)*rnd.Float64() + min
   158  		}
   159  
   160  		got, _ := tree.Nearest(q)
   161  		want, _ := nearest(q, randData)
   162  		if !reflect.DeepEqual(got, want) {
   163  			t.Fatalf("unexpected result from query %d %.3f: got:%.3f want:%.3f", i, q, got, want)
   164  		}
   165  	}
   166  }
   167  
   168  func TestNearest(t *testing.T) {
   169  	tree, err := New(wpData, 3, rand.NewSource(1))
   170  	if err != nil {
   171  		t.Fatalf("unexpected error: %v", err)
   172  	}
   173  	for _, q := range append([]Comparable{
   174  		Point{4, 6},
   175  		// Point{7, 5}, // Omitted because it is ambiguously finds [9 6] or [5 4].
   176  		Point{8, 7},
   177  		Point{6, -5},
   178  		Point{1e5, 1e5},
   179  		Point{1e5, -1e5},
   180  		Point{-1e5, 1e5},
   181  		Point{-1e5, -1e5},
   182  		Point{1e5, 0},
   183  		Point{0, -1e5},
   184  		Point{0, 1e5},
   185  		Point{-1e5, 0},
   186  	}, wpData...) {
   187  		gotP, gotD := tree.Nearest(q)
   188  		wantP, wantD := nearest(q, wpData)
   189  		if !reflect.DeepEqual(gotP, wantP) {
   190  			t.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP)
   191  		}
   192  		if gotD != wantD {
   193  			t.Errorf("unexpected distance for query %.3f : got:%v want:%v", q, gotD, wantD)
   194  		}
   195  	}
   196  }
   197  
   198  func nearestN(n int, q Comparable, p []Comparable) []ComparableDist {
   199  	nk := NewNKeeper(n)
   200  	for i := 0; i < len(p); i++ {
   201  		nk.Keep(ComparableDist{Comparable: p[i], Dist: q.Distance(p[i])})
   202  	}
   203  	if len(nk.Heap) == 1 {
   204  		return nk.Heap
   205  	}
   206  	sort.Sort(nk)
   207  	for i, j := 0, len(nk.Heap)-1; i < j; i, j = i+1, j-1 {
   208  		nk.Heap[i], nk.Heap[j] = nk.Heap[j], nk.Heap[i]
   209  	}
   210  	return nk.Heap
   211  }
   212  
   213  func TestNearestSetN(t *testing.T) {
   214  	data := append([]Comparable{
   215  		Point{4, 6},
   216  		Point{7, 5}, // OK here because we collect N.
   217  		Point{8, 7},
   218  		Point{6, -5},
   219  		Point{1e5, 1e5},
   220  		Point{1e5, -1e5},
   221  		Point{-1e5, 1e5},
   222  		Point{-1e5, -1e5},
   223  		Point{1e5, 0},
   224  		Point{0, -1e5},
   225  		Point{0, 1e5},
   226  		Point{-1e5, 0}},
   227  		wpData[:len(wpData)-1]...)
   228  
   229  	tree, err := New(wpData, 3, rand.NewSource(1))
   230  	if err != nil {
   231  		t.Fatalf("unexpected error: %v", err)
   232  	}
   233  	for k := 1; k <= len(wpData); k++ {
   234  		for _, q := range data {
   235  			wantP := nearestN(k, q, wpData)
   236  
   237  			nk := NewNKeeper(k)
   238  			tree.NearestSet(nk, q)
   239  
   240  			var max float64
   241  			wantD := make(map[float64]map[string]struct{})
   242  			for _, p := range wantP {
   243  				if p.Dist > max {
   244  					max = p.Dist
   245  				}
   246  				d, ok := wantD[p.Dist]
   247  				if !ok {
   248  					d = make(map[string]struct{})
   249  				}
   250  				d[fmt.Sprint(p.Comparable)] = struct{}{}
   251  				wantD[p.Dist] = d
   252  			}
   253  			gotD := make(map[float64]map[string]struct{})
   254  			for _, p := range nk.Heap {
   255  				if p.Dist > max {
   256  					t.Errorf("unexpected distance for point %.3f: got:%v want:<=%v", p.Comparable, p.Dist, max)
   257  				}
   258  				d, ok := gotD[p.Dist]
   259  				if !ok {
   260  					d = make(map[string]struct{})
   261  				}
   262  				d[fmt.Sprint(p.Comparable)] = struct{}{}
   263  				gotD[p.Dist] = d
   264  			}
   265  
   266  			// If the available number of slots does not fit all the coequal furthest points
   267  			// we will fail the check. So remove, but check them minimally here.
   268  			if !reflect.DeepEqual(wantD[max], gotD[max]) {
   269  				// The best we can do at this stage is confirm that there are an equal number of matches at this distance.
   270  				if len(gotD[max]) != len(wantD[max]) {
   271  					t.Errorf("unexpected number of maximal distance points: got:%d want:%d", len(gotD[max]), len(wantD[max]))
   272  				}
   273  				delete(wantD, max)
   274  				delete(gotD, max)
   275  			}
   276  
   277  			if !reflect.DeepEqual(gotD, wantD) {
   278  				t.Errorf("unexpected result for k=%d query %.3f: got:%v want:%v", k, q, gotD, wantD)
   279  			}
   280  		}
   281  	}
   282  }
   283  
   284  var nearestSetDistTests = []Point{
   285  	{4, 6},
   286  	{7, 5},
   287  	{8, 7},
   288  	{6, -5},
   289  }
   290  
   291  func TestNearestSetDist(t *testing.T) {
   292  	tree, err := New(wpData, 3, rand.NewSource(1))
   293  	if err != nil {
   294  		t.Fatalf("unexpected error: %v", err)
   295  	}
   296  	for i, q := range nearestSetDistTests {
   297  		for d := 1.0; d < 100; d += 0.1 {
   298  			dk := NewDistKeeper(d)
   299  			tree.NearestSet(dk, q)
   300  
   301  			hits := make(map[string]float64)
   302  			for _, p := range wpData {
   303  				hits[fmt.Sprint(p)] = p.Distance(q)
   304  			}
   305  
   306  			for _, p := range dk.Heap {
   307  				var done bool
   308  				if p.Comparable == nil {
   309  					done = true
   310  					continue
   311  				}
   312  				delete(hits, fmt.Sprint(p.Comparable))
   313  				if done {
   314  					t.Error("expectedly finished heap iteration")
   315  					break
   316  				}
   317  				dist := p.Comparable.Distance(q)
   318  				if dist > d {
   319  					t.Errorf("Test %d: query %v found %v expect %.3f <= %.3f", i, q, p, dist, d)
   320  				}
   321  			}
   322  
   323  			for p, dist := range hits {
   324  				if dist <= d {
   325  					t.Errorf("Test %d: query %v missed %v expect %.3f > %.3f", i, q, p, dist, d)
   326  				}
   327  			}
   328  		}
   329  	}
   330  }
   331  
   332  func TestDo(t *testing.T) {
   333  	tree, err := New(wpData, 3, rand.NewSource(1))
   334  	if err != nil {
   335  		t.Fatalf("unexpected error: %v", err)
   336  	}
   337  	var got []Point
   338  	fn := func(c Comparable, _ int) (done bool) {
   339  		got = append(got, c.(Point))
   340  		return
   341  	}
   342  	killed := tree.Do(fn)
   343  
   344  	want := make([]Point, len(wpData))
   345  	for i, p := range wpData {
   346  		want[i] = p.(Point)
   347  	}
   348  	sort.Sort(lexical(got))
   349  	sort.Sort(lexical(want))
   350  
   351  	if !reflect.DeepEqual(got, want) {
   352  		t.Errorf("unexpected result from tree iteration: got:%v want:%v", got, want)
   353  	}
   354  	if killed {
   355  		t.Error("tree iteration unexpectedly killed")
   356  	}
   357  }
   358  
   359  type lexical []Point
   360  
   361  func (c lexical) Len() int { return len(c) }
   362  func (c lexical) Less(i, j int) bool {
   363  	a, b := c[i], c[j]
   364  	l := len(a)
   365  	if len(b) < l {
   366  		l = len(b)
   367  	}
   368  	for k, v := range a[:l] {
   369  		if v < b[k] {
   370  			return true
   371  		}
   372  		if v > b[k] {
   373  			return false
   374  		}
   375  	}
   376  	return len(a) < len(b)
   377  }
   378  func (c lexical) Swap(i, j int) { c[i], c[j] = c[j], c[i] }
   379  
   380  func BenchmarkNew(b *testing.B) {
   381  	for _, effort := range []int{0, 10, 100} {
   382  		b.Run(fmt.Sprintf("New:%d", effort), func(b *testing.B) {
   383  			rnd := rand.New(rand.NewSource(1))
   384  			p := make([]Comparable, 1e5)
   385  			for i := range p {
   386  				p[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}
   387  			}
   388  			b.ResetTimer()
   389  			for i := 0; i < b.N; i++ {
   390  				_, err := New(p, effort, rand.NewSource(1))
   391  				if err != nil {
   392  					b.Fatalf("unexpected error: %v", err)
   393  				}
   394  			}
   395  		})
   396  	}
   397  }
   398  
   399  func Benchmark(b *testing.B) {
   400  	var r Comparable
   401  	var d float64
   402  	queryBenchmarks := []struct {
   403  		name string
   404  		fn   func(data []Comparable, tree *Tree, rnd *rand.Rand) func(*testing.B)
   405  	}{
   406  		{
   407  			name: "NearestBrute", fn: func(data []Comparable, _ *Tree, rnd *rand.Rand) func(b *testing.B) {
   408  				return func(b *testing.B) {
   409  					for i := 0; i < b.N; i++ {
   410  						r, d = nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data)
   411  					}
   412  					if r == nil {
   413  						b.Error("unexpected nil result")
   414  					}
   415  					if math.IsNaN(d) {
   416  						b.Error("unexpected NaN result")
   417  					}
   418  				}
   419  			},
   420  		},
   421  		{
   422  			name: "NearestBruteN10", fn: func(data []Comparable, _ *Tree, rnd *rand.Rand) func(b *testing.B) {
   423  				return func(b *testing.B) {
   424  					var r []ComparableDist
   425  					for i := 0; i < b.N; i++ {
   426  						r = nearestN(10, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data)
   427  					}
   428  					if len(r) != 10 {
   429  						b.Error("unexpected result length", len(r))
   430  					}
   431  				}
   432  			},
   433  		},
   434  		{
   435  			name: "Nearest", fn: func(_ []Comparable, tree *Tree, rnd *rand.Rand) func(b *testing.B) {
   436  				return func(b *testing.B) {
   437  					for i := 0; i < b.N; i++ {
   438  						r, d = tree.Nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()})
   439  					}
   440  					if r == nil {
   441  						b.Error("unexpected nil result")
   442  					}
   443  					if math.IsNaN(d) {
   444  						b.Error("unexpected NaN result")
   445  					}
   446  				}
   447  			},
   448  		},
   449  		{
   450  			name: "NearestSetN10", fn: func(_ []Comparable, tree *Tree, rnd *rand.Rand) func(b *testing.B) {
   451  				return func(b *testing.B) {
   452  					nk := NewNKeeper(10)
   453  					for i := 0; i < b.N; i++ {
   454  						tree.NearestSet(nk, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()})
   455  						if nk.Len() != 10 {
   456  							b.Error("unexpected result length")
   457  						}
   458  						nk.Heap = nk.Heap[:1]
   459  						nk.Heap[0] = ComparableDist{Dist: inf}
   460  					}
   461  				}
   462  			},
   463  		},
   464  	}
   465  
   466  	for _, effort := range []int{0, 3, 10, 30, 100, 300} {
   467  		rnd := rand.New(rand.NewSource(1))
   468  		data := make([]Comparable, 1e5)
   469  		for i := range data {
   470  			data[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}
   471  		}
   472  		tree, err := New(data, effort, rand.NewSource(1))
   473  		if err != nil {
   474  			b.Errorf("unexpected error for effort=%d: %v", effort, err)
   475  			continue
   476  		}
   477  
   478  		if !tree.Root.isVPTree() {
   479  			b.Fatal("tree is not vantage point tree")
   480  		}
   481  
   482  		for i := 0; i < 1e3; i++ {
   483  			q := Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}
   484  			gotP, gotD := tree.Nearest(q)
   485  			wantP, wantD := nearest(q, data)
   486  			if !reflect.DeepEqual(gotP, wantP) {
   487  				b.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP)
   488  			}
   489  			if gotD != wantD {
   490  				b.Errorf("unexpected distance for query %.3f: got:%v want:%v", q, gotD, wantD)
   491  			}
   492  		}
   493  
   494  		if b.Failed() && *genDot && tree.Len() <= *dotLimit {
   495  			err := dotFile(tree, "TestBenches", "")
   496  			if err != nil {
   497  				b.Fatalf("failed to write DOT file: %v", err)
   498  			}
   499  			return
   500  		}
   501  
   502  		for _, bench := range queryBenchmarks {
   503  			if strings.Contains(bench.name, "Brute") && effort != 0 {
   504  				continue
   505  			}
   506  			b.Run(fmt.Sprintf("%s:%d", bench.name, effort), bench.fn(data, tree, rnd))
   507  		}
   508  	}
   509  }
   510  
   511  func dot(t *Tree, label string) string {
   512  	if t == nil {
   513  		return ""
   514  	}
   515  	var (
   516  		s      []string
   517  		follow func(*Node)
   518  	)
   519  	follow = func(n *Node) {
   520  		id := uintptr(unsafe.Pointer(n))
   521  		c := fmt.Sprintf("%d[label = \"<Closer> |<Elem> %.3f/%.3f|<Further>\"];",
   522  			id, n.Point, n.Radius)
   523  		if n.Closer != nil {
   524  			c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Closer -> \"%d\":Elem [label=%.3f];",
   525  				id, uintptr(unsafe.Pointer(n.Closer)), n.Point.Distance(n.Closer.Point))
   526  			follow(n.Closer)
   527  		}
   528  		if n.Further != nil {
   529  			c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Further -> \"%d\":Elem [label=%.3f];",
   530  				id, uintptr(unsafe.Pointer(n.Further)), n.Point.Distance(n.Further.Point))
   531  			follow(n.Further)
   532  		}
   533  		s = append(s, c)
   534  	}
   535  	if t.Root != nil {
   536  		follow(t.Root)
   537  	}
   538  	return fmt.Sprintf("digraph %s {\n\tnode [shape=record,height=0.1];\n\t%s\n}\n",
   539  		label,
   540  		strings.Join(s, "\n\t"),
   541  	)
   542  }
   543  
   544  func dotFile(t *Tree, label, dotString string) (err error) {
   545  	if t == nil && dotString == "" {
   546  		return
   547  	}
   548  	f, err := os.Create(label + ".dot")
   549  	if err != nil {
   550  		return
   551  	}
   552  	defer f.Close()
   553  	if dotString == "" {
   554  		fmt.Fprint(f, dot(t, label))
   555  	} else {
   556  		fmt.Fprint(f, dotString)
   557  	}
   558  	return
   559  }