github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/util/bitarray/bitarray.go (about) 1 // Copyright 2018 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package bitarray 12 13 import ( 14 "bytes" 15 "fmt" 16 "math/rand" 17 "unsafe" 18 19 "github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgcode" 20 "github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgerror" 21 "github.com/cockroachdb/errors" 22 ) 23 24 // BitArray implements a bit string of arbitrary length. 25 // 26 // This uses a packed encoding (i.e. groups of 64 bits at a time) for 27 // memory efficiency and speed of bitwise operations (enables use of 28 // full machine registers for comparisons and logical operations), 29 // akin to the big.nat type. 30 // 31 // There is something fancy needed to handle sorting values properly: 32 // the last group of bits must be padded right (start on the MSB) 33 // inside its word to compare properly according to pg semantics. 34 // 35 // This type is designed for immutable instances. The functions and 36 // methods defined below never write to a bit array in-place. Of note, 37 // the ToWidth() and Next() functions will share the backing array 38 // between their operand and their result in some cases. 39 // 40 // For portability, the size of the backing word is guaranteed to be 64 41 // bits. 42 type BitArray struct { 43 // words is the backing array. 44 // 45 // The leftmost bits in the literal representation are placed in the 46 // MSB of each word. 47 // 48 // The last word contain the rightmost bits in the literal 49 // representation, right-padded. For example if there are 3 bits 50 // to store, the 3 MSB bits of the last word will be set and the 51 // remaining LSB bits will be set to zero. 52 // 53 // The number of stored bits is actually: 54 // 0 if lastBitsUsed = 0 or len(word) == 0 55 // otherwise, (len(words)-1)*numBitsPerWord + lastBitsUsed 56 // 57 // TODO(jutin, nathan): consider using the trick in bytes.Buffer of 58 // keeping a static [1]word which word can initially point to to 59 // avoid heap allocations in the common case of small arrays. 60 words []word 61 62 // lastBitsUsed is the number of bits in the last word that 63 // participate in the value stored. It can only be zero 64 // for empty bit arrays; otherwise it's always between 1 and 65 // numBitsPerWord. 66 // 67 // For example: 68 // - 0 bits in array: len(words) == 0, lastBitsUsed = 0 69 // - 1 bits in array: len(words) == 1, lastBitsUsed = 1 70 // - 64 bits in array: len(words) == 1, lastBitsUsed = 64 71 // - 65 bits in array: len(words) == 2, lastBitsUsed = 1 72 lastBitsUsed uint8 73 } 74 75 type word = uint64 76 77 const numBytesPerWord = 8 78 const numBitsPerWord = 64 79 80 // BitLen returns the number of bits stored. 81 func (d BitArray) BitLen() uint { 82 if len(d.words) == 0 { 83 return 0 84 } 85 return d.nonEmptyBitLen() 86 } 87 88 func (d BitArray) nonEmptyBitLen() uint { 89 return uint(len(d.words)-1)*numBitsPerWord + uint(d.lastBitsUsed) 90 } 91 92 // String implements the fmt.Stringer interface. 93 func (d BitArray) String() string { 94 var buf bytes.Buffer 95 d.Format(&buf) 96 return buf.String() 97 } 98 99 // Clone makes a copy of the bit array. 100 func (d BitArray) Clone() BitArray { 101 return BitArray{ 102 words: append([]word(nil), d.words...), 103 lastBitsUsed: d.lastBitsUsed, 104 } 105 } 106 107 // MakeZeroBitArray creates a bit array with the specified bit size. 108 func MakeZeroBitArray(bitLen uint) BitArray { 109 a, b := EncodingPartsForBitLen(bitLen) 110 return mustFromEncodingParts(a, b) 111 } 112 113 // ToWidth resizes the bit array to the specified size. 114 // If the specified width is shorter, bits on the right are truncated away. 115 // If the specified width is larger, zero bits are added on the right. 116 func (d BitArray) ToWidth(desiredLen uint) BitArray { 117 bitlen := d.BitLen() 118 if bitlen == desiredLen { 119 // Nothing to do; fast path. 120 return d 121 } 122 if desiredLen == 0 { 123 // Nothing to do; fast path. 124 return BitArray{} 125 } 126 if desiredLen < bitlen { 127 // Destructive, we have to copy. 128 words, lastBitsUsed := EncodingPartsForBitLen(desiredLen) 129 copy(words, d.words[:len(words)]) 130 words[len(words)-1] &= (^word(0) << (numBitsPerWord - lastBitsUsed)) 131 return mustFromEncodingParts(words, lastBitsUsed) 132 } 133 134 // New length is larger. 135 numWords, lastBitsUsed := SizesForBitLen(desiredLen) 136 var words []word 137 if numWords <= uint(cap(d.words)) { 138 words = d.words[0:numWords] 139 } else { 140 words = make([]word, numWords) 141 copy(words, d.words) 142 } 143 return mustFromEncodingParts(words, lastBitsUsed) 144 } 145 146 // Sizeof returns the size in bytes of the bit array and its components. 147 func (d BitArray) Sizeof() uintptr { 148 return unsafe.Sizeof(d) + uintptr(numBytesPerWord*cap(d.words)) 149 } 150 151 // IsEmpty returns true iff the array is empty. 152 func (d BitArray) IsEmpty() bool { 153 return d.lastBitsUsed == 0 154 } 155 156 // MakeBitArrayFromInt64 creates a bit array with the specified 157 // size. The bits from the integer are written to the right of the bit 158 // array and the sign bit is extended. 159 func MakeBitArrayFromInt64(bitLen uint, val int64, valWidth uint) BitArray { 160 if bitLen == 0 { 161 return BitArray{} 162 } 163 d := MakeZeroBitArray(bitLen) 164 if bitLen < valWidth { 165 // Fast path, no sign extension to compute. 166 d.words[len(d.words)-1] = word(val << (numBitsPerWord - bitLen)) 167 return d 168 } 169 if val&(1<<(valWidth-1)) != 0 { 170 // Sign extend, fill ones in every word but the last. 171 for i := 0; i < len(d.words)-1; i++ { 172 d.words[i] = ^word(0) 173 } 174 } 175 // Shift the value to its given number of bits, to position the sign 176 // bit to the left. 177 val = val << (numBitsPerWord - valWidth) 178 // Shift right back with arithmetic shift to extend the sign bit. 179 val = val >> (numBitsPerWord - valWidth) 180 // Store the right part of the value in the last word. 181 d.words[len(d.words)-1] = word(val << (numBitsPerWord - d.lastBitsUsed)) 182 // Store the left part in the next-to-last word, if any. 183 if valWidth > uint(d.lastBitsUsed) { 184 d.words[len(d.words)-2] = word(val >> d.lastBitsUsed) 185 } 186 return d 187 } 188 189 // AsInt64 returns the int constituted from the rightmost bits in the 190 // bit array. 191 func (d BitArray) AsInt64(nbits uint) int64 { 192 if d.lastBitsUsed == 0 { 193 // Fast path. 194 return 0 195 } 196 197 lowPart := d.words[len(d.words)-1] >> (numBitsPerWord - d.lastBitsUsed) 198 highPart := word(0) 199 if nbits > uint(d.lastBitsUsed) && len(d.words) > 1 { 200 highPart = d.words[len(d.words)-2] << d.lastBitsUsed 201 } 202 combined := lowPart | highPart 203 signExtended := int64(combined<<(numBitsPerWord-nbits)) >> (numBitsPerWord - nbits) 204 return signExtended 205 } 206 207 // LeftShiftAny performs a logical left shift, with a possible 208 // negative count. 209 // The number of bits to shift can be arbitrarily large (i.e. possibly 210 // larger than 64 in absolute value). 211 func (d BitArray) LeftShiftAny(n int64) BitArray { 212 bitlen := d.BitLen() 213 if n == 0 || bitlen == 0 { 214 // Fast path. 215 return d 216 } 217 218 r := MakeZeroBitArray(bitlen) 219 if (n > 0 && n > int64(bitlen)) || (n < 0 && -n > int64(bitlen)) { 220 // Fast path. 221 return r 222 } 223 224 if n > 0 { 225 // This is a left shift. 226 dstWord := uint(0) 227 srcWord := uint(uint64(n) / numBitsPerWord) 228 srcShift := uint(uint64(n) % numBitsPerWord) 229 for i, j := srcWord, dstWord; i < uint(len(d.words)); i++ { 230 r.words[j] = d.words[i] << srcShift 231 j++ 232 } 233 for i, j := srcWord+1, dstWord; i < uint(len(d.words)); i++ { 234 r.words[j] |= d.words[i] >> (numBitsPerWord - srcShift) 235 j++ 236 } 237 } else { 238 // A right shift. 239 n = -n 240 srcWord := uint(0) 241 dstWord := uint(uint64(n) / numBitsPerWord) 242 srcShift := uint(uint64(n) % numBitsPerWord) 243 for i, j := srcWord, dstWord; j < uint(len(r.words)); i++ { 244 r.words[j] = d.words[i] >> srcShift 245 j++ 246 } 247 for i, j := srcWord, dstWord+1; j < uint(len(r.words)); i++ { 248 r.words[j] |= d.words[i] << (numBitsPerWord - srcShift) 249 j++ 250 } 251 // Erase the trailing bits that are not used any more. 252 // See #36606. 253 if len(r.words) > 0 { 254 r.words[len(r.words)-1] &= ^word(0) << (numBitsPerWord - r.lastBitsUsed) 255 } 256 } 257 258 return r 259 } 260 261 // byteReprs contains the bit representation of the 256 possible 262 // groups of 8 bits. 263 var byteReprs = func() (ret [256]string) { 264 for i := range ret { 265 // Change this format if numBitsPerWord changes. 266 ret[i] = fmt.Sprintf("%08b", i) 267 } 268 return ret 269 }() 270 271 // Format prints out the bit array to the buffer. 272 func (d BitArray) Format(buf *bytes.Buffer) { 273 bitLen := d.BitLen() 274 buf.Grow(int(bitLen)) 275 for i := uint(0); i < bitLen/numBitsPerWord; i++ { 276 w := d.words[i] 277 // Change this loop if numBitsPerWord changes. 278 buf.WriteString(byteReprs[(w>>56)&0xff]) 279 buf.WriteString(byteReprs[(w>>48)&0xff]) 280 buf.WriteString(byteReprs[(w>>40)&0xff]) 281 buf.WriteString(byteReprs[(w>>32)&0xff]) 282 buf.WriteString(byteReprs[(w>>24)&0xff]) 283 buf.WriteString(byteReprs[(w>>16)&0xff]) 284 buf.WriteString(byteReprs[(w>>8)&0xff]) 285 buf.WriteString(byteReprs[(w>>0)&0xff]) 286 } 287 remainingBits := bitLen % numBitsPerWord 288 if remainingBits > 0 { 289 lastWord := d.words[bitLen/numBitsPerWord] 290 minShift := numBitsPerWord - 1 - remainingBits 291 for i := numBitsPerWord - 1; i > int(minShift); i-- { 292 bitVal := (lastWord >> uint(i)) & 1 293 buf.WriteByte('0' + byte(bitVal)) 294 } 295 } 296 } 297 298 // EncodingPartsForBitLen creates a word backing array and the 299 // "last bits used" value given the given total number of bits. 300 func EncodingPartsForBitLen(bitLen uint) ([]uint64, uint64) { 301 if bitLen == 0 { 302 return nil, 0 303 } 304 numWords, lastBitsUsed := SizesForBitLen(bitLen) 305 words := make([]word, numWords) 306 return words, lastBitsUsed 307 } 308 309 // SizesForBitLen computes the number of words and last bits used for 310 // the requested bit array size. 311 func SizesForBitLen(bitLen uint) (uint, uint64) { 312 // This computes ceil(bitLen / numBitsPerWord). 313 numWords := (bitLen + numBitsPerWord - 1) / numBitsPerWord 314 lastBitsUsed := uint64(bitLen % numBitsPerWord) 315 if lastBitsUsed == 0 { 316 lastBitsUsed = numBitsPerWord 317 } 318 return numWords, lastBitsUsed 319 } 320 321 // Parse parses a bit array from the specified string. 322 func Parse(s string) (res BitArray, err error) { 323 if len(s) == 0 { 324 return res, nil 325 } 326 327 if s[0] == 'x' || s[0] == 'X' { 328 return parseFromHex(s[1:]) 329 } 330 return parseFromBinary(s) 331 } 332 333 func parseFromBinary(s string) (res BitArray, err error) { 334 words, lastBitsUsed := EncodingPartsForBitLen(uint(len(s))) 335 336 // Parse the bits. 337 wordIdx := 0 338 bitIdx := uint(0) 339 curWord := word(0) 340 for _, c := range s { 341 val := word(c - '0') 342 bitVal := val & 1 343 if bitVal != val { 344 // Note: the prefix "could not parse" is important as it is used 345 // to detect parsing errors in tests. 346 err := fmt.Errorf(`could not parse string as bit array: "%c" is not a valid binary digit`, c) 347 return res, pgerror.WithCandidateCode(err, pgcode.InvalidTextRepresentation) 348 } 349 curWord |= bitVal << (63 - bitIdx) 350 bitIdx = (bitIdx + 1) % numBitsPerWord 351 if bitIdx == 0 { 352 words[wordIdx] = curWord 353 curWord = 0 354 wordIdx++ 355 } 356 } 357 if bitIdx > 0 { 358 // Ensure the last word is stored. 359 words[wordIdx] = curWord 360 } 361 362 return FromEncodingParts(words, lastBitsUsed) 363 } 364 365 func parseFromHex(s string) (res BitArray, err error) { 366 words, lastBitsUsed := EncodingPartsForBitLen(uint(len(s)) * 4) 367 368 // Parse the bits. 369 wordIdx := 0 370 bitIdx := uint(0) 371 curWord := word(0) 372 for _, c := range s { 373 var bitVal word 374 if c >= '0' && c <= '9' { 375 bitVal = word(c - '0') 376 } else if c >= 'a' && c <= 'f' { 377 bitVal = word(c-'a') + 10 378 } else if c >= 'A' && c <= 'F' { 379 bitVal = word(c-'A') + 10 380 } else { 381 // Note: the prefix "could not parse" is important as it is used 382 // to detect parsing errors in tests. 383 err := fmt.Errorf(`could not parse string as bit array: "%c" is not a valid hexadecimal digit`, c) 384 return res, pgerror.WithCandidateCode(err, pgcode.InvalidTextRepresentation) 385 } 386 curWord |= bitVal << (60 - bitIdx) 387 bitIdx = (bitIdx + 4) % numBitsPerWord 388 if bitIdx == 0 { 389 words[wordIdx] = curWord 390 curWord = 0 391 wordIdx++ 392 } 393 } 394 if bitIdx > 0 { 395 // Ensure the last word is stored. 396 words[wordIdx] = curWord 397 } 398 399 return FromEncodingParts(words, lastBitsUsed) 400 } 401 402 // Concat concatenates two bit arrays. 403 func Concat(lhs, rhs BitArray) BitArray { 404 if lhs.lastBitsUsed == 0 { 405 return rhs 406 } 407 if rhs.lastBitsUsed == 0 { 408 return lhs 409 } 410 words := make([]word, (lhs.nonEmptyBitLen()+rhs.nonEmptyBitLen()+numBitsPerWord-1)/numBitsPerWord) 411 412 // The first bits come from the lhs unchanged. 413 copy(words, lhs.words) 414 var lastBitsUsed uint8 415 if lhs.lastBitsUsed == numBitsPerWord { 416 // Fast path. Just concatenate. 417 copy(words[len(lhs.words):], rhs.words) 418 lastBitsUsed = rhs.lastBitsUsed 419 } else { 420 // We need to shift all the words in the RHS 421 // by the lastBitsUsed of the LHS. 422 rhsShift := lhs.lastBitsUsed 423 targetWordIdx := len(lhs.words) - 1 424 trailingBits := words[targetWordIdx] 425 for _, w := range rhs.words { 426 headingBits := w >> rhsShift 427 combinedBits := trailingBits | headingBits 428 words[targetWordIdx] = combinedBits 429 targetWordIdx++ 430 trailingBits = w << (numBitsPerWord - rhsShift) 431 } 432 lastBitsUsed = lhs.lastBitsUsed + rhs.lastBitsUsed 433 if lastBitsUsed > numBitsPerWord { 434 // Some bits from the RHS didn't fill a 435 // word, we need to fit them in the last word. 436 words[targetWordIdx] = trailingBits 437 } 438 439 // Compute the final thing. 440 lastBitsUsed %= numBitsPerWord 441 if lastBitsUsed == 0 { 442 lastBitsUsed = numBitsPerWord 443 } 444 } 445 return BitArray{words: words, lastBitsUsed: lastBitsUsed} 446 } 447 448 // Not computes the complement of a bit array. 449 func Not(d BitArray) BitArray { 450 res := d.Clone() 451 for i, w := range res.words { 452 res.words[i] = ^w 453 } 454 if res.lastBitsUsed > 0 { 455 lastWord := len(res.words) - 1 456 res.words[lastWord] &= (^word(0) << (numBitsPerWord - res.lastBitsUsed)) 457 } 458 return res 459 } 460 461 // And computes the logical AND of two bit arrays. 462 // The caller must ensure they have the same bit size. 463 func And(lhs, rhs BitArray) BitArray { 464 res := lhs.Clone() 465 for i, w := range rhs.words { 466 res.words[i] &= w 467 } 468 return res 469 } 470 471 // Or computes the logical OR of two bit arrays. 472 // The caller must ensure they have the same bit size. 473 func Or(lhs, rhs BitArray) BitArray { 474 res := lhs.Clone() 475 for i, w := range rhs.words { 476 res.words[i] |= w 477 } 478 return res 479 } 480 481 // Xor computes the logical XOR of two bit arrays. 482 // The caller must ensure they have the same bit size. 483 func Xor(lhs, rhs BitArray) BitArray { 484 res := lhs.Clone() 485 for i, w := range rhs.words { 486 res.words[i] ^= w 487 } 488 return res 489 } 490 491 // Compare compares two bit arrays. They can have mixed sizes. 492 func Compare(lhs, rhs BitArray) int { 493 n := len(lhs.words) 494 if n > len(rhs.words) { 495 n = len(rhs.words) 496 } 497 i := 0 498 for ; i < n; i++ { 499 lw := lhs.words[i] 500 rw := rhs.words[i] 501 if lw < rw { 502 return -1 503 } 504 if lw > rw { 505 return 1 506 } 507 } 508 if i < len(rhs.words) { 509 // lhs is shorter. 510 return -1 511 } 512 if i < len(lhs.words) { 513 // rhs is shorter. 514 return 1 515 } 516 // Same length. 517 if lhs.lastBitsUsed < rhs.lastBitsUsed { 518 return -1 519 } 520 if lhs.lastBitsUsed > rhs.lastBitsUsed { 521 return 1 522 } 523 return 0 524 } 525 526 // EncodingParts retrieves the encoding bits from the bit array. The 527 // words are presented in big-endian order, with the leftmost bits of 528 // the bitarray (MSB) in the MSB of each word. 529 func (d BitArray) EncodingParts() ([]uint64, uint64) { 530 return d.words, uint64(d.lastBitsUsed) 531 } 532 533 // FromEncodingParts creates a bit array from the encoding parts. 534 func FromEncodingParts(words []uint64, lastBitsUsed uint64) (BitArray, error) { 535 if lastBitsUsed > numBitsPerWord { 536 err := fmt.Errorf("FromEncodingParts: lastBitsUsed must not exceed %d, got %d", 537 errors.Safe(numBitsPerWord), errors.Safe(lastBitsUsed)) 538 return BitArray{}, pgerror.WithCandidateCode(err, pgcode.InvalidParameterValue) 539 } 540 return BitArray{ 541 words: words, 542 lastBitsUsed: uint8(lastBitsUsed), 543 }, nil 544 } 545 546 // mustFromEncodingParts is like FromEncodingParts but errors cause a panic. 547 func mustFromEncodingParts(words []uint64, lastBitsUsed uint64) BitArray { 548 ba, err := FromEncodingParts(words, lastBitsUsed) 549 if err != nil { 550 panic(err) 551 } 552 return ba 553 } 554 555 // Rand generates a random bit array of the specified length. 556 func Rand(rng *rand.Rand, bitLen uint) BitArray { 557 d := MakeZeroBitArray(bitLen) 558 for i := range d.words { 559 d.words[i] = rng.Uint64() 560 } 561 if len(d.words) > 0 { 562 d.words[len(d.words)-1] <<= (numBitsPerWord - d.lastBitsUsed) 563 } 564 return d 565 } 566 567 // Next returns the next possible bit array in lexicographic order. 568 // The backing array of words is shared if possible. 569 func Next(d BitArray) BitArray { 570 if d.lastBitsUsed == 0 { 571 return BitArray{words: []word{0}, lastBitsUsed: 1} 572 } 573 if d.lastBitsUsed < numBitsPerWord { 574 res := d 575 res.lastBitsUsed++ 576 return res 577 } 578 res := BitArray{ 579 words: make([]word, len(d.words)+1), 580 lastBitsUsed: 1, 581 } 582 copy(res.words, d.words) 583 return res 584 } 585 586 // GetBitAtIndex extract bit at given index in the BitArray. 587 func (d BitArray) GetBitAtIndex(index int) (int, error) { 588 // Check whether index asked is inside BitArray. 589 if index < 0 || uint(index) >= d.BitLen() { 590 err := fmt.Errorf("GetBitAtIndex: bit index %d out of valid range (0..%d)", index, int(d.BitLen())-1) 591 return 0, pgerror.WithCandidateCode(err, pgcode.ArraySubscript) 592 } 593 // To extract bit at the given index, we have to determine the 594 // position within words array, i.e. index/numBitsPerWord after 595 // that checked the bit at residual index. 596 if d.words[index/numBitsPerWord]&(word(1)<<(numBitsPerWord-1-uint(index)%numBitsPerWord)) != 0 { 597 return 1, nil 598 } 599 return 0, nil 600 } 601 602 // SetBitAtIndex returns the BitArray with an updated bit at a given index. 603 func (d BitArray) SetBitAtIndex(index, toSet int) (BitArray, error) { 604 res := d.Clone() 605 // Check whether index asked is inside BitArray. 606 if index < 0 || uint(index) >= res.BitLen() { 607 err := fmt.Errorf("SetBitAtIndex: bit index %d out of valid range (0..%d)", index, int(res.BitLen())-1) 608 return BitArray{}, pgerror.WithCandidateCode(err, pgcode.ArraySubscript) 609 } 610 // To update bit at the given index, we have to determine the 611 // position within words array, i.e. index/numBitsPerWord after 612 // that updated the bit at residual index. 613 // Forcefully making bit at the index to 0. 614 res.words[index/numBitsPerWord] &= ^(word(1) << (numBitsPerWord - 1 - uint(index)%numBitsPerWord)) 615 // Updating value at the index to toSet. 616 res.words[index/numBitsPerWord] |= word(toSet) << (numBitsPerWord - 1 - uint(index)%numBitsPerWord) 617 return res, nil 618 } 619 620 // AsUInt64 returns the uint64 constituted from the rightmost bits in the 621 // bit array. 622 func (d *BitArray) AsUInt64() uint64 { 623 if len(d.words) == 0 { 624 return 0 625 } 626 627 lowPart := d.words[len(d.words)-1] >> (numBitsPerWord - d.lastBitsUsed) 628 highPart := word(0) 629 if len(d.words) > 1 { 630 highPart = d.words[len(d.words)-2] << d.lastBitsUsed 631 } 632 return lowPart | highPart 633 }