github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/intervalmap/intervalmap.go (about) 1 // Package intervalmap stores a set of (potentially overlapping) intervals. It 2 // supports searching for intervals that overlap user-provided interval. 3 // 4 // The implementation uses an 1-D version of Kd tree with randomized 5 // surface-area heuristic 6 // (http://www.sci.utah.edu/~wald/Publications/2007/ParallelBVHBuild/fastbuild.pdf). 7 package intervalmap 8 9 //go:generate ../gtl/generate_randomized_freepool.py --output=search_freepool --prefix=searcher --PREFIX=searcher -DELEM=*searcher --package=intervalmap 10 11 import ( 12 "bytes" 13 "encoding/gob" 14 "fmt" 15 "math" 16 "math/rand" 17 "runtime" 18 "unsafe" 19 20 "github.com/Schaudge/grailbase/log" 21 "github.com/Schaudge/grailbase/must" 22 ) 23 24 // Key is the type for interval boundaries. 25 type Key = int64 26 27 // Interval defines an half-open interval, [Start, Limit). 28 type Interval struct { 29 // Start is included 30 Start Key 31 // Limit is excluded. 32 Limit Key 33 } 34 35 var emptyInterval = Interval{math.MaxInt64, math.MinInt64} 36 37 func min(x, y Key) Key { 38 if x < y { 39 return x 40 } 41 return y 42 } 43 44 func max(x, y Key) Key { 45 if x < y { 46 return y 47 } 48 return x 49 } 50 51 // Intersects checks if (i∩j) != ∅ 52 func (i Interval) Intersects(j Interval) bool { 53 return i.Limit > j.Start && j.Limit > i.Start 54 } 55 56 // Intersect computes i ∩ j. 57 func (i Interval) Intersect(j Interval) Interval { 58 minKey := max(i.Start, j.Start) 59 maxKey := min(i.Limit, j.Limit) 60 return Interval{minKey, maxKey} 61 } 62 63 // Empty checks if the interval is empty. 64 func (i Interval) Empty() bool { return i.Start >= i.Limit } 65 66 // Span computes a minimal interval that spans over both i and j. If either i 67 // or j is an empty set, this function returns the other set. 68 func (i Interval) Span(j Interval) Interval { 69 switch { 70 case i.Empty(): 71 return j 72 case j.Empty(): 73 return i 74 default: 75 return Interval{min(i.Start, j.Start), max(i.Limit, j.Limit)} 76 } 77 } 78 79 const ( 80 maxEntsInNode = 16 // max size of node.ents. 81 ) 82 83 // Entry represents one interval. 84 type Entry struct { 85 // Interval defines a half-open interval, [Start,Limit) 86 Interval Interval 87 // Data is an arbitrary user-defined payload 88 Data interface{} 89 } 90 91 type entry struct { 92 Entry 93 id int // dense sequence number 0, 1, 2, ... 94 } 95 96 // node represents one node in Kdtree. 97 type node struct { 98 bounds Interval // interval covered by this node. 99 left, right *node // children. Maybe nil. 100 ents []*entry // Nonempty iff. left=nil&&right=nil. 101 label string // for debugging only. 102 } 103 104 // TreeStats shows tree-wide stats. 105 type TreeStats struct { 106 // Nodes is the total number of tree nodes. 107 Nodes int 108 // Nodes is the total number of leaf nodes. 109 // 110 // Invariant: LeafNodes < Nodes 111 LeafNodes int 112 // MaxDepth is the max depth of the tree. 113 MaxDepth int 114 // MaxLeafNodeSize is the maximum len(node.ents) of all nodes in the tree. 115 MaxLeafNodeSize int 116 // TotalLeafDepth is the sum of depth of all leaf nodes. 117 TotalLeafDepth int 118 // TotalLeafDepth is the sum of len(node.ents) of all leaf nodes. 119 TotalLeafNodeSize int 120 } 121 122 // T represents the intervalmap. It must be created using New(). 123 type T struct { 124 root node 125 stats TreeStats 126 pool *searcherFreePool 127 } 128 129 // New creates a new tree with the given set of entries. The intervals may 130 // overlap, and they need not be sorted. 131 func New(ents []Entry) *T { 132 entsCopy := make([]entry, len(ents)) 133 for i := range ents { 134 entsCopy[i] = entry{Entry: ents[i], id: i} 135 } 136 ients := make([]*entry, len(ents)) 137 for i := range entsCopy { 138 ients[i] = &entsCopy[i] 139 } 140 r := rand.New(rand.NewSource(0)) 141 t := &T{} 142 t.stats.MaxDepth = -1 143 t.stats.MaxLeafNodeSize = -1 144 t.root.init("", ients, keyRange(ients), r, &t.stats) 145 t.pool = newSearcherFreePool(t, len(ents)) 146 return t 147 } 148 149 func newSearcherFreePool(t *T, nEnt int) *searcherFreePool { 150 return NewsearcherFreePool(func() *searcher { 151 return &searcher{ 152 tree: t, 153 hits: make([]uint32, nEnt), 154 } 155 }, runtime.NumCPU()*2) 156 } 157 158 // searcher keeps state needed during one search episode. It is owned by one 159 // goroutine. 160 type searcher struct { 161 tree *T 162 searchID uint32 // increments on every search 163 hits []uint32 // hits[i] == searchID if the i'th entry has already been visited 164 } 165 166 func (s *searcher) visit(i int) bool { 167 if s.hits[i] != s.searchID { 168 s.hits[i] = s.searchID 169 return true 170 } 171 return false 172 } 173 174 // Stats returns tree-wide stats. 175 func (t *T) Stats() TreeStats { return t.stats } 176 177 // Get finds all the entries that intersect the given interval and return them 178 // in *ents. 179 func (t *T) Get(interval Interval, ents *[]*Entry) { 180 s := t.pool.Get() 181 s.searchID++ 182 *ents = (*ents)[:0] 183 t.root.get(interval, ents, s) 184 if s.searchID < math.MaxUint32 { 185 t.pool.Put(s) 186 } 187 } 188 189 // Any checks if any of the entries intersect the given interval. 190 func (t *T) Any(interval Interval) bool { 191 s := t.pool.Get() 192 s.searchID++ 193 found := t.root.any(interval, s) 194 if s.searchID < math.MaxUint32 { 195 t.pool.Put(s) 196 } 197 return found 198 } 199 200 func keyRange(ents []*entry) Interval { 201 i := emptyInterval 202 for _, e := range ents { 203 i = i.Span(e.Interval) 204 } 205 return i 206 } 207 208 const maxSample = 8 209 210 // randomSample picks maxSample random elements from ents[]. It shuffles ents[] 211 // in place. 212 func randomSample(ents []*entry, r *rand.Rand) []*entry { 213 if len(ents) <= maxSample { 214 return ents 215 } 216 shuffleFirstN := func(n int) { // Fisher-Yates shuffle 217 for i := 0; i < n-1; i++ { 218 j := i + r.Intn(len(ents)-i) 219 ents[i], ents[j] = ents[j], ents[i] 220 } 221 } 222 n := maxSample 223 if len(ents)-n < n { 224 // When maxSample < len(n) < maxSample*2, it's faster to compute the 225 // complement set. 226 n = len(ents) - n 227 shuffleFirstN(len(ents) - n) 228 return ents[n:] 229 } 230 shuffleFirstN(n) 231 return ents[:n] 232 } 233 234 // This function splits interval "bounds" into two balanced subintervals, 235 // [bounds.Start, mid) and [mid, bounds.Limit). left (right) will store a subset 236 // of ents[] that fits in the first (second, resp) subinterval. Note that an 237 // entry in ents[] may belong to both left and right, if the entry spans over 238 // the midpoint. 239 // 240 // Ok=false if this function fails to find a good split point. 241 func split(label string, ents []*entry, bounds Interval, r *rand.Rand) (mid Key, left []*entry, right []*entry, ok bool) { 242 // A good interval split point is guaranteed to be at one of the interval 243 // endpoints. To bound the compute time, we sample up to 16 intervals in 244 // ents[], and examine their endpoints one by one. 245 sample := randomSample(ents, r) 246 sampleRange := keyRange(sample).Intersect(bounds) 247 log.Debug.Printf("%s: Split %+v, %d ents", label, sampleRange, len(ents)) 248 if sampleRange.Empty() { 249 panic(sample) 250 } 251 var ( 252 candidates [maxSample * 2]Key 253 nCandidate int 254 ) 255 for i, e := range sample { 256 candidates[i*2] = e.Interval.Start 257 candidates[i*2+1] = e.Interval.Limit 258 nCandidate += 2 259 } 260 261 // splitAt splits ents[] into two subsets, assuming bounds is split at mid. 262 splitAt := func(ents []*entry, mid Key, left, right *[]*entry) { 263 *left = (*left)[:0] 264 *right = (*right)[:0] 265 for _, e := range ents { 266 if e.Interval.Intersects(Interval{bounds.Start, mid}) { 267 *left = append(*left, e) 268 } 269 if e.Interval.Intersects(Interval{mid, bounds.Limit}) { 270 *right = append(*right, e) 271 } 272 } 273 } 274 275 // Compute the cost of splitting at each of candidates[]. 276 // We use the surface-area heuristics. The best explanation is in 277 // the following paper: 278 // 279 // Ingo Wald, Realtime ray tracing and interactive global illumination, 280 // http://www.sci.utah.edu/~wald/Publications/2004/PhD/phd.pdf 281 // 282 // The basic idea is the following: 283 // 284 // - Assume we split the parent interval [s, e) into two intervals 285 // [s,m) and [m,e) 286 // 287 // - The cost C(x) of searching a subinterval x is roughly 288 // C(x) = (length of x) * (# of entries that intersect x). 289 // 290 // The first term is the probability that a query hits the subinterval, and 291 // the 2nd term is the cost of searching inside the subinterval. 292 // 293 // This assumes that a query is distributed uniformly over the domain (in 294 // our case, [-maxint32, maxint32]. 295 // 296 // - The best split point is m that minimizes C([s,m)) + C([m,e)) 297 minCost := math.MaxFloat64 298 var minMid Key 299 var minLeft, minRight []*entry 300 var tmpLeft, tmpRight []*entry 301 302 for _, mid := range candidates[:nCandidate] { 303 splitAt(ents, mid, &tmpLeft, &tmpRight) 304 if len(tmpLeft) == 0 || len(tmpRight) == 0 { 305 continue 306 } 307 cost := float64(len(tmpLeft))*float64(mid-sampleRange.Start) + 308 float64(len(tmpRight))*float64(sampleRange.Limit-mid) 309 if cost < minCost { 310 minMid = mid 311 minLeft, tmpLeft = tmpLeft, minLeft 312 minRight, tmpRight = tmpRight, minRight 313 minCost = cost 314 } 315 } 316 if minCost == math.MaxFloat64 || len(minLeft) == len(ents) || len(minRight) == len(ents) { 317 return 318 } 319 mid = minMid 320 left = minLeft 321 right = minRight 322 ok = true 323 return 324 } 325 326 func (n *node) init(label string, ents []*entry, bounds Interval, r *rand.Rand, stats *TreeStats) { 327 defer func() { 328 // Update the stats. 329 stats.Nodes++ 330 depth := len(n.label) 331 if depth > stats.MaxDepth { 332 stats.MaxDepth = depth 333 } 334 if e := len(n.ents); e > 0 { // Leaf node 335 stats.LeafNodes++ 336 stats.TotalLeafNodeSize += e 337 stats.TotalLeafDepth += depth 338 if e > stats.MaxLeafNodeSize { 339 stats.MaxLeafNodeSize = e 340 } 341 } 342 }() 343 344 n.label = label 345 n.bounds = bounds 346 if len(ents) <= maxEntsInNode { 347 n.ents = ents 348 return 349 } 350 mid, left, right, ok := split(n.label, ents, bounds, r) 351 if !ok { 352 n.ents = ents 353 return 354 } 355 n.left = &node{} 356 357 leftInterval := Interval{n.bounds.Start, mid} 358 leftKR := keyRange(left) 359 log.Debug.Printf("%v (bounds %v): left %v %v %v", n.label, n.bounds, leftKR, leftInterval, leftKR.Intersect(leftInterval)) 360 n.left.init(label+"L", left, leftKR.Intersect(leftInterval), r, stats) 361 n.right = &node{} 362 n.right.init(label+"R", right, keyRange(right).Intersect(Interval{mid, n.bounds.Limit}), r, stats) 363 } 364 365 func addEntry(ents *[]*Entry, e *entry, s *searcher) { 366 if s.visit(e.id) { 367 *ents = append(*ents, (*Entry)(unsafe.Pointer(e))) 368 } 369 } 370 371 func (n *node) get(interval Interval, ents *[]*Entry, s *searcher) { 372 interval = interval.Intersect(n.bounds) 373 if interval.Empty() { 374 return 375 } 376 if len(n.ents) > 0 { // Leaf node 377 for _, e := range n.ents { 378 if interval.Intersects(e.Interval) { 379 addEntry(ents, e, s) 380 } 381 } 382 return 383 } 384 n.left.get(interval, ents, s) 385 n.right.get(interval, ents, s) 386 } 387 388 func (n *node) any(interval Interval, s *searcher) bool { 389 interval = interval.Intersect(n.bounds) 390 if interval.Empty() { 391 return false 392 } 393 if len(n.ents) > 0 { // Leaf node 394 for _, e := range n.ents { 395 if interval.Intersects(e.Interval) { 396 return true 397 } 398 } 399 return false 400 } 401 found := n.left.any(interval, s) 402 if !found { 403 found = n.right.any(interval, s) 404 } 405 return found 406 } 407 408 // GOB support 409 410 const gobFormatVersion = 1 411 412 // MarshalBinary implements encoding.BinaryMarshaler interface. It allows T to 413 // be encoded and decoded using Gob. 414 func (t *T) MarshalBinary() (data []byte, err error) { 415 buf := bytes.Buffer{} 416 e := gob.NewEncoder(&buf) 417 must.Nil(e.Encode(gobFormatVersion)) 418 marshalNode(e, &t.root) 419 must.Nil(e.Encode(t.stats)) 420 return buf.Bytes(), nil 421 } 422 423 func marshalNode(e *gob.Encoder, n *node) { 424 if n == nil { 425 must.Nil(e.Encode(false)) 426 return 427 } 428 must.Nil(e.Encode(true)) 429 must.Nil(e.Encode(n.bounds)) 430 marshalNode(e, n.left) 431 marshalNode(e, n.right) 432 must.Nil(e.Encode(len(n.ents))) 433 for _, ent := range n.ents { 434 must.Nil(e.Encode(ent.Entry)) 435 must.Nil(e.Encode(ent.id)) 436 } 437 must.Nil(e.Encode(n.label)) 438 } 439 440 // UnmarshalBinary implements encoding.BinaryUnmarshaler interface. 441 // It allows T to be encoded and decoded using Gob. 442 func (t *T) UnmarshalBinary(data []byte) error { 443 buf := bytes.NewReader(data) 444 d := gob.NewDecoder(buf) 445 var version int 446 if err := d.Decode(&version); err != nil { 447 return err 448 } 449 if version != gobFormatVersion { 450 return fmt.Errorf("gob decode: got version %d, want %d", version, gobFormatVersion) 451 } 452 var ( 453 maxid = -1 454 err error 455 root *node 456 ) 457 if root, err = unmarshalNode(d, &maxid); err != nil { 458 return err 459 } 460 t.root = *root 461 if err := d.Decode(&t.stats); err != nil { 462 return err 463 } 464 t.pool = newSearcherFreePool(t, maxid+1) 465 return nil 466 } 467 468 func unmarshalNode(d *gob.Decoder, maxid *int) (*node, error) { 469 var ( 470 exist bool 471 err error 472 ) 473 if err = d.Decode(&exist); err != nil { 474 return nil, err 475 } 476 if !exist { 477 return nil, nil 478 } 479 n := &node{} 480 if err := d.Decode(&n.bounds); err != nil { 481 return nil, err 482 } 483 if n.left, err = unmarshalNode(d, maxid); err != nil { 484 return nil, err 485 } 486 if n.right, err = unmarshalNode(d, maxid); err != nil { 487 return nil, err 488 } 489 var nEnt int 490 if err := d.Decode(&nEnt); err != nil { 491 return nil, err 492 } 493 n.ents = make([]*entry, nEnt) 494 for i := 0; i < nEnt; i++ { 495 n.ents[i] = &entry{} 496 if err := d.Decode(&n.ents[i].Entry); err != nil { 497 return nil, err 498 } 499 if err := d.Decode(&n.ents[i].id); err != nil { 500 return nil, err 501 } 502 if n.ents[i].id > *maxid { 503 *maxid = n.ents[i].id 504 } 505 } 506 if err := d.Decode(&n.label); err != nil { 507 return nil, err 508 } 509 return n, nil 510 }