github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/util/bitarray/bitarray.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package bitarray
    12  
    13  import (
    14  	"bytes"
    15  	"fmt"
    16  	"math/rand"
    17  	"unsafe"
    18  
    19  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
    20  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
    21  	"github.com/cockroachdb/errors"
    22  )
    23  
    24  // BitArray implements a bit string of arbitrary length.
    25  //
    26  // This uses a packed encoding (i.e. groups of 64 bits at a time) for
    27  // memory efficiency and speed of bitwise operations (enables use of
    28  // full machine registers for comparisons and logical operations),
    29  // akin to the big.nat type.
    30  //
    31  // There is something fancy needed to handle sorting values properly:
    32  // the last group of bits must be padded right (start on the MSB)
    33  // inside its word to compare properly according to pg semantics.
    34  //
    35  // This type is designed for immutable instances. The functions and
    36  // methods defined below never write to a bit array in-place. Of note,
    37  // the ToWidth() and Next() functions will share the backing array
    38  // between their operand and their result in some cases.
    39  //
    40  // For portability, the size of the backing word is guaranteed to be 64
    41  // bits.
    42  type BitArray struct {
    43  	// words is the backing array.
    44  	//
    45  	// The leftmost bits in the literal representation are placed in the
    46  	// MSB of each word.
    47  	//
    48  	// The last word contain the rightmost bits in the literal
    49  	// representation, right-padded. For example if there are 3 bits
    50  	// to store, the 3 MSB bits of the last word will be set and the
    51  	// remaining LSB bits will be set to zero.
    52  	//
    53  	// The number of stored bits is actually:
    54  	//   0 if lastBitsUsed = 0  or len(word) == 0
    55  	//   otherwise, (len(words)-1)*numBitsPerWord + lastBitsUsed
    56  	//
    57  	// TODO(jutin, nathan): consider using the trick in bytes.Buffer of
    58  	// keeping a static [1]word which word can initially point to to
    59  	// avoid heap allocations in the common case of small arrays.
    60  	words []word
    61  
    62  	// lastBitsUsed is the number of bits in the last word that
    63  	// participate in the value stored. It can only be zero
    64  	// for empty bit arrays; otherwise it's always between 1 and
    65  	// numBitsPerWord.
    66  	//
    67  	// For example:
    68  	// - 0 bits in array: len(words) == 0, lastBitsUsed = 0
    69  	// - 1 bits in array: len(words) == 1, lastBitsUsed = 1
    70  	// - 64 bits in array: len(words) == 1, lastBitsUsed = 64
    71  	// - 65 bits in array: len(words) == 2, lastBitsUsed = 1
    72  	lastBitsUsed uint8
    73  }
    74  
    75  type word = uint64
    76  
    77  const numBytesPerWord = 8
    78  const numBitsPerWord = 64
    79  
    80  // BitLen returns the number of bits stored.
    81  func (d BitArray) BitLen() uint {
    82  	if len(d.words) == 0 {
    83  		return 0
    84  	}
    85  	return d.nonEmptyBitLen()
    86  }
    87  
    88  func (d BitArray) nonEmptyBitLen() uint {
    89  	return uint(len(d.words)-1)*numBitsPerWord + uint(d.lastBitsUsed)
    90  }
    91  
    92  // String implements the fmt.Stringer interface.
    93  func (d BitArray) String() string {
    94  	var buf bytes.Buffer
    95  	d.Format(&buf)
    96  	return buf.String()
    97  }
    98  
    99  // Clone makes a copy of the bit array.
   100  func (d BitArray) Clone() BitArray {
   101  	return BitArray{
   102  		words:        append([]word(nil), d.words...),
   103  		lastBitsUsed: d.lastBitsUsed,
   104  	}
   105  }
   106  
   107  // MakeZeroBitArray creates a bit array with the specified bit size.
   108  func MakeZeroBitArray(bitLen uint) BitArray {
   109  	a, b := EncodingPartsForBitLen(bitLen)
   110  	return mustFromEncodingParts(a, b)
   111  }
   112  
   113  // ToWidth resizes the bit array to the specified size.
   114  // If the specified width is shorter, bits on the right are truncated away.
   115  // If the specified width is larger, zero bits are added on the right.
   116  func (d BitArray) ToWidth(desiredLen uint) BitArray {
   117  	bitlen := d.BitLen()
   118  	if bitlen == desiredLen {
   119  		// Nothing to do; fast path.
   120  		return d
   121  	}
   122  	if desiredLen == 0 {
   123  		// Nothing to do; fast path.
   124  		return BitArray{}
   125  	}
   126  	if desiredLen < bitlen {
   127  		// Destructive, we have to copy.
   128  		words, lastBitsUsed := EncodingPartsForBitLen(desiredLen)
   129  		copy(words, d.words[:len(words)])
   130  		words[len(words)-1] &= (^word(0) << (numBitsPerWord - lastBitsUsed))
   131  		return mustFromEncodingParts(words, lastBitsUsed)
   132  	}
   133  
   134  	// New length is larger.
   135  	numWords, lastBitsUsed := SizesForBitLen(desiredLen)
   136  	var words []word
   137  	if numWords <= uint(cap(d.words)) {
   138  		words = d.words[0:numWords]
   139  	} else {
   140  		words = make([]word, numWords)
   141  		copy(words, d.words)
   142  	}
   143  	return mustFromEncodingParts(words, lastBitsUsed)
   144  }
   145  
   146  // Sizeof returns the size in bytes of the bit array and its components.
   147  func (d BitArray) Sizeof() uintptr {
   148  	return unsafe.Sizeof(d) + uintptr(numBytesPerWord*cap(d.words))
   149  }
   150  
   151  // IsEmpty returns true iff the array is empty.
   152  func (d BitArray) IsEmpty() bool {
   153  	return d.lastBitsUsed == 0
   154  }
   155  
   156  // MakeBitArrayFromInt64 creates a bit array with the specified
   157  // size. The bits from the integer are written to the right of the bit
   158  // array and the sign bit is extended.
   159  func MakeBitArrayFromInt64(bitLen uint, val int64, valWidth uint) BitArray {
   160  	if bitLen == 0 {
   161  		return BitArray{}
   162  	}
   163  	d := MakeZeroBitArray(bitLen)
   164  	if bitLen < valWidth {
   165  		// Fast path, no sign extension to compute.
   166  		d.words[len(d.words)-1] = word(val << (numBitsPerWord - bitLen))
   167  		return d
   168  	}
   169  	if val&(1<<(valWidth-1)) != 0 {
   170  		// Sign extend, fill ones in every word but the last.
   171  		for i := 0; i < len(d.words)-1; i++ {
   172  			d.words[i] = ^word(0)
   173  		}
   174  	}
   175  	// Shift the value to its given number of bits, to position the sign
   176  	// bit to the left.
   177  	val = val << (numBitsPerWord - valWidth)
   178  	// Shift right back with arithmetic shift to extend the sign bit.
   179  	val = val >> (numBitsPerWord - valWidth)
   180  	// Store the right part of the value in the last word.
   181  	d.words[len(d.words)-1] = word(val << (numBitsPerWord - d.lastBitsUsed))
   182  	// Store the left part in the next-to-last word, if any.
   183  	if valWidth > uint(d.lastBitsUsed) {
   184  		d.words[len(d.words)-2] = word(val >> d.lastBitsUsed)
   185  	}
   186  	return d
   187  }
   188  
   189  // AsInt64 returns the int constituted from the rightmost bits in the
   190  // bit array.
   191  func (d BitArray) AsInt64(nbits uint) int64 {
   192  	if d.lastBitsUsed == 0 {
   193  		// Fast path.
   194  		return 0
   195  	}
   196  
   197  	lowPart := d.words[len(d.words)-1] >> (numBitsPerWord - d.lastBitsUsed)
   198  	highPart := word(0)
   199  	if nbits > uint(d.lastBitsUsed) && len(d.words) > 1 {
   200  		highPart = d.words[len(d.words)-2] << d.lastBitsUsed
   201  	}
   202  	combined := lowPart | highPart
   203  	signExtended := int64(combined<<(numBitsPerWord-nbits)) >> (numBitsPerWord - nbits)
   204  	return signExtended
   205  }
   206  
   207  // LeftShiftAny performs a logical left shift, with a possible
   208  // negative count.
   209  // The number of bits to shift can be arbitrarily large (i.e. possibly
   210  // larger than 64 in absolute value).
   211  func (d BitArray) LeftShiftAny(n int64) BitArray {
   212  	bitlen := d.BitLen()
   213  	if n == 0 || bitlen == 0 {
   214  		// Fast path.
   215  		return d
   216  	}
   217  
   218  	r := MakeZeroBitArray(bitlen)
   219  	if (n > 0 && n > int64(bitlen)) || (n < 0 && -n > int64(bitlen)) {
   220  		// Fast path.
   221  		return r
   222  	}
   223  
   224  	if n > 0 {
   225  		// This is a left shift.
   226  		dstWord := uint(0)
   227  		srcWord := uint(uint64(n) / numBitsPerWord)
   228  		srcShift := uint(uint64(n) % numBitsPerWord)
   229  		for i, j := srcWord, dstWord; i < uint(len(d.words)); i++ {
   230  			r.words[j] = d.words[i] << srcShift
   231  			j++
   232  		}
   233  		for i, j := srcWord+1, dstWord; i < uint(len(d.words)); i++ {
   234  			r.words[j] |= d.words[i] >> (numBitsPerWord - srcShift)
   235  			j++
   236  		}
   237  	} else {
   238  		// A right shift.
   239  		n = -n
   240  		srcWord := uint(0)
   241  		dstWord := uint(uint64(n) / numBitsPerWord)
   242  		srcShift := uint(uint64(n) % numBitsPerWord)
   243  		for i, j := srcWord, dstWord; j < uint(len(r.words)); i++ {
   244  			r.words[j] = d.words[i] >> srcShift
   245  			j++
   246  		}
   247  		for i, j := srcWord, dstWord+1; j < uint(len(r.words)); i++ {
   248  			r.words[j] |= d.words[i] << (numBitsPerWord - srcShift)
   249  			j++
   250  		}
   251  		// Erase the trailing bits that are not used any more.
   252  		// See #36606.
   253  		if len(r.words) > 0 {
   254  			r.words[len(r.words)-1] &= ^word(0) << (numBitsPerWord - r.lastBitsUsed)
   255  		}
   256  	}
   257  
   258  	return r
   259  }
   260  
   261  // byteReprs contains the bit representation of the 256 possible
   262  // groups of 8 bits.
   263  var byteReprs = func() (ret [256]string) {
   264  	for i := range ret {
   265  		// Change this format if numBitsPerWord changes.
   266  		ret[i] = fmt.Sprintf("%08b", i)
   267  	}
   268  	return ret
   269  }()
   270  
   271  // Format prints out the bit array to the buffer.
   272  func (d BitArray) Format(buf *bytes.Buffer) {
   273  	bitLen := d.BitLen()
   274  	buf.Grow(int(bitLen))
   275  	for i := uint(0); i < bitLen/numBitsPerWord; i++ {
   276  		w := d.words[i]
   277  		// Change this loop if numBitsPerWord changes.
   278  		buf.WriteString(byteReprs[(w>>56)&0xff])
   279  		buf.WriteString(byteReprs[(w>>48)&0xff])
   280  		buf.WriteString(byteReprs[(w>>40)&0xff])
   281  		buf.WriteString(byteReprs[(w>>32)&0xff])
   282  		buf.WriteString(byteReprs[(w>>24)&0xff])
   283  		buf.WriteString(byteReprs[(w>>16)&0xff])
   284  		buf.WriteString(byteReprs[(w>>8)&0xff])
   285  		buf.WriteString(byteReprs[(w>>0)&0xff])
   286  	}
   287  	remainingBits := bitLen % numBitsPerWord
   288  	if remainingBits > 0 {
   289  		lastWord := d.words[bitLen/numBitsPerWord]
   290  		minShift := numBitsPerWord - 1 - remainingBits
   291  		for i := numBitsPerWord - 1; i > int(minShift); i-- {
   292  			bitVal := (lastWord >> uint(i)) & 1
   293  			buf.WriteByte('0' + byte(bitVal))
   294  		}
   295  	}
   296  }
   297  
   298  // EncodingPartsForBitLen creates a word backing array and the
   299  // "last bits used" value given the given total number of bits.
   300  func EncodingPartsForBitLen(bitLen uint) ([]uint64, uint64) {
   301  	if bitLen == 0 {
   302  		return nil, 0
   303  	}
   304  	numWords, lastBitsUsed := SizesForBitLen(bitLen)
   305  	words := make([]word, numWords)
   306  	return words, lastBitsUsed
   307  }
   308  
   309  // SizesForBitLen computes the number of words and last bits used for
   310  // the requested bit array size.
   311  func SizesForBitLen(bitLen uint) (uint, uint64) {
   312  	// This computes ceil(bitLen / numBitsPerWord).
   313  	numWords := (bitLen + numBitsPerWord - 1) / numBitsPerWord
   314  	lastBitsUsed := uint64(bitLen % numBitsPerWord)
   315  	if lastBitsUsed == 0 {
   316  		lastBitsUsed = numBitsPerWord
   317  	}
   318  	return numWords, lastBitsUsed
   319  }
   320  
   321  // Parse parses a bit array from the specified string.
   322  func Parse(s string) (res BitArray, err error) {
   323  	if len(s) == 0 {
   324  		return res, nil
   325  	}
   326  
   327  	words, lastBitsUsed := EncodingPartsForBitLen(uint(len(s)))
   328  
   329  	// Parse the bits.
   330  	wordIdx := 0
   331  	bitIdx := uint(0)
   332  	curWord := word(0)
   333  	for _, c := range s {
   334  		val := word(c - '0')
   335  		bitVal := val & 1
   336  		if bitVal != val {
   337  			// Note: the prefix "could not parse" is important as it is used
   338  			// to detect parsing errors in tests.
   339  			err := fmt.Errorf(`could not parse string as bit array: "%c" is not a valid binary digit`, c)
   340  			return res, pgerror.WithCandidateCode(err, pgcode.InvalidTextRepresentation)
   341  		}
   342  		curWord |= bitVal << (63 - bitIdx)
   343  		bitIdx = (bitIdx + 1) % numBitsPerWord
   344  		if bitIdx == 0 {
   345  			words[wordIdx] = curWord
   346  			curWord = 0
   347  			wordIdx++
   348  		}
   349  	}
   350  	if bitIdx > 0 {
   351  		// Ensure the last word is stored.
   352  		words[wordIdx] = curWord
   353  	}
   354  
   355  	return FromEncodingParts(words, lastBitsUsed)
   356  }
   357  
   358  // Concat concatenates two bit arrays.
   359  func Concat(lhs, rhs BitArray) BitArray {
   360  	if lhs.lastBitsUsed == 0 {
   361  		return rhs
   362  	}
   363  	if rhs.lastBitsUsed == 0 {
   364  		return lhs
   365  	}
   366  	words := make([]word, (lhs.nonEmptyBitLen()+rhs.nonEmptyBitLen()+numBitsPerWord-1)/numBitsPerWord)
   367  
   368  	// The first bits come from the lhs unchanged.
   369  	copy(words, lhs.words)
   370  	var lastBitsUsed uint8
   371  	if lhs.lastBitsUsed == numBitsPerWord {
   372  		// Fast path. Just concatenate.
   373  		copy(words[len(lhs.words):], rhs.words)
   374  		lastBitsUsed = rhs.lastBitsUsed
   375  	} else {
   376  		// We need to shift all the words in the RHS
   377  		// by the lastBitsUsed of the LHS.
   378  		rhsShift := lhs.lastBitsUsed
   379  		targetWordIdx := len(lhs.words) - 1
   380  		trailingBits := words[targetWordIdx]
   381  		for _, w := range rhs.words {
   382  			headingBits := w >> rhsShift
   383  			combinedBits := trailingBits | headingBits
   384  			words[targetWordIdx] = combinedBits
   385  			targetWordIdx++
   386  			trailingBits = w << (numBitsPerWord - rhsShift)
   387  		}
   388  		lastBitsUsed = lhs.lastBitsUsed + rhs.lastBitsUsed
   389  		if lastBitsUsed > numBitsPerWord {
   390  			// Some bits from the RHS didn't fill a
   391  			// word, we need to fit them in the last word.
   392  			words[targetWordIdx] = trailingBits
   393  		}
   394  
   395  		// Compute the final thing.
   396  		lastBitsUsed %= numBitsPerWord
   397  		if lastBitsUsed == 0 {
   398  			lastBitsUsed = numBitsPerWord
   399  		}
   400  	}
   401  	return BitArray{words: words, lastBitsUsed: lastBitsUsed}
   402  }
   403  
   404  // Not computes the complement of a bit array.
   405  func Not(d BitArray) BitArray {
   406  	res := d.Clone()
   407  	for i, w := range res.words {
   408  		res.words[i] = ^w
   409  	}
   410  	if res.lastBitsUsed > 0 {
   411  		lastWord := len(res.words) - 1
   412  		res.words[lastWord] &= (^word(0) << (numBitsPerWord - res.lastBitsUsed))
   413  	}
   414  	return res
   415  }
   416  
   417  // And computes the logical AND of two bit arrays.
   418  // The caller must ensure they have the same bit size.
   419  func And(lhs, rhs BitArray) BitArray {
   420  	res := lhs.Clone()
   421  	for i, w := range rhs.words {
   422  		res.words[i] &= w
   423  	}
   424  	return res
   425  }
   426  
   427  // Or computes the logical OR of two bit arrays.
   428  // The caller must ensure they have the same bit size.
   429  func Or(lhs, rhs BitArray) BitArray {
   430  	res := lhs.Clone()
   431  	for i, w := range rhs.words {
   432  		res.words[i] |= w
   433  	}
   434  	return res
   435  }
   436  
   437  // Xor computes the logical XOR of two bit arrays.
   438  // The caller must ensure they have the same bit size.
   439  func Xor(lhs, rhs BitArray) BitArray {
   440  	res := lhs.Clone()
   441  	for i, w := range rhs.words {
   442  		res.words[i] ^= w
   443  	}
   444  	return res
   445  }
   446  
   447  // Compare compares two bit arrays. They can have mixed sizes.
   448  func Compare(lhs, rhs BitArray) int {
   449  	n := len(lhs.words)
   450  	if n > len(rhs.words) {
   451  		n = len(rhs.words)
   452  	}
   453  	i := 0
   454  	for ; i < n; i++ {
   455  		lw := lhs.words[i]
   456  		rw := rhs.words[i]
   457  		if lw < rw {
   458  			return -1
   459  		}
   460  		if lw > rw {
   461  			return 1
   462  		}
   463  	}
   464  	if i < len(rhs.words) {
   465  		// lhs is shorter.
   466  		return -1
   467  	}
   468  	if i < len(lhs.words) {
   469  		// rhs is shorter.
   470  		return 1
   471  	}
   472  	// Same length.
   473  	if lhs.lastBitsUsed < rhs.lastBitsUsed {
   474  		return -1
   475  	}
   476  	if lhs.lastBitsUsed > rhs.lastBitsUsed {
   477  		return 1
   478  	}
   479  	return 0
   480  }
   481  
   482  // EncodingParts retrieves the encoding bits from the bit array. The
   483  // words are presented in big-endian order, with the leftmost bits of
   484  // the bitarray (MSB) in the MSB of each word.
   485  func (d BitArray) EncodingParts() ([]uint64, uint64) {
   486  	return d.words, uint64(d.lastBitsUsed)
   487  }
   488  
   489  // FromEncodingParts creates a bit array from the encoding parts.
   490  func FromEncodingParts(words []uint64, lastBitsUsed uint64) (BitArray, error) {
   491  	if lastBitsUsed > numBitsPerWord {
   492  		err := fmt.Errorf("FromEncodingParts: lastBitsUsed must not exceed %d, got %d",
   493  			errors.Safe(numBitsPerWord), errors.Safe(lastBitsUsed))
   494  		return BitArray{}, pgerror.WithCandidateCode(err, pgcode.InvalidParameterValue)
   495  	}
   496  	return BitArray{
   497  		words:        words,
   498  		lastBitsUsed: uint8(lastBitsUsed),
   499  	}, nil
   500  }
   501  
   502  // mustFromEncodingParts is like FromEncodingParts but errors cause a panic.
   503  func mustFromEncodingParts(words []uint64, lastBitsUsed uint64) BitArray {
   504  	ba, err := FromEncodingParts(words, lastBitsUsed)
   505  	if err != nil {
   506  		panic(err)
   507  	}
   508  	return ba
   509  }
   510  
   511  // Rand generates a random bit array of the specified length.
   512  func Rand(rng *rand.Rand, bitLen uint) BitArray {
   513  	d := MakeZeroBitArray(bitLen)
   514  	for i := range d.words {
   515  		d.words[i] = rng.Uint64()
   516  	}
   517  	if len(d.words) > 0 {
   518  		d.words[len(d.words)-1] <<= (numBitsPerWord - d.lastBitsUsed)
   519  	}
   520  	return d
   521  }
   522  
   523  // Next returns the next possible bit array in lexicographic order.
   524  // The backing array of words is shared if possible.
   525  func Next(d BitArray) BitArray {
   526  	if d.lastBitsUsed == 0 {
   527  		return BitArray{words: []word{0}, lastBitsUsed: 1}
   528  	}
   529  	if d.lastBitsUsed < numBitsPerWord {
   530  		res := d
   531  		res.lastBitsUsed++
   532  		return res
   533  	}
   534  	res := BitArray{
   535  		words:        make([]word, len(d.words)+1),
   536  		lastBitsUsed: 1,
   537  	}
   538  	copy(res.words, d.words)
   539  	return res
   540  }
   541  
   542  // GetBitAtIndex extract bit at given index in the BitArray.
   543  func (d BitArray) GetBitAtIndex(index int) (int, error) {
   544  	// Check whether index asked is inside BitArray.
   545  	if index < 0 || uint(index) >= d.BitLen() {
   546  		err := fmt.Errorf("GetBitAtIndex: bit index %d out of valid range (0..%d)", index, int(d.BitLen())-1)
   547  		return 0, pgerror.WithCandidateCode(err, pgcode.ArraySubscript)
   548  	}
   549  	// To extract bit at the given index, we have to determine the
   550  	// position within words array, i.e. index/numBitsPerWord after
   551  	// that checked the bit at residual index.
   552  	if d.words[index/numBitsPerWord]&(word(1)<<(numBitsPerWord-1-uint(index)%numBitsPerWord)) != 0 {
   553  		return 1, nil
   554  	}
   555  	return 0, nil
   556  }
   557  
   558  // SetBitAtIndex returns the BitArray with an updated bit at a given index.
   559  func (d BitArray) SetBitAtIndex(index, toSet int) (BitArray, error) {
   560  	res := d.Clone()
   561  	// Check whether index asked is inside BitArray.
   562  	if index < 0 || uint(index) >= res.BitLen() {
   563  		err := fmt.Errorf("SetBitAtIndex: bit index %d out of valid range (0..%d)", index, int(res.BitLen())-1)
   564  		return BitArray{}, pgerror.WithCandidateCode(err, pgcode.ArraySubscript)
   565  	}
   566  	// To update bit at the given index, we have to determine the
   567  	// position within words array, i.e. index/numBitsPerWord after
   568  	// that updated the bit at residual index.
   569  	// Forcefully making bit at the index to 0.
   570  	res.words[index/numBitsPerWord] &= ^(word(1) << (numBitsPerWord - 1 - uint(index)%numBitsPerWord))
   571  	// Updating value at the index to toSet.
   572  	res.words[index/numBitsPerWord] |= word(toSet) << (numBitsPerWord - 1 - uint(index)%numBitsPerWord)
   573  	return res, nil
   574  }