github.com/biogo/store@v0.0.0-20201120204734-aad293a2328f/kdtree/kdtree_test.go (about)

     1  // Copyright ©2012 The bíogo Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package kdtree
     6  
     7  import (
     8  	"flag"
     9  	"fmt"
    10  	"math/rand"
    11  	"os"
    12  	"reflect"
    13  	"sort"
    14  	"strings"
    15  	"testing"
    16  	"unsafe"
    17  
    18  	"gopkg.in/check.v1"
    19  )
    20  
    21  var (
    22  	genDot   = flag.Bool("dot", false, "Generate dot code for failing trees.")
    23  	dotLimit = flag.Int("dotmax", 100, "Maximum size for tree output for dot format.")
    24  )
    25  
    26  func Test(t *testing.T) { check.TestingT(t) }
    27  
    28  type S struct{}
    29  
    30  var _ = check.Suite(&S{})
    31  
    32  var (
    33  	// Using example from WP article.
    34  	wpData   = Points{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}}
    35  	nbWpData = nbPoints{{2, 3}, {5, 4}, {9, 6}, {4, 7}, {8, 1}, {7, 2}}
    36  	wpBound  = &Bounding{Point{2, 1}, Point{9, 7}}
    37  	bData    = func(i int) Points {
    38  		p := make(Points, i)
    39  		for i := range p {
    40  			p[i] = Point{rand.Float64(), rand.Float64(), rand.Float64()}
    41  		}
    42  		return p
    43  	}(1e2)
    44  	bTree = New(bData, true)
    45  )
    46  
    47  func (s *S) TestNew(c *check.C) {
    48  	for i, test := range []struct {
    49  		data     Interface
    50  		bounding bool
    51  		bounds   *Bounding
    52  	}{
    53  		{wpData, false, nil},
    54  		{nbWpData, false, nil},
    55  		{wpData, true, wpBound},
    56  		{nbWpData, true, nil},
    57  	} {
    58  		var t *Tree
    59  		NewTreePanics := func() (panicked bool) {
    60  			defer func() {
    61  				if r := recover(); r != nil {
    62  					panicked = true
    63  				}
    64  			}()
    65  			t = New(test.data, test.bounding)
    66  			return
    67  		}
    68  		c.Check(NewTreePanics(), check.Equals, false)
    69  		c.Check(t.Root.isKDTree(), check.Equals, true)
    70  		switch data := test.data.(type) {
    71  		case Points:
    72  			for _, p := range data {
    73  				c.Check(t.Contains(p), check.Equals, true)
    74  			}
    75  		case nbPoints:
    76  			for _, p := range data {
    77  				c.Check(t.Contains(p), check.Equals, true)
    78  			}
    79  		}
    80  		c.Check(t.Root.Bounding, check.DeepEquals, test.bounds,
    81  			check.Commentf("Test %d. %T %v", i, test.data, test.bounding))
    82  		if c.Failed() && *genDot && t.Len() <= *dotLimit {
    83  			err := dotFile(t, fmt.Sprintf("TestNew%T", test.data), "")
    84  			if err != nil {
    85  				c.Errorf("Dot file write failed: %v", err)
    86  			}
    87  		}
    88  	}
    89  }
    90  
    91  func (s *S) TestInsert(c *check.C) {
    92  	for i, test := range []struct {
    93  		data   Interface
    94  		insert []Comparable
    95  		bounds *Bounding
    96  	}{
    97  		{
    98  			wpData,
    99  			[]Comparable{Point{0, 0}, Point{10, 10}},
   100  			&Bounding{Point{0, 0}, Point{10, 10}},
   101  		},
   102  		{
   103  			nbWpData,
   104  			[]Comparable{nbPoint{0, 0}, nbPoint{10, 10}},
   105  			nil,
   106  		},
   107  	} {
   108  		t := New(test.data, true)
   109  		for _, v := range test.insert {
   110  			t.Insert(v, true)
   111  		}
   112  		c.Check(t.Root.isKDTree(), check.Equals, true)
   113  		c.Check(t.Root.Bounding, check.DeepEquals, test.bounds,
   114  			check.Commentf("Test %d. %T", i, test.data))
   115  		if c.Failed() && *genDot && t.Len() <= *dotLimit {
   116  			err := dotFile(t, fmt.Sprintf("TestInsert%T", test.data), "")
   117  			if err != nil {
   118  				c.Errorf("Dot file write failed: %v", err)
   119  			}
   120  		}
   121  	}
   122  }
   123  
   124  type compFn func(float64) bool
   125  
   126  func left(v float64) bool  { return v <= 0 }
   127  func right(v float64) bool { return !left(v) }
   128  
   129  func (n *Node) isKDTree() bool {
   130  	if n == nil {
   131  		return true
   132  	}
   133  	d := n.Point.Dims()
   134  	// Together these define the property of minimal orthogonal bounding.
   135  	if !(n.isContainedBy(n.Bounding) && n.Bounding.planesHaveCoincidentPointsIn(n, [2][]bool{make([]bool, d), make([]bool, d)})) {
   136  		return false
   137  	}
   138  	if !n.Left.isPartitioned(n.Point, left, n.Plane) {
   139  		return false
   140  	}
   141  	if !n.Right.isPartitioned(n.Point, right, n.Plane) {
   142  		return false
   143  	}
   144  	return n.Left.isKDTree() && n.Right.isKDTree()
   145  }
   146  
   147  func (n *Node) isPartitioned(pivot Comparable, fn compFn, plane Dim) bool {
   148  	if n == nil {
   149  		return true
   150  	}
   151  	if n.Left != nil && fn(pivot.Compare(n.Left.Point, plane)) {
   152  		return false
   153  	}
   154  	if n.Right != nil && fn(pivot.Compare(n.Right.Point, plane)) {
   155  		return false
   156  	}
   157  	return n.Left.isPartitioned(pivot, fn, plane) && n.Right.isPartitioned(pivot, fn, plane)
   158  }
   159  
   160  func (n *Node) isContainedBy(b *Bounding) bool {
   161  	if n == nil {
   162  		return true
   163  	}
   164  	if !b.Contains(n.Point) {
   165  		return false
   166  	}
   167  	return n.Left.isContainedBy(b) && n.Right.isContainedBy(b)
   168  }
   169  
   170  func (b *Bounding) planesHaveCoincidentPointsIn(n *Node, tight [2][]bool) bool {
   171  	if b == nil {
   172  		return true
   173  	}
   174  	if n == nil {
   175  		return true
   176  	}
   177  
   178  	b.planesHaveCoincidentPointsIn(n.Left, tight)
   179  	b.planesHaveCoincidentPointsIn(n.Right, tight)
   180  
   181  	var ok = true
   182  	for i := range tight {
   183  		for d := 0; d < n.Point.Dims(); d++ {
   184  			if c := n.Point.Compare(b[0], Dim(d)); c == 0 {
   185  				tight[i][d] = true
   186  			}
   187  			ok = ok && tight[i][d]
   188  		}
   189  	}
   190  	return ok
   191  }
   192  
   193  func nearest(q Point, p Points) (Point, float64) {
   194  	min := q.Distance(p[0])
   195  	var r int
   196  	for i := 1; i < p.Len(); i++ {
   197  		d := q.Distance(p[i])
   198  		if d < min {
   199  			min = d
   200  			r = i
   201  		}
   202  	}
   203  	return p[r], min
   204  }
   205  
   206  func (s *S) TestNearestRandom(c *check.C) {
   207  	const (
   208  		min = 0.
   209  		max = 1000.
   210  
   211  		dims    = 4
   212  		setSize = 10000
   213  	)
   214  
   215  	var randData Points
   216  	for i := 0; i < setSize; i++ {
   217  		p := make(Point, dims)
   218  		for j := 0; j < dims; j++ {
   219  			p[j] = (max-min)*rand.Float64() + min
   220  		}
   221  		randData = append(randData, p)
   222  	}
   223  	t := New(randData, false)
   224  
   225  	for i := 0; i < setSize; i++ {
   226  		q := make(Point, dims)
   227  		for j := 0; j < dims; j++ {
   228  			q[j] = (max-min)*rand.Float64() + min
   229  		}
   230  
   231  		p, _ := t.Nearest(q)
   232  		ep, _ := nearest(q, randData)
   233  		c.Assert(p, check.DeepEquals, ep, check.Commentf("Test %d: query %.3f expects %.3f", i, q, ep))
   234  	}
   235  }
   236  
   237  func (s *S) TestNearest(c *check.C) {
   238  	t := New(wpData, false)
   239  	for i, q := range append([]Point{
   240  		{4, 6},
   241  		{7, 5},
   242  		{8, 7},
   243  		{6, -5},
   244  		{1e5, 1e5},
   245  		{1e5, -1e5},
   246  		{-1e5, 1e5},
   247  		{-1e5, -1e5},
   248  		{1e5, 0},
   249  		{0, -1e5},
   250  		{0, 1e5},
   251  		{-1e5, 0},
   252  	}, wpData...) {
   253  		p, d := t.Nearest(q)
   254  		ep, ed := nearest(q, wpData)
   255  		c.Check(p, check.DeepEquals, ep, check.Commentf("Test %d: query %.3f expects %.3f", i, q, ep))
   256  		c.Check(d, check.Equals, ed)
   257  	}
   258  }
   259  
   260  func nearestN(n int, q Point, p Points) []ComparableDist {
   261  	nk := NewNKeeper(n)
   262  	for i := 0; i < p.Len(); i++ {
   263  		nk.Keep(ComparableDist{Comparable: p[i], Dist: q.Distance(p[i])})
   264  	}
   265  	if len(nk.Heap) == 1 {
   266  		return nk.Heap
   267  	}
   268  	sort.Sort(nk)
   269  	for i, j := 0, len(nk.Heap)-1; i < j; i, j = i+1, j-1 {
   270  		nk.Heap[i], nk.Heap[j] = nk.Heap[j], nk.Heap[i]
   271  	}
   272  	return nk.Heap
   273  }
   274  
   275  func (s *S) TestNearestSetN(c *check.C) {
   276  	t := New(wpData, false)
   277  	in := append([]Point{
   278  		{4, 6},
   279  		{7, 5},
   280  		{8, 7},
   281  		{6, -5},
   282  		{1e5, 1e5},
   283  		{1e5, -1e5},
   284  		{-1e5, 1e5},
   285  		{-1e5, -1e5},
   286  		{1e5, 0},
   287  		{0, -1e5},
   288  		{0, 1e5},
   289  		{-1e5, 0}}, wpData[:len(wpData)-1]...)
   290  	for k := 1; k <= len(wpData); k++ {
   291  		for i, q := range in {
   292  			ep := nearestN(k, q, wpData)
   293  			nk := NewNKeeper(k)
   294  			t.NearestSet(nk, q)
   295  
   296  			var max float64
   297  			ed := make(map[float64]map[string]struct{})
   298  			for _, p := range ep {
   299  				if p.Dist > max {
   300  					max = p.Dist
   301  				}
   302  				d, ok := ed[p.Dist]
   303  				if !ok {
   304  					d = make(map[string]struct{})
   305  				}
   306  				d[fmt.Sprint(p.Comparable)] = struct{}{}
   307  				ed[p.Dist] = d
   308  			}
   309  			kd := make(map[float64]map[string]struct{})
   310  			for _, p := range nk.Heap {
   311  				c.Check(max >= p.Dist, check.Equals, true)
   312  				d, ok := kd[p.Dist]
   313  				if !ok {
   314  					d = make(map[string]struct{})
   315  				}
   316  				d[fmt.Sprint(p.Comparable)] = struct{}{}
   317  				kd[p.Dist] = d
   318  			}
   319  
   320  			// If the available number of slots does not fit all the coequal furthest points
   321  			// we will fail the check. So remove, but check them minimally here.
   322  			if !reflect.DeepEqual(ed[max], kd[max]) {
   323  				// The best we can do at this stage is confirm that there are an equal number of matches at this distance.
   324  				c.Check(len(ed[max]), check.Equals, len(kd[max]))
   325  				delete(ed, max)
   326  				delete(kd, max)
   327  			}
   328  
   329  			c.Check(kd, check.DeepEquals, ed, check.Commentf("Test k=%d %d: query %.3f expects %.3f", k, i, q, ep))
   330  		}
   331  	}
   332  }
   333  
   334  func (s *S) TestNearestSetDist(c *check.C) {
   335  	t := New(wpData, false)
   336  	for i, q := range []Point{
   337  		{4, 6},
   338  		{7, 5},
   339  		{8, 7},
   340  		{6, -5},
   341  	} {
   342  		for d := 1.; d < 100; d += 0.1 {
   343  			dk := NewDistKeeper(d)
   344  			t.NearestSet(dk, q)
   345  
   346  			hits := make(map[string]float64)
   347  			for _, p := range wpData {
   348  				hits[fmt.Sprint(p)] = p.Distance(q)
   349  			}
   350  
   351  			for _, p := range dk.Heap {
   352  				var finished bool
   353  				if p.Comparable != nil {
   354  					delete(hits, fmt.Sprint(p.Comparable))
   355  					c.Check(finished, check.Equals, false)
   356  					dist := p.Comparable.Distance(q)
   357  					c.Check(dist <= d, check.Equals, true, check.Commentf("Test %d: query %v found %v expect %.3f <= %.3f", i, q, p, dist, d))
   358  				} else {
   359  					finished = true
   360  				}
   361  			}
   362  
   363  			for p, dist := range hits {
   364  				c.Check(dist > d, check.Equals, true, check.Commentf("Test %d: query %v missed %v expect %.3f > %.3f", i, q, p, dist, d))
   365  			}
   366  		}
   367  	}
   368  }
   369  
   370  func (s *S) TestDo(c *check.C) {
   371  	var result Points
   372  	t := New(wpData, false)
   373  	f := func(c Comparable, _ *Bounding, _ int) (done bool) {
   374  		result = append(result, c.(Point))
   375  		return
   376  	}
   377  	killed := t.Do(f)
   378  	c.Check(result, check.DeepEquals, wpData)
   379  	c.Check(killed, check.Equals, false)
   380  }
   381  
   382  func (s *S) TestDoBounded(c *check.C) {
   383  	for _, test := range []struct {
   384  		bounds *Bounding
   385  		result Points
   386  	}{
   387  		{
   388  			nil,
   389  			wpData,
   390  		},
   391  		{
   392  			&Bounding{Point{0, 0}, Point{10, 10}},
   393  			wpData,
   394  		},
   395  		{
   396  			&Bounding{Point{3, 4}, Point{10, 10}},
   397  			Points{Point{5, 4}, Point{4, 7}, Point{9, 6}},
   398  		},
   399  		{
   400  			&Bounding{Point{3, 3}, Point{10, 10}},
   401  			Points{Point{5, 4}, Point{4, 7}, Point{9, 6}},
   402  		},
   403  		{
   404  			&Bounding{Point{0, 0}, Point{6, 5}},
   405  			Points{Point{2, 3}, Point{5, 4}},
   406  		},
   407  		{
   408  			&Bounding{Point{5, 2}, Point{7, 4}},
   409  			Points{Point{5, 4}, Point{7, 2}},
   410  		},
   411  		{
   412  			&Bounding{Point{2, 2}, Point{7, 4}},
   413  			Points{Point{2, 3}, Point{5, 4}, Point{7, 2}},
   414  		},
   415  		{
   416  			&Bounding{Point{2, 3}, Point{9, 6}},
   417  			Points{Point{2, 3}, Point{5, 4}, Point{9, 6}},
   418  		},
   419  		{
   420  			&Bounding{Point{7, 2}, Point{7, 2}},
   421  			Points{Point{7, 2}},
   422  		},
   423  	} {
   424  		var result Points
   425  		t := New(wpData, false)
   426  		f := func(c Comparable, _ *Bounding, _ int) (done bool) {
   427  			result = append(result, c.(Point))
   428  			return
   429  		}
   430  		killed := t.DoBounded(f, test.bounds)
   431  		c.Check(result, check.DeepEquals, test.result)
   432  		c.Check(killed, check.Equals, false)
   433  	}
   434  }
   435  
   436  func BenchmarkNew(b *testing.B) {
   437  	b.StopTimer()
   438  	p := make(Points, 1e5)
   439  	for i := range p {
   440  		p[i] = Point{rand.Float64(), rand.Float64(), rand.Float64()}
   441  	}
   442  	b.StartTimer()
   443  	for i := 0; i < b.N; i++ {
   444  		_ = New(p, false)
   445  	}
   446  }
   447  
   448  func BenchmarkNewBounds(b *testing.B) {
   449  	b.StopTimer()
   450  	p := make(Points, 1e5)
   451  	for i := range p {
   452  		p[i] = Point{rand.Float64(), rand.Float64(), rand.Float64()}
   453  	}
   454  	b.StartTimer()
   455  	for i := 0; i < b.N; i++ {
   456  		_ = New(p, true)
   457  	}
   458  }
   459  
   460  func BenchmarkInsert(b *testing.B) {
   461  	rand.Seed(1)
   462  	t := &Tree{}
   463  	for i := 0; i < b.N; i++ {
   464  		t.Insert(Point{rand.Float64(), rand.Float64(), rand.Float64()}, false)
   465  	}
   466  }
   467  
   468  func BenchmarkInsertBounds(b *testing.B) {
   469  	rand.Seed(1)
   470  	t := &Tree{}
   471  	for i := 0; i < b.N; i++ {
   472  		t.Insert(Point{rand.Float64(), rand.Float64(), rand.Float64()}, true)
   473  	}
   474  }
   475  
   476  func (s *S) TestBenches(c *check.C) {
   477  	c.Check(bTree.Root.isKDTree(), check.Equals, true)
   478  	for i := 0; i < 1e3; i++ {
   479  		q := Point{rand.Float64(), rand.Float64(), rand.Float64()}
   480  		p, d := bTree.Nearest(q)
   481  		ep, ed := nearest(q, bData)
   482  		c.Check(p, check.DeepEquals, ep, check.Commentf("Test %d: query %.3f expects %.3f", i, q, ep))
   483  		c.Check(d, check.Equals, ed)
   484  	}
   485  	if c.Failed() && *genDot && bTree.Len() <= *dotLimit {
   486  		err := dotFile(bTree, "TestBenches", "")
   487  		if err != nil {
   488  			c.Errorf("Dot file write failed: %v", err)
   489  		}
   490  	}
   491  }
   492  
   493  func BenchmarkNearest(b *testing.B) {
   494  	var (
   495  		r Comparable
   496  		d float64
   497  	)
   498  	for i := 0; i < b.N; i++ {
   499  		r, d = bTree.Nearest(Point{rand.Float64(), rand.Float64(), rand.Float64()})
   500  	}
   501  	_, _ = r, d
   502  }
   503  
   504  func BenchmarkNearBrute(b *testing.B) {
   505  	var (
   506  		r Comparable
   507  		d float64
   508  	)
   509  	for i := 0; i < b.N; i++ {
   510  		r, d = nearest(Point{rand.Float64(), rand.Float64(), rand.Float64()}, bData)
   511  	}
   512  	_, _ = r, d
   513  }
   514  
   515  func BenchmarkNearestSetN10(b *testing.B) {
   516  	var nk = NewNKeeper(10)
   517  	for i := 0; i < b.N; i++ {
   518  		bTree.NearestSet(nk, Point{rand.Float64(), rand.Float64(), rand.Float64()})
   519  		nk.Heap = nk.Heap[:1]
   520  		nk.Heap[0] = ComparableDist{Comparable: nil, Dist: inf}
   521  	}
   522  }
   523  
   524  func BenchmarkNearBruteN10(b *testing.B) {
   525  	var r []ComparableDist
   526  	for i := 0; i < b.N; i++ {
   527  		r = nearestN(10, Point{rand.Float64(), rand.Float64(), rand.Float64()}, bData)
   528  	}
   529  	_ = r
   530  }
   531  
   532  func dot(t *Tree, label string) string {
   533  	if t == nil {
   534  		return ""
   535  	}
   536  	var (
   537  		s      []string
   538  		follow func(*Node)
   539  	)
   540  	follow = func(n *Node) {
   541  		id := uintptr(unsafe.Pointer(n))
   542  		c := fmt.Sprintf("%d[label = \"<Left> |<Elem> %s/%.3f\\n%.3f|<Right>\"];",
   543  			id, n, n.Point.(Point)[n.Plane], *n.Bounding)
   544  		if n.Left != nil {
   545  			c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Left -> \"%d\":Elem;",
   546  				id, uintptr(unsafe.Pointer(n.Left)))
   547  			follow(n.Left)
   548  		}
   549  		if n.Right != nil {
   550  			c += fmt.Sprintf("\n\t\tedge [arrowhead=normal]; \"%d\":Right -> \"%d\":Elem;",
   551  				id, uintptr(unsafe.Pointer(n.Right)))
   552  			follow(n.Right)
   553  		}
   554  		s = append(s, c)
   555  	}
   556  	if t.Root != nil {
   557  		follow(t.Root)
   558  	}
   559  	return fmt.Sprintf("digraph %s {\n\tnode [shape=record,height=0.1];\n\t%s\n}\n",
   560  		label,
   561  		strings.Join(s, "\n\t"),
   562  	)
   563  }
   564  
   565  func dotFile(t *Tree, label, dotString string) (err error) {
   566  	if t == nil && dotString == "" {
   567  		return
   568  	}
   569  	f, err := os.Create(label + ".dot")
   570  	if err != nil {
   571  		return
   572  	}
   573  	defer f.Close()
   574  	if dotString == "" {
   575  		fmt.Fprintf(f, dot(t, label))
   576  	} else {
   577  		fmt.Fprintf(f, dotString)
   578  	}
   579  	return
   580  }