github.com/gopherd/gonum@v0.0.4/spatial/vptree/vptree.go (about)

     1  // Copyright ©2019 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package vptree
     6  
     7  import (
     8  	"container/heap"
     9  	"errors"
    10  	"math"
    11  	"sort"
    12  
    13  	"math/rand"
    14  
    15  	"github.com/gopherd/gonum/stat"
    16  )
    17  
    18  // Comparable is the element interface for values stored in a vp-tree.
    19  type Comparable interface {
    20  	// Distance returns the distance between the receiver and the
    21  	// parameter. The returned distance must satisfy the properties
    22  	// of distances in a metric space.
    23  	//
    24  	// - a.Distance(a) == 0
    25  	// - a.Distance(b) >= 0
    26  	// - a.Distance(b) == b.Distance(a)
    27  	// - a.Distance(b) <= a.Distance(c)+c.Distance(b)
    28  	//
    29  	Distance(Comparable) float64
    30  }
    31  
    32  // Point represents a point in a Euclidean k-d space that satisfies the Comparable
    33  // interface.
    34  type Point []float64
    35  
    36  // Distance returns the Euclidean distance between c and the receiver. The concrete
    37  // type of c must be Point.
    38  func (p Point) Distance(c Comparable) float64 {
    39  	q := c.(Point)
    40  	var sum float64
    41  	for dim, c := range p {
    42  		d := c - q[dim]
    43  		sum += d * d
    44  	}
    45  	return math.Sqrt(sum)
    46  }
    47  
    48  // Node holds a single point value in a vantage point tree.
    49  type Node struct {
    50  	Point   Comparable
    51  	Radius  float64
    52  	Closer  *Node
    53  	Further *Node
    54  }
    55  
    56  // Tree implements a vantage point tree creation and nearest neighbor search.
    57  type Tree struct {
    58  	Root  *Node
    59  	Count int
    60  }
    61  
    62  // New returns a vantage point tree constructed from the values in p. The effort
    63  // parameter specifies how much work should be put into optimizing the choice of
    64  // vantage point. If effort is one or less, random vantage points are chosen.
    65  // The order of elements in p will be altered after New returns. The src parameter
    66  // provides the source of randomness for vantage point selection. If src is nil
    67  // global rand package functions are used. Points in p must not be infinitely
    68  // distant.
    69  func New(p []Comparable, effort int, src rand.Source) (t *Tree, err error) {
    70  	var intn func(int) int
    71  	var shuf func(n int, swap func(i, j int))
    72  	if src == nil {
    73  		intn = rand.Intn
    74  		shuf = rand.Shuffle
    75  	} else {
    76  		rnd := rand.New(src)
    77  		intn = rnd.Intn
    78  		shuf = rnd.Shuffle
    79  	}
    80  	b := builder{work: make([]float64, len(p)), intn: intn, shuf: shuf}
    81  
    82  	defer func() {
    83  		switch r := recover(); r {
    84  		case nil:
    85  		case pointAtInfinity:
    86  			t = nil
    87  			err = pointAtInfinity
    88  		default:
    89  			panic(r)
    90  		}
    91  	}()
    92  
    93  	t = &Tree{
    94  		Root:  b.build(p, effort),
    95  		Count: len(p),
    96  	}
    97  	return t, nil
    98  }
    99  
   100  var pointAtInfinity = errors.New("vptree: point at infinity")
   101  
   102  // builder performs vp-tree construction as described for the simple vp-tree
   103  // algorithm in http://pnylab.com/papers/vptree/vptree.pdf.
   104  type builder struct {
   105  	work []float64
   106  	intn func(n int) int
   107  	shuf func(n int, swap func(i, j int))
   108  }
   109  
   110  func (b *builder) build(s []Comparable, effort int) *Node {
   111  	if len(s) <= 1 {
   112  		if len(s) == 0 {
   113  			return nil
   114  		}
   115  		return &Node{Point: s[0]}
   116  	}
   117  	n := Node{Point: b.selectVantage(s, effort)}
   118  	radius, closer, further := b.partition(n.Point, s)
   119  	n.Radius = radius
   120  	n.Closer = b.build(closer, effort)
   121  	n.Further = b.build(further, effort)
   122  	return &n
   123  }
   124  
   125  func (b *builder) selectVantage(s []Comparable, effort int) Comparable {
   126  	if effort <= 1 {
   127  		return s[b.intn(len(s))]
   128  	}
   129  	if effort > len(s) {
   130  		effort = len(s)
   131  	}
   132  	var best Comparable
   133  	bestVar := -1.0
   134  	b.work = b.work[:effort]
   135  	choices := b.random(effort, s)
   136  	for _, p := range choices {
   137  		for i, q := range choices {
   138  			d := p.Distance(q)
   139  			if math.IsInf(d, 0) {
   140  				panic(pointAtInfinity)
   141  			}
   142  			b.work[i] = d
   143  		}
   144  		variance := stat.Variance(b.work, nil)
   145  		if variance > bestVar {
   146  			best, bestVar = p, variance
   147  		}
   148  	}
   149  	if best == nil {
   150  		// This should never be reached.
   151  		panic("vptree: could not find vantage point")
   152  	}
   153  	return best
   154  }
   155  
   156  func (b *builder) random(n int, s []Comparable) []Comparable {
   157  	if n >= len(s) {
   158  		n = len(s)
   159  	}
   160  	b.shuf(len(s), func(i, j int) { s[i], s[j] = s[j], s[i] })
   161  	return s[:n]
   162  }
   163  
   164  func (b *builder) partition(v Comparable, s []Comparable) (radius float64, closer, further []Comparable) {
   165  	b.work = b.work[:len(s)]
   166  	for i, p := range s {
   167  		d := v.Distance(p)
   168  		if math.IsInf(d, 0) {
   169  			panic(pointAtInfinity)
   170  		}
   171  		b.work[i] = d
   172  	}
   173  	sort.Sort(byDist{dists: b.work, points: s})
   174  
   175  	// Note that this does not conform exactly to the description
   176  	// in the paper which specifies d(p, s) < mu for L; in cases
   177  	// where the median element has a lower indexed element with
   178  	// the same distance from the vantage point, L will include a
   179  	// d(p, s) == mu.
   180  	// The additional work required to satisfy the algorithm is
   181  	// not worth doing as it has no effect on the correctness or
   182  	// performance of the algorithm.
   183  	radius = b.work[len(b.work)/2]
   184  
   185  	if len(b.work) > 1 {
   186  		// Remove vantage if it is present.
   187  		closer = s[1 : len(b.work)/2]
   188  	}
   189  	further = s[len(b.work)/2:]
   190  	return radius, closer, further
   191  }
   192  
   193  type byDist struct {
   194  	dists  []float64
   195  	points []Comparable
   196  }
   197  
   198  func (c byDist) Len() int           { return len(c.dists) }
   199  func (c byDist) Less(i, j int) bool { return c.dists[i] < c.dists[j] }
   200  func (c byDist) Swap(i, j int) {
   201  	c.dists[i], c.dists[j] = c.dists[j], c.dists[i]
   202  	c.points[i], c.points[j] = c.points[j], c.points[i]
   203  }
   204  
   205  // Len returns the number of elements in the tree.
   206  func (t *Tree) Len() int { return t.Count }
   207  
   208  var inf = math.Inf(1)
   209  
   210  // Nearest returns the nearest value to the query and the distance between them.
   211  func (t *Tree) Nearest(q Comparable) (Comparable, float64) {
   212  	if t.Root == nil {
   213  		return nil, inf
   214  	}
   215  	n, dist := t.Root.search(q, inf)
   216  	if n == nil {
   217  		return nil, inf
   218  	}
   219  	return n.Point, dist
   220  }
   221  
   222  func (n *Node) search(q Comparable, dist float64) (*Node, float64) {
   223  	if n == nil {
   224  		return nil, inf
   225  	}
   226  
   227  	d := q.Distance(n.Point)
   228  	dist = math.Min(dist, d)
   229  
   230  	bn := n
   231  	if d < n.Radius {
   232  		cn, cd := n.Closer.search(q, dist)
   233  		if cd < dist {
   234  			bn, dist = cn, cd
   235  		}
   236  		if d+dist >= n.Radius {
   237  			fn, fd := n.Further.search(q, dist)
   238  			if fd < dist {
   239  				bn, dist = fn, fd
   240  			}
   241  		}
   242  	} else {
   243  		fn, fd := n.Further.search(q, dist)
   244  		if fd < dist {
   245  			bn, dist = fn, fd
   246  		}
   247  		if d-dist <= n.Radius {
   248  			cn, cd := n.Closer.search(q, dist)
   249  			if cd < dist {
   250  				bn, dist = cn, cd
   251  			}
   252  		}
   253  	}
   254  
   255  	return bn, dist
   256  }
   257  
   258  // ComparableDist holds a Comparable and a distance to a specific query. A nil Comparable
   259  // is used to mark the end of the heap, so clients should not store nil values except for
   260  // this purpose.
   261  type ComparableDist struct {
   262  	Comparable Comparable
   263  	Dist       float64
   264  }
   265  
   266  // Heap is a max heap sorted on Dist.
   267  type Heap []ComparableDist
   268  
   269  func (h *Heap) Max() ComparableDist  { return (*h)[0] }
   270  func (h *Heap) Len() int             { return len(*h) }
   271  func (h *Heap) Less(i, j int) bool   { return (*h)[i].Comparable == nil || (*h)[i].Dist > (*h)[j].Dist }
   272  func (h *Heap) Swap(i, j int)        { (*h)[i], (*h)[j] = (*h)[j], (*h)[i] }
   273  func (h *Heap) Push(x interface{})   { (*h) = append(*h, x.(ComparableDist)) }
   274  func (h *Heap) Pop() (i interface{}) { i, *h = (*h)[len(*h)-1], (*h)[:len(*h)-1]; return i }
   275  
   276  // NKeeper is a Keeper that retains the n best ComparableDists that have been passed to Keep.
   277  type NKeeper struct {
   278  	Heap
   279  }
   280  
   281  // NewNKeeper returns an NKeeper with the max value of the heap set to infinite distance. The
   282  // returned NKeeper is able to retain at most n values.
   283  func NewNKeeper(n int) *NKeeper {
   284  	k := NKeeper{make(Heap, 1, n)}
   285  	k.Heap[0].Dist = inf
   286  	return &k
   287  }
   288  
   289  // Keep adds c to the heap if its distance is less than the maximum value of the heap. If adding
   290  // c would increase the size of the heap beyond the initial maximum length, the maximum value of
   291  // the heap is dropped.
   292  func (k *NKeeper) Keep(c ComparableDist) {
   293  	if c.Dist <= k.Heap[0].Dist { // Favour later finds to displace sentinel.
   294  		if len(k.Heap) == cap(k.Heap) {
   295  			heap.Pop(k)
   296  		}
   297  		heap.Push(k, c)
   298  	}
   299  }
   300  
   301  // DistKeeper is a Keeper that retains the ComparableDists within the specified distance of the
   302  // query that it is called to Keep.
   303  type DistKeeper struct {
   304  	Heap
   305  }
   306  
   307  // NewDistKeeper returns an DistKeeper with the maximum value of the heap set to d.
   308  func NewDistKeeper(d float64) *DistKeeper { return &DistKeeper{Heap{{Dist: d}}} }
   309  
   310  // Keep adds c to the heap if its distance is less than or equal to the max value of the heap.
   311  func (k *DistKeeper) Keep(c ComparableDist) {
   312  	if c.Dist <= k.Heap[0].Dist {
   313  		heap.Push(k, c)
   314  	}
   315  }
   316  
   317  // Keeper implements a conditional max heap sorted on the Dist field of the ComparableDist type.
   318  // vantage point search is guided by the distance stored in the max value of the heap.
   319  type Keeper interface {
   320  	Keep(ComparableDist) // Keep conditionally pushes the provided ComparableDist onto the heap.
   321  	Max() ComparableDist // Max returns the maximum element of the Keeper.
   322  	heap.Interface
   323  }
   324  
   325  // NearestSet finds the nearest values to the query accepted by the provided Keeper, k.
   326  // k must be able to return a ComparableDist specifying the maximum acceptable distance
   327  // when Max() is called, and retains the results of the search in min sorted order after
   328  // the call to NearestSet returns.
   329  // If a sentinel ComparableDist with a nil Comparable is used by the Keeper to mark the
   330  // maximum distance, NearestSet will remove it before returning.
   331  func (t *Tree) NearestSet(k Keeper, q Comparable) {
   332  	if t.Root == nil {
   333  		return
   334  	}
   335  	t.Root.searchSet(q, k)
   336  
   337  	// Check whether we have retained a sentinel
   338  	// and flag removal if we have.
   339  	removeSentinel := k.Len() != 0 && k.Max().Comparable == nil
   340  
   341  	sort.Sort(sort.Reverse(k))
   342  
   343  	// This abuses the interface to drop the max.
   344  	// It is reasonable to do this because we know
   345  	// that the maximum value will now be at element
   346  	// zero, which is removed by the Pop method.
   347  	if removeSentinel {
   348  		k.Pop()
   349  	}
   350  }
   351  
   352  func (n *Node) searchSet(q Comparable, k Keeper) {
   353  	if n == nil {
   354  		return
   355  	}
   356  
   357  	k.Keep(ComparableDist{Comparable: n.Point, Dist: q.Distance(n.Point)})
   358  
   359  	d := q.Distance(n.Point)
   360  	if d < n.Radius {
   361  		n.Closer.searchSet(q, k)
   362  		if d+k.Max().Dist >= n.Radius {
   363  			n.Further.searchSet(q, k)
   364  		}
   365  	} else {
   366  		n.Further.searchSet(q, k)
   367  		if d-k.Max().Dist <= n.Radius {
   368  			n.Closer.searchSet(q, k)
   369  		}
   370  	}
   371  }
   372  
   373  // Operation is a function that operates on a Comparable. The bounding volume and tree depth
   374  // of the point is also provided. If done is returned true, the Operation is indicating that no
   375  // further work needs to be done and so the Do function should traverse no further.
   376  type Operation func(Comparable, int) (done bool)
   377  
   378  // Do performs fn on all values stored in the tree. A boolean is returned indicating whether the
   379  // Do traversal was interrupted by an Operation returning true. If fn alters stored values' sort
   380  // relationships, future tree operation behaviors are undefined.
   381  func (t *Tree) Do(fn Operation) bool {
   382  	if t.Root == nil {
   383  		return false
   384  	}
   385  	return t.Root.do(fn, 0)
   386  }
   387  
   388  func (n *Node) do(fn Operation, depth int) (done bool) {
   389  	if n.Closer != nil {
   390  		done = n.Closer.do(fn, depth+1)
   391  		if done {
   392  			return
   393  		}
   394  	}
   395  	done = fn(n.Point, depth)
   396  	if done {
   397  		return
   398  	}
   399  	if n.Further != nil {
   400  		done = n.Further.do(fn, depth+1)
   401  	}
   402  	return
   403  }