github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/intervalmap/intervalmap.go (about)

     1  // Package intervalmap stores a set of (potentially overlapping) intervals.  It
     2  // supports searching for intervals that overlap user-provided interval.
     3  //
     4  // The implementation uses an 1-D version of Kd tree with randomized
     5  // surface-area heuristic
     6  // (http://www.sci.utah.edu/~wald/Publications/2007/ParallelBVHBuild/fastbuild.pdf).
     7  package intervalmap
     8  
     9  //go:generate ../gtl/generate_randomized_freepool.py --output=search_freepool --prefix=searcher --PREFIX=searcher -DELEM=*searcher --package=intervalmap
    10  
    11  import (
    12  	"bytes"
    13  	"encoding/gob"
    14  	"fmt"
    15  	"math"
    16  	"math/rand"
    17  	"runtime"
    18  	"unsafe"
    19  
    20  	"github.com/Schaudge/grailbase/log"
    21  	"github.com/Schaudge/grailbase/must"
    22  )
    23  
    24  // Key is the type for interval boundaries.
    25  type Key = int64
    26  
    27  // Interval defines an half-open interval, [Start, Limit).
    28  type Interval struct {
    29  	// Start is included
    30  	Start Key
    31  	// Limit is excluded.
    32  	Limit Key
    33  }
    34  
    35  var emptyInterval = Interval{math.MaxInt64, math.MinInt64}
    36  
    37  func min(x, y Key) Key {
    38  	if x < y {
    39  		return x
    40  	}
    41  	return y
    42  }
    43  
    44  func max(x, y Key) Key {
    45  	if x < y {
    46  		return y
    47  	}
    48  	return x
    49  }
    50  
    51  // Intersects checks if (i∩j) != ∅
    52  func (i Interval) Intersects(j Interval) bool {
    53  	return i.Limit > j.Start && j.Limit > i.Start
    54  }
    55  
    56  // Intersect computes i ∩ j.
    57  func (i Interval) Intersect(j Interval) Interval {
    58  	minKey := max(i.Start, j.Start)
    59  	maxKey := min(i.Limit, j.Limit)
    60  	return Interval{minKey, maxKey}
    61  }
    62  
    63  // Empty checks if the interval is empty.
    64  func (i Interval) Empty() bool { return i.Start >= i.Limit }
    65  
    66  // Span computes a minimal interval that spans over both i and j.  If either i
    67  // or j is an empty set, this function returns the other set.
    68  func (i Interval) Span(j Interval) Interval {
    69  	switch {
    70  	case i.Empty():
    71  		return j
    72  	case j.Empty():
    73  		return i
    74  	default:
    75  		return Interval{min(i.Start, j.Start), max(i.Limit, j.Limit)}
    76  	}
    77  }
    78  
    79  const (
    80  	maxEntsInNode = 16 // max size of node.ents.
    81  )
    82  
    83  // Entry represents one interval.
    84  type Entry struct {
    85  	// Interval defines a half-open interval, [Start,Limit)
    86  	Interval Interval
    87  	// Data is an arbitrary user-defined payload
    88  	Data interface{}
    89  }
    90  
    91  type entry struct {
    92  	Entry
    93  	id int // dense sequence number 0, 1, 2, ...
    94  }
    95  
    96  // node represents one node in Kdtree.
    97  type node struct {
    98  	bounds      Interval // interval covered by this node.
    99  	left, right *node    // children. Maybe nil.
   100  	ents        []*entry // Nonempty iff. left=nil&&right=nil.
   101  	label       string   // for debugging only.
   102  }
   103  
   104  // TreeStats shows tree-wide stats.
   105  type TreeStats struct {
   106  	// Nodes is the total number of tree nodes.
   107  	Nodes int
   108  	// Nodes is the total number of leaf nodes.
   109  	//
   110  	// Invariant: LeafNodes < Nodes
   111  	LeafNodes int
   112  	// MaxDepth is the max depth of the tree.
   113  	MaxDepth int
   114  	// MaxLeafNodeSize is the maximum len(node.ents) of all nodes in the tree.
   115  	MaxLeafNodeSize int
   116  	// TotalLeafDepth is the sum of depth of all leaf nodes.
   117  	TotalLeafDepth int
   118  	// TotalLeafDepth is the sum of len(node.ents) of all leaf nodes.
   119  	TotalLeafNodeSize int
   120  }
   121  
   122  // T represents the intervalmap. It must be created using New().
   123  type T struct {
   124  	root  node
   125  	stats TreeStats
   126  	pool  *searcherFreePool
   127  }
   128  
   129  // New creates a new tree with the given set of entries.  The intervals may
   130  // overlap, and they need not be sorted.
   131  func New(ents []Entry) *T {
   132  	entsCopy := make([]entry, len(ents))
   133  	for i := range ents {
   134  		entsCopy[i] = entry{Entry: ents[i], id: i}
   135  	}
   136  	ients := make([]*entry, len(ents))
   137  	for i := range entsCopy {
   138  		ients[i] = &entsCopy[i]
   139  	}
   140  	r := rand.New(rand.NewSource(0))
   141  	t := &T{}
   142  	t.stats.MaxDepth = -1
   143  	t.stats.MaxLeafNodeSize = -1
   144  	t.root.init("", ients, keyRange(ients), r, &t.stats)
   145  	t.pool = newSearcherFreePool(t, len(ents))
   146  	return t
   147  }
   148  
   149  func newSearcherFreePool(t *T, nEnt int) *searcherFreePool {
   150  	return NewsearcherFreePool(func() *searcher {
   151  		return &searcher{
   152  			tree: t,
   153  			hits: make([]uint32, nEnt),
   154  		}
   155  	}, runtime.NumCPU()*2)
   156  }
   157  
   158  // searcher keeps state needed during one search episode.  It is owned by one
   159  // goroutine.
   160  type searcher struct {
   161  	tree     *T
   162  	searchID uint32   // increments on every search
   163  	hits     []uint32 // hits[i] == searchID if the i'th entry has already been visited
   164  }
   165  
   166  func (s *searcher) visit(i int) bool {
   167  	if s.hits[i] != s.searchID {
   168  		s.hits[i] = s.searchID
   169  		return true
   170  	}
   171  	return false
   172  }
   173  
   174  // Stats returns tree-wide stats.
   175  func (t *T) Stats() TreeStats { return t.stats }
   176  
   177  // Get finds all the entries that intersect the given interval and return them
   178  // in *ents.
   179  func (t *T) Get(interval Interval, ents *[]*Entry) {
   180  	s := t.pool.Get()
   181  	s.searchID++
   182  	*ents = (*ents)[:0]
   183  	t.root.get(interval, ents, s)
   184  	if s.searchID < math.MaxUint32 {
   185  		t.pool.Put(s)
   186  	}
   187  }
   188  
   189  // Any checks if any of the entries intersect the given interval.
   190  func (t *T) Any(interval Interval) bool {
   191  	s := t.pool.Get()
   192  	s.searchID++
   193  	found := t.root.any(interval, s)
   194  	if s.searchID < math.MaxUint32 {
   195  		t.pool.Put(s)
   196  	}
   197  	return found
   198  }
   199  
   200  func keyRange(ents []*entry) Interval {
   201  	i := emptyInterval
   202  	for _, e := range ents {
   203  		i = i.Span(e.Interval)
   204  	}
   205  	return i
   206  }
   207  
   208  const maxSample = 8
   209  
   210  // randomSample picks maxSample random elements from ents[]. It shuffles ents[]
   211  // in place.
   212  func randomSample(ents []*entry, r *rand.Rand) []*entry {
   213  	if len(ents) <= maxSample {
   214  		return ents
   215  	}
   216  	shuffleFirstN := func(n int) { // Fisher-Yates shuffle
   217  		for i := 0; i < n-1; i++ {
   218  			j := i + r.Intn(len(ents)-i)
   219  			ents[i], ents[j] = ents[j], ents[i]
   220  		}
   221  	}
   222  	n := maxSample
   223  	if len(ents)-n < n {
   224  		// When maxSample < len(n) < maxSample*2, it's faster to compute the
   225  		// complement set.
   226  		n = len(ents) - n
   227  		shuffleFirstN(len(ents) - n)
   228  		return ents[n:]
   229  	}
   230  	shuffleFirstN(n)
   231  	return ents[:n]
   232  }
   233  
   234  // This function splits interval "bounds" into two balanced subintervals,
   235  // [bounds.Start, mid) and [mid, bounds.Limit). left (right) will store a subset
   236  // of ents[] that fits in the first (second, resp) subinterval. Note that an
   237  // entry in ents[] may belong to both left and right, if the entry spans over
   238  // the midpoint.
   239  //
   240  // Ok=false if this function fails to find a good split point.
   241  func split(label string, ents []*entry, bounds Interval, r *rand.Rand) (mid Key, left []*entry, right []*entry, ok bool) {
   242  	// A good interval split point is guaranteed to be at one of the interval
   243  	// endpoints.  To bound the compute time, we sample up to 16 intervals in
   244  	// ents[], and examine their endpoints one by one.
   245  	sample := randomSample(ents, r)
   246  	sampleRange := keyRange(sample).Intersect(bounds)
   247  	log.Debug.Printf("%s: Split %+v, %d ents", label, sampleRange, len(ents))
   248  	if sampleRange.Empty() {
   249  		panic(sample)
   250  	}
   251  	var (
   252  		candidates [maxSample * 2]Key
   253  		nCandidate int
   254  	)
   255  	for i, e := range sample {
   256  		candidates[i*2] = e.Interval.Start
   257  		candidates[i*2+1] = e.Interval.Limit
   258  		nCandidate += 2
   259  	}
   260  
   261  	// splitAt splits ents[] into two subsets, assuming bounds is split at mid.
   262  	splitAt := func(ents []*entry, mid Key, left, right *[]*entry) {
   263  		*left = (*left)[:0]
   264  		*right = (*right)[:0]
   265  		for _, e := range ents {
   266  			if e.Interval.Intersects(Interval{bounds.Start, mid}) {
   267  				*left = append(*left, e)
   268  			}
   269  			if e.Interval.Intersects(Interval{mid, bounds.Limit}) {
   270  				*right = append(*right, e)
   271  			}
   272  		}
   273  	}
   274  
   275  	// Compute the cost of splitting at each of candidates[].
   276  	// We use the surface-area heuristics. The best explanation is in
   277  	// the following paper:
   278  	//
   279  	// Ingo Wald, Realtime ray tracing and interactive global illumination,
   280  	// http://www.sci.utah.edu/~wald/Publications/2004/PhD/phd.pdf
   281  	//
   282  	// The basic idea is the following:
   283  	//
   284  	// - Assume we split the parent interval [s, e) into two intervals
   285  	//   [s,m) and [m,e)
   286  	//
   287  	// - The cost C(x) of searching a subinterval x is roughly
   288  	//    C(x) = (length of x) * (# of entries that intersect x).
   289  	//
   290  	//    The first term is the probability that a query hits the subinterval, and
   291  	//    the 2nd term is the cost of searching inside the subinterval.
   292  	//
   293  	//    This assumes that a query is distributed uniformly over the domain (in
   294  	//    our case, [-maxint32, maxint32].
   295  	//
   296  	// - The best split point is m that minimizes C([s,m)) + C([m,e))
   297  	minCost := math.MaxFloat64
   298  	var minMid Key
   299  	var minLeft, minRight []*entry
   300  	var tmpLeft, tmpRight []*entry
   301  
   302  	for _, mid := range candidates[:nCandidate] {
   303  		splitAt(ents, mid, &tmpLeft, &tmpRight)
   304  		if len(tmpLeft) == 0 || len(tmpRight) == 0 {
   305  			continue
   306  		}
   307  		cost := float64(len(tmpLeft))*float64(mid-sampleRange.Start) +
   308  			float64(len(tmpRight))*float64(sampleRange.Limit-mid)
   309  		if cost < minCost {
   310  			minMid = mid
   311  			minLeft, tmpLeft = tmpLeft, minLeft
   312  			minRight, tmpRight = tmpRight, minRight
   313  			minCost = cost
   314  		}
   315  	}
   316  	if minCost == math.MaxFloat64 || len(minLeft) == len(ents) || len(minRight) == len(ents) {
   317  		return
   318  	}
   319  	mid = minMid
   320  	left = minLeft
   321  	right = minRight
   322  	ok = true
   323  	return
   324  }
   325  
   326  func (n *node) init(label string, ents []*entry, bounds Interval, r *rand.Rand, stats *TreeStats) {
   327  	defer func() {
   328  		// Update the stats.
   329  		stats.Nodes++
   330  		depth := len(n.label)
   331  		if depth > stats.MaxDepth {
   332  			stats.MaxDepth = depth
   333  		}
   334  		if e := len(n.ents); e > 0 { // Leaf node
   335  			stats.LeafNodes++
   336  			stats.TotalLeafNodeSize += e
   337  			stats.TotalLeafDepth += depth
   338  			if e > stats.MaxLeafNodeSize {
   339  				stats.MaxLeafNodeSize = e
   340  			}
   341  		}
   342  	}()
   343  
   344  	n.label = label
   345  	n.bounds = bounds
   346  	if len(ents) <= maxEntsInNode {
   347  		n.ents = ents
   348  		return
   349  	}
   350  	mid, left, right, ok := split(n.label, ents, bounds, r)
   351  	if !ok {
   352  		n.ents = ents
   353  		return
   354  	}
   355  	n.left = &node{}
   356  
   357  	leftInterval := Interval{n.bounds.Start, mid}
   358  	leftKR := keyRange(left)
   359  	log.Debug.Printf("%v (bounds %v): left %v %v %v", n.label, n.bounds, leftKR, leftInterval, leftKR.Intersect(leftInterval))
   360  	n.left.init(label+"L", left, leftKR.Intersect(leftInterval), r, stats)
   361  	n.right = &node{}
   362  	n.right.init(label+"R", right, keyRange(right).Intersect(Interval{mid, n.bounds.Limit}), r, stats)
   363  }
   364  
   365  func addEntry(ents *[]*Entry, e *entry, s *searcher) {
   366  	if s.visit(e.id) {
   367  		*ents = append(*ents, (*Entry)(unsafe.Pointer(e)))
   368  	}
   369  }
   370  
   371  func (n *node) get(interval Interval, ents *[]*Entry, s *searcher) {
   372  	interval = interval.Intersect(n.bounds)
   373  	if interval.Empty() {
   374  		return
   375  	}
   376  	if len(n.ents) > 0 { // Leaf node
   377  		for _, e := range n.ents {
   378  			if interval.Intersects(e.Interval) {
   379  				addEntry(ents, e, s)
   380  			}
   381  		}
   382  		return
   383  	}
   384  	n.left.get(interval, ents, s)
   385  	n.right.get(interval, ents, s)
   386  }
   387  
   388  func (n *node) any(interval Interval, s *searcher) bool {
   389  	interval = interval.Intersect(n.bounds)
   390  	if interval.Empty() {
   391  		return false
   392  	}
   393  	if len(n.ents) > 0 { // Leaf node
   394  		for _, e := range n.ents {
   395  			if interval.Intersects(e.Interval) {
   396  				return true
   397  			}
   398  		}
   399  		return false
   400  	}
   401  	found := n.left.any(interval, s)
   402  	if !found {
   403  		found = n.right.any(interval, s)
   404  	}
   405  	return found
   406  }
   407  
   408  // GOB support
   409  
   410  const gobFormatVersion = 1
   411  
   412  // MarshalBinary implements encoding.BinaryMarshaler interface.  It allows T to
   413  // be encoded and decoded using Gob.
   414  func (t *T) MarshalBinary() (data []byte, err error) {
   415  	buf := bytes.Buffer{}
   416  	e := gob.NewEncoder(&buf)
   417  	must.Nil(e.Encode(gobFormatVersion))
   418  	marshalNode(e, &t.root)
   419  	must.Nil(e.Encode(t.stats))
   420  	return buf.Bytes(), nil
   421  }
   422  
   423  func marshalNode(e *gob.Encoder, n *node) {
   424  	if n == nil {
   425  		must.Nil(e.Encode(false))
   426  		return
   427  	}
   428  	must.Nil(e.Encode(true))
   429  	must.Nil(e.Encode(n.bounds))
   430  	marshalNode(e, n.left)
   431  	marshalNode(e, n.right)
   432  	must.Nil(e.Encode(len(n.ents)))
   433  	for _, ent := range n.ents {
   434  		must.Nil(e.Encode(ent.Entry))
   435  		must.Nil(e.Encode(ent.id))
   436  	}
   437  	must.Nil(e.Encode(n.label))
   438  }
   439  
   440  // UnmarshalBinary implements encoding.BinaryUnmarshaler interface.
   441  // It allows T to be encoded and decoded using Gob.
   442  func (t *T) UnmarshalBinary(data []byte) error {
   443  	buf := bytes.NewReader(data)
   444  	d := gob.NewDecoder(buf)
   445  	var version int
   446  	if err := d.Decode(&version); err != nil {
   447  		return err
   448  	}
   449  	if version != gobFormatVersion {
   450  		return fmt.Errorf("gob decode: got version %d, want %d", version, gobFormatVersion)
   451  	}
   452  	var (
   453  		maxid = -1
   454  		err   error
   455  		root  *node
   456  	)
   457  	if root, err = unmarshalNode(d, &maxid); err != nil {
   458  		return err
   459  	}
   460  	t.root = *root
   461  	if err := d.Decode(&t.stats); err != nil {
   462  		return err
   463  	}
   464  	t.pool = newSearcherFreePool(t, maxid+1)
   465  	return nil
   466  }
   467  
   468  func unmarshalNode(d *gob.Decoder, maxid *int) (*node, error) {
   469  	var (
   470  		exist bool
   471  		err   error
   472  	)
   473  	if err = d.Decode(&exist); err != nil {
   474  		return nil, err
   475  	}
   476  	if !exist {
   477  		return nil, nil
   478  	}
   479  	n := &node{}
   480  	if err := d.Decode(&n.bounds); err != nil {
   481  		return nil, err
   482  	}
   483  	if n.left, err = unmarshalNode(d, maxid); err != nil {
   484  		return nil, err
   485  	}
   486  	if n.right, err = unmarshalNode(d, maxid); err != nil {
   487  		return nil, err
   488  	}
   489  	var nEnt int
   490  	if err := d.Decode(&nEnt); err != nil {
   491  		return nil, err
   492  	}
   493  	n.ents = make([]*entry, nEnt)
   494  	for i := 0; i < nEnt; i++ {
   495  		n.ents[i] = &entry{}
   496  		if err := d.Decode(&n.ents[i].Entry); err != nil {
   497  			return nil, err
   498  		}
   499  		if err := d.Decode(&n.ents[i].id); err != nil {
   500  			return nil, err
   501  		}
   502  		if n.ents[i].id > *maxid {
   503  			*maxid = n.ents[i].id
   504  		}
   505  	}
   506  	if err := d.Decode(&n.label); err != nil {
   507  		return nil, err
   508  	}
   509  	return n, nil
   510  }