github.com/richardwilkes/toolbox@v1.121.0/xmath/bitset.go (about)

     1  // Copyright (c) 2016-2024 by Richard A. Wilkes. All rights reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the Mozilla Public
     4  // License, version 2.0. If a copy of the MPL was not distributed with
     5  // this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     6  //
     7  // This Source Code Form is "Incompatible With Secondary Licenses", as
     8  // defined by the Mozilla Public License, version 2.0.
     9  
    10  package xmath
    11  
    12  import (
    13  	"fmt"
    14  	"math"
    15  
    16  	"github.com/richardwilkes/toolbox/atexit"
    17  )
    18  
    19  const (
    20  	addressBitsPerWord = 6
    21  	dataBitsPerWord    = 1 << addressBitsPerWord
    22  	bitIndexMask       = dataBitsPerWord - 1
    23  )
    24  
    25  // BitSet contains a set of bits.
    26  type BitSet struct {
    27  	data []uint64
    28  	set  int
    29  }
    30  
    31  // Clone this BitSet.
    32  func (b *BitSet) Clone() *BitSet {
    33  	bs := &BitSet{data: make([]uint64, len(b.data)), set: b.set}
    34  	copy(bs.data, b.data)
    35  	return bs
    36  }
    37  
    38  // Copy the content of 'other' into this BitSet, making them equal.
    39  func (b *BitSet) Copy(other *BitSet) {
    40  	b.set = other.set
    41  	b.data = make([]uint64, len(other.data))
    42  	copy(b.data, other.data)
    43  }
    44  
    45  // Equal returns true if this BitSet is equal to 'other'.
    46  func (b *BitSet) Equal(other *BitSet) bool {
    47  	if other == nil {
    48  		return false
    49  	}
    50  	if b.set != other.set {
    51  		return false
    52  	}
    53  	if len(b.data) != len(other.data) {
    54  		return false
    55  	}
    56  	for i := range b.data {
    57  		if b.data[i] != other.data[i] {
    58  			return false
    59  		}
    60  	}
    61  	return true
    62  }
    63  
    64  // Count returns the number of set bits.
    65  func (b *BitSet) Count() int {
    66  	return b.set
    67  }
    68  
    69  // State returns the state of the bit at 'index'.
    70  func (b *BitSet) State(index int) bool {
    71  	validateBitSetIndex(index)
    72  	i := index >> addressBitsPerWord
    73  	if i >= len(b.data) {
    74  		return false
    75  	}
    76  	mask := wordMask(index)
    77  	return b.data[i]&mask == mask
    78  }
    79  
    80  // Set the bit at 'index'.
    81  func (b *BitSet) Set(index int) {
    82  	validateBitSetIndex(index)
    83  	i := index >> addressBitsPerWord
    84  	b.EnsureCapacity(i + 1)
    85  	mask := wordMask(index)
    86  	if b.data[i]&mask == 0 {
    87  		b.data[i] |= mask
    88  		b.set++
    89  	}
    90  }
    91  
    92  func countSetBits(x uint64) int {
    93  	x -= (x >> 1) & 0x5555555555555555
    94  	x = (x>>2)&0x3333333333333333 + x&0x3333333333333333
    95  	x += x >> 4
    96  	x &= 0x0f0f0f0f0f0f0f0f
    97  	x *= 0x0101010101010101
    98  	return int(x >> 56)
    99  }
   100  
   101  // SetRange sets the bits from 'start' to 'end', inclusive.
   102  func (b *BitSet) SetRange(start, end int) {
   103  	validateBitSetIndex(start)
   104  	validateBitSetIndex(end)
   105  	if start > end {
   106  		start, end = end, start
   107  	}
   108  	i1 := start >> addressBitsPerWord
   109  	i2 := end >> addressBitsPerWord
   110  	b.EnsureCapacity(i2 + 1)
   111  	j := bitIndexForMask(wordMask(start))
   112  	for i := i1; i <= i2; i++ {
   113  		if i != i1 && i != i2 {
   114  			b.set += dataBitsPerWord - countSetBits(b.data[i])
   115  			b.data[i] = math.MaxUint64
   116  		} else {
   117  			var last int
   118  			if i == i2 {
   119  				last = bitIndexForMask(wordMask(end)) + 1
   120  			} else {
   121  				last = dataBitsPerWord
   122  			}
   123  			for j < last {
   124  				mask := wordMask(j)
   125  				if b.data[i]&mask == 0 {
   126  					b.data[i] |= mask
   127  					b.set++
   128  				}
   129  				j++
   130  			}
   131  			j = 0
   132  		}
   133  	}
   134  }
   135  
   136  // Clear the bit at 'index'.
   137  func (b *BitSet) Clear(index int) {
   138  	validateBitSetIndex(index)
   139  	i := index >> addressBitsPerWord
   140  	if i < len(b.data) {
   141  		mask := wordMask(index)
   142  		if b.data[i]&mask == mask {
   143  			b.data[i] &= ^mask
   144  			b.set--
   145  		}
   146  	}
   147  }
   148  
   149  // ClearRange clears the bits from 'start' to 'end', inclusive.
   150  func (b *BitSet) ClearRange(start, end int) {
   151  	validateBitSetIndex(start)
   152  	validateBitSetIndex(end)
   153  	if start > end {
   154  		start, end = end, start
   155  	}
   156  	maximum := len(b.data) - 1
   157  	i1 := start >> addressBitsPerWord
   158  	if i1 > maximum {
   159  		return
   160  	}
   161  	i2 := end >> addressBitsPerWord
   162  	if i2 > maximum {
   163  		i2 = maximum
   164  	}
   165  	j := bitIndexForMask(wordMask(start))
   166  	for i := i1; i <= i2; i++ {
   167  		if i != i1 && i != i2 {
   168  			b.set -= countSetBits(b.data[i])
   169  			b.data[i] = 0
   170  		} else {
   171  			var last int
   172  			if i == i2 {
   173  				last = bitIndexForMask(wordMask(end)) + 1
   174  			} else {
   175  				last = dataBitsPerWord
   176  			}
   177  			for j < last {
   178  				mask := wordMask(j)
   179  				if b.data[i]&mask == mask {
   180  					b.data[i] &= ^mask
   181  					b.set--
   182  				}
   183  				j++
   184  			}
   185  			j = 0
   186  		}
   187  	}
   188  }
   189  
   190  // Flip the bit at 'index'.
   191  func (b *BitSet) Flip(index int) {
   192  	validateBitSetIndex(index)
   193  	i := index >> addressBitsPerWord
   194  	b.EnsureCapacity(i + 1)
   195  	mask := wordMask(index)
   196  	b.data[i] ^= mask
   197  	if b.data[i]&mask == mask {
   198  		b.set++
   199  	} else {
   200  		b.set--
   201  	}
   202  }
   203  
   204  // FlipRange flips the bits from 'start' to 'end', inclusive.
   205  func (b *BitSet) FlipRange(start, end int) {
   206  	validateBitSetIndex(start)
   207  	validateBitSetIndex(end)
   208  	if start > end {
   209  		start, end = end, start
   210  	}
   211  	i1 := start >> addressBitsPerWord
   212  	i2 := end >> addressBitsPerWord
   213  	b.EnsureCapacity(i2 + 1)
   214  	j := bitIndexForMask(wordMask(start))
   215  	for i := i1; i <= i2; i++ {
   216  		if i != i1 && i != i2 {
   217  			b.set += dataBitsPerWord - 2*countSetBits(b.data[i])
   218  			b.data[i] ^= math.MaxUint64
   219  		} else {
   220  			var last int
   221  			if i == i2 {
   222  				last = bitIndexForMask(wordMask(end)) + 1
   223  			} else {
   224  				last = dataBitsPerWord
   225  			}
   226  			for j < last {
   227  				mask := wordMask(j)
   228  				b.data[i] ^= mask
   229  				if b.data[i]&mask == mask {
   230  					b.set++
   231  				} else {
   232  					b.set--
   233  				}
   234  				j++
   235  			}
   236  			j = 0
   237  		}
   238  	}
   239  }
   240  
   241  // FirstSet returns the first set bit. If no bits are set, then -1 is returned.
   242  func (b *BitSet) FirstSet() int {
   243  	return b.NextSet(0)
   244  }
   245  
   246  // LastSet returns the last set bit. If no bits are set, then -1 is returned.
   247  func (b *BitSet) LastSet() int {
   248  	return b.PreviousSet(len(b.data) << addressBitsPerWord)
   249  }
   250  
   251  // PreviousSet returns the previous set bit starting from 'start'. If no bits are set at or before 'start', then -1 is
   252  // returned.
   253  func (b *BitSet) PreviousSet(start int) int {
   254  	validateBitSetIndex(start)
   255  	i := start >> addressBitsPerWord
   256  	var firstBit int
   257  	if maximum := len(b.data) - 1; i > maximum {
   258  		i = maximum
   259  		firstBit = 63
   260  	} else {
   261  		firstBit = bitIndexForMask(wordMask(start))
   262  	}
   263  	for i >= 0 {
   264  		word := b.data[i]
   265  		if word != 0 {
   266  			for j := firstBit; j >= 0; j-- {
   267  				mask := wordMask(j)
   268  				if word&mask == mask {
   269  					return i<<addressBitsPerWord + j
   270  				}
   271  			}
   272  		}
   273  		firstBit = 63
   274  		i--
   275  	}
   276  	return -1
   277  }
   278  
   279  // NextSet returns the next set bit starting from 'start'. If no bits are set at or beyond 'start', then -1 is returned.
   280  func (b *BitSet) NextSet(start int) int {
   281  	validateBitSetIndex(start)
   282  	i := start >> addressBitsPerWord
   283  	firstBit := bitIndexForMask(wordMask(start))
   284  	maximum := len(b.data)
   285  	for i < maximum {
   286  		word := b.data[i]
   287  		if word != 0 {
   288  			for j := firstBit; j < dataBitsPerWord; j++ {
   289  				mask := wordMask(j)
   290  				if word&mask == mask {
   291  					return i<<addressBitsPerWord + j
   292  				}
   293  			}
   294  		}
   295  		firstBit = 0
   296  		i++
   297  	}
   298  	return -1
   299  }
   300  
   301  // PreviousClear returns the previous clear bit starting from 'start'. If no bits are clear at or before 'start', then
   302  // -1 is returned.
   303  func (b *BitSet) PreviousClear(start int) int {
   304  	validateBitSetIndex(start)
   305  	i := start >> addressBitsPerWord
   306  	if i > len(b.data)-1 {
   307  		return start
   308  	}
   309  	firstBit := bitIndexForMask(wordMask(start))
   310  	for i >= 0 {
   311  		word := b.data[i]
   312  		if word != math.MaxUint64 {
   313  			for j := firstBit; j >= 0; j-- {
   314  				mask := wordMask(j)
   315  				if word&mask == 0 {
   316  					return i<<addressBitsPerWord + j
   317  				}
   318  			}
   319  		}
   320  		firstBit = 63
   321  		i--
   322  	}
   323  	return -1
   324  }
   325  
   326  // NextClear returns the next clear bit starting from 'start'.
   327  func (b *BitSet) NextClear(start int) int {
   328  	validateBitSetIndex(start)
   329  	i := start >> addressBitsPerWord
   330  	firstBit := bitIndexForMask(wordMask(start))
   331  	maximum := len(b.data)
   332  	for i < maximum {
   333  		word := b.data[i]
   334  		if word != math.MaxUint64 {
   335  			for j := firstBit; j < dataBitsPerWord; j++ {
   336  				mask := wordMask(j)
   337  				if word&mask == 0 {
   338  					return i<<addressBitsPerWord + j
   339  				}
   340  			}
   341  		}
   342  		firstBit = 0
   343  		i++
   344  	}
   345  	return max(maximum*dataBitsPerWord, start)
   346  }
   347  
   348  // Trim the BitSet down to the minimum required to store the set bits.
   349  func (b *BitSet) Trim() {
   350  	size := len(b.data)
   351  	for i := size - 1; i >= 0; i-- {
   352  		if b.data[i] != 0 {
   353  			i++
   354  			if i != size {
   355  				data := make([]uint64, i)
   356  				copy(data, b.data)
   357  				b.data = data
   358  			}
   359  			return
   360  		}
   361  		i--
   362  	}
   363  	b.data = nil
   364  }
   365  
   366  // EnsureCapacity ensures that the BitSet has enough underlying storage to accommodate setting a bit as high as index
   367  // position 'words' x 64 - 1 without needing to allocate more storage.
   368  func (b *BitSet) EnsureCapacity(words int) {
   369  	size := len(b.data)
   370  	if words > size {
   371  		size *= 2
   372  		if size < words {
   373  			size = words
   374  		}
   375  		data := make([]uint64, size)
   376  		copy(data, b.data)
   377  		b.data = data
   378  	}
   379  }
   380  
   381  // Data returns a copy of the underlying storage.
   382  func (b *BitSet) Data() []uint64 {
   383  	b.Trim()
   384  	data := make([]uint64, len(b.data))
   385  	copy(data, b.data)
   386  	return data
   387  }
   388  
   389  // Load replaces the current data with the bits set in 'data'.
   390  func (b *BitSet) Load(data []uint64) {
   391  	b.data = make([]uint64, len(data))
   392  	copy(b.data, data)
   393  	b.Trim()
   394  	b.set = 0
   395  	for i := len(b.data) - 1; i >= 0; i-- {
   396  		word := data[i]
   397  		if word != 0 {
   398  			for j := 0; j < dataBitsPerWord; j++ {
   399  				mask := wordMask(j)
   400  				if word&mask == mask {
   401  					b.set++
   402  				}
   403  			}
   404  		}
   405  	}
   406  }
   407  
   408  // Reset the BitSet back to an empty state.
   409  func (b *BitSet) Reset() {
   410  	b.data = nil
   411  	b.set = 0
   412  }
   413  
   414  func wordMask(index int) uint64 {
   415  	return uint64(1) << uint(index&bitIndexMask)
   416  }
   417  
   418  func bitIndexForMask(mask uint64) int {
   419  	for i := 0; i < dataBitsPerWord; i++ {
   420  		if mask == wordMask(i) {
   421  			return i
   422  		}
   423  	}
   424  	fmt.Printf("Unable to determine bit index for mask %064b\n", mask)
   425  	atexit.Exit(1)
   426  	return 0
   427  }
   428  
   429  func validateBitSetIndex(index int) {
   430  	if index < 0 {
   431  		fmt.Printf("Index must be positive (was %d)\n", index)
   432  		atexit.Exit(1)
   433  	}
   434  }