github.com/evdatsion/aphelion-dpos-bft@v0.32.1/libs/common/bit_array.go (about)

     1  package common
     2  
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     6  	"regexp"
     7  	"strings"
     8  	"sync"
     9  )
    10  
    11  // BitArray is a thread-safe implementation of a bit array.
    12  type BitArray struct {
    13  	mtx   sync.Mutex
    14  	Bits  int      `json:"bits"`  // NOTE: persisted via reflect, must be exported
    15  	Elems []uint64 `json:"elems"` // NOTE: persisted via reflect, must be exported
    16  }
    17  
    18  // NewBitArray returns a new bit array.
    19  // It returns nil if the number of bits is zero.
    20  func NewBitArray(bits int) *BitArray {
    21  	if bits <= 0 {
    22  		return nil
    23  	}
    24  	return &BitArray{
    25  		Bits:  bits,
    26  		Elems: make([]uint64, (bits+63)/64),
    27  	}
    28  }
    29  
    30  // Size returns the number of bits in the bitarray
    31  func (bA *BitArray) Size() int {
    32  	if bA == nil {
    33  		return 0
    34  	}
    35  	return bA.Bits
    36  }
    37  
    38  // GetIndex returns the bit at index i within the bit array.
    39  // The behavior is undefined if i >= bA.Bits
    40  func (bA *BitArray) GetIndex(i int) bool {
    41  	if bA == nil {
    42  		return false
    43  	}
    44  	bA.mtx.Lock()
    45  	defer bA.mtx.Unlock()
    46  	return bA.getIndex(i)
    47  }
    48  
    49  func (bA *BitArray) getIndex(i int) bool {
    50  	if i >= bA.Bits {
    51  		return false
    52  	}
    53  	return bA.Elems[i/64]&(uint64(1)<<uint(i%64)) > 0
    54  }
    55  
    56  // SetIndex sets the bit at index i within the bit array.
    57  // The behavior is undefined if i >= bA.Bits
    58  func (bA *BitArray) SetIndex(i int, v bool) bool {
    59  	if bA == nil {
    60  		return false
    61  	}
    62  	bA.mtx.Lock()
    63  	defer bA.mtx.Unlock()
    64  	return bA.setIndex(i, v)
    65  }
    66  
    67  func (bA *BitArray) setIndex(i int, v bool) bool {
    68  	if i >= bA.Bits {
    69  		return false
    70  	}
    71  	if v {
    72  		bA.Elems[i/64] |= (uint64(1) << uint(i%64))
    73  	} else {
    74  		bA.Elems[i/64] &= ^(uint64(1) << uint(i%64))
    75  	}
    76  	return true
    77  }
    78  
    79  // Copy returns a copy of the provided bit array.
    80  func (bA *BitArray) Copy() *BitArray {
    81  	if bA == nil {
    82  		return nil
    83  	}
    84  	bA.mtx.Lock()
    85  	defer bA.mtx.Unlock()
    86  	return bA.copy()
    87  }
    88  
    89  func (bA *BitArray) copy() *BitArray {
    90  	c := make([]uint64, len(bA.Elems))
    91  	copy(c, bA.Elems)
    92  	return &BitArray{
    93  		Bits:  bA.Bits,
    94  		Elems: c,
    95  	}
    96  }
    97  
    98  func (bA *BitArray) copyBits(bits int) *BitArray {
    99  	c := make([]uint64, (bits+63)/64)
   100  	copy(c, bA.Elems)
   101  	return &BitArray{
   102  		Bits:  bits,
   103  		Elems: c,
   104  	}
   105  }
   106  
   107  // Or returns a bit array resulting from a bitwise OR of the two bit arrays.
   108  // If the two bit-arrys have different lengths, Or right-pads the smaller of the two bit-arrays with zeroes.
   109  // Thus the size of the return value is the maximum of the two provided bit arrays.
   110  func (bA *BitArray) Or(o *BitArray) *BitArray {
   111  	if bA == nil && o == nil {
   112  		return nil
   113  	}
   114  	if bA == nil && o != nil {
   115  		return o.Copy()
   116  	}
   117  	if o == nil {
   118  		return bA.Copy()
   119  	}
   120  	bA.mtx.Lock()
   121  	o.mtx.Lock()
   122  	c := bA.copyBits(MaxInt(bA.Bits, o.Bits))
   123  	smaller := MinInt(len(bA.Elems), len(o.Elems))
   124  	for i := 0; i < smaller; i++ {
   125  		c.Elems[i] |= o.Elems[i]
   126  	}
   127  	bA.mtx.Unlock()
   128  	o.mtx.Unlock()
   129  	return c
   130  }
   131  
   132  // And returns a bit array resulting from a bitwise AND of the two bit arrays.
   133  // If the two bit-arrys have different lengths, this truncates the larger of the two bit-arrays from the right.
   134  // Thus the size of the return value is the minimum of the two provided bit arrays.
   135  func (bA *BitArray) And(o *BitArray) *BitArray {
   136  	if bA == nil || o == nil {
   137  		return nil
   138  	}
   139  	bA.mtx.Lock()
   140  	o.mtx.Lock()
   141  	defer func() {
   142  		bA.mtx.Unlock()
   143  		o.mtx.Unlock()
   144  	}()
   145  	return bA.and(o)
   146  }
   147  
   148  func (bA *BitArray) and(o *BitArray) *BitArray {
   149  	c := bA.copyBits(MinInt(bA.Bits, o.Bits))
   150  	for i := 0; i < len(c.Elems); i++ {
   151  		c.Elems[i] &= o.Elems[i]
   152  	}
   153  	return c
   154  }
   155  
   156  // Not returns a bit array resulting from a bitwise Not of the provided bit array.
   157  func (bA *BitArray) Not() *BitArray {
   158  	if bA == nil {
   159  		return nil // Degenerate
   160  	}
   161  	bA.mtx.Lock()
   162  	defer bA.mtx.Unlock()
   163  	return bA.not()
   164  }
   165  
   166  func (bA *BitArray) not() *BitArray {
   167  	c := bA.copy()
   168  	for i := 0; i < len(c.Elems); i++ {
   169  		c.Elems[i] = ^c.Elems[i]
   170  	}
   171  	return c
   172  }
   173  
   174  // Sub subtracts the two bit-arrays bitwise, without carrying the bits.
   175  // Note that carryless subtraction of a - b is (a and not b).
   176  // The output is the same as bA, regardless of o's size.
   177  // If bA is longer than o, o is right padded with zeroes
   178  func (bA *BitArray) Sub(o *BitArray) *BitArray {
   179  	if bA == nil || o == nil {
   180  		// TODO: Decide if we should do 1's complement here?
   181  		return nil
   182  	}
   183  	bA.mtx.Lock()
   184  	o.mtx.Lock()
   185  	// output is the same size as bA
   186  	c := bA.copyBits(bA.Bits)
   187  	// Only iterate to the minimum size between the two.
   188  	// If o is longer, those bits are ignored.
   189  	// If bA is longer, then skipping those iterations is equivalent
   190  	// to right padding with 0's
   191  	smaller := MinInt(len(bA.Elems), len(o.Elems))
   192  	for i := 0; i < smaller; i++ {
   193  		// &^ is and not in golang
   194  		c.Elems[i] &^= o.Elems[i]
   195  	}
   196  	bA.mtx.Unlock()
   197  	o.mtx.Unlock()
   198  	return c
   199  }
   200  
   201  // IsEmpty returns true iff all bits in the bit array are 0
   202  func (bA *BitArray) IsEmpty() bool {
   203  	if bA == nil {
   204  		return true // should this be opposite?
   205  	}
   206  	bA.mtx.Lock()
   207  	defer bA.mtx.Unlock()
   208  	for _, e := range bA.Elems {
   209  		if e > 0 {
   210  			return false
   211  		}
   212  	}
   213  	return true
   214  }
   215  
   216  // IsFull returns true iff all bits in the bit array are 1.
   217  func (bA *BitArray) IsFull() bool {
   218  	if bA == nil {
   219  		return true
   220  	}
   221  	bA.mtx.Lock()
   222  	defer bA.mtx.Unlock()
   223  
   224  	// Check all elements except the last
   225  	for _, elem := range bA.Elems[:len(bA.Elems)-1] {
   226  		if (^elem) != 0 {
   227  			return false
   228  		}
   229  	}
   230  
   231  	// Check that the last element has (lastElemBits) 1's
   232  	lastElemBits := (bA.Bits+63)%64 + 1
   233  	lastElem := bA.Elems[len(bA.Elems)-1]
   234  	return (lastElem+1)&((uint64(1)<<uint(lastElemBits))-1) == 0
   235  }
   236  
   237  // PickRandom returns a random index for a set bit in the bit array.
   238  // If there is no such value, it returns 0, false.
   239  // It uses the global randomness in `random.go` to get this index.
   240  func (bA *BitArray) PickRandom() (int, bool) {
   241  	if bA == nil {
   242  		return 0, false
   243  	}
   244  
   245  	bA.mtx.Lock()
   246  	trueIndices := bA.getTrueIndices()
   247  	bA.mtx.Unlock()
   248  
   249  	if len(trueIndices) == 0 { // no bits set to true
   250  		return 0, false
   251  	}
   252  
   253  	return trueIndices[RandIntn(len(trueIndices))], true
   254  }
   255  
   256  func (bA *BitArray) getTrueIndices() []int {
   257  	trueIndices := make([]int, 0, bA.Bits)
   258  	curBit := 0
   259  	numElems := len(bA.Elems)
   260  	// set all true indices
   261  	for i := 0; i < numElems-1; i++ {
   262  		elem := bA.Elems[i]
   263  		if elem == 0 {
   264  			curBit += 64
   265  			continue
   266  		}
   267  		for j := 0; j < 64; j++ {
   268  			if (elem & (uint64(1) << uint64(j))) > 0 {
   269  				trueIndices = append(trueIndices, curBit)
   270  			}
   271  			curBit++
   272  		}
   273  	}
   274  	// handle last element
   275  	lastElem := bA.Elems[numElems-1]
   276  	numFinalBits := bA.Bits - curBit
   277  	for i := 0; i < numFinalBits; i++ {
   278  		if (lastElem & (uint64(1) << uint64(i))) > 0 {
   279  			trueIndices = append(trueIndices, curBit)
   280  		}
   281  		curBit++
   282  	}
   283  	return trueIndices
   284  }
   285  
   286  // String returns a string representation of BitArray: BA{<bit-string>},
   287  // where <bit-string> is a sequence of 'x' (1) and '_' (0).
   288  // The <bit-string> includes spaces and newlines to help people.
   289  // For a simple sequence of 'x' and '_' characters with no spaces or newlines,
   290  // see the MarshalJSON() method.
   291  // Example: "BA{_x_}" or "nil-BitArray" for nil.
   292  func (bA *BitArray) String() string {
   293  	return bA.StringIndented("")
   294  }
   295  
   296  // StringIndented returns the same thing as String(), but applies the indent
   297  // at every 10th bit, and twice at every 50th bit.
   298  func (bA *BitArray) StringIndented(indent string) string {
   299  	if bA == nil {
   300  		return "nil-BitArray"
   301  	}
   302  	bA.mtx.Lock()
   303  	defer bA.mtx.Unlock()
   304  	return bA.stringIndented(indent)
   305  }
   306  
   307  func (bA *BitArray) stringIndented(indent string) string {
   308  	lines := []string{}
   309  	bits := ""
   310  	for i := 0; i < bA.Bits; i++ {
   311  		if bA.getIndex(i) {
   312  			bits += "x"
   313  		} else {
   314  			bits += "_"
   315  		}
   316  		if i%100 == 99 {
   317  			lines = append(lines, bits)
   318  			bits = ""
   319  		}
   320  		if i%10 == 9 {
   321  			bits += indent
   322  		}
   323  		if i%50 == 49 {
   324  			bits += indent
   325  		}
   326  	}
   327  	if len(bits) > 0 {
   328  		lines = append(lines, bits)
   329  	}
   330  	return fmt.Sprintf("BA{%v:%v}", bA.Bits, strings.Join(lines, indent))
   331  }
   332  
   333  // Bytes returns the byte representation of the bits within the bitarray.
   334  func (bA *BitArray) Bytes() []byte {
   335  	bA.mtx.Lock()
   336  	defer bA.mtx.Unlock()
   337  
   338  	numBytes := (bA.Bits + 7) / 8
   339  	bytes := make([]byte, numBytes)
   340  	for i := 0; i < len(bA.Elems); i++ {
   341  		elemBytes := [8]byte{}
   342  		binary.LittleEndian.PutUint64(elemBytes[:], bA.Elems[i])
   343  		copy(bytes[i*8:], elemBytes[:])
   344  	}
   345  	return bytes
   346  }
   347  
   348  // Update sets the bA's bits to be that of the other bit array.
   349  // The copying begins from the begin of both bit arrays.
   350  func (bA *BitArray) Update(o *BitArray) {
   351  	if bA == nil || o == nil {
   352  		return
   353  	}
   354  	bA.mtx.Lock()
   355  	o.mtx.Lock()
   356  	defer func() {
   357  		bA.mtx.Unlock()
   358  		o.mtx.Unlock()
   359  	}()
   360  
   361  	copy(bA.Elems, o.Elems)
   362  }
   363  
   364  // MarshalJSON implements json.Marshaler interface by marshaling bit array
   365  // using a custom format: a string of '-' or 'x' where 'x' denotes the 1 bit.
   366  func (bA *BitArray) MarshalJSON() ([]byte, error) {
   367  	if bA == nil {
   368  		return []byte("null"), nil
   369  	}
   370  
   371  	bA.mtx.Lock()
   372  	defer bA.mtx.Unlock()
   373  
   374  	bits := `"`
   375  	for i := 0; i < bA.Bits; i++ {
   376  		if bA.getIndex(i) {
   377  			bits += `x`
   378  		} else {
   379  			bits += `_`
   380  		}
   381  	}
   382  	bits += `"`
   383  	return []byte(bits), nil
   384  }
   385  
   386  var bitArrayJSONRegexp = regexp.MustCompile(`\A"([_x]*)"\z`)
   387  
   388  // UnmarshalJSON implements json.Unmarshaler interface by unmarshaling a custom
   389  // JSON description.
   390  func (bA *BitArray) UnmarshalJSON(bz []byte) error {
   391  	b := string(bz)
   392  	if b == "null" {
   393  		// This is required e.g. for encoding/json when decoding
   394  		// into a pointer with pre-allocated BitArray.
   395  		bA.Bits = 0
   396  		bA.Elems = nil
   397  		return nil
   398  	}
   399  
   400  	// Validate 'b'.
   401  	match := bitArrayJSONRegexp.FindStringSubmatch(b)
   402  	if match == nil {
   403  		return fmt.Errorf("BitArray in JSON should be a string of format %q but got %s", bitArrayJSONRegexp.String(), b)
   404  	}
   405  	bits := match[1]
   406  
   407  	// Construct new BitArray and copy over.
   408  	numBits := len(bits)
   409  	bA2 := NewBitArray(numBits)
   410  	for i := 0; i < numBits; i++ {
   411  		if bits[i] == 'x' {
   412  			bA2.SetIndex(i, true)
   413  		}
   414  	}
   415  	*bA = *bA2 //nolint:govet
   416  	return nil
   417  }