github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/util/fast_int_map.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package util
    12  
    13  import (
    14  	"bytes"
    15  	"fmt"
    16  	"math/bits"
    17  	"sort"
    18  
    19  	"github.com/cockroachdb/cockroachdb-parser/pkg/util/intsets"
    20  )
    21  
    22  // FastIntMap is a replacement for map[int]int which is more efficient when both
    23  // keys and values are small. It can be passed by value (but Copy must be used
    24  // for independent modification of copies).
    25  type FastIntMap struct {
    26  	small [numWords]uint64
    27  	large map[int]int
    28  }
    29  
    30  // Empty returns true if the map is empty.
    31  func (m FastIntMap) Empty() bool {
    32  	return m.small == [numWords]uint64{} && len(m.large) == 0
    33  }
    34  
    35  // Copy returns a FastIntMap that can be independently modified.
    36  func (m FastIntMap) Copy() FastIntMap {
    37  	if m.large == nil {
    38  		return FastIntMap{small: m.small}
    39  	}
    40  	largeCopy := make(map[int]int, len(m.large))
    41  	for k, v := range m.large {
    42  		largeCopy[k] = v
    43  	}
    44  	return FastIntMap{large: largeCopy}
    45  }
    46  
    47  // Set maps a key to the given value.
    48  func (m *FastIntMap) Set(key, val int) {
    49  	if m.large == nil {
    50  		if key >= 0 && key < numVals && val >= 0 && val <= maxValue {
    51  			m.setSmallVal(uint32(key), int32(val))
    52  			return
    53  		}
    54  		m.large = m.toLarge()
    55  		m.small = [numWords]uint64{}
    56  	}
    57  	m.large[key] = val
    58  }
    59  
    60  // Unset unmaps the given key.
    61  func (m *FastIntMap) Unset(key int) {
    62  	if m.large == nil {
    63  		if key < 0 || key >= numVals {
    64  			return
    65  		}
    66  		m.setSmallVal(uint32(key), -1)
    67  	}
    68  	delete(m.large, key)
    69  }
    70  
    71  // Get returns the current value mapped to key, or (-1, false) if the
    72  // key is unmapped.
    73  func (m FastIntMap) Get(key int) (value int, ok bool) {
    74  	if m.large == nil {
    75  		if key < 0 || key >= numVals {
    76  			return -1, false
    77  		}
    78  		val := m.getSmallVal(uint32(key))
    79  		return int(val), (val != -1)
    80  	}
    81  	if value, ok = m.large[key]; ok {
    82  		return value, true
    83  	}
    84  	return -1, false
    85  }
    86  
    87  // GetDefault returns the current value mapped to key, or 0 if the key is
    88  // unmapped.
    89  func (m FastIntMap) GetDefault(key int) (value int) {
    90  	value, ok := m.Get(key)
    91  	if !ok {
    92  		return 0
    93  	}
    94  	return value
    95  }
    96  
    97  // Len returns the number of keys in the map.
    98  func (m FastIntMap) Len() int {
    99  	if m.large != nil {
   100  		return len(m.large)
   101  	}
   102  	res := 0
   103  	for w := 0; w < numWords; w++ {
   104  		v := m.small[w]
   105  		// We want to count the number of non-zero groups. To do this, we OR all
   106  		// the bits of each group into the low-bit of that group, apply a mask
   107  		// selecting just those low bits and count the number of 1s.
   108  		// To OR the bits efficiently, we first OR the high half of each group into
   109  		// the low half of each group, and repeat.
   110  		// Note: this code assumes that numBits is a power of two.
   111  		for i := uint32(numBits / 2); i > 0; i /= 2 {
   112  			v |= (v >> i)
   113  		}
   114  		res += bits.OnesCount64(v & groupLowBitMask)
   115  	}
   116  	return res
   117  }
   118  
   119  // MaxKey returns the maximum key that is in the map. If the map
   120  // is empty, returns ok=false.
   121  func (m FastIntMap) MaxKey() (_ int, ok bool) {
   122  	if m.large == nil {
   123  		for w := numWords - 1; w >= 0; w-- {
   124  			if val := m.small[w]; val != 0 {
   125  				// Example (with numBits = 4)
   126  				//   pos:   3    2    1    0
   127  				//   bits:  0000 0000 0010 0000
   128  				// To get the left-most non-zero group, we calculate how many groups are
   129  				// covered by the leading zeros.
   130  				pos := numValsPerWord - 1 - bits.LeadingZeros64(val)/numBits
   131  				return w*numValsPerWord + pos, true
   132  			}
   133  		}
   134  		return 0, false
   135  	}
   136  	if len(m.large) == 0 {
   137  		return 0, false
   138  	}
   139  	max := intsets.MinInt
   140  	for k := range m.large {
   141  		if max < k {
   142  			max = k
   143  		}
   144  	}
   145  	return max, true
   146  }
   147  
   148  // MaxValue returns the maximum value that is in the map. If the map
   149  // is empty, returns (0, false).
   150  func (m FastIntMap) MaxValue() (_ int, ok bool) {
   151  	if m.large == nil {
   152  		// In the small case, all values are positive.
   153  		max := -1
   154  		for w := 0; w < numWords; w++ {
   155  			if m.small[w] != 0 {
   156  				// To optimize for small maps, we stop when the rest of the values are
   157  				// unset. See the comment in MaxKey.
   158  				numVals := numValsPerWord - bits.LeadingZeros64(m.small[w])/numBits
   159  				for i := 0; i < numVals; i++ {
   160  					val := int(m.getSmallVal(uint32(w*numValsPerWord + i)))
   161  					// NB: val is -1 here if this key isn't in the map.
   162  					if max < val {
   163  						max = val
   164  					}
   165  				}
   166  			}
   167  		}
   168  		if max == -1 {
   169  			return 0, false
   170  		}
   171  		return max, true
   172  	}
   173  	if len(m.large) == 0 {
   174  		return 0, false
   175  	}
   176  	max := intsets.MinInt
   177  	for _, v := range m.large {
   178  		if max < v {
   179  			max = v
   180  		}
   181  	}
   182  	return max, true
   183  }
   184  
   185  // ForEach calls the given function for each key/value pair in the map (in
   186  // arbitrary order).
   187  func (m FastIntMap) ForEach(fn func(key, val int)) {
   188  	if m.large == nil {
   189  		for i := 0; i < numVals; i++ {
   190  			if val := m.getSmallVal(uint32(i)); val != -1 {
   191  				fn(i, int(val))
   192  			}
   193  		}
   194  	} else {
   195  		for k, v := range m.large {
   196  			fn(k, v)
   197  		}
   198  	}
   199  }
   200  
   201  // ContentsIntoBuffer writes the contents of the map into the provided buffer in
   202  // the following format:
   203  //
   204  //	key1:val1 key2:val2 ...
   205  //
   206  // The keys are in ascending order.
   207  func (m FastIntMap) ContentsIntoBuffer(buf *bytes.Buffer) {
   208  	first := true
   209  	if m.large != nil {
   210  		keys := make([]int, 0, len(m.large))
   211  		for k := range m.large {
   212  			keys = append(keys, k)
   213  		}
   214  		sort.Ints(keys)
   215  		for _, k := range keys {
   216  			if !first {
   217  				buf.WriteByte(' ')
   218  			}
   219  			first = false
   220  			fmt.Fprintf(buf, "%d:%d", k, m.large[k])
   221  		}
   222  	} else {
   223  		for i := 0; i < numVals; i++ {
   224  			if val := m.getSmallVal(uint32(i)); val != -1 {
   225  				if !first {
   226  					buf.WriteByte(' ')
   227  				}
   228  				first = false
   229  				fmt.Fprintf(buf, "%d:%d", i, val)
   230  			}
   231  		}
   232  	}
   233  }
   234  
   235  // String prints out the contents of the map in the following format:
   236  //
   237  //	map[key1:val1 key2:val2 ...]
   238  //
   239  // The keys are in ascending order.
   240  func (m FastIntMap) String() string {
   241  	var buf bytes.Buffer
   242  	buf.WriteString("map[")
   243  	m.ContentsIntoBuffer(&buf)
   244  	buf.WriteByte(']')
   245  	return buf.String()
   246  }
   247  
   248  // These constants determine the "small" representation: we pack <numVals>
   249  // values of <numBits> bits into <numWords> 64-bit words. Each value is 0 if the
   250  // corresponding key is not set, otherwise it is the value+1.
   251  //
   252  // It's desirable for efficiency that numBits, numValsPerWord are powers of two.
   253  //
   254  // The current settings support a map from keys in [0, 31] to values in [0, 14].
   255  // Note that one value is reserved to indicate an unmapped element.
   256  const (
   257  	numWords       = 2
   258  	numBits        = 4
   259  	numValsPerWord = 64 / numBits              // 16
   260  	numVals        = numWords * numValsPerWord // 32
   261  	mask           = (1 << numBits) - 1
   262  	maxValue       = mask - 1
   263  	// Mask for the low bits of each group: 0001 0001 0001 ...
   264  	groupLowBitMask = 0x1111111111111111
   265  )
   266  
   267  // Returns -1 if the value is unmapped.
   268  func (m FastIntMap) getSmallVal(idx uint32) int32 {
   269  	word := idx / numValsPerWord
   270  	pos := (idx % numValsPerWord) * numBits
   271  	return int32((m.small[word]>>pos)&mask) - 1
   272  }
   273  
   274  func (m *FastIntMap) setSmallVal(idx uint32, val int32) {
   275  	word := idx / numValsPerWord
   276  	pos := (idx % numValsPerWord) * numBits
   277  	// Clear out any previous value
   278  	m.small[word] &= ^(mask << pos)
   279  	m.small[word] |= uint64(val+1) << pos
   280  }
   281  
   282  func (m *FastIntMap) toLarge() map[int]int {
   283  	res := make(map[int]int, numVals)
   284  	for i := 0; i < numVals; i++ {
   285  		val := m.getSmallVal(uint32(i))
   286  		if val != -1 {
   287  			res[i] = int(val)
   288  		}
   289  	}
   290  	return res
   291  }