github.com/weedge/lib@v0.0.0-20230424045628-a36dcc1d90e4/container/set/bitset.go (about) 1 package set 2 3 import ( 4 "fmt" 5 "math" 6 "math/bits" 7 ) 8 9 const ( 10 shift = 6 // 2^6 = 64 11 mask = 0x3f // 63 12 ) 13 14 /* 15 type IBitSet interface { 16 Set(pos int, value int) int 17 Get(pos int) int 18 Count() uint64 19 RightShift(n int) 20 LeftShift(n int) 21 } 22 */ 23 24 type BitSet struct { 25 data []uint64 //64位 26 upCeil uint64 //for left/right shift 27 len uint64 28 size int 29 } 30 31 // 创建BitSet 32 func NewBitSet(len uint64) *BitSet { 33 size := int(len >> shift) 34 if len&mask > 0 { 35 size += 1 36 } 37 bt := &BitSet{ 38 data: make([]uint64, size), 39 len: len, 40 size: size, 41 } 42 firstSize := int(len & mask) 43 for i := 0; i < firstSize; i++ { 44 bt.upCeil |= 1 << i 45 } 46 if firstSize == 0 { 47 bt.upCeil = math.MaxUint64 48 } 49 50 return bt 51 } 52 53 func (set *BitSet) String() string { 54 //notice don't fmt print bitset obj 55 return set.StringAsc() 56 } 57 58 func (set *BitSet) StringDesc() string { 59 str := "" 60 for i := set.size - 1; i >= 0; i-- { 61 if i == set.size-1 { 62 str += "data:" 63 str += fmt.Sprintf("%b", set.data[i]) 64 } else { 65 str += fmt.Sprintf("#%b", set.data[i]) 66 } 67 } 68 str += fmt.Sprintf(" size:%d len:%d upCeilOnesCn:%d", set.size, set.len, bits.OnesCount64(set.upCeil)) 69 70 return str 71 } 72 73 func (set *BitSet) StringAsc() string { 74 str := "" 75 for i := 0; i < set.size; i++ { 76 if i == 0 { 77 str += "data:" 78 str += fmt.Sprintf("%b", set.data[i]) 79 } else { 80 str += fmt.Sprintf("#%b", set.data[i]) 81 } 82 } 83 str += fmt.Sprintf(" size:%d len:%d upCeilOnesCn:%d", set.size, set.len, bits.OnesCount64(set.upCeil)) 84 85 return str 86 } 87 88 func (set *BitSet) StringBit() string { 89 str := "" 90 for i := 0; i < set.size; i++ { 91 for j := mask; j >= 0; j-- { 92 if set.data[i]&(uint64(1)<<j) == 1 { 93 str += "1" 94 } else { 95 str += "0" 96 } 97 } 98 } 99 str += fmt.Sprintf(" len:%d", len(str)) 100 str += fmt.Sprintf("\nsize:%d upCeilOnesCn:%d", set.size, bits.OnesCount64(set.upCeil)) 101 return str 102 } 103 104 // set in LittleEndian order 105 // notice: 0<= pos < len 106 func (set *BitSet) Set(pos uint64, value int) int { 107 if pos < 0 || pos >= set.len || !(value == 0 || value == 1) { 108 return -1 109 } 110 index, offset := set._getPos(pos) 111 oldVal := set._get(index, offset) 112 113 if value == 1 { 114 set.data[index] |= uint64(1) << offset 115 } else { 116 set.data[index] |= uint64(1) << offset 117 set.data[index] ^= uint64(1) << offset 118 } 119 120 return oldVal 121 } 122 123 // get in LittleEndian order (test) 124 // notice: 0<= pos < len 125 func (set *BitSet) Get(pos uint64) int { 126 if pos < 0 || pos >= set.len { 127 return -1 128 } 129 130 index, offset := set._getPos(pos) 131 return set._get(index, offset) 132 } 133 134 // get data index and offset like partition offset 135 func (set *BitSet) _getPos(pos uint64) (index, offset int) { 136 index = set.size - int(pos>>shift) - 1 137 offset = int(pos & mask) 138 139 return 140 } 141 142 func (set *BitSet) _get(index int, offset int) int { 143 if set.data[index]&(uint64(1)<<offset) > 0 { 144 return 1 145 } 146 return 0 147 } 148 149 // get in BigEndian order 150 func (set *BitSet) _getForBigEndian(pos int) int { 151 if pos < 0 { 152 return -1 153 } 154 index := pos >> shift 155 if index >= len(set.data) { 156 return -1 157 } 158 if set.data[index]&(1<<uint(pos&mask)) == 0 { 159 return 0 160 } 161 return 1 162 } 163 164 // set in BigEndian order 165 func (set *BitSet) _setForBigEndian(pos int, value int) int { 166 if pos < 0 || !(value == 0 || value == 1) { 167 return -1 168 } 169 index := pos >> shift 170 if index >= len(set.data) { //溢出 171 return -1 172 } 173 oldValue := set._getForBigEndian(pos) 174 if oldValue == 0 && value == 1 { 175 set.data[index] |= 1 << uint(pos&mask) //对应的位设置为1,直接安位或操作即可 176 } else if oldValue == 1 && value == 0 { 177 set.data[index] &^= 1 << uint(pos&mask) //对应的位设置为0,先按位取反,然后进行与操作 178 } 179 return oldValue 180 } 181 182 // https://en.wikipedia.org/wiki/Hamming_weight 183 // use variable-precision SWAR 184 func (set *BitSet) Count() uint64 { 185 var count uint64 186 for _, b := range set.data { 187 count += swar(b) 188 } 189 return count 190 } 191 192 // variable-precision SWAR 193 func swar(i uint64) uint64 { 194 // 将相邻2位的1的数量计算出来,结果存放在这2位 195 i = (i & 0x5555555555555555) + ((i >> 1) & 0x5555555555555555) 196 // 将相邻4位的结果相加,结果存放在这4位 197 i = (i & 0x3333333333333333) + ((i >> 2) & 0x3333333333333333) 198 // 将相邻8位的结果相加,结果存放在这8位 199 i = (i & 0x0F0F0F0F0F0F0F0F) + ((i >> 4) & 0x0F0F0F0F0F0F0F0F) 200 // 计算整体1的数量,记录在高8位,然后通过右移运算,将结果放到低8位,得到最终结果 201 i = (i * 0x0101010101010101) >> 56 202 return i 203 } 204 205 // << operator 206 func (set *BitSet) LeftShift(n int) { 207 set.leftShiftData(n) 208 set.leftShiftBit(n) 209 } 210 211 func (set *BitSet) leftShiftData(n int) { 212 index := n >> shift 213 for i := 0; i+index < set.size; i++ { 214 set.data[i] = set.data[i+index] 215 } 216 //fmt.Println(n, index, set.data) 217 for i := set.size - index; i < set.size; i++ { 218 set.data[i] = 0 219 } 220 //fmt.Println(n, index, set.data) 221 } 222 223 func (set *BitSet) leftShiftBit(n int) { 224 v := n & mask 225 tp := uint64(0) 226 lstv, pos := uint64(0), uint64(mask-v+1) 227 //fmt.Println(v, tp, lstv, pos) 228 229 for i := 1; i <= v; i++ { 230 tp |= uint64(1) << (mask + 1 - i) 231 } 232 233 for i := set.size - 1; i >= 0; i-- { 234 tpLstv := (set.data[i] & tp) >> pos 235 set.data[i] <<= v 236 set.data[i] |= lstv 237 lstv = tpLstv 238 } 239 set.data[0] &= set.upCeil 240 } 241 242 // >> operator 243 func (set *BitSet) RightShift(n int) { 244 set.rightShiftData(n) 245 set.rightShiftBit(n) 246 } 247 248 func (set *BitSet) rightShiftData(n int) { 249 index := n >> shift 250 for i := set.size - 1; i >= index; i-- { 251 set.data[i] = set.data[i-index] 252 } 253 //fmt.Println(n, index, set.data) 254 for i := index - 1; i >= 0; i-- { 255 set.data[i] = 0 256 } 257 //fmt.Println(n, index, set.data) 258 } 259 260 func (set *BitSet) rightShiftBit(n int) { 261 v := n & mask 262 tp := uint64(1)<<v - 1 263 lstv, pos := uint64(0), mask-v+1 264 //fmt.Println(v, tp, lstv, pos) 265 266 for i := 0; i < set.size; i++ { 267 tpLstv := (set.data[i] & tp) << pos 268 set.data[i] >>= v 269 set.data[i] |= lstv 270 lstv = tpLstv 271 } 272 set.data[0] &= set.upCeil 273 } 274 275 // & operator (set&compare -> res) 276 func (set *BitSet) And(compare *BitSet) (res *BitSet) { 277 panicIfNull(set) 278 panicIfNull(compare) 279 280 s, c := sortByLength(set, compare) 281 res = NewBitSet(c.len) 282 for i, word := range s.data { 283 res.data[c.size-s.size+i] = word & c.data[c.size-s.size+i] 284 } 285 286 return 287 } 288 289 // | operator (set|compare -> res) 290 func (set *BitSet) Or(compare *BitSet) (res *BitSet) { 291 panicIfNull(set) 292 panicIfNull(compare) 293 294 s, c := sortByLength(set, compare) 295 res = c.Clone() 296 for i, word := range s.data { 297 res.data[c.size-s.size+i] = word | c.data[c.size-s.size+i] 298 } 299 300 return 301 } 302 303 // ^ operator (set^compare -> res) 304 func (set *BitSet) Xor(compare *BitSet) (res *BitSet) { 305 panicIfNull(set) 306 panicIfNull(compare) 307 308 s, c := sortByLength(set, compare) 309 res = c.Clone() 310 for i, word := range s.data { 311 res.data[c.size-s.size+i] = word ^ c.data[c.size-s.size+i] 312 } 313 314 return 315 } 316 317 // ~ operator(golang option ^self) 318 func (set *BitSet) Not() (res *BitSet) { 319 panicIfNull(set) 320 321 res = set.Clone() 322 for i, word := range set.data { 323 res.data[i] = ^word 324 } 325 326 return 327 } 328 329 // diff operator (&^) return new bitset (diff(set,compare) set not compare) 330 func (set *BitSet) Diff(compare *BitSet) (res *BitSet) { 331 panicIfNull(set) 332 panicIfNull(compare) 333 334 // clone set (in case set is bigger than compare) 335 res = set.Clone() 336 if set.size > compare.size { 337 for i := 0; i < compare.size; i++ { 338 res.data[set.size-compare.size+i] = set.data[set.size-compare.size+i] &^ compare.data[i] 339 } 340 } else { 341 for i := 0; i < set.size; i++ { 342 res.data[i] = set.data[i] &^ compare.data[compare.size-set.size+i] 343 } 344 } 345 346 return 347 } 348 349 // self diff operator (&^) return set diff compare(set~compare) 350 func (set *BitSet) InPlaceDiff(compare *BitSet) { 351 panicIfNull(set) 352 panicIfNull(compare) 353 354 if set.size > compare.size { 355 for i := 0; i < compare.size; i++ { 356 set.data[set.size-compare.size+i] = set.data[set.size-compare.size+i] &^ compare.data[i] 357 } 358 } else { 359 for i := 0; i < set.size; i++ { 360 set.data[i] = set.data[i] &^ compare.data[compare.size-set.size+i] 361 } 362 } 363 364 return 365 } 366 367 // Clone this BitSet 368 func (set *BitSet) Clone() *BitSet { 369 c := NewBitSet(set.len) 370 if set.data != nil { // Clone should not modify current object 371 copy(c.data, set.data) 372 } 373 return c 374 } 375 376 // Convenience function: return two bitsets ordered by asc 377 // increasing length. Note: neither can be nil 378 func sortByLength(a *BitSet, b *BitSet) (ap *BitSet, bp *BitSet) { 379 if a.len <= b.len { 380 ap, bp = a, b 381 } else { 382 ap, bp = b, a 383 } 384 return 385 } 386 387 // Error is used to distinguish errors (panics) generated in this package. 388 type Error string 389 390 func panicIfNull(b *BitSet) { 391 if b == nil { 392 panic(Error("BitSet must not be null")) 393 } 394 }