gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/spatial/kdtree/kdtree.go (about) 1 // Copyright ©2019 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package kdtree 6 7 import ( 8 "container/heap" 9 "fmt" 10 "math" 11 "sort" 12 ) 13 14 // Interface is the set of methods required for construction of efficiently 15 // searchable k-d trees. A k-d tree may be constructed without using the 16 // Interface type, but it is likely to have reduced search performance. 17 type Interface interface { 18 // Index returns the ith element of the list of points. 19 Index(i int) Comparable 20 21 // Len returns the length of the list. 22 Len() int 23 24 // Pivot partitions the list based on the dimension specified. 25 Pivot(Dim) int 26 27 // Slice returns a slice of the list using zero-based half 28 // open indexing equivalent to built-in slice indexing. 29 Slice(start, end int) Interface 30 } 31 32 // Bounder returns a bounding volume containing the list of points. Bounds may return nil. 33 type Bounder interface { 34 Bounds() *Bounding 35 } 36 37 type bounder interface { 38 Interface 39 Bounder 40 } 41 42 // Dim is an index into a point's coordinates. 43 type Dim int 44 45 // Comparable is the element interface for values stored in a k-d tree. 46 type Comparable interface { 47 // Compare returns the signed distance of a from the plane passing through 48 // b and perpendicular to the dimension d. 49 // 50 // Given c = a.Compare(b, d): 51 // c = a_d - b_d 52 // 53 Compare(Comparable, Dim) float64 54 55 // Dims returns the number of dimensions described in the Comparable. 56 Dims() int 57 58 // Distance returns the squared Euclidean distance between the receiver and 59 // the parameter. 60 Distance(Comparable) float64 61 } 62 63 // Extender is a Comparable that can increase a bounding volume to include the 64 // point represented by the Comparable. 65 type Extender interface { 66 Comparable 67 68 // Extend returns a bounding box that has been extended to include the 69 // receiver. Extend may return nil. 70 Extend(*Bounding) *Bounding 71 } 72 73 // Bounding represents a volume bounding box. 74 type Bounding struct { 75 Min, Max Comparable 76 } 77 78 // Contains returns whether c is within the volume of the Bounding. A nil Bounding 79 // returns true. 80 func (b *Bounding) Contains(c Comparable) bool { 81 if b == nil { 82 return true 83 } 84 for d := Dim(0); d < Dim(c.Dims()); d++ { 85 if c.Compare(b.Min, d) < 0 || 0 < c.Compare(b.Max, d) { 86 return false 87 } 88 } 89 return true 90 } 91 92 // Node holds a single point value in a k-d tree. 93 type Node struct { 94 Point Comparable 95 Plane Dim 96 Left, Right *Node 97 *Bounding 98 } 99 100 func (n *Node) String() string { 101 if n == nil { 102 return "<nil>" 103 } 104 return fmt.Sprintf("%.3f %d", n.Point, n.Plane) 105 } 106 107 // Tree implements a k-d tree creation and nearest neighbor search. 108 type Tree struct { 109 Root *Node 110 Count int 111 } 112 113 // New returns a k-d tree constructed from the values in p. If p is a Bounder and 114 // bounding is true, bounds are determined for each node. 115 // The ordering of elements in p may be altered after New returns. 116 func New(p Interface, bounding bool) *Tree { 117 if p, ok := p.(bounder); ok && bounding { 118 return &Tree{ 119 Root: buildBounded(p, 0, bounding), 120 Count: p.Len(), 121 } 122 } 123 return &Tree{ 124 Root: build(p, 0), 125 Count: p.Len(), 126 } 127 } 128 129 func build(p Interface, plane Dim) *Node { 130 if p.Len() == 0 { 131 return nil 132 } 133 134 piv := p.Pivot(plane) 135 d := p.Index(piv) 136 np := (plane + 1) % Dim(d.Dims()) 137 138 return &Node{ 139 Point: d, 140 Plane: plane, 141 Left: build(p.Slice(0, piv), np), 142 Right: build(p.Slice(piv+1, p.Len()), np), 143 Bounding: nil, 144 } 145 } 146 147 func buildBounded(p bounder, plane Dim, bounding bool) *Node { 148 if p.Len() == 0 { 149 return nil 150 } 151 152 piv := p.Pivot(plane) 153 d := p.Index(piv) 154 np := (plane + 1) % Dim(d.Dims()) 155 156 b := p.Bounds() 157 return &Node{ 158 Point: d, 159 Plane: plane, 160 Left: buildBounded(p.Slice(0, piv).(bounder), np, bounding), 161 Right: buildBounded(p.Slice(piv+1, p.Len()).(bounder), np, bounding), 162 Bounding: b, 163 } 164 } 165 166 // Insert adds a point to the tree, updating the bounding volumes if bounding is 167 // true, and the tree is empty or the tree already has bounding volumes stored, 168 // and c is an Extender. No rebalancing of the tree is performed. 169 func (t *Tree) Insert(c Comparable, bounding bool) { 170 t.Count++ 171 if t.Root != nil { 172 bounding = t.Root.Bounding != nil 173 } 174 if c, ok := c.(Extender); ok && bounding { 175 t.Root = t.Root.insertBounded(c, 0, bounding) 176 return 177 } else if !ok && t.Root != nil { 178 // If we are not rebounding, mark the tree as non-bounded. 179 t.Root.Bounding = nil 180 } 181 t.Root = t.Root.insert(c, 0) 182 } 183 184 func (n *Node) insert(c Comparable, d Dim) *Node { 185 if n == nil { 186 return &Node{ 187 Point: c, 188 Plane: d, 189 Bounding: nil, 190 } 191 } 192 193 d = (n.Plane + 1) % Dim(c.Dims()) 194 if c.Compare(n.Point, n.Plane) <= 0 { 195 n.Left = n.Left.insert(c, d) 196 } else { 197 n.Right = n.Right.insert(c, d) 198 } 199 200 return n 201 } 202 203 func (n *Node) insertBounded(c Extender, d Dim, bounding bool) *Node { 204 if n == nil { 205 var b *Bounding 206 if bounding { 207 b = c.Extend(b) 208 } 209 return &Node{ 210 Point: c, 211 Plane: d, 212 Bounding: b, 213 } 214 } 215 216 if bounding { 217 n.Bounding = c.Extend(n.Bounding) 218 } 219 d = (n.Plane + 1) % Dim(c.Dims()) 220 if c.Compare(n.Point, n.Plane) <= 0 { 221 n.Left = n.Left.insertBounded(c, d, bounding) 222 } else { 223 n.Right = n.Right.insertBounded(c, d, bounding) 224 } 225 226 return n 227 } 228 229 // Len returns the number of elements in the tree. 230 func (t *Tree) Len() int { return t.Count } 231 232 // Contains returns whether a Comparable is in the bounds of the tree. If no bounding has 233 // been constructed Contains returns true. 234 func (t *Tree) Contains(c Comparable) bool { 235 if t.Root.Bounding == nil { 236 return true 237 } 238 return t.Root.Contains(c) 239 } 240 241 var inf = math.Inf(1) 242 243 // Nearest returns the nearest value to the query and the distance between them. 244 func (t *Tree) Nearest(q Comparable) (Comparable, float64) { 245 if t.Root == nil { 246 return nil, inf 247 } 248 n, dist := t.Root.search(q, inf) 249 if n == nil { 250 return nil, inf 251 } 252 return n.Point, dist 253 } 254 255 func (n *Node) search(q Comparable, dist float64) (*Node, float64) { 256 if n == nil { 257 return nil, inf 258 } 259 260 c := q.Compare(n.Point, n.Plane) 261 dist = math.Min(dist, q.Distance(n.Point)) 262 263 bn := n 264 if c <= 0 { 265 ln, ld := n.Left.search(q, dist) 266 if ld < dist { 267 bn, dist = ln, ld 268 } 269 if c*c < dist { 270 rn, rd := n.Right.search(q, dist) 271 if rd < dist { 272 bn, dist = rn, rd 273 } 274 } 275 return bn, dist 276 } 277 rn, rd := n.Right.search(q, dist) 278 if rd < dist { 279 bn, dist = rn, rd 280 } 281 if c*c < dist { 282 ln, ld := n.Left.search(q, dist) 283 if ld < dist { 284 bn, dist = ln, ld 285 } 286 } 287 return bn, dist 288 } 289 290 // ComparableDist holds a Comparable and a distance to a specific query. A nil Comparable 291 // is used to mark the end of the heap, so clients should not store nil values except for 292 // this purpose. 293 type ComparableDist struct { 294 Comparable Comparable 295 Dist float64 296 } 297 298 // Heap is a max heap sorted on Dist. 299 type Heap []ComparableDist 300 301 func (h *Heap) Max() ComparableDist { return (*h)[0] } 302 func (h *Heap) Len() int { return len(*h) } 303 func (h *Heap) Less(i, j int) bool { return (*h)[i].Comparable == nil || (*h)[i].Dist > (*h)[j].Dist } 304 func (h *Heap) Swap(i, j int) { (*h)[i], (*h)[j] = (*h)[j], (*h)[i] } 305 func (h *Heap) Push(x interface{}) { (*h) = append(*h, x.(ComparableDist)) } 306 func (h *Heap) Pop() (i interface{}) { i, *h = (*h)[len(*h)-1], (*h)[:len(*h)-1]; return i } 307 308 // NKeeper is a Keeper that retains the n best ComparableDists that have been passed to Keep. 309 type NKeeper struct { 310 Heap 311 } 312 313 // NewNKeeper returns an NKeeper with the max value of the heap set to infinite distance. The 314 // returned NKeeper is able to retain at most n values. 315 func NewNKeeper(n int) *NKeeper { 316 k := NKeeper{make(Heap, 1, n)} 317 k.Heap[0].Dist = inf 318 return &k 319 } 320 321 // Keep adds c to the heap if its distance is less than the maximum value of the heap. If adding 322 // c would increase the size of the heap beyond the initial maximum length, the maximum value of 323 // the heap is dropped. 324 func (k *NKeeper) Keep(c ComparableDist) { 325 if c.Dist <= k.Heap[0].Dist { // Favour later finds to displace sentinel. 326 if len(k.Heap) == cap(k.Heap) { 327 // Changing the value of the element at index 0 and then calling Fix is equivalent to, 328 // but less expensive than, calling Pop() followed by a Push of the new value. 329 k.Heap[0] = c 330 heap.Fix(k, 0) 331 } else { 332 heap.Push(k, c) 333 } 334 } 335 } 336 337 // DistKeeper is a Keeper that retains the ComparableDists within the specified distance of the 338 // query that it is called to Keep. 339 type DistKeeper struct { 340 Heap 341 } 342 343 // NewDistKeeper returns an DistKeeper with the maximum value of the heap set to d. 344 func NewDistKeeper(d float64) *DistKeeper { return &DistKeeper{Heap{{Dist: d}}} } 345 346 // Keep adds c to the heap if its distance is less than or equal to the max value of the heap. 347 func (k *DistKeeper) Keep(c ComparableDist) { 348 if c.Dist <= k.Heap[0].Dist { 349 heap.Push(k, c) 350 } 351 } 352 353 // Keeper implements a conditional max heap sorted on the Dist field of the ComparableDist type. 354 // kd search is guided by the distance stored in the max value of the heap. 355 type Keeper interface { 356 Keep(ComparableDist) // Keep conditionally pushes the provided ComparableDist onto the heap. 357 Max() ComparableDist // Max returns the maximum element of the Keeper. 358 heap.Interface 359 } 360 361 // NearestSet finds the nearest values to the query accepted by the provided Keeper, k. 362 // k must be able to return a ComparableDist specifying the maximum acceptable distance 363 // when Max() is called, and retains the results of the search in min sorted order after 364 // the call to NearestSet returns. 365 // If a sentinel ComparableDist with a nil Comparable is used by the Keeper to mark the 366 // maximum distance, NearestSet will remove it before returning. 367 func (t *Tree) NearestSet(k Keeper, q Comparable) { 368 if t.Root == nil { 369 return 370 } 371 t.Root.searchSet(q, k) 372 373 // Check whether we have retained a sentinel 374 // and flag removal if we have. 375 removeSentinel := k.Len() != 0 && k.Max().Comparable == nil 376 377 sort.Sort(sort.Reverse(k)) 378 379 // This abuses the interface to drop the max. 380 // It is reasonable to do this because we know 381 // that the maximum value will now be at element 382 // zero, which is removed by the Pop method. 383 if removeSentinel { 384 k.Pop() 385 } 386 } 387 388 func (n *Node) searchSet(q Comparable, k Keeper) { 389 if n == nil { 390 return 391 } 392 393 c := q.Compare(n.Point, n.Plane) 394 k.Keep(ComparableDist{Comparable: n.Point, Dist: q.Distance(n.Point)}) 395 if c <= 0 { 396 n.Left.searchSet(q, k) 397 if c*c <= k.Max().Dist { 398 n.Right.searchSet(q, k) 399 } 400 return 401 } 402 n.Right.searchSet(q, k) 403 if c*c <= k.Max().Dist { 404 n.Left.searchSet(q, k) 405 } 406 } 407 408 // Operation is a function that operates on a Comparable. The bounding volume and tree depth 409 // of the point is also provided. If done is returned true, the Operation is indicating that no 410 // further work needs to be done and so the Do function should traverse no further. 411 type Operation func(Comparable, *Bounding, int) (done bool) 412 413 // Do performs fn on all values stored in the tree. A boolean is returned indicating whether the 414 // Do traversal was interrupted by an Operation returning true. If fn alters stored values' sort 415 // relationships, future tree operation behaviors are undefined. 416 func (t *Tree) Do(fn Operation) bool { 417 if t.Root == nil { 418 return false 419 } 420 return t.Root.do(fn, 0) 421 } 422 423 func (n *Node) do(fn Operation, depth int) (done bool) { 424 if n.Left != nil { 425 done = n.Left.do(fn, depth+1) 426 if done { 427 return 428 } 429 } 430 done = fn(n.Point, n.Bounding, depth) 431 if done { 432 return 433 } 434 if n.Right != nil { 435 done = n.Right.do(fn, depth+1) 436 } 437 return 438 } 439 440 // DoBounded performs fn on all values stored in the tree that are within the specified bound. 441 // If b is nil, the result is the same as a Do. A boolean is returned indicating whether the 442 // DoBounded traversal was interrupted by an Operation returning true. If fn alters stored 443 // values' sort relationships future tree operation behaviors are undefined. 444 func (t *Tree) DoBounded(b *Bounding, fn Operation) bool { 445 if t.Root == nil { 446 return false 447 } 448 if b == nil { 449 return t.Root.do(fn, 0) 450 } 451 return t.Root.doBounded(fn, b, 0) 452 } 453 454 func (n *Node) doBounded(fn Operation, b *Bounding, depth int) (done bool) { 455 if n.Left != nil && b.Min.Compare(n.Point, n.Plane) < 0 { 456 done = n.Left.doBounded(fn, b, depth+1) 457 if done { 458 return 459 } 460 } 461 if b.Contains(n.Point) { 462 done = fn(n.Point, n.Bounding, depth) 463 if done { 464 return 465 } 466 } 467 if n.Right != nil && 0 < b.Max.Compare(n.Point, n.Plane) { 468 done = n.Right.doBounded(fn, b, depth+1) 469 } 470 return 471 }