github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/util/fast_int_map.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package util
    12  
    13  import (
    14  	"bytes"
    15  	"fmt"
    16  	"math/bits"
    17  	"sort"
    18  
    19  	"golang.org/x/tools/container/intsets"
    20  )
    21  
    22  // FastIntMap is a replacement for map[int]int which is more efficient when both
    23  // keys and values are small. It can be passed by value (but Copy must be used
    24  // for independent modification of copies).
    25  type FastIntMap struct {
    26  	small [numWords]uint64
    27  	large map[int]int
    28  }
    29  
    30  // Empty returns true if the map is empty.
    31  func (m FastIntMap) Empty() bool {
    32  	return m.small == [numWords]uint64{} && len(m.large) == 0
    33  }
    34  
    35  // Copy returns a FastIntMap that can be independently modified.
    36  func (m FastIntMap) Copy() FastIntMap {
    37  	if m.large == nil {
    38  		return FastIntMap{small: m.small}
    39  	}
    40  	largeCopy := make(map[int]int, len(m.large))
    41  	for k, v := range m.large {
    42  		largeCopy[k] = v
    43  	}
    44  	return FastIntMap{large: largeCopy}
    45  }
    46  
    47  // Set maps a key to the given value.
    48  func (m *FastIntMap) Set(key, val int) {
    49  	if m.large == nil {
    50  		if key >= 0 && key < numVals && val >= 0 && val <= maxValue {
    51  			m.setSmallVal(uint32(key), int32(val))
    52  			return
    53  		}
    54  		m.large = m.toLarge()
    55  		m.small = [numWords]uint64{}
    56  	}
    57  	m.large[key] = val
    58  }
    59  
    60  // Unset unmaps the given key.
    61  func (m *FastIntMap) Unset(key int) {
    62  	if m.large == nil {
    63  		if key < 0 || key >= numVals {
    64  			return
    65  		}
    66  		m.setSmallVal(uint32(key), -1)
    67  	}
    68  	delete(m.large, key)
    69  }
    70  
    71  // Get returns the current value mapped to key, or ok=false if the
    72  // key is unmapped.
    73  func (m FastIntMap) Get(key int) (value int, ok bool) {
    74  	if m.large == nil {
    75  		if key < 0 || key >= numVals {
    76  			return -1, false
    77  		}
    78  		val := m.getSmallVal(uint32(key))
    79  		return int(val), (val != -1)
    80  	}
    81  	value, ok = m.large[key]
    82  	return value, ok
    83  }
    84  
    85  // Len returns the number of keys in the map.
    86  func (m FastIntMap) Len() int {
    87  	if m.large != nil {
    88  		return len(m.large)
    89  	}
    90  	res := 0
    91  	for w := 0; w < numWords; w++ {
    92  		v := m.small[w]
    93  		// We want to count the number of non-zero groups. To do this, we OR all
    94  		// the bits of each group into the low-bit of that group, apply a mask
    95  		// selecting just those low bits and count the number of 1s.
    96  		// To OR the bits efficiently, we first OR the high half of each group into
    97  		// the low half of each group, and repeat.
    98  		// Note: this code assumes that numBits is a power of two.
    99  		for i := uint32(numBits / 2); i > 0; i /= 2 {
   100  			v |= (v >> i)
   101  		}
   102  		res += bits.OnesCount64(v & groupLowBitMask)
   103  	}
   104  	return res
   105  }
   106  
   107  // MaxKey returns the maximum key that is in the map. If the map
   108  // is empty, returns ok=false.
   109  func (m FastIntMap) MaxKey() (_ int, ok bool) {
   110  	if m.large == nil {
   111  		for w := numWords - 1; w >= 0; w-- {
   112  			if val := m.small[w]; val != 0 {
   113  				// Example (with numBits = 4)
   114  				//   pos:   3    2    1    0
   115  				//   bits:  0000 0000 0010 0000
   116  				// To get the left-most non-zero group, we calculate how many groups are
   117  				// covered by the leading zeros.
   118  				pos := numValsPerWord - 1 - bits.LeadingZeros64(val)/numBits
   119  				return w*numValsPerWord + pos, true
   120  			}
   121  		}
   122  		return 0, false
   123  	}
   124  	if len(m.large) == 0 {
   125  		return 0, false
   126  	}
   127  	max := intsets.MinInt
   128  	for k := range m.large {
   129  		if max < k {
   130  			max = k
   131  		}
   132  	}
   133  	return max, true
   134  }
   135  
   136  // MaxValue returns the maximum value that is in the map. If the map
   137  // is empty, returns ok=false.
   138  func (m FastIntMap) MaxValue() (_ int, ok bool) {
   139  	if m.large == nil {
   140  		// In the small case, all values are positive.
   141  		max := -1
   142  		for w := 0; w < numWords; w++ {
   143  			if m.small[w] != 0 {
   144  				// To optimize for small maps, we stop when the rest of the values are
   145  				// unset. See the comment in MaxKey.
   146  				numVals := numValsPerWord - bits.LeadingZeros64(m.small[w])/numBits
   147  				for i := 0; i < numVals; i++ {
   148  					val := int(m.getSmallVal(uint32(w*numValsPerWord + i)))
   149  					// NB: val is -1 here if this key isn't in the map.
   150  					if max < val {
   151  						max = val
   152  					}
   153  				}
   154  			}
   155  		}
   156  		if max == -1 {
   157  			return 0, false
   158  		}
   159  		return max, true
   160  	}
   161  	if len(m.large) == 0 {
   162  		return 0, false
   163  	}
   164  	max := intsets.MinInt
   165  	for _, v := range m.large {
   166  		if max < v {
   167  			max = v
   168  		}
   169  	}
   170  	return max, true
   171  }
   172  
   173  // ForEach calls the given function for each key/value pair in the map (in
   174  // arbitrary order).
   175  func (m FastIntMap) ForEach(fn func(key, val int)) {
   176  	if m.large == nil {
   177  		for i := 0; i < numVals; i++ {
   178  			if val := m.getSmallVal(uint32(i)); val != -1 {
   179  				fn(i, int(val))
   180  			}
   181  		}
   182  	} else {
   183  		for k, v := range m.large {
   184  			fn(k, v)
   185  		}
   186  	}
   187  }
   188  
   189  // String prints out the contents of the map in the following format:
   190  //   map[key1:val1 key2:val2 ...]
   191  // The keys are in ascending order.
   192  func (m FastIntMap) String() string {
   193  	var buf bytes.Buffer
   194  	buf.WriteString("map[")
   195  	first := true
   196  
   197  	if m.large != nil {
   198  		keys := make([]int, 0, len(m.large))
   199  		for k := range m.large {
   200  			keys = append(keys, k)
   201  		}
   202  		sort.Ints(keys)
   203  		for _, k := range keys {
   204  			if !first {
   205  				buf.WriteByte(' ')
   206  			}
   207  			first = false
   208  			fmt.Fprintf(&buf, "%d:%d", k, m.large[k])
   209  		}
   210  	} else {
   211  		for i := 0; i < numVals; i++ {
   212  			if val := m.getSmallVal(uint32(i)); val != -1 {
   213  				if !first {
   214  					buf.WriteByte(' ')
   215  				}
   216  				first = false
   217  				fmt.Fprintf(&buf, "%d:%d", i, val)
   218  			}
   219  		}
   220  	}
   221  	buf.WriteByte(']')
   222  	return buf.String()
   223  }
   224  
   225  // These constants determine the "small" representation: we pack <numVals>
   226  // values of <numBits> bits into <numWords> 64-bit words. Each value is 0 if the
   227  // corresponding key is not set, otherwise it is the value+1.
   228  //
   229  // It's desirable for efficiency that numBits, numValsPerWord are powers of two.
   230  //
   231  // The current settings support a map from keys in [0, 31] to values in [0, 14].
   232  // Note that one value is reserved to indicate an unmapped element.
   233  const (
   234  	numWords       = 2
   235  	numBits        = 4
   236  	numValsPerWord = 64 / numBits              // 16
   237  	numVals        = numWords * numValsPerWord // 32
   238  	mask           = (1 << numBits) - 1
   239  	maxValue       = mask - 1
   240  	// Mask for the low bits of each group: 0001 0001 0001 ...
   241  	groupLowBitMask = 0x1111111111111111
   242  )
   243  
   244  // Returns -1 if the value is unmapped.
   245  func (m FastIntMap) getSmallVal(idx uint32) int32 {
   246  	word := idx / numValsPerWord
   247  	pos := (idx % numValsPerWord) * numBits
   248  	return int32((m.small[word]>>pos)&mask) - 1
   249  }
   250  
   251  func (m *FastIntMap) setSmallVal(idx uint32, val int32) {
   252  	word := idx / numValsPerWord
   253  	pos := (idx % numValsPerWord) * numBits
   254  	// Clear out any previous value
   255  	m.small[word] &= ^(mask << pos)
   256  	m.small[word] |= uint64(val+1) << pos
   257  }
   258  
   259  func (m *FastIntMap) toLarge() map[int]int {
   260  	res := make(map[int]int, numVals)
   261  	for i := 0; i < numVals; i++ {
   262  		val := m.getSmallVal(uint32(i))
   263  		if val != -1 {
   264  			res[i] = int(val)
   265  		}
   266  	}
   267  	return res
   268  }