get.pme.sh/pnats@v0.0.0-20240304004023-26bb5a137ed0/server/avl/seqset.go (about) 1 // Copyright 2023 The NATS Authors 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package avl 15 16 import ( 17 "encoding/binary" 18 "errors" 19 "math/bits" 20 "sort" 21 ) 22 23 // SequenceSet is a memory and encoding optimized set for storing unsigned ints. 24 // 25 // SequenceSet is ~80-100 times more efficient memory wise than a map[uint64]struct{}. 26 // SequenceSet is ~1.75 times slower at inserts than the same map. 27 // SequenceSet is not thread safe. 28 // 29 // We use an AVL tree with nodes that hold bitmasks for set membership. 30 // 31 // Encoding will convert to a space optimized encoding using bitmasks. 32 type SequenceSet struct { 33 root *node // root node 34 size int // number of items 35 nodes int // number of nodes 36 // Having this here vs on the stack in Insert/Delete 37 // makes a difference in memory usage. 38 changed bool 39 } 40 41 // Insert will insert the sequence into the set. 42 // The tree will be balanced inline. 43 func (ss *SequenceSet) Insert(seq uint64) { 44 if ss.root = ss.root.insert(seq, &ss.changed, &ss.nodes); ss.changed { 45 ss.changed = false 46 ss.size++ 47 } 48 } 49 50 // Exists will return true iff the sequence is a member of this set. 51 func (ss *SequenceSet) Exists(seq uint64) bool { 52 for n := ss.root; n != nil; { 53 if seq < n.base { 54 n = n.l 55 continue 56 } else if seq >= n.base+numEntries { 57 n = n.r 58 continue 59 } 60 return n.exists(seq) 61 } 62 return false 63 } 64 65 // SetInitialMin should be used to set the initial minimum sequence when known. 66 // This will more effectively utilize space versus self selecting. 67 // The set should be empty. 68 func (ss *SequenceSet) SetInitialMin(min uint64) error { 69 if !ss.IsEmpty() { 70 return ErrSetNotEmpty 71 } 72 ss.root, ss.nodes = &node{base: min, h: 1}, 1 73 return nil 74 } 75 76 // Delete will remove the sequence from the set. 77 // Will optionally remove nodes and rebalance. 78 // Returns where the sequence was set. 79 func (ss *SequenceSet) Delete(seq uint64) bool { 80 if ss == nil || ss.root == nil { 81 return false 82 } 83 ss.root = ss.root.delete(seq, &ss.changed, &ss.nodes) 84 if ss.changed { 85 ss.changed = false 86 ss.size-- 87 if ss.size == 0 { 88 ss.Empty() 89 } 90 return true 91 } 92 return false 93 } 94 95 // Size returns the number of items in the set. 96 func (ss *SequenceSet) Size() int { 97 return ss.size 98 } 99 100 // Nodes returns the number of nodes in the tree. 101 func (ss *SequenceSet) Nodes() int { 102 return ss.nodes 103 } 104 105 // Empty will clear all items from a set. 106 func (ss *SequenceSet) Empty() { 107 ss.root = nil 108 ss.size = 0 109 ss.nodes = 0 110 } 111 112 // IsEmpty is a fast check of the set being empty. 113 func (ss *SequenceSet) IsEmpty() bool { 114 if ss == nil || ss.root == nil { 115 return true 116 } 117 return false 118 } 119 120 // Range will invoke the given function for each item in the set. 121 // They will range over the set in ascending order. 122 // If the callback returns false we terminate the iteration. 123 func (ss *SequenceSet) Range(f func(uint64) bool) { 124 ss.root.iter(f) 125 } 126 127 // Heights returns the left and right heights of the tree. 128 func (ss *SequenceSet) Heights() (l, r int) { 129 if ss.root == nil { 130 return 0, 0 131 } 132 if ss.root.l != nil { 133 l = ss.root.l.h 134 } 135 if ss.root.r != nil { 136 r = ss.root.r.h 137 } 138 return l, r 139 } 140 141 // Returns min, max and number of set items. 142 func (ss *SequenceSet) State() (min, max, num uint64) { 143 if ss == nil || ss.root == nil { 144 return 0, 0, 0 145 } 146 min, max = ss.MinMax() 147 return min, max, uint64(ss.Size()) 148 } 149 150 // MinMax will return the minunum and maximum values in the set. 151 func (ss *SequenceSet) MinMax() (min, max uint64) { 152 if ss.root == nil { 153 return 0, 0 154 } 155 for l := ss.root; l != nil; l = l.l { 156 if l.l == nil { 157 min = l.min() 158 } 159 } 160 for r := ss.root; r != nil; r = r.r { 161 if r.r == nil { 162 max = r.max() 163 } 164 } 165 return min, max 166 } 167 168 func clone(src *node, target **node) { 169 if src == nil { 170 return 171 } 172 n := &node{base: src.base, bits: src.bits, h: src.h} 173 *target = n 174 clone(src.l, &n.l) 175 clone(src.r, &n.r) 176 } 177 178 // Clone will return a clone of the given SequenceSet. 179 func (ss *SequenceSet) Clone() *SequenceSet { 180 if ss == nil { 181 return nil 182 } 183 css := &SequenceSet{nodes: ss.nodes, size: ss.size} 184 clone(ss.root, &css.root) 185 186 return css 187 } 188 189 // Union will union this SequenceSet with ssa. 190 func (ss *SequenceSet) Union(ssa ...*SequenceSet) { 191 for _, sa := range ssa { 192 sa.root.nodeIter(func(n *node) { 193 for nb, b := range n.bits { 194 for pos := uint64(0); b != 0; pos++ { 195 if b&1 == 1 { 196 seq := n.base + (uint64(nb) * uint64(bitsPerBucket)) + pos 197 ss.Insert(seq) 198 } 199 b >>= 1 200 } 201 } 202 }) 203 } 204 } 205 206 // Union will return a union of all sets. 207 func Union(ssa ...*SequenceSet) *SequenceSet { 208 if len(ssa) == 0 { 209 return nil 210 } 211 // Sort so we can clone largest. 212 sort.Slice(ssa, func(i, j int) bool { return ssa[i].Size() > ssa[j].Size() }) 213 ss := ssa[0].Clone() 214 215 // Insert the rest through range call. 216 for i := 1; i < len(ssa); i++ { 217 ssa[i].Range(func(n uint64) bool { 218 ss.Insert(n) 219 return true 220 }) 221 } 222 return ss 223 } 224 225 const ( 226 // Magic is used to identify the encode binary state.. 227 magic = uint8(22) 228 // Version 229 version = uint8(2) 230 // hdrLen 231 hdrLen = 2 232 // minimum length of an encoded SequenceSet. 233 minLen = 2 + 8 // magic + version + num nodes + num entries. 234 ) 235 236 // EncodeLen returns the bytes needed for encoding. 237 func (ss SequenceSet) EncodeLen() int { 238 return minLen + (ss.Nodes() * ((numBuckets+1)*8 + 2)) 239 } 240 241 func (ss SequenceSet) Encode(buf []byte) ([]byte, error) { 242 nn, encLen := ss.Nodes(), ss.EncodeLen() 243 244 if cap(buf) < encLen { 245 buf = make([]byte, encLen) 246 } else { 247 buf = buf[:encLen] 248 } 249 250 // TODO(dlc) - Go 1.19 introduced Append to not have to keep track. 251 // Once 1.20 is out we could change this over. 252 // Also binary.Write() is way slower, do not use. 253 254 var le = binary.LittleEndian 255 buf[0], buf[1] = magic, version 256 i := hdrLen 257 le.PutUint32(buf[i:], uint32(nn)) 258 le.PutUint32(buf[i+4:], uint32(ss.size)) 259 i += 8 260 ss.root.nodeIter(func(n *node) { 261 le.PutUint64(buf[i:], n.base) 262 i += 8 263 for _, b := range n.bits { 264 le.PutUint64(buf[i:], b) 265 i += 8 266 } 267 le.PutUint16(buf[i:], uint16(n.h)) 268 i += 2 269 }) 270 return buf[:i], nil 271 } 272 273 // ErrBadEncoding is returned when we can not decode properly. 274 var ( 275 ErrBadEncoding = errors.New("ss: bad encoding") 276 ErrBadVersion = errors.New("ss: bad version") 277 ErrSetNotEmpty = errors.New("ss: set not empty") 278 ) 279 280 // Decode returns the sequence set and number of bytes read from the buffer on success. 281 func Decode(buf []byte) (*SequenceSet, int, error) { 282 if len(buf) < minLen || buf[0] != magic { 283 return nil, -1, ErrBadEncoding 284 } 285 286 switch v := buf[1]; v { 287 case 1: 288 return decodev1(buf) 289 case 2: 290 return decodev2(buf) 291 default: 292 return nil, -1, ErrBadVersion 293 } 294 } 295 296 // Helper to decode v2. 297 func decodev2(buf []byte) (*SequenceSet, int, error) { 298 var le = binary.LittleEndian 299 index := 2 300 nn := int(le.Uint32(buf[index:])) 301 sz := int(le.Uint32(buf[index+4:])) 302 index += 8 303 304 expectedLen := minLen + (nn * ((numBuckets+1)*8 + 2)) 305 if len(buf) < expectedLen { 306 return nil, -1, ErrBadEncoding 307 } 308 309 ss, nodes := SequenceSet{size: sz}, make([]node, nn) 310 311 for i := 0; i < nn; i++ { 312 n := &nodes[i] 313 n.base = le.Uint64(buf[index:]) 314 index += 8 315 for bi := range n.bits { 316 n.bits[bi] = le.Uint64(buf[index:]) 317 index += 8 318 } 319 n.h = int(le.Uint16(buf[index:])) 320 index += 2 321 ss.insertNode(n) 322 } 323 324 return &ss, index, nil 325 } 326 327 // Helper to decode v1 into v2 which has fixed buckets of 32 vs 64 originally. 328 func decodev1(buf []byte) (*SequenceSet, int, error) { 329 var le = binary.LittleEndian 330 index := 2 331 nn := int(le.Uint32(buf[index:])) 332 sz := int(le.Uint32(buf[index+4:])) 333 index += 8 334 335 const v1NumBuckets = 64 336 337 expectedLen := minLen + (nn * ((v1NumBuckets+1)*8 + 2)) 338 if len(buf) < expectedLen { 339 return nil, -1, ErrBadEncoding 340 } 341 342 var ss SequenceSet 343 for i := 0; i < nn; i++ { 344 base := le.Uint64(buf[index:]) 345 index += 8 346 for nb := uint64(0); nb < v1NumBuckets; nb++ { 347 n := le.Uint64(buf[index:]) 348 // Walk all set bits and insert sequences manually for this decode from v1. 349 for pos := uint64(0); n != 0; pos++ { 350 if n&1 == 1 { 351 seq := base + (nb * uint64(bitsPerBucket)) + pos 352 ss.Insert(seq) 353 } 354 n >>= 1 355 } 356 index += 8 357 } 358 // Skip over encoded height. 359 index += 2 360 } 361 362 // Sanity check. 363 if ss.Size() != sz { 364 return nil, -1, ErrBadEncoding 365 } 366 367 return &ss, index, nil 368 369 } 370 371 // insertNode places a decoded node into the tree. 372 // These should be done in tree order as defined by Encode() 373 // This allows us to not have to calculate height or do rebalancing. 374 // So much better performance this way. 375 func (ss *SequenceSet) insertNode(n *node) { 376 ss.nodes++ 377 378 if ss.root == nil { 379 ss.root = n 380 return 381 } 382 // Walk our way to the insertion point. 383 for p := ss.root; p != nil; { 384 if n.base < p.base { 385 if p.l == nil { 386 p.l = n 387 return 388 } 389 p = p.l 390 } else { 391 if p.r == nil { 392 p.r = n 393 return 394 } 395 p = p.r 396 } 397 } 398 } 399 400 const ( 401 bitsPerBucket = 64 // bits in uint64 402 numBuckets = 32 403 numEntries = numBuckets * bitsPerBucket 404 ) 405 406 type node struct { 407 //v dvalue 408 base uint64 409 bits [numBuckets]uint64 410 l *node 411 r *node 412 h int 413 } 414 415 // Set the proper bit. 416 // seq should have already been qualified and inserted should be non nil. 417 func (n *node) set(seq uint64, inserted *bool) { 418 seq -= n.base 419 i := seq / bitsPerBucket 420 mask := uint64(1) << (seq % bitsPerBucket) 421 if (n.bits[i] & mask) == 0 { 422 n.bits[i] |= mask 423 *inserted = true 424 } 425 } 426 427 func (n *node) insert(seq uint64, inserted *bool, nodes *int) *node { 428 if n == nil { 429 base := (seq / numEntries) * numEntries 430 n := &node{base: base, h: 1} 431 n.set(seq, inserted) 432 *nodes++ 433 return n 434 } 435 436 if seq < n.base { 437 n.l = n.l.insert(seq, inserted, nodes) 438 } else if seq >= n.base+numEntries { 439 n.r = n.r.insert(seq, inserted, nodes) 440 } else { 441 n.set(seq, inserted) 442 } 443 444 n.h = maxH(n) + 1 445 446 // Don't make a function, impacts performance. 447 if bf := balanceF(n); bf > 1 { 448 // Left unbalanced. 449 if balanceF(n.l) < 0 { 450 n.l = n.l.rotateL() 451 } 452 return n.rotateR() 453 } else if bf < -1 { 454 // Right unbalanced. 455 if balanceF(n.r) > 0 { 456 n.r = n.r.rotateR() 457 } 458 return n.rotateL() 459 } 460 return n 461 } 462 463 func (n *node) rotateL() *node { 464 r := n.r 465 if r != nil { 466 n.r = r.l 467 r.l = n 468 n.h = maxH(n) + 1 469 r.h = maxH(r) + 1 470 } else { 471 n.r = nil 472 n.h = maxH(n) + 1 473 } 474 return r 475 } 476 477 func (n *node) rotateR() *node { 478 l := n.l 479 if l != nil { 480 n.l = l.r 481 l.r = n 482 n.h = maxH(n) + 1 483 l.h = maxH(l) + 1 484 } else { 485 n.l = nil 486 n.h = maxH(n) + 1 487 } 488 return l 489 } 490 491 func balanceF(n *node) int { 492 if n == nil { 493 return 0 494 } 495 var lh, rh int 496 if n.l != nil { 497 lh = n.l.h 498 } 499 if n.r != nil { 500 rh = n.r.h 501 } 502 return lh - rh 503 } 504 505 func maxH(n *node) int { 506 if n == nil { 507 return 0 508 } 509 var lh, rh int 510 if n.l != nil { 511 lh = n.l.h 512 } 513 if n.r != nil { 514 rh = n.r.h 515 } 516 if lh > rh { 517 return lh 518 } 519 return rh 520 } 521 522 // Clear the proper bit. 523 // seq should have already been qualified and deleted should be non nil. 524 // Will return true if this node is now empty. 525 func (n *node) clear(seq uint64, deleted *bool) bool { 526 seq -= n.base 527 i := seq / bitsPerBucket 528 mask := uint64(1) << (seq % bitsPerBucket) 529 if (n.bits[i] & mask) != 0 { 530 n.bits[i] &^= mask 531 *deleted = true 532 } 533 for _, b := range n.bits { 534 if b != 0 { 535 return false 536 } 537 } 538 return true 539 } 540 541 func (n *node) delete(seq uint64, deleted *bool, nodes *int) *node { 542 if n == nil { 543 return nil 544 } 545 546 if seq < n.base { 547 n.l = n.l.delete(seq, deleted, nodes) 548 } else if seq >= n.base+numEntries { 549 n.r = n.r.delete(seq, deleted, nodes) 550 } else if empty := n.clear(seq, deleted); empty { 551 *nodes-- 552 if n.l == nil { 553 n = n.r 554 } else if n.r == nil { 555 n = n.l 556 } else { 557 // We have both children. 558 n.r = n.r.insertNodePrev(n.l) 559 n = n.r 560 } 561 } 562 563 if n != nil { 564 n.h = maxH(n) + 1 565 } 566 567 // Check balance. 568 if bf := balanceF(n); bf > 1 { 569 // Left unbalanced. 570 if balanceF(n.l) < 0 { 571 n.l = n.l.rotateL() 572 } 573 return n.rotateR() 574 } else if bf < -1 { 575 // right unbalanced. 576 if balanceF(n.r) > 0 { 577 n.r = n.r.rotateR() 578 } 579 return n.rotateL() 580 } 581 582 return n 583 } 584 585 // Will insert nn into the node assuming it is less than all other nodes in n. 586 // Will re-calculate height and balance. 587 func (n *node) insertNodePrev(nn *node) *node { 588 if n.l == nil { 589 n.l = nn 590 } else { 591 n.l = n.l.insertNodePrev(nn) 592 } 593 n.h = maxH(n) + 1 594 595 // Check balance. 596 if bf := balanceF(n); bf > 1 { 597 // Left unbalanced. 598 if balanceF(n.l) < 0 { 599 n.l = n.l.rotateL() 600 } 601 return n.rotateR() 602 } else if bf < -1 { 603 // right unbalanced. 604 if balanceF(n.r) > 0 { 605 n.r = n.r.rotateR() 606 } 607 return n.rotateL() 608 } 609 return n 610 } 611 612 func (n *node) exists(seq uint64) bool { 613 seq -= n.base 614 i := seq / bitsPerBucket 615 mask := uint64(1) << (seq % bitsPerBucket) 616 return n.bits[i]&mask != 0 617 } 618 619 // Return minimum sequence in the set. 620 // This node can not be empty. 621 func (n *node) min() uint64 { 622 for i, b := range n.bits { 623 if b != 0 { 624 return n.base + 625 uint64(i*bitsPerBucket) + 626 uint64(bits.TrailingZeros64(b)) 627 } 628 } 629 return 0 630 } 631 632 // Return maximum sequence in the set. 633 // This node can not be empty. 634 func (n *node) max() uint64 { 635 for i := numBuckets - 1; i >= 0; i-- { 636 if b := n.bits[i]; b != 0 { 637 return n.base + 638 uint64(i*bitsPerBucket) + 639 uint64(bitsPerBucket-bits.LeadingZeros64(b>>1)) 640 } 641 } 642 return 0 643 } 644 645 // This is done in tree order. 646 func (n *node) nodeIter(f func(n *node)) { 647 if n == nil { 648 return 649 } 650 f(n) 651 n.l.nodeIter(f) 652 n.r.nodeIter(f) 653 } 654 655 // iter will iterate through the set's items in this node. 656 // If the supplied function returns false we terminate the iteration. 657 func (n *node) iter(f func(uint64) bool) bool { 658 if n == nil { 659 return true 660 } 661 662 if ok := n.l.iter(f); !ok { 663 return false 664 } 665 for num := n.base; num < n.base+numEntries; num++ { 666 if n.exists(num) { 667 if ok := f(num); !ok { 668 return false 669 } 670 } 671 } 672 if ok := n.r.iter(f); !ok { 673 return false 674 } 675 676 return true 677 }