github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/util/intsets/fast.go (about)

     1  // Copyright 2017 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  //go:build !fast_int_set_small && !fast_int_set_large
    12  // +build !fast_int_set_small,!fast_int_set_large
    13  
    14  package intsets
    15  
    16  import (
    17  	"bytes"
    18  	"encoding/binary"
    19  	"io"
    20  	"math/bits"
    21  
    22  	"github.com/cockroachdb/errors"
    23  )
    24  
    25  // Fast keeps track of a set of integers. It does not perform any
    26  // allocations when the values are in the range [0, smallCutoff). It is not
    27  // thread-safe.
    28  type Fast struct {
    29  	// small is a bitmap that stores values in the range [0, smallCutoff).
    30  	small bitmap
    31  	// large is only allocated if values are added to the set that are not in
    32  	// the range [0, smallCutoff).
    33  	large *Sparse
    34  }
    35  
    36  // MakeFast returns a set initialized with the given values.
    37  func MakeFast(vals ...int) Fast {
    38  	var res Fast
    39  	for _, v := range vals {
    40  		res.Add(v)
    41  	}
    42  	return res
    43  }
    44  
    45  // fitsInSmall returns whether all elements in this set are between 0 and
    46  // smallCutoff.
    47  //
    48  //gcassert:inline
    49  func (s *Fast) fitsInSmall() bool {
    50  	return s.large == nil || s.large.Empty()
    51  }
    52  
    53  // Add adds a value to the set. No-op if the value is already in the set. If the
    54  // large set is not nil and the value is within the range [0, 63], the value is
    55  // added to both the large and small sets.
    56  func (s *Fast) Add(i int) {
    57  	if i >= 0 && i < smallCutoff {
    58  		s.small.Set(i)
    59  		return
    60  	}
    61  	if s.large == nil {
    62  		s.large = new(Sparse)
    63  	}
    64  	s.large.Add(i)
    65  }
    66  
    67  // AddRange adds values 'from' up to 'to' (inclusively) to the set.
    68  // E.g. AddRange(1,5) adds the values 1, 2, 3, 4, 5 to the set.
    69  // 'to' must be >= 'from'.
    70  // AddRange is always more efficient than individual Adds.
    71  func (s *Fast) AddRange(from, to int) {
    72  	if to < from {
    73  		panic("invalid range when adding range to Fast")
    74  	}
    75  
    76  	if s.large == nil && from >= 0 && to < smallCutoff {
    77  		s.small.SetRange(from, to)
    78  	} else {
    79  		for i := from; i <= to; i++ {
    80  			s.Add(i)
    81  		}
    82  	}
    83  }
    84  
    85  // Remove removes a value from the set. No-op if the value is not in the set.
    86  func (s *Fast) Remove(i int) {
    87  	if i >= 0 && i < smallCutoff {
    88  		s.small.Unset(i)
    89  		return
    90  	}
    91  	if s.large != nil {
    92  		s.large.Remove(i)
    93  	}
    94  }
    95  
    96  // Contains returns true if the set contains the value.
    97  func (s Fast) Contains(i int) bool {
    98  	if i >= 0 && i < smallCutoff {
    99  		return s.small.IsSet(i)
   100  	}
   101  	if s.large != nil {
   102  		return s.large.Contains(i)
   103  	}
   104  	return false
   105  }
   106  
   107  // Empty returns true if the set is empty.
   108  func (s Fast) Empty() bool {
   109  	return s.small == bitmap{} && (s.large == nil || s.large.Empty())
   110  }
   111  
   112  // Len returns the number of the elements in the set.
   113  func (s Fast) Len() int {
   114  	l := s.small.OnesCount()
   115  	if s.large != nil {
   116  		l += s.large.Len()
   117  	}
   118  	return l
   119  }
   120  
   121  // Next returns the first value in the set which is >= startVal. If there is no
   122  // value, the second return value is false.
   123  func (s Fast) Next(startVal int) (int, bool) {
   124  	if startVal < 0 && s.large != nil {
   125  		if res := s.large.LowerBound(startVal); res < 0 {
   126  			return res, true
   127  		}
   128  	}
   129  	if startVal < 0 {
   130  		// Negative values are must be in s.large.
   131  		startVal = 0
   132  	}
   133  	if startVal < smallCutoff {
   134  		if nextVal, ok := s.small.Next(startVal); ok {
   135  			return nextVal, true
   136  		}
   137  	}
   138  	if s.large != nil {
   139  		res := s.large.LowerBound(startVal)
   140  		return res, res != MaxInt
   141  	}
   142  	return MaxInt, false
   143  }
   144  
   145  // ForEach calls a function for each value in the set (in increasing order).
   146  func (s Fast) ForEach(f func(i int)) {
   147  	if !s.fitsInSmall() {
   148  		for x := s.large.Min(); x < 0; x = s.large.LowerBound(x + 1) {
   149  			f(x)
   150  		}
   151  	}
   152  	for v := s.small.lo; v != 0; {
   153  		i := bits.TrailingZeros64(v)
   154  		f(i)
   155  		v &^= 1 << uint(i)
   156  	}
   157  	for v := s.small.hi; v != 0; {
   158  		i := bits.TrailingZeros64(v)
   159  		f(64 + i)
   160  		v &^= 1 << uint(i)
   161  	}
   162  	if !s.fitsInSmall() {
   163  		for x := s.large.LowerBound(0); x != MaxInt; x = s.large.LowerBound(x + 1) {
   164  			f(x)
   165  		}
   166  	}
   167  }
   168  
   169  // Ordered returns a slice with all the integers in the set, in increasing order.
   170  func (s Fast) Ordered() []int {
   171  	if s.Empty() {
   172  		return nil
   173  	}
   174  	result := make([]int, 0, s.Len())
   175  	s.ForEach(func(i int) {
   176  		result = append(result, i)
   177  	})
   178  	return result
   179  }
   180  
   181  // Copy returns a copy of s which can be modified independently.
   182  func (s Fast) Copy() Fast {
   183  	var c Fast
   184  	c.small = s.small
   185  	if s.large != nil && !s.large.Empty() {
   186  		c.large = new(Sparse)
   187  		c.large.Copy(s.large)
   188  	}
   189  	return c
   190  }
   191  
   192  // CopyFrom sets the receiver to a copy of other, which can then be modified
   193  // independently.
   194  func (s *Fast) CopyFrom(other Fast) {
   195  	s.small = other.small
   196  	if other.large != nil && !other.large.Empty() {
   197  		if s.large == nil {
   198  			s.large = new(Sparse)
   199  		}
   200  		s.large.Copy(other.large)
   201  	} else {
   202  		if s.large != nil {
   203  			s.large.Clear()
   204  		}
   205  	}
   206  }
   207  
   208  // UnionWith adds all the elements from rhs to this set.
   209  func (s *Fast) UnionWith(rhs Fast) {
   210  	s.small.UnionWith(rhs.small)
   211  	if rhs.large == nil || rhs.large.Empty() {
   212  		// Fast path.
   213  		return
   214  	}
   215  	if s.large == nil {
   216  		s.large = new(Sparse)
   217  	}
   218  	s.large.UnionWith(rhs.large)
   219  }
   220  
   221  // Union returns the union of s and rhs as a new set.
   222  func (s Fast) Union(rhs Fast) Fast {
   223  	r := s.Copy()
   224  	r.UnionWith(rhs)
   225  	return r
   226  }
   227  
   228  // IntersectionWith removes any elements not in rhs from this set.
   229  func (s *Fast) IntersectionWith(rhs Fast) {
   230  	s.small.IntersectionWith(rhs.small)
   231  	if rhs.large == nil {
   232  		s.large = nil
   233  	}
   234  	if s.large == nil {
   235  		// Fast path.
   236  		return
   237  	}
   238  	s.large.IntersectionWith(rhs.large)
   239  }
   240  
   241  // Intersection returns the intersection of s and rhs as a new set.
   242  func (s Fast) Intersection(rhs Fast) Fast {
   243  	r := s.Copy()
   244  	r.IntersectionWith(rhs)
   245  	return r
   246  }
   247  
   248  // Intersects returns true if s has any elements in common with rhs.
   249  func (s Fast) Intersects(rhs Fast) bool {
   250  	if s.small.Intersects(rhs.small) {
   251  		return true
   252  	}
   253  	if s.large == nil || rhs.large == nil {
   254  		return false
   255  	}
   256  	return s.large.Intersects(rhs.large)
   257  }
   258  
   259  // DifferenceWith removes any elements in rhs from this set.
   260  func (s *Fast) DifferenceWith(rhs Fast) {
   261  	s.small.DifferenceWith(rhs.small)
   262  	if s.large == nil || rhs.large == nil {
   263  		// Fast path
   264  		return
   265  	}
   266  	s.large.DifferenceWith(rhs.large)
   267  }
   268  
   269  // Difference returns the elements of s that are not in rhs as a new set.
   270  func (s Fast) Difference(rhs Fast) Fast {
   271  	r := s.Copy()
   272  	r.DifferenceWith(rhs)
   273  	return r
   274  }
   275  
   276  // Equals returns true if the two sets are identical.
   277  func (s Fast) Equals(rhs Fast) bool {
   278  	if s.small != rhs.small {
   279  		return false
   280  	}
   281  	if s.fitsInSmall() {
   282  		// We already know that the `small` fields are equal. We just have to make
   283  		// sure that the other set also has no large elements.
   284  		return rhs.fitsInSmall()
   285  	}
   286  	// We know that s has large elements.
   287  	return rhs.large != nil && s.large.Equals(rhs.large)
   288  }
   289  
   290  // SubsetOf returns true if rhs contains all the elements in s.
   291  func (s Fast) SubsetOf(rhs Fast) bool {
   292  	if s.fitsInSmall() {
   293  		return s.small.SubsetOf(rhs.small)
   294  	}
   295  	if rhs.fitsInSmall() {
   296  		// s doesn't fit in small and rhs does.
   297  		return false
   298  	}
   299  	return s.small.SubsetOf(rhs.small) && s.large.SubsetOf(rhs.large)
   300  }
   301  
   302  // Encode the set and write it to a bytes.Buffer using binary.varint byte
   303  // encoding.
   304  //
   305  // This method cannot be used if the set contains negative elements.
   306  //
   307  // If the set has only elements in the range [0, 63], we encode a 0 followed by
   308  // a 64-bit bitmap. Otherwise, we encode a length followed by each element.
   309  //
   310  // WARNING: this is used by plan gists, so if this encoding changes,
   311  // explain.gistVersion needs to be bumped.
   312  func (s *Fast) Encode(buf *bytes.Buffer) error {
   313  	if s.large != nil && s.large.Min() < 0 {
   314  		return errors.AssertionFailedf("Encode used with negative elements")
   315  	}
   316  
   317  	// This slice should stay on stack. We only need enough bytes to encode a 0
   318  	// and then an arbitrary 64-bit integer.
   319  	//gcassert:noescape
   320  	tmp := make([]byte, binary.MaxVarintLen64+1)
   321  
   322  	if s.small.hi == 0 && s.fitsInSmall() {
   323  		n := binary.PutUvarint(tmp, 0)
   324  		n += binary.PutUvarint(tmp[n:], s.small.lo)
   325  		buf.Write(tmp[:n])
   326  	} else {
   327  		n := binary.PutUvarint(tmp, uint64(s.Len()))
   328  		buf.Write(tmp[:n])
   329  		for i, ok := s.Next(0); ok; i, ok = s.Next(i + 1) {
   330  			n := binary.PutUvarint(tmp, uint64(i))
   331  			buf.Write(tmp[:n])
   332  		}
   333  	}
   334  	return nil
   335  }
   336  
   337  // Decode does the opposite of Encode. The contents of the receiver are
   338  // overwritten.
   339  func (s *Fast) Decode(br io.ByteReader) error {
   340  	length, err := binary.ReadUvarint(br)
   341  	if err != nil {
   342  		return err
   343  	}
   344  	*s = Fast{}
   345  
   346  	if length == 0 {
   347  		// Special case: a 64-bit bitmap is encoded directly.
   348  		val, err := binary.ReadUvarint(br)
   349  		if err != nil {
   350  			return err
   351  		}
   352  		s.small.lo = val
   353  	} else {
   354  		for i := 0; i < int(length); i++ {
   355  			elem, err := binary.ReadUvarint(br)
   356  			if err != nil {
   357  				*s = Fast{}
   358  				return err
   359  			}
   360  			s.Add(int(elem))
   361  		}
   362  	}
   363  	return nil
   364  }