github.com/weedge/lib@v0.0.0-20230424045628-a36dcc1d90e4/container/set/bitset.go (about)

     1  package set
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"math/bits"
     7  )
     8  
     9  const (
    10  	shift = 6    // 2^6 = 64
    11  	mask  = 0x3f // 63
    12  )
    13  
    14  /*
    15  type IBitSet interface {
    16  	Set(pos int, value int) int
    17  	Get(pos int) int
    18  	Count() uint64
    19  	RightShift(n int)
    20  	LeftShift(n int)
    21  }
    22  */
    23  
    24  type BitSet struct {
    25  	data   []uint64 //64位
    26  	upCeil uint64   //for left/right shift
    27  	len    uint64
    28  	size   int
    29  }
    30  
    31  // 创建BitSet
    32  func NewBitSet(len uint64) *BitSet {
    33  	size := int(len >> shift)
    34  	if len&mask > 0 {
    35  		size += 1
    36  	}
    37  	bt := &BitSet{
    38  		data: make([]uint64, size),
    39  		len:  len,
    40  		size: size,
    41  	}
    42  	firstSize := int(len & mask)
    43  	for i := 0; i < firstSize; i++ {
    44  		bt.upCeil |= 1 << i
    45  	}
    46  	if firstSize == 0 {
    47  		bt.upCeil = math.MaxUint64
    48  	}
    49  
    50  	return bt
    51  }
    52  
    53  func (set *BitSet) String() string {
    54  	//notice don't fmt print bitset obj
    55  	return set.StringAsc()
    56  }
    57  
    58  func (set *BitSet) StringDesc() string {
    59  	str := ""
    60  	for i := set.size - 1; i >= 0; i-- {
    61  		if i == set.size-1 {
    62  			str += "data:"
    63  			str += fmt.Sprintf("%b", set.data[i])
    64  		} else {
    65  			str += fmt.Sprintf("#%b", set.data[i])
    66  		}
    67  	}
    68  	str += fmt.Sprintf(" size:%d len:%d upCeilOnesCn:%d", set.size, set.len, bits.OnesCount64(set.upCeil))
    69  
    70  	return str
    71  }
    72  
    73  func (set *BitSet) StringAsc() string {
    74  	str := ""
    75  	for i := 0; i < set.size; i++ {
    76  		if i == 0 {
    77  			str += "data:"
    78  			str += fmt.Sprintf("%b", set.data[i])
    79  		} else {
    80  			str += fmt.Sprintf("#%b", set.data[i])
    81  		}
    82  	}
    83  	str += fmt.Sprintf(" size:%d len:%d upCeilOnesCn:%d", set.size, set.len, bits.OnesCount64(set.upCeil))
    84  
    85  	return str
    86  }
    87  
    88  func (set *BitSet) StringBit() string {
    89  	str := ""
    90  	for i := 0; i < set.size; i++ {
    91  		for j := mask; j >= 0; j-- {
    92  			if set.data[i]&(uint64(1)<<j) == 1 {
    93  				str += "1"
    94  			} else {
    95  				str += "0"
    96  			}
    97  		}
    98  	}
    99  	str += fmt.Sprintf(" len:%d", len(str))
   100  	str += fmt.Sprintf("\nsize:%d  upCeilOnesCn:%d", set.size, bits.OnesCount64(set.upCeil))
   101  	return str
   102  }
   103  
   104  // set in LittleEndian order
   105  // notice: 0<= pos < len
   106  func (set *BitSet) Set(pos uint64, value int) int {
   107  	if pos < 0 || pos >= set.len || !(value == 0 || value == 1) {
   108  		return -1
   109  	}
   110  	index, offset := set._getPos(pos)
   111  	oldVal := set._get(index, offset)
   112  
   113  	if value == 1 {
   114  		set.data[index] |= uint64(1) << offset
   115  	} else {
   116  		set.data[index] |= uint64(1) << offset
   117  		set.data[index] ^= uint64(1) << offset
   118  	}
   119  
   120  	return oldVal
   121  }
   122  
   123  // get in LittleEndian order (test)
   124  // notice: 0<= pos < len
   125  func (set *BitSet) Get(pos uint64) int {
   126  	if pos < 0 || pos >= set.len {
   127  		return -1
   128  	}
   129  
   130  	index, offset := set._getPos(pos)
   131  	return set._get(index, offset)
   132  }
   133  
   134  // get data index and offset like partition offset
   135  func (set *BitSet) _getPos(pos uint64) (index, offset int) {
   136  	index = set.size - int(pos>>shift) - 1
   137  	offset = int(pos & mask)
   138  
   139  	return
   140  }
   141  
   142  func (set *BitSet) _get(index int, offset int) int {
   143  	if set.data[index]&(uint64(1)<<offset) > 0 {
   144  		return 1
   145  	}
   146  	return 0
   147  }
   148  
   149  // get in BigEndian order
   150  func (set *BitSet) _getForBigEndian(pos int) int {
   151  	if pos < 0 {
   152  		return -1
   153  	}
   154  	index := pos >> shift
   155  	if index >= len(set.data) {
   156  		return -1
   157  	}
   158  	if set.data[index]&(1<<uint(pos&mask)) == 0 {
   159  		return 0
   160  	}
   161  	return 1
   162  }
   163  
   164  // set in BigEndian order
   165  func (set *BitSet) _setForBigEndian(pos int, value int) int {
   166  	if pos < 0 || !(value == 0 || value == 1) {
   167  		return -1
   168  	}
   169  	index := pos >> shift
   170  	if index >= len(set.data) { //溢出
   171  		return -1
   172  	}
   173  	oldValue := set._getForBigEndian(pos)
   174  	if oldValue == 0 && value == 1 {
   175  		set.data[index] |= 1 << uint(pos&mask) //对应的位设置为1,直接安位或操作即可
   176  	} else if oldValue == 1 && value == 0 {
   177  		set.data[index] &^= 1 << uint(pos&mask) //对应的位设置为0,先按位取反,然后进行与操作
   178  	}
   179  	return oldValue
   180  }
   181  
   182  // https://en.wikipedia.org/wiki/Hamming_weight
   183  // use variable-precision SWAR
   184  func (set *BitSet) Count() uint64 {
   185  	var count uint64
   186  	for _, b := range set.data {
   187  		count += swar(b)
   188  	}
   189  	return count
   190  }
   191  
   192  // variable-precision SWAR
   193  func swar(i uint64) uint64 {
   194  	// 将相邻2位的1的数量计算出来,结果存放在这2位
   195  	i = (i & 0x5555555555555555) + ((i >> 1) & 0x5555555555555555)
   196  	// 将相邻4位的结果相加,结果存放在这4位
   197  	i = (i & 0x3333333333333333) + ((i >> 2) & 0x3333333333333333)
   198  	// 将相邻8位的结果相加,结果存放在这8位
   199  	i = (i & 0x0F0F0F0F0F0F0F0F) + ((i >> 4) & 0x0F0F0F0F0F0F0F0F)
   200  	// 计算整体1的数量,记录在高8位,然后通过右移运算,将结果放到低8位,得到最终结果
   201  	i = (i * 0x0101010101010101) >> 56
   202  	return i
   203  }
   204  
   205  // << operator
   206  func (set *BitSet) LeftShift(n int) {
   207  	set.leftShiftData(n)
   208  	set.leftShiftBit(n)
   209  }
   210  
   211  func (set *BitSet) leftShiftData(n int) {
   212  	index := n >> shift
   213  	for i := 0; i+index < set.size; i++ {
   214  		set.data[i] = set.data[i+index]
   215  	}
   216  	//fmt.Println(n, index, set.data)
   217  	for i := set.size - index; i < set.size; i++ {
   218  		set.data[i] = 0
   219  	}
   220  	//fmt.Println(n, index, set.data)
   221  }
   222  
   223  func (set *BitSet) leftShiftBit(n int) {
   224  	v := n & mask
   225  	tp := uint64(0)
   226  	lstv, pos := uint64(0), uint64(mask-v+1)
   227  	//fmt.Println(v, tp, lstv, pos)
   228  
   229  	for i := 1; i <= v; i++ {
   230  		tp |= uint64(1) << (mask + 1 - i)
   231  	}
   232  
   233  	for i := set.size - 1; i >= 0; i-- {
   234  		tpLstv := (set.data[i] & tp) >> pos
   235  		set.data[i] <<= v
   236  		set.data[i] |= lstv
   237  		lstv = tpLstv
   238  	}
   239  	set.data[0] &= set.upCeil
   240  }
   241  
   242  // >> operator
   243  func (set *BitSet) RightShift(n int) {
   244  	set.rightShiftData(n)
   245  	set.rightShiftBit(n)
   246  }
   247  
   248  func (set *BitSet) rightShiftData(n int) {
   249  	index := n >> shift
   250  	for i := set.size - 1; i >= index; i-- {
   251  		set.data[i] = set.data[i-index]
   252  	}
   253  	//fmt.Println(n, index, set.data)
   254  	for i := index - 1; i >= 0; i-- {
   255  		set.data[i] = 0
   256  	}
   257  	//fmt.Println(n, index, set.data)
   258  }
   259  
   260  func (set *BitSet) rightShiftBit(n int) {
   261  	v := n & mask
   262  	tp := uint64(1)<<v - 1
   263  	lstv, pos := uint64(0), mask-v+1
   264  	//fmt.Println(v, tp, lstv, pos)
   265  
   266  	for i := 0; i < set.size; i++ {
   267  		tpLstv := (set.data[i] & tp) << pos
   268  		set.data[i] >>= v
   269  		set.data[i] |= lstv
   270  		lstv = tpLstv
   271  	}
   272  	set.data[0] &= set.upCeil
   273  }
   274  
   275  // & operator (set&compare -> res)
   276  func (set *BitSet) And(compare *BitSet) (res *BitSet) {
   277  	panicIfNull(set)
   278  	panicIfNull(compare)
   279  
   280  	s, c := sortByLength(set, compare)
   281  	res = NewBitSet(c.len)
   282  	for i, word := range s.data {
   283  		res.data[c.size-s.size+i] = word & c.data[c.size-s.size+i]
   284  	}
   285  
   286  	return
   287  }
   288  
   289  // | operator (set|compare -> res)
   290  func (set *BitSet) Or(compare *BitSet) (res *BitSet) {
   291  	panicIfNull(set)
   292  	panicIfNull(compare)
   293  
   294  	s, c := sortByLength(set, compare)
   295  	res = c.Clone()
   296  	for i, word := range s.data {
   297  		res.data[c.size-s.size+i] = word | c.data[c.size-s.size+i]
   298  	}
   299  
   300  	return
   301  }
   302  
   303  // ^ operator (set^compare -> res)
   304  func (set *BitSet) Xor(compare *BitSet) (res *BitSet) {
   305  	panicIfNull(set)
   306  	panicIfNull(compare)
   307  
   308  	s, c := sortByLength(set, compare)
   309  	res = c.Clone()
   310  	for i, word := range s.data {
   311  		res.data[c.size-s.size+i] = word ^ c.data[c.size-s.size+i]
   312  	}
   313  
   314  	return
   315  }
   316  
   317  // ~ operator(golang option ^self)
   318  func (set *BitSet) Not() (res *BitSet) {
   319  	panicIfNull(set)
   320  
   321  	res = set.Clone()
   322  	for i, word := range set.data {
   323  		res.data[i] = ^word
   324  	}
   325  
   326  	return
   327  }
   328  
   329  // diff operator (&^) return new bitset (diff(set,compare) set not compare)
   330  func (set *BitSet) Diff(compare *BitSet) (res *BitSet) {
   331  	panicIfNull(set)
   332  	panicIfNull(compare)
   333  
   334  	// clone set (in case set is bigger than compare)
   335  	res = set.Clone()
   336  	if set.size > compare.size {
   337  		for i := 0; i < compare.size; i++ {
   338  			res.data[set.size-compare.size+i] = set.data[set.size-compare.size+i] &^ compare.data[i]
   339  		}
   340  	} else {
   341  		for i := 0; i < set.size; i++ {
   342  			res.data[i] = set.data[i] &^ compare.data[compare.size-set.size+i]
   343  		}
   344  	}
   345  
   346  	return
   347  }
   348  
   349  // self diff operator (&^) return set diff compare(set~compare)
   350  func (set *BitSet) InPlaceDiff(compare *BitSet) {
   351  	panicIfNull(set)
   352  	panicIfNull(compare)
   353  
   354  	if set.size > compare.size {
   355  		for i := 0; i < compare.size; i++ {
   356  			set.data[set.size-compare.size+i] = set.data[set.size-compare.size+i] &^ compare.data[i]
   357  		}
   358  	} else {
   359  		for i := 0; i < set.size; i++ {
   360  			set.data[i] = set.data[i] &^ compare.data[compare.size-set.size+i]
   361  		}
   362  	}
   363  
   364  	return
   365  }
   366  
   367  // Clone this BitSet
   368  func (set *BitSet) Clone() *BitSet {
   369  	c := NewBitSet(set.len)
   370  	if set.data != nil { // Clone should not modify current object
   371  		copy(c.data, set.data)
   372  	}
   373  	return c
   374  }
   375  
   376  // Convenience function: return two bitsets ordered by asc
   377  // increasing length. Note: neither can be nil
   378  func sortByLength(a *BitSet, b *BitSet) (ap *BitSet, bp *BitSet) {
   379  	if a.len <= b.len {
   380  		ap, bp = a, b
   381  	} else {
   382  		ap, bp = b, a
   383  	}
   384  	return
   385  }
   386  
   387  // Error is used to distinguish errors (panics) generated in this package.
   388  type Error string
   389  
   390  func panicIfNull(b *BitSet) {
   391  	if b == nil {
   392  		panic(Error("BitSet must not be null"))
   393  	}
   394  }