github.com/richardwilkes/toolbox@v1.121.0/xmath/bitset.go (about) 1 // Copyright (c) 2016-2024 by Richard A. Wilkes. All rights reserved. 2 // 3 // This Source Code Form is subject to the terms of the Mozilla Public 4 // License, version 2.0. If a copy of the MPL was not distributed with 5 // this file, You can obtain one at http://mozilla.org/MPL/2.0/. 6 // 7 // This Source Code Form is "Incompatible With Secondary Licenses", as 8 // defined by the Mozilla Public License, version 2.0. 9 10 package xmath 11 12 import ( 13 "fmt" 14 "math" 15 16 "github.com/richardwilkes/toolbox/atexit" 17 ) 18 19 const ( 20 addressBitsPerWord = 6 21 dataBitsPerWord = 1 << addressBitsPerWord 22 bitIndexMask = dataBitsPerWord - 1 23 ) 24 25 // BitSet contains a set of bits. 26 type BitSet struct { 27 data []uint64 28 set int 29 } 30 31 // Clone this BitSet. 32 func (b *BitSet) Clone() *BitSet { 33 bs := &BitSet{data: make([]uint64, len(b.data)), set: b.set} 34 copy(bs.data, b.data) 35 return bs 36 } 37 38 // Copy the content of 'other' into this BitSet, making them equal. 39 func (b *BitSet) Copy(other *BitSet) { 40 b.set = other.set 41 b.data = make([]uint64, len(other.data)) 42 copy(b.data, other.data) 43 } 44 45 // Equal returns true if this BitSet is equal to 'other'. 46 func (b *BitSet) Equal(other *BitSet) bool { 47 if other == nil { 48 return false 49 } 50 if b.set != other.set { 51 return false 52 } 53 if len(b.data) != len(other.data) { 54 return false 55 } 56 for i := range b.data { 57 if b.data[i] != other.data[i] { 58 return false 59 } 60 } 61 return true 62 } 63 64 // Count returns the number of set bits. 65 func (b *BitSet) Count() int { 66 return b.set 67 } 68 69 // State returns the state of the bit at 'index'. 70 func (b *BitSet) State(index int) bool { 71 validateBitSetIndex(index) 72 i := index >> addressBitsPerWord 73 if i >= len(b.data) { 74 return false 75 } 76 mask := wordMask(index) 77 return b.data[i]&mask == mask 78 } 79 80 // Set the bit at 'index'. 81 func (b *BitSet) Set(index int) { 82 validateBitSetIndex(index) 83 i := index >> addressBitsPerWord 84 b.EnsureCapacity(i + 1) 85 mask := wordMask(index) 86 if b.data[i]&mask == 0 { 87 b.data[i] |= mask 88 b.set++ 89 } 90 } 91 92 func countSetBits(x uint64) int { 93 x -= (x >> 1) & 0x5555555555555555 94 x = (x>>2)&0x3333333333333333 + x&0x3333333333333333 95 x += x >> 4 96 x &= 0x0f0f0f0f0f0f0f0f 97 x *= 0x0101010101010101 98 return int(x >> 56) 99 } 100 101 // SetRange sets the bits from 'start' to 'end', inclusive. 102 func (b *BitSet) SetRange(start, end int) { 103 validateBitSetIndex(start) 104 validateBitSetIndex(end) 105 if start > end { 106 start, end = end, start 107 } 108 i1 := start >> addressBitsPerWord 109 i2 := end >> addressBitsPerWord 110 b.EnsureCapacity(i2 + 1) 111 j := bitIndexForMask(wordMask(start)) 112 for i := i1; i <= i2; i++ { 113 if i != i1 && i != i2 { 114 b.set += dataBitsPerWord - countSetBits(b.data[i]) 115 b.data[i] = math.MaxUint64 116 } else { 117 var last int 118 if i == i2 { 119 last = bitIndexForMask(wordMask(end)) + 1 120 } else { 121 last = dataBitsPerWord 122 } 123 for j < last { 124 mask := wordMask(j) 125 if b.data[i]&mask == 0 { 126 b.data[i] |= mask 127 b.set++ 128 } 129 j++ 130 } 131 j = 0 132 } 133 } 134 } 135 136 // Clear the bit at 'index'. 137 func (b *BitSet) Clear(index int) { 138 validateBitSetIndex(index) 139 i := index >> addressBitsPerWord 140 if i < len(b.data) { 141 mask := wordMask(index) 142 if b.data[i]&mask == mask { 143 b.data[i] &= ^mask 144 b.set-- 145 } 146 } 147 } 148 149 // ClearRange clears the bits from 'start' to 'end', inclusive. 150 func (b *BitSet) ClearRange(start, end int) { 151 validateBitSetIndex(start) 152 validateBitSetIndex(end) 153 if start > end { 154 start, end = end, start 155 } 156 maximum := len(b.data) - 1 157 i1 := start >> addressBitsPerWord 158 if i1 > maximum { 159 return 160 } 161 i2 := end >> addressBitsPerWord 162 if i2 > maximum { 163 i2 = maximum 164 } 165 j := bitIndexForMask(wordMask(start)) 166 for i := i1; i <= i2; i++ { 167 if i != i1 && i != i2 { 168 b.set -= countSetBits(b.data[i]) 169 b.data[i] = 0 170 } else { 171 var last int 172 if i == i2 { 173 last = bitIndexForMask(wordMask(end)) + 1 174 } else { 175 last = dataBitsPerWord 176 } 177 for j < last { 178 mask := wordMask(j) 179 if b.data[i]&mask == mask { 180 b.data[i] &= ^mask 181 b.set-- 182 } 183 j++ 184 } 185 j = 0 186 } 187 } 188 } 189 190 // Flip the bit at 'index'. 191 func (b *BitSet) Flip(index int) { 192 validateBitSetIndex(index) 193 i := index >> addressBitsPerWord 194 b.EnsureCapacity(i + 1) 195 mask := wordMask(index) 196 b.data[i] ^= mask 197 if b.data[i]&mask == mask { 198 b.set++ 199 } else { 200 b.set-- 201 } 202 } 203 204 // FlipRange flips the bits from 'start' to 'end', inclusive. 205 func (b *BitSet) FlipRange(start, end int) { 206 validateBitSetIndex(start) 207 validateBitSetIndex(end) 208 if start > end { 209 start, end = end, start 210 } 211 i1 := start >> addressBitsPerWord 212 i2 := end >> addressBitsPerWord 213 b.EnsureCapacity(i2 + 1) 214 j := bitIndexForMask(wordMask(start)) 215 for i := i1; i <= i2; i++ { 216 if i != i1 && i != i2 { 217 b.set += dataBitsPerWord - 2*countSetBits(b.data[i]) 218 b.data[i] ^= math.MaxUint64 219 } else { 220 var last int 221 if i == i2 { 222 last = bitIndexForMask(wordMask(end)) + 1 223 } else { 224 last = dataBitsPerWord 225 } 226 for j < last { 227 mask := wordMask(j) 228 b.data[i] ^= mask 229 if b.data[i]&mask == mask { 230 b.set++ 231 } else { 232 b.set-- 233 } 234 j++ 235 } 236 j = 0 237 } 238 } 239 } 240 241 // FirstSet returns the first set bit. If no bits are set, then -1 is returned. 242 func (b *BitSet) FirstSet() int { 243 return b.NextSet(0) 244 } 245 246 // LastSet returns the last set bit. If no bits are set, then -1 is returned. 247 func (b *BitSet) LastSet() int { 248 return b.PreviousSet(len(b.data) << addressBitsPerWord) 249 } 250 251 // PreviousSet returns the previous set bit starting from 'start'. If no bits are set at or before 'start', then -1 is 252 // returned. 253 func (b *BitSet) PreviousSet(start int) int { 254 validateBitSetIndex(start) 255 i := start >> addressBitsPerWord 256 var firstBit int 257 if maximum := len(b.data) - 1; i > maximum { 258 i = maximum 259 firstBit = 63 260 } else { 261 firstBit = bitIndexForMask(wordMask(start)) 262 } 263 for i >= 0 { 264 word := b.data[i] 265 if word != 0 { 266 for j := firstBit; j >= 0; j-- { 267 mask := wordMask(j) 268 if word&mask == mask { 269 return i<<addressBitsPerWord + j 270 } 271 } 272 } 273 firstBit = 63 274 i-- 275 } 276 return -1 277 } 278 279 // NextSet returns the next set bit starting from 'start'. If no bits are set at or beyond 'start', then -1 is returned. 280 func (b *BitSet) NextSet(start int) int { 281 validateBitSetIndex(start) 282 i := start >> addressBitsPerWord 283 firstBit := bitIndexForMask(wordMask(start)) 284 maximum := len(b.data) 285 for i < maximum { 286 word := b.data[i] 287 if word != 0 { 288 for j := firstBit; j < dataBitsPerWord; j++ { 289 mask := wordMask(j) 290 if word&mask == mask { 291 return i<<addressBitsPerWord + j 292 } 293 } 294 } 295 firstBit = 0 296 i++ 297 } 298 return -1 299 } 300 301 // PreviousClear returns the previous clear bit starting from 'start'. If no bits are clear at or before 'start', then 302 // -1 is returned. 303 func (b *BitSet) PreviousClear(start int) int { 304 validateBitSetIndex(start) 305 i := start >> addressBitsPerWord 306 if i > len(b.data)-1 { 307 return start 308 } 309 firstBit := bitIndexForMask(wordMask(start)) 310 for i >= 0 { 311 word := b.data[i] 312 if word != math.MaxUint64 { 313 for j := firstBit; j >= 0; j-- { 314 mask := wordMask(j) 315 if word&mask == 0 { 316 return i<<addressBitsPerWord + j 317 } 318 } 319 } 320 firstBit = 63 321 i-- 322 } 323 return -1 324 } 325 326 // NextClear returns the next clear bit starting from 'start'. 327 func (b *BitSet) NextClear(start int) int { 328 validateBitSetIndex(start) 329 i := start >> addressBitsPerWord 330 firstBit := bitIndexForMask(wordMask(start)) 331 maximum := len(b.data) 332 for i < maximum { 333 word := b.data[i] 334 if word != math.MaxUint64 { 335 for j := firstBit; j < dataBitsPerWord; j++ { 336 mask := wordMask(j) 337 if word&mask == 0 { 338 return i<<addressBitsPerWord + j 339 } 340 } 341 } 342 firstBit = 0 343 i++ 344 } 345 return max(maximum*dataBitsPerWord, start) 346 } 347 348 // Trim the BitSet down to the minimum required to store the set bits. 349 func (b *BitSet) Trim() { 350 size := len(b.data) 351 for i := size - 1; i >= 0; i-- { 352 if b.data[i] != 0 { 353 i++ 354 if i != size { 355 data := make([]uint64, i) 356 copy(data, b.data) 357 b.data = data 358 } 359 return 360 } 361 i-- 362 } 363 b.data = nil 364 } 365 366 // EnsureCapacity ensures that the BitSet has enough underlying storage to accommodate setting a bit as high as index 367 // position 'words' x 64 - 1 without needing to allocate more storage. 368 func (b *BitSet) EnsureCapacity(words int) { 369 size := len(b.data) 370 if words > size { 371 size *= 2 372 if size < words { 373 size = words 374 } 375 data := make([]uint64, size) 376 copy(data, b.data) 377 b.data = data 378 } 379 } 380 381 // Data returns a copy of the underlying storage. 382 func (b *BitSet) Data() []uint64 { 383 b.Trim() 384 data := make([]uint64, len(b.data)) 385 copy(data, b.data) 386 return data 387 } 388 389 // Load replaces the current data with the bits set in 'data'. 390 func (b *BitSet) Load(data []uint64) { 391 b.data = make([]uint64, len(data)) 392 copy(b.data, data) 393 b.Trim() 394 b.set = 0 395 for i := len(b.data) - 1; i >= 0; i-- { 396 word := data[i] 397 if word != 0 { 398 for j := 0; j < dataBitsPerWord; j++ { 399 mask := wordMask(j) 400 if word&mask == mask { 401 b.set++ 402 } 403 } 404 } 405 } 406 } 407 408 // Reset the BitSet back to an empty state. 409 func (b *BitSet) Reset() { 410 b.data = nil 411 b.set = 0 412 } 413 414 func wordMask(index int) uint64 { 415 return uint64(1) << uint(index&bitIndexMask) 416 } 417 418 func bitIndexForMask(mask uint64) int { 419 for i := 0; i < dataBitsPerWord; i++ { 420 if mask == wordMask(i) { 421 return i 422 } 423 } 424 fmt.Printf("Unable to determine bit index for mask %064b\n", mask) 425 atexit.Exit(1) 426 return 0 427 } 428 429 func validateBitSetIndex(index int) { 430 if index < 0 { 431 fmt.Printf("Index must be positive (was %d)\n", index) 432 atexit.Exit(1) 433 } 434 }