github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/spatial/vptree/vptree_test.go (about)

     1  // Copyright ©2019 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package vptree
     6  
     7  import (
     8  	"flag"
     9  	"fmt"
    10  	"math"
    11  	"os"
    12  	"reflect"
    13  	"sort"
    14  	"strings"
    15  	"testing"
    16  	"unsafe"
    17  
    18  	"golang.org/x/exp/rand"
    19  )
    20  
    21  var (
    22  	genDot   = flag.Bool("dot", false, "generate dot code for failing trees")
    23  	dotLimit = flag.Int("dotmax", 100, "specify maximum size for tree output for dot format")
    24  )
    25  
    26  var (
    27  	// Using example from WP article: https://en.wikipedia.org/w/index.php?title=K-d_tree&oldid=887573572.
    28  	wpData = []Comparable{
    29  		Point{2, 3},
    30  		Point{5, 4},
    31  		Point{9, 6},
    32  		Point{4, 7},
    33  		Point{8, 1},
    34  		Point{7, 2},
    35  	}
    36  )
    37  
    38  var newTests = []struct {
    39  	data   []Comparable
    40  	effort int
    41  }{
    42  	{data: wpData, effort: 0},
    43  	{data: wpData, effort: 1},
    44  	{data: wpData, effort: 2},
    45  	{data: wpData, effort: 4},
    46  	{data: wpData, effort: 8},
    47  }
    48  
    49  func TestNew(t *testing.T) {
    50  	for i, test := range newTests {
    51  		var tree *Tree
    52  		var err error
    53  		var panicked bool
    54  		func() {
    55  			defer func() {
    56  				if r := recover(); r != nil {
    57  					panicked = true
    58  				}
    59  			}()
    60  			tree, err = New(test.data, test.effort, rand.NewSource(1))
    61  		}()
    62  		if panicked {
    63  			t.Errorf("unexpected panic for test %d", i)
    64  			continue
    65  		}
    66  		if err != nil {
    67  			t.Errorf("unexpected error for test %d: %v", i, err)
    68  			continue
    69  		}
    70  
    71  		if !tree.Root.isVPTree() {
    72  			t.Errorf("tree %d is not vp-tree", i)
    73  		}
    74  
    75  		if t.Failed() && *genDot && tree.Len() <= *dotLimit {
    76  			err := dotFile(tree, fmt.Sprintf("TestNew%d", i), "")
    77  			if err != nil {
    78  				t.Fatalf("failed to write DOT file: %v", err)
    79  			}
    80  		}
    81  	}
    82  }
    83  
    84  type compFn func(v, radius float64) bool
    85  
    86  func closer(v, radius float64) bool  { return v <= radius }
    87  func further(v, radius float64) bool { return v >= radius }
    88  
    89  func (n *Node) isVPTree() bool {
    90  	if n == nil {
    91  		return true
    92  	}
    93  	if !n.Closer.isPartitioned(n.Point, closer, n.Radius) {
    94  		return false
    95  	}
    96  	if !n.Further.isPartitioned(n.Point, further, n.Radius) {
    97  		return false
    98  	}
    99  	return n.Closer.isVPTree() && n.Further.isVPTree()
   100  }
   101  
   102  func (n *Node) isPartitioned(vp Comparable, fn compFn, radius float64) bool {
   103  	if n == nil {
   104  		return true
   105  	}
   106  	if n.Closer != nil && !fn(vp.Distance(n.Closer.Point), radius) {
   107  		return false
   108  	}
   109  	if n.Further != nil && !fn(vp.Distance(n.Further.Point), radius) {
   110  		return false
   111  	}
   112  	return n.Closer.isPartitioned(vp, fn, radius) && n.Further.isPartitioned(vp, fn, radius)
   113  }
   114  
   115  func nearest(q Comparable, p []Comparable) (Comparable, float64) {
   116  	min := q.Distance(p[0])
   117  	var r int
   118  	for i := 1; i < len(p); i++ {
   119  		d := q.Distance(p[i])
   120  		if d < min {
   121  			min = d
   122  			r = i
   123  		}
   124  	}
   125  	return p[r], min
   126  }
   127  
   128  func TestNearestRandom(t *testing.T) {
   129  	rnd := rand.New(rand.NewSource(1))
   130  
   131  	const (
   132  		min = 0.0
   133  		max = 1000.0
   134  
   135  		dims    = 4
   136  		setSize = 10000
   137  	)
   138  
   139  	var randData []Comparable
   140  	for i := 0; i < setSize; i++ {
   141  		p := make(Point, dims)
   142  		for j := 0; j < dims; j++ {
   143  			p[j] = (max-min)*rnd.Float64() + min
   144  		}
   145  		randData = append(randData, p)
   146  	}
   147  	tree, err := New(randData, 10, rand.NewSource(1))
   148  	if err != nil {
   149  		t.Fatalf("unexpected error: %v", err)
   150  	}
   151  
   152  	for i := 0; i < setSize; i++ {
   153  		q := make(Point, dims)
   154  		for j := 0; j < dims; j++ {
   155  			q[j] = (max-min)*rnd.Float64() + min
   156  		}
   157  
   158  		got, _ := tree.Nearest(q)
   159  		want, _ := nearest(q, randData)
   160  		if !reflect.DeepEqual(got, want) {
   161  			t.Fatalf("unexpected result from query %d %.3f: got:%.3f want:%.3f", i, q, got, want)
   162  		}
   163  	}
   164  }
   165  
   166  func TestNearest(t *testing.T) {
   167  	tree, err := New(wpData, 3, rand.NewSource(1))
   168  	if err != nil {
   169  		t.Fatalf("unexpected error: %v", err)
   170  	}
   171  	for _, q := range append([]Comparable{
   172  		Point{4, 6},
   173  		// Point{7, 5}, // Omitted because it is ambiguously finds [9 6] or [5 4].
   174  		Point{8, 7},
   175  		Point{6, -5},
   176  		Point{1e5, 1e5},
   177  		Point{1e5, -1e5},
   178  		Point{-1e5, 1e5},
   179  		Point{-1e5, -1e5},
   180  		Point{1e5, 0},
   181  		Point{0, -1e5},
   182  		Point{0, 1e5},
   183  		Point{-1e5, 0},
   184  	}, wpData...) {
   185  		gotP, gotD := tree.Nearest(q)
   186  		wantP, wantD := nearest(q, wpData)
   187  		if !reflect.DeepEqual(gotP, wantP) {
   188  			t.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP)
   189  		}
   190  		if gotD != wantD {
   191  			t.Errorf("unexpected distance for query %.3f : got:%v want:%v", q, gotD, wantD)
   192  		}
   193  	}
   194  }
   195  
   196  func nearestN(n int, q Comparable, p []Comparable) []ComparableDist {
   197  	nk := NewNKeeper(n)
   198  	for i := 0; i < len(p); i++ {
   199  		nk.Keep(ComparableDist{Comparable: p[i], Dist: q.Distance(p[i])})
   200  	}
   201  	if len(nk.Heap) == 1 {
   202  		return nk.Heap
   203  	}
   204  	sort.Sort(nk)
   205  	for i, j := 0, len(nk.Heap)-1; i < j; i, j = i+1, j-1 {
   206  		nk.Heap[i], nk.Heap[j] = nk.Heap[j], nk.Heap[i]
   207  	}
   208  	return nk.Heap
   209  }
   210  
   211  func TestNearestSetN(t *testing.T) {
   212  	data := append([]Comparable{
   213  		Point{4, 6},
   214  		Point{7, 5}, // OK here because we collect N.
   215  		Point{8, 7},
   216  		Point{6, -5},
   217  		Point{1e5, 1e5},
   218  		Point{1e5, -1e5},
   219  		Point{-1e5, 1e5},
   220  		Point{-1e5, -1e5},
   221  		Point{1e5, 0},
   222  		Point{0, -1e5},
   223  		Point{0, 1e5},
   224  		Point{-1e5, 0}},
   225  		wpData[:len(wpData)-1]...)
   226  
   227  	tree, err := New(wpData, 3, rand.NewSource(1))
   228  	if err != nil {
   229  		t.Fatalf("unexpected error: %v", err)
   230  	}
   231  	for k := 1; k <= len(wpData); k++ {
   232  		for _, q := range data {
   233  			wantP := nearestN(k, q, wpData)
   234  
   235  			nk := NewNKeeper(k)
   236  			tree.NearestSet(nk, q)
   237  
   238  			var max float64
   239  			wantD := make(map[float64]map[string]struct{})
   240  			for _, p := range wantP {
   241  				if p.Dist > max {
   242  					max = p.Dist
   243  				}
   244  				d, ok := wantD[p.Dist]
   245  				if !ok {
   246  					d = make(map[string]struct{})
   247  				}
   248  				d[fmt.Sprint(p.Comparable)] = struct{}{}
   249  				wantD[p.Dist] = d
   250  			}
   251  			gotD := make(map[float64]map[string]struct{})
   252  			for _, p := range nk.Heap {
   253  				if p.Dist > max {
   254  					t.Errorf("unexpected distance for point %.3f: got:%v want:<=%v", p.Comparable, p.Dist, max)
   255  				}
   256  				d, ok := gotD[p.Dist]
   257  				if !ok {
   258  					d = make(map[string]struct{})
   259  				}
   260  				d[fmt.Sprint(p.Comparable)] = struct{}{}
   261  				gotD[p.Dist] = d
   262  			}
   263  
   264  			// If the available number of slots does not fit all the coequal furthest points
   265  			// we will fail the check. So remove, but check them minimally here.
   266  			if !reflect.DeepEqual(wantD[max], gotD[max]) {
   267  				// The best we can do at this stage is confirm that there are an equal number of matches at this distance.
   268  				if len(gotD[max]) != len(wantD[max]) {
   269  					t.Errorf("unexpected number of maximal distance points: got:%d want:%d", len(gotD[max]), len(wantD[max]))
   270  				}
   271  				delete(wantD, max)
   272  				delete(gotD, max)
   273  			}
   274  
   275  			if !reflect.DeepEqual(gotD, wantD) {
   276  				t.Errorf("unexpected result for k=%d query %.3f: got:%v want:%v", k, q, gotD, wantD)
   277  			}
   278  		}
   279  	}
   280  }
   281  
   282  var nearestSetDistTests = []Point{
   283  	{4, 6},
   284  	{7, 5},
   285  	{8, 7},
   286  	{6, -5},
   287  }
   288  
   289  func TestNearestSetDist(t *testing.T) {
   290  	tree, err := New(wpData, 3, rand.NewSource(1))
   291  	if err != nil {
   292  		t.Fatalf("unexpected error: %v", err)
   293  	}
   294  	for i, q := range nearestSetDistTests {
   295  		for d := 1.0; d < 100; d += 0.1 {
   296  			dk := NewDistKeeper(d)
   297  			tree.NearestSet(dk, q)
   298  
   299  			hits := make(map[string]float64)
   300  			for _, p := range wpData {
   301  				hits[fmt.Sprint(p)] = p.Distance(q)
   302  			}
   303  
   304  			for _, p := range dk.Heap {
   305  				var done bool
   306  				if p.Comparable == nil {
   307  					done = true
   308  					continue
   309  				}
   310  				delete(hits, fmt.Sprint(p.Comparable))
   311  				if done {
   312  					t.Error("expectedly finished heap iteration")
   313  					break
   314  				}
   315  				dist := p.Comparable.Distance(q)
   316  				if dist > d {
   317  					t.Errorf("Test %d: query %v found %v expect %.3f <= %.3f", i, q, p, dist, d)
   318  				}
   319  			}
   320  
   321  			for p, dist := range hits {
   322  				if dist <= d {
   323  					t.Errorf("Test %d: query %v missed %v expect %.3f > %.3f", i, q, p, dist, d)
   324  				}
   325  			}
   326  		}
   327  	}
   328  }
   329  
   330  func TestDo(t *testing.T) {
   331  	tree, err := New(wpData, 3, rand.NewSource(1))
   332  	if err != nil {
   333  		t.Fatalf("unexpected error: %v", err)
   334  	}
   335  	var got []Point
   336  	fn := func(c Comparable, _ int) (done bool) {
   337  		got = append(got, c.(Point))
   338  		return
   339  	}
   340  	killed := tree.Do(fn)
   341  
   342  	want := make([]Point, len(wpData))
   343  	for i, p := range wpData {
   344  		want[i] = p.(Point)
   345  	}
   346  	sort.Sort(lexical(got))
   347  	sort.Sort(lexical(want))
   348  
   349  	if !reflect.DeepEqual(got, want) {
   350  		t.Errorf("unexpected result from tree iteration: got:%v want:%v", got, want)
   351  	}
   352  	if killed {
   353  		t.Error("tree iteration unexpectedly killed")
   354  	}
   355  }
   356  
   357  type lexical []Point
   358  
   359  func (c lexical) Len() int { return len(c) }
   360  func (c lexical) Less(i, j int) bool {
   361  	a, b := c[i], c[j]
   362  	l := len(a)
   363  	if len(b) < l {
   364  		l = len(b)
   365  	}
   366  	for k, v := range a[:l] {
   367  		if v < b[k] {
   368  			return true
   369  		}
   370  		if v > b[k] {
   371  			return false
   372  		}
   373  	}
   374  	return len(a) < len(b)
   375  }
   376  func (c lexical) Swap(i, j int) { c[i], c[j] = c[j], c[i] }
   377  
   378  func BenchmarkNew(b *testing.B) {
   379  	for _, effort := range []int{0, 10, 100} {
   380  		b.Run(fmt.Sprintf("New:%d", effort), func(b *testing.B) {
   381  			rnd := rand.New(rand.NewSource(1))
   382  			p := make([]Comparable, 1e5)
   383  			for i := range p {
   384  				p[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}
   385  			}
   386  			b.ResetTimer()
   387  			for i := 0; i < b.N; i++ {
   388  				_, err := New(p, effort, rand.NewSource(1))
   389  				if err != nil {
   390  					b.Fatalf("unexpected error: %v", err)
   391  				}
   392  			}
   393  		})
   394  	}
   395  }
   396  
   397  func Benchmark(b *testing.B) {
   398  	var r Comparable
   399  	var d float64
   400  	queryBenchmarks := []struct {
   401  		name string
   402  		fn   func(data []Comparable, tree *Tree, rnd *rand.Rand) func(*testing.B)
   403  	}{
   404  		{
   405  			name: "NearestBrute", fn: func(data []Comparable, _ *Tree, rnd *rand.Rand) func(b *testing.B) {
   406  				return func(b *testing.B) {
   407  					for i := 0; i < b.N; i++ {
   408  						r, d = nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data)
   409  					}
   410  					if r == nil {
   411  						b.Error("unexpected nil result")
   412  					}
   413  					if math.IsNaN(d) {
   414  						b.Error("unexpected NaN result")
   415  					}
   416  				}
   417  			},
   418  		},
   419  		{
   420  			name: "NearestBruteN10", fn: func(data []Comparable, _ *Tree, rnd *rand.Rand) func(b *testing.B) {
   421  				return func(b *testing.B) {
   422  					var r []ComparableDist
   423  					for i := 0; i < b.N; i++ {
   424  						r = nearestN(10, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}, data)
   425  					}
   426  					if len(r) != 10 {
   427  						b.Error("unexpected result length", len(r))
   428  					}
   429  				}
   430  			},
   431  		},
   432  		{
   433  			name: "Nearest", fn: func(_ []Comparable, tree *Tree, rnd *rand.Rand) func(b *testing.B) {
   434  				return func(b *testing.B) {
   435  					for i := 0; i < b.N; i++ {
   436  						r, d = tree.Nearest(Point{rnd.Float64(), rnd.Float64(), rnd.Float64()})
   437  					}
   438  					if r == nil {
   439  						b.Error("unexpected nil result")
   440  					}
   441  					if math.IsNaN(d) {
   442  						b.Error("unexpected NaN result")
   443  					}
   444  				}
   445  			},
   446  		},
   447  		{
   448  			name: "NearestSetN10", fn: func(_ []Comparable, tree *Tree, rnd *rand.Rand) func(b *testing.B) {
   449  				return func(b *testing.B) {
   450  					nk := NewNKeeper(10)
   451  					for i := 0; i < b.N; i++ {
   452  						tree.NearestSet(nk, Point{rnd.Float64(), rnd.Float64(), rnd.Float64()})
   453  						if nk.Len() != 10 {
   454  							b.Error("unexpected result length")
   455  						}
   456  						nk.Heap = nk.Heap[:1]
   457  						nk.Heap[0] = ComparableDist{Dist: inf}
   458  					}
   459  				}
   460  			},
   461  		},
   462  	}
   463  
   464  	for _, effort := range []int{0, 3, 10, 30, 100, 300} {
   465  		rnd := rand.New(rand.NewSource(1))
   466  		data := make([]Comparable, 1e5)
   467  		for i := range data {
   468  			data[i] = Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}
   469  		}
   470  		tree, err := New(data, effort, rand.NewSource(1))
   471  		if err != nil {
   472  			b.Errorf("unexpected error for effort=%d: %v", effort, err)
   473  			continue
   474  		}
   475  
   476  		if !tree.Root.isVPTree() {
   477  			b.Fatal("tree is not vantage point tree")
   478  		}
   479  
   480  		for i := 0; i < 1e3; i++ {
   481  			q := Point{rnd.Float64(), rnd.Float64(), rnd.Float64()}
   482  			gotP, gotD := tree.Nearest(q)
   483  			wantP, wantD := nearest(q, data)
   484  			if !reflect.DeepEqual(gotP, wantP) {
   485  				b.Errorf("unexpected result for query %.3f: got:%.3f want:%.3f", q, gotP, wantP)
   486  			}
   487  			if gotD != wantD {
   488  				b.Errorf("unexpected distance for query %.3f: got:%v want:%v", q, gotD, wantD)
   489  			}
   490  		}
   491  
   492  		if b.Failed() && *genDot && tree.Len() <= *dotLimit {
   493  			err := dotFile(tree, "TestBenches", "")
   494  			if err != nil {
   495  				b.Fatalf("failed to write DOT file: %v", err)
   496  			}
   497  			return
   498  		}
   499  
   500  		for _, bench := range queryBenchmarks {
   501  			if strings.Contains(bench.name, "Brute") && effort != 0 {
   502  				continue
   503  			}
   504  			b.Run(fmt.Sprintf("%s:%d", bench.name, effort), bench.fn(data, tree, rnd))
   505  		}
   506  	}
   507  }
   508  
   509  func dot(t *Tree, label string) string {
   510  	if t == nil {
   511  		return ""
   512  	}
   513  	var (
   514  		s      []string
   515  		follow func(*Node)
   516  	)
   517  	follow = func(n *Node) {
   518  		id := uintptr(unsafe.Pointer(n))
   519  		c := fmt.Sprintf("%d[label = \"<Closer> |<Elem> %.3f/%.3f|<Further>\"];",
   520  			id, n.Point, n.Radius)
   521  		if n.Closer != nil {
   522  			c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Closer -> \"%d\":Elem [label=%.3f];",
   523  				id, uintptr(unsafe.Pointer(n.Closer)), n.Point.Distance(n.Closer.Point))
   524  			follow(n.Closer)
   525  		}
   526  		if n.Further != nil {
   527  			c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Further -> \"%d\":Elem [label=%.3f];",
   528  				id, uintptr(unsafe.Pointer(n.Further)), n.Point.Distance(n.Further.Point))
   529  			follow(n.Further)
   530  		}
   531  		s = append(s, c)
   532  	}
   533  	if t.Root != nil {
   534  		follow(t.Root)
   535  	}
   536  	return fmt.Sprintf("digraph %s {\n\tnode [shape=record,height=0.1];\n\t%s\n}\n",
   537  		label,
   538  		strings.Join(s, "\n\t"),
   539  	)
   540  }
   541  
   542  func dotFile(t *Tree, label, dotString string) (err error) {
   543  	if t == nil && dotString == "" {
   544  		return
   545  	}
   546  	f, err := os.Create(label + ".dot")
   547  	if err != nil {
   548  		return
   549  	}
   550  	defer f.Close()
   551  	if dotString == "" {
   552  		fmt.Fprint(f, dot(t, label))
   553  	} else {
   554  		fmt.Fprint(f, dotString)
   555  	}
   556  	return
   557  }