github.com/ari-anchor/sei-tendermint@v0.0.0-20230519144642-dc826b7b56bb/libs/bits/bit_array.go (about)

     1  package bits
     2  
     3  import (
     4  	"encoding/binary"
     5  	"errors"
     6  	"fmt"
     7  	"math"
     8  	"math/rand"
     9  	"regexp"
    10  	"strings"
    11  	"sync"
    12  
    13  	tmmath "github.com/ari-anchor/sei-tendermint/libs/math"
    14  	tmprotobits "github.com/ari-anchor/sei-tendermint/proto/tendermint/libs/bits"
    15  )
    16  
    17  // BitArray is a thread-safe implementation of a bit array.
    18  type BitArray struct {
    19  	mtx   sync.Mutex
    20  	Bits  int      `json:"bits"`  // NOTE: persisted via reflect, must be exported
    21  	Elems []uint64 `json:"elems"` // NOTE: persisted via reflect, must be exported
    22  }
    23  
    24  // NewBitArray returns a new bit array.
    25  // It returns nil if the number of bits is zero.
    26  func NewBitArray(bits int) *BitArray {
    27  	if bits <= 0 {
    28  		return nil
    29  	}
    30  	bA := &BitArray{}
    31  	bA.reset(bits)
    32  	return bA
    33  }
    34  
    35  // reset changes size of BitArray to `bits` and re-allocates (zeroed) data buffer
    36  func (bA *BitArray) reset(bits int) {
    37  	bA.mtx.Lock()
    38  	defer bA.mtx.Unlock()
    39  
    40  	bA.Bits = bits
    41  	if bits == 0 {
    42  		bA.Elems = nil
    43  	} else {
    44  		bA.Elems = make([]uint64, numElems(bits))
    45  	}
    46  }
    47  
    48  // Size returns the number of bits in the bitarray
    49  func (bA *BitArray) Size() int {
    50  	if bA == nil {
    51  		return 0
    52  	}
    53  	return bA.Bits
    54  }
    55  
    56  // GetIndex returns the bit at index i within the bit array.
    57  // The behavior is undefined if i >= bA.Bits
    58  func (bA *BitArray) GetIndex(i int) bool {
    59  	if bA == nil {
    60  		return false
    61  	}
    62  	bA.mtx.Lock()
    63  	defer bA.mtx.Unlock()
    64  	return bA.getIndex(i)
    65  }
    66  
    67  func (bA *BitArray) getIndex(i int) bool {
    68  	if i >= bA.Bits {
    69  		return false
    70  	}
    71  	return bA.Elems[i/64]&(uint64(1)<<uint(i%64)) > 0
    72  }
    73  
    74  // SetIndex sets the bit at index i within the bit array.
    75  // This method returns false if i is out of range of the BitArray.
    76  func (bA *BitArray) SetIndex(i int, v bool) bool {
    77  	if bA == nil {
    78  		return false
    79  	}
    80  	bA.mtx.Lock()
    81  	defer bA.mtx.Unlock()
    82  	return bA.setIndex(i, v)
    83  }
    84  
    85  func (bA *BitArray) setIndex(i int, v bool) bool {
    86  	if i < 0 || i >= bA.Bits {
    87  		return false
    88  	}
    89  	if v {
    90  		bA.Elems[i/64] |= (uint64(1) << uint(i%64))
    91  	} else {
    92  		bA.Elems[i/64] &= ^(uint64(1) << uint(i%64))
    93  	}
    94  	return true
    95  }
    96  
    97  // Copy returns a copy of the provided bit array.
    98  func (bA *BitArray) Copy() *BitArray {
    99  	if bA == nil {
   100  		return nil
   101  	}
   102  	bA.mtx.Lock()
   103  	defer bA.mtx.Unlock()
   104  	return bA.copy()
   105  }
   106  
   107  func (bA *BitArray) copy() *BitArray {
   108  	c := make([]uint64, len(bA.Elems))
   109  	copy(c, bA.Elems)
   110  	return &BitArray{
   111  		Bits:  bA.Bits,
   112  		Elems: c,
   113  	}
   114  }
   115  
   116  func (bA *BitArray) copyBits(bits int) *BitArray {
   117  	c := make([]uint64, numElems(bits))
   118  	copy(c, bA.Elems)
   119  	return &BitArray{
   120  		Bits:  bits,
   121  		Elems: c,
   122  	}
   123  }
   124  
   125  // Or returns a bit array resulting from a bitwise OR of the two bit arrays.
   126  // If the two bit-arrys have different lengths, Or right-pads the smaller of the two bit-arrays with zeroes.
   127  // Thus the size of the return value is the maximum of the two provided bit arrays.
   128  func (bA *BitArray) Or(o *BitArray) *BitArray {
   129  	if bA == nil && o == nil {
   130  		return nil
   131  	}
   132  	if bA == nil && o != nil {
   133  		return o.Copy()
   134  	}
   135  	if o == nil {
   136  		return bA.Copy()
   137  	}
   138  	bA.mtx.Lock()
   139  	o.mtx.Lock()
   140  	c := bA.copyBits(tmmath.MaxInt(bA.Bits, o.Bits))
   141  	smaller := tmmath.MinInt(len(bA.Elems), len(o.Elems))
   142  	for i := 0; i < smaller; i++ {
   143  		c.Elems[i] |= o.Elems[i]
   144  	}
   145  	bA.mtx.Unlock()
   146  	o.mtx.Unlock()
   147  	return c
   148  }
   149  
   150  // And returns a bit array resulting from a bitwise AND of the two bit arrays.
   151  // If the two bit-arrys have different lengths, this truncates the larger of the two bit-arrays from the right.
   152  // Thus the size of the return value is the minimum of the two provided bit arrays.
   153  func (bA *BitArray) And(o *BitArray) *BitArray {
   154  	if bA == nil || o == nil {
   155  		return nil
   156  	}
   157  	bA.mtx.Lock()
   158  	o.mtx.Lock()
   159  	defer func() {
   160  		bA.mtx.Unlock()
   161  		o.mtx.Unlock()
   162  	}()
   163  	return bA.and(o)
   164  }
   165  
   166  func (bA *BitArray) and(o *BitArray) *BitArray {
   167  	c := bA.copyBits(tmmath.MinInt(bA.Bits, o.Bits))
   168  	for i := 0; i < len(c.Elems); i++ {
   169  		c.Elems[i] &= o.Elems[i]
   170  	}
   171  	return c
   172  }
   173  
   174  // Not returns a bit array resulting from a bitwise Not of the provided bit array.
   175  func (bA *BitArray) Not() *BitArray {
   176  	if bA == nil {
   177  		return nil // Degenerate
   178  	}
   179  	bA.mtx.Lock()
   180  	defer bA.mtx.Unlock()
   181  	return bA.not()
   182  }
   183  
   184  func (bA *BitArray) not() *BitArray {
   185  	c := bA.copy()
   186  	for i := 0; i < len(c.Elems); i++ {
   187  		c.Elems[i] = ^c.Elems[i]
   188  	}
   189  	return c
   190  }
   191  
   192  // Sub subtracts the two bit-arrays bitwise, without carrying the bits.
   193  // Note that carryless subtraction of a - b is (a and not b).
   194  // The output is the same as bA, regardless of o's size.
   195  // If bA is longer than o, o is right padded with zeroes
   196  func (bA *BitArray) Sub(o *BitArray) *BitArray {
   197  	if bA == nil || o == nil {
   198  		// TODO: Decide if we should do 1's complement here?
   199  		return nil
   200  	}
   201  	bA.mtx.Lock()
   202  	o.mtx.Lock()
   203  	// output is the same size as bA
   204  	c := bA.copyBits(bA.Bits)
   205  	// Only iterate to the minimum size between the two.
   206  	// If o is longer, those bits are ignored.
   207  	// If bA is longer, then skipping those iterations is equivalent
   208  	// to right padding with 0's
   209  	smaller := tmmath.MinInt(len(bA.Elems), len(o.Elems))
   210  	for i := 0; i < smaller; i++ {
   211  		// &^ is and not in golang
   212  		c.Elems[i] &^= o.Elems[i]
   213  	}
   214  	bA.mtx.Unlock()
   215  	o.mtx.Unlock()
   216  	return c
   217  }
   218  
   219  // IsEmpty returns true iff all bits in the bit array are 0
   220  func (bA *BitArray) IsEmpty() bool {
   221  	if bA == nil {
   222  		return true // should this be opposite?
   223  	}
   224  	bA.mtx.Lock()
   225  	defer bA.mtx.Unlock()
   226  	for _, e := range bA.Elems {
   227  		if e > 0 {
   228  			return false
   229  		}
   230  	}
   231  	return true
   232  }
   233  
   234  // IsFull returns true iff all bits in the bit array are 1.
   235  func (bA *BitArray) IsFull() bool {
   236  	if bA == nil {
   237  		return true
   238  	}
   239  	bA.mtx.Lock()
   240  	defer bA.mtx.Unlock()
   241  
   242  	// Check all elements except the last
   243  	for _, elem := range bA.Elems[:len(bA.Elems)-1] {
   244  		if (^elem) != 0 {
   245  			return false
   246  		}
   247  	}
   248  
   249  	// Check that the last element has (lastElemBits) 1's
   250  	lastElemBits := (bA.Bits+63)%64 + 1
   251  	lastElem := bA.Elems[len(bA.Elems)-1]
   252  	return (lastElem+1)&((uint64(1)<<uint(lastElemBits))-1) == 0
   253  }
   254  
   255  // PickRandom returns a random index for a set bit in the bit array.
   256  // If there is no such value, it returns 0, false.
   257  // It uses math/rand's global randomness Source to get this index.
   258  func (bA *BitArray) PickRandom() (int, bool) {
   259  	if bA == nil {
   260  		return 0, false
   261  	}
   262  
   263  	bA.mtx.Lock()
   264  	trueIndices := bA.getTrueIndices()
   265  	bA.mtx.Unlock()
   266  
   267  	if len(trueIndices) == 0 { // no bits set to true
   268  		return 0, false
   269  	}
   270  
   271  	// NOTE: using the default math/rand might result in somewhat
   272  	// amount of determinism here. It would be possible to use
   273  	// rand.New(rand.NewSeed(time.Now().Unix())).Intn() to
   274  	// counteract this possibility if it proved to be material.
   275  	//
   276  	// nolint:gosec // G404: Use of weak random number generator
   277  	return trueIndices[rand.Intn(len(trueIndices))], true
   278  }
   279  
   280  func (bA *BitArray) getTrueIndices() []int {
   281  	trueIndices := make([]int, 0, bA.Bits)
   282  	curBit := 0
   283  	numElems := len(bA.Elems)
   284  	// set all true indices
   285  	for i := 0; i < numElems-1; i++ {
   286  		elem := bA.Elems[i]
   287  		if elem == 0 {
   288  			curBit += 64
   289  			continue
   290  		}
   291  		for j := 0; j < 64; j++ {
   292  			if (elem & (uint64(1) << uint64(j))) > 0 {
   293  				trueIndices = append(trueIndices, curBit)
   294  			}
   295  			curBit++
   296  		}
   297  	}
   298  	// handle last element
   299  	lastElem := bA.Elems[numElems-1]
   300  	numFinalBits := bA.Bits - curBit
   301  	for i := 0; i < numFinalBits; i++ {
   302  		if (lastElem & (uint64(1) << uint64(i))) > 0 {
   303  			trueIndices = append(trueIndices, curBit)
   304  		}
   305  		curBit++
   306  	}
   307  	return trueIndices
   308  }
   309  
   310  // String returns a string representation of BitArray: BA{<bit-string>},
   311  // where <bit-string> is a sequence of 'x' (1) and '_' (0).
   312  // The <bit-string> includes spaces and newlines to help people.
   313  // For a simple sequence of 'x' and '_' characters with no spaces or newlines,
   314  // see the MarshalJSON() method.
   315  // Example: "BA{_x_}" or "nil-BitArray" for nil.
   316  func (bA *BitArray) String() string {
   317  	return bA.StringIndented("")
   318  }
   319  
   320  // StringIndented returns the same thing as String(), but applies the indent
   321  // at every 10th bit, and twice at every 50th bit.
   322  func (bA *BitArray) StringIndented(indent string) string {
   323  	if bA == nil {
   324  		return "nil-BitArray"
   325  	}
   326  	bA.mtx.Lock()
   327  	defer bA.mtx.Unlock()
   328  	return bA.stringIndented(indent)
   329  }
   330  
   331  func (bA *BitArray) stringIndented(indent string) string {
   332  	lines := []string{}
   333  	bits := ""
   334  	for i := 0; i < bA.Bits; i++ {
   335  		if bA.getIndex(i) {
   336  			bits += "x"
   337  		} else {
   338  			bits += "_"
   339  		}
   340  		if i%100 == 99 {
   341  			lines = append(lines, bits)
   342  			bits = ""
   343  		}
   344  		if i%10 == 9 {
   345  			bits += indent
   346  		}
   347  		if i%50 == 49 {
   348  			bits += indent
   349  		}
   350  	}
   351  	if len(bits) > 0 {
   352  		lines = append(lines, bits)
   353  	}
   354  	return fmt.Sprintf("BA{%v:%v}", bA.Bits, strings.Join(lines, indent))
   355  }
   356  
   357  // Bytes returns the byte representation of the bits within the bitarray.
   358  func (bA *BitArray) Bytes() []byte {
   359  	bA.mtx.Lock()
   360  	defer bA.mtx.Unlock()
   361  
   362  	numBytes := (bA.Bits + 7) / 8
   363  	bytes := make([]byte, numBytes)
   364  	for i := 0; i < len(bA.Elems); i++ {
   365  		elemBytes := [8]byte{}
   366  		binary.LittleEndian.PutUint64(elemBytes[:], bA.Elems[i])
   367  		copy(bytes[i*8:], elemBytes[:])
   368  	}
   369  	return bytes
   370  }
   371  
   372  // Update sets the bA's bits to be that of the other bit array.
   373  // The copying begins from the begin of both bit arrays.
   374  func (bA *BitArray) Update(o *BitArray) {
   375  	if bA == nil || o == nil {
   376  		return
   377  	}
   378  
   379  	bA.mtx.Lock()
   380  	o.mtx.Lock()
   381  	copy(bA.Elems, o.Elems)
   382  	o.mtx.Unlock()
   383  	bA.mtx.Unlock()
   384  }
   385  
   386  // MarshalJSON implements json.Marshaler interface by marshaling bit array
   387  // using a custom format: a string of '-' or 'x' where 'x' denotes the 1 bit.
   388  func (bA *BitArray) MarshalJSON() ([]byte, error) {
   389  	if bA == nil {
   390  		return []byte("null"), nil
   391  	}
   392  
   393  	bA.mtx.Lock()
   394  	defer bA.mtx.Unlock()
   395  
   396  	bits := `"`
   397  	for i := 0; i < bA.Bits; i++ {
   398  		if bA.getIndex(i) {
   399  			bits += `x`
   400  		} else {
   401  			bits += `_`
   402  		}
   403  	}
   404  	bits += `"`
   405  	return []byte(bits), nil
   406  }
   407  
   408  var bitArrayJSONRegexp = regexp.MustCompile(`\A"([_x]*)"\z`)
   409  
   410  // UnmarshalJSON implements json.Unmarshaler interface by unmarshaling a custom
   411  // JSON description.
   412  func (bA *BitArray) UnmarshalJSON(bz []byte) error {
   413  	b := string(bz)
   414  	if b == "null" {
   415  		// This is required e.g. for encoding/json when decoding
   416  		// into a pointer with pre-allocated BitArray.
   417  		bA.reset(0)
   418  		return nil
   419  	}
   420  
   421  	// Validate 'b'.
   422  	match := bitArrayJSONRegexp.FindStringSubmatch(b)
   423  	if match == nil {
   424  		return fmt.Errorf("bitArray in JSON should be a string of format %q but got %s", bitArrayJSONRegexp.String(), b)
   425  	}
   426  	bits := match[1]
   427  	numBits := len(bits)
   428  
   429  	bA.reset(numBits)
   430  	for i := 0; i < numBits; i++ {
   431  		if bits[i] == 'x' {
   432  			bA.SetIndex(i, true)
   433  		}
   434  	}
   435  
   436  	return nil
   437  }
   438  
   439  // ToProto converts BitArray to protobuf. It returns nil if BitArray is
   440  // nil/empty.
   441  func (bA *BitArray) ToProto() *tmprotobits.BitArray {
   442  	if bA == nil ||
   443  		(len(bA.Elems) == 0 && bA.Bits == 0) { // empty
   444  		return nil
   445  	}
   446  
   447  	bA.mtx.Lock()
   448  	defer bA.mtx.Unlock()
   449  
   450  	bc := bA.copy()
   451  	return &tmprotobits.BitArray{Bits: int64(bc.Bits), Elems: bc.Elems}
   452  }
   453  
   454  // FromProto sets BitArray to the given protoBitArray. It returns an error if
   455  // protoBitArray is invalid.
   456  func (bA *BitArray) FromProto(protoBitArray *tmprotobits.BitArray) error {
   457  	if protoBitArray == nil {
   458  		return nil
   459  	}
   460  
   461  	// Validate protoBitArray.
   462  	if protoBitArray.Bits < 0 {
   463  		return errors.New("negative Bits")
   464  	}
   465  	// #[32bit]
   466  	if protoBitArray.Bits > math.MaxInt32 { // prevent overflow on 32bit systems
   467  		return errors.New("too many Bits")
   468  	}
   469  	if got, exp := len(protoBitArray.Elems), numElems(int(protoBitArray.Bits)); got != exp {
   470  		return fmt.Errorf("invalid number of Elems: got %d, but exp %d", got, exp)
   471  	}
   472  
   473  	bA.mtx.Lock()
   474  	defer bA.mtx.Unlock()
   475  
   476  	ec := make([]uint64, len(protoBitArray.Elems))
   477  	copy(ec, protoBitArray.Elems)
   478  
   479  	bA.Bits = int(protoBitArray.Bits)
   480  	bA.Elems = ec
   481  	return nil
   482  }
   483  
   484  func numElems(bits int) int {
   485  	return (bits + 63) / 64
   486  }