github.com/biogo/store@v0.0.0-20201120204734-aad293a2328f/kdtree/kdtree.go (about) 1 // Copyright ©2012 The bíogo Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Package kdtree implements a k-d tree. 6 package kdtree 7 8 import ( 9 "container/heap" 10 "fmt" 11 "math" 12 "sort" 13 ) 14 15 type Interface interface { 16 // Index returns the ith element of the list of points. 17 Index(i int) Comparable 18 19 // Len returns the length of the list. 20 Len() int 21 22 // Pivot partitions the list based on the dimension specified. 23 Pivot(Dim) int 24 25 // Slice returns a slice of the list. 26 Slice(start, end int) Interface 27 } 28 29 // An Bounder returns a bounding volume containing the list of points. Bounds may return nil. 30 type Bounder interface { 31 Bounds() *Bounding 32 } 33 34 type bounder interface { 35 Interface 36 Bounder 37 } 38 39 // A Dim is an index into a point's coordinates. 40 type Dim int 41 42 // A Comparable is the element interface for values stored in a k-d tree. 43 type Comparable interface { 44 // Compare returns the shortest translation of the plane through b with 45 // normal vector along dimension d to the parallel plane through a. 46 // 47 // Given c = a.Compare(b, d): 48 // c = a_d - b_d 49 // 50 Compare(Comparable, Dim) float64 51 52 // Dims returns the number of dimensions described in the Comparable. 53 Dims() int 54 55 // Distance returns the squared Euclidean distance between the receiver and 56 // the parameter. 57 Distance(Comparable) float64 58 } 59 60 // An Extender is a Comparable that can increase a bounding volume to include the 61 // point represented by the Comparable. 62 type Extender interface { 63 Comparable 64 65 // Extend returns a bounding box that has been extended to include the 66 // receiver. Extend may return nil. 67 Extend(*Bounding) *Bounding 68 } 69 70 // A Bounding represents a volume bounding box. 71 type Bounding [2]Comparable 72 73 // Contains returns whether c is within the volume of the Bounding. A nil Bounding 74 // returns true. 75 func (b *Bounding) Contains(c Comparable) bool { 76 if b == nil { 77 return true 78 } 79 for d := Dim(0); d < Dim(c.Dims()); d++ { 80 if c.Compare(b[0], d) < 0 || c.Compare(b[1], d) > 0 { 81 return false 82 } 83 } 84 return true 85 } 86 87 // A Node holds a single point value in a k-d tree. 88 type Node struct { 89 Point Comparable 90 Plane Dim 91 Left, Right *Node 92 *Bounding 93 } 94 95 func (n *Node) String() string { 96 if n == nil { 97 return "<nil>" 98 } 99 return fmt.Sprintf("%.3f %d", n.Point, n.Plane) 100 } 101 102 // A Tree implements a k-d tree creation and nearest neighbour search. 103 type Tree struct { 104 Root *Node 105 Count int 106 } 107 108 // New returns a k-d tree constructed from the values in p. If p is a Bounder and 109 // bounding is true, bounds are determined for each node. 110 func New(p Interface, bounding bool) *Tree { 111 if p, ok := p.(bounder); ok && bounding { 112 return &Tree{ 113 Root: buildBounded(p, 0, bounding), 114 Count: p.Len(), 115 } 116 } 117 return &Tree{ 118 Root: build(p, 0), 119 Count: p.Len(), 120 } 121 } 122 123 func build(p Interface, plane Dim) *Node { 124 if p.Len() == 0 { 125 return nil 126 } 127 128 piv := p.Pivot(plane) 129 d := p.Index(piv) 130 np := (plane + 1) % Dim(d.Dims()) 131 132 return &Node{ 133 Point: d, 134 Plane: plane, 135 Left: build(p.Slice(0, piv), np), 136 Right: build(p.Slice(piv+1, p.Len()), np), 137 Bounding: nil, 138 } 139 } 140 141 func buildBounded(p bounder, plane Dim, bounding bool) *Node { 142 if p.Len() == 0 { 143 return nil 144 } 145 146 piv := p.Pivot(plane) 147 d := p.Index(piv) 148 np := (plane + 1) % Dim(d.Dims()) 149 150 b := p.Bounds() 151 return &Node{ 152 Point: d, 153 Plane: plane, 154 Left: buildBounded(p.Slice(0, piv).(bounder), np, bounding), 155 Right: buildBounded(p.Slice(piv+1, p.Len()).(bounder), np, bounding), 156 Bounding: b, 157 } 158 } 159 160 // Insert adds a point to the tree, updating the bounding volumes if bounding is 161 // true, and the tree is empty or the tree already has bounding volumes stored, 162 // and c is an Extender. No rebalancing of the tree is performed. 163 func (t *Tree) Insert(c Comparable, bounding bool) { 164 t.Count++ 165 if t.Root != nil { 166 bounding = t.Root.Bounding != nil 167 } 168 if c, ok := c.(Extender); ok && bounding { 169 t.Root = t.Root.insertBounded(c, 0, bounding) 170 return 171 } else if !ok && t.Root != nil { 172 // If we are not rebounding, mark the tree as non-bounded. 173 t.Root.Bounding = nil 174 } 175 t.Root = t.Root.insert(c, 0) 176 } 177 178 func (n *Node) insert(c Comparable, d Dim) *Node { 179 if n == nil { 180 return &Node{ 181 Point: c, 182 Plane: d, 183 Bounding: nil, 184 } 185 } 186 187 d = (n.Plane + 1) % Dim(c.Dims()) 188 if c.Compare(n.Point, n.Plane) <= 0 { 189 n.Left = n.Left.insert(c, d) 190 } else { 191 n.Right = n.Right.insert(c, d) 192 } 193 194 return n 195 } 196 197 func (n *Node) insertBounded(c Extender, d Dim, bounding bool) *Node { 198 if n == nil { 199 var b *Bounding 200 if bounding { 201 b = c.Extend(b) 202 } 203 return &Node{ 204 Point: c, 205 Plane: d, 206 Bounding: b, 207 } 208 } 209 210 if bounding { 211 n.Bounding = c.Extend(n.Bounding) 212 } 213 d = (n.Plane + 1) % Dim(c.Dims()) 214 if c.Compare(n.Point, n.Plane) <= 0 { 215 n.Left = n.Left.insertBounded(c, d, bounding) 216 } else { 217 n.Right = n.Right.insertBounded(c, d, bounding) 218 } 219 220 return n 221 } 222 223 // Len returns the number of elements in the tree. 224 func (t *Tree) Len() int { return t.Count } 225 226 // Contains returns whether a Comparable is in the bounds of the tree. If no bounding has 227 // been constructed Contains returns true. 228 func (t *Tree) Contains(c Comparable) bool { 229 if t.Root.Bounding == nil { 230 return true 231 } 232 return t.Root.Contains(c) 233 } 234 235 var inf = math.Inf(1) 236 237 // Nearest returns the nearest value to the query and the distance between them. 238 func (t *Tree) Nearest(q Comparable) (Comparable, float64) { 239 if t.Root == nil { 240 return nil, inf 241 } 242 n, dist := t.Root.search(q, inf) 243 if n == nil { 244 return nil, inf 245 } 246 return n.Point, dist 247 } 248 249 func (n *Node) search(q Comparable, dist float64) (*Node, float64) { 250 if n == nil { 251 return nil, inf 252 } 253 254 c := q.Compare(n.Point, n.Plane) 255 dist = math.Min(dist, q.Distance(n.Point)) 256 257 bn := n 258 if c <= 0 { 259 ln, ld := n.Left.search(q, dist) 260 if ld < dist { 261 dist = ld 262 bn = ln 263 } 264 if c*c < dist { 265 rn, rd := n.Right.search(q, dist) 266 if rd < dist { 267 bn, dist = rn, rd 268 } 269 } 270 return bn, dist 271 } 272 rn, rd := n.Right.search(q, dist) 273 if rd < dist { 274 dist = rd 275 bn = rn 276 } 277 if c*c < dist { 278 ln, ld := n.Left.search(q, dist) 279 if ld < dist { 280 bn, dist = ln, ld 281 } 282 } 283 return bn, dist 284 } 285 286 // ComparableDist holds a Comparable and a distance to a specific query. A nil Comparable 287 // is used to mark the end of the heap, so clients should not store nil values except for 288 // this purpose. 289 type ComparableDist struct { 290 Comparable Comparable 291 Dist float64 292 } 293 294 // Heap is a max heap sorted on Dist. 295 type Heap []ComparableDist 296 297 func (h *Heap) Max() ComparableDist { return (*h)[0] } 298 func (h *Heap) Len() int { return len(*h) } 299 func (h *Heap) Less(i, j int) bool { return (*h)[i].Comparable == nil || (*h)[i].Dist > (*h)[j].Dist } 300 func (h *Heap) Swap(i, j int) { (*h)[i], (*h)[j] = (*h)[j], (*h)[i] } 301 func (h *Heap) Push(x interface{}) { (*h) = append(*h, x.(ComparableDist)) } 302 func (h *Heap) Pop() (i interface{}) { i, *h = (*h)[len(*h)-1], (*h)[:len(*h)-1]; return i } 303 304 // NKeeper is a Keeper that retains the n best ComparableDists that it is called to Keep. 305 type NKeeper struct { 306 Heap 307 } 308 309 // NewNKeeper returns an NKeeper with the max value of the heap set to infinite distance. The 310 // returned NKeeper is able to retain at most n values. 311 func NewNKeeper(n int) *NKeeper { 312 k := NKeeper{make(Heap, 1, n)} 313 k.Heap[0].Dist = inf 314 return &k 315 } 316 317 // Keep add c to the heap if its distance is less than the maximum value of the heap. If adding 318 // c would increase the size of the heap beyond the initial maximum length, the maximum value of 319 // the heap is dropped. 320 func (k *NKeeper) Keep(c ComparableDist) { 321 if c.Dist < k.Heap[0].Dist { 322 if len(k.Heap) == cap(k.Heap) { 323 heap.Pop(k) 324 } 325 heap.Push(k, c) 326 } 327 } 328 329 // DistKeeper is a Keeper that retains the ComparableDists within the specified distance of the 330 // query that it is called to Keep. 331 type DistKeeper struct { 332 Heap 333 } 334 335 // NewDistKeeper returns an DistKeeper with the max value of the heap set to d. 336 func NewDistKeeper(d float64) *DistKeeper { return &DistKeeper{Heap{{Dist: d}}} } 337 338 // Keep adds c to the heap if its distance is less than or equal to the max value of the heap. 339 func (k *DistKeeper) Keep(c ComparableDist) { 340 if c.Dist <= k.Heap[0].Dist { 341 heap.Push(k, c) 342 } 343 } 344 345 // Keeper implements a conditional max heap sorted on the Dist field of the ComparableDist type. 346 // kd search is guided by the distance stored in the max value of the heap. 347 type Keeper interface { 348 Keep(ComparableDist) // Keep conditionally pushes the provided ComparableDist onto the heap. 349 Max() ComparableDist // Max returns the maximum element of the Keeper. 350 heap.Interface 351 } 352 353 // NearestSet finds the nearest values to the query accepted by the provided Keeper, k. 354 // k must be able to return a ComparableDist specifying the maximum acceptable distance 355 // when Max() is called, and retains the results of the search in min sorted order after 356 // the call to NearestSet returns. 357 func (t *Tree) NearestSet(k Keeper, q Comparable) { 358 if t.Root == nil { 359 return 360 } 361 t.Root.searchSet(q, k) 362 363 // Check whether we have retained a sentinel 364 // and flag removal if we have. 365 removeSentinel := k.Len() != 0 && k.Max().Comparable == nil 366 367 sort.Sort(sort.Reverse(k)) 368 369 // This abuses the interface to drop the max. 370 // It is reasonable to do this because we know 371 // that the maximum value will now be at element 372 // zero, which is removed by the Pop method. 373 if removeSentinel { 374 k.Pop() 375 } 376 } 377 378 func (n *Node) searchSet(q Comparable, k Keeper) { 379 if n == nil { 380 return 381 } 382 383 c := q.Compare(n.Point, n.Plane) 384 k.Keep(ComparableDist{Comparable: n.Point, Dist: q.Distance(n.Point)}) 385 if c <= 0 { 386 n.Left.searchSet(q, k) 387 if c*c <= k.Max().Dist { 388 n.Right.searchSet(q, k) 389 } 390 return 391 } 392 n.Right.searchSet(q, k) 393 if c*c <= k.Max().Dist { 394 n.Left.searchSet(q, k) 395 } 396 return 397 } 398 399 // An Operation is a function that operates on a Comparable. The bounding volume and tree depth 400 // of the point is also provided. If done is returned true, the Operation is indicating that no 401 // further work needs to be done and so the Do function should traverse no further. 402 type Operation func(Comparable, *Bounding, int) (done bool) 403 404 // Do performs fn on all values stored in the tree. A boolean is returned indicating whether the 405 // Do traversal was interrupted by an Operation returning true. If fn alters stored values' sort 406 // relationships, future tree operation behaviors are undefined. 407 func (t *Tree) Do(fn Operation) bool { 408 if t.Root == nil { 409 return false 410 } 411 return t.Root.do(fn, 0) 412 } 413 414 func (n *Node) do(fn Operation, depth int) (done bool) { 415 if n.Left != nil { 416 done = n.Left.do(fn, depth+1) 417 if done { 418 return 419 } 420 } 421 done = fn(n.Point, n.Bounding, depth) 422 if done { 423 return 424 } 425 if n.Right != nil { 426 done = n.Right.do(fn, depth+1) 427 } 428 return 429 } 430 431 // DoBounded performs fn on all values stored in the tree that are within the specified bound. 432 // If b is nil, the result is the same as a Do. A boolean is returned indicating whether the 433 // DoBounded traversal was interrupted by an Operation returning true. If fn alters stored 434 // values' sort relationships future tree operation behaviors are undefined. 435 func (t *Tree) DoBounded(fn Operation, b *Bounding) bool { 436 if t.Root == nil { 437 return false 438 } 439 if b == nil { 440 return t.Root.do(fn, 0) 441 } 442 return t.Root.doBounded(fn, b, 0) 443 } 444 445 func (n *Node) doBounded(fn Operation, b *Bounding, depth int) (done bool) { 446 if n.Left != nil && b[0].Compare(n.Point, n.Plane) < 0 { 447 done = n.Left.doBounded(fn, b, depth+1) 448 if done { 449 return 450 } 451 } 452 if b.Contains(n.Point) { 453 done = fn(n.Point, b, depth) 454 if done { 455 return 456 } 457 } 458 if n.Right != nil && 0 < b[1].Compare(n.Point, n.Plane) { 459 done = n.Right.doBounded(fn, b, depth+1) 460 } 461 return 462 }