github.com/coyove/sdss@v0.0.0-20231129015646-c2ec58cca6a2/contrib/plru/rhmap.go (about)

     1  package plru
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"math"
     7  	"strconv"
     8  	"strings"
     9  	"unsafe"
    10  	_ "unsafe"
    11  )
    12  
    13  type Map[K comparable, V any] struct {
    14  	Fixed bool
    15  
    16  	count uint32
    17  	hash  func(K) uint64
    18  	items []hashItem[K, V]
    19  }
    20  
    21  // hashItem represents a slot in the map.
    22  type hashItem[K, V any] struct {
    23  	dist     uint32
    24  	occupied bool
    25  	Key      K
    26  	Value    V
    27  }
    28  
    29  func NewMap[K comparable, V any](size int, hash func(K) uint64) *Map[K, V] {
    30  	if size < 1 {
    31  		size = 1
    32  	}
    33  	obj := &Map[K, V]{hash: hash}
    34  	obj.items = make([]hashItem[K, V], size*2)
    35  	return obj
    36  }
    37  
    38  // Cap returns the capacity of the map.
    39  // Cap * 0.75 is the expanding threshold for non-fixed map.
    40  // Fixed map panic when keys exceed the capacity.
    41  func (m *Map[K, V]) Cap() int {
    42  	if m == nil {
    43  		return 0
    44  	}
    45  	return len(m.items)
    46  }
    47  
    48  // Len returns the count of keys in the map.
    49  func (m *Map[K, V]) Len() int {
    50  	if m == nil {
    51  		return 0
    52  	}
    53  	return int(m.count)
    54  }
    55  
    56  // Clear clears all keys in the map, allocated memory will be reused.
    57  func (m *Map[K, V]) Clear() {
    58  	for i := range m.items {
    59  		m.items[i] = hashItem[K, V]{}
    60  	}
    61  	m.count = 0
    62  }
    63  
    64  // Find finds the value by 'k', returns false as the second argument if not found.
    65  func (m *Map[K, V]) Find(k K) (v V, exists bool) {
    66  	if m == nil {
    67  		return
    68  	}
    69  	if idx := m.findValue(k); idx >= 0 {
    70  		return m.items[idx].Value, true
    71  	}
    72  	return v, false
    73  }
    74  
    75  // Get gets the value by 'k'.
    76  func (m *Map[K, V]) Get(k K) (v V) {
    77  	if m == nil {
    78  		return
    79  	}
    80  	if idx := m.findValue(k); idx >= 0 {
    81  		return m.items[idx].Value
    82  	}
    83  	return v
    84  }
    85  
    86  // Ref retrieves the value pointer by 'k', it is legal to alter what it points to
    87  // as long as the map stays unchanged.
    88  func (m *Map[K, V]) Ref(k K) (v *V) {
    89  	if m == nil {
    90  		return nil
    91  	}
    92  	if idx := m.findValue(k); idx >= 0 {
    93  		return &m.items[idx].Value
    94  	}
    95  	return nil
    96  }
    97  
    98  func (m *Map[K, V]) findValue(k K) int {
    99  	num := len(m.items)
   100  	if num <= 0 {
   101  		return -1
   102  	}
   103  	idx := int(m.hash(k) % uint64(num))
   104  	idxStart := idx
   105  
   106  	for {
   107  		e := &m.items[idx]
   108  		if !e.occupied {
   109  			return -1
   110  		}
   111  
   112  		if e.Key == k {
   113  			return idx
   114  		}
   115  
   116  		idx = (idx + 1) % num
   117  		if idx == idxStart {
   118  			return -1
   119  		}
   120  	}
   121  }
   122  
   123  // Contains returns true if the map contains 'k'.
   124  func (m *Map[K, V]) Contains(k K) bool {
   125  	if m == nil {
   126  		return false
   127  	}
   128  	return m.findValue(k) >= 0
   129  }
   130  
   131  // Set upserts a key-value pair in the map and returns the previous value if updated.
   132  func (m *Map[K, V]) Set(k K, v V) (prev V, updated bool) {
   133  	if len(m.items) <= 0 {
   134  		m.items = make([]hashItem[K, V], 8)
   135  	}
   136  	if int(m.count) >= len(m.items)*3/4 {
   137  		m.resizeHash(len(m.items) * 2)
   138  	}
   139  	return m.setHash(hashItem[K, V]{Key: k, Value: v, occupied: true})
   140  }
   141  
   142  // Delete deletes a key from the map, returns deleted value if existed.
   143  func (m *Map[K, V]) Delete(k K) (prev V, ok bool) {
   144  	idx := m.findValue(k)
   145  	if idx < 0 {
   146  		return prev, false
   147  	}
   148  	prev = m.items[idx].Value
   149  
   150  	// Shift the following keys forward
   151  	num := len(m.items)
   152  	startIdx := idx
   153  	current := idx
   154  
   155  NEXT:
   156  	next := (current + 1) % num
   157  	if m.items[next].dist > 0 {
   158  		m.items[current] = m.items[next]
   159  		m.items[current].dist--
   160  		current = next
   161  		if current != startIdx {
   162  			goto NEXT
   163  		}
   164  	} else {
   165  		m.items[current] = hashItem[K, V]{}
   166  	}
   167  
   168  	m.count--
   169  	return prev, true
   170  }
   171  
   172  func (m *Map[K, V]) setHash(incoming hashItem[K, V]) (prev V, updated bool) {
   173  	num := len(m.items)
   174  	idx := int(m.hash(incoming.Key) % uint64(num))
   175  
   176  	for idxStart := idx; ; {
   177  		e := &m.items[idx]
   178  
   179  		if !e.occupied {
   180  			m.items[idx] = incoming
   181  			m.count++
   182  			return
   183  		}
   184  
   185  		if e.Key == incoming.Key {
   186  			prev = e.Value
   187  			e.Value, e.dist = incoming.Value, incoming.dist
   188  			return prev, true
   189  		}
   190  
   191  		// Swap if the incoming item is further from its best idx.
   192  		if e.dist < incoming.dist {
   193  			incoming, m.items[idx] = m.items[idx], incoming
   194  		}
   195  
   196  		incoming.dist++ // one step further away from best idx.
   197  		idx = (idx + 1) % num
   198  
   199  		if idx == idxStart {
   200  			if m.Fixed {
   201  				panic("fixed map is full")
   202  			} else {
   203  				panic("fatal: space not enough")
   204  			}
   205  		}
   206  	}
   207  }
   208  
   209  // Foreach iterates all keys in the map, for each of them, 'f(key, &value)' will be
   210  // called. Values are passed by pointers and it is legal to manipulate them directly in 'f'.
   211  func (m *Map[K, V]) Foreach(f func(K, *V) bool) {
   212  	if m == nil {
   213  		return
   214  	}
   215  	for i := 0; i < len(m.items); i++ {
   216  		ip := &m.items[i]
   217  		if ip.occupied {
   218  			if !f(ip.Key, &ip.Value) {
   219  				return
   220  			}
   221  		}
   222  	}
   223  }
   224  
   225  // Keys returns all keys in the map as list.
   226  func (m *Map[K, V]) Keys() (res []K) {
   227  	if m == nil {
   228  		return
   229  	}
   230  	for i := 0; i < len(m.items); i++ {
   231  		ip := &m.items[i]
   232  		if ip.occupied {
   233  			res = append(res, ip.Key)
   234  		}
   235  	}
   236  	return
   237  }
   238  
   239  // Values returns all values in the map as list.
   240  func (m *Map[K, V]) Values() (res []V) {
   241  	if m == nil {
   242  		return
   243  	}
   244  	for i := 0; i < len(m.items); i++ {
   245  		ip := &m.items[i]
   246  		if ip.occupied {
   247  			res = append(res, ip.Value)
   248  		}
   249  	}
   250  	return
   251  }
   252  
   253  func (m *Map[K, V]) nextItem(idx int) (int, *hashItem[K, V]) {
   254  	for i := idx; i < len(m.items); i++ {
   255  		if p := &m.items[i]; p.occupied {
   256  			return i, p
   257  		}
   258  	}
   259  	return 0, nil
   260  }
   261  
   262  func (m *Map[K, V]) First() *hashItem[K, V] {
   263  	if m == nil {
   264  		return nil
   265  	}
   266  	for i := range m.items {
   267  		if m.items[i].occupied {
   268  			return &m.items[i]
   269  		}
   270  	}
   271  	return nil
   272  }
   273  
   274  func (m *Map[K, V]) Next(el *hashItem[K, V]) *hashItem[K, V] {
   275  	if len(m.items) == 0 {
   276  		return nil
   277  	}
   278  	hashItemSize := unsafe.Sizeof(hashItem[K, V]{})
   279  	for el != &m.items[len(m.items)-1] {
   280  		ptr := uintptr(unsafe.Pointer(el)) + hashItemSize
   281  		el = (*hashItem[K, V])(unsafe.Pointer(ptr))
   282  		if el.occupied {
   283  			return el
   284  		}
   285  	}
   286  	return nil
   287  }
   288  
   289  func (m *Map[K, V]) Copy() *Map[K, V] {
   290  	m2 := *m
   291  	m2.items = append([]hashItem[K, V]{}, m.items...)
   292  	return &m2
   293  }
   294  
   295  func (m *Map[K, V]) Merge(src *Map[K, V]) *Map[K, V] {
   296  	if src.Len() > 0 {
   297  		m.resizeHash((m.Len() + src.Len()) * 2)
   298  		src.Foreach(func(k K, v *V) bool { m.Set(k, *v); return true })
   299  	}
   300  	return m
   301  }
   302  
   303  func (m *Map[K, V]) resizeHash(newSize int) {
   304  	if m.Fixed {
   305  		return
   306  	}
   307  	if newSize <= len(m.items) {
   308  		return
   309  	}
   310  	tmp := *m
   311  	tmp.items = make([]hashItem[K, V], newSize)
   312  	for _, e := range m.items {
   313  		if e.occupied {
   314  			e.dist = 0
   315  			tmp.setHash(e)
   316  		}
   317  	}
   318  	m.items = tmp.items
   319  }
   320  
   321  func (m *Map[K, V]) density() float64 {
   322  	num := len(m.items)
   323  	if num <= 0 || m.count <= 0 {
   324  		return math.NaN()
   325  	}
   326  
   327  	var maxRun int
   328  	for i := 0; i < num; {
   329  		if !m.items[i].occupied {
   330  			i++
   331  			continue
   332  		}
   333  		run := 1
   334  		for i++; i < num; i++ {
   335  			if m.items[i].occupied {
   336  				run++
   337  			} else {
   338  				break
   339  			}
   340  		}
   341  		if run > maxRun {
   342  			maxRun = run
   343  		}
   344  	}
   345  	return float64(maxRun) / (float64(num) / float64(m.count))
   346  }
   347  
   348  func (m *Map[K, V]) String() string {
   349  	if m == nil {
   350  		return "{}"
   351  	}
   352  	p, f := bytes.NewBufferString("{"), false
   353  	for _, i := range m.items {
   354  		if !i.occupied {
   355  			continue
   356  		}
   357  		fmt.Fprintf(p, "%v: %v, ", i.Key, i.Value)
   358  		f = true
   359  	}
   360  	if f {
   361  		p.Truncate(p.Len() - 2)
   362  	}
   363  	p.WriteString("}")
   364  	return p.String()
   365  }
   366  
   367  func (m *Map[K, V]) GoString() string {
   368  	if m == nil {
   369  		return "{}"
   370  	}
   371  	w := "                "[:int(math.Ceil(math.Log10(float64(len(m.items)))))]
   372  	itoa := func(i int) string {
   373  		s := strconv.Itoa(i)
   374  		return w[:len(w)-len(s)] + s
   375  	}
   376  	p := bytes.Buffer{}
   377  	var maxDist uint32
   378  	for idx, i := range m.items {
   379  		p.WriteString(itoa(idx) + ":")
   380  		if !i.occupied {
   381  			p.WriteString(w)
   382  			p.WriteString(" \t-\n")
   383  		} else {
   384  			at := m.hash(i.Key) % uint64(len(m.items))
   385  			if i.dist > 0 {
   386  				p.WriteString("^")
   387  				p.WriteString(itoa(int(at)))
   388  				if i.dist > uint32(maxDist) {
   389  					maxDist = i.dist
   390  				}
   391  			} else {
   392  				p.WriteString(w)
   393  				p.WriteString(" ")
   394  			}
   395  			p.WriteString("\t" + strings.Repeat(".", int(i.dist)) + fmt.Sprintf("%v\n", i.Key))
   396  		}
   397  	}
   398  	fmt.Fprintf(&p, "max distance: %d", maxDist)
   399  	return p.String()
   400  }