github.com/vipernet-xyz/tm@v0.34.24/libs/bits/bit_array.go (about)

     1  package bits
     2  
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     6  	"regexp"
     7  	"strings"
     8  	"sync"
     9  
    10  	tmmath "github.com/vipernet-xyz/tm/libs/math"
    11  	tmrand "github.com/vipernet-xyz/tm/libs/rand"
    12  	tmprotobits "github.com/vipernet-xyz/tm/proto/tendermint/libs/bits"
    13  )
    14  
    15  // BitArray is a thread-safe implementation of a bit array.
    16  type BitArray struct {
    17  	mtx   sync.Mutex
    18  	Bits  int      `json:"bits"`  // NOTE: persisted via reflect, must be exported
    19  	Elems []uint64 `json:"elems"` // NOTE: persisted via reflect, must be exported
    20  }
    21  
    22  // NewBitArray returns a new bit array.
    23  // It returns nil if the number of bits is zero.
    24  func NewBitArray(bits int) *BitArray {
    25  	if bits <= 0 {
    26  		return nil
    27  	}
    28  	return &BitArray{
    29  		Bits:  bits,
    30  		Elems: make([]uint64, (bits+63)/64),
    31  	}
    32  }
    33  
    34  // Size returns the number of bits in the bitarray
    35  func (bA *BitArray) Size() int {
    36  	if bA == nil {
    37  		return 0
    38  	}
    39  	return bA.Bits
    40  }
    41  
    42  // GetIndex returns the bit at index i within the bit array.
    43  // The behavior is undefined if i >= bA.Bits
    44  func (bA *BitArray) GetIndex(i int) bool {
    45  	if bA == nil {
    46  		return false
    47  	}
    48  	bA.mtx.Lock()
    49  	defer bA.mtx.Unlock()
    50  	return bA.getIndex(i)
    51  }
    52  
    53  func (bA *BitArray) getIndex(i int) bool {
    54  	if i >= bA.Bits {
    55  		return false
    56  	}
    57  	return bA.Elems[i/64]&(uint64(1)<<uint(i%64)) > 0
    58  }
    59  
    60  // SetIndex sets the bit at index i within the bit array.
    61  // The behavior is undefined if i >= bA.Bits
    62  func (bA *BitArray) SetIndex(i int, v bool) bool {
    63  	if bA == nil {
    64  		return false
    65  	}
    66  	bA.mtx.Lock()
    67  	defer bA.mtx.Unlock()
    68  	return bA.setIndex(i, v)
    69  }
    70  
    71  func (bA *BitArray) setIndex(i int, v bool) bool {
    72  	if i >= bA.Bits {
    73  		return false
    74  	}
    75  	if v {
    76  		bA.Elems[i/64] |= (uint64(1) << uint(i%64))
    77  	} else {
    78  		bA.Elems[i/64] &= ^(uint64(1) << uint(i%64))
    79  	}
    80  	return true
    81  }
    82  
    83  // Copy returns a copy of the provided bit array.
    84  func (bA *BitArray) Copy() *BitArray {
    85  	if bA == nil {
    86  		return nil
    87  	}
    88  	bA.mtx.Lock()
    89  	defer bA.mtx.Unlock()
    90  	return bA.copy()
    91  }
    92  
    93  func (bA *BitArray) copy() *BitArray {
    94  	c := make([]uint64, len(bA.Elems))
    95  	copy(c, bA.Elems)
    96  	return &BitArray{
    97  		Bits:  bA.Bits,
    98  		Elems: c,
    99  	}
   100  }
   101  
   102  func (bA *BitArray) copyBits(bits int) *BitArray {
   103  	c := make([]uint64, (bits+63)/64)
   104  	copy(c, bA.Elems)
   105  	return &BitArray{
   106  		Bits:  bits,
   107  		Elems: c,
   108  	}
   109  }
   110  
   111  // Or returns a bit array resulting from a bitwise OR of the two bit arrays.
   112  // If the two bit-arrys have different lengths, Or right-pads the smaller of the two bit-arrays with zeroes.
   113  // Thus the size of the return value is the maximum of the two provided bit arrays.
   114  func (bA *BitArray) Or(o *BitArray) *BitArray {
   115  	if bA == nil && o == nil {
   116  		return nil
   117  	}
   118  	if bA == nil && o != nil {
   119  		return o.Copy()
   120  	}
   121  	if o == nil {
   122  		return bA.Copy()
   123  	}
   124  	bA.mtx.Lock()
   125  	o.mtx.Lock()
   126  	c := bA.copyBits(tmmath.MaxInt(bA.Bits, o.Bits))
   127  	smaller := tmmath.MinInt(len(bA.Elems), len(o.Elems))
   128  	for i := 0; i < smaller; i++ {
   129  		c.Elems[i] |= o.Elems[i]
   130  	}
   131  	bA.mtx.Unlock()
   132  	o.mtx.Unlock()
   133  	return c
   134  }
   135  
   136  // And returns a bit array resulting from a bitwise AND of the two bit arrays.
   137  // If the two bit-arrys have different lengths, this truncates the larger of the two bit-arrays from the right.
   138  // Thus the size of the return value is the minimum of the two provided bit arrays.
   139  func (bA *BitArray) And(o *BitArray) *BitArray {
   140  	if bA == nil || o == nil {
   141  		return nil
   142  	}
   143  	bA.mtx.Lock()
   144  	o.mtx.Lock()
   145  	defer func() {
   146  		bA.mtx.Unlock()
   147  		o.mtx.Unlock()
   148  	}()
   149  	return bA.and(o)
   150  }
   151  
   152  func (bA *BitArray) and(o *BitArray) *BitArray {
   153  	c := bA.copyBits(tmmath.MinInt(bA.Bits, o.Bits))
   154  	for i := 0; i < len(c.Elems); i++ {
   155  		c.Elems[i] &= o.Elems[i]
   156  	}
   157  	return c
   158  }
   159  
   160  // Not returns a bit array resulting from a bitwise Not of the provided bit array.
   161  func (bA *BitArray) Not() *BitArray {
   162  	if bA == nil {
   163  		return nil // Degenerate
   164  	}
   165  	bA.mtx.Lock()
   166  	defer bA.mtx.Unlock()
   167  	return bA.not()
   168  }
   169  
   170  func (bA *BitArray) not() *BitArray {
   171  	c := bA.copy()
   172  	for i := 0; i < len(c.Elems); i++ {
   173  		c.Elems[i] = ^c.Elems[i]
   174  	}
   175  	return c
   176  }
   177  
   178  // Sub subtracts the two bit-arrays bitwise, without carrying the bits.
   179  // Note that carryless subtraction of a - b is (a and not b).
   180  // The output is the same as bA, regardless of o's size.
   181  // If bA is longer than o, o is right padded with zeroes
   182  func (bA *BitArray) Sub(o *BitArray) *BitArray {
   183  	if bA == nil || o == nil {
   184  		// TODO: Decide if we should do 1's complement here?
   185  		return nil
   186  	}
   187  	bA.mtx.Lock()
   188  	o.mtx.Lock()
   189  	// output is the same size as bA
   190  	c := bA.copyBits(bA.Bits)
   191  	// Only iterate to the minimum size between the two.
   192  	// If o is longer, those bits are ignored.
   193  	// If bA is longer, then skipping those iterations is equivalent
   194  	// to right padding with 0's
   195  	smaller := tmmath.MinInt(len(bA.Elems), len(o.Elems))
   196  	for i := 0; i < smaller; i++ {
   197  		// &^ is and not in golang
   198  		c.Elems[i] &^= o.Elems[i]
   199  	}
   200  	bA.mtx.Unlock()
   201  	o.mtx.Unlock()
   202  	return c
   203  }
   204  
   205  // IsEmpty returns true iff all bits in the bit array are 0
   206  func (bA *BitArray) IsEmpty() bool {
   207  	if bA == nil {
   208  		return true // should this be opposite?
   209  	}
   210  	bA.mtx.Lock()
   211  	defer bA.mtx.Unlock()
   212  	for _, e := range bA.Elems {
   213  		if e > 0 {
   214  			return false
   215  		}
   216  	}
   217  	return true
   218  }
   219  
   220  // IsFull returns true iff all bits in the bit array are 1.
   221  func (bA *BitArray) IsFull() bool {
   222  	if bA == nil {
   223  		return true
   224  	}
   225  	bA.mtx.Lock()
   226  	defer bA.mtx.Unlock()
   227  
   228  	// Check all elements except the last
   229  	for _, elem := range bA.Elems[:len(bA.Elems)-1] {
   230  		if (^elem) != 0 {
   231  			return false
   232  		}
   233  	}
   234  
   235  	// Check that the last element has (lastElemBits) 1's
   236  	lastElemBits := (bA.Bits+63)%64 + 1
   237  	lastElem := bA.Elems[len(bA.Elems)-1]
   238  	return (lastElem+1)&((uint64(1)<<uint(lastElemBits))-1) == 0
   239  }
   240  
   241  // PickRandom returns a random index for a set bit in the bit array.
   242  // If there is no such value, it returns 0, false.
   243  // It uses the global randomness in `random.go` to get this index.
   244  func (bA *BitArray) PickRandom() (int, bool) {
   245  	if bA == nil {
   246  		return 0, false
   247  	}
   248  
   249  	bA.mtx.Lock()
   250  	trueIndices := bA.getTrueIndices()
   251  	bA.mtx.Unlock()
   252  
   253  	if len(trueIndices) == 0 { // no bits set to true
   254  		return 0, false
   255  	}
   256  
   257  	return trueIndices[tmrand.Intn(len(trueIndices))], true
   258  }
   259  
   260  func (bA *BitArray) getTrueIndices() []int {
   261  	trueIndices := make([]int, 0, bA.Bits)
   262  	curBit := 0
   263  	numElems := len(bA.Elems)
   264  	// set all true indices
   265  	for i := 0; i < numElems-1; i++ {
   266  		elem := bA.Elems[i]
   267  		if elem == 0 {
   268  			curBit += 64
   269  			continue
   270  		}
   271  		for j := 0; j < 64; j++ {
   272  			if (elem & (uint64(1) << uint64(j))) > 0 {
   273  				trueIndices = append(trueIndices, curBit)
   274  			}
   275  			curBit++
   276  		}
   277  	}
   278  	// handle last element
   279  	lastElem := bA.Elems[numElems-1]
   280  	numFinalBits := bA.Bits - curBit
   281  	for i := 0; i < numFinalBits; i++ {
   282  		if (lastElem & (uint64(1) << uint64(i))) > 0 {
   283  			trueIndices = append(trueIndices, curBit)
   284  		}
   285  		curBit++
   286  	}
   287  	return trueIndices
   288  }
   289  
   290  // String returns a string representation of BitArray: BA{<bit-string>},
   291  // where <bit-string> is a sequence of 'x' (1) and '_' (0).
   292  // The <bit-string> includes spaces and newlines to help people.
   293  // For a simple sequence of 'x' and '_' characters with no spaces or newlines,
   294  // see the MarshalJSON() method.
   295  // Example: "BA{_x_}" or "nil-BitArray" for nil.
   296  func (bA *BitArray) String() string {
   297  	return bA.StringIndented("")
   298  }
   299  
   300  // StringIndented returns the same thing as String(), but applies the indent
   301  // at every 10th bit, and twice at every 50th bit.
   302  func (bA *BitArray) StringIndented(indent string) string {
   303  	if bA == nil {
   304  		return "nil-BitArray"
   305  	}
   306  	bA.mtx.Lock()
   307  	defer bA.mtx.Unlock()
   308  	return bA.stringIndented(indent)
   309  }
   310  
   311  func (bA *BitArray) stringIndented(indent string) string {
   312  	lines := []string{}
   313  	bits := ""
   314  	for i := 0; i < bA.Bits; i++ {
   315  		if bA.getIndex(i) {
   316  			bits += "x"
   317  		} else {
   318  			bits += "_"
   319  		}
   320  		if i%100 == 99 {
   321  			lines = append(lines, bits)
   322  			bits = ""
   323  		}
   324  		if i%10 == 9 {
   325  			bits += indent
   326  		}
   327  		if i%50 == 49 {
   328  			bits += indent
   329  		}
   330  	}
   331  	if len(bits) > 0 {
   332  		lines = append(lines, bits)
   333  	}
   334  	return fmt.Sprintf("BA{%v:%v}", bA.Bits, strings.Join(lines, indent))
   335  }
   336  
   337  // Bytes returns the byte representation of the bits within the bitarray.
   338  func (bA *BitArray) Bytes() []byte {
   339  	bA.mtx.Lock()
   340  	defer bA.mtx.Unlock()
   341  
   342  	numBytes := (bA.Bits + 7) / 8
   343  	bytes := make([]byte, numBytes)
   344  	for i := 0; i < len(bA.Elems); i++ {
   345  		elemBytes := [8]byte{}
   346  		binary.LittleEndian.PutUint64(elemBytes[:], bA.Elems[i])
   347  		copy(bytes[i*8:], elemBytes[:])
   348  	}
   349  	return bytes
   350  }
   351  
   352  // Update sets the bA's bits to be that of the other bit array.
   353  // The copying begins from the begin of both bit arrays.
   354  func (bA *BitArray) Update(o *BitArray) {
   355  	if bA == nil || o == nil {
   356  		return
   357  	}
   358  
   359  	bA.mtx.Lock()
   360  	o.mtx.Lock()
   361  	copy(bA.Elems, o.Elems)
   362  	o.mtx.Unlock()
   363  	bA.mtx.Unlock()
   364  }
   365  
   366  // MarshalJSON implements json.Marshaler interface by marshaling bit array
   367  // using a custom format: a string of '-' or 'x' where 'x' denotes the 1 bit.
   368  func (bA *BitArray) MarshalJSON() ([]byte, error) {
   369  	if bA == nil {
   370  		return []byte("null"), nil
   371  	}
   372  
   373  	bA.mtx.Lock()
   374  	defer bA.mtx.Unlock()
   375  
   376  	bits := `"`
   377  	for i := 0; i < bA.Bits; i++ {
   378  		if bA.getIndex(i) {
   379  			bits += `x`
   380  		} else {
   381  			bits += `_`
   382  		}
   383  	}
   384  	bits += `"`
   385  	return []byte(bits), nil
   386  }
   387  
   388  var bitArrayJSONRegexp = regexp.MustCompile(`\A"([_x]*)"\z`)
   389  
   390  // UnmarshalJSON implements json.Unmarshaler interface by unmarshaling a custom
   391  // JSON description.
   392  func (bA *BitArray) UnmarshalJSON(bz []byte) error {
   393  	b := string(bz)
   394  	if b == "null" {
   395  		// This is required e.g. for encoding/json when decoding
   396  		// into a pointer with pre-allocated BitArray.
   397  		bA.Bits = 0
   398  		bA.Elems = nil
   399  		return nil
   400  	}
   401  
   402  	// Validate 'b'.
   403  	match := bitArrayJSONRegexp.FindStringSubmatch(b)
   404  	if match == nil {
   405  		return fmt.Errorf("bitArray in JSON should be a string of format %q but got %s", bitArrayJSONRegexp.String(), b)
   406  	}
   407  	bits := match[1]
   408  
   409  	// Construct new BitArray and copy over.
   410  	numBits := len(bits)
   411  	bA2 := NewBitArray(numBits)
   412  	for i := 0; i < numBits; i++ {
   413  		if bits[i] == 'x' {
   414  			bA2.SetIndex(i, true)
   415  		}
   416  	}
   417  	*bA = *bA2 //nolint:govet
   418  	return nil
   419  }
   420  
   421  // ToProto converts BitArray to protobuf
   422  func (bA *BitArray) ToProto() *tmprotobits.BitArray {
   423  	if bA == nil || len(bA.Elems) == 0 {
   424  		return nil
   425  	}
   426  
   427  	return &tmprotobits.BitArray{
   428  		Bits:  int64(bA.Bits),
   429  		Elems: bA.Elems,
   430  	}
   431  }
   432  
   433  // FromProto sets a protobuf BitArray to the given pointer.
   434  func (bA *BitArray) FromProto(protoBitArray *tmprotobits.BitArray) {
   435  	if protoBitArray == nil {
   436  		bA = nil
   437  		return
   438  	}
   439  
   440  	bA.Bits = int(protoBitArray.Bits)
   441  	if len(protoBitArray.Elems) > 0 {
   442  		bA.Elems = protoBitArray.Elems
   443  	}
   444  }