github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/spatial/kdtree/kdtree.go (about)

     1  // Copyright ©2019 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package kdtree
     6  
     7  import (
     8  	"container/heap"
     9  	"fmt"
    10  	"math"
    11  	"sort"
    12  )
    13  
    14  // Interface is the set of methods required for construction of efficiently
    15  // searchable k-d trees. A k-d tree may be constructed without using the
    16  // Interface type, but it is likely to have reduced search performance.
    17  type Interface interface {
    18  	// Index returns the ith element of the list of points.
    19  	Index(i int) Comparable
    20  
    21  	// Len returns the length of the list.
    22  	Len() int
    23  
    24  	// Pivot partitions the list based on the dimension specified.
    25  	Pivot(Dim) int
    26  
    27  	// Slice returns a slice of the list using zero-based half
    28  	// open indexing equivalent to built-in slice indexing.
    29  	Slice(start, end int) Interface
    30  }
    31  
    32  // Bounder returns a bounding volume containing the list of points. Bounds may return nil.
    33  type Bounder interface {
    34  	Bounds() *Bounding
    35  }
    36  
    37  type bounder interface {
    38  	Interface
    39  	Bounder
    40  }
    41  
    42  // Dim is an index into a point's coordinates.
    43  type Dim int
    44  
    45  // Comparable is the element interface for values stored in a k-d tree.
    46  type Comparable interface {
    47  	// Compare returns the signed distance of a from the plane passing through
    48  	// b and perpendicular to the dimension d.
    49  	//
    50  	// Given c = a.Compare(b, d):
    51  	//  c = a_d - b_d
    52  	//
    53  	Compare(Comparable, Dim) float64
    54  
    55  	// Dims returns the number of dimensions described in the Comparable.
    56  	Dims() int
    57  
    58  	// Distance returns the squared Euclidean distance between the receiver and
    59  	// the parameter.
    60  	Distance(Comparable) float64
    61  }
    62  
    63  // Extender is a Comparable that can increase a bounding volume to include the
    64  // point represented by the Comparable.
    65  type Extender interface {
    66  	Comparable
    67  
    68  	// Extend returns a bounding box that has been extended to include the
    69  	// receiver. Extend may return nil.
    70  	Extend(*Bounding) *Bounding
    71  }
    72  
    73  // Bounding represents a volume bounding box.
    74  type Bounding struct {
    75  	Min, Max Comparable
    76  }
    77  
    78  // Contains returns whether c is within the volume of the Bounding. A nil Bounding
    79  // returns true.
    80  func (b *Bounding) Contains(c Comparable) bool {
    81  	if b == nil {
    82  		return true
    83  	}
    84  	for d := Dim(0); d < Dim(c.Dims()); d++ {
    85  		if c.Compare(b.Min, d) < 0 || 0 < c.Compare(b.Max, d) {
    86  			return false
    87  		}
    88  	}
    89  	return true
    90  }
    91  
    92  // Node holds a single point value in a k-d tree.
    93  type Node struct {
    94  	Point       Comparable
    95  	Plane       Dim
    96  	Left, Right *Node
    97  	*Bounding
    98  }
    99  
   100  func (n *Node) String() string {
   101  	if n == nil {
   102  		return "<nil>"
   103  	}
   104  	return fmt.Sprintf("%.3f %d", n.Point, n.Plane)
   105  }
   106  
   107  // Tree implements a k-d tree creation and nearest neighbor search.
   108  type Tree struct {
   109  	Root  *Node
   110  	Count int
   111  }
   112  
   113  // New returns a k-d tree constructed from the values in p. If p is a Bounder and
   114  // bounding is true, bounds are determined for each node.
   115  // The ordering of elements in p may be altered after New returns.
   116  func New(p Interface, bounding bool) *Tree {
   117  	if p, ok := p.(bounder); ok && bounding {
   118  		return &Tree{
   119  			Root:  buildBounded(p, 0, bounding),
   120  			Count: p.Len(),
   121  		}
   122  	}
   123  	return &Tree{
   124  		Root:  build(p, 0),
   125  		Count: p.Len(),
   126  	}
   127  }
   128  
   129  func build(p Interface, plane Dim) *Node {
   130  	if p.Len() == 0 {
   131  		return nil
   132  	}
   133  
   134  	piv := p.Pivot(plane)
   135  	d := p.Index(piv)
   136  	np := (plane + 1) % Dim(d.Dims())
   137  
   138  	return &Node{
   139  		Point:    d,
   140  		Plane:    plane,
   141  		Left:     build(p.Slice(0, piv), np),
   142  		Right:    build(p.Slice(piv+1, p.Len()), np),
   143  		Bounding: nil,
   144  	}
   145  }
   146  
   147  func buildBounded(p bounder, plane Dim, bounding bool) *Node {
   148  	if p.Len() == 0 {
   149  		return nil
   150  	}
   151  
   152  	piv := p.Pivot(plane)
   153  	d := p.Index(piv)
   154  	np := (plane + 1) % Dim(d.Dims())
   155  
   156  	b := p.Bounds()
   157  	return &Node{
   158  		Point:    d,
   159  		Plane:    plane,
   160  		Left:     buildBounded(p.Slice(0, piv).(bounder), np, bounding),
   161  		Right:    buildBounded(p.Slice(piv+1, p.Len()).(bounder), np, bounding),
   162  		Bounding: b,
   163  	}
   164  }
   165  
   166  // Insert adds a point to the tree, updating the bounding volumes if bounding is
   167  // true, and the tree is empty or the tree already has bounding volumes stored,
   168  // and c is an Extender. No rebalancing of the tree is performed.
   169  func (t *Tree) Insert(c Comparable, bounding bool) {
   170  	t.Count++
   171  	if t.Root != nil {
   172  		bounding = t.Root.Bounding != nil
   173  	}
   174  	if c, ok := c.(Extender); ok && bounding {
   175  		t.Root = t.Root.insertBounded(c, 0, bounding)
   176  		return
   177  	} else if !ok && t.Root != nil {
   178  		// If we are not rebounding, mark the tree as non-bounded.
   179  		t.Root.Bounding = nil
   180  	}
   181  	t.Root = t.Root.insert(c, 0)
   182  }
   183  
   184  func (n *Node) insert(c Comparable, d Dim) *Node {
   185  	if n == nil {
   186  		return &Node{
   187  			Point:    c,
   188  			Plane:    d,
   189  			Bounding: nil,
   190  		}
   191  	}
   192  
   193  	d = (n.Plane + 1) % Dim(c.Dims())
   194  	if c.Compare(n.Point, n.Plane) <= 0 {
   195  		n.Left = n.Left.insert(c, d)
   196  	} else {
   197  		n.Right = n.Right.insert(c, d)
   198  	}
   199  
   200  	return n
   201  }
   202  
   203  func (n *Node) insertBounded(c Extender, d Dim, bounding bool) *Node {
   204  	if n == nil {
   205  		var b *Bounding
   206  		if bounding {
   207  			b = c.Extend(b)
   208  		}
   209  		return &Node{
   210  			Point:    c,
   211  			Plane:    d,
   212  			Bounding: b,
   213  		}
   214  	}
   215  
   216  	if bounding {
   217  		n.Bounding = c.Extend(n.Bounding)
   218  	}
   219  	d = (n.Plane + 1) % Dim(c.Dims())
   220  	if c.Compare(n.Point, n.Plane) <= 0 {
   221  		n.Left = n.Left.insertBounded(c, d, bounding)
   222  	} else {
   223  		n.Right = n.Right.insertBounded(c, d, bounding)
   224  	}
   225  
   226  	return n
   227  }
   228  
   229  // Len returns the number of elements in the tree.
   230  func (t *Tree) Len() int { return t.Count }
   231  
   232  // Contains returns whether a Comparable is in the bounds of the tree. If no bounding has
   233  // been constructed Contains returns true.
   234  func (t *Tree) Contains(c Comparable) bool {
   235  	if t.Root.Bounding == nil {
   236  		return true
   237  	}
   238  	return t.Root.Contains(c)
   239  }
   240  
   241  var inf = math.Inf(1)
   242  
   243  // Nearest returns the nearest value to the query and the distance between them.
   244  func (t *Tree) Nearest(q Comparable) (Comparable, float64) {
   245  	if t.Root == nil {
   246  		return nil, inf
   247  	}
   248  	n, dist := t.Root.search(q, inf)
   249  	if n == nil {
   250  		return nil, inf
   251  	}
   252  	return n.Point, dist
   253  }
   254  
   255  func (n *Node) search(q Comparable, dist float64) (*Node, float64) {
   256  	if n == nil {
   257  		return nil, inf
   258  	}
   259  
   260  	c := q.Compare(n.Point, n.Plane)
   261  	dist = math.Min(dist, q.Distance(n.Point))
   262  
   263  	bn := n
   264  	if c <= 0 {
   265  		ln, ld := n.Left.search(q, dist)
   266  		if ld < dist {
   267  			bn, dist = ln, ld
   268  		}
   269  		if c*c < dist {
   270  			rn, rd := n.Right.search(q, dist)
   271  			if rd < dist {
   272  				bn, dist = rn, rd
   273  			}
   274  		}
   275  		return bn, dist
   276  	}
   277  	rn, rd := n.Right.search(q, dist)
   278  	if rd < dist {
   279  		bn, dist = rn, rd
   280  	}
   281  	if c*c < dist {
   282  		ln, ld := n.Left.search(q, dist)
   283  		if ld < dist {
   284  			bn, dist = ln, ld
   285  		}
   286  	}
   287  	return bn, dist
   288  }
   289  
   290  // ComparableDist holds a Comparable and a distance to a specific query. A nil Comparable
   291  // is used to mark the end of the heap, so clients should not store nil values except for
   292  // this purpose.
   293  type ComparableDist struct {
   294  	Comparable Comparable
   295  	Dist       float64
   296  }
   297  
   298  // Heap is a max heap sorted on Dist.
   299  type Heap []ComparableDist
   300  
   301  func (h *Heap) Max() ComparableDist  { return (*h)[0] }
   302  func (h *Heap) Len() int             { return len(*h) }
   303  func (h *Heap) Less(i, j int) bool   { return (*h)[i].Comparable == nil || (*h)[i].Dist > (*h)[j].Dist }
   304  func (h *Heap) Swap(i, j int)        { (*h)[i], (*h)[j] = (*h)[j], (*h)[i] }
   305  func (h *Heap) Push(x interface{})   { (*h) = append(*h, x.(ComparableDist)) }
   306  func (h *Heap) Pop() (i interface{}) { i, *h = (*h)[len(*h)-1], (*h)[:len(*h)-1]; return i }
   307  
   308  // NKeeper is a Keeper that retains the n best ComparableDists that have been passed to Keep.
   309  type NKeeper struct {
   310  	Heap
   311  }
   312  
   313  // NewNKeeper returns an NKeeper with the max value of the heap set to infinite distance. The
   314  // returned NKeeper is able to retain at most n values.
   315  func NewNKeeper(n int) *NKeeper {
   316  	k := NKeeper{make(Heap, 1, n)}
   317  	k.Heap[0].Dist = inf
   318  	return &k
   319  }
   320  
   321  // Keep adds c to the heap if its distance is less than the maximum value of the heap. If adding
   322  // c would increase the size of the heap beyond the initial maximum length, the maximum value of
   323  // the heap is dropped.
   324  func (k *NKeeper) Keep(c ComparableDist) {
   325  	if c.Dist <= k.Heap[0].Dist { // Favour later finds to displace sentinel.
   326  		if len(k.Heap) == cap(k.Heap) {
   327  			heap.Pop(k)
   328  		}
   329  		heap.Push(k, c)
   330  	}
   331  }
   332  
   333  // DistKeeper is a Keeper that retains the ComparableDists within the specified distance of the
   334  // query that it is called to Keep.
   335  type DistKeeper struct {
   336  	Heap
   337  }
   338  
   339  // NewDistKeeper returns an DistKeeper with the maximum value of the heap set to d.
   340  func NewDistKeeper(d float64) *DistKeeper { return &DistKeeper{Heap{{Dist: d}}} }
   341  
   342  // Keep adds c to the heap if its distance is less than or equal to the max value of the heap.
   343  func (k *DistKeeper) Keep(c ComparableDist) {
   344  	if c.Dist <= k.Heap[0].Dist {
   345  		heap.Push(k, c)
   346  	}
   347  }
   348  
   349  // Keeper implements a conditional max heap sorted on the Dist field of the ComparableDist type.
   350  // kd search is guided by the distance stored in the max value of the heap.
   351  type Keeper interface {
   352  	Keep(ComparableDist) // Keep conditionally pushes the provided ComparableDist onto the heap.
   353  	Max() ComparableDist // Max returns the maximum element of the Keeper.
   354  	heap.Interface
   355  }
   356  
   357  // NearestSet finds the nearest values to the query accepted by the provided Keeper, k.
   358  // k must be able to return a ComparableDist specifying the maximum acceptable distance
   359  // when Max() is called, and retains the results of the search in min sorted order after
   360  // the call to NearestSet returns.
   361  // If a sentinel ComparableDist with a nil Comparable is used by the Keeper to mark the
   362  // maximum distance, NearestSet will remove it before returning.
   363  func (t *Tree) NearestSet(k Keeper, q Comparable) {
   364  	if t.Root == nil {
   365  		return
   366  	}
   367  	t.Root.searchSet(q, k)
   368  
   369  	// Check whether we have retained a sentinel
   370  	// and flag removal if we have.
   371  	removeSentinel := k.Len() != 0 && k.Max().Comparable == nil
   372  
   373  	sort.Sort(sort.Reverse(k))
   374  
   375  	// This abuses the interface to drop the max.
   376  	// It is reasonable to do this because we know
   377  	// that the maximum value will now be at element
   378  	// zero, which is removed by the Pop method.
   379  	if removeSentinel {
   380  		k.Pop()
   381  	}
   382  }
   383  
   384  func (n *Node) searchSet(q Comparable, k Keeper) {
   385  	if n == nil {
   386  		return
   387  	}
   388  
   389  	c := q.Compare(n.Point, n.Plane)
   390  	k.Keep(ComparableDist{Comparable: n.Point, Dist: q.Distance(n.Point)})
   391  	if c <= 0 {
   392  		n.Left.searchSet(q, k)
   393  		if c*c <= k.Max().Dist {
   394  			n.Right.searchSet(q, k)
   395  		}
   396  		return
   397  	}
   398  	n.Right.searchSet(q, k)
   399  	if c*c <= k.Max().Dist {
   400  		n.Left.searchSet(q, k)
   401  	}
   402  }
   403  
   404  // Operation is a function that operates on a Comparable. The bounding volume and tree depth
   405  // of the point is also provided. If done is returned true, the Operation is indicating that no
   406  // further work needs to be done and so the Do function should traverse no further.
   407  type Operation func(Comparable, *Bounding, int) (done bool)
   408  
   409  // Do performs fn on all values stored in the tree. A boolean is returned indicating whether the
   410  // Do traversal was interrupted by an Operation returning true. If fn alters stored values' sort
   411  // relationships, future tree operation behaviors are undefined.
   412  func (t *Tree) Do(fn Operation) bool {
   413  	if t.Root == nil {
   414  		return false
   415  	}
   416  	return t.Root.do(fn, 0)
   417  }
   418  
   419  func (n *Node) do(fn Operation, depth int) (done bool) {
   420  	if n.Left != nil {
   421  		done = n.Left.do(fn, depth+1)
   422  		if done {
   423  			return
   424  		}
   425  	}
   426  	done = fn(n.Point, n.Bounding, depth)
   427  	if done {
   428  		return
   429  	}
   430  	if n.Right != nil {
   431  		done = n.Right.do(fn, depth+1)
   432  	}
   433  	return
   434  }
   435  
   436  // DoBounded performs fn on all values stored in the tree that are within the specified bound.
   437  // If b is nil, the result is the same as a Do. A boolean is returned indicating whether the
   438  // DoBounded traversal was interrupted by an Operation returning true. If fn alters stored
   439  // values' sort relationships future tree operation behaviors are undefined.
   440  func (t *Tree) DoBounded(b *Bounding, fn Operation) bool {
   441  	if t.Root == nil {
   442  		return false
   443  	}
   444  	if b == nil {
   445  		return t.Root.do(fn, 0)
   446  	}
   447  	return t.Root.doBounded(fn, b, 0)
   448  }
   449  
   450  func (n *Node) doBounded(fn Operation, b *Bounding, depth int) (done bool) {
   451  	if n.Left != nil && b.Min.Compare(n.Point, n.Plane) < 0 {
   452  		done = n.Left.doBounded(fn, b, depth+1)
   453  		if done {
   454  			return
   455  		}
   456  	}
   457  	if b.Contains(n.Point) {
   458  		done = fn(n.Point, n.Bounding, depth)
   459  		if done {
   460  			return
   461  		}
   462  	}
   463  	if n.Right != nil && 0 < b.Max.Compare(n.Point, n.Plane) {
   464  		done = n.Right.doBounded(fn, b, depth+1)
   465  	}
   466  	return
   467  }