github.com/biogo/store@v0.0.0-20201120204734-aad293a2328f/kdtree/kdtree.go (about)

     1  // Copyright ©2012 The bíogo Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package kdtree implements a k-d tree.
     6  package kdtree
     7  
     8  import (
     9  	"container/heap"
    10  	"fmt"
    11  	"math"
    12  	"sort"
    13  )
    14  
    15  type Interface interface {
    16  	// Index returns the ith element of the list of points.
    17  	Index(i int) Comparable
    18  
    19  	// Len returns the length of the list.
    20  	Len() int
    21  
    22  	// Pivot partitions the list based on the dimension specified.
    23  	Pivot(Dim) int
    24  
    25  	// Slice returns a slice of the list.
    26  	Slice(start, end int) Interface
    27  }
    28  
    29  // An Bounder returns a bounding volume containing the list of points. Bounds may return nil.
    30  type Bounder interface {
    31  	Bounds() *Bounding
    32  }
    33  
    34  type bounder interface {
    35  	Interface
    36  	Bounder
    37  }
    38  
    39  // A Dim is an index into a point's coordinates.
    40  type Dim int
    41  
    42  // A Comparable is the element interface for values stored in a k-d tree.
    43  type Comparable interface {
    44  	// Compare returns the shortest translation of the plane through b with
    45  	// normal vector along dimension d to the parallel plane through a.
    46  	//
    47  	// Given c = a.Compare(b, d):
    48  	//  c = a_d - b_d
    49  	//
    50  	Compare(Comparable, Dim) float64
    51  
    52  	// Dims returns the number of dimensions described in the Comparable.
    53  	Dims() int
    54  
    55  	// Distance returns the squared Euclidean distance between the receiver and
    56  	// the parameter.
    57  	Distance(Comparable) float64
    58  }
    59  
    60  // An Extender is a Comparable that can increase a bounding volume to include the
    61  // point represented by the Comparable.
    62  type Extender interface {
    63  	Comparable
    64  
    65  	// Extend returns a bounding box that has been extended to include the
    66  	// receiver. Extend may return nil.
    67  	Extend(*Bounding) *Bounding
    68  }
    69  
    70  // A Bounding represents a volume bounding box.
    71  type Bounding [2]Comparable
    72  
    73  // Contains returns whether c is within the volume of the Bounding. A nil Bounding
    74  // returns true.
    75  func (b *Bounding) Contains(c Comparable) bool {
    76  	if b == nil {
    77  		return true
    78  	}
    79  	for d := Dim(0); d < Dim(c.Dims()); d++ {
    80  		if c.Compare(b[0], d) < 0 || c.Compare(b[1], d) > 0 {
    81  			return false
    82  		}
    83  	}
    84  	return true
    85  }
    86  
    87  // A Node holds a single point value in a k-d tree.
    88  type Node struct {
    89  	Point       Comparable
    90  	Plane       Dim
    91  	Left, Right *Node
    92  	*Bounding
    93  }
    94  
    95  func (n *Node) String() string {
    96  	if n == nil {
    97  		return "<nil>"
    98  	}
    99  	return fmt.Sprintf("%.3f %d", n.Point, n.Plane)
   100  }
   101  
   102  // A Tree implements a k-d tree creation and nearest neighbour search.
   103  type Tree struct {
   104  	Root  *Node
   105  	Count int
   106  }
   107  
   108  // New returns a k-d tree constructed from the values in p. If p is a Bounder and
   109  // bounding is true, bounds are determined for each node.
   110  func New(p Interface, bounding bool) *Tree {
   111  	if p, ok := p.(bounder); ok && bounding {
   112  		return &Tree{
   113  			Root:  buildBounded(p, 0, bounding),
   114  			Count: p.Len(),
   115  		}
   116  	}
   117  	return &Tree{
   118  		Root:  build(p, 0),
   119  		Count: p.Len(),
   120  	}
   121  }
   122  
   123  func build(p Interface, plane Dim) *Node {
   124  	if p.Len() == 0 {
   125  		return nil
   126  	}
   127  
   128  	piv := p.Pivot(plane)
   129  	d := p.Index(piv)
   130  	np := (plane + 1) % Dim(d.Dims())
   131  
   132  	return &Node{
   133  		Point:    d,
   134  		Plane:    plane,
   135  		Left:     build(p.Slice(0, piv), np),
   136  		Right:    build(p.Slice(piv+1, p.Len()), np),
   137  		Bounding: nil,
   138  	}
   139  }
   140  
   141  func buildBounded(p bounder, plane Dim, bounding bool) *Node {
   142  	if p.Len() == 0 {
   143  		return nil
   144  	}
   145  
   146  	piv := p.Pivot(plane)
   147  	d := p.Index(piv)
   148  	np := (plane + 1) % Dim(d.Dims())
   149  
   150  	b := p.Bounds()
   151  	return &Node{
   152  		Point:    d,
   153  		Plane:    plane,
   154  		Left:     buildBounded(p.Slice(0, piv).(bounder), np, bounding),
   155  		Right:    buildBounded(p.Slice(piv+1, p.Len()).(bounder), np, bounding),
   156  		Bounding: b,
   157  	}
   158  }
   159  
   160  // Insert adds a point to the tree, updating the bounding volumes if bounding is
   161  // true, and the tree is empty or the tree already has bounding volumes stored,
   162  // and c is an Extender. No rebalancing of the tree is performed.
   163  func (t *Tree) Insert(c Comparable, bounding bool) {
   164  	t.Count++
   165  	if t.Root != nil {
   166  		bounding = t.Root.Bounding != nil
   167  	}
   168  	if c, ok := c.(Extender); ok && bounding {
   169  		t.Root = t.Root.insertBounded(c, 0, bounding)
   170  		return
   171  	} else if !ok && t.Root != nil {
   172  		// If we are not rebounding, mark the tree as non-bounded.
   173  		t.Root.Bounding = nil
   174  	}
   175  	t.Root = t.Root.insert(c, 0)
   176  }
   177  
   178  func (n *Node) insert(c Comparable, d Dim) *Node {
   179  	if n == nil {
   180  		return &Node{
   181  			Point:    c,
   182  			Plane:    d,
   183  			Bounding: nil,
   184  		}
   185  	}
   186  
   187  	d = (n.Plane + 1) % Dim(c.Dims())
   188  	if c.Compare(n.Point, n.Plane) <= 0 {
   189  		n.Left = n.Left.insert(c, d)
   190  	} else {
   191  		n.Right = n.Right.insert(c, d)
   192  	}
   193  
   194  	return n
   195  }
   196  
   197  func (n *Node) insertBounded(c Extender, d Dim, bounding bool) *Node {
   198  	if n == nil {
   199  		var b *Bounding
   200  		if bounding {
   201  			b = c.Extend(b)
   202  		}
   203  		return &Node{
   204  			Point:    c,
   205  			Plane:    d,
   206  			Bounding: b,
   207  		}
   208  	}
   209  
   210  	if bounding {
   211  		n.Bounding = c.Extend(n.Bounding)
   212  	}
   213  	d = (n.Plane + 1) % Dim(c.Dims())
   214  	if c.Compare(n.Point, n.Plane) <= 0 {
   215  		n.Left = n.Left.insertBounded(c, d, bounding)
   216  	} else {
   217  		n.Right = n.Right.insertBounded(c, d, bounding)
   218  	}
   219  
   220  	return n
   221  }
   222  
   223  // Len returns the number of elements in the tree.
   224  func (t *Tree) Len() int { return t.Count }
   225  
   226  // Contains returns whether a Comparable is in the bounds of the tree. If no bounding has
   227  // been constructed Contains returns true.
   228  func (t *Tree) Contains(c Comparable) bool {
   229  	if t.Root.Bounding == nil {
   230  		return true
   231  	}
   232  	return t.Root.Contains(c)
   233  }
   234  
   235  var inf = math.Inf(1)
   236  
   237  // Nearest returns the nearest value to the query and the distance between them.
   238  func (t *Tree) Nearest(q Comparable) (Comparable, float64) {
   239  	if t.Root == nil {
   240  		return nil, inf
   241  	}
   242  	n, dist := t.Root.search(q, inf)
   243  	if n == nil {
   244  		return nil, inf
   245  	}
   246  	return n.Point, dist
   247  }
   248  
   249  func (n *Node) search(q Comparable, dist float64) (*Node, float64) {
   250  	if n == nil {
   251  		return nil, inf
   252  	}
   253  
   254  	c := q.Compare(n.Point, n.Plane)
   255  	dist = math.Min(dist, q.Distance(n.Point))
   256  
   257  	bn := n
   258  	if c <= 0 {
   259  		ln, ld := n.Left.search(q, dist)
   260  		if ld < dist {
   261  			dist = ld
   262  			bn = ln
   263  		}
   264  		if c*c < dist {
   265  			rn, rd := n.Right.search(q, dist)
   266  			if rd < dist {
   267  				bn, dist = rn, rd
   268  			}
   269  		}
   270  		return bn, dist
   271  	}
   272  	rn, rd := n.Right.search(q, dist)
   273  	if rd < dist {
   274  		dist = rd
   275  		bn = rn
   276  	}
   277  	if c*c < dist {
   278  		ln, ld := n.Left.search(q, dist)
   279  		if ld < dist {
   280  			bn, dist = ln, ld
   281  		}
   282  	}
   283  	return bn, dist
   284  }
   285  
   286  // ComparableDist holds a Comparable and a distance to a specific query. A nil Comparable
   287  // is used to mark the end of the heap, so clients should not store nil values except for
   288  // this purpose.
   289  type ComparableDist struct {
   290  	Comparable Comparable
   291  	Dist       float64
   292  }
   293  
   294  // Heap is a max heap sorted on Dist.
   295  type Heap []ComparableDist
   296  
   297  func (h *Heap) Max() ComparableDist  { return (*h)[0] }
   298  func (h *Heap) Len() int             { return len(*h) }
   299  func (h *Heap) Less(i, j int) bool   { return (*h)[i].Comparable == nil || (*h)[i].Dist > (*h)[j].Dist }
   300  func (h *Heap) Swap(i, j int)        { (*h)[i], (*h)[j] = (*h)[j], (*h)[i] }
   301  func (h *Heap) Push(x interface{})   { (*h) = append(*h, x.(ComparableDist)) }
   302  func (h *Heap) Pop() (i interface{}) { i, *h = (*h)[len(*h)-1], (*h)[:len(*h)-1]; return i }
   303  
   304  // NKeeper is a Keeper that retains the n best ComparableDists that it is called to Keep.
   305  type NKeeper struct {
   306  	Heap
   307  }
   308  
   309  // NewNKeeper returns an NKeeper with the max value of the heap set to infinite distance. The
   310  // returned NKeeper is able to retain at most n values.
   311  func NewNKeeper(n int) *NKeeper {
   312  	k := NKeeper{make(Heap, 1, n)}
   313  	k.Heap[0].Dist = inf
   314  	return &k
   315  }
   316  
   317  // Keep add c to the heap if its distance is less than the maximum value of the heap. If adding
   318  // c would increase the size of the heap beyond the initial maximum length, the maximum value of
   319  // the heap is dropped.
   320  func (k *NKeeper) Keep(c ComparableDist) {
   321  	if c.Dist < k.Heap[0].Dist {
   322  		if len(k.Heap) == cap(k.Heap) {
   323  			heap.Pop(k)
   324  		}
   325  		heap.Push(k, c)
   326  	}
   327  }
   328  
   329  // DistKeeper is a Keeper that retains the ComparableDists within the specified distance of the
   330  // query that it is called to Keep.
   331  type DistKeeper struct {
   332  	Heap
   333  }
   334  
   335  // NewDistKeeper returns an DistKeeper with the max value of the heap set to d.
   336  func NewDistKeeper(d float64) *DistKeeper { return &DistKeeper{Heap{{Dist: d}}} }
   337  
   338  // Keep adds c to the heap if its distance is less than or equal to the max value of the heap.
   339  func (k *DistKeeper) Keep(c ComparableDist) {
   340  	if c.Dist <= k.Heap[0].Dist {
   341  		heap.Push(k, c)
   342  	}
   343  }
   344  
   345  // Keeper implements a conditional max heap sorted on the Dist field of the ComparableDist type.
   346  // kd search is guided by the distance stored in the max value of the heap.
   347  type Keeper interface {
   348  	Keep(ComparableDist) // Keep conditionally pushes the provided ComparableDist onto the heap.
   349  	Max() ComparableDist // Max returns the maximum element of the Keeper.
   350  	heap.Interface
   351  }
   352  
   353  // NearestSet finds the nearest values to the query accepted by the provided Keeper, k.
   354  // k must be able to return a ComparableDist specifying the maximum acceptable distance
   355  // when Max() is called, and retains the results of the search in min sorted order after
   356  // the call to NearestSet returns.
   357  func (t *Tree) NearestSet(k Keeper, q Comparable) {
   358  	if t.Root == nil {
   359  		return
   360  	}
   361  	t.Root.searchSet(q, k)
   362  
   363  	// Check whether we have retained a sentinel
   364  	// and flag removal if we have.
   365  	removeSentinel := k.Len() != 0 && k.Max().Comparable == nil
   366  
   367  	sort.Sort(sort.Reverse(k))
   368  
   369  	// This abuses the interface to drop the max.
   370  	// It is reasonable to do this because we know
   371  	// that the maximum value will now be at element
   372  	// zero, which is removed by the Pop method.
   373  	if removeSentinel {
   374  		k.Pop()
   375  	}
   376  }
   377  
   378  func (n *Node) searchSet(q Comparable, k Keeper) {
   379  	if n == nil {
   380  		return
   381  	}
   382  
   383  	c := q.Compare(n.Point, n.Plane)
   384  	k.Keep(ComparableDist{Comparable: n.Point, Dist: q.Distance(n.Point)})
   385  	if c <= 0 {
   386  		n.Left.searchSet(q, k)
   387  		if c*c <= k.Max().Dist {
   388  			n.Right.searchSet(q, k)
   389  		}
   390  		return
   391  	}
   392  	n.Right.searchSet(q, k)
   393  	if c*c <= k.Max().Dist {
   394  		n.Left.searchSet(q, k)
   395  	}
   396  	return
   397  }
   398  
   399  // An Operation is a function that operates on a Comparable. The bounding volume and tree depth
   400  // of the point is also provided. If done is returned true, the Operation is indicating that no
   401  // further work needs to be done and so the Do function should traverse no further.
   402  type Operation func(Comparable, *Bounding, int) (done bool)
   403  
   404  // Do performs fn on all values stored in the tree. A boolean is returned indicating whether the
   405  // Do traversal was interrupted by an Operation returning true. If fn alters stored values' sort
   406  // relationships, future tree operation behaviors are undefined.
   407  func (t *Tree) Do(fn Operation) bool {
   408  	if t.Root == nil {
   409  		return false
   410  	}
   411  	return t.Root.do(fn, 0)
   412  }
   413  
   414  func (n *Node) do(fn Operation, depth int) (done bool) {
   415  	if n.Left != nil {
   416  		done = n.Left.do(fn, depth+1)
   417  		if done {
   418  			return
   419  		}
   420  	}
   421  	done = fn(n.Point, n.Bounding, depth)
   422  	if done {
   423  		return
   424  	}
   425  	if n.Right != nil {
   426  		done = n.Right.do(fn, depth+1)
   427  	}
   428  	return
   429  }
   430  
   431  // DoBounded performs fn on all values stored in the tree that are within the specified bound.
   432  // If b is nil, the result is the same as a Do. A boolean is returned indicating whether the
   433  // DoBounded traversal was interrupted by an Operation returning true. If fn alters stored
   434  // values' sort relationships future tree operation behaviors are undefined.
   435  func (t *Tree) DoBounded(fn Operation, b *Bounding) bool {
   436  	if t.Root == nil {
   437  		return false
   438  	}
   439  	if b == nil {
   440  		return t.Root.do(fn, 0)
   441  	}
   442  	return t.Root.doBounded(fn, b, 0)
   443  }
   444  
   445  func (n *Node) doBounded(fn Operation, b *Bounding, depth int) (done bool) {
   446  	if n.Left != nil && b[0].Compare(n.Point, n.Plane) < 0 {
   447  		done = n.Left.doBounded(fn, b, depth+1)
   448  		if done {
   449  			return
   450  		}
   451  	}
   452  	if b.Contains(n.Point) {
   453  		done = fn(n.Point, b, depth)
   454  		if done {
   455  			return
   456  		}
   457  	}
   458  	if n.Right != nil && 0 < b[1].Compare(n.Point, n.Plane) {
   459  		done = n.Right.doBounded(fn, b, depth+1)
   460  	}
   461  	return
   462  }