github.com/nats-io/nats-server/v2@v2.11.0-preview.2/server/avl/seqset.go (about)

     1  // Copyright 2023 The NATS Authors
     2  // Licensed under the Apache License, Version 2.0 (the "License");
     3  // you may not use this file except in compliance with the License.
     4  // You may obtain a copy of the License at
     5  //
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package avl
    15  
    16  import (
    17  	"encoding/binary"
    18  	"errors"
    19  	"math/bits"
    20  	"sort"
    21  )
    22  
    23  // SequenceSet is a memory and encoding optimized set for storing unsigned ints.
    24  //
    25  // SequenceSet is ~80-100 times more efficient memory wise than a map[uint64]struct{}.
    26  // SequenceSet is ~1.75 times slower at inserts than the same map.
    27  // SequenceSet is not thread safe.
    28  //
    29  // We use an AVL tree with nodes that hold bitmasks for set membership.
    30  //
    31  // Encoding will convert to a space optimized encoding using bitmasks.
    32  type SequenceSet struct {
    33  	root  *node // root node
    34  	size  int   // number of items
    35  	nodes int   // number of nodes
    36  	// Having this here vs on the stack in Insert/Delete
    37  	// makes a difference in memory usage.
    38  	changed bool
    39  }
    40  
    41  // Insert will insert the sequence into the set.
    42  // The tree will be balanced inline.
    43  func (ss *SequenceSet) Insert(seq uint64) {
    44  	if ss.root = ss.root.insert(seq, &ss.changed, &ss.nodes); ss.changed {
    45  		ss.changed = false
    46  		ss.size++
    47  	}
    48  }
    49  
    50  // Exists will return true iff the sequence is a member of this set.
    51  func (ss *SequenceSet) Exists(seq uint64) bool {
    52  	for n := ss.root; n != nil; {
    53  		if seq < n.base {
    54  			n = n.l
    55  			continue
    56  		} else if seq >= n.base+numEntries {
    57  			n = n.r
    58  			continue
    59  		}
    60  		return n.exists(seq)
    61  	}
    62  	return false
    63  }
    64  
    65  // SetInitialMin should be used to set the initial minimum sequence when known.
    66  // This will more effectively utilize space versus self selecting.
    67  // The set should be empty.
    68  func (ss *SequenceSet) SetInitialMin(min uint64) error {
    69  	if !ss.IsEmpty() {
    70  		return ErrSetNotEmpty
    71  	}
    72  	ss.root, ss.nodes = &node{base: min, h: 1}, 1
    73  	return nil
    74  }
    75  
    76  // Delete will remove the sequence from the set.
    77  // Will optionally remove nodes and rebalance.
    78  // Returns where the sequence was set.
    79  func (ss *SequenceSet) Delete(seq uint64) bool {
    80  	if ss == nil || ss.root == nil {
    81  		return false
    82  	}
    83  	ss.root = ss.root.delete(seq, &ss.changed, &ss.nodes)
    84  	if ss.changed {
    85  		ss.changed = false
    86  		ss.size--
    87  		if ss.size == 0 {
    88  			ss.Empty()
    89  		}
    90  		return true
    91  	}
    92  	return false
    93  }
    94  
    95  // Size returns the number of items in the set.
    96  func (ss *SequenceSet) Size() int {
    97  	return ss.size
    98  }
    99  
   100  // Nodes returns the number of nodes in the tree.
   101  func (ss *SequenceSet) Nodes() int {
   102  	return ss.nodes
   103  }
   104  
   105  // Empty will clear all items from a set.
   106  func (ss *SequenceSet) Empty() {
   107  	ss.root = nil
   108  	ss.size = 0
   109  	ss.nodes = 0
   110  }
   111  
   112  // IsEmpty is a fast check of the set being empty.
   113  func (ss *SequenceSet) IsEmpty() bool {
   114  	if ss == nil || ss.root == nil {
   115  		return true
   116  	}
   117  	return false
   118  }
   119  
   120  // Range will invoke the given function for each item in the set.
   121  // They will range over the set in ascending order.
   122  // If the callback returns false we terminate the iteration.
   123  func (ss *SequenceSet) Range(f func(uint64) bool) {
   124  	ss.root.iter(f)
   125  }
   126  
   127  // Heights returns the left and right heights of the tree.
   128  func (ss *SequenceSet) Heights() (l, r int) {
   129  	if ss.root == nil {
   130  		return 0, 0
   131  	}
   132  	if ss.root.l != nil {
   133  		l = ss.root.l.h
   134  	}
   135  	if ss.root.r != nil {
   136  		r = ss.root.r.h
   137  	}
   138  	return l, r
   139  }
   140  
   141  // Returns min, max and number of set items.
   142  func (ss *SequenceSet) State() (min, max, num uint64) {
   143  	if ss == nil || ss.root == nil {
   144  		return 0, 0, 0
   145  	}
   146  	min, max = ss.MinMax()
   147  	return min, max, uint64(ss.Size())
   148  }
   149  
   150  // MinMax will return the minunum and maximum values in the set.
   151  func (ss *SequenceSet) MinMax() (min, max uint64) {
   152  	if ss.root == nil {
   153  		return 0, 0
   154  	}
   155  	for l := ss.root; l != nil; l = l.l {
   156  		if l.l == nil {
   157  			min = l.min()
   158  		}
   159  	}
   160  	for r := ss.root; r != nil; r = r.r {
   161  		if r.r == nil {
   162  			max = r.max()
   163  		}
   164  	}
   165  	return min, max
   166  }
   167  
   168  func clone(src *node, target **node) {
   169  	if src == nil {
   170  		return
   171  	}
   172  	n := &node{base: src.base, bits: src.bits, h: src.h}
   173  	*target = n
   174  	clone(src.l, &n.l)
   175  	clone(src.r, &n.r)
   176  }
   177  
   178  // Clone will return a clone of the given SequenceSet.
   179  func (ss *SequenceSet) Clone() *SequenceSet {
   180  	if ss == nil {
   181  		return nil
   182  	}
   183  	css := &SequenceSet{nodes: ss.nodes, size: ss.size}
   184  	clone(ss.root, &css.root)
   185  
   186  	return css
   187  }
   188  
   189  // Union will union this SequenceSet with ssa.
   190  func (ss *SequenceSet) Union(ssa ...*SequenceSet) {
   191  	for _, sa := range ssa {
   192  		sa.root.nodeIter(func(n *node) {
   193  			for nb, b := range n.bits {
   194  				for pos := uint64(0); b != 0; pos++ {
   195  					if b&1 == 1 {
   196  						seq := n.base + (uint64(nb) * uint64(bitsPerBucket)) + pos
   197  						ss.Insert(seq)
   198  					}
   199  					b >>= 1
   200  				}
   201  			}
   202  		})
   203  	}
   204  }
   205  
   206  // Union will return a union of all sets.
   207  func Union(ssa ...*SequenceSet) *SequenceSet {
   208  	if len(ssa) == 0 {
   209  		return nil
   210  	}
   211  	// Sort so we can clone largest.
   212  	sort.Slice(ssa, func(i, j int) bool { return ssa[i].Size() > ssa[j].Size() })
   213  	ss := ssa[0].Clone()
   214  
   215  	// Insert the rest through range call.
   216  	for i := 1; i < len(ssa); i++ {
   217  		ssa[i].Range(func(n uint64) bool {
   218  			ss.Insert(n)
   219  			return true
   220  		})
   221  	}
   222  	return ss
   223  }
   224  
   225  const (
   226  	// Magic is used to identify the encode binary state..
   227  	magic = uint8(22)
   228  	// Version
   229  	version = uint8(2)
   230  	// hdrLen
   231  	hdrLen = 2
   232  	// minimum length of an encoded SequenceSet.
   233  	minLen = 2 + 8 // magic + version + num nodes + num entries.
   234  )
   235  
   236  // EncodeLen returns the bytes needed for encoding.
   237  func (ss SequenceSet) EncodeLen() int {
   238  	return minLen + (ss.Nodes() * ((numBuckets+1)*8 + 2))
   239  }
   240  
   241  func (ss SequenceSet) Encode(buf []byte) ([]byte, error) {
   242  	nn, encLen := ss.Nodes(), ss.EncodeLen()
   243  
   244  	if cap(buf) < encLen {
   245  		buf = make([]byte, encLen)
   246  	} else {
   247  		buf = buf[:encLen]
   248  	}
   249  
   250  	// TODO(dlc) - Go 1.19 introduced Append to not have to keep track.
   251  	// Once 1.20 is out we could change this over.
   252  	// Also binary.Write() is way slower, do not use.
   253  
   254  	var le = binary.LittleEndian
   255  	buf[0], buf[1] = magic, version
   256  	i := hdrLen
   257  	le.PutUint32(buf[i:], uint32(nn))
   258  	le.PutUint32(buf[i+4:], uint32(ss.size))
   259  	i += 8
   260  	ss.root.nodeIter(func(n *node) {
   261  		le.PutUint64(buf[i:], n.base)
   262  		i += 8
   263  		for _, b := range n.bits {
   264  			le.PutUint64(buf[i:], b)
   265  			i += 8
   266  		}
   267  		le.PutUint16(buf[i:], uint16(n.h))
   268  		i += 2
   269  	})
   270  	return buf[:i], nil
   271  }
   272  
   273  // ErrBadEncoding is returned when we can not decode properly.
   274  var (
   275  	ErrBadEncoding = errors.New("ss: bad encoding")
   276  	ErrBadVersion  = errors.New("ss: bad version")
   277  	ErrSetNotEmpty = errors.New("ss: set not empty")
   278  )
   279  
   280  // Decode returns the sequence set and number of bytes read from the buffer on success.
   281  func Decode(buf []byte) (*SequenceSet, int, error) {
   282  	if len(buf) < minLen || buf[0] != magic {
   283  		return nil, -1, ErrBadEncoding
   284  	}
   285  
   286  	switch v := buf[1]; v {
   287  	case 1:
   288  		return decodev1(buf)
   289  	case 2:
   290  		return decodev2(buf)
   291  	default:
   292  		return nil, -1, ErrBadVersion
   293  	}
   294  }
   295  
   296  // Helper to decode v2.
   297  func decodev2(buf []byte) (*SequenceSet, int, error) {
   298  	var le = binary.LittleEndian
   299  	index := 2
   300  	nn := int(le.Uint32(buf[index:]))
   301  	sz := int(le.Uint32(buf[index+4:]))
   302  	index += 8
   303  
   304  	expectedLen := minLen + (nn * ((numBuckets+1)*8 + 2))
   305  	if len(buf) < expectedLen {
   306  		return nil, -1, ErrBadEncoding
   307  	}
   308  
   309  	ss, nodes := SequenceSet{size: sz}, make([]node, nn)
   310  
   311  	for i := 0; i < nn; i++ {
   312  		n := &nodes[i]
   313  		n.base = le.Uint64(buf[index:])
   314  		index += 8
   315  		for bi := range n.bits {
   316  			n.bits[bi] = le.Uint64(buf[index:])
   317  			index += 8
   318  		}
   319  		n.h = int(le.Uint16(buf[index:]))
   320  		index += 2
   321  		ss.insertNode(n)
   322  	}
   323  
   324  	return &ss, index, nil
   325  }
   326  
   327  // Helper to decode v1 into v2 which has fixed buckets of 32 vs 64 originally.
   328  func decodev1(buf []byte) (*SequenceSet, int, error) {
   329  	var le = binary.LittleEndian
   330  	index := 2
   331  	nn := int(le.Uint32(buf[index:]))
   332  	sz := int(le.Uint32(buf[index+4:]))
   333  	index += 8
   334  
   335  	const v1NumBuckets = 64
   336  
   337  	expectedLen := minLen + (nn * ((v1NumBuckets+1)*8 + 2))
   338  	if len(buf) < expectedLen {
   339  		return nil, -1, ErrBadEncoding
   340  	}
   341  
   342  	var ss SequenceSet
   343  	for i := 0; i < nn; i++ {
   344  		base := le.Uint64(buf[index:])
   345  		index += 8
   346  		for nb := uint64(0); nb < v1NumBuckets; nb++ {
   347  			n := le.Uint64(buf[index:])
   348  			// Walk all set bits and insert sequences manually for this decode from v1.
   349  			for pos := uint64(0); n != 0; pos++ {
   350  				if n&1 == 1 {
   351  					seq := base + (nb * uint64(bitsPerBucket)) + pos
   352  					ss.Insert(seq)
   353  				}
   354  				n >>= 1
   355  			}
   356  			index += 8
   357  		}
   358  		// Skip over encoded height.
   359  		index += 2
   360  	}
   361  
   362  	// Sanity check.
   363  	if ss.Size() != sz {
   364  		return nil, -1, ErrBadEncoding
   365  	}
   366  
   367  	return &ss, index, nil
   368  
   369  }
   370  
   371  // insertNode places a decoded node into the tree.
   372  // These should be done in tree order as defined by Encode()
   373  // This allows us to not have to calculate height or do rebalancing.
   374  // So much better performance this way.
   375  func (ss *SequenceSet) insertNode(n *node) {
   376  	ss.nodes++
   377  
   378  	if ss.root == nil {
   379  		ss.root = n
   380  		return
   381  	}
   382  	// Walk our way to the insertion point.
   383  	for p := ss.root; p != nil; {
   384  		if n.base < p.base {
   385  			if p.l == nil {
   386  				p.l = n
   387  				return
   388  			}
   389  			p = p.l
   390  		} else {
   391  			if p.r == nil {
   392  				p.r = n
   393  				return
   394  			}
   395  			p = p.r
   396  		}
   397  	}
   398  }
   399  
   400  const (
   401  	bitsPerBucket = 64 // bits in uint64
   402  	numBuckets    = 32
   403  	numEntries    = numBuckets * bitsPerBucket
   404  )
   405  
   406  type node struct {
   407  	//v dvalue
   408  	base uint64
   409  	bits [numBuckets]uint64
   410  	l    *node
   411  	r    *node
   412  	h    int
   413  }
   414  
   415  // Set the proper bit.
   416  // seq should have already been qualified and inserted should be non nil.
   417  func (n *node) set(seq uint64, inserted *bool) {
   418  	seq -= n.base
   419  	i := seq / bitsPerBucket
   420  	mask := uint64(1) << (seq % bitsPerBucket)
   421  	if (n.bits[i] & mask) == 0 {
   422  		n.bits[i] |= mask
   423  		*inserted = true
   424  	}
   425  }
   426  
   427  func (n *node) insert(seq uint64, inserted *bool, nodes *int) *node {
   428  	if n == nil {
   429  		base := (seq / numEntries) * numEntries
   430  		n := &node{base: base, h: 1}
   431  		n.set(seq, inserted)
   432  		*nodes++
   433  		return n
   434  	}
   435  
   436  	if seq < n.base {
   437  		n.l = n.l.insert(seq, inserted, nodes)
   438  	} else if seq >= n.base+numEntries {
   439  		n.r = n.r.insert(seq, inserted, nodes)
   440  	} else {
   441  		n.set(seq, inserted)
   442  	}
   443  
   444  	n.h = maxH(n) + 1
   445  
   446  	// Don't make a function, impacts performance.
   447  	if bf := balanceF(n); bf > 1 {
   448  		// Left unbalanced.
   449  		if balanceF(n.l) < 0 {
   450  			n.l = n.l.rotateL()
   451  		}
   452  		return n.rotateR()
   453  	} else if bf < -1 {
   454  		// Right unbalanced.
   455  		if balanceF(n.r) > 0 {
   456  			n.r = n.r.rotateR()
   457  		}
   458  		return n.rotateL()
   459  	}
   460  	return n
   461  }
   462  
   463  func (n *node) rotateL() *node {
   464  	r := n.r
   465  	if r != nil {
   466  		n.r = r.l
   467  		r.l = n
   468  		n.h = maxH(n) + 1
   469  		r.h = maxH(r) + 1
   470  	} else {
   471  		n.r = nil
   472  		n.h = maxH(n) + 1
   473  	}
   474  	return r
   475  }
   476  
   477  func (n *node) rotateR() *node {
   478  	l := n.l
   479  	if l != nil {
   480  		n.l = l.r
   481  		l.r = n
   482  		n.h = maxH(n) + 1
   483  		l.h = maxH(l) + 1
   484  	} else {
   485  		n.l = nil
   486  		n.h = maxH(n) + 1
   487  	}
   488  	return l
   489  }
   490  
   491  func balanceF(n *node) int {
   492  	if n == nil {
   493  		return 0
   494  	}
   495  	var lh, rh int
   496  	if n.l != nil {
   497  		lh = n.l.h
   498  	}
   499  	if n.r != nil {
   500  		rh = n.r.h
   501  	}
   502  	return lh - rh
   503  }
   504  
   505  func maxH(n *node) int {
   506  	if n == nil {
   507  		return 0
   508  	}
   509  	var lh, rh int
   510  	if n.l != nil {
   511  		lh = n.l.h
   512  	}
   513  	if n.r != nil {
   514  		rh = n.r.h
   515  	}
   516  	if lh > rh {
   517  		return lh
   518  	}
   519  	return rh
   520  }
   521  
   522  // Clear the proper bit.
   523  // seq should have already been qualified and deleted should be non nil.
   524  // Will return true if this node is now empty.
   525  func (n *node) clear(seq uint64, deleted *bool) bool {
   526  	seq -= n.base
   527  	i := seq / bitsPerBucket
   528  	mask := uint64(1) << (seq % bitsPerBucket)
   529  	if (n.bits[i] & mask) != 0 {
   530  		n.bits[i] &^= mask
   531  		*deleted = true
   532  	}
   533  	for _, b := range n.bits {
   534  		if b != 0 {
   535  			return false
   536  		}
   537  	}
   538  	return true
   539  }
   540  
   541  func (n *node) delete(seq uint64, deleted *bool, nodes *int) *node {
   542  	if n == nil {
   543  		return nil
   544  	}
   545  
   546  	if seq < n.base {
   547  		n.l = n.l.delete(seq, deleted, nodes)
   548  	} else if seq >= n.base+numEntries {
   549  		n.r = n.r.delete(seq, deleted, nodes)
   550  	} else if empty := n.clear(seq, deleted); empty {
   551  		*nodes--
   552  		if n.l == nil {
   553  			n = n.r
   554  		} else if n.r == nil {
   555  			n = n.l
   556  		} else {
   557  			// We have both children.
   558  			n.r = n.r.insertNodePrev(n.l)
   559  			n = n.r
   560  		}
   561  	}
   562  
   563  	if n != nil {
   564  		n.h = maxH(n) + 1
   565  	}
   566  
   567  	// Check balance.
   568  	if bf := balanceF(n); bf > 1 {
   569  		// Left unbalanced.
   570  		if balanceF(n.l) < 0 {
   571  			n.l = n.l.rotateL()
   572  		}
   573  		return n.rotateR()
   574  	} else if bf < -1 {
   575  		// right unbalanced.
   576  		if balanceF(n.r) > 0 {
   577  			n.r = n.r.rotateR()
   578  		}
   579  		return n.rotateL()
   580  	}
   581  
   582  	return n
   583  }
   584  
   585  // Will insert nn into the node assuming it is less than all other nodes in n.
   586  // Will re-calculate height and balance.
   587  func (n *node) insertNodePrev(nn *node) *node {
   588  	if n.l == nil {
   589  		n.l = nn
   590  	} else {
   591  		n.l = n.l.insertNodePrev(nn)
   592  	}
   593  	n.h = maxH(n) + 1
   594  
   595  	// Check balance.
   596  	if bf := balanceF(n); bf > 1 {
   597  		// Left unbalanced.
   598  		if balanceF(n.l) < 0 {
   599  			n.l = n.l.rotateL()
   600  		}
   601  		return n.rotateR()
   602  	} else if bf < -1 {
   603  		// right unbalanced.
   604  		if balanceF(n.r) > 0 {
   605  			n.r = n.r.rotateR()
   606  		}
   607  		return n.rotateL()
   608  	}
   609  	return n
   610  }
   611  
   612  func (n *node) exists(seq uint64) bool {
   613  	seq -= n.base
   614  	i := seq / bitsPerBucket
   615  	mask := uint64(1) << (seq % bitsPerBucket)
   616  	return n.bits[i]&mask != 0
   617  }
   618  
   619  // Return minimum sequence in the set.
   620  // This node can not be empty.
   621  func (n *node) min() uint64 {
   622  	for i, b := range n.bits {
   623  		if b != 0 {
   624  			return n.base +
   625  				uint64(i*bitsPerBucket) +
   626  				uint64(bits.TrailingZeros64(b))
   627  		}
   628  	}
   629  	return 0
   630  }
   631  
   632  // Return maximum sequence in the set.
   633  // This node can not be empty.
   634  func (n *node) max() uint64 {
   635  	for i := numBuckets - 1; i >= 0; i-- {
   636  		if b := n.bits[i]; b != 0 {
   637  			return n.base +
   638  				uint64(i*bitsPerBucket) +
   639  				uint64(bitsPerBucket-bits.LeadingZeros64(b>>1))
   640  		}
   641  	}
   642  	return 0
   643  }
   644  
   645  // This is done in tree order.
   646  func (n *node) nodeIter(f func(n *node)) {
   647  	if n == nil {
   648  		return
   649  	}
   650  	f(n)
   651  	n.l.nodeIter(f)
   652  	n.r.nodeIter(f)
   653  }
   654  
   655  // iter will iterate through the set's items in this node.
   656  // If the supplied function returns false we terminate the iteration.
   657  func (n *node) iter(f func(uint64) bool) bool {
   658  	if n == nil {
   659  		return true
   660  	}
   661  
   662  	if ok := n.l.iter(f); !ok {
   663  		return false
   664  	}
   665  	for num := n.base; num < n.base+numEntries; num++ {
   666  		if n.exists(num) {
   667  			if ok := f(num); !ok {
   668  				return false
   669  			}
   670  		}
   671  	}
   672  	if ok := n.r.iter(f); !ok {
   673  		return false
   674  	}
   675  
   676  	return true
   677  }