github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/util/intsets/sparse.go (about)

     1  // Copyright 2022 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package intsets
    12  
    13  // Sparse is a set of integers. It is not thread-safe. It must be copied with
    14  // the Copy method.
    15  //
    16  // Sparse is implemented as a linked list of blocks, each containing an offset
    17  // and a bitmap. A block with offset=o contains an integer o+b if the b-th bit
    18  // of the bitmap is set. Block offsets are always divisible by smallCutoff.
    19  //
    20  // For example, here is a diagram of the set {0, 1, 128, 129, 512}, where
    21  // each block is denoted by {offset, bitmap}:
    22  //
    23  //	{0, ..011} ---> {128, ..011} ---> {512, ..001}
    24  //
    25  // Sparse is inspired by golang.org/x/tools/container/intsets. Sparse implements
    26  // a smaller API, providing only the methods required by Fast. The omission of a
    27  // Max method allows us to use a singly-linked list here instead of a
    28  // circular, doubly-linked list.
    29  type Sparse struct {
    30  	root block
    31  }
    32  
    33  // block is a node in a singly-linked list with an offset and a bitmap. A block
    34  // with offset=o contains an integer o+b if the b-th bit of the bitmap is set.
    35  type block struct {
    36  	offset int
    37  	bits   bitmap
    38  	next   *block
    39  }
    40  
    41  const (
    42  	// MaxInt is the maximum integer that can be stored in a set.
    43  	MaxInt = int(^uint(0) >> 1)
    44  	// MinInt is the minimum integer that can be stored in a set.
    45  	MinInt = -MaxInt - 1
    46  
    47  	smallCutoffMask = smallCutoff - 1
    48  )
    49  
    50  func init() {
    51  	if smallCutoff == 0 || (smallCutoff&smallCutoffMask) != 0 {
    52  		panic("smallCutoff must be a power of two; see offset and bit")
    53  	}
    54  }
    55  
    56  // offset returns the block offset for the given integer.
    57  // Note: Bitwise AND NOT only works here because smallCutoff is a power of two.
    58  //
    59  //gcassert:inline
    60  func offset(i int) int {
    61  	return i &^ smallCutoffMask
    62  }
    63  
    64  // bit returns the bit within a block that should be set for the given integer.
    65  // Note: Bitwise AND only works here because smallCutoff is a power of two.
    66  //
    67  //gcassert:inline
    68  func bit(i int) int {
    69  	return i & smallCutoffMask
    70  }
    71  
    72  // empty returns true if the block is empty, i.e., none of its bits have been
    73  // set.
    74  //
    75  //gcassert:inline
    76  func (s block) empty() bool {
    77  	return s.bits == bitmap{}
    78  }
    79  
    80  // insertBlock inserts a block after prev and returns it. If prev is nil, a
    81  // block is inserted at the front of the list.
    82  func (s *Sparse) insertBlock(prev *block) *block {
    83  	if s.Empty() {
    84  		return &s.root
    85  	}
    86  	if prev == nil {
    87  		// Insert a new block at the front of the list.
    88  		second := s.root
    89  		s.root = block{}
    90  		s.root.next = &second
    91  		return &s.root
    92  	}
    93  	// Insert a new block in the middle of the list.
    94  	n := block{}
    95  	n.next = prev.next
    96  	prev.next = &n
    97  	return &n
    98  }
    99  
   100  // removeBlock removes a block from the list. prev must be the block before b.
   101  func (s *Sparse) removeBlock(prev, b *block) *block {
   102  	if prev == nil {
   103  		if b.next == nil {
   104  			s.root = block{}
   105  			return nil
   106  		}
   107  		s.root.offset = b.next.offset
   108  		s.root.bits = b.next.bits
   109  		s.root.next = b.next.next
   110  		return &s.root
   111  	}
   112  	prev.next = prev.next.next
   113  	return prev.next
   114  }
   115  
   116  // Clear empties the set.
   117  func (s *Sparse) Clear() {
   118  	s.root = block{}
   119  }
   120  
   121  // Add adds an integer to the set.
   122  func (s *Sparse) Add(i int) {
   123  	o := offset(i)
   124  	b := bit(i)
   125  	var last *block
   126  	for sb := &s.root; sb != nil && sb.offset <= o; sb = sb.next {
   127  		if sb.offset == o {
   128  			sb.bits.Set(b)
   129  			return
   130  		}
   131  		last = sb
   132  	}
   133  	n := s.insertBlock(last)
   134  	n.offset = o
   135  	n.bits.Set(b)
   136  }
   137  
   138  // Remove removes an integer from the set.
   139  func (s *Sparse) Remove(i int) {
   140  	o := offset(i)
   141  	b := bit(i)
   142  	var last *block
   143  	for sb := &s.root; sb != nil && sb.offset <= o; sb = sb.next {
   144  		if sb.offset == o {
   145  			sb.bits.Unset(b)
   146  			if sb.empty() {
   147  				s.removeBlock(last, sb)
   148  			}
   149  			return
   150  		}
   151  		last = sb
   152  	}
   153  }
   154  
   155  // Contains returns true if the set contains the given integer.
   156  func (s Sparse) Contains(i int) bool {
   157  	o := offset(i)
   158  	b := bit(i)
   159  	for sb := &s.root; sb != nil && sb.offset <= o; sb = sb.next {
   160  		if sb.offset == o {
   161  			return sb.bits.IsSet(b)
   162  		}
   163  	}
   164  	return false
   165  }
   166  
   167  // Empty returns true if the set contains no integers.
   168  func (s Sparse) Empty() bool {
   169  	return s.root.empty()
   170  }
   171  
   172  // Len returns the number of integers in the set.
   173  func (s Sparse) Len() int {
   174  	l := 0
   175  	for sb := &s.root; sb != nil; sb = sb.next {
   176  		l += sb.bits.OnesCount()
   177  	}
   178  	return l
   179  }
   180  
   181  // LowerBound returns the smallest element >= startVal, or MaxInt if there is no
   182  // such element.
   183  func (s *Sparse) LowerBound(startVal int) int {
   184  	if s.Empty() {
   185  		return MaxInt
   186  	}
   187  	o := offset(startVal)
   188  	b := bit(startVal)
   189  	for sb := &s.root; sb != nil; sb = sb.next {
   190  		if sb.offset > o {
   191  			v, _ := sb.bits.Next(0)
   192  			return v + sb.offset
   193  		}
   194  		if sb.offset == o {
   195  			if v, ok := sb.bits.Next(b); ok {
   196  				return v + sb.offset
   197  			}
   198  		}
   199  	}
   200  	return MaxInt
   201  }
   202  
   203  // Min returns the minimum value in the set. If the set is empty, MaxInt is
   204  // returned.
   205  func (s *Sparse) Min() int {
   206  	if s.Empty() {
   207  		return MaxInt
   208  	}
   209  	b := s.root
   210  	v, _ := b.bits.Next(0)
   211  	return v + b.offset
   212  }
   213  
   214  // Copy sets the receiver to a copy of rhs, which can then be modified
   215  // independently.
   216  func (s *Sparse) Copy(rhs *Sparse) {
   217  	var last *block
   218  	sb := &s.root
   219  	rb := &rhs.root
   220  	for rb != nil {
   221  		if sb == nil {
   222  			sb = s.insertBlock(last)
   223  		}
   224  		sb.offset = rb.offset
   225  		sb.bits = rb.bits
   226  		last = sb
   227  		sb = sb.next
   228  		rb = rb.next
   229  	}
   230  	if last != nil {
   231  		last.next = nil
   232  	}
   233  }
   234  
   235  // UnionWith adds all the elements from rhs to this set.
   236  func (s *Sparse) UnionWith(rhs *Sparse) {
   237  	if rhs.Empty() {
   238  		return
   239  	}
   240  
   241  	var last *block
   242  	sb := &s.root
   243  	rb := &rhs.root
   244  	for rb != nil {
   245  		if sb != nil && sb.offset == rb.offset {
   246  			sb.bits.UnionWith(rb.bits)
   247  			rb = rb.next
   248  		} else if sb == nil || sb.offset > rb.offset {
   249  			sb = s.insertBlock(last)
   250  			sb.offset = rb.offset
   251  			sb.bits = rb.bits
   252  			rb = rb.next
   253  		}
   254  		last = sb
   255  		sb = sb.next
   256  	}
   257  }
   258  
   259  // IntersectionWith removes any elements not in rhs from this set.
   260  func (s *Sparse) IntersectionWith(rhs *Sparse) {
   261  	var last *block
   262  	sb := &s.root
   263  	rb := &rhs.root
   264  	for sb != nil && rb != nil {
   265  		switch {
   266  		case sb.offset > rb.offset:
   267  			rb = rb.next
   268  		case sb.offset < rb.offset:
   269  			sb = s.removeBlock(last, sb)
   270  		default:
   271  			sb.bits.IntersectionWith(rb.bits)
   272  			if !sb.empty() {
   273  				// If sb is not empty, then advance sb and last.
   274  				//
   275  				// If sb is empty, we advance neither sb nor last so that the
   276  				// empty sb will be removed in the next iteration of the loop
   277  				// (the sb.offset < rb.offset case), or after the loop (see the
   278  				// comment below).
   279  				last = sb
   280  				sb = sb.next
   281  			}
   282  			rb = rb.next
   283  		}
   284  	}
   285  	if sb == &s.root {
   286  		// This is a special case that only happens when all the following are
   287  		// true:
   288  		//
   289  		//   1. Either s or rhs has a single block.
   290  		//   2. The first blocks of s and rhs have matching offsets.
   291  		//   3. The intersection of the first blocks of s and rhs yields an
   292  		//      empty block.
   293  		//
   294  		// In this case, the root block would not have been removed in the loop,
   295  		// and it may have a non-zero offset and a non-nil next block, so we
   296  		// clear it here.
   297  		s.root = block{}
   298  	}
   299  	if last != nil {
   300  		// At this point, last is a pointer to the last block in s that we've
   301  		// intersected with a block in rhs. If there are no remaining blocks in
   302  		// s, then last.next will be nil. If there are no remaining blocks in
   303  		// rhs, then we must remove any blocks after last. Unconditionally
   304  		// clearing last.next works in both cases.
   305  		last.next = nil
   306  	}
   307  }
   308  
   309  // Intersects returns true if s has any elements in common with rhs.
   310  func (s *Sparse) Intersects(rhs *Sparse) bool {
   311  	sb := &s.root
   312  	rb := &rhs.root
   313  	for sb != nil && rb != nil {
   314  		switch {
   315  		case sb.offset > rb.offset:
   316  			rb = rb.next
   317  		case sb.offset < rb.offset:
   318  			sb = sb.next
   319  		default:
   320  			if sb.bits.Intersects(rb.bits) {
   321  				return true
   322  			}
   323  			sb = sb.next
   324  			rb = rb.next
   325  		}
   326  	}
   327  	return false
   328  }
   329  
   330  // DifferenceWith removes any elements in rhs from this set.
   331  func (s *Sparse) DifferenceWith(rhs *Sparse) {
   332  	var last *block
   333  	sb := &s.root
   334  	rb := &rhs.root
   335  	for sb != nil && rb != nil {
   336  		switch {
   337  		case sb.offset > rb.offset:
   338  			rb = rb.next
   339  		case sb.offset < rb.offset:
   340  			last = sb
   341  			sb = sb.next
   342  		default:
   343  			sb.bits.DifferenceWith(rb.bits)
   344  			if sb.empty() {
   345  				sb = s.removeBlock(last, sb)
   346  			} else {
   347  				last = sb
   348  				sb = sb.next
   349  			}
   350  			rb = rb.next
   351  		}
   352  	}
   353  }
   354  
   355  // Equals returns true if the two sets are identical.
   356  func (s *Sparse) Equals(rhs *Sparse) bool {
   357  	sb := &s.root
   358  	rb := &rhs.root
   359  	for sb != nil && rb != nil {
   360  		if sb.offset != rb.offset || sb.bits != rb.bits {
   361  			return false
   362  		}
   363  		sb = sb.next
   364  		rb = rb.next
   365  	}
   366  	return sb == nil && rb == nil
   367  }
   368  
   369  // SubsetOf returns true if rhs contains all the elements in s.
   370  func (s *Sparse) SubsetOf(rhs *Sparse) bool {
   371  	if s.Empty() {
   372  		return true
   373  	}
   374  	sb := &s.root
   375  	rb := &rhs.root
   376  	for sb != nil && rb != nil {
   377  		if sb.offset > rb.offset {
   378  			rb = rb.next
   379  			continue
   380  		}
   381  		if sb.offset < rb.offset {
   382  			return false
   383  		}
   384  		if !sb.bits.SubsetOf(rb.bits) {
   385  			return false
   386  		}
   387  		sb = sb.next
   388  		rb = rb.next
   389  	}
   390  	return sb == nil
   391  }