github.com/gnolang/gno@v0.0.0-20240520182011-228e9d0192ce/tm2/pkg/bitarray/bit_array.go (about)

     1  package bitarray
     2  
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     6  	"regexp"
     7  	"strings"
     8  	"sync"
     9  
    10  	"github.com/gnolang/gno/tm2/pkg/random"
    11  )
    12  
    13  // BitArray is a thread-safe implementation of a bit array.
    14  type BitArray struct {
    15  	mtx   sync.Mutex
    16  	Bits  int      `json:"bits"`  // NOTE: persisted via reflect, must be exported
    17  	Elems []uint64 `json:"elems"` // NOTE: persisted via reflect, must be exported
    18  }
    19  
    20  // NewBitArray returns a new bit array.
    21  // It returns nil if the number of bits is zero.
    22  func NewBitArray(bits int) *BitArray {
    23  	if bits <= 0 {
    24  		return nil
    25  	}
    26  	return &BitArray{
    27  		Bits:  bits,
    28  		Elems: make([]uint64, (bits+63)/64),
    29  	}
    30  }
    31  
    32  // Size returns the number of bits in the bitarray
    33  func (bA *BitArray) Size() int {
    34  	if bA == nil {
    35  		return 0
    36  	}
    37  	return bA.Bits
    38  }
    39  
    40  // GetIndex returns the bit at index i within the bit array.
    41  // The behavior is undefined if i >= bA.Bits
    42  func (bA *BitArray) GetIndex(i int) bool {
    43  	if bA == nil {
    44  		return false
    45  	}
    46  	bA.mtx.Lock()
    47  	defer bA.mtx.Unlock()
    48  	return bA.getIndex(i)
    49  }
    50  
    51  func (bA *BitArray) getIndex(i int) bool {
    52  	if i >= bA.Bits {
    53  		return false
    54  	}
    55  	return bA.Elems[i/64]&(uint64(1)<<uint(i%64)) > 0
    56  }
    57  
    58  // SetIndex sets the bit at index i within the bit array.
    59  // The behavior is undefined if i >= bA.Bits
    60  func (bA *BitArray) SetIndex(i int, v bool) bool {
    61  	if bA == nil {
    62  		return false
    63  	}
    64  	bA.mtx.Lock()
    65  	defer bA.mtx.Unlock()
    66  	return bA.setIndex(i, v)
    67  }
    68  
    69  func (bA *BitArray) setIndex(i int, v bool) bool {
    70  	if i >= bA.Bits {
    71  		return false
    72  	}
    73  	if v {
    74  		bA.Elems[i/64] |= (uint64(1) << uint(i%64))
    75  	} else {
    76  		bA.Elems[i/64] &= ^(uint64(1) << uint(i%64))
    77  	}
    78  	return true
    79  }
    80  
    81  // Copy returns a copy of the provided bit array.
    82  func (bA *BitArray) Copy() *BitArray {
    83  	if bA == nil {
    84  		return nil
    85  	}
    86  	bA.mtx.Lock()
    87  	defer bA.mtx.Unlock()
    88  	return bA.copy()
    89  }
    90  
    91  func (bA *BitArray) copy() *BitArray {
    92  	c := make([]uint64, len(bA.Elems))
    93  	copy(c, bA.Elems)
    94  	return &BitArray{
    95  		Bits:  bA.Bits,
    96  		Elems: c,
    97  	}
    98  }
    99  
   100  func (bA *BitArray) copyBits(bits int) *BitArray {
   101  	c := make([]uint64, (bits+63)/64)
   102  	copy(c, bA.Elems)
   103  	return &BitArray{
   104  		Bits:  bits,
   105  		Elems: c,
   106  	}
   107  }
   108  
   109  // Or returns a bit array resulting from a bitwise OR of the two bit arrays.
   110  // If the two bit-arrays have different lengths, Or right-pads the smaller of the two bit-arrays with zeroes.
   111  // Thus the size of the return value is the maximum of the two provided bit arrays.
   112  func (bA *BitArray) Or(o *BitArray) *BitArray {
   113  	if bA == nil && o == nil {
   114  		return nil
   115  	}
   116  	if bA == nil && o != nil {
   117  		return o.Copy()
   118  	}
   119  	if o == nil {
   120  		return bA.Copy()
   121  	}
   122  	bA.mtx.Lock()
   123  	o.mtx.Lock()
   124  	c := bA.copyBits(max(bA.Bits, o.Bits))
   125  	smaller := min(len(bA.Elems), len(o.Elems))
   126  	for i := 0; i < smaller; i++ {
   127  		c.Elems[i] |= o.Elems[i]
   128  	}
   129  	bA.mtx.Unlock()
   130  	o.mtx.Unlock()
   131  	return c
   132  }
   133  
   134  // And returns a bit array resulting from a bitwise AND of the two bit arrays.
   135  // If the two bit-arrays have different lengths, this truncates the larger of the two bit-arrays from the right.
   136  // Thus the size of the return value is the minimum of the two provided bit arrays.
   137  func (bA *BitArray) And(o *BitArray) *BitArray {
   138  	if bA == nil || o == nil {
   139  		return nil
   140  	}
   141  	bA.mtx.Lock()
   142  	o.mtx.Lock()
   143  	defer func() {
   144  		bA.mtx.Unlock()
   145  		o.mtx.Unlock()
   146  	}()
   147  	return bA.and(o)
   148  }
   149  
   150  func (bA *BitArray) and(o *BitArray) *BitArray {
   151  	c := bA.copyBits(min(bA.Bits, o.Bits))
   152  	for i := 0; i < len(c.Elems); i++ {
   153  		c.Elems[i] &= o.Elems[i]
   154  	}
   155  	return c
   156  }
   157  
   158  // Not returns a bit array resulting from a bitwise Not of the provided bit array.
   159  func (bA *BitArray) Not() *BitArray {
   160  	if bA == nil {
   161  		return nil // Degenerate
   162  	}
   163  	bA.mtx.Lock()
   164  	defer bA.mtx.Unlock()
   165  	return bA.not()
   166  }
   167  
   168  func (bA *BitArray) not() *BitArray {
   169  	c := bA.copy()
   170  	for i := 0; i < len(c.Elems); i++ {
   171  		c.Elems[i] = ^c.Elems[i]
   172  	}
   173  	return c
   174  }
   175  
   176  // Sub subtracts the two bit-arrays bitwise, without carrying the bits.
   177  // Note that carryless subtraction of a - b is (a and not b).
   178  // The output is the same as bA, regardless of o's size.
   179  // If bA is longer than o, o is right padded with zeroes
   180  func (bA *BitArray) Sub(o *BitArray) *BitArray {
   181  	if bA == nil || o == nil {
   182  		// TODO: Decide if we should do 1's complement here?
   183  		return nil
   184  	}
   185  	bA.mtx.Lock()
   186  	o.mtx.Lock()
   187  	// output is the same size as bA
   188  	c := bA.copyBits(bA.Bits)
   189  	// Only iterate to the minimum size between the two.
   190  	// If o is longer, those bits are ignored.
   191  	// If bA is longer, then skipping those iterations is equivalent
   192  	// to right padding with 0's
   193  	smaller := min(len(bA.Elems), len(o.Elems))
   194  	for i := 0; i < smaller; i++ {
   195  		// &^ is and not in golang
   196  		c.Elems[i] &^= o.Elems[i]
   197  	}
   198  	bA.mtx.Unlock()
   199  	o.mtx.Unlock()
   200  	return c
   201  }
   202  
   203  // IsEmpty returns true iff all bits in the bit array are 0
   204  func (bA *BitArray) IsEmpty() bool {
   205  	if bA == nil {
   206  		return true // should this be opposite?
   207  	}
   208  	bA.mtx.Lock()
   209  	defer bA.mtx.Unlock()
   210  	for _, e := range bA.Elems {
   211  		if e > 0 {
   212  			return false
   213  		}
   214  	}
   215  	return true
   216  }
   217  
   218  // IsFull returns true iff all bits in the bit array are 1.
   219  func (bA *BitArray) IsFull() bool {
   220  	if bA == nil {
   221  		return true
   222  	}
   223  	bA.mtx.Lock()
   224  	defer bA.mtx.Unlock()
   225  
   226  	// Check all elements except the last
   227  	for _, elem := range bA.Elems[:len(bA.Elems)-1] {
   228  		if (^elem) != 0 {
   229  			return false
   230  		}
   231  	}
   232  
   233  	// Check that the last element has (lastElemBits) 1's
   234  	lastElemBits := (bA.Bits+63)%64 + 1
   235  	lastElem := bA.Elems[len(bA.Elems)-1]
   236  	return (lastElem+1)&((uint64(1)<<uint(lastElemBits))-1) == 0
   237  }
   238  
   239  // PickRandom returns a random index for a set bit in the bit array.
   240  // If there is no such value, it returns 0, false.
   241  // It uses the global randomness in `random.go` to get this index.
   242  func (bA *BitArray) PickRandom() (int, bool) {
   243  	if bA == nil {
   244  		return 0, false
   245  	}
   246  
   247  	bA.mtx.Lock()
   248  	trueIndices := bA.getTrueIndices()
   249  	bA.mtx.Unlock()
   250  
   251  	if len(trueIndices) == 0 { // no bits set to true
   252  		return 0, false
   253  	}
   254  
   255  	return trueIndices[random.RandIntn(len(trueIndices))], true
   256  }
   257  
   258  func (bA *BitArray) getTrueIndices() []int {
   259  	trueIndices := make([]int, 0, bA.Bits)
   260  	curBit := 0
   261  	numElems := len(bA.Elems)
   262  	// set all true indices
   263  	for i := 0; i < numElems-1; i++ {
   264  		elem := bA.Elems[i]
   265  		if elem == 0 {
   266  			curBit += 64
   267  			continue
   268  		}
   269  		for j := 0; j < 64; j++ {
   270  			if (elem & (uint64(1) << uint64(j))) > 0 {
   271  				trueIndices = append(trueIndices, curBit)
   272  			}
   273  			curBit++
   274  		}
   275  	}
   276  	// handle last element
   277  	lastElem := bA.Elems[numElems-1]
   278  	numFinalBits := bA.Bits - curBit
   279  	for i := 0; i < numFinalBits; i++ {
   280  		if (lastElem & (uint64(1) << uint64(i))) > 0 {
   281  			trueIndices = append(trueIndices, curBit)
   282  		}
   283  		curBit++
   284  	}
   285  	return trueIndices
   286  }
   287  
   288  // String returns a string representation of BitArray: BA{<bit-string>},
   289  // where <bit-string> is a sequence of 'x' (1) and '_' (0).
   290  // The <bit-string> includes spaces and newlines to help people.
   291  // For a simple sequence of 'x' and '_' characters with no spaces or newlines,
   292  // see the MarshalJSON() method.
   293  // Example: "BA{_x_}" or "nil-BitArray" for nil.
   294  func (bA *BitArray) String() string {
   295  	return bA.StringIndented("")
   296  }
   297  
   298  // StringIndented returns the same thing as String(), but applies the indent
   299  // at every 10th bit, and twice at every 50th bit.
   300  func (bA *BitArray) StringIndented(indent string) string {
   301  	if bA == nil {
   302  		return "nil-BitArray"
   303  	}
   304  	bA.mtx.Lock()
   305  	defer bA.mtx.Unlock()
   306  	return bA.stringIndented(indent)
   307  }
   308  
   309  func (bA *BitArray) stringIndented(indent string) string {
   310  	lines := []string{}
   311  	bits := ""
   312  	for i := 0; i < bA.Bits; i++ {
   313  		if bA.getIndex(i) {
   314  			bits += "x"
   315  		} else {
   316  			bits += "_"
   317  		}
   318  		if i%100 == 99 {
   319  			lines = append(lines, bits)
   320  			bits = ""
   321  		}
   322  		if i%10 == 9 {
   323  			bits += indent
   324  		}
   325  		if i%50 == 49 {
   326  			bits += indent
   327  		}
   328  	}
   329  	if len(bits) > 0 {
   330  		lines = append(lines, bits)
   331  	}
   332  	return fmt.Sprintf("BA{%v:%v}", bA.Bits, strings.Join(lines, indent))
   333  }
   334  
   335  // Bytes returns the byte representation of the bits within the bitarray.
   336  func (bA *BitArray) Bytes() []byte {
   337  	bA.mtx.Lock()
   338  	defer bA.mtx.Unlock()
   339  
   340  	numBytes := (bA.Bits + 7) / 8
   341  	bytes := make([]byte, numBytes)
   342  	for i := 0; i < len(bA.Elems); i++ {
   343  		elemBytes := [8]byte{}
   344  		binary.LittleEndian.PutUint64(elemBytes[:], bA.Elems[i])
   345  		copy(bytes[i*8:], elemBytes[:])
   346  	}
   347  	return bytes
   348  }
   349  
   350  // Update sets the bA's bits to be that of the other bit array.
   351  // The copying begins from the begin of both bit arrays.
   352  func (bA *BitArray) Update(o *BitArray) {
   353  	if bA == nil || o == nil {
   354  		return
   355  	}
   356  	bA.mtx.Lock()
   357  	o.mtx.Lock()
   358  	defer func() {
   359  		bA.mtx.Unlock()
   360  		o.mtx.Unlock()
   361  	}()
   362  
   363  	copy(bA.Elems, o.Elems)
   364  }
   365  
   366  // MarshalJSON implements json.Marshaler interface by marshaling bit array
   367  // using a custom format: a string of '-' or 'x' where 'x' denotes the 1 bit.
   368  func (bA *BitArray) MarshalJSON() ([]byte, error) {
   369  	if bA == nil {
   370  		return []byte("null"), nil
   371  	}
   372  
   373  	bA.mtx.Lock()
   374  	defer bA.mtx.Unlock()
   375  
   376  	bits := `"`
   377  	for i := 0; i < bA.Bits; i++ {
   378  		if bA.getIndex(i) {
   379  			bits += `x`
   380  		} else {
   381  			bits += `_`
   382  		}
   383  	}
   384  	bits += `"`
   385  	return []byte(bits), nil
   386  }
   387  
   388  var bitArrayJSONRegexp = regexp.MustCompile(`\A"([_x]*)"\z`)
   389  
   390  // UnmarshalJSON implements json.Unmarshaler interface by unmarshalling a custom
   391  // JSON description.
   392  func (bA *BitArray) UnmarshalJSON(bz []byte) error {
   393  	b := string(bz)
   394  	if b == "null" {
   395  		// This is required e.g. for encoding/json when decoding
   396  		// into a pointer with pre-allocated BitArray.
   397  		bA.Bits = 0
   398  		bA.Elems = nil
   399  		return nil
   400  	}
   401  
   402  	// Validate 'b'.
   403  	match := bitArrayJSONRegexp.FindStringSubmatch(b)
   404  	if match == nil {
   405  		return fmt.Errorf("BitArray in JSON should be a string of format %q but got %s", bitArrayJSONRegexp.String(), b)
   406  	}
   407  	bits := match[1]
   408  
   409  	// Construct new BitArray and copy over.
   410  	numBits := len(bits)
   411  	bA2 := NewBitArray(numBits)
   412  	for i := 0; i < numBits; i++ {
   413  		if bits[i] == 'x' {
   414  			bA2.SetIndex(i, true)
   415  		}
   416  	}
   417  	*bA = *bA2 //nolint:govet
   418  	return nil
   419  }