github.com/zeebo/mon@v0.0.0-20211012163247-13d39bdb54fa/internal/lfht/lfht.go (about)

     1  package lfht
     2  
     3  import (
     4  	"fmt"
     5  	"math/bits"
     6  	"sync/atomic"
     7  	"unsafe"
     8  
     9  	"github.com/zeebo/mon/internal/bitmap"
    10  	"github.com/zeebo/xxh3"
    11  )
    12  
    13  // https://repositorio.inesctec.pt/bitstream/123456789/5465/1/P-00F-YAG.pdf
    14  
    15  //
    16  // parameters for the table
    17  //
    18  
    19  const (
    20  	_width    = 3
    21  	_entries  = 1 << _width
    22  	_mask     = _entries - 1
    23  	_bits     = bits.UintSize
    24  	_depth    = 3
    25  	_maxLevel = _bits / _width
    26  )
    27  
    28  //
    29  // shorten some common phrases
    30  //
    31  
    32  type ptr = unsafe.Pointer
    33  
    34  func cas(addr *ptr, old, new ptr) bool { return atomic.CompareAndSwapPointer(addr, old, new) }
    35  func load(addr *ptr) ptr               { return atomic.LoadPointer(addr) }
    36  func store(addr *ptr, val ptr)         { atomic.StorePointer(addr, val) }
    37  
    38  func tag(b *Table) ptr   { return ptr(uintptr(ptr(b)) + 1) }
    39  func tagged(p ptr) bool  { return uintptr(p)&1 > 0 }
    40  func untag(p ptr) *Table { return (*Table)(ptr(uintptr(p) - 1)) }
    41  
    42  //
    43  // hashing support
    44  //
    45  
    46  func hash(x string) uintptr {
    47  	return uintptr(xxh3.HashString(x))
    48  }
    49  
    50  //
    51  // helper data types
    52  //
    53  
    54  type lazyValue struct {
    55  	value ptr
    56  	fn    func() ptr
    57  }
    58  
    59  func (lv *lazyValue) get() ptr {
    60  	if lv.value == nil {
    61  		lv.value = lv.fn()
    62  	}
    63  	return lv.value
    64  }
    65  
    66  type hashedKey struct {
    67  	key  string
    68  	hash uintptr
    69  }
    70  
    71  //
    72  // data structrue
    73  //
    74  
    75  type tableHeader struct {
    76  	level  uint
    77  	prev   *Table
    78  	bitmap bitmap.B128
    79  }
    80  
    81  type Table struct {
    82  	tableHeader
    83  	_       [64 - unsafe.Sizeof(tableHeader{})]byte // pad to cache line
    84  	buckets [_entries]ptr
    85  }
    86  
    87  func (t *Table) getHashBucket(hash uintptr) (*ptr, uint) {
    88  	idx := uint(hash>>((t.level*_width)&(_bits-1))) & _mask
    89  	return &t.buckets[idx], idx
    90  }
    91  
    92  type node struct {
    93  	key   string
    94  	value ptr
    95  	next  ptr
    96  }
    97  
    98  func (n *node) getNextRef() *ptr { return &n.next }
    99  
   100  //
   101  // upsert
   102  //
   103  
   104  func (t *Table) Upsert(k string, vf func() unsafe.Pointer) unsafe.Pointer {
   105  	return t.upsert(hashedKey{key: k, hash: hash(k)}, lazyValue{fn: vf}).value
   106  }
   107  
   108  func (t *Table) upsert(key hashedKey, value lazyValue) *node {
   109  	bucket, idx := t.getHashBucket(key.hash)
   110  	entryRef := load(bucket)
   111  	if entryRef == nil {
   112  		newNode := &node{key: key.key, value: value.get(), next: tag(t)}
   113  		if cas(bucket, nil, ptr(newNode)) {
   114  			t.bitmap.Set(idx)
   115  			return newNode
   116  		}
   117  		entryRef = load(bucket)
   118  	}
   119  
   120  	if tagged(entryRef) {
   121  		return untag(entryRef).upsert(key, value)
   122  	}
   123  	return (*node)(entryRef).upsert(key, value, t, 1)
   124  }
   125  
   126  func (n *node) upsert(key hashedKey, value lazyValue, t *Table, count int) *node {
   127  	if n.key == key.key {
   128  		return n
   129  	}
   130  
   131  	next := n.getNextRef()
   132  	nextRef := load(next)
   133  	if nextRef == tag(t) {
   134  		if count == _depth && t.level+1 < _maxLevel {
   135  			newTable := &Table{tableHeader: tableHeader{
   136  				level: t.level + 1,
   137  				prev:  t,
   138  			}}
   139  			if cas(next, tag(t), tag(newTable)) {
   140  				bucket, _ := t.getHashBucket(key.hash)
   141  				adjustChainNodes((*node)(load(bucket)), newTable)
   142  				store(bucket, tag(newTable))
   143  				return newTable.upsert(key, value)
   144  			}
   145  		} else {
   146  			newNode := &node{key: key.key, value: value.get(), next: tag(t)}
   147  			if cas(next, tag(t), ptr(newNode)) {
   148  				return newNode
   149  			}
   150  		}
   151  		nextRef = load(next)
   152  	}
   153  
   154  	if tagged(nextRef) {
   155  		prevTable := untag(nextRef)
   156  		for prevTable.prev != nil && prevTable.prev != t {
   157  			prevTable = prevTable.prev
   158  		}
   159  		return prevTable.upsert(key, value)
   160  	}
   161  	return (*node)(nextRef).upsert(key, value, t, count+1)
   162  }
   163  
   164  //
   165  // adjust
   166  //
   167  
   168  func adjustChainNodes(r *node, t *Table) {
   169  	next := r.getNextRef()
   170  	nextRef := load(next)
   171  	if nextRef != tag(t) {
   172  		adjustChainNodes((*node)(nextRef), t)
   173  	}
   174  	t.adjustNode(r)
   175  }
   176  
   177  func (t *Table) adjustNode(n *node) {
   178  	next := n.getNextRef()
   179  	store(next, tag(t))
   180  
   181  	bucket, idx := t.getHashBucket(hash(n.key))
   182  	entryRef := load(bucket)
   183  	if entryRef == nil {
   184  		if cas(bucket, nil, ptr(n)) {
   185  			t.bitmap.Set(idx)
   186  			return
   187  		}
   188  		entryRef = load(bucket)
   189  	}
   190  
   191  	if tagged(entryRef) {
   192  		untag(entryRef).adjustNode(n)
   193  		return
   194  	}
   195  	n.adjustNode(t, (*node)(entryRef), 1)
   196  }
   197  
   198  func (n *node) adjustNode(t *Table, r *node, count int) {
   199  	next := r.getNextRef()
   200  	nextRef := load(next)
   201  	if nextRef == tag(t) {
   202  		if count == _depth && t.level+1 < _maxLevel {
   203  			newTable := &Table{tableHeader: tableHeader{
   204  				level: t.level + 1,
   205  				prev:  t,
   206  			}}
   207  			if cas(next, tag(t), tag(newTable)) {
   208  				bucket, _ := t.getHashBucket(hash(n.key))
   209  				adjustChainNodes((*node)(load(bucket)), newTable)
   210  				store(bucket, tag(newTable))
   211  				newTable.adjustNode(n)
   212  				return
   213  			}
   214  		} else if cas(next, tag(t), ptr(n)) {
   215  			return
   216  		}
   217  		nextRef = load(next)
   218  	}
   219  
   220  	if tagged(nextRef) {
   221  		prevTable := untag(nextRef)
   222  		for prevTable.prev != nil && prevTable.prev != t {
   223  			prevTable = prevTable.prev
   224  		}
   225  		prevTable.adjustNode(n)
   226  		return
   227  	}
   228  	n.adjustNode(t, (*node)(nextRef), count+1)
   229  }
   230  
   231  //
   232  // lookup
   233  //
   234  
   235  func (t *Table) Lookup(k string) unsafe.Pointer {
   236  	return t.lookup(hashedKey{key: k, hash: hash(k)})
   237  }
   238  
   239  func (t *Table) lookup(key hashedKey) ptr {
   240  	// if lookup misses are frequent, it may be worthwhile to check
   241  	// the bitmap to avoid a cache miss loading the bucket.
   242  	bucket, _ := t.getHashBucket(key.hash)
   243  	entryRef := load(bucket)
   244  	if entryRef == nil {
   245  		return nil
   246  	}
   247  	if tagged(entryRef) {
   248  		return untag(entryRef).lookup(key)
   249  	}
   250  	return (*node)(entryRef).lookup(key, t)
   251  }
   252  
   253  func (n *node) lookup(key hashedKey, t *Table) ptr {
   254  	if n.key == key.key {
   255  		return n.value
   256  	}
   257  
   258  	next := n.getNextRef()
   259  	nextRef := load(next)
   260  	if tagged(nextRef) {
   261  		prevTable := untag(nextRef)
   262  		for prevTable.prev != nil && prevTable.prev != t {
   263  			prevTable = prevTable.prev
   264  		}
   265  		return prevTable.lookup(key)
   266  	}
   267  	return (*node)(nextRef).lookup(key, t)
   268  }
   269  
   270  //
   271  // iterator
   272  //
   273  
   274  type Iterator struct {
   275  	n     *node
   276  	top   int
   277  	stack [_maxLevel]struct {
   278  		table *Table
   279  		pos   bitmap.B128
   280  	}
   281  }
   282  
   283  func (t *Table) Iterator() (itr Iterator) {
   284  	itr.stack[0].table = t
   285  	itr.stack[0].pos = t.bitmap.Clone()
   286  	return itr
   287  }
   288  
   289  func (i *Iterator) Next() bool {
   290  next:
   291  	// if the stack is empty, we're done
   292  	if i.top < 0 {
   293  		return false
   294  	}
   295  	is := &i.stack[i.top]
   296  
   297  	// if we don't have a node, load it from the top of the stack
   298  	var nextTable *Table
   299  	if i.n == nil {
   300  		idx, ok := is.pos.Next()
   301  		if !ok {
   302  			// if we've walked the whole table, pop it and try again
   303  			i.top--
   304  			goto next
   305  		}
   306  
   307  		bucket := &is.table.buckets[idx&127]
   308  		entryRef := load(bucket)
   309  
   310  		// if it's a node, set it and continue
   311  		if !tagged(entryRef) {
   312  			i.n = (*node)(entryRef)
   313  			return true
   314  		}
   315  
   316  		// otherwise, we need to walk to a new table.
   317  		nextTable = untag(entryRef)
   318  	} else {
   319  		// if we have a node, try to walk to the next entry.
   320  		nextRef := load(i.n.getNextRef())
   321  
   322  		// if it's a node, set it and continue
   323  		if !tagged(nextRef) {
   324  			i.n = (*node)(nextRef)
   325  			return true
   326  		}
   327  
   328  		// otherwise, we need to walk to a new table
   329  		nextTable = untag(nextRef)
   330  	}
   331  
   332  	// if we're on the same table, just go to the next entry
   333  	if nextTable == is.table {
   334  		i.n = nil
   335  		goto next
   336  	}
   337  
   338  	// walk nextTable backwards as much as possible.
   339  	for nextTable.prev != nil && nextTable.prev != is.table {
   340  		nextTable = nextTable.prev
   341  	}
   342  
   343  	// if it's a different table, push it on to the stack.
   344  	if nextTable != is.table {
   345  		i.top++
   346  		i.stack[i.top].table = nextTable
   347  		i.stack[i.top].pos = nextTable.bitmap.Clone()
   348  	}
   349  
   350  	// walk to the next entry in the top of the stack table
   351  	i.n = nil
   352  	goto next
   353  }
   354  
   355  func (i *Iterator) Key() string           { return i.n.key }
   356  func (i *Iterator) Value() unsafe.Pointer { return i.n.value }
   357  
   358  //
   359  // dumping code
   360  //
   361  
   362  const dumpIndent = "|    "
   363  
   364  func dumpPointer(indent string, p ptr) {
   365  	if tagged(p) {
   366  		table := untag(p)
   367  		fmt.Printf("%stable[%p]:\n", indent, table)
   368  		for i := range &table.buckets {
   369  			dumpPointer(indent+dumpIndent, load(&table.buckets[i]))
   370  		}
   371  	} else if p != nil {
   372  		n := (*node)(p)
   373  		p := load(&n.next)
   374  		fmt.Printf("%snode[%p](key:%q, value:%p, next:%p):\n", indent, n, n.key, n.value, p)
   375  		if !tagged(p) {
   376  			dumpPointer(indent+dumpIndent, load(&n.next))
   377  		}
   378  	}
   379  }
   380  
   381  func (t *Table) dump() { dumpPointer("", tag(t)) }