github.com/markusbkk/elvish@v0.0.0-20231204143114-91dc52438621/pkg/persistent/hashmap/hashmap.go (about)

     1  // Package hashmap implements persistent hashmap.
     2  package hashmap
     3  
     4  import (
     5  	"bytes"
     6  	"encoding"
     7  	"encoding/json"
     8  	"fmt"
     9  	"reflect"
    10  	"strconv"
    11  )
    12  
    13  const (
    14  	chunkBits = 5
    15  	nodeCap   = 1 << chunkBits
    16  	chunkMask = nodeCap - 1
    17  )
    18  
    19  // Equal is the type of a function that reports whether two keys are equal.
    20  type Equal func(k1, k2 interface{}) bool
    21  
    22  // Hash is the type of a function that returns the hash code of a key.
    23  type Hash func(k interface{}) uint32
    24  
    25  // New takes an equality function and a hash function, and returns an empty
    26  // Map.
    27  func New(e Equal, h Hash) Map {
    28  	return &hashMap{0, emptyBitmapNode, nil, e, h}
    29  }
    30  
    31  type hashMap struct {
    32  	count int
    33  	root  node
    34  	nilV  *interface{}
    35  	equal Equal
    36  	hash  Hash
    37  }
    38  
    39  func (m *hashMap) Len() int {
    40  	return m.count
    41  }
    42  
    43  func (m *hashMap) Index(k interface{}) (interface{}, bool) {
    44  	if k == nil {
    45  		if m.nilV == nil {
    46  			return nil, false
    47  		}
    48  		return *m.nilV, true
    49  	}
    50  	return m.root.find(0, m.hash(k), k, m.equal)
    51  }
    52  
    53  func (m *hashMap) Assoc(k, v interface{}) Map {
    54  	if k == nil {
    55  		newCount := m.count
    56  		if m.nilV == nil {
    57  			newCount++
    58  		}
    59  		return &hashMap{newCount, m.root, &v, m.equal, m.hash}
    60  	}
    61  	newRoot, added := m.root.assoc(0, m.hash(k), k, v, m.hash, m.equal)
    62  	newCount := m.count
    63  	if added {
    64  		newCount++
    65  	}
    66  	return &hashMap{newCount, newRoot, m.nilV, m.equal, m.hash}
    67  }
    68  
    69  func (m *hashMap) Dissoc(k interface{}) Map {
    70  	if k == nil {
    71  		newCount := m.count
    72  		if m.nilV != nil {
    73  			newCount--
    74  		}
    75  		return &hashMap{newCount, m.root, nil, m.equal, m.hash}
    76  	}
    77  	newRoot, deleted := m.root.without(0, m.hash(k), k, m.equal)
    78  	newCount := m.count
    79  	if deleted {
    80  		newCount--
    81  	}
    82  	return &hashMap{newCount, newRoot, m.nilV, m.equal, m.hash}
    83  }
    84  
    85  func (m *hashMap) Iterator() Iterator {
    86  	if m.nilV != nil {
    87  		return &nilVIterator{true, *m.nilV, m.root.iterator()}
    88  	}
    89  	return m.root.iterator()
    90  }
    91  
    92  type nilVIterator struct {
    93  	atNil bool
    94  	nilV  interface{}
    95  	tail  Iterator
    96  }
    97  
    98  func (it *nilVIterator) Elem() (interface{}, interface{}) {
    99  	if it.atNil {
   100  		return nil, it.nilV
   101  	}
   102  	return it.tail.Elem()
   103  }
   104  
   105  func (it *nilVIterator) HasElem() bool {
   106  	return it.atNil || it.tail.HasElem()
   107  }
   108  
   109  func (it *nilVIterator) Next() {
   110  	if it.atNil {
   111  		it.atNil = false
   112  	} else {
   113  		it.tail.Next()
   114  	}
   115  }
   116  
   117  func (m *hashMap) MarshalJSON() ([]byte, error) {
   118  	var buf bytes.Buffer
   119  	buf.WriteByte('{')
   120  	first := true
   121  	for it := m.Iterator(); it.HasElem(); it.Next() {
   122  		if first {
   123  			first = false
   124  		} else {
   125  			buf.WriteByte(',')
   126  		}
   127  		k, v := it.Elem()
   128  		kString, err := convertKey(k)
   129  		if err != nil {
   130  			return nil, err
   131  		}
   132  		kBytes, err := json.Marshal(kString)
   133  		if err != nil {
   134  			return nil, err
   135  		}
   136  		vBytes, err := json.Marshal(v)
   137  		if err != nil {
   138  			return nil, err
   139  		}
   140  		buf.Write(kBytes)
   141  		buf.WriteByte(':')
   142  		buf.Write(vBytes)
   143  	}
   144  	buf.WriteByte('}')
   145  	return buf.Bytes(), nil
   146  }
   147  
   148  // convertKey converts a map key to a string. The implementation matches the
   149  // behavior of how json.Marshal encodes keys of the builtin map type.
   150  func convertKey(k interface{}) (string, error) {
   151  	kref := reflect.ValueOf(k)
   152  	if kref.Kind() == reflect.String {
   153  		return kref.String(), nil
   154  	}
   155  	if t, ok := k.(encoding.TextMarshaler); ok {
   156  		b2, err := t.MarshalText()
   157  		if err != nil {
   158  			return "", err
   159  		}
   160  		return string(b2), nil
   161  	}
   162  	switch kref.Kind() {
   163  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   164  		return strconv.FormatInt(kref.Int(), 10), nil
   165  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   166  		return strconv.FormatUint(kref.Uint(), 10), nil
   167  	}
   168  	return "", fmt.Errorf("unsupported key type %T", k)
   169  }
   170  
   171  // node is an interface for all nodes in the hash map tree.
   172  type node interface {
   173  	// assoc adds a new pair of key and value. It returns the new node, and
   174  	// whether the key did not exist before (i.e. a new pair has been added,
   175  	// instead of replaced).
   176  	assoc(shift, hash uint32, k, v interface{}, h Hash, eq Equal) (node, bool)
   177  	// without removes a key. It returns the new node and whether the key did
   178  	// not exist before (i.e. a key was indeed removed).
   179  	without(shift, hash uint32, k interface{}, eq Equal) (node, bool)
   180  	// find finds the value for a key. It returns the found value (if any) and
   181  	// whether such a pair exists.
   182  	find(shift, hash uint32, k interface{}, eq Equal) (interface{}, bool)
   183  	// iterator returns an iterator.
   184  	iterator() Iterator
   185  }
   186  
   187  // arrayNode stores all of its children in an array. The array is always at
   188  // least 1/4 full, otherwise it will be packed into a bitmapNode.
   189  type arrayNode struct {
   190  	nChildren int
   191  	children  [nodeCap]node
   192  }
   193  
   194  func (n *arrayNode) withNewChild(i uint32, newChild node, d int) *arrayNode {
   195  	newChildren := n.children
   196  	newChildren[i] = newChild
   197  	return &arrayNode{n.nChildren + d, newChildren}
   198  }
   199  
   200  func (n *arrayNode) assoc(shift, hash uint32, k, v interface{}, h Hash, eq Equal) (node, bool) {
   201  	idx := chunk(shift, hash)
   202  	child := n.children[idx]
   203  	if child == nil {
   204  		newChild, _ := emptyBitmapNode.assoc(shift+chunkBits, hash, k, v, h, eq)
   205  		return n.withNewChild(idx, newChild, 1), true
   206  	}
   207  	newChild, added := child.assoc(shift+chunkBits, hash, k, v, h, eq)
   208  	return n.withNewChild(idx, newChild, 0), added
   209  }
   210  
   211  func (n *arrayNode) without(shift, hash uint32, k interface{}, eq Equal) (node, bool) {
   212  	idx := chunk(shift, hash)
   213  	child := n.children[idx]
   214  	if child == nil {
   215  		return n, false
   216  	}
   217  	newChild, _ := child.without(shift+chunkBits, hash, k, eq)
   218  	if newChild == child {
   219  		return n, false
   220  	}
   221  	if newChild == emptyBitmapNode {
   222  		if n.nChildren <= nodeCap/4 {
   223  			// less than 1/4 full; shrink
   224  			return n.pack(int(idx)), true
   225  		}
   226  		return n.withNewChild(idx, nil, -1), true
   227  	}
   228  	return n.withNewChild(idx, newChild, 0), true
   229  }
   230  
   231  func (n *arrayNode) pack(skip int) *bitmapNode {
   232  	newNode := bitmapNode{0, make([]mapEntry, n.nChildren-1)}
   233  	j := 0
   234  	for i, child := range n.children {
   235  		// TODO(xiaq): Benchmark performance difference after unrolling this
   236  		// into two loops without the if
   237  		if i != skip && child != nil {
   238  			newNode.bitmap |= 1 << uint(i)
   239  			newNode.entries[j].value = child
   240  			j++
   241  		}
   242  	}
   243  	return &newNode
   244  }
   245  
   246  func (n *arrayNode) find(shift, hash uint32, k interface{}, eq Equal) (interface{}, bool) {
   247  	idx := chunk(shift, hash)
   248  	child := n.children[idx]
   249  	if child == nil {
   250  		return nil, false
   251  	}
   252  	return child.find(shift+chunkBits, hash, k, eq)
   253  }
   254  
   255  func (n *arrayNode) iterator() Iterator {
   256  	it := &arrayNodeIterator{n, 0, nil}
   257  	it.fixCurrent()
   258  	return it
   259  }
   260  
   261  type arrayNodeIterator struct {
   262  	n       *arrayNode
   263  	index   int
   264  	current Iterator
   265  }
   266  
   267  func (it *arrayNodeIterator) fixCurrent() {
   268  	for ; it.index < nodeCap && it.n.children[it.index] == nil; it.index++ {
   269  	}
   270  	if it.index < nodeCap {
   271  		it.current = it.n.children[it.index].iterator()
   272  	} else {
   273  		it.current = nil
   274  	}
   275  }
   276  
   277  func (it *arrayNodeIterator) Elem() (interface{}, interface{}) {
   278  	return it.current.Elem()
   279  }
   280  
   281  func (it *arrayNodeIterator) HasElem() bool {
   282  	return it.current != nil
   283  }
   284  
   285  func (it *arrayNodeIterator) Next() {
   286  	it.current.Next()
   287  	if !it.current.HasElem() {
   288  		it.index++
   289  		it.fixCurrent()
   290  	}
   291  }
   292  
   293  var emptyBitmapNode = &bitmapNode{}
   294  
   295  type bitmapNode struct {
   296  	bitmap  uint32
   297  	entries []mapEntry
   298  }
   299  
   300  // mapEntry is a map entry. When used in a collisionNode, it is also an entry
   301  // with non-nil key. When used in a bitmapNode, it is also abused to represent
   302  // children when the key is nil.
   303  type mapEntry struct {
   304  	key   interface{}
   305  	value interface{}
   306  }
   307  
   308  func chunk(shift, hash uint32) uint32 {
   309  	return (hash >> shift) & chunkMask
   310  }
   311  
   312  func bitpos(shift, hash uint32) uint32 {
   313  	return 1 << chunk(shift, hash)
   314  }
   315  
   316  func index(bitmap, bit uint32) uint32 {
   317  	return popCount(bitmap & (bit - 1))
   318  }
   319  
   320  const (
   321  	m1  uint32 = 0x55555555
   322  	m2  uint32 = 0x33333333
   323  	m4  uint32 = 0x0f0f0f0f
   324  	m8  uint32 = 0x00ff00ff
   325  	m16 uint32 = 0x0000ffff
   326  )
   327  
   328  // TODO(xiaq): Use an optimized implementation.
   329  func popCount(u uint32) uint32 {
   330  	u = (u & m1) + ((u >> 1) & m1)
   331  	u = (u & m2) + ((u >> 2) & m2)
   332  	u = (u & m4) + ((u >> 4) & m4)
   333  	u = (u & m8) + ((u >> 8) & m8)
   334  	u = (u & m16) + ((u >> 16) & m16)
   335  	return u
   336  }
   337  
   338  func createNode(shift uint32, k1 interface{}, v1 interface{}, h2 uint32, k2 interface{}, v2 interface{}, h Hash, eq Equal) node {
   339  	h1 := h(k1)
   340  	if h1 == h2 {
   341  		return &collisionNode{h1, []mapEntry{{k1, v1}, {k2, v2}}}
   342  	}
   343  	n, _ := emptyBitmapNode.assoc(shift, h1, k1, v1, h, eq)
   344  	n, _ = n.assoc(shift, h2, k2, v2, h, eq)
   345  	return n
   346  }
   347  
   348  func (n *bitmapNode) unpack(shift, idx uint32, newChild node, h Hash, eq Equal) *arrayNode {
   349  	var newNode arrayNode
   350  	newNode.nChildren = len(n.entries) + 1
   351  	newNode.children[idx] = newChild
   352  	j := 0
   353  	for i := uint(0); i < nodeCap; i++ {
   354  		if (n.bitmap>>i)&1 != 0 {
   355  			entry := n.entries[j]
   356  			j++
   357  			if entry.key == nil {
   358  				newNode.children[i] = entry.value.(node)
   359  			} else {
   360  				newNode.children[i], _ = emptyBitmapNode.assoc(
   361  					shift+chunkBits, h(entry.key), entry.key, entry.value, h, eq)
   362  			}
   363  		}
   364  	}
   365  	return &newNode
   366  }
   367  
   368  func (n *bitmapNode) withoutEntry(bit, idx uint32) *bitmapNode {
   369  	if n.bitmap == bit {
   370  		return emptyBitmapNode
   371  	}
   372  	return &bitmapNode{n.bitmap ^ bit, withoutEntry(n.entries, idx)}
   373  }
   374  
   375  func withoutEntry(entries []mapEntry, idx uint32) []mapEntry {
   376  	newEntries := make([]mapEntry, len(entries)-1)
   377  	copy(newEntries[:idx], entries[:idx])
   378  	copy(newEntries[idx:], entries[idx+1:])
   379  	return newEntries
   380  }
   381  
   382  func (n *bitmapNode) withReplacedEntry(i uint32, entry mapEntry) *bitmapNode {
   383  	return &bitmapNode{n.bitmap, replaceEntry(n.entries, i, entry.key, entry.value)}
   384  }
   385  
   386  func replaceEntry(entries []mapEntry, i uint32, k, v interface{}) []mapEntry {
   387  	newEntries := append([]mapEntry(nil), entries...)
   388  	newEntries[i] = mapEntry{k, v}
   389  	return newEntries
   390  }
   391  
   392  func (n *bitmapNode) assoc(shift, hash uint32, k, v interface{}, h Hash, eq Equal) (node, bool) {
   393  	bit := bitpos(shift, hash)
   394  	idx := index(n.bitmap, bit)
   395  	if n.bitmap&bit == 0 {
   396  		// Entry does not exist yet
   397  		nEntries := len(n.entries)
   398  		if nEntries >= nodeCap/2 {
   399  			// Unpack into an arrayNode
   400  			newNode, _ := emptyBitmapNode.assoc(shift+chunkBits, hash, k, v, h, eq)
   401  			return n.unpack(shift, chunk(shift, hash), newNode, h, eq), true
   402  		}
   403  		// Add a new entry
   404  		newEntries := make([]mapEntry, len(n.entries)+1)
   405  		copy(newEntries[:idx], n.entries[:idx])
   406  		newEntries[idx] = mapEntry{k, v}
   407  		copy(newEntries[idx+1:], n.entries[idx:])
   408  		return &bitmapNode{n.bitmap | bit, newEntries}, true
   409  	}
   410  	// Entry exists
   411  	entry := n.entries[idx]
   412  	if entry.key == nil {
   413  		// Non-leaf child
   414  		child := entry.value.(node)
   415  		newChild, added := child.assoc(shift+chunkBits, hash, k, v, h, eq)
   416  		return n.withReplacedEntry(idx, mapEntry{nil, newChild}), added
   417  	}
   418  	// Leaf
   419  	if eq(k, entry.key) {
   420  		// Identical key, replace
   421  		return n.withReplacedEntry(idx, mapEntry{k, v}), false
   422  	}
   423  	// Create and insert new inner node
   424  	newNode := createNode(shift+chunkBits, entry.key, entry.value, hash, k, v, h, eq)
   425  	return n.withReplacedEntry(idx, mapEntry{nil, newNode}), true
   426  }
   427  
   428  func (n *bitmapNode) without(shift, hash uint32, k interface{}, eq Equal) (node, bool) {
   429  	bit := bitpos(shift, hash)
   430  	if n.bitmap&bit == 0 {
   431  		return n, false
   432  	}
   433  	idx := index(n.bitmap, bit)
   434  	entry := n.entries[idx]
   435  	if entry.key == nil {
   436  		// Non-leaf child
   437  		child := entry.value.(node)
   438  		newChild, deleted := child.without(shift+chunkBits, hash, k, eq)
   439  		if newChild == child {
   440  			return n, false
   441  		}
   442  		if newChild == emptyBitmapNode {
   443  			return n.withoutEntry(bit, idx), true
   444  		}
   445  		return n.withReplacedEntry(idx, mapEntry{nil, newChild}), deleted
   446  	} else if eq(entry.key, k) {
   447  		// Leaf, and this is the entry to delete.
   448  		return n.withoutEntry(bit, idx), true
   449  	}
   450  	// Nothing to delete.
   451  	return n, false
   452  }
   453  
   454  func (n *bitmapNode) find(shift, hash uint32, k interface{}, eq Equal) (interface{}, bool) {
   455  	bit := bitpos(shift, hash)
   456  	if n.bitmap&bit == 0 {
   457  		return nil, false
   458  	}
   459  	idx := index(n.bitmap, bit)
   460  	entry := n.entries[idx]
   461  	if entry.key == nil {
   462  		child := entry.value.(node)
   463  		return child.find(shift+chunkBits, hash, k, eq)
   464  	} else if eq(entry.key, k) {
   465  		return entry.value, true
   466  	}
   467  	return nil, false
   468  }
   469  
   470  func (n *bitmapNode) iterator() Iterator {
   471  	it := &bitmapNodeIterator{n, 0, nil}
   472  	it.fixCurrent()
   473  	return it
   474  }
   475  
   476  type bitmapNodeIterator struct {
   477  	n       *bitmapNode
   478  	index   int
   479  	current Iterator
   480  }
   481  
   482  func (it *bitmapNodeIterator) fixCurrent() {
   483  	if it.index < len(it.n.entries) {
   484  		entry := it.n.entries[it.index]
   485  		if entry.key == nil {
   486  			it.current = entry.value.(node).iterator()
   487  		} else {
   488  			it.current = nil
   489  		}
   490  	} else {
   491  		it.current = nil
   492  	}
   493  }
   494  
   495  func (it *bitmapNodeIterator) Elem() (interface{}, interface{}) {
   496  	if it.current != nil {
   497  		return it.current.Elem()
   498  	}
   499  	entry := it.n.entries[it.index]
   500  	return entry.key, entry.value
   501  }
   502  
   503  func (it *bitmapNodeIterator) HasElem() bool {
   504  	return it.index < len(it.n.entries)
   505  }
   506  
   507  func (it *bitmapNodeIterator) Next() {
   508  	if it.current != nil {
   509  		it.current.Next()
   510  	}
   511  	if it.current == nil || !it.current.HasElem() {
   512  		it.index++
   513  		it.fixCurrent()
   514  	}
   515  }
   516  
   517  type collisionNode struct {
   518  	hash    uint32
   519  	entries []mapEntry
   520  }
   521  
   522  func (n *collisionNode) assoc(shift, hash uint32, k, v interface{}, h Hash, eq Equal) (node, bool) {
   523  	if hash == n.hash {
   524  		idx := n.findIndex(k, eq)
   525  		if idx != -1 {
   526  			return &collisionNode{
   527  				n.hash, replaceEntry(n.entries, uint32(idx), k, v)}, false
   528  		}
   529  		newEntries := make([]mapEntry, len(n.entries)+1)
   530  		copy(newEntries[:len(n.entries)], n.entries[:])
   531  		newEntries[len(n.entries)] = mapEntry{k, v}
   532  		return &collisionNode{n.hash, newEntries}, true
   533  	}
   534  	// Wrap in a bitmapNode and add the entry
   535  	wrap := bitmapNode{bitpos(shift, n.hash), []mapEntry{{nil, n}}}
   536  	return wrap.assoc(shift, hash, k, v, h, eq)
   537  }
   538  
   539  func (n *collisionNode) without(shift, hash uint32, k interface{}, eq Equal) (node, bool) {
   540  	idx := n.findIndex(k, eq)
   541  	if idx == -1 {
   542  		return n, false
   543  	}
   544  	if len(n.entries) == 1 {
   545  		return emptyBitmapNode, true
   546  	}
   547  	return &collisionNode{n.hash, withoutEntry(n.entries, uint32(idx))}, true
   548  }
   549  
   550  func (n *collisionNode) find(shift, hash uint32, k interface{}, eq Equal) (interface{}, bool) {
   551  	idx := n.findIndex(k, eq)
   552  	if idx == -1 {
   553  		return nil, false
   554  	}
   555  	return n.entries[idx].value, true
   556  }
   557  
   558  func (n *collisionNode) findIndex(k interface{}, eq Equal) int {
   559  	for i, entry := range n.entries {
   560  		if eq(k, entry.key) {
   561  			return i
   562  		}
   563  	}
   564  	return -1
   565  }
   566  
   567  func (n *collisionNode) iterator() Iterator {
   568  	return &collisionNodeIterator{n, 0}
   569  }
   570  
   571  type collisionNodeIterator struct {
   572  	n     *collisionNode
   573  	index int
   574  }
   575  
   576  func (it *collisionNodeIterator) Elem() (interface{}, interface{}) {
   577  	entry := it.n.entries[it.index]
   578  	return entry.key, entry.value
   579  }
   580  
   581  func (it *collisionNodeIterator) HasElem() bool {
   582  	return it.index < len(it.n.entries)
   583  }
   584  
   585  func (it *collisionNodeIterator) Next() {
   586  	it.index++
   587  }