github.com/grailbio/base@v0.0.11/bitset/bitset.go (about)

     1  // Copyright 2022 GRAIL, Inc.  All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  // This is similar to github.com/willf/bitset , but with some extraneous
     6  // abstraction removed.  See also simd/count_amd64.go.
     7  //
     8  // ([]byte <-> []uintptr adapters will be added when needed.)
     9  
    10  package bitset
    11  
    12  import (
    13  	"math/bits"
    14  )
    15  
    16  // BitsPerWord is the number of bits in a machine word.
    17  const BitsPerWord = 64
    18  
    19  // Log2BitsPerWord is log_2(BitsPerWord).
    20  const Log2BitsPerWord = uint(6)
    21  
    22  // Set sets the given bit in a []uintptr bitset.
    23  func Set(data []uintptr, bitIdx int) {
    24  	// Unsigned division by a power-of-2 constant compiles to a right-shift,
    25  	// while signed does not due to negative nastiness.
    26  	data[uint(bitIdx)/BitsPerWord] |= 1 << (uint(bitIdx) % BitsPerWord)
    27  }
    28  
    29  // Clear clears the given bit in a []uintptr bitset.
    30  func Clear(data []uintptr, bitIdx int) {
    31  	wordIdx := uint(bitIdx) / BitsPerWord
    32  	data[wordIdx] = data[wordIdx] &^ (1 << (uint(bitIdx) % BitsPerWord))
    33  }
    34  
    35  // Test returns true iff the given bit is set.
    36  func Test(data []uintptr, bitIdx int) bool {
    37  	return (data[uint(bitIdx)/BitsPerWord] & (1 << (uint(bitIdx) % BitsPerWord))) != 0
    38  }
    39  
    40  // SetInterval sets the bits at all positions in [startIdx, limitIdx) in a
    41  // []uintptr bitset.
    42  func SetInterval(data []uintptr, startIdx, limitIdx int) {
    43  	if startIdx >= limitIdx {
    44  		return
    45  	}
    46  	startWordIdx := startIdx >> Log2BitsPerWord
    47  	startBit := uintptr(1) << uint32(startIdx&(BitsPerWord-1))
    48  	limitWordIdx := limitIdx >> Log2BitsPerWord
    49  	limitBit := uintptr(1) << uint32(limitIdx&(BitsPerWord-1))
    50  	if startWordIdx == limitWordIdx {
    51  		// We can't fill all bits from startBit on in the first word, since the
    52  		// limit is also within this word.
    53  		data[startWordIdx] |= limitBit - startBit
    54  		return
    55  	}
    56  	// Fill all bits from startBit on in the first word.
    57  	data[startWordIdx] |= -startBit
    58  	// Fill all bits in intermediate words.
    59  	// (todo: ensure compiler doesn't insert pointless slice bounds-checks on
    60  	// every iteration)
    61  	for wordIdx := startWordIdx + 1; wordIdx < limitWordIdx; wordIdx++ {
    62  		data[wordIdx] = ^uintptr(0)
    63  	}
    64  	// Fill just the bottom bits in the last word, if necessary.
    65  	if limitBit != 1 {
    66  		data[limitWordIdx] |= limitBit - 1
    67  	}
    68  }
    69  
    70  // ClearInterval clears the bits at all positions in [startIdx, limitIdx) in a
    71  // []uintptr bitset.
    72  func ClearInterval(data []uintptr, startIdx, limitIdx int) {
    73  	if startIdx >= limitIdx {
    74  		return
    75  	}
    76  	startWordIdx := startIdx >> Log2BitsPerWord
    77  	startBit := uintptr(1) << uint32(startIdx&(BitsPerWord-1))
    78  	limitWordIdx := limitIdx >> Log2BitsPerWord
    79  	limitBit := uintptr(1) << uint32(limitIdx&(BitsPerWord-1))
    80  	if startWordIdx == limitWordIdx {
    81  		// We can't clear all bits from startBit on in the first word, since the
    82  		// limit is also within this word.
    83  		data[startWordIdx] &= ^(limitBit - startBit)
    84  		return
    85  	}
    86  	// Clear all bits from startBit on in the first word.
    87  	data[startWordIdx] &= startBit - 1
    88  	// Clear all bits in intermediate words.
    89  	for wordIdx := startWordIdx + 1; wordIdx < limitWordIdx; wordIdx++ {
    90  		data[wordIdx] = 0
    91  	}
    92  	// Clear just the bottom bits in the last word, if necessary.
    93  	if limitBit != 1 {
    94  		data[limitWordIdx] &= -limitBit
    95  	}
    96  }
    97  
    98  // NewClearBits creates a []uintptr bitset with capacity for at least nBit
    99  // bits, and all bits clear.
   100  func NewClearBits(nBit int) []uintptr {
   101  	nWord := (nBit + BitsPerWord - 1) / BitsPerWord
   102  	return make([]uintptr, nWord)
   103  }
   104  
   105  // NewSetBits creates a []uintptr bitset with capacity for at least nBit bits,
   106  // and all bits at positions [0, nBit) set.
   107  func NewSetBits(nBit int) []uintptr {
   108  	data := NewClearBits(nBit)
   109  	SetInterval(data, 0, nBit)
   110  	return data
   111  }
   112  
   113  // NonzeroWordScanner iterates over and clears the set bits in a bitset, with
   114  // the somewhat unusual precondition that the number of nonzero words is known
   115  // in advance.  The 'BitsetScanner' name is being reserved for a scanner which
   116  // expects the number of set bits to be known instead.
   117  //
   118  // Note that, when many bits are set, a more complicated double-loop based
   119  // around a function like willf/bitset.NextSetMany() has ~40% less overhead (at
   120  // least with Go 1.10 on a Mac), and you can do even better with manual
   121  // inlining of the iteration logic.  As a consequence, it shouldn't be used
   122  // when the bit iteration/clearing process is actually the dominant
   123  // computational cost (and neither should NextSetMany(), manual inlining is
   124  // 2-6x better without much more code, see bitsetManualInlineSubtask() in
   125  // bitset_test.go for an example).  However, it's a good choice everywhere
   126  // else, outperforming the other scanners I'm aware of with similar ease of
   127  // use, and maybe a future Go version will inline it properly.
   128  type NonzeroWordScanner struct {
   129  	// data is the original bitset.
   130  	data []uintptr
   131  	// bitIdxOffset is BitsPerWord times the current data[] array index.
   132  	bitIdxOffset int
   133  	// bitWord is bits[bitIdxOffset / BitsPerWord], with already-iterated-over
   134  	// bits cleared.
   135  	bitWord uintptr
   136  	// nNonzeroWord is the number of nonzero words remaining in data[].
   137  	nNonzeroWord int
   138  }
   139  
   140  // NewNonzeroWordScanner returns a NonzeroWordScanner for the given bitset,
   141  // along with the position of the first bit.  (This interface has been chosen
   142  // to make for loops with properly-scoped variables easy to write.)
   143  //
   144  // The bitset is expected to be nonempty; otherwise this will crash the program
   145  // with an out-of-bounds slice access.  Similarly, if nNonzeroWord is larger
   146  // than the actual number of nonzero words, or initially <= 0, the standard for
   147  // loop will crash the program.  (If nNonzeroWord is smaller but >0, the last
   148  // nonzero words will be ignored.)
   149  func NewNonzeroWordScanner(data []uintptr, nNonzeroWord int) (NonzeroWordScanner, int) {
   150  	for wordIdx := 0; ; wordIdx++ {
   151  		bitWord := data[wordIdx]
   152  		if bitWord != 0 {
   153  			bitIdxOffset := wordIdx * BitsPerWord
   154  			return NonzeroWordScanner{
   155  				data:         data,
   156  				bitIdxOffset: bitIdxOffset,
   157  				bitWord:      bitWord & (bitWord - 1),
   158  				nNonzeroWord: nNonzeroWord,
   159  			}, bits.TrailingZeros64(uint64(bitWord)) + bitIdxOffset
   160  		}
   161  	}
   162  }
   163  
   164  // Next returns the position of the next set bit, or -1 if there aren't any.
   165  func (s *NonzeroWordScanner) Next() int {
   166  	bitWord := s.bitWord
   167  	if bitWord == 0 {
   168  		wordIdx := int(uint(s.bitIdxOffset) / BitsPerWord)
   169  		s.data[wordIdx] = 0
   170  		s.nNonzeroWord--
   171  		if s.nNonzeroWord == 0 {
   172  			// All words with set bits are accounted for, we can exit early.
   173  			// This is deliberately == 0 instead of <= 0 since it'll only be less
   174  			// than zero if there's a bug in the caller.  We want to crash with an
   175  			// out-of-bounds access in that case.
   176  			return -1
   177  		}
   178  		for {
   179  			wordIdx++
   180  			bitWord = s.data[wordIdx]
   181  			if bitWord != 0 {
   182  				break
   183  			}
   184  		}
   185  		s.bitIdxOffset = wordIdx * BitsPerWord
   186  	}
   187  	s.bitWord = bitWord & (bitWord - 1)
   188  	return bits.TrailingZeros64(uint64(bitWord)) + s.bitIdxOffset
   189  }