gonum.org/v1/gonum@v0.14.0/spatial/vptree/vptree.go (about) 1 // Copyright ©2019 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package vptree 6 7 import ( 8 "container/heap" 9 "errors" 10 "math" 11 "sort" 12 13 "golang.org/x/exp/rand" 14 15 "gonum.org/v1/gonum/stat" 16 ) 17 18 // Comparable is the element interface for values stored in a vp-tree. 19 type Comparable interface { 20 // Distance returns the distance between the receiver and the 21 // parameter. The returned distance must satisfy the properties 22 // of distances in a metric space. 23 // 24 // - a.Distance(a) == 0 25 // - a.Distance(b) >= 0 26 // - a.Distance(b) == b.Distance(a) 27 // - a.Distance(b) <= a.Distance(c)+c.Distance(b) 28 // 29 Distance(Comparable) float64 30 } 31 32 // Point represents a point in a Euclidean k-d space that satisfies the Comparable 33 // interface. 34 type Point []float64 35 36 // Distance returns the Euclidean distance between c and the receiver. The concrete 37 // type of c must be Point. 38 func (p Point) Distance(c Comparable) float64 { 39 q := c.(Point) 40 var sum float64 41 for dim, c := range p { 42 d := c - q[dim] 43 sum += d * d 44 } 45 return math.Sqrt(sum) 46 } 47 48 // Node holds a single point value in a vantage point tree. 49 type Node struct { 50 Point Comparable 51 Radius float64 52 Closer *Node 53 Further *Node 54 } 55 56 // Tree implements a vantage point tree creation and nearest neighbor search. 57 type Tree struct { 58 Root *Node 59 Count int 60 } 61 62 // New returns a vantage point tree constructed from the values in p. The effort 63 // parameter specifies how much work should be put into optimizing the choice of 64 // vantage point. If effort is one or less, random vantage points are chosen. 65 // The order of elements in p will be altered after New returns. The src parameter 66 // provides the source of randomness for vantage point selection. If src is nil 67 // global rand package functions are used. Points in p must not be infinitely 68 // distant. 69 func New(p []Comparable, effort int, src rand.Source) (t *Tree, err error) { 70 var intn func(int) int 71 var shuf func(n int, swap func(i, j int)) 72 if src == nil { 73 intn = rand.Intn 74 shuf = rand.Shuffle 75 } else { 76 rnd := rand.New(src) 77 intn = rnd.Intn 78 shuf = rnd.Shuffle 79 } 80 b := builder{work: make([]float64, len(p)), intn: intn, shuf: shuf} 81 82 defer func() { 83 switch r := recover(); r { 84 case nil: 85 case pointAtInfinity: 86 t = nil 87 err = pointAtInfinity 88 default: 89 panic(r) 90 } 91 }() 92 93 t = &Tree{ 94 Root: b.build(p, effort), 95 Count: len(p), 96 } 97 return t, nil 98 } 99 100 var pointAtInfinity = errors.New("vptree: point at infinity") 101 102 // builder performs vp-tree construction as described for the simple vp-tree 103 // algorithm in http://pnylab.com/papers/vptree/vptree.pdf. 104 type builder struct { 105 work []float64 106 intn func(n int) int 107 shuf func(n int, swap func(i, j int)) 108 } 109 110 func (b *builder) build(s []Comparable, effort int) *Node { 111 if len(s) <= 1 { 112 if len(s) == 0 { 113 return nil 114 } 115 return &Node{Point: s[0]} 116 } 117 n := Node{Point: b.selectVantage(s, effort)} 118 radius, closer, further := b.partition(n.Point, s) 119 n.Radius = radius 120 n.Closer = b.build(closer, effort) 121 n.Further = b.build(further, effort) 122 return &n 123 } 124 125 func (b *builder) selectVantage(s []Comparable, effort int) Comparable { 126 if effort <= 1 { 127 return s[b.intn(len(s))] 128 } 129 if effort > len(s) { 130 effort = len(s) 131 } 132 var best Comparable 133 bestVar := -1.0 134 b.work = b.work[:effort] 135 choices := b.random(effort, s) 136 for _, p := range choices { 137 for i, q := range choices { 138 d := p.Distance(q) 139 if math.IsInf(d, 0) { 140 panic(pointAtInfinity) 141 } 142 b.work[i] = d 143 } 144 variance := stat.Variance(b.work, nil) 145 if variance > bestVar { 146 best, bestVar = p, variance 147 } 148 } 149 if best == nil { 150 // This should never be reached. 151 panic("vptree: could not find vantage point") 152 } 153 return best 154 } 155 156 func (b *builder) random(n int, s []Comparable) []Comparable { 157 if n >= len(s) { 158 n = len(s) 159 } 160 b.shuf(len(s), func(i, j int) { s[i], s[j] = s[j], s[i] }) 161 return s[:n] 162 } 163 164 func (b *builder) partition(v Comparable, s []Comparable) (radius float64, closer, further []Comparable) { 165 b.work = b.work[:len(s)] 166 for i, p := range s { 167 d := v.Distance(p) 168 if math.IsInf(d, 0) { 169 panic(pointAtInfinity) 170 } 171 b.work[i] = d 172 } 173 sort.Sort(byDist{dists: b.work, points: s}) 174 175 // Note that this does not conform exactly to the description 176 // in the paper which specifies d(p, s) < mu for L; in cases 177 // where the median element has a lower indexed element with 178 // the same distance from the vantage point, L will include a 179 // d(p, s) == mu. 180 // The additional work required to satisfy the algorithm is 181 // not worth doing as it has no effect on the correctness or 182 // performance of the algorithm. 183 radius = b.work[len(b.work)/2] 184 185 if len(b.work) > 1 { 186 // Remove vantage if it is present. 187 closer = s[1 : len(b.work)/2] 188 } 189 further = s[len(b.work)/2:] 190 return radius, closer, further 191 } 192 193 type byDist struct { 194 dists []float64 195 points []Comparable 196 } 197 198 func (c byDist) Len() int { return len(c.dists) } 199 func (c byDist) Less(i, j int) bool { return c.dists[i] < c.dists[j] } 200 func (c byDist) Swap(i, j int) { 201 c.dists[i], c.dists[j] = c.dists[j], c.dists[i] 202 c.points[i], c.points[j] = c.points[j], c.points[i] 203 } 204 205 // Len returns the number of elements in the tree. 206 func (t *Tree) Len() int { return t.Count } 207 208 var inf = math.Inf(1) 209 210 // Nearest returns the nearest value to the query and the distance between them. 211 func (t *Tree) Nearest(q Comparable) (Comparable, float64) { 212 if t.Root == nil { 213 return nil, inf 214 } 215 n, dist := t.Root.search(q, inf) 216 if n == nil { 217 return nil, inf 218 } 219 return n.Point, dist 220 } 221 222 func (n *Node) search(q Comparable, dist float64) (*Node, float64) { 223 if n == nil { 224 return nil, inf 225 } 226 227 d := q.Distance(n.Point) 228 dist = math.Min(dist, d) 229 230 bn := n 231 if d < n.Radius { 232 cn, cd := n.Closer.search(q, dist) 233 if cd < dist { 234 bn, dist = cn, cd 235 } 236 if d+dist >= n.Radius { 237 fn, fd := n.Further.search(q, dist) 238 if fd < dist { 239 bn, dist = fn, fd 240 } 241 } 242 } else { 243 fn, fd := n.Further.search(q, dist) 244 if fd < dist { 245 bn, dist = fn, fd 246 } 247 if d-dist <= n.Radius { 248 cn, cd := n.Closer.search(q, dist) 249 if cd < dist { 250 bn, dist = cn, cd 251 } 252 } 253 } 254 255 return bn, dist 256 } 257 258 // ComparableDist holds a Comparable and a distance to a specific query. A nil Comparable 259 // is used to mark the end of the heap, so clients should not store nil values except for 260 // this purpose. 261 type ComparableDist struct { 262 Comparable Comparable 263 Dist float64 264 } 265 266 // Heap is a max heap sorted on Dist. 267 type Heap []ComparableDist 268 269 func (h *Heap) Max() ComparableDist { return (*h)[0] } 270 func (h *Heap) Len() int { return len(*h) } 271 func (h *Heap) Less(i, j int) bool { return (*h)[i].Comparable == nil || (*h)[i].Dist > (*h)[j].Dist } 272 func (h *Heap) Swap(i, j int) { (*h)[i], (*h)[j] = (*h)[j], (*h)[i] } 273 func (h *Heap) Push(x interface{}) { (*h) = append(*h, x.(ComparableDist)) } 274 func (h *Heap) Pop() (i interface{}) { i, *h = (*h)[len(*h)-1], (*h)[:len(*h)-1]; return i } 275 276 // NKeeper is a Keeper that retains the n best ComparableDists that have been passed to Keep. 277 type NKeeper struct { 278 Heap 279 } 280 281 // NewNKeeper returns an NKeeper with the max value of the heap set to infinite distance. The 282 // returned NKeeper is able to retain at most n values. 283 func NewNKeeper(n int) *NKeeper { 284 k := NKeeper{make(Heap, 1, n)} 285 k.Heap[0].Dist = inf 286 return &k 287 } 288 289 // Keep adds c to the heap if its distance is less than the maximum value of the heap. If adding 290 // c would increase the size of the heap beyond the initial maximum length, the maximum value of 291 // the heap is dropped. 292 func (k *NKeeper) Keep(c ComparableDist) { 293 if c.Dist <= k.Heap[0].Dist { // Favour later finds to displace sentinel. 294 if len(k.Heap) == cap(k.Heap) { 295 heap.Pop(k) 296 } 297 heap.Push(k, c) 298 } 299 } 300 301 // DistKeeper is a Keeper that retains the ComparableDists within the specified distance of the 302 // query that it is called to Keep. 303 type DistKeeper struct { 304 Heap 305 } 306 307 // NewDistKeeper returns an DistKeeper with the maximum value of the heap set to d. 308 func NewDistKeeper(d float64) *DistKeeper { return &DistKeeper{Heap{{Dist: d}}} } 309 310 // Keep adds c to the heap if its distance is less than or equal to the max value of the heap. 311 func (k *DistKeeper) Keep(c ComparableDist) { 312 if c.Dist <= k.Heap[0].Dist { 313 heap.Push(k, c) 314 } 315 } 316 317 // Keeper implements a conditional max heap sorted on the Dist field of the ComparableDist type. 318 // vantage point search is guided by the distance stored in the max value of the heap. 319 type Keeper interface { 320 Keep(ComparableDist) // Keep conditionally pushes the provided ComparableDist onto the heap. 321 Max() ComparableDist // Max returns the maximum element of the Keeper. 322 heap.Interface 323 } 324 325 // NearestSet finds the nearest values to the query accepted by the provided Keeper, k. 326 // k must be able to return a ComparableDist specifying the maximum acceptable distance 327 // when Max() is called, and retains the results of the search in min sorted order after 328 // the call to NearestSet returns. 329 // If a sentinel ComparableDist with a nil Comparable is used by the Keeper to mark the 330 // maximum distance, NearestSet will remove it before returning. 331 func (t *Tree) NearestSet(k Keeper, q Comparable) { 332 if t.Root == nil { 333 return 334 } 335 t.Root.searchSet(q, k) 336 337 // Check whether we have retained a sentinel 338 // and flag removal if we have. 339 removeSentinel := k.Len() != 0 && k.Max().Comparable == nil 340 341 sort.Sort(sort.Reverse(k)) 342 343 // This abuses the interface to drop the max. 344 // It is reasonable to do this because we know 345 // that the maximum value will now be at element 346 // zero, which is removed by the Pop method. 347 if removeSentinel { 348 k.Pop() 349 } 350 } 351 352 func (n *Node) searchSet(q Comparable, k Keeper) { 353 if n == nil { 354 return 355 } 356 357 k.Keep(ComparableDist{Comparable: n.Point, Dist: q.Distance(n.Point)}) 358 359 d := q.Distance(n.Point) 360 if d < n.Radius { 361 n.Closer.searchSet(q, k) 362 if d+k.Max().Dist >= n.Radius { 363 n.Further.searchSet(q, k) 364 } 365 } else { 366 n.Further.searchSet(q, k) 367 if d-k.Max().Dist <= n.Radius { 368 n.Closer.searchSet(q, k) 369 } 370 } 371 } 372 373 // Operation is a function that operates on a Comparable. The bounding volume and tree depth 374 // of the point is also provided. If done is returned true, the Operation is indicating that no 375 // further work needs to be done and so the Do function should traverse no further. 376 type Operation func(Comparable, int) (done bool) 377 378 // Do performs fn on all values stored in the tree. A boolean is returned indicating whether the 379 // Do traversal was interrupted by an Operation returning true. If fn alters stored values' sort 380 // relationships, future tree operation behaviors are undefined. 381 func (t *Tree) Do(fn Operation) bool { 382 if t.Root == nil { 383 return false 384 } 385 return t.Root.do(fn, 0) 386 } 387 388 func (n *Node) do(fn Operation, depth int) (done bool) { 389 if n.Closer != nil { 390 done = n.Closer.do(fn, depth+1) 391 if done { 392 return 393 } 394 } 395 done = fn(n.Point, depth) 396 if done { 397 return 398 } 399 if n.Further != nil { 400 done = n.Further.do(fn, depth+1) 401 } 402 return 403 }